commit
c505a11ed6
|
@ -31,7 +31,7 @@ jobs:
|
|||
protoc --version
|
||||
cd dl4j-test-resources-master && mvn clean install -DskipTests && cd ..
|
||||
export OMP_NUM_THREADS=1
|
||||
mvn -Pintegration-tests -Ptestresources -Dlibnd4j.buildthreads=1 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
|
||||
mvn -Pintegration-tests -Ptestresources -Dlibnd4j.buildthreads=1 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
|
||||
|
||||
windows-x86_64:
|
||||
runs-on: windows-2019
|
||||
|
@ -44,7 +44,7 @@ jobs:
|
|||
run: |
|
||||
set "PATH=C:\msys64\usr\bin;%PATH%"
|
||||
export OMP_NUM_THREADS=1
|
||||
mvn -DskipTestResourceEnforcement=true -Pintegration-tests -Ptestresources -Dlibnd4j.buildthreads=1 -Dlibnd4j.build="Debug" -Djavacpp.platform=windows-x86_64 -libnd4j.platform=windows-x86_64 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
|
||||
mvn -DskipTestResourceEnforcement=true -Pintegration-tests -Ptestresources -Dlibnd4j.buildthreads=1 -Dlibnd4j.build="Debug" -Djavacpp.platform=windows-x86_64 -libnd4j.platform=windows-x86_64 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
|
||||
|
||||
|
||||
|
||||
|
@ -60,5 +60,5 @@ jobs:
|
|||
run: |
|
||||
brew install unzip ccache gcc swig autoconf-archive automake cmake libomp libtool libusb ant maven nasm xz pkg-config sdl gpg1 bison flex perl ragel binutils gradle gmp isl libmpc mpfr wget python
|
||||
export OMP_NUM_THREADS=1
|
||||
mvn -Pintegration-tests -Dlibnd4j.build="Debug" -Dlibnd4j.buildthreads=1 -Ptestresources -Djavacpp.platform=macosx-x86_64 -libnd4j.platform=macosx-x86_64 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
|
||||
mvn -Pintegration-tests -Dlibnd4j.build="Debug" -Dlibnd4j.buildthreads=1 -Ptestresources -Djavacpp.platform=macosx-x86_64 -libnd4j.platform=macosx-x86_64 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
|
||||
|
||||
|
|
|
@ -31,7 +31,7 @@ jobs:
|
|||
protoc --version
|
||||
cd dl4j-test-resources-master && mvn clean install -DskipTests && cd ..
|
||||
export OMP_NUM_THREADS=1
|
||||
mvn -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Dlibnd4j.buildthreads=1 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
|
||||
mvn -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Dlibnd4j.buildthreads=1 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
|
||||
|
||||
windows-x86_64:
|
||||
runs-on: windows-2019
|
||||
|
@ -44,7 +44,7 @@ jobs:
|
|||
run: |
|
||||
set "PATH=C:\msys64\usr\bin;%PATH%"
|
||||
export OMP_NUM_THREADS=1
|
||||
mvn -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -DskipTestResourceEnforcement=true -Ptestresources -Dlibnd4j.buildthreads=1 -Dlibnd4j.build="Debug" -Djavacpp.platform=windows-x86_64 -libnd4j.platform=windows-x86_64 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
|
||||
mvn -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -DskipTestResourceEnforcement=true -Ptestresources -Dlibnd4j.buildthreads=1 -Dlibnd4j.build="Debug" -Djavacpp.platform=windows-x86_64 -libnd4j.platform=windows-x86_64 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
|
||||
|
||||
|
||||
|
||||
|
@ -60,5 +60,5 @@ jobs:
|
|||
run: |
|
||||
brew install unzip ccache gcc swig autoconf-archive automake cmake libomp libtool libusb ant maven nasm xz pkg-config sdl gpg1 bison flex perl ragel binutils gradle gmp isl libmpc mpfr wget python
|
||||
export OMP_NUM_THREADS=1
|
||||
mvn -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Dlibnd4j.build="Debug" -Dlibnd4j.buildthreads=1 -Ptestresources -Djavacpp.platform=macosx-x86_64 -libnd4j.platform=macosx-x86_64 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
|
||||
mvn -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Dlibnd4j.build="Debug" -Dlibnd4j.buildthreads=1 -Ptestresources -Djavacpp.platform=macosx-x86_64 -libnd4j.platform=macosx-x86_64 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
|
||||
|
||||
|
|
|
@ -1,11 +1,3 @@
|
|||
on:
|
||||
workflow_dispatch:
|
||||
jobs:
|
||||
# Wait for up to a minute for previous run to complete, abort if not done by then
|
||||
pre-ci:
|
||||
run
|
||||
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
jobs:
|
||||
|
@ -42,5 +34,5 @@ jobs:
|
|||
cmake --version
|
||||
protoc --version
|
||||
export OMP_NUM_THREADS=1
|
||||
mvn -DskipTestResourceEnforcement=true -Ptestresources -Pintegration-tests -Pdl4j-integration-tests -Pnd4j-tests-cpu clean test
|
||||
mvn -DskipTestResourceEnforcement=true -Ptestresources -Pintegration-tests -Pnd4j-tests-cpu clean test -rf :rl4j-core
|
||||
|
||||
|
|
|
@ -34,5 +34,5 @@ jobs:
|
|||
cmake --version
|
||||
protoc --version
|
||||
export OMP_NUM_THREADS=1
|
||||
mvn -DskipTestResourceEnforcement=true -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Ptest-nd4j-native --also-make clean test
|
||||
mvn -DskipTestResourceEnforcement=true -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Pnd4j-tests-cpu --also-make clean test
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
<commons.dbutils.version>1.7</commons.dbutils.version>
|
||||
<lombok.version>1.18.8</lombok.version>
|
||||
<logback.version>1.1.7</logback.version>
|
||||
<junit.version>4.12</junit.version>
|
||||
<junit.version>5.8.0-M1</junit.version>
|
||||
<junit-jupiter.version>5.4.2</junit-jupiter.version>
|
||||
<java.version>1.8</java.version>
|
||||
<maven-shade-plugin.version>3.1.1</maven-shade-plugin.version>
|
||||
|
|
|
@ -17,13 +17,14 @@
|
|||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.nd4j.codegen.ir;
|
||||
|
||||
public class SerializationTest {
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
|
||||
public static void main(String...args) {
|
||||
@DisplayName("Serialization Test")
|
||||
class SerializationTest {
|
||||
|
||||
public static void main(String... args) {
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -17,29 +17,23 @@
|
|||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.nd4j.codegen.dsl;
|
||||
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.codegen.impl.java.DocsGenerator;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
|
||||
public class DocsGeneratorTest {
|
||||
@DisplayName("Docs Generator Test")
|
||||
class DocsGeneratorTest {
|
||||
|
||||
@Test
|
||||
public void testJDtoMDAdapter() {
|
||||
String original = "{@code %INPUT_TYPE% eye = eye(3,2)\n" +
|
||||
" eye:\n" +
|
||||
" [ 1, 0]\n" +
|
||||
" [ 0, 1]\n" +
|
||||
" [ 0, 0]}";
|
||||
String expected = "{ INDArray eye = eye(3,2)\n" +
|
||||
" eye:\n" +
|
||||
" [ 1, 0]\n" +
|
||||
" [ 0, 1]\n" +
|
||||
" [ 0, 0]}";
|
||||
@DisplayName("Test J Dto MD Adapter")
|
||||
void testJDtoMDAdapter() {
|
||||
String original = "{@code %INPUT_TYPE% eye = eye(3,2)\n" + " eye:\n" + " [ 1, 0]\n" + " [ 0, 1]\n" + " [ 0, 0]}";
|
||||
String expected = "{ INDArray eye = eye(3,2)\n" + " eye:\n" + " [ 1, 0]\n" + " [ 0, 1]\n" + " [ 0, 0]}";
|
||||
DocsGenerator.JavaDocToMDAdapter adapter = new DocsGenerator.JavaDocToMDAdapter(original);
|
||||
String out = adapter.filter("@code", StringUtils.EMPTY).filter("%INPUT_TYPE%", "INDArray").toString();
|
||||
assertEquals(out, expected);
|
||||
|
|
|
@ -34,6 +34,14 @@
|
|||
<artifactId>datavec-api</artifactId>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter-api</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.junit.vintage</groupId>
|
||||
<artifactId>junit-vintage-engine</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.commons</groupId>
|
||||
<artifactId>commons-lang3</artifactId>
|
||||
|
@ -101,10 +109,10 @@
|
|||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<id>nd4j-tests-cpu</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
<id>nd4j-tests-cuda</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.api.records.reader.impl;
|
||||
|
||||
import org.apache.commons.io.FileUtils;
|
||||
|
@ -26,47 +25,38 @@ import org.datavec.api.records.reader.impl.csv.CSVLineSequenceRecordReader;
|
|||
import org.datavec.api.split.FileSplit;
|
||||
import org.datavec.api.writable.Text;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import java.io.File;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import java.nio.file.Path;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
@DisplayName("Csv Line Sequence Record Reader Test")
|
||||
class CSVLineSequenceRecordReaderTest extends BaseND4JTest {
|
||||
|
||||
public class CSVLineSequenceRecordReaderTest extends BaseND4JTest {
|
||||
|
||||
@Rule
|
||||
public TemporaryFolder testDir = new TemporaryFolder();
|
||||
@TempDir
|
||||
public Path testDir;
|
||||
|
||||
@Test
|
||||
public void test() throws Exception {
|
||||
|
||||
File f = testDir.newFolder();
|
||||
@DisplayName("Test")
|
||||
void test(@TempDir Path testDir) throws Exception {
|
||||
File f = testDir.toFile();
|
||||
File source = new File(f, "temp.csv");
|
||||
String str = "a,b,c\n1,2,3,4";
|
||||
FileUtils.writeStringToFile(source, str, StandardCharsets.UTF_8);
|
||||
|
||||
SequenceRecordReader rr = new CSVLineSequenceRecordReader();
|
||||
rr.initialize(new FileSplit(source));
|
||||
|
||||
List<List<Writable>> exp0 = Arrays.asList(
|
||||
Collections.<Writable>singletonList(new Text("a")),
|
||||
Collections.<Writable>singletonList(new Text("b")),
|
||||
Collections.<Writable>singletonList(new Text("c")));
|
||||
|
||||
List<List<Writable>> exp1 = Arrays.asList(
|
||||
Collections.<Writable>singletonList(new Text("1")),
|
||||
Collections.<Writable>singletonList(new Text("2")),
|
||||
Collections.<Writable>singletonList(new Text("3")),
|
||||
Collections.<Writable>singletonList(new Text("4")));
|
||||
|
||||
for( int i=0; i<3; i++ ) {
|
||||
List<List<Writable>> exp0 = Arrays.asList(Collections.<Writable>singletonList(new Text("a")), Collections.<Writable>singletonList(new Text("b")), Collections.<Writable>singletonList(new Text("c")));
|
||||
List<List<Writable>> exp1 = Arrays.asList(Collections.<Writable>singletonList(new Text("1")), Collections.<Writable>singletonList(new Text("2")), Collections.<Writable>singletonList(new Text("3")), Collections.<Writable>singletonList(new Text("4")));
|
||||
for (int i = 0; i < 3; i++) {
|
||||
int count = 0;
|
||||
while (rr.hasNext()) {
|
||||
List<List<Writable>> next = rr.sequenceRecord();
|
||||
|
@ -76,9 +66,7 @@ public class CSVLineSequenceRecordReaderTest extends BaseND4JTest {
|
|||
assertEquals(exp1, next);
|
||||
}
|
||||
}
|
||||
|
||||
assertEquals(2, count);
|
||||
|
||||
rr.reset();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.api.records.reader.impl;
|
||||
|
||||
import org.apache.commons.io.FileUtils;
|
||||
|
@ -26,33 +25,37 @@ import org.datavec.api.records.reader.impl.csv.CSVMultiSequenceRecordReader;
|
|||
import org.datavec.api.split.FileSplit;
|
||||
import org.datavec.api.writable.Text;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import java.io.File;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import java.nio.file.Path;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
@DisplayName("Csv Multi Sequence Record Reader Test")
|
||||
class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
|
||||
|
||||
public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
|
||||
|
||||
@Rule
|
||||
public TemporaryFolder testDir = new TemporaryFolder();
|
||||
@TempDir
|
||||
public Path testDir;
|
||||
|
||||
@Test
|
||||
public void testConcatMode() throws Exception {
|
||||
for( int i=0; i<3; i++ ) {
|
||||
|
||||
@DisplayName("Test Concat Mode")
|
||||
@Disabled
|
||||
void testConcatMode() throws Exception {
|
||||
for (int i = 0; i < 3; i++) {
|
||||
String seqSep;
|
||||
String seqSepRegex;
|
||||
switch (i){
|
||||
switch(i) {
|
||||
case 0:
|
||||
seqSep = "";
|
||||
seqSepRegex = "^$";
|
||||
|
@ -68,31 +71,23 @@ public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
|
|||
default:
|
||||
throw new RuntimeException();
|
||||
}
|
||||
|
||||
String str = "a,b,c\n1,2,3,4\nx,y\n" + seqSep + "\nA,B,C";
|
||||
File f = testDir.newFile();
|
||||
File f = testDir.toFile();
|
||||
FileUtils.writeStringToFile(f, str, StandardCharsets.UTF_8);
|
||||
|
||||
SequenceRecordReader seqRR = new CSVMultiSequenceRecordReader(seqSepRegex, CSVMultiSequenceRecordReader.Mode.CONCAT);
|
||||
seqRR.initialize(new FileSplit(f));
|
||||
|
||||
|
||||
List<List<Writable>> exp0 = new ArrayList<>();
|
||||
for (String s : "a,b,c,1,2,3,4,x,y".split(",")) {
|
||||
exp0.add(Collections.<Writable>singletonList(new Text(s)));
|
||||
}
|
||||
|
||||
List<List<Writable>> exp1 = new ArrayList<>();
|
||||
for (String s : "A,B,C".split(",")) {
|
||||
exp1.add(Collections.<Writable>singletonList(new Text(s)));
|
||||
}
|
||||
|
||||
assertEquals(exp0, seqRR.sequenceRecord());
|
||||
assertEquals(exp1, seqRR.sequenceRecord());
|
||||
assertFalse(seqRR.hasNext());
|
||||
|
||||
seqRR.reset();
|
||||
|
||||
assertEquals(exp0, seqRR.sequenceRecord());
|
||||
assertEquals(exp1, seqRR.sequenceRecord());
|
||||
assertFalse(seqRR.hasNext());
|
||||
|
@ -100,13 +95,13 @@ public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testEqualLength() throws Exception {
|
||||
|
||||
for( int i=0; i<3; i++ ) {
|
||||
|
||||
@DisplayName("Test Equal Length")
|
||||
@Disabled
|
||||
void testEqualLength() throws Exception {
|
||||
for (int i = 0; i < 3; i++) {
|
||||
String seqSep;
|
||||
String seqSepRegex;
|
||||
switch (i) {
|
||||
switch(i) {
|
||||
case 0:
|
||||
seqSep = "";
|
||||
seqSepRegex = "^$";
|
||||
|
@ -122,27 +117,17 @@ public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
|
|||
default:
|
||||
throw new RuntimeException();
|
||||
}
|
||||
|
||||
String str = "a,b\n1,2\nx,y\n" + seqSep + "\nA\nB\nC";
|
||||
File f = testDir.newFile();
|
||||
File f = testDir.toFile();
|
||||
FileUtils.writeStringToFile(f, str, StandardCharsets.UTF_8);
|
||||
|
||||
SequenceRecordReader seqRR = new CSVMultiSequenceRecordReader(seqSepRegex, CSVMultiSequenceRecordReader.Mode.EQUAL_LENGTH);
|
||||
seqRR.initialize(new FileSplit(f));
|
||||
|
||||
|
||||
List<List<Writable>> exp0 = Arrays.asList(
|
||||
Arrays.<Writable>asList(new Text("a"), new Text("1"), new Text("x")),
|
||||
Arrays.<Writable>asList(new Text("b"), new Text("2"), new Text("y")));
|
||||
|
||||
List<List<Writable>> exp0 = Arrays.asList(Arrays.<Writable>asList(new Text("a"), new Text("1"), new Text("x")), Arrays.<Writable>asList(new Text("b"), new Text("2"), new Text("y")));
|
||||
List<List<Writable>> exp1 = Collections.singletonList(Arrays.<Writable>asList(new Text("A"), new Text("B"), new Text("C")));
|
||||
|
||||
assertEquals(exp0, seqRR.sequenceRecord());
|
||||
assertEquals(exp1, seqRR.sequenceRecord());
|
||||
assertFalse(seqRR.hasNext());
|
||||
|
||||
seqRR.reset();
|
||||
|
||||
assertEquals(exp0, seqRR.sequenceRecord());
|
||||
assertEquals(exp1, seqRR.sequenceRecord());
|
||||
assertFalse(seqRR.hasNext());
|
||||
|
@ -150,13 +135,13 @@ public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testPadding() throws Exception {
|
||||
|
||||
for( int i=0; i<3; i++ ) {
|
||||
|
||||
@DisplayName("Test Padding")
|
||||
@Disabled
|
||||
void testPadding() throws Exception {
|
||||
for (int i = 0; i < 3; i++) {
|
||||
String seqSep;
|
||||
String seqSepRegex;
|
||||
switch (i) {
|
||||
switch(i) {
|
||||
case 0:
|
||||
seqSep = "";
|
||||
seqSepRegex = "^$";
|
||||
|
@ -172,27 +157,17 @@ public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
|
|||
default:
|
||||
throw new RuntimeException();
|
||||
}
|
||||
|
||||
String str = "a,b\n1\nx\n" + seqSep + "\nA\nB\nC";
|
||||
File f = testDir.newFile();
|
||||
File f = testDir.toFile();
|
||||
FileUtils.writeStringToFile(f, str, StandardCharsets.UTF_8);
|
||||
|
||||
SequenceRecordReader seqRR = new CSVMultiSequenceRecordReader(seqSepRegex, CSVMultiSequenceRecordReader.Mode.PAD, new Text("PAD"));
|
||||
seqRR.initialize(new FileSplit(f));
|
||||
|
||||
|
||||
List<List<Writable>> exp0 = Arrays.asList(
|
||||
Arrays.<Writable>asList(new Text("a"), new Text("1"), new Text("x")),
|
||||
Arrays.<Writable>asList(new Text("b"), new Text("PAD"), new Text("PAD")));
|
||||
|
||||
List<List<Writable>> exp0 = Arrays.asList(Arrays.<Writable>asList(new Text("a"), new Text("1"), new Text("x")), Arrays.<Writable>asList(new Text("b"), new Text("PAD"), new Text("PAD")));
|
||||
List<List<Writable>> exp1 = Collections.singletonList(Arrays.<Writable>asList(new Text("A"), new Text("B"), new Text("C")));
|
||||
|
||||
assertEquals(exp0, seqRR.sequenceRecord());
|
||||
assertEquals(exp1, seqRR.sequenceRecord());
|
||||
assertFalse(seqRR.hasNext());
|
||||
|
||||
seqRR.reset();
|
||||
|
||||
assertEquals(exp0, seqRR.sequenceRecord());
|
||||
assertEquals(exp1, seqRR.sequenceRecord());
|
||||
assertFalse(seqRR.hasNext());
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.api.records.reader.impl;
|
||||
|
||||
import org.datavec.api.records.SequenceRecord;
|
||||
|
@ -27,61 +26,53 @@ import org.datavec.api.records.reader.impl.csv.CSVNLinesSequenceRecordReader;
|
|||
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
|
||||
import org.datavec.api.split.FileSplit;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
public class CSVNLinesSequenceRecordReaderTest extends BaseND4JTest {
|
||||
@DisplayName("Csvn Lines Sequence Record Reader Test")
|
||||
class CSVNLinesSequenceRecordReaderTest extends BaseND4JTest {
|
||||
|
||||
@Test
|
||||
public void testCSVNLinesSequenceRecordReader() throws Exception {
|
||||
@DisplayName("Test CSVN Lines Sequence Record Reader")
|
||||
void testCSVNLinesSequenceRecordReader() throws Exception {
|
||||
int nLinesPerSequence = 10;
|
||||
|
||||
SequenceRecordReader seqRR = new CSVNLinesSequenceRecordReader(nLinesPerSequence);
|
||||
seqRR.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
|
||||
|
||||
CSVRecordReader rr = new CSVRecordReader();
|
||||
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
|
||||
|
||||
int count = 0;
|
||||
while (seqRR.hasNext()) {
|
||||
List<List<Writable>> next = seqRR.sequenceRecord();
|
||||
|
||||
List<List<Writable>> expected = new ArrayList<>();
|
||||
for (int i = 0; i < nLinesPerSequence; i++) {
|
||||
expected.add(rr.next());
|
||||
}
|
||||
|
||||
assertEquals(10, next.size());
|
||||
assertEquals(expected, next);
|
||||
|
||||
count++;
|
||||
}
|
||||
|
||||
assertEquals(150 / nLinesPerSequence, count);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCSVNlinesSequenceRecordReaderMetaData() throws Exception {
|
||||
@DisplayName("Test CSV Nlines Sequence Record Reader Meta Data")
|
||||
void testCSVNlinesSequenceRecordReaderMetaData() throws Exception {
|
||||
int nLinesPerSequence = 10;
|
||||
|
||||
SequenceRecordReader seqRR = new CSVNLinesSequenceRecordReader(nLinesPerSequence);
|
||||
seqRR.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
|
||||
|
||||
CSVRecordReader rr = new CSVRecordReader();
|
||||
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
|
||||
|
||||
List<List<List<Writable>>> out = new ArrayList<>();
|
||||
while (seqRR.hasNext()) {
|
||||
List<List<Writable>> next = seqRR.sequenceRecord();
|
||||
out.add(next);
|
||||
}
|
||||
|
||||
seqRR.reset();
|
||||
List<List<List<Writable>>> out2 = new ArrayList<>();
|
||||
List<SequenceRecord> out3 = new ArrayList<>();
|
||||
|
@ -92,11 +83,8 @@ public class CSVNLinesSequenceRecordReaderTest extends BaseND4JTest {
|
|||
meta.add(seq.getMetaData());
|
||||
out3.add(seq);
|
||||
}
|
||||
|
||||
assertEquals(out, out2);
|
||||
|
||||
List<SequenceRecord> out4 = seqRR.loadSequenceFromMetaData(meta);
|
||||
assertEquals(out3, out4);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.api.records.reader.impl;
|
||||
|
||||
import org.apache.commons.io.FileUtils;
|
||||
|
@ -34,10 +33,10 @@ import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
|
|||
import org.datavec.api.writable.IntWritable;
|
||||
import org.datavec.api.writable.Text;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Files;
|
||||
|
@ -47,41 +46,44 @@ import java.util.Arrays;
|
|||
import java.util.List;
|
||||
import java.util.NoSuchElementException;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
|
||||
@DisplayName("Csv Record Reader Test")
|
||||
class CSVRecordReaderTest extends BaseND4JTest {
|
||||
|
||||
public class CSVRecordReaderTest extends BaseND4JTest {
|
||||
@Test
|
||||
public void testNext() throws Exception {
|
||||
@DisplayName("Test Next")
|
||||
void testNext() throws Exception {
|
||||
CSVRecordReader reader = new CSVRecordReader();
|
||||
reader.initialize(new StringSplit("1,1,8.0,,,,14.0,,,,15.0,,,,,,,,,,,,1"));
|
||||
while (reader.hasNext()) {
|
||||
List<Writable> vals = reader.next();
|
||||
List<Writable> arr = new ArrayList<>(vals);
|
||||
|
||||
assertEquals("Entry count", 23, vals.size());
|
||||
assertEquals(23, vals.size(), "Entry count");
|
||||
Text lastEntry = (Text) arr.get(arr.size() - 1);
|
||||
assertEquals("Last entry garbage", 1, lastEntry.getLength());
|
||||
assertEquals(1, lastEntry.getLength(), "Last entry garbage");
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEmptyEntries() throws Exception {
|
||||
@DisplayName("Test Empty Entries")
|
||||
void testEmptyEntries() throws Exception {
|
||||
CSVRecordReader reader = new CSVRecordReader();
|
||||
reader.initialize(new StringSplit("1,1,8.0,,,,14.0,,,,15.0,,,,,,,,,,,,"));
|
||||
while (reader.hasNext()) {
|
||||
List<Writable> vals = reader.next();
|
||||
assertEquals("Entry count", 23, vals.size());
|
||||
assertEquals(23, vals.size(), "Entry count");
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testReset() throws Exception {
|
||||
@DisplayName("Test Reset")
|
||||
void testReset() throws Exception {
|
||||
CSVRecordReader rr = new CSVRecordReader(0, ',');
|
||||
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
|
||||
|
||||
int nResets = 5;
|
||||
for (int i = 0; i < nResets; i++) {
|
||||
|
||||
int lineCount = 0;
|
||||
while (rr.hasNext()) {
|
||||
List<Writable> line = rr.next();
|
||||
|
@ -95,7 +97,8 @@ public class CSVRecordReaderTest extends BaseND4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testResetWithSkipLines() throws Exception {
|
||||
@DisplayName("Test Reset With Skip Lines")
|
||||
void testResetWithSkipLines() throws Exception {
|
||||
CSVRecordReader rr = new CSVRecordReader(10, ',');
|
||||
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
|
||||
int lineCount = 0;
|
||||
|
@ -114,7 +117,8 @@ public class CSVRecordReaderTest extends BaseND4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testWrite() throws Exception {
|
||||
@DisplayName("Test Write")
|
||||
void testWrite() throws Exception {
|
||||
List<List<Writable>> list = new ArrayList<>();
|
||||
StringBuilder sb = new StringBuilder();
|
||||
for (int i = 0; i < 10; i++) {
|
||||
|
@ -130,81 +134,72 @@ public class CSVRecordReaderTest extends BaseND4JTest {
|
|||
}
|
||||
list.add(temp);
|
||||
}
|
||||
|
||||
String expected = sb.toString();
|
||||
|
||||
Path p = Files.createTempFile("csvwritetest", "csv");
|
||||
p.toFile().deleteOnExit();
|
||||
|
||||
FileRecordWriter writer = new CSVRecordWriter();
|
||||
FileSplit fileSplit = new FileSplit(p.toFile());
|
||||
writer.initialize(fileSplit,new NumberOfRecordsPartitioner());
|
||||
writer.initialize(fileSplit, new NumberOfRecordsPartitioner());
|
||||
for (List<Writable> c : list) {
|
||||
writer.write(c);
|
||||
}
|
||||
writer.close();
|
||||
|
||||
//Read file back in; compare
|
||||
// Read file back in; compare
|
||||
String fileContents = FileUtils.readFileToString(p.toFile(), FileRecordWriter.DEFAULT_CHARSET.name());
|
||||
|
||||
// System.out.println(expected);
|
||||
// System.out.println("----------");
|
||||
// System.out.println(fileContents);
|
||||
|
||||
assertEquals(expected, fileContents);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTabsAsSplit1() throws Exception {
|
||||
|
||||
@DisplayName("Test Tabs As Split 1")
|
||||
void testTabsAsSplit1() throws Exception {
|
||||
CSVRecordReader reader = new CSVRecordReader(0, '\t');
|
||||
reader.initialize(new FileSplit(new ClassPathResource("datavec-api/tabbed.txt").getFile()));
|
||||
while (reader.hasNext()) {
|
||||
List<Writable> list = new ArrayList<>(reader.next());
|
||||
|
||||
assertEquals(2, list.size());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPipesAsSplit() throws Exception {
|
||||
|
||||
@DisplayName("Test Pipes As Split")
|
||||
void testPipesAsSplit() throws Exception {
|
||||
CSVRecordReader reader = new CSVRecordReader(0, '|');
|
||||
reader.initialize(new FileSplit(new ClassPathResource("datavec-api/issue414.csv").getFile()));
|
||||
int lineidx = 0;
|
||||
List<Integer> sixthColumn = Arrays.asList(13, 95, 15, 25);
|
||||
while (reader.hasNext()) {
|
||||
List<Writable> list = new ArrayList<>(reader.next());
|
||||
|
||||
assertEquals(10, list.size());
|
||||
assertEquals((long)sixthColumn.get(lineidx), list.get(5).toInt());
|
||||
assertEquals((long) sixthColumn.get(lineidx), list.get(5).toInt());
|
||||
lineidx++;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testWithQuotes() throws Exception {
|
||||
@DisplayName("Test With Quotes")
|
||||
void testWithQuotes() throws Exception {
|
||||
CSVRecordReader reader = new CSVRecordReader(0, ',', '\"');
|
||||
reader.initialize(new StringSplit("1,0,3,\"Braund, Mr. Owen Harris\",male,\"\"\"\""));
|
||||
while (reader.hasNext()) {
|
||||
List<Writable> vals = reader.next();
|
||||
assertEquals("Entry count", 6, vals.size());
|
||||
assertEquals("1", vals.get(0).toString());
|
||||
assertEquals("0", vals.get(1).toString());
|
||||
assertEquals("3", vals.get(2).toString());
|
||||
assertEquals("Braund, Mr. Owen Harris", vals.get(3).toString());
|
||||
assertEquals("male", vals.get(4).toString());
|
||||
assertEquals("\"", vals.get(5).toString());
|
||||
assertEquals(6, vals.size(), "Entry count");
|
||||
assertEquals(vals.get(0).toString(), "1");
|
||||
assertEquals(vals.get(1).toString(), "0");
|
||||
assertEquals(vals.get(2).toString(), "3");
|
||||
assertEquals(vals.get(3).toString(), "Braund, Mr. Owen Harris");
|
||||
assertEquals(vals.get(4).toString(), "male");
|
||||
assertEquals(vals.get(5).toString(), "\"");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testMeta() throws Exception {
|
||||
@DisplayName("Test Meta")
|
||||
void testMeta() throws Exception {
|
||||
CSVRecordReader rr = new CSVRecordReader(0, ',');
|
||||
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
|
||||
|
||||
int lineCount = 0;
|
||||
List<RecordMetaData> metaList = new ArrayList<>();
|
||||
List<List<Writable>> writables = new ArrayList<>();
|
||||
|
@ -214,29 +209,24 @@ public class CSVRecordReaderTest extends BaseND4JTest {
|
|||
lineCount++;
|
||||
RecordMetaData meta = r.getMetaData();
|
||||
// System.out.println(r.getRecord() + "\t" + meta.getLocation() + "\t" + meta.getURI());
|
||||
|
||||
metaList.add(meta);
|
||||
writables.add(r.getRecord());
|
||||
}
|
||||
assertFalse(rr.hasNext());
|
||||
assertEquals(150, lineCount);
|
||||
rr.reset();
|
||||
|
||||
|
||||
System.out.println("\n\n\n--------------------------------");
|
||||
List<Record> contents = rr.loadFromMetaData(metaList);
|
||||
assertEquals(150, contents.size());
|
||||
// for(Record r : contents ){
|
||||
// System.out.println(r);
|
||||
// }
|
||||
|
||||
List<RecordMetaData> meta2 = new ArrayList<>();
|
||||
meta2.add(metaList.get(100));
|
||||
meta2.add(metaList.get(90));
|
||||
meta2.add(metaList.get(80));
|
||||
meta2.add(metaList.get(70));
|
||||
meta2.add(metaList.get(60));
|
||||
|
||||
List<Record> contents2 = rr.loadFromMetaData(meta2);
|
||||
assertEquals(writables.get(100), contents2.get(0).getRecord());
|
||||
assertEquals(writables.get(90), contents2.get(1).getRecord());
|
||||
|
@ -246,50 +236,49 @@ public class CSVRecordReaderTest extends BaseND4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testRegex() throws Exception {
|
||||
CSVRecordReader reader = new CSVRegexRecordReader(0, ",", null, new String[] {null, "(.+) (.+) (.+)"});
|
||||
@DisplayName("Test Regex")
|
||||
void testRegex() throws Exception {
|
||||
CSVRecordReader reader = new CSVRegexRecordReader(0, ",", null, new String[] { null, "(.+) (.+) (.+)" });
|
||||
reader.initialize(new StringSplit("normal,1.2.3.4 space separator"));
|
||||
while (reader.hasNext()) {
|
||||
List<Writable> vals = reader.next();
|
||||
assertEquals("Entry count", 4, vals.size());
|
||||
assertEquals("normal", vals.get(0).toString());
|
||||
assertEquals("1.2.3.4", vals.get(1).toString());
|
||||
assertEquals("space", vals.get(2).toString());
|
||||
assertEquals("separator", vals.get(3).toString());
|
||||
assertEquals(4, vals.size(), "Entry count");
|
||||
assertEquals(vals.get(0).toString(), "normal");
|
||||
assertEquals(vals.get(1).toString(), "1.2.3.4");
|
||||
assertEquals(vals.get(2).toString(), "space");
|
||||
assertEquals(vals.get(3).toString(), "separator");
|
||||
}
|
||||
}
|
||||
|
||||
@Test(expected = NoSuchElementException.class)
|
||||
public void testCsvSkipAllLines() throws IOException, InterruptedException {
|
||||
@Test
|
||||
@DisplayName("Test Csv Skip All Lines")
|
||||
void testCsvSkipAllLines() {
|
||||
assertThrows(NoSuchElementException.class, () -> {
|
||||
final int numLines = 4;
|
||||
final List<Writable> lineList = Arrays.asList((Writable) new IntWritable(numLines - 1),
|
||||
(Writable) new Text("one"), (Writable) new Text("two"), (Writable) new Text("three"));
|
||||
final List<Writable> lineList = Arrays.asList((Writable) new IntWritable(numLines - 1), (Writable) new Text("one"), (Writable) new Text("two"), (Writable) new Text("three"));
|
||||
String header = ",one,two,three";
|
||||
List<String> lines = new ArrayList<>();
|
||||
for (int i = 0; i < numLines; i++)
|
||||
lines.add(Integer.toString(i) + header);
|
||||
for (int i = 0; i < numLines; i++) lines.add(Integer.toString(i) + header);
|
||||
File tempFile = File.createTempFile("csvSkipLines", ".csv");
|
||||
FileUtils.writeLines(tempFile, lines);
|
||||
|
||||
CSVRecordReader rr = new CSVRecordReader(numLines, ',');
|
||||
rr.initialize(new FileSplit(tempFile));
|
||||
rr.reset();
|
||||
assertTrue(!rr.hasNext());
|
||||
rr.next();
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCsvSkipAllButOneLine() throws IOException, InterruptedException {
|
||||
@DisplayName("Test Csv Skip All But One Line")
|
||||
void testCsvSkipAllButOneLine() throws IOException, InterruptedException {
|
||||
final int numLines = 4;
|
||||
final List<Writable> lineList = Arrays.<Writable>asList(new Text(Integer.toString(numLines - 1)),
|
||||
new Text("one"), new Text("two"), new Text("three"));
|
||||
final List<Writable> lineList = Arrays.<Writable>asList(new Text(Integer.toString(numLines - 1)), new Text("one"), new Text("two"), new Text("three"));
|
||||
String header = ",one,two,three";
|
||||
List<String> lines = new ArrayList<>();
|
||||
for (int i = 0; i < numLines; i++)
|
||||
lines.add(Integer.toString(i) + header);
|
||||
for (int i = 0; i < numLines; i++) lines.add(Integer.toString(i) + header);
|
||||
File tempFile = File.createTempFile("csvSkipLines", ".csv");
|
||||
FileUtils.writeLines(tempFile, lines);
|
||||
|
||||
CSVRecordReader rr = new CSVRecordReader(numLines - 1, ',');
|
||||
rr.initialize(new FileSplit(tempFile));
|
||||
rr.reset();
|
||||
|
@ -297,50 +286,45 @@ public class CSVRecordReaderTest extends BaseND4JTest {
|
|||
assertEquals(rr.next(), lineList);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testStreamReset() throws Exception {
|
||||
@DisplayName("Test Stream Reset")
|
||||
void testStreamReset() throws Exception {
|
||||
CSVRecordReader rr = new CSVRecordReader(0, ',');
|
||||
rr.initialize(new InputStreamInputSplit(new ClassPathResource("datavec-api/iris.dat").getInputStream()));
|
||||
|
||||
int count = 0;
|
||||
while(rr.hasNext()){
|
||||
while (rr.hasNext()) {
|
||||
assertNotNull(rr.next());
|
||||
count++;
|
||||
}
|
||||
assertEquals(150, count);
|
||||
|
||||
assertFalse(rr.resetSupported());
|
||||
|
||||
try{
|
||||
try {
|
||||
rr.reset();
|
||||
fail("Expected exception");
|
||||
} catch (Exception e){
|
||||
} catch (Exception e) {
|
||||
String msg = e.getMessage();
|
||||
String msg2 = e.getCause().getMessage();
|
||||
assertTrue(msg, msg.contains("Error during LineRecordReader reset"));
|
||||
assertTrue(msg2, msg2.contains("Reset not supported from streams"));
|
||||
// e.printStackTrace();
|
||||
assertTrue(msg.contains("Error during LineRecordReader reset"),msg);
|
||||
assertTrue(msg2.contains("Reset not supported from streams"),msg2);
|
||||
// e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testUsefulExceptionNoInit(){
|
||||
|
||||
@DisplayName("Test Useful Exception No Init")
|
||||
void testUsefulExceptionNoInit() {
|
||||
CSVRecordReader rr = new CSVRecordReader(0, ',');
|
||||
|
||||
try{
|
||||
try {
|
||||
rr.hasNext();
|
||||
fail("Expected exception");
|
||||
} catch (Exception e){
|
||||
assertTrue(e.getMessage(), e.getMessage().contains("initialized"));
|
||||
} catch (Exception e) {
|
||||
assertTrue( e.getMessage().contains("initialized"),e.getMessage());
|
||||
}
|
||||
|
||||
try{
|
||||
try {
|
||||
rr.next();
|
||||
fail("Expected exception");
|
||||
} catch (Exception e){
|
||||
assertTrue(e.getMessage(), e.getMessage().contains("initialized"));
|
||||
} catch (Exception e) {
|
||||
assertTrue(e.getMessage().contains("initialized"),e.getMessage());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.api.records.reader.impl;
|
||||
|
||||
import org.datavec.api.records.SequenceRecord;
|
||||
|
@ -27,12 +26,11 @@ import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
|
|||
import org.datavec.api.split.InputSplit;
|
||||
import org.datavec.api.split.NumberedFileInputSplit;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
|
@ -41,25 +39,27 @@ import java.util.ArrayList;
|
|||
import java.util.Arrays;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import java.nio.file.Path;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
@DisplayName("Csv Sequence Record Reader Test")
|
||||
class CSVSequenceRecordReaderTest extends BaseND4JTest {
|
||||
|
||||
public class CSVSequenceRecordReaderTest extends BaseND4JTest {
|
||||
|
||||
@Rule
|
||||
public TemporaryFolder tempDir = new TemporaryFolder();
|
||||
@TempDir
|
||||
public Path tempDir;
|
||||
|
||||
@Test
|
||||
public void test() throws Exception {
|
||||
|
||||
@DisplayName("Test")
|
||||
void test() throws Exception {
|
||||
CSVSequenceRecordReader seqReader = new CSVSequenceRecordReader(1, ",");
|
||||
seqReader.initialize(new TestInputSplit());
|
||||
|
||||
int sequenceCount = 0;
|
||||
while (seqReader.hasNext()) {
|
||||
List<List<Writable>> sequence = seqReader.sequenceRecord();
|
||||
assertEquals(4, sequence.size()); //4 lines, plus 1 header line
|
||||
|
||||
// 4 lines, plus 1 header line
|
||||
assertEquals(4, sequence.size());
|
||||
Iterator<List<Writable>> timeStepIter = sequence.iterator();
|
||||
int lineCount = 0;
|
||||
while (timeStepIter.hasNext()) {
|
||||
|
@ -80,19 +80,18 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testReset() throws Exception {
|
||||
@DisplayName("Test Reset")
|
||||
void testReset() throws Exception {
|
||||
CSVSequenceRecordReader seqReader = new CSVSequenceRecordReader(1, ",");
|
||||
seqReader.initialize(new TestInputSplit());
|
||||
|
||||
int nTests = 5;
|
||||
for (int i = 0; i < nTests; i++) {
|
||||
seqReader.reset();
|
||||
|
||||
int sequenceCount = 0;
|
||||
while (seqReader.hasNext()) {
|
||||
List<List<Writable>> sequence = seqReader.sequenceRecord();
|
||||
assertEquals(4, sequence.size()); //4 lines, plus 1 header line
|
||||
|
||||
// 4 lines, plus 1 header line
|
||||
assertEquals(4, sequence.size());
|
||||
Iterator<List<Writable>> timeStepIter = sequence.iterator();
|
||||
int lineCount = 0;
|
||||
while (timeStepIter.hasNext()) {
|
||||
|
@ -107,15 +106,15 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testMetaData() throws Exception {
|
||||
@DisplayName("Test Meta Data")
|
||||
void testMetaData() throws Exception {
|
||||
CSVSequenceRecordReader seqReader = new CSVSequenceRecordReader(1, ",");
|
||||
seqReader.initialize(new TestInputSplit());
|
||||
|
||||
List<List<List<Writable>>> l = new ArrayList<>();
|
||||
while (seqReader.hasNext()) {
|
||||
List<List<Writable>> sequence = seqReader.sequenceRecord();
|
||||
assertEquals(4, sequence.size()); //4 lines, plus 1 header line
|
||||
|
||||
// 4 lines, plus 1 header line
|
||||
assertEquals(4, sequence.size());
|
||||
Iterator<List<Writable>> timeStepIter = sequence.iterator();
|
||||
int lineCount = 0;
|
||||
while (timeStepIter.hasNext()) {
|
||||
|
@ -123,10 +122,8 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
|
|||
lineCount++;
|
||||
}
|
||||
assertEquals(4, lineCount);
|
||||
|
||||
l.add(sequence);
|
||||
}
|
||||
|
||||
List<SequenceRecord> l2 = new ArrayList<>();
|
||||
List<RecordMetaData> meta = new ArrayList<>();
|
||||
seqReader.reset();
|
||||
|
@ -136,7 +133,6 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
|
|||
meta.add(sr.getMetaData());
|
||||
}
|
||||
assertEquals(3, l2.size());
|
||||
|
||||
List<SequenceRecord> fromMeta = seqReader.loadSequenceFromMetaData(meta);
|
||||
for (int i = 0; i < 3; i++) {
|
||||
assertEquals(l.get(i), l2.get(i).getSequenceRecord());
|
||||
|
@ -144,8 +140,8 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
|
|||
}
|
||||
}
|
||||
|
||||
private static class
|
||||
TestInputSplit implements InputSplit {
|
||||
@DisplayName("Test Input Split")
|
||||
private static class TestInputSplit implements InputSplit {
|
||||
|
||||
@Override
|
||||
public boolean canWriteToLocation(URI location) {
|
||||
|
@ -164,7 +160,6 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
|
|||
|
||||
@Override
|
||||
public void updateSplitLocations(boolean reset) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -174,7 +169,6 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
|
|||
|
||||
@Override
|
||||
public void bootStrapForWrite() {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -222,38 +216,30 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
|
|||
|
||||
@Override
|
||||
public void reset() {
|
||||
//No op
|
||||
// No op
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean resetSupported() {
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testCsvSeqAndNumberedFileSplit() throws Exception {
|
||||
File baseDir = tempDir.newFolder();
|
||||
//Simple sanity check unit test
|
||||
@DisplayName("Test Csv Seq And Numbered File Split")
|
||||
void testCsvSeqAndNumberedFileSplit(@TempDir Path tempDir) throws Exception {
|
||||
File baseDir = tempDir.toFile();
|
||||
// Simple sanity check unit test
|
||||
for (int i = 0; i < 3; i++) {
|
||||
new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(baseDir);
|
||||
}
|
||||
|
||||
//Load time series from CSV sequence files; compare to SequenceRecordReaderDataSetIterator
|
||||
// Load time series from CSV sequence files; compare to SequenceRecordReaderDataSetIterator
|
||||
ClassPathResource resource = new ClassPathResource("csvsequence_0.txt");
|
||||
String featuresPath = new File(baseDir, "csvsequence_%d.txt").getAbsolutePath();
|
||||
|
||||
SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
|
||||
featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
|
||||
|
||||
while(featureReader.hasNext()){
|
||||
while (featureReader.hasNext()) {
|
||||
featureReader.nextSequence();
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.api.records.reader.impl;
|
||||
|
||||
import org.datavec.api.records.reader.SequenceRecordReader;
|
||||
|
@ -25,94 +24,87 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
|
|||
import org.datavec.api.records.reader.impl.csv.CSVVariableSlidingWindowRecordReader;
|
||||
import org.datavec.api.split.FileSplit;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
public class CSVVariableSlidingWindowRecordReaderTest extends BaseND4JTest {
|
||||
@DisplayName("Csv Variable Sliding Window Record Reader Test")
|
||||
class CSVVariableSlidingWindowRecordReaderTest extends BaseND4JTest {
|
||||
|
||||
@Test
|
||||
public void testCSVVariableSlidingWindowRecordReader() throws Exception {
|
||||
@DisplayName("Test CSV Variable Sliding Window Record Reader")
|
||||
void testCSVVariableSlidingWindowRecordReader() throws Exception {
|
||||
int maxLinesPerSequence = 3;
|
||||
|
||||
SequenceRecordReader seqRR = new CSVVariableSlidingWindowRecordReader(maxLinesPerSequence);
|
||||
seqRR.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
|
||||
|
||||
CSVRecordReader rr = new CSVRecordReader();
|
||||
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
|
||||
|
||||
int count = 0;
|
||||
while (seqRR.hasNext()) {
|
||||
List<List<Writable>> next = seqRR.sequenceRecord();
|
||||
|
||||
if(count==maxLinesPerSequence-1) {
|
||||
if (count == maxLinesPerSequence - 1) {
|
||||
LinkedList<List<Writable>> expected = new LinkedList<>();
|
||||
for (int i = 0; i < maxLinesPerSequence; i++) {
|
||||
expected.addFirst(rr.next());
|
||||
}
|
||||
assertEquals(expected, next);
|
||||
|
||||
}
|
||||
if(count==maxLinesPerSequence) {
|
||||
if (count == maxLinesPerSequence) {
|
||||
assertEquals(maxLinesPerSequence, next.size());
|
||||
}
|
||||
if(count==0) { // first seq should be length 1
|
||||
if (count == 0) {
|
||||
// first seq should be length 1
|
||||
assertEquals(1, next.size());
|
||||
}
|
||||
if(count>151) { // last seq should be length 1
|
||||
if (count > 151) {
|
||||
// last seq should be length 1
|
||||
assertEquals(1, next.size());
|
||||
}
|
||||
|
||||
count++;
|
||||
}
|
||||
|
||||
assertEquals(152, count);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCSVVariableSlidingWindowRecordReaderStride() throws Exception {
|
||||
@DisplayName("Test CSV Variable Sliding Window Record Reader Stride")
|
||||
void testCSVVariableSlidingWindowRecordReaderStride() throws Exception {
|
||||
int maxLinesPerSequence = 3;
|
||||
int stride = 2;
|
||||
|
||||
SequenceRecordReader seqRR = new CSVVariableSlidingWindowRecordReader(maxLinesPerSequence, stride);
|
||||
seqRR.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
|
||||
|
||||
CSVRecordReader rr = new CSVRecordReader();
|
||||
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
|
||||
|
||||
int count = 0;
|
||||
while (seqRR.hasNext()) {
|
||||
List<List<Writable>> next = seqRR.sequenceRecord();
|
||||
|
||||
if(count==maxLinesPerSequence-1) {
|
||||
if (count == maxLinesPerSequence - 1) {
|
||||
LinkedList<List<Writable>> expected = new LinkedList<>();
|
||||
for(int s = 0; s < stride; s++) {
|
||||
for (int s = 0; s < stride; s++) {
|
||||
expected = new LinkedList<>();
|
||||
for (int i = 0; i < maxLinesPerSequence; i++) {
|
||||
expected.addFirst(rr.next());
|
||||
}
|
||||
}
|
||||
assertEquals(expected, next);
|
||||
|
||||
}
|
||||
if(count==maxLinesPerSequence) {
|
||||
if (count == maxLinesPerSequence) {
|
||||
assertEquals(maxLinesPerSequence, next.size());
|
||||
}
|
||||
if(count==0) { // first seq should be length 2
|
||||
if (count == 0) {
|
||||
// first seq should be length 2
|
||||
assertEquals(2, next.size());
|
||||
}
|
||||
if(count>151) { // last seq should be length 1
|
||||
if (count > 151) {
|
||||
// last seq should be length 1
|
||||
assertEquals(1, next.size());
|
||||
}
|
||||
|
||||
count++;
|
||||
}
|
||||
|
||||
assertEquals(76, count);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.api.records.reader.impl;
|
||||
|
||||
import org.apache.commons.io.FileUtils;
|
||||
|
@ -28,45 +27,44 @@ import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
|
|||
import org.datavec.api.records.reader.impl.filebatch.FileBatchRecordReader;
|
||||
import org.datavec.api.records.reader.impl.filebatch.FileBatchSequenceRecordReader;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.loader.FileBatch;
|
||||
|
||||
import java.io.File;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import java.nio.file.Path;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
@DisplayName("File Batch Record Reader Test")
|
||||
public class FileBatchRecordReaderTest extends BaseND4JTest {
|
||||
@TempDir Path testDir;
|
||||
|
||||
@Rule
|
||||
public TemporaryFolder testDir = new TemporaryFolder();
|
||||
|
||||
@Test
|
||||
public void testCsv() throws Exception {
|
||||
|
||||
//This is an unrealistic use case - one line/record per CSV
|
||||
File baseDir = testDir.newFolder();
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
@DisplayName("Test Csv")
|
||||
void testCsv(Nd4jBackend backend) throws Exception {
|
||||
// This is an unrealistic use case - one line/record per CSV
|
||||
File baseDir = testDir.toFile();
|
||||
List<File> fileList = new ArrayList<>();
|
||||
for( int i=0; i<10; i++ ){
|
||||
for (int i = 0; i < 10; i++) {
|
||||
String s = "file_" + i + "," + i + "," + i;
|
||||
File f = new File(baseDir, "origFile" + i + ".csv");
|
||||
FileUtils.writeStringToFile(f, s, StandardCharsets.UTF_8);
|
||||
fileList.add(f);
|
||||
}
|
||||
|
||||
FileBatch fb = FileBatch.forFiles(fileList);
|
||||
|
||||
RecordReader rr = new CSVRecordReader();
|
||||
FileBatchRecordReader fbrr = new FileBatchRecordReader(rr, fb);
|
||||
|
||||
|
||||
for( int test=0; test<3; test++) {
|
||||
for (int test = 0; test < 3; test++) {
|
||||
for (int i = 0; i < 10; i++) {
|
||||
assertTrue(fbrr.hasNext());
|
||||
List<Writable> next = fbrr.next();
|
||||
|
@ -82,16 +80,17 @@ public class FileBatchRecordReaderTest extends BaseND4JTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCsvSequence() throws Exception {
|
||||
//CSV sequence - 3 lines per file, 10 files
|
||||
File baseDir = testDir.newFolder();
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
@DisplayName("Test Csv Sequence")
|
||||
void testCsvSequence(Nd4jBackend backend) throws Exception {
|
||||
// CSV sequence - 3 lines per file, 10 files
|
||||
File baseDir = testDir.toFile();
|
||||
List<File> fileList = new ArrayList<>();
|
||||
for( int i=0; i<10; i++ ){
|
||||
for (int i = 0; i < 10; i++) {
|
||||
StringBuilder sb = new StringBuilder();
|
||||
for( int j=0; j<3; j++ ){
|
||||
if(j > 0)
|
||||
for (int j = 0; j < 3; j++) {
|
||||
if (j > 0)
|
||||
sb.append("\n");
|
||||
sb.append("file_" + i + "," + i + "," + j);
|
||||
}
|
||||
|
@ -99,19 +98,16 @@ public class FileBatchRecordReaderTest extends BaseND4JTest {
|
|||
FileUtils.writeStringToFile(f, sb.toString(), StandardCharsets.UTF_8);
|
||||
fileList.add(f);
|
||||
}
|
||||
|
||||
FileBatch fb = FileBatch.forFiles(fileList);
|
||||
SequenceRecordReader rr = new CSVSequenceRecordReader();
|
||||
FileBatchSequenceRecordReader fbrr = new FileBatchSequenceRecordReader(rr, fb);
|
||||
|
||||
|
||||
for( int test=0; test<3; test++) {
|
||||
for (int test = 0; test < 3; test++) {
|
||||
for (int i = 0; i < 10; i++) {
|
||||
assertTrue(fbrr.hasNext());
|
||||
List<List<Writable>> next = fbrr.sequenceRecord();
|
||||
assertEquals(3, next.size());
|
||||
int count = 0;
|
||||
for(List<Writable> step : next ){
|
||||
for (List<Writable> step : next) {
|
||||
String s1 = "file_" + i;
|
||||
assertEquals(s1, step.get(0).toString());
|
||||
assertEquals(String.valueOf(i), step.get(1).toString());
|
||||
|
@ -123,5 +119,4 @@ public class FileBatchRecordReaderTest extends BaseND4JTest {
|
|||
fbrr.reset();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.api.records.reader.impl;
|
||||
|
||||
import org.datavec.api.records.Record;
|
||||
|
@ -26,28 +25,28 @@ import org.datavec.api.split.CollectionInputSplit;
|
|||
import org.datavec.api.split.FileSplit;
|
||||
import org.datavec.api.split.InputSplit;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
|
||||
import java.net.URI;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
|
||||
public class FileRecordReaderTest extends BaseND4JTest {
|
||||
@DisplayName("File Record Reader Test")
|
||||
class FileRecordReaderTest extends BaseND4JTest {
|
||||
|
||||
@Test
|
||||
public void testReset() throws Exception {
|
||||
@DisplayName("Test Reset")
|
||||
void testReset() throws Exception {
|
||||
FileRecordReader rr = new FileRecordReader();
|
||||
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
|
||||
|
||||
int nResets = 5;
|
||||
for (int i = 0; i < nResets; i++) {
|
||||
|
||||
int lineCount = 0;
|
||||
while (rr.hasNext()) {
|
||||
List<Writable> line = rr.next();
|
||||
|
@ -61,25 +60,20 @@ public class FileRecordReaderTest extends BaseND4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testMeta() throws Exception {
|
||||
@DisplayName("Test Meta")
|
||||
void testMeta() throws Exception {
|
||||
FileRecordReader rr = new FileRecordReader();
|
||||
|
||||
|
||||
URI[] arr = new URI[3];
|
||||
arr[0] = new ClassPathResource("datavec-api/csvsequence_0.txt").getFile().toURI();
|
||||
arr[1] = new ClassPathResource("datavec-api/csvsequence_1.txt").getFile().toURI();
|
||||
arr[2] = new ClassPathResource("datavec-api/csvsequence_2.txt").getFile().toURI();
|
||||
|
||||
InputSplit is = new CollectionInputSplit(Arrays.asList(arr));
|
||||
rr.initialize(is);
|
||||
|
||||
List<List<Writable>> out = new ArrayList<>();
|
||||
while (rr.hasNext()) {
|
||||
out.add(rr.next());
|
||||
}
|
||||
|
||||
assertEquals(3, out.size());
|
||||
|
||||
rr.reset();
|
||||
List<List<Writable>> out2 = new ArrayList<>();
|
||||
List<Record> out3 = new ArrayList<>();
|
||||
|
@ -90,13 +84,10 @@ public class FileRecordReaderTest extends BaseND4JTest {
|
|||
out2.add(r.getRecord());
|
||||
out3.add(r);
|
||||
meta.add(r.getMetaData());
|
||||
|
||||
assertEquals(arr[count++], r.getMetaData().getURI());
|
||||
}
|
||||
|
||||
assertEquals(out, out2);
|
||||
List<Record> fromMeta = rr.loadFromMetaData(meta);
|
||||
assertEquals(out3, fromMeta);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.api.records.reader.impl;
|
||||
|
||||
import org.datavec.api.records.reader.RecordReader;
|
||||
|
@ -28,97 +27,81 @@ import org.datavec.api.split.CollectionInputSplit;
|
|||
import org.datavec.api.split.FileSplit;
|
||||
import org.datavec.api.writable.Text;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
import org.nd4j.shade.jackson.core.JsonFactory;
|
||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||||
|
||||
import java.io.File;
|
||||
import java.net.URI;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import java.nio.file.Path;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
@DisplayName("Jackson Line Record Reader Test")
|
||||
class JacksonLineRecordReaderTest extends BaseND4JTest {
|
||||
|
||||
public class JacksonLineRecordReaderTest extends BaseND4JTest {
|
||||
|
||||
@Rule
|
||||
public TemporaryFolder testDir = new TemporaryFolder();
|
||||
@TempDir
|
||||
public Path testDir;
|
||||
|
||||
public JacksonLineRecordReaderTest() {
|
||||
}
|
||||
|
||||
private static FieldSelection getFieldSelection() {
|
||||
return new FieldSelection.Builder().addField("value1").
|
||||
addField("value2").
|
||||
addField("value3").
|
||||
addField("value4").
|
||||
addField("value5").
|
||||
addField("value6").
|
||||
addField("value7").
|
||||
addField("value8").
|
||||
addField("value9").
|
||||
addField("value10").build();
|
||||
return new FieldSelection.Builder().addField("value1").addField("value2").addField("value3").addField("value4").addField("value5").addField("value6").addField("value7").addField("value8").addField("value9").addField("value10").build();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testReadJSON() throws Exception {
|
||||
|
||||
@DisplayName("Test Read JSON")
|
||||
void testReadJSON() throws Exception {
|
||||
RecordReader rr = new JacksonLineRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()));
|
||||
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/json/json_test_3.txt").getFile()));
|
||||
|
||||
testJacksonRecordReader(rr);
|
||||
}
|
||||
|
||||
private static void testJacksonRecordReader(RecordReader rr) {
|
||||
while (rr.hasNext()) {
|
||||
List<Writable> json0 = rr.next();
|
||||
//System.out.println(json0);
|
||||
assert(json0.size() > 0);
|
||||
// System.out.println(json0);
|
||||
assert (json0.size() > 0);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testJacksonLineSequenceRecordReader() throws Exception {
|
||||
File dir = testDir.newFolder();
|
||||
@DisplayName("Test Jackson Line Sequence Record Reader")
|
||||
void testJacksonLineSequenceRecordReader(@TempDir Path testDir) throws Exception {
|
||||
File dir = testDir.toFile();
|
||||
new ClassPathResource("datavec-api/JacksonLineSequenceRecordReaderTest/").copyDirectory(dir);
|
||||
|
||||
FieldSelection f = new FieldSelection.Builder().addField("a").addField(new Text("MISSING_B"), "b")
|
||||
.addField(new Text("MISSING_CX"), "c", "x").build();
|
||||
|
||||
FieldSelection f = new FieldSelection.Builder().addField("a").addField(new Text("MISSING_B"), "b").addField(new Text("MISSING_CX"), "c", "x").build();
|
||||
JacksonLineSequenceRecordReader rr = new JacksonLineSequenceRecordReader(f, new ObjectMapper(new JsonFactory()));
|
||||
File[] files = dir.listFiles();
|
||||
Arrays.sort(files);
|
||||
URI[] u = new URI[files.length];
|
||||
for( int i=0; i<files.length; i++ ){
|
||||
for (int i = 0; i < files.length; i++) {
|
||||
u[i] = files[i].toURI();
|
||||
}
|
||||
rr.initialize(new CollectionInputSplit(u));
|
||||
|
||||
List<List<Writable>> expSeq0 = new ArrayList<>();
|
||||
expSeq0.add(Arrays.asList((Writable) new Text("aValue0"), new Text("bValue0"), new Text("cxValue0")));
|
||||
expSeq0.add(Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1")));
|
||||
expSeq0.add(Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX")));
|
||||
|
||||
List<List<Writable>> expSeq1 = new ArrayList<>();
|
||||
expSeq1.add(Arrays.asList((Writable) new Text("aValue3"), new Text("bValue3"), new Text("cxValue3")));
|
||||
|
||||
|
||||
int count = 0;
|
||||
while(rr.hasNext()){
|
||||
while (rr.hasNext()) {
|
||||
List<List<Writable>> next = rr.sequenceRecord();
|
||||
if(count++ == 0){
|
||||
if (count++ == 0) {
|
||||
assertEquals(expSeq0, next);
|
||||
} else {
|
||||
assertEquals(expSeq1, next);
|
||||
}
|
||||
}
|
||||
|
||||
assertEquals(2, count);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.api.records.reader.impl;
|
||||
|
||||
import org.datavec.api.io.labels.PathLabelGenerator;
|
||||
|
@ -31,114 +30,95 @@ import org.datavec.api.split.NumberedFileInputSplit;
|
|||
import org.datavec.api.writable.IntWritable;
|
||||
import org.datavec.api.writable.Text;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
import org.nd4j.shade.jackson.core.JsonFactory;
|
||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||||
import org.nd4j.shade.jackson.dataformat.xml.XmlFactory;
|
||||
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
|
||||
|
||||
import java.io.File;
|
||||
import java.net.URI;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import java.nio.file.Path;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
@DisplayName("Jackson Record Reader Test")
|
||||
class JacksonRecordReaderTest extends BaseND4JTest {
|
||||
|
||||
public class JacksonRecordReaderTest extends BaseND4JTest {
|
||||
|
||||
@Rule
|
||||
public TemporaryFolder testDir = new TemporaryFolder();
|
||||
@TempDir
|
||||
public Path testDir;
|
||||
|
||||
@Test
|
||||
public void testReadingJson() throws Exception {
|
||||
//Load 3 values from 3 JSON files
|
||||
//stricture: a:value, b:value, c:x:value, c:y:value
|
||||
//And we want to load only a:value, b:value and c:x:value
|
||||
//For first JSON file: all values are present
|
||||
//For second JSON file: b:value is missing
|
||||
//For third JSON file: c:x:value is missing
|
||||
|
||||
@DisplayName("Test Reading Json")
|
||||
void testReadingJson(@TempDir Path testDir) throws Exception {
|
||||
// Load 3 values from 3 JSON files
|
||||
// stricture: a:value, b:value, c:x:value, c:y:value
|
||||
// And we want to load only a:value, b:value and c:x:value
|
||||
// For first JSON file: all values are present
|
||||
// For second JSON file: b:value is missing
|
||||
// For third JSON file: c:x:value is missing
|
||||
ClassPathResource cpr = new ClassPathResource("datavec-api/json/");
|
||||
File f = testDir.newFolder();
|
||||
File f = testDir.toFile();
|
||||
cpr.copyDirectory(f);
|
||||
String path = new File(f, "json_test_%d.txt").getAbsolutePath();
|
||||
|
||||
InputSplit is = new NumberedFileInputSplit(path, 0, 2);
|
||||
|
||||
RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()));
|
||||
rr.initialize(is);
|
||||
|
||||
testJacksonRecordReader(rr);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testReadingYaml() throws Exception {
|
||||
//Exact same information as JSON format, but in YAML format
|
||||
|
||||
@DisplayName("Test Reading Yaml")
|
||||
void testReadingYaml(@TempDir Path testDir) throws Exception {
|
||||
// Exact same information as JSON format, but in YAML format
|
||||
ClassPathResource cpr = new ClassPathResource("datavec-api/yaml/");
|
||||
File f = testDir.newFolder();
|
||||
File f = testDir.toFile();
|
||||
cpr.copyDirectory(f);
|
||||
String path = new File(f, "yaml_test_%d.txt").getAbsolutePath();
|
||||
|
||||
|
||||
InputSplit is = new NumberedFileInputSplit(path, 0, 2);
|
||||
|
||||
RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new YAMLFactory()));
|
||||
rr.initialize(is);
|
||||
|
||||
testJacksonRecordReader(rr);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testReadingXml() throws Exception {
|
||||
//Exact same information as JSON format, but in XML format
|
||||
|
||||
@DisplayName("Test Reading Xml")
|
||||
void testReadingXml(@TempDir Path testDir) throws Exception {
|
||||
// Exact same information as JSON format, but in XML format
|
||||
ClassPathResource cpr = new ClassPathResource("datavec-api/xml/");
|
||||
File f = testDir.newFolder();
|
||||
File f = testDir.toFile();
|
||||
cpr.copyDirectory(f);
|
||||
String path = new File(f, "xml_test_%d.txt").getAbsolutePath();
|
||||
|
||||
InputSplit is = new NumberedFileInputSplit(path, 0, 2);
|
||||
|
||||
RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new XmlFactory()));
|
||||
rr.initialize(is);
|
||||
|
||||
testJacksonRecordReader(rr);
|
||||
}
|
||||
|
||||
|
||||
private static FieldSelection getFieldSelection() {
|
||||
return new FieldSelection.Builder().addField("a").addField(new Text("MISSING_B"), "b")
|
||||
.addField(new Text("MISSING_CX"), "c", "x").build();
|
||||
return new FieldSelection.Builder().addField("a").addField(new Text("MISSING_B"), "b").addField(new Text("MISSING_CX"), "c", "x").build();
|
||||
}
|
||||
|
||||
|
||||
|
||||
private static void testJacksonRecordReader(RecordReader rr) {
|
||||
|
||||
List<Writable> json0 = rr.next();
|
||||
List<Writable> exp0 = Arrays.asList((Writable) new Text("aValue0"), new Text("bValue0"), new Text("cxValue0"));
|
||||
assertEquals(exp0, json0);
|
||||
|
||||
List<Writable> json1 = rr.next();
|
||||
List<Writable> exp1 =
|
||||
Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1"));
|
||||
List<Writable> exp1 = Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1"));
|
||||
assertEquals(exp1, json1);
|
||||
|
||||
List<Writable> json2 = rr.next();
|
||||
List<Writable> exp2 =
|
||||
Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX"));
|
||||
List<Writable> exp2 = Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX"));
|
||||
assertEquals(exp2, json2);
|
||||
|
||||
assertFalse(rr.hasNext());
|
||||
|
||||
//Test reset
|
||||
// Test reset
|
||||
rr.reset();
|
||||
assertEquals(exp0, rr.next());
|
||||
assertEquals(exp1, rr.next());
|
||||
|
@ -147,72 +127,50 @@ public class JacksonRecordReaderTest extends BaseND4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testAppendingLabels() throws Exception {
|
||||
|
||||
@DisplayName("Test Appending Labels")
|
||||
void testAppendingLabels(@TempDir Path testDir) throws Exception {
|
||||
ClassPathResource cpr = new ClassPathResource("datavec-api/json/");
|
||||
File f = testDir.newFolder();
|
||||
File f = testDir.toFile();
|
||||
cpr.copyDirectory(f);
|
||||
String path = new File(f, "json_test_%d.txt").getAbsolutePath();
|
||||
|
||||
InputSplit is = new NumberedFileInputSplit(path, 0, 2);
|
||||
|
||||
//Insert at the end:
|
||||
RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1,
|
||||
new LabelGen());
|
||||
// Insert at the end:
|
||||
RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1, new LabelGen());
|
||||
rr.initialize(is);
|
||||
|
||||
List<Writable> exp0 = Arrays.asList((Writable) new Text("aValue0"), new Text("bValue0"), new Text("cxValue0"),
|
||||
new IntWritable(0));
|
||||
List<Writable> exp0 = Arrays.asList((Writable) new Text("aValue0"), new Text("bValue0"), new Text("cxValue0"), new IntWritable(0));
|
||||
assertEquals(exp0, rr.next());
|
||||
|
||||
List<Writable> exp1 = Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1"),
|
||||
new IntWritable(1));
|
||||
List<Writable> exp1 = Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1"), new IntWritable(1));
|
||||
assertEquals(exp1, rr.next());
|
||||
|
||||
List<Writable> exp2 = Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX"),
|
||||
new IntWritable(2));
|
||||
List<Writable> exp2 = Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX"), new IntWritable(2));
|
||||
assertEquals(exp2, rr.next());
|
||||
|
||||
//Insert at position 0:
|
||||
rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1,
|
||||
new LabelGen(), 0);
|
||||
// Insert at position 0:
|
||||
rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1, new LabelGen(), 0);
|
||||
rr.initialize(is);
|
||||
|
||||
exp0 = Arrays.asList((Writable) new IntWritable(0), new Text("aValue0"), new Text("bValue0"),
|
||||
new Text("cxValue0"));
|
||||
exp0 = Arrays.asList((Writable) new IntWritable(0), new Text("aValue0"), new Text("bValue0"), new Text("cxValue0"));
|
||||
assertEquals(exp0, rr.next());
|
||||
|
||||
exp1 = Arrays.asList((Writable) new IntWritable(1), new Text("aValue1"), new Text("MISSING_B"),
|
||||
new Text("cxValue1"));
|
||||
exp1 = Arrays.asList((Writable) new IntWritable(1), new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1"));
|
||||
assertEquals(exp1, rr.next());
|
||||
|
||||
exp2 = Arrays.asList((Writable) new IntWritable(2), new Text("aValue2"), new Text("bValue2"),
|
||||
new Text("MISSING_CX"));
|
||||
exp2 = Arrays.asList((Writable) new IntWritable(2), new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX"));
|
||||
assertEquals(exp2, rr.next());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testAppendingLabelsMetaData() throws Exception {
|
||||
@DisplayName("Test Appending Labels Meta Data")
|
||||
void testAppendingLabelsMetaData(@TempDir Path testDir) throws Exception {
|
||||
ClassPathResource cpr = new ClassPathResource("datavec-api/json/");
|
||||
File f = testDir.newFolder();
|
||||
File f = testDir.toFile();
|
||||
cpr.copyDirectory(f);
|
||||
String path = new File(f, "json_test_%d.txt").getAbsolutePath();
|
||||
|
||||
InputSplit is = new NumberedFileInputSplit(path, 0, 2);
|
||||
|
||||
//Insert at the end:
|
||||
RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1,
|
||||
new LabelGen());
|
||||
// Insert at the end:
|
||||
RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1, new LabelGen());
|
||||
rr.initialize(is);
|
||||
|
||||
List<List<Writable>> out = new ArrayList<>();
|
||||
while (rr.hasNext()) {
|
||||
out.add(rr.next());
|
||||
}
|
||||
assertEquals(3, out.size());
|
||||
|
||||
rr.reset();
|
||||
|
||||
List<List<Writable>> out2 = new ArrayList<>();
|
||||
List<Record> outRecord = new ArrayList<>();
|
||||
List<RecordMetaData> meta = new ArrayList<>();
|
||||
|
@ -222,14 +180,12 @@ public class JacksonRecordReaderTest extends BaseND4JTest {
|
|||
outRecord.add(r);
|
||||
meta.add(r.getMetaData());
|
||||
}
|
||||
|
||||
assertEquals(out, out2);
|
||||
|
||||
List<Record> fromMeta = rr.loadFromMetaData(meta);
|
||||
assertEquals(outRecord, fromMeta);
|
||||
}
|
||||
|
||||
|
||||
@DisplayName("Label Gen")
|
||||
private static class LabelGen implements PathLabelGenerator {
|
||||
|
||||
@Override
|
||||
|
@ -252,5 +208,4 @@ public class JacksonRecordReaderTest extends BaseND4JTest {
|
|||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.api.records.reader.impl;
|
||||
|
||||
import org.datavec.api.conf.Configuration;
|
||||
|
@ -27,43 +26,30 @@ import org.datavec.api.split.FileSplit;
|
|||
import org.datavec.api.writable.DoubleWritable;
|
||||
import org.datavec.api.writable.IntWritable;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.*;
|
||||
|
||||
import static org.datavec.api.records.reader.impl.misc.LibSvmRecordReader.*;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
|
||||
public class LibSvmRecordReaderTest extends BaseND4JTest {
|
||||
@DisplayName("Lib Svm Record Reader Test")
|
||||
class LibSvmRecordReaderTest extends BaseND4JTest {
|
||||
|
||||
@Test
|
||||
public void testBasicRecord() throws IOException, InterruptedException {
|
||||
@DisplayName("Test Basic Record")
|
||||
void testBasicRecord() throws IOException, InterruptedException {
|
||||
Map<Integer, List<Writable>> correct = new HashMap<>();
|
||||
// 7 2:1 4:2 6:3 8:4 10:5
|
||||
correct.put(0, Arrays.asList(ZERO, ONE,
|
||||
ZERO, new DoubleWritable(2),
|
||||
ZERO, new DoubleWritable(3),
|
||||
ZERO, new DoubleWritable(4),
|
||||
ZERO, new DoubleWritable(5),
|
||||
new IntWritable(7)));
|
||||
correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), new IntWritable(7)));
|
||||
// 2 qid:42 1:0.1 2:2 6:6.6 8:80
|
||||
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2),
|
||||
ZERO, ZERO,
|
||||
ZERO, new DoubleWritable(6.6),
|
||||
ZERO, new DoubleWritable(80),
|
||||
ZERO, ZERO,
|
||||
new IntWritable(2)));
|
||||
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, new IntWritable(2)));
|
||||
// 33
|
||||
correct.put(2, Arrays.asList(ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
new IntWritable(33)));
|
||||
|
||||
correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, new IntWritable(33)));
|
||||
LibSvmRecordReader rr = new LibSvmRecordReader();
|
||||
Configuration config = new Configuration();
|
||||
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
|
||||
|
@ -80,27 +66,15 @@ public class LibSvmRecordReaderTest extends BaseND4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testNoAppendLabel() throws IOException, InterruptedException {
|
||||
@DisplayName("Test No Append Label")
|
||||
void testNoAppendLabel() throws IOException, InterruptedException {
|
||||
Map<Integer, List<Writable>> correct = new HashMap<>();
|
||||
// 7 2:1 4:2 6:3 8:4 10:5
|
||||
correct.put(0, Arrays.asList(ZERO, ONE,
|
||||
ZERO, new DoubleWritable(2),
|
||||
ZERO, new DoubleWritable(3),
|
||||
ZERO, new DoubleWritable(4),
|
||||
ZERO, new DoubleWritable(5)));
|
||||
correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5)));
|
||||
// 2 qid:42 1:0.1 2:2 6:6.6 8:80
|
||||
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2),
|
||||
ZERO, ZERO,
|
||||
ZERO, new DoubleWritable(6.6),
|
||||
ZERO, new DoubleWritable(80),
|
||||
ZERO, ZERO));
|
||||
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO));
|
||||
// 33
|
||||
correct.put(2, Arrays.asList(ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO));
|
||||
|
||||
correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO));
|
||||
SVMLightRecordReader rr = new SVMLightRecordReader();
|
||||
Configuration config = new Configuration();
|
||||
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
|
||||
|
@ -117,33 +91,17 @@ public class LibSvmRecordReaderTest extends BaseND4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testNoLabel() throws IOException, InterruptedException {
|
||||
@DisplayName("Test No Label")
|
||||
void testNoLabel() throws IOException, InterruptedException {
|
||||
Map<Integer, List<Writable>> correct = new HashMap<>();
|
||||
// 2:1 4:2 6:3 8:4 10:5
|
||||
correct.put(0, Arrays.asList(ZERO, ONE,
|
||||
ZERO, new DoubleWritable(2),
|
||||
ZERO, new DoubleWritable(3),
|
||||
ZERO, new DoubleWritable(4),
|
||||
ZERO, new DoubleWritable(5)));
|
||||
correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5)));
|
||||
// qid:42 1:0.1 2:2 6:6.6 8:80
|
||||
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2),
|
||||
ZERO, ZERO,
|
||||
ZERO, new DoubleWritable(6.6),
|
||||
ZERO, new DoubleWritable(80),
|
||||
ZERO, ZERO));
|
||||
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO));
|
||||
// 1:1.0
|
||||
correct.put(2, Arrays.asList(new DoubleWritable(1.0), ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO));
|
||||
correct.put(2, Arrays.asList(new DoubleWritable(1.0), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO));
|
||||
//
|
||||
correct.put(3, Arrays.asList(ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO));
|
||||
|
||||
correct.put(3, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO));
|
||||
SVMLightRecordReader rr = new SVMLightRecordReader();
|
||||
Configuration config = new Configuration();
|
||||
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
|
||||
|
@ -160,33 +118,15 @@ public class LibSvmRecordReaderTest extends BaseND4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testMultioutputRecord() throws IOException, InterruptedException {
|
||||
@DisplayName("Test Multioutput Record")
|
||||
void testMultioutputRecord() throws IOException, InterruptedException {
|
||||
Map<Integer, List<Writable>> correct = new HashMap<>();
|
||||
// 7 2.45,9 2:1 4:2 6:3 8:4 10:5
|
||||
correct.put(0, Arrays.asList(ZERO, ONE,
|
||||
ZERO, new DoubleWritable(2),
|
||||
ZERO, new DoubleWritable(3),
|
||||
ZERO, new DoubleWritable(4),
|
||||
ZERO, new DoubleWritable(5),
|
||||
new IntWritable(7), new DoubleWritable(2.45),
|
||||
new IntWritable(9)));
|
||||
correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), new IntWritable(7), new DoubleWritable(2.45), new IntWritable(9)));
|
||||
// 2,3,4 qid:42 1:0.1 2:2 6:6.6 8:80
|
||||
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2),
|
||||
ZERO, ZERO,
|
||||
ZERO, new DoubleWritable(6.6),
|
||||
ZERO, new DoubleWritable(80),
|
||||
ZERO, ZERO,
|
||||
new IntWritable(2), new IntWritable(3),
|
||||
new IntWritable(4)));
|
||||
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, new IntWritable(2), new IntWritable(3), new IntWritable(4)));
|
||||
// 33,32.0,31.9
|
||||
correct.put(2, Arrays.asList(ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
new IntWritable(33), new DoubleWritable(32.0),
|
||||
new DoubleWritable(31.9)));
|
||||
|
||||
correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, new IntWritable(33), new DoubleWritable(32.0), new DoubleWritable(31.9)));
|
||||
LibSvmRecordReader rr = new LibSvmRecordReader();
|
||||
Configuration config = new Configuration();
|
||||
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
|
||||
|
@ -202,51 +142,20 @@ public class LibSvmRecordReaderTest extends BaseND4JTest {
|
|||
assertEquals(i, correct.size());
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testMultilabelRecord() throws IOException, InterruptedException {
|
||||
@DisplayName("Test Multilabel Record")
|
||||
void testMultilabelRecord() throws IOException, InterruptedException {
|
||||
Map<Integer, List<Writable>> correct = new HashMap<>();
|
||||
// 1,3 2:1 4:2 6:3 8:4 10:5
|
||||
correct.put(0, Arrays.asList(ZERO, ONE,
|
||||
ZERO, new DoubleWritable(2),
|
||||
ZERO, new DoubleWritable(3),
|
||||
ZERO, new DoubleWritable(4),
|
||||
ZERO, new DoubleWritable(5),
|
||||
LABEL_ONE, LABEL_ZERO,
|
||||
LABEL_ONE, LABEL_ZERO));
|
||||
correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), LABEL_ONE, LABEL_ZERO, LABEL_ONE, LABEL_ZERO));
|
||||
// 2 qid:42 1:0.1 2:2 6:6.6 8:80
|
||||
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2),
|
||||
ZERO, ZERO,
|
||||
ZERO, new DoubleWritable(6.6),
|
||||
ZERO, new DoubleWritable(80),
|
||||
ZERO, ZERO,
|
||||
LABEL_ZERO, LABEL_ONE,
|
||||
LABEL_ZERO, LABEL_ZERO));
|
||||
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, LABEL_ZERO, LABEL_ONE, LABEL_ZERO, LABEL_ZERO));
|
||||
// 1,2,4
|
||||
correct.put(2, Arrays.asList(ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
LABEL_ONE, LABEL_ONE,
|
||||
LABEL_ZERO, LABEL_ONE));
|
||||
correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ONE, LABEL_ONE, LABEL_ZERO, LABEL_ONE));
|
||||
// 1:1.0
|
||||
correct.put(3, Arrays.asList(new DoubleWritable(1.0), ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
LABEL_ZERO, LABEL_ZERO,
|
||||
LABEL_ZERO, LABEL_ZERO));
|
||||
correct.put(3, Arrays.asList(new DoubleWritable(1.0), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
|
||||
//
|
||||
correct.put(4, Arrays.asList(ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
LABEL_ZERO, LABEL_ZERO,
|
||||
LABEL_ZERO, LABEL_ZERO));
|
||||
|
||||
correct.put(4, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
|
||||
LibSvmRecordReader rr = new LibSvmRecordReader();
|
||||
Configuration config = new Configuration();
|
||||
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
|
||||
|
@ -265,63 +174,24 @@ public class LibSvmRecordReaderTest extends BaseND4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testZeroBasedIndexing() throws IOException, InterruptedException {
|
||||
@DisplayName("Test Zero Based Indexing")
|
||||
void testZeroBasedIndexing() throws IOException, InterruptedException {
|
||||
Map<Integer, List<Writable>> correct = new HashMap<>();
|
||||
// 1,3 2:1 4:2 6:3 8:4 10:5
|
||||
correct.put(0, Arrays.asList(ZERO,
|
||||
ZERO, ONE,
|
||||
ZERO, new DoubleWritable(2),
|
||||
ZERO, new DoubleWritable(3),
|
||||
ZERO, new DoubleWritable(4),
|
||||
ZERO, new DoubleWritable(5),
|
||||
LABEL_ZERO,
|
||||
LABEL_ONE, LABEL_ZERO,
|
||||
LABEL_ONE, LABEL_ZERO));
|
||||
correct.put(0, Arrays.asList(ZERO, ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), LABEL_ZERO, LABEL_ONE, LABEL_ZERO, LABEL_ONE, LABEL_ZERO));
|
||||
// 2 qid:42 1:0.1 2:2 6:6.6 8:80
|
||||
correct.put(1, Arrays.asList(ZERO,
|
||||
new DoubleWritable(0.1), new DoubleWritable(2),
|
||||
ZERO, ZERO,
|
||||
ZERO, new DoubleWritable(6.6),
|
||||
ZERO, new DoubleWritable(80),
|
||||
ZERO, ZERO,
|
||||
LABEL_ZERO,
|
||||
LABEL_ZERO, LABEL_ONE,
|
||||
LABEL_ZERO, LABEL_ZERO));
|
||||
correct.put(1, Arrays.asList(ZERO, new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ONE, LABEL_ZERO, LABEL_ZERO));
|
||||
// 1,2,4
|
||||
correct.put(2, Arrays.asList(ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
LABEL_ZERO,
|
||||
LABEL_ONE, LABEL_ONE,
|
||||
LABEL_ZERO, LABEL_ONE));
|
||||
correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ONE, LABEL_ONE, LABEL_ZERO, LABEL_ONE));
|
||||
// 1:1.0
|
||||
correct.put(3, Arrays.asList(ZERO,
|
||||
new DoubleWritable(1.0), ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
LABEL_ZERO,
|
||||
LABEL_ZERO, LABEL_ZERO,
|
||||
LABEL_ZERO, LABEL_ZERO));
|
||||
correct.put(3, Arrays.asList(ZERO, new DoubleWritable(1.0), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
|
||||
//
|
||||
correct.put(4, Arrays.asList(ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
LABEL_ZERO,
|
||||
LABEL_ZERO, LABEL_ZERO,
|
||||
LABEL_ZERO, LABEL_ZERO));
|
||||
|
||||
correct.put(4, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
|
||||
LibSvmRecordReader rr = new LibSvmRecordReader();
|
||||
Configuration config = new Configuration();
|
||||
// Zero-based indexing is default
|
||||
config.setBoolean(SVMLightRecordReader.ZERO_BASED_LABEL_INDEXING, true); // NOT STANDARD!
|
||||
// NOT STANDARD!
|
||||
config.setBoolean(SVMLightRecordReader.ZERO_BASED_LABEL_INDEXING, true);
|
||||
config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true);
|
||||
config.setInt(LibSvmRecordReader.NUM_FEATURES, 11);
|
||||
config.setBoolean(LibSvmRecordReader.MULTILABEL, true);
|
||||
|
@ -336,58 +206,71 @@ public class LibSvmRecordReaderTest extends BaseND4JTest {
|
|||
assertEquals(i, correct.size());
|
||||
}
|
||||
|
||||
@Test(expected = NoSuchElementException.class)
|
||||
public void testNoSuchElementException() throws Exception {
|
||||
@Test
|
||||
@DisplayName("Test No Such Element Exception")
|
||||
void testNoSuchElementException() {
|
||||
assertThrows(NoSuchElementException.class, () -> {
|
||||
LibSvmRecordReader rr = new LibSvmRecordReader();
|
||||
Configuration config = new Configuration();
|
||||
config.setInt(LibSvmRecordReader.NUM_FEATURES, 11);
|
||||
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
|
||||
while (rr.hasNext())
|
||||
rr.next();
|
||||
while (rr.hasNext()) rr.next();
|
||||
rr.next();
|
||||
});
|
||||
}
|
||||
|
||||
@Test(expected = UnsupportedOperationException.class)
|
||||
public void failedToSetNumFeaturesException() throws Exception {
|
||||
@Test
|
||||
@DisplayName("Failed To Set Num Features Exception")
|
||||
void failedToSetNumFeaturesException() {
|
||||
assertThrows(UnsupportedOperationException.class, () -> {
|
||||
LibSvmRecordReader rr = new LibSvmRecordReader();
|
||||
Configuration config = new Configuration();
|
||||
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
|
||||
while (rr.hasNext())
|
||||
rr.next();
|
||||
while (rr.hasNext()) rr.next();
|
||||
});
|
||||
}
|
||||
|
||||
@Test(expected = UnsupportedOperationException.class)
|
||||
public void testInconsistentNumLabelsException() throws Exception {
|
||||
@Test
|
||||
@DisplayName("Test Inconsistent Num Labels Exception")
|
||||
void testInconsistentNumLabelsException() {
|
||||
assertThrows(UnsupportedOperationException.class, () -> {
|
||||
LibSvmRecordReader rr = new LibSvmRecordReader();
|
||||
Configuration config = new Configuration();
|
||||
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
|
||||
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/inconsistentNumLabels.txt").getFile()));
|
||||
while (rr.hasNext())
|
||||
rr.next();
|
||||
while (rr.hasNext()) rr.next();
|
||||
});
|
||||
}
|
||||
|
||||
@Test(expected = UnsupportedOperationException.class)
|
||||
public void testInconsistentNumMultiabelsException() throws Exception {
|
||||
@Test
|
||||
@DisplayName("Test Inconsistent Num Multiabels Exception")
|
||||
void testInconsistentNumMultiabelsException() {
|
||||
assertThrows(UnsupportedOperationException.class, () -> {
|
||||
LibSvmRecordReader rr = new LibSvmRecordReader();
|
||||
Configuration config = new Configuration();
|
||||
config.setBoolean(LibSvmRecordReader.MULTILABEL, false);
|
||||
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
|
||||
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile()));
|
||||
while (rr.hasNext())
|
||||
rr.next();
|
||||
while (rr.hasNext()) rr.next();
|
||||
});
|
||||
}
|
||||
|
||||
@Test(expected = IndexOutOfBoundsException.class)
|
||||
public void testFeatureIndexExceedsNumFeatures() throws Exception {
|
||||
@Test
|
||||
@DisplayName("Test Feature Index Exceeds Num Features")
|
||||
void testFeatureIndexExceedsNumFeatures() {
|
||||
assertThrows(IndexOutOfBoundsException.class, () -> {
|
||||
LibSvmRecordReader rr = new LibSvmRecordReader();
|
||||
Configuration config = new Configuration();
|
||||
config.setInt(LibSvmRecordReader.NUM_FEATURES, 9);
|
||||
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
|
||||
rr.next();
|
||||
});
|
||||
}
|
||||
|
||||
@Test(expected = IndexOutOfBoundsException.class)
|
||||
public void testLabelIndexExceedsNumLabels() throws Exception {
|
||||
@Test
|
||||
@DisplayName("Test Label Index Exceeds Num Labels")
|
||||
void testLabelIndexExceedsNumLabels() {
|
||||
assertThrows(IndexOutOfBoundsException.class, () -> {
|
||||
LibSvmRecordReader rr = new LibSvmRecordReader();
|
||||
Configuration config = new Configuration();
|
||||
config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true);
|
||||
|
@ -395,10 +278,13 @@ public class LibSvmRecordReaderTest extends BaseND4JTest {
|
|||
config.setInt(LibSvmRecordReader.NUM_LABELS, 6);
|
||||
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
|
||||
rr.next();
|
||||
});
|
||||
}
|
||||
|
||||
@Test(expected = IndexOutOfBoundsException.class)
|
||||
public void testZeroIndexFeatureWithoutUsingZeroIndexing() throws Exception {
|
||||
@Test
|
||||
@DisplayName("Test Zero Index Feature Without Using Zero Indexing")
|
||||
void testZeroIndexFeatureWithoutUsingZeroIndexing() {
|
||||
assertThrows(IndexOutOfBoundsException.class, () -> {
|
||||
LibSvmRecordReader rr = new LibSvmRecordReader();
|
||||
Configuration config = new Configuration();
|
||||
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
|
||||
|
@ -406,10 +292,13 @@ public class LibSvmRecordReaderTest extends BaseND4JTest {
|
|||
config.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
|
||||
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexFeature.txt").getFile()));
|
||||
rr.next();
|
||||
});
|
||||
}
|
||||
|
||||
@Test(expected = IndexOutOfBoundsException.class)
|
||||
public void testZeroIndexLabelWithoutUsingZeroIndexing() throws Exception {
|
||||
@Test
|
||||
@DisplayName("Test Zero Index Label Without Using Zero Indexing")
|
||||
void testZeroIndexLabelWithoutUsingZeroIndexing() {
|
||||
assertThrows(IndexOutOfBoundsException.class, () -> {
|
||||
LibSvmRecordReader rr = new LibSvmRecordReader();
|
||||
Configuration config = new Configuration();
|
||||
config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true);
|
||||
|
@ -418,5 +307,6 @@ public class LibSvmRecordReaderTest extends BaseND4JTest {
|
|||
config.setInt(LibSvmRecordReader.NUM_LABELS, 2);
|
||||
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexLabel.txt").getFile()));
|
||||
rr.next();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.api.records.reader.impl;
|
||||
|
||||
import org.apache.commons.io.FileUtils;
|
||||
|
@ -30,11 +29,10 @@ import org.datavec.api.split.FileSplit;
|
|||
import org.datavec.api.split.InputSplit;
|
||||
import org.datavec.api.split.InputStreamInputSplit;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import java.io.File;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.FileOutputStream;
|
||||
|
@ -45,34 +43,31 @@ import java.util.Arrays;
|
|||
import java.util.List;
|
||||
import java.util.zip.GZIPInputStream;
|
||||
import java.util.zip.GZIPOutputStream;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import java.nio.file.Path;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
@DisplayName("Line Reader Test")
|
||||
class LineReaderTest extends BaseND4JTest {
|
||||
|
||||
public class LineReaderTest extends BaseND4JTest {
|
||||
|
||||
@Rule
|
||||
public TemporaryFolder testDir = new TemporaryFolder();
|
||||
|
||||
@Test
|
||||
public void testLineReader() throws Exception {
|
||||
File tmpdir = testDir.newFolder();
|
||||
@DisplayName("Test Line Reader")
|
||||
void testLineReader(@TempDir Path tmpDir) throws Exception {
|
||||
File tmpdir = tmpDir.toFile();
|
||||
if (tmpdir.exists())
|
||||
tmpdir.delete();
|
||||
tmpdir.mkdir();
|
||||
|
||||
File tmp1 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp1.txt"));
|
||||
File tmp2 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp2.txt"));
|
||||
File tmp3 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp3.txt"));
|
||||
|
||||
FileUtils.writeLines(tmp1, Arrays.asList("1", "2", "3"));
|
||||
FileUtils.writeLines(tmp2, Arrays.asList("4", "5", "6"));
|
||||
FileUtils.writeLines(tmp3, Arrays.asList("7", "8", "9"));
|
||||
|
||||
InputSplit split = new FileSplit(tmpdir);
|
||||
|
||||
RecordReader reader = new LineRecordReader();
|
||||
reader.initialize(split);
|
||||
|
||||
int count = 0;
|
||||
List<List<Writable>> list = new ArrayList<>();
|
||||
while (reader.hasNext()) {
|
||||
|
@ -81,34 +76,27 @@ public class LineReaderTest extends BaseND4JTest {
|
|||
list.add(l);
|
||||
count++;
|
||||
}
|
||||
|
||||
assertEquals(9, count);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLineReaderMetaData() throws Exception {
|
||||
File tmpdir = testDir.newFolder();
|
||||
|
||||
@DisplayName("Test Line Reader Meta Data")
|
||||
void testLineReaderMetaData(@TempDir Path tmpDir) throws Exception {
|
||||
File tmpdir = tmpDir.toFile();
|
||||
File tmp1 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp1.txt"));
|
||||
File tmp2 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp2.txt"));
|
||||
File tmp3 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp3.txt"));
|
||||
|
||||
FileUtils.writeLines(tmp1, Arrays.asList("1", "2", "3"));
|
||||
FileUtils.writeLines(tmp2, Arrays.asList("4", "5", "6"));
|
||||
FileUtils.writeLines(tmp3, Arrays.asList("7", "8", "9"));
|
||||
|
||||
InputSplit split = new FileSplit(tmpdir);
|
||||
|
||||
RecordReader reader = new LineRecordReader();
|
||||
reader.initialize(split);
|
||||
|
||||
List<List<Writable>> list = new ArrayList<>();
|
||||
while (reader.hasNext()) {
|
||||
list.add(reader.next());
|
||||
}
|
||||
assertEquals(9, list.size());
|
||||
|
||||
|
||||
List<List<Writable>> out2 = new ArrayList<>();
|
||||
List<Record> out3 = new ArrayList<>();
|
||||
List<RecordMetaData> meta = new ArrayList<>();
|
||||
|
@ -124,13 +112,10 @@ public class LineReaderTest extends BaseND4JTest {
|
|||
assertEquals(uri, split.locations()[fileIdx]);
|
||||
count++;
|
||||
}
|
||||
|
||||
assertEquals(list, out2);
|
||||
|
||||
List<Record> fromMeta = reader.loadFromMetaData(meta);
|
||||
assertEquals(out3, fromMeta);
|
||||
|
||||
//try: second line of second and third files only...
|
||||
// try: second line of second and third files only...
|
||||
List<RecordMetaData> subsetMeta = new ArrayList<>();
|
||||
subsetMeta.add(meta.get(4));
|
||||
subsetMeta.add(meta.get(7));
|
||||
|
@ -141,27 +126,22 @@ public class LineReaderTest extends BaseND4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testLineReaderWithInputStreamInputSplit() throws Exception {
|
||||
File tmpdir = testDir.newFolder();
|
||||
|
||||
@DisplayName("Test Line Reader With Input Stream Input Split")
|
||||
void testLineReaderWithInputStreamInputSplit(@TempDir Path testDir) throws Exception {
|
||||
File tmpdir = testDir.toFile();
|
||||
File tmp1 = new File(tmpdir, "tmp1.txt.gz");
|
||||
|
||||
OutputStream os = new GZIPOutputStream(new FileOutputStream(tmp1, false));
|
||||
IOUtils.writeLines(Arrays.asList("1", "2", "3", "4", "5", "6", "7", "8", "9"), null, os);
|
||||
os.flush();
|
||||
os.close();
|
||||
|
||||
InputSplit split = new InputStreamInputSplit(new GZIPInputStream(new FileInputStream(tmp1)));
|
||||
|
||||
RecordReader reader = new LineRecordReader();
|
||||
reader.initialize(split);
|
||||
|
||||
int count = 0;
|
||||
while (reader.hasNext()) {
|
||||
assertEquals(1, reader.next().size());
|
||||
count++;
|
||||
}
|
||||
|
||||
assertEquals(9, count);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.api.records.reader.impl;
|
||||
|
||||
import org.datavec.api.records.Record;
|
||||
|
@ -33,44 +32,41 @@ import org.datavec.api.split.InputSplit;
|
|||
import org.datavec.api.split.NumberedFileInputSplit;
|
||||
import org.datavec.api.writable.Text;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
|
||||
import java.io.File;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import java.nio.file.Path;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
@DisplayName("Regex Record Reader Test")
|
||||
class RegexRecordReaderTest extends BaseND4JTest {
|
||||
|
||||
public class RegexRecordReaderTest extends BaseND4JTest {
|
||||
|
||||
@Rule
|
||||
public TemporaryFolder testDir = new TemporaryFolder();
|
||||
@TempDir
|
||||
public Path testDir;
|
||||
|
||||
@Test
|
||||
public void testRegexLineRecordReader() throws Exception {
|
||||
@DisplayName("Test Regex Line Record Reader")
|
||||
void testRegexLineRecordReader() throws Exception {
|
||||
String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)";
|
||||
|
||||
RecordReader rr = new RegexLineRecordReader(regex, 1);
|
||||
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/logtestdata/logtestfile0.txt").getFile()));
|
||||
|
||||
List<Writable> exp0 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.001"), new Text("1"),
|
||||
new Text("DEBUG"), new Text("First entry message!"));
|
||||
List<Writable> exp1 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.002"), new Text("2"),
|
||||
new Text("INFO"), new Text("Second entry message!"));
|
||||
List<Writable> exp2 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.003"), new Text("3"),
|
||||
new Text("WARN"), new Text("Third entry message!"));
|
||||
List<Writable> exp0 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.001"), new Text("1"), new Text("DEBUG"), new Text("First entry message!"));
|
||||
List<Writable> exp1 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.002"), new Text("2"), new Text("INFO"), new Text("Second entry message!"));
|
||||
List<Writable> exp2 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.003"), new Text("3"), new Text("WARN"), new Text("Third entry message!"));
|
||||
assertEquals(exp0, rr.next());
|
||||
assertEquals(exp1, rr.next());
|
||||
assertEquals(exp2, rr.next());
|
||||
assertFalse(rr.hasNext());
|
||||
|
||||
//Test reset:
|
||||
// Test reset:
|
||||
rr.reset();
|
||||
assertEquals(exp0, rr.next());
|
||||
assertEquals(exp1, rr.next());
|
||||
|
@ -79,74 +75,57 @@ public class RegexRecordReaderTest extends BaseND4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testRegexLineRecordReaderMeta() throws Exception {
|
||||
@DisplayName("Test Regex Line Record Reader Meta")
|
||||
void testRegexLineRecordReaderMeta() throws Exception {
|
||||
String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)";
|
||||
|
||||
RecordReader rr = new RegexLineRecordReader(regex, 1);
|
||||
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/logtestdata/logtestfile0.txt").getFile()));
|
||||
|
||||
List<List<Writable>> list = new ArrayList<>();
|
||||
while (rr.hasNext()) {
|
||||
list.add(rr.next());
|
||||
}
|
||||
assertEquals(3, list.size());
|
||||
|
||||
List<Record> list2 = new ArrayList<>();
|
||||
List<List<Writable>> list3 = new ArrayList<>();
|
||||
List<RecordMetaData> meta = new ArrayList<>();
|
||||
rr.reset();
|
||||
int count = 1; //Start by skipping 1 line
|
||||
// Start by skipping 1 line
|
||||
int count = 1;
|
||||
while (rr.hasNext()) {
|
||||
Record r = rr.nextRecord();
|
||||
list2.add(r);
|
||||
list3.add(r.getRecord());
|
||||
meta.add(r.getMetaData());
|
||||
|
||||
assertEquals(count++, ((RecordMetaDataLine) r.getMetaData()).getLineNumber());
|
||||
}
|
||||
|
||||
List<Record> fromMeta = rr.loadFromMetaData(meta);
|
||||
|
||||
assertEquals(list, list3);
|
||||
assertEquals(list2, fromMeta);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRegexSequenceRecordReader() throws Exception {
|
||||
@DisplayName("Test Regex Sequence Record Reader")
|
||||
void testRegexSequenceRecordReader(@TempDir Path testDir) throws Exception {
|
||||
String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)";
|
||||
|
||||
ClassPathResource cpr = new ClassPathResource("datavec-api/logtestdata/");
|
||||
File f = testDir.newFolder();
|
||||
File f = testDir.toFile();
|
||||
cpr.copyDirectory(f);
|
||||
String path = new File(f, "logtestfile%d.txt").getAbsolutePath();
|
||||
|
||||
InputSplit is = new NumberedFileInputSplit(path, 0, 1);
|
||||
|
||||
SequenceRecordReader rr = new RegexSequenceRecordReader(regex, 1);
|
||||
rr.initialize(is);
|
||||
|
||||
List<List<Writable>> exp0 = new ArrayList<>();
|
||||
exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.001"), new Text("1"), new Text("DEBUG"),
|
||||
new Text("First entry message!")));
|
||||
exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.002"), new Text("2"), new Text("INFO"),
|
||||
new Text("Second entry message!")));
|
||||
exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.003"), new Text("3"), new Text("WARN"),
|
||||
new Text("Third entry message!")));
|
||||
|
||||
|
||||
exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.001"), new Text("1"), new Text("DEBUG"), new Text("First entry message!")));
|
||||
exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.002"), new Text("2"), new Text("INFO"), new Text("Second entry message!")));
|
||||
exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.003"), new Text("3"), new Text("WARN"), new Text("Third entry message!")));
|
||||
List<List<Writable>> exp1 = new ArrayList<>();
|
||||
exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.011"), new Text("11"), new Text("DEBUG"),
|
||||
new Text("First entry message!")));
|
||||
exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.012"), new Text("12"), new Text("INFO"),
|
||||
new Text("Second entry message!")));
|
||||
exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.013"), new Text("13"), new Text("WARN"),
|
||||
new Text("Third entry message!")));
|
||||
|
||||
exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.011"), new Text("11"), new Text("DEBUG"), new Text("First entry message!")));
|
||||
exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.012"), new Text("12"), new Text("INFO"), new Text("Second entry message!")));
|
||||
exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.013"), new Text("13"), new Text("WARN"), new Text("Third entry message!")));
|
||||
assertEquals(exp0, rr.sequenceRecord());
|
||||
assertEquals(exp1, rr.sequenceRecord());
|
||||
assertFalse(rr.hasNext());
|
||||
|
||||
//Test resetting:
|
||||
// Test resetting:
|
||||
rr.reset();
|
||||
assertEquals(exp0, rr.sequenceRecord());
|
||||
assertEquals(exp1, rr.sequenceRecord());
|
||||
|
@ -154,24 +133,20 @@ public class RegexRecordReaderTest extends BaseND4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testRegexSequenceRecordReaderMeta() throws Exception {
|
||||
@DisplayName("Test Regex Sequence Record Reader Meta")
|
||||
void testRegexSequenceRecordReaderMeta(@TempDir Path testDir) throws Exception {
|
||||
String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)";
|
||||
|
||||
ClassPathResource cpr = new ClassPathResource("datavec-api/logtestdata/");
|
||||
File f = testDir.newFolder();
|
||||
File f = testDir.toFile();
|
||||
cpr.copyDirectory(f);
|
||||
String path = new File(f, "logtestfile%d.txt").getAbsolutePath();
|
||||
|
||||
InputSplit is = new NumberedFileInputSplit(path, 0, 1);
|
||||
|
||||
SequenceRecordReader rr = new RegexSequenceRecordReader(regex, 1);
|
||||
rr.initialize(is);
|
||||
|
||||
List<List<List<Writable>>> out = new ArrayList<>();
|
||||
while (rr.hasNext()) {
|
||||
out.add(rr.sequenceRecord());
|
||||
}
|
||||
|
||||
assertEquals(2, out.size());
|
||||
List<List<List<Writable>>> out2 = new ArrayList<>();
|
||||
List<SequenceRecord> out3 = new ArrayList<>();
|
||||
|
@ -183,11 +158,8 @@ public class RegexRecordReaderTest extends BaseND4JTest {
|
|||
out3.add(seqr);
|
||||
meta.add(seqr.getMetaData());
|
||||
}
|
||||
|
||||
List<SequenceRecord> fromMeta = rr.loadSequenceFromMetaData(meta);
|
||||
|
||||
assertEquals(out, out2);
|
||||
assertEquals(out3, fromMeta);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.api.records.reader.impl;
|
||||
|
||||
import org.datavec.api.conf.Configuration;
|
||||
|
@ -27,43 +26,30 @@ import org.datavec.api.split.FileSplit;
|
|||
import org.datavec.api.writable.DoubleWritable;
|
||||
import org.datavec.api.writable.IntWritable;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.*;
|
||||
|
||||
import static org.datavec.api.records.reader.impl.misc.SVMLightRecordReader.*;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
|
||||
public class SVMLightRecordReaderTest extends BaseND4JTest {
|
||||
@DisplayName("Svm Light Record Reader Test")
|
||||
class SVMLightRecordReaderTest extends BaseND4JTest {
|
||||
|
||||
@Test
|
||||
public void testBasicRecord() throws IOException, InterruptedException {
|
||||
@DisplayName("Test Basic Record")
|
||||
void testBasicRecord() throws IOException, InterruptedException {
|
||||
Map<Integer, List<Writable>> correct = new HashMap<>();
|
||||
// 7 2:1 4:2 6:3 8:4 10:5
|
||||
correct.put(0, Arrays.asList(ZERO, ONE,
|
||||
ZERO, new DoubleWritable(2),
|
||||
ZERO, new DoubleWritable(3),
|
||||
ZERO, new DoubleWritable(4),
|
||||
ZERO, new DoubleWritable(5),
|
||||
new IntWritable(7)));
|
||||
correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), new IntWritable(7)));
|
||||
// 2 qid:42 1:0.1 2:2 6:6.6 8:80
|
||||
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2),
|
||||
ZERO, ZERO,
|
||||
ZERO, new DoubleWritable(6.6),
|
||||
ZERO, new DoubleWritable(80),
|
||||
ZERO, ZERO,
|
||||
new IntWritable(2)));
|
||||
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, new IntWritable(2)));
|
||||
// 33
|
||||
correct.put(2, Arrays.asList(ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
new IntWritable(33)));
|
||||
|
||||
correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, new IntWritable(33)));
|
||||
SVMLightRecordReader rr = new SVMLightRecordReader();
|
||||
Configuration config = new Configuration();
|
||||
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
|
||||
|
@ -79,27 +65,15 @@ public class SVMLightRecordReaderTest extends BaseND4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testNoAppendLabel() throws IOException, InterruptedException {
|
||||
@DisplayName("Test No Append Label")
|
||||
void testNoAppendLabel() throws IOException, InterruptedException {
|
||||
Map<Integer, List<Writable>> correct = new HashMap<>();
|
||||
// 7 2:1 4:2 6:3 8:4 10:5
|
||||
correct.put(0, Arrays.asList(ZERO, ONE,
|
||||
ZERO, new DoubleWritable(2),
|
||||
ZERO, new DoubleWritable(3),
|
||||
ZERO, new DoubleWritable(4),
|
||||
ZERO, new DoubleWritable(5)));
|
||||
correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5)));
|
||||
// 2 qid:42 1:0.1 2:2 6:6.6 8:80
|
||||
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2),
|
||||
ZERO, ZERO,
|
||||
ZERO, new DoubleWritable(6.6),
|
||||
ZERO, new DoubleWritable(80),
|
||||
ZERO, ZERO));
|
||||
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO));
|
||||
// 33
|
||||
correct.put(2, Arrays.asList(ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO));
|
||||
|
||||
correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO));
|
||||
SVMLightRecordReader rr = new SVMLightRecordReader();
|
||||
Configuration config = new Configuration();
|
||||
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
|
||||
|
@ -116,33 +90,17 @@ public class SVMLightRecordReaderTest extends BaseND4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testNoLabel() throws IOException, InterruptedException {
|
||||
@DisplayName("Test No Label")
|
||||
void testNoLabel() throws IOException, InterruptedException {
|
||||
Map<Integer, List<Writable>> correct = new HashMap<>();
|
||||
// 2:1 4:2 6:3 8:4 10:5
|
||||
correct.put(0, Arrays.asList(ZERO, ONE,
|
||||
ZERO, new DoubleWritable(2),
|
||||
ZERO, new DoubleWritable(3),
|
||||
ZERO, new DoubleWritable(4),
|
||||
ZERO, new DoubleWritable(5)));
|
||||
correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5)));
|
||||
// qid:42 1:0.1 2:2 6:6.6 8:80
|
||||
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2),
|
||||
ZERO, ZERO,
|
||||
ZERO, new DoubleWritable(6.6),
|
||||
ZERO, new DoubleWritable(80),
|
||||
ZERO, ZERO));
|
||||
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO));
|
||||
// 1:1.0
|
||||
correct.put(2, Arrays.asList(new DoubleWritable(1.0), ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO));
|
||||
correct.put(2, Arrays.asList(new DoubleWritable(1.0), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO));
|
||||
//
|
||||
correct.put(3, Arrays.asList(ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO));
|
||||
|
||||
correct.put(3, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO));
|
||||
SVMLightRecordReader rr = new SVMLightRecordReader();
|
||||
Configuration config = new Configuration();
|
||||
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
|
||||
|
@ -159,33 +117,15 @@ public class SVMLightRecordReaderTest extends BaseND4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testMultioutputRecord() throws IOException, InterruptedException {
|
||||
@DisplayName("Test Multioutput Record")
|
||||
void testMultioutputRecord() throws IOException, InterruptedException {
|
||||
Map<Integer, List<Writable>> correct = new HashMap<>();
|
||||
// 7 2.45,9 2:1 4:2 6:3 8:4 10:5
|
||||
correct.put(0, Arrays.asList(ZERO, ONE,
|
||||
ZERO, new DoubleWritable(2),
|
||||
ZERO, new DoubleWritable(3),
|
||||
ZERO, new DoubleWritable(4),
|
||||
ZERO, new DoubleWritable(5),
|
||||
new IntWritable(7), new DoubleWritable(2.45),
|
||||
new IntWritable(9)));
|
||||
correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), new IntWritable(7), new DoubleWritable(2.45), new IntWritable(9)));
|
||||
// 2,3,4 qid:42 1:0.1 2:2 6:6.6 8:80
|
||||
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2),
|
||||
ZERO, ZERO,
|
||||
ZERO, new DoubleWritable(6.6),
|
||||
ZERO, new DoubleWritable(80),
|
||||
ZERO, ZERO,
|
||||
new IntWritable(2), new IntWritable(3),
|
||||
new IntWritable(4)));
|
||||
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, new IntWritable(2), new IntWritable(3), new IntWritable(4)));
|
||||
// 33,32.0,31.9
|
||||
correct.put(2, Arrays.asList(ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
new IntWritable(33), new DoubleWritable(32.0),
|
||||
new DoubleWritable(31.9)));
|
||||
|
||||
correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, new IntWritable(33), new DoubleWritable(32.0), new DoubleWritable(31.9)));
|
||||
SVMLightRecordReader rr = new SVMLightRecordReader();
|
||||
Configuration config = new Configuration();
|
||||
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
|
||||
|
@ -200,51 +140,20 @@ public class SVMLightRecordReaderTest extends BaseND4JTest {
|
|||
assertEquals(i, correct.size());
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testMultilabelRecord() throws IOException, InterruptedException {
|
||||
@DisplayName("Test Multilabel Record")
|
||||
void testMultilabelRecord() throws IOException, InterruptedException {
|
||||
Map<Integer, List<Writable>> correct = new HashMap<>();
|
||||
// 1,3 2:1 4:2 6:3 8:4 10:5
|
||||
correct.put(0, Arrays.asList(ZERO, ONE,
|
||||
ZERO, new DoubleWritable(2),
|
||||
ZERO, new DoubleWritable(3),
|
||||
ZERO, new DoubleWritable(4),
|
||||
ZERO, new DoubleWritable(5),
|
||||
LABEL_ONE, LABEL_ZERO,
|
||||
LABEL_ONE, LABEL_ZERO));
|
||||
correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), LABEL_ONE, LABEL_ZERO, LABEL_ONE, LABEL_ZERO));
|
||||
// 2 qid:42 1:0.1 2:2 6:6.6 8:80
|
||||
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2),
|
||||
ZERO, ZERO,
|
||||
ZERO, new DoubleWritable(6.6),
|
||||
ZERO, new DoubleWritable(80),
|
||||
ZERO, ZERO,
|
||||
LABEL_ZERO, LABEL_ONE,
|
||||
LABEL_ZERO, LABEL_ZERO));
|
||||
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, LABEL_ZERO, LABEL_ONE, LABEL_ZERO, LABEL_ZERO));
|
||||
// 1,2,4
|
||||
correct.put(2, Arrays.asList(ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
LABEL_ONE, LABEL_ONE,
|
||||
LABEL_ZERO, LABEL_ONE));
|
||||
correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ONE, LABEL_ONE, LABEL_ZERO, LABEL_ONE));
|
||||
// 1:1.0
|
||||
correct.put(3, Arrays.asList(new DoubleWritable(1.0), ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
LABEL_ZERO, LABEL_ZERO,
|
||||
LABEL_ZERO, LABEL_ZERO));
|
||||
correct.put(3, Arrays.asList(new DoubleWritable(1.0), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
|
||||
//
|
||||
correct.put(4, Arrays.asList(ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
LABEL_ZERO, LABEL_ZERO,
|
||||
LABEL_ZERO, LABEL_ZERO));
|
||||
|
||||
correct.put(4, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
|
||||
SVMLightRecordReader rr = new SVMLightRecordReader();
|
||||
Configuration config = new Configuration();
|
||||
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
|
||||
|
@ -262,63 +171,24 @@ public class SVMLightRecordReaderTest extends BaseND4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testZeroBasedIndexing() throws IOException, InterruptedException {
|
||||
@DisplayName("Test Zero Based Indexing")
|
||||
void testZeroBasedIndexing() throws IOException, InterruptedException {
|
||||
Map<Integer, List<Writable>> correct = new HashMap<>();
|
||||
// 1,3 2:1 4:2 6:3 8:4 10:5
|
||||
correct.put(0, Arrays.asList(ZERO,
|
||||
ZERO, ONE,
|
||||
ZERO, new DoubleWritable(2),
|
||||
ZERO, new DoubleWritable(3),
|
||||
ZERO, new DoubleWritable(4),
|
||||
ZERO, new DoubleWritable(5),
|
||||
LABEL_ZERO,
|
||||
LABEL_ONE, LABEL_ZERO,
|
||||
LABEL_ONE, LABEL_ZERO));
|
||||
correct.put(0, Arrays.asList(ZERO, ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), LABEL_ZERO, LABEL_ONE, LABEL_ZERO, LABEL_ONE, LABEL_ZERO));
|
||||
// 2 qid:42 1:0.1 2:2 6:6.6 8:80
|
||||
correct.put(1, Arrays.asList(ZERO,
|
||||
new DoubleWritable(0.1), new DoubleWritable(2),
|
||||
ZERO, ZERO,
|
||||
ZERO, new DoubleWritable(6.6),
|
||||
ZERO, new DoubleWritable(80),
|
||||
ZERO, ZERO,
|
||||
LABEL_ZERO,
|
||||
LABEL_ZERO, LABEL_ONE,
|
||||
LABEL_ZERO, LABEL_ZERO));
|
||||
correct.put(1, Arrays.asList(ZERO, new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ONE, LABEL_ZERO, LABEL_ZERO));
|
||||
// 1,2,4
|
||||
correct.put(2, Arrays.asList(ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
LABEL_ZERO,
|
||||
LABEL_ONE, LABEL_ONE,
|
||||
LABEL_ZERO, LABEL_ONE));
|
||||
correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ONE, LABEL_ONE, LABEL_ZERO, LABEL_ONE));
|
||||
// 1:1.0
|
||||
correct.put(3, Arrays.asList(ZERO,
|
||||
new DoubleWritable(1.0), ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
LABEL_ZERO,
|
||||
LABEL_ZERO, LABEL_ZERO,
|
||||
LABEL_ZERO, LABEL_ZERO));
|
||||
correct.put(3, Arrays.asList(ZERO, new DoubleWritable(1.0), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
|
||||
//
|
||||
correct.put(4, Arrays.asList(ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
ZERO, ZERO,
|
||||
LABEL_ZERO,
|
||||
LABEL_ZERO, LABEL_ZERO,
|
||||
LABEL_ZERO, LABEL_ZERO));
|
||||
|
||||
correct.put(4, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
|
||||
SVMLightRecordReader rr = new SVMLightRecordReader();
|
||||
Configuration config = new Configuration();
|
||||
// Zero-based indexing is default
|
||||
config.setBoolean(SVMLightRecordReader.ZERO_BASED_LABEL_INDEXING, true); // NOT STANDARD!
|
||||
// NOT STANDARD!
|
||||
config.setBoolean(SVMLightRecordReader.ZERO_BASED_LABEL_INDEXING, true);
|
||||
config.setInt(SVMLightRecordReader.NUM_FEATURES, 11);
|
||||
config.setBoolean(SVMLightRecordReader.MULTILABEL, true);
|
||||
config.setInt(SVMLightRecordReader.NUM_LABELS, 5);
|
||||
|
@ -333,20 +203,19 @@ public class SVMLightRecordReaderTest extends BaseND4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testNextRecord() throws IOException, InterruptedException {
|
||||
@DisplayName("Test Next Record")
|
||||
void testNextRecord() throws IOException, InterruptedException {
|
||||
SVMLightRecordReader rr = new SVMLightRecordReader();
|
||||
Configuration config = new Configuration();
|
||||
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
|
||||
config.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
|
||||
config.setBoolean(SVMLightRecordReader.APPEND_LABEL, false);
|
||||
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
|
||||
|
||||
Record record = rr.nextRecord();
|
||||
List<Writable> recordList = record.getRecord();
|
||||
assertEquals(new DoubleWritable(1.0), recordList.get(1));
|
||||
assertEquals(new DoubleWritable(3.0), recordList.get(5));
|
||||
assertEquals(new DoubleWritable(4.0), recordList.get(7));
|
||||
|
||||
record = rr.nextRecord();
|
||||
recordList = record.getRecord();
|
||||
assertEquals(new DoubleWritable(0.1), recordList.get(0));
|
||||
|
@ -354,76 +223,95 @@ public class SVMLightRecordReaderTest extends BaseND4JTest {
|
|||
assertEquals(new DoubleWritable(80.0), recordList.get(7));
|
||||
}
|
||||
|
||||
@Test(expected = NoSuchElementException.class)
|
||||
public void testNoSuchElementException() throws Exception {
|
||||
@Test
|
||||
@DisplayName("Test No Such Element Exception")
|
||||
void testNoSuchElementException() {
|
||||
assertThrows(NoSuchElementException.class, () -> {
|
||||
SVMLightRecordReader rr = new SVMLightRecordReader();
|
||||
Configuration config = new Configuration();
|
||||
config.setInt(SVMLightRecordReader.NUM_FEATURES, 11);
|
||||
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
|
||||
while (rr.hasNext())
|
||||
rr.next();
|
||||
while (rr.hasNext()) rr.next();
|
||||
rr.next();
|
||||
});
|
||||
}
|
||||
|
||||
@Test(expected = UnsupportedOperationException.class)
|
||||
public void failedToSetNumFeaturesException() throws Exception {
|
||||
@Test
|
||||
@DisplayName("Failed To Set Num Features Exception")
|
||||
void failedToSetNumFeaturesException() {
|
||||
assertThrows(UnsupportedOperationException.class, () -> {
|
||||
SVMLightRecordReader rr = new SVMLightRecordReader();
|
||||
Configuration config = new Configuration();
|
||||
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
|
||||
while (rr.hasNext())
|
||||
rr.next();
|
||||
while (rr.hasNext()) rr.next();
|
||||
});
|
||||
}
|
||||
|
||||
@Test(expected = UnsupportedOperationException.class)
|
||||
public void testInconsistentNumLabelsException() throws Exception {
|
||||
@Test
|
||||
@DisplayName("Test Inconsistent Num Labels Exception")
|
||||
void testInconsistentNumLabelsException() {
|
||||
assertThrows(UnsupportedOperationException.class, () -> {
|
||||
SVMLightRecordReader rr = new SVMLightRecordReader();
|
||||
Configuration config = new Configuration();
|
||||
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
|
||||
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/inconsistentNumLabels.txt").getFile()));
|
||||
while (rr.hasNext())
|
||||
rr.next();
|
||||
while (rr.hasNext()) rr.next();
|
||||
});
|
||||
}
|
||||
|
||||
@Test(expected = UnsupportedOperationException.class)
|
||||
public void failedToSetNumMultiabelsException() throws Exception {
|
||||
@Test
|
||||
@DisplayName("Failed To Set Num Multiabels Exception")
|
||||
void failedToSetNumMultiabelsException() {
|
||||
assertThrows(UnsupportedOperationException.class, () -> {
|
||||
SVMLightRecordReader rr = new SVMLightRecordReader();
|
||||
Configuration config = new Configuration();
|
||||
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile()));
|
||||
while (rr.hasNext())
|
||||
rr.next();
|
||||
while (rr.hasNext()) rr.next();
|
||||
});
|
||||
}
|
||||
|
||||
@Test(expected = IndexOutOfBoundsException.class)
|
||||
public void testFeatureIndexExceedsNumFeatures() throws Exception {
|
||||
@Test
|
||||
@DisplayName("Test Feature Index Exceeds Num Features")
|
||||
void testFeatureIndexExceedsNumFeatures() {
|
||||
assertThrows(IndexOutOfBoundsException.class, () -> {
|
||||
SVMLightRecordReader rr = new SVMLightRecordReader();
|
||||
Configuration config = new Configuration();
|
||||
config.setInt(SVMLightRecordReader.NUM_FEATURES, 9);
|
||||
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
|
||||
rr.next();
|
||||
});
|
||||
}
|
||||
|
||||
@Test(expected = IndexOutOfBoundsException.class)
|
||||
public void testLabelIndexExceedsNumLabels() throws Exception {
|
||||
@Test
|
||||
@DisplayName("Test Label Index Exceeds Num Labels")
|
||||
void testLabelIndexExceedsNumLabels() {
|
||||
assertThrows(IndexOutOfBoundsException.class, () -> {
|
||||
SVMLightRecordReader rr = new SVMLightRecordReader();
|
||||
Configuration config = new Configuration();
|
||||
config.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
|
||||
config.setInt(SVMLightRecordReader.NUM_LABELS, 6);
|
||||
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
|
||||
rr.next();
|
||||
});
|
||||
}
|
||||
|
||||
@Test(expected = IndexOutOfBoundsException.class)
|
||||
public void testZeroIndexFeatureWithoutUsingZeroIndexing() throws Exception {
|
||||
@Test
|
||||
@DisplayName("Test Zero Index Feature Without Using Zero Indexing")
|
||||
void testZeroIndexFeatureWithoutUsingZeroIndexing() {
|
||||
assertThrows(IndexOutOfBoundsException.class, () -> {
|
||||
SVMLightRecordReader rr = new SVMLightRecordReader();
|
||||
Configuration config = new Configuration();
|
||||
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
|
||||
config.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
|
||||
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexFeature.txt").getFile()));
|
||||
rr.next();
|
||||
});
|
||||
}
|
||||
|
||||
@Test(expected = IndexOutOfBoundsException.class)
|
||||
public void testZeroIndexLabelWithoutUsingZeroIndexing() throws Exception {
|
||||
@Test
|
||||
@DisplayName("Test Zero Index Label Without Using Zero Indexing")
|
||||
void testZeroIndexLabelWithoutUsingZeroIndexing() {
|
||||
assertThrows(IndexOutOfBoundsException.class, () -> {
|
||||
SVMLightRecordReader rr = new SVMLightRecordReader();
|
||||
Configuration config = new Configuration();
|
||||
config.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
|
||||
|
@ -431,5 +319,6 @@ public class SVMLightRecordReaderTest extends BaseND4JTest {
|
|||
config.setInt(SVMLightRecordReader.NUM_LABELS, 2);
|
||||
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexLabel.txt").getFile()));
|
||||
rr.next();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
|
@ -26,14 +26,14 @@ import org.datavec.api.records.reader.SequenceRecordReader;
|
|||
import org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader;
|
||||
import org.datavec.api.writable.IntWritable;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
public class TestCollectionRecordReaders extends BaseND4JTest {
|
||||
|
||||
|
|
|
@ -23,11 +23,11 @@ package org.datavec.api.records.reader.impl;
|
|||
import org.datavec.api.records.reader.RecordReader;
|
||||
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
|
||||
import org.datavec.api.split.FileSplit;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
public class TestConcatenatingRecordReader extends BaseND4JTest {
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ import org.datavec.api.transform.TransformProcess;
|
|||
import org.datavec.api.transform.schema.Schema;
|
||||
import org.datavec.api.writable.Text;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
import org.nd4j.shade.jackson.core.JsonFactory;
|
||||
|
@ -47,7 +47,7 @@ import java.io.*;
|
|||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
public class TestSerialization extends BaseND4JTest {
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ import org.datavec.api.writable.IntWritable;
|
|||
import org.datavec.api.writable.LongWritable;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.joda.time.DateTimeZone;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
|
||||
|
@ -38,8 +38,8 @@ import java.util.ArrayList;
|
|||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
public class TransformProcessRecordReaderTests extends BaseND4JTest {
|
||||
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.api.records.writer.impl;
|
||||
|
||||
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
|
||||
|
@ -26,44 +25,42 @@ import org.datavec.api.split.FileSplit;
|
|||
import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
|
||||
import org.datavec.api.writable.Text;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
|
||||
import java.io.File;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
public class CSVRecordWriterTest extends BaseND4JTest {
|
||||
|
||||
@Before
|
||||
public void setUp() throws Exception {
|
||||
@DisplayName("Csv Record Writer Test")
|
||||
class CSVRecordWriterTest extends BaseND4JTest {
|
||||
|
||||
@BeforeEach
|
||||
void setUp() throws Exception {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testWrite() throws Exception {
|
||||
@DisplayName("Test Write")
|
||||
void testWrite() throws Exception {
|
||||
File tempFile = File.createTempFile("datavec", "writer");
|
||||
tempFile.deleteOnExit();
|
||||
FileSplit fileSplit = new FileSplit(tempFile);
|
||||
CSVRecordWriter writer = new CSVRecordWriter();
|
||||
writer.initialize(fileSplit,new NumberOfRecordsPartitioner());
|
||||
writer.initialize(fileSplit, new NumberOfRecordsPartitioner());
|
||||
List<Writable> collection = new ArrayList<>();
|
||||
collection.add(new Text("12"));
|
||||
collection.add(new Text("13"));
|
||||
collection.add(new Text("14"));
|
||||
|
||||
writer.write(collection);
|
||||
|
||||
CSVRecordReader reader = new CSVRecordReader(0);
|
||||
reader.initialize(new FileSplit(tempFile));
|
||||
int cnt = 0;
|
||||
while (reader.hasNext()) {
|
||||
List<Writable> line = new ArrayList<>(reader.next());
|
||||
assertEquals(3, line.size());
|
||||
|
||||
assertEquals(12, line.get(0).toInt());
|
||||
assertEquals(13, line.get(1).toInt());
|
||||
assertEquals(14, line.get(2).toInt());
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.api.records.writer.impl;
|
||||
|
||||
import org.apache.commons.io.FileUtils;
|
||||
|
@ -30,93 +29,90 @@ import org.datavec.api.writable.DoubleWritable;
|
|||
import org.datavec.api.writable.IntWritable;
|
||||
import org.datavec.api.writable.NDArrayWritable;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
|
||||
import java.io.File;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
public class LibSvmRecordWriterTest extends BaseND4JTest {
|
||||
@DisplayName("Lib Svm Record Writer Test")
|
||||
class LibSvmRecordWriterTest extends BaseND4JTest {
|
||||
|
||||
@Test
|
||||
public void testBasic() throws Exception {
|
||||
@DisplayName("Test Basic")
|
||||
void testBasic() throws Exception {
|
||||
Configuration configWriter = new Configuration();
|
||||
|
||||
Configuration configReader = new Configuration();
|
||||
configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
|
||||
configReader.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
|
||||
|
||||
File inputFile = new ClassPathResource("datavec-api/svmlight/basic.txt").getFile();
|
||||
executeTest(configWriter, configReader, inputFile);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNoLabel() throws Exception {
|
||||
@DisplayName("Test No Label")
|
||||
void testNoLabel() throws Exception {
|
||||
Configuration configWriter = new Configuration();
|
||||
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
||||
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 9);
|
||||
|
||||
Configuration configReader = new Configuration();
|
||||
configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
|
||||
configReader.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
|
||||
|
||||
File inputFile = new ClassPathResource("datavec-api/svmlight/basic.txt").getFile();
|
||||
executeTest(configWriter, configReader, inputFile);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMultioutputRecord() throws Exception {
|
||||
@DisplayName("Test Multioutput Record")
|
||||
void testMultioutputRecord() throws Exception {
|
||||
Configuration configWriter = new Configuration();
|
||||
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
||||
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 9);
|
||||
|
||||
Configuration configReader = new Configuration();
|
||||
configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
|
||||
configReader.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
|
||||
|
||||
File inputFile = new ClassPathResource("datavec-api/svmlight/multioutput.txt").getFile();
|
||||
executeTest(configWriter, configReader, inputFile);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMultilabelRecord() throws Exception {
|
||||
@DisplayName("Test Multilabel Record")
|
||||
void testMultilabelRecord() throws Exception {
|
||||
Configuration configWriter = new Configuration();
|
||||
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
||||
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 9);
|
||||
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
|
||||
|
||||
Configuration configReader = new Configuration();
|
||||
configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
|
||||
configReader.setBoolean(LibSvmRecordReader.MULTILABEL, true);
|
||||
configReader.setInt(LibSvmRecordReader.NUM_LABELS, 4);
|
||||
configReader.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
|
||||
|
||||
File inputFile = new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile();
|
||||
executeTest(configWriter, configReader, inputFile);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testZeroBasedIndexing() throws Exception {
|
||||
@DisplayName("Test Zero Based Indexing")
|
||||
void testZeroBasedIndexing() throws Exception {
|
||||
Configuration configWriter = new Configuration();
|
||||
configWriter.setBoolean(LibSvmRecordWriter.ZERO_BASED_INDEXING, true);
|
||||
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
||||
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 10);
|
||||
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
|
||||
|
||||
Configuration configReader = new Configuration();
|
||||
configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 11);
|
||||
configReader.setBoolean(LibSvmRecordReader.MULTILABEL, true);
|
||||
configReader.setInt(LibSvmRecordReader.NUM_LABELS, 5);
|
||||
|
||||
File inputFile = new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile();
|
||||
executeTest(configWriter, configReader, inputFile);
|
||||
}
|
||||
|
@ -127,10 +123,9 @@ public class LibSvmRecordWriterTest extends BaseND4JTest {
|
|||
tempFile.deleteOnExit();
|
||||
if (tempFile.exists())
|
||||
tempFile.delete();
|
||||
|
||||
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
|
||||
FileSplit outputSplit = new FileSplit(tempFile);
|
||||
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
|
||||
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
|
||||
LibSvmRecordReader rr = new LibSvmRecordReader();
|
||||
rr.initialize(configReader, new FileSplit(inputFile));
|
||||
while (rr.hasNext()) {
|
||||
|
@ -138,7 +133,6 @@ public class LibSvmRecordWriterTest extends BaseND4JTest {
|
|||
writer.write(record);
|
||||
}
|
||||
}
|
||||
|
||||
Pattern p = Pattern.compile(String.format("%s:\\d+ ", LibSvmRecordReader.QID_PREFIX));
|
||||
List<String> linesOriginal = new ArrayList<>();
|
||||
for (String line : FileUtils.readLines(inputFile)) {
|
||||
|
@ -159,7 +153,8 @@ public class LibSvmRecordWriterTest extends BaseND4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testNDArrayWritables() throws Exception {
|
||||
@DisplayName("Test ND Array Writables")
|
||||
void testNDArrayWritables() throws Exception {
|
||||
INDArray arr2 = Nd4j.zeros(2);
|
||||
arr2.putScalar(0, 11);
|
||||
arr2.putScalar(1, 12);
|
||||
|
@ -167,35 +162,28 @@ public class LibSvmRecordWriterTest extends BaseND4JTest {
|
|||
arr3.putScalar(0, 13);
|
||||
arr3.putScalar(1, 14);
|
||||
arr3.putScalar(2, 15);
|
||||
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1),
|
||||
new NDArrayWritable(arr2),
|
||||
new IntWritable(2),
|
||||
new DoubleWritable(3),
|
||||
new NDArrayWritable(arr3),
|
||||
new IntWritable(4));
|
||||
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), new NDArrayWritable(arr3), new IntWritable(4));
|
||||
File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
|
||||
tempFile.setWritable(true);
|
||||
tempFile.deleteOnExit();
|
||||
if (tempFile.exists())
|
||||
tempFile.delete();
|
||||
|
||||
String lineOriginal = "13.0,14.0,15.0,4 1:1.0 2:11.0 3:12.0 4:2.0 5:3.0";
|
||||
|
||||
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
|
||||
Configuration configWriter = new Configuration();
|
||||
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
||||
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 3);
|
||||
FileSplit outputSplit = new FileSplit(tempFile);
|
||||
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
|
||||
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
|
||||
writer.write(record);
|
||||
}
|
||||
|
||||
String lineNew = FileUtils.readFileToString(tempFile).trim();
|
||||
assertEquals(lineOriginal, lineNew);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNDArrayWritablesMultilabel() throws Exception {
|
||||
@DisplayName("Test ND Array Writables Multilabel")
|
||||
void testNDArrayWritablesMultilabel() throws Exception {
|
||||
INDArray arr2 = Nd4j.zeros(2);
|
||||
arr2.putScalar(0, 11);
|
||||
arr2.putScalar(1, 12);
|
||||
|
@ -203,36 +191,29 @@ public class LibSvmRecordWriterTest extends BaseND4JTest {
|
|||
arr3.putScalar(0, 0);
|
||||
arr3.putScalar(1, 1);
|
||||
arr3.putScalar(2, 0);
|
||||
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1),
|
||||
new NDArrayWritable(arr2),
|
||||
new IntWritable(2),
|
||||
new DoubleWritable(3),
|
||||
new NDArrayWritable(arr3),
|
||||
new DoubleWritable(1));
|
||||
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), new NDArrayWritable(arr3), new DoubleWritable(1));
|
||||
File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
|
||||
tempFile.setWritable(true);
|
||||
tempFile.deleteOnExit();
|
||||
if (tempFile.exists())
|
||||
tempFile.delete();
|
||||
|
||||
String lineOriginal = "2,4 1:1.0 2:11.0 3:12.0 4:2.0 5:3.0";
|
||||
|
||||
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
|
||||
Configuration configWriter = new Configuration();
|
||||
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
|
||||
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
||||
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 3);
|
||||
FileSplit outputSplit = new FileSplit(tempFile);
|
||||
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
|
||||
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
|
||||
writer.write(record);
|
||||
}
|
||||
|
||||
String lineNew = FileUtils.readFileToString(tempFile).trim();
|
||||
assertEquals(lineOriginal, lineNew);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNDArrayWritablesZeroIndex() throws Exception {
|
||||
@DisplayName("Test ND Array Writables Zero Index")
|
||||
void testNDArrayWritablesZeroIndex() throws Exception {
|
||||
INDArray arr2 = Nd4j.zeros(2);
|
||||
arr2.putScalar(0, 11);
|
||||
arr2.putScalar(1, 12);
|
||||
|
@ -240,99 +221,91 @@ public class LibSvmRecordWriterTest extends BaseND4JTest {
|
|||
arr3.putScalar(0, 0);
|
||||
arr3.putScalar(1, 1);
|
||||
arr3.putScalar(2, 0);
|
||||
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1),
|
||||
new NDArrayWritable(arr2),
|
||||
new IntWritable(2),
|
||||
new DoubleWritable(3),
|
||||
new NDArrayWritable(arr3),
|
||||
new DoubleWritable(1));
|
||||
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), new NDArrayWritable(arr3), new DoubleWritable(1));
|
||||
File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
|
||||
tempFile.setWritable(true);
|
||||
tempFile.deleteOnExit();
|
||||
if (tempFile.exists())
|
||||
tempFile.delete();
|
||||
|
||||
String lineOriginal = "1,3 0:1.0 1:11.0 2:12.0 3:2.0 4:3.0";
|
||||
|
||||
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
|
||||
Configuration configWriter = new Configuration();
|
||||
configWriter.setBoolean(LibSvmRecordWriter.ZERO_BASED_INDEXING, true); // NOT STANDARD!
|
||||
configWriter.setBoolean(LibSvmRecordWriter.ZERO_BASED_LABEL_INDEXING, true); // NOT STANDARD!
|
||||
// NOT STANDARD!
|
||||
configWriter.setBoolean(LibSvmRecordWriter.ZERO_BASED_INDEXING, true);
|
||||
// NOT STANDARD!
|
||||
configWriter.setBoolean(LibSvmRecordWriter.ZERO_BASED_LABEL_INDEXING, true);
|
||||
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
|
||||
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
||||
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 3);
|
||||
FileSplit outputSplit = new FileSplit(tempFile);
|
||||
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
|
||||
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
|
||||
writer.write(record);
|
||||
}
|
||||
|
||||
String lineNew = FileUtils.readFileToString(tempFile).trim();
|
||||
assertEquals(lineOriginal, lineNew);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNonIntegerButValidMultilabel() throws Exception {
|
||||
List<Writable> record = Arrays.asList((Writable) new IntWritable(3),
|
||||
new IntWritable(2),
|
||||
new DoubleWritable(1.0));
|
||||
@DisplayName("Test Non Integer But Valid Multilabel")
|
||||
void testNonIntegerButValidMultilabel() throws Exception {
|
||||
List<Writable> record = Arrays.asList((Writable) new IntWritable(3), new IntWritable(2), new DoubleWritable(1.0));
|
||||
File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
|
||||
tempFile.setWritable(true);
|
||||
tempFile.deleteOnExit();
|
||||
if (tempFile.exists())
|
||||
tempFile.delete();
|
||||
|
||||
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
|
||||
Configuration configWriter = new Configuration();
|
||||
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
||||
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 1);
|
||||
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
|
||||
FileSplit outputSplit = new FileSplit(tempFile);
|
||||
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
|
||||
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
|
||||
writer.write(record);
|
||||
}
|
||||
}
|
||||
|
||||
@Test(expected = NumberFormatException.class)
|
||||
public void nonIntegerMultilabel() throws Exception {
|
||||
List<Writable> record = Arrays.asList((Writable) new IntWritable(3),
|
||||
new IntWritable(2),
|
||||
new DoubleWritable(1.2));
|
||||
@Test
|
||||
@DisplayName("Non Integer Multilabel")
|
||||
void nonIntegerMultilabel() {
|
||||
assertThrows(NumberFormatException.class, () -> {
|
||||
List<Writable> record = Arrays.asList((Writable) new IntWritable(3), new IntWritable(2), new DoubleWritable(1.2));
|
||||
File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
|
||||
tempFile.setWritable(true);
|
||||
tempFile.deleteOnExit();
|
||||
if (tempFile.exists())
|
||||
tempFile.delete();
|
||||
|
||||
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
|
||||
Configuration configWriter = new Configuration();
|
||||
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
||||
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 1);
|
||||
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
|
||||
FileSplit outputSplit = new FileSplit(tempFile);
|
||||
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
|
||||
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
|
||||
writer.write(record);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@Test(expected = NumberFormatException.class)
|
||||
public void nonBinaryMultilabel() throws Exception {
|
||||
List<Writable> record = Arrays.asList((Writable) new IntWritable(0),
|
||||
new IntWritable(1),
|
||||
new IntWritable(2));
|
||||
@Test
|
||||
@DisplayName("Non Binary Multilabel")
|
||||
void nonBinaryMultilabel() {
|
||||
assertThrows(NumberFormatException.class, () -> {
|
||||
List<Writable> record = Arrays.asList((Writable) new IntWritable(0), new IntWritable(1), new IntWritable(2));
|
||||
File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
|
||||
tempFile.setWritable(true);
|
||||
tempFile.deleteOnExit();
|
||||
if (tempFile.exists())
|
||||
tempFile.delete();
|
||||
|
||||
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
|
||||
Configuration configWriter = new Configuration();
|
||||
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN,0);
|
||||
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN,1);
|
||||
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL,true);
|
||||
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
||||
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 1);
|
||||
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
|
||||
FileSplit outputSplit = new FileSplit(tempFile);
|
||||
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
|
||||
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
|
||||
writer.write(record);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.api.records.writer.impl;
|
||||
|
||||
import org.apache.commons.io.FileUtils;
|
||||
|
@ -27,93 +26,90 @@ import org.datavec.api.records.writer.impl.misc.SVMLightRecordWriter;
|
|||
import org.datavec.api.split.FileSplit;
|
||||
import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
|
||||
import org.datavec.api.writable.*;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
|
||||
import java.io.File;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
public class SVMLightRecordWriterTest extends BaseND4JTest {
|
||||
@DisplayName("Svm Light Record Writer Test")
|
||||
class SVMLightRecordWriterTest extends BaseND4JTest {
|
||||
|
||||
@Test
|
||||
public void testBasic() throws Exception {
|
||||
@DisplayName("Test Basic")
|
||||
void testBasic() throws Exception {
|
||||
Configuration configWriter = new Configuration();
|
||||
|
||||
Configuration configReader = new Configuration();
|
||||
configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
|
||||
configReader.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
|
||||
|
||||
File inputFile = new ClassPathResource("datavec-api/svmlight/basic.txt").getFile();
|
||||
executeTest(configWriter, configReader, inputFile);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNoLabel() throws Exception {
|
||||
@DisplayName("Test No Label")
|
||||
void testNoLabel() throws Exception {
|
||||
Configuration configWriter = new Configuration();
|
||||
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
||||
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 9);
|
||||
|
||||
Configuration configReader = new Configuration();
|
||||
configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
|
||||
configReader.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
|
||||
|
||||
File inputFile = new ClassPathResource("datavec-api/svmlight/noLabels.txt").getFile();
|
||||
executeTest(configWriter, configReader, inputFile);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMultioutputRecord() throws Exception {
|
||||
@DisplayName("Test Multioutput Record")
|
||||
void testMultioutputRecord() throws Exception {
|
||||
Configuration configWriter = new Configuration();
|
||||
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
||||
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 9);
|
||||
|
||||
Configuration configReader = new Configuration();
|
||||
configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
|
||||
configReader.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
|
||||
|
||||
File inputFile = new ClassPathResource("datavec-api/svmlight/multioutput.txt").getFile();
|
||||
executeTest(configWriter, configReader, inputFile);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMultilabelRecord() throws Exception {
|
||||
@DisplayName("Test Multilabel Record")
|
||||
void testMultilabelRecord() throws Exception {
|
||||
Configuration configWriter = new Configuration();
|
||||
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
||||
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 9);
|
||||
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
|
||||
|
||||
Configuration configReader = new Configuration();
|
||||
configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
|
||||
configReader.setBoolean(SVMLightRecordReader.MULTILABEL, true);
|
||||
configReader.setInt(SVMLightRecordReader.NUM_LABELS, 4);
|
||||
configReader.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
|
||||
|
||||
File inputFile = new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile();
|
||||
executeTest(configWriter, configReader, inputFile);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testZeroBasedIndexing() throws Exception {
|
||||
@DisplayName("Test Zero Based Indexing")
|
||||
void testZeroBasedIndexing() throws Exception {
|
||||
Configuration configWriter = new Configuration();
|
||||
configWriter.setBoolean(SVMLightRecordWriter.ZERO_BASED_INDEXING, true);
|
||||
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
||||
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 10);
|
||||
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
|
||||
|
||||
Configuration configReader = new Configuration();
|
||||
configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 11);
|
||||
configReader.setBoolean(SVMLightRecordReader.MULTILABEL, true);
|
||||
configReader.setInt(SVMLightRecordReader.NUM_LABELS, 5);
|
||||
|
||||
File inputFile = new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile();
|
||||
executeTest(configWriter, configReader, inputFile);
|
||||
}
|
||||
|
@ -124,10 +120,9 @@ public class SVMLightRecordWriterTest extends BaseND4JTest {
|
|||
tempFile.deleteOnExit();
|
||||
if (tempFile.exists())
|
||||
tempFile.delete();
|
||||
|
||||
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
|
||||
FileSplit outputSplit = new FileSplit(tempFile);
|
||||
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
|
||||
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
|
||||
SVMLightRecordReader rr = new SVMLightRecordReader();
|
||||
rr.initialize(configReader, new FileSplit(inputFile));
|
||||
while (rr.hasNext()) {
|
||||
|
@ -135,7 +130,6 @@ public class SVMLightRecordWriterTest extends BaseND4JTest {
|
|||
writer.write(record);
|
||||
}
|
||||
}
|
||||
|
||||
Pattern p = Pattern.compile(String.format("%s:\\d+ ", SVMLightRecordReader.QID_PREFIX));
|
||||
List<String> linesOriginal = new ArrayList<>();
|
||||
for (String line : FileUtils.readLines(inputFile)) {
|
||||
|
@ -156,7 +150,8 @@ public class SVMLightRecordWriterTest extends BaseND4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testNDArrayWritables() throws Exception {
|
||||
@DisplayName("Test ND Array Writables")
|
||||
void testNDArrayWritables() throws Exception {
|
||||
INDArray arr2 = Nd4j.zeros(2);
|
||||
arr2.putScalar(0, 11);
|
||||
arr2.putScalar(1, 12);
|
||||
|
@ -164,35 +159,28 @@ public class SVMLightRecordWriterTest extends BaseND4JTest {
|
|||
arr3.putScalar(0, 13);
|
||||
arr3.putScalar(1, 14);
|
||||
arr3.putScalar(2, 15);
|
||||
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1),
|
||||
new NDArrayWritable(arr2),
|
||||
new IntWritable(2),
|
||||
new DoubleWritable(3),
|
||||
new NDArrayWritable(arr3),
|
||||
new IntWritable(4));
|
||||
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), new NDArrayWritable(arr3), new IntWritable(4));
|
||||
File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
|
||||
tempFile.setWritable(true);
|
||||
tempFile.deleteOnExit();
|
||||
if (tempFile.exists())
|
||||
tempFile.delete();
|
||||
|
||||
String lineOriginal = "13.0,14.0,15.0,4 1:1.0 2:11.0 3:12.0 4:2.0 5:3.0";
|
||||
|
||||
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
|
||||
Configuration configWriter = new Configuration();
|
||||
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
||||
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 3);
|
||||
FileSplit outputSplit = new FileSplit(tempFile);
|
||||
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
|
||||
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
|
||||
writer.write(record);
|
||||
}
|
||||
|
||||
String lineNew = FileUtils.readFileToString(tempFile).trim();
|
||||
assertEquals(lineOriginal, lineNew);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNDArrayWritablesMultilabel() throws Exception {
|
||||
@DisplayName("Test ND Array Writables Multilabel")
|
||||
void testNDArrayWritablesMultilabel() throws Exception {
|
||||
INDArray arr2 = Nd4j.zeros(2);
|
||||
arr2.putScalar(0, 11);
|
||||
arr2.putScalar(1, 12);
|
||||
|
@ -200,36 +188,29 @@ public class SVMLightRecordWriterTest extends BaseND4JTest {
|
|||
arr3.putScalar(0, 0);
|
||||
arr3.putScalar(1, 1);
|
||||
arr3.putScalar(2, 0);
|
||||
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1),
|
||||
new NDArrayWritable(arr2),
|
||||
new IntWritable(2),
|
||||
new DoubleWritable(3),
|
||||
new NDArrayWritable(arr3),
|
||||
new DoubleWritable(1));
|
||||
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), new NDArrayWritable(arr3), new DoubleWritable(1));
|
||||
File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
|
||||
tempFile.setWritable(true);
|
||||
tempFile.deleteOnExit();
|
||||
if (tempFile.exists())
|
||||
tempFile.delete();
|
||||
|
||||
String lineOriginal = "2,4 1:1.0 2:11.0 3:12.0 4:2.0 5:3.0";
|
||||
|
||||
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
|
||||
Configuration configWriter = new Configuration();
|
||||
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
|
||||
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
||||
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 3);
|
||||
FileSplit outputSplit = new FileSplit(tempFile);
|
||||
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
|
||||
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
|
||||
writer.write(record);
|
||||
}
|
||||
|
||||
String lineNew = FileUtils.readFileToString(tempFile).trim();
|
||||
assertEquals(lineOriginal, lineNew);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNDArrayWritablesZeroIndex() throws Exception {
|
||||
@DisplayName("Test ND Array Writables Zero Index")
|
||||
void testNDArrayWritablesZeroIndex() throws Exception {
|
||||
INDArray arr2 = Nd4j.zeros(2);
|
||||
arr2.putScalar(0, 11);
|
||||
arr2.putScalar(1, 12);
|
||||
|
@ -237,99 +218,91 @@ public class SVMLightRecordWriterTest extends BaseND4JTest {
|
|||
arr3.putScalar(0, 0);
|
||||
arr3.putScalar(1, 1);
|
||||
arr3.putScalar(2, 0);
|
||||
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1),
|
||||
new NDArrayWritable(arr2),
|
||||
new IntWritable(2),
|
||||
new DoubleWritable(3),
|
||||
new NDArrayWritable(arr3),
|
||||
new DoubleWritable(1));
|
||||
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), new NDArrayWritable(arr3), new DoubleWritable(1));
|
||||
File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
|
||||
tempFile.setWritable(true);
|
||||
tempFile.deleteOnExit();
|
||||
if (tempFile.exists())
|
||||
tempFile.delete();
|
||||
|
||||
String lineOriginal = "1,3 0:1.0 1:11.0 2:12.0 3:2.0 4:3.0";
|
||||
|
||||
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
|
||||
Configuration configWriter = new Configuration();
|
||||
configWriter.setBoolean(SVMLightRecordWriter.ZERO_BASED_INDEXING, true); // NOT STANDARD!
|
||||
configWriter.setBoolean(SVMLightRecordWriter.ZERO_BASED_LABEL_INDEXING, true); // NOT STANDARD!
|
||||
// NOT STANDARD!
|
||||
configWriter.setBoolean(SVMLightRecordWriter.ZERO_BASED_INDEXING, true);
|
||||
// NOT STANDARD!
|
||||
configWriter.setBoolean(SVMLightRecordWriter.ZERO_BASED_LABEL_INDEXING, true);
|
||||
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
|
||||
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
||||
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 3);
|
||||
FileSplit outputSplit = new FileSplit(tempFile);
|
||||
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
|
||||
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
|
||||
writer.write(record);
|
||||
}
|
||||
|
||||
String lineNew = FileUtils.readFileToString(tempFile).trim();
|
||||
assertEquals(lineOriginal, lineNew);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNonIntegerButValidMultilabel() throws Exception {
|
||||
List<Writable> record = Arrays.asList((Writable) new IntWritable(3),
|
||||
new IntWritable(2),
|
||||
new DoubleWritable(1.0));
|
||||
@DisplayName("Test Non Integer But Valid Multilabel")
|
||||
void testNonIntegerButValidMultilabel() throws Exception {
|
||||
List<Writable> record = Arrays.asList((Writable) new IntWritable(3), new IntWritable(2), new DoubleWritable(1.0));
|
||||
File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
|
||||
tempFile.setWritable(true);
|
||||
tempFile.deleteOnExit();
|
||||
if (tempFile.exists())
|
||||
tempFile.delete();
|
||||
|
||||
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
|
||||
Configuration configWriter = new Configuration();
|
||||
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
||||
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 1);
|
||||
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
|
||||
FileSplit outputSplit = new FileSplit(tempFile);
|
||||
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
|
||||
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
|
||||
writer.write(record);
|
||||
}
|
||||
}
|
||||
|
||||
@Test(expected = NumberFormatException.class)
|
||||
public void nonIntegerMultilabel() throws Exception {
|
||||
List<Writable> record = Arrays.asList((Writable) new IntWritable(3),
|
||||
new IntWritable(2),
|
||||
new DoubleWritable(1.2));
|
||||
@Test
|
||||
@DisplayName("Non Integer Multilabel")
|
||||
void nonIntegerMultilabel() {
|
||||
assertThrows(NumberFormatException.class, () -> {
|
||||
List<Writable> record = Arrays.asList((Writable) new IntWritable(3), new IntWritable(2), new DoubleWritable(1.2));
|
||||
File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
|
||||
tempFile.setWritable(true);
|
||||
tempFile.deleteOnExit();
|
||||
if (tempFile.exists())
|
||||
tempFile.delete();
|
||||
|
||||
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
|
||||
Configuration configWriter = new Configuration();
|
||||
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
||||
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 1);
|
||||
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
|
||||
FileSplit outputSplit = new FileSplit(tempFile);
|
||||
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
|
||||
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
|
||||
writer.write(record);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@Test(expected = NumberFormatException.class)
|
||||
public void nonBinaryMultilabel() throws Exception {
|
||||
List<Writable> record = Arrays.asList((Writable) new IntWritable(0),
|
||||
new IntWritable(1),
|
||||
new IntWritable(2));
|
||||
@Test
|
||||
@DisplayName("Non Binary Multilabel")
|
||||
void nonBinaryMultilabel() {
|
||||
assertThrows(NumberFormatException.class, () -> {
|
||||
List<Writable> record = Arrays.asList((Writable) new IntWritable(0), new IntWritable(1), new IntWritable(2));
|
||||
File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
|
||||
tempFile.setWritable(true);
|
||||
tempFile.deleteOnExit();
|
||||
if (tempFile.exists())
|
||||
tempFile.delete();
|
||||
|
||||
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
|
||||
Configuration configWriter = new Configuration();
|
||||
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
||||
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 1);
|
||||
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
|
||||
FileSplit outputSplit = new FileSplit(tempFile);
|
||||
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
|
||||
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
|
||||
writer.write(record);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
|
@ -26,7 +26,7 @@ import org.datavec.api.io.filters.BalancedPathFilter;
|
|||
import org.datavec.api.io.filters.RandomPathFilter;
|
||||
import org.datavec.api.io.labels.ParentPathLabelGenerator;
|
||||
import org.datavec.api.io.labels.PatternPathLabelGenerator;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.io.*;
|
||||
import java.net.URI;
|
||||
|
@ -34,8 +34,9 @@ import java.net.URISyntaxException;
|
|||
import java.util.ArrayList;
|
||||
import java.util.Random;
|
||||
|
||||
import static junit.framework.TestCase.assertTrue;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
/**
|
||||
*
|
||||
|
|
|
@ -20,13 +20,12 @@
|
|||
|
||||
package org.datavec.api.split;
|
||||
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
|
||||
import java.net.URI;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
public class NumberedFileInputSplitTests extends BaseND4JTest {
|
||||
@Test
|
||||
|
@ -69,60 +68,81 @@ public class NumberedFileInputSplitTests extends BaseND4JTest {
|
|||
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
@Test()
|
||||
public void testNumberedFileInputSplitWithLeadingSpaces() {
|
||||
assertThrows(IllegalArgumentException.class,() -> {
|
||||
String baseString = "/path/to/files/prefix-%5d.suffix";
|
||||
int minIdx = 0;
|
||||
int maxIdx = 10;
|
||||
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
|
||||
});
|
||||
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
@Test()
|
||||
public void testNumberedFileInputSplitWithNoLeadingZeroInPadding() {
|
||||
assertThrows(IllegalArgumentException.class, () -> {
|
||||
String baseString = "/path/to/files/prefix%5d.suffix";
|
||||
int minIdx = 0;
|
||||
int maxIdx = 10;
|
||||
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
|
||||
});
|
||||
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
@Test()
|
||||
public void testNumberedFileInputSplitWithLeadingPlusInPadding() {
|
||||
assertThrows(IllegalArgumentException.class,() -> {
|
||||
String baseString = "/path/to/files/prefix%+5d.suffix";
|
||||
int minIdx = 0;
|
||||
int maxIdx = 10;
|
||||
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
|
||||
});
|
||||
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
@Test()
|
||||
public void testNumberedFileInputSplitWithLeadingMinusInPadding() {
|
||||
assertThrows(IllegalArgumentException.class,() -> {
|
||||
String baseString = "/path/to/files/prefix%-5d.suffix";
|
||||
int minIdx = 0;
|
||||
int maxIdx = 10;
|
||||
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
|
||||
});
|
||||
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
@Test()
|
||||
public void testNumberedFileInputSplitWithTwoDigitsInPadding() {
|
||||
assertThrows(IllegalArgumentException.class,() -> {
|
||||
String baseString = "/path/to/files/prefix%011d.suffix";
|
||||
int minIdx = 0;
|
||||
int maxIdx = 10;
|
||||
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
|
||||
});
|
||||
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
@Test()
|
||||
public void testNumberedFileInputSplitWithInnerZerosInPadding() {
|
||||
assertThrows(IllegalArgumentException.class,() -> {
|
||||
String baseString = "/path/to/files/prefix%101d.suffix";
|
||||
int minIdx = 0;
|
||||
int maxIdx = 10;
|
||||
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
|
||||
});
|
||||
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
@Test()
|
||||
public void testNumberedFileInputSplitWithRepeatInnerZerosInPadding() {
|
||||
assertThrows(IllegalArgumentException.class,() -> {
|
||||
String baseString = "/path/to/files/prefix%0505d.suffix";
|
||||
int minIdx = 0;
|
||||
int maxIdx = 10;
|
||||
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
|
||||
});
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -135,7 +155,7 @@ public class NumberedFileInputSplitTests extends BaseND4JTest {
|
|||
String path = locs[j++].getPath();
|
||||
String exp = String.format(baseString, i);
|
||||
String msg = exp + " vs " + path;
|
||||
assertTrue(msg, path.endsWith(exp)); //Note: on Windows, Java can prepend drive to path - "/C:/"
|
||||
assertTrue(path.endsWith(exp),msg); //Note: on Windows, Java can prepend drive to path - "/C:/"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,9 +25,10 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
|
|||
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
|
||||
import org.datavec.api.writable.Text;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.function.Function;
|
||||
|
||||
|
@ -37,22 +38,22 @@ import java.io.IOException;
|
|||
import java.io.InputStream;
|
||||
import java.net.URI;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.nio.file.Path;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertNotEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertNotEquals;
|
||||
|
||||
public class TestStreamInputSplit extends BaseND4JTest {
|
||||
|
||||
@Rule
|
||||
public TemporaryFolder testDir = new TemporaryFolder();
|
||||
|
||||
|
||||
@Test
|
||||
public void testCsvSimple() throws Exception {
|
||||
File dir = testDir.newFolder();
|
||||
public void testCsvSimple(@TempDir Path testDir) throws Exception {
|
||||
File dir = testDir.toFile();
|
||||
File f1 = new File(dir, "file1.txt");
|
||||
File f2 = new File(dir, "file2.txt");
|
||||
|
||||
|
@ -93,9 +94,9 @@ public class TestStreamInputSplit extends BaseND4JTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testCsvSequenceSimple() throws Exception {
|
||||
public void testCsvSequenceSimple(@TempDir Path testDir) throws Exception {
|
||||
|
||||
File dir = testDir.newFolder();
|
||||
File dir = testDir.toFile();
|
||||
File f1 = new File(dir, "file1.txt");
|
||||
File f2 = new File(dir, "file2.txt");
|
||||
|
||||
|
@ -137,8 +138,8 @@ public class TestStreamInputSplit extends BaseND4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testShuffle() throws Exception {
|
||||
File dir = testDir.newFolder();
|
||||
public void testShuffle(@TempDir Path testDir) throws Exception {
|
||||
File dir = testDir.toFile();
|
||||
File f1 = new File(dir, "file1.txt");
|
||||
File f2 = new File(dir, "file2.txt");
|
||||
File f3 = new File(dir, "file3.txt");
|
||||
|
|
|
@ -17,44 +17,43 @@
|
|||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.api.split;
|
||||
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
|
||||
import java.net.URI;
|
||||
import java.net.URISyntaxException;
|
||||
import java.util.Collection;
|
||||
|
||||
import static java.util.Arrays.asList;
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
|
||||
/**
|
||||
* @author Ede Meijer
|
||||
*/
|
||||
public class TransformSplitTest extends BaseND4JTest {
|
||||
@Test
|
||||
public void testTransform() throws URISyntaxException {
|
||||
Collection<URI> inputFiles = asList(new URI("file:///foo/bar/../0.csv"), new URI("file:///foo/1.csv"));
|
||||
@DisplayName("Transform Split Test")
|
||||
class TransformSplitTest extends BaseND4JTest {
|
||||
|
||||
@Test
|
||||
@DisplayName("Test Transform")
|
||||
void testTransform() throws URISyntaxException {
|
||||
Collection<URI> inputFiles = asList(new URI("file:///foo/bar/../0.csv"), new URI("file:///foo/1.csv"));
|
||||
InputSplit SUT = new TransformSplit(new CollectionInputSplit(inputFiles), new TransformSplit.URITransform() {
|
||||
|
||||
@Override
|
||||
public URI apply(URI uri) throws URISyntaxException {
|
||||
return uri.normalize();
|
||||
}
|
||||
});
|
||||
|
||||
assertArrayEquals(new URI[] {new URI("file:///foo/0.csv"), new URI("file:///foo/1.csv")}, SUT.locations());
|
||||
assertArrayEquals(new URI[] { new URI("file:///foo/0.csv"), new URI("file:///foo/1.csv") }, SUT.locations());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSearchReplace() throws URISyntaxException {
|
||||
@DisplayName("Test Search Replace")
|
||||
void testSearchReplace() throws URISyntaxException {
|
||||
Collection<URI> inputFiles = asList(new URI("file:///foo/1-in.csv"), new URI("file:///foo/2-in.csv"));
|
||||
|
||||
InputSplit SUT = TransformSplit.ofSearchReplace(new CollectionInputSplit(inputFiles), "-in.csv", "-out.csv");
|
||||
|
||||
assertArrayEquals(new URI[] {new URI("file:///foo/1-out.csv"), new URI("file:///foo/2-out.csv")},
|
||||
SUT.locations());
|
||||
assertArrayEquals(new URI[] { new URI("file:///foo/1-out.csv"), new URI("file:///foo/2-out.csv") }, SUT.locations());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -27,14 +27,12 @@ import org.datavec.api.split.FileSplit;
|
|||
import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
|
||||
import org.datavec.api.split.partition.PartitionMetaData;
|
||||
import org.datavec.api.split.partition.Partitioner;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.OutputStream;
|
||||
|
||||
import static junit.framework.TestCase.assertTrue;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
public class PartitionerTests extends BaseND4JTest {
|
||||
@Test
|
||||
|
|
|
@ -29,12 +29,12 @@ import org.datavec.api.writable.DoubleWritable;
|
|||
import org.datavec.api.writable.IntWritable;
|
||||
import org.datavec.api.writable.Text;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
public class TestTransformProcess extends BaseND4JTest {
|
||||
|
||||
|
|
|
@ -27,13 +27,13 @@ import org.datavec.api.transform.condition.string.StringRegexColumnCondition;
|
|||
import org.datavec.api.transform.schema.Schema;
|
||||
import org.datavec.api.transform.transform.TestTransforms;
|
||||
import org.datavec.api.writable.*;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
import static org.junit.Assert.assertFalse;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
public class TestConditions extends BaseND4JTest {
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@ import org.datavec.api.transform.schema.Schema;
|
|||
import org.datavec.api.writable.DoubleWritable;
|
||||
import org.datavec.api.writable.IntWritable;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
@ -36,8 +36,8 @@ import java.util.Collections;
|
|||
import java.util.List;
|
||||
|
||||
import static java.util.Arrays.asList;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
public class TestFilters extends BaseND4JTest {
|
||||
|
||||
|
|
|
@ -26,19 +26,22 @@ import org.datavec.api.writable.IntWritable;
|
|||
import org.datavec.api.writable.NullWritable;
|
||||
import org.datavec.api.writable.Text;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
|
||||
import java.nio.file.Path;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
|
||||
public class TestJoin extends BaseND4JTest {
|
||||
|
||||
@Test
|
||||
public void testJoin() {
|
||||
public void testJoin(@TempDir Path testDir) {
|
||||
|
||||
Schema firstSchema =
|
||||
new Schema.Builder().addColumnString("keyColumn").addColumnsInteger("first0", "first1").build();
|
||||
|
@ -46,20 +49,20 @@ public class TestJoin extends BaseND4JTest {
|
|||
Schema secondSchema = new Schema.Builder().addColumnString("keyColumn").addColumnsInteger("second0").build();
|
||||
|
||||
List<List<Writable>> first = new ArrayList<>();
|
||||
first.add(Arrays.asList((Writable) new Text("key0"), new IntWritable(0), new IntWritable(1)));
|
||||
first.add(Arrays.asList((Writable) new Text("key1"), new IntWritable(10), new IntWritable(11)));
|
||||
first.add(Arrays.asList(new Text("key0"), new IntWritable(0), new IntWritable(1)));
|
||||
first.add(Arrays.asList(new Text("key1"), new IntWritable(10), new IntWritable(11)));
|
||||
|
||||
List<List<Writable>> second = new ArrayList<>();
|
||||
second.add(Arrays.asList((Writable) new Text("key0"), new IntWritable(100)));
|
||||
second.add(Arrays.asList((Writable) new Text("key1"), new IntWritable(110)));
|
||||
second.add(Arrays.asList(new Text("key0"), new IntWritable(100)));
|
||||
second.add(Arrays.asList(new Text("key1"), new IntWritable(110)));
|
||||
|
||||
Join join = new Join.Builder(Join.JoinType.Inner).setJoinColumns("keyColumn")
|
||||
.setSchemas(firstSchema, secondSchema).build();
|
||||
|
||||
List<List<Writable>> expected = new ArrayList<>();
|
||||
expected.add(Arrays.asList((Writable) new Text("key0"), new IntWritable(0), new IntWritable(1),
|
||||
expected.add(Arrays.asList(new Text("key0"), new IntWritable(0), new IntWritable(1),
|
||||
new IntWritable(100)));
|
||||
expected.add(Arrays.asList((Writable) new Text("key1"), new IntWritable(10), new IntWritable(11),
|
||||
expected.add(Arrays.asList(new Text("key1"), new IntWritable(10), new IntWritable(11),
|
||||
new IntWritable(110)));
|
||||
|
||||
|
||||
|
@ -94,9 +97,9 @@ public class TestJoin extends BaseND4JTest {
|
|||
}
|
||||
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
@Test()
|
||||
public void testJoinValidation() {
|
||||
|
||||
assertThrows(IllegalArgumentException.class,() -> {
|
||||
Schema firstSchema = new Schema.Builder().addColumnString("keyColumn1").addColumnsInteger("first0", "first1")
|
||||
.build();
|
||||
|
||||
|
@ -104,11 +107,13 @@ public class TestJoin extends BaseND4JTest {
|
|||
|
||||
new Join.Builder(Join.JoinType.Inner).setJoinColumns("keyColumn1", "thisDoesntExist")
|
||||
.setSchemas(firstSchema, secondSchema).build();
|
||||
});
|
||||
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
@Test()
|
||||
public void testJoinValidation2() {
|
||||
|
||||
assertThrows(IllegalArgumentException.class,() -> {
|
||||
Schema firstSchema = new Schema.Builder().addColumnString("keyColumn1").addColumnsInteger("first0", "first1")
|
||||
.build();
|
||||
|
||||
|
@ -116,5 +121,7 @@ public class TestJoin extends BaseND4JTest {
|
|||
|
||||
new Join.Builder(Join.JoinType.Inner).setJoinColumns("keyColumn1").setSchemas(firstSchema, secondSchema)
|
||||
.build();
|
||||
});
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,32 +17,25 @@
|
|||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.api.transform.ops;
|
||||
|
||||
import com.tngtech.archunit.core.importer.ImportOption;
|
||||
import com.tngtech.archunit.junit.AnalyzeClasses;
|
||||
import com.tngtech.archunit.junit.ArchTest;
|
||||
import com.tngtech.archunit.junit.ArchUnitRunner;
|
||||
import com.tngtech.archunit.lang.ArchRule;
|
||||
import com.tngtech.archunit.lang.extension.ArchUnitExtension;
|
||||
import com.tngtech.archunit.lang.extension.ArchUnitExtensions;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
import static com.tngtech.archunit.lang.syntax.ArchRuleDefinition.classes;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
|
||||
@RunWith(ArchUnitRunner.class)
|
||||
@AnalyzeClasses(packages = "org.datavec.api.transform.ops", importOptions = {ImportOption.DoNotIncludeTests.class})
|
||||
public class AggregableMultiOpArchTest extends BaseND4JTest {
|
||||
@AnalyzeClasses(packages = "org.datavec.api.transform.ops", importOptions = { ImportOption.DoNotIncludeTests.class })
|
||||
@DisplayName("Aggregable Multi Op Arch Test")
|
||||
class AggregableMultiOpArchTest extends BaseND4JTest {
|
||||
|
||||
@ArchTest
|
||||
public static final ArchRule ALL_AGGREGATE_OPS_MUST_BE_SERIALIZABLE = classes()
|
||||
.that().resideInAPackage("org.datavec.api.transform.ops")
|
||||
.and().doNotHaveSimpleName("AggregatorImpls")
|
||||
.and().doNotHaveSimpleName("IAggregableReduceOp")
|
||||
.and().doNotHaveSimpleName("StringAggregatorImpls")
|
||||
.and().doNotHaveFullyQualifiedName("org.datavec.api.transform.ops.StringAggregatorImpls$1")
|
||||
.should().implement(Serializable.class)
|
||||
.because("All aggregate ops must be serializable.");
|
||||
public static final ArchRule ALL_AGGREGATE_OPS_MUST_BE_SERIALIZABLE = classes().that().resideInAPackage("org.datavec.api.transform.ops").and().doNotHaveSimpleName("AggregatorImpls").and().doNotHaveSimpleName("IAggregableReduceOp").and().doNotHaveSimpleName("StringAggregatorImpls").and().doNotHaveFullyQualifiedName("org.datavec.api.transform.ops.StringAggregatorImpls$1").should().implement(Serializable.class).because("All aggregate ops must be serializable.");
|
||||
}
|
|
@ -17,52 +17,46 @@
|
|||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.api.transform.ops;
|
||||
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
|
||||
import java.util.*;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
public class AggregableMultiOpTest extends BaseND4JTest {
|
||||
@DisplayName("Aggregable Multi Op Test")
|
||||
class AggregableMultiOpTest extends BaseND4JTest {
|
||||
|
||||
private List<Integer> intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
|
||||
|
||||
@Test
|
||||
public void testMulti() throws Exception {
|
||||
@DisplayName("Test Multi")
|
||||
void testMulti() throws Exception {
|
||||
AggregatorImpls.AggregableFirst<Integer> af = new AggregatorImpls.AggregableFirst<>();
|
||||
AggregatorImpls.AggregableSum<Integer> as = new AggregatorImpls.AggregableSum<>();
|
||||
AggregableMultiOp<Integer> multi = new AggregableMultiOp<>(Arrays.asList(af, as));
|
||||
|
||||
assertTrue(multi.getOperations().size() == 2);
|
||||
for (int i = 0; i < intList.size(); i++) {
|
||||
multi.accept(intList.get(i));
|
||||
}
|
||||
|
||||
// mutablility
|
||||
assertTrue(as.get().toDouble() == 45D);
|
||||
assertTrue(af.get().toInt() == 1);
|
||||
|
||||
List<Writable> res = multi.get();
|
||||
assertTrue(res.get(1).toDouble() == 45D);
|
||||
assertTrue(res.get(0).toInt() == 1);
|
||||
|
||||
AggregatorImpls.AggregableFirst<Integer> rf = new AggregatorImpls.AggregableFirst<>();
|
||||
AggregatorImpls.AggregableSum<Integer> rs = new AggregatorImpls.AggregableSum<>();
|
||||
AggregableMultiOp<Integer> reverse = new AggregableMultiOp<>(Arrays.asList(rf, rs));
|
||||
|
||||
for (int i = 0; i < intList.size(); i++) {
|
||||
reverse.accept(intList.get(intList.size() - i - 1));
|
||||
}
|
||||
|
||||
List<Writable> revRes = reverse.get();
|
||||
assertTrue(revRes.get(1).toDouble() == 45D);
|
||||
assertTrue(revRes.get(0).toInt() == 9);
|
||||
|
||||
multi.combine(reverse);
|
||||
List<Writable> combinedRes = multi.get();
|
||||
assertTrue(combinedRes.get(1).toDouble() == 90D);
|
||||
|
|
|
@ -17,41 +17,39 @@
|
|||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.api.transform.ops;
|
||||
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.ExpectedException;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
|
||||
public class AggregatorImplsTest extends BaseND4JTest {
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
@DisplayName("Aggregator Impls Test")
|
||||
class AggregatorImplsTest extends BaseND4JTest {
|
||||
|
||||
private List<Integer> intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
|
||||
|
||||
private List<String> stringList = new ArrayList<>(Arrays.asList("arakoa", "abracadabra", "blast", "acceptance"));
|
||||
|
||||
@Test
|
||||
public void aggregableFirstTest() {
|
||||
@DisplayName("Aggregable First Test")
|
||||
void aggregableFirstTest() {
|
||||
AggregatorImpls.AggregableFirst<Integer> first = new AggregatorImpls.AggregableFirst<>();
|
||||
for (int i = 0; i < intList.size(); i++) {
|
||||
first.accept(intList.get(i));
|
||||
}
|
||||
assertEquals(1, first.get().toInt());
|
||||
|
||||
AggregatorImpls.AggregableFirst<String> firstS = new AggregatorImpls.AggregableFirst<>();
|
||||
for (int i = 0; i < stringList.size(); i++) {
|
||||
firstS.accept(stringList.get(i));
|
||||
}
|
||||
assertTrue(firstS.get().toString().equals("arakoa"));
|
||||
|
||||
|
||||
AggregatorImpls.AggregableFirst<Integer> reverse = new AggregatorImpls.AggregableFirst<>();
|
||||
for (int i = 0; i < intList.size(); i++) {
|
||||
reverse.accept(intList.get(intList.size() - i - 1));
|
||||
|
@ -60,22 +58,19 @@ public class AggregatorImplsTest extends BaseND4JTest {
|
|||
assertEquals(1, first.get().toInt());
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void aggregableLastTest() {
|
||||
@DisplayName("Aggregable Last Test")
|
||||
void aggregableLastTest() {
|
||||
AggregatorImpls.AggregableLast<Integer> last = new AggregatorImpls.AggregableLast<>();
|
||||
for (int i = 0; i < intList.size(); i++) {
|
||||
last.accept(intList.get(i));
|
||||
}
|
||||
assertEquals(9, last.get().toInt());
|
||||
|
||||
AggregatorImpls.AggregableLast<String> lastS = new AggregatorImpls.AggregableLast<>();
|
||||
for (int i = 0; i < stringList.size(); i++) {
|
||||
lastS.accept(stringList.get(i));
|
||||
}
|
||||
assertTrue(lastS.get().toString().equals("acceptance"));
|
||||
|
||||
|
||||
AggregatorImpls.AggregableLast<Integer> reverse = new AggregatorImpls.AggregableLast<>();
|
||||
for (int i = 0; i < intList.size(); i++) {
|
||||
reverse.accept(intList.get(intList.size() - i - 1));
|
||||
|
@ -85,20 +80,18 @@ public class AggregatorImplsTest extends BaseND4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void aggregableCountTest() {
|
||||
@DisplayName("Aggregable Count Test")
|
||||
void aggregableCountTest() {
|
||||
AggregatorImpls.AggregableCount<Integer> cnt = new AggregatorImpls.AggregableCount<>();
|
||||
for (int i = 0; i < intList.size(); i++) {
|
||||
cnt.accept(intList.get(i));
|
||||
}
|
||||
assertEquals(9, cnt.get().toInt());
|
||||
|
||||
AggregatorImpls.AggregableCount<String> lastS = new AggregatorImpls.AggregableCount<>();
|
||||
for (int i = 0; i < stringList.size(); i++) {
|
||||
lastS.accept(stringList.get(i));
|
||||
}
|
||||
assertEquals(4, lastS.get().toInt());
|
||||
|
||||
|
||||
AggregatorImpls.AggregableCount<Integer> reverse = new AggregatorImpls.AggregableCount<>();
|
||||
for (int i = 0; i < intList.size(); i++) {
|
||||
reverse.accept(intList.get(intList.size() - i - 1));
|
||||
|
@ -108,14 +101,13 @@ public class AggregatorImplsTest extends BaseND4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void aggregableMaxTest() {
|
||||
@DisplayName("Aggregable Max Test")
|
||||
void aggregableMaxTest() {
|
||||
AggregatorImpls.AggregableMax<Integer> mx = new AggregatorImpls.AggregableMax<>();
|
||||
for (int i = 0; i < intList.size(); i++) {
|
||||
mx.accept(intList.get(i));
|
||||
}
|
||||
assertEquals(9, mx.get().toInt());
|
||||
|
||||
|
||||
AggregatorImpls.AggregableMax<Integer> reverse = new AggregatorImpls.AggregableMax<>();
|
||||
for (int i = 0; i < intList.size(); i++) {
|
||||
reverse.accept(intList.get(intList.size() - i - 1));
|
||||
|
@ -124,16 +116,14 @@ public class AggregatorImplsTest extends BaseND4JTest {
|
|||
assertEquals(9, mx.get().toInt());
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void aggregableRangeTest() {
|
||||
@DisplayName("Aggregable Range Test")
|
||||
void aggregableRangeTest() {
|
||||
AggregatorImpls.AggregableRange<Integer> mx = new AggregatorImpls.AggregableRange<>();
|
||||
for (int i = 0; i < intList.size(); i++) {
|
||||
mx.accept(intList.get(i));
|
||||
}
|
||||
assertEquals(8, mx.get().toInt());
|
||||
|
||||
|
||||
AggregatorImpls.AggregableRange<Integer> reverse = new AggregatorImpls.AggregableRange<>();
|
||||
for (int i = 0; i < intList.size(); i++) {
|
||||
reverse.accept(intList.get(intList.size() - i - 1) + 9);
|
||||
|
@ -143,14 +133,13 @@ public class AggregatorImplsTest extends BaseND4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void aggregableMinTest() {
|
||||
@DisplayName("Aggregable Min Test")
|
||||
void aggregableMinTest() {
|
||||
AggregatorImpls.AggregableMin<Integer> mn = new AggregatorImpls.AggregableMin<>();
|
||||
for (int i = 0; i < intList.size(); i++) {
|
||||
mn.accept(intList.get(i));
|
||||
}
|
||||
assertEquals(1, mn.get().toInt());
|
||||
|
||||
|
||||
AggregatorImpls.AggregableMin<Integer> reverse = new AggregatorImpls.AggregableMin<>();
|
||||
for (int i = 0; i < intList.size(); i++) {
|
||||
reverse.accept(intList.get(intList.size() - i - 1));
|
||||
|
@ -160,14 +149,13 @@ public class AggregatorImplsTest extends BaseND4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void aggregableSumTest() {
|
||||
@DisplayName("Aggregable Sum Test")
|
||||
void aggregableSumTest() {
|
||||
AggregatorImpls.AggregableSum<Integer> sm = new AggregatorImpls.AggregableSum<>();
|
||||
for (int i = 0; i < intList.size(); i++) {
|
||||
sm.accept(intList.get(i));
|
||||
}
|
||||
assertEquals(45, sm.get().toInt());
|
||||
|
||||
|
||||
AggregatorImpls.AggregableSum<Integer> reverse = new AggregatorImpls.AggregableSum<>();
|
||||
for (int i = 0; i < intList.size(); i++) {
|
||||
reverse.accept(intList.get(intList.size() - i - 1));
|
||||
|
@ -176,17 +164,15 @@ public class AggregatorImplsTest extends BaseND4JTest {
|
|||
assertEquals(90, sm.get().toInt());
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void aggregableMeanTest() {
|
||||
@DisplayName("Aggregable Mean Test")
|
||||
void aggregableMeanTest() {
|
||||
AggregatorImpls.AggregableMean<Integer> mn = new AggregatorImpls.AggregableMean<>();
|
||||
for (int i = 0; i < intList.size(); i++) {
|
||||
mn.accept(intList.get(i));
|
||||
}
|
||||
assertEquals(9l, (long) mn.getCount());
|
||||
assertEquals(5D, mn.get().toDouble(), 0.001);
|
||||
|
||||
|
||||
AggregatorImpls.AggregableMean<Integer> reverse = new AggregatorImpls.AggregableMean<>();
|
||||
for (int i = 0; i < intList.size(); i++) {
|
||||
reverse.accept(intList.get(intList.size() - i - 1));
|
||||
|
@ -197,80 +183,73 @@ public class AggregatorImplsTest extends BaseND4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void aggregableStdDevTest() {
|
||||
@DisplayName("Aggregable Std Dev Test")
|
||||
void aggregableStdDevTest() {
|
||||
AggregatorImpls.AggregableStdDev<Integer> sd = new AggregatorImpls.AggregableStdDev<>();
|
||||
for (int i = 0; i < intList.size(); i++) {
|
||||
sd.accept(intList.get(i));
|
||||
}
|
||||
assertTrue(Math.abs(sd.get().toDouble() - 2.7386) < 0.0001);
|
||||
|
||||
|
||||
AggregatorImpls.AggregableStdDev<Integer> reverse = new AggregatorImpls.AggregableStdDev<>();
|
||||
for (int i = 0; i < intList.size(); i++) {
|
||||
reverse.accept(intList.get(intList.size() - i - 1));
|
||||
}
|
||||
sd.combine(reverse);
|
||||
assertTrue("" + sd.get().toDouble(), Math.abs(sd.get().toDouble() - 1.8787) < 0.0001);
|
||||
assertTrue(Math.abs(sd.get().toDouble() - 1.8787) < 0.0001,"" + sd.get().toDouble());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void aggregableVariance() {
|
||||
@DisplayName("Aggregable Variance")
|
||||
void aggregableVariance() {
|
||||
AggregatorImpls.AggregableVariance<Integer> sd = new AggregatorImpls.AggregableVariance<>();
|
||||
for (int i = 0; i < intList.size(); i++) {
|
||||
sd.accept(intList.get(i));
|
||||
}
|
||||
assertTrue(Math.abs(sd.get().toDouble() - 60D / 8) < 0.0001);
|
||||
|
||||
|
||||
AggregatorImpls.AggregableVariance<Integer> reverse = new AggregatorImpls.AggregableVariance<>();
|
||||
for (int i = 0; i < intList.size(); i++) {
|
||||
reverse.accept(intList.get(intList.size() - i - 1));
|
||||
}
|
||||
sd.combine(reverse);
|
||||
assertTrue("" + sd.get().toDouble(), Math.abs(sd.get().toDouble() - 3.5294) < 0.0001);
|
||||
assertTrue(Math.abs(sd.get().toDouble() - 3.5294) < 0.0001,"" + sd.get().toDouble());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void aggregableUncorrectedStdDevTest() {
|
||||
@DisplayName("Aggregable Uncorrected Std Dev Test")
|
||||
void aggregableUncorrectedStdDevTest() {
|
||||
AggregatorImpls.AggregableUncorrectedStdDev<Integer> sd = new AggregatorImpls.AggregableUncorrectedStdDev<>();
|
||||
for (int i = 0; i < intList.size(); i++) {
|
||||
sd.accept(intList.get(i));
|
||||
}
|
||||
assertTrue(Math.abs(sd.get().toDouble() - 2.582) < 0.0001);
|
||||
|
||||
|
||||
AggregatorImpls.AggregableUncorrectedStdDev<Integer> reverse =
|
||||
new AggregatorImpls.AggregableUncorrectedStdDev<>();
|
||||
AggregatorImpls.AggregableUncorrectedStdDev<Integer> reverse = new AggregatorImpls.AggregableUncorrectedStdDev<>();
|
||||
for (int i = 0; i < intList.size(); i++) {
|
||||
reverse.accept(intList.get(intList.size() - i - 1));
|
||||
}
|
||||
sd.combine(reverse);
|
||||
assertTrue("" + sd.get().toDouble(), Math.abs(sd.get().toDouble() - 1.8257) < 0.0001);
|
||||
assertTrue(Math.abs(sd.get().toDouble() - 1.8257) < 0.0001,"" + sd.get().toDouble());
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void aggregablePopulationVariance() {
|
||||
@DisplayName("Aggregable Population Variance")
|
||||
void aggregablePopulationVariance() {
|
||||
AggregatorImpls.AggregablePopulationVariance<Integer> sd = new AggregatorImpls.AggregablePopulationVariance<>();
|
||||
for (int i = 0; i < intList.size(); i++) {
|
||||
sd.accept(intList.get(i));
|
||||
}
|
||||
assertTrue(Math.abs(sd.get().toDouble() - 60D / 9) < 0.0001);
|
||||
|
||||
|
||||
AggregatorImpls.AggregablePopulationVariance<Integer> reverse =
|
||||
new AggregatorImpls.AggregablePopulationVariance<>();
|
||||
AggregatorImpls.AggregablePopulationVariance<Integer> reverse = new AggregatorImpls.AggregablePopulationVariance<>();
|
||||
for (int i = 0; i < intList.size(); i++) {
|
||||
reverse.accept(intList.get(intList.size() - i - 1));
|
||||
}
|
||||
sd.combine(reverse);
|
||||
assertTrue("" + sd.get().toDouble(), Math.abs(sd.get().toDouble() - 30D / 9) < 0.0001);
|
||||
assertTrue(Math.abs(sd.get().toDouble() - 30D / 9) < 0.0001,"" + sd.get().toDouble());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void aggregableCountUniqueTest() {
|
||||
@DisplayName("Aggregable Count Unique Test")
|
||||
void aggregableCountUniqueTest() {
|
||||
// at this low range, it's linear counting
|
||||
|
||||
AggregatorImpls.AggregableCountUnique<Integer> cu = new AggregatorImpls.AggregableCountUnique<>();
|
||||
for (int i = 0; i < intList.size(); i++) {
|
||||
cu.accept(intList.get(i));
|
||||
|
@ -278,7 +257,6 @@ public class AggregatorImplsTest extends BaseND4JTest {
|
|||
assertEquals(9, cu.get().toInt());
|
||||
cu.accept(1);
|
||||
assertEquals(9, cu.get().toInt());
|
||||
|
||||
AggregatorImpls.AggregableCountUnique<Integer> reverse = new AggregatorImpls.AggregableCountUnique<>();
|
||||
for (int i = 0; i < intList.size(); i++) {
|
||||
reverse.accept(intList.get(intList.size() - i - 1));
|
||||
|
@ -287,26 +265,25 @@ public class AggregatorImplsTest extends BaseND4JTest {
|
|||
assertEquals(9, cu.get().toInt());
|
||||
}
|
||||
|
||||
@Rule
|
||||
public final ExpectedException exception = ExpectedException.none();
|
||||
|
||||
|
||||
@Test
|
||||
public void incompatibleAggregatorTest() {
|
||||
@DisplayName("Incompatible Aggregator Test")
|
||||
void incompatibleAggregatorTest() {
|
||||
assertThrows(UnsupportedOperationException.class,() -> {
|
||||
AggregatorImpls.AggregableSum<Integer> sm = new AggregatorImpls.AggregableSum<>();
|
||||
for (int i = 0; i < intList.size(); i++) {
|
||||
sm.accept(intList.get(i));
|
||||
}
|
||||
assertEquals(45, sm.get().toInt());
|
||||
|
||||
|
||||
AggregatorImpls.AggregableMean<Integer> reverse = new AggregatorImpls.AggregableMean<>();
|
||||
for (int i = 0; i < intList.size(); i++) {
|
||||
reverse.accept(intList.get(intList.size() - i - 1));
|
||||
}
|
||||
exception.expect(UnsupportedOperationException.class);
|
||||
|
||||
sm.combine(reverse);
|
||||
assertEquals(45, sm.get().toInt());
|
||||
}
|
||||
});
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,77 +17,65 @@
|
|||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.api.transform.ops;
|
||||
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
public class DispatchOpTest extends BaseND4JTest {
|
||||
@DisplayName("Dispatch Op Test")
|
||||
class DispatchOpTest extends BaseND4JTest {
|
||||
|
||||
private List<Integer> intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
|
||||
|
||||
private List<String> stringList = new ArrayList<>(Arrays.asList("arakoa", "abracadabra", "blast", "acceptance"));
|
||||
|
||||
@Test
|
||||
public void testDispatchSimple() {
|
||||
@DisplayName("Test Dispatch Simple")
|
||||
void testDispatchSimple() {
|
||||
AggregatorImpls.AggregableFirst<Integer> af = new AggregatorImpls.AggregableFirst<>();
|
||||
AggregatorImpls.AggregableSum<Integer> as = new AggregatorImpls.AggregableSum<>();
|
||||
AggregableMultiOp<Integer> multiaf =
|
||||
new AggregableMultiOp<>(Collections.<IAggregableReduceOp<Integer, Writable>>singletonList(af));
|
||||
AggregableMultiOp<Integer> multias =
|
||||
new AggregableMultiOp<>(Collections.<IAggregableReduceOp<Integer, Writable>>singletonList(as));
|
||||
|
||||
DispatchOp<Integer, Writable> parallel =
|
||||
new DispatchOp<>(Arrays.<IAggregableReduceOp<Integer, List<Writable>>>asList(multiaf, multias));
|
||||
|
||||
AggregableMultiOp<Integer> multiaf = new AggregableMultiOp<>(Collections.<IAggregableReduceOp<Integer, Writable>>singletonList(af));
|
||||
AggregableMultiOp<Integer> multias = new AggregableMultiOp<>(Collections.<IAggregableReduceOp<Integer, Writable>>singletonList(as));
|
||||
DispatchOp<Integer, Writable> parallel = new DispatchOp<>(Arrays.<IAggregableReduceOp<Integer, List<Writable>>>asList(multiaf, multias));
|
||||
assertTrue(multiaf.getOperations().size() == 1);
|
||||
assertTrue(multias.getOperations().size() == 1);
|
||||
assertTrue(parallel.getOperations().size() == 2);
|
||||
for (int i = 0; i < intList.size(); i++) {
|
||||
parallel.accept(Arrays.asList(intList.get(i), intList.get(i)));
|
||||
}
|
||||
|
||||
List<Writable> res = parallel.get();
|
||||
assertTrue(res.get(1).toDouble() == 45D);
|
||||
assertTrue(res.get(0).toInt() == 1);
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDispatchFlatMap() {
|
||||
@DisplayName("Test Dispatch Flat Map")
|
||||
void testDispatchFlatMap() {
|
||||
AggregatorImpls.AggregableFirst<Integer> af = new AggregatorImpls.AggregableFirst<>();
|
||||
AggregatorImpls.AggregableSum<Integer> as = new AggregatorImpls.AggregableSum<>();
|
||||
AggregableMultiOp<Integer> multi = new AggregableMultiOp<>(Arrays.asList(af, as));
|
||||
|
||||
AggregatorImpls.AggregableLast<Integer> al = new AggregatorImpls.AggregableLast<>();
|
||||
AggregatorImpls.AggregableMax<Integer> amax = new AggregatorImpls.AggregableMax<>();
|
||||
AggregableMultiOp<Integer> otherMulti = new AggregableMultiOp<>(Arrays.asList(al, amax));
|
||||
|
||||
|
||||
DispatchOp<Integer, Writable> parallel = new DispatchOp<>(
|
||||
Arrays.<IAggregableReduceOp<Integer, List<Writable>>>asList(multi, otherMulti));
|
||||
|
||||
DispatchOp<Integer, Writable> parallel = new DispatchOp<>(Arrays.<IAggregableReduceOp<Integer, List<Writable>>>asList(multi, otherMulti));
|
||||
assertTrue(multi.getOperations().size() == 2);
|
||||
assertTrue(otherMulti.getOperations().size() == 2);
|
||||
assertTrue(parallel.getOperations().size() == 2);
|
||||
for (int i = 0; i < intList.size(); i++) {
|
||||
parallel.accept(Arrays.asList(intList.get(i), intList.get(i)));
|
||||
}
|
||||
|
||||
List<Writable> res = parallel.get();
|
||||
assertTrue(res.get(1).toDouble() == 45D);
|
||||
assertTrue(res.get(0).toInt() == 1);
|
||||
assertTrue(res.get(3).toDouble() == 9);
|
||||
assertTrue(res.get(2).toInt() == 9);
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -32,13 +32,14 @@ import org.datavec.api.transform.ops.AggregableMultiOp;
|
|||
import org.datavec.api.transform.ops.IAggregableReduceOp;
|
||||
import org.datavec.api.transform.schema.Schema;
|
||||
import org.datavec.api.writable.*;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.fail;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.fail;
|
||||
|
||||
public class TestMultiOpReduce extends BaseND4JTest {
|
||||
|
||||
|
@ -46,10 +47,10 @@ public class TestMultiOpReduce extends BaseND4JTest {
|
|||
public void testMultiOpReducerDouble() {
|
||||
|
||||
List<List<Writable>> inputs = new ArrayList<>();
|
||||
inputs.add(Arrays.asList((Writable) new Text("someKey"), new DoubleWritable(0)));
|
||||
inputs.add(Arrays.asList((Writable) new Text("someKey"), new DoubleWritable(1)));
|
||||
inputs.add(Arrays.asList((Writable) new Text("someKey"), new DoubleWritable(2)));
|
||||
inputs.add(Arrays.asList((Writable) new Text("someKey"), new DoubleWritable(2)));
|
||||
inputs.add(Arrays.asList(new Text("someKey"), new DoubleWritable(0)));
|
||||
inputs.add(Arrays.asList(new Text("someKey"), new DoubleWritable(1)));
|
||||
inputs.add(Arrays.asList(new Text("someKey"), new DoubleWritable(2)));
|
||||
inputs.add(Arrays.asList(new Text("someKey"), new DoubleWritable(2)));
|
||||
|
||||
Map<ReduceOp, Double> exp = new LinkedHashMap<>();
|
||||
exp.put(ReduceOp.Min, 0.0);
|
||||
|
@ -82,7 +83,7 @@ public class TestMultiOpReduce extends BaseND4JTest {
|
|||
assertEquals(out.get(0), new Text("someKey"));
|
||||
|
||||
String msg = op.toString();
|
||||
assertEquals(msg, exp.get(op), out.get(1).toDouble(), 1e-5);
|
||||
assertEquals(exp.get(op), out.get(1).toDouble(), 1e-5,msg);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -126,19 +127,20 @@ public class TestMultiOpReduce extends BaseND4JTest {
|
|||
assertEquals(out.get(0), new Text("someKey"));
|
||||
|
||||
String msg = op.toString();
|
||||
assertEquals(msg, exp.get(op), out.get(1).toDouble(), 1e-5);
|
||||
assertEquals(exp.get(op), out.get(1).toDouble(), 1e-5,msg);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testReduceString() {
|
||||
|
||||
List<List<Writable>> inputs = new ArrayList<>();
|
||||
inputs.add(Arrays.asList((Writable) new Text("someKey"), new Text("1")));
|
||||
inputs.add(Arrays.asList((Writable) new Text("someKey"), new Text("2")));
|
||||
inputs.add(Arrays.asList((Writable) new Text("someKey"), new Text("3")));
|
||||
inputs.add(Arrays.asList((Writable) new Text("someKey"), new Text("4")));
|
||||
inputs.add(Arrays.asList(new Text("someKey"), new Text("1")));
|
||||
inputs.add(Arrays.asList(new Text("someKey"), new Text("2")));
|
||||
inputs.add(Arrays.asList(new Text("someKey"), new Text("3")));
|
||||
inputs.add(Arrays.asList(new Text("someKey"), new Text("4")));
|
||||
|
||||
Map<ReduceOp, String> exp = new LinkedHashMap<>();
|
||||
exp.put(ReduceOp.Append, "1234");
|
||||
|
@ -210,7 +212,7 @@ public class TestMultiOpReduce extends BaseND4JTest {
|
|||
assertEquals(out.get(0), new Text("someKey"));
|
||||
|
||||
String msg = op.toString();
|
||||
assertEquals(msg, exp.get(op), out.get(1).toDouble(), 1e-5);
|
||||
assertEquals(exp.get(op), out.get(1).toDouble(), 1e-5,msg);
|
||||
}
|
||||
|
||||
for (ReduceOp op : Arrays.asList(ReduceOp.Min, ReduceOp.Max, ReduceOp.Range, ReduceOp.Sum, ReduceOp.Mean,
|
||||
|
|
|
@ -24,13 +24,13 @@ import org.datavec.api.transform.ops.IAggregableReduceOp;
|
|||
import org.datavec.api.transform.reduce.impl.GeographicMidpointReduction;
|
||||
import org.datavec.api.writable.Text;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
public class TestReductions extends BaseND4JTest {
|
||||
|
||||
|
|
|
@ -22,10 +22,10 @@ package org.datavec.api.transform.schema;
|
|||
|
||||
import org.datavec.api.transform.metadata.ColumnMetaData;
|
||||
import org.joda.time.DateTimeZone;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
public class TestJsonYaml extends BaseND4JTest {
|
||||
|
||||
|
|
|
@ -21,10 +21,10 @@
|
|||
package org.datavec.api.transform.schema;
|
||||
|
||||
import org.datavec.api.transform.ColumnType;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
public class TestSchemaMethods extends BaseND4JTest {
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ import org.datavec.api.writable.LongWritable;
|
|||
import org.datavec.api.writable.NullWritable;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.joda.time.DateTimeZone;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
@ -41,7 +41,7 @@ import java.util.Arrays;
|
|||
import java.util.List;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
public class TestReduceSequenceByWindowFunction extends BaseND4JTest {
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@ import org.datavec.api.writable.LongWritable;
|
|||
import org.datavec.api.writable.Text;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.joda.time.DateTimeZone;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
@ -35,7 +35,7 @@ import java.util.Arrays;
|
|||
import java.util.List;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
public class TestSequenceSplit extends BaseND4JTest {
|
||||
|
||||
|
|
|
@ -29,7 +29,7 @@ import org.datavec.api.writable.IntWritable;
|
|||
import org.datavec.api.writable.LongWritable;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.joda.time.DateTimeZone;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
@ -37,7 +37,7 @@ import java.util.Arrays;
|
|||
import java.util.List;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
public class TestWindowFunctions extends BaseND4JTest {
|
||||
|
||||
|
|
|
@ -26,10 +26,10 @@ import org.datavec.api.transform.schema.Schema;
|
|||
import org.datavec.api.transform.serde.testClasses.CustomCondition;
|
||||
import org.datavec.api.transform.serde.testClasses.CustomFilter;
|
||||
import org.datavec.api.transform.serde.testClasses.CustomTransform;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
public class TestCustomTransformJsonYaml extends BaseND4JTest {
|
||||
|
||||
|
|
|
@ -64,13 +64,13 @@ import org.datavec.api.transform.transform.time.TimeMathOpTransform;
|
|||
import org.datavec.api.writable.comparator.DoubleWritableComparator;
|
||||
import org.joda.time.DateTimeFieldType;
|
||||
import org.joda.time.DateTimeZone;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
public class TestYamlJsonSerde extends BaseND4JTest {
|
||||
|
||||
|
|
|
@ -24,12 +24,12 @@ import org.datavec.api.transform.StringReduceOp;
|
|||
import org.datavec.api.transform.schema.Schema;
|
||||
import org.datavec.api.writable.Text;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
public class TestReduce extends BaseND4JTest {
|
||||
|
||||
|
|
|
@ -50,7 +50,7 @@ import org.datavec.api.writable.Text;
|
|||
import org.datavec.api.writable.comparator.LongWritableComparator;
|
||||
import org.joda.time.DateTimeFieldType;
|
||||
import org.joda.time.DateTimeZone;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
|
||||
|
@ -61,7 +61,7 @@ import java.util.HashMap;
|
|||
import java.util.Map;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
public class RegressionTestJson extends BaseND4JTest {
|
||||
|
||||
|
|
|
@ -50,13 +50,13 @@ import org.datavec.api.writable.Text;
|
|||
import org.datavec.api.writable.comparator.LongWritableComparator;
|
||||
import org.joda.time.DateTimeFieldType;
|
||||
import org.joda.time.DateTimeZone;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
public class TestJsonYaml extends BaseND4JTest {
|
||||
|
||||
|
|
|
@ -58,8 +58,8 @@ import org.datavec.api.transform.transform.time.TimeMathOpTransform;
|
|||
import org.datavec.api.writable.*;
|
||||
import org.joda.time.DateTimeFieldType;
|
||||
import org.joda.time.DateTimeZone;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -71,8 +71,8 @@ import java.io.ObjectOutputStream;
|
|||
import java.util.*;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static junit.framework.TestCase.assertEquals;
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
public class TestTransforms extends BaseND4JTest {
|
||||
|
||||
|
@ -277,22 +277,22 @@ public class TestTransforms extends BaseND4JTest {
|
|||
List<String> outputColumns = new ArrayList<>(ALL_COLUMNS);
|
||||
outputColumns.add(NEW_COLUMN);
|
||||
Schema newSchema = transform.transform(schema);
|
||||
Assert.assertEquals(outputColumns, newSchema.getColumnNames());
|
||||
assertEquals(outputColumns, newSchema.getColumnNames());
|
||||
|
||||
List<Writable> input = new ArrayList<>();
|
||||
input.addAll(COLUMN_VALUES);
|
||||
|
||||
transform.setInputSchema(schema);
|
||||
List<Writable> transformed = transform.map(input);
|
||||
Assert.assertEquals(NEW_COLUMN_VALUE, transformed.get(transformed.size() - 1).toString());
|
||||
assertEquals(NEW_COLUMN_VALUE, transformed.get(transformed.size() - 1).toString());
|
||||
|
||||
List<Text> outputColumnValues = new ArrayList<>(COLUMN_VALUES);
|
||||
outputColumnValues.add(new Text(NEW_COLUMN_VALUE));
|
||||
Assert.assertEquals(outputColumnValues, transformed);
|
||||
assertEquals(outputColumnValues, transformed);
|
||||
|
||||
String s = JsonMappers.getMapper().writeValueAsString(transform);
|
||||
Transform transform2 = JsonMappers.getMapper().readValue(s, ConcatenateStringColumns.class);
|
||||
Assert.assertEquals(transform, transform2);
|
||||
assertEquals(transform, transform2);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -309,7 +309,7 @@ public class TestTransforms extends BaseND4JTest {
|
|||
transform.setInputSchema(schema);
|
||||
Schema newSchema = transform.transform(schema);
|
||||
List<String> outputColumns = new ArrayList<>(ALL_COLUMNS);
|
||||
Assert.assertEquals(outputColumns, newSchema.getColumnNames());
|
||||
assertEquals(outputColumns, newSchema.getColumnNames());
|
||||
|
||||
transform = new ChangeCaseStringTransform(STRING_COLUMN, ChangeCaseStringTransform.CaseType.LOWER);
|
||||
transform.setInputSchema(schema);
|
||||
|
@ -320,8 +320,8 @@ public class TestTransforms extends BaseND4JTest {
|
|||
output.add(new Text(TEXT_LOWER_CASE));
|
||||
output.add(new Text(TEXT_MIXED_CASE));
|
||||
List<Writable> transformed = transform.map(input);
|
||||
Assert.assertEquals(transformed.get(0).toString(), TEXT_LOWER_CASE);
|
||||
Assert.assertEquals(transformed, output);
|
||||
assertEquals(transformed.get(0).toString(), TEXT_LOWER_CASE);
|
||||
assertEquals(transformed, output);
|
||||
|
||||
transform = new ChangeCaseStringTransform(STRING_COLUMN, ChangeCaseStringTransform.CaseType.UPPER);
|
||||
transform.setInputSchema(schema);
|
||||
|
@ -329,12 +329,12 @@ public class TestTransforms extends BaseND4JTest {
|
|||
output.add(new Text(TEXT_UPPER_CASE));
|
||||
output.add(new Text(TEXT_MIXED_CASE));
|
||||
transformed = transform.map(input);
|
||||
Assert.assertEquals(transformed.get(0).toString(), TEXT_UPPER_CASE);
|
||||
Assert.assertEquals(transformed, output);
|
||||
assertEquals(transformed.get(0).toString(), TEXT_UPPER_CASE);
|
||||
assertEquals(transformed, output);
|
||||
|
||||
String s = JsonMappers.getMapper().writeValueAsString(transform);
|
||||
Transform transform2 = JsonMappers.getMapper().readValue(s, ChangeCaseStringTransform.class);
|
||||
Assert.assertEquals(transform, transform2);
|
||||
assertEquals(transform, transform2);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -1530,7 +1530,7 @@ public class TestTransforms extends BaseND4JTest {
|
|||
|
||||
String json = JsonMappers.getMapper().writeValueAsString(t);
|
||||
Transform transform2 = JsonMappers.getMapper().readValue(json, StringListToCountsNDArrayTransform.class);
|
||||
Assert.assertEquals(t, transform2);
|
||||
assertEquals(t, transform2);
|
||||
}
|
||||
|
||||
|
||||
|
@ -1551,7 +1551,7 @@ public class TestTransforms extends BaseND4JTest {
|
|||
|
||||
String json = JsonMappers.getMapper().writeValueAsString(t);
|
||||
Transform transform2 = JsonMappers.getMapper().readValue(json, StringListToIndicesNDArrayTransform.class);
|
||||
Assert.assertEquals(t, transform2);
|
||||
assertEquals(t, transform2);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -29,7 +29,7 @@ import org.datavec.api.writable.DoubleWritable;
|
|||
import org.datavec.api.writable.NDArrayWritable;
|
||||
import org.datavec.api.writable.Text;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -39,7 +39,7 @@ import org.nd4j.linalg.ops.transforms.Transforms;
|
|||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
public class TestNDArrayWritableTransforms extends BaseND4JTest {
|
||||
|
||||
|
|
|
@ -30,13 +30,13 @@ import org.datavec.api.transform.ndarray.NDArrayScalarOpTransform;
|
|||
import org.datavec.api.transform.schema.Schema;
|
||||
import org.datavec.api.transform.serde.JsonSerializer;
|
||||
import org.datavec.api.transform.serde.YamlSerializer;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
public class TestYamlJsonSerde extends BaseND4JTest {
|
||||
|
||||
|
|
|
@ -17,29 +17,29 @@
|
|||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.api.transform.transform.parse;
|
||||
|
||||
import org.datavec.api.writable.DoubleWritable;
|
||||
import org.datavec.api.writable.Text;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
@DisplayName("Parse Double Transform Test")
|
||||
class ParseDoubleTransformTest extends BaseND4JTest {
|
||||
|
||||
public class ParseDoubleTransformTest extends BaseND4JTest {
|
||||
@Test
|
||||
public void testDoubleTransform() {
|
||||
@DisplayName("Test Double Transform")
|
||||
void testDoubleTransform() {
|
||||
List<Writable> record = new ArrayList<>();
|
||||
record.add(new Text("0.0"));
|
||||
List<Writable> transformed = Arrays.<Writable>asList(new DoubleWritable(0.0));
|
||||
assertEquals(transformed, new ParseDoubleTransform().map(record));
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -35,26 +35,26 @@ import org.datavec.api.writable.DoubleWritable;
|
|||
import org.datavec.api.writable.Text;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.joda.time.DateTimeZone;
|
||||
import org.junit.Ignore;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
|
||||
import java.io.File;
|
||||
import java.nio.file.Path;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
public class TestUI extends BaseND4JTest {
|
||||
|
||||
@Rule
|
||||
public TemporaryFolder testDir = new TemporaryFolder();
|
||||
|
||||
@Test
|
||||
public void testUI() throws Exception {
|
||||
public void testUI(@TempDir Path testDir) throws Exception {
|
||||
Schema schema = new Schema.Builder().addColumnString("StringColumn").addColumnInteger("IntColumn")
|
||||
.addColumnInteger("IntColumn2").addColumnInteger("IntColumn3")
|
||||
.addColumnTime("TimeColumn", DateTimeZone.UTC).build();
|
||||
|
@ -92,7 +92,7 @@ public class TestUI extends BaseND4JTest {
|
|||
|
||||
DataAnalysis da = new DataAnalysis(schema, list);
|
||||
|
||||
File fDir = testDir.newFolder();
|
||||
File fDir = testDir.toFile();
|
||||
String tempDir = fDir.getAbsolutePath();
|
||||
String outPath = FilenameUtils.concat(tempDir, "datavec_transform_UITest.html");
|
||||
System.out.println(outPath);
|
||||
|
@ -143,7 +143,7 @@ public class TestUI extends BaseND4JTest {
|
|||
|
||||
|
||||
@Test
|
||||
@Ignore
|
||||
@Disabled
|
||||
public void testSequencePlot() throws Exception {
|
||||
|
||||
Schema schema = new SequenceSchema.Builder().addColumnDouble("sinx")
|
||||
|
|
|
@ -17,30 +17,31 @@
|
|||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.api.util;
|
||||
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
|
||||
import java.io.BufferedReader;
|
||||
import java.io.File;
|
||||
import java.io.InputStream;
|
||||
import java.io.InputStreamReader;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
import static org.hamcrest.MatcherAssert.assertThat;
|
||||
import static org.hamcrest.core.AnyOf.anyOf;
|
||||
import static org.hamcrest.core.IsEqual.equalTo;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
|
||||
public class ClassPathResourceTest extends BaseND4JTest {
|
||||
@DisplayName("Class Path Resource Test")
|
||||
class ClassPathResourceTest extends BaseND4JTest {
|
||||
|
||||
private boolean isWindows = false; //File sizes are reported slightly different on Linux vs. Windows
|
||||
// File sizes are reported slightly different on Linux vs. Windows
|
||||
private boolean isWindows = false;
|
||||
|
||||
@Before
|
||||
public void setUp() throws Exception {
|
||||
@BeforeEach
|
||||
void setUp() throws Exception {
|
||||
String osname = System.getProperty("os.name");
|
||||
if (osname != null && osname.toLowerCase().contains("win")) {
|
||||
isWindows = true;
|
||||
|
@ -48,9 +49,9 @@ public class ClassPathResourceTest extends BaseND4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testGetFile1() throws Exception {
|
||||
@DisplayName("Test Get File 1")
|
||||
void testGetFile1() throws Exception {
|
||||
File intFile = new ClassPathResource("datavec-api/iris.dat").getFile();
|
||||
|
||||
assertTrue(intFile.exists());
|
||||
if (isWindows) {
|
||||
assertThat(intFile.length(), anyOf(equalTo(2700L), equalTo(2850L)));
|
||||
|
@ -60,9 +61,9 @@ public class ClassPathResourceTest extends BaseND4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testGetFileSlash1() throws Exception {
|
||||
@DisplayName("Test Get File Slash 1")
|
||||
void testGetFileSlash1() throws Exception {
|
||||
File intFile = new ClassPathResource("datavec-api/iris.dat").getFile();
|
||||
|
||||
assertTrue(intFile.exists());
|
||||
if (isWindows) {
|
||||
assertThat(intFile.length(), anyOf(equalTo(2700L), equalTo(2850L)));
|
||||
|
@ -72,11 +73,10 @@ public class ClassPathResourceTest extends BaseND4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testGetFileWithSpace1() throws Exception {
|
||||
@DisplayName("Test Get File With Space 1")
|
||||
void testGetFileWithSpace1() throws Exception {
|
||||
File intFile = new ClassPathResource("datavec-api/csvsequence test.txt").getFile();
|
||||
|
||||
assertTrue(intFile.exists());
|
||||
|
||||
if (isWindows) {
|
||||
assertThat(intFile.length(), anyOf(equalTo(60L), equalTo(64L)));
|
||||
} else {
|
||||
|
@ -85,16 +85,15 @@ public class ClassPathResourceTest extends BaseND4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testInputStream() throws Exception {
|
||||
@DisplayName("Test Input Stream")
|
||||
void testInputStream() throws Exception {
|
||||
ClassPathResource resource = new ClassPathResource("datavec-api/csvsequence_1.txt");
|
||||
File intFile = resource.getFile();
|
||||
|
||||
if (isWindows) {
|
||||
assertThat(intFile.length(), anyOf(equalTo(60L), equalTo(64L)));
|
||||
} else {
|
||||
assertEquals(60, intFile.length());
|
||||
}
|
||||
|
||||
InputStream stream = resource.getInputStream();
|
||||
BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
|
||||
String line = "";
|
||||
|
@ -102,21 +101,19 @@ public class ClassPathResourceTest extends BaseND4JTest {
|
|||
while ((line = reader.readLine()) != null) {
|
||||
cnt++;
|
||||
}
|
||||
|
||||
assertEquals(5, cnt);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInputStreamSlash() throws Exception {
|
||||
@DisplayName("Test Input Stream Slash")
|
||||
void testInputStreamSlash() throws Exception {
|
||||
ClassPathResource resource = new ClassPathResource("datavec-api/csvsequence_1.txt");
|
||||
File intFile = resource.getFile();
|
||||
|
||||
if (isWindows) {
|
||||
assertThat(intFile.length(), anyOf(equalTo(60L), equalTo(64L)));
|
||||
} else {
|
||||
assertEquals(60, intFile.length());
|
||||
}
|
||||
|
||||
InputStream stream = resource.getInputStream();
|
||||
BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
|
||||
String line = "";
|
||||
|
@ -124,7 +121,6 @@ public class ClassPathResourceTest extends BaseND4JTest {
|
|||
while ((line = reader.readLine()) != null) {
|
||||
cnt++;
|
||||
}
|
||||
|
||||
assertEquals(5, cnt);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,44 +17,41 @@
|
|||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.api.util;
|
||||
|
||||
import org.datavec.api.timeseries.util.TimeSeriesWritableUtils;
|
||||
import org.datavec.api.writable.DoubleWritable;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
|
||||
public class TimeSeriesUtilsTest extends BaseND4JTest {
|
||||
@DisplayName("Time Series Utils Test")
|
||||
class TimeSeriesUtilsTest extends BaseND4JTest {
|
||||
|
||||
@Test
|
||||
public void testTimeSeriesCreation() {
|
||||
@DisplayName("Test Time Series Creation")
|
||||
void testTimeSeriesCreation() {
|
||||
List<List<List<Writable>>> test = new ArrayList<>();
|
||||
List<List<Writable>> timeStep = new ArrayList<>();
|
||||
for(int i = 0; i < 5; i++) {
|
||||
for (int i = 0; i < 5; i++) {
|
||||
timeStep.add(getRecord(5));
|
||||
}
|
||||
|
||||
test.add(timeStep);
|
||||
|
||||
INDArray arr = TimeSeriesWritableUtils.convertWritablesSequence(test).getFirst();
|
||||
assertArrayEquals(new long[]{1,5,5},arr.shape());
|
||||
assertArrayEquals(new long[] { 1, 5, 5 }, arr.shape());
|
||||
}
|
||||
|
||||
private List<Writable> getRecord(int length) {
|
||||
List<Writable> ret = new ArrayList<>();
|
||||
for(int i = 0; i < length; i++) {
|
||||
for (int i = 0; i < length; i++) {
|
||||
ret.add(new DoubleWritable(1.0));
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -17,52 +17,50 @@
|
|||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.api.writable;
|
||||
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.shade.guava.collect.Lists;
|
||||
import org.datavec.api.transform.schema.Schema;
|
||||
import org.datavec.api.util.ndarray.RecordConverter;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.DataSet;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.TimeZone;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
@DisplayName("Record Converter Test")
|
||||
class RecordConverterTest extends BaseND4JTest {
|
||||
|
||||
public class RecordConverterTest extends BaseND4JTest {
|
||||
@Test
|
||||
public void toRecords_PassInClassificationDataSet_ExpectNDArrayAndIntWritables() {
|
||||
INDArray feature1 = Nd4j.create(new double[]{4, -5.7, 10, -0.1}, new long[]{1, 4}, DataType.FLOAT);
|
||||
INDArray feature2 = Nd4j.create(new double[]{11, .7, -1.3, 4}, new long[]{1, 4}, DataType.FLOAT);
|
||||
INDArray label1 = Nd4j.create(new double[]{0, 0, 1, 0}, new long[]{1, 4}, DataType.FLOAT);
|
||||
INDArray label2 = Nd4j.create(new double[]{0, 1, 0, 0}, new long[]{1, 4}, DataType.FLOAT);
|
||||
DataSet dataSet = new DataSet(Nd4j.vstack(Lists.newArrayList(feature1, feature2)),
|
||||
Nd4j.vstack(Lists.newArrayList(label1, label2)));
|
||||
|
||||
@DisplayName("To Records _ Pass In Classification Data Set _ Expect ND Array And Int Writables")
|
||||
void toRecords_PassInClassificationDataSet_ExpectNDArrayAndIntWritables() {
|
||||
INDArray feature1 = Nd4j.create(new double[] { 4, -5.7, 10, -0.1 }, new long[] { 1, 4 }, DataType.FLOAT);
|
||||
INDArray feature2 = Nd4j.create(new double[] { 11, .7, -1.3, 4 }, new long[] { 1, 4 }, DataType.FLOAT);
|
||||
INDArray label1 = Nd4j.create(new double[] { 0, 0, 1, 0 }, new long[] { 1, 4 }, DataType.FLOAT);
|
||||
INDArray label2 = Nd4j.create(new double[] { 0, 1, 0, 0 }, new long[] { 1, 4 }, DataType.FLOAT);
|
||||
DataSet dataSet = new DataSet(Nd4j.vstack(Lists.newArrayList(feature1, feature2)), Nd4j.vstack(Lists.newArrayList(label1, label2)));
|
||||
List<List<Writable>> writableList = RecordConverter.toRecords(dataSet);
|
||||
|
||||
assertEquals(2, writableList.size());
|
||||
testClassificationWritables(feature1, 2, writableList.get(0));
|
||||
testClassificationWritables(feature2, 1, writableList.get(1));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void toRecords_PassInRegressionDataSet_ExpectNDArrayAndDoubleWritables() {
|
||||
INDArray feature = Nd4j.create(new double[]{4, -5.7, 10, -0.1}, new long[]{1, 4}, DataType.FLOAT);
|
||||
INDArray label = Nd4j.create(new double[]{.5, 2, 3, .5}, new long[]{1, 4}, DataType.FLOAT);
|
||||
@DisplayName("To Records _ Pass In Regression Data Set _ Expect ND Array And Double Writables")
|
||||
void toRecords_PassInRegressionDataSet_ExpectNDArrayAndDoubleWritables() {
|
||||
INDArray feature = Nd4j.create(new double[] { 4, -5.7, 10, -0.1 }, new long[] { 1, 4 }, DataType.FLOAT);
|
||||
INDArray label = Nd4j.create(new double[] { .5, 2, 3, .5 }, new long[] { 1, 4 }, DataType.FLOAT);
|
||||
DataSet dataSet = new DataSet(feature, label);
|
||||
|
||||
List<List<Writable>> writableList = RecordConverter.toRecords(dataSet);
|
||||
List<Writable> results = writableList.get(0);
|
||||
NDArrayWritable ndArrayWritable = (NDArrayWritable) results.get(0);
|
||||
|
||||
assertEquals(1, writableList.size());
|
||||
assertEquals(5, results.size());
|
||||
assertEquals(feature, ndArrayWritable.get());
|
||||
|
@ -72,62 +70,39 @@ public class RecordConverterTest extends BaseND4JTest {
|
|||
}
|
||||
}
|
||||
|
||||
private void testClassificationWritables(INDArray expectedFeatureVector, int expectLabelIndex,
|
||||
List<Writable> writables) {
|
||||
private void testClassificationWritables(INDArray expectedFeatureVector, int expectLabelIndex, List<Writable> writables) {
|
||||
NDArrayWritable ndArrayWritable = (NDArrayWritable) writables.get(0);
|
||||
IntWritable intWritable = (IntWritable) writables.get(1);
|
||||
|
||||
assertEquals(2, writables.size());
|
||||
assertEquals(expectedFeatureVector, ndArrayWritable.get());
|
||||
assertEquals(expectLabelIndex, intWritable.get());
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testNDArrayWritableConcat() {
|
||||
List<Writable> l = Arrays.<Writable>asList(new DoubleWritable(1),
|
||||
new NDArrayWritable(Nd4j.create(new double[]{2, 3, 4}, new long[]{1, 3}, DataType.FLOAT)), new DoubleWritable(5),
|
||||
new NDArrayWritable(Nd4j.create(new double[]{6, 7, 8}, new long[]{1, 3}, DataType.FLOAT)), new IntWritable(9),
|
||||
new IntWritable(1));
|
||||
|
||||
INDArray exp = Nd4j.create(new double[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 1}, new long[]{1, 10}, DataType.FLOAT);
|
||||
@DisplayName("Test ND Array Writable Concat")
|
||||
void testNDArrayWritableConcat() {
|
||||
List<Writable> l = Arrays.<Writable>asList(new DoubleWritable(1), new NDArrayWritable(Nd4j.create(new double[] { 2, 3, 4 }, new long[] { 1, 3 }, DataType.FLOAT)), new DoubleWritable(5), new NDArrayWritable(Nd4j.create(new double[] { 6, 7, 8 }, new long[] { 1, 3 }, DataType.FLOAT)), new IntWritable(9), new IntWritable(1));
|
||||
INDArray exp = Nd4j.create(new double[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 1 }, new long[] { 1, 10 }, DataType.FLOAT);
|
||||
INDArray act = RecordConverter.toArray(DataType.FLOAT, l);
|
||||
|
||||
assertEquals(exp, act);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNDArrayWritableConcatToMatrix(){
|
||||
|
||||
List<Writable> l1 = Arrays.<Writable>asList(new DoubleWritable(1), new NDArrayWritable(Nd4j.create(new double[]{2, 3, 4}, new long[]{1,3}, DataType.FLOAT)), new DoubleWritable(5));
|
||||
List<Writable> l2 = Arrays.<Writable>asList(new DoubleWritable(6), new NDArrayWritable(Nd4j.create(new double[]{7, 8, 9}, new long[]{1,3}, DataType.FLOAT)), new DoubleWritable(10));
|
||||
|
||||
INDArray exp = Nd4j.create(new double[][]{
|
||||
{1,2,3,4,5},
|
||||
{6,7,8,9,10}}).castTo(DataType.FLOAT);
|
||||
|
||||
INDArray act = RecordConverter.toMatrix(DataType.FLOAT, Arrays.asList(l1,l2));
|
||||
|
||||
@DisplayName("Test ND Array Writable Concat To Matrix")
|
||||
void testNDArrayWritableConcatToMatrix() {
|
||||
List<Writable> l1 = Arrays.<Writable>asList(new DoubleWritable(1), new NDArrayWritable(Nd4j.create(new double[] { 2, 3, 4 }, new long[] { 1, 3 }, DataType.FLOAT)), new DoubleWritable(5));
|
||||
List<Writable> l2 = Arrays.<Writable>asList(new DoubleWritable(6), new NDArrayWritable(Nd4j.create(new double[] { 7, 8, 9 }, new long[] { 1, 3 }, DataType.FLOAT)), new DoubleWritable(10));
|
||||
INDArray exp = Nd4j.create(new double[][] { { 1, 2, 3, 4, 5 }, { 6, 7, 8, 9, 10 } }).castTo(DataType.FLOAT);
|
||||
INDArray act = RecordConverter.toMatrix(DataType.FLOAT, Arrays.asList(l1, l2));
|
||||
assertEquals(exp, act);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testToRecordWithListOfObject(){
|
||||
final List<Object> list = Arrays.asList((Object)3, 7.0f, "Foo", "Bar", 1.0, 3f, 3L, 7, 0L);
|
||||
final Schema schema = new Schema.Builder()
|
||||
.addColumnInteger("a")
|
||||
.addColumnFloat("b")
|
||||
.addColumnString("c")
|
||||
.addColumnCategorical("d", "Bar", "Baz")
|
||||
.addColumnDouble("e")
|
||||
.addColumnFloat("f")
|
||||
.addColumnLong("g")
|
||||
.addColumnInteger("h")
|
||||
.addColumnTime("i", TimeZone.getDefault())
|
||||
.build();
|
||||
|
||||
@DisplayName("Test To Record With List Of Object")
|
||||
void testToRecordWithListOfObject() {
|
||||
final List<Object> list = Arrays.asList((Object) 3, 7.0f, "Foo", "Bar", 1.0, 3f, 3L, 7, 0L);
|
||||
final Schema schema = new Schema.Builder().addColumnInteger("a").addColumnFloat("b").addColumnString("c").addColumnCategorical("d", "Bar", "Baz").addColumnDouble("e").addColumnFloat("f").addColumnLong("g").addColumnInteger("h").addColumnTime("i", TimeZone.getDefault()).build();
|
||||
final List<Writable> record = RecordConverter.toRecord(schema, list);
|
||||
|
||||
assertEquals(record.get(0).toInt(), 3);
|
||||
assertEquals(record.get(1).toFloat(), 7f, 1e-6);
|
||||
assertEquals(record.get(2).toString(), "Foo");
|
||||
|
@ -137,7 +112,5 @@ public class RecordConverterTest extends BaseND4JTest {
|
|||
assertEquals(record.get(6).toLong(), 3L);
|
||||
assertEquals(record.get(7).toInt(), 7);
|
||||
assertEquals(record.get(8).toLong(), 0);
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,14 +21,14 @@
|
|||
package org.datavec.api.writable;
|
||||
|
||||
import org.datavec.api.transform.metadata.NDArrayMetaData;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.io.*;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
public class TestNDArrayWritableAndSerialization extends BaseND4JTest {
|
||||
|
||||
|
|
|
@ -17,38 +17,38 @@
|
|||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.api.writable;
|
||||
|
||||
import org.datavec.api.writable.batch.NDArrayRecordBatch;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.nio.Buffer;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
public class WritableTest extends BaseND4JTest {
|
||||
@DisplayName("Writable Test")
|
||||
class WritableTest extends BaseND4JTest {
|
||||
|
||||
@Test
|
||||
public void testWritableEqualityReflexive() {
|
||||
@DisplayName("Test Writable Equality Reflexive")
|
||||
void testWritableEqualityReflexive() {
|
||||
assertEquals(new IntWritable(1), new IntWritable(1));
|
||||
assertEquals(new LongWritable(1), new LongWritable(1));
|
||||
assertEquals(new DoubleWritable(1), new DoubleWritable(1));
|
||||
assertEquals(new FloatWritable(1), new FloatWritable(1));
|
||||
assertEquals(new Text("Hello"), new Text("Hello"));
|
||||
assertEquals(new BytesWritable("Hello".getBytes()),new BytesWritable("Hello".getBytes()));
|
||||
INDArray ndArray = Nd4j.rand(new int[]{1, 100});
|
||||
|
||||
assertEquals(new BytesWritable("Hello".getBytes()), new BytesWritable("Hello".getBytes()));
|
||||
INDArray ndArray = Nd4j.rand(new int[] { 1, 100 });
|
||||
assertEquals(new NDArrayWritable(ndArray), new NDArrayWritable(ndArray));
|
||||
assertEquals(new NullWritable(), new NullWritable());
|
||||
assertEquals(new BooleanWritable(true), new BooleanWritable(true));
|
||||
|
@ -56,9 +56,9 @@ public class WritableTest extends BaseND4JTest {
|
|||
assertEquals(new ByteWritable(b), new ByteWritable(b));
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testBytesWritableIndexing() {
|
||||
@DisplayName("Test Bytes Writable Indexing")
|
||||
void testBytesWritableIndexing() {
|
||||
byte[] doubleWrite = new byte[16];
|
||||
ByteBuffer wrapped = ByteBuffer.wrap(doubleWrite);
|
||||
Buffer buffer = (Buffer) wrapped;
|
||||
|
@ -66,53 +66,51 @@ public class WritableTest extends BaseND4JTest {
|
|||
wrapped.putDouble(2.0);
|
||||
buffer.rewind();
|
||||
BytesWritable byteWritable = new BytesWritable(doubleWrite);
|
||||
assertEquals(2,byteWritable.getDouble(1),1e-1);
|
||||
DataBuffer dataBuffer = Nd4j.createBuffer(new double[] {1,2});
|
||||
assertEquals(2, byteWritable.getDouble(1), 1e-1);
|
||||
DataBuffer dataBuffer = Nd4j.createBuffer(new double[] { 1, 2 });
|
||||
double[] d1 = dataBuffer.asDouble();
|
||||
double[] d2 = byteWritable.asNd4jBuffer(DataType.DOUBLE,8).asDouble();
|
||||
double[] d2 = byteWritable.asNd4jBuffer(DataType.DOUBLE, 8).asDouble();
|
||||
assertArrayEquals(d1, d2, 0.0);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testByteWritable() {
|
||||
@DisplayName("Test Byte Writable")
|
||||
void testByteWritable() {
|
||||
byte b = 0xfffffffe;
|
||||
assertEquals(new IntWritable(-2), new ByteWritable(b));
|
||||
assertEquals(new LongWritable(-2), new ByteWritable(b));
|
||||
assertEquals(new ByteWritable(b), new IntWritable(-2));
|
||||
assertEquals(new ByteWritable(b), new LongWritable(-2));
|
||||
|
||||
// those would cast to the same Int
|
||||
byte minus126 = 0xffffff82;
|
||||
assertNotEquals(new ByteWritable(minus126), new IntWritable(130));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testIntLongWritable() {
|
||||
@DisplayName("Test Int Long Writable")
|
||||
void testIntLongWritable() {
|
||||
assertEquals(new IntWritable(1), new LongWritable(1l));
|
||||
assertEquals(new LongWritable(2l), new IntWritable(2));
|
||||
|
||||
long l = 1L << 34;
|
||||
// those would cast to the same Int
|
||||
assertNotEquals(new LongWritable(l), new IntWritable(4));
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testDoubleFloatWritable() {
|
||||
@DisplayName("Test Double Float Writable")
|
||||
void testDoubleFloatWritable() {
|
||||
assertEquals(new DoubleWritable(1d), new FloatWritable(1f));
|
||||
assertEquals(new FloatWritable(2f), new DoubleWritable(2d));
|
||||
|
||||
// we defer to Java equality for Floats
|
||||
assertNotEquals(new DoubleWritable(1.1d), new FloatWritable(1.1f));
|
||||
// same idea as above
|
||||
assertNotEquals(new DoubleWritable(1.1d), new FloatWritable((float)1.1d));
|
||||
|
||||
assertNotEquals(new DoubleWritable((double)Float.MAX_VALUE + 1), new FloatWritable(Float.POSITIVE_INFINITY));
|
||||
assertNotEquals(new DoubleWritable(1.1d), new FloatWritable((float) 1.1d));
|
||||
assertNotEquals(new DoubleWritable((double) Float.MAX_VALUE + 1), new FloatWritable(Float.POSITIVE_INFINITY));
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testFuzzies() {
|
||||
@DisplayName("Test Fuzzies")
|
||||
void testFuzzies() {
|
||||
assertTrue(new DoubleWritable(1.1d).fuzzyEquals(new FloatWritable(1.1f), 1e-6d));
|
||||
assertTrue(new FloatWritable(1.1f).fuzzyEquals(new DoubleWritable(1.1d), 1e-6d));
|
||||
byte b = 0xfffffffe;
|
||||
|
@ -122,62 +120,57 @@ public class WritableTest extends BaseND4JTest {
|
|||
assertTrue(new LongWritable(1).fuzzyEquals(new DoubleWritable(1.05f), 1e-1d));
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testNDArrayRecordBatch(){
|
||||
@DisplayName("Test ND Array Record Batch")
|
||||
void testNDArrayRecordBatch() {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
List<List<INDArray>> orig = new ArrayList<>(); //Outer list over writables/columns, inner list over examples
|
||||
for( int i=0; i<3; i++ ){
|
||||
// Outer list over writables/columns, inner list over examples
|
||||
List<List<INDArray>> orig = new ArrayList<>();
|
||||
for (int i = 0; i < 3; i++) {
|
||||
orig.add(new ArrayList<INDArray>());
|
||||
}
|
||||
|
||||
for( int i=0; i<5; i++ ){
|
||||
orig.get(0).add(Nd4j.rand(1,10));
|
||||
orig.get(1).add(Nd4j.rand(new int[]{1,5,6}));
|
||||
orig.get(2).add(Nd4j.rand(new int[]{1,3,4,5}));
|
||||
for (int i = 0; i < 5; i++) {
|
||||
orig.get(0).add(Nd4j.rand(1, 10));
|
||||
orig.get(1).add(Nd4j.rand(new int[] { 1, 5, 6 }));
|
||||
orig.get(2).add(Nd4j.rand(new int[] { 1, 3, 4, 5 }));
|
||||
}
|
||||
|
||||
List<List<INDArray>> origByExample = new ArrayList<>(); //Outer list over examples, inner list over writables
|
||||
for( int i=0; i<5; i++ ){
|
||||
// Outer list over examples, inner list over writables
|
||||
List<List<INDArray>> origByExample = new ArrayList<>();
|
||||
for (int i = 0; i < 5; i++) {
|
||||
origByExample.add(Arrays.asList(orig.get(0).get(i), orig.get(1).get(i), orig.get(2).get(i)));
|
||||
}
|
||||
|
||||
List<INDArray> batched = new ArrayList<>();
|
||||
for(List<INDArray> l : orig){
|
||||
for (List<INDArray> l : orig) {
|
||||
batched.add(Nd4j.concat(0, l.toArray(new INDArray[5])));
|
||||
}
|
||||
|
||||
NDArrayRecordBatch batch = new NDArrayRecordBatch(batched);
|
||||
assertEquals(5, batch.size());
|
||||
for( int i=0; i<5; i++ ){
|
||||
for (int i = 0; i < 5; i++) {
|
||||
List<Writable> act = batch.get(i);
|
||||
List<INDArray> unboxed = new ArrayList<>();
|
||||
for(Writable w : act){
|
||||
unboxed.add(((NDArrayWritable)w).get());
|
||||
for (Writable w : act) {
|
||||
unboxed.add(((NDArrayWritable) w).get());
|
||||
}
|
||||
List<INDArray> exp = origByExample.get(i);
|
||||
assertEquals(exp.size(), unboxed.size());
|
||||
for( int j=0; j<exp.size(); j++ ){
|
||||
for (int j = 0; j < exp.size(); j++) {
|
||||
assertEquals(exp.get(j), unboxed.get(j));
|
||||
}
|
||||
}
|
||||
|
||||
Iterator<List<Writable>> iter = batch.iterator();
|
||||
int count = 0;
|
||||
while(iter.hasNext()){
|
||||
while (iter.hasNext()) {
|
||||
List<Writable> next = iter.next();
|
||||
List<INDArray> unboxed = new ArrayList<>();
|
||||
for(Writable w : next){
|
||||
unboxed.add(((NDArrayWritable)w).get());
|
||||
for (Writable w : next) {
|
||||
unboxed.add(((NDArrayWritable) w).get());
|
||||
}
|
||||
List<INDArray> exp = origByExample.get(count++);
|
||||
assertEquals(exp.size(), unboxed.size());
|
||||
for( int j=0; j<exp.size(); j++ ){
|
||||
for (int j = 0; j < exp.size(); j++) {
|
||||
assertEquals(exp.get(j), unboxed.get(j));
|
||||
}
|
||||
}
|
||||
|
||||
assertEquals(5, count);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -60,10 +60,10 @@
|
|||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<id>nd4j-tests-cpu</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
<id>nd4j-tests-cuda</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.arrow;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
@ -42,461 +41,397 @@ import org.datavec.api.transform.schema.Schema;
|
|||
import org.datavec.api.writable.*;
|
||||
import org.datavec.arrow.recordreader.ArrowRecordReader;
|
||||
import org.datavec.arrow.recordreader.ArrowWritableRecordBatch;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.common.primitives.Pair;
|
||||
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.File;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.util.*;
|
||||
|
||||
import static java.nio.channels.Channels.newChannel;
|
||||
import static junit.framework.TestCase.assertTrue;
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import java.nio.file.Path;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
|
||||
@Slf4j
|
||||
public class ArrowConverterTest extends BaseND4JTest {
|
||||
@DisplayName("Arrow Converter Test")
|
||||
class ArrowConverterTest extends BaseND4JTest {
|
||||
|
||||
private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE);
|
||||
|
||||
@Rule
|
||||
public TemporaryFolder testDir = new TemporaryFolder();
|
||||
|
||||
|
||||
@TempDir
|
||||
public Path testDir;
|
||||
|
||||
@Test
|
||||
public void testToArrayFromINDArray() {
|
||||
@DisplayName("Test To Array From IND Array")
|
||||
void testToArrayFromINDArray() {
|
||||
Schema.Builder schemaBuilder = new Schema.Builder();
|
||||
schemaBuilder.addColumnNDArray("outputArray",new long[]{1,4});
|
||||
schemaBuilder.addColumnNDArray("outputArray", new long[] { 1, 4 });
|
||||
Schema schema = schemaBuilder.build();
|
||||
int numRows = 4;
|
||||
List<List<Writable>> ret = new ArrayList<>(numRows);
|
||||
for(int i = 0; i < numRows; i++) {
|
||||
ret.add(Arrays.<Writable>asList(new NDArrayWritable(Nd4j.linspace(1,4,4).reshape(1, 4))));
|
||||
for (int i = 0; i < numRows; i++) {
|
||||
ret.add(Arrays.<Writable>asList(new NDArrayWritable(Nd4j.linspace(1, 4, 4).reshape(1, 4))));
|
||||
}
|
||||
|
||||
List<FieldVector> fieldVectors = ArrowConverter.toArrowColumns(bufferAllocator, schema, ret);
|
||||
ArrowWritableRecordBatch arrowWritableRecordBatch = new ArrowWritableRecordBatch(fieldVectors,schema);
|
||||
ArrowWritableRecordBatch arrowWritableRecordBatch = new ArrowWritableRecordBatch(fieldVectors, schema);
|
||||
INDArray array = ArrowConverter.toArray(arrowWritableRecordBatch);
|
||||
assertArrayEquals(new long[]{4,4},array.shape());
|
||||
|
||||
INDArray assertion = Nd4j.repeat(Nd4j.linspace(1,4,4),4).reshape(4,4);
|
||||
assertEquals(assertion,array);
|
||||
assertArrayEquals(new long[] { 4, 4 }, array.shape());
|
||||
INDArray assertion = Nd4j.repeat(Nd4j.linspace(1, 4, 4), 4).reshape(4, 4);
|
||||
assertEquals(assertion, array);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testArrowColumnINDArray() {
|
||||
@DisplayName("Test Arrow Column IND Array")
|
||||
void testArrowColumnINDArray() {
|
||||
Schema.Builder schema = new Schema.Builder();
|
||||
List<String> single = new ArrayList<>();
|
||||
int numCols = 2;
|
||||
INDArray arr = Nd4j.linspace(1,4,4);
|
||||
for(int i = 0; i < numCols; i++) {
|
||||
schema.addColumnNDArray(String.valueOf(i),new long[]{1,4});
|
||||
INDArray arr = Nd4j.linspace(1, 4, 4);
|
||||
for (int i = 0; i < numCols; i++) {
|
||||
schema.addColumnNDArray(String.valueOf(i), new long[] { 1, 4 });
|
||||
single.add(String.valueOf(i));
|
||||
}
|
||||
|
||||
Schema buildSchema = schema.build();
|
||||
List<List<Writable>> list = new ArrayList<>();
|
||||
List<Writable> firstRow = new ArrayList<>();
|
||||
for(int i = 0 ; i < numCols; i++) {
|
||||
for (int i = 0; i < numCols; i++) {
|
||||
firstRow.add(new NDArrayWritable(arr));
|
||||
}
|
||||
|
||||
list.add(firstRow);
|
||||
|
||||
List<FieldVector> fieldVectors = ArrowConverter.toArrowColumns(bufferAllocator, buildSchema, list);
|
||||
assertEquals(numCols,fieldVectors.size());
|
||||
assertEquals(1,fieldVectors.get(0).getValueCount());
|
||||
assertEquals(numCols, fieldVectors.size());
|
||||
assertEquals(1, fieldVectors.get(0).getValueCount());
|
||||
assertFalse(fieldVectors.get(0).isNull(0));
|
||||
|
||||
ArrowWritableRecordBatch arrowWritableRecordBatch = ArrowConverter.toArrowWritables(fieldVectors, buildSchema);
|
||||
assertEquals(1,arrowWritableRecordBatch.size());
|
||||
|
||||
assertEquals(1, arrowWritableRecordBatch.size());
|
||||
Writable writable = arrowWritableRecordBatch.get(0).get(0);
|
||||
assertTrue(writable instanceof NDArrayWritable);
|
||||
NDArrayWritable ndArrayWritable = (NDArrayWritable) writable;
|
||||
assertEquals(arr,ndArrayWritable.get());
|
||||
|
||||
assertEquals(arr, ndArrayWritable.get());
|
||||
Writable writable1 = ArrowConverter.fromEntry(0, fieldVectors.get(0), ColumnType.NDArray);
|
||||
NDArrayWritable ndArrayWritablewritable1 = (NDArrayWritable) writable1;
|
||||
System.out.println(ndArrayWritablewritable1.get());
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testArrowColumnString() {
|
||||
@DisplayName("Test Arrow Column String")
|
||||
void testArrowColumnString() {
|
||||
Schema.Builder schema = new Schema.Builder();
|
||||
List<String> single = new ArrayList<>();
|
||||
for(int i = 0; i < 2; i++) {
|
||||
for (int i = 0; i < 2; i++) {
|
||||
schema.addColumnInteger(String.valueOf(i));
|
||||
single.add(String.valueOf(i));
|
||||
}
|
||||
|
||||
|
||||
List<FieldVector> fieldVectors = ArrowConverter.toArrowColumnsStringSingle(bufferAllocator, schema.build(), single);
|
||||
List<List<Writable>> records = ArrowConverter.toArrowWritables(fieldVectors, schema.build());
|
||||
List<List<Writable>> assertion = new ArrayList<>();
|
||||
assertion.add(Arrays.<Writable>asList(new IntWritable(0),new IntWritable(1)));
|
||||
assertEquals(assertion,records);
|
||||
|
||||
assertion.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(1)));
|
||||
assertEquals(assertion, records);
|
||||
List<List<String>> batch = new ArrayList<>();
|
||||
for(int i = 0; i < 2; i++) {
|
||||
batch.add(Arrays.asList(String.valueOf(i),String.valueOf(i)));
|
||||
for (int i = 0; i < 2; i++) {
|
||||
batch.add(Arrays.asList(String.valueOf(i), String.valueOf(i)));
|
||||
}
|
||||
|
||||
List<FieldVector> fieldVectorsBatch = ArrowConverter.toArrowColumnsString(bufferAllocator, schema.build(), batch);
|
||||
List<List<Writable>> batchRecords = ArrowConverter.toArrowWritables(fieldVectorsBatch, schema.build());
|
||||
|
||||
List<List<Writable>> assertionBatch = new ArrayList<>();
|
||||
assertionBatch.add(Arrays.<Writable>asList(new IntWritable(0),new IntWritable(0)));
|
||||
assertionBatch.add(Arrays.<Writable>asList(new IntWritable(1),new IntWritable(1)));
|
||||
assertEquals(assertionBatch,batchRecords);
|
||||
|
||||
|
||||
assertionBatch.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(0)));
|
||||
assertionBatch.add(Arrays.<Writable>asList(new IntWritable(1), new IntWritable(1)));
|
||||
assertEquals(assertionBatch, batchRecords);
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Test
|
||||
public void testArrowBatchSetTime() {
|
||||
@DisplayName("Test Arrow Batch Set Time")
|
||||
void testArrowBatchSetTime() {
|
||||
Schema.Builder schema = new Schema.Builder();
|
||||
List<String> single = new ArrayList<>();
|
||||
for(int i = 0; i < 2; i++) {
|
||||
schema.addColumnTime(String.valueOf(i),TimeZone.getDefault());
|
||||
for (int i = 0; i < 2; i++) {
|
||||
schema.addColumnTime(String.valueOf(i), TimeZone.getDefault());
|
||||
single.add(String.valueOf(i));
|
||||
}
|
||||
|
||||
List<List<Writable>> input = Arrays.asList(
|
||||
Arrays.<Writable>asList(new LongWritable(0),new LongWritable(1)),
|
||||
Arrays.<Writable>asList(new LongWritable(2),new LongWritable(3))
|
||||
);
|
||||
|
||||
List<FieldVector> fieldVector = ArrowConverter.toArrowColumns(bufferAllocator,schema.build(),input);
|
||||
ArrowWritableRecordBatch writableRecordBatch = new ArrowWritableRecordBatch(fieldVector,schema.build());
|
||||
List<List<Writable>> input = Arrays.asList(Arrays.<Writable>asList(new LongWritable(0), new LongWritable(1)), Arrays.<Writable>asList(new LongWritable(2), new LongWritable(3)));
|
||||
List<FieldVector> fieldVector = ArrowConverter.toArrowColumns(bufferAllocator, schema.build(), input);
|
||||
ArrowWritableRecordBatch writableRecordBatch = new ArrowWritableRecordBatch(fieldVector, schema.build());
|
||||
List<Writable> assertion = Arrays.<Writable>asList(new LongWritable(4), new LongWritable(5));
|
||||
writableRecordBatch.set(1, Arrays.<Writable>asList(new LongWritable(4),new LongWritable(5)));
|
||||
writableRecordBatch.set(1, Arrays.<Writable>asList(new LongWritable(4), new LongWritable(5)));
|
||||
List<Writable> recordTest = writableRecordBatch.get(1);
|
||||
assertEquals(assertion,recordTest);
|
||||
assertEquals(assertion, recordTest);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testArrowBatchSet() {
|
||||
@DisplayName("Test Arrow Batch Set")
|
||||
void testArrowBatchSet() {
|
||||
Schema.Builder schema = new Schema.Builder();
|
||||
List<String> single = new ArrayList<>();
|
||||
for(int i = 0; i < 2; i++) {
|
||||
for (int i = 0; i < 2; i++) {
|
||||
schema.addColumnInteger(String.valueOf(i));
|
||||
single.add(String.valueOf(i));
|
||||
}
|
||||
|
||||
List<List<Writable>> input = Arrays.asList(
|
||||
Arrays.<Writable>asList(new IntWritable(0),new IntWritable(1)),
|
||||
Arrays.<Writable>asList(new IntWritable(2),new IntWritable(3))
|
||||
);
|
||||
|
||||
List<FieldVector> fieldVector = ArrowConverter.toArrowColumns(bufferAllocator,schema.build(),input);
|
||||
ArrowWritableRecordBatch writableRecordBatch = new ArrowWritableRecordBatch(fieldVector,schema.build());
|
||||
List<List<Writable>> input = Arrays.asList(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(1)), Arrays.<Writable>asList(new IntWritable(2), new IntWritable(3)));
|
||||
List<FieldVector> fieldVector = ArrowConverter.toArrowColumns(bufferAllocator, schema.build(), input);
|
||||
ArrowWritableRecordBatch writableRecordBatch = new ArrowWritableRecordBatch(fieldVector, schema.build());
|
||||
List<Writable> assertion = Arrays.<Writable>asList(new IntWritable(4), new IntWritable(5));
|
||||
writableRecordBatch.set(1, Arrays.<Writable>asList(new IntWritable(4),new IntWritable(5)));
|
||||
writableRecordBatch.set(1, Arrays.<Writable>asList(new IntWritable(4), new IntWritable(5)));
|
||||
List<Writable> recordTest = writableRecordBatch.get(1);
|
||||
assertEquals(assertion,recordTest);
|
||||
assertEquals(assertion, recordTest);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testArrowColumnsStringTimeSeries() {
|
||||
@DisplayName("Test Arrow Columns String Time Series")
|
||||
void testArrowColumnsStringTimeSeries() {
|
||||
Schema.Builder schema = new Schema.Builder();
|
||||
List<List<List<String>>> entries = new ArrayList<>();
|
||||
for(int i = 0; i < 3; i++) {
|
||||
for (int i = 0; i < 3; i++) {
|
||||
schema.addColumnInteger(String.valueOf(i));
|
||||
}
|
||||
|
||||
for(int i = 0; i < 5; i++) {
|
||||
for (int i = 0; i < 5; i++) {
|
||||
List<List<String>> arr = Arrays.asList(Arrays.asList(String.valueOf(i), String.valueOf(i), String.valueOf(i)));
|
||||
entries.add(arr);
|
||||
}
|
||||
|
||||
List<FieldVector> fieldVectors = ArrowConverter.toArrowColumnsStringTimeSeries(bufferAllocator, schema.build(), entries);
|
||||
assertEquals(3,fieldVectors.size());
|
||||
assertEquals(5,fieldVectors.get(0).getValueCount());
|
||||
|
||||
|
||||
assertEquals(3, fieldVectors.size());
|
||||
assertEquals(5, fieldVectors.get(0).getValueCount());
|
||||
INDArray exp = Nd4j.create(5, 3);
|
||||
for( int i = 0; i < 5; i++) {
|
||||
for (int i = 0; i < 5; i++) {
|
||||
exp.getRow(i).assign(i);
|
||||
}
|
||||
//Convert to ArrowWritableRecordBatch - note we can't do this in general with time series...
|
||||
// Convert to ArrowWritableRecordBatch - note we can't do this in general with time series...
|
||||
ArrowWritableRecordBatch wri = ArrowConverter.toArrowWritables(fieldVectors, schema.build());
|
||||
INDArray arr = ArrowConverter.toArray(wri);
|
||||
assertArrayEquals(new long[] {5,3}, arr.shape());
|
||||
|
||||
|
||||
assertArrayEquals(new long[] { 5, 3 }, arr.shape());
|
||||
assertEquals(exp, arr);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testConvertVector() {
|
||||
@DisplayName("Test Convert Vector")
|
||||
void testConvertVector() {
|
||||
Schema.Builder schema = new Schema.Builder();
|
||||
List<List<List<String>>> entries = new ArrayList<>();
|
||||
for(int i = 0; i < 3; i++) {
|
||||
for (int i = 0; i < 3; i++) {
|
||||
schema.addColumnInteger(String.valueOf(i));
|
||||
}
|
||||
|
||||
for(int i = 0; i < 5; i++) {
|
||||
for (int i = 0; i < 5; i++) {
|
||||
List<List<String>> arr = Arrays.asList(Arrays.asList(String.valueOf(i), String.valueOf(i), String.valueOf(i)));
|
||||
entries.add(arr);
|
||||
}
|
||||
|
||||
List<FieldVector> fieldVectors = ArrowConverter.toArrowColumnsStringTimeSeries(bufferAllocator, schema.build(), entries);
|
||||
INDArray arr = ArrowConverter.convertArrowVector(fieldVectors.get(0),schema.build().getType(0));
|
||||
assertEquals(5,arr.length());
|
||||
INDArray arr = ArrowConverter.convertArrowVector(fieldVectors.get(0), schema.build().getType(0));
|
||||
assertEquals(5, arr.length());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCreateNDArray() throws Exception {
|
||||
@DisplayName("Test Create ND Array")
|
||||
void testCreateNDArray() throws Exception {
|
||||
val recordsToWrite = recordToWrite();
|
||||
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
|
||||
ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(),recordsToWrite.getFirst(),byteArrayOutputStream);
|
||||
|
||||
File f = testDir.newFolder();
|
||||
|
||||
ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(), recordsToWrite.getFirst(), byteArrayOutputStream);
|
||||
File f = testDir.toFile();
|
||||
File tmpFile = new File(f, "tmp-arrow-file-" + UUID.randomUUID().toString() + ".arrorw");
|
||||
FileOutputStream outputStream = new FileOutputStream(tmpFile);
|
||||
tmpFile.deleteOnExit();
|
||||
ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(),recordsToWrite.getFirst(),outputStream);
|
||||
ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(), recordsToWrite.getFirst(), outputStream);
|
||||
outputStream.flush();
|
||||
outputStream.close();
|
||||
|
||||
Pair<Schema, ArrowWritableRecordBatch> schemaArrowWritableRecordBatchPair = ArrowConverter.readFromFile(tmpFile);
|
||||
assertEquals(recordsToWrite.getFirst(),schemaArrowWritableRecordBatchPair.getFirst());
|
||||
assertEquals(recordsToWrite.getRight(),schemaArrowWritableRecordBatchPair.getRight().toArrayList());
|
||||
|
||||
assertEquals(recordsToWrite.getFirst(), schemaArrowWritableRecordBatchPair.getFirst());
|
||||
assertEquals(recordsToWrite.getRight(), schemaArrowWritableRecordBatchPair.getRight().toArrayList());
|
||||
byte[] arr = byteArrayOutputStream.toByteArray();
|
||||
val read = ArrowConverter.readFromBytes(arr);
|
||||
assertEquals(recordsToWrite,read);
|
||||
|
||||
//send file
|
||||
assertEquals(recordsToWrite, read);
|
||||
// send file
|
||||
File tmp = tmpDataFile(recordsToWrite);
|
||||
ArrowRecordReader recordReader = new ArrowRecordReader();
|
||||
|
||||
recordReader.initialize(new FileSplit(tmp));
|
||||
|
||||
recordReader.next();
|
||||
ArrowWritableRecordBatch currentBatch = recordReader.getCurrentBatch();
|
||||
INDArray arr2 = ArrowConverter.toArray(currentBatch);
|
||||
assertEquals(2,arr2.rows());
|
||||
assertEquals(2,arr2.columns());
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testConvertToArrowVectors() {
|
||||
INDArray matrix = Nd4j.linspace(1,4,4).reshape(2,2);
|
||||
val vectors = ArrowConverter.convertToArrowVector(matrix,Arrays.asList("test","test2"), ColumnType.Double,bufferAllocator);
|
||||
assertEquals(matrix.rows(),vectors.size());
|
||||
|
||||
INDArray vector = Nd4j.linspace(1,4,4);
|
||||
val vectors2 = ArrowConverter.convertToArrowVector(vector,Arrays.asList("test"), ColumnType.Double,bufferAllocator);
|
||||
assertEquals(1,vectors2.size());
|
||||
assertEquals(matrix.length(),vectors2.get(0).getValueCount());
|
||||
|
||||
assertEquals(2, arr2.rows());
|
||||
assertEquals(2, arr2.columns());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSchemaConversionBasic() {
|
||||
@DisplayName("Test Convert To Arrow Vectors")
|
||||
void testConvertToArrowVectors() {
|
||||
INDArray matrix = Nd4j.linspace(1, 4, 4).reshape(2, 2);
|
||||
val vectors = ArrowConverter.convertToArrowVector(matrix, Arrays.asList("test", "test2"), ColumnType.Double, bufferAllocator);
|
||||
assertEquals(matrix.rows(), vectors.size());
|
||||
INDArray vector = Nd4j.linspace(1, 4, 4);
|
||||
val vectors2 = ArrowConverter.convertToArrowVector(vector, Arrays.asList("test"), ColumnType.Double, bufferAllocator);
|
||||
assertEquals(1, vectors2.size());
|
||||
assertEquals(matrix.length(), vectors2.get(0).getValueCount());
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("Test Schema Conversion Basic")
|
||||
void testSchemaConversionBasic() {
|
||||
Schema.Builder schemaBuilder = new Schema.Builder();
|
||||
for(int i = 0; i < 2; i++) {
|
||||
for (int i = 0; i < 2; i++) {
|
||||
schemaBuilder.addColumnDouble("test-" + i);
|
||||
schemaBuilder.addColumnInteger("testi-" + i);
|
||||
schemaBuilder.addColumnLong("testl-" + i);
|
||||
schemaBuilder.addColumnFloat("testf-" + i);
|
||||
}
|
||||
|
||||
|
||||
Schema schema = schemaBuilder.build();
|
||||
val schema2 = ArrowConverter.toArrowSchema(schema);
|
||||
assertEquals(8,schema2.getFields().size());
|
||||
assertEquals(8, schema2.getFields().size());
|
||||
val convertedSchema = ArrowConverter.toDatavecSchema(schema2);
|
||||
assertEquals(schema,convertedSchema);
|
||||
assertEquals(schema, convertedSchema);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testReadSchemaAndRecordsFromByteArray() throws Exception {
|
||||
@DisplayName("Test Read Schema And Records From Byte Array")
|
||||
void testReadSchemaAndRecordsFromByteArray() throws Exception {
|
||||
BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
|
||||
|
||||
int valueCount = 3;
|
||||
List<Field> fields = new ArrayList<>();
|
||||
fields.add(ArrowConverter.field("field1",new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)));
|
||||
fields.add(ArrowConverter.field("field1", new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)));
|
||||
fields.add(ArrowConverter.intField("field2"));
|
||||
|
||||
List<FieldVector> fieldVectors = new ArrayList<>();
|
||||
fieldVectors.add(ArrowConverter.vectorFor(allocator,"field1",new float[] {1,2,3}));
|
||||
fieldVectors.add(ArrowConverter.vectorFor(allocator,"field2",new int[] {1,2,3}));
|
||||
|
||||
|
||||
fieldVectors.add(ArrowConverter.vectorFor(allocator, "field1", new float[] { 1, 2, 3 }));
|
||||
fieldVectors.add(ArrowConverter.vectorFor(allocator, "field2", new int[] { 1, 2, 3 }));
|
||||
org.apache.arrow.vector.types.pojo.Schema schema = new org.apache.arrow.vector.types.pojo.Schema(fields);
|
||||
|
||||
VectorSchemaRoot schemaRoot1 = new VectorSchemaRoot(schema, fieldVectors, valueCount);
|
||||
VectorUnloader vectorUnloader = new VectorUnloader(schemaRoot1);
|
||||
vectorUnloader.getRecordBatch();
|
||||
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
|
||||
try(ArrowFileWriter arrowFileWriter = new ArrowFileWriter(schemaRoot1,null,newChannel(byteArrayOutputStream))) {
|
||||
try (ArrowFileWriter arrowFileWriter = new ArrowFileWriter(schemaRoot1, null, newChannel(byteArrayOutputStream))) {
|
||||
arrowFileWriter.writeBatch();
|
||||
} catch (IOException e) {
|
||||
log.error("",e);
|
||||
log.error("", e);
|
||||
}
|
||||
|
||||
byte[] arr = byteArrayOutputStream.toByteArray();
|
||||
val arr2 = ArrowConverter.readFromBytes(arr);
|
||||
assertEquals(2,arr2.getFirst().numColumns());
|
||||
assertEquals(3,arr2.getRight().size());
|
||||
|
||||
val arrowCols = ArrowConverter.toArrowColumns(allocator,arr2.getFirst(),arr2.getRight());
|
||||
assertEquals(2,arrowCols.size());
|
||||
assertEquals(valueCount,arrowCols.get(0).getValueCount());
|
||||
assertEquals(2, arr2.getFirst().numColumns());
|
||||
assertEquals(3, arr2.getRight().size());
|
||||
val arrowCols = ArrowConverter.toArrowColumns(allocator, arr2.getFirst(), arr2.getRight());
|
||||
assertEquals(2, arrowCols.size());
|
||||
assertEquals(valueCount, arrowCols.get(0).getValueCount());
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testVectorForEdgeCases() {
|
||||
@DisplayName("Test Vector For Edge Cases")
|
||||
void testVectorForEdgeCases() {
|
||||
BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
|
||||
val vector = ArrowConverter.vectorFor(allocator,"field1",new float[]{Float.MIN_VALUE,Float.MAX_VALUE});
|
||||
assertEquals(Float.MIN_VALUE,vector.get(0),1e-2);
|
||||
assertEquals(Float.MAX_VALUE,vector.get(1),1e-2);
|
||||
|
||||
val vectorInt = ArrowConverter.vectorFor(allocator,"field1",new int[]{Integer.MIN_VALUE,Integer.MAX_VALUE});
|
||||
assertEquals(Integer.MIN_VALUE,vectorInt.get(0),1e-2);
|
||||
assertEquals(Integer.MAX_VALUE,vectorInt.get(1),1e-2);
|
||||
|
||||
val vector = ArrowConverter.vectorFor(allocator, "field1", new float[] { Float.MIN_VALUE, Float.MAX_VALUE });
|
||||
assertEquals(Float.MIN_VALUE, vector.get(0), 1e-2);
|
||||
assertEquals(Float.MAX_VALUE, vector.get(1), 1e-2);
|
||||
val vectorInt = ArrowConverter.vectorFor(allocator, "field1", new int[] { Integer.MIN_VALUE, Integer.MAX_VALUE });
|
||||
assertEquals(Integer.MIN_VALUE, vectorInt.get(0), 1e-2);
|
||||
assertEquals(Integer.MAX_VALUE, vectorInt.get(1), 1e-2);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testVectorFor() {
|
||||
@DisplayName("Test Vector For")
|
||||
void testVectorFor() {
|
||||
BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
|
||||
|
||||
val vector = ArrowConverter.vectorFor(allocator,"field1",new float[]{1,2,3});
|
||||
assertEquals(3,vector.getValueCount());
|
||||
assertEquals(1,vector.get(0),1e-2);
|
||||
assertEquals(2,vector.get(1),1e-2);
|
||||
assertEquals(3,vector.get(2),1e-2);
|
||||
|
||||
val vectorLong = ArrowConverter.vectorFor(allocator,"field1",new long[]{1,2,3});
|
||||
assertEquals(3,vectorLong.getValueCount());
|
||||
assertEquals(1,vectorLong.get(0),1e-2);
|
||||
assertEquals(2,vectorLong.get(1),1e-2);
|
||||
assertEquals(3,vectorLong.get(2),1e-2);
|
||||
|
||||
|
||||
val vectorInt = ArrowConverter.vectorFor(allocator,"field1",new int[]{1,2,3});
|
||||
assertEquals(3,vectorInt.getValueCount());
|
||||
assertEquals(1,vectorInt.get(0),1e-2);
|
||||
assertEquals(2,vectorInt.get(1),1e-2);
|
||||
assertEquals(3,vectorInt.get(2),1e-2);
|
||||
|
||||
val vectorDouble = ArrowConverter.vectorFor(allocator,"field1",new double[]{1,2,3});
|
||||
assertEquals(3,vectorDouble.getValueCount());
|
||||
assertEquals(1,vectorDouble.get(0),1e-2);
|
||||
assertEquals(2,vectorDouble.get(1),1e-2);
|
||||
assertEquals(3,vectorDouble.get(2),1e-2);
|
||||
|
||||
|
||||
val vectorBool = ArrowConverter.vectorFor(allocator,"field1",new boolean[]{true,true,false});
|
||||
assertEquals(3,vectorBool.getValueCount());
|
||||
assertEquals(1,vectorBool.get(0),1e-2);
|
||||
assertEquals(1,vectorBool.get(1),1e-2);
|
||||
assertEquals(0,vectorBool.get(2),1e-2);
|
||||
val vector = ArrowConverter.vectorFor(allocator, "field1", new float[] { 1, 2, 3 });
|
||||
assertEquals(3, vector.getValueCount());
|
||||
assertEquals(1, vector.get(0), 1e-2);
|
||||
assertEquals(2, vector.get(1), 1e-2);
|
||||
assertEquals(3, vector.get(2), 1e-2);
|
||||
val vectorLong = ArrowConverter.vectorFor(allocator, "field1", new long[] { 1, 2, 3 });
|
||||
assertEquals(3, vectorLong.getValueCount());
|
||||
assertEquals(1, vectorLong.get(0), 1e-2);
|
||||
assertEquals(2, vectorLong.get(1), 1e-2);
|
||||
assertEquals(3, vectorLong.get(2), 1e-2);
|
||||
val vectorInt = ArrowConverter.vectorFor(allocator, "field1", new int[] { 1, 2, 3 });
|
||||
assertEquals(3, vectorInt.getValueCount());
|
||||
assertEquals(1, vectorInt.get(0), 1e-2);
|
||||
assertEquals(2, vectorInt.get(1), 1e-2);
|
||||
assertEquals(3, vectorInt.get(2), 1e-2);
|
||||
val vectorDouble = ArrowConverter.vectorFor(allocator, "field1", new double[] { 1, 2, 3 });
|
||||
assertEquals(3, vectorDouble.getValueCount());
|
||||
assertEquals(1, vectorDouble.get(0), 1e-2);
|
||||
assertEquals(2, vectorDouble.get(1), 1e-2);
|
||||
assertEquals(3, vectorDouble.get(2), 1e-2);
|
||||
val vectorBool = ArrowConverter.vectorFor(allocator, "field1", new boolean[] { true, true, false });
|
||||
assertEquals(3, vectorBool.getValueCount());
|
||||
assertEquals(1, vectorBool.get(0), 1e-2);
|
||||
assertEquals(1, vectorBool.get(1), 1e-2);
|
||||
assertEquals(0, vectorBool.get(2), 1e-2);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRecordReaderAndWriteFile() throws Exception {
|
||||
@DisplayName("Test Record Reader And Write File")
|
||||
void testRecordReaderAndWriteFile() throws Exception {
|
||||
val recordsToWrite = recordToWrite();
|
||||
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
|
||||
ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(),recordsToWrite.getFirst(),byteArrayOutputStream);
|
||||
ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(), recordsToWrite.getFirst(), byteArrayOutputStream);
|
||||
byte[] arr = byteArrayOutputStream.toByteArray();
|
||||
val read = ArrowConverter.readFromBytes(arr);
|
||||
assertEquals(recordsToWrite,read);
|
||||
|
||||
//send file
|
||||
assertEquals(recordsToWrite, read);
|
||||
// send file
|
||||
File tmp = tmpDataFile(recordsToWrite);
|
||||
RecordReader recordReader = new ArrowRecordReader();
|
||||
|
||||
recordReader.initialize(new FileSplit(tmp));
|
||||
|
||||
List<Writable> record = recordReader.next();
|
||||
assertEquals(2,record.size());
|
||||
|
||||
assertEquals(2, record.size());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRecordReaderMetaDataList() throws Exception {
|
||||
@DisplayName("Test Record Reader Meta Data List")
|
||||
void testRecordReaderMetaDataList() throws Exception {
|
||||
val recordsToWrite = recordToWrite();
|
||||
//send file
|
||||
// send file
|
||||
File tmp = tmpDataFile(recordsToWrite);
|
||||
RecordReader recordReader = new ArrowRecordReader();
|
||||
RecordMetaDataIndex recordMetaDataIndex = new RecordMetaDataIndex(0,tmp.toURI(),ArrowRecordReader.class);
|
||||
RecordMetaDataIndex recordMetaDataIndex = new RecordMetaDataIndex(0, tmp.toURI(), ArrowRecordReader.class);
|
||||
recordReader.loadFromMetaData(Arrays.<RecordMetaData>asList(recordMetaDataIndex));
|
||||
|
||||
Record record = recordReader.nextRecord();
|
||||
assertEquals(2,record.getRecord().size());
|
||||
|
||||
assertEquals(2, record.getRecord().size());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDates() {
|
||||
@DisplayName("Test Dates")
|
||||
void testDates() {
|
||||
Date now = new Date();
|
||||
BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE);
|
||||
TimeStampMilliVector timeStampMilliVector = ArrowConverter.vectorFor(bufferAllocator, "col1", new Date[]{now});
|
||||
assertEquals(now.getTime(),timeStampMilliVector.get(0));
|
||||
TimeStampMilliVector timeStampMilliVector = ArrowConverter.vectorFor(bufferAllocator, "col1", new Date[] { now });
|
||||
assertEquals(now.getTime(), timeStampMilliVector.get(0));
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testRecordReaderMetaData() throws Exception {
|
||||
@DisplayName("Test Record Reader Meta Data")
|
||||
void testRecordReaderMetaData() throws Exception {
|
||||
val recordsToWrite = recordToWrite();
|
||||
//send file
|
||||
// send file
|
||||
File tmp = tmpDataFile(recordsToWrite);
|
||||
RecordReader recordReader = new ArrowRecordReader();
|
||||
RecordMetaDataIndex recordMetaDataIndex = new RecordMetaDataIndex(0,tmp.toURI(),ArrowRecordReader.class);
|
||||
RecordMetaDataIndex recordMetaDataIndex = new RecordMetaDataIndex(0, tmp.toURI(), ArrowRecordReader.class);
|
||||
recordReader.loadFromMetaData(recordMetaDataIndex);
|
||||
|
||||
Record record = recordReader.nextRecord();
|
||||
assertEquals(2,record.getRecord().size());
|
||||
assertEquals(2, record.getRecord().size());
|
||||
}
|
||||
|
||||
private File tmpDataFile(Pair<Schema,List<List<Writable>>> recordsToWrite) throws IOException {
|
||||
|
||||
File f = testDir.newFolder();
|
||||
|
||||
//send file
|
||||
File tmp = new File(f,"tmp-file-" + UUID.randomUUID().toString());
|
||||
private File tmpDataFile(Pair<Schema, List<List<Writable>>> recordsToWrite) throws IOException {
|
||||
File f = testDir.toFile();
|
||||
// send file
|
||||
File tmp = new File(f, "tmp-file-" + UUID.randomUUID().toString());
|
||||
tmp.mkdirs();
|
||||
File tmpFile = new File(tmp,"data.arrow");
|
||||
File tmpFile = new File(tmp, "data.arrow");
|
||||
tmpFile.deleteOnExit();
|
||||
FileOutputStream bufferedOutputStream = new FileOutputStream(tmpFile);
|
||||
ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(),recordsToWrite.getFirst(),bufferedOutputStream);
|
||||
ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(), recordsToWrite.getFirst(), bufferedOutputStream);
|
||||
bufferedOutputStream.flush();
|
||||
bufferedOutputStream.close();
|
||||
return tmp;
|
||||
}
|
||||
|
||||
private Pair<Schema,List<List<Writable>>> recordToWrite() {
|
||||
private Pair<Schema, List<List<Writable>>> recordToWrite() {
|
||||
List<List<Writable>> records = new ArrayList<>();
|
||||
records.add(Arrays.<Writable>asList(new DoubleWritable(0.0),new DoubleWritable(0.0)));
|
||||
records.add(Arrays.<Writable>asList(new DoubleWritable(0.0),new DoubleWritable(0.0)));
|
||||
records.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new DoubleWritable(0.0)));
|
||||
records.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new DoubleWritable(0.0)));
|
||||
Schema.Builder schemaBuilder = new Schema.Builder();
|
||||
for(int i = 0; i < 2; i++) {
|
||||
for (int i = 0; i < 2; i++) {
|
||||
schemaBuilder.addColumnFloat("col-" + i);
|
||||
}
|
||||
|
||||
return Pair.of(schemaBuilder.build(),records);
|
||||
return Pair.of(schemaBuilder.build(), records);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.arrow;
|
||||
|
||||
import lombok.val;
|
||||
|
@ -34,132 +33,98 @@ import org.datavec.api.writable.IntWritable;
|
|||
import org.datavec.api.writable.Writable;
|
||||
import org.datavec.arrow.recordreader.ArrowRecordReader;
|
||||
import org.datavec.arrow.recordreader.ArrowRecordWriter;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.primitives.Triple;
|
||||
|
||||
import java.io.File;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
public class RecordMapperTest extends BaseND4JTest {
|
||||
@DisplayName("Record Mapper Test")
|
||||
class RecordMapperTest extends BaseND4JTest {
|
||||
|
||||
@Test
|
||||
public void testMultiWrite() throws Exception {
|
||||
@DisplayName("Test Multi Write")
|
||||
void testMultiWrite() throws Exception {
|
||||
val recordsPair = records();
|
||||
|
||||
Path p = Files.createTempFile("arrowwritetest", ".arrow");
|
||||
FileUtils.write(p.toFile(),recordsPair.getFirst());
|
||||
FileUtils.write(p.toFile(), recordsPair.getFirst());
|
||||
p.toFile().deleteOnExit();
|
||||
|
||||
int numReaders = 2;
|
||||
RecordReader[] readers = new RecordReader[numReaders];
|
||||
InputSplit[] splits = new InputSplit[numReaders];
|
||||
for(int i = 0; i < readers.length; i++) {
|
||||
for (int i = 0; i < readers.length; i++) {
|
||||
FileSplit split = new FileSplit(p.toFile());
|
||||
ArrowRecordReader arrowRecordReader = new ArrowRecordReader();
|
||||
readers[i] = arrowRecordReader;
|
||||
splits[i] = split;
|
||||
}
|
||||
|
||||
ArrowRecordWriter arrowRecordWriter = new ArrowRecordWriter(recordsPair.getMiddle());
|
||||
FileSplit split = new FileSplit(p.toFile());
|
||||
arrowRecordWriter.initialize(split,new NumberOfRecordsPartitioner());
|
||||
arrowRecordWriter.initialize(split, new NumberOfRecordsPartitioner());
|
||||
arrowRecordWriter.writeBatch(recordsPair.getRight());
|
||||
|
||||
|
||||
CSVRecordWriter csvRecordWriter = new CSVRecordWriter();
|
||||
Path p2 = Files.createTempFile("arrowwritetest", ".csv");
|
||||
FileUtils.write(p2.toFile(),recordsPair.getFirst());
|
||||
FileUtils.write(p2.toFile(), recordsPair.getFirst());
|
||||
p.toFile().deleteOnExit();
|
||||
FileSplit outputCsv = new FileSplit(p2.toFile());
|
||||
|
||||
RecordMapper mapper = RecordMapper.builder().batchSize(10).inputUrl(split)
|
||||
.outputUrl(outputCsv)
|
||||
.partitioner(new NumberOfRecordsPartitioner()).readersToConcat(readers)
|
||||
.splitPerReader(splits)
|
||||
.recordWriter(csvRecordWriter)
|
||||
.build();
|
||||
RecordMapper mapper = RecordMapper.builder().batchSize(10).inputUrl(split).outputUrl(outputCsv).partitioner(new NumberOfRecordsPartitioner()).readersToConcat(readers).splitPerReader(splits).recordWriter(csvRecordWriter).build();
|
||||
mapper.copy();
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testCopyFromArrowToCsv() throws Exception {
|
||||
@DisplayName("Test Copy From Arrow To Csv")
|
||||
void testCopyFromArrowToCsv() throws Exception {
|
||||
val recordsPair = records();
|
||||
|
||||
Path p = Files.createTempFile("arrowwritetest", ".arrow");
|
||||
FileUtils.write(p.toFile(),recordsPair.getFirst());
|
||||
FileUtils.write(p.toFile(), recordsPair.getFirst());
|
||||
p.toFile().deleteOnExit();
|
||||
|
||||
ArrowRecordWriter arrowRecordWriter = new ArrowRecordWriter(recordsPair.getMiddle());
|
||||
FileSplit split = new FileSplit(p.toFile());
|
||||
arrowRecordWriter.initialize(split,new NumberOfRecordsPartitioner());
|
||||
arrowRecordWriter.initialize(split, new NumberOfRecordsPartitioner());
|
||||
arrowRecordWriter.writeBatch(recordsPair.getRight());
|
||||
|
||||
|
||||
ArrowRecordReader arrowRecordReader = new ArrowRecordReader();
|
||||
arrowRecordReader.initialize(split);
|
||||
|
||||
|
||||
CSVRecordWriter csvRecordWriter = new CSVRecordWriter();
|
||||
Path p2 = Files.createTempFile("arrowwritetest", ".csv");
|
||||
FileUtils.write(p2.toFile(),recordsPair.getFirst());
|
||||
FileUtils.write(p2.toFile(), recordsPair.getFirst());
|
||||
p.toFile().deleteOnExit();
|
||||
FileSplit outputCsv = new FileSplit(p2.toFile());
|
||||
|
||||
RecordMapper mapper = RecordMapper.builder().batchSize(10).inputUrl(split)
|
||||
.outputUrl(outputCsv)
|
||||
.partitioner(new NumberOfRecordsPartitioner())
|
||||
.recordReader(arrowRecordReader).recordWriter(csvRecordWriter)
|
||||
.build();
|
||||
RecordMapper mapper = RecordMapper.builder().batchSize(10).inputUrl(split).outputUrl(outputCsv).partitioner(new NumberOfRecordsPartitioner()).recordReader(arrowRecordReader).recordWriter(csvRecordWriter).build();
|
||||
mapper.copy();
|
||||
|
||||
CSVRecordReader recordReader = new CSVRecordReader();
|
||||
recordReader.initialize(outputCsv);
|
||||
|
||||
|
||||
List<List<Writable>> loadedCSvRecords = recordReader.next(10);
|
||||
assertEquals(10,loadedCSvRecords.size());
|
||||
assertEquals(10, loadedCSvRecords.size());
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testCopyFromCsvToArrow() throws Exception {
|
||||
@DisplayName("Test Copy From Csv To Arrow")
|
||||
void testCopyFromCsvToArrow() throws Exception {
|
||||
val recordsPair = records();
|
||||
|
||||
Path p = Files.createTempFile("csvwritetest", ".csv");
|
||||
FileUtils.write(p.toFile(),recordsPair.getFirst());
|
||||
FileUtils.write(p.toFile(), recordsPair.getFirst());
|
||||
p.toFile().deleteOnExit();
|
||||
|
||||
|
||||
CSVRecordReader recordReader = new CSVRecordReader();
|
||||
FileSplit fileSplit = new FileSplit(p.toFile());
|
||||
|
||||
ArrowRecordWriter arrowRecordWriter = new ArrowRecordWriter(recordsPair.getMiddle());
|
||||
File outputFile = Files.createTempFile("outputarrow","arrow").toFile();
|
||||
File outputFile = Files.createTempFile("outputarrow", "arrow").toFile();
|
||||
FileSplit outputFileSplit = new FileSplit(outputFile);
|
||||
RecordMapper mapper = RecordMapper.builder().batchSize(10).inputUrl(fileSplit)
|
||||
.outputUrl(outputFileSplit).partitioner(new NumberOfRecordsPartitioner())
|
||||
.recordReader(recordReader).recordWriter(arrowRecordWriter)
|
||||
.build();
|
||||
RecordMapper mapper = RecordMapper.builder().batchSize(10).inputUrl(fileSplit).outputUrl(outputFileSplit).partitioner(new NumberOfRecordsPartitioner()).recordReader(recordReader).recordWriter(arrowRecordWriter).build();
|
||||
mapper.copy();
|
||||
|
||||
ArrowRecordReader arrowRecordReader = new ArrowRecordReader();
|
||||
arrowRecordReader.initialize(outputFileSplit);
|
||||
List<List<Writable>> next = arrowRecordReader.next(10);
|
||||
System.out.println(next);
|
||||
assertEquals(10,next.size());
|
||||
|
||||
assertEquals(10, next.size());
|
||||
}
|
||||
|
||||
private Triple<String,Schema,List<List<Writable>>> records() {
|
||||
private Triple<String, Schema, List<List<Writable>>> records() {
|
||||
List<List<Writable>> list = new ArrayList<>();
|
||||
StringBuilder sb = new StringBuilder();
|
||||
int numColumns = 3;
|
||||
|
@ -176,15 +141,10 @@ public class RecordMapperTest extends BaseND4JTest {
|
|||
}
|
||||
list.add(temp);
|
||||
}
|
||||
|
||||
|
||||
Schema.Builder schemaBuilder = new Schema.Builder();
|
||||
for(int i = 0; i < numColumns; i++) {
|
||||
for (int i = 0; i < numColumns; i++) {
|
||||
schemaBuilder.addColumnInteger(String.valueOf(i));
|
||||
}
|
||||
|
||||
return Triple.of(sb.toString(),schemaBuilder.build(),list);
|
||||
return Triple.of(sb.toString(), schemaBuilder.build(), list);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -29,16 +29,16 @@ import org.datavec.api.writable.IntWritable;
|
|||
import org.datavec.api.writable.Text;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.datavec.arrow.ArrowConverter;
|
||||
import org.junit.Ignore;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
|
||||
public class ArrowWritableRecordTimeSeriesBatchTests extends BaseND4JTest {
|
||||
|
||||
|
@ -46,6 +46,7 @@ public class ArrowWritableRecordTimeSeriesBatchTests extends BaseND4JTest {
|
|||
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testBasicIndexing() {
|
||||
Schema.Builder schema = new Schema.Builder();
|
||||
for(int i = 0; i < 3; i++) {
|
||||
|
@ -54,9 +55,9 @@ public class ArrowWritableRecordTimeSeriesBatchTests extends BaseND4JTest {
|
|||
|
||||
|
||||
List<List<Writable>> timeStep = Arrays.asList(
|
||||
Arrays.<Writable>asList(new IntWritable(0),new IntWritable(1),new IntWritable(2)),
|
||||
Arrays.<Writable>asList(new IntWritable(1),new IntWritable(2),new IntWritable(3)),
|
||||
Arrays.<Writable>asList(new IntWritable(4),new IntWritable(5),new IntWritable(6))
|
||||
Arrays.asList(new IntWritable(0),new IntWritable(1),new IntWritable(2)),
|
||||
Arrays.asList(new IntWritable(1),new IntWritable(2),new IntWritable(3)),
|
||||
Arrays.asList(new IntWritable(4),new IntWritable(5),new IntWritable(6))
|
||||
);
|
||||
|
||||
int numTimeSteps = 5;
|
||||
|
@ -69,7 +70,7 @@ public class ArrowWritableRecordTimeSeriesBatchTests extends BaseND4JTest {
|
|||
assertEquals(3,fieldVectors.size());
|
||||
for(FieldVector fieldVector : fieldVectors) {
|
||||
for(int i = 0; i < fieldVector.getValueCount(); i++) {
|
||||
assertFalse("Index " + i + " was null for field vector " + fieldVector, fieldVector.isNull(i));
|
||||
assertFalse( fieldVector.isNull(i),"Index " + i + " was null for field vector " + fieldVector);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -79,7 +80,7 @@ public class ArrowWritableRecordTimeSeriesBatchTests extends BaseND4JTest {
|
|||
|
||||
@Test
|
||||
//not worried about this till after next release
|
||||
@Ignore
|
||||
@Disabled
|
||||
public void testVariableLengthTS() {
|
||||
Schema.Builder schema = new Schema.Builder()
|
||||
.addColumnString("str")
|
||||
|
|
|
@ -119,10 +119,10 @@
|
|||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<id>nd4j-tests-cpu</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
<id>nd4j-tests-cuda</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
||||
|
|
|
@ -17,41 +17,39 @@
|
|||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.image;
|
||||
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.datavec.api.io.labels.ParentPathLabelGenerator;
|
||||
import org.datavec.api.split.FileSplit;
|
||||
import org.datavec.image.recordreader.ImageRecordReader;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
import java.io.File;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import java.nio.file.Path;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
@DisplayName("Label Generator Test")
|
||||
class LabelGeneratorTest {
|
||||
|
||||
public class LabelGeneratorTest {
|
||||
|
||||
@Rule
|
||||
public TemporaryFolder testDir = new TemporaryFolder();
|
||||
|
||||
@Test
|
||||
public void testParentPathLabelGenerator() throws Exception {
|
||||
//https://github.com/deeplearning4j/DataVec/issues/273
|
||||
@DisplayName("Test Parent Path Label Generator")
|
||||
@Disabled
|
||||
void testParentPathLabelGenerator(@TempDir Path testDir) throws Exception {
|
||||
File orig = new ClassPathResource("datavec-data-image/testimages/class0/0.jpg").getFile();
|
||||
|
||||
for(String dirPrefix : new String[]{"m.", "m"}) {
|
||||
File f = testDir.newFolder();
|
||||
|
||||
for (String dirPrefix : new String[] { "m.", "m" }) {
|
||||
File f = testDir.toFile();
|
||||
int numDirs = 3;
|
||||
int filesPerDir = 4;
|
||||
|
||||
for (int i = 0; i < numDirs; i++) {
|
||||
File currentLabelDir = new File(f, dirPrefix + i);
|
||||
currentLabelDir.mkdirs();
|
||||
|
@ -61,14 +59,11 @@ public class LabelGeneratorTest {
|
|||
assertTrue(f3.exists());
|
||||
}
|
||||
}
|
||||
|
||||
ImageRecordReader rr = new ImageRecordReader(28, 28, 1, new ParentPathLabelGenerator());
|
||||
rr.initialize(new FileSplit(f));
|
||||
|
||||
List<String> labelsAct = rr.getLabels();
|
||||
List<String> labelsExp = Arrays.asList(dirPrefix + "0", dirPrefix + "1", dirPrefix + "2");
|
||||
assertEquals(labelsExp, labelsAct);
|
||||
|
||||
int expCount = numDirs * filesPerDir;
|
||||
int actCount = 0;
|
||||
while (rr.hasNext()) {
|
||||
|
|
|
@ -22,8 +22,8 @@ package org.datavec.image.loader;
|
|||
|
||||
import org.apache.commons.io.FilenameUtils;
|
||||
import org.datavec.api.records.reader.RecordReader;
|
||||
import org.junit.Ignore;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.linalg.dataset.DataSet;
|
||||
|
||||
import java.io.File;
|
||||
|
@ -32,9 +32,9 @@ import java.io.InputStream;
|
|||
import java.util.List;
|
||||
import java.util.Random;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
/**
|
||||
*
|
||||
|
@ -182,7 +182,7 @@ public class LoaderTests {
|
|||
}
|
||||
|
||||
|
||||
@Ignore // Use when confirming data is getting stored
|
||||
@Disabled // Use when confirming data is getting stored
|
||||
@Test
|
||||
public void testProcessCifar() {
|
||||
int row = 32;
|
||||
|
@ -208,15 +208,15 @@ public class LoaderTests {
|
|||
int minibatch = 100;
|
||||
int nMinibatches = 50000 / minibatch;
|
||||
|
||||
for( int i=0; i<nMinibatches; i++ ){
|
||||
for( int i=0; i < nMinibatches; i++) {
|
||||
DataSet ds = loader.next(minibatch);
|
||||
String s = String.valueOf(i);
|
||||
assertNotNull(s, ds.getFeatures());
|
||||
assertNotNull(s, ds.getLabels());
|
||||
assertNotNull(ds.getFeatures(),s);
|
||||
assertNotNull(ds.getLabels(),s);
|
||||
|
||||
assertEquals(s, minibatch, ds.getFeatures().size(0));
|
||||
assertEquals(s, minibatch, ds.getLabels().size(0));
|
||||
assertEquals(s, 10, ds.getLabels().size(1));
|
||||
assertEquals(minibatch, ds.getFeatures().size(0),s);
|
||||
assertEquals(minibatch, ds.getLabels().size(0),s);
|
||||
assertEquals(10, ds.getLabels().size(1),s);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
package org.datavec.image.loader;
|
||||
|
||||
import org.datavec.image.data.Image;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.resources.Resources;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
|
@ -32,7 +32,7 @@ import java.io.FileInputStream;
|
|||
import java.io.InputStream;
|
||||
import java.util.Random;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
|
||||
public class TestImageLoader {
|
||||
|
|
|
@ -30,9 +30,10 @@ import org.bytedeco.javacv.Java2DFrameConverter;
|
|||
import org.bytedeco.javacv.OpenCVFrameConverter;
|
||||
import org.datavec.image.data.Image;
|
||||
import org.datavec.image.data.ImageWritable;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.nd4j.common.resources.Resources;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -42,16 +43,17 @@ import org.nd4j.common.io.ClassPathResource;
|
|||
import java.awt.image.BufferedImage;
|
||||
import java.io.*;
|
||||
import java.lang.reflect.Field;
|
||||
import java.nio.file.Path;
|
||||
import java.util.Random;
|
||||
|
||||
import org.bytedeco.leptonica.*;
|
||||
import org.bytedeco.opencv.opencv_core.*;
|
||||
import static org.bytedeco.leptonica.global.lept.*;
|
||||
import static org.bytedeco.opencv.global.opencv_core.*;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertNotEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import static org.junit.Assert.fail;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertNotEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
import static org.junit.jupiter.api.Assertions.fail;
|
||||
|
||||
/**
|
||||
*
|
||||
|
@ -62,8 +64,6 @@ public class TestNativeImageLoader {
|
|||
static final long seed = 10;
|
||||
static final Random rng = new Random(seed);
|
||||
|
||||
@Rule
|
||||
public TemporaryFolder testDir = new TemporaryFolder();
|
||||
|
||||
@Test
|
||||
public void testConvertPix() throws Exception {
|
||||
|
@ -566,8 +566,8 @@ public class TestNativeImageLoader {
|
|||
|
||||
|
||||
@Test
|
||||
public void testNativeImageLoaderEmptyStreams() throws Exception {
|
||||
File dir = testDir.newFolder();
|
||||
public void testNativeImageLoaderEmptyStreams(@TempDir Path testDir) throws Exception {
|
||||
File dir = testDir.toFile();
|
||||
File f = new File(dir, "myFile.jpg");
|
||||
f.createNewFile();
|
||||
|
||||
|
@ -578,7 +578,7 @@ public class TestNativeImageLoader {
|
|||
fail("Expected exception");
|
||||
} catch (IOException e){
|
||||
String msg = e.getMessage();
|
||||
assertTrue(msg, msg.contains("decode image"));
|
||||
assertTrue(msg.contains("decode image"),msg);
|
||||
}
|
||||
|
||||
try(InputStream is = new FileInputStream(f)){
|
||||
|
@ -586,7 +586,7 @@ public class TestNativeImageLoader {
|
|||
fail("Expected exception");
|
||||
} catch (IOException e){
|
||||
String msg = e.getMessage();
|
||||
assertTrue(msg, msg.contains("decode image"));
|
||||
assertTrue(msg.contains("decode image"),msg);
|
||||
}
|
||||
|
||||
try(InputStream is = new FileInputStream(f)){
|
||||
|
@ -594,7 +594,7 @@ public class TestNativeImageLoader {
|
|||
fail("Expected exception");
|
||||
} catch (IOException e){
|
||||
String msg = e.getMessage();
|
||||
assertTrue(msg, msg.contains("decode image"));
|
||||
assertTrue(msg.contains("decode image"),msg);
|
||||
}
|
||||
|
||||
try(InputStream is = new FileInputStream(f)){
|
||||
|
@ -603,7 +603,7 @@ public class TestNativeImageLoader {
|
|||
fail("Expected exception");
|
||||
} catch (IOException e){
|
||||
String msg = e.getMessage();
|
||||
assertTrue(msg, msg.contains("decode image"));
|
||||
assertTrue( msg.contains("decode image"),msg);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.image.recordreader;
|
||||
|
||||
import org.apache.commons.io.FileUtils;
|
||||
|
@ -28,61 +27,56 @@ import org.datavec.api.writable.IntWritable;
|
|||
import org.datavec.api.writable.NDArrayWritable;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.datavec.image.loader.NativeImageLoader;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.nd4j.common.loader.FileBatch;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
|
||||
import java.io.File;
|
||||
import java.util.*;
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import java.nio.file.Path;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
@DisplayName("File Batch Record Reader Test")
|
||||
class FileBatchRecordReaderTest {
|
||||
|
||||
public class FileBatchRecordReaderTest {
|
||||
|
||||
@Rule
|
||||
public TemporaryFolder testDir = new TemporaryFolder();
|
||||
@TempDir
|
||||
public Path testDir;
|
||||
|
||||
@Test
|
||||
public void testCsv() throws Exception {
|
||||
File extractedSourceDir = testDir.newFolder();
|
||||
@DisplayName("Test Csv")
|
||||
void testCsv(@TempDir Path testDir,@TempDir Path baseDirPath) throws Exception {
|
||||
File extractedSourceDir = testDir.toFile();
|
||||
new ClassPathResource("datavec-data-image/testimages").copyDirectory(extractedSourceDir);
|
||||
File baseDir = testDir.newFolder();
|
||||
|
||||
|
||||
File baseDir = baseDirPath.toFile();
|
||||
List<File> c = new ArrayList<>(FileUtils.listFiles(extractedSourceDir, null, true));
|
||||
assertEquals(6, c.size());
|
||||
|
||||
Collections.sort(c, new Comparator<File>() {
|
||||
|
||||
@Override
|
||||
public int compare(File o1, File o2) {
|
||||
return o1.getPath().compareTo(o2.getPath());
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
FileBatch fb = FileBatch.forFiles(c);
|
||||
File saveFile = new File(baseDir, "saved.zip");
|
||||
fb.writeAsZip(saveFile);
|
||||
fb = FileBatch.readFromZip(saveFile);
|
||||
|
||||
PathLabelGenerator labelMaker = new ParentPathLabelGenerator();
|
||||
ImageRecordReader rr = new ImageRecordReader(32, 32, 1, labelMaker);
|
||||
rr.setLabels(Arrays.asList("class0", "class1"));
|
||||
FileBatchRecordReader fbrr = new FileBatchRecordReader(rr, fb);
|
||||
|
||||
|
||||
NativeImageLoader il = new NativeImageLoader(32, 32, 1);
|
||||
for( int test=0; test<3; test++) {
|
||||
for (int test = 0; test < 3; test++) {
|
||||
for (int i = 0; i < 6; i++) {
|
||||
assertTrue(fbrr.hasNext());
|
||||
List<Writable> next = fbrr.next();
|
||||
assertEquals(2, next.size());
|
||||
|
||||
INDArray exp;
|
||||
switch (i){
|
||||
switch(i) {
|
||||
case 0:
|
||||
exp = il.asMatrix(new File(extractedSourceDir, "class0/0.jpg"));
|
||||
break;
|
||||
|
@ -105,8 +99,7 @@ public class FileBatchRecordReaderTest {
|
|||
throw new RuntimeException();
|
||||
}
|
||||
Writable expLabel = (i < 3 ? new IntWritable(0) : new IntWritable(1));
|
||||
|
||||
assertEquals(((NDArrayWritable)next.get(0)).get(), exp);
|
||||
assertEquals(((NDArrayWritable) next.get(0)).get(), exp);
|
||||
assertEquals(expLabel, next.get(1));
|
||||
}
|
||||
assertFalse(fbrr.hasNext());
|
||||
|
@ -114,5 +107,4 @@ public class FileBatchRecordReaderTest {
|
|||
fbrr.reset();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -36,9 +36,10 @@ import org.datavec.api.writable.DoubleWritable;
|
|||
import org.datavec.api.writable.NDArrayWritable;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.datavec.api.writable.batch.NDArrayRecordBatch;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -46,28 +47,30 @@ import org.nd4j.common.io.ClassPathResource;
|
|||
|
||||
import java.io.*;
|
||||
import java.net.URI;
|
||||
import java.nio.file.Path;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
public class TestImageRecordReader {
|
||||
|
||||
@Rule
|
||||
public TemporaryFolder testDir = new TemporaryFolder();
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
@Test()
|
||||
public void testEmptySplit() throws IOException {
|
||||
InputSplit data = new CollectionInputSplit(new ArrayList<URI>());
|
||||
assertThrows(IllegalArgumentException.class,() -> {
|
||||
InputSplit data = new CollectionInputSplit(new ArrayList<>());
|
||||
new ImageRecordReader().initialize(data, null);
|
||||
});
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMetaData() throws IOException {
|
||||
public void testMetaData(@TempDir Path testDir) throws IOException {
|
||||
|
||||
File parentDir = testDir.newFolder();
|
||||
File parentDir = testDir.toFile();
|
||||
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(parentDir);
|
||||
// System.out.println(f.getAbsolutePath());
|
||||
// System.out.println(f.getParentFile().getParentFile().getAbsolutePath());
|
||||
|
@ -104,11 +107,11 @@ public class TestImageRecordReader {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testImageRecordReaderLabelsOrder() throws Exception {
|
||||
public void testImageRecordReaderLabelsOrder(@TempDir Path testDir) throws Exception {
|
||||
//Labels order should be consistent, regardless of file iteration order
|
||||
|
||||
//Idea: labels order should be consistent regardless of input file order
|
||||
File f = testDir.newFolder();
|
||||
File f = testDir.toFile();
|
||||
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f);
|
||||
File f0 = new File(f, "/class0/0.jpg");
|
||||
File f1 = new File(f, "/class1/A.jpg");
|
||||
|
@ -135,11 +138,11 @@ public class TestImageRecordReader {
|
|||
|
||||
|
||||
@Test
|
||||
public void testImageRecordReaderRandomization() throws Exception {
|
||||
public void testImageRecordReaderRandomization(@TempDir Path testDir) throws Exception {
|
||||
//Order of FileSplit+ImageRecordReader should be different after reset
|
||||
|
||||
//Idea: labels order should be consistent regardless of input file order
|
||||
File f0 = testDir.newFolder();
|
||||
File f0 = testDir.toFile();
|
||||
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f0);
|
||||
|
||||
FileSplit fs = new FileSplit(f0, new Random(12345));
|
||||
|
@ -189,13 +192,13 @@ public class TestImageRecordReader {
|
|||
|
||||
|
||||
@Test
|
||||
public void testImageRecordReaderRegression() throws Exception {
|
||||
public void testImageRecordReaderRegression(@TempDir Path testDir) throws Exception {
|
||||
|
||||
PathLabelGenerator regressionLabelGen = new TestRegressionLabelGen();
|
||||
|
||||
ImageRecordReader rr = new ImageRecordReader(28, 28, 3, regressionLabelGen);
|
||||
|
||||
File rootDir = testDir.newFolder();
|
||||
File rootDir = testDir.toFile();
|
||||
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(rootDir);
|
||||
FileSplit fs = new FileSplit(rootDir);
|
||||
rr.initialize(fs);
|
||||
|
@ -244,10 +247,10 @@ public class TestImageRecordReader {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testListenerInvocationBatch() throws IOException {
|
||||
public void testListenerInvocationBatch(@TempDir Path testDir) throws IOException {
|
||||
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
|
||||
ImageRecordReader rr = new ImageRecordReader(32, 32, 3, labelMaker);
|
||||
File f = testDir.newFolder();
|
||||
File f = testDir.toFile();
|
||||
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f);
|
||||
|
||||
File parent = f;
|
||||
|
@ -260,10 +263,10 @@ public class TestImageRecordReader {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testListenerInvocationSingle() throws IOException {
|
||||
public void testListenerInvocationSingle(@TempDir Path testDir) throws IOException {
|
||||
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
|
||||
ImageRecordReader rr = new ImageRecordReader(32, 32, 3, labelMaker);
|
||||
File parent = testDir.newFolder();
|
||||
File parent = testDir.toFile();
|
||||
new ClassPathResource("datavec-data-image/testimages/class0/").copyDirectory(parent);
|
||||
int numFiles = parent.list().length;
|
||||
rr.initialize(new FileSplit(parent));
|
||||
|
@ -315,7 +318,7 @@ public class TestImageRecordReader {
|
|||
|
||||
|
||||
@Test
|
||||
public void testImageRecordReaderPathMultiLabelGenerator() throws Exception {
|
||||
public void testImageRecordReaderPathMultiLabelGenerator(@TempDir Path testDir) throws Exception {
|
||||
Nd4j.setDataType(DataType.FLOAT);
|
||||
//Assumption: 2 multi-class (one hot) classification labels: 2 and 3 classes respectively
|
||||
// PLUS single value (Writable) regression label
|
||||
|
@ -324,7 +327,7 @@ public class TestImageRecordReader {
|
|||
|
||||
ImageRecordReader rr = new ImageRecordReader(28, 28, 3, multiLabelGen);
|
||||
|
||||
File rootDir = testDir.newFolder();
|
||||
File rootDir = testDir.toFile();
|
||||
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(rootDir);
|
||||
FileSplit fs = new FileSplit(rootDir);
|
||||
rr.initialize(fs);
|
||||
|
@ -471,9 +474,9 @@ public class TestImageRecordReader {
|
|||
|
||||
|
||||
@Test
|
||||
public void testNCHW_NCHW() throws Exception {
|
||||
public void testNCHW_NCHW(@TempDir Path testDir) throws Exception {
|
||||
//Idea: labels order should be consistent regardless of input file order
|
||||
File f0 = testDir.newFolder();
|
||||
File f0 = testDir.toFile();
|
||||
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f0);
|
||||
|
||||
FileSplit fs0 = new FileSplit(f0, new Random(12345));
|
||||
|
|
|
@ -35,9 +35,10 @@ import org.datavec.image.transform.FlipImageTransform;
|
|||
import org.datavec.image.transform.ImageTransform;
|
||||
import org.datavec.image.transform.PipelineImageTransform;
|
||||
import org.datavec.image.transform.ResizeImageTransform;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.indexing.BooleanIndexing;
|
||||
|
@ -46,24 +47,24 @@ import org.nd4j.common.io.ClassPathResource;
|
|||
|
||||
import java.io.File;
|
||||
import java.net.URI;
|
||||
import java.nio.file.Path;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
public class TestObjectDetectionRecordReader {
|
||||
|
||||
@Rule
|
||||
public TemporaryFolder testDir = new TemporaryFolder();
|
||||
|
||||
|
||||
@Test
|
||||
public void test() throws Exception {
|
||||
public void test(@TempDir Path testDir) throws Exception {
|
||||
for(boolean nchw : new boolean[]{true, false}) {
|
||||
ImageObjectLabelProvider lp = new TestImageObjectDetectionLabelProvider();
|
||||
|
||||
File f = testDir.newFolder();
|
||||
File f = testDir.toFile();
|
||||
new ClassPathResource("datavec-data-image/objdetect/").copyDirectory(f);
|
||||
|
||||
String path = new File(f, "000012.jpg").getParent();
|
||||
|
|
|
@ -21,27 +21,27 @@
|
|||
package org.datavec.image.recordreader.objdetect;
|
||||
|
||||
import org.datavec.image.recordreader.objdetect.impl.VocLabelProvider;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
|
||||
import java.io.File;
|
||||
import java.nio.file.Path;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
public class TestVocLabelProvider {
|
||||
|
||||
@Rule
|
||||
public TemporaryFolder testDir = new TemporaryFolder();
|
||||
|
||||
@Test
|
||||
public void testVocLabelProvider() throws Exception {
|
||||
public void testVocLabelProvider(@TempDir Path testDir) throws Exception {
|
||||
|
||||
File f = testDir.newFolder();
|
||||
File f = testDir.toFile();
|
||||
new ClassPathResource("datavec-data-image/voc/2007/").copyDirectory(f);
|
||||
|
||||
String path = f.getAbsolutePath(); //new ClassPathResource("voc/2007/JPEGImages/000005.jpg").getFile().getParentFile().getParent();
|
||||
|
|
|
@ -17,106 +17,70 @@
|
|||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.image.transform;
|
||||
|
||||
import org.datavec.image.data.ImageWritable;
|
||||
import org.junit.Test;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
@DisplayName("Json Yaml Test")
|
||||
class JsonYamlTest {
|
||||
|
||||
public class JsonYamlTest {
|
||||
@Test
|
||||
public void testJsonYamlImageTransformProcess() throws IOException {
|
||||
@DisplayName("Test Json Yaml Image Transform Process")
|
||||
void testJsonYamlImageTransformProcess() throws IOException {
|
||||
int seed = 12345;
|
||||
Random random = new Random(seed);
|
||||
|
||||
//from org.bytedeco.javacpp.opencv_imgproc
|
||||
// from org.bytedeco.javacpp.opencv_imgproc
|
||||
int COLOR_BGR2Luv = 50;
|
||||
int CV_BGR2GRAY = 6;
|
||||
|
||||
|
||||
ImageTransformProcess itp = new ImageTransformProcess.Builder().colorConversionTransform(COLOR_BGR2Luv)
|
||||
.cropImageTransform(10).equalizeHistTransform(CV_BGR2GRAY).flipImageTransform(0)
|
||||
.resizeImageTransform(300, 300).rotateImageTransform(30).scaleImageTransform(3)
|
||||
.warpImageTransform((float) 0.5)
|
||||
|
||||
// Note : since randomCropTransform use random value
|
||||
// the results from each case(json, yaml, ImageTransformProcess)
|
||||
// can be different
|
||||
// don't use the below line
|
||||
// if you uncomment it, you will get fail from below assertions
|
||||
// .randomCropTransform(seed, 50, 50)
|
||||
|
||||
// Note : you will get "java.lang.NoClassDefFoundError: Could not initialize class org.bytedeco.javacpp.avutil"
|
||||
// it needs to add the below dependency
|
||||
// <dependency>
|
||||
// <groupId>org.bytedeco</groupId>
|
||||
// <artifactId>ffmpeg-platform</artifactId>
|
||||
// </dependency>
|
||||
// FFmpeg has license issues, be careful to use it
|
||||
//.filterImageTransform("noise=alls=20:allf=t+u,format=rgba", 100, 100, 4)
|
||||
|
||||
.build();
|
||||
|
||||
ImageTransformProcess itp = new ImageTransformProcess.Builder().colorConversionTransform(COLOR_BGR2Luv).cropImageTransform(10).equalizeHistTransform(CV_BGR2GRAY).flipImageTransform(0).resizeImageTransform(300, 300).rotateImageTransform(30).scaleImageTransform(3).warpImageTransform((float) 0.5).build();
|
||||
String asJson = itp.toJson();
|
||||
String asYaml = itp.toYaml();
|
||||
|
||||
// System.out.println(asJson);
|
||||
// System.out.println("\n\n\n");
|
||||
// System.out.println(asYaml);
|
||||
|
||||
// System.out.println(asJson);
|
||||
// System.out.println("\n\n\n");
|
||||
// System.out.println(asYaml);
|
||||
ImageWritable img = TestImageTransform.makeRandomImage(0, 0, 3);
|
||||
ImageWritable imgJson = new ImageWritable(img.getFrame().clone());
|
||||
ImageWritable imgYaml = new ImageWritable(img.getFrame().clone());
|
||||
ImageWritable imgAll = new ImageWritable(img.getFrame().clone());
|
||||
|
||||
ImageTransformProcess itpFromJson = ImageTransformProcess.fromJson(asJson);
|
||||
ImageTransformProcess itpFromYaml = ImageTransformProcess.fromYaml(asYaml);
|
||||
|
||||
List<ImageTransform> transformList = itp.getTransformList();
|
||||
List<ImageTransform> transformListJson = itpFromJson.getTransformList();
|
||||
List<ImageTransform> transformListYaml = itpFromYaml.getTransformList();
|
||||
|
||||
for (int i = 0; i < transformList.size(); i++) {
|
||||
ImageTransform it = transformList.get(i);
|
||||
ImageTransform itJson = transformListJson.get(i);
|
||||
ImageTransform itYaml = transformListYaml.get(i);
|
||||
|
||||
System.out.println(i + "\t" + it);
|
||||
|
||||
img = it.transform(img);
|
||||
imgJson = itJson.transform(imgJson);
|
||||
imgYaml = itYaml.transform(imgYaml);
|
||||
|
||||
if (it instanceof RandomCropTransform) {
|
||||
assertTrue(img.getFrame().imageHeight == imgJson.getFrame().imageHeight);
|
||||
assertTrue(img.getFrame().imageWidth == imgJson.getFrame().imageWidth);
|
||||
|
||||
assertTrue(img.getFrame().imageHeight == imgYaml.getFrame().imageHeight);
|
||||
assertTrue(img.getFrame().imageWidth == imgYaml.getFrame().imageWidth);
|
||||
} else if (it instanceof FilterImageTransform) {
|
||||
assertEquals(img.getFrame().imageHeight, imgJson.getFrame().imageHeight);
|
||||
assertEquals(img.getFrame().imageWidth, imgJson.getFrame().imageWidth);
|
||||
assertEquals(img.getFrame().imageChannels, imgJson.getFrame().imageChannels);
|
||||
|
||||
assertEquals(img.getFrame().imageHeight, imgYaml.getFrame().imageHeight);
|
||||
assertEquals(img.getFrame().imageWidth, imgYaml.getFrame().imageWidth);
|
||||
assertEquals(img.getFrame().imageChannels, imgYaml.getFrame().imageChannels);
|
||||
} else {
|
||||
assertEquals(img, imgJson);
|
||||
|
||||
assertEquals(img, imgYaml);
|
||||
}
|
||||
}
|
||||
|
||||
imgAll = itp.execute(imgAll);
|
||||
|
||||
assertEquals(imgAll, img);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,56 +17,50 @@
|
|||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.image.transform;
|
||||
|
||||
import org.bytedeco.javacv.Frame;
|
||||
import org.datavec.image.data.ImageWritable;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
public class ResizeImageTransformTest {
|
||||
@Before
|
||||
public void setUp() throws Exception {
|
||||
@DisplayName("Resize Image Transform Test")
|
||||
class ResizeImageTransformTest {
|
||||
|
||||
@BeforeEach
|
||||
void setUp() throws Exception {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testResizeUpscale1() throws Exception {
|
||||
@DisplayName("Test Resize Upscale 1")
|
||||
void testResizeUpscale1() throws Exception {
|
||||
ImageWritable srcImg = TestImageTransform.makeRandomImage(32, 32, 3);
|
||||
|
||||
ResizeImageTransform transform = new ResizeImageTransform(200, 200);
|
||||
|
||||
ImageWritable dstImg = transform.transform(srcImg);
|
||||
|
||||
Frame f = dstImg.getFrame();
|
||||
assertEquals(f.imageWidth, 200);
|
||||
assertEquals(f.imageHeight, 200);
|
||||
|
||||
float[] coordinates = {100, 200};
|
||||
float[] coordinates = { 100, 200 };
|
||||
float[] transformed = transform.query(coordinates);
|
||||
assertEquals(200f * 100 / 32, transformed[0], 0);
|
||||
assertEquals(200f * 200 / 32, transformed[1], 0);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testResizeDownscale() throws Exception {
|
||||
@DisplayName("Test Resize Downscale")
|
||||
void testResizeDownscale() throws Exception {
|
||||
ImageWritable srcImg = TestImageTransform.makeRandomImage(571, 443, 3);
|
||||
|
||||
ResizeImageTransform transform = new ResizeImageTransform(200, 200);
|
||||
|
||||
ImageWritable dstImg = transform.transform(srcImg);
|
||||
|
||||
Frame f = dstImg.getFrame();
|
||||
assertEquals(f.imageWidth, 200);
|
||||
assertEquals(f.imageHeight, 200);
|
||||
|
||||
float[] coordinates = {300, 400};
|
||||
float[] coordinates = { 300, 400 };
|
||||
float[] transformed = transform.query(coordinates);
|
||||
assertEquals(200f * 300 / 443, transformed[0], 0);
|
||||
assertEquals(200f * 400 / 571, transformed[1], 0);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -28,8 +28,8 @@ import org.nd4j.common.io.ClassPathResource;
|
|||
import org.nd4j.common.primitives.Pair;
|
||||
import org.datavec.image.data.ImageWritable;
|
||||
import org.datavec.image.loader.NativeImageLoader;
|
||||
import org.junit.Ignore;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.awt.*;
|
||||
import java.util.LinkedList;
|
||||
|
@ -40,7 +40,7 @@ import org.bytedeco.opencv.opencv_core.*;
|
|||
|
||||
import static org.bytedeco.opencv.global.opencv_core.*;
|
||||
import static org.bytedeco.opencv.global.opencv_imgproc.*;
|
||||
import static org.junit.Assert.*;
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
/**
|
||||
*
|
||||
|
@ -255,7 +255,7 @@ public class TestImageTransform {
|
|||
assertEquals(22, transformed[1], 0);
|
||||
}
|
||||
|
||||
@Ignore
|
||||
@Disabled
|
||||
@Test
|
||||
public void testFilterImageTransform() throws Exception {
|
||||
ImageWritable writable = makeRandomImage(0, 0, 4);
|
||||
|
|
|
@ -59,10 +59,10 @@
|
|||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<id>nd4j-tests-cpu</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
<id>nd4j-tests-cuda</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
||||
|
|
|
@ -57,10 +57,10 @@
|
|||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<id>nd4j-tests-cpu</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
<id>nd4j-tests-cuda</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
||||
|
|
|
@ -17,37 +17,34 @@
|
|||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.poi.excel;
|
||||
|
||||
import org.datavec.api.records.reader.RecordReader;
|
||||
import org.datavec.api.split.FileSplit;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
|
||||
import java.util.List;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
public class ExcelRecordReaderTest {
|
||||
@DisplayName("Excel Record Reader Test")
|
||||
class ExcelRecordReaderTest {
|
||||
|
||||
@Test
|
||||
public void testSimple() throws Exception {
|
||||
@DisplayName("Test Simple")
|
||||
void testSimple() throws Exception {
|
||||
RecordReader excel = new ExcelRecordReader();
|
||||
excel.initialize(new FileSplit(new ClassPathResource("datavec-excel/testsheet.xlsx").getFile()));
|
||||
assertTrue(excel.hasNext());
|
||||
List<Writable> next = excel.next();
|
||||
assertEquals(3,next.size());
|
||||
|
||||
assertEquals(3, next.size());
|
||||
RecordReader headerReader = new ExcelRecordReader(1);
|
||||
headerReader.initialize(new FileSplit(new ClassPathResource("datavec-excel/testsheetheader.xlsx").getFile()));
|
||||
assertTrue(excel.hasNext());
|
||||
List<Writable> next2 = excel.next();
|
||||
assertEquals(3,next2.size());
|
||||
|
||||
|
||||
assertEquals(3, next2.size());
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.poi.excel;
|
||||
|
||||
import lombok.val;
|
||||
|
@ -26,44 +25,45 @@ import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
|
|||
import org.datavec.api.transform.schema.Schema;
|
||||
import org.datavec.api.writable.IntWritable;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
import org.nd4j.common.primitives.Triple;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.nd4j.common.primitives.Triple;
|
||||
import java.io.File;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import java.nio.file.Path;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
@DisplayName("Excel Record Writer Test")
|
||||
class ExcelRecordWriterTest {
|
||||
|
||||
public class ExcelRecordWriterTest {
|
||||
|
||||
@Rule
|
||||
public TemporaryFolder testDir = new TemporaryFolder();
|
||||
@TempDir
|
||||
public Path testDir;
|
||||
|
||||
@Test
|
||||
public void testWriter() throws Exception {
|
||||
@DisplayName("Test Writer")
|
||||
void testWriter() throws Exception {
|
||||
ExcelRecordWriter excelRecordWriter = new ExcelRecordWriter();
|
||||
val records = records();
|
||||
File tmpDir = testDir.newFolder();
|
||||
File outputFile = new File(tmpDir,"testexcel.xlsx");
|
||||
File tmpDir = testDir.toFile();
|
||||
File outputFile = new File(tmpDir, "testexcel.xlsx");
|
||||
outputFile.deleteOnExit();
|
||||
FileSplit fileSplit = new FileSplit(outputFile);
|
||||
excelRecordWriter.initialize(fileSplit,new NumberOfRecordsPartitioner());
|
||||
excelRecordWriter.initialize(fileSplit, new NumberOfRecordsPartitioner());
|
||||
excelRecordWriter.writeBatch(records.getRight());
|
||||
excelRecordWriter.close();
|
||||
File parentFile = outputFile.getParentFile();
|
||||
assertEquals(1,parentFile.list().length);
|
||||
|
||||
assertEquals(1, parentFile.list().length);
|
||||
ExcelRecordReader excelRecordReader = new ExcelRecordReader();
|
||||
excelRecordReader.initialize(fileSplit);
|
||||
List<List<Writable>> next = excelRecordReader.next(10);
|
||||
assertEquals(10,next.size());
|
||||
|
||||
assertEquals(10, next.size());
|
||||
}
|
||||
|
||||
private Triple<String,Schema,List<List<Writable>>> records() {
|
||||
private Triple<String, Schema, List<List<Writable>>> records() {
|
||||
List<List<Writable>> list = new ArrayList<>();
|
||||
StringBuilder sb = new StringBuilder();
|
||||
int numColumns = 3;
|
||||
|
@ -80,13 +80,10 @@ public class ExcelRecordWriterTest {
|
|||
}
|
||||
list.add(temp);
|
||||
}
|
||||
|
||||
|
||||
Schema.Builder schemaBuilder = new Schema.Builder();
|
||||
for(int i = 0; i < numColumns; i++) {
|
||||
for (int i = 0; i < numColumns; i++) {
|
||||
schemaBuilder.addColumnInteger(String.valueOf(i));
|
||||
}
|
||||
|
||||
return Triple.of(sb.toString(),schemaBuilder.build(),list);
|
||||
return Triple.of(sb.toString(), schemaBuilder.build(), list);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -65,10 +65,10 @@
|
|||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<id>nd4j-tests-cpu</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
<id>nd4j-tests-cuda</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
||||
|
|
|
@ -17,14 +17,12 @@
|
|||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.api.records.reader.impl;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
import java.io.File;
|
||||
import java.net.URI;
|
||||
import java.sql.Connection;
|
||||
|
@ -49,53 +47,57 @@ import org.datavec.api.writable.IntWritable;
|
|||
import org.datavec.api.writable.LongWritable;
|
||||
import org.datavec.api.writable.Text;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
|
||||
public class JDBCRecordReaderTest {
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import java.nio.file.Path;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
|
||||
@Rule
|
||||
public TemporaryFolder testDir = new TemporaryFolder();
|
||||
@DisplayName("Jdbc Record Reader Test")
|
||||
class JDBCRecordReaderTest {
|
||||
|
||||
@TempDir
|
||||
public Path testDir;
|
||||
|
||||
Connection conn;
|
||||
|
||||
EmbeddedDataSource dataSource;
|
||||
|
||||
private final String dbName = "datavecTests";
|
||||
|
||||
private final String driverClassName = "org.apache.derby.jdbc.EmbeddedDriver";
|
||||
|
||||
@Before
|
||||
public void setUp() throws Exception {
|
||||
File f = testDir.newFolder();
|
||||
@BeforeEach
|
||||
void setUp() throws Exception {
|
||||
File f = testDir.toFile();
|
||||
System.setProperty("derby.system.home", f.getAbsolutePath());
|
||||
|
||||
dataSource = new EmbeddedDataSource();
|
||||
dataSource.setDatabaseName(dbName);
|
||||
dataSource.setCreateDatabase("create");
|
||||
conn = dataSource.getConnection();
|
||||
|
||||
TestDb.dropTables(conn);
|
||||
TestDb.buildCoffeeTable(conn);
|
||||
}
|
||||
|
||||
@After
|
||||
public void tearDown() throws Exception {
|
||||
@AfterEach
|
||||
void tearDown() throws Exception {
|
||||
DbUtils.closeQuietly(conn);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSimpleIter() throws Exception {
|
||||
@DisplayName("Test Simple Iter")
|
||||
void testSimpleIter() throws Exception {
|
||||
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
|
||||
List<List<Writable>> records = new ArrayList<>();
|
||||
while (reader.hasNext()) {
|
||||
List<Writable> values = reader.next();
|
||||
records.add(values);
|
||||
}
|
||||
|
||||
assertFalse(records.isEmpty());
|
||||
|
||||
List<Writable> first = records.get(0);
|
||||
assertEquals(new Text("Bolivian Dark"), first.get(0));
|
||||
assertEquals(new Text("14-001"), first.get(1));
|
||||
|
@ -104,39 +106,43 @@ public class JDBCRecordReaderTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testSimpleWithListener() throws Exception {
|
||||
@DisplayName("Test Simple With Listener")
|
||||
void testSimpleWithListener() throws Exception {
|
||||
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
|
||||
RecordListener recordListener = new LogRecordListener();
|
||||
reader.setListeners(recordListener);
|
||||
reader.next();
|
||||
|
||||
assertTrue(recordListener.invoked());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testReset() throws Exception {
|
||||
@DisplayName("Test Reset")
|
||||
void testReset() throws Exception {
|
||||
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
|
||||
List<List<Writable>> records = new ArrayList<>();
|
||||
records.add(reader.next());
|
||||
reader.reset();
|
||||
records.add(reader.next());
|
||||
|
||||
assertEquals(2, records.size());
|
||||
assertEquals(new Text("Bolivian Dark"), records.get(0).get(0));
|
||||
assertEquals(new Text("Bolivian Dark"), records.get(1).get(0));
|
||||
}
|
||||
}
|
||||
|
||||
@Test(expected = IllegalStateException.class)
|
||||
public void testLackingDataSourceShouldFail() throws Exception {
|
||||
@Test
|
||||
@DisplayName("Test Lacking Data Source Should Fail")
|
||||
void testLackingDataSourceShouldFail() {
|
||||
assertThrows(IllegalStateException.class, () -> {
|
||||
try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM Coffee")) {
|
||||
reader.initialize(null);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testConfigurationDataSourceInitialization() throws Exception {
|
||||
@DisplayName("Test Configuration Data Source Initialization")
|
||||
void testConfigurationDataSourceInitialization() throws Exception {
|
||||
try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM Coffee")) {
|
||||
Configuration conf = new Configuration();
|
||||
conf.set(JDBCRecordReader.JDBC_URL, "jdbc:derby:" + dbName + ";create=true");
|
||||
|
@ -146,28 +152,33 @@ public class JDBCRecordReaderTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
public void testInitConfigurationMissingParametersShouldFail() throws Exception {
|
||||
@Test
|
||||
@DisplayName("Test Init Configuration Missing Parameters Should Fail")
|
||||
void testInitConfigurationMissingParametersShouldFail() {
|
||||
assertThrows(IllegalArgumentException.class, () -> {
|
||||
try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM Coffee")) {
|
||||
Configuration conf = new Configuration();
|
||||
conf.set(JDBCRecordReader.JDBC_URL, "should fail anyway");
|
||||
reader.initialize(conf, null);
|
||||
}
|
||||
}
|
||||
|
||||
@Test(expected = UnsupportedOperationException.class)
|
||||
public void testRecordDataInputStreamShouldFail() throws Exception {
|
||||
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
|
||||
reader.record(null, null);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLoadFromMetaData() throws Exception {
|
||||
@DisplayName("Test Record Data Input Stream Should Fail")
|
||||
void testRecordDataInputStreamShouldFail() {
|
||||
assertThrows(UnsupportedOperationException.class, () -> {
|
||||
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
|
||||
RecordMetaDataJdbc rmd = new RecordMetaDataJdbc(new URI(conn.getMetaData().getURL()),
|
||||
"SELECT * FROM Coffee WHERE ProdNum = ?", Collections.singletonList("14-001"), reader.getClass());
|
||||
reader.record(null, null);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("Test Load From Meta Data")
|
||||
void testLoadFromMetaData() throws Exception {
|
||||
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
|
||||
RecordMetaDataJdbc rmd = new RecordMetaDataJdbc(new URI(conn.getMetaData().getURL()), "SELECT * FROM Coffee WHERE ProdNum = ?", Collections.singletonList("14-001"), reader.getClass());
|
||||
Record res = reader.loadFromMetaData(rmd);
|
||||
assertNotNull(res);
|
||||
assertEquals(new Text("Bolivian Dark"), res.getRecord().get(0));
|
||||
|
@ -177,7 +188,8 @@ public class JDBCRecordReaderTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testNextRecord() throws Exception {
|
||||
@DisplayName("Test Next Record")
|
||||
void testNextRecord() throws Exception {
|
||||
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
|
||||
Record r = reader.nextRecord();
|
||||
List<Writable> fields = r.getRecord();
|
||||
|
@ -193,7 +205,8 @@ public class JDBCRecordReaderTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testNextRecordAndRecover() throws Exception {
|
||||
@DisplayName("Test Next Record And Recover")
|
||||
void testNextRecordAndRecover() throws Exception {
|
||||
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
|
||||
Record r = reader.nextRecord();
|
||||
List<Writable> fields = r.getRecord();
|
||||
|
@ -208,8 +221,10 @@ public class JDBCRecordReaderTest {
|
|||
}
|
||||
|
||||
// Resetting the record reader when initialized as forward only should fail
|
||||
@Test(expected = RuntimeException.class)
|
||||
public void testResetForwardOnlyShouldFail() throws Exception {
|
||||
@Test
|
||||
@DisplayName("Test Reset Forward Only Should Fail")
|
||||
void testResetForwardOnlyShouldFail() {
|
||||
assertThrows(RuntimeException.class, () -> {
|
||||
try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM Coffee", dataSource)) {
|
||||
Configuration conf = new Configuration();
|
||||
conf.setInt(JDBCRecordReader.JDBC_RESULTSET_TYPE, ResultSet.TYPE_FORWARD_ONLY);
|
||||
|
@ -217,58 +232,78 @@ public class JDBCRecordReaderTest {
|
|||
reader.next();
|
||||
reader.reset();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testReadAllTypes() throws Exception {
|
||||
@DisplayName("Test Read All Types")
|
||||
void testReadAllTypes() throws Exception {
|
||||
TestDb.buildAllTypesTable(conn);
|
||||
try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM AllTypes", dataSource)) {
|
||||
reader.initialize(null);
|
||||
List<Writable> item = reader.next();
|
||||
|
||||
assertEquals(item.size(), 15);
|
||||
assertEquals(BooleanWritable.class, item.get(0).getClass()); // boolean to boolean
|
||||
assertEquals(Text.class, item.get(1).getClass()); // date to text
|
||||
assertEquals(Text.class, item.get(2).getClass()); // time to text
|
||||
assertEquals(Text.class, item.get(3).getClass()); // timestamp to text
|
||||
assertEquals(Text.class, item.get(4).getClass()); // char to text
|
||||
assertEquals(Text.class, item.get(5).getClass()); // long varchar to text
|
||||
assertEquals(Text.class, item.get(6).getClass()); // varchar to text
|
||||
assertEquals(DoubleWritable.class,
|
||||
item.get(7).getClass()); // float to double (derby's float is an alias of double by default)
|
||||
assertEquals(FloatWritable.class, item.get(8).getClass()); // real to float
|
||||
assertEquals(DoubleWritable.class, item.get(9).getClass()); // decimal to double
|
||||
assertEquals(DoubleWritable.class, item.get(10).getClass()); // numeric to double
|
||||
assertEquals(DoubleWritable.class, item.get(11).getClass()); // double to double
|
||||
assertEquals(IntWritable.class, item.get(12).getClass()); // integer to integer
|
||||
assertEquals(IntWritable.class, item.get(13).getClass()); // small int to integer
|
||||
assertEquals(LongWritable.class, item.get(14).getClass()); // bigint to long
|
||||
|
||||
// boolean to boolean
|
||||
assertEquals(BooleanWritable.class, item.get(0).getClass());
|
||||
// date to text
|
||||
assertEquals(Text.class, item.get(1).getClass());
|
||||
// time to text
|
||||
assertEquals(Text.class, item.get(2).getClass());
|
||||
// timestamp to text
|
||||
assertEquals(Text.class, item.get(3).getClass());
|
||||
// char to text
|
||||
assertEquals(Text.class, item.get(4).getClass());
|
||||
// long varchar to text
|
||||
assertEquals(Text.class, item.get(5).getClass());
|
||||
// varchar to text
|
||||
assertEquals(Text.class, item.get(6).getClass());
|
||||
assertEquals(DoubleWritable.class, // float to double (derby's float is an alias of double by default)
|
||||
item.get(7).getClass());
|
||||
// real to float
|
||||
assertEquals(FloatWritable.class, item.get(8).getClass());
|
||||
// decimal to double
|
||||
assertEquals(DoubleWritable.class, item.get(9).getClass());
|
||||
// numeric to double
|
||||
assertEquals(DoubleWritable.class, item.get(10).getClass());
|
||||
// double to double
|
||||
assertEquals(DoubleWritable.class, item.get(11).getClass());
|
||||
// integer to integer
|
||||
assertEquals(IntWritable.class, item.get(12).getClass());
|
||||
// small int to integer
|
||||
assertEquals(IntWritable.class, item.get(13).getClass());
|
||||
// bigint to long
|
||||
assertEquals(LongWritable.class, item.get(14).getClass());
|
||||
}
|
||||
}
|
||||
|
||||
@Test(expected = RuntimeException.class)
|
||||
public void testNextNoMoreShouldFail() throws Exception {
|
||||
@Test
|
||||
@DisplayName("Test Next No More Should Fail")
|
||||
void testNextNoMoreShouldFail() {
|
||||
assertThrows(RuntimeException.class, () -> {
|
||||
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
|
||||
while (reader.hasNext()) {
|
||||
reader.next();
|
||||
}
|
||||
reader.next();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
public void testInvalidMetadataShouldFail() throws Exception {
|
||||
@Test
|
||||
@DisplayName("Test Invalid Metadata Should Fail")
|
||||
void testInvalidMetadataShouldFail() {
|
||||
assertThrows(IllegalArgumentException.class, () -> {
|
||||
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
|
||||
RecordMetaDataLine md = new RecordMetaDataLine(1, new URI("file://test"), JDBCRecordReader.class);
|
||||
reader.loadFromMetaData(md);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private JDBCRecordReader getInitializedReader(String query) throws Exception {
|
||||
int[] indices = {1}; // ProdNum column
|
||||
JDBCRecordReader reader = new JDBCRecordReader(query, dataSource, "SELECT * FROM Coffee WHERE ProdNum = ?",
|
||||
indices);
|
||||
// ProdNum column
|
||||
int[] indices = { 1 };
|
||||
JDBCRecordReader reader = new JDBCRecordReader(query, dataSource, "SELECT * FROM Coffee WHERE ProdNum = ?", indices);
|
||||
reader.setTrimStrings(true);
|
||||
reader.initialize(null);
|
||||
return reader;
|
||||
|
|
|
@ -61,25 +61,18 @@
|
|||
<artifactId>nd4j-common</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.datavec</groupId>
|
||||
<artifactId>datavec-geo</artifactId>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>python4j-numpy</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.datavec</groupId>
|
||||
<artifactId>datavec-python</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<id>nd4j-tests-cpu</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
<id>nd4j-tests-cuda</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
||||
|
|
|
@ -36,14 +36,14 @@ import org.datavec.api.writable.LongWritable;
|
|||
import org.datavec.api.writable.Text;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.joda.time.DateTimeZone;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
public class LocalTransformProcessRecordReaderTests {
|
||||
|
||||
|
|
|
@ -29,9 +29,9 @@ import org.datavec.api.transform.schema.Schema;
|
|||
import org.datavec.api.util.ndarray.RecordConverter;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.datavec.local.transforms.AnalyzeLocal;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
|
@ -39,12 +39,11 @@ import org.nd4j.common.io.ClassPathResource;
|
|||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
public class TestAnalyzeLocal {
|
||||
|
||||
@Rule
|
||||
public TemporaryFolder testDir = new TemporaryFolder();
|
||||
|
||||
|
||||
@Test
|
||||
public void testAnalysisBasic() throws Exception {
|
||||
|
@ -72,7 +71,7 @@ public class TestAnalyzeLocal {
|
|||
INDArray mean = arr.mean(0);
|
||||
INDArray std = arr.std(0);
|
||||
|
||||
for( int i=0; i<5; i++ ){
|
||||
for( int i = 0; i < 5; i++) {
|
||||
double m = ((NumericalColumnAnalysis)da.getColumnAnalysis().get(i)).getMean();
|
||||
double stddev = ((NumericalColumnAnalysis)da.getColumnAnalysis().get(i)).getSampleStdev();
|
||||
assertEquals(mean.getDouble(i), m, 1e-3);
|
||||
|
|
|
@ -27,7 +27,7 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
|
|||
import org.datavec.api.split.FileSplit;
|
||||
import org.datavec.api.writable.Writable;
|
||||
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
|
||||
import java.io.File;
|
||||
|
@ -36,8 +36,8 @@ import java.util.List;
|
|||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
public class TestLineRecordReaderFunction {
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ import org.datavec.api.writable.NDArrayWritable;
|
|||
import org.datavec.api.writable.Writable;
|
||||
|
||||
import org.datavec.local.transforms.misc.NDArrayToWritablesFunction;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
|
@ -33,7 +33,7 @@ import java.util.ArrayList;
|
|||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
public class TestNDArrayToWritablesFunction {
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ import org.datavec.api.writable.NDArrayWritable;
|
|||
import org.datavec.api.writable.Writable;
|
||||
|
||||
import org.datavec.local.transforms.misc.WritablesToNDArrayFunction;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -33,7 +33,7 @@ import org.nd4j.linalg.factory.Nd4j;
|
|||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
public class TestWritablesToNDArrayFunction {
|
||||
|
||||
|
|
|
@ -30,12 +30,12 @@ import org.datavec.api.writable.Writable;
|
|||
|
||||
import org.datavec.local.transforms.misc.SequenceWritablesToStringFunction;
|
||||
import org.datavec.local.transforms.misc.WritablesToStringFunction;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
public class TestWritablesToStringFunctions {
|
||||
|
||||
|
|
|
@ -17,10 +17,8 @@
|
|||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.local.transforms.transform;
|
||||
|
||||
|
||||
import org.datavec.api.transform.MathFunction;
|
||||
import org.datavec.api.transform.MathOp;
|
||||
import org.datavec.api.transform.ReduceOp;
|
||||
|
@ -31,108 +29,85 @@ import org.datavec.api.transform.reduce.Reducer;
|
|||
import org.datavec.api.transform.schema.Schema;
|
||||
import org.datavec.api.transform.schema.SequenceSchema;
|
||||
import org.datavec.api.writable.*;
|
||||
import org.datavec.python.PythonTransform;
|
||||
|
||||
import org.datavec.local.transforms.LocalTransformExecutor;
|
||||
import org.junit.Ignore;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||
|
||||
import java.util.*;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import static java.time.Duration.ofMillis;
|
||||
import static org.junit.jupiter.api.Assertions.assertTimeout;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
public class ExecutionTest {
|
||||
@DisplayName("Execution Test")
|
||||
class ExecutionTest {
|
||||
|
||||
@Test
|
||||
public void testExecutionNdarray() {
|
||||
Schema schema = new Schema.Builder()
|
||||
.addColumnNDArray("first",new long[]{1,32577})
|
||||
.addColumnNDArray("second",new long[]{1,32577}).build();
|
||||
|
||||
TransformProcess transformProcess = new TransformProcess.Builder(schema)
|
||||
.ndArrayMathFunctionTransform("first", MathFunction.SIN)
|
||||
.ndArrayMathFunctionTransform("second",MathFunction.COS)
|
||||
.build();
|
||||
|
||||
@DisplayName("Test Execution Ndarray")
|
||||
void testExecutionNdarray() {
|
||||
Schema schema = new Schema.Builder().addColumnNDArray("first", new long[] { 1, 32577 }).addColumnNDArray("second", new long[] { 1, 32577 }).build();
|
||||
TransformProcess transformProcess = new TransformProcess.Builder(schema).ndArrayMathFunctionTransform("first", MathFunction.SIN).ndArrayMathFunctionTransform("second", MathFunction.COS).build();
|
||||
List<List<Writable>> functions = new ArrayList<>();
|
||||
List<Writable> firstRow = new ArrayList<>();
|
||||
INDArray firstArr = Nd4j.linspace(1,4,4);
|
||||
INDArray secondArr = Nd4j.linspace(1,4,4);
|
||||
INDArray firstArr = Nd4j.linspace(1, 4, 4);
|
||||
INDArray secondArr = Nd4j.linspace(1, 4, 4);
|
||||
firstRow.add(new NDArrayWritable(firstArr));
|
||||
firstRow.add(new NDArrayWritable(secondArr));
|
||||
functions.add(firstRow);
|
||||
|
||||
List<List<Writable>> execute = LocalTransformExecutor.execute(functions, transformProcess);
|
||||
INDArray firstResult = ((NDArrayWritable) execute.get(0).get(0)).get();
|
||||
INDArray secondResult = ((NDArrayWritable) execute.get(0).get(1)).get();
|
||||
|
||||
INDArray expected = Transforms.sin(firstArr);
|
||||
INDArray secondExpected = Transforms.cos(secondArr);
|
||||
assertEquals(expected,firstResult);
|
||||
assertEquals(secondExpected,secondResult);
|
||||
|
||||
assertEquals(expected, firstResult);
|
||||
assertEquals(secondExpected, secondResult);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testExecutionSimple() {
|
||||
Schema schema = new Schema.Builder().addColumnInteger("col0")
|
||||
.addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").
|
||||
addColumnFloat("col3").build();
|
||||
|
||||
TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1")
|
||||
.doubleMathOp("col2", MathOp.Add, 10.0).floatMathOp("col3", MathOp.Add, 5f).build();
|
||||
|
||||
@DisplayName("Test Execution Simple")
|
||||
void testExecutionSimple() {
|
||||
Schema schema = new Schema.Builder().addColumnInteger("col0").addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").addColumnFloat("col3").build();
|
||||
TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1").doubleMathOp("col2", MathOp.Add, 10.0).floatMathOp("col3", MathOp.Add, 5f).build();
|
||||
List<List<Writable>> inputData = new ArrayList<>();
|
||||
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1), new FloatWritable(0.3f)));
|
||||
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1), new FloatWritable(1.7f)));
|
||||
inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1), new FloatWritable(3.6f)));
|
||||
|
||||
List<List<Writable>> rdd = (inputData);
|
||||
|
||||
List<List<Writable>> out = new ArrayList<>(LocalTransformExecutor.execute(rdd, tp));
|
||||
|
||||
Collections.sort(out, new Comparator<List<Writable>>() {
|
||||
|
||||
@Override
|
||||
public int compare(List<Writable> o1, List<Writable> o2) {
|
||||
return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
|
||||
}
|
||||
});
|
||||
|
||||
List<List<Writable>> expected = new ArrayList<>();
|
||||
expected.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1), new FloatWritable(5.3f)));
|
||||
expected.add(Arrays.<Writable>asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1), new FloatWritable(6.7f)));
|
||||
expected.add(Arrays.<Writable>asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1), new FloatWritable(8.6f)));
|
||||
|
||||
assertEquals(expected, out);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testFilter() {
|
||||
Schema filterSchema = new Schema.Builder()
|
||||
.addColumnDouble("col1").addColumnDouble("col2")
|
||||
.addColumnDouble("col3").build();
|
||||
@DisplayName("Test Filter")
|
||||
void testFilter() {
|
||||
Schema filterSchema = new Schema.Builder().addColumnDouble("col1").addColumnDouble("col2").addColumnDouble("col3").build();
|
||||
List<List<Writable>> inputData = new ArrayList<>();
|
||||
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new DoubleWritable(1), new DoubleWritable(0.1)));
|
||||
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new DoubleWritable(3), new DoubleWritable(1.1)));
|
||||
inputData.add(Arrays.<Writable>asList(new IntWritable(2), new DoubleWritable(3), new DoubleWritable(2.1)));
|
||||
TransformProcess transformProcess = new TransformProcess.Builder(filterSchema)
|
||||
.filter(new DoubleColumnCondition("col1",ConditionOp.LessThan,1)).build();
|
||||
TransformProcess transformProcess = new TransformProcess.Builder(filterSchema).filter(new DoubleColumnCondition("col1", ConditionOp.LessThan, 1)).build();
|
||||
List<List<Writable>> execute = LocalTransformExecutor.execute(inputData, transformProcess);
|
||||
assertEquals(2,execute.size());
|
||||
assertEquals(2, execute.size());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testExecutionSequence() {
|
||||
|
||||
Schema schema = new SequenceSchema.Builder().addColumnInteger("col0")
|
||||
.addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").build();
|
||||
|
||||
TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1")
|
||||
.doubleMathOp("col2", MathOp.Add, 10.0).build();
|
||||
|
||||
@DisplayName("Test Execution Sequence")
|
||||
void testExecutionSequence() {
|
||||
Schema schema = new SequenceSchema.Builder().addColumnInteger("col0").addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").build();
|
||||
TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1").doubleMathOp("col2", MathOp.Add, 10.0).build();
|
||||
List<List<List<Writable>>> inputSequences = new ArrayList<>();
|
||||
List<List<Writable>> seq1 = new ArrayList<>();
|
||||
seq1.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
|
||||
|
@ -141,21 +116,17 @@ public class ExecutionTest {
|
|||
List<List<Writable>> seq2 = new ArrayList<>();
|
||||
seq2.add(Arrays.<Writable>asList(new IntWritable(3), new Text("state0"), new DoubleWritable(3.1)));
|
||||
seq2.add(Arrays.<Writable>asList(new IntWritable(4), new Text("state1"), new DoubleWritable(4.1)));
|
||||
|
||||
inputSequences.add(seq1);
|
||||
inputSequences.add(seq2);
|
||||
|
||||
List<List<List<Writable>>> rdd = (inputSequences);
|
||||
|
||||
List<List<List<Writable>>> out = LocalTransformExecutor.executeSequenceToSequence(rdd, tp);
|
||||
|
||||
Collections.sort(out, new Comparator<List<List<Writable>>>() {
|
||||
|
||||
@Override
|
||||
public int compare(List<List<Writable>> o1, List<List<Writable>> o2) {
|
||||
return -Integer.compare(o1.size(), o2.size());
|
||||
}
|
||||
});
|
||||
|
||||
List<List<List<Writable>>> expectedSequence = new ArrayList<>();
|
||||
List<List<Writable>> seq1e = new ArrayList<>();
|
||||
seq1e.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1)));
|
||||
|
@ -164,121 +135,37 @@ public class ExecutionTest {
|
|||
List<List<Writable>> seq2e = new ArrayList<>();
|
||||
seq2e.add(Arrays.<Writable>asList(new IntWritable(3), new IntWritable(0), new DoubleWritable(13.1)));
|
||||
seq2e.add(Arrays.<Writable>asList(new IntWritable(4), new IntWritable(1), new DoubleWritable(14.1)));
|
||||
|
||||
expectedSequence.add(seq1e);
|
||||
expectedSequence.add(seq2e);
|
||||
|
||||
assertEquals(expectedSequence, out);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testReductionGlobal() {
|
||||
|
||||
List<List<Writable>> in = Arrays.asList(
|
||||
Arrays.<Writable>asList(new Text("first"), new DoubleWritable(3.0)),
|
||||
Arrays.<Writable>asList(new Text("second"), new DoubleWritable(5.0))
|
||||
);
|
||||
|
||||
@DisplayName("Test Reduction Global")
|
||||
void testReductionGlobal() {
|
||||
List<List<Writable>> in = Arrays.asList(Arrays.<Writable>asList(new Text("first"), new DoubleWritable(3.0)), Arrays.<Writable>asList(new Text("second"), new DoubleWritable(5.0)));
|
||||
List<List<Writable>> inData = in;
|
||||
|
||||
Schema s = new Schema.Builder()
|
||||
.addColumnString("textCol")
|
||||
.addColumnDouble("doubleCol")
|
||||
.build();
|
||||
|
||||
TransformProcess tp = new TransformProcess.Builder(s)
|
||||
.reduce(new Reducer.Builder(ReduceOp.TakeFirst)
|
||||
.takeFirstColumns("textCol")
|
||||
.meanColumns("doubleCol").build())
|
||||
.build();
|
||||
|
||||
Schema s = new Schema.Builder().addColumnString("textCol").addColumnDouble("doubleCol").build();
|
||||
TransformProcess tp = new TransformProcess.Builder(s).reduce(new Reducer.Builder(ReduceOp.TakeFirst).takeFirstColumns("textCol").meanColumns("doubleCol").build()).build();
|
||||
List<List<Writable>> outRdd = LocalTransformExecutor.execute(inData, tp);
|
||||
|
||||
List<List<Writable>> out = outRdd;
|
||||
|
||||
List<List<Writable>> expOut = Collections.singletonList(Arrays.<Writable>asList(new Text("first"), new DoubleWritable(4.0)));
|
||||
|
||||
assertEquals(expOut, out);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testReductionByKey(){
|
||||
|
||||
List<List<Writable>> in = Arrays.asList(
|
||||
Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(3.0)),
|
||||
Arrays.<Writable>asList(new IntWritable(0), new Text("second"), new DoubleWritable(5.0)),
|
||||
Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(30.0)),
|
||||
Arrays.<Writable>asList(new IntWritable(1), new Text("s"), new DoubleWritable(50.0))
|
||||
);
|
||||
|
||||
@DisplayName("Test Reduction By Key")
|
||||
void testReductionByKey() {
|
||||
List<List<Writable>> in = Arrays.asList(Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(3.0)), Arrays.<Writable>asList(new IntWritable(0), new Text("second"), new DoubleWritable(5.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(30.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("s"), new DoubleWritable(50.0)));
|
||||
List<List<Writable>> inData = in;
|
||||
|
||||
Schema s = new Schema.Builder()
|
||||
.addColumnInteger("intCol")
|
||||
.addColumnString("textCol")
|
||||
.addColumnDouble("doubleCol")
|
||||
.build();
|
||||
|
||||
TransformProcess tp = new TransformProcess.Builder(s)
|
||||
.reduce(new Reducer.Builder(ReduceOp.TakeFirst)
|
||||
.keyColumns("intCol")
|
||||
.takeFirstColumns("textCol")
|
||||
.meanColumns("doubleCol").build())
|
||||
.build();
|
||||
|
||||
Schema s = new Schema.Builder().addColumnInteger("intCol").addColumnString("textCol").addColumnDouble("doubleCol").build();
|
||||
TransformProcess tp = new TransformProcess.Builder(s).reduce(new Reducer.Builder(ReduceOp.TakeFirst).keyColumns("intCol").takeFirstColumns("textCol").meanColumns("doubleCol").build()).build();
|
||||
List<List<Writable>> outRdd = LocalTransformExecutor.execute(inData, tp);
|
||||
|
||||
List<List<Writable>> out = outRdd;
|
||||
|
||||
List<List<Writable>> expOut = Arrays.asList(
|
||||
Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(4.0)),
|
||||
Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(40.0)));
|
||||
|
||||
List<List<Writable>> expOut = Arrays.asList(Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(4.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(40.0)));
|
||||
out = new ArrayList<>(out);
|
||||
Collections.sort(
|
||||
out, new Comparator<List<Writable>>() {
|
||||
@Override
|
||||
public int compare(List<Writable> o1, List<Writable> o2) {
|
||||
return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
Collections.sort(out, Comparator.comparingInt(o -> o.get(0).toInt()));
|
||||
assertEquals(expOut, out);
|
||||
}
|
||||
|
||||
@Test(timeout = 60000L)
|
||||
@Ignore("AB 2019/05/21 - Fine locally, timeouts on CI - Issue #7657 and #7771")
|
||||
public void testPythonExecutionNdarray()throws Exception{
|
||||
Schema schema = new Schema.Builder()
|
||||
.addColumnNDArray("first",new long[]{1,32577})
|
||||
.addColumnNDArray("second",new long[]{1,32577}).build();
|
||||
|
||||
TransformProcess transformProcess = new TransformProcess.Builder(schema)
|
||||
.transform(
|
||||
PythonTransform.builder().code(
|
||||
"first = np.sin(first)\nsecond = np.cos(second)")
|
||||
.outputSchema(schema).build())
|
||||
.build();
|
||||
|
||||
List<List<Writable>> functions = new ArrayList<>();
|
||||
List<Writable> firstRow = new ArrayList<>();
|
||||
INDArray firstArr = Nd4j.linspace(1,4,4);
|
||||
INDArray secondArr = Nd4j.linspace(1,4,4);
|
||||
firstRow.add(new NDArrayWritable(firstArr));
|
||||
firstRow.add(new NDArrayWritable(secondArr));
|
||||
functions.add(firstRow);
|
||||
|
||||
List<List<Writable>> execute = LocalTransformExecutor.execute(functions, transformProcess);
|
||||
INDArray firstResult = ((NDArrayWritable) execute.get(0).get(0)).get();
|
||||
INDArray secondResult = ((NDArrayWritable) execute.get(0).get(1)).get();
|
||||
|
||||
INDArray expected = Transforms.sin(firstArr);
|
||||
INDArray secondExpected = Transforms.cos(secondArr);
|
||||
assertEquals(expected,firstResult);
|
||||
assertEquals(secondExpected,secondResult);
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,154 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.local.transforms.transform;
|
||||
|
||||
import org.datavec.api.transform.ColumnType;
|
||||
import org.datavec.api.transform.Transform;
|
||||
import org.datavec.api.transform.geo.LocationType;
|
||||
import org.datavec.api.transform.schema.Schema;
|
||||
import org.datavec.api.transform.transform.geo.CoordinatesDistanceTransform;
|
||||
import org.datavec.api.transform.transform.geo.IPAddressToCoordinatesTransform;
|
||||
import org.datavec.api.transform.transform.geo.IPAddressToLocationTransform;
|
||||
import org.datavec.api.writable.DoubleWritable;
|
||||
import org.datavec.api.writable.Text;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.AfterClass;
|
||||
import org.junit.BeforeClass;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
|
||||
import java.io.*;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
/**
|
||||
* @author saudet
|
||||
*/
|
||||
public class TestGeoTransforms {
|
||||
|
||||
@BeforeClass
|
||||
public static void beforeClass() throws Exception {
|
||||
//Use test resources version to avoid tests suddenly failing due to IP/Location DB content changing
|
||||
File f = new ClassPathResource("datavec-geo/GeoIP2-City-Test.mmdb").getFile();
|
||||
System.setProperty(IPAddressToLocationTransform.GEOIP_FILE_PROPERTY, f.getPath());
|
||||
}
|
||||
|
||||
@AfterClass
|
||||
public static void afterClass(){
|
||||
System.setProperty(IPAddressToLocationTransform.GEOIP_FILE_PROPERTY, "");
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testCoordinatesDistanceTransform() throws Exception {
|
||||
Schema schema = new Schema.Builder().addColumnString("point").addColumnString("mean").addColumnString("stddev")
|
||||
.build();
|
||||
|
||||
Transform transform = new CoordinatesDistanceTransform("dist", "point", "mean", "stddev", "\\|");
|
||||
transform.setInputSchema(schema);
|
||||
|
||||
Schema out = transform.transform(schema);
|
||||
assertEquals(4, out.numColumns());
|
||||
assertEquals(Arrays.asList("point", "mean", "stddev", "dist"), out.getColumnNames());
|
||||
assertEquals(Arrays.asList(ColumnType.String, ColumnType.String, ColumnType.String, ColumnType.Double),
|
||||
out.getColumnTypes());
|
||||
|
||||
assertEquals(Arrays.asList((Writable) new Text("-30"), new Text("20"), new Text("10"), new DoubleWritable(5.0)),
|
||||
transform.map(Arrays.asList((Writable) new Text("-30"), new Text("20"), new Text("10"))));
|
||||
assertEquals(Arrays.asList((Writable) new Text("50|40"), new Text("10|-20"), new Text("10|5"),
|
||||
new DoubleWritable(Math.sqrt(160))),
|
||||
transform.map(Arrays.asList((Writable) new Text("50|40"), new Text("10|-20"),
|
||||
new Text("10|5"))));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testIPAddressToCoordinatesTransform() throws Exception {
|
||||
Schema schema = new Schema.Builder().addColumnString("column").build();
|
||||
|
||||
Transform transform = new IPAddressToCoordinatesTransform("column", "CUSTOM_DELIMITER");
|
||||
transform.setInputSchema(schema);
|
||||
|
||||
Schema out = transform.transform(schema);
|
||||
|
||||
assertEquals(1, out.getColumnMetaData().size());
|
||||
assertEquals(ColumnType.String, out.getMetaData(0).getColumnType());
|
||||
|
||||
String in = "81.2.69.160";
|
||||
double latitude = 51.5142;
|
||||
double longitude = -0.0931;
|
||||
|
||||
List<Writable> writables = transform.map(Collections.singletonList((Writable) new Text(in)));
|
||||
assertEquals(1, writables.size());
|
||||
String[] coordinates = writables.get(0).toString().split("CUSTOM_DELIMITER");
|
||||
assertEquals(2, coordinates.length);
|
||||
assertEquals(latitude, Double.parseDouble(coordinates[0]), 0.1);
|
||||
assertEquals(longitude, Double.parseDouble(coordinates[1]), 0.1);
|
||||
|
||||
//Check serialization: things like DatabaseReader etc aren't serializable, hence we need custom serialization :/
|
||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
||||
ObjectOutputStream oos = new ObjectOutputStream(baos);
|
||||
oos.writeObject(transform);
|
||||
|
||||
byte[] bytes = baos.toByteArray();
|
||||
|
||||
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
|
||||
ObjectInputStream ois = new ObjectInputStream(bais);
|
||||
|
||||
Transform deserialized = (Transform) ois.readObject();
|
||||
writables = deserialized.map(Collections.singletonList((Writable) new Text(in)));
|
||||
assertEquals(1, writables.size());
|
||||
coordinates = writables.get(0).toString().split("CUSTOM_DELIMITER");
|
||||
//System.out.println(Arrays.toString(coordinates));
|
||||
assertEquals(2, coordinates.length);
|
||||
assertEquals(latitude, Double.parseDouble(coordinates[0]), 0.1);
|
||||
assertEquals(longitude, Double.parseDouble(coordinates[1]), 0.1);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testIPAddressToLocationTransform() throws Exception {
|
||||
Schema schema = new Schema.Builder().addColumnString("column").build();
|
||||
LocationType[] locationTypes = LocationType.values();
|
||||
String in = "81.2.69.160";
|
||||
String[] locations = {"London", "2643743", "Europe", "6255148", "United Kingdom", "2635167",
|
||||
"51.5142:-0.0931", "", "England", "6269131"}; //Note: no postcode in this test DB for this record
|
||||
|
||||
for (int i = 0; i < locationTypes.length; i++) {
|
||||
LocationType locationType = locationTypes[i];
|
||||
String location = locations[i];
|
||||
|
||||
Transform transform = new IPAddressToLocationTransform("column", locationType);
|
||||
transform.setInputSchema(schema);
|
||||
|
||||
Schema out = transform.transform(schema);
|
||||
|
||||
assertEquals(1, out.getColumnMetaData().size());
|
||||
assertEquals(ColumnType.String, out.getMetaData(0).getColumnType());
|
||||
|
||||
List<Writable> writables = transform.map(Collections.singletonList((Writable) new Text(in)));
|
||||
assertEquals(1, writables.size());
|
||||
assertEquals(location, writables.get(0).toString());
|
||||
//System.out.println(location);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,379 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.local.transforms.transform;
|
||||
|
||||
import org.datavec.api.transform.TransformProcess;
|
||||
import org.datavec.api.transform.condition.Condition;
|
||||
import org.datavec.api.transform.filter.ConditionFilter;
|
||||
import org.datavec.api.transform.filter.Filter;
|
||||
import org.datavec.api.transform.schema.Schema;
|
||||
import org.datavec.local.transforms.LocalTransformExecutor;
|
||||
|
||||
import org.datavec.api.writable.*;
|
||||
import org.datavec.python.PythonCondition;
|
||||
import org.datavec.python.PythonTransform;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import javax.annotation.concurrent.NotThreadSafe;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
import static junit.framework.TestCase.assertTrue;
|
||||
import static org.datavec.api.transform.schema.Schema.Builder;
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
@NotThreadSafe
|
||||
public class TestPythonTransformProcess {
|
||||
|
||||
|
||||
@Test()
|
||||
public void testStringConcat() throws Exception{
|
||||
Builder schemaBuilder = new Builder();
|
||||
schemaBuilder
|
||||
.addColumnString("col1")
|
||||
.addColumnString("col2");
|
||||
|
||||
Schema initialSchema = schemaBuilder.build();
|
||||
schemaBuilder.addColumnString("col3");
|
||||
Schema finalSchema = schemaBuilder.build();
|
||||
|
||||
String pythonCode = "col3 = col1 + col2";
|
||||
|
||||
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
|
||||
PythonTransform.builder().code(pythonCode)
|
||||
.outputSchema(finalSchema)
|
||||
.build()
|
||||
).build();
|
||||
|
||||
List<Writable> inputs = Arrays.asList((Writable)new Text("Hello "), new Text("World!"));
|
||||
|
||||
List<Writable> outputs = tp.execute(inputs);
|
||||
assertEquals((outputs.get(0)).toString(), "Hello ");
|
||||
assertEquals((outputs.get(1)).toString(), "World!");
|
||||
assertEquals((outputs.get(2)).toString(), "Hello World!");
|
||||
|
||||
}
|
||||
|
||||
@Test(timeout = 60000L)
|
||||
public void testMixedTypes() throws Exception{
|
||||
Builder schemaBuilder = new Builder();
|
||||
schemaBuilder
|
||||
.addColumnInteger("col1")
|
||||
.addColumnFloat("col2")
|
||||
.addColumnString("col3")
|
||||
.addColumnDouble("col4");
|
||||
|
||||
|
||||
Schema initialSchema = schemaBuilder.build();
|
||||
schemaBuilder.addColumnInteger("col5");
|
||||
Schema finalSchema = schemaBuilder.build();
|
||||
|
||||
String pythonCode = "col5 = (int(col3) + col1 + int(col2)) * int(col4)";
|
||||
|
||||
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
|
||||
PythonTransform.builder().code(pythonCode)
|
||||
.outputSchema(finalSchema)
|
||||
.inputSchema(initialSchema)
|
||||
.build() ).build();
|
||||
|
||||
List<Writable> inputs = Arrays.asList((Writable)new IntWritable(10),
|
||||
new FloatWritable(3.5f),
|
||||
new Text("5"),
|
||||
new DoubleWritable(2.0)
|
||||
);
|
||||
|
||||
List<Writable> outputs = tp.execute(inputs);
|
||||
assertEquals(((LongWritable)outputs.get(4)).get(), 36);
|
||||
}
|
||||
|
||||
@Test(timeout = 60000L)
|
||||
public void testNDArray() throws Exception{
|
||||
long[] shape = new long[]{3, 2};
|
||||
INDArray arr1 = Nd4j.rand(shape);
|
||||
INDArray arr2 = Nd4j.rand(shape);
|
||||
|
||||
INDArray expectedOutput = arr1.add(arr2);
|
||||
|
||||
Builder schemaBuilder = new Builder();
|
||||
schemaBuilder
|
||||
.addColumnNDArray("col1", shape)
|
||||
.addColumnNDArray("col2", shape);
|
||||
|
||||
Schema initialSchema = schemaBuilder.build();
|
||||
schemaBuilder.addColumnNDArray("col3", shape);
|
||||
Schema finalSchema = schemaBuilder.build();
|
||||
|
||||
String pythonCode = "col3 = col1 + col2";
|
||||
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
|
||||
PythonTransform.builder().code(pythonCode)
|
||||
.outputSchema(finalSchema)
|
||||
.build() ).build();
|
||||
|
||||
List<Writable> inputs = Arrays.asList(
|
||||
(Writable)
|
||||
new NDArrayWritable(arr1),
|
||||
new NDArrayWritable(arr2)
|
||||
);
|
||||
|
||||
List<Writable> outputs = tp.execute(inputs);
|
||||
assertEquals(arr1, ((NDArrayWritable)outputs.get(0)).get());
|
||||
assertEquals(arr2, ((NDArrayWritable)outputs.get(1)).get());
|
||||
assertEquals(expectedOutput,((NDArrayWritable)outputs.get(2)).get());
|
||||
|
||||
}
|
||||
|
||||
@Test(timeout = 60000L)
|
||||
public void testNDArray2() throws Exception{
|
||||
long[] shape = new long[]{3, 2};
|
||||
INDArray arr1 = Nd4j.rand(shape);
|
||||
INDArray arr2 = Nd4j.rand(shape);
|
||||
|
||||
INDArray expectedOutput = arr1.add(arr2);
|
||||
|
||||
Builder schemaBuilder = new Builder();
|
||||
schemaBuilder
|
||||
.addColumnNDArray("col1", shape)
|
||||
.addColumnNDArray("col2", shape);
|
||||
|
||||
Schema initialSchema = schemaBuilder.build();
|
||||
schemaBuilder.addColumnNDArray("col3", shape);
|
||||
Schema finalSchema = schemaBuilder.build();
|
||||
|
||||
String pythonCode = "col3 = col1 + col2";
|
||||
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
|
||||
PythonTransform.builder().code(pythonCode)
|
||||
.outputSchema(finalSchema)
|
||||
.build() ).build();
|
||||
|
||||
List<Writable> inputs = Arrays.asList(
|
||||
(Writable)
|
||||
new NDArrayWritable(arr1),
|
||||
new NDArrayWritable(arr2)
|
||||
);
|
||||
|
||||
List<Writable> outputs = tp.execute(inputs);
|
||||
assertEquals(arr1, ((NDArrayWritable)outputs.get(0)).get());
|
||||
assertEquals(arr2, ((NDArrayWritable)outputs.get(1)).get());
|
||||
assertEquals(expectedOutput,((NDArrayWritable)outputs.get(2)).get());
|
||||
|
||||
}
|
||||
|
||||
@Test(timeout = 60000L)
|
||||
public void testNDArrayMixed() throws Exception{
|
||||
long[] shape = new long[]{3, 2};
|
||||
INDArray arr1 = Nd4j.rand(DataType.DOUBLE, shape);
|
||||
INDArray arr2 = Nd4j.rand(DataType.DOUBLE, shape);
|
||||
INDArray expectedOutput = arr1.add(arr2.castTo(DataType.DOUBLE));
|
||||
|
||||
Builder schemaBuilder = new Builder();
|
||||
schemaBuilder
|
||||
.addColumnNDArray("col1", shape)
|
||||
.addColumnNDArray("col2", shape);
|
||||
|
||||
Schema initialSchema = schemaBuilder.build();
|
||||
schemaBuilder.addColumnNDArray("col3", shape);
|
||||
Schema finalSchema = schemaBuilder.build();
|
||||
|
||||
String pythonCode = "col3 = col1 + col2";
|
||||
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
|
||||
PythonTransform.builder().code(pythonCode)
|
||||
.outputSchema(finalSchema)
|
||||
.build()
|
||||
).build();
|
||||
|
||||
List<Writable> inputs = Arrays.asList(
|
||||
(Writable)
|
||||
new NDArrayWritable(arr1),
|
||||
new NDArrayWritable(arr2)
|
||||
);
|
||||
|
||||
List<Writable> outputs = tp.execute(inputs);
|
||||
assertEquals(arr1, ((NDArrayWritable)outputs.get(0)).get());
|
||||
assertEquals(arr2, ((NDArrayWritable)outputs.get(1)).get());
|
||||
assertEquals(expectedOutput,((NDArrayWritable)outputs.get(2)).get());
|
||||
|
||||
}
|
||||
|
||||
@Test(timeout = 60000L)
|
||||
public void testPythonFilter() {
|
||||
Schema schema = new Builder().addColumnInteger("column").build();
|
||||
|
||||
Condition condition = new PythonCondition(
|
||||
"f = lambda: column < 0"
|
||||
);
|
||||
|
||||
condition.setInputSchema(schema);
|
||||
|
||||
Filter filter = new ConditionFilter(condition);
|
||||
|
||||
assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(10))));
|
||||
assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(1))));
|
||||
assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(0))));
|
||||
assertTrue(filter.removeExample(Collections.singletonList(new IntWritable(-1))));
|
||||
assertTrue(filter.removeExample(Collections.singletonList(new IntWritable(-10))));
|
||||
|
||||
}
|
||||
|
||||
@Test(timeout = 60000L)
|
||||
public void testPythonFilterAndTransform() throws Exception{
|
||||
Builder schemaBuilder = new Builder();
|
||||
schemaBuilder
|
||||
.addColumnInteger("col1")
|
||||
.addColumnFloat("col2")
|
||||
.addColumnString("col3")
|
||||
.addColumnDouble("col4");
|
||||
|
||||
Schema initialSchema = schemaBuilder.build();
|
||||
schemaBuilder.addColumnString("col6");
|
||||
Schema finalSchema = schemaBuilder.build();
|
||||
|
||||
Condition condition = new PythonCondition(
|
||||
"f = lambda: col1 < 0 and col2 > 10.0"
|
||||
);
|
||||
|
||||
condition.setInputSchema(initialSchema);
|
||||
|
||||
Filter filter = new ConditionFilter(condition);
|
||||
|
||||
String pythonCode = "col6 = str(col1 + col2)";
|
||||
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
|
||||
PythonTransform.builder().code(pythonCode)
|
||||
.outputSchema(finalSchema)
|
||||
.build()
|
||||
).filter(
|
||||
filter
|
||||
).build();
|
||||
|
||||
List<List<Writable>> inputs = new ArrayList<>();
|
||||
inputs.add(
|
||||
Arrays.asList(
|
||||
(Writable)
|
||||
new IntWritable(5),
|
||||
new FloatWritable(3.0f),
|
||||
new Text("abcd"),
|
||||
new DoubleWritable(2.1))
|
||||
);
|
||||
inputs.add(
|
||||
Arrays.asList(
|
||||
(Writable)
|
||||
new IntWritable(-3),
|
||||
new FloatWritable(3.0f),
|
||||
new Text("abcd"),
|
||||
new DoubleWritable(2.1))
|
||||
);
|
||||
inputs.add(
|
||||
Arrays.asList(
|
||||
(Writable)
|
||||
new IntWritable(5),
|
||||
new FloatWritable(11.2f),
|
||||
new Text("abcd"),
|
||||
new DoubleWritable(2.1))
|
||||
);
|
||||
|
||||
LocalTransformExecutor.execute(inputs,tp);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testPythonTransformNoOutputSpecified() throws Exception {
|
||||
PythonTransform pythonTransform = PythonTransform.builder()
|
||||
.code("a += 2; b = 'hello world'")
|
||||
.returnAllInputs(true)
|
||||
.build();
|
||||
List<List<Writable>> inputs = new ArrayList<>();
|
||||
inputs.add(Arrays.asList((Writable)new IntWritable(1)));
|
||||
Schema inputSchema = new Builder()
|
||||
.addColumnInteger("a")
|
||||
.build();
|
||||
|
||||
TransformProcess tp = new TransformProcess.Builder(inputSchema)
|
||||
.transform(pythonTransform)
|
||||
.build();
|
||||
List<List<Writable>> execute = LocalTransformExecutor.execute(inputs, tp);
|
||||
assertEquals(3,execute.get(0).get(0).toInt());
|
||||
assertEquals("hello world",execute.get(0).get(1).toString());
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNumpyTransform() {
|
||||
PythonTransform pythonTransform = PythonTransform.builder()
|
||||
.code("a += 2; b = 'hello world'")
|
||||
.returnAllInputs(true)
|
||||
.build();
|
||||
|
||||
List<List<Writable>> inputs = new ArrayList<>();
|
||||
inputs.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.scalar(1).reshape(1,1))));
|
||||
Schema inputSchema = new Builder()
|
||||
.addColumnNDArray("a",new long[]{1,1})
|
||||
.build();
|
||||
|
||||
TransformProcess tp = new TransformProcess.Builder(inputSchema)
|
||||
.transform(pythonTransform)
|
||||
.build();
|
||||
List<List<Writable>> execute = LocalTransformExecutor.execute(inputs, tp);
|
||||
assertFalse(execute.isEmpty());
|
||||
assertNotNull(execute.get(0));
|
||||
assertNotNull(execute.get(0).get(0));
|
||||
assertNotNull(execute.get(0).get(1));
|
||||
assertEquals(Nd4j.scalar(3).reshape(1, 1),((NDArrayWritable)execute.get(0).get(0)).get());
|
||||
assertEquals("hello world",execute.get(0).get(1).toString());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testWithSetupRun() throws Exception {
|
||||
|
||||
PythonTransform pythonTransform = PythonTransform.builder()
|
||||
.code("five=None\n" +
|
||||
"def setup():\n" +
|
||||
" global five\n"+
|
||||
" five = 5\n\n" +
|
||||
"def run(a, b):\n" +
|
||||
" c = a + b + five\n"+
|
||||
" return {'c':c}\n\n")
|
||||
.returnAllInputs(true)
|
||||
.setupAndRun(true)
|
||||
.build();
|
||||
|
||||
List<List<Writable>> inputs = new ArrayList<>();
|
||||
inputs.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.scalar(1).reshape(1,1)),
|
||||
new NDArrayWritable(Nd4j.scalar(2).reshape(1,1))));
|
||||
Schema inputSchema = new Builder()
|
||||
.addColumnNDArray("a",new long[]{1,1})
|
||||
.addColumnNDArray("b", new long[]{1, 1})
|
||||
.build();
|
||||
|
||||
TransformProcess tp = new TransformProcess.Builder(inputSchema)
|
||||
.transform(pythonTransform)
|
||||
.build();
|
||||
List<List<Writable>> execute = LocalTransformExecutor.execute(inputs, tp);
|
||||
assertFalse(execute.isEmpty());
|
||||
assertNotNull(execute.get(0));
|
||||
assertNotNull(execute.get(0).get(0));
|
||||
assertEquals(Nd4j.scalar(8).reshape(1, 1),((NDArrayWritable)execute.get(0).get(3)).get());
|
||||
}
|
||||
|
||||
}
|
|
@ -28,11 +28,11 @@ import org.datavec.api.writable.*;
|
|||
|
||||
|
||||
import org.datavec.local.transforms.LocalTransformExecutor;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
public class TestJoin {
|
||||
|
||||
|
|
|
@ -31,13 +31,13 @@ import org.datavec.api.writable.comparator.DoubleWritableComparator;
|
|||
|
||||
|
||||
import org.datavec.local.transforms.LocalTransformExecutor;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
public class TestCalculateSortedRank {
|
||||
|
||||
|
|
|
@ -31,14 +31,14 @@ import org.datavec.api.writable.Writable;
|
|||
|
||||
import org.datavec.arrow.recordreader.ArrowWritableRecordTimeSeriesBatch;
|
||||
import org.datavec.local.transforms.LocalTransformExecutor;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
public class TestConvertToSequence {
|
||||
|
||||
|
|
|
@ -41,6 +41,12 @@
|
|||
</properties>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>com.tdunning</groupId>
|
||||
<artifactId>t-digest</artifactId>
|
||||
<version>3.2</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.scala-lang</groupId>
|
||||
<artifactId>scala-library</artifactId>
|
||||
|
@ -122,10 +128,10 @@
|
|||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<id>nd4j-tests-cpu</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
<id>nd4j-tests-cuda</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue