Merge pull request #9233 from eclipse/ag_junit5

Upgrade dl4j to junit 5
master
Adam Gibson 2021-03-19 21:26:43 +09:00 committed by GitHub
commit c505a11ed6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1022 changed files with 28893 additions and 38578 deletions

View File

@ -31,7 +31,7 @@ jobs:
protoc --version protoc --version
cd dl4j-test-resources-master && mvn clean install -DskipTests && cd .. cd dl4j-test-resources-master && mvn clean install -DskipTests && cd ..
export OMP_NUM_THREADS=1 export OMP_NUM_THREADS=1
mvn -Pintegration-tests -Ptestresources -Dlibnd4j.buildthreads=1 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test mvn -Pintegration-tests -Ptestresources -Dlibnd4j.buildthreads=1 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
windows-x86_64: windows-x86_64:
runs-on: windows-2019 runs-on: windows-2019
@ -44,7 +44,7 @@ jobs:
run: | run: |
set "PATH=C:\msys64\usr\bin;%PATH%" set "PATH=C:\msys64\usr\bin;%PATH%"
export OMP_NUM_THREADS=1 export OMP_NUM_THREADS=1
mvn -DskipTestResourceEnforcement=true -Pintegration-tests -Ptestresources -Dlibnd4j.buildthreads=1 -Dlibnd4j.build="Debug" -Djavacpp.platform=windows-x86_64 -libnd4j.platform=windows-x86_64 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test mvn -DskipTestResourceEnforcement=true -Pintegration-tests -Ptestresources -Dlibnd4j.buildthreads=1 -Dlibnd4j.build="Debug" -Djavacpp.platform=windows-x86_64 -libnd4j.platform=windows-x86_64 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
@ -60,5 +60,5 @@ jobs:
run: | run: |
brew install unzip ccache gcc swig autoconf-archive automake cmake libomp libtool libusb ant maven nasm xz pkg-config sdl gpg1 bison flex perl ragel binutils gradle gmp isl libmpc mpfr wget python brew install unzip ccache gcc swig autoconf-archive automake cmake libomp libtool libusb ant maven nasm xz pkg-config sdl gpg1 bison flex perl ragel binutils gradle gmp isl libmpc mpfr wget python
export OMP_NUM_THREADS=1 export OMP_NUM_THREADS=1
mvn -Pintegration-tests -Dlibnd4j.build="Debug" -Dlibnd4j.buildthreads=1 -Ptestresources -Djavacpp.platform=macosx-x86_64 -libnd4j.platform=macosx-x86_64 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test mvn -Pintegration-tests -Dlibnd4j.build="Debug" -Dlibnd4j.buildthreads=1 -Ptestresources -Djavacpp.platform=macosx-x86_64 -libnd4j.platform=macosx-x86_64 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test

View File

@ -31,7 +31,7 @@ jobs:
protoc --version protoc --version
cd dl4j-test-resources-master && mvn clean install -DskipTests && cd .. cd dl4j-test-resources-master && mvn clean install -DskipTests && cd ..
export OMP_NUM_THREADS=1 export OMP_NUM_THREADS=1
mvn -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Dlibnd4j.buildthreads=1 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test mvn -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Dlibnd4j.buildthreads=1 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
windows-x86_64: windows-x86_64:
runs-on: windows-2019 runs-on: windows-2019
@ -44,7 +44,7 @@ jobs:
run: | run: |
set "PATH=C:\msys64\usr\bin;%PATH%" set "PATH=C:\msys64\usr\bin;%PATH%"
export OMP_NUM_THREADS=1 export OMP_NUM_THREADS=1
mvn -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -DskipTestResourceEnforcement=true -Ptestresources -Dlibnd4j.buildthreads=1 -Dlibnd4j.build="Debug" -Djavacpp.platform=windows-x86_64 -libnd4j.platform=windows-x86_64 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test mvn -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -DskipTestResourceEnforcement=true -Ptestresources -Dlibnd4j.buildthreads=1 -Dlibnd4j.build="Debug" -Djavacpp.platform=windows-x86_64 -libnd4j.platform=windows-x86_64 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
@ -60,5 +60,5 @@ jobs:
run: | run: |
brew install unzip ccache gcc swig autoconf-archive automake cmake libomp libtool libusb ant maven nasm xz pkg-config sdl gpg1 bison flex perl ragel binutils gradle gmp isl libmpc mpfr wget python brew install unzip ccache gcc swig autoconf-archive automake cmake libomp libtool libusb ant maven nasm xz pkg-config sdl gpg1 bison flex perl ragel binutils gradle gmp isl libmpc mpfr wget python
export OMP_NUM_THREADS=1 export OMP_NUM_THREADS=1
mvn -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Dlibnd4j.build="Debug" -Dlibnd4j.buildthreads=1 -Ptestresources -Djavacpp.platform=macosx-x86_64 -libnd4j.platform=macosx-x86_64 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test mvn -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Dlibnd4j.build="Debug" -Dlibnd4j.buildthreads=1 -Ptestresources -Djavacpp.platform=macosx-x86_64 -libnd4j.platform=macosx-x86_64 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test

View File

@ -1,11 +1,3 @@
on:
workflow_dispatch:
jobs:
# Wait for up to a minute for previous run to complete, abort if not done by then
pre-ci:
run
on: on:
workflow_dispatch: workflow_dispatch:
jobs: jobs:
@ -42,5 +34,5 @@ jobs:
cmake --version cmake --version
protoc --version protoc --version
export OMP_NUM_THREADS=1 export OMP_NUM_THREADS=1
mvn -DskipTestResourceEnforcement=true -Ptestresources -Pintegration-tests -Pdl4j-integration-tests -Pnd4j-tests-cpu clean test mvn -DskipTestResourceEnforcement=true -Ptestresources -Pintegration-tests -Pnd4j-tests-cpu clean test -rf :rl4j-core

View File

@ -34,5 +34,5 @@ jobs:
cmake --version cmake --version
protoc --version protoc --version
export OMP_NUM_THREADS=1 export OMP_NUM_THREADS=1
mvn -DskipTestResourceEnforcement=true -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Ptest-nd4j-native --also-make clean test mvn -DskipTestResourceEnforcement=true -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Pnd4j-tests-cpu --also-make clean test

View File

@ -15,7 +15,7 @@
<commons.dbutils.version>1.7</commons.dbutils.version> <commons.dbutils.version>1.7</commons.dbutils.version>
<lombok.version>1.18.8</lombok.version> <lombok.version>1.18.8</lombok.version>
<logback.version>1.1.7</logback.version> <logback.version>1.1.7</logback.version>
<junit.version>4.12</junit.version> <junit.version>5.8.0-M1</junit.version>
<junit-jupiter.version>5.4.2</junit-jupiter.version> <junit-jupiter.version>5.4.2</junit-jupiter.version>
<java.version>1.8</java.version> <java.version>1.8</java.version>
<maven-shade-plugin.version>3.1.1</maven-shade-plugin.version> <maven-shade-plugin.version>3.1.1</maven-shade-plugin.version>

View File

@ -17,13 +17,14 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.nd4j.codegen.ir; package org.nd4j.codegen.ir;
public class SerializationTest { import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
public static void main(String...args) { @DisplayName("Serialization Test")
class SerializationTest {
public static void main(String... args) {
} }
} }

View File

@ -17,29 +17,23 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.nd4j.codegen.dsl; package org.nd4j.codegen.dsl;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.codegen.impl.java.DocsGenerator; import org.nd4j.codegen.impl.java.DocsGenerator;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
public class DocsGeneratorTest { @DisplayName("Docs Generator Test")
class DocsGeneratorTest {
@Test @Test
public void testJDtoMDAdapter() { @DisplayName("Test J Dto MD Adapter")
String original = "{@code %INPUT_TYPE% eye = eye(3,2)\n" + void testJDtoMDAdapter() {
" eye:\n" + String original = "{@code %INPUT_TYPE% eye = eye(3,2)\n" + " eye:\n" + " [ 1, 0]\n" + " [ 0, 1]\n" + " [ 0, 0]}";
" [ 1, 0]\n" + String expected = "{ INDArray eye = eye(3,2)\n" + " eye:\n" + " [ 1, 0]\n" + " [ 0, 1]\n" + " [ 0, 0]}";
" [ 0, 1]\n" +
" [ 0, 0]}";
String expected = "{ INDArray eye = eye(3,2)\n" +
" eye:\n" +
" [ 1, 0]\n" +
" [ 0, 1]\n" +
" [ 0, 0]}";
DocsGenerator.JavaDocToMDAdapter adapter = new DocsGenerator.JavaDocToMDAdapter(original); DocsGenerator.JavaDocToMDAdapter adapter = new DocsGenerator.JavaDocToMDAdapter(original);
String out = adapter.filter("@code", StringUtils.EMPTY).filter("%INPUT_TYPE%", "INDArray").toString(); String out = adapter.filter("@code", StringUtils.EMPTY).filter("%INPUT_TYPE%", "INDArray").toString();
assertEquals(out, expected); assertEquals(out, expected);

View File

@ -34,6 +34,14 @@
<artifactId>datavec-api</artifactId> <artifactId>datavec-api</artifactId>
<dependencies> <dependencies>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-api</artifactId>
</dependency>
<dependency>
<groupId>org.junit.vintage</groupId>
<artifactId>junit-vintage-engine</artifactId>
</dependency>
<dependency> <dependency>
<groupId>org.apache.commons</groupId> <groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId> <artifactId>commons-lang3</artifactId>
@ -101,10 +109,10 @@
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>nd4j-tests-cpu</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-11.0</id> <id>nd4j-tests-cuda</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.records.reader.impl; package org.datavec.api.records.reader.impl;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
@ -26,47 +25,38 @@ import org.datavec.api.records.reader.impl.csv.CSVLineSequenceRecordReader;
import org.datavec.api.split.FileSplit; import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.common.tests.BaseND4JTest;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;
import java.io.File; import java.io.File;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals; @DisplayName("Csv Line Sequence Record Reader Test")
class CSVLineSequenceRecordReaderTest extends BaseND4JTest {
public class CSVLineSequenceRecordReaderTest extends BaseND4JTest { @TempDir
public Path testDir;
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test @Test
public void test() throws Exception { @DisplayName("Test")
void test(@TempDir Path testDir) throws Exception {
File f = testDir.newFolder(); File f = testDir.toFile();
File source = new File(f, "temp.csv"); File source = new File(f, "temp.csv");
String str = "a,b,c\n1,2,3,4"; String str = "a,b,c\n1,2,3,4";
FileUtils.writeStringToFile(source, str, StandardCharsets.UTF_8); FileUtils.writeStringToFile(source, str, StandardCharsets.UTF_8);
SequenceRecordReader rr = new CSVLineSequenceRecordReader(); SequenceRecordReader rr = new CSVLineSequenceRecordReader();
rr.initialize(new FileSplit(source)); rr.initialize(new FileSplit(source));
List<List<Writable>> exp0 = Arrays.asList(Collections.<Writable>singletonList(new Text("a")), Collections.<Writable>singletonList(new Text("b")), Collections.<Writable>singletonList(new Text("c")));
List<List<Writable>> exp0 = Arrays.asList( List<List<Writable>> exp1 = Arrays.asList(Collections.<Writable>singletonList(new Text("1")), Collections.<Writable>singletonList(new Text("2")), Collections.<Writable>singletonList(new Text("3")), Collections.<Writable>singletonList(new Text("4")));
Collections.<Writable>singletonList(new Text("a")), for (int i = 0; i < 3; i++) {
Collections.<Writable>singletonList(new Text("b")),
Collections.<Writable>singletonList(new Text("c")));
List<List<Writable>> exp1 = Arrays.asList(
Collections.<Writable>singletonList(new Text("1")),
Collections.<Writable>singletonList(new Text("2")),
Collections.<Writable>singletonList(new Text("3")),
Collections.<Writable>singletonList(new Text("4")));
for( int i=0; i<3; i++ ) {
int count = 0; int count = 0;
while (rr.hasNext()) { while (rr.hasNext()) {
List<List<Writable>> next = rr.sequenceRecord(); List<List<Writable>> next = rr.sequenceRecord();
@ -76,9 +66,7 @@ public class CSVLineSequenceRecordReaderTest extends BaseND4JTest {
assertEquals(exp1, next); assertEquals(exp1, next);
} }
} }
assertEquals(2, count); assertEquals(2, count);
rr.reset(); rr.reset();
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.records.reader.impl; package org.datavec.api.records.reader.impl;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
@ -26,33 +25,37 @@ import org.datavec.api.records.reader.impl.csv.CSVMultiSequenceRecordReader;
import org.datavec.api.split.FileSplit; import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.common.tests.BaseND4JTest;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;
import java.io.File; import java.io.File;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals; @DisplayName("Csv Multi Sequence Record Reader Test")
import static org.junit.Assert.assertFalse; class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest { @TempDir
public Path testDir;
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test @Test
public void testConcatMode() throws Exception { @DisplayName("Test Concat Mode")
for( int i=0; i<3; i++ ) { @Disabled
void testConcatMode() throws Exception {
for (int i = 0; i < 3; i++) {
String seqSep; String seqSep;
String seqSepRegex; String seqSepRegex;
switch (i){ switch(i) {
case 0: case 0:
seqSep = ""; seqSep = "";
seqSepRegex = "^$"; seqSepRegex = "^$";
@ -68,31 +71,23 @@ public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
default: default:
throw new RuntimeException(); throw new RuntimeException();
} }
String str = "a,b,c\n1,2,3,4\nx,y\n" + seqSep + "\nA,B,C"; String str = "a,b,c\n1,2,3,4\nx,y\n" + seqSep + "\nA,B,C";
File f = testDir.newFile(); File f = testDir.toFile();
FileUtils.writeStringToFile(f, str, StandardCharsets.UTF_8); FileUtils.writeStringToFile(f, str, StandardCharsets.UTF_8);
SequenceRecordReader seqRR = new CSVMultiSequenceRecordReader(seqSepRegex, CSVMultiSequenceRecordReader.Mode.CONCAT); SequenceRecordReader seqRR = new CSVMultiSequenceRecordReader(seqSepRegex, CSVMultiSequenceRecordReader.Mode.CONCAT);
seqRR.initialize(new FileSplit(f)); seqRR.initialize(new FileSplit(f));
List<List<Writable>> exp0 = new ArrayList<>(); List<List<Writable>> exp0 = new ArrayList<>();
for (String s : "a,b,c,1,2,3,4,x,y".split(",")) { for (String s : "a,b,c,1,2,3,4,x,y".split(",")) {
exp0.add(Collections.<Writable>singletonList(new Text(s))); exp0.add(Collections.<Writable>singletonList(new Text(s)));
} }
List<List<Writable>> exp1 = new ArrayList<>(); List<List<Writable>> exp1 = new ArrayList<>();
for (String s : "A,B,C".split(",")) { for (String s : "A,B,C".split(",")) {
exp1.add(Collections.<Writable>singletonList(new Text(s))); exp1.add(Collections.<Writable>singletonList(new Text(s)));
} }
assertEquals(exp0, seqRR.sequenceRecord()); assertEquals(exp0, seqRR.sequenceRecord());
assertEquals(exp1, seqRR.sequenceRecord()); assertEquals(exp1, seqRR.sequenceRecord());
assertFalse(seqRR.hasNext()); assertFalse(seqRR.hasNext());
seqRR.reset(); seqRR.reset();
assertEquals(exp0, seqRR.sequenceRecord()); assertEquals(exp0, seqRR.sequenceRecord());
assertEquals(exp1, seqRR.sequenceRecord()); assertEquals(exp1, seqRR.sequenceRecord());
assertFalse(seqRR.hasNext()); assertFalse(seqRR.hasNext());
@ -100,13 +95,13 @@ public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
} }
@Test @Test
public void testEqualLength() throws Exception { @DisplayName("Test Equal Length")
@Disabled
for( int i=0; i<3; i++ ) { void testEqualLength() throws Exception {
for (int i = 0; i < 3; i++) {
String seqSep; String seqSep;
String seqSepRegex; String seqSepRegex;
switch (i) { switch(i) {
case 0: case 0:
seqSep = ""; seqSep = "";
seqSepRegex = "^$"; seqSepRegex = "^$";
@ -122,27 +117,17 @@ public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
default: default:
throw new RuntimeException(); throw new RuntimeException();
} }
String str = "a,b\n1,2\nx,y\n" + seqSep + "\nA\nB\nC"; String str = "a,b\n1,2\nx,y\n" + seqSep + "\nA\nB\nC";
File f = testDir.newFile(); File f = testDir.toFile();
FileUtils.writeStringToFile(f, str, StandardCharsets.UTF_8); FileUtils.writeStringToFile(f, str, StandardCharsets.UTF_8);
SequenceRecordReader seqRR = new CSVMultiSequenceRecordReader(seqSepRegex, CSVMultiSequenceRecordReader.Mode.EQUAL_LENGTH); SequenceRecordReader seqRR = new CSVMultiSequenceRecordReader(seqSepRegex, CSVMultiSequenceRecordReader.Mode.EQUAL_LENGTH);
seqRR.initialize(new FileSplit(f)); seqRR.initialize(new FileSplit(f));
List<List<Writable>> exp0 = Arrays.asList(Arrays.<Writable>asList(new Text("a"), new Text("1"), new Text("x")), Arrays.<Writable>asList(new Text("b"), new Text("2"), new Text("y")));
List<List<Writable>> exp0 = Arrays.asList(
Arrays.<Writable>asList(new Text("a"), new Text("1"), new Text("x")),
Arrays.<Writable>asList(new Text("b"), new Text("2"), new Text("y")));
List<List<Writable>> exp1 = Collections.singletonList(Arrays.<Writable>asList(new Text("A"), new Text("B"), new Text("C"))); List<List<Writable>> exp1 = Collections.singletonList(Arrays.<Writable>asList(new Text("A"), new Text("B"), new Text("C")));
assertEquals(exp0, seqRR.sequenceRecord()); assertEquals(exp0, seqRR.sequenceRecord());
assertEquals(exp1, seqRR.sequenceRecord()); assertEquals(exp1, seqRR.sequenceRecord());
assertFalse(seqRR.hasNext()); assertFalse(seqRR.hasNext());
seqRR.reset(); seqRR.reset();
assertEquals(exp0, seqRR.sequenceRecord()); assertEquals(exp0, seqRR.sequenceRecord());
assertEquals(exp1, seqRR.sequenceRecord()); assertEquals(exp1, seqRR.sequenceRecord());
assertFalse(seqRR.hasNext()); assertFalse(seqRR.hasNext());
@ -150,13 +135,13 @@ public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
} }
@Test @Test
public void testPadding() throws Exception { @DisplayName("Test Padding")
@Disabled
for( int i=0; i<3; i++ ) { void testPadding() throws Exception {
for (int i = 0; i < 3; i++) {
String seqSep; String seqSep;
String seqSepRegex; String seqSepRegex;
switch (i) { switch(i) {
case 0: case 0:
seqSep = ""; seqSep = "";
seqSepRegex = "^$"; seqSepRegex = "^$";
@ -172,27 +157,17 @@ public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
default: default:
throw new RuntimeException(); throw new RuntimeException();
} }
String str = "a,b\n1\nx\n" + seqSep + "\nA\nB\nC"; String str = "a,b\n1\nx\n" + seqSep + "\nA\nB\nC";
File f = testDir.newFile(); File f = testDir.toFile();
FileUtils.writeStringToFile(f, str, StandardCharsets.UTF_8); FileUtils.writeStringToFile(f, str, StandardCharsets.UTF_8);
SequenceRecordReader seqRR = new CSVMultiSequenceRecordReader(seqSepRegex, CSVMultiSequenceRecordReader.Mode.PAD, new Text("PAD")); SequenceRecordReader seqRR = new CSVMultiSequenceRecordReader(seqSepRegex, CSVMultiSequenceRecordReader.Mode.PAD, new Text("PAD"));
seqRR.initialize(new FileSplit(f)); seqRR.initialize(new FileSplit(f));
List<List<Writable>> exp0 = Arrays.asList(Arrays.<Writable>asList(new Text("a"), new Text("1"), new Text("x")), Arrays.<Writable>asList(new Text("b"), new Text("PAD"), new Text("PAD")));
List<List<Writable>> exp0 = Arrays.asList(
Arrays.<Writable>asList(new Text("a"), new Text("1"), new Text("x")),
Arrays.<Writable>asList(new Text("b"), new Text("PAD"), new Text("PAD")));
List<List<Writable>> exp1 = Collections.singletonList(Arrays.<Writable>asList(new Text("A"), new Text("B"), new Text("C"))); List<List<Writable>> exp1 = Collections.singletonList(Arrays.<Writable>asList(new Text("A"), new Text("B"), new Text("C")));
assertEquals(exp0, seqRR.sequenceRecord()); assertEquals(exp0, seqRR.sequenceRecord());
assertEquals(exp1, seqRR.sequenceRecord()); assertEquals(exp1, seqRR.sequenceRecord());
assertFalse(seqRR.hasNext()); assertFalse(seqRR.hasNext());
seqRR.reset(); seqRR.reset();
assertEquals(exp0, seqRR.sequenceRecord()); assertEquals(exp0, seqRR.sequenceRecord());
assertEquals(exp1, seqRR.sequenceRecord()); assertEquals(exp1, seqRR.sequenceRecord());
assertFalse(seqRR.hasNext()); assertFalse(seqRR.hasNext());

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.records.reader.impl; package org.datavec.api.records.reader.impl;
import org.datavec.api.records.SequenceRecord; import org.datavec.api.records.SequenceRecord;
@ -27,61 +26,53 @@ import org.datavec.api.records.reader.impl.csv.CSVNLinesSequenceRecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit; import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals; @DisplayName("Csvn Lines Sequence Record Reader Test")
class CSVNLinesSequenceRecordReaderTest extends BaseND4JTest {
public class CSVNLinesSequenceRecordReaderTest extends BaseND4JTest {
@Test @Test
public void testCSVNLinesSequenceRecordReader() throws Exception { @DisplayName("Test CSVN Lines Sequence Record Reader")
void testCSVNLinesSequenceRecordReader() throws Exception {
int nLinesPerSequence = 10; int nLinesPerSequence = 10;
SequenceRecordReader seqRR = new CSVNLinesSequenceRecordReader(nLinesPerSequence); SequenceRecordReader seqRR = new CSVNLinesSequenceRecordReader(nLinesPerSequence);
seqRR.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile())); seqRR.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
CSVRecordReader rr = new CSVRecordReader(); CSVRecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile())); rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
int count = 0; int count = 0;
while (seqRR.hasNext()) { while (seqRR.hasNext()) {
List<List<Writable>> next = seqRR.sequenceRecord(); List<List<Writable>> next = seqRR.sequenceRecord();
List<List<Writable>> expected = new ArrayList<>(); List<List<Writable>> expected = new ArrayList<>();
for (int i = 0; i < nLinesPerSequence; i++) { for (int i = 0; i < nLinesPerSequence; i++) {
expected.add(rr.next()); expected.add(rr.next());
} }
assertEquals(10, next.size()); assertEquals(10, next.size());
assertEquals(expected, next); assertEquals(expected, next);
count++; count++;
} }
assertEquals(150 / nLinesPerSequence, count); assertEquals(150 / nLinesPerSequence, count);
} }
@Test @Test
public void testCSVNlinesSequenceRecordReaderMetaData() throws Exception { @DisplayName("Test CSV Nlines Sequence Record Reader Meta Data")
void testCSVNlinesSequenceRecordReaderMetaData() throws Exception {
int nLinesPerSequence = 10; int nLinesPerSequence = 10;
SequenceRecordReader seqRR = new CSVNLinesSequenceRecordReader(nLinesPerSequence); SequenceRecordReader seqRR = new CSVNLinesSequenceRecordReader(nLinesPerSequence);
seqRR.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile())); seqRR.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
CSVRecordReader rr = new CSVRecordReader(); CSVRecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile())); rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
List<List<List<Writable>>> out = new ArrayList<>(); List<List<List<Writable>>> out = new ArrayList<>();
while (seqRR.hasNext()) { while (seqRR.hasNext()) {
List<List<Writable>> next = seqRR.sequenceRecord(); List<List<Writable>> next = seqRR.sequenceRecord();
out.add(next); out.add(next);
} }
seqRR.reset(); seqRR.reset();
List<List<List<Writable>>> out2 = new ArrayList<>(); List<List<List<Writable>>> out2 = new ArrayList<>();
List<SequenceRecord> out3 = new ArrayList<>(); List<SequenceRecord> out3 = new ArrayList<>();
@ -92,11 +83,8 @@ public class CSVNLinesSequenceRecordReaderTest extends BaseND4JTest {
meta.add(seq.getMetaData()); meta.add(seq.getMetaData());
out3.add(seq); out3.add(seq);
} }
assertEquals(out, out2); assertEquals(out, out2);
List<SequenceRecord> out4 = seqRR.loadSequenceFromMetaData(meta); List<SequenceRecord> out4 = seqRR.loadSequenceFromMetaData(meta);
assertEquals(out3, out4); assertEquals(out3, out4);
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.records.reader.impl; package org.datavec.api.records.reader.impl;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
@ -34,10 +33,10 @@ import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.nio.file.Files; import java.nio.file.Files;
@ -47,41 +46,44 @@ import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.NoSuchElementException; import java.util.NoSuchElementException;
import static org.junit.Assert.*; import static org.junit.jupiter.api.Assertions.*;
@DisplayName("Csv Record Reader Test")
class CSVRecordReaderTest extends BaseND4JTest {
public class CSVRecordReaderTest extends BaseND4JTest {
@Test @Test
public void testNext() throws Exception { @DisplayName("Test Next")
void testNext() throws Exception {
CSVRecordReader reader = new CSVRecordReader(); CSVRecordReader reader = new CSVRecordReader();
reader.initialize(new StringSplit("1,1,8.0,,,,14.0,,,,15.0,,,,,,,,,,,,1")); reader.initialize(new StringSplit("1,1,8.0,,,,14.0,,,,15.0,,,,,,,,,,,,1"));
while (reader.hasNext()) { while (reader.hasNext()) {
List<Writable> vals = reader.next(); List<Writable> vals = reader.next();
List<Writable> arr = new ArrayList<>(vals); List<Writable> arr = new ArrayList<>(vals);
assertEquals(23, vals.size(), "Entry count");
assertEquals("Entry count", 23, vals.size());
Text lastEntry = (Text) arr.get(arr.size() - 1); Text lastEntry = (Text) arr.get(arr.size() - 1);
assertEquals("Last entry garbage", 1, lastEntry.getLength()); assertEquals(1, lastEntry.getLength(), "Last entry garbage");
} }
} }
@Test @Test
public void testEmptyEntries() throws Exception { @DisplayName("Test Empty Entries")
void testEmptyEntries() throws Exception {
CSVRecordReader reader = new CSVRecordReader(); CSVRecordReader reader = new CSVRecordReader();
reader.initialize(new StringSplit("1,1,8.0,,,,14.0,,,,15.0,,,,,,,,,,,,")); reader.initialize(new StringSplit("1,1,8.0,,,,14.0,,,,15.0,,,,,,,,,,,,"));
while (reader.hasNext()) { while (reader.hasNext()) {
List<Writable> vals = reader.next(); List<Writable> vals = reader.next();
assertEquals("Entry count", 23, vals.size()); assertEquals(23, vals.size(), "Entry count");
} }
} }
@Test @Test
public void testReset() throws Exception { @DisplayName("Test Reset")
void testReset() throws Exception {
CSVRecordReader rr = new CSVRecordReader(0, ','); CSVRecordReader rr = new CSVRecordReader(0, ',');
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile())); rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
int nResets = 5; int nResets = 5;
for (int i = 0; i < nResets; i++) { for (int i = 0; i < nResets; i++) {
int lineCount = 0; int lineCount = 0;
while (rr.hasNext()) { while (rr.hasNext()) {
List<Writable> line = rr.next(); List<Writable> line = rr.next();
@ -95,7 +97,8 @@ public class CSVRecordReaderTest extends BaseND4JTest {
} }
@Test @Test
public void testResetWithSkipLines() throws Exception { @DisplayName("Test Reset With Skip Lines")
void testResetWithSkipLines() throws Exception {
CSVRecordReader rr = new CSVRecordReader(10, ','); CSVRecordReader rr = new CSVRecordReader(10, ',');
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile())); rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
int lineCount = 0; int lineCount = 0;
@ -114,7 +117,8 @@ public class CSVRecordReaderTest extends BaseND4JTest {
} }
@Test @Test
public void testWrite() throws Exception { @DisplayName("Test Write")
void testWrite() throws Exception {
List<List<Writable>> list = new ArrayList<>(); List<List<Writable>> list = new ArrayList<>();
StringBuilder sb = new StringBuilder(); StringBuilder sb = new StringBuilder();
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
@ -130,81 +134,72 @@ public class CSVRecordReaderTest extends BaseND4JTest {
} }
list.add(temp); list.add(temp);
} }
String expected = sb.toString(); String expected = sb.toString();
Path p = Files.createTempFile("csvwritetest", "csv"); Path p = Files.createTempFile("csvwritetest", "csv");
p.toFile().deleteOnExit(); p.toFile().deleteOnExit();
FileRecordWriter writer = new CSVRecordWriter(); FileRecordWriter writer = new CSVRecordWriter();
FileSplit fileSplit = new FileSplit(p.toFile()); FileSplit fileSplit = new FileSplit(p.toFile());
writer.initialize(fileSplit,new NumberOfRecordsPartitioner()); writer.initialize(fileSplit, new NumberOfRecordsPartitioner());
for (List<Writable> c : list) { for (List<Writable> c : list) {
writer.write(c); writer.write(c);
} }
writer.close(); writer.close();
// Read file back in; compare
//Read file back in; compare
String fileContents = FileUtils.readFileToString(p.toFile(), FileRecordWriter.DEFAULT_CHARSET.name()); String fileContents = FileUtils.readFileToString(p.toFile(), FileRecordWriter.DEFAULT_CHARSET.name());
// System.out.println(expected); // System.out.println(expected);
// System.out.println("----------"); // System.out.println("----------");
// System.out.println(fileContents); // System.out.println(fileContents);
assertEquals(expected, fileContents); assertEquals(expected, fileContents);
} }
@Test @Test
public void testTabsAsSplit1() throws Exception { @DisplayName("Test Tabs As Split 1")
void testTabsAsSplit1() throws Exception {
CSVRecordReader reader = new CSVRecordReader(0, '\t'); CSVRecordReader reader = new CSVRecordReader(0, '\t');
reader.initialize(new FileSplit(new ClassPathResource("datavec-api/tabbed.txt").getFile())); reader.initialize(new FileSplit(new ClassPathResource("datavec-api/tabbed.txt").getFile()));
while (reader.hasNext()) { while (reader.hasNext()) {
List<Writable> list = new ArrayList<>(reader.next()); List<Writable> list = new ArrayList<>(reader.next());
assertEquals(2, list.size()); assertEquals(2, list.size());
} }
} }
@Test @Test
public void testPipesAsSplit() throws Exception { @DisplayName("Test Pipes As Split")
void testPipesAsSplit() throws Exception {
CSVRecordReader reader = new CSVRecordReader(0, '|'); CSVRecordReader reader = new CSVRecordReader(0, '|');
reader.initialize(new FileSplit(new ClassPathResource("datavec-api/issue414.csv").getFile())); reader.initialize(new FileSplit(new ClassPathResource("datavec-api/issue414.csv").getFile()));
int lineidx = 0; int lineidx = 0;
List<Integer> sixthColumn = Arrays.asList(13, 95, 15, 25); List<Integer> sixthColumn = Arrays.asList(13, 95, 15, 25);
while (reader.hasNext()) { while (reader.hasNext()) {
List<Writable> list = new ArrayList<>(reader.next()); List<Writable> list = new ArrayList<>(reader.next());
assertEquals(10, list.size()); assertEquals(10, list.size());
assertEquals((long)sixthColumn.get(lineidx), list.get(5).toInt()); assertEquals((long) sixthColumn.get(lineidx), list.get(5).toInt());
lineidx++; lineidx++;
} }
} }
@Test @Test
public void testWithQuotes() throws Exception { @DisplayName("Test With Quotes")
void testWithQuotes() throws Exception {
CSVRecordReader reader = new CSVRecordReader(0, ',', '\"'); CSVRecordReader reader = new CSVRecordReader(0, ',', '\"');
reader.initialize(new StringSplit("1,0,3,\"Braund, Mr. Owen Harris\",male,\"\"\"\"")); reader.initialize(new StringSplit("1,0,3,\"Braund, Mr. Owen Harris\",male,\"\"\"\""));
while (reader.hasNext()) { while (reader.hasNext()) {
List<Writable> vals = reader.next(); List<Writable> vals = reader.next();
assertEquals("Entry count", 6, vals.size()); assertEquals(6, vals.size(), "Entry count");
assertEquals("1", vals.get(0).toString()); assertEquals(vals.get(0).toString(), "1");
assertEquals("0", vals.get(1).toString()); assertEquals(vals.get(1).toString(), "0");
assertEquals("3", vals.get(2).toString()); assertEquals(vals.get(2).toString(), "3");
assertEquals("Braund, Mr. Owen Harris", vals.get(3).toString()); assertEquals(vals.get(3).toString(), "Braund, Mr. Owen Harris");
assertEquals("male", vals.get(4).toString()); assertEquals(vals.get(4).toString(), "male");
assertEquals("\"", vals.get(5).toString()); assertEquals(vals.get(5).toString(), "\"");
} }
} }
@Test @Test
public void testMeta() throws Exception { @DisplayName("Test Meta")
void testMeta() throws Exception {
CSVRecordReader rr = new CSVRecordReader(0, ','); CSVRecordReader rr = new CSVRecordReader(0, ',');
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile())); rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
int lineCount = 0; int lineCount = 0;
List<RecordMetaData> metaList = new ArrayList<>(); List<RecordMetaData> metaList = new ArrayList<>();
List<List<Writable>> writables = new ArrayList<>(); List<List<Writable>> writables = new ArrayList<>();
@ -214,29 +209,24 @@ public class CSVRecordReaderTest extends BaseND4JTest {
lineCount++; lineCount++;
RecordMetaData meta = r.getMetaData(); RecordMetaData meta = r.getMetaData();
// System.out.println(r.getRecord() + "\t" + meta.getLocation() + "\t" + meta.getURI()); // System.out.println(r.getRecord() + "\t" + meta.getLocation() + "\t" + meta.getURI());
metaList.add(meta); metaList.add(meta);
writables.add(r.getRecord()); writables.add(r.getRecord());
} }
assertFalse(rr.hasNext()); assertFalse(rr.hasNext());
assertEquals(150, lineCount); assertEquals(150, lineCount);
rr.reset(); rr.reset();
System.out.println("\n\n\n--------------------------------"); System.out.println("\n\n\n--------------------------------");
List<Record> contents = rr.loadFromMetaData(metaList); List<Record> contents = rr.loadFromMetaData(metaList);
assertEquals(150, contents.size()); assertEquals(150, contents.size());
// for(Record r : contents ){ // for(Record r : contents ){
// System.out.println(r); // System.out.println(r);
// } // }
List<RecordMetaData> meta2 = new ArrayList<>(); List<RecordMetaData> meta2 = new ArrayList<>();
meta2.add(metaList.get(100)); meta2.add(metaList.get(100));
meta2.add(metaList.get(90)); meta2.add(metaList.get(90));
meta2.add(metaList.get(80)); meta2.add(metaList.get(80));
meta2.add(metaList.get(70)); meta2.add(metaList.get(70));
meta2.add(metaList.get(60)); meta2.add(metaList.get(60));
List<Record> contents2 = rr.loadFromMetaData(meta2); List<Record> contents2 = rr.loadFromMetaData(meta2);
assertEquals(writables.get(100), contents2.get(0).getRecord()); assertEquals(writables.get(100), contents2.get(0).getRecord());
assertEquals(writables.get(90), contents2.get(1).getRecord()); assertEquals(writables.get(90), contents2.get(1).getRecord());
@ -246,50 +236,49 @@ public class CSVRecordReaderTest extends BaseND4JTest {
} }
@Test @Test
public void testRegex() throws Exception { @DisplayName("Test Regex")
CSVRecordReader reader = new CSVRegexRecordReader(0, ",", null, new String[] {null, "(.+) (.+) (.+)"}); void testRegex() throws Exception {
CSVRecordReader reader = new CSVRegexRecordReader(0, ",", null, new String[] { null, "(.+) (.+) (.+)" });
reader.initialize(new StringSplit("normal,1.2.3.4 space separator")); reader.initialize(new StringSplit("normal,1.2.3.4 space separator"));
while (reader.hasNext()) { while (reader.hasNext()) {
List<Writable> vals = reader.next(); List<Writable> vals = reader.next();
assertEquals("Entry count", 4, vals.size()); assertEquals(4, vals.size(), "Entry count");
assertEquals("normal", vals.get(0).toString()); assertEquals(vals.get(0).toString(), "normal");
assertEquals("1.2.3.4", vals.get(1).toString()); assertEquals(vals.get(1).toString(), "1.2.3.4");
assertEquals("space", vals.get(2).toString()); assertEquals(vals.get(2).toString(), "space");
assertEquals("separator", vals.get(3).toString()); assertEquals(vals.get(3).toString(), "separator");
} }
} }
@Test(expected = NoSuchElementException.class) @Test
public void testCsvSkipAllLines() throws IOException, InterruptedException { @DisplayName("Test Csv Skip All Lines")
void testCsvSkipAllLines() {
assertThrows(NoSuchElementException.class, () -> {
final int numLines = 4; final int numLines = 4;
final List<Writable> lineList = Arrays.asList((Writable) new IntWritable(numLines - 1), final List<Writable> lineList = Arrays.asList((Writable) new IntWritable(numLines - 1), (Writable) new Text("one"), (Writable) new Text("two"), (Writable) new Text("three"));
(Writable) new Text("one"), (Writable) new Text("two"), (Writable) new Text("three"));
String header = ",one,two,three"; String header = ",one,two,three";
List<String> lines = new ArrayList<>(); List<String> lines = new ArrayList<>();
for (int i = 0; i < numLines; i++) for (int i = 0; i < numLines; i++) lines.add(Integer.toString(i) + header);
lines.add(Integer.toString(i) + header);
File tempFile = File.createTempFile("csvSkipLines", ".csv"); File tempFile = File.createTempFile("csvSkipLines", ".csv");
FileUtils.writeLines(tempFile, lines); FileUtils.writeLines(tempFile, lines);
CSVRecordReader rr = new CSVRecordReader(numLines, ','); CSVRecordReader rr = new CSVRecordReader(numLines, ',');
rr.initialize(new FileSplit(tempFile)); rr.initialize(new FileSplit(tempFile));
rr.reset(); rr.reset();
assertTrue(!rr.hasNext()); assertTrue(!rr.hasNext());
rr.next(); rr.next();
});
} }
@Test @Test
public void testCsvSkipAllButOneLine() throws IOException, InterruptedException { @DisplayName("Test Csv Skip All But One Line")
void testCsvSkipAllButOneLine() throws IOException, InterruptedException {
final int numLines = 4; final int numLines = 4;
final List<Writable> lineList = Arrays.<Writable>asList(new Text(Integer.toString(numLines - 1)), final List<Writable> lineList = Arrays.<Writable>asList(new Text(Integer.toString(numLines - 1)), new Text("one"), new Text("two"), new Text("three"));
new Text("one"), new Text("two"), new Text("three"));
String header = ",one,two,three"; String header = ",one,two,three";
List<String> lines = new ArrayList<>(); List<String> lines = new ArrayList<>();
for (int i = 0; i < numLines; i++) for (int i = 0; i < numLines; i++) lines.add(Integer.toString(i) + header);
lines.add(Integer.toString(i) + header);
File tempFile = File.createTempFile("csvSkipLines", ".csv"); File tempFile = File.createTempFile("csvSkipLines", ".csv");
FileUtils.writeLines(tempFile, lines); FileUtils.writeLines(tempFile, lines);
CSVRecordReader rr = new CSVRecordReader(numLines - 1, ','); CSVRecordReader rr = new CSVRecordReader(numLines - 1, ',');
rr.initialize(new FileSplit(tempFile)); rr.initialize(new FileSplit(tempFile));
rr.reset(); rr.reset();
@ -297,50 +286,45 @@ public class CSVRecordReaderTest extends BaseND4JTest {
assertEquals(rr.next(), lineList); assertEquals(rr.next(), lineList);
} }
@Test @Test
public void testStreamReset() throws Exception { @DisplayName("Test Stream Reset")
void testStreamReset() throws Exception {
CSVRecordReader rr = new CSVRecordReader(0, ','); CSVRecordReader rr = new CSVRecordReader(0, ',');
rr.initialize(new InputStreamInputSplit(new ClassPathResource("datavec-api/iris.dat").getInputStream())); rr.initialize(new InputStreamInputSplit(new ClassPathResource("datavec-api/iris.dat").getInputStream()));
int count = 0; int count = 0;
while(rr.hasNext()){ while (rr.hasNext()) {
assertNotNull(rr.next()); assertNotNull(rr.next());
count++; count++;
} }
assertEquals(150, count); assertEquals(150, count);
assertFalse(rr.resetSupported()); assertFalse(rr.resetSupported());
try {
try{
rr.reset(); rr.reset();
fail("Expected exception"); fail("Expected exception");
} catch (Exception e){ } catch (Exception e) {
String msg = e.getMessage(); String msg = e.getMessage();
String msg2 = e.getCause().getMessage(); String msg2 = e.getCause().getMessage();
assertTrue(msg, msg.contains("Error during LineRecordReader reset")); assertTrue(msg.contains("Error during LineRecordReader reset"),msg);
assertTrue(msg2, msg2.contains("Reset not supported from streams")); assertTrue(msg2.contains("Reset not supported from streams"),msg2);
// e.printStackTrace(); // e.printStackTrace();
} }
} }
@Test @Test
public void testUsefulExceptionNoInit(){ @DisplayName("Test Useful Exception No Init")
void testUsefulExceptionNoInit() {
CSVRecordReader rr = new CSVRecordReader(0, ','); CSVRecordReader rr = new CSVRecordReader(0, ',');
try {
try{
rr.hasNext(); rr.hasNext();
fail("Expected exception"); fail("Expected exception");
} catch (Exception e){ } catch (Exception e) {
assertTrue(e.getMessage(), e.getMessage().contains("initialized")); assertTrue( e.getMessage().contains("initialized"),e.getMessage());
} }
try {
try{
rr.next(); rr.next();
fail("Expected exception"); fail("Expected exception");
} catch (Exception e){ } catch (Exception e) {
assertTrue(e.getMessage(), e.getMessage().contains("initialized")); assertTrue(e.getMessage().contains("initialized"),e.getMessage());
} }
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.records.reader.impl; package org.datavec.api.records.reader.impl;
import org.datavec.api.records.SequenceRecord; import org.datavec.api.records.SequenceRecord;
@ -27,12 +26,11 @@ import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
import org.datavec.api.split.InputSplit; import org.datavec.api.split.InputSplit;
import org.datavec.api.split.NumberedFileInputSplit; import org.datavec.api.split.NumberedFileInputSplit;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.TemporaryFolder; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import java.io.File; import java.io.File;
import java.io.InputStream; import java.io.InputStream;
import java.io.OutputStream; import java.io.OutputStream;
@ -41,25 +39,27 @@ import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Iterator; import java.util.Iterator;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals; @DisplayName("Csv Sequence Record Reader Test")
class CSVSequenceRecordReaderTest extends BaseND4JTest {
public class CSVSequenceRecordReaderTest extends BaseND4JTest { @TempDir
public Path tempDir;
@Rule
public TemporaryFolder tempDir = new TemporaryFolder();
@Test @Test
public void test() throws Exception { @DisplayName("Test")
void test() throws Exception {
CSVSequenceRecordReader seqReader = new CSVSequenceRecordReader(1, ","); CSVSequenceRecordReader seqReader = new CSVSequenceRecordReader(1, ",");
seqReader.initialize(new TestInputSplit()); seqReader.initialize(new TestInputSplit());
int sequenceCount = 0; int sequenceCount = 0;
while (seqReader.hasNext()) { while (seqReader.hasNext()) {
List<List<Writable>> sequence = seqReader.sequenceRecord(); List<List<Writable>> sequence = seqReader.sequenceRecord();
assertEquals(4, sequence.size()); //4 lines, plus 1 header line // 4 lines, plus 1 header line
assertEquals(4, sequence.size());
Iterator<List<Writable>> timeStepIter = sequence.iterator(); Iterator<List<Writable>> timeStepIter = sequence.iterator();
int lineCount = 0; int lineCount = 0;
while (timeStepIter.hasNext()) { while (timeStepIter.hasNext()) {
@ -80,19 +80,18 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
} }
@Test @Test
public void testReset() throws Exception { @DisplayName("Test Reset")
void testReset() throws Exception {
CSVSequenceRecordReader seqReader = new CSVSequenceRecordReader(1, ","); CSVSequenceRecordReader seqReader = new CSVSequenceRecordReader(1, ",");
seqReader.initialize(new TestInputSplit()); seqReader.initialize(new TestInputSplit());
int nTests = 5; int nTests = 5;
for (int i = 0; i < nTests; i++) { for (int i = 0; i < nTests; i++) {
seqReader.reset(); seqReader.reset();
int sequenceCount = 0; int sequenceCount = 0;
while (seqReader.hasNext()) { while (seqReader.hasNext()) {
List<List<Writable>> sequence = seqReader.sequenceRecord(); List<List<Writable>> sequence = seqReader.sequenceRecord();
assertEquals(4, sequence.size()); //4 lines, plus 1 header line // 4 lines, plus 1 header line
assertEquals(4, sequence.size());
Iterator<List<Writable>> timeStepIter = sequence.iterator(); Iterator<List<Writable>> timeStepIter = sequence.iterator();
int lineCount = 0; int lineCount = 0;
while (timeStepIter.hasNext()) { while (timeStepIter.hasNext()) {
@ -107,15 +106,15 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
} }
@Test @Test
public void testMetaData() throws Exception { @DisplayName("Test Meta Data")
void testMetaData() throws Exception {
CSVSequenceRecordReader seqReader = new CSVSequenceRecordReader(1, ","); CSVSequenceRecordReader seqReader = new CSVSequenceRecordReader(1, ",");
seqReader.initialize(new TestInputSplit()); seqReader.initialize(new TestInputSplit());
List<List<List<Writable>>> l = new ArrayList<>(); List<List<List<Writable>>> l = new ArrayList<>();
while (seqReader.hasNext()) { while (seqReader.hasNext()) {
List<List<Writable>> sequence = seqReader.sequenceRecord(); List<List<Writable>> sequence = seqReader.sequenceRecord();
assertEquals(4, sequence.size()); //4 lines, plus 1 header line // 4 lines, plus 1 header line
assertEquals(4, sequence.size());
Iterator<List<Writable>> timeStepIter = sequence.iterator(); Iterator<List<Writable>> timeStepIter = sequence.iterator();
int lineCount = 0; int lineCount = 0;
while (timeStepIter.hasNext()) { while (timeStepIter.hasNext()) {
@ -123,10 +122,8 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
lineCount++; lineCount++;
} }
assertEquals(4, lineCount); assertEquals(4, lineCount);
l.add(sequence); l.add(sequence);
} }
List<SequenceRecord> l2 = new ArrayList<>(); List<SequenceRecord> l2 = new ArrayList<>();
List<RecordMetaData> meta = new ArrayList<>(); List<RecordMetaData> meta = new ArrayList<>();
seqReader.reset(); seqReader.reset();
@ -136,7 +133,6 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
meta.add(sr.getMetaData()); meta.add(sr.getMetaData());
} }
assertEquals(3, l2.size()); assertEquals(3, l2.size());
List<SequenceRecord> fromMeta = seqReader.loadSequenceFromMetaData(meta); List<SequenceRecord> fromMeta = seqReader.loadSequenceFromMetaData(meta);
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
assertEquals(l.get(i), l2.get(i).getSequenceRecord()); assertEquals(l.get(i), l2.get(i).getSequenceRecord());
@ -144,8 +140,8 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
} }
} }
private static class @DisplayName("Test Input Split")
TestInputSplit implements InputSplit { private static class TestInputSplit implements InputSplit {
@Override @Override
public boolean canWriteToLocation(URI location) { public boolean canWriteToLocation(URI location) {
@ -164,7 +160,6 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
@Override @Override
public void updateSplitLocations(boolean reset) { public void updateSplitLocations(boolean reset) {
} }
@Override @Override
@ -174,7 +169,6 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
@Override @Override
public void bootStrapForWrite() { public void bootStrapForWrite() {
} }
@Override @Override
@ -222,38 +216,30 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
@Override @Override
public void reset() { public void reset() {
//No op // No op
} }
@Override @Override
public boolean resetSupported() { public boolean resetSupported() {
return true; return true;
} }
} }
@Test @Test
public void testCsvSeqAndNumberedFileSplit() throws Exception { @DisplayName("Test Csv Seq And Numbered File Split")
File baseDir = tempDir.newFolder(); void testCsvSeqAndNumberedFileSplit(@TempDir Path tempDir) throws Exception {
//Simple sanity check unit test File baseDir = tempDir.toFile();
// Simple sanity check unit test
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(baseDir); new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(baseDir);
} }
// Load time series from CSV sequence files; compare to SequenceRecordReaderDataSetIterator
//Load time series from CSV sequence files; compare to SequenceRecordReaderDataSetIterator
ClassPathResource resource = new ClassPathResource("csvsequence_0.txt"); ClassPathResource resource = new ClassPathResource("csvsequence_0.txt");
String featuresPath = new File(baseDir, "csvsequence_%d.txt").getAbsolutePath(); String featuresPath = new File(baseDir, "csvsequence_%d.txt").getAbsolutePath();
SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ","); SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
while (featureReader.hasNext()) {
while(featureReader.hasNext()){
featureReader.nextSequence(); featureReader.nextSequence();
} }
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.records.reader.impl; package org.datavec.api.records.reader.impl;
import org.datavec.api.records.reader.SequenceRecordReader; import org.datavec.api.records.reader.SequenceRecordReader;
@ -25,94 +24,87 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.records.reader.impl.csv.CSVVariableSlidingWindowRecordReader; import org.datavec.api.records.reader.impl.csv.CSVVariableSlidingWindowRecordReader;
import org.datavec.api.split.FileSplit; import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals; @DisplayName("Csv Variable Sliding Window Record Reader Test")
class CSVVariableSlidingWindowRecordReaderTest extends BaseND4JTest {
public class CSVVariableSlidingWindowRecordReaderTest extends BaseND4JTest {
@Test @Test
public void testCSVVariableSlidingWindowRecordReader() throws Exception { @DisplayName("Test CSV Variable Sliding Window Record Reader")
void testCSVVariableSlidingWindowRecordReader() throws Exception {
int maxLinesPerSequence = 3; int maxLinesPerSequence = 3;
SequenceRecordReader seqRR = new CSVVariableSlidingWindowRecordReader(maxLinesPerSequence); SequenceRecordReader seqRR = new CSVVariableSlidingWindowRecordReader(maxLinesPerSequence);
seqRR.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile())); seqRR.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
CSVRecordReader rr = new CSVRecordReader(); CSVRecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile())); rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
int count = 0; int count = 0;
while (seqRR.hasNext()) { while (seqRR.hasNext()) {
List<List<Writable>> next = seqRR.sequenceRecord(); List<List<Writable>> next = seqRR.sequenceRecord();
if (count == maxLinesPerSequence - 1) {
if(count==maxLinesPerSequence-1) {
LinkedList<List<Writable>> expected = new LinkedList<>(); LinkedList<List<Writable>> expected = new LinkedList<>();
for (int i = 0; i < maxLinesPerSequence; i++) { for (int i = 0; i < maxLinesPerSequence; i++) {
expected.addFirst(rr.next()); expected.addFirst(rr.next());
} }
assertEquals(expected, next); assertEquals(expected, next);
} }
if(count==maxLinesPerSequence) { if (count == maxLinesPerSequence) {
assertEquals(maxLinesPerSequence, next.size()); assertEquals(maxLinesPerSequence, next.size());
} }
if(count==0) { // first seq should be length 1 if (count == 0) {
// first seq should be length 1
assertEquals(1, next.size()); assertEquals(1, next.size());
} }
if(count>151) { // last seq should be length 1 if (count > 151) {
// last seq should be length 1
assertEquals(1, next.size()); assertEquals(1, next.size());
} }
count++; count++;
} }
assertEquals(152, count); assertEquals(152, count);
} }
@Test @Test
public void testCSVVariableSlidingWindowRecordReaderStride() throws Exception { @DisplayName("Test CSV Variable Sliding Window Record Reader Stride")
void testCSVVariableSlidingWindowRecordReaderStride() throws Exception {
int maxLinesPerSequence = 3; int maxLinesPerSequence = 3;
int stride = 2; int stride = 2;
SequenceRecordReader seqRR = new CSVVariableSlidingWindowRecordReader(maxLinesPerSequence, stride); SequenceRecordReader seqRR = new CSVVariableSlidingWindowRecordReader(maxLinesPerSequence, stride);
seqRR.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile())); seqRR.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
CSVRecordReader rr = new CSVRecordReader(); CSVRecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile())); rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
int count = 0; int count = 0;
while (seqRR.hasNext()) { while (seqRR.hasNext()) {
List<List<Writable>> next = seqRR.sequenceRecord(); List<List<Writable>> next = seqRR.sequenceRecord();
if (count == maxLinesPerSequence - 1) {
if(count==maxLinesPerSequence-1) {
LinkedList<List<Writable>> expected = new LinkedList<>(); LinkedList<List<Writable>> expected = new LinkedList<>();
for(int s = 0; s < stride; s++) { for (int s = 0; s < stride; s++) {
expected = new LinkedList<>(); expected = new LinkedList<>();
for (int i = 0; i < maxLinesPerSequence; i++) { for (int i = 0; i < maxLinesPerSequence; i++) {
expected.addFirst(rr.next()); expected.addFirst(rr.next());
} }
} }
assertEquals(expected, next); assertEquals(expected, next);
} }
if(count==maxLinesPerSequence) { if (count == maxLinesPerSequence) {
assertEquals(maxLinesPerSequence, next.size()); assertEquals(maxLinesPerSequence, next.size());
} }
if(count==0) { // first seq should be length 2 if (count == 0) {
// first seq should be length 2
assertEquals(2, next.size()); assertEquals(2, next.size());
} }
if(count>151) { // last seq should be length 1 if (count > 151) {
// last seq should be length 1
assertEquals(1, next.size()); assertEquals(1, next.size());
} }
count++; count++;
} }
assertEquals(76, count); assertEquals(76, count);
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.records.reader.impl; package org.datavec.api.records.reader.impl;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
@ -28,45 +27,44 @@ import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
import org.datavec.api.records.reader.impl.filebatch.FileBatchRecordReader; import org.datavec.api.records.reader.impl.filebatch.FileBatchRecordReader;
import org.datavec.api.records.reader.impl.filebatch.FileBatchSequenceRecordReader; import org.datavec.api.records.reader.impl.filebatch.FileBatchSequenceRecordReader;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.TemporaryFolder; import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.loader.FileBatch; import org.nd4j.common.loader.FileBatch;
import java.io.File; import java.io.File;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.Assert.*; @DisplayName("File Batch Record Reader Test")
public class FileBatchRecordReaderTest extends BaseND4JTest { public class FileBatchRecordReaderTest extends BaseND4JTest {
@TempDir Path testDir;
@Rule @ParameterizedTest
public TemporaryFolder testDir = new TemporaryFolder(); @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@DisplayName("Test Csv")
@Test void testCsv(Nd4jBackend backend) throws Exception {
public void testCsv() throws Exception { // This is an unrealistic use case - one line/record per CSV
File baseDir = testDir.toFile();
//This is an unrealistic use case - one line/record per CSV
File baseDir = testDir.newFolder();
List<File> fileList = new ArrayList<>(); List<File> fileList = new ArrayList<>();
for( int i=0; i<10; i++ ){ for (int i = 0; i < 10; i++) {
String s = "file_" + i + "," + i + "," + i; String s = "file_" + i + "," + i + "," + i;
File f = new File(baseDir, "origFile" + i + ".csv"); File f = new File(baseDir, "origFile" + i + ".csv");
FileUtils.writeStringToFile(f, s, StandardCharsets.UTF_8); FileUtils.writeStringToFile(f, s, StandardCharsets.UTF_8);
fileList.add(f); fileList.add(f);
} }
FileBatch fb = FileBatch.forFiles(fileList); FileBatch fb = FileBatch.forFiles(fileList);
RecordReader rr = new CSVRecordReader(); RecordReader rr = new CSVRecordReader();
FileBatchRecordReader fbrr = new FileBatchRecordReader(rr, fb); FileBatchRecordReader fbrr = new FileBatchRecordReader(rr, fb);
for (int test = 0; test < 3; test++) {
for( int test=0; test<3; test++) {
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
assertTrue(fbrr.hasNext()); assertTrue(fbrr.hasNext());
List<Writable> next = fbrr.next(); List<Writable> next = fbrr.next();
@ -82,16 +80,17 @@ public class FileBatchRecordReaderTest extends BaseND4JTest {
} }
} }
@Test @ParameterizedTest
public void testCsvSequence() throws Exception { @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
//CSV sequence - 3 lines per file, 10 files @DisplayName("Test Csv Sequence")
File baseDir = testDir.newFolder(); void testCsvSequence(Nd4jBackend backend) throws Exception {
// CSV sequence - 3 lines per file, 10 files
File baseDir = testDir.toFile();
List<File> fileList = new ArrayList<>(); List<File> fileList = new ArrayList<>();
for( int i=0; i<10; i++ ){ for (int i = 0; i < 10; i++) {
StringBuilder sb = new StringBuilder(); StringBuilder sb = new StringBuilder();
for( int j=0; j<3; j++ ){ for (int j = 0; j < 3; j++) {
if(j > 0) if (j > 0)
sb.append("\n"); sb.append("\n");
sb.append("file_" + i + "," + i + "," + j); sb.append("file_" + i + "," + i + "," + j);
} }
@ -99,19 +98,16 @@ public class FileBatchRecordReaderTest extends BaseND4JTest {
FileUtils.writeStringToFile(f, sb.toString(), StandardCharsets.UTF_8); FileUtils.writeStringToFile(f, sb.toString(), StandardCharsets.UTF_8);
fileList.add(f); fileList.add(f);
} }
FileBatch fb = FileBatch.forFiles(fileList); FileBatch fb = FileBatch.forFiles(fileList);
SequenceRecordReader rr = new CSVSequenceRecordReader(); SequenceRecordReader rr = new CSVSequenceRecordReader();
FileBatchSequenceRecordReader fbrr = new FileBatchSequenceRecordReader(rr, fb); FileBatchSequenceRecordReader fbrr = new FileBatchSequenceRecordReader(rr, fb);
for (int test = 0; test < 3; test++) {
for( int test=0; test<3; test++) {
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
assertTrue(fbrr.hasNext()); assertTrue(fbrr.hasNext());
List<List<Writable>> next = fbrr.sequenceRecord(); List<List<Writable>> next = fbrr.sequenceRecord();
assertEquals(3, next.size()); assertEquals(3, next.size());
int count = 0; int count = 0;
for(List<Writable> step : next ){ for (List<Writable> step : next) {
String s1 = "file_" + i; String s1 = "file_" + i;
assertEquals(s1, step.get(0).toString()); assertEquals(s1, step.get(0).toString());
assertEquals(String.valueOf(i), step.get(1).toString()); assertEquals(String.valueOf(i), step.get(1).toString());
@ -123,5 +119,4 @@ public class FileBatchRecordReaderTest extends BaseND4JTest {
fbrr.reset(); fbrr.reset();
} }
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.records.reader.impl; package org.datavec.api.records.reader.impl;
import org.datavec.api.records.Record; import org.datavec.api.records.Record;
@ -26,28 +25,28 @@ import org.datavec.api.split.CollectionInputSplit;
import org.datavec.api.split.FileSplit; import org.datavec.api.split.FileSplit;
import org.datavec.api.split.InputSplit; import org.datavec.api.split.InputSplit;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import java.net.URI; import java.net.URI;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals; @DisplayName("File Record Reader Test")
import static org.junit.Assert.assertFalse; class FileRecordReaderTest extends BaseND4JTest {
public class FileRecordReaderTest extends BaseND4JTest {
@Test @Test
public void testReset() throws Exception { @DisplayName("Test Reset")
void testReset() throws Exception {
FileRecordReader rr = new FileRecordReader(); FileRecordReader rr = new FileRecordReader();
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile())); rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
int nResets = 5; int nResets = 5;
for (int i = 0; i < nResets; i++) { for (int i = 0; i < nResets; i++) {
int lineCount = 0; int lineCount = 0;
while (rr.hasNext()) { while (rr.hasNext()) {
List<Writable> line = rr.next(); List<Writable> line = rr.next();
@ -61,25 +60,20 @@ public class FileRecordReaderTest extends BaseND4JTest {
} }
@Test @Test
public void testMeta() throws Exception { @DisplayName("Test Meta")
void testMeta() throws Exception {
FileRecordReader rr = new FileRecordReader(); FileRecordReader rr = new FileRecordReader();
URI[] arr = new URI[3]; URI[] arr = new URI[3];
arr[0] = new ClassPathResource("datavec-api/csvsequence_0.txt").getFile().toURI(); arr[0] = new ClassPathResource("datavec-api/csvsequence_0.txt").getFile().toURI();
arr[1] = new ClassPathResource("datavec-api/csvsequence_1.txt").getFile().toURI(); arr[1] = new ClassPathResource("datavec-api/csvsequence_1.txt").getFile().toURI();
arr[2] = new ClassPathResource("datavec-api/csvsequence_2.txt").getFile().toURI(); arr[2] = new ClassPathResource("datavec-api/csvsequence_2.txt").getFile().toURI();
InputSplit is = new CollectionInputSplit(Arrays.asList(arr)); InputSplit is = new CollectionInputSplit(Arrays.asList(arr));
rr.initialize(is); rr.initialize(is);
List<List<Writable>> out = new ArrayList<>(); List<List<Writable>> out = new ArrayList<>();
while (rr.hasNext()) { while (rr.hasNext()) {
out.add(rr.next()); out.add(rr.next());
} }
assertEquals(3, out.size()); assertEquals(3, out.size());
rr.reset(); rr.reset();
List<List<Writable>> out2 = new ArrayList<>(); List<List<Writable>> out2 = new ArrayList<>();
List<Record> out3 = new ArrayList<>(); List<Record> out3 = new ArrayList<>();
@ -90,13 +84,10 @@ public class FileRecordReaderTest extends BaseND4JTest {
out2.add(r.getRecord()); out2.add(r.getRecord());
out3.add(r); out3.add(r);
meta.add(r.getMetaData()); meta.add(r.getMetaData());
assertEquals(arr[count++], r.getMetaData().getURI()); assertEquals(arr[count++], r.getMetaData().getURI());
} }
assertEquals(out, out2); assertEquals(out, out2);
List<Record> fromMeta = rr.loadFromMetaData(meta); List<Record> fromMeta = rr.loadFromMetaData(meta);
assertEquals(out3, fromMeta); assertEquals(out3, fromMeta);
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.records.reader.impl; package org.datavec.api.records.reader.impl;
import org.datavec.api.records.reader.RecordReader; import org.datavec.api.records.reader.RecordReader;
@ -28,97 +27,81 @@ import org.datavec.api.split.CollectionInputSplit;
import org.datavec.api.split.FileSplit; import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.TemporaryFolder; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import org.nd4j.shade.jackson.core.JsonFactory; import org.nd4j.shade.jackson.core.JsonFactory;
import org.nd4j.shade.jackson.databind.ObjectMapper; import org.nd4j.shade.jackson.databind.ObjectMapper;
import java.io.File; import java.io.File;
import java.net.URI; import java.net.URI;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals; @DisplayName("Jackson Line Record Reader Test")
class JacksonLineRecordReaderTest extends BaseND4JTest {
public class JacksonLineRecordReaderTest extends BaseND4JTest { @TempDir
public Path testDir;
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
public JacksonLineRecordReaderTest() { public JacksonLineRecordReaderTest() {
} }
private static FieldSelection getFieldSelection() { private static FieldSelection getFieldSelection() {
return new FieldSelection.Builder().addField("value1"). return new FieldSelection.Builder().addField("value1").addField("value2").addField("value3").addField("value4").addField("value5").addField("value6").addField("value7").addField("value8").addField("value9").addField("value10").build();
addField("value2").
addField("value3").
addField("value4").
addField("value5").
addField("value6").
addField("value7").
addField("value8").
addField("value9").
addField("value10").build();
} }
@Test @Test
public void testReadJSON() throws Exception { @DisplayName("Test Read JSON")
void testReadJSON() throws Exception {
RecordReader rr = new JacksonLineRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory())); RecordReader rr = new JacksonLineRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()));
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/json/json_test_3.txt").getFile())); rr.initialize(new FileSplit(new ClassPathResource("datavec-api/json/json_test_3.txt").getFile()));
testJacksonRecordReader(rr); testJacksonRecordReader(rr);
} }
private static void testJacksonRecordReader(RecordReader rr) { private static void testJacksonRecordReader(RecordReader rr) {
while (rr.hasNext()) { while (rr.hasNext()) {
List<Writable> json0 = rr.next(); List<Writable> json0 = rr.next();
//System.out.println(json0); // System.out.println(json0);
assert(json0.size() > 0); assert (json0.size() > 0);
} }
} }
@Test @Test
public void testJacksonLineSequenceRecordReader() throws Exception { @DisplayName("Test Jackson Line Sequence Record Reader")
File dir = testDir.newFolder(); void testJacksonLineSequenceRecordReader(@TempDir Path testDir) throws Exception {
File dir = testDir.toFile();
new ClassPathResource("datavec-api/JacksonLineSequenceRecordReaderTest/").copyDirectory(dir); new ClassPathResource("datavec-api/JacksonLineSequenceRecordReaderTest/").copyDirectory(dir);
FieldSelection f = new FieldSelection.Builder().addField("a").addField(new Text("MISSING_B"), "b").addField(new Text("MISSING_CX"), "c", "x").build();
FieldSelection f = new FieldSelection.Builder().addField("a").addField(new Text("MISSING_B"), "b")
.addField(new Text("MISSING_CX"), "c", "x").build();
JacksonLineSequenceRecordReader rr = new JacksonLineSequenceRecordReader(f, new ObjectMapper(new JsonFactory())); JacksonLineSequenceRecordReader rr = new JacksonLineSequenceRecordReader(f, new ObjectMapper(new JsonFactory()));
File[] files = dir.listFiles(); File[] files = dir.listFiles();
Arrays.sort(files); Arrays.sort(files);
URI[] u = new URI[files.length]; URI[] u = new URI[files.length];
for( int i=0; i<files.length; i++ ){ for (int i = 0; i < files.length; i++) {
u[i] = files[i].toURI(); u[i] = files[i].toURI();
} }
rr.initialize(new CollectionInputSplit(u)); rr.initialize(new CollectionInputSplit(u));
List<List<Writable>> expSeq0 = new ArrayList<>(); List<List<Writable>> expSeq0 = new ArrayList<>();
expSeq0.add(Arrays.asList((Writable) new Text("aValue0"), new Text("bValue0"), new Text("cxValue0"))); expSeq0.add(Arrays.asList((Writable) new Text("aValue0"), new Text("bValue0"), new Text("cxValue0")));
expSeq0.add(Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1"))); expSeq0.add(Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1")));
expSeq0.add(Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX"))); expSeq0.add(Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX")));
List<List<Writable>> expSeq1 = new ArrayList<>(); List<List<Writable>> expSeq1 = new ArrayList<>();
expSeq1.add(Arrays.asList((Writable) new Text("aValue3"), new Text("bValue3"), new Text("cxValue3"))); expSeq1.add(Arrays.asList((Writable) new Text("aValue3"), new Text("bValue3"), new Text("cxValue3")));
int count = 0; int count = 0;
while(rr.hasNext()){ while (rr.hasNext()) {
List<List<Writable>> next = rr.sequenceRecord(); List<List<Writable>> next = rr.sequenceRecord();
if(count++ == 0){ if (count++ == 0) {
assertEquals(expSeq0, next); assertEquals(expSeq0, next);
} else { } else {
assertEquals(expSeq1, next); assertEquals(expSeq1, next);
} }
} }
assertEquals(2, count); assertEquals(2, count);
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.records.reader.impl; package org.datavec.api.records.reader.impl;
import org.datavec.api.io.labels.PathLabelGenerator; import org.datavec.api.io.labels.PathLabelGenerator;
@ -31,114 +30,95 @@ import org.datavec.api.split.NumberedFileInputSplit;
import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.TemporaryFolder; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import org.nd4j.shade.jackson.core.JsonFactory; import org.nd4j.shade.jackson.core.JsonFactory;
import org.nd4j.shade.jackson.databind.ObjectMapper; import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.dataformat.xml.XmlFactory; import org.nd4j.shade.jackson.dataformat.xml.XmlFactory;
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory; import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
import java.io.File; import java.io.File;
import java.net.URI; import java.net.URI;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals; @DisplayName("Jackson Record Reader Test")
import static org.junit.Assert.assertFalse; class JacksonRecordReaderTest extends BaseND4JTest {
public class JacksonRecordReaderTest extends BaseND4JTest { @TempDir
public Path testDir;
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test @Test
public void testReadingJson() throws Exception { @DisplayName("Test Reading Json")
//Load 3 values from 3 JSON files void testReadingJson(@TempDir Path testDir) throws Exception {
//stricture: a:value, b:value, c:x:value, c:y:value // Load 3 values from 3 JSON files
//And we want to load only a:value, b:value and c:x:value // stricture: a:value, b:value, c:x:value, c:y:value
//For first JSON file: all values are present // And we want to load only a:value, b:value and c:x:value
//For second JSON file: b:value is missing // For first JSON file: all values are present
//For third JSON file: c:x:value is missing // For second JSON file: b:value is missing
// For third JSON file: c:x:value is missing
ClassPathResource cpr = new ClassPathResource("datavec-api/json/"); ClassPathResource cpr = new ClassPathResource("datavec-api/json/");
File f = testDir.newFolder(); File f = testDir.toFile();
cpr.copyDirectory(f); cpr.copyDirectory(f);
String path = new File(f, "json_test_%d.txt").getAbsolutePath(); String path = new File(f, "json_test_%d.txt").getAbsolutePath();
InputSplit is = new NumberedFileInputSplit(path, 0, 2); InputSplit is = new NumberedFileInputSplit(path, 0, 2);
RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory())); RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()));
rr.initialize(is); rr.initialize(is);
testJacksonRecordReader(rr); testJacksonRecordReader(rr);
} }
@Test @Test
public void testReadingYaml() throws Exception { @DisplayName("Test Reading Yaml")
//Exact same information as JSON format, but in YAML format void testReadingYaml(@TempDir Path testDir) throws Exception {
// Exact same information as JSON format, but in YAML format
ClassPathResource cpr = new ClassPathResource("datavec-api/yaml/"); ClassPathResource cpr = new ClassPathResource("datavec-api/yaml/");
File f = testDir.newFolder(); File f = testDir.toFile();
cpr.copyDirectory(f); cpr.copyDirectory(f);
String path = new File(f, "yaml_test_%d.txt").getAbsolutePath(); String path = new File(f, "yaml_test_%d.txt").getAbsolutePath();
InputSplit is = new NumberedFileInputSplit(path, 0, 2); InputSplit is = new NumberedFileInputSplit(path, 0, 2);
RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new YAMLFactory())); RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new YAMLFactory()));
rr.initialize(is); rr.initialize(is);
testJacksonRecordReader(rr); testJacksonRecordReader(rr);
} }
@Test @Test
public void testReadingXml() throws Exception { @DisplayName("Test Reading Xml")
//Exact same information as JSON format, but in XML format void testReadingXml(@TempDir Path testDir) throws Exception {
// Exact same information as JSON format, but in XML format
ClassPathResource cpr = new ClassPathResource("datavec-api/xml/"); ClassPathResource cpr = new ClassPathResource("datavec-api/xml/");
File f = testDir.newFolder(); File f = testDir.toFile();
cpr.copyDirectory(f); cpr.copyDirectory(f);
String path = new File(f, "xml_test_%d.txt").getAbsolutePath(); String path = new File(f, "xml_test_%d.txt").getAbsolutePath();
InputSplit is = new NumberedFileInputSplit(path, 0, 2); InputSplit is = new NumberedFileInputSplit(path, 0, 2);
RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new XmlFactory())); RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new XmlFactory()));
rr.initialize(is); rr.initialize(is);
testJacksonRecordReader(rr); testJacksonRecordReader(rr);
} }
private static FieldSelection getFieldSelection() { private static FieldSelection getFieldSelection() {
return new FieldSelection.Builder().addField("a").addField(new Text("MISSING_B"), "b") return new FieldSelection.Builder().addField("a").addField(new Text("MISSING_B"), "b").addField(new Text("MISSING_CX"), "c", "x").build();
.addField(new Text("MISSING_CX"), "c", "x").build();
} }
private static void testJacksonRecordReader(RecordReader rr) { private static void testJacksonRecordReader(RecordReader rr) {
List<Writable> json0 = rr.next(); List<Writable> json0 = rr.next();
List<Writable> exp0 = Arrays.asList((Writable) new Text("aValue0"), new Text("bValue0"), new Text("cxValue0")); List<Writable> exp0 = Arrays.asList((Writable) new Text("aValue0"), new Text("bValue0"), new Text("cxValue0"));
assertEquals(exp0, json0); assertEquals(exp0, json0);
List<Writable> json1 = rr.next(); List<Writable> json1 = rr.next();
List<Writable> exp1 = List<Writable> exp1 = Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1"));
Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1"));
assertEquals(exp1, json1); assertEquals(exp1, json1);
List<Writable> json2 = rr.next(); List<Writable> json2 = rr.next();
List<Writable> exp2 = List<Writable> exp2 = Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX"));
Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX"));
assertEquals(exp2, json2); assertEquals(exp2, json2);
assertFalse(rr.hasNext()); assertFalse(rr.hasNext());
// Test reset
//Test reset
rr.reset(); rr.reset();
assertEquals(exp0, rr.next()); assertEquals(exp0, rr.next());
assertEquals(exp1, rr.next()); assertEquals(exp1, rr.next());
@ -147,72 +127,50 @@ public class JacksonRecordReaderTest extends BaseND4JTest {
} }
@Test @Test
public void testAppendingLabels() throws Exception { @DisplayName("Test Appending Labels")
void testAppendingLabels(@TempDir Path testDir) throws Exception {
ClassPathResource cpr = new ClassPathResource("datavec-api/json/"); ClassPathResource cpr = new ClassPathResource("datavec-api/json/");
File f = testDir.newFolder(); File f = testDir.toFile();
cpr.copyDirectory(f); cpr.copyDirectory(f);
String path = new File(f, "json_test_%d.txt").getAbsolutePath(); String path = new File(f, "json_test_%d.txt").getAbsolutePath();
InputSplit is = new NumberedFileInputSplit(path, 0, 2); InputSplit is = new NumberedFileInputSplit(path, 0, 2);
// Insert at the end:
//Insert at the end: RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1, new LabelGen());
RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1,
new LabelGen());
rr.initialize(is); rr.initialize(is);
List<Writable> exp0 = Arrays.asList((Writable) new Text("aValue0"), new Text("bValue0"), new Text("cxValue0"), new IntWritable(0));
List<Writable> exp0 = Arrays.asList((Writable) new Text("aValue0"), new Text("bValue0"), new Text("cxValue0"),
new IntWritable(0));
assertEquals(exp0, rr.next()); assertEquals(exp0, rr.next());
List<Writable> exp1 = Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1"), new IntWritable(1));
List<Writable> exp1 = Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1"),
new IntWritable(1));
assertEquals(exp1, rr.next()); assertEquals(exp1, rr.next());
List<Writable> exp2 = Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX"), new IntWritable(2));
List<Writable> exp2 = Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX"),
new IntWritable(2));
assertEquals(exp2, rr.next()); assertEquals(exp2, rr.next());
// Insert at position 0:
//Insert at position 0: rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1, new LabelGen(), 0);
rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1,
new LabelGen(), 0);
rr.initialize(is); rr.initialize(is);
exp0 = Arrays.asList((Writable) new IntWritable(0), new Text("aValue0"), new Text("bValue0"), new Text("cxValue0"));
exp0 = Arrays.asList((Writable) new IntWritable(0), new Text("aValue0"), new Text("bValue0"),
new Text("cxValue0"));
assertEquals(exp0, rr.next()); assertEquals(exp0, rr.next());
exp1 = Arrays.asList((Writable) new IntWritable(1), new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1"));
exp1 = Arrays.asList((Writable) new IntWritable(1), new Text("aValue1"), new Text("MISSING_B"),
new Text("cxValue1"));
assertEquals(exp1, rr.next()); assertEquals(exp1, rr.next());
exp2 = Arrays.asList((Writable) new IntWritable(2), new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX"));
exp2 = Arrays.asList((Writable) new IntWritable(2), new Text("aValue2"), new Text("bValue2"),
new Text("MISSING_CX"));
assertEquals(exp2, rr.next()); assertEquals(exp2, rr.next());
} }
@Test @Test
public void testAppendingLabelsMetaData() throws Exception { @DisplayName("Test Appending Labels Meta Data")
void testAppendingLabelsMetaData(@TempDir Path testDir) throws Exception {
ClassPathResource cpr = new ClassPathResource("datavec-api/json/"); ClassPathResource cpr = new ClassPathResource("datavec-api/json/");
File f = testDir.newFolder(); File f = testDir.toFile();
cpr.copyDirectory(f); cpr.copyDirectory(f);
String path = new File(f, "json_test_%d.txt").getAbsolutePath(); String path = new File(f, "json_test_%d.txt").getAbsolutePath();
InputSplit is = new NumberedFileInputSplit(path, 0, 2); InputSplit is = new NumberedFileInputSplit(path, 0, 2);
// Insert at the end:
//Insert at the end: RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1, new LabelGen());
RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1,
new LabelGen());
rr.initialize(is); rr.initialize(is);
List<List<Writable>> out = new ArrayList<>(); List<List<Writable>> out = new ArrayList<>();
while (rr.hasNext()) { while (rr.hasNext()) {
out.add(rr.next()); out.add(rr.next());
} }
assertEquals(3, out.size()); assertEquals(3, out.size());
rr.reset(); rr.reset();
List<List<Writable>> out2 = new ArrayList<>(); List<List<Writable>> out2 = new ArrayList<>();
List<Record> outRecord = new ArrayList<>(); List<Record> outRecord = new ArrayList<>();
List<RecordMetaData> meta = new ArrayList<>(); List<RecordMetaData> meta = new ArrayList<>();
@ -222,14 +180,12 @@ public class JacksonRecordReaderTest extends BaseND4JTest {
outRecord.add(r); outRecord.add(r);
meta.add(r.getMetaData()); meta.add(r.getMetaData());
} }
assertEquals(out, out2); assertEquals(out, out2);
List<Record> fromMeta = rr.loadFromMetaData(meta); List<Record> fromMeta = rr.loadFromMetaData(meta);
assertEquals(outRecord, fromMeta); assertEquals(outRecord, fromMeta);
} }
@DisplayName("Label Gen")
private static class LabelGen implements PathLabelGenerator { private static class LabelGen implements PathLabelGenerator {
@Override @Override
@ -252,5 +208,4 @@ public class JacksonRecordReaderTest extends BaseND4JTest {
return true; return true;
} }
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.records.reader.impl; package org.datavec.api.records.reader.impl;
import org.datavec.api.conf.Configuration; import org.datavec.api.conf.Configuration;
@ -27,43 +26,30 @@ import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import java.io.IOException; import java.io.IOException;
import java.util.*; import java.util.*;
import static org.datavec.api.records.reader.impl.misc.LibSvmRecordReader.*; import static org.datavec.api.records.reader.impl.misc.LibSvmRecordReader.*;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.jupiter.api.Assertions.assertThrows;
public class LibSvmRecordReaderTest extends BaseND4JTest { @DisplayName("Lib Svm Record Reader Test")
class LibSvmRecordReaderTest extends BaseND4JTest {
@Test @Test
public void testBasicRecord() throws IOException, InterruptedException { @DisplayName("Test Basic Record")
void testBasicRecord() throws IOException, InterruptedException {
Map<Integer, List<Writable>> correct = new HashMap<>(); Map<Integer, List<Writable>> correct = new HashMap<>();
// 7 2:1 4:2 6:3 8:4 10:5 // 7 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO, ONE, correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), new IntWritable(7)));
ZERO, new DoubleWritable(2),
ZERO, new DoubleWritable(3),
ZERO, new DoubleWritable(4),
ZERO, new DoubleWritable(5),
new IntWritable(7)));
// 2 qid:42 1:0.1 2:2 6:6.6 8:80 // 2 qid:42 1:0.1 2:2 6:6.6 8:80
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, new IntWritable(2)));
ZERO, ZERO,
ZERO, new DoubleWritable(6.6),
ZERO, new DoubleWritable(80),
ZERO, ZERO,
new IntWritable(2)));
// 33 // 33
correct.put(2, Arrays.asList(ZERO, ZERO, correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, new IntWritable(33)));
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
new IntWritable(33)));
LibSvmRecordReader rr = new LibSvmRecordReader(); LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration(); Configuration config = new Configuration();
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false); config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
@ -80,27 +66,15 @@ public class LibSvmRecordReaderTest extends BaseND4JTest {
} }
@Test @Test
public void testNoAppendLabel() throws IOException, InterruptedException { @DisplayName("Test No Append Label")
void testNoAppendLabel() throws IOException, InterruptedException {
Map<Integer, List<Writable>> correct = new HashMap<>(); Map<Integer, List<Writable>> correct = new HashMap<>();
// 7 2:1 4:2 6:3 8:4 10:5 // 7 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO, ONE, correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5)));
ZERO, new DoubleWritable(2),
ZERO, new DoubleWritable(3),
ZERO, new DoubleWritable(4),
ZERO, new DoubleWritable(5)));
// 2 qid:42 1:0.1 2:2 6:6.6 8:80 // 2 qid:42 1:0.1 2:2 6:6.6 8:80
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO));
ZERO, ZERO,
ZERO, new DoubleWritable(6.6),
ZERO, new DoubleWritable(80),
ZERO, ZERO));
// 33 // 33
correct.put(2, Arrays.asList(ZERO, ZERO, correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO));
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO));
SVMLightRecordReader rr = new SVMLightRecordReader(); SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration(); Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false); config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
@ -117,33 +91,17 @@ public class LibSvmRecordReaderTest extends BaseND4JTest {
} }
@Test @Test
public void testNoLabel() throws IOException, InterruptedException { @DisplayName("Test No Label")
void testNoLabel() throws IOException, InterruptedException {
Map<Integer, List<Writable>> correct = new HashMap<>(); Map<Integer, List<Writable>> correct = new HashMap<>();
// 2:1 4:2 6:3 8:4 10:5 // 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO, ONE, correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5)));
ZERO, new DoubleWritable(2),
ZERO, new DoubleWritable(3),
ZERO, new DoubleWritable(4),
ZERO, new DoubleWritable(5)));
// qid:42 1:0.1 2:2 6:6.6 8:80 // qid:42 1:0.1 2:2 6:6.6 8:80
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO));
ZERO, ZERO,
ZERO, new DoubleWritable(6.6),
ZERO, new DoubleWritable(80),
ZERO, ZERO));
// 1:1.0 // 1:1.0
correct.put(2, Arrays.asList(new DoubleWritable(1.0), ZERO, correct.put(2, Arrays.asList(new DoubleWritable(1.0), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO));
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO));
// //
correct.put(3, Arrays.asList(ZERO, ZERO, correct.put(3, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO));
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO));
SVMLightRecordReader rr = new SVMLightRecordReader(); SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration(); Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false); config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
@ -160,33 +118,15 @@ public class LibSvmRecordReaderTest extends BaseND4JTest {
} }
@Test @Test
public void testMultioutputRecord() throws IOException, InterruptedException { @DisplayName("Test Multioutput Record")
void testMultioutputRecord() throws IOException, InterruptedException {
Map<Integer, List<Writable>> correct = new HashMap<>(); Map<Integer, List<Writable>> correct = new HashMap<>();
// 7 2.45,9 2:1 4:2 6:3 8:4 10:5 // 7 2.45,9 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO, ONE, correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), new IntWritable(7), new DoubleWritable(2.45), new IntWritable(9)));
ZERO, new DoubleWritable(2),
ZERO, new DoubleWritable(3),
ZERO, new DoubleWritable(4),
ZERO, new DoubleWritable(5),
new IntWritable(7), new DoubleWritable(2.45),
new IntWritable(9)));
// 2,3,4 qid:42 1:0.1 2:2 6:6.6 8:80 // 2,3,4 qid:42 1:0.1 2:2 6:6.6 8:80
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, new IntWritable(2), new IntWritable(3), new IntWritable(4)));
ZERO, ZERO,
ZERO, new DoubleWritable(6.6),
ZERO, new DoubleWritable(80),
ZERO, ZERO,
new IntWritable(2), new IntWritable(3),
new IntWritable(4)));
// 33,32.0,31.9 // 33,32.0,31.9
correct.put(2, Arrays.asList(ZERO, ZERO, correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, new IntWritable(33), new DoubleWritable(32.0), new DoubleWritable(31.9)));
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
new IntWritable(33), new DoubleWritable(32.0),
new DoubleWritable(31.9)));
LibSvmRecordReader rr = new LibSvmRecordReader(); LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration(); Configuration config = new Configuration();
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false); config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
@ -202,51 +142,20 @@ public class LibSvmRecordReaderTest extends BaseND4JTest {
assertEquals(i, correct.size()); assertEquals(i, correct.size());
} }
@Test @Test
public void testMultilabelRecord() throws IOException, InterruptedException { @DisplayName("Test Multilabel Record")
void testMultilabelRecord() throws IOException, InterruptedException {
Map<Integer, List<Writable>> correct = new HashMap<>(); Map<Integer, List<Writable>> correct = new HashMap<>();
// 1,3 2:1 4:2 6:3 8:4 10:5 // 1,3 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO, ONE, correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), LABEL_ONE, LABEL_ZERO, LABEL_ONE, LABEL_ZERO));
ZERO, new DoubleWritable(2),
ZERO, new DoubleWritable(3),
ZERO, new DoubleWritable(4),
ZERO, new DoubleWritable(5),
LABEL_ONE, LABEL_ZERO,
LABEL_ONE, LABEL_ZERO));
// 2 qid:42 1:0.1 2:2 6:6.6 8:80 // 2 qid:42 1:0.1 2:2 6:6.6 8:80
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, LABEL_ZERO, LABEL_ONE, LABEL_ZERO, LABEL_ZERO));
ZERO, ZERO,
ZERO, new DoubleWritable(6.6),
ZERO, new DoubleWritable(80),
ZERO, ZERO,
LABEL_ZERO, LABEL_ONE,
LABEL_ZERO, LABEL_ZERO));
// 1,2,4 // 1,2,4
correct.put(2, Arrays.asList(ZERO, ZERO, correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ONE, LABEL_ONE, LABEL_ZERO, LABEL_ONE));
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
LABEL_ONE, LABEL_ONE,
LABEL_ZERO, LABEL_ONE));
// 1:1.0 // 1:1.0
correct.put(3, Arrays.asList(new DoubleWritable(1.0), ZERO, correct.put(3, Arrays.asList(new DoubleWritable(1.0), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
LABEL_ZERO, LABEL_ZERO,
LABEL_ZERO, LABEL_ZERO));
// //
correct.put(4, Arrays.asList(ZERO, ZERO, correct.put(4, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
LABEL_ZERO, LABEL_ZERO,
LABEL_ZERO, LABEL_ZERO));
LibSvmRecordReader rr = new LibSvmRecordReader(); LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration(); Configuration config = new Configuration();
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false); config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
@ -265,63 +174,24 @@ public class LibSvmRecordReaderTest extends BaseND4JTest {
} }
@Test @Test
public void testZeroBasedIndexing() throws IOException, InterruptedException { @DisplayName("Test Zero Based Indexing")
void testZeroBasedIndexing() throws IOException, InterruptedException {
Map<Integer, List<Writable>> correct = new HashMap<>(); Map<Integer, List<Writable>> correct = new HashMap<>();
// 1,3 2:1 4:2 6:3 8:4 10:5 // 1,3 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO, correct.put(0, Arrays.asList(ZERO, ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), LABEL_ZERO, LABEL_ONE, LABEL_ZERO, LABEL_ONE, LABEL_ZERO));
ZERO, ONE,
ZERO, new DoubleWritable(2),
ZERO, new DoubleWritable(3),
ZERO, new DoubleWritable(4),
ZERO, new DoubleWritable(5),
LABEL_ZERO,
LABEL_ONE, LABEL_ZERO,
LABEL_ONE, LABEL_ZERO));
// 2 qid:42 1:0.1 2:2 6:6.6 8:80 // 2 qid:42 1:0.1 2:2 6:6.6 8:80
correct.put(1, Arrays.asList(ZERO, correct.put(1, Arrays.asList(ZERO, new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ONE, LABEL_ZERO, LABEL_ZERO));
new DoubleWritable(0.1), new DoubleWritable(2),
ZERO, ZERO,
ZERO, new DoubleWritable(6.6),
ZERO, new DoubleWritable(80),
ZERO, ZERO,
LABEL_ZERO,
LABEL_ZERO, LABEL_ONE,
LABEL_ZERO, LABEL_ZERO));
// 1,2,4 // 1,2,4
correct.put(2, Arrays.asList(ZERO, correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ONE, LABEL_ONE, LABEL_ZERO, LABEL_ONE));
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
LABEL_ZERO,
LABEL_ONE, LABEL_ONE,
LABEL_ZERO, LABEL_ONE));
// 1:1.0 // 1:1.0
correct.put(3, Arrays.asList(ZERO, correct.put(3, Arrays.asList(ZERO, new DoubleWritable(1.0), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
new DoubleWritable(1.0), ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
LABEL_ZERO,
LABEL_ZERO, LABEL_ZERO,
LABEL_ZERO, LABEL_ZERO));
// //
correct.put(4, Arrays.asList(ZERO, correct.put(4, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
LABEL_ZERO,
LABEL_ZERO, LABEL_ZERO,
LABEL_ZERO, LABEL_ZERO));
LibSvmRecordReader rr = new LibSvmRecordReader(); LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration(); Configuration config = new Configuration();
// Zero-based indexing is default // Zero-based indexing is default
config.setBoolean(SVMLightRecordReader.ZERO_BASED_LABEL_INDEXING, true); // NOT STANDARD! // NOT STANDARD!
config.setBoolean(SVMLightRecordReader.ZERO_BASED_LABEL_INDEXING, true);
config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true); config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true);
config.setInt(LibSvmRecordReader.NUM_FEATURES, 11); config.setInt(LibSvmRecordReader.NUM_FEATURES, 11);
config.setBoolean(LibSvmRecordReader.MULTILABEL, true); config.setBoolean(LibSvmRecordReader.MULTILABEL, true);
@ -336,58 +206,71 @@ public class LibSvmRecordReaderTest extends BaseND4JTest {
assertEquals(i, correct.size()); assertEquals(i, correct.size());
} }
@Test(expected = NoSuchElementException.class) @Test
public void testNoSuchElementException() throws Exception { @DisplayName("Test No Such Element Exception")
void testNoSuchElementException() {
assertThrows(NoSuchElementException.class, () -> {
LibSvmRecordReader rr = new LibSvmRecordReader(); LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration(); Configuration config = new Configuration();
config.setInt(LibSvmRecordReader.NUM_FEATURES, 11); config.setInt(LibSvmRecordReader.NUM_FEATURES, 11);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile())); rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
while (rr.hasNext()) while (rr.hasNext()) rr.next();
rr.next();
rr.next(); rr.next();
});
} }
@Test(expected = UnsupportedOperationException.class) @Test
public void failedToSetNumFeaturesException() throws Exception { @DisplayName("Failed To Set Num Features Exception")
void failedToSetNumFeaturesException() {
assertThrows(UnsupportedOperationException.class, () -> {
LibSvmRecordReader rr = new LibSvmRecordReader(); LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration(); Configuration config = new Configuration();
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile())); rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
while (rr.hasNext()) while (rr.hasNext()) rr.next();
rr.next(); });
} }
@Test(expected = UnsupportedOperationException.class) @Test
public void testInconsistentNumLabelsException() throws Exception { @DisplayName("Test Inconsistent Num Labels Exception")
void testInconsistentNumLabelsException() {
assertThrows(UnsupportedOperationException.class, () -> {
LibSvmRecordReader rr = new LibSvmRecordReader(); LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration(); Configuration config = new Configuration();
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false); config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/inconsistentNumLabels.txt").getFile())); rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/inconsistentNumLabels.txt").getFile()));
while (rr.hasNext()) while (rr.hasNext()) rr.next();
rr.next(); });
} }
@Test(expected = UnsupportedOperationException.class) @Test
public void testInconsistentNumMultiabelsException() throws Exception { @DisplayName("Test Inconsistent Num Multiabels Exception")
void testInconsistentNumMultiabelsException() {
assertThrows(UnsupportedOperationException.class, () -> {
LibSvmRecordReader rr = new LibSvmRecordReader(); LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration(); Configuration config = new Configuration();
config.setBoolean(LibSvmRecordReader.MULTILABEL, false); config.setBoolean(LibSvmRecordReader.MULTILABEL, false);
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false); config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile())); rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile()));
while (rr.hasNext()) while (rr.hasNext()) rr.next();
rr.next(); });
} }
@Test(expected = IndexOutOfBoundsException.class) @Test
public void testFeatureIndexExceedsNumFeatures() throws Exception { @DisplayName("Test Feature Index Exceeds Num Features")
void testFeatureIndexExceedsNumFeatures() {
assertThrows(IndexOutOfBoundsException.class, () -> {
LibSvmRecordReader rr = new LibSvmRecordReader(); LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration(); Configuration config = new Configuration();
config.setInt(LibSvmRecordReader.NUM_FEATURES, 9); config.setInt(LibSvmRecordReader.NUM_FEATURES, 9);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile())); rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
rr.next(); rr.next();
});
} }
@Test(expected = IndexOutOfBoundsException.class) @Test
public void testLabelIndexExceedsNumLabels() throws Exception { @DisplayName("Test Label Index Exceeds Num Labels")
void testLabelIndexExceedsNumLabels() {
assertThrows(IndexOutOfBoundsException.class, () -> {
LibSvmRecordReader rr = new LibSvmRecordReader(); LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration(); Configuration config = new Configuration();
config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true); config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true);
@ -395,10 +278,13 @@ public class LibSvmRecordReaderTest extends BaseND4JTest {
config.setInt(LibSvmRecordReader.NUM_LABELS, 6); config.setInt(LibSvmRecordReader.NUM_LABELS, 6);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile())); rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
rr.next(); rr.next();
});
} }
@Test(expected = IndexOutOfBoundsException.class) @Test
public void testZeroIndexFeatureWithoutUsingZeroIndexing() throws Exception { @DisplayName("Test Zero Index Feature Without Using Zero Indexing")
void testZeroIndexFeatureWithoutUsingZeroIndexing() {
assertThrows(IndexOutOfBoundsException.class, () -> {
LibSvmRecordReader rr = new LibSvmRecordReader(); LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration(); Configuration config = new Configuration();
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false); config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
@ -406,10 +292,13 @@ public class LibSvmRecordReaderTest extends BaseND4JTest {
config.setInt(LibSvmRecordReader.NUM_FEATURES, 10); config.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexFeature.txt").getFile())); rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexFeature.txt").getFile()));
rr.next(); rr.next();
});
} }
@Test(expected = IndexOutOfBoundsException.class) @Test
public void testZeroIndexLabelWithoutUsingZeroIndexing() throws Exception { @DisplayName("Test Zero Index Label Without Using Zero Indexing")
void testZeroIndexLabelWithoutUsingZeroIndexing() {
assertThrows(IndexOutOfBoundsException.class, () -> {
LibSvmRecordReader rr = new LibSvmRecordReader(); LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration(); Configuration config = new Configuration();
config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true); config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true);
@ -418,5 +307,6 @@ public class LibSvmRecordReaderTest extends BaseND4JTest {
config.setInt(LibSvmRecordReader.NUM_LABELS, 2); config.setInt(LibSvmRecordReader.NUM_LABELS, 2);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexLabel.txt").getFile())); rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexLabel.txt").getFile()));
rr.next(); rr.next();
});
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.records.reader.impl; package org.datavec.api.records.reader.impl;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
@ -30,11 +29,10 @@ import org.datavec.api.split.FileSplit;
import org.datavec.api.split.InputSplit; import org.datavec.api.split.InputSplit;
import org.datavec.api.split.InputStreamInputSplit; import org.datavec.api.split.InputStreamInputSplit;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.common.tests.BaseND4JTest;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;
import java.io.File; import java.io.File;
import java.io.FileInputStream; import java.io.FileInputStream;
import java.io.FileOutputStream; import java.io.FileOutputStream;
@ -45,34 +43,31 @@ import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.zip.GZIPInputStream; import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream; import java.util.zip.GZIPOutputStream;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals; @DisplayName("Line Reader Test")
class LineReaderTest extends BaseND4JTest {
public class LineReaderTest extends BaseND4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test @Test
public void testLineReader() throws Exception { @DisplayName("Test Line Reader")
File tmpdir = testDir.newFolder(); void testLineReader(@TempDir Path tmpDir) throws Exception {
File tmpdir = tmpDir.toFile();
if (tmpdir.exists()) if (tmpdir.exists())
tmpdir.delete(); tmpdir.delete();
tmpdir.mkdir(); tmpdir.mkdir();
File tmp1 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp1.txt")); File tmp1 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp1.txt"));
File tmp2 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp2.txt")); File tmp2 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp2.txt"));
File tmp3 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp3.txt")); File tmp3 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp3.txt"));
FileUtils.writeLines(tmp1, Arrays.asList("1", "2", "3")); FileUtils.writeLines(tmp1, Arrays.asList("1", "2", "3"));
FileUtils.writeLines(tmp2, Arrays.asList("4", "5", "6")); FileUtils.writeLines(tmp2, Arrays.asList("4", "5", "6"));
FileUtils.writeLines(tmp3, Arrays.asList("7", "8", "9")); FileUtils.writeLines(tmp3, Arrays.asList("7", "8", "9"));
InputSplit split = new FileSplit(tmpdir); InputSplit split = new FileSplit(tmpdir);
RecordReader reader = new LineRecordReader(); RecordReader reader = new LineRecordReader();
reader.initialize(split); reader.initialize(split);
int count = 0; int count = 0;
List<List<Writable>> list = new ArrayList<>(); List<List<Writable>> list = new ArrayList<>();
while (reader.hasNext()) { while (reader.hasNext()) {
@ -81,34 +76,27 @@ public class LineReaderTest extends BaseND4JTest {
list.add(l); list.add(l);
count++; count++;
} }
assertEquals(9, count); assertEquals(9, count);
} }
@Test @Test
public void testLineReaderMetaData() throws Exception { @DisplayName("Test Line Reader Meta Data")
File tmpdir = testDir.newFolder(); void testLineReaderMetaData(@TempDir Path tmpDir) throws Exception {
File tmpdir = tmpDir.toFile();
File tmp1 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp1.txt")); File tmp1 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp1.txt"));
File tmp2 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp2.txt")); File tmp2 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp2.txt"));
File tmp3 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp3.txt")); File tmp3 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp3.txt"));
FileUtils.writeLines(tmp1, Arrays.asList("1", "2", "3")); FileUtils.writeLines(tmp1, Arrays.asList("1", "2", "3"));
FileUtils.writeLines(tmp2, Arrays.asList("4", "5", "6")); FileUtils.writeLines(tmp2, Arrays.asList("4", "5", "6"));
FileUtils.writeLines(tmp3, Arrays.asList("7", "8", "9")); FileUtils.writeLines(tmp3, Arrays.asList("7", "8", "9"));
InputSplit split = new FileSplit(tmpdir); InputSplit split = new FileSplit(tmpdir);
RecordReader reader = new LineRecordReader(); RecordReader reader = new LineRecordReader();
reader.initialize(split); reader.initialize(split);
List<List<Writable>> list = new ArrayList<>(); List<List<Writable>> list = new ArrayList<>();
while (reader.hasNext()) { while (reader.hasNext()) {
list.add(reader.next()); list.add(reader.next());
} }
assertEquals(9, list.size()); assertEquals(9, list.size());
List<List<Writable>> out2 = new ArrayList<>(); List<List<Writable>> out2 = new ArrayList<>();
List<Record> out3 = new ArrayList<>(); List<Record> out3 = new ArrayList<>();
List<RecordMetaData> meta = new ArrayList<>(); List<RecordMetaData> meta = new ArrayList<>();
@ -124,13 +112,10 @@ public class LineReaderTest extends BaseND4JTest {
assertEquals(uri, split.locations()[fileIdx]); assertEquals(uri, split.locations()[fileIdx]);
count++; count++;
} }
assertEquals(list, out2); assertEquals(list, out2);
List<Record> fromMeta = reader.loadFromMetaData(meta); List<Record> fromMeta = reader.loadFromMetaData(meta);
assertEquals(out3, fromMeta); assertEquals(out3, fromMeta);
// try: second line of second and third files only...
//try: second line of second and third files only...
List<RecordMetaData> subsetMeta = new ArrayList<>(); List<RecordMetaData> subsetMeta = new ArrayList<>();
subsetMeta.add(meta.get(4)); subsetMeta.add(meta.get(4));
subsetMeta.add(meta.get(7)); subsetMeta.add(meta.get(7));
@ -141,27 +126,22 @@ public class LineReaderTest extends BaseND4JTest {
} }
@Test @Test
public void testLineReaderWithInputStreamInputSplit() throws Exception { @DisplayName("Test Line Reader With Input Stream Input Split")
File tmpdir = testDir.newFolder(); void testLineReaderWithInputStreamInputSplit(@TempDir Path testDir) throws Exception {
File tmpdir = testDir.toFile();
File tmp1 = new File(tmpdir, "tmp1.txt.gz"); File tmp1 = new File(tmpdir, "tmp1.txt.gz");
OutputStream os = new GZIPOutputStream(new FileOutputStream(tmp1, false)); OutputStream os = new GZIPOutputStream(new FileOutputStream(tmp1, false));
IOUtils.writeLines(Arrays.asList("1", "2", "3", "4", "5", "6", "7", "8", "9"), null, os); IOUtils.writeLines(Arrays.asList("1", "2", "3", "4", "5", "6", "7", "8", "9"), null, os);
os.flush(); os.flush();
os.close(); os.close();
InputSplit split = new InputStreamInputSplit(new GZIPInputStream(new FileInputStream(tmp1))); InputSplit split = new InputStreamInputSplit(new GZIPInputStream(new FileInputStream(tmp1)));
RecordReader reader = new LineRecordReader(); RecordReader reader = new LineRecordReader();
reader.initialize(split); reader.initialize(split);
int count = 0; int count = 0;
while (reader.hasNext()) { while (reader.hasNext()) {
assertEquals(1, reader.next().size()); assertEquals(1, reader.next().size());
count++; count++;
} }
assertEquals(9, count); assertEquals(9, count);
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.records.reader.impl; package org.datavec.api.records.reader.impl;
import org.datavec.api.records.Record; import org.datavec.api.records.Record;
@ -33,44 +32,41 @@ import org.datavec.api.split.InputSplit;
import org.datavec.api.split.NumberedFileInputSplit; import org.datavec.api.split.NumberedFileInputSplit;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.TemporaryFolder; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import java.io.File; import java.io.File;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals; @DisplayName("Regex Record Reader Test")
import static org.junit.Assert.assertFalse; class RegexRecordReaderTest extends BaseND4JTest {
public class RegexRecordReaderTest extends BaseND4JTest { @TempDir
public Path testDir;
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test @Test
public void testRegexLineRecordReader() throws Exception { @DisplayName("Test Regex Line Record Reader")
void testRegexLineRecordReader() throws Exception {
String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)"; String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)";
RecordReader rr = new RegexLineRecordReader(regex, 1); RecordReader rr = new RegexLineRecordReader(regex, 1);
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/logtestdata/logtestfile0.txt").getFile())); rr.initialize(new FileSplit(new ClassPathResource("datavec-api/logtestdata/logtestfile0.txt").getFile()));
List<Writable> exp0 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.001"), new Text("1"), new Text("DEBUG"), new Text("First entry message!"));
List<Writable> exp0 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.001"), new Text("1"), List<Writable> exp1 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.002"), new Text("2"), new Text("INFO"), new Text("Second entry message!"));
new Text("DEBUG"), new Text("First entry message!")); List<Writable> exp2 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.003"), new Text("3"), new Text("WARN"), new Text("Third entry message!"));
List<Writable> exp1 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.002"), new Text("2"),
new Text("INFO"), new Text("Second entry message!"));
List<Writable> exp2 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.003"), new Text("3"),
new Text("WARN"), new Text("Third entry message!"));
assertEquals(exp0, rr.next()); assertEquals(exp0, rr.next());
assertEquals(exp1, rr.next()); assertEquals(exp1, rr.next());
assertEquals(exp2, rr.next()); assertEquals(exp2, rr.next());
assertFalse(rr.hasNext()); assertFalse(rr.hasNext());
// Test reset:
//Test reset:
rr.reset(); rr.reset();
assertEquals(exp0, rr.next()); assertEquals(exp0, rr.next());
assertEquals(exp1, rr.next()); assertEquals(exp1, rr.next());
@ -79,74 +75,57 @@ public class RegexRecordReaderTest extends BaseND4JTest {
} }
@Test @Test
public void testRegexLineRecordReaderMeta() throws Exception { @DisplayName("Test Regex Line Record Reader Meta")
void testRegexLineRecordReaderMeta() throws Exception {
String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)"; String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)";
RecordReader rr = new RegexLineRecordReader(regex, 1); RecordReader rr = new RegexLineRecordReader(regex, 1);
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/logtestdata/logtestfile0.txt").getFile())); rr.initialize(new FileSplit(new ClassPathResource("datavec-api/logtestdata/logtestfile0.txt").getFile()));
List<List<Writable>> list = new ArrayList<>(); List<List<Writable>> list = new ArrayList<>();
while (rr.hasNext()) { while (rr.hasNext()) {
list.add(rr.next()); list.add(rr.next());
} }
assertEquals(3, list.size()); assertEquals(3, list.size());
List<Record> list2 = new ArrayList<>(); List<Record> list2 = new ArrayList<>();
List<List<Writable>> list3 = new ArrayList<>(); List<List<Writable>> list3 = new ArrayList<>();
List<RecordMetaData> meta = new ArrayList<>(); List<RecordMetaData> meta = new ArrayList<>();
rr.reset(); rr.reset();
int count = 1; //Start by skipping 1 line // Start by skipping 1 line
int count = 1;
while (rr.hasNext()) { while (rr.hasNext()) {
Record r = rr.nextRecord(); Record r = rr.nextRecord();
list2.add(r); list2.add(r);
list3.add(r.getRecord()); list3.add(r.getRecord());
meta.add(r.getMetaData()); meta.add(r.getMetaData());
assertEquals(count++, ((RecordMetaDataLine) r.getMetaData()).getLineNumber()); assertEquals(count++, ((RecordMetaDataLine) r.getMetaData()).getLineNumber());
} }
List<Record> fromMeta = rr.loadFromMetaData(meta); List<Record> fromMeta = rr.loadFromMetaData(meta);
assertEquals(list, list3); assertEquals(list, list3);
assertEquals(list2, fromMeta); assertEquals(list2, fromMeta);
} }
@Test @Test
public void testRegexSequenceRecordReader() throws Exception { @DisplayName("Test Regex Sequence Record Reader")
void testRegexSequenceRecordReader(@TempDir Path testDir) throws Exception {
String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)"; String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)";
ClassPathResource cpr = new ClassPathResource("datavec-api/logtestdata/"); ClassPathResource cpr = new ClassPathResource("datavec-api/logtestdata/");
File f = testDir.newFolder(); File f = testDir.toFile();
cpr.copyDirectory(f); cpr.copyDirectory(f);
String path = new File(f, "logtestfile%d.txt").getAbsolutePath(); String path = new File(f, "logtestfile%d.txt").getAbsolutePath();
InputSplit is = new NumberedFileInputSplit(path, 0, 1); InputSplit is = new NumberedFileInputSplit(path, 0, 1);
SequenceRecordReader rr = new RegexSequenceRecordReader(regex, 1); SequenceRecordReader rr = new RegexSequenceRecordReader(regex, 1);
rr.initialize(is); rr.initialize(is);
List<List<Writable>> exp0 = new ArrayList<>(); List<List<Writable>> exp0 = new ArrayList<>();
exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.001"), new Text("1"), new Text("DEBUG"), exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.001"), new Text("1"), new Text("DEBUG"), new Text("First entry message!")));
new Text("First entry message!"))); exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.002"), new Text("2"), new Text("INFO"), new Text("Second entry message!")));
exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.002"), new Text("2"), new Text("INFO"), exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.003"), new Text("3"), new Text("WARN"), new Text("Third entry message!")));
new Text("Second entry message!")));
exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.003"), new Text("3"), new Text("WARN"),
new Text("Third entry message!")));
List<List<Writable>> exp1 = new ArrayList<>(); List<List<Writable>> exp1 = new ArrayList<>();
exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.011"), new Text("11"), new Text("DEBUG"), exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.011"), new Text("11"), new Text("DEBUG"), new Text("First entry message!")));
new Text("First entry message!"))); exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.012"), new Text("12"), new Text("INFO"), new Text("Second entry message!")));
exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.012"), new Text("12"), new Text("INFO"), exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.013"), new Text("13"), new Text("WARN"), new Text("Third entry message!")));
new Text("Second entry message!")));
exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.013"), new Text("13"), new Text("WARN"),
new Text("Third entry message!")));
assertEquals(exp0, rr.sequenceRecord()); assertEquals(exp0, rr.sequenceRecord());
assertEquals(exp1, rr.sequenceRecord()); assertEquals(exp1, rr.sequenceRecord());
assertFalse(rr.hasNext()); assertFalse(rr.hasNext());
// Test resetting:
//Test resetting:
rr.reset(); rr.reset();
assertEquals(exp0, rr.sequenceRecord()); assertEquals(exp0, rr.sequenceRecord());
assertEquals(exp1, rr.sequenceRecord()); assertEquals(exp1, rr.sequenceRecord());
@ -154,24 +133,20 @@ public class RegexRecordReaderTest extends BaseND4JTest {
} }
@Test @Test
public void testRegexSequenceRecordReaderMeta() throws Exception { @DisplayName("Test Regex Sequence Record Reader Meta")
void testRegexSequenceRecordReaderMeta(@TempDir Path testDir) throws Exception {
String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)"; String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)";
ClassPathResource cpr = new ClassPathResource("datavec-api/logtestdata/"); ClassPathResource cpr = new ClassPathResource("datavec-api/logtestdata/");
File f = testDir.newFolder(); File f = testDir.toFile();
cpr.copyDirectory(f); cpr.copyDirectory(f);
String path = new File(f, "logtestfile%d.txt").getAbsolutePath(); String path = new File(f, "logtestfile%d.txt").getAbsolutePath();
InputSplit is = new NumberedFileInputSplit(path, 0, 1); InputSplit is = new NumberedFileInputSplit(path, 0, 1);
SequenceRecordReader rr = new RegexSequenceRecordReader(regex, 1); SequenceRecordReader rr = new RegexSequenceRecordReader(regex, 1);
rr.initialize(is); rr.initialize(is);
List<List<List<Writable>>> out = new ArrayList<>(); List<List<List<Writable>>> out = new ArrayList<>();
while (rr.hasNext()) { while (rr.hasNext()) {
out.add(rr.sequenceRecord()); out.add(rr.sequenceRecord());
} }
assertEquals(2, out.size()); assertEquals(2, out.size());
List<List<List<Writable>>> out2 = new ArrayList<>(); List<List<List<Writable>>> out2 = new ArrayList<>();
List<SequenceRecord> out3 = new ArrayList<>(); List<SequenceRecord> out3 = new ArrayList<>();
@ -183,11 +158,8 @@ public class RegexRecordReaderTest extends BaseND4JTest {
out3.add(seqr); out3.add(seqr);
meta.add(seqr.getMetaData()); meta.add(seqr.getMetaData());
} }
List<SequenceRecord> fromMeta = rr.loadSequenceFromMetaData(meta); List<SequenceRecord> fromMeta = rr.loadSequenceFromMetaData(meta);
assertEquals(out, out2); assertEquals(out, out2);
assertEquals(out3, fromMeta); assertEquals(out3, fromMeta);
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.records.reader.impl; package org.datavec.api.records.reader.impl;
import org.datavec.api.conf.Configuration; import org.datavec.api.conf.Configuration;
@ -27,43 +26,30 @@ import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import java.io.IOException; import java.io.IOException;
import java.util.*; import java.util.*;
import static org.datavec.api.records.reader.impl.misc.SVMLightRecordReader.*; import static org.datavec.api.records.reader.impl.misc.SVMLightRecordReader.*;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.jupiter.api.Assertions.assertThrows;
public class SVMLightRecordReaderTest extends BaseND4JTest { @DisplayName("Svm Light Record Reader Test")
class SVMLightRecordReaderTest extends BaseND4JTest {
@Test @Test
public void testBasicRecord() throws IOException, InterruptedException { @DisplayName("Test Basic Record")
void testBasicRecord() throws IOException, InterruptedException {
Map<Integer, List<Writable>> correct = new HashMap<>(); Map<Integer, List<Writable>> correct = new HashMap<>();
// 7 2:1 4:2 6:3 8:4 10:5 // 7 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO, ONE, correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), new IntWritable(7)));
ZERO, new DoubleWritable(2),
ZERO, new DoubleWritable(3),
ZERO, new DoubleWritable(4),
ZERO, new DoubleWritable(5),
new IntWritable(7)));
// 2 qid:42 1:0.1 2:2 6:6.6 8:80 // 2 qid:42 1:0.1 2:2 6:6.6 8:80
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, new IntWritable(2)));
ZERO, ZERO,
ZERO, new DoubleWritable(6.6),
ZERO, new DoubleWritable(80),
ZERO, ZERO,
new IntWritable(2)));
// 33 // 33
correct.put(2, Arrays.asList(ZERO, ZERO, correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, new IntWritable(33)));
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
new IntWritable(33)));
SVMLightRecordReader rr = new SVMLightRecordReader(); SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration(); Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false); config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
@ -79,27 +65,15 @@ public class SVMLightRecordReaderTest extends BaseND4JTest {
} }
@Test @Test
public void testNoAppendLabel() throws IOException, InterruptedException { @DisplayName("Test No Append Label")
void testNoAppendLabel() throws IOException, InterruptedException {
Map<Integer, List<Writable>> correct = new HashMap<>(); Map<Integer, List<Writable>> correct = new HashMap<>();
// 7 2:1 4:2 6:3 8:4 10:5 // 7 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO, ONE, correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5)));
ZERO, new DoubleWritable(2),
ZERO, new DoubleWritable(3),
ZERO, new DoubleWritable(4),
ZERO, new DoubleWritable(5)));
// 2 qid:42 1:0.1 2:2 6:6.6 8:80 // 2 qid:42 1:0.1 2:2 6:6.6 8:80
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO));
ZERO, ZERO,
ZERO, new DoubleWritable(6.6),
ZERO, new DoubleWritable(80),
ZERO, ZERO));
// 33 // 33
correct.put(2, Arrays.asList(ZERO, ZERO, correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO));
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO));
SVMLightRecordReader rr = new SVMLightRecordReader(); SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration(); Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false); config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
@ -116,33 +90,17 @@ public class SVMLightRecordReaderTest extends BaseND4JTest {
} }
@Test @Test
public void testNoLabel() throws IOException, InterruptedException { @DisplayName("Test No Label")
void testNoLabel() throws IOException, InterruptedException {
Map<Integer, List<Writable>> correct = new HashMap<>(); Map<Integer, List<Writable>> correct = new HashMap<>();
// 2:1 4:2 6:3 8:4 10:5 // 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO, ONE, correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5)));
ZERO, new DoubleWritable(2),
ZERO, new DoubleWritable(3),
ZERO, new DoubleWritable(4),
ZERO, new DoubleWritable(5)));
// qid:42 1:0.1 2:2 6:6.6 8:80 // qid:42 1:0.1 2:2 6:6.6 8:80
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO));
ZERO, ZERO,
ZERO, new DoubleWritable(6.6),
ZERO, new DoubleWritable(80),
ZERO, ZERO));
// 1:1.0 // 1:1.0
correct.put(2, Arrays.asList(new DoubleWritable(1.0), ZERO, correct.put(2, Arrays.asList(new DoubleWritable(1.0), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO));
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO));
// //
correct.put(3, Arrays.asList(ZERO, ZERO, correct.put(3, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO));
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO));
SVMLightRecordReader rr = new SVMLightRecordReader(); SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration(); Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false); config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
@ -159,33 +117,15 @@ public class SVMLightRecordReaderTest extends BaseND4JTest {
} }
@Test @Test
public void testMultioutputRecord() throws IOException, InterruptedException { @DisplayName("Test Multioutput Record")
void testMultioutputRecord() throws IOException, InterruptedException {
Map<Integer, List<Writable>> correct = new HashMap<>(); Map<Integer, List<Writable>> correct = new HashMap<>();
// 7 2.45,9 2:1 4:2 6:3 8:4 10:5 // 7 2.45,9 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO, ONE, correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), new IntWritable(7), new DoubleWritable(2.45), new IntWritable(9)));
ZERO, new DoubleWritable(2),
ZERO, new DoubleWritable(3),
ZERO, new DoubleWritable(4),
ZERO, new DoubleWritable(5),
new IntWritable(7), new DoubleWritable(2.45),
new IntWritable(9)));
// 2,3,4 qid:42 1:0.1 2:2 6:6.6 8:80 // 2,3,4 qid:42 1:0.1 2:2 6:6.6 8:80
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, new IntWritable(2), new IntWritable(3), new IntWritable(4)));
ZERO, ZERO,
ZERO, new DoubleWritable(6.6),
ZERO, new DoubleWritable(80),
ZERO, ZERO,
new IntWritable(2), new IntWritable(3),
new IntWritable(4)));
// 33,32.0,31.9 // 33,32.0,31.9
correct.put(2, Arrays.asList(ZERO, ZERO, correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, new IntWritable(33), new DoubleWritable(32.0), new DoubleWritable(31.9)));
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
new IntWritable(33), new DoubleWritable(32.0),
new DoubleWritable(31.9)));
SVMLightRecordReader rr = new SVMLightRecordReader(); SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration(); Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false); config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
@ -200,51 +140,20 @@ public class SVMLightRecordReaderTest extends BaseND4JTest {
assertEquals(i, correct.size()); assertEquals(i, correct.size());
} }
@Test @Test
public void testMultilabelRecord() throws IOException, InterruptedException { @DisplayName("Test Multilabel Record")
void testMultilabelRecord() throws IOException, InterruptedException {
Map<Integer, List<Writable>> correct = new HashMap<>(); Map<Integer, List<Writable>> correct = new HashMap<>();
// 1,3 2:1 4:2 6:3 8:4 10:5 // 1,3 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO, ONE, correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), LABEL_ONE, LABEL_ZERO, LABEL_ONE, LABEL_ZERO));
ZERO, new DoubleWritable(2),
ZERO, new DoubleWritable(3),
ZERO, new DoubleWritable(4),
ZERO, new DoubleWritable(5),
LABEL_ONE, LABEL_ZERO,
LABEL_ONE, LABEL_ZERO));
// 2 qid:42 1:0.1 2:2 6:6.6 8:80 // 2 qid:42 1:0.1 2:2 6:6.6 8:80
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, LABEL_ZERO, LABEL_ONE, LABEL_ZERO, LABEL_ZERO));
ZERO, ZERO,
ZERO, new DoubleWritable(6.6),
ZERO, new DoubleWritable(80),
ZERO, ZERO,
LABEL_ZERO, LABEL_ONE,
LABEL_ZERO, LABEL_ZERO));
// 1,2,4 // 1,2,4
correct.put(2, Arrays.asList(ZERO, ZERO, correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ONE, LABEL_ONE, LABEL_ZERO, LABEL_ONE));
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
LABEL_ONE, LABEL_ONE,
LABEL_ZERO, LABEL_ONE));
// 1:1.0 // 1:1.0
correct.put(3, Arrays.asList(new DoubleWritable(1.0), ZERO, correct.put(3, Arrays.asList(new DoubleWritable(1.0), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
LABEL_ZERO, LABEL_ZERO,
LABEL_ZERO, LABEL_ZERO));
// //
correct.put(4, Arrays.asList(ZERO, ZERO, correct.put(4, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
LABEL_ZERO, LABEL_ZERO,
LABEL_ZERO, LABEL_ZERO));
SVMLightRecordReader rr = new SVMLightRecordReader(); SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration(); Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false); config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
@ -262,63 +171,24 @@ public class SVMLightRecordReaderTest extends BaseND4JTest {
} }
@Test @Test
public void testZeroBasedIndexing() throws IOException, InterruptedException { @DisplayName("Test Zero Based Indexing")
void testZeroBasedIndexing() throws IOException, InterruptedException {
Map<Integer, List<Writable>> correct = new HashMap<>(); Map<Integer, List<Writable>> correct = new HashMap<>();
// 1,3 2:1 4:2 6:3 8:4 10:5 // 1,3 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO, correct.put(0, Arrays.asList(ZERO, ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), LABEL_ZERO, LABEL_ONE, LABEL_ZERO, LABEL_ONE, LABEL_ZERO));
ZERO, ONE,
ZERO, new DoubleWritable(2),
ZERO, new DoubleWritable(3),
ZERO, new DoubleWritable(4),
ZERO, new DoubleWritable(5),
LABEL_ZERO,
LABEL_ONE, LABEL_ZERO,
LABEL_ONE, LABEL_ZERO));
// 2 qid:42 1:0.1 2:2 6:6.6 8:80 // 2 qid:42 1:0.1 2:2 6:6.6 8:80
correct.put(1, Arrays.asList(ZERO, correct.put(1, Arrays.asList(ZERO, new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ONE, LABEL_ZERO, LABEL_ZERO));
new DoubleWritable(0.1), new DoubleWritable(2),
ZERO, ZERO,
ZERO, new DoubleWritable(6.6),
ZERO, new DoubleWritable(80),
ZERO, ZERO,
LABEL_ZERO,
LABEL_ZERO, LABEL_ONE,
LABEL_ZERO, LABEL_ZERO));
// 1,2,4 // 1,2,4
correct.put(2, Arrays.asList(ZERO, correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ONE, LABEL_ONE, LABEL_ZERO, LABEL_ONE));
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
LABEL_ZERO,
LABEL_ONE, LABEL_ONE,
LABEL_ZERO, LABEL_ONE));
// 1:1.0 // 1:1.0
correct.put(3, Arrays.asList(ZERO, correct.put(3, Arrays.asList(ZERO, new DoubleWritable(1.0), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
new DoubleWritable(1.0), ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
LABEL_ZERO,
LABEL_ZERO, LABEL_ZERO,
LABEL_ZERO, LABEL_ZERO));
// //
correct.put(4, Arrays.asList(ZERO, correct.put(4, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
LABEL_ZERO,
LABEL_ZERO, LABEL_ZERO,
LABEL_ZERO, LABEL_ZERO));
SVMLightRecordReader rr = new SVMLightRecordReader(); SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration(); Configuration config = new Configuration();
// Zero-based indexing is default // Zero-based indexing is default
config.setBoolean(SVMLightRecordReader.ZERO_BASED_LABEL_INDEXING, true); // NOT STANDARD! // NOT STANDARD!
config.setBoolean(SVMLightRecordReader.ZERO_BASED_LABEL_INDEXING, true);
config.setInt(SVMLightRecordReader.NUM_FEATURES, 11); config.setInt(SVMLightRecordReader.NUM_FEATURES, 11);
config.setBoolean(SVMLightRecordReader.MULTILABEL, true); config.setBoolean(SVMLightRecordReader.MULTILABEL, true);
config.setInt(SVMLightRecordReader.NUM_LABELS, 5); config.setInt(SVMLightRecordReader.NUM_LABELS, 5);
@ -333,20 +203,19 @@ public class SVMLightRecordReaderTest extends BaseND4JTest {
} }
@Test @Test
public void testNextRecord() throws IOException, InterruptedException { @DisplayName("Test Next Record")
void testNextRecord() throws IOException, InterruptedException {
SVMLightRecordReader rr = new SVMLightRecordReader(); SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration(); Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false); config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
config.setInt(SVMLightRecordReader.NUM_FEATURES, 10); config.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
config.setBoolean(SVMLightRecordReader.APPEND_LABEL, false); config.setBoolean(SVMLightRecordReader.APPEND_LABEL, false);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile())); rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
Record record = rr.nextRecord(); Record record = rr.nextRecord();
List<Writable> recordList = record.getRecord(); List<Writable> recordList = record.getRecord();
assertEquals(new DoubleWritable(1.0), recordList.get(1)); assertEquals(new DoubleWritable(1.0), recordList.get(1));
assertEquals(new DoubleWritable(3.0), recordList.get(5)); assertEquals(new DoubleWritable(3.0), recordList.get(5));
assertEquals(new DoubleWritable(4.0), recordList.get(7)); assertEquals(new DoubleWritable(4.0), recordList.get(7));
record = rr.nextRecord(); record = rr.nextRecord();
recordList = record.getRecord(); recordList = record.getRecord();
assertEquals(new DoubleWritable(0.1), recordList.get(0)); assertEquals(new DoubleWritable(0.1), recordList.get(0));
@ -354,76 +223,95 @@ public class SVMLightRecordReaderTest extends BaseND4JTest {
assertEquals(new DoubleWritable(80.0), recordList.get(7)); assertEquals(new DoubleWritable(80.0), recordList.get(7));
} }
@Test(expected = NoSuchElementException.class) @Test
public void testNoSuchElementException() throws Exception { @DisplayName("Test No Such Element Exception")
void testNoSuchElementException() {
assertThrows(NoSuchElementException.class, () -> {
SVMLightRecordReader rr = new SVMLightRecordReader(); SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration(); Configuration config = new Configuration();
config.setInt(SVMLightRecordReader.NUM_FEATURES, 11); config.setInt(SVMLightRecordReader.NUM_FEATURES, 11);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile())); rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
while (rr.hasNext()) while (rr.hasNext()) rr.next();
rr.next();
rr.next(); rr.next();
});
} }
@Test(expected = UnsupportedOperationException.class) @Test
public void failedToSetNumFeaturesException() throws Exception { @DisplayName("Failed To Set Num Features Exception")
void failedToSetNumFeaturesException() {
assertThrows(UnsupportedOperationException.class, () -> {
SVMLightRecordReader rr = new SVMLightRecordReader(); SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration(); Configuration config = new Configuration();
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile())); rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
while (rr.hasNext()) while (rr.hasNext()) rr.next();
rr.next(); });
} }
@Test(expected = UnsupportedOperationException.class) @Test
public void testInconsistentNumLabelsException() throws Exception { @DisplayName("Test Inconsistent Num Labels Exception")
void testInconsistentNumLabelsException() {
assertThrows(UnsupportedOperationException.class, () -> {
SVMLightRecordReader rr = new SVMLightRecordReader(); SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration(); Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false); config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/inconsistentNumLabels.txt").getFile())); rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/inconsistentNumLabels.txt").getFile()));
while (rr.hasNext()) while (rr.hasNext()) rr.next();
rr.next(); });
} }
@Test(expected = UnsupportedOperationException.class) @Test
public void failedToSetNumMultiabelsException() throws Exception { @DisplayName("Failed To Set Num Multiabels Exception")
void failedToSetNumMultiabelsException() {
assertThrows(UnsupportedOperationException.class, () -> {
SVMLightRecordReader rr = new SVMLightRecordReader(); SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration(); Configuration config = new Configuration();
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile())); rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile()));
while (rr.hasNext()) while (rr.hasNext()) rr.next();
rr.next(); });
} }
@Test(expected = IndexOutOfBoundsException.class) @Test
public void testFeatureIndexExceedsNumFeatures() throws Exception { @DisplayName("Test Feature Index Exceeds Num Features")
void testFeatureIndexExceedsNumFeatures() {
assertThrows(IndexOutOfBoundsException.class, () -> {
SVMLightRecordReader rr = new SVMLightRecordReader(); SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration(); Configuration config = new Configuration();
config.setInt(SVMLightRecordReader.NUM_FEATURES, 9); config.setInt(SVMLightRecordReader.NUM_FEATURES, 9);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile())); rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
rr.next(); rr.next();
});
} }
@Test(expected = IndexOutOfBoundsException.class) @Test
public void testLabelIndexExceedsNumLabels() throws Exception { @DisplayName("Test Label Index Exceeds Num Labels")
void testLabelIndexExceedsNumLabels() {
assertThrows(IndexOutOfBoundsException.class, () -> {
SVMLightRecordReader rr = new SVMLightRecordReader(); SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration(); Configuration config = new Configuration();
config.setInt(SVMLightRecordReader.NUM_FEATURES, 10); config.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
config.setInt(SVMLightRecordReader.NUM_LABELS, 6); config.setInt(SVMLightRecordReader.NUM_LABELS, 6);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile())); rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
rr.next(); rr.next();
});
} }
@Test(expected = IndexOutOfBoundsException.class) @Test
public void testZeroIndexFeatureWithoutUsingZeroIndexing() throws Exception { @DisplayName("Test Zero Index Feature Without Using Zero Indexing")
void testZeroIndexFeatureWithoutUsingZeroIndexing() {
assertThrows(IndexOutOfBoundsException.class, () -> {
SVMLightRecordReader rr = new SVMLightRecordReader(); SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration(); Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false); config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
config.setInt(SVMLightRecordReader.NUM_FEATURES, 10); config.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexFeature.txt").getFile())); rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexFeature.txt").getFile()));
rr.next(); rr.next();
});
} }
@Test(expected = IndexOutOfBoundsException.class) @Test
public void testZeroIndexLabelWithoutUsingZeroIndexing() throws Exception { @DisplayName("Test Zero Index Label Without Using Zero Indexing")
void testZeroIndexLabelWithoutUsingZeroIndexing() {
assertThrows(IndexOutOfBoundsException.class, () -> {
SVMLightRecordReader rr = new SVMLightRecordReader(); SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration(); Configuration config = new Configuration();
config.setInt(SVMLightRecordReader.NUM_FEATURES, 10); config.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
@ -431,5 +319,6 @@ public class SVMLightRecordReaderTest extends BaseND4JTest {
config.setInt(SVMLightRecordReader.NUM_LABELS, 2); config.setInt(SVMLightRecordReader.NUM_LABELS, 2);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexLabel.txt").getFile())); rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexLabel.txt").getFile()));
rr.next(); rr.next();
});
} }
} }

View File

@ -26,14 +26,14 @@ import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader; import org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader;
import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.Assert.*; import static org.junit.jupiter.api.Assertions.*;
public class TestCollectionRecordReaders extends BaseND4JTest { public class TestCollectionRecordReaders extends BaseND4JTest {

View File

@ -23,11 +23,11 @@ package org.datavec.api.records.reader.impl;
import org.datavec.api.records.reader.RecordReader; import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit; import org.datavec.api.split.FileSplit;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestConcatenatingRecordReader extends BaseND4JTest { public class TestConcatenatingRecordReader extends BaseND4JTest {

View File

@ -37,7 +37,7 @@ import org.datavec.api.transform.TransformProcess;
import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import org.nd4j.shade.jackson.core.JsonFactory; import org.nd4j.shade.jackson.core.JsonFactory;
@ -47,7 +47,7 @@ import java.io.*;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestSerialization extends BaseND4JTest { public class TestSerialization extends BaseND4JTest {

View File

@ -30,7 +30,7 @@ import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.LongWritable; import org.datavec.api.writable.LongWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.joda.time.DateTimeZone; import org.joda.time.DateTimeZone;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
@ -38,8 +38,8 @@ import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
public class TransformProcessRecordReaderTests extends BaseND4JTest { public class TransformProcessRecordReaderTests extends BaseND4JTest {

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.records.writer.impl; package org.datavec.api.records.writer.impl;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
@ -26,44 +25,42 @@ import org.datavec.api.split.FileSplit;
import org.datavec.api.split.partition.NumberOfRecordsPartitioner; import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Before; import org.junit.jupiter.api.BeforeEach;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.io.File; import java.io.File;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals; @DisplayName("Csv Record Writer Test")
class CSVRecordWriterTest extends BaseND4JTest {
public class CSVRecordWriterTest extends BaseND4JTest {
@Before
public void setUp() throws Exception {
@BeforeEach
void setUp() throws Exception {
} }
@Test @Test
public void testWrite() throws Exception { @DisplayName("Test Write")
void testWrite() throws Exception {
File tempFile = File.createTempFile("datavec", "writer"); File tempFile = File.createTempFile("datavec", "writer");
tempFile.deleteOnExit(); tempFile.deleteOnExit();
FileSplit fileSplit = new FileSplit(tempFile); FileSplit fileSplit = new FileSplit(tempFile);
CSVRecordWriter writer = new CSVRecordWriter(); CSVRecordWriter writer = new CSVRecordWriter();
writer.initialize(fileSplit,new NumberOfRecordsPartitioner()); writer.initialize(fileSplit, new NumberOfRecordsPartitioner());
List<Writable> collection = new ArrayList<>(); List<Writable> collection = new ArrayList<>();
collection.add(new Text("12")); collection.add(new Text("12"));
collection.add(new Text("13")); collection.add(new Text("13"));
collection.add(new Text("14")); collection.add(new Text("14"));
writer.write(collection); writer.write(collection);
CSVRecordReader reader = new CSVRecordReader(0); CSVRecordReader reader = new CSVRecordReader(0);
reader.initialize(new FileSplit(tempFile)); reader.initialize(new FileSplit(tempFile));
int cnt = 0; int cnt = 0;
while (reader.hasNext()) { while (reader.hasNext()) {
List<Writable> line = new ArrayList<>(reader.next()); List<Writable> line = new ArrayList<>(reader.next());
assertEquals(3, line.size()); assertEquals(3, line.size());
assertEquals(12, line.get(0).toInt()); assertEquals(12, line.get(0).toInt());
assertEquals(13, line.get(1).toInt()); assertEquals(13, line.get(1).toInt());
assertEquals(14, line.get(2).toInt()); assertEquals(14, line.get(2).toInt());

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.records.writer.impl; package org.datavec.api.records.writer.impl;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
@ -30,93 +29,90 @@ import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.NDArrayWritable; import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import java.io.File; import java.io.File;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.regex.Matcher; import java.util.regex.Matcher;
import java.util.regex.Pattern; import java.util.regex.Pattern;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.Assert.assertEquals; @DisplayName("Lib Svm Record Writer Test")
class LibSvmRecordWriterTest extends BaseND4JTest {
public class LibSvmRecordWriterTest extends BaseND4JTest {
@Test @Test
public void testBasic() throws Exception { @DisplayName("Test Basic")
void testBasic() throws Exception {
Configuration configWriter = new Configuration(); Configuration configWriter = new Configuration();
Configuration configReader = new Configuration(); Configuration configReader = new Configuration();
configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 10); configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
configReader.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false); configReader.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
File inputFile = new ClassPathResource("datavec-api/svmlight/basic.txt").getFile(); File inputFile = new ClassPathResource("datavec-api/svmlight/basic.txt").getFile();
executeTest(configWriter, configReader, inputFile); executeTest(configWriter, configReader, inputFile);
} }
@Test @Test
public void testNoLabel() throws Exception { @DisplayName("Test No Label")
void testNoLabel() throws Exception {
Configuration configWriter = new Configuration(); Configuration configWriter = new Configuration();
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 9); configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 9);
Configuration configReader = new Configuration(); Configuration configReader = new Configuration();
configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 10); configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
configReader.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false); configReader.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
File inputFile = new ClassPathResource("datavec-api/svmlight/basic.txt").getFile(); File inputFile = new ClassPathResource("datavec-api/svmlight/basic.txt").getFile();
executeTest(configWriter, configReader, inputFile); executeTest(configWriter, configReader, inputFile);
} }
@Test @Test
public void testMultioutputRecord() throws Exception { @DisplayName("Test Multioutput Record")
void testMultioutputRecord() throws Exception {
Configuration configWriter = new Configuration(); Configuration configWriter = new Configuration();
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 9); configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 9);
Configuration configReader = new Configuration(); Configuration configReader = new Configuration();
configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 10); configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
configReader.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false); configReader.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
File inputFile = new ClassPathResource("datavec-api/svmlight/multioutput.txt").getFile(); File inputFile = new ClassPathResource("datavec-api/svmlight/multioutput.txt").getFile();
executeTest(configWriter, configReader, inputFile); executeTest(configWriter, configReader, inputFile);
} }
@Test @Test
public void testMultilabelRecord() throws Exception { @DisplayName("Test Multilabel Record")
void testMultilabelRecord() throws Exception {
Configuration configWriter = new Configuration(); Configuration configWriter = new Configuration();
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 9); configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 9);
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true); configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
Configuration configReader = new Configuration(); Configuration configReader = new Configuration();
configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 10); configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
configReader.setBoolean(LibSvmRecordReader.MULTILABEL, true); configReader.setBoolean(LibSvmRecordReader.MULTILABEL, true);
configReader.setInt(LibSvmRecordReader.NUM_LABELS, 4); configReader.setInt(LibSvmRecordReader.NUM_LABELS, 4);
configReader.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false); configReader.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
File inputFile = new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile(); File inputFile = new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile();
executeTest(configWriter, configReader, inputFile); executeTest(configWriter, configReader, inputFile);
} }
@Test @Test
public void testZeroBasedIndexing() throws Exception { @DisplayName("Test Zero Based Indexing")
void testZeroBasedIndexing() throws Exception {
Configuration configWriter = new Configuration(); Configuration configWriter = new Configuration();
configWriter.setBoolean(LibSvmRecordWriter.ZERO_BASED_INDEXING, true); configWriter.setBoolean(LibSvmRecordWriter.ZERO_BASED_INDEXING, true);
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 10); configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 10);
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true); configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
Configuration configReader = new Configuration(); Configuration configReader = new Configuration();
configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 11); configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 11);
configReader.setBoolean(LibSvmRecordReader.MULTILABEL, true); configReader.setBoolean(LibSvmRecordReader.MULTILABEL, true);
configReader.setInt(LibSvmRecordReader.NUM_LABELS, 5); configReader.setInt(LibSvmRecordReader.NUM_LABELS, 5);
File inputFile = new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile(); File inputFile = new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile();
executeTest(configWriter, configReader, inputFile); executeTest(configWriter, configReader, inputFile);
} }
@ -127,10 +123,9 @@ public class LibSvmRecordWriterTest extends BaseND4JTest {
tempFile.deleteOnExit(); tempFile.deleteOnExit();
if (tempFile.exists()) if (tempFile.exists())
tempFile.delete(); tempFile.delete();
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) { try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
FileSplit outputSplit = new FileSplit(tempFile); FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner()); writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
LibSvmRecordReader rr = new LibSvmRecordReader(); LibSvmRecordReader rr = new LibSvmRecordReader();
rr.initialize(configReader, new FileSplit(inputFile)); rr.initialize(configReader, new FileSplit(inputFile));
while (rr.hasNext()) { while (rr.hasNext()) {
@ -138,7 +133,6 @@ public class LibSvmRecordWriterTest extends BaseND4JTest {
writer.write(record); writer.write(record);
} }
} }
Pattern p = Pattern.compile(String.format("%s:\\d+ ", LibSvmRecordReader.QID_PREFIX)); Pattern p = Pattern.compile(String.format("%s:\\d+ ", LibSvmRecordReader.QID_PREFIX));
List<String> linesOriginal = new ArrayList<>(); List<String> linesOriginal = new ArrayList<>();
for (String line : FileUtils.readLines(inputFile)) { for (String line : FileUtils.readLines(inputFile)) {
@ -159,7 +153,8 @@ public class LibSvmRecordWriterTest extends BaseND4JTest {
} }
@Test @Test
public void testNDArrayWritables() throws Exception { @DisplayName("Test ND Array Writables")
void testNDArrayWritables() throws Exception {
INDArray arr2 = Nd4j.zeros(2); INDArray arr2 = Nd4j.zeros(2);
arr2.putScalar(0, 11); arr2.putScalar(0, 11);
arr2.putScalar(1, 12); arr2.putScalar(1, 12);
@ -167,35 +162,28 @@ public class LibSvmRecordWriterTest extends BaseND4JTest {
arr3.putScalar(0, 13); arr3.putScalar(0, 13);
arr3.putScalar(1, 14); arr3.putScalar(1, 14);
arr3.putScalar(2, 15); arr3.putScalar(2, 15);
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1), List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), new NDArrayWritable(arr3), new IntWritable(4));
new NDArrayWritable(arr2),
new IntWritable(2),
new DoubleWritable(3),
new NDArrayWritable(arr3),
new IntWritable(4));
File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt"); File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
tempFile.setWritable(true); tempFile.setWritable(true);
tempFile.deleteOnExit(); tempFile.deleteOnExit();
if (tempFile.exists()) if (tempFile.exists())
tempFile.delete(); tempFile.delete();
String lineOriginal = "13.0,14.0,15.0,4 1:1.0 2:11.0 3:12.0 4:2.0 5:3.0"; String lineOriginal = "13.0,14.0,15.0,4 1:1.0 2:11.0 3:12.0 4:2.0 5:3.0";
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) { try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
Configuration configWriter = new Configuration(); Configuration configWriter = new Configuration();
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 3); configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 3);
FileSplit outputSplit = new FileSplit(tempFile); FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner()); writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record); writer.write(record);
} }
String lineNew = FileUtils.readFileToString(tempFile).trim(); String lineNew = FileUtils.readFileToString(tempFile).trim();
assertEquals(lineOriginal, lineNew); assertEquals(lineOriginal, lineNew);
} }
@Test @Test
public void testNDArrayWritablesMultilabel() throws Exception { @DisplayName("Test ND Array Writables Multilabel")
void testNDArrayWritablesMultilabel() throws Exception {
INDArray arr2 = Nd4j.zeros(2); INDArray arr2 = Nd4j.zeros(2);
arr2.putScalar(0, 11); arr2.putScalar(0, 11);
arr2.putScalar(1, 12); arr2.putScalar(1, 12);
@ -203,36 +191,29 @@ public class LibSvmRecordWriterTest extends BaseND4JTest {
arr3.putScalar(0, 0); arr3.putScalar(0, 0);
arr3.putScalar(1, 1); arr3.putScalar(1, 1);
arr3.putScalar(2, 0); arr3.putScalar(2, 0);
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1), List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), new NDArrayWritable(arr3), new DoubleWritable(1));
new NDArrayWritable(arr2),
new IntWritable(2),
new DoubleWritable(3),
new NDArrayWritable(arr3),
new DoubleWritable(1));
File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt"); File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
tempFile.setWritable(true); tempFile.setWritable(true);
tempFile.deleteOnExit(); tempFile.deleteOnExit();
if (tempFile.exists()) if (tempFile.exists())
tempFile.delete(); tempFile.delete();
String lineOriginal = "2,4 1:1.0 2:11.0 3:12.0 4:2.0 5:3.0"; String lineOriginal = "2,4 1:1.0 2:11.0 3:12.0 4:2.0 5:3.0";
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) { try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
Configuration configWriter = new Configuration(); Configuration configWriter = new Configuration();
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true); configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 3); configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 3);
FileSplit outputSplit = new FileSplit(tempFile); FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner()); writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record); writer.write(record);
} }
String lineNew = FileUtils.readFileToString(tempFile).trim(); String lineNew = FileUtils.readFileToString(tempFile).trim();
assertEquals(lineOriginal, lineNew); assertEquals(lineOriginal, lineNew);
} }
@Test @Test
public void testNDArrayWritablesZeroIndex() throws Exception { @DisplayName("Test ND Array Writables Zero Index")
void testNDArrayWritablesZeroIndex() throws Exception {
INDArray arr2 = Nd4j.zeros(2); INDArray arr2 = Nd4j.zeros(2);
arr2.putScalar(0, 11); arr2.putScalar(0, 11);
arr2.putScalar(1, 12); arr2.putScalar(1, 12);
@ -240,99 +221,91 @@ public class LibSvmRecordWriterTest extends BaseND4JTest {
arr3.putScalar(0, 0); arr3.putScalar(0, 0);
arr3.putScalar(1, 1); arr3.putScalar(1, 1);
arr3.putScalar(2, 0); arr3.putScalar(2, 0);
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1), List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), new NDArrayWritable(arr3), new DoubleWritable(1));
new NDArrayWritable(arr2),
new IntWritable(2),
new DoubleWritable(3),
new NDArrayWritable(arr3),
new DoubleWritable(1));
File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt"); File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
tempFile.setWritable(true); tempFile.setWritable(true);
tempFile.deleteOnExit(); tempFile.deleteOnExit();
if (tempFile.exists()) if (tempFile.exists())
tempFile.delete(); tempFile.delete();
String lineOriginal = "1,3 0:1.0 1:11.0 2:12.0 3:2.0 4:3.0"; String lineOriginal = "1,3 0:1.0 1:11.0 2:12.0 3:2.0 4:3.0";
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) { try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
Configuration configWriter = new Configuration(); Configuration configWriter = new Configuration();
configWriter.setBoolean(LibSvmRecordWriter.ZERO_BASED_INDEXING, true); // NOT STANDARD! // NOT STANDARD!
configWriter.setBoolean(LibSvmRecordWriter.ZERO_BASED_LABEL_INDEXING, true); // NOT STANDARD! configWriter.setBoolean(LibSvmRecordWriter.ZERO_BASED_INDEXING, true);
// NOT STANDARD!
configWriter.setBoolean(LibSvmRecordWriter.ZERO_BASED_LABEL_INDEXING, true);
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true); configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 3); configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 3);
FileSplit outputSplit = new FileSplit(tempFile); FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner()); writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record); writer.write(record);
} }
String lineNew = FileUtils.readFileToString(tempFile).trim(); String lineNew = FileUtils.readFileToString(tempFile).trim();
assertEquals(lineOriginal, lineNew); assertEquals(lineOriginal, lineNew);
} }
@Test @Test
public void testNonIntegerButValidMultilabel() throws Exception { @DisplayName("Test Non Integer But Valid Multilabel")
List<Writable> record = Arrays.asList((Writable) new IntWritable(3), void testNonIntegerButValidMultilabel() throws Exception {
new IntWritable(2), List<Writable> record = Arrays.asList((Writable) new IntWritable(3), new IntWritable(2), new DoubleWritable(1.0));
new DoubleWritable(1.0));
File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt"); File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
tempFile.setWritable(true); tempFile.setWritable(true);
tempFile.deleteOnExit(); tempFile.deleteOnExit();
if (tempFile.exists()) if (tempFile.exists())
tempFile.delete(); tempFile.delete();
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) { try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
Configuration configWriter = new Configuration(); Configuration configWriter = new Configuration();
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 1); configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 1);
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true); configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
FileSplit outputSplit = new FileSplit(tempFile); FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner()); writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record); writer.write(record);
} }
} }
@Test(expected = NumberFormatException.class) @Test
public void nonIntegerMultilabel() throws Exception { @DisplayName("Non Integer Multilabel")
List<Writable> record = Arrays.asList((Writable) new IntWritable(3), void nonIntegerMultilabel() {
new IntWritable(2), assertThrows(NumberFormatException.class, () -> {
new DoubleWritable(1.2)); List<Writable> record = Arrays.asList((Writable) new IntWritable(3), new IntWritable(2), new DoubleWritable(1.2));
File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt"); File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
tempFile.setWritable(true); tempFile.setWritable(true);
tempFile.deleteOnExit(); tempFile.deleteOnExit();
if (tempFile.exists()) if (tempFile.exists())
tempFile.delete(); tempFile.delete();
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) { try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
Configuration configWriter = new Configuration(); Configuration configWriter = new Configuration();
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 1); configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 1);
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true); configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
FileSplit outputSplit = new FileSplit(tempFile); FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner()); writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record); writer.write(record);
} }
});
} }
@Test(expected = NumberFormatException.class) @Test
public void nonBinaryMultilabel() throws Exception { @DisplayName("Non Binary Multilabel")
List<Writable> record = Arrays.asList((Writable) new IntWritable(0), void nonBinaryMultilabel() {
new IntWritable(1), assertThrows(NumberFormatException.class, () -> {
new IntWritable(2)); List<Writable> record = Arrays.asList((Writable) new IntWritable(0), new IntWritable(1), new IntWritable(2));
File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt"); File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
tempFile.setWritable(true); tempFile.setWritable(true);
tempFile.deleteOnExit(); tempFile.deleteOnExit();
if (tempFile.exists()) if (tempFile.exists())
tempFile.delete(); tempFile.delete();
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) { try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
Configuration configWriter = new Configuration(); Configuration configWriter = new Configuration();
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN,0); configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN,1); configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 1);
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL,true); configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
FileSplit outputSplit = new FileSplit(tempFile); FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner()); writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record); writer.write(record);
} }
});
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.records.writer.impl; package org.datavec.api.records.writer.impl;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
@ -27,93 +26,90 @@ import org.datavec.api.records.writer.impl.misc.SVMLightRecordWriter;
import org.datavec.api.split.FileSplit; import org.datavec.api.split.FileSplit;
import org.datavec.api.split.partition.NumberOfRecordsPartitioner; import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
import org.datavec.api.writable.*; import org.datavec.api.writable.*;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import java.io.File; import java.io.File;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.regex.Matcher; import java.util.regex.Matcher;
import java.util.regex.Pattern; import java.util.regex.Pattern;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.Assert.assertEquals; @DisplayName("Svm Light Record Writer Test")
class SVMLightRecordWriterTest extends BaseND4JTest {
public class SVMLightRecordWriterTest extends BaseND4JTest {
@Test @Test
public void testBasic() throws Exception { @DisplayName("Test Basic")
void testBasic() throws Exception {
Configuration configWriter = new Configuration(); Configuration configWriter = new Configuration();
Configuration configReader = new Configuration(); Configuration configReader = new Configuration();
configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 10); configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
configReader.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false); configReader.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
File inputFile = new ClassPathResource("datavec-api/svmlight/basic.txt").getFile(); File inputFile = new ClassPathResource("datavec-api/svmlight/basic.txt").getFile();
executeTest(configWriter, configReader, inputFile); executeTest(configWriter, configReader, inputFile);
} }
@Test @Test
public void testNoLabel() throws Exception { @DisplayName("Test No Label")
void testNoLabel() throws Exception {
Configuration configWriter = new Configuration(); Configuration configWriter = new Configuration();
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 9); configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 9);
Configuration configReader = new Configuration(); Configuration configReader = new Configuration();
configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 10); configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
configReader.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false); configReader.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
File inputFile = new ClassPathResource("datavec-api/svmlight/noLabels.txt").getFile(); File inputFile = new ClassPathResource("datavec-api/svmlight/noLabels.txt").getFile();
executeTest(configWriter, configReader, inputFile); executeTest(configWriter, configReader, inputFile);
} }
@Test @Test
public void testMultioutputRecord() throws Exception { @DisplayName("Test Multioutput Record")
void testMultioutputRecord() throws Exception {
Configuration configWriter = new Configuration(); Configuration configWriter = new Configuration();
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 9); configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 9);
Configuration configReader = new Configuration(); Configuration configReader = new Configuration();
configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 10); configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
configReader.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false); configReader.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
File inputFile = new ClassPathResource("datavec-api/svmlight/multioutput.txt").getFile(); File inputFile = new ClassPathResource("datavec-api/svmlight/multioutput.txt").getFile();
executeTest(configWriter, configReader, inputFile); executeTest(configWriter, configReader, inputFile);
} }
@Test @Test
public void testMultilabelRecord() throws Exception { @DisplayName("Test Multilabel Record")
void testMultilabelRecord() throws Exception {
Configuration configWriter = new Configuration(); Configuration configWriter = new Configuration();
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 9); configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 9);
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true); configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
Configuration configReader = new Configuration(); Configuration configReader = new Configuration();
configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 10); configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
configReader.setBoolean(SVMLightRecordReader.MULTILABEL, true); configReader.setBoolean(SVMLightRecordReader.MULTILABEL, true);
configReader.setInt(SVMLightRecordReader.NUM_LABELS, 4); configReader.setInt(SVMLightRecordReader.NUM_LABELS, 4);
configReader.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false); configReader.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
File inputFile = new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile(); File inputFile = new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile();
executeTest(configWriter, configReader, inputFile); executeTest(configWriter, configReader, inputFile);
} }
@Test @Test
public void testZeroBasedIndexing() throws Exception { @DisplayName("Test Zero Based Indexing")
void testZeroBasedIndexing() throws Exception {
Configuration configWriter = new Configuration(); Configuration configWriter = new Configuration();
configWriter.setBoolean(SVMLightRecordWriter.ZERO_BASED_INDEXING, true); configWriter.setBoolean(SVMLightRecordWriter.ZERO_BASED_INDEXING, true);
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 10); configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 10);
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true); configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
Configuration configReader = new Configuration(); Configuration configReader = new Configuration();
configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 11); configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 11);
configReader.setBoolean(SVMLightRecordReader.MULTILABEL, true); configReader.setBoolean(SVMLightRecordReader.MULTILABEL, true);
configReader.setInt(SVMLightRecordReader.NUM_LABELS, 5); configReader.setInt(SVMLightRecordReader.NUM_LABELS, 5);
File inputFile = new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile(); File inputFile = new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile();
executeTest(configWriter, configReader, inputFile); executeTest(configWriter, configReader, inputFile);
} }
@ -124,10 +120,9 @@ public class SVMLightRecordWriterTest extends BaseND4JTest {
tempFile.deleteOnExit(); tempFile.deleteOnExit();
if (tempFile.exists()) if (tempFile.exists())
tempFile.delete(); tempFile.delete();
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) { try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
FileSplit outputSplit = new FileSplit(tempFile); FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner()); writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
SVMLightRecordReader rr = new SVMLightRecordReader(); SVMLightRecordReader rr = new SVMLightRecordReader();
rr.initialize(configReader, new FileSplit(inputFile)); rr.initialize(configReader, new FileSplit(inputFile));
while (rr.hasNext()) { while (rr.hasNext()) {
@ -135,7 +130,6 @@ public class SVMLightRecordWriterTest extends BaseND4JTest {
writer.write(record); writer.write(record);
} }
} }
Pattern p = Pattern.compile(String.format("%s:\\d+ ", SVMLightRecordReader.QID_PREFIX)); Pattern p = Pattern.compile(String.format("%s:\\d+ ", SVMLightRecordReader.QID_PREFIX));
List<String> linesOriginal = new ArrayList<>(); List<String> linesOriginal = new ArrayList<>();
for (String line : FileUtils.readLines(inputFile)) { for (String line : FileUtils.readLines(inputFile)) {
@ -156,7 +150,8 @@ public class SVMLightRecordWriterTest extends BaseND4JTest {
} }
@Test @Test
public void testNDArrayWritables() throws Exception { @DisplayName("Test ND Array Writables")
void testNDArrayWritables() throws Exception {
INDArray arr2 = Nd4j.zeros(2); INDArray arr2 = Nd4j.zeros(2);
arr2.putScalar(0, 11); arr2.putScalar(0, 11);
arr2.putScalar(1, 12); arr2.putScalar(1, 12);
@ -164,35 +159,28 @@ public class SVMLightRecordWriterTest extends BaseND4JTest {
arr3.putScalar(0, 13); arr3.putScalar(0, 13);
arr3.putScalar(1, 14); arr3.putScalar(1, 14);
arr3.putScalar(2, 15); arr3.putScalar(2, 15);
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1), List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), new NDArrayWritable(arr3), new IntWritable(4));
new NDArrayWritable(arr2),
new IntWritable(2),
new DoubleWritable(3),
new NDArrayWritable(arr3),
new IntWritable(4));
File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt"); File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
tempFile.setWritable(true); tempFile.setWritable(true);
tempFile.deleteOnExit(); tempFile.deleteOnExit();
if (tempFile.exists()) if (tempFile.exists())
tempFile.delete(); tempFile.delete();
String lineOriginal = "13.0,14.0,15.0,4 1:1.0 2:11.0 3:12.0 4:2.0 5:3.0"; String lineOriginal = "13.0,14.0,15.0,4 1:1.0 2:11.0 3:12.0 4:2.0 5:3.0";
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) { try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
Configuration configWriter = new Configuration(); Configuration configWriter = new Configuration();
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 3); configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 3);
FileSplit outputSplit = new FileSplit(tempFile); FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner()); writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record); writer.write(record);
} }
String lineNew = FileUtils.readFileToString(tempFile).trim(); String lineNew = FileUtils.readFileToString(tempFile).trim();
assertEquals(lineOriginal, lineNew); assertEquals(lineOriginal, lineNew);
} }
@Test @Test
public void testNDArrayWritablesMultilabel() throws Exception { @DisplayName("Test ND Array Writables Multilabel")
void testNDArrayWritablesMultilabel() throws Exception {
INDArray arr2 = Nd4j.zeros(2); INDArray arr2 = Nd4j.zeros(2);
arr2.putScalar(0, 11); arr2.putScalar(0, 11);
arr2.putScalar(1, 12); arr2.putScalar(1, 12);
@ -200,36 +188,29 @@ public class SVMLightRecordWriterTest extends BaseND4JTest {
arr3.putScalar(0, 0); arr3.putScalar(0, 0);
arr3.putScalar(1, 1); arr3.putScalar(1, 1);
arr3.putScalar(2, 0); arr3.putScalar(2, 0);
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1), List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), new NDArrayWritable(arr3), new DoubleWritable(1));
new NDArrayWritable(arr2),
new IntWritable(2),
new DoubleWritable(3),
new NDArrayWritable(arr3),
new DoubleWritable(1));
File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt"); File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
tempFile.setWritable(true); tempFile.setWritable(true);
tempFile.deleteOnExit(); tempFile.deleteOnExit();
if (tempFile.exists()) if (tempFile.exists())
tempFile.delete(); tempFile.delete();
String lineOriginal = "2,4 1:1.0 2:11.0 3:12.0 4:2.0 5:3.0"; String lineOriginal = "2,4 1:1.0 2:11.0 3:12.0 4:2.0 5:3.0";
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) { try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
Configuration configWriter = new Configuration(); Configuration configWriter = new Configuration();
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true); configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 3); configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 3);
FileSplit outputSplit = new FileSplit(tempFile); FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner()); writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record); writer.write(record);
} }
String lineNew = FileUtils.readFileToString(tempFile).trim(); String lineNew = FileUtils.readFileToString(tempFile).trim();
assertEquals(lineOriginal, lineNew); assertEquals(lineOriginal, lineNew);
} }
@Test @Test
public void testNDArrayWritablesZeroIndex() throws Exception { @DisplayName("Test ND Array Writables Zero Index")
void testNDArrayWritablesZeroIndex() throws Exception {
INDArray arr2 = Nd4j.zeros(2); INDArray arr2 = Nd4j.zeros(2);
arr2.putScalar(0, 11); arr2.putScalar(0, 11);
arr2.putScalar(1, 12); arr2.putScalar(1, 12);
@ -237,99 +218,91 @@ public class SVMLightRecordWriterTest extends BaseND4JTest {
arr3.putScalar(0, 0); arr3.putScalar(0, 0);
arr3.putScalar(1, 1); arr3.putScalar(1, 1);
arr3.putScalar(2, 0); arr3.putScalar(2, 0);
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1), List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), new NDArrayWritable(arr3), new DoubleWritable(1));
new NDArrayWritable(arr2),
new IntWritable(2),
new DoubleWritable(3),
new NDArrayWritable(arr3),
new DoubleWritable(1));
File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt"); File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
tempFile.setWritable(true); tempFile.setWritable(true);
tempFile.deleteOnExit(); tempFile.deleteOnExit();
if (tempFile.exists()) if (tempFile.exists())
tempFile.delete(); tempFile.delete();
String lineOriginal = "1,3 0:1.0 1:11.0 2:12.0 3:2.0 4:3.0"; String lineOriginal = "1,3 0:1.0 1:11.0 2:12.0 3:2.0 4:3.0";
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) { try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
Configuration configWriter = new Configuration(); Configuration configWriter = new Configuration();
configWriter.setBoolean(SVMLightRecordWriter.ZERO_BASED_INDEXING, true); // NOT STANDARD! // NOT STANDARD!
configWriter.setBoolean(SVMLightRecordWriter.ZERO_BASED_LABEL_INDEXING, true); // NOT STANDARD! configWriter.setBoolean(SVMLightRecordWriter.ZERO_BASED_INDEXING, true);
// NOT STANDARD!
configWriter.setBoolean(SVMLightRecordWriter.ZERO_BASED_LABEL_INDEXING, true);
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true); configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 3); configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 3);
FileSplit outputSplit = new FileSplit(tempFile); FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner()); writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record); writer.write(record);
} }
String lineNew = FileUtils.readFileToString(tempFile).trim(); String lineNew = FileUtils.readFileToString(tempFile).trim();
assertEquals(lineOriginal, lineNew); assertEquals(lineOriginal, lineNew);
} }
@Test @Test
public void testNonIntegerButValidMultilabel() throws Exception { @DisplayName("Test Non Integer But Valid Multilabel")
List<Writable> record = Arrays.asList((Writable) new IntWritable(3), void testNonIntegerButValidMultilabel() throws Exception {
new IntWritable(2), List<Writable> record = Arrays.asList((Writable) new IntWritable(3), new IntWritable(2), new DoubleWritable(1.0));
new DoubleWritable(1.0));
File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt"); File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
tempFile.setWritable(true); tempFile.setWritable(true);
tempFile.deleteOnExit(); tempFile.deleteOnExit();
if (tempFile.exists()) if (tempFile.exists())
tempFile.delete(); tempFile.delete();
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) { try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
Configuration configWriter = new Configuration(); Configuration configWriter = new Configuration();
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 1); configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 1);
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true); configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
FileSplit outputSplit = new FileSplit(tempFile); FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner()); writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record); writer.write(record);
} }
} }
@Test(expected = NumberFormatException.class) @Test
public void nonIntegerMultilabel() throws Exception { @DisplayName("Non Integer Multilabel")
List<Writable> record = Arrays.asList((Writable) new IntWritable(3), void nonIntegerMultilabel() {
new IntWritable(2), assertThrows(NumberFormatException.class, () -> {
new DoubleWritable(1.2)); List<Writable> record = Arrays.asList((Writable) new IntWritable(3), new IntWritable(2), new DoubleWritable(1.2));
File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt"); File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
tempFile.setWritable(true); tempFile.setWritable(true);
tempFile.deleteOnExit(); tempFile.deleteOnExit();
if (tempFile.exists()) if (tempFile.exists())
tempFile.delete(); tempFile.delete();
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) { try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
Configuration configWriter = new Configuration(); Configuration configWriter = new Configuration();
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 1); configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 1);
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true); configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
FileSplit outputSplit = new FileSplit(tempFile); FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner()); writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record); writer.write(record);
} }
});
} }
@Test(expected = NumberFormatException.class) @Test
public void nonBinaryMultilabel() throws Exception { @DisplayName("Non Binary Multilabel")
List<Writable> record = Arrays.asList((Writable) new IntWritable(0), void nonBinaryMultilabel() {
new IntWritable(1), assertThrows(NumberFormatException.class, () -> {
new IntWritable(2)); List<Writable> record = Arrays.asList((Writable) new IntWritable(0), new IntWritable(1), new IntWritable(2));
File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt"); File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
tempFile.setWritable(true); tempFile.setWritable(true);
tempFile.deleteOnExit(); tempFile.deleteOnExit();
if (tempFile.exists()) if (tempFile.exists())
tempFile.delete(); tempFile.delete();
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) { try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
Configuration configWriter = new Configuration(); Configuration configWriter = new Configuration();
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 1); configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 1);
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true); configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
FileSplit outputSplit = new FileSplit(tempFile); FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner()); writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record); writer.write(record);
} }
});
} }
} }

View File

@ -26,7 +26,7 @@ import org.datavec.api.io.filters.BalancedPathFilter;
import org.datavec.api.io.filters.RandomPathFilter; import org.datavec.api.io.filters.RandomPathFilter;
import org.datavec.api.io.labels.ParentPathLabelGenerator; import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.io.labels.PatternPathLabelGenerator; import org.datavec.api.io.labels.PatternPathLabelGenerator;
import org.junit.Test; import org.junit.jupiter.api.Test;
import java.io.*; import java.io.*;
import java.net.URI; import java.net.URI;
@ -34,8 +34,9 @@ import java.net.URISyntaxException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Random; import java.util.Random;
import static junit.framework.TestCase.assertTrue;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
/** /**
* *

View File

@ -20,13 +20,12 @@
package org.datavec.api.split; package org.datavec.api.split;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.net.URI; import java.net.URI;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.*;
import static org.junit.Assert.assertTrue;
public class NumberedFileInputSplitTests extends BaseND4JTest { public class NumberedFileInputSplitTests extends BaseND4JTest {
@Test @Test
@ -69,60 +68,81 @@ public class NumberedFileInputSplitTests extends BaseND4JTest {
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx); runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
} }
@Test(expected = IllegalArgumentException.class) @Test()
public void testNumberedFileInputSplitWithLeadingSpaces() { public void testNumberedFileInputSplitWithLeadingSpaces() {
assertThrows(IllegalArgumentException.class,() -> {
String baseString = "/path/to/files/prefix-%5d.suffix"; String baseString = "/path/to/files/prefix-%5d.suffix";
int minIdx = 0; int minIdx = 0;
int maxIdx = 10; int maxIdx = 10;
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx); runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
});
} }
@Test(expected = IllegalArgumentException.class) @Test()
public void testNumberedFileInputSplitWithNoLeadingZeroInPadding() { public void testNumberedFileInputSplitWithNoLeadingZeroInPadding() {
assertThrows(IllegalArgumentException.class, () -> {
String baseString = "/path/to/files/prefix%5d.suffix"; String baseString = "/path/to/files/prefix%5d.suffix";
int minIdx = 0; int minIdx = 0;
int maxIdx = 10; int maxIdx = 10;
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx); runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
});
} }
@Test(expected = IllegalArgumentException.class) @Test()
public void testNumberedFileInputSplitWithLeadingPlusInPadding() { public void testNumberedFileInputSplitWithLeadingPlusInPadding() {
assertThrows(IllegalArgumentException.class,() -> {
String baseString = "/path/to/files/prefix%+5d.suffix"; String baseString = "/path/to/files/prefix%+5d.suffix";
int minIdx = 0; int minIdx = 0;
int maxIdx = 10; int maxIdx = 10;
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx); runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
});
} }
@Test(expected = IllegalArgumentException.class) @Test()
public void testNumberedFileInputSplitWithLeadingMinusInPadding() { public void testNumberedFileInputSplitWithLeadingMinusInPadding() {
assertThrows(IllegalArgumentException.class,() -> {
String baseString = "/path/to/files/prefix%-5d.suffix"; String baseString = "/path/to/files/prefix%-5d.suffix";
int minIdx = 0; int minIdx = 0;
int maxIdx = 10; int maxIdx = 10;
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx); runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
});
} }
@Test(expected = IllegalArgumentException.class) @Test()
public void testNumberedFileInputSplitWithTwoDigitsInPadding() { public void testNumberedFileInputSplitWithTwoDigitsInPadding() {
assertThrows(IllegalArgumentException.class,() -> {
String baseString = "/path/to/files/prefix%011d.suffix"; String baseString = "/path/to/files/prefix%011d.suffix";
int minIdx = 0; int minIdx = 0;
int maxIdx = 10; int maxIdx = 10;
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx); runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
});
} }
@Test(expected = IllegalArgumentException.class) @Test()
public void testNumberedFileInputSplitWithInnerZerosInPadding() { public void testNumberedFileInputSplitWithInnerZerosInPadding() {
assertThrows(IllegalArgumentException.class,() -> {
String baseString = "/path/to/files/prefix%101d.suffix"; String baseString = "/path/to/files/prefix%101d.suffix";
int minIdx = 0; int minIdx = 0;
int maxIdx = 10; int maxIdx = 10;
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx); runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
});
} }
@Test(expected = IllegalArgumentException.class) @Test()
public void testNumberedFileInputSplitWithRepeatInnerZerosInPadding() { public void testNumberedFileInputSplitWithRepeatInnerZerosInPadding() {
assertThrows(IllegalArgumentException.class,() -> {
String baseString = "/path/to/files/prefix%0505d.suffix"; String baseString = "/path/to/files/prefix%0505d.suffix";
int minIdx = 0; int minIdx = 0;
int maxIdx = 10; int maxIdx = 10;
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx); runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
});
} }
@ -135,7 +155,7 @@ public class NumberedFileInputSplitTests extends BaseND4JTest {
String path = locs[j++].getPath(); String path = locs[j++].getPath();
String exp = String.format(baseString, i); String exp = String.format(baseString, i);
String msg = exp + " vs " + path; String msg = exp + " vs " + path;
assertTrue(msg, path.endsWith(exp)); //Note: on Windows, Java can prepend drive to path - "/C:/" assertTrue(path.endsWith(exp),msg); //Note: on Windows, Java can prepend drive to path - "/C:/"
} }
} }
} }

View File

@ -25,9 +25,10 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader; import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.function.Function; import org.nd4j.common.function.Function;
@ -37,22 +38,22 @@ import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.net.URI; import java.net.URI;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.Random; import java.util.Random;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertNotEquals;
public class TestStreamInputSplit extends BaseND4JTest { public class TestStreamInputSplit extends BaseND4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test @Test
public void testCsvSimple() throws Exception { public void testCsvSimple(@TempDir Path testDir) throws Exception {
File dir = testDir.newFolder(); File dir = testDir.toFile();
File f1 = new File(dir, "file1.txt"); File f1 = new File(dir, "file1.txt");
File f2 = new File(dir, "file2.txt"); File f2 = new File(dir, "file2.txt");
@ -93,9 +94,9 @@ public class TestStreamInputSplit extends BaseND4JTest {
@Test @Test
public void testCsvSequenceSimple() throws Exception { public void testCsvSequenceSimple(@TempDir Path testDir) throws Exception {
File dir = testDir.newFolder(); File dir = testDir.toFile();
File f1 = new File(dir, "file1.txt"); File f1 = new File(dir, "file1.txt");
File f2 = new File(dir, "file2.txt"); File f2 = new File(dir, "file2.txt");
@ -137,8 +138,8 @@ public class TestStreamInputSplit extends BaseND4JTest {
} }
@Test @Test
public void testShuffle() throws Exception { public void testShuffle(@TempDir Path testDir) throws Exception {
File dir = testDir.newFolder(); File dir = testDir.toFile();
File f1 = new File(dir, "file1.txt"); File f1 = new File(dir, "file1.txt");
File f2 = new File(dir, "file2.txt"); File f2 = new File(dir, "file2.txt");
File f3 = new File(dir, "file3.txt"); File f3 = new File(dir, "file3.txt");

View File

@ -17,44 +17,43 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.split; package org.datavec.api.split;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.net.URI; import java.net.URI;
import java.net.URISyntaxException; import java.net.URISyntaxException;
import java.util.Collection; import java.util.Collection;
import static java.util.Arrays.asList; import static java.util.Arrays.asList;
import static org.junit.Assert.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
/** /**
* @author Ede Meijer * @author Ede Meijer
*/ */
public class TransformSplitTest extends BaseND4JTest { @DisplayName("Transform Split Test")
@Test class TransformSplitTest extends BaseND4JTest {
public void testTransform() throws URISyntaxException {
Collection<URI> inputFiles = asList(new URI("file:///foo/bar/../0.csv"), new URI("file:///foo/1.csv"));
@Test
@DisplayName("Test Transform")
void testTransform() throws URISyntaxException {
Collection<URI> inputFiles = asList(new URI("file:///foo/bar/../0.csv"), new URI("file:///foo/1.csv"));
InputSplit SUT = new TransformSplit(new CollectionInputSplit(inputFiles), new TransformSplit.URITransform() { InputSplit SUT = new TransformSplit(new CollectionInputSplit(inputFiles), new TransformSplit.URITransform() {
@Override @Override
public URI apply(URI uri) throws URISyntaxException { public URI apply(URI uri) throws URISyntaxException {
return uri.normalize(); return uri.normalize();
} }
}); });
assertArrayEquals(new URI[] { new URI("file:///foo/0.csv"), new URI("file:///foo/1.csv") }, SUT.locations());
assertArrayEquals(new URI[] {new URI("file:///foo/0.csv"), new URI("file:///foo/1.csv")}, SUT.locations());
} }
@Test @Test
public void testSearchReplace() throws URISyntaxException { @DisplayName("Test Search Replace")
void testSearchReplace() throws URISyntaxException {
Collection<URI> inputFiles = asList(new URI("file:///foo/1-in.csv"), new URI("file:///foo/2-in.csv")); Collection<URI> inputFiles = asList(new URI("file:///foo/1-in.csv"), new URI("file:///foo/2-in.csv"));
InputSplit SUT = TransformSplit.ofSearchReplace(new CollectionInputSplit(inputFiles), "-in.csv", "-out.csv"); InputSplit SUT = TransformSplit.ofSearchReplace(new CollectionInputSplit(inputFiles), "-in.csv", "-out.csv");
assertArrayEquals(new URI[] { new URI("file:///foo/1-out.csv"), new URI("file:///foo/2-out.csv") }, SUT.locations());
assertArrayEquals(new URI[] {new URI("file:///foo/1-out.csv"), new URI("file:///foo/2-out.csv")},
SUT.locations());
} }
} }

View File

@ -27,14 +27,12 @@ import org.datavec.api.split.FileSplit;
import org.datavec.api.split.partition.NumberOfRecordsPartitioner; import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
import org.datavec.api.split.partition.PartitionMetaData; import org.datavec.api.split.partition.PartitionMetaData;
import org.datavec.api.split.partition.Partitioner; import org.datavec.api.split.partition.Partitioner;
import org.junit.Test; import org.junit.jupiter.api.Test;
import java.io.File; import java.io.File;
import java.io.OutputStream; import java.io.OutputStream;
import static junit.framework.TestCase.assertTrue; import static org.junit.jupiter.api.Assertions.*;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
public class PartitionerTests extends BaseND4JTest { public class PartitionerTests extends BaseND4JTest {
@Test @Test

View File

@ -29,12 +29,12 @@ import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.util.*; import java.util.*;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestTransformProcess extends BaseND4JTest { public class TestTransformProcess extends BaseND4JTest {

View File

@ -27,13 +27,13 @@ import org.datavec.api.transform.condition.string.StringRegexColumnCondition;
import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.transform.TestTransforms; import org.datavec.api.transform.transform.TestTransforms;
import org.datavec.api.writable.*; import org.datavec.api.writable.*;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.util.*; import java.util.*;
import static org.junit.Assert.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.Assert.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
public class TestConditions extends BaseND4JTest { public class TestConditions extends BaseND4JTest {

View File

@ -27,7 +27,7 @@ import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.util.ArrayList; import java.util.ArrayList;
@ -36,8 +36,8 @@ import java.util.Collections;
import java.util.List; import java.util.List;
import static java.util.Arrays.asList; import static java.util.Arrays.asList;
import static org.junit.Assert.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.Assert.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
public class TestFilters extends BaseND4JTest { public class TestFilters extends BaseND4JTest {

View File

@ -26,19 +26,22 @@ import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.NullWritable; import org.datavec.api.writable.NullWritable;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.nio.file.Path;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
public class TestJoin extends BaseND4JTest { public class TestJoin extends BaseND4JTest {
@Test @Test
public void testJoin() { public void testJoin(@TempDir Path testDir) {
Schema firstSchema = Schema firstSchema =
new Schema.Builder().addColumnString("keyColumn").addColumnsInteger("first0", "first1").build(); new Schema.Builder().addColumnString("keyColumn").addColumnsInteger("first0", "first1").build();
@ -46,20 +49,20 @@ public class TestJoin extends BaseND4JTest {
Schema secondSchema = new Schema.Builder().addColumnString("keyColumn").addColumnsInteger("second0").build(); Schema secondSchema = new Schema.Builder().addColumnString("keyColumn").addColumnsInteger("second0").build();
List<List<Writable>> first = new ArrayList<>(); List<List<Writable>> first = new ArrayList<>();
first.add(Arrays.asList((Writable) new Text("key0"), new IntWritable(0), new IntWritable(1))); first.add(Arrays.asList(new Text("key0"), new IntWritable(0), new IntWritable(1)));
first.add(Arrays.asList((Writable) new Text("key1"), new IntWritable(10), new IntWritable(11))); first.add(Arrays.asList(new Text("key1"), new IntWritable(10), new IntWritable(11)));
List<List<Writable>> second = new ArrayList<>(); List<List<Writable>> second = new ArrayList<>();
second.add(Arrays.asList((Writable) new Text("key0"), new IntWritable(100))); second.add(Arrays.asList(new Text("key0"), new IntWritable(100)));
second.add(Arrays.asList((Writable) new Text("key1"), new IntWritable(110))); second.add(Arrays.asList(new Text("key1"), new IntWritable(110)));
Join join = new Join.Builder(Join.JoinType.Inner).setJoinColumns("keyColumn") Join join = new Join.Builder(Join.JoinType.Inner).setJoinColumns("keyColumn")
.setSchemas(firstSchema, secondSchema).build(); .setSchemas(firstSchema, secondSchema).build();
List<List<Writable>> expected = new ArrayList<>(); List<List<Writable>> expected = new ArrayList<>();
expected.add(Arrays.asList((Writable) new Text("key0"), new IntWritable(0), new IntWritable(1), expected.add(Arrays.asList(new Text("key0"), new IntWritable(0), new IntWritable(1),
new IntWritable(100))); new IntWritable(100)));
expected.add(Arrays.asList((Writable) new Text("key1"), new IntWritable(10), new IntWritable(11), expected.add(Arrays.asList(new Text("key1"), new IntWritable(10), new IntWritable(11),
new IntWritable(110))); new IntWritable(110)));
@ -94,9 +97,9 @@ public class TestJoin extends BaseND4JTest {
} }
@Test(expected = IllegalArgumentException.class) @Test()
public void testJoinValidation() { public void testJoinValidation() {
assertThrows(IllegalArgumentException.class,() -> {
Schema firstSchema = new Schema.Builder().addColumnString("keyColumn1").addColumnsInteger("first0", "first1") Schema firstSchema = new Schema.Builder().addColumnString("keyColumn1").addColumnsInteger("first0", "first1")
.build(); .build();
@ -104,11 +107,13 @@ public class TestJoin extends BaseND4JTest {
new Join.Builder(Join.JoinType.Inner).setJoinColumns("keyColumn1", "thisDoesntExist") new Join.Builder(Join.JoinType.Inner).setJoinColumns("keyColumn1", "thisDoesntExist")
.setSchemas(firstSchema, secondSchema).build(); .setSchemas(firstSchema, secondSchema).build();
});
} }
@Test(expected = IllegalArgumentException.class) @Test()
public void testJoinValidation2() { public void testJoinValidation2() {
assertThrows(IllegalArgumentException.class,() -> {
Schema firstSchema = new Schema.Builder().addColumnString("keyColumn1").addColumnsInteger("first0", "first1") Schema firstSchema = new Schema.Builder().addColumnString("keyColumn1").addColumnsInteger("first0", "first1")
.build(); .build();
@ -116,5 +121,7 @@ public class TestJoin extends BaseND4JTest {
new Join.Builder(Join.JoinType.Inner).setJoinColumns("keyColumn1").setSchemas(firstSchema, secondSchema) new Join.Builder(Join.JoinType.Inner).setJoinColumns("keyColumn1").setSchemas(firstSchema, secondSchema)
.build(); .build();
});
} }
} }

View File

@ -17,32 +17,25 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.transform.ops; package org.datavec.api.transform.ops;
import com.tngtech.archunit.core.importer.ImportOption; import com.tngtech.archunit.core.importer.ImportOption;
import com.tngtech.archunit.junit.AnalyzeClasses; import com.tngtech.archunit.junit.AnalyzeClasses;
import com.tngtech.archunit.junit.ArchTest; import com.tngtech.archunit.junit.ArchTest;
import com.tngtech.archunit.junit.ArchUnitRunner;
import com.tngtech.archunit.lang.ArchRule; import com.tngtech.archunit.lang.ArchRule;
import com.tngtech.archunit.lang.extension.ArchUnitExtension;
import com.tngtech.archunit.lang.extension.ArchUnitExtensions;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.io.Serializable; import java.io.Serializable;
import static com.tngtech.archunit.lang.syntax.ArchRuleDefinition.classes; import static com.tngtech.archunit.lang.syntax.ArchRuleDefinition.classes;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@RunWith(ArchUnitRunner.class) @AnalyzeClasses(packages = "org.datavec.api.transform.ops", importOptions = { ImportOption.DoNotIncludeTests.class })
@AnalyzeClasses(packages = "org.datavec.api.transform.ops", importOptions = {ImportOption.DoNotIncludeTests.class}) @DisplayName("Aggregable Multi Op Arch Test")
public class AggregableMultiOpArchTest extends BaseND4JTest { class AggregableMultiOpArchTest extends BaseND4JTest {
@ArchTest @ArchTest
public static final ArchRule ALL_AGGREGATE_OPS_MUST_BE_SERIALIZABLE = classes() public static final ArchRule ALL_AGGREGATE_OPS_MUST_BE_SERIALIZABLE = classes().that().resideInAPackage("org.datavec.api.transform.ops").and().doNotHaveSimpleName("AggregatorImpls").and().doNotHaveSimpleName("IAggregableReduceOp").and().doNotHaveSimpleName("StringAggregatorImpls").and().doNotHaveFullyQualifiedName("org.datavec.api.transform.ops.StringAggregatorImpls$1").should().implement(Serializable.class).because("All aggregate ops must be serializable.");
.that().resideInAPackage("org.datavec.api.transform.ops")
.and().doNotHaveSimpleName("AggregatorImpls")
.and().doNotHaveSimpleName("IAggregableReduceOp")
.and().doNotHaveSimpleName("StringAggregatorImpls")
.and().doNotHaveFullyQualifiedName("org.datavec.api.transform.ops.StringAggregatorImpls$1")
.should().implement(Serializable.class)
.because("All aggregate ops must be serializable.");
} }

View File

@ -17,52 +17,46 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.transform.ops; package org.datavec.api.transform.ops;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.util.*; import java.util.*;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertTrue; @DisplayName("Aggregable Multi Op Test")
class AggregableMultiOpTest extends BaseND4JTest {
public class AggregableMultiOpTest extends BaseND4JTest {
private List<Integer> intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); private List<Integer> intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
@Test @Test
public void testMulti() throws Exception { @DisplayName("Test Multi")
void testMulti() throws Exception {
AggregatorImpls.AggregableFirst<Integer> af = new AggregatorImpls.AggregableFirst<>(); AggregatorImpls.AggregableFirst<Integer> af = new AggregatorImpls.AggregableFirst<>();
AggregatorImpls.AggregableSum<Integer> as = new AggregatorImpls.AggregableSum<>(); AggregatorImpls.AggregableSum<Integer> as = new AggregatorImpls.AggregableSum<>();
AggregableMultiOp<Integer> multi = new AggregableMultiOp<>(Arrays.asList(af, as)); AggregableMultiOp<Integer> multi = new AggregableMultiOp<>(Arrays.asList(af, as));
assertTrue(multi.getOperations().size() == 2); assertTrue(multi.getOperations().size() == 2);
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
multi.accept(intList.get(i)); multi.accept(intList.get(i));
} }
// mutablility // mutablility
assertTrue(as.get().toDouble() == 45D); assertTrue(as.get().toDouble() == 45D);
assertTrue(af.get().toInt() == 1); assertTrue(af.get().toInt() == 1);
List<Writable> res = multi.get(); List<Writable> res = multi.get();
assertTrue(res.get(1).toDouble() == 45D); assertTrue(res.get(1).toDouble() == 45D);
assertTrue(res.get(0).toInt() == 1); assertTrue(res.get(0).toInt() == 1);
AggregatorImpls.AggregableFirst<Integer> rf = new AggregatorImpls.AggregableFirst<>(); AggregatorImpls.AggregableFirst<Integer> rf = new AggregatorImpls.AggregableFirst<>();
AggregatorImpls.AggregableSum<Integer> rs = new AggregatorImpls.AggregableSum<>(); AggregatorImpls.AggregableSum<Integer> rs = new AggregatorImpls.AggregableSum<>();
AggregableMultiOp<Integer> reverse = new AggregableMultiOp<>(Arrays.asList(rf, rs)); AggregableMultiOp<Integer> reverse = new AggregableMultiOp<>(Arrays.asList(rf, rs));
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1)); reverse.accept(intList.get(intList.size() - i - 1));
} }
List<Writable> revRes = reverse.get(); List<Writable> revRes = reverse.get();
assertTrue(revRes.get(1).toDouble() == 45D); assertTrue(revRes.get(1).toDouble() == 45D);
assertTrue(revRes.get(0).toInt() == 9); assertTrue(revRes.get(0).toInt() == 9);
multi.combine(reverse); multi.combine(reverse);
List<Writable> combinedRes = multi.get(); List<Writable> combinedRes = multi.get();
assertTrue(combinedRes.get(1).toDouble() == 90D); assertTrue(combinedRes.get(1).toDouble() == 90D);

View File

@ -17,41 +17,39 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.transform.ops; package org.datavec.api.transform.ops;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.nd4j.common.tests.BaseND4JTest;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import org.junit.jupiter.api.DisplayName;
import static org.junit.Assert.assertTrue;
public class AggregatorImplsTest extends BaseND4JTest { import static org.junit.jupiter.api.Assertions.*;
@DisplayName("Aggregator Impls Test")
class AggregatorImplsTest extends BaseND4JTest {
private List<Integer> intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); private List<Integer> intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
private List<String> stringList = new ArrayList<>(Arrays.asList("arakoa", "abracadabra", "blast", "acceptance")); private List<String> stringList = new ArrayList<>(Arrays.asList("arakoa", "abracadabra", "blast", "acceptance"));
@Test @Test
public void aggregableFirstTest() { @DisplayName("Aggregable First Test")
void aggregableFirstTest() {
AggregatorImpls.AggregableFirst<Integer> first = new AggregatorImpls.AggregableFirst<>(); AggregatorImpls.AggregableFirst<Integer> first = new AggregatorImpls.AggregableFirst<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
first.accept(intList.get(i)); first.accept(intList.get(i));
} }
assertEquals(1, first.get().toInt()); assertEquals(1, first.get().toInt());
AggregatorImpls.AggregableFirst<String> firstS = new AggregatorImpls.AggregableFirst<>(); AggregatorImpls.AggregableFirst<String> firstS = new AggregatorImpls.AggregableFirst<>();
for (int i = 0; i < stringList.size(); i++) { for (int i = 0; i < stringList.size(); i++) {
firstS.accept(stringList.get(i)); firstS.accept(stringList.get(i));
} }
assertTrue(firstS.get().toString().equals("arakoa")); assertTrue(firstS.get().toString().equals("arakoa"));
AggregatorImpls.AggregableFirst<Integer> reverse = new AggregatorImpls.AggregableFirst<>(); AggregatorImpls.AggregableFirst<Integer> reverse = new AggregatorImpls.AggregableFirst<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1)); reverse.accept(intList.get(intList.size() - i - 1));
@ -60,22 +58,19 @@ public class AggregatorImplsTest extends BaseND4JTest {
assertEquals(1, first.get().toInt()); assertEquals(1, first.get().toInt());
} }
@Test @Test
public void aggregableLastTest() { @DisplayName("Aggregable Last Test")
void aggregableLastTest() {
AggregatorImpls.AggregableLast<Integer> last = new AggregatorImpls.AggregableLast<>(); AggregatorImpls.AggregableLast<Integer> last = new AggregatorImpls.AggregableLast<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
last.accept(intList.get(i)); last.accept(intList.get(i));
} }
assertEquals(9, last.get().toInt()); assertEquals(9, last.get().toInt());
AggregatorImpls.AggregableLast<String> lastS = new AggregatorImpls.AggregableLast<>(); AggregatorImpls.AggregableLast<String> lastS = new AggregatorImpls.AggregableLast<>();
for (int i = 0; i < stringList.size(); i++) { for (int i = 0; i < stringList.size(); i++) {
lastS.accept(stringList.get(i)); lastS.accept(stringList.get(i));
} }
assertTrue(lastS.get().toString().equals("acceptance")); assertTrue(lastS.get().toString().equals("acceptance"));
AggregatorImpls.AggregableLast<Integer> reverse = new AggregatorImpls.AggregableLast<>(); AggregatorImpls.AggregableLast<Integer> reverse = new AggregatorImpls.AggregableLast<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1)); reverse.accept(intList.get(intList.size() - i - 1));
@ -85,20 +80,18 @@ public class AggregatorImplsTest extends BaseND4JTest {
} }
@Test @Test
public void aggregableCountTest() { @DisplayName("Aggregable Count Test")
void aggregableCountTest() {
AggregatorImpls.AggregableCount<Integer> cnt = new AggregatorImpls.AggregableCount<>(); AggregatorImpls.AggregableCount<Integer> cnt = new AggregatorImpls.AggregableCount<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
cnt.accept(intList.get(i)); cnt.accept(intList.get(i));
} }
assertEquals(9, cnt.get().toInt()); assertEquals(9, cnt.get().toInt());
AggregatorImpls.AggregableCount<String> lastS = new AggregatorImpls.AggregableCount<>(); AggregatorImpls.AggregableCount<String> lastS = new AggregatorImpls.AggregableCount<>();
for (int i = 0; i < stringList.size(); i++) { for (int i = 0; i < stringList.size(); i++) {
lastS.accept(stringList.get(i)); lastS.accept(stringList.get(i));
} }
assertEquals(4, lastS.get().toInt()); assertEquals(4, lastS.get().toInt());
AggregatorImpls.AggregableCount<Integer> reverse = new AggregatorImpls.AggregableCount<>(); AggregatorImpls.AggregableCount<Integer> reverse = new AggregatorImpls.AggregableCount<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1)); reverse.accept(intList.get(intList.size() - i - 1));
@ -108,14 +101,13 @@ public class AggregatorImplsTest extends BaseND4JTest {
} }
@Test @Test
public void aggregableMaxTest() { @DisplayName("Aggregable Max Test")
void aggregableMaxTest() {
AggregatorImpls.AggregableMax<Integer> mx = new AggregatorImpls.AggregableMax<>(); AggregatorImpls.AggregableMax<Integer> mx = new AggregatorImpls.AggregableMax<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
mx.accept(intList.get(i)); mx.accept(intList.get(i));
} }
assertEquals(9, mx.get().toInt()); assertEquals(9, mx.get().toInt());
AggregatorImpls.AggregableMax<Integer> reverse = new AggregatorImpls.AggregableMax<>(); AggregatorImpls.AggregableMax<Integer> reverse = new AggregatorImpls.AggregableMax<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1)); reverse.accept(intList.get(intList.size() - i - 1));
@ -124,16 +116,14 @@ public class AggregatorImplsTest extends BaseND4JTest {
assertEquals(9, mx.get().toInt()); assertEquals(9, mx.get().toInt());
} }
@Test @Test
public void aggregableRangeTest() { @DisplayName("Aggregable Range Test")
void aggregableRangeTest() {
AggregatorImpls.AggregableRange<Integer> mx = new AggregatorImpls.AggregableRange<>(); AggregatorImpls.AggregableRange<Integer> mx = new AggregatorImpls.AggregableRange<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
mx.accept(intList.get(i)); mx.accept(intList.get(i));
} }
assertEquals(8, mx.get().toInt()); assertEquals(8, mx.get().toInt());
AggregatorImpls.AggregableRange<Integer> reverse = new AggregatorImpls.AggregableRange<>(); AggregatorImpls.AggregableRange<Integer> reverse = new AggregatorImpls.AggregableRange<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1) + 9); reverse.accept(intList.get(intList.size() - i - 1) + 9);
@ -143,14 +133,13 @@ public class AggregatorImplsTest extends BaseND4JTest {
} }
@Test @Test
public void aggregableMinTest() { @DisplayName("Aggregable Min Test")
void aggregableMinTest() {
AggregatorImpls.AggregableMin<Integer> mn = new AggregatorImpls.AggregableMin<>(); AggregatorImpls.AggregableMin<Integer> mn = new AggregatorImpls.AggregableMin<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
mn.accept(intList.get(i)); mn.accept(intList.get(i));
} }
assertEquals(1, mn.get().toInt()); assertEquals(1, mn.get().toInt());
AggregatorImpls.AggregableMin<Integer> reverse = new AggregatorImpls.AggregableMin<>(); AggregatorImpls.AggregableMin<Integer> reverse = new AggregatorImpls.AggregableMin<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1)); reverse.accept(intList.get(intList.size() - i - 1));
@ -160,14 +149,13 @@ public class AggregatorImplsTest extends BaseND4JTest {
} }
@Test @Test
public void aggregableSumTest() { @DisplayName("Aggregable Sum Test")
void aggregableSumTest() {
AggregatorImpls.AggregableSum<Integer> sm = new AggregatorImpls.AggregableSum<>(); AggregatorImpls.AggregableSum<Integer> sm = new AggregatorImpls.AggregableSum<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
sm.accept(intList.get(i)); sm.accept(intList.get(i));
} }
assertEquals(45, sm.get().toInt()); assertEquals(45, sm.get().toInt());
AggregatorImpls.AggregableSum<Integer> reverse = new AggregatorImpls.AggregableSum<>(); AggregatorImpls.AggregableSum<Integer> reverse = new AggregatorImpls.AggregableSum<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1)); reverse.accept(intList.get(intList.size() - i - 1));
@ -176,17 +164,15 @@ public class AggregatorImplsTest extends BaseND4JTest {
assertEquals(90, sm.get().toInt()); assertEquals(90, sm.get().toInt());
} }
@Test @Test
public void aggregableMeanTest() { @DisplayName("Aggregable Mean Test")
void aggregableMeanTest() {
AggregatorImpls.AggregableMean<Integer> mn = new AggregatorImpls.AggregableMean<>(); AggregatorImpls.AggregableMean<Integer> mn = new AggregatorImpls.AggregableMean<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
mn.accept(intList.get(i)); mn.accept(intList.get(i));
} }
assertEquals(9l, (long) mn.getCount()); assertEquals(9l, (long) mn.getCount());
assertEquals(5D, mn.get().toDouble(), 0.001); assertEquals(5D, mn.get().toDouble(), 0.001);
AggregatorImpls.AggregableMean<Integer> reverse = new AggregatorImpls.AggregableMean<>(); AggregatorImpls.AggregableMean<Integer> reverse = new AggregatorImpls.AggregableMean<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1)); reverse.accept(intList.get(intList.size() - i - 1));
@ -197,80 +183,73 @@ public class AggregatorImplsTest extends BaseND4JTest {
} }
@Test @Test
public void aggregableStdDevTest() { @DisplayName("Aggregable Std Dev Test")
void aggregableStdDevTest() {
AggregatorImpls.AggregableStdDev<Integer> sd = new AggregatorImpls.AggregableStdDev<>(); AggregatorImpls.AggregableStdDev<Integer> sd = new AggregatorImpls.AggregableStdDev<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
sd.accept(intList.get(i)); sd.accept(intList.get(i));
} }
assertTrue(Math.abs(sd.get().toDouble() - 2.7386) < 0.0001); assertTrue(Math.abs(sd.get().toDouble() - 2.7386) < 0.0001);
AggregatorImpls.AggregableStdDev<Integer> reverse = new AggregatorImpls.AggregableStdDev<>(); AggregatorImpls.AggregableStdDev<Integer> reverse = new AggregatorImpls.AggregableStdDev<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1)); reverse.accept(intList.get(intList.size() - i - 1));
} }
sd.combine(reverse); sd.combine(reverse);
assertTrue("" + sd.get().toDouble(), Math.abs(sd.get().toDouble() - 1.8787) < 0.0001); assertTrue(Math.abs(sd.get().toDouble() - 1.8787) < 0.0001,"" + sd.get().toDouble());
} }
@Test @Test
public void aggregableVariance() { @DisplayName("Aggregable Variance")
void aggregableVariance() {
AggregatorImpls.AggregableVariance<Integer> sd = new AggregatorImpls.AggregableVariance<>(); AggregatorImpls.AggregableVariance<Integer> sd = new AggregatorImpls.AggregableVariance<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
sd.accept(intList.get(i)); sd.accept(intList.get(i));
} }
assertTrue(Math.abs(sd.get().toDouble() - 60D / 8) < 0.0001); assertTrue(Math.abs(sd.get().toDouble() - 60D / 8) < 0.0001);
AggregatorImpls.AggregableVariance<Integer> reverse = new AggregatorImpls.AggregableVariance<>(); AggregatorImpls.AggregableVariance<Integer> reverse = new AggregatorImpls.AggregableVariance<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1)); reverse.accept(intList.get(intList.size() - i - 1));
} }
sd.combine(reverse); sd.combine(reverse);
assertTrue("" + sd.get().toDouble(), Math.abs(sd.get().toDouble() - 3.5294) < 0.0001); assertTrue(Math.abs(sd.get().toDouble() - 3.5294) < 0.0001,"" + sd.get().toDouble());
} }
@Test @Test
public void aggregableUncorrectedStdDevTest() { @DisplayName("Aggregable Uncorrected Std Dev Test")
void aggregableUncorrectedStdDevTest() {
AggregatorImpls.AggregableUncorrectedStdDev<Integer> sd = new AggregatorImpls.AggregableUncorrectedStdDev<>(); AggregatorImpls.AggregableUncorrectedStdDev<Integer> sd = new AggregatorImpls.AggregableUncorrectedStdDev<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
sd.accept(intList.get(i)); sd.accept(intList.get(i));
} }
assertTrue(Math.abs(sd.get().toDouble() - 2.582) < 0.0001); assertTrue(Math.abs(sd.get().toDouble() - 2.582) < 0.0001);
AggregatorImpls.AggregableUncorrectedStdDev<Integer> reverse = new AggregatorImpls.AggregableUncorrectedStdDev<>();
AggregatorImpls.AggregableUncorrectedStdDev<Integer> reverse =
new AggregatorImpls.AggregableUncorrectedStdDev<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1)); reverse.accept(intList.get(intList.size() - i - 1));
} }
sd.combine(reverse); sd.combine(reverse);
assertTrue("" + sd.get().toDouble(), Math.abs(sd.get().toDouble() - 1.8257) < 0.0001); assertTrue(Math.abs(sd.get().toDouble() - 1.8257) < 0.0001,"" + sd.get().toDouble());
} }
@Test @Test
public void aggregablePopulationVariance() { @DisplayName("Aggregable Population Variance")
void aggregablePopulationVariance() {
AggregatorImpls.AggregablePopulationVariance<Integer> sd = new AggregatorImpls.AggregablePopulationVariance<>(); AggregatorImpls.AggregablePopulationVariance<Integer> sd = new AggregatorImpls.AggregablePopulationVariance<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
sd.accept(intList.get(i)); sd.accept(intList.get(i));
} }
assertTrue(Math.abs(sd.get().toDouble() - 60D / 9) < 0.0001); assertTrue(Math.abs(sd.get().toDouble() - 60D / 9) < 0.0001);
AggregatorImpls.AggregablePopulationVariance<Integer> reverse = new AggregatorImpls.AggregablePopulationVariance<>();
AggregatorImpls.AggregablePopulationVariance<Integer> reverse =
new AggregatorImpls.AggregablePopulationVariance<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1)); reverse.accept(intList.get(intList.size() - i - 1));
} }
sd.combine(reverse); sd.combine(reverse);
assertTrue("" + sd.get().toDouble(), Math.abs(sd.get().toDouble() - 30D / 9) < 0.0001); assertTrue(Math.abs(sd.get().toDouble() - 30D / 9) < 0.0001,"" + sd.get().toDouble());
} }
@Test @Test
public void aggregableCountUniqueTest() { @DisplayName("Aggregable Count Unique Test")
void aggregableCountUniqueTest() {
// at this low range, it's linear counting // at this low range, it's linear counting
AggregatorImpls.AggregableCountUnique<Integer> cu = new AggregatorImpls.AggregableCountUnique<>(); AggregatorImpls.AggregableCountUnique<Integer> cu = new AggregatorImpls.AggregableCountUnique<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
cu.accept(intList.get(i)); cu.accept(intList.get(i));
@ -278,7 +257,6 @@ public class AggregatorImplsTest extends BaseND4JTest {
assertEquals(9, cu.get().toInt()); assertEquals(9, cu.get().toInt());
cu.accept(1); cu.accept(1);
assertEquals(9, cu.get().toInt()); assertEquals(9, cu.get().toInt());
AggregatorImpls.AggregableCountUnique<Integer> reverse = new AggregatorImpls.AggregableCountUnique<>(); AggregatorImpls.AggregableCountUnique<Integer> reverse = new AggregatorImpls.AggregableCountUnique<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1)); reverse.accept(intList.get(intList.size() - i - 1));
@ -287,26 +265,25 @@ public class AggregatorImplsTest extends BaseND4JTest {
assertEquals(9, cu.get().toInt()); assertEquals(9, cu.get().toInt());
} }
@Rule
public final ExpectedException exception = ExpectedException.none();
@Test @Test
public void incompatibleAggregatorTest() { @DisplayName("Incompatible Aggregator Test")
void incompatibleAggregatorTest() {
assertThrows(UnsupportedOperationException.class,() -> {
AggregatorImpls.AggregableSum<Integer> sm = new AggregatorImpls.AggregableSum<>(); AggregatorImpls.AggregableSum<Integer> sm = new AggregatorImpls.AggregableSum<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
sm.accept(intList.get(i)); sm.accept(intList.get(i));
} }
assertEquals(45, sm.get().toInt()); assertEquals(45, sm.get().toInt());
AggregatorImpls.AggregableMean<Integer> reverse = new AggregatorImpls.AggregableMean<>(); AggregatorImpls.AggregableMean<Integer> reverse = new AggregatorImpls.AggregableMean<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1)); reverse.accept(intList.get(intList.size() - i - 1));
} }
exception.expect(UnsupportedOperationException.class);
sm.combine(reverse); sm.combine(reverse);
assertEquals(45, sm.get().toInt()); assertEquals(45, sm.get().toInt());
} });
}
} }

View File

@ -17,77 +17,65 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.transform.ops; package org.datavec.api.transform.ops;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertTrue; @DisplayName("Dispatch Op Test")
class DispatchOpTest extends BaseND4JTest {
public class DispatchOpTest extends BaseND4JTest {
private List<Integer> intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); private List<Integer> intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
private List<String> stringList = new ArrayList<>(Arrays.asList("arakoa", "abracadabra", "blast", "acceptance")); private List<String> stringList = new ArrayList<>(Arrays.asList("arakoa", "abracadabra", "blast", "acceptance"));
@Test @Test
public void testDispatchSimple() { @DisplayName("Test Dispatch Simple")
void testDispatchSimple() {
AggregatorImpls.AggregableFirst<Integer> af = new AggregatorImpls.AggregableFirst<>(); AggregatorImpls.AggregableFirst<Integer> af = new AggregatorImpls.AggregableFirst<>();
AggregatorImpls.AggregableSum<Integer> as = new AggregatorImpls.AggregableSum<>(); AggregatorImpls.AggregableSum<Integer> as = new AggregatorImpls.AggregableSum<>();
AggregableMultiOp<Integer> multiaf = AggregableMultiOp<Integer> multiaf = new AggregableMultiOp<>(Collections.<IAggregableReduceOp<Integer, Writable>>singletonList(af));
new AggregableMultiOp<>(Collections.<IAggregableReduceOp<Integer, Writable>>singletonList(af)); AggregableMultiOp<Integer> multias = new AggregableMultiOp<>(Collections.<IAggregableReduceOp<Integer, Writable>>singletonList(as));
AggregableMultiOp<Integer> multias = DispatchOp<Integer, Writable> parallel = new DispatchOp<>(Arrays.<IAggregableReduceOp<Integer, List<Writable>>>asList(multiaf, multias));
new AggregableMultiOp<>(Collections.<IAggregableReduceOp<Integer, Writable>>singletonList(as));
DispatchOp<Integer, Writable> parallel =
new DispatchOp<>(Arrays.<IAggregableReduceOp<Integer, List<Writable>>>asList(multiaf, multias));
assertTrue(multiaf.getOperations().size() == 1); assertTrue(multiaf.getOperations().size() == 1);
assertTrue(multias.getOperations().size() == 1); assertTrue(multias.getOperations().size() == 1);
assertTrue(parallel.getOperations().size() == 2); assertTrue(parallel.getOperations().size() == 2);
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
parallel.accept(Arrays.asList(intList.get(i), intList.get(i))); parallel.accept(Arrays.asList(intList.get(i), intList.get(i)));
} }
List<Writable> res = parallel.get(); List<Writable> res = parallel.get();
assertTrue(res.get(1).toDouble() == 45D); assertTrue(res.get(1).toDouble() == 45D);
assertTrue(res.get(0).toInt() == 1); assertTrue(res.get(0).toInt() == 1);
} }
@Test @Test
public void testDispatchFlatMap() { @DisplayName("Test Dispatch Flat Map")
void testDispatchFlatMap() {
AggregatorImpls.AggregableFirst<Integer> af = new AggregatorImpls.AggregableFirst<>(); AggregatorImpls.AggregableFirst<Integer> af = new AggregatorImpls.AggregableFirst<>();
AggregatorImpls.AggregableSum<Integer> as = new AggregatorImpls.AggregableSum<>(); AggregatorImpls.AggregableSum<Integer> as = new AggregatorImpls.AggregableSum<>();
AggregableMultiOp<Integer> multi = new AggregableMultiOp<>(Arrays.asList(af, as)); AggregableMultiOp<Integer> multi = new AggregableMultiOp<>(Arrays.asList(af, as));
AggregatorImpls.AggregableLast<Integer> al = new AggregatorImpls.AggregableLast<>(); AggregatorImpls.AggregableLast<Integer> al = new AggregatorImpls.AggregableLast<>();
AggregatorImpls.AggregableMax<Integer> amax = new AggregatorImpls.AggregableMax<>(); AggregatorImpls.AggregableMax<Integer> amax = new AggregatorImpls.AggregableMax<>();
AggregableMultiOp<Integer> otherMulti = new AggregableMultiOp<>(Arrays.asList(al, amax)); AggregableMultiOp<Integer> otherMulti = new AggregableMultiOp<>(Arrays.asList(al, amax));
DispatchOp<Integer, Writable> parallel = new DispatchOp<>(Arrays.<IAggregableReduceOp<Integer, List<Writable>>>asList(multi, otherMulti));
DispatchOp<Integer, Writable> parallel = new DispatchOp<>(
Arrays.<IAggregableReduceOp<Integer, List<Writable>>>asList(multi, otherMulti));
assertTrue(multi.getOperations().size() == 2); assertTrue(multi.getOperations().size() == 2);
assertTrue(otherMulti.getOperations().size() == 2); assertTrue(otherMulti.getOperations().size() == 2);
assertTrue(parallel.getOperations().size() == 2); assertTrue(parallel.getOperations().size() == 2);
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
parallel.accept(Arrays.asList(intList.get(i), intList.get(i))); parallel.accept(Arrays.asList(intList.get(i), intList.get(i)));
} }
List<Writable> res = parallel.get(); List<Writable> res = parallel.get();
assertTrue(res.get(1).toDouble() == 45D); assertTrue(res.get(1).toDouble() == 45D);
assertTrue(res.get(0).toInt() == 1); assertTrue(res.get(0).toInt() == 1);
assertTrue(res.get(3).toDouble() == 9); assertTrue(res.get(3).toDouble() == 9);
assertTrue(res.get(2).toInt() == 9); assertTrue(res.get(2).toInt() == 9);
} }
} }

View File

@ -32,13 +32,14 @@ import org.datavec.api.transform.ops.AggregableMultiOp;
import org.datavec.api.transform.ops.IAggregableReduceOp; import org.datavec.api.transform.ops.IAggregableReduceOp;
import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.*; import org.datavec.api.writable.*;
import org.junit.Test; import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.util.*; import java.util.*;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.fail; import static org.junit.jupiter.api.Assertions.fail;
public class TestMultiOpReduce extends BaseND4JTest { public class TestMultiOpReduce extends BaseND4JTest {
@ -46,10 +47,10 @@ public class TestMultiOpReduce extends BaseND4JTest {
public void testMultiOpReducerDouble() { public void testMultiOpReducerDouble() {
List<List<Writable>> inputs = new ArrayList<>(); List<List<Writable>> inputs = new ArrayList<>();
inputs.add(Arrays.asList((Writable) new Text("someKey"), new DoubleWritable(0))); inputs.add(Arrays.asList(new Text("someKey"), new DoubleWritable(0)));
inputs.add(Arrays.asList((Writable) new Text("someKey"), new DoubleWritable(1))); inputs.add(Arrays.asList(new Text("someKey"), new DoubleWritable(1)));
inputs.add(Arrays.asList((Writable) new Text("someKey"), new DoubleWritable(2))); inputs.add(Arrays.asList(new Text("someKey"), new DoubleWritable(2)));
inputs.add(Arrays.asList((Writable) new Text("someKey"), new DoubleWritable(2))); inputs.add(Arrays.asList(new Text("someKey"), new DoubleWritable(2)));
Map<ReduceOp, Double> exp = new LinkedHashMap<>(); Map<ReduceOp, Double> exp = new LinkedHashMap<>();
exp.put(ReduceOp.Min, 0.0); exp.put(ReduceOp.Min, 0.0);
@ -82,7 +83,7 @@ public class TestMultiOpReduce extends BaseND4JTest {
assertEquals(out.get(0), new Text("someKey")); assertEquals(out.get(0), new Text("someKey"));
String msg = op.toString(); String msg = op.toString();
assertEquals(msg, exp.get(op), out.get(1).toDouble(), 1e-5); assertEquals(exp.get(op), out.get(1).toDouble(), 1e-5,msg);
} }
} }
@ -126,19 +127,20 @@ public class TestMultiOpReduce extends BaseND4JTest {
assertEquals(out.get(0), new Text("someKey")); assertEquals(out.get(0), new Text("someKey"));
String msg = op.toString(); String msg = op.toString();
assertEquals(msg, exp.get(op), out.get(1).toDouble(), 1e-5); assertEquals(exp.get(op), out.get(1).toDouble(), 1e-5,msg);
} }
} }
@Test @Test
@Disabled
public void testReduceString() { public void testReduceString() {
List<List<Writable>> inputs = new ArrayList<>(); List<List<Writable>> inputs = new ArrayList<>();
inputs.add(Arrays.asList((Writable) new Text("someKey"), new Text("1"))); inputs.add(Arrays.asList(new Text("someKey"), new Text("1")));
inputs.add(Arrays.asList((Writable) new Text("someKey"), new Text("2"))); inputs.add(Arrays.asList(new Text("someKey"), new Text("2")));
inputs.add(Arrays.asList((Writable) new Text("someKey"), new Text("3"))); inputs.add(Arrays.asList(new Text("someKey"), new Text("3")));
inputs.add(Arrays.asList((Writable) new Text("someKey"), new Text("4"))); inputs.add(Arrays.asList(new Text("someKey"), new Text("4")));
Map<ReduceOp, String> exp = new LinkedHashMap<>(); Map<ReduceOp, String> exp = new LinkedHashMap<>();
exp.put(ReduceOp.Append, "1234"); exp.put(ReduceOp.Append, "1234");
@ -210,7 +212,7 @@ public class TestMultiOpReduce extends BaseND4JTest {
assertEquals(out.get(0), new Text("someKey")); assertEquals(out.get(0), new Text("someKey"));
String msg = op.toString(); String msg = op.toString();
assertEquals(msg, exp.get(op), out.get(1).toDouble(), 1e-5); assertEquals(exp.get(op), out.get(1).toDouble(), 1e-5,msg);
} }
for (ReduceOp op : Arrays.asList(ReduceOp.Min, ReduceOp.Max, ReduceOp.Range, ReduceOp.Sum, ReduceOp.Mean, for (ReduceOp op : Arrays.asList(ReduceOp.Min, ReduceOp.Max, ReduceOp.Range, ReduceOp.Sum, ReduceOp.Mean,

View File

@ -24,13 +24,13 @@ import org.datavec.api.transform.ops.IAggregableReduceOp;
import org.datavec.api.transform.reduce.impl.GeographicMidpointReduction; import org.datavec.api.transform.reduce.impl.GeographicMidpointReduction;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestReductions extends BaseND4JTest { public class TestReductions extends BaseND4JTest {

View File

@ -22,10 +22,10 @@ package org.datavec.api.transform.schema;
import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.metadata.ColumnMetaData;
import org.joda.time.DateTimeZone; import org.joda.time.DateTimeZone;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestJsonYaml extends BaseND4JTest { public class TestJsonYaml extends BaseND4JTest {

View File

@ -21,10 +21,10 @@
package org.datavec.api.transform.schema; package org.datavec.api.transform.schema;
import org.datavec.api.transform.ColumnType; import org.datavec.api.transform.ColumnType;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestSchemaMethods extends BaseND4JTest { public class TestSchemaMethods extends BaseND4JTest {

View File

@ -33,7 +33,7 @@ import org.datavec.api.writable.LongWritable;
import org.datavec.api.writable.NullWritable; import org.datavec.api.writable.NullWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.joda.time.DateTimeZone; import org.joda.time.DateTimeZone;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.util.ArrayList; import java.util.ArrayList;
@ -41,7 +41,7 @@ import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestReduceSequenceByWindowFunction extends BaseND4JTest { public class TestReduceSequenceByWindowFunction extends BaseND4JTest {

View File

@ -27,7 +27,7 @@ import org.datavec.api.writable.LongWritable;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.joda.time.DateTimeZone; import org.joda.time.DateTimeZone;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.util.ArrayList; import java.util.ArrayList;
@ -35,7 +35,7 @@ import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestSequenceSplit extends BaseND4JTest { public class TestSequenceSplit extends BaseND4JTest {

View File

@ -29,7 +29,7 @@ import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.LongWritable; import org.datavec.api.writable.LongWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.joda.time.DateTimeZone; import org.joda.time.DateTimeZone;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.util.ArrayList; import java.util.ArrayList;
@ -37,7 +37,7 @@ import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestWindowFunctions extends BaseND4JTest { public class TestWindowFunctions extends BaseND4JTest {

View File

@ -26,10 +26,10 @@ import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.serde.testClasses.CustomCondition; import org.datavec.api.transform.serde.testClasses.CustomCondition;
import org.datavec.api.transform.serde.testClasses.CustomFilter; import org.datavec.api.transform.serde.testClasses.CustomFilter;
import org.datavec.api.transform.serde.testClasses.CustomTransform; import org.datavec.api.transform.serde.testClasses.CustomTransform;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestCustomTransformJsonYaml extends BaseND4JTest { public class TestCustomTransformJsonYaml extends BaseND4JTest {

View File

@ -64,13 +64,13 @@ import org.datavec.api.transform.transform.time.TimeMathOpTransform;
import org.datavec.api.writable.comparator.DoubleWritableComparator; import org.datavec.api.writable.comparator.DoubleWritableComparator;
import org.joda.time.DateTimeFieldType; import org.joda.time.DateTimeFieldType;
import org.joda.time.DateTimeZone; import org.joda.time.DateTimeZone;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.util.*; import java.util.*;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestYamlJsonSerde extends BaseND4JTest { public class TestYamlJsonSerde extends BaseND4JTest {

View File

@ -24,12 +24,12 @@ import org.datavec.api.transform.StringReduceOp;
import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.util.*; import java.util.*;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestReduce extends BaseND4JTest { public class TestReduce extends BaseND4JTest {

View File

@ -50,7 +50,7 @@ import org.datavec.api.writable.Text;
import org.datavec.api.writable.comparator.LongWritableComparator; import org.datavec.api.writable.comparator.LongWritableComparator;
import org.joda.time.DateTimeFieldType; import org.joda.time.DateTimeFieldType;
import org.joda.time.DateTimeZone; import org.joda.time.DateTimeZone;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
@ -61,7 +61,7 @@ import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class RegressionTestJson extends BaseND4JTest { public class RegressionTestJson extends BaseND4JTest {

View File

@ -50,13 +50,13 @@ import org.datavec.api.writable.Text;
import org.datavec.api.writable.comparator.LongWritableComparator; import org.datavec.api.writable.comparator.LongWritableComparator;
import org.joda.time.DateTimeFieldType; import org.joda.time.DateTimeFieldType;
import org.joda.time.DateTimeZone; import org.joda.time.DateTimeZone;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.util.*; import java.util.*;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestJsonYaml extends BaseND4JTest { public class TestJsonYaml extends BaseND4JTest {

View File

@ -58,8 +58,8 @@ import org.datavec.api.transform.transform.time.TimeMathOpTransform;
import org.datavec.api.writable.*; import org.datavec.api.writable.*;
import org.joda.time.DateTimeFieldType; import org.joda.time.DateTimeFieldType;
import org.joda.time.DateTimeZone; import org.joda.time.DateTimeZone;
import org.junit.Assert;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -71,8 +71,8 @@ import java.io.ObjectOutputStream;
import java.util.*; import java.util.*;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static junit.framework.TestCase.assertEquals;
import static org.junit.Assert.*; import static org.junit.jupiter.api.Assertions.*;
public class TestTransforms extends BaseND4JTest { public class TestTransforms extends BaseND4JTest {
@ -277,22 +277,22 @@ public class TestTransforms extends BaseND4JTest {
List<String> outputColumns = new ArrayList<>(ALL_COLUMNS); List<String> outputColumns = new ArrayList<>(ALL_COLUMNS);
outputColumns.add(NEW_COLUMN); outputColumns.add(NEW_COLUMN);
Schema newSchema = transform.transform(schema); Schema newSchema = transform.transform(schema);
Assert.assertEquals(outputColumns, newSchema.getColumnNames()); assertEquals(outputColumns, newSchema.getColumnNames());
List<Writable> input = new ArrayList<>(); List<Writable> input = new ArrayList<>();
input.addAll(COLUMN_VALUES); input.addAll(COLUMN_VALUES);
transform.setInputSchema(schema); transform.setInputSchema(schema);
List<Writable> transformed = transform.map(input); List<Writable> transformed = transform.map(input);
Assert.assertEquals(NEW_COLUMN_VALUE, transformed.get(transformed.size() - 1).toString()); assertEquals(NEW_COLUMN_VALUE, transformed.get(transformed.size() - 1).toString());
List<Text> outputColumnValues = new ArrayList<>(COLUMN_VALUES); List<Text> outputColumnValues = new ArrayList<>(COLUMN_VALUES);
outputColumnValues.add(new Text(NEW_COLUMN_VALUE)); outputColumnValues.add(new Text(NEW_COLUMN_VALUE));
Assert.assertEquals(outputColumnValues, transformed); assertEquals(outputColumnValues, transformed);
String s = JsonMappers.getMapper().writeValueAsString(transform); String s = JsonMappers.getMapper().writeValueAsString(transform);
Transform transform2 = JsonMappers.getMapper().readValue(s, ConcatenateStringColumns.class); Transform transform2 = JsonMappers.getMapper().readValue(s, ConcatenateStringColumns.class);
Assert.assertEquals(transform, transform2); assertEquals(transform, transform2);
} }
@Test @Test
@ -309,7 +309,7 @@ public class TestTransforms extends BaseND4JTest {
transform.setInputSchema(schema); transform.setInputSchema(schema);
Schema newSchema = transform.transform(schema); Schema newSchema = transform.transform(schema);
List<String> outputColumns = new ArrayList<>(ALL_COLUMNS); List<String> outputColumns = new ArrayList<>(ALL_COLUMNS);
Assert.assertEquals(outputColumns, newSchema.getColumnNames()); assertEquals(outputColumns, newSchema.getColumnNames());
transform = new ChangeCaseStringTransform(STRING_COLUMN, ChangeCaseStringTransform.CaseType.LOWER); transform = new ChangeCaseStringTransform(STRING_COLUMN, ChangeCaseStringTransform.CaseType.LOWER);
transform.setInputSchema(schema); transform.setInputSchema(schema);
@ -320,8 +320,8 @@ public class TestTransforms extends BaseND4JTest {
output.add(new Text(TEXT_LOWER_CASE)); output.add(new Text(TEXT_LOWER_CASE));
output.add(new Text(TEXT_MIXED_CASE)); output.add(new Text(TEXT_MIXED_CASE));
List<Writable> transformed = transform.map(input); List<Writable> transformed = transform.map(input);
Assert.assertEquals(transformed.get(0).toString(), TEXT_LOWER_CASE); assertEquals(transformed.get(0).toString(), TEXT_LOWER_CASE);
Assert.assertEquals(transformed, output); assertEquals(transformed, output);
transform = new ChangeCaseStringTransform(STRING_COLUMN, ChangeCaseStringTransform.CaseType.UPPER); transform = new ChangeCaseStringTransform(STRING_COLUMN, ChangeCaseStringTransform.CaseType.UPPER);
transform.setInputSchema(schema); transform.setInputSchema(schema);
@ -329,12 +329,12 @@ public class TestTransforms extends BaseND4JTest {
output.add(new Text(TEXT_UPPER_CASE)); output.add(new Text(TEXT_UPPER_CASE));
output.add(new Text(TEXT_MIXED_CASE)); output.add(new Text(TEXT_MIXED_CASE));
transformed = transform.map(input); transformed = transform.map(input);
Assert.assertEquals(transformed.get(0).toString(), TEXT_UPPER_CASE); assertEquals(transformed.get(0).toString(), TEXT_UPPER_CASE);
Assert.assertEquals(transformed, output); assertEquals(transformed, output);
String s = JsonMappers.getMapper().writeValueAsString(transform); String s = JsonMappers.getMapper().writeValueAsString(transform);
Transform transform2 = JsonMappers.getMapper().readValue(s, ChangeCaseStringTransform.class); Transform transform2 = JsonMappers.getMapper().readValue(s, ChangeCaseStringTransform.class);
Assert.assertEquals(transform, transform2); assertEquals(transform, transform2);
} }
@Test @Test
@ -1530,7 +1530,7 @@ public class TestTransforms extends BaseND4JTest {
String json = JsonMappers.getMapper().writeValueAsString(t); String json = JsonMappers.getMapper().writeValueAsString(t);
Transform transform2 = JsonMappers.getMapper().readValue(json, StringListToCountsNDArrayTransform.class); Transform transform2 = JsonMappers.getMapper().readValue(json, StringListToCountsNDArrayTransform.class);
Assert.assertEquals(t, transform2); assertEquals(t, transform2);
} }
@ -1551,7 +1551,7 @@ public class TestTransforms extends BaseND4JTest {
String json = JsonMappers.getMapper().writeValueAsString(t); String json = JsonMappers.getMapper().writeValueAsString(t);
Transform transform2 = JsonMappers.getMapper().readValue(json, StringListToIndicesNDArrayTransform.class); Transform transform2 = JsonMappers.getMapper().readValue(json, StringListToIndicesNDArrayTransform.class);
Assert.assertEquals(t, transform2); assertEquals(t, transform2);
} }

View File

@ -29,7 +29,7 @@ import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.NDArrayWritable; import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -39,7 +39,7 @@ import org.nd4j.linalg.ops.transforms.Transforms;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestNDArrayWritableTransforms extends BaseND4JTest { public class TestNDArrayWritableTransforms extends BaseND4JTest {

View File

@ -30,13 +30,13 @@ import org.datavec.api.transform.ndarray.NDArrayScalarOpTransform;
import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.serde.JsonSerializer; import org.datavec.api.transform.serde.JsonSerializer;
import org.datavec.api.transform.serde.YamlSerializer; import org.datavec.api.transform.serde.YamlSerializer;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestYamlJsonSerde extends BaseND4JTest { public class TestYamlJsonSerde extends BaseND4JTest {

View File

@ -17,29 +17,29 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.transform.transform.parse; package org.datavec.api.transform.transform.parse;
import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals; @DisplayName("Parse Double Transform Test")
class ParseDoubleTransformTest extends BaseND4JTest {
public class ParseDoubleTransformTest extends BaseND4JTest {
@Test @Test
public void testDoubleTransform() { @DisplayName("Test Double Transform")
void testDoubleTransform() {
List<Writable> record = new ArrayList<>(); List<Writable> record = new ArrayList<>();
record.add(new Text("0.0")); record.add(new Text("0.0"));
List<Writable> transformed = Arrays.<Writable>asList(new DoubleWritable(0.0)); List<Writable> transformed = Arrays.<Writable>asList(new DoubleWritable(0.0));
assertEquals(transformed, new ParseDoubleTransform().map(record)); assertEquals(transformed, new ParseDoubleTransform().map(record));
} }
} }

View File

@ -35,26 +35,26 @@ import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.joda.time.DateTimeZone; import org.joda.time.DateTimeZone;
import org.junit.Ignore; import org.junit.jupiter.api.Disabled;
import org.junit.Rule;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.io.File; import java.io.File;
import java.nio.file.Path;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestUI extends BaseND4JTest { public class TestUI extends BaseND4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test @Test
public void testUI() throws Exception { public void testUI(@TempDir Path testDir) throws Exception {
Schema schema = new Schema.Builder().addColumnString("StringColumn").addColumnInteger("IntColumn") Schema schema = new Schema.Builder().addColumnString("StringColumn").addColumnInteger("IntColumn")
.addColumnInteger("IntColumn2").addColumnInteger("IntColumn3") .addColumnInteger("IntColumn2").addColumnInteger("IntColumn3")
.addColumnTime("TimeColumn", DateTimeZone.UTC).build(); .addColumnTime("TimeColumn", DateTimeZone.UTC).build();
@ -92,7 +92,7 @@ public class TestUI extends BaseND4JTest {
DataAnalysis da = new DataAnalysis(schema, list); DataAnalysis da = new DataAnalysis(schema, list);
File fDir = testDir.newFolder(); File fDir = testDir.toFile();
String tempDir = fDir.getAbsolutePath(); String tempDir = fDir.getAbsolutePath();
String outPath = FilenameUtils.concat(tempDir, "datavec_transform_UITest.html"); String outPath = FilenameUtils.concat(tempDir, "datavec_transform_UITest.html");
System.out.println(outPath); System.out.println(outPath);
@ -143,7 +143,7 @@ public class TestUI extends BaseND4JTest {
@Test @Test
@Ignore @Disabled
public void testSequencePlot() throws Exception { public void testSequencePlot() throws Exception {
Schema schema = new SequenceSchema.Builder().addColumnDouble("sinx") Schema schema = new SequenceSchema.Builder().addColumnDouble("sinx")

View File

@ -17,30 +17,31 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.util; package org.datavec.api.util;
import org.junit.Before; import org.junit.jupiter.api.BeforeEach;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.io.BufferedReader; import java.io.BufferedReader;
import java.io.File; import java.io.File;
import java.io.InputStream; import java.io.InputStream;
import java.io.InputStreamReader; import java.io.InputStreamReader;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.Assert.assertTrue;
import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.core.AnyOf.anyOf; import static org.hamcrest.core.AnyOf.anyOf;
import static org.hamcrest.core.IsEqual.equalTo; import static org.hamcrest.core.IsEqual.equalTo;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
public class ClassPathResourceTest extends BaseND4JTest { @DisplayName("Class Path Resource Test")
class ClassPathResourceTest extends BaseND4JTest {
private boolean isWindows = false; //File sizes are reported slightly different on Linux vs. Windows // File sizes are reported slightly different on Linux vs. Windows
private boolean isWindows = false;
@Before @BeforeEach
public void setUp() throws Exception { void setUp() throws Exception {
String osname = System.getProperty("os.name"); String osname = System.getProperty("os.name");
if (osname != null && osname.toLowerCase().contains("win")) { if (osname != null && osname.toLowerCase().contains("win")) {
isWindows = true; isWindows = true;
@ -48,9 +49,9 @@ public class ClassPathResourceTest extends BaseND4JTest {
} }
@Test @Test
public void testGetFile1() throws Exception { @DisplayName("Test Get File 1")
void testGetFile1() throws Exception {
File intFile = new ClassPathResource("datavec-api/iris.dat").getFile(); File intFile = new ClassPathResource("datavec-api/iris.dat").getFile();
assertTrue(intFile.exists()); assertTrue(intFile.exists());
if (isWindows) { if (isWindows) {
assertThat(intFile.length(), anyOf(equalTo(2700L), equalTo(2850L))); assertThat(intFile.length(), anyOf(equalTo(2700L), equalTo(2850L)));
@ -60,9 +61,9 @@ public class ClassPathResourceTest extends BaseND4JTest {
} }
@Test @Test
public void testGetFileSlash1() throws Exception { @DisplayName("Test Get File Slash 1")
void testGetFileSlash1() throws Exception {
File intFile = new ClassPathResource("datavec-api/iris.dat").getFile(); File intFile = new ClassPathResource("datavec-api/iris.dat").getFile();
assertTrue(intFile.exists()); assertTrue(intFile.exists());
if (isWindows) { if (isWindows) {
assertThat(intFile.length(), anyOf(equalTo(2700L), equalTo(2850L))); assertThat(intFile.length(), anyOf(equalTo(2700L), equalTo(2850L)));
@ -72,11 +73,10 @@ public class ClassPathResourceTest extends BaseND4JTest {
} }
@Test @Test
public void testGetFileWithSpace1() throws Exception { @DisplayName("Test Get File With Space 1")
void testGetFileWithSpace1() throws Exception {
File intFile = new ClassPathResource("datavec-api/csvsequence test.txt").getFile(); File intFile = new ClassPathResource("datavec-api/csvsequence test.txt").getFile();
assertTrue(intFile.exists()); assertTrue(intFile.exists());
if (isWindows) { if (isWindows) {
assertThat(intFile.length(), anyOf(equalTo(60L), equalTo(64L))); assertThat(intFile.length(), anyOf(equalTo(60L), equalTo(64L)));
} else { } else {
@ -85,16 +85,15 @@ public class ClassPathResourceTest extends BaseND4JTest {
} }
@Test @Test
public void testInputStream() throws Exception { @DisplayName("Test Input Stream")
void testInputStream() throws Exception {
ClassPathResource resource = new ClassPathResource("datavec-api/csvsequence_1.txt"); ClassPathResource resource = new ClassPathResource("datavec-api/csvsequence_1.txt");
File intFile = resource.getFile(); File intFile = resource.getFile();
if (isWindows) { if (isWindows) {
assertThat(intFile.length(), anyOf(equalTo(60L), equalTo(64L))); assertThat(intFile.length(), anyOf(equalTo(60L), equalTo(64L)));
} else { } else {
assertEquals(60, intFile.length()); assertEquals(60, intFile.length());
} }
InputStream stream = resource.getInputStream(); InputStream stream = resource.getInputStream();
BufferedReader reader = new BufferedReader(new InputStreamReader(stream)); BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
String line = ""; String line = "";
@ -102,21 +101,19 @@ public class ClassPathResourceTest extends BaseND4JTest {
while ((line = reader.readLine()) != null) { while ((line = reader.readLine()) != null) {
cnt++; cnt++;
} }
assertEquals(5, cnt); assertEquals(5, cnt);
} }
@Test @Test
public void testInputStreamSlash() throws Exception { @DisplayName("Test Input Stream Slash")
void testInputStreamSlash() throws Exception {
ClassPathResource resource = new ClassPathResource("datavec-api/csvsequence_1.txt"); ClassPathResource resource = new ClassPathResource("datavec-api/csvsequence_1.txt");
File intFile = resource.getFile(); File intFile = resource.getFile();
if (isWindows) { if (isWindows) {
assertThat(intFile.length(), anyOf(equalTo(60L), equalTo(64L))); assertThat(intFile.length(), anyOf(equalTo(60L), equalTo(64L)));
} else { } else {
assertEquals(60, intFile.length()); assertEquals(60, intFile.length());
} }
InputStream stream = resource.getInputStream(); InputStream stream = resource.getInputStream();
BufferedReader reader = new BufferedReader(new InputStreamReader(stream)); BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
String line = ""; String line = "";
@ -124,7 +121,6 @@ public class ClassPathResourceTest extends BaseND4JTest {
while ((line = reader.readLine()) != null) { while ((line = reader.readLine()) != null) {
cnt++; cnt++;
} }
assertEquals(5, cnt); assertEquals(5, cnt);
} }
} }

View File

@ -17,44 +17,41 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.util; package org.datavec.api.util;
import org.datavec.api.timeseries.util.TimeSeriesWritableUtils; import org.datavec.api.timeseries.util.TimeSeriesWritableUtils;
import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertArrayEquals; @DisplayName("Time Series Utils Test")
class TimeSeriesUtilsTest extends BaseND4JTest {
public class TimeSeriesUtilsTest extends BaseND4JTest {
@Test @Test
public void testTimeSeriesCreation() { @DisplayName("Test Time Series Creation")
void testTimeSeriesCreation() {
List<List<List<Writable>>> test = new ArrayList<>(); List<List<List<Writable>>> test = new ArrayList<>();
List<List<Writable>> timeStep = new ArrayList<>(); List<List<Writable>> timeStep = new ArrayList<>();
for(int i = 0; i < 5; i++) { for (int i = 0; i < 5; i++) {
timeStep.add(getRecord(5)); timeStep.add(getRecord(5));
} }
test.add(timeStep); test.add(timeStep);
INDArray arr = TimeSeriesWritableUtils.convertWritablesSequence(test).getFirst(); INDArray arr = TimeSeriesWritableUtils.convertWritablesSequence(test).getFirst();
assertArrayEquals(new long[]{1,5,5},arr.shape()); assertArrayEquals(new long[] { 1, 5, 5 }, arr.shape());
} }
private List<Writable> getRecord(int length) { private List<Writable> getRecord(int length) {
List<Writable> ret = new ArrayList<>(); List<Writable> ret = new ArrayList<>();
for(int i = 0; i < length; i++) { for (int i = 0; i < length; i++) {
ret.add(new DoubleWritable(1.0)); ret.add(new DoubleWritable(1.0));
} }
return ret; return ret;
} }
} }

View File

@ -17,52 +17,50 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.writable; package org.datavec.api.writable;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.shade.guava.collect.Lists; import org.nd4j.shade.guava.collect.Lists;
import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.Schema;
import org.datavec.api.util.ndarray.RecordConverter; import org.datavec.api.util.ndarray.RecordConverter;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.TimeZone; import java.util.TimeZone;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals; @DisplayName("Record Converter Test")
class RecordConverterTest extends BaseND4JTest {
public class RecordConverterTest extends BaseND4JTest {
@Test @Test
public void toRecords_PassInClassificationDataSet_ExpectNDArrayAndIntWritables() { @DisplayName("To Records _ Pass In Classification Data Set _ Expect ND Array And Int Writables")
INDArray feature1 = Nd4j.create(new double[]{4, -5.7, 10, -0.1}, new long[]{1, 4}, DataType.FLOAT); void toRecords_PassInClassificationDataSet_ExpectNDArrayAndIntWritables() {
INDArray feature2 = Nd4j.create(new double[]{11, .7, -1.3, 4}, new long[]{1, 4}, DataType.FLOAT); INDArray feature1 = Nd4j.create(new double[] { 4, -5.7, 10, -0.1 }, new long[] { 1, 4 }, DataType.FLOAT);
INDArray label1 = Nd4j.create(new double[]{0, 0, 1, 0}, new long[]{1, 4}, DataType.FLOAT); INDArray feature2 = Nd4j.create(new double[] { 11, .7, -1.3, 4 }, new long[] { 1, 4 }, DataType.FLOAT);
INDArray label2 = Nd4j.create(new double[]{0, 1, 0, 0}, new long[]{1, 4}, DataType.FLOAT); INDArray label1 = Nd4j.create(new double[] { 0, 0, 1, 0 }, new long[] { 1, 4 }, DataType.FLOAT);
DataSet dataSet = new DataSet(Nd4j.vstack(Lists.newArrayList(feature1, feature2)), INDArray label2 = Nd4j.create(new double[] { 0, 1, 0, 0 }, new long[] { 1, 4 }, DataType.FLOAT);
Nd4j.vstack(Lists.newArrayList(label1, label2))); DataSet dataSet = new DataSet(Nd4j.vstack(Lists.newArrayList(feature1, feature2)), Nd4j.vstack(Lists.newArrayList(label1, label2)));
List<List<Writable>> writableList = RecordConverter.toRecords(dataSet); List<List<Writable>> writableList = RecordConverter.toRecords(dataSet);
assertEquals(2, writableList.size()); assertEquals(2, writableList.size());
testClassificationWritables(feature1, 2, writableList.get(0)); testClassificationWritables(feature1, 2, writableList.get(0));
testClassificationWritables(feature2, 1, writableList.get(1)); testClassificationWritables(feature2, 1, writableList.get(1));
} }
@Test @Test
public void toRecords_PassInRegressionDataSet_ExpectNDArrayAndDoubleWritables() { @DisplayName("To Records _ Pass In Regression Data Set _ Expect ND Array And Double Writables")
INDArray feature = Nd4j.create(new double[]{4, -5.7, 10, -0.1}, new long[]{1, 4}, DataType.FLOAT); void toRecords_PassInRegressionDataSet_ExpectNDArrayAndDoubleWritables() {
INDArray label = Nd4j.create(new double[]{.5, 2, 3, .5}, new long[]{1, 4}, DataType.FLOAT); INDArray feature = Nd4j.create(new double[] { 4, -5.7, 10, -0.1 }, new long[] { 1, 4 }, DataType.FLOAT);
INDArray label = Nd4j.create(new double[] { .5, 2, 3, .5 }, new long[] { 1, 4 }, DataType.FLOAT);
DataSet dataSet = new DataSet(feature, label); DataSet dataSet = new DataSet(feature, label);
List<List<Writable>> writableList = RecordConverter.toRecords(dataSet); List<List<Writable>> writableList = RecordConverter.toRecords(dataSet);
List<Writable> results = writableList.get(0); List<Writable> results = writableList.get(0);
NDArrayWritable ndArrayWritable = (NDArrayWritable) results.get(0); NDArrayWritable ndArrayWritable = (NDArrayWritable) results.get(0);
assertEquals(1, writableList.size()); assertEquals(1, writableList.size());
assertEquals(5, results.size()); assertEquals(5, results.size());
assertEquals(feature, ndArrayWritable.get()); assertEquals(feature, ndArrayWritable.get());
@ -72,62 +70,39 @@ public class RecordConverterTest extends BaseND4JTest {
} }
} }
private void testClassificationWritables(INDArray expectedFeatureVector, int expectLabelIndex, private void testClassificationWritables(INDArray expectedFeatureVector, int expectLabelIndex, List<Writable> writables) {
List<Writable> writables) {
NDArrayWritable ndArrayWritable = (NDArrayWritable) writables.get(0); NDArrayWritable ndArrayWritable = (NDArrayWritable) writables.get(0);
IntWritable intWritable = (IntWritable) writables.get(1); IntWritable intWritable = (IntWritable) writables.get(1);
assertEquals(2, writables.size()); assertEquals(2, writables.size());
assertEquals(expectedFeatureVector, ndArrayWritable.get()); assertEquals(expectedFeatureVector, ndArrayWritable.get());
assertEquals(expectLabelIndex, intWritable.get()); assertEquals(expectLabelIndex, intWritable.get());
} }
@Test @Test
public void testNDArrayWritableConcat() { @DisplayName("Test ND Array Writable Concat")
List<Writable> l = Arrays.<Writable>asList(new DoubleWritable(1), void testNDArrayWritableConcat() {
new NDArrayWritable(Nd4j.create(new double[]{2, 3, 4}, new long[]{1, 3}, DataType.FLOAT)), new DoubleWritable(5), List<Writable> l = Arrays.<Writable>asList(new DoubleWritable(1), new NDArrayWritable(Nd4j.create(new double[] { 2, 3, 4 }, new long[] { 1, 3 }, DataType.FLOAT)), new DoubleWritable(5), new NDArrayWritable(Nd4j.create(new double[] { 6, 7, 8 }, new long[] { 1, 3 }, DataType.FLOAT)), new IntWritable(9), new IntWritable(1));
new NDArrayWritable(Nd4j.create(new double[]{6, 7, 8}, new long[]{1, 3}, DataType.FLOAT)), new IntWritable(9), INDArray exp = Nd4j.create(new double[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 1 }, new long[] { 1, 10 }, DataType.FLOAT);
new IntWritable(1));
INDArray exp = Nd4j.create(new double[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 1}, new long[]{1, 10}, DataType.FLOAT);
INDArray act = RecordConverter.toArray(DataType.FLOAT, l); INDArray act = RecordConverter.toArray(DataType.FLOAT, l);
assertEquals(exp, act); assertEquals(exp, act);
} }
@Test @Test
public void testNDArrayWritableConcatToMatrix(){ @DisplayName("Test ND Array Writable Concat To Matrix")
void testNDArrayWritableConcatToMatrix() {
List<Writable> l1 = Arrays.<Writable>asList(new DoubleWritable(1), new NDArrayWritable(Nd4j.create(new double[]{2, 3, 4}, new long[]{1,3}, DataType.FLOAT)), new DoubleWritable(5)); List<Writable> l1 = Arrays.<Writable>asList(new DoubleWritable(1), new NDArrayWritable(Nd4j.create(new double[] { 2, 3, 4 }, new long[] { 1, 3 }, DataType.FLOAT)), new DoubleWritable(5));
List<Writable> l2 = Arrays.<Writable>asList(new DoubleWritable(6), new NDArrayWritable(Nd4j.create(new double[]{7, 8, 9}, new long[]{1,3}, DataType.FLOAT)), new DoubleWritable(10)); List<Writable> l2 = Arrays.<Writable>asList(new DoubleWritable(6), new NDArrayWritable(Nd4j.create(new double[] { 7, 8, 9 }, new long[] { 1, 3 }, DataType.FLOAT)), new DoubleWritable(10));
INDArray exp = Nd4j.create(new double[][] { { 1, 2, 3, 4, 5 }, { 6, 7, 8, 9, 10 } }).castTo(DataType.FLOAT);
INDArray exp = Nd4j.create(new double[][]{ INDArray act = RecordConverter.toMatrix(DataType.FLOAT, Arrays.asList(l1, l2));
{1,2,3,4,5},
{6,7,8,9,10}}).castTo(DataType.FLOAT);
INDArray act = RecordConverter.toMatrix(DataType.FLOAT, Arrays.asList(l1,l2));
assertEquals(exp, act); assertEquals(exp, act);
} }
@Test @Test
public void testToRecordWithListOfObject(){ @DisplayName("Test To Record With List Of Object")
final List<Object> list = Arrays.asList((Object)3, 7.0f, "Foo", "Bar", 1.0, 3f, 3L, 7, 0L); void testToRecordWithListOfObject() {
final Schema schema = new Schema.Builder() final List<Object> list = Arrays.asList((Object) 3, 7.0f, "Foo", "Bar", 1.0, 3f, 3L, 7, 0L);
.addColumnInteger("a") final Schema schema = new Schema.Builder().addColumnInteger("a").addColumnFloat("b").addColumnString("c").addColumnCategorical("d", "Bar", "Baz").addColumnDouble("e").addColumnFloat("f").addColumnLong("g").addColumnInteger("h").addColumnTime("i", TimeZone.getDefault()).build();
.addColumnFloat("b")
.addColumnString("c")
.addColumnCategorical("d", "Bar", "Baz")
.addColumnDouble("e")
.addColumnFloat("f")
.addColumnLong("g")
.addColumnInteger("h")
.addColumnTime("i", TimeZone.getDefault())
.build();
final List<Writable> record = RecordConverter.toRecord(schema, list); final List<Writable> record = RecordConverter.toRecord(schema, list);
assertEquals(record.get(0).toInt(), 3); assertEquals(record.get(0).toInt(), 3);
assertEquals(record.get(1).toFloat(), 7f, 1e-6); assertEquals(record.get(1).toFloat(), 7f, 1e-6);
assertEquals(record.get(2).toString(), "Foo"); assertEquals(record.get(2).toString(), "Foo");
@ -137,7 +112,5 @@ public class RecordConverterTest extends BaseND4JTest {
assertEquals(record.get(6).toLong(), 3L); assertEquals(record.get(6).toLong(), 3L);
assertEquals(record.get(7).toInt(), 7); assertEquals(record.get(7).toInt(), 7);
assertEquals(record.get(8).toLong(), 0); assertEquals(record.get(8).toLong(), 0);
} }
} }

View File

@ -21,14 +21,14 @@
package org.datavec.api.writable; package org.datavec.api.writable;
import org.datavec.api.transform.metadata.NDArrayMetaData; import org.datavec.api.transform.metadata.NDArrayMetaData;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import java.io.*; import java.io.*;
import static org.junit.Assert.*; import static org.junit.jupiter.api.Assertions.*;
public class TestNDArrayWritableAndSerialization extends BaseND4JTest { public class TestNDArrayWritableAndSerialization extends BaseND4JTest {

View File

@ -17,38 +17,38 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.writable; package org.datavec.api.writable;
import org.datavec.api.writable.batch.NDArrayRecordBatch; import org.datavec.api.writable.batch.NDArrayRecordBatch;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import java.nio.Buffer; import java.nio.Buffer;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Iterator; import java.util.Iterator;
import java.util.List; import java.util.List;
import org.junit.jupiter.api.DisplayName;
import static org.junit.Assert.*; import static org.junit.jupiter.api.Assertions.*;
public class WritableTest extends BaseND4JTest { @DisplayName("Writable Test")
class WritableTest extends BaseND4JTest {
@Test @Test
public void testWritableEqualityReflexive() { @DisplayName("Test Writable Equality Reflexive")
void testWritableEqualityReflexive() {
assertEquals(new IntWritable(1), new IntWritable(1)); assertEquals(new IntWritable(1), new IntWritable(1));
assertEquals(new LongWritable(1), new LongWritable(1)); assertEquals(new LongWritable(1), new LongWritable(1));
assertEquals(new DoubleWritable(1), new DoubleWritable(1)); assertEquals(new DoubleWritable(1), new DoubleWritable(1));
assertEquals(new FloatWritable(1), new FloatWritable(1)); assertEquals(new FloatWritable(1), new FloatWritable(1));
assertEquals(new Text("Hello"), new Text("Hello")); assertEquals(new Text("Hello"), new Text("Hello"));
assertEquals(new BytesWritable("Hello".getBytes()),new BytesWritable("Hello".getBytes())); assertEquals(new BytesWritable("Hello".getBytes()), new BytesWritable("Hello".getBytes()));
INDArray ndArray = Nd4j.rand(new int[]{1, 100}); INDArray ndArray = Nd4j.rand(new int[] { 1, 100 });
assertEquals(new NDArrayWritable(ndArray), new NDArrayWritable(ndArray)); assertEquals(new NDArrayWritable(ndArray), new NDArrayWritable(ndArray));
assertEquals(new NullWritable(), new NullWritable()); assertEquals(new NullWritable(), new NullWritable());
assertEquals(new BooleanWritable(true), new BooleanWritable(true)); assertEquals(new BooleanWritable(true), new BooleanWritable(true));
@ -56,9 +56,9 @@ public class WritableTest extends BaseND4JTest {
assertEquals(new ByteWritable(b), new ByteWritable(b)); assertEquals(new ByteWritable(b), new ByteWritable(b));
} }
@Test @Test
public void testBytesWritableIndexing() { @DisplayName("Test Bytes Writable Indexing")
void testBytesWritableIndexing() {
byte[] doubleWrite = new byte[16]; byte[] doubleWrite = new byte[16];
ByteBuffer wrapped = ByteBuffer.wrap(doubleWrite); ByteBuffer wrapped = ByteBuffer.wrap(doubleWrite);
Buffer buffer = (Buffer) wrapped; Buffer buffer = (Buffer) wrapped;
@ -66,53 +66,51 @@ public class WritableTest extends BaseND4JTest {
wrapped.putDouble(2.0); wrapped.putDouble(2.0);
buffer.rewind(); buffer.rewind();
BytesWritable byteWritable = new BytesWritable(doubleWrite); BytesWritable byteWritable = new BytesWritable(doubleWrite);
assertEquals(2,byteWritable.getDouble(1),1e-1); assertEquals(2, byteWritable.getDouble(1), 1e-1);
DataBuffer dataBuffer = Nd4j.createBuffer(new double[] {1,2}); DataBuffer dataBuffer = Nd4j.createBuffer(new double[] { 1, 2 });
double[] d1 = dataBuffer.asDouble(); double[] d1 = dataBuffer.asDouble();
double[] d2 = byteWritable.asNd4jBuffer(DataType.DOUBLE,8).asDouble(); double[] d2 = byteWritable.asNd4jBuffer(DataType.DOUBLE, 8).asDouble();
assertArrayEquals(d1, d2, 0.0); assertArrayEquals(d1, d2, 0.0);
} }
@Test @Test
public void testByteWritable() { @DisplayName("Test Byte Writable")
void testByteWritable() {
byte b = 0xfffffffe; byte b = 0xfffffffe;
assertEquals(new IntWritable(-2), new ByteWritable(b)); assertEquals(new IntWritable(-2), new ByteWritable(b));
assertEquals(new LongWritable(-2), new ByteWritable(b)); assertEquals(new LongWritable(-2), new ByteWritable(b));
assertEquals(new ByteWritable(b), new IntWritable(-2)); assertEquals(new ByteWritable(b), new IntWritable(-2));
assertEquals(new ByteWritable(b), new LongWritable(-2)); assertEquals(new ByteWritable(b), new LongWritable(-2));
// those would cast to the same Int // those would cast to the same Int
byte minus126 = 0xffffff82; byte minus126 = 0xffffff82;
assertNotEquals(new ByteWritable(minus126), new IntWritable(130)); assertNotEquals(new ByteWritable(minus126), new IntWritable(130));
} }
@Test @Test
public void testIntLongWritable() { @DisplayName("Test Int Long Writable")
void testIntLongWritable() {
assertEquals(new IntWritable(1), new LongWritable(1l)); assertEquals(new IntWritable(1), new LongWritable(1l));
assertEquals(new LongWritable(2l), new IntWritable(2)); assertEquals(new LongWritable(2l), new IntWritable(2));
long l = 1L << 34; long l = 1L << 34;
// those would cast to the same Int // those would cast to the same Int
assertNotEquals(new LongWritable(l), new IntWritable(4)); assertNotEquals(new LongWritable(l), new IntWritable(4));
} }
@Test @Test
public void testDoubleFloatWritable() { @DisplayName("Test Double Float Writable")
void testDoubleFloatWritable() {
assertEquals(new DoubleWritable(1d), new FloatWritable(1f)); assertEquals(new DoubleWritable(1d), new FloatWritable(1f));
assertEquals(new FloatWritable(2f), new DoubleWritable(2d)); assertEquals(new FloatWritable(2f), new DoubleWritable(2d));
// we defer to Java equality for Floats // we defer to Java equality for Floats
assertNotEquals(new DoubleWritable(1.1d), new FloatWritable(1.1f)); assertNotEquals(new DoubleWritable(1.1d), new FloatWritable(1.1f));
// same idea as above // same idea as above
assertNotEquals(new DoubleWritable(1.1d), new FloatWritable((float)1.1d)); assertNotEquals(new DoubleWritable(1.1d), new FloatWritable((float) 1.1d));
assertNotEquals(new DoubleWritable((double) Float.MAX_VALUE + 1), new FloatWritable(Float.POSITIVE_INFINITY));
assertNotEquals(new DoubleWritable((double)Float.MAX_VALUE + 1), new FloatWritable(Float.POSITIVE_INFINITY));
} }
@Test @Test
public void testFuzzies() { @DisplayName("Test Fuzzies")
void testFuzzies() {
assertTrue(new DoubleWritable(1.1d).fuzzyEquals(new FloatWritable(1.1f), 1e-6d)); assertTrue(new DoubleWritable(1.1d).fuzzyEquals(new FloatWritable(1.1f), 1e-6d));
assertTrue(new FloatWritable(1.1f).fuzzyEquals(new DoubleWritable(1.1d), 1e-6d)); assertTrue(new FloatWritable(1.1f).fuzzyEquals(new DoubleWritable(1.1d), 1e-6d));
byte b = 0xfffffffe; byte b = 0xfffffffe;
@ -122,62 +120,57 @@ public class WritableTest extends BaseND4JTest {
assertTrue(new LongWritable(1).fuzzyEquals(new DoubleWritable(1.05f), 1e-1d)); assertTrue(new LongWritable(1).fuzzyEquals(new DoubleWritable(1.05f), 1e-1d));
} }
@Test @Test
public void testNDArrayRecordBatch(){ @DisplayName("Test ND Array Record Batch")
void testNDArrayRecordBatch() {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
// Outer list over writables/columns, inner list over examples
List<List<INDArray>> orig = new ArrayList<>(); //Outer list over writables/columns, inner list over examples List<List<INDArray>> orig = new ArrayList<>();
for( int i=0; i<3; i++ ){ for (int i = 0; i < 3; i++) {
orig.add(new ArrayList<INDArray>()); orig.add(new ArrayList<INDArray>());
} }
for (int i = 0; i < 5; i++) {
for( int i=0; i<5; i++ ){ orig.get(0).add(Nd4j.rand(1, 10));
orig.get(0).add(Nd4j.rand(1,10)); orig.get(1).add(Nd4j.rand(new int[] { 1, 5, 6 }));
orig.get(1).add(Nd4j.rand(new int[]{1,5,6})); orig.get(2).add(Nd4j.rand(new int[] { 1, 3, 4, 5 }));
orig.get(2).add(Nd4j.rand(new int[]{1,3,4,5}));
} }
// Outer list over examples, inner list over writables
List<List<INDArray>> origByExample = new ArrayList<>(); //Outer list over examples, inner list over writables List<List<INDArray>> origByExample = new ArrayList<>();
for( int i=0; i<5; i++ ){ for (int i = 0; i < 5; i++) {
origByExample.add(Arrays.asList(orig.get(0).get(i), orig.get(1).get(i), orig.get(2).get(i))); origByExample.add(Arrays.asList(orig.get(0).get(i), orig.get(1).get(i), orig.get(2).get(i)));
} }
List<INDArray> batched = new ArrayList<>(); List<INDArray> batched = new ArrayList<>();
for(List<INDArray> l : orig){ for (List<INDArray> l : orig) {
batched.add(Nd4j.concat(0, l.toArray(new INDArray[5]))); batched.add(Nd4j.concat(0, l.toArray(new INDArray[5])));
} }
NDArrayRecordBatch batch = new NDArrayRecordBatch(batched); NDArrayRecordBatch batch = new NDArrayRecordBatch(batched);
assertEquals(5, batch.size()); assertEquals(5, batch.size());
for( int i=0; i<5; i++ ){ for (int i = 0; i < 5; i++) {
List<Writable> act = batch.get(i); List<Writable> act = batch.get(i);
List<INDArray> unboxed = new ArrayList<>(); List<INDArray> unboxed = new ArrayList<>();
for(Writable w : act){ for (Writable w : act) {
unboxed.add(((NDArrayWritable)w).get()); unboxed.add(((NDArrayWritable) w).get());
} }
List<INDArray> exp = origByExample.get(i); List<INDArray> exp = origByExample.get(i);
assertEquals(exp.size(), unboxed.size()); assertEquals(exp.size(), unboxed.size());
for( int j=0; j<exp.size(); j++ ){ for (int j = 0; j < exp.size(); j++) {
assertEquals(exp.get(j), unboxed.get(j)); assertEquals(exp.get(j), unboxed.get(j));
} }
} }
Iterator<List<Writable>> iter = batch.iterator(); Iterator<List<Writable>> iter = batch.iterator();
int count = 0; int count = 0;
while(iter.hasNext()){ while (iter.hasNext()) {
List<Writable> next = iter.next(); List<Writable> next = iter.next();
List<INDArray> unboxed = new ArrayList<>(); List<INDArray> unboxed = new ArrayList<>();
for(Writable w : next){ for (Writable w : next) {
unboxed.add(((NDArrayWritable)w).get()); unboxed.add(((NDArrayWritable) w).get());
} }
List<INDArray> exp = origByExample.get(count++); List<INDArray> exp = origByExample.get(count++);
assertEquals(exp.size(), unboxed.size()); assertEquals(exp.size(), unboxed.size());
for( int j=0; j<exp.size(); j++ ){ for (int j = 0; j < exp.size(); j++) {
assertEquals(exp.get(j), unboxed.get(j)); assertEquals(exp.get(j), unboxed.get(j));
} }
} }
assertEquals(5, count); assertEquals(5, count);
} }
} }

View File

@ -60,10 +60,10 @@
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>nd4j-tests-cpu</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-11.0</id> <id>nd4j-tests-cuda</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.arrow; package org.datavec.arrow;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@ -42,461 +41,397 @@ import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.*; import org.datavec.api.writable.*;
import org.datavec.arrow.recordreader.ArrowRecordReader; import org.datavec.arrow.recordreader.ArrowRecordReader;
import org.datavec.arrow.recordreader.ArrowWritableRecordBatch; import org.datavec.arrow.recordreader.ArrowWritableRecordBatch;
import org.junit.Rule;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.TemporaryFolder; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.primitives.Pair; import org.nd4j.common.primitives.Pair;
import java.io.ByteArrayOutputStream; import java.io.ByteArrayOutputStream;
import java.io.File; import java.io.File;
import java.io.FileOutputStream; import java.io.FileOutputStream;
import java.io.IOException; import java.io.IOException;
import java.util.*; import java.util.*;
import static java.nio.channels.Channels.newChannel; import static java.nio.channels.Channels.newChannel;
import static junit.framework.TestCase.assertTrue; import static org.junit.jupiter.api.Assertions.*;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals; import org.junit.jupiter.api.DisplayName;
import static org.junit.Assert.assertFalse; import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
@Slf4j @Slf4j
public class ArrowConverterTest extends BaseND4JTest { @DisplayName("Arrow Converter Test")
class ArrowConverterTest extends BaseND4JTest {
private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE); private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE);
@Rule @TempDir
public TemporaryFolder testDir = new TemporaryFolder(); public Path testDir;
@Test @Test
public void testToArrayFromINDArray() { @DisplayName("Test To Array From IND Array")
void testToArrayFromINDArray() {
Schema.Builder schemaBuilder = new Schema.Builder(); Schema.Builder schemaBuilder = new Schema.Builder();
schemaBuilder.addColumnNDArray("outputArray",new long[]{1,4}); schemaBuilder.addColumnNDArray("outputArray", new long[] { 1, 4 });
Schema schema = schemaBuilder.build(); Schema schema = schemaBuilder.build();
int numRows = 4; int numRows = 4;
List<List<Writable>> ret = new ArrayList<>(numRows); List<List<Writable>> ret = new ArrayList<>(numRows);
for(int i = 0; i < numRows; i++) { for (int i = 0; i < numRows; i++) {
ret.add(Arrays.<Writable>asList(new NDArrayWritable(Nd4j.linspace(1,4,4).reshape(1, 4)))); ret.add(Arrays.<Writable>asList(new NDArrayWritable(Nd4j.linspace(1, 4, 4).reshape(1, 4))));
} }
List<FieldVector> fieldVectors = ArrowConverter.toArrowColumns(bufferAllocator, schema, ret); List<FieldVector> fieldVectors = ArrowConverter.toArrowColumns(bufferAllocator, schema, ret);
ArrowWritableRecordBatch arrowWritableRecordBatch = new ArrowWritableRecordBatch(fieldVectors,schema); ArrowWritableRecordBatch arrowWritableRecordBatch = new ArrowWritableRecordBatch(fieldVectors, schema);
INDArray array = ArrowConverter.toArray(arrowWritableRecordBatch); INDArray array = ArrowConverter.toArray(arrowWritableRecordBatch);
assertArrayEquals(new long[]{4,4},array.shape()); assertArrayEquals(new long[] { 4, 4 }, array.shape());
INDArray assertion = Nd4j.repeat(Nd4j.linspace(1, 4, 4), 4).reshape(4, 4);
INDArray assertion = Nd4j.repeat(Nd4j.linspace(1,4,4),4).reshape(4,4); assertEquals(assertion, array);
assertEquals(assertion,array);
} }
@Test @Test
public void testArrowColumnINDArray() { @DisplayName("Test Arrow Column IND Array")
void testArrowColumnINDArray() {
Schema.Builder schema = new Schema.Builder(); Schema.Builder schema = new Schema.Builder();
List<String> single = new ArrayList<>(); List<String> single = new ArrayList<>();
int numCols = 2; int numCols = 2;
INDArray arr = Nd4j.linspace(1,4,4); INDArray arr = Nd4j.linspace(1, 4, 4);
for(int i = 0; i < numCols; i++) { for (int i = 0; i < numCols; i++) {
schema.addColumnNDArray(String.valueOf(i),new long[]{1,4}); schema.addColumnNDArray(String.valueOf(i), new long[] { 1, 4 });
single.add(String.valueOf(i)); single.add(String.valueOf(i));
} }
Schema buildSchema = schema.build(); Schema buildSchema = schema.build();
List<List<Writable>> list = new ArrayList<>(); List<List<Writable>> list = new ArrayList<>();
List<Writable> firstRow = new ArrayList<>(); List<Writable> firstRow = new ArrayList<>();
for(int i = 0 ; i < numCols; i++) { for (int i = 0; i < numCols; i++) {
firstRow.add(new NDArrayWritable(arr)); firstRow.add(new NDArrayWritable(arr));
} }
list.add(firstRow); list.add(firstRow);
List<FieldVector> fieldVectors = ArrowConverter.toArrowColumns(bufferAllocator, buildSchema, list); List<FieldVector> fieldVectors = ArrowConverter.toArrowColumns(bufferAllocator, buildSchema, list);
assertEquals(numCols,fieldVectors.size()); assertEquals(numCols, fieldVectors.size());
assertEquals(1,fieldVectors.get(0).getValueCount()); assertEquals(1, fieldVectors.get(0).getValueCount());
assertFalse(fieldVectors.get(0).isNull(0)); assertFalse(fieldVectors.get(0).isNull(0));
ArrowWritableRecordBatch arrowWritableRecordBatch = ArrowConverter.toArrowWritables(fieldVectors, buildSchema); ArrowWritableRecordBatch arrowWritableRecordBatch = ArrowConverter.toArrowWritables(fieldVectors, buildSchema);
assertEquals(1,arrowWritableRecordBatch.size()); assertEquals(1, arrowWritableRecordBatch.size());
Writable writable = arrowWritableRecordBatch.get(0).get(0); Writable writable = arrowWritableRecordBatch.get(0).get(0);
assertTrue(writable instanceof NDArrayWritable); assertTrue(writable instanceof NDArrayWritable);
NDArrayWritable ndArrayWritable = (NDArrayWritable) writable; NDArrayWritable ndArrayWritable = (NDArrayWritable) writable;
assertEquals(arr,ndArrayWritable.get()); assertEquals(arr, ndArrayWritable.get());
Writable writable1 = ArrowConverter.fromEntry(0, fieldVectors.get(0), ColumnType.NDArray); Writable writable1 = ArrowConverter.fromEntry(0, fieldVectors.get(0), ColumnType.NDArray);
NDArrayWritable ndArrayWritablewritable1 = (NDArrayWritable) writable1; NDArrayWritable ndArrayWritablewritable1 = (NDArrayWritable) writable1;
System.out.println(ndArrayWritablewritable1.get()); System.out.println(ndArrayWritablewritable1.get());
} }
@Test @Test
public void testArrowColumnString() { @DisplayName("Test Arrow Column String")
void testArrowColumnString() {
Schema.Builder schema = new Schema.Builder(); Schema.Builder schema = new Schema.Builder();
List<String> single = new ArrayList<>(); List<String> single = new ArrayList<>();
for(int i = 0; i < 2; i++) { for (int i = 0; i < 2; i++) {
schema.addColumnInteger(String.valueOf(i)); schema.addColumnInteger(String.valueOf(i));
single.add(String.valueOf(i)); single.add(String.valueOf(i));
} }
List<FieldVector> fieldVectors = ArrowConverter.toArrowColumnsStringSingle(bufferAllocator, schema.build(), single); List<FieldVector> fieldVectors = ArrowConverter.toArrowColumnsStringSingle(bufferAllocator, schema.build(), single);
List<List<Writable>> records = ArrowConverter.toArrowWritables(fieldVectors, schema.build()); List<List<Writable>> records = ArrowConverter.toArrowWritables(fieldVectors, schema.build());
List<List<Writable>> assertion = new ArrayList<>(); List<List<Writable>> assertion = new ArrayList<>();
assertion.add(Arrays.<Writable>asList(new IntWritable(0),new IntWritable(1))); assertion.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(1)));
assertEquals(assertion,records); assertEquals(assertion, records);
List<List<String>> batch = new ArrayList<>(); List<List<String>> batch = new ArrayList<>();
for(int i = 0; i < 2; i++) { for (int i = 0; i < 2; i++) {
batch.add(Arrays.asList(String.valueOf(i),String.valueOf(i))); batch.add(Arrays.asList(String.valueOf(i), String.valueOf(i)));
} }
List<FieldVector> fieldVectorsBatch = ArrowConverter.toArrowColumnsString(bufferAllocator, schema.build(), batch); List<FieldVector> fieldVectorsBatch = ArrowConverter.toArrowColumnsString(bufferAllocator, schema.build(), batch);
List<List<Writable>> batchRecords = ArrowConverter.toArrowWritables(fieldVectorsBatch, schema.build()); List<List<Writable>> batchRecords = ArrowConverter.toArrowWritables(fieldVectorsBatch, schema.build());
List<List<Writable>> assertionBatch = new ArrayList<>(); List<List<Writable>> assertionBatch = new ArrayList<>();
assertionBatch.add(Arrays.<Writable>asList(new IntWritable(0),new IntWritable(0))); assertionBatch.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(0)));
assertionBatch.add(Arrays.<Writable>asList(new IntWritable(1),new IntWritable(1))); assertionBatch.add(Arrays.<Writable>asList(new IntWritable(1), new IntWritable(1)));
assertEquals(assertionBatch,batchRecords); assertEquals(assertionBatch, batchRecords);
} }
@Test @Test
public void testArrowBatchSetTime() { @DisplayName("Test Arrow Batch Set Time")
void testArrowBatchSetTime() {
Schema.Builder schema = new Schema.Builder(); Schema.Builder schema = new Schema.Builder();
List<String> single = new ArrayList<>(); List<String> single = new ArrayList<>();
for(int i = 0; i < 2; i++) { for (int i = 0; i < 2; i++) {
schema.addColumnTime(String.valueOf(i),TimeZone.getDefault()); schema.addColumnTime(String.valueOf(i), TimeZone.getDefault());
single.add(String.valueOf(i)); single.add(String.valueOf(i));
} }
List<List<Writable>> input = Arrays.asList(Arrays.<Writable>asList(new LongWritable(0), new LongWritable(1)), Arrays.<Writable>asList(new LongWritable(2), new LongWritable(3)));
List<List<Writable>> input = Arrays.asList( List<FieldVector> fieldVector = ArrowConverter.toArrowColumns(bufferAllocator, schema.build(), input);
Arrays.<Writable>asList(new LongWritable(0),new LongWritable(1)), ArrowWritableRecordBatch writableRecordBatch = new ArrowWritableRecordBatch(fieldVector, schema.build());
Arrays.<Writable>asList(new LongWritable(2),new LongWritable(3))
);
List<FieldVector> fieldVector = ArrowConverter.toArrowColumns(bufferAllocator,schema.build(),input);
ArrowWritableRecordBatch writableRecordBatch = new ArrowWritableRecordBatch(fieldVector,schema.build());
List<Writable> assertion = Arrays.<Writable>asList(new LongWritable(4), new LongWritable(5)); List<Writable> assertion = Arrays.<Writable>asList(new LongWritable(4), new LongWritable(5));
writableRecordBatch.set(1, Arrays.<Writable>asList(new LongWritable(4),new LongWritable(5))); writableRecordBatch.set(1, Arrays.<Writable>asList(new LongWritable(4), new LongWritable(5)));
List<Writable> recordTest = writableRecordBatch.get(1); List<Writable> recordTest = writableRecordBatch.get(1);
assertEquals(assertion,recordTest); assertEquals(assertion, recordTest);
} }
@Test @Test
public void testArrowBatchSet() { @DisplayName("Test Arrow Batch Set")
void testArrowBatchSet() {
Schema.Builder schema = new Schema.Builder(); Schema.Builder schema = new Schema.Builder();
List<String> single = new ArrayList<>(); List<String> single = new ArrayList<>();
for(int i = 0; i < 2; i++) { for (int i = 0; i < 2; i++) {
schema.addColumnInteger(String.valueOf(i)); schema.addColumnInteger(String.valueOf(i));
single.add(String.valueOf(i)); single.add(String.valueOf(i));
} }
List<List<Writable>> input = Arrays.asList(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(1)), Arrays.<Writable>asList(new IntWritable(2), new IntWritable(3)));
List<List<Writable>> input = Arrays.asList( List<FieldVector> fieldVector = ArrowConverter.toArrowColumns(bufferAllocator, schema.build(), input);
Arrays.<Writable>asList(new IntWritable(0),new IntWritable(1)), ArrowWritableRecordBatch writableRecordBatch = new ArrowWritableRecordBatch(fieldVector, schema.build());
Arrays.<Writable>asList(new IntWritable(2),new IntWritable(3))
);
List<FieldVector> fieldVector = ArrowConverter.toArrowColumns(bufferAllocator,schema.build(),input);
ArrowWritableRecordBatch writableRecordBatch = new ArrowWritableRecordBatch(fieldVector,schema.build());
List<Writable> assertion = Arrays.<Writable>asList(new IntWritable(4), new IntWritable(5)); List<Writable> assertion = Arrays.<Writable>asList(new IntWritable(4), new IntWritable(5));
writableRecordBatch.set(1, Arrays.<Writable>asList(new IntWritable(4),new IntWritable(5))); writableRecordBatch.set(1, Arrays.<Writable>asList(new IntWritable(4), new IntWritable(5)));
List<Writable> recordTest = writableRecordBatch.get(1); List<Writable> recordTest = writableRecordBatch.get(1);
assertEquals(assertion,recordTest); assertEquals(assertion, recordTest);
} }
@Test @Test
public void testArrowColumnsStringTimeSeries() { @DisplayName("Test Arrow Columns String Time Series")
void testArrowColumnsStringTimeSeries() {
Schema.Builder schema = new Schema.Builder(); Schema.Builder schema = new Schema.Builder();
List<List<List<String>>> entries = new ArrayList<>(); List<List<List<String>>> entries = new ArrayList<>();
for(int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
schema.addColumnInteger(String.valueOf(i)); schema.addColumnInteger(String.valueOf(i));
} }
for (int i = 0; i < 5; i++) {
for(int i = 0; i < 5; i++) {
List<List<String>> arr = Arrays.asList(Arrays.asList(String.valueOf(i), String.valueOf(i), String.valueOf(i))); List<List<String>> arr = Arrays.asList(Arrays.asList(String.valueOf(i), String.valueOf(i), String.valueOf(i)));
entries.add(arr); entries.add(arr);
} }
List<FieldVector> fieldVectors = ArrowConverter.toArrowColumnsStringTimeSeries(bufferAllocator, schema.build(), entries); List<FieldVector> fieldVectors = ArrowConverter.toArrowColumnsStringTimeSeries(bufferAllocator, schema.build(), entries);
assertEquals(3,fieldVectors.size()); assertEquals(3, fieldVectors.size());
assertEquals(5,fieldVectors.get(0).getValueCount()); assertEquals(5, fieldVectors.get(0).getValueCount());
INDArray exp = Nd4j.create(5, 3); INDArray exp = Nd4j.create(5, 3);
for( int i = 0; i < 5; i++) { for (int i = 0; i < 5; i++) {
exp.getRow(i).assign(i); exp.getRow(i).assign(i);
} }
//Convert to ArrowWritableRecordBatch - note we can't do this in general with time series... // Convert to ArrowWritableRecordBatch - note we can't do this in general with time series...
ArrowWritableRecordBatch wri = ArrowConverter.toArrowWritables(fieldVectors, schema.build()); ArrowWritableRecordBatch wri = ArrowConverter.toArrowWritables(fieldVectors, schema.build());
INDArray arr = ArrowConverter.toArray(wri); INDArray arr = ArrowConverter.toArray(wri);
assertArrayEquals(new long[] {5,3}, arr.shape()); assertArrayEquals(new long[] { 5, 3 }, arr.shape());
assertEquals(exp, arr); assertEquals(exp, arr);
} }
@Test @Test
public void testConvertVector() { @DisplayName("Test Convert Vector")
void testConvertVector() {
Schema.Builder schema = new Schema.Builder(); Schema.Builder schema = new Schema.Builder();
List<List<List<String>>> entries = new ArrayList<>(); List<List<List<String>>> entries = new ArrayList<>();
for(int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
schema.addColumnInteger(String.valueOf(i)); schema.addColumnInteger(String.valueOf(i));
} }
for (int i = 0; i < 5; i++) {
for(int i = 0; i < 5; i++) {
List<List<String>> arr = Arrays.asList(Arrays.asList(String.valueOf(i), String.valueOf(i), String.valueOf(i))); List<List<String>> arr = Arrays.asList(Arrays.asList(String.valueOf(i), String.valueOf(i), String.valueOf(i)));
entries.add(arr); entries.add(arr);
} }
List<FieldVector> fieldVectors = ArrowConverter.toArrowColumnsStringTimeSeries(bufferAllocator, schema.build(), entries); List<FieldVector> fieldVectors = ArrowConverter.toArrowColumnsStringTimeSeries(bufferAllocator, schema.build(), entries);
INDArray arr = ArrowConverter.convertArrowVector(fieldVectors.get(0),schema.build().getType(0)); INDArray arr = ArrowConverter.convertArrowVector(fieldVectors.get(0), schema.build().getType(0));
assertEquals(5,arr.length()); assertEquals(5, arr.length());
} }
@Test @Test
public void testCreateNDArray() throws Exception { @DisplayName("Test Create ND Array")
void testCreateNDArray() throws Exception {
val recordsToWrite = recordToWrite(); val recordsToWrite = recordToWrite();
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(),recordsToWrite.getFirst(),byteArrayOutputStream); ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(), recordsToWrite.getFirst(), byteArrayOutputStream);
File f = testDir.toFile();
File f = testDir.newFolder();
File tmpFile = new File(f, "tmp-arrow-file-" + UUID.randomUUID().toString() + ".arrorw"); File tmpFile = new File(f, "tmp-arrow-file-" + UUID.randomUUID().toString() + ".arrorw");
FileOutputStream outputStream = new FileOutputStream(tmpFile); FileOutputStream outputStream = new FileOutputStream(tmpFile);
tmpFile.deleteOnExit(); tmpFile.deleteOnExit();
ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(),recordsToWrite.getFirst(),outputStream); ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(), recordsToWrite.getFirst(), outputStream);
outputStream.flush(); outputStream.flush();
outputStream.close(); outputStream.close();
Pair<Schema, ArrowWritableRecordBatch> schemaArrowWritableRecordBatchPair = ArrowConverter.readFromFile(tmpFile); Pair<Schema, ArrowWritableRecordBatch> schemaArrowWritableRecordBatchPair = ArrowConverter.readFromFile(tmpFile);
assertEquals(recordsToWrite.getFirst(),schemaArrowWritableRecordBatchPair.getFirst()); assertEquals(recordsToWrite.getFirst(), schemaArrowWritableRecordBatchPair.getFirst());
assertEquals(recordsToWrite.getRight(),schemaArrowWritableRecordBatchPair.getRight().toArrayList()); assertEquals(recordsToWrite.getRight(), schemaArrowWritableRecordBatchPair.getRight().toArrayList());
byte[] arr = byteArrayOutputStream.toByteArray(); byte[] arr = byteArrayOutputStream.toByteArray();
val read = ArrowConverter.readFromBytes(arr); val read = ArrowConverter.readFromBytes(arr);
assertEquals(recordsToWrite,read); assertEquals(recordsToWrite, read);
// send file
//send file
File tmp = tmpDataFile(recordsToWrite); File tmp = tmpDataFile(recordsToWrite);
ArrowRecordReader recordReader = new ArrowRecordReader(); ArrowRecordReader recordReader = new ArrowRecordReader();
recordReader.initialize(new FileSplit(tmp)); recordReader.initialize(new FileSplit(tmp));
recordReader.next(); recordReader.next();
ArrowWritableRecordBatch currentBatch = recordReader.getCurrentBatch(); ArrowWritableRecordBatch currentBatch = recordReader.getCurrentBatch();
INDArray arr2 = ArrowConverter.toArray(currentBatch); INDArray arr2 = ArrowConverter.toArray(currentBatch);
assertEquals(2,arr2.rows()); assertEquals(2, arr2.rows());
assertEquals(2,arr2.columns()); assertEquals(2, arr2.columns());
}
@Test
public void testConvertToArrowVectors() {
INDArray matrix = Nd4j.linspace(1,4,4).reshape(2,2);
val vectors = ArrowConverter.convertToArrowVector(matrix,Arrays.asList("test","test2"), ColumnType.Double,bufferAllocator);
assertEquals(matrix.rows(),vectors.size());
INDArray vector = Nd4j.linspace(1,4,4);
val vectors2 = ArrowConverter.convertToArrowVector(vector,Arrays.asList("test"), ColumnType.Double,bufferAllocator);
assertEquals(1,vectors2.size());
assertEquals(matrix.length(),vectors2.get(0).getValueCount());
} }
@Test @Test
public void testSchemaConversionBasic() { @DisplayName("Test Convert To Arrow Vectors")
void testConvertToArrowVectors() {
INDArray matrix = Nd4j.linspace(1, 4, 4).reshape(2, 2);
val vectors = ArrowConverter.convertToArrowVector(matrix, Arrays.asList("test", "test2"), ColumnType.Double, bufferAllocator);
assertEquals(matrix.rows(), vectors.size());
INDArray vector = Nd4j.linspace(1, 4, 4);
val vectors2 = ArrowConverter.convertToArrowVector(vector, Arrays.asList("test"), ColumnType.Double, bufferAllocator);
assertEquals(1, vectors2.size());
assertEquals(matrix.length(), vectors2.get(0).getValueCount());
}
@Test
@DisplayName("Test Schema Conversion Basic")
void testSchemaConversionBasic() {
Schema.Builder schemaBuilder = new Schema.Builder(); Schema.Builder schemaBuilder = new Schema.Builder();
for(int i = 0; i < 2; i++) { for (int i = 0; i < 2; i++) {
schemaBuilder.addColumnDouble("test-" + i); schemaBuilder.addColumnDouble("test-" + i);
schemaBuilder.addColumnInteger("testi-" + i); schemaBuilder.addColumnInteger("testi-" + i);
schemaBuilder.addColumnLong("testl-" + i); schemaBuilder.addColumnLong("testl-" + i);
schemaBuilder.addColumnFloat("testf-" + i); schemaBuilder.addColumnFloat("testf-" + i);
} }
Schema schema = schemaBuilder.build(); Schema schema = schemaBuilder.build();
val schema2 = ArrowConverter.toArrowSchema(schema); val schema2 = ArrowConverter.toArrowSchema(schema);
assertEquals(8,schema2.getFields().size()); assertEquals(8, schema2.getFields().size());
val convertedSchema = ArrowConverter.toDatavecSchema(schema2); val convertedSchema = ArrowConverter.toDatavecSchema(schema2);
assertEquals(schema,convertedSchema); assertEquals(schema, convertedSchema);
} }
@Test @Test
public void testReadSchemaAndRecordsFromByteArray() throws Exception { @DisplayName("Test Read Schema And Records From Byte Array")
void testReadSchemaAndRecordsFromByteArray() throws Exception {
BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
int valueCount = 3; int valueCount = 3;
List<Field> fields = new ArrayList<>(); List<Field> fields = new ArrayList<>();
fields.add(ArrowConverter.field("field1",new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE))); fields.add(ArrowConverter.field("field1", new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)));
fields.add(ArrowConverter.intField("field2")); fields.add(ArrowConverter.intField("field2"));
List<FieldVector> fieldVectors = new ArrayList<>(); List<FieldVector> fieldVectors = new ArrayList<>();
fieldVectors.add(ArrowConverter.vectorFor(allocator,"field1",new float[] {1,2,3})); fieldVectors.add(ArrowConverter.vectorFor(allocator, "field1", new float[] { 1, 2, 3 }));
fieldVectors.add(ArrowConverter.vectorFor(allocator,"field2",new int[] {1,2,3})); fieldVectors.add(ArrowConverter.vectorFor(allocator, "field2", new int[] { 1, 2, 3 }));
org.apache.arrow.vector.types.pojo.Schema schema = new org.apache.arrow.vector.types.pojo.Schema(fields); org.apache.arrow.vector.types.pojo.Schema schema = new org.apache.arrow.vector.types.pojo.Schema(fields);
VectorSchemaRoot schemaRoot1 = new VectorSchemaRoot(schema, fieldVectors, valueCount); VectorSchemaRoot schemaRoot1 = new VectorSchemaRoot(schema, fieldVectors, valueCount);
VectorUnloader vectorUnloader = new VectorUnloader(schemaRoot1); VectorUnloader vectorUnloader = new VectorUnloader(schemaRoot1);
vectorUnloader.getRecordBatch(); vectorUnloader.getRecordBatch();
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
try(ArrowFileWriter arrowFileWriter = new ArrowFileWriter(schemaRoot1,null,newChannel(byteArrayOutputStream))) { try (ArrowFileWriter arrowFileWriter = new ArrowFileWriter(schemaRoot1, null, newChannel(byteArrayOutputStream))) {
arrowFileWriter.writeBatch(); arrowFileWriter.writeBatch();
} catch (IOException e) { } catch (IOException e) {
log.error("",e); log.error("", e);
} }
byte[] arr = byteArrayOutputStream.toByteArray(); byte[] arr = byteArrayOutputStream.toByteArray();
val arr2 = ArrowConverter.readFromBytes(arr); val arr2 = ArrowConverter.readFromBytes(arr);
assertEquals(2,arr2.getFirst().numColumns()); assertEquals(2, arr2.getFirst().numColumns());
assertEquals(3,arr2.getRight().size()); assertEquals(3, arr2.getRight().size());
val arrowCols = ArrowConverter.toArrowColumns(allocator, arr2.getFirst(), arr2.getRight());
val arrowCols = ArrowConverter.toArrowColumns(allocator,arr2.getFirst(),arr2.getRight()); assertEquals(2, arrowCols.size());
assertEquals(2,arrowCols.size()); assertEquals(valueCount, arrowCols.get(0).getValueCount());
assertEquals(valueCount,arrowCols.get(0).getValueCount());
} }
@Test @Test
public void testVectorForEdgeCases() { @DisplayName("Test Vector For Edge Cases")
void testVectorForEdgeCases() {
BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
val vector = ArrowConverter.vectorFor(allocator,"field1",new float[]{Float.MIN_VALUE,Float.MAX_VALUE}); val vector = ArrowConverter.vectorFor(allocator, "field1", new float[] { Float.MIN_VALUE, Float.MAX_VALUE });
assertEquals(Float.MIN_VALUE,vector.get(0),1e-2); assertEquals(Float.MIN_VALUE, vector.get(0), 1e-2);
assertEquals(Float.MAX_VALUE,vector.get(1),1e-2); assertEquals(Float.MAX_VALUE, vector.get(1), 1e-2);
val vectorInt = ArrowConverter.vectorFor(allocator, "field1", new int[] { Integer.MIN_VALUE, Integer.MAX_VALUE });
val vectorInt = ArrowConverter.vectorFor(allocator,"field1",new int[]{Integer.MIN_VALUE,Integer.MAX_VALUE}); assertEquals(Integer.MIN_VALUE, vectorInt.get(0), 1e-2);
assertEquals(Integer.MIN_VALUE,vectorInt.get(0),1e-2); assertEquals(Integer.MAX_VALUE, vectorInt.get(1), 1e-2);
assertEquals(Integer.MAX_VALUE,vectorInt.get(1),1e-2);
} }
@Test @Test
public void testVectorFor() { @DisplayName("Test Vector For")
void testVectorFor() {
BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
val vector = ArrowConverter.vectorFor(allocator, "field1", new float[] { 1, 2, 3 });
val vector = ArrowConverter.vectorFor(allocator,"field1",new float[]{1,2,3}); assertEquals(3, vector.getValueCount());
assertEquals(3,vector.getValueCount()); assertEquals(1, vector.get(0), 1e-2);
assertEquals(1,vector.get(0),1e-2); assertEquals(2, vector.get(1), 1e-2);
assertEquals(2,vector.get(1),1e-2); assertEquals(3, vector.get(2), 1e-2);
assertEquals(3,vector.get(2),1e-2); val vectorLong = ArrowConverter.vectorFor(allocator, "field1", new long[] { 1, 2, 3 });
assertEquals(3, vectorLong.getValueCount());
val vectorLong = ArrowConverter.vectorFor(allocator,"field1",new long[]{1,2,3}); assertEquals(1, vectorLong.get(0), 1e-2);
assertEquals(3,vectorLong.getValueCount()); assertEquals(2, vectorLong.get(1), 1e-2);
assertEquals(1,vectorLong.get(0),1e-2); assertEquals(3, vectorLong.get(2), 1e-2);
assertEquals(2,vectorLong.get(1),1e-2); val vectorInt = ArrowConverter.vectorFor(allocator, "field1", new int[] { 1, 2, 3 });
assertEquals(3,vectorLong.get(2),1e-2); assertEquals(3, vectorInt.getValueCount());
assertEquals(1, vectorInt.get(0), 1e-2);
assertEquals(2, vectorInt.get(1), 1e-2);
val vectorInt = ArrowConverter.vectorFor(allocator,"field1",new int[]{1,2,3}); assertEquals(3, vectorInt.get(2), 1e-2);
assertEquals(3,vectorInt.getValueCount()); val vectorDouble = ArrowConverter.vectorFor(allocator, "field1", new double[] { 1, 2, 3 });
assertEquals(1,vectorInt.get(0),1e-2); assertEquals(3, vectorDouble.getValueCount());
assertEquals(2,vectorInt.get(1),1e-2); assertEquals(1, vectorDouble.get(0), 1e-2);
assertEquals(3,vectorInt.get(2),1e-2); assertEquals(2, vectorDouble.get(1), 1e-2);
assertEquals(3, vectorDouble.get(2), 1e-2);
val vectorDouble = ArrowConverter.vectorFor(allocator,"field1",new double[]{1,2,3}); val vectorBool = ArrowConverter.vectorFor(allocator, "field1", new boolean[] { true, true, false });
assertEquals(3,vectorDouble.getValueCount()); assertEquals(3, vectorBool.getValueCount());
assertEquals(1,vectorDouble.get(0),1e-2); assertEquals(1, vectorBool.get(0), 1e-2);
assertEquals(2,vectorDouble.get(1),1e-2); assertEquals(1, vectorBool.get(1), 1e-2);
assertEquals(3,vectorDouble.get(2),1e-2); assertEquals(0, vectorBool.get(2), 1e-2);
val vectorBool = ArrowConverter.vectorFor(allocator,"field1",new boolean[]{true,true,false});
assertEquals(3,vectorBool.getValueCount());
assertEquals(1,vectorBool.get(0),1e-2);
assertEquals(1,vectorBool.get(1),1e-2);
assertEquals(0,vectorBool.get(2),1e-2);
} }
@Test @Test
public void testRecordReaderAndWriteFile() throws Exception { @DisplayName("Test Record Reader And Write File")
void testRecordReaderAndWriteFile() throws Exception {
val recordsToWrite = recordToWrite(); val recordsToWrite = recordToWrite();
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(),recordsToWrite.getFirst(),byteArrayOutputStream); ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(), recordsToWrite.getFirst(), byteArrayOutputStream);
byte[] arr = byteArrayOutputStream.toByteArray(); byte[] arr = byteArrayOutputStream.toByteArray();
val read = ArrowConverter.readFromBytes(arr); val read = ArrowConverter.readFromBytes(arr);
assertEquals(recordsToWrite,read); assertEquals(recordsToWrite, read);
// send file
//send file
File tmp = tmpDataFile(recordsToWrite); File tmp = tmpDataFile(recordsToWrite);
RecordReader recordReader = new ArrowRecordReader(); RecordReader recordReader = new ArrowRecordReader();
recordReader.initialize(new FileSplit(tmp)); recordReader.initialize(new FileSplit(tmp));
List<Writable> record = recordReader.next(); List<Writable> record = recordReader.next();
assertEquals(2,record.size()); assertEquals(2, record.size());
} }
@Test @Test
public void testRecordReaderMetaDataList() throws Exception { @DisplayName("Test Record Reader Meta Data List")
void testRecordReaderMetaDataList() throws Exception {
val recordsToWrite = recordToWrite(); val recordsToWrite = recordToWrite();
//send file // send file
File tmp = tmpDataFile(recordsToWrite); File tmp = tmpDataFile(recordsToWrite);
RecordReader recordReader = new ArrowRecordReader(); RecordReader recordReader = new ArrowRecordReader();
RecordMetaDataIndex recordMetaDataIndex = new RecordMetaDataIndex(0,tmp.toURI(),ArrowRecordReader.class); RecordMetaDataIndex recordMetaDataIndex = new RecordMetaDataIndex(0, tmp.toURI(), ArrowRecordReader.class);
recordReader.loadFromMetaData(Arrays.<RecordMetaData>asList(recordMetaDataIndex)); recordReader.loadFromMetaData(Arrays.<RecordMetaData>asList(recordMetaDataIndex));
Record record = recordReader.nextRecord(); Record record = recordReader.nextRecord();
assertEquals(2,record.getRecord().size()); assertEquals(2, record.getRecord().size());
} }
@Test @Test
public void testDates() { @DisplayName("Test Dates")
void testDates() {
Date now = new Date(); Date now = new Date();
BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE); BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE);
TimeStampMilliVector timeStampMilliVector = ArrowConverter.vectorFor(bufferAllocator, "col1", new Date[]{now}); TimeStampMilliVector timeStampMilliVector = ArrowConverter.vectorFor(bufferAllocator, "col1", new Date[] { now });
assertEquals(now.getTime(),timeStampMilliVector.get(0)); assertEquals(now.getTime(), timeStampMilliVector.get(0));
} }
@Test @Test
public void testRecordReaderMetaData() throws Exception { @DisplayName("Test Record Reader Meta Data")
void testRecordReaderMetaData() throws Exception {
val recordsToWrite = recordToWrite(); val recordsToWrite = recordToWrite();
//send file // send file
File tmp = tmpDataFile(recordsToWrite); File tmp = tmpDataFile(recordsToWrite);
RecordReader recordReader = new ArrowRecordReader(); RecordReader recordReader = new ArrowRecordReader();
RecordMetaDataIndex recordMetaDataIndex = new RecordMetaDataIndex(0,tmp.toURI(),ArrowRecordReader.class); RecordMetaDataIndex recordMetaDataIndex = new RecordMetaDataIndex(0, tmp.toURI(), ArrowRecordReader.class);
recordReader.loadFromMetaData(recordMetaDataIndex); recordReader.loadFromMetaData(recordMetaDataIndex);
Record record = recordReader.nextRecord(); Record record = recordReader.nextRecord();
assertEquals(2,record.getRecord().size()); assertEquals(2, record.getRecord().size());
} }
private File tmpDataFile(Pair<Schema,List<List<Writable>>> recordsToWrite) throws IOException { private File tmpDataFile(Pair<Schema, List<List<Writable>>> recordsToWrite) throws IOException {
File f = testDir.toFile();
File f = testDir.newFolder(); // send file
File tmp = new File(f, "tmp-file-" + UUID.randomUUID().toString());
//send file
File tmp = new File(f,"tmp-file-" + UUID.randomUUID().toString());
tmp.mkdirs(); tmp.mkdirs();
File tmpFile = new File(tmp,"data.arrow"); File tmpFile = new File(tmp, "data.arrow");
tmpFile.deleteOnExit(); tmpFile.deleteOnExit();
FileOutputStream bufferedOutputStream = new FileOutputStream(tmpFile); FileOutputStream bufferedOutputStream = new FileOutputStream(tmpFile);
ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(),recordsToWrite.getFirst(),bufferedOutputStream); ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(), recordsToWrite.getFirst(), bufferedOutputStream);
bufferedOutputStream.flush(); bufferedOutputStream.flush();
bufferedOutputStream.close(); bufferedOutputStream.close();
return tmp; return tmp;
} }
private Pair<Schema,List<List<Writable>>> recordToWrite() { private Pair<Schema, List<List<Writable>>> recordToWrite() {
List<List<Writable>> records = new ArrayList<>(); List<List<Writable>> records = new ArrayList<>();
records.add(Arrays.<Writable>asList(new DoubleWritable(0.0),new DoubleWritable(0.0))); records.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new DoubleWritable(0.0)));
records.add(Arrays.<Writable>asList(new DoubleWritable(0.0),new DoubleWritable(0.0))); records.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new DoubleWritable(0.0)));
Schema.Builder schemaBuilder = new Schema.Builder(); Schema.Builder schemaBuilder = new Schema.Builder();
for(int i = 0; i < 2; i++) { for (int i = 0; i < 2; i++) {
schemaBuilder.addColumnFloat("col-" + i); schemaBuilder.addColumnFloat("col-" + i);
} }
return Pair.of(schemaBuilder.build(), records);
return Pair.of(schemaBuilder.build(),records);
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.arrow; package org.datavec.arrow;
import lombok.val; import lombok.val;
@ -34,132 +33,98 @@ import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.datavec.arrow.recordreader.ArrowRecordReader; import org.datavec.arrow.recordreader.ArrowRecordReader;
import org.datavec.arrow.recordreader.ArrowRecordWriter; import org.datavec.arrow.recordreader.ArrowRecordWriter;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.primitives.Triple; import org.nd4j.common.primitives.Triple;
import java.io.File; import java.io.File;
import java.nio.file.Files; import java.nio.file.Files;
import java.nio.file.Path; import java.nio.file.Path;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals; @DisplayName("Record Mapper Test")
class RecordMapperTest extends BaseND4JTest {
public class RecordMapperTest extends BaseND4JTest {
@Test @Test
public void testMultiWrite() throws Exception { @DisplayName("Test Multi Write")
void testMultiWrite() throws Exception {
val recordsPair = records(); val recordsPair = records();
Path p = Files.createTempFile("arrowwritetest", ".arrow"); Path p = Files.createTempFile("arrowwritetest", ".arrow");
FileUtils.write(p.toFile(),recordsPair.getFirst()); FileUtils.write(p.toFile(), recordsPair.getFirst());
p.toFile().deleteOnExit(); p.toFile().deleteOnExit();
int numReaders = 2; int numReaders = 2;
RecordReader[] readers = new RecordReader[numReaders]; RecordReader[] readers = new RecordReader[numReaders];
InputSplit[] splits = new InputSplit[numReaders]; InputSplit[] splits = new InputSplit[numReaders];
for(int i = 0; i < readers.length; i++) { for (int i = 0; i < readers.length; i++) {
FileSplit split = new FileSplit(p.toFile()); FileSplit split = new FileSplit(p.toFile());
ArrowRecordReader arrowRecordReader = new ArrowRecordReader(); ArrowRecordReader arrowRecordReader = new ArrowRecordReader();
readers[i] = arrowRecordReader; readers[i] = arrowRecordReader;
splits[i] = split; splits[i] = split;
} }
ArrowRecordWriter arrowRecordWriter = new ArrowRecordWriter(recordsPair.getMiddle()); ArrowRecordWriter arrowRecordWriter = new ArrowRecordWriter(recordsPair.getMiddle());
FileSplit split = new FileSplit(p.toFile()); FileSplit split = new FileSplit(p.toFile());
arrowRecordWriter.initialize(split,new NumberOfRecordsPartitioner()); arrowRecordWriter.initialize(split, new NumberOfRecordsPartitioner());
arrowRecordWriter.writeBatch(recordsPair.getRight()); arrowRecordWriter.writeBatch(recordsPair.getRight());
CSVRecordWriter csvRecordWriter = new CSVRecordWriter(); CSVRecordWriter csvRecordWriter = new CSVRecordWriter();
Path p2 = Files.createTempFile("arrowwritetest", ".csv"); Path p2 = Files.createTempFile("arrowwritetest", ".csv");
FileUtils.write(p2.toFile(),recordsPair.getFirst()); FileUtils.write(p2.toFile(), recordsPair.getFirst());
p.toFile().deleteOnExit(); p.toFile().deleteOnExit();
FileSplit outputCsv = new FileSplit(p2.toFile()); FileSplit outputCsv = new FileSplit(p2.toFile());
RecordMapper mapper = RecordMapper.builder().batchSize(10).inputUrl(split).outputUrl(outputCsv).partitioner(new NumberOfRecordsPartitioner()).readersToConcat(readers).splitPerReader(splits).recordWriter(csvRecordWriter).build();
RecordMapper mapper = RecordMapper.builder().batchSize(10).inputUrl(split)
.outputUrl(outputCsv)
.partitioner(new NumberOfRecordsPartitioner()).readersToConcat(readers)
.splitPerReader(splits)
.recordWriter(csvRecordWriter)
.build();
mapper.copy(); mapper.copy();
} }
@Test @Test
public void testCopyFromArrowToCsv() throws Exception { @DisplayName("Test Copy From Arrow To Csv")
void testCopyFromArrowToCsv() throws Exception {
val recordsPair = records(); val recordsPair = records();
Path p = Files.createTempFile("arrowwritetest", ".arrow"); Path p = Files.createTempFile("arrowwritetest", ".arrow");
FileUtils.write(p.toFile(),recordsPair.getFirst()); FileUtils.write(p.toFile(), recordsPair.getFirst());
p.toFile().deleteOnExit(); p.toFile().deleteOnExit();
ArrowRecordWriter arrowRecordWriter = new ArrowRecordWriter(recordsPair.getMiddle()); ArrowRecordWriter arrowRecordWriter = new ArrowRecordWriter(recordsPair.getMiddle());
FileSplit split = new FileSplit(p.toFile()); FileSplit split = new FileSplit(p.toFile());
arrowRecordWriter.initialize(split,new NumberOfRecordsPartitioner()); arrowRecordWriter.initialize(split, new NumberOfRecordsPartitioner());
arrowRecordWriter.writeBatch(recordsPair.getRight()); arrowRecordWriter.writeBatch(recordsPair.getRight());
ArrowRecordReader arrowRecordReader = new ArrowRecordReader(); ArrowRecordReader arrowRecordReader = new ArrowRecordReader();
arrowRecordReader.initialize(split); arrowRecordReader.initialize(split);
CSVRecordWriter csvRecordWriter = new CSVRecordWriter(); CSVRecordWriter csvRecordWriter = new CSVRecordWriter();
Path p2 = Files.createTempFile("arrowwritetest", ".csv"); Path p2 = Files.createTempFile("arrowwritetest", ".csv");
FileUtils.write(p2.toFile(),recordsPair.getFirst()); FileUtils.write(p2.toFile(), recordsPair.getFirst());
p.toFile().deleteOnExit(); p.toFile().deleteOnExit();
FileSplit outputCsv = new FileSplit(p2.toFile()); FileSplit outputCsv = new FileSplit(p2.toFile());
RecordMapper mapper = RecordMapper.builder().batchSize(10).inputUrl(split).outputUrl(outputCsv).partitioner(new NumberOfRecordsPartitioner()).recordReader(arrowRecordReader).recordWriter(csvRecordWriter).build();
RecordMapper mapper = RecordMapper.builder().batchSize(10).inputUrl(split)
.outputUrl(outputCsv)
.partitioner(new NumberOfRecordsPartitioner())
.recordReader(arrowRecordReader).recordWriter(csvRecordWriter)
.build();
mapper.copy(); mapper.copy();
CSVRecordReader recordReader = new CSVRecordReader(); CSVRecordReader recordReader = new CSVRecordReader();
recordReader.initialize(outputCsv); recordReader.initialize(outputCsv);
List<List<Writable>> loadedCSvRecords = recordReader.next(10); List<List<Writable>> loadedCSvRecords = recordReader.next(10);
assertEquals(10,loadedCSvRecords.size()); assertEquals(10, loadedCSvRecords.size());
} }
@Test @Test
public void testCopyFromCsvToArrow() throws Exception { @DisplayName("Test Copy From Csv To Arrow")
void testCopyFromCsvToArrow() throws Exception {
val recordsPair = records(); val recordsPair = records();
Path p = Files.createTempFile("csvwritetest", ".csv"); Path p = Files.createTempFile("csvwritetest", ".csv");
FileUtils.write(p.toFile(),recordsPair.getFirst()); FileUtils.write(p.toFile(), recordsPair.getFirst());
p.toFile().deleteOnExit(); p.toFile().deleteOnExit();
CSVRecordReader recordReader = new CSVRecordReader(); CSVRecordReader recordReader = new CSVRecordReader();
FileSplit fileSplit = new FileSplit(p.toFile()); FileSplit fileSplit = new FileSplit(p.toFile());
ArrowRecordWriter arrowRecordWriter = new ArrowRecordWriter(recordsPair.getMiddle()); ArrowRecordWriter arrowRecordWriter = new ArrowRecordWriter(recordsPair.getMiddle());
File outputFile = Files.createTempFile("outputarrow","arrow").toFile(); File outputFile = Files.createTempFile("outputarrow", "arrow").toFile();
FileSplit outputFileSplit = new FileSplit(outputFile); FileSplit outputFileSplit = new FileSplit(outputFile);
RecordMapper mapper = RecordMapper.builder().batchSize(10).inputUrl(fileSplit) RecordMapper mapper = RecordMapper.builder().batchSize(10).inputUrl(fileSplit).outputUrl(outputFileSplit).partitioner(new NumberOfRecordsPartitioner()).recordReader(recordReader).recordWriter(arrowRecordWriter).build();
.outputUrl(outputFileSplit).partitioner(new NumberOfRecordsPartitioner())
.recordReader(recordReader).recordWriter(arrowRecordWriter)
.build();
mapper.copy(); mapper.copy();
ArrowRecordReader arrowRecordReader = new ArrowRecordReader(); ArrowRecordReader arrowRecordReader = new ArrowRecordReader();
arrowRecordReader.initialize(outputFileSplit); arrowRecordReader.initialize(outputFileSplit);
List<List<Writable>> next = arrowRecordReader.next(10); List<List<Writable>> next = arrowRecordReader.next(10);
System.out.println(next); System.out.println(next);
assertEquals(10,next.size()); assertEquals(10, next.size());
} }
private Triple<String,Schema,List<List<Writable>>> records() { private Triple<String, Schema, List<List<Writable>>> records() {
List<List<Writable>> list = new ArrayList<>(); List<List<Writable>> list = new ArrayList<>();
StringBuilder sb = new StringBuilder(); StringBuilder sb = new StringBuilder();
int numColumns = 3; int numColumns = 3;
@ -176,15 +141,10 @@ public class RecordMapperTest extends BaseND4JTest {
} }
list.add(temp); list.add(temp);
} }
Schema.Builder schemaBuilder = new Schema.Builder(); Schema.Builder schemaBuilder = new Schema.Builder();
for(int i = 0; i < numColumns; i++) { for (int i = 0; i < numColumns; i++) {
schemaBuilder.addColumnInteger(String.valueOf(i)); schemaBuilder.addColumnInteger(String.valueOf(i));
} }
return Triple.of(sb.toString(), schemaBuilder.build(), list);
return Triple.of(sb.toString(),schemaBuilder.build(),list);
} }
} }

View File

@ -29,16 +29,16 @@ import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.datavec.arrow.ArrowConverter; import org.datavec.arrow.ArrowConverter;
import org.junit.Ignore; import org.junit.jupiter.api.Disabled;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
public class ArrowWritableRecordTimeSeriesBatchTests extends BaseND4JTest { public class ArrowWritableRecordTimeSeriesBatchTests extends BaseND4JTest {
@ -46,6 +46,7 @@ public class ArrowWritableRecordTimeSeriesBatchTests extends BaseND4JTest {
@Test @Test
@Disabled
public void testBasicIndexing() { public void testBasicIndexing() {
Schema.Builder schema = new Schema.Builder(); Schema.Builder schema = new Schema.Builder();
for(int i = 0; i < 3; i++) { for(int i = 0; i < 3; i++) {
@ -54,9 +55,9 @@ public class ArrowWritableRecordTimeSeriesBatchTests extends BaseND4JTest {
List<List<Writable>> timeStep = Arrays.asList( List<List<Writable>> timeStep = Arrays.asList(
Arrays.<Writable>asList(new IntWritable(0),new IntWritable(1),new IntWritable(2)), Arrays.asList(new IntWritable(0),new IntWritable(1),new IntWritable(2)),
Arrays.<Writable>asList(new IntWritable(1),new IntWritable(2),new IntWritable(3)), Arrays.asList(new IntWritable(1),new IntWritable(2),new IntWritable(3)),
Arrays.<Writable>asList(new IntWritable(4),new IntWritable(5),new IntWritable(6)) Arrays.asList(new IntWritable(4),new IntWritable(5),new IntWritable(6))
); );
int numTimeSteps = 5; int numTimeSteps = 5;
@ -69,7 +70,7 @@ public class ArrowWritableRecordTimeSeriesBatchTests extends BaseND4JTest {
assertEquals(3,fieldVectors.size()); assertEquals(3,fieldVectors.size());
for(FieldVector fieldVector : fieldVectors) { for(FieldVector fieldVector : fieldVectors) {
for(int i = 0; i < fieldVector.getValueCount(); i++) { for(int i = 0; i < fieldVector.getValueCount(); i++) {
assertFalse("Index " + i + " was null for field vector " + fieldVector, fieldVector.isNull(i)); assertFalse( fieldVector.isNull(i),"Index " + i + " was null for field vector " + fieldVector);
} }
} }
@ -79,7 +80,7 @@ public class ArrowWritableRecordTimeSeriesBatchTests extends BaseND4JTest {
@Test @Test
//not worried about this till after next release //not worried about this till after next release
@Ignore @Disabled
public void testVariableLengthTS() { public void testVariableLengthTS() {
Schema.Builder schema = new Schema.Builder() Schema.Builder schema = new Schema.Builder()
.addColumnString("str") .addColumnString("str")

View File

@ -119,10 +119,10 @@
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>nd4j-tests-cpu</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-11.0</id> <id>nd4j-tests-cuda</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -17,41 +17,39 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.image; package org.datavec.image;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
import org.datavec.api.io.labels.ParentPathLabelGenerator; import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit; import org.datavec.api.split.FileSplit;
import org.datavec.image.recordreader.ImageRecordReader; import org.datavec.image.recordreader.ImageRecordReader;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.common.io.ClassPathResource;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.io.ClassPathResource;
import java.io.File; import java.io.File;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals; @DisplayName("Label Generator Test")
import static org.junit.Assert.assertTrue; class LabelGeneratorTest {
public class LabelGeneratorTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test @Test
public void testParentPathLabelGenerator() throws Exception { @DisplayName("Test Parent Path Label Generator")
//https://github.com/deeplearning4j/DataVec/issues/273 @Disabled
void testParentPathLabelGenerator(@TempDir Path testDir) throws Exception {
File orig = new ClassPathResource("datavec-data-image/testimages/class0/0.jpg").getFile(); File orig = new ClassPathResource("datavec-data-image/testimages/class0/0.jpg").getFile();
for (String dirPrefix : new String[] { "m.", "m" }) {
for(String dirPrefix : new String[]{"m.", "m"}) { File f = testDir.toFile();
File f = testDir.newFolder();
int numDirs = 3; int numDirs = 3;
int filesPerDir = 4; int filesPerDir = 4;
for (int i = 0; i < numDirs; i++) { for (int i = 0; i < numDirs; i++) {
File currentLabelDir = new File(f, dirPrefix + i); File currentLabelDir = new File(f, dirPrefix + i);
currentLabelDir.mkdirs(); currentLabelDir.mkdirs();
@ -61,14 +59,11 @@ public class LabelGeneratorTest {
assertTrue(f3.exists()); assertTrue(f3.exists());
} }
} }
ImageRecordReader rr = new ImageRecordReader(28, 28, 1, new ParentPathLabelGenerator()); ImageRecordReader rr = new ImageRecordReader(28, 28, 1, new ParentPathLabelGenerator());
rr.initialize(new FileSplit(f)); rr.initialize(new FileSplit(f));
List<String> labelsAct = rr.getLabels(); List<String> labelsAct = rr.getLabels();
List<String> labelsExp = Arrays.asList(dirPrefix + "0", dirPrefix + "1", dirPrefix + "2"); List<String> labelsExp = Arrays.asList(dirPrefix + "0", dirPrefix + "1", dirPrefix + "2");
assertEquals(labelsExp, labelsAct); assertEquals(labelsExp, labelsAct);
int expCount = numDirs * filesPerDir; int expCount = numDirs * filesPerDir;
int actCount = 0; int actCount = 0;
while (rr.hasNext()) { while (rr.hasNext()) {

View File

@ -22,8 +22,8 @@ package org.datavec.image.loader;
import org.apache.commons.io.FilenameUtils; import org.apache.commons.io.FilenameUtils;
import org.datavec.api.records.reader.RecordReader; import org.datavec.api.records.reader.RecordReader;
import org.junit.Ignore; import org.junit.jupiter.api.Disabled;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
import java.io.File; import java.io.File;
@ -32,9 +32,9 @@ import java.io.InputStream;
import java.util.List; import java.util.List;
import java.util.Random; import java.util.Random;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.Assert.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
/** /**
* *
@ -182,7 +182,7 @@ public class LoaderTests {
} }
@Ignore // Use when confirming data is getting stored @Disabled // Use when confirming data is getting stored
@Test @Test
public void testProcessCifar() { public void testProcessCifar() {
int row = 32; int row = 32;
@ -208,15 +208,15 @@ public class LoaderTests {
int minibatch = 100; int minibatch = 100;
int nMinibatches = 50000 / minibatch; int nMinibatches = 50000 / minibatch;
for( int i=0; i<nMinibatches; i++ ){ for( int i=0; i < nMinibatches; i++) {
DataSet ds = loader.next(minibatch); DataSet ds = loader.next(minibatch);
String s = String.valueOf(i); String s = String.valueOf(i);
assertNotNull(s, ds.getFeatures()); assertNotNull(ds.getFeatures(),s);
assertNotNull(s, ds.getLabels()); assertNotNull(ds.getLabels(),s);
assertEquals(s, minibatch, ds.getFeatures().size(0)); assertEquals(minibatch, ds.getFeatures().size(0),s);
assertEquals(s, minibatch, ds.getLabels().size(0)); assertEquals(minibatch, ds.getLabels().size(0),s);
assertEquals(s, 10, ds.getLabels().size(1)); assertEquals(10, ds.getLabels().size(1),s);
} }
} }

View File

@ -21,7 +21,7 @@
package org.datavec.image.loader; package org.datavec.image.loader;
import org.datavec.image.data.Image; import org.datavec.image.data.Image;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.resources.Resources; import org.nd4j.common.resources.Resources;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -32,7 +32,7 @@ import java.io.FileInputStream;
import java.io.InputStream; import java.io.InputStream;
import java.util.Random; import java.util.Random;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestImageLoader { public class TestImageLoader {

View File

@ -30,9 +30,10 @@ import org.bytedeco.javacv.Java2DFrameConverter;
import org.bytedeco.javacv.OpenCVFrameConverter; import org.bytedeco.javacv.OpenCVFrameConverter;
import org.datavec.image.data.Image; import org.datavec.image.data.Image;
import org.datavec.image.data.ImageWritable; import org.datavec.image.data.ImageWritable;
import org.junit.Rule;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.resources.Resources; import org.nd4j.common.resources.Resources;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -42,16 +43,17 @@ import org.nd4j.common.io.ClassPathResource;
import java.awt.image.BufferedImage; import java.awt.image.BufferedImage;
import java.io.*; import java.io.*;
import java.lang.reflect.Field; import java.lang.reflect.Field;
import java.nio.file.Path;
import java.util.Random; import java.util.Random;
import org.bytedeco.leptonica.*; import org.bytedeco.leptonica.*;
import org.bytedeco.opencv.opencv_core.*; import org.bytedeco.opencv.opencv_core.*;
import static org.bytedeco.leptonica.global.lept.*; import static org.bytedeco.leptonica.global.lept.*;
import static org.bytedeco.opencv.global.opencv_core.*; import static org.bytedeco.opencv.global.opencv_core.*;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.Assert.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.Assert.fail; import static org.junit.jupiter.api.Assertions.fail;
/** /**
* *
@ -62,8 +64,6 @@ public class TestNativeImageLoader {
static final long seed = 10; static final long seed = 10;
static final Random rng = new Random(seed); static final Random rng = new Random(seed);
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test @Test
public void testConvertPix() throws Exception { public void testConvertPix() throws Exception {
@ -566,8 +566,8 @@ public class TestNativeImageLoader {
@Test @Test
public void testNativeImageLoaderEmptyStreams() throws Exception { public void testNativeImageLoaderEmptyStreams(@TempDir Path testDir) throws Exception {
File dir = testDir.newFolder(); File dir = testDir.toFile();
File f = new File(dir, "myFile.jpg"); File f = new File(dir, "myFile.jpg");
f.createNewFile(); f.createNewFile();
@ -578,7 +578,7 @@ public class TestNativeImageLoader {
fail("Expected exception"); fail("Expected exception");
} catch (IOException e){ } catch (IOException e){
String msg = e.getMessage(); String msg = e.getMessage();
assertTrue(msg, msg.contains("decode image")); assertTrue(msg.contains("decode image"),msg);
} }
try(InputStream is = new FileInputStream(f)){ try(InputStream is = new FileInputStream(f)){
@ -586,7 +586,7 @@ public class TestNativeImageLoader {
fail("Expected exception"); fail("Expected exception");
} catch (IOException e){ } catch (IOException e){
String msg = e.getMessage(); String msg = e.getMessage();
assertTrue(msg, msg.contains("decode image")); assertTrue(msg.contains("decode image"),msg);
} }
try(InputStream is = new FileInputStream(f)){ try(InputStream is = new FileInputStream(f)){
@ -594,7 +594,7 @@ public class TestNativeImageLoader {
fail("Expected exception"); fail("Expected exception");
} catch (IOException e){ } catch (IOException e){
String msg = e.getMessage(); String msg = e.getMessage();
assertTrue(msg, msg.contains("decode image")); assertTrue(msg.contains("decode image"),msg);
} }
try(InputStream is = new FileInputStream(f)){ try(InputStream is = new FileInputStream(f)){
@ -603,7 +603,7 @@ public class TestNativeImageLoader {
fail("Expected exception"); fail("Expected exception");
} catch (IOException e){ } catch (IOException e){
String msg = e.getMessage(); String msg = e.getMessage();
assertTrue(msg, msg.contains("decode image")); assertTrue( msg.contains("decode image"),msg);
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.image.recordreader; package org.datavec.image.recordreader;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
@ -28,61 +27,56 @@ import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.NDArrayWritable; import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.datavec.image.loader.NativeImageLoader; import org.datavec.image.loader.NativeImageLoader;
import org.junit.Rule;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.TemporaryFolder; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.loader.FileBatch; import org.nd4j.common.loader.FileBatch;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import java.io.File; import java.io.File;
import java.util.*; import java.util.*;
import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.*; @DisplayName("File Batch Record Reader Test")
class FileBatchRecordReaderTest {
public class FileBatchRecordReaderTest { @TempDir
public Path testDir;
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test @Test
public void testCsv() throws Exception { @DisplayName("Test Csv")
File extractedSourceDir = testDir.newFolder(); void testCsv(@TempDir Path testDir,@TempDir Path baseDirPath) throws Exception {
File extractedSourceDir = testDir.toFile();
new ClassPathResource("datavec-data-image/testimages").copyDirectory(extractedSourceDir); new ClassPathResource("datavec-data-image/testimages").copyDirectory(extractedSourceDir);
File baseDir = testDir.newFolder(); File baseDir = baseDirPath.toFile();
List<File> c = new ArrayList<>(FileUtils.listFiles(extractedSourceDir, null, true)); List<File> c = new ArrayList<>(FileUtils.listFiles(extractedSourceDir, null, true));
assertEquals(6, c.size()); assertEquals(6, c.size());
Collections.sort(c, new Comparator<File>() { Collections.sort(c, new Comparator<File>() {
@Override @Override
public int compare(File o1, File o2) { public int compare(File o1, File o2) {
return o1.getPath().compareTo(o2.getPath()); return o1.getPath().compareTo(o2.getPath());
} }
}); });
FileBatch fb = FileBatch.forFiles(c); FileBatch fb = FileBatch.forFiles(c);
File saveFile = new File(baseDir, "saved.zip"); File saveFile = new File(baseDir, "saved.zip");
fb.writeAsZip(saveFile); fb.writeAsZip(saveFile);
fb = FileBatch.readFromZip(saveFile); fb = FileBatch.readFromZip(saveFile);
PathLabelGenerator labelMaker = new ParentPathLabelGenerator(); PathLabelGenerator labelMaker = new ParentPathLabelGenerator();
ImageRecordReader rr = new ImageRecordReader(32, 32, 1, labelMaker); ImageRecordReader rr = new ImageRecordReader(32, 32, 1, labelMaker);
rr.setLabels(Arrays.asList("class0", "class1")); rr.setLabels(Arrays.asList("class0", "class1"));
FileBatchRecordReader fbrr = new FileBatchRecordReader(rr, fb); FileBatchRecordReader fbrr = new FileBatchRecordReader(rr, fb);
NativeImageLoader il = new NativeImageLoader(32, 32, 1); NativeImageLoader il = new NativeImageLoader(32, 32, 1);
for( int test=0; test<3; test++) { for (int test = 0; test < 3; test++) {
for (int i = 0; i < 6; i++) { for (int i = 0; i < 6; i++) {
assertTrue(fbrr.hasNext()); assertTrue(fbrr.hasNext());
List<Writable> next = fbrr.next(); List<Writable> next = fbrr.next();
assertEquals(2, next.size()); assertEquals(2, next.size());
INDArray exp; INDArray exp;
switch (i){ switch(i) {
case 0: case 0:
exp = il.asMatrix(new File(extractedSourceDir, "class0/0.jpg")); exp = il.asMatrix(new File(extractedSourceDir, "class0/0.jpg"));
break; break;
@ -105,8 +99,7 @@ public class FileBatchRecordReaderTest {
throw new RuntimeException(); throw new RuntimeException();
} }
Writable expLabel = (i < 3 ? new IntWritable(0) : new IntWritable(1)); Writable expLabel = (i < 3 ? new IntWritable(0) : new IntWritable(1));
assertEquals(((NDArrayWritable) next.get(0)).get(), exp);
assertEquals(((NDArrayWritable)next.get(0)).get(), exp);
assertEquals(expLabel, next.get(1)); assertEquals(expLabel, next.get(1));
} }
assertFalse(fbrr.hasNext()); assertFalse(fbrr.hasNext());
@ -114,5 +107,4 @@ public class FileBatchRecordReaderTest {
fbrr.reset(); fbrr.reset();
} }
} }
} }

View File

@ -36,9 +36,10 @@ import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.NDArrayWritable; import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.datavec.api.writable.batch.NDArrayRecordBatch; import org.datavec.api.writable.batch.NDArrayRecordBatch;
import org.junit.Rule;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -46,28 +47,30 @@ import org.nd4j.common.io.ClassPathResource;
import java.io.*; import java.io.*;
import java.net.URI; import java.net.URI;
import java.nio.file.Path;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.Random; import java.util.Random;
import static org.junit.Assert.*; import static org.junit.jupiter.api.Assertions.*;
public class TestImageRecordReader { public class TestImageRecordReader {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test(expected = IllegalArgumentException.class) @Test()
public void testEmptySplit() throws IOException { public void testEmptySplit() throws IOException {
InputSplit data = new CollectionInputSplit(new ArrayList<URI>()); assertThrows(IllegalArgumentException.class,() -> {
InputSplit data = new CollectionInputSplit(new ArrayList<>());
new ImageRecordReader().initialize(data, null); new ImageRecordReader().initialize(data, null);
});
} }
@Test @Test
public void testMetaData() throws IOException { public void testMetaData(@TempDir Path testDir) throws IOException {
File parentDir = testDir.newFolder(); File parentDir = testDir.toFile();
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(parentDir); new ClassPathResource("datavec-data-image/testimages/").copyDirectory(parentDir);
// System.out.println(f.getAbsolutePath()); // System.out.println(f.getAbsolutePath());
// System.out.println(f.getParentFile().getParentFile().getAbsolutePath()); // System.out.println(f.getParentFile().getParentFile().getAbsolutePath());
@ -104,11 +107,11 @@ public class TestImageRecordReader {
} }
@Test @Test
public void testImageRecordReaderLabelsOrder() throws Exception { public void testImageRecordReaderLabelsOrder(@TempDir Path testDir) throws Exception {
//Labels order should be consistent, regardless of file iteration order //Labels order should be consistent, regardless of file iteration order
//Idea: labels order should be consistent regardless of input file order //Idea: labels order should be consistent regardless of input file order
File f = testDir.newFolder(); File f = testDir.toFile();
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f); new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f);
File f0 = new File(f, "/class0/0.jpg"); File f0 = new File(f, "/class0/0.jpg");
File f1 = new File(f, "/class1/A.jpg"); File f1 = new File(f, "/class1/A.jpg");
@ -135,11 +138,11 @@ public class TestImageRecordReader {
@Test @Test
public void testImageRecordReaderRandomization() throws Exception { public void testImageRecordReaderRandomization(@TempDir Path testDir) throws Exception {
//Order of FileSplit+ImageRecordReader should be different after reset //Order of FileSplit+ImageRecordReader should be different after reset
//Idea: labels order should be consistent regardless of input file order //Idea: labels order should be consistent regardless of input file order
File f0 = testDir.newFolder(); File f0 = testDir.toFile();
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f0); new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f0);
FileSplit fs = new FileSplit(f0, new Random(12345)); FileSplit fs = new FileSplit(f0, new Random(12345));
@ -189,13 +192,13 @@ public class TestImageRecordReader {
@Test @Test
public void testImageRecordReaderRegression() throws Exception { public void testImageRecordReaderRegression(@TempDir Path testDir) throws Exception {
PathLabelGenerator regressionLabelGen = new TestRegressionLabelGen(); PathLabelGenerator regressionLabelGen = new TestRegressionLabelGen();
ImageRecordReader rr = new ImageRecordReader(28, 28, 3, regressionLabelGen); ImageRecordReader rr = new ImageRecordReader(28, 28, 3, regressionLabelGen);
File rootDir = testDir.newFolder(); File rootDir = testDir.toFile();
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(rootDir); new ClassPathResource("datavec-data-image/testimages/").copyDirectory(rootDir);
FileSplit fs = new FileSplit(rootDir); FileSplit fs = new FileSplit(rootDir);
rr.initialize(fs); rr.initialize(fs);
@ -244,10 +247,10 @@ public class TestImageRecordReader {
} }
@Test @Test
public void testListenerInvocationBatch() throws IOException { public void testListenerInvocationBatch(@TempDir Path testDir) throws IOException {
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator(); ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
ImageRecordReader rr = new ImageRecordReader(32, 32, 3, labelMaker); ImageRecordReader rr = new ImageRecordReader(32, 32, 3, labelMaker);
File f = testDir.newFolder(); File f = testDir.toFile();
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f); new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f);
File parent = f; File parent = f;
@ -260,10 +263,10 @@ public class TestImageRecordReader {
} }
@Test @Test
public void testListenerInvocationSingle() throws IOException { public void testListenerInvocationSingle(@TempDir Path testDir) throws IOException {
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator(); ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
ImageRecordReader rr = new ImageRecordReader(32, 32, 3, labelMaker); ImageRecordReader rr = new ImageRecordReader(32, 32, 3, labelMaker);
File parent = testDir.newFolder(); File parent = testDir.toFile();
new ClassPathResource("datavec-data-image/testimages/class0/").copyDirectory(parent); new ClassPathResource("datavec-data-image/testimages/class0/").copyDirectory(parent);
int numFiles = parent.list().length; int numFiles = parent.list().length;
rr.initialize(new FileSplit(parent)); rr.initialize(new FileSplit(parent));
@ -315,7 +318,7 @@ public class TestImageRecordReader {
@Test @Test
public void testImageRecordReaderPathMultiLabelGenerator() throws Exception { public void testImageRecordReaderPathMultiLabelGenerator(@TempDir Path testDir) throws Exception {
Nd4j.setDataType(DataType.FLOAT); Nd4j.setDataType(DataType.FLOAT);
//Assumption: 2 multi-class (one hot) classification labels: 2 and 3 classes respectively //Assumption: 2 multi-class (one hot) classification labels: 2 and 3 classes respectively
// PLUS single value (Writable) regression label // PLUS single value (Writable) regression label
@ -324,7 +327,7 @@ public class TestImageRecordReader {
ImageRecordReader rr = new ImageRecordReader(28, 28, 3, multiLabelGen); ImageRecordReader rr = new ImageRecordReader(28, 28, 3, multiLabelGen);
File rootDir = testDir.newFolder(); File rootDir = testDir.toFile();
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(rootDir); new ClassPathResource("datavec-data-image/testimages/").copyDirectory(rootDir);
FileSplit fs = new FileSplit(rootDir); FileSplit fs = new FileSplit(rootDir);
rr.initialize(fs); rr.initialize(fs);
@ -471,9 +474,9 @@ public class TestImageRecordReader {
@Test @Test
public void testNCHW_NCHW() throws Exception { public void testNCHW_NCHW(@TempDir Path testDir) throws Exception {
//Idea: labels order should be consistent regardless of input file order //Idea: labels order should be consistent regardless of input file order
File f0 = testDir.newFolder(); File f0 = testDir.toFile();
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f0); new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f0);
FileSplit fs0 = new FileSplit(f0, new Random(12345)); FileSplit fs0 = new FileSplit(f0, new Random(12345));

View File

@ -35,9 +35,10 @@ import org.datavec.image.transform.FlipImageTransform;
import org.datavec.image.transform.ImageTransform; import org.datavec.image.transform.ImageTransform;
import org.datavec.image.transform.PipelineImageTransform; import org.datavec.image.transform.PipelineImageTransform;
import org.datavec.image.transform.ResizeImageTransform; import org.datavec.image.transform.ResizeImageTransform;
import org.junit.Rule;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing; import org.nd4j.linalg.indexing.BooleanIndexing;
@ -46,24 +47,24 @@ import org.nd4j.common.io.ClassPathResource;
import java.io.File; import java.io.File;
import java.net.URI; import java.net.URI;
import java.nio.file.Path;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import static org.junit.Assert.*; import static org.junit.jupiter.api.Assertions.*;
public class TestObjectDetectionRecordReader { public class TestObjectDetectionRecordReader {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test @Test
public void test() throws Exception { public void test(@TempDir Path testDir) throws Exception {
for(boolean nchw : new boolean[]{true, false}) { for(boolean nchw : new boolean[]{true, false}) {
ImageObjectLabelProvider lp = new TestImageObjectDetectionLabelProvider(); ImageObjectLabelProvider lp = new TestImageObjectDetectionLabelProvider();
File f = testDir.newFolder(); File f = testDir.toFile();
new ClassPathResource("datavec-data-image/objdetect/").copyDirectory(f); new ClassPathResource("datavec-data-image/objdetect/").copyDirectory(f);
String path = new File(f, "000012.jpg").getParent(); String path = new File(f, "000012.jpg").getParent();

View File

@ -21,27 +21,27 @@
package org.datavec.image.recordreader.objdetect; package org.datavec.image.recordreader.objdetect;
import org.datavec.image.recordreader.objdetect.impl.VocLabelProvider; import org.datavec.image.recordreader.objdetect.impl.VocLabelProvider;
import org.junit.Rule;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import java.io.File; import java.io.File;
import java.nio.file.Path;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestVocLabelProvider { public class TestVocLabelProvider {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test @Test
public void testVocLabelProvider() throws Exception { public void testVocLabelProvider(@TempDir Path testDir) throws Exception {
File f = testDir.newFolder(); File f = testDir.toFile();
new ClassPathResource("datavec-data-image/voc/2007/").copyDirectory(f); new ClassPathResource("datavec-data-image/voc/2007/").copyDirectory(f);
String path = f.getAbsolutePath(); //new ClassPathResource("voc/2007/JPEGImages/000005.jpg").getFile().getParentFile().getParent(); String path = f.getAbsolutePath(); //new ClassPathResource("voc/2007/JPEGImages/000005.jpg").getFile().getParentFile().getParent();

View File

@ -17,106 +17,70 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.image.transform; package org.datavec.image.transform;
import org.datavec.image.data.ImageWritable; import org.datavec.image.data.ImageWritable;
import org.junit.Test; import org.junit.jupiter.api.Test;
import java.io.IOException; import java.io.IOException;
import java.util.List; import java.util.List;
import java.util.Random; import java.util.Random;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals; @DisplayName("Json Yaml Test")
import static org.junit.Assert.assertTrue; class JsonYamlTest {
public class JsonYamlTest {
@Test @Test
public void testJsonYamlImageTransformProcess() throws IOException { @DisplayName("Test Json Yaml Image Transform Process")
void testJsonYamlImageTransformProcess() throws IOException {
int seed = 12345; int seed = 12345;
Random random = new Random(seed); Random random = new Random(seed);
// from org.bytedeco.javacpp.opencv_imgproc
//from org.bytedeco.javacpp.opencv_imgproc
int COLOR_BGR2Luv = 50; int COLOR_BGR2Luv = 50;
int CV_BGR2GRAY = 6; int CV_BGR2GRAY = 6;
ImageTransformProcess itp = new ImageTransformProcess.Builder().colorConversionTransform(COLOR_BGR2Luv).cropImageTransform(10).equalizeHistTransform(CV_BGR2GRAY).flipImageTransform(0).resizeImageTransform(300, 300).rotateImageTransform(30).scaleImageTransform(3).warpImageTransform((float) 0.5).build();
ImageTransformProcess itp = new ImageTransformProcess.Builder().colorConversionTransform(COLOR_BGR2Luv)
.cropImageTransform(10).equalizeHistTransform(CV_BGR2GRAY).flipImageTransform(0)
.resizeImageTransform(300, 300).rotateImageTransform(30).scaleImageTransform(3)
.warpImageTransform((float) 0.5)
// Note : since randomCropTransform use random value
// the results from each case(json, yaml, ImageTransformProcess)
// can be different
// don't use the below line
// if you uncomment it, you will get fail from below assertions
// .randomCropTransform(seed, 50, 50)
// Note : you will get "java.lang.NoClassDefFoundError: Could not initialize class org.bytedeco.javacpp.avutil"
// it needs to add the below dependency
// <dependency>
// <groupId>org.bytedeco</groupId>
// <artifactId>ffmpeg-platform</artifactId>
// </dependency>
// FFmpeg has license issues, be careful to use it
//.filterImageTransform("noise=alls=20:allf=t+u,format=rgba", 100, 100, 4)
.build();
String asJson = itp.toJson(); String asJson = itp.toJson();
String asYaml = itp.toYaml(); String asYaml = itp.toYaml();
// System.out.println(asJson);
// System.out.println(asJson); // System.out.println("\n\n\n");
// System.out.println("\n\n\n"); // System.out.println(asYaml);
// System.out.println(asYaml);
ImageWritable img = TestImageTransform.makeRandomImage(0, 0, 3); ImageWritable img = TestImageTransform.makeRandomImage(0, 0, 3);
ImageWritable imgJson = new ImageWritable(img.getFrame().clone()); ImageWritable imgJson = new ImageWritable(img.getFrame().clone());
ImageWritable imgYaml = new ImageWritable(img.getFrame().clone()); ImageWritable imgYaml = new ImageWritable(img.getFrame().clone());
ImageWritable imgAll = new ImageWritable(img.getFrame().clone()); ImageWritable imgAll = new ImageWritable(img.getFrame().clone());
ImageTransformProcess itpFromJson = ImageTransformProcess.fromJson(asJson); ImageTransformProcess itpFromJson = ImageTransformProcess.fromJson(asJson);
ImageTransformProcess itpFromYaml = ImageTransformProcess.fromYaml(asYaml); ImageTransformProcess itpFromYaml = ImageTransformProcess.fromYaml(asYaml);
List<ImageTransform> transformList = itp.getTransformList(); List<ImageTransform> transformList = itp.getTransformList();
List<ImageTransform> transformListJson = itpFromJson.getTransformList(); List<ImageTransform> transformListJson = itpFromJson.getTransformList();
List<ImageTransform> transformListYaml = itpFromYaml.getTransformList(); List<ImageTransform> transformListYaml = itpFromYaml.getTransformList();
for (int i = 0; i < transformList.size(); i++) { for (int i = 0; i < transformList.size(); i++) {
ImageTransform it = transformList.get(i); ImageTransform it = transformList.get(i);
ImageTransform itJson = transformListJson.get(i); ImageTransform itJson = transformListJson.get(i);
ImageTransform itYaml = transformListYaml.get(i); ImageTransform itYaml = transformListYaml.get(i);
System.out.println(i + "\t" + it); System.out.println(i + "\t" + it);
img = it.transform(img); img = it.transform(img);
imgJson = itJson.transform(imgJson); imgJson = itJson.transform(imgJson);
imgYaml = itYaml.transform(imgYaml); imgYaml = itYaml.transform(imgYaml);
if (it instanceof RandomCropTransform) { if (it instanceof RandomCropTransform) {
assertTrue(img.getFrame().imageHeight == imgJson.getFrame().imageHeight); assertTrue(img.getFrame().imageHeight == imgJson.getFrame().imageHeight);
assertTrue(img.getFrame().imageWidth == imgJson.getFrame().imageWidth); assertTrue(img.getFrame().imageWidth == imgJson.getFrame().imageWidth);
assertTrue(img.getFrame().imageHeight == imgYaml.getFrame().imageHeight); assertTrue(img.getFrame().imageHeight == imgYaml.getFrame().imageHeight);
assertTrue(img.getFrame().imageWidth == imgYaml.getFrame().imageWidth); assertTrue(img.getFrame().imageWidth == imgYaml.getFrame().imageWidth);
} else if (it instanceof FilterImageTransform) { } else if (it instanceof FilterImageTransform) {
assertEquals(img.getFrame().imageHeight, imgJson.getFrame().imageHeight); assertEquals(img.getFrame().imageHeight, imgJson.getFrame().imageHeight);
assertEquals(img.getFrame().imageWidth, imgJson.getFrame().imageWidth); assertEquals(img.getFrame().imageWidth, imgJson.getFrame().imageWidth);
assertEquals(img.getFrame().imageChannels, imgJson.getFrame().imageChannels); assertEquals(img.getFrame().imageChannels, imgJson.getFrame().imageChannels);
assertEquals(img.getFrame().imageHeight, imgYaml.getFrame().imageHeight); assertEquals(img.getFrame().imageHeight, imgYaml.getFrame().imageHeight);
assertEquals(img.getFrame().imageWidth, imgYaml.getFrame().imageWidth); assertEquals(img.getFrame().imageWidth, imgYaml.getFrame().imageWidth);
assertEquals(img.getFrame().imageChannels, imgYaml.getFrame().imageChannels); assertEquals(img.getFrame().imageChannels, imgYaml.getFrame().imageChannels);
} else { } else {
assertEquals(img, imgJson); assertEquals(img, imgJson);
assertEquals(img, imgYaml); assertEquals(img, imgYaml);
} }
} }
imgAll = itp.execute(imgAll); imgAll = itp.execute(imgAll);
assertEquals(imgAll, img); assertEquals(imgAll, img);
} }
} }

View File

@ -17,56 +17,50 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.image.transform; package org.datavec.image.transform;
import org.bytedeco.javacv.Frame; import org.bytedeco.javacv.Frame;
import org.datavec.image.data.ImageWritable; import org.datavec.image.data.ImageWritable;
import org.junit.Before; import org.junit.jupiter.api.BeforeEach;
import org.junit.Test; import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals; @DisplayName("Resize Image Transform Test")
class ResizeImageTransformTest {
public class ResizeImageTransformTest {
@Before
public void setUp() throws Exception {
@BeforeEach
void setUp() throws Exception {
} }
@Test @Test
public void testResizeUpscale1() throws Exception { @DisplayName("Test Resize Upscale 1")
void testResizeUpscale1() throws Exception {
ImageWritable srcImg = TestImageTransform.makeRandomImage(32, 32, 3); ImageWritable srcImg = TestImageTransform.makeRandomImage(32, 32, 3);
ResizeImageTransform transform = new ResizeImageTransform(200, 200); ResizeImageTransform transform = new ResizeImageTransform(200, 200);
ImageWritable dstImg = transform.transform(srcImg); ImageWritable dstImg = transform.transform(srcImg);
Frame f = dstImg.getFrame(); Frame f = dstImg.getFrame();
assertEquals(f.imageWidth, 200); assertEquals(f.imageWidth, 200);
assertEquals(f.imageHeight, 200); assertEquals(f.imageHeight, 200);
float[] coordinates = { 100, 200 };
float[] coordinates = {100, 200};
float[] transformed = transform.query(coordinates); float[] transformed = transform.query(coordinates);
assertEquals(200f * 100 / 32, transformed[0], 0); assertEquals(200f * 100 / 32, transformed[0], 0);
assertEquals(200f * 200 / 32, transformed[1], 0); assertEquals(200f * 200 / 32, transformed[1], 0);
} }
@Test @Test
public void testResizeDownscale() throws Exception { @DisplayName("Test Resize Downscale")
void testResizeDownscale() throws Exception {
ImageWritable srcImg = TestImageTransform.makeRandomImage(571, 443, 3); ImageWritable srcImg = TestImageTransform.makeRandomImage(571, 443, 3);
ResizeImageTransform transform = new ResizeImageTransform(200, 200); ResizeImageTransform transform = new ResizeImageTransform(200, 200);
ImageWritable dstImg = transform.transform(srcImg); ImageWritable dstImg = transform.transform(srcImg);
Frame f = dstImg.getFrame(); Frame f = dstImg.getFrame();
assertEquals(f.imageWidth, 200); assertEquals(f.imageWidth, 200);
assertEquals(f.imageHeight, 200); assertEquals(f.imageHeight, 200);
float[] coordinates = { 300, 400 };
float[] coordinates = {300, 400};
float[] transformed = transform.query(coordinates); float[] transformed = transform.query(coordinates);
assertEquals(200f * 300 / 443, transformed[0], 0); assertEquals(200f * 300 / 443, transformed[0], 0);
assertEquals(200f * 400 / 571, transformed[1], 0); assertEquals(200f * 400 / 571, transformed[1], 0);
} }
} }

View File

@ -28,8 +28,8 @@ import org.nd4j.common.io.ClassPathResource;
import org.nd4j.common.primitives.Pair; import org.nd4j.common.primitives.Pair;
import org.datavec.image.data.ImageWritable; import org.datavec.image.data.ImageWritable;
import org.datavec.image.loader.NativeImageLoader; import org.datavec.image.loader.NativeImageLoader;
import org.junit.Ignore; import org.junit.jupiter.api.Disabled;
import org.junit.Test; import org.junit.jupiter.api.Test;
import java.awt.*; import java.awt.*;
import java.util.LinkedList; import java.util.LinkedList;
@ -40,7 +40,7 @@ import org.bytedeco.opencv.opencv_core.*;
import static org.bytedeco.opencv.global.opencv_core.*; import static org.bytedeco.opencv.global.opencv_core.*;
import static org.bytedeco.opencv.global.opencv_imgproc.*; import static org.bytedeco.opencv.global.opencv_imgproc.*;
import static org.junit.Assert.*; import static org.junit.jupiter.api.Assertions.*;
/** /**
* *
@ -255,7 +255,7 @@ public class TestImageTransform {
assertEquals(22, transformed[1], 0); assertEquals(22, transformed[1], 0);
} }
@Ignore @Disabled
@Test @Test
public void testFilterImageTransform() throws Exception { public void testFilterImageTransform() throws Exception {
ImageWritable writable = makeRandomImage(0, 0, 4); ImageWritable writable = makeRandomImage(0, 0, 4);

View File

@ -59,10 +59,10 @@
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>nd4j-tests-cpu</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-11.0</id> <id>nd4j-tests-cuda</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -57,10 +57,10 @@
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>nd4j-tests-cpu</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-11.0</id> <id>nd4j-tests-cuda</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -17,37 +17,34 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.poi.excel; package org.datavec.poi.excel;
import org.datavec.api.records.reader.RecordReader; import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.split.FileSplit; import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals; @DisplayName("Excel Record Reader Test")
import static org.junit.Assert.assertTrue; class ExcelRecordReaderTest {
public class ExcelRecordReaderTest {
@Test @Test
public void testSimple() throws Exception { @DisplayName("Test Simple")
void testSimple() throws Exception {
RecordReader excel = new ExcelRecordReader(); RecordReader excel = new ExcelRecordReader();
excel.initialize(new FileSplit(new ClassPathResource("datavec-excel/testsheet.xlsx").getFile())); excel.initialize(new FileSplit(new ClassPathResource("datavec-excel/testsheet.xlsx").getFile()));
assertTrue(excel.hasNext()); assertTrue(excel.hasNext());
List<Writable> next = excel.next(); List<Writable> next = excel.next();
assertEquals(3,next.size()); assertEquals(3, next.size());
RecordReader headerReader = new ExcelRecordReader(1); RecordReader headerReader = new ExcelRecordReader(1);
headerReader.initialize(new FileSplit(new ClassPathResource("datavec-excel/testsheetheader.xlsx").getFile())); headerReader.initialize(new FileSplit(new ClassPathResource("datavec-excel/testsheetheader.xlsx").getFile()));
assertTrue(excel.hasNext()); assertTrue(excel.hasNext());
List<Writable> next2 = excel.next(); List<Writable> next2 = excel.next();
assertEquals(3,next2.size()); assertEquals(3, next2.size());
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.poi.excel; package org.datavec.poi.excel;
import lombok.val; import lombok.val;
@ -26,44 +25,45 @@ import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.common.primitives.Triple;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.primitives.Triple;
import java.io.File; import java.io.File;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals; @DisplayName("Excel Record Writer Test")
class ExcelRecordWriterTest {
public class ExcelRecordWriterTest { @TempDir
public Path testDir;
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test @Test
public void testWriter() throws Exception { @DisplayName("Test Writer")
void testWriter() throws Exception {
ExcelRecordWriter excelRecordWriter = new ExcelRecordWriter(); ExcelRecordWriter excelRecordWriter = new ExcelRecordWriter();
val records = records(); val records = records();
File tmpDir = testDir.newFolder(); File tmpDir = testDir.toFile();
File outputFile = new File(tmpDir,"testexcel.xlsx"); File outputFile = new File(tmpDir, "testexcel.xlsx");
outputFile.deleteOnExit(); outputFile.deleteOnExit();
FileSplit fileSplit = new FileSplit(outputFile); FileSplit fileSplit = new FileSplit(outputFile);
excelRecordWriter.initialize(fileSplit,new NumberOfRecordsPartitioner()); excelRecordWriter.initialize(fileSplit, new NumberOfRecordsPartitioner());
excelRecordWriter.writeBatch(records.getRight()); excelRecordWriter.writeBatch(records.getRight());
excelRecordWriter.close(); excelRecordWriter.close();
File parentFile = outputFile.getParentFile(); File parentFile = outputFile.getParentFile();
assertEquals(1,parentFile.list().length); assertEquals(1, parentFile.list().length);
ExcelRecordReader excelRecordReader = new ExcelRecordReader(); ExcelRecordReader excelRecordReader = new ExcelRecordReader();
excelRecordReader.initialize(fileSplit); excelRecordReader.initialize(fileSplit);
List<List<Writable>> next = excelRecordReader.next(10); List<List<Writable>> next = excelRecordReader.next(10);
assertEquals(10,next.size()); assertEquals(10, next.size());
} }
private Triple<String,Schema,List<List<Writable>>> records() { private Triple<String, Schema, List<List<Writable>>> records() {
List<List<Writable>> list = new ArrayList<>(); List<List<Writable>> list = new ArrayList<>();
StringBuilder sb = new StringBuilder(); StringBuilder sb = new StringBuilder();
int numColumns = 3; int numColumns = 3;
@ -80,13 +80,10 @@ public class ExcelRecordWriterTest {
} }
list.add(temp); list.add(temp);
} }
Schema.Builder schemaBuilder = new Schema.Builder(); Schema.Builder schemaBuilder = new Schema.Builder();
for(int i = 0; i < numColumns; i++) { for (int i = 0; i < numColumns; i++) {
schemaBuilder.addColumnInteger(String.valueOf(i)); schemaBuilder.addColumnInteger(String.valueOf(i));
} }
return Triple.of(sb.toString(), schemaBuilder.build(), list);
return Triple.of(sb.toString(),schemaBuilder.build(),list);
} }
} }

View File

@ -65,10 +65,10 @@
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>nd4j-tests-cpu</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-11.0</id> <id>nd4j-tests-cuda</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -17,14 +17,12 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.records.reader.impl; package org.datavec.api.records.reader.impl;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.Assert.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.Assert.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
import java.io.File; import java.io.File;
import java.net.URI; import java.net.URI;
import java.sql.Connection; import java.sql.Connection;
@ -49,53 +47,57 @@ import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.LongWritable; import org.datavec.api.writable.LongWritable;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.After; import org.junit.jupiter.api.AfterEach;
import org.junit.Before; import org.junit.jupiter.api.BeforeEach;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
public class JDBCRecordReaderTest { import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.jupiter.api.Assertions.assertThrows;
@Rule @DisplayName("Jdbc Record Reader Test")
public TemporaryFolder testDir = new TemporaryFolder(); class JDBCRecordReaderTest {
@TempDir
public Path testDir;
Connection conn; Connection conn;
EmbeddedDataSource dataSource; EmbeddedDataSource dataSource;
private final String dbName = "datavecTests"; private final String dbName = "datavecTests";
private final String driverClassName = "org.apache.derby.jdbc.EmbeddedDriver"; private final String driverClassName = "org.apache.derby.jdbc.EmbeddedDriver";
@Before @BeforeEach
public void setUp() throws Exception { void setUp() throws Exception {
File f = testDir.newFolder(); File f = testDir.toFile();
System.setProperty("derby.system.home", f.getAbsolutePath()); System.setProperty("derby.system.home", f.getAbsolutePath());
dataSource = new EmbeddedDataSource(); dataSource = new EmbeddedDataSource();
dataSource.setDatabaseName(dbName); dataSource.setDatabaseName(dbName);
dataSource.setCreateDatabase("create"); dataSource.setCreateDatabase("create");
conn = dataSource.getConnection(); conn = dataSource.getConnection();
TestDb.dropTables(conn); TestDb.dropTables(conn);
TestDb.buildCoffeeTable(conn); TestDb.buildCoffeeTable(conn);
} }
@After @AfterEach
public void tearDown() throws Exception { void tearDown() throws Exception {
DbUtils.closeQuietly(conn); DbUtils.closeQuietly(conn);
} }
@Test @Test
public void testSimpleIter() throws Exception { @DisplayName("Test Simple Iter")
void testSimpleIter() throws Exception {
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) { try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
List<List<Writable>> records = new ArrayList<>(); List<List<Writable>> records = new ArrayList<>();
while (reader.hasNext()) { while (reader.hasNext()) {
List<Writable> values = reader.next(); List<Writable> values = reader.next();
records.add(values); records.add(values);
} }
assertFalse(records.isEmpty()); assertFalse(records.isEmpty());
List<Writable> first = records.get(0); List<Writable> first = records.get(0);
assertEquals(new Text("Bolivian Dark"), first.get(0)); assertEquals(new Text("Bolivian Dark"), first.get(0));
assertEquals(new Text("14-001"), first.get(1)); assertEquals(new Text("14-001"), first.get(1));
@ -104,39 +106,43 @@ public class JDBCRecordReaderTest {
} }
@Test @Test
public void testSimpleWithListener() throws Exception { @DisplayName("Test Simple With Listener")
void testSimpleWithListener() throws Exception {
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) { try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
RecordListener recordListener = new LogRecordListener(); RecordListener recordListener = new LogRecordListener();
reader.setListeners(recordListener); reader.setListeners(recordListener);
reader.next(); reader.next();
assertTrue(recordListener.invoked()); assertTrue(recordListener.invoked());
} }
} }
@Test @Test
public void testReset() throws Exception { @DisplayName("Test Reset")
void testReset() throws Exception {
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) { try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
List<List<Writable>> records = new ArrayList<>(); List<List<Writable>> records = new ArrayList<>();
records.add(reader.next()); records.add(reader.next());
reader.reset(); reader.reset();
records.add(reader.next()); records.add(reader.next());
assertEquals(2, records.size()); assertEquals(2, records.size());
assertEquals(new Text("Bolivian Dark"), records.get(0).get(0)); assertEquals(new Text("Bolivian Dark"), records.get(0).get(0));
assertEquals(new Text("Bolivian Dark"), records.get(1).get(0)); assertEquals(new Text("Bolivian Dark"), records.get(1).get(0));
} }
} }
@Test(expected = IllegalStateException.class) @Test
public void testLackingDataSourceShouldFail() throws Exception { @DisplayName("Test Lacking Data Source Should Fail")
void testLackingDataSourceShouldFail() {
assertThrows(IllegalStateException.class, () -> {
try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM Coffee")) { try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM Coffee")) {
reader.initialize(null); reader.initialize(null);
} }
});
} }
@Test @Test
public void testConfigurationDataSourceInitialization() throws Exception { @DisplayName("Test Configuration Data Source Initialization")
void testConfigurationDataSourceInitialization() throws Exception {
try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM Coffee")) { try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM Coffee")) {
Configuration conf = new Configuration(); Configuration conf = new Configuration();
conf.set(JDBCRecordReader.JDBC_URL, "jdbc:derby:" + dbName + ";create=true"); conf.set(JDBCRecordReader.JDBC_URL, "jdbc:derby:" + dbName + ";create=true");
@ -146,28 +152,33 @@ public class JDBCRecordReaderTest {
} }
} }
@Test(expected = IllegalArgumentException.class) @Test
public void testInitConfigurationMissingParametersShouldFail() throws Exception { @DisplayName("Test Init Configuration Missing Parameters Should Fail")
void testInitConfigurationMissingParametersShouldFail() {
assertThrows(IllegalArgumentException.class, () -> {
try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM Coffee")) { try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM Coffee")) {
Configuration conf = new Configuration(); Configuration conf = new Configuration();
conf.set(JDBCRecordReader.JDBC_URL, "should fail anyway"); conf.set(JDBCRecordReader.JDBC_URL, "should fail anyway");
reader.initialize(conf, null); reader.initialize(conf, null);
} }
} });
@Test(expected = UnsupportedOperationException.class)
public void testRecordDataInputStreamShouldFail() throws Exception {
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
reader.record(null, null);
}
} }
@Test @Test
public void testLoadFromMetaData() throws Exception { @DisplayName("Test Record Data Input Stream Should Fail")
void testRecordDataInputStreamShouldFail() {
assertThrows(UnsupportedOperationException.class, () -> {
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) { try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
RecordMetaDataJdbc rmd = new RecordMetaDataJdbc(new URI(conn.getMetaData().getURL()), reader.record(null, null);
"SELECT * FROM Coffee WHERE ProdNum = ?", Collections.singletonList("14-001"), reader.getClass()); }
});
}
@Test
@DisplayName("Test Load From Meta Data")
void testLoadFromMetaData() throws Exception {
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
RecordMetaDataJdbc rmd = new RecordMetaDataJdbc(new URI(conn.getMetaData().getURL()), "SELECT * FROM Coffee WHERE ProdNum = ?", Collections.singletonList("14-001"), reader.getClass());
Record res = reader.loadFromMetaData(rmd); Record res = reader.loadFromMetaData(rmd);
assertNotNull(res); assertNotNull(res);
assertEquals(new Text("Bolivian Dark"), res.getRecord().get(0)); assertEquals(new Text("Bolivian Dark"), res.getRecord().get(0));
@ -177,7 +188,8 @@ public class JDBCRecordReaderTest {
} }
@Test @Test
public void testNextRecord() throws Exception { @DisplayName("Test Next Record")
void testNextRecord() throws Exception {
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) { try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
Record r = reader.nextRecord(); Record r = reader.nextRecord();
List<Writable> fields = r.getRecord(); List<Writable> fields = r.getRecord();
@ -193,7 +205,8 @@ public class JDBCRecordReaderTest {
} }
@Test @Test
public void testNextRecordAndRecover() throws Exception { @DisplayName("Test Next Record And Recover")
void testNextRecordAndRecover() throws Exception {
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) { try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
Record r = reader.nextRecord(); Record r = reader.nextRecord();
List<Writable> fields = r.getRecord(); List<Writable> fields = r.getRecord();
@ -208,8 +221,10 @@ public class JDBCRecordReaderTest {
} }
// Resetting the record reader when initialized as forward only should fail // Resetting the record reader when initialized as forward only should fail
@Test(expected = RuntimeException.class) @Test
public void testResetForwardOnlyShouldFail() throws Exception { @DisplayName("Test Reset Forward Only Should Fail")
void testResetForwardOnlyShouldFail() {
assertThrows(RuntimeException.class, () -> {
try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM Coffee", dataSource)) { try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM Coffee", dataSource)) {
Configuration conf = new Configuration(); Configuration conf = new Configuration();
conf.setInt(JDBCRecordReader.JDBC_RESULTSET_TYPE, ResultSet.TYPE_FORWARD_ONLY); conf.setInt(JDBCRecordReader.JDBC_RESULTSET_TYPE, ResultSet.TYPE_FORWARD_ONLY);
@ -217,58 +232,78 @@ public class JDBCRecordReaderTest {
reader.next(); reader.next();
reader.reset(); reader.reset();
} }
});
} }
@Test @Test
public void testReadAllTypes() throws Exception { @DisplayName("Test Read All Types")
void testReadAllTypes() throws Exception {
TestDb.buildAllTypesTable(conn); TestDb.buildAllTypesTable(conn);
try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM AllTypes", dataSource)) { try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM AllTypes", dataSource)) {
reader.initialize(null); reader.initialize(null);
List<Writable> item = reader.next(); List<Writable> item = reader.next();
assertEquals(item.size(), 15); assertEquals(item.size(), 15);
assertEquals(BooleanWritable.class, item.get(0).getClass()); // boolean to boolean // boolean to boolean
assertEquals(Text.class, item.get(1).getClass()); // date to text assertEquals(BooleanWritable.class, item.get(0).getClass());
assertEquals(Text.class, item.get(2).getClass()); // time to text // date to text
assertEquals(Text.class, item.get(3).getClass()); // timestamp to text assertEquals(Text.class, item.get(1).getClass());
assertEquals(Text.class, item.get(4).getClass()); // char to text // time to text
assertEquals(Text.class, item.get(5).getClass()); // long varchar to text assertEquals(Text.class, item.get(2).getClass());
assertEquals(Text.class, item.get(6).getClass()); // varchar to text // timestamp to text
assertEquals(DoubleWritable.class, assertEquals(Text.class, item.get(3).getClass());
item.get(7).getClass()); // float to double (derby's float is an alias of double by default) // char to text
assertEquals(FloatWritable.class, item.get(8).getClass()); // real to float assertEquals(Text.class, item.get(4).getClass());
assertEquals(DoubleWritable.class, item.get(9).getClass()); // decimal to double // long varchar to text
assertEquals(DoubleWritable.class, item.get(10).getClass()); // numeric to double assertEquals(Text.class, item.get(5).getClass());
assertEquals(DoubleWritable.class, item.get(11).getClass()); // double to double // varchar to text
assertEquals(IntWritable.class, item.get(12).getClass()); // integer to integer assertEquals(Text.class, item.get(6).getClass());
assertEquals(IntWritable.class, item.get(13).getClass()); // small int to integer assertEquals(DoubleWritable.class, // float to double (derby's float is an alias of double by default)
assertEquals(LongWritable.class, item.get(14).getClass()); // bigint to long item.get(7).getClass());
// real to float
assertEquals(FloatWritable.class, item.get(8).getClass());
// decimal to double
assertEquals(DoubleWritable.class, item.get(9).getClass());
// numeric to double
assertEquals(DoubleWritable.class, item.get(10).getClass());
// double to double
assertEquals(DoubleWritable.class, item.get(11).getClass());
// integer to integer
assertEquals(IntWritable.class, item.get(12).getClass());
// small int to integer
assertEquals(IntWritable.class, item.get(13).getClass());
// bigint to long
assertEquals(LongWritable.class, item.get(14).getClass());
} }
} }
@Test(expected = RuntimeException.class) @Test
public void testNextNoMoreShouldFail() throws Exception { @DisplayName("Test Next No More Should Fail")
void testNextNoMoreShouldFail() {
assertThrows(RuntimeException.class, () -> {
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) { try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
while (reader.hasNext()) { while (reader.hasNext()) {
reader.next(); reader.next();
} }
reader.next(); reader.next();
} }
});
} }
@Test(expected = IllegalArgumentException.class) @Test
public void testInvalidMetadataShouldFail() throws Exception { @DisplayName("Test Invalid Metadata Should Fail")
void testInvalidMetadataShouldFail() {
assertThrows(IllegalArgumentException.class, () -> {
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) { try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
RecordMetaDataLine md = new RecordMetaDataLine(1, new URI("file://test"), JDBCRecordReader.class); RecordMetaDataLine md = new RecordMetaDataLine(1, new URI("file://test"), JDBCRecordReader.class);
reader.loadFromMetaData(md); reader.loadFromMetaData(md);
} }
});
} }
private JDBCRecordReader getInitializedReader(String query) throws Exception { private JDBCRecordReader getInitializedReader(String query) throws Exception {
int[] indices = {1}; // ProdNum column // ProdNum column
JDBCRecordReader reader = new JDBCRecordReader(query, dataSource, "SELECT * FROM Coffee WHERE ProdNum = ?", int[] indices = { 1 };
indices); JDBCRecordReader reader = new JDBCRecordReader(query, dataSource, "SELECT * FROM Coffee WHERE ProdNum = ?", indices);
reader.setTrimStrings(true); reader.setTrimStrings(true);
reader.initialize(null); reader.initialize(null);
return reader; return reader;

View File

@ -61,25 +61,18 @@
<artifactId>nd4j-common</artifactId> <artifactId>nd4j-common</artifactId>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.datavec</groupId> <groupId>org.nd4j</groupId>
<artifactId>datavec-geo</artifactId> <artifactId>python4j-numpy</artifactId>
<version>${project.version}</version> <version>${project.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-python</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency> </dependency>
</dependencies> </dependencies>
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>nd4j-tests-cpu</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-11.0</id> <id>nd4j-tests-cuda</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -36,14 +36,14 @@ import org.datavec.api.writable.LongWritable;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.joda.time.DateTimeZone; import org.joda.time.DateTimeZone;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class LocalTransformProcessRecordReaderTests { public class LocalTransformProcessRecordReaderTests {

View File

@ -29,9 +29,9 @@ import org.datavec.api.transform.schema.Schema;
import org.datavec.api.util.ndarray.RecordConverter; import org.datavec.api.util.ndarray.RecordConverter;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.datavec.local.transforms.AnalyzeLocal; import org.datavec.local.transforms.AnalyzeLocal;
import org.junit.Rule;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
@ -39,12 +39,11 @@ import org.nd4j.common.io.ClassPathResource;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestAnalyzeLocal { public class TestAnalyzeLocal {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test @Test
public void testAnalysisBasic() throws Exception { public void testAnalysisBasic() throws Exception {
@ -72,7 +71,7 @@ public class TestAnalyzeLocal {
INDArray mean = arr.mean(0); INDArray mean = arr.mean(0);
INDArray std = arr.std(0); INDArray std = arr.std(0);
for( int i=0; i<5; i++ ){ for( int i = 0; i < 5; i++) {
double m = ((NumericalColumnAnalysis)da.getColumnAnalysis().get(i)).getMean(); double m = ((NumericalColumnAnalysis)da.getColumnAnalysis().get(i)).getMean();
double stddev = ((NumericalColumnAnalysis)da.getColumnAnalysis().get(i)).getSampleStdev(); double stddev = ((NumericalColumnAnalysis)da.getColumnAnalysis().get(i)).getSampleStdev();
assertEquals(mean.getDouble(i), m, 1e-3); assertEquals(mean.getDouble(i), m, 1e-3);

View File

@ -27,7 +27,7 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit; import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import java.io.File; import java.io.File;
@ -36,8 +36,8 @@ import java.util.List;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
public class TestLineRecordReaderFunction { public class TestLineRecordReaderFunction {

View File

@ -25,7 +25,7 @@ import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.datavec.local.transforms.misc.NDArrayToWritablesFunction; import org.datavec.local.transforms.misc.NDArrayToWritablesFunction;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -33,7 +33,7 @@ import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestNDArrayToWritablesFunction { public class TestNDArrayToWritablesFunction {

View File

@ -25,7 +25,7 @@ import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.datavec.local.transforms.misc.WritablesToNDArrayFunction; import org.datavec.local.transforms.misc.WritablesToNDArrayFunction;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -33,7 +33,7 @@ import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestWritablesToNDArrayFunction { public class TestWritablesToNDArrayFunction {

View File

@ -30,12 +30,12 @@ import org.datavec.api.writable.Writable;
import org.datavec.local.transforms.misc.SequenceWritablesToStringFunction; import org.datavec.local.transforms.misc.SequenceWritablesToStringFunction;
import org.datavec.local.transforms.misc.WritablesToStringFunction; import org.datavec.local.transforms.misc.WritablesToStringFunction;
import org.junit.Test; import org.junit.jupiter.api.Test;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestWritablesToStringFunctions { public class TestWritablesToStringFunctions {

View File

@ -17,10 +17,8 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.local.transforms.transform; package org.datavec.local.transforms.transform;
import org.datavec.api.transform.MathFunction; import org.datavec.api.transform.MathFunction;
import org.datavec.api.transform.MathOp; import org.datavec.api.transform.MathOp;
import org.datavec.api.transform.ReduceOp; import org.datavec.api.transform.ReduceOp;
@ -31,108 +29,85 @@ import org.datavec.api.transform.reduce.Reducer;
import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.schema.SequenceSchema; import org.datavec.api.transform.schema.SequenceSchema;
import org.datavec.api.writable.*; import org.datavec.api.writable.*;
import org.datavec.python.PythonTransform;
import org.datavec.local.transforms.LocalTransformExecutor; import org.datavec.local.transforms.LocalTransformExecutor;
import org.junit.Ignore; import org.junit.jupiter.api.Disabled;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms; import org.nd4j.linalg.ops.transforms.Transforms;
import java.util.*; import java.util.*;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import static java.time.Duration.ofMillis;
import static org.junit.jupiter.api.Assertions.assertTimeout;
import static org.junit.Assert.assertEquals; @DisplayName("Execution Test")
class ExecutionTest {
public class ExecutionTest {
@Test @Test
public void testExecutionNdarray() { @DisplayName("Test Execution Ndarray")
Schema schema = new Schema.Builder() void testExecutionNdarray() {
.addColumnNDArray("first",new long[]{1,32577}) Schema schema = new Schema.Builder().addColumnNDArray("first", new long[] { 1, 32577 }).addColumnNDArray("second", new long[] { 1, 32577 }).build();
.addColumnNDArray("second",new long[]{1,32577}).build(); TransformProcess transformProcess = new TransformProcess.Builder(schema).ndArrayMathFunctionTransform("first", MathFunction.SIN).ndArrayMathFunctionTransform("second", MathFunction.COS).build();
TransformProcess transformProcess = new TransformProcess.Builder(schema)
.ndArrayMathFunctionTransform("first", MathFunction.SIN)
.ndArrayMathFunctionTransform("second",MathFunction.COS)
.build();
List<List<Writable>> functions = new ArrayList<>(); List<List<Writable>> functions = new ArrayList<>();
List<Writable> firstRow = new ArrayList<>(); List<Writable> firstRow = new ArrayList<>();
INDArray firstArr = Nd4j.linspace(1,4,4); INDArray firstArr = Nd4j.linspace(1, 4, 4);
INDArray secondArr = Nd4j.linspace(1,4,4); INDArray secondArr = Nd4j.linspace(1, 4, 4);
firstRow.add(new NDArrayWritable(firstArr)); firstRow.add(new NDArrayWritable(firstArr));
firstRow.add(new NDArrayWritable(secondArr)); firstRow.add(new NDArrayWritable(secondArr));
functions.add(firstRow); functions.add(firstRow);
List<List<Writable>> execute = LocalTransformExecutor.execute(functions, transformProcess); List<List<Writable>> execute = LocalTransformExecutor.execute(functions, transformProcess);
INDArray firstResult = ((NDArrayWritable) execute.get(0).get(0)).get(); INDArray firstResult = ((NDArrayWritable) execute.get(0).get(0)).get();
INDArray secondResult = ((NDArrayWritable) execute.get(0).get(1)).get(); INDArray secondResult = ((NDArrayWritable) execute.get(0).get(1)).get();
INDArray expected = Transforms.sin(firstArr); INDArray expected = Transforms.sin(firstArr);
INDArray secondExpected = Transforms.cos(secondArr); INDArray secondExpected = Transforms.cos(secondArr);
assertEquals(expected,firstResult); assertEquals(expected, firstResult);
assertEquals(secondExpected,secondResult); assertEquals(secondExpected, secondResult);
} }
@Test @Test
public void testExecutionSimple() { @DisplayName("Test Execution Simple")
Schema schema = new Schema.Builder().addColumnInteger("col0") void testExecutionSimple() {
.addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2"). Schema schema = new Schema.Builder().addColumnInteger("col0").addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").addColumnFloat("col3").build();
addColumnFloat("col3").build(); TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1").doubleMathOp("col2", MathOp.Add, 10.0).floatMathOp("col3", MathOp.Add, 5f).build();
TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1")
.doubleMathOp("col2", MathOp.Add, 10.0).floatMathOp("col3", MathOp.Add, 5f).build();
List<List<Writable>> inputData = new ArrayList<>(); List<List<Writable>> inputData = new ArrayList<>();
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1), new FloatWritable(0.3f))); inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1), new FloatWritable(0.3f)));
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1), new FloatWritable(1.7f))); inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1), new FloatWritable(1.7f)));
inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1), new FloatWritable(3.6f))); inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1), new FloatWritable(3.6f)));
List<List<Writable>> rdd = (inputData); List<List<Writable>> rdd = (inputData);
List<List<Writable>> out = new ArrayList<>(LocalTransformExecutor.execute(rdd, tp)); List<List<Writable>> out = new ArrayList<>(LocalTransformExecutor.execute(rdd, tp));
Collections.sort(out, new Comparator<List<Writable>>() { Collections.sort(out, new Comparator<List<Writable>>() {
@Override @Override
public int compare(List<Writable> o1, List<Writable> o2) { public int compare(List<Writable> o1, List<Writable> o2) {
return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt()); return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
} }
}); });
List<List<Writable>> expected = new ArrayList<>(); List<List<Writable>> expected = new ArrayList<>();
expected.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1), new FloatWritable(5.3f))); expected.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1), new FloatWritable(5.3f)));
expected.add(Arrays.<Writable>asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1), new FloatWritable(6.7f))); expected.add(Arrays.<Writable>asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1), new FloatWritable(6.7f)));
expected.add(Arrays.<Writable>asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1), new FloatWritable(8.6f))); expected.add(Arrays.<Writable>asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1), new FloatWritable(8.6f)));
assertEquals(expected, out); assertEquals(expected, out);
} }
@Test @Test
public void testFilter() { @DisplayName("Test Filter")
Schema filterSchema = new Schema.Builder() void testFilter() {
.addColumnDouble("col1").addColumnDouble("col2") Schema filterSchema = new Schema.Builder().addColumnDouble("col1").addColumnDouble("col2").addColumnDouble("col3").build();
.addColumnDouble("col3").build();
List<List<Writable>> inputData = new ArrayList<>(); List<List<Writable>> inputData = new ArrayList<>();
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new DoubleWritable(1), new DoubleWritable(0.1))); inputData.add(Arrays.<Writable>asList(new IntWritable(0), new DoubleWritable(1), new DoubleWritable(0.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new DoubleWritable(3), new DoubleWritable(1.1))); inputData.add(Arrays.<Writable>asList(new IntWritable(1), new DoubleWritable(3), new DoubleWritable(1.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(2), new DoubleWritable(3), new DoubleWritable(2.1))); inputData.add(Arrays.<Writable>asList(new IntWritable(2), new DoubleWritable(3), new DoubleWritable(2.1)));
TransformProcess transformProcess = new TransformProcess.Builder(filterSchema) TransformProcess transformProcess = new TransformProcess.Builder(filterSchema).filter(new DoubleColumnCondition("col1", ConditionOp.LessThan, 1)).build();
.filter(new DoubleColumnCondition("col1",ConditionOp.LessThan,1)).build();
List<List<Writable>> execute = LocalTransformExecutor.execute(inputData, transformProcess); List<List<Writable>> execute = LocalTransformExecutor.execute(inputData, transformProcess);
assertEquals(2,execute.size()); assertEquals(2, execute.size());
} }
@Test @Test
public void testExecutionSequence() { @DisplayName("Test Execution Sequence")
void testExecutionSequence() {
Schema schema = new SequenceSchema.Builder().addColumnInteger("col0") Schema schema = new SequenceSchema.Builder().addColumnInteger("col0").addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").build();
.addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").build(); TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1").doubleMathOp("col2", MathOp.Add, 10.0).build();
TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1")
.doubleMathOp("col2", MathOp.Add, 10.0).build();
List<List<List<Writable>>> inputSequences = new ArrayList<>(); List<List<List<Writable>>> inputSequences = new ArrayList<>();
List<List<Writable>> seq1 = new ArrayList<>(); List<List<Writable>> seq1 = new ArrayList<>();
seq1.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); seq1.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
@ -141,21 +116,17 @@ public class ExecutionTest {
List<List<Writable>> seq2 = new ArrayList<>(); List<List<Writable>> seq2 = new ArrayList<>();
seq2.add(Arrays.<Writable>asList(new IntWritable(3), new Text("state0"), new DoubleWritable(3.1))); seq2.add(Arrays.<Writable>asList(new IntWritable(3), new Text("state0"), new DoubleWritable(3.1)));
seq2.add(Arrays.<Writable>asList(new IntWritable(4), new Text("state1"), new DoubleWritable(4.1))); seq2.add(Arrays.<Writable>asList(new IntWritable(4), new Text("state1"), new DoubleWritable(4.1)));
inputSequences.add(seq1); inputSequences.add(seq1);
inputSequences.add(seq2); inputSequences.add(seq2);
List<List<List<Writable>>> rdd = (inputSequences); List<List<List<Writable>>> rdd = (inputSequences);
List<List<List<Writable>>> out = LocalTransformExecutor.executeSequenceToSequence(rdd, tp); List<List<List<Writable>>> out = LocalTransformExecutor.executeSequenceToSequence(rdd, tp);
Collections.sort(out, new Comparator<List<List<Writable>>>() { Collections.sort(out, new Comparator<List<List<Writable>>>() {
@Override @Override
public int compare(List<List<Writable>> o1, List<List<Writable>> o2) { public int compare(List<List<Writable>> o1, List<List<Writable>> o2) {
return -Integer.compare(o1.size(), o2.size()); return -Integer.compare(o1.size(), o2.size());
} }
}); });
List<List<List<Writable>>> expectedSequence = new ArrayList<>(); List<List<List<Writable>>> expectedSequence = new ArrayList<>();
List<List<Writable>> seq1e = new ArrayList<>(); List<List<Writable>> seq1e = new ArrayList<>();
seq1e.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1))); seq1e.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1)));
@ -164,121 +135,37 @@ public class ExecutionTest {
List<List<Writable>> seq2e = new ArrayList<>(); List<List<Writable>> seq2e = new ArrayList<>();
seq2e.add(Arrays.<Writable>asList(new IntWritable(3), new IntWritable(0), new DoubleWritable(13.1))); seq2e.add(Arrays.<Writable>asList(new IntWritable(3), new IntWritable(0), new DoubleWritable(13.1)));
seq2e.add(Arrays.<Writable>asList(new IntWritable(4), new IntWritable(1), new DoubleWritable(14.1))); seq2e.add(Arrays.<Writable>asList(new IntWritable(4), new IntWritable(1), new DoubleWritable(14.1)));
expectedSequence.add(seq1e); expectedSequence.add(seq1e);
expectedSequence.add(seq2e); expectedSequence.add(seq2e);
assertEquals(expectedSequence, out); assertEquals(expectedSequence, out);
} }
@Test @Test
public void testReductionGlobal() { @DisplayName("Test Reduction Global")
void testReductionGlobal() {
List<List<Writable>> in = Arrays.asList( List<List<Writable>> in = Arrays.asList(Arrays.<Writable>asList(new Text("first"), new DoubleWritable(3.0)), Arrays.<Writable>asList(new Text("second"), new DoubleWritable(5.0)));
Arrays.<Writable>asList(new Text("first"), new DoubleWritable(3.0)),
Arrays.<Writable>asList(new Text("second"), new DoubleWritable(5.0))
);
List<List<Writable>> inData = in; List<List<Writable>> inData = in;
Schema s = new Schema.Builder().addColumnString("textCol").addColumnDouble("doubleCol").build();
Schema s = new Schema.Builder() TransformProcess tp = new TransformProcess.Builder(s).reduce(new Reducer.Builder(ReduceOp.TakeFirst).takeFirstColumns("textCol").meanColumns("doubleCol").build()).build();
.addColumnString("textCol")
.addColumnDouble("doubleCol")
.build();
TransformProcess tp = new TransformProcess.Builder(s)
.reduce(new Reducer.Builder(ReduceOp.TakeFirst)
.takeFirstColumns("textCol")
.meanColumns("doubleCol").build())
.build();
List<List<Writable>> outRdd = LocalTransformExecutor.execute(inData, tp); List<List<Writable>> outRdd = LocalTransformExecutor.execute(inData, tp);
List<List<Writable>> out = outRdd; List<List<Writable>> out = outRdd;
List<List<Writable>> expOut = Collections.singletonList(Arrays.<Writable>asList(new Text("first"), new DoubleWritable(4.0))); List<List<Writable>> expOut = Collections.singletonList(Arrays.<Writable>asList(new Text("first"), new DoubleWritable(4.0)));
assertEquals(expOut, out); assertEquals(expOut, out);
} }
@Test @Test
public void testReductionByKey(){ @DisplayName("Test Reduction By Key")
void testReductionByKey() {
List<List<Writable>> in = Arrays.asList( List<List<Writable>> in = Arrays.asList(Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(3.0)), Arrays.<Writable>asList(new IntWritable(0), new Text("second"), new DoubleWritable(5.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(30.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("s"), new DoubleWritable(50.0)));
Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(3.0)),
Arrays.<Writable>asList(new IntWritable(0), new Text("second"), new DoubleWritable(5.0)),
Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(30.0)),
Arrays.<Writable>asList(new IntWritable(1), new Text("s"), new DoubleWritable(50.0))
);
List<List<Writable>> inData = in; List<List<Writable>> inData = in;
Schema s = new Schema.Builder().addColumnInteger("intCol").addColumnString("textCol").addColumnDouble("doubleCol").build();
Schema s = new Schema.Builder() TransformProcess tp = new TransformProcess.Builder(s).reduce(new Reducer.Builder(ReduceOp.TakeFirst).keyColumns("intCol").takeFirstColumns("textCol").meanColumns("doubleCol").build()).build();
.addColumnInteger("intCol")
.addColumnString("textCol")
.addColumnDouble("doubleCol")
.build();
TransformProcess tp = new TransformProcess.Builder(s)
.reduce(new Reducer.Builder(ReduceOp.TakeFirst)
.keyColumns("intCol")
.takeFirstColumns("textCol")
.meanColumns("doubleCol").build())
.build();
List<List<Writable>> outRdd = LocalTransformExecutor.execute(inData, tp); List<List<Writable>> outRdd = LocalTransformExecutor.execute(inData, tp);
List<List<Writable>> out = outRdd; List<List<Writable>> out = outRdd;
List<List<Writable>> expOut = Arrays.asList(Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(4.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(40.0)));
List<List<Writable>> expOut = Arrays.asList(
Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(4.0)),
Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(40.0)));
out = new ArrayList<>(out); out = new ArrayList<>(out);
Collections.sort( Collections.sort(out, Comparator.comparingInt(o -> o.get(0).toInt()));
out, new Comparator<List<Writable>>() {
@Override
public int compare(List<Writable> o1, List<Writable> o2) {
return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
}
}
);
assertEquals(expOut, out); assertEquals(expOut, out);
} }
@Test(timeout = 60000L)
@Ignore("AB 2019/05/21 - Fine locally, timeouts on CI - Issue #7657 and #7771")
public void testPythonExecutionNdarray()throws Exception{
Schema schema = new Schema.Builder()
.addColumnNDArray("first",new long[]{1,32577})
.addColumnNDArray("second",new long[]{1,32577}).build();
TransformProcess transformProcess = new TransformProcess.Builder(schema)
.transform(
PythonTransform.builder().code(
"first = np.sin(first)\nsecond = np.cos(second)")
.outputSchema(schema).build())
.build();
List<List<Writable>> functions = new ArrayList<>();
List<Writable> firstRow = new ArrayList<>();
INDArray firstArr = Nd4j.linspace(1,4,4);
INDArray secondArr = Nd4j.linspace(1,4,4);
firstRow.add(new NDArrayWritable(firstArr));
firstRow.add(new NDArrayWritable(secondArr));
functions.add(firstRow);
List<List<Writable>> execute = LocalTransformExecutor.execute(functions, transformProcess);
INDArray firstResult = ((NDArrayWritable) execute.get(0).get(0)).get();
INDArray secondResult = ((NDArrayWritable) execute.get(0).get(1)).get();
INDArray expected = Transforms.sin(firstArr);
INDArray secondExpected = Transforms.cos(secondArr);
assertEquals(expected,firstResult);
assertEquals(secondExpected,secondResult);
}
} }

View File

@ -1,154 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.local.transforms.transform;
import org.datavec.api.transform.ColumnType;
import org.datavec.api.transform.Transform;
import org.datavec.api.transform.geo.LocationType;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.transform.geo.CoordinatesDistanceTransform;
import org.datavec.api.transform.transform.geo.IPAddressToCoordinatesTransform;
import org.datavec.api.transform.transform.geo.IPAddressToLocationTransform;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
import org.nd4j.common.io.ClassPathResource;
import java.io.*;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import static org.junit.Assert.assertEquals;
/**
* @author saudet
*/
public class TestGeoTransforms {
@BeforeClass
public static void beforeClass() throws Exception {
//Use test resources version to avoid tests suddenly failing due to IP/Location DB content changing
File f = new ClassPathResource("datavec-geo/GeoIP2-City-Test.mmdb").getFile();
System.setProperty(IPAddressToLocationTransform.GEOIP_FILE_PROPERTY, f.getPath());
}
@AfterClass
public static void afterClass(){
System.setProperty(IPAddressToLocationTransform.GEOIP_FILE_PROPERTY, "");
}
@Test
public void testCoordinatesDistanceTransform() throws Exception {
Schema schema = new Schema.Builder().addColumnString("point").addColumnString("mean").addColumnString("stddev")
.build();
Transform transform = new CoordinatesDistanceTransform("dist", "point", "mean", "stddev", "\\|");
transform.setInputSchema(schema);
Schema out = transform.transform(schema);
assertEquals(4, out.numColumns());
assertEquals(Arrays.asList("point", "mean", "stddev", "dist"), out.getColumnNames());
assertEquals(Arrays.asList(ColumnType.String, ColumnType.String, ColumnType.String, ColumnType.Double),
out.getColumnTypes());
assertEquals(Arrays.asList((Writable) new Text("-30"), new Text("20"), new Text("10"), new DoubleWritable(5.0)),
transform.map(Arrays.asList((Writable) new Text("-30"), new Text("20"), new Text("10"))));
assertEquals(Arrays.asList((Writable) new Text("50|40"), new Text("10|-20"), new Text("10|5"),
new DoubleWritable(Math.sqrt(160))),
transform.map(Arrays.asList((Writable) new Text("50|40"), new Text("10|-20"),
new Text("10|5"))));
}
@Test
public void testIPAddressToCoordinatesTransform() throws Exception {
Schema schema = new Schema.Builder().addColumnString("column").build();
Transform transform = new IPAddressToCoordinatesTransform("column", "CUSTOM_DELIMITER");
transform.setInputSchema(schema);
Schema out = transform.transform(schema);
assertEquals(1, out.getColumnMetaData().size());
assertEquals(ColumnType.String, out.getMetaData(0).getColumnType());
String in = "81.2.69.160";
double latitude = 51.5142;
double longitude = -0.0931;
List<Writable> writables = transform.map(Collections.singletonList((Writable) new Text(in)));
assertEquals(1, writables.size());
String[] coordinates = writables.get(0).toString().split("CUSTOM_DELIMITER");
assertEquals(2, coordinates.length);
assertEquals(latitude, Double.parseDouble(coordinates[0]), 0.1);
assertEquals(longitude, Double.parseDouble(coordinates[1]), 0.1);
//Check serialization: things like DatabaseReader etc aren't serializable, hence we need custom serialization :/
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(baos);
oos.writeObject(transform);
byte[] bytes = baos.toByteArray();
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
ObjectInputStream ois = new ObjectInputStream(bais);
Transform deserialized = (Transform) ois.readObject();
writables = deserialized.map(Collections.singletonList((Writable) new Text(in)));
assertEquals(1, writables.size());
coordinates = writables.get(0).toString().split("CUSTOM_DELIMITER");
//System.out.println(Arrays.toString(coordinates));
assertEquals(2, coordinates.length);
assertEquals(latitude, Double.parseDouble(coordinates[0]), 0.1);
assertEquals(longitude, Double.parseDouble(coordinates[1]), 0.1);
}
@Test
public void testIPAddressToLocationTransform() throws Exception {
Schema schema = new Schema.Builder().addColumnString("column").build();
LocationType[] locationTypes = LocationType.values();
String in = "81.2.69.160";
String[] locations = {"London", "2643743", "Europe", "6255148", "United Kingdom", "2635167",
"51.5142:-0.0931", "", "England", "6269131"}; //Note: no postcode in this test DB for this record
for (int i = 0; i < locationTypes.length; i++) {
LocationType locationType = locationTypes[i];
String location = locations[i];
Transform transform = new IPAddressToLocationTransform("column", locationType);
transform.setInputSchema(schema);
Schema out = transform.transform(schema);
assertEquals(1, out.getColumnMetaData().size());
assertEquals(ColumnType.String, out.getMetaData(0).getColumnType());
List<Writable> writables = transform.map(Collections.singletonList((Writable) new Text(in)));
assertEquals(1, writables.size());
assertEquals(location, writables.get(0).toString());
//System.out.println(location);
}
}
}

View File

@ -1,379 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.local.transforms.transform;
import org.datavec.api.transform.TransformProcess;
import org.datavec.api.transform.condition.Condition;
import org.datavec.api.transform.filter.ConditionFilter;
import org.datavec.api.transform.filter.Filter;
import org.datavec.api.transform.schema.Schema;
import org.datavec.local.transforms.LocalTransformExecutor;
import org.datavec.api.writable.*;
import org.datavec.python.PythonCondition;
import org.datavec.python.PythonTransform;
import org.junit.Test;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import javax.annotation.concurrent.NotThreadSafe;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import static junit.framework.TestCase.assertTrue;
import static org.datavec.api.transform.schema.Schema.Builder;
import static org.junit.Assert.*;
@NotThreadSafe
public class TestPythonTransformProcess {
@Test()
public void testStringConcat() throws Exception{
Builder schemaBuilder = new Builder();
schemaBuilder
.addColumnString("col1")
.addColumnString("col2");
Schema initialSchema = schemaBuilder.build();
schemaBuilder.addColumnString("col3");
Schema finalSchema = schemaBuilder.build();
String pythonCode = "col3 = col1 + col2";
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
PythonTransform.builder().code(pythonCode)
.outputSchema(finalSchema)
.build()
).build();
List<Writable> inputs = Arrays.asList((Writable)new Text("Hello "), new Text("World!"));
List<Writable> outputs = tp.execute(inputs);
assertEquals((outputs.get(0)).toString(), "Hello ");
assertEquals((outputs.get(1)).toString(), "World!");
assertEquals((outputs.get(2)).toString(), "Hello World!");
}
@Test(timeout = 60000L)
public void testMixedTypes() throws Exception{
Builder schemaBuilder = new Builder();
schemaBuilder
.addColumnInteger("col1")
.addColumnFloat("col2")
.addColumnString("col3")
.addColumnDouble("col4");
Schema initialSchema = schemaBuilder.build();
schemaBuilder.addColumnInteger("col5");
Schema finalSchema = schemaBuilder.build();
String pythonCode = "col5 = (int(col3) + col1 + int(col2)) * int(col4)";
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
PythonTransform.builder().code(pythonCode)
.outputSchema(finalSchema)
.inputSchema(initialSchema)
.build() ).build();
List<Writable> inputs = Arrays.asList((Writable)new IntWritable(10),
new FloatWritable(3.5f),
new Text("5"),
new DoubleWritable(2.0)
);
List<Writable> outputs = tp.execute(inputs);
assertEquals(((LongWritable)outputs.get(4)).get(), 36);
}
@Test(timeout = 60000L)
public void testNDArray() throws Exception{
long[] shape = new long[]{3, 2};
INDArray arr1 = Nd4j.rand(shape);
INDArray arr2 = Nd4j.rand(shape);
INDArray expectedOutput = arr1.add(arr2);
Builder schemaBuilder = new Builder();
schemaBuilder
.addColumnNDArray("col1", shape)
.addColumnNDArray("col2", shape);
Schema initialSchema = schemaBuilder.build();
schemaBuilder.addColumnNDArray("col3", shape);
Schema finalSchema = schemaBuilder.build();
String pythonCode = "col3 = col1 + col2";
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
PythonTransform.builder().code(pythonCode)
.outputSchema(finalSchema)
.build() ).build();
List<Writable> inputs = Arrays.asList(
(Writable)
new NDArrayWritable(arr1),
new NDArrayWritable(arr2)
);
List<Writable> outputs = tp.execute(inputs);
assertEquals(arr1, ((NDArrayWritable)outputs.get(0)).get());
assertEquals(arr2, ((NDArrayWritable)outputs.get(1)).get());
assertEquals(expectedOutput,((NDArrayWritable)outputs.get(2)).get());
}
@Test(timeout = 60000L)
public void testNDArray2() throws Exception{
long[] shape = new long[]{3, 2};
INDArray arr1 = Nd4j.rand(shape);
INDArray arr2 = Nd4j.rand(shape);
INDArray expectedOutput = arr1.add(arr2);
Builder schemaBuilder = new Builder();
schemaBuilder
.addColumnNDArray("col1", shape)
.addColumnNDArray("col2", shape);
Schema initialSchema = schemaBuilder.build();
schemaBuilder.addColumnNDArray("col3", shape);
Schema finalSchema = schemaBuilder.build();
String pythonCode = "col3 = col1 + col2";
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
PythonTransform.builder().code(pythonCode)
.outputSchema(finalSchema)
.build() ).build();
List<Writable> inputs = Arrays.asList(
(Writable)
new NDArrayWritable(arr1),
new NDArrayWritable(arr2)
);
List<Writable> outputs = tp.execute(inputs);
assertEquals(arr1, ((NDArrayWritable)outputs.get(0)).get());
assertEquals(arr2, ((NDArrayWritable)outputs.get(1)).get());
assertEquals(expectedOutput,((NDArrayWritable)outputs.get(2)).get());
}
@Test(timeout = 60000L)
public void testNDArrayMixed() throws Exception{
long[] shape = new long[]{3, 2};
INDArray arr1 = Nd4j.rand(DataType.DOUBLE, shape);
INDArray arr2 = Nd4j.rand(DataType.DOUBLE, shape);
INDArray expectedOutput = arr1.add(arr2.castTo(DataType.DOUBLE));
Builder schemaBuilder = new Builder();
schemaBuilder
.addColumnNDArray("col1", shape)
.addColumnNDArray("col2", shape);
Schema initialSchema = schemaBuilder.build();
schemaBuilder.addColumnNDArray("col3", shape);
Schema finalSchema = schemaBuilder.build();
String pythonCode = "col3 = col1 + col2";
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
PythonTransform.builder().code(pythonCode)
.outputSchema(finalSchema)
.build()
).build();
List<Writable> inputs = Arrays.asList(
(Writable)
new NDArrayWritable(arr1),
new NDArrayWritable(arr2)
);
List<Writable> outputs = tp.execute(inputs);
assertEquals(arr1, ((NDArrayWritable)outputs.get(0)).get());
assertEquals(arr2, ((NDArrayWritable)outputs.get(1)).get());
assertEquals(expectedOutput,((NDArrayWritable)outputs.get(2)).get());
}
@Test(timeout = 60000L)
public void testPythonFilter() {
Schema schema = new Builder().addColumnInteger("column").build();
Condition condition = new PythonCondition(
"f = lambda: column < 0"
);
condition.setInputSchema(schema);
Filter filter = new ConditionFilter(condition);
assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(10))));
assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(1))));
assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(0))));
assertTrue(filter.removeExample(Collections.singletonList(new IntWritable(-1))));
assertTrue(filter.removeExample(Collections.singletonList(new IntWritable(-10))));
}
@Test(timeout = 60000L)
public void testPythonFilterAndTransform() throws Exception{
Builder schemaBuilder = new Builder();
schemaBuilder
.addColumnInteger("col1")
.addColumnFloat("col2")
.addColumnString("col3")
.addColumnDouble("col4");
Schema initialSchema = schemaBuilder.build();
schemaBuilder.addColumnString("col6");
Schema finalSchema = schemaBuilder.build();
Condition condition = new PythonCondition(
"f = lambda: col1 < 0 and col2 > 10.0"
);
condition.setInputSchema(initialSchema);
Filter filter = new ConditionFilter(condition);
String pythonCode = "col6 = str(col1 + col2)";
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
PythonTransform.builder().code(pythonCode)
.outputSchema(finalSchema)
.build()
).filter(
filter
).build();
List<List<Writable>> inputs = new ArrayList<>();
inputs.add(
Arrays.asList(
(Writable)
new IntWritable(5),
new FloatWritable(3.0f),
new Text("abcd"),
new DoubleWritable(2.1))
);
inputs.add(
Arrays.asList(
(Writable)
new IntWritable(-3),
new FloatWritable(3.0f),
new Text("abcd"),
new DoubleWritable(2.1))
);
inputs.add(
Arrays.asList(
(Writable)
new IntWritable(5),
new FloatWritable(11.2f),
new Text("abcd"),
new DoubleWritable(2.1))
);
LocalTransformExecutor.execute(inputs,tp);
}
@Test
public void testPythonTransformNoOutputSpecified() throws Exception {
PythonTransform pythonTransform = PythonTransform.builder()
.code("a += 2; b = 'hello world'")
.returnAllInputs(true)
.build();
List<List<Writable>> inputs = new ArrayList<>();
inputs.add(Arrays.asList((Writable)new IntWritable(1)));
Schema inputSchema = new Builder()
.addColumnInteger("a")
.build();
TransformProcess tp = new TransformProcess.Builder(inputSchema)
.transform(pythonTransform)
.build();
List<List<Writable>> execute = LocalTransformExecutor.execute(inputs, tp);
assertEquals(3,execute.get(0).get(0).toInt());
assertEquals("hello world",execute.get(0).get(1).toString());
}
@Test
public void testNumpyTransform() {
PythonTransform pythonTransform = PythonTransform.builder()
.code("a += 2; b = 'hello world'")
.returnAllInputs(true)
.build();
List<List<Writable>> inputs = new ArrayList<>();
inputs.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.scalar(1).reshape(1,1))));
Schema inputSchema = new Builder()
.addColumnNDArray("a",new long[]{1,1})
.build();
TransformProcess tp = new TransformProcess.Builder(inputSchema)
.transform(pythonTransform)
.build();
List<List<Writable>> execute = LocalTransformExecutor.execute(inputs, tp);
assertFalse(execute.isEmpty());
assertNotNull(execute.get(0));
assertNotNull(execute.get(0).get(0));
assertNotNull(execute.get(0).get(1));
assertEquals(Nd4j.scalar(3).reshape(1, 1),((NDArrayWritable)execute.get(0).get(0)).get());
assertEquals("hello world",execute.get(0).get(1).toString());
}
@Test
public void testWithSetupRun() throws Exception {
PythonTransform pythonTransform = PythonTransform.builder()
.code("five=None\n" +
"def setup():\n" +
" global five\n"+
" five = 5\n\n" +
"def run(a, b):\n" +
" c = a + b + five\n"+
" return {'c':c}\n\n")
.returnAllInputs(true)
.setupAndRun(true)
.build();
List<List<Writable>> inputs = new ArrayList<>();
inputs.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.scalar(1).reshape(1,1)),
new NDArrayWritable(Nd4j.scalar(2).reshape(1,1))));
Schema inputSchema = new Builder()
.addColumnNDArray("a",new long[]{1,1})
.addColumnNDArray("b", new long[]{1, 1})
.build();
TransformProcess tp = new TransformProcess.Builder(inputSchema)
.transform(pythonTransform)
.build();
List<List<Writable>> execute = LocalTransformExecutor.execute(inputs, tp);
assertFalse(execute.isEmpty());
assertNotNull(execute.get(0));
assertNotNull(execute.get(0).get(0));
assertEquals(Nd4j.scalar(8).reshape(1, 1),((NDArrayWritable)execute.get(0).get(3)).get());
}
}

View File

@ -28,11 +28,11 @@ import org.datavec.api.writable.*;
import org.datavec.local.transforms.LocalTransformExecutor; import org.datavec.local.transforms.LocalTransformExecutor;
import org.junit.Test; import org.junit.jupiter.api.Test;
import java.util.*; import java.util.*;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestJoin { public class TestJoin {

View File

@ -31,13 +31,13 @@ import org.datavec.api.writable.comparator.DoubleWritableComparator;
import org.datavec.local.transforms.LocalTransformExecutor; import org.datavec.local.transforms.LocalTransformExecutor;
import org.junit.Test; import org.junit.jupiter.api.Test;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestCalculateSortedRank { public class TestCalculateSortedRank {

View File

@ -31,14 +31,14 @@ import org.datavec.api.writable.Writable;
import org.datavec.arrow.recordreader.ArrowWritableRecordTimeSeriesBatch; import org.datavec.arrow.recordreader.ArrowWritableRecordTimeSeriesBatch;
import org.datavec.local.transforms.LocalTransformExecutor; import org.datavec.local.transforms.LocalTransformExecutor;
import org.junit.Test; import org.junit.jupiter.api.Test;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
public class TestConvertToSequence { public class TestConvertToSequence {

View File

@ -41,6 +41,12 @@
</properties> </properties>
<dependencies> <dependencies>
<dependency>
<groupId>com.tdunning</groupId>
<artifactId>t-digest</artifactId>
<version>3.2</version>
<scope>test</scope>
</dependency>
<dependency> <dependency>
<groupId>org.scala-lang</groupId> <groupId>org.scala-lang</groupId>
<artifactId>scala-library</artifactId> <artifactId>scala-library</artifactId>
@ -122,10 +128,10 @@
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>nd4j-tests-cpu</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-11.0</id> <id>nd4j-tests-cuda</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

Some files were not shown because too many files have changed in this diff Show More