commit
c505a11ed6
|
@ -31,7 +31,7 @@ jobs:
|
||||||
protoc --version
|
protoc --version
|
||||||
cd dl4j-test-resources-master && mvn clean install -DskipTests && cd ..
|
cd dl4j-test-resources-master && mvn clean install -DskipTests && cd ..
|
||||||
export OMP_NUM_THREADS=1
|
export OMP_NUM_THREADS=1
|
||||||
mvn -Pintegration-tests -Ptestresources -Dlibnd4j.buildthreads=1 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
|
mvn -Pintegration-tests -Ptestresources -Dlibnd4j.buildthreads=1 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
|
||||||
|
|
||||||
windows-x86_64:
|
windows-x86_64:
|
||||||
runs-on: windows-2019
|
runs-on: windows-2019
|
||||||
|
@ -44,7 +44,7 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
set "PATH=C:\msys64\usr\bin;%PATH%"
|
set "PATH=C:\msys64\usr\bin;%PATH%"
|
||||||
export OMP_NUM_THREADS=1
|
export OMP_NUM_THREADS=1
|
||||||
mvn -DskipTestResourceEnforcement=true -Pintegration-tests -Ptestresources -Dlibnd4j.buildthreads=1 -Dlibnd4j.build="Debug" -Djavacpp.platform=windows-x86_64 -libnd4j.platform=windows-x86_64 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
|
mvn -DskipTestResourceEnforcement=true -Pintegration-tests -Ptestresources -Dlibnd4j.buildthreads=1 -Dlibnd4j.build="Debug" -Djavacpp.platform=windows-x86_64 -libnd4j.platform=windows-x86_64 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -60,5 +60,5 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
brew install unzip ccache gcc swig autoconf-archive automake cmake libomp libtool libusb ant maven nasm xz pkg-config sdl gpg1 bison flex perl ragel binutils gradle gmp isl libmpc mpfr wget python
|
brew install unzip ccache gcc swig autoconf-archive automake cmake libomp libtool libusb ant maven nasm xz pkg-config sdl gpg1 bison flex perl ragel binutils gradle gmp isl libmpc mpfr wget python
|
||||||
export OMP_NUM_THREADS=1
|
export OMP_NUM_THREADS=1
|
||||||
mvn -Pintegration-tests -Dlibnd4j.build="Debug" -Dlibnd4j.buildthreads=1 -Ptestresources -Djavacpp.platform=macosx-x86_64 -libnd4j.platform=macosx-x86_64 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
|
mvn -Pintegration-tests -Dlibnd4j.build="Debug" -Dlibnd4j.buildthreads=1 -Ptestresources -Djavacpp.platform=macosx-x86_64 -libnd4j.platform=macosx-x86_64 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
|
||||||
|
|
||||||
|
|
|
@ -31,7 +31,7 @@ jobs:
|
||||||
protoc --version
|
protoc --version
|
||||||
cd dl4j-test-resources-master && mvn clean install -DskipTests && cd ..
|
cd dl4j-test-resources-master && mvn clean install -DskipTests && cd ..
|
||||||
export OMP_NUM_THREADS=1
|
export OMP_NUM_THREADS=1
|
||||||
mvn -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Dlibnd4j.buildthreads=1 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
|
mvn -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Dlibnd4j.buildthreads=1 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
|
||||||
|
|
||||||
windows-x86_64:
|
windows-x86_64:
|
||||||
runs-on: windows-2019
|
runs-on: windows-2019
|
||||||
|
@ -44,7 +44,7 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
set "PATH=C:\msys64\usr\bin;%PATH%"
|
set "PATH=C:\msys64\usr\bin;%PATH%"
|
||||||
export OMP_NUM_THREADS=1
|
export OMP_NUM_THREADS=1
|
||||||
mvn -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -DskipTestResourceEnforcement=true -Ptestresources -Dlibnd4j.buildthreads=1 -Dlibnd4j.build="Debug" -Djavacpp.platform=windows-x86_64 -libnd4j.platform=windows-x86_64 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
|
mvn -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -DskipTestResourceEnforcement=true -Ptestresources -Dlibnd4j.buildthreads=1 -Dlibnd4j.build="Debug" -Djavacpp.platform=windows-x86_64 -libnd4j.platform=windows-x86_64 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -60,5 +60,5 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
brew install unzip ccache gcc swig autoconf-archive automake cmake libomp libtool libusb ant maven nasm xz pkg-config sdl gpg1 bison flex perl ragel binutils gradle gmp isl libmpc mpfr wget python
|
brew install unzip ccache gcc swig autoconf-archive automake cmake libomp libtool libusb ant maven nasm xz pkg-config sdl gpg1 bison flex perl ragel binutils gradle gmp isl libmpc mpfr wget python
|
||||||
export OMP_NUM_THREADS=1
|
export OMP_NUM_THREADS=1
|
||||||
mvn -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Dlibnd4j.build="Debug" -Dlibnd4j.buildthreads=1 -Ptestresources -Djavacpp.platform=macosx-x86_64 -libnd4j.platform=macosx-x86_64 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
|
mvn -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Dlibnd4j.build="Debug" -Dlibnd4j.buildthreads=1 -Ptestresources -Djavacpp.platform=macosx-x86_64 -libnd4j.platform=macosx-x86_64 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
|
||||||
|
|
||||||
|
|
|
@ -1,11 +1,3 @@
|
||||||
on:
|
|
||||||
workflow_dispatch:
|
|
||||||
jobs:
|
|
||||||
# Wait for up to a minute for previous run to complete, abort if not done by then
|
|
||||||
pre-ci:
|
|
||||||
run
|
|
||||||
|
|
||||||
|
|
||||||
on:
|
on:
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
jobs:
|
jobs:
|
||||||
|
@ -42,5 +34,5 @@ jobs:
|
||||||
cmake --version
|
cmake --version
|
||||||
protoc --version
|
protoc --version
|
||||||
export OMP_NUM_THREADS=1
|
export OMP_NUM_THREADS=1
|
||||||
mvn -DskipTestResourceEnforcement=true -Ptestresources -Pintegration-tests -Pdl4j-integration-tests -Pnd4j-tests-cpu clean test
|
mvn -DskipTestResourceEnforcement=true -Ptestresources -Pintegration-tests -Pnd4j-tests-cpu clean test -rf :rl4j-core
|
||||||
|
|
||||||
|
|
|
@ -34,5 +34,5 @@ jobs:
|
||||||
cmake --version
|
cmake --version
|
||||||
protoc --version
|
protoc --version
|
||||||
export OMP_NUM_THREADS=1
|
export OMP_NUM_THREADS=1
|
||||||
mvn -DskipTestResourceEnforcement=true -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Ptest-nd4j-native --also-make clean test
|
mvn -DskipTestResourceEnforcement=true -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Pnd4j-tests-cpu --also-make clean test
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
<commons.dbutils.version>1.7</commons.dbutils.version>
|
<commons.dbutils.version>1.7</commons.dbutils.version>
|
||||||
<lombok.version>1.18.8</lombok.version>
|
<lombok.version>1.18.8</lombok.version>
|
||||||
<logback.version>1.1.7</logback.version>
|
<logback.version>1.1.7</logback.version>
|
||||||
<junit.version>4.12</junit.version>
|
<junit.version>5.8.0-M1</junit.version>
|
||||||
<junit-jupiter.version>5.4.2</junit-jupiter.version>
|
<junit-jupiter.version>5.4.2</junit-jupiter.version>
|
||||||
<java.version>1.8</java.version>
|
<java.version>1.8</java.version>
|
||||||
<maven-shade-plugin.version>3.1.1</maven-shade-plugin.version>
|
<maven-shade-plugin.version>3.1.1</maven-shade-plugin.version>
|
||||||
|
|
|
@ -17,13 +17,14 @@
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
* *****************************************************************************
|
* *****************************************************************************
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.nd4j.codegen.ir;
|
package org.nd4j.codegen.ir;
|
||||||
|
|
||||||
public class SerializationTest {
|
import org.junit.jupiter.api.DisplayName;
|
||||||
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
|
||||||
public static void main(String...args) {
|
@DisplayName("Serialization Test")
|
||||||
|
class SerializationTest {
|
||||||
|
|
||||||
|
public static void main(String... args) {
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,29 +17,23 @@
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
* *****************************************************************************
|
* *****************************************************************************
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.nd4j.codegen.dsl;
|
package org.nd4j.codegen.dsl;
|
||||||
|
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.codegen.impl.java.DocsGenerator;
|
import org.nd4j.codegen.impl.java.DocsGenerator;
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import org.junit.jupiter.api.DisplayName;
|
||||||
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
|
||||||
public class DocsGeneratorTest {
|
@DisplayName("Docs Generator Test")
|
||||||
|
class DocsGeneratorTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testJDtoMDAdapter() {
|
@DisplayName("Test J Dto MD Adapter")
|
||||||
String original = "{@code %INPUT_TYPE% eye = eye(3,2)\n" +
|
void testJDtoMDAdapter() {
|
||||||
" eye:\n" +
|
String original = "{@code %INPUT_TYPE% eye = eye(3,2)\n" + " eye:\n" + " [ 1, 0]\n" + " [ 0, 1]\n" + " [ 0, 0]}";
|
||||||
" [ 1, 0]\n" +
|
String expected = "{ INDArray eye = eye(3,2)\n" + " eye:\n" + " [ 1, 0]\n" + " [ 0, 1]\n" + " [ 0, 0]}";
|
||||||
" [ 0, 1]\n" +
|
|
||||||
" [ 0, 0]}";
|
|
||||||
String expected = "{ INDArray eye = eye(3,2)\n" +
|
|
||||||
" eye:\n" +
|
|
||||||
" [ 1, 0]\n" +
|
|
||||||
" [ 0, 1]\n" +
|
|
||||||
" [ 0, 0]}";
|
|
||||||
DocsGenerator.JavaDocToMDAdapter adapter = new DocsGenerator.JavaDocToMDAdapter(original);
|
DocsGenerator.JavaDocToMDAdapter adapter = new DocsGenerator.JavaDocToMDAdapter(original);
|
||||||
String out = adapter.filter("@code", StringUtils.EMPTY).filter("%INPUT_TYPE%", "INDArray").toString();
|
String out = adapter.filter("@code", StringUtils.EMPTY).filter("%INPUT_TYPE%", "INDArray").toString();
|
||||||
assertEquals(out, expected);
|
assertEquals(out, expected);
|
||||||
|
|
|
@ -34,6 +34,14 @@
|
||||||
<artifactId>datavec-api</artifactId>
|
<artifactId>datavec-api</artifactId>
|
||||||
|
|
||||||
<dependencies>
|
<dependencies>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.junit.jupiter</groupId>
|
||||||
|
<artifactId>junit-jupiter-api</artifactId>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.junit.vintage</groupId>
|
||||||
|
<artifactId>junit-vintage-engine</artifactId>
|
||||||
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.apache.commons</groupId>
|
<groupId>org.apache.commons</groupId>
|
||||||
<artifactId>commons-lang3</artifactId>
|
<artifactId>commons-lang3</artifactId>
|
||||||
|
@ -101,10 +109,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -17,7 +17,6 @@
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
* *****************************************************************************
|
* *****************************************************************************
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.datavec.api.records.reader.impl;
|
package org.datavec.api.records.reader.impl;
|
||||||
|
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
|
@ -26,47 +25,38 @@ import org.datavec.api.records.reader.impl.csv.CSVLineSequenceRecordReader;
|
||||||
import org.datavec.api.split.FileSplit;
|
import org.datavec.api.split.FileSplit;
|
||||||
import org.datavec.api.writable.Text;
|
import org.datavec.api.writable.Text;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.junit.rules.TemporaryFolder;
|
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
|
||||||
|
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.nio.charset.StandardCharsets;
|
import java.nio.charset.StandardCharsets;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import org.junit.jupiter.api.DisplayName;
|
||||||
|
import java.nio.file.Path;
|
||||||
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
@DisplayName("Csv Line Sequence Record Reader Test")
|
||||||
|
class CSVLineSequenceRecordReaderTest extends BaseND4JTest {
|
||||||
|
|
||||||
public class CSVLineSequenceRecordReaderTest extends BaseND4JTest {
|
@TempDir
|
||||||
|
public Path testDir;
|
||||||
@Rule
|
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void test() throws Exception {
|
@DisplayName("Test")
|
||||||
|
void test(@TempDir Path testDir) throws Exception {
|
||||||
File f = testDir.newFolder();
|
File f = testDir.toFile();
|
||||||
File source = new File(f, "temp.csv");
|
File source = new File(f, "temp.csv");
|
||||||
String str = "a,b,c\n1,2,3,4";
|
String str = "a,b,c\n1,2,3,4";
|
||||||
FileUtils.writeStringToFile(source, str, StandardCharsets.UTF_8);
|
FileUtils.writeStringToFile(source, str, StandardCharsets.UTF_8);
|
||||||
|
|
||||||
SequenceRecordReader rr = new CSVLineSequenceRecordReader();
|
SequenceRecordReader rr = new CSVLineSequenceRecordReader();
|
||||||
rr.initialize(new FileSplit(source));
|
rr.initialize(new FileSplit(source));
|
||||||
|
List<List<Writable>> exp0 = Arrays.asList(Collections.<Writable>singletonList(new Text("a")), Collections.<Writable>singletonList(new Text("b")), Collections.<Writable>singletonList(new Text("c")));
|
||||||
List<List<Writable>> exp0 = Arrays.asList(
|
List<List<Writable>> exp1 = Arrays.asList(Collections.<Writable>singletonList(new Text("1")), Collections.<Writable>singletonList(new Text("2")), Collections.<Writable>singletonList(new Text("3")), Collections.<Writable>singletonList(new Text("4")));
|
||||||
Collections.<Writable>singletonList(new Text("a")),
|
for (int i = 0; i < 3; i++) {
|
||||||
Collections.<Writable>singletonList(new Text("b")),
|
|
||||||
Collections.<Writable>singletonList(new Text("c")));
|
|
||||||
|
|
||||||
List<List<Writable>> exp1 = Arrays.asList(
|
|
||||||
Collections.<Writable>singletonList(new Text("1")),
|
|
||||||
Collections.<Writable>singletonList(new Text("2")),
|
|
||||||
Collections.<Writable>singletonList(new Text("3")),
|
|
||||||
Collections.<Writable>singletonList(new Text("4")));
|
|
||||||
|
|
||||||
for( int i=0; i<3; i++ ) {
|
|
||||||
int count = 0;
|
int count = 0;
|
||||||
while (rr.hasNext()) {
|
while (rr.hasNext()) {
|
||||||
List<List<Writable>> next = rr.sequenceRecord();
|
List<List<Writable>> next = rr.sequenceRecord();
|
||||||
|
@ -76,9 +66,7 @@ public class CSVLineSequenceRecordReaderTest extends BaseND4JTest {
|
||||||
assertEquals(exp1, next);
|
assertEquals(exp1, next);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(2, count);
|
assertEquals(2, count);
|
||||||
|
|
||||||
rr.reset();
|
rr.reset();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,7 +17,6 @@
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
* *****************************************************************************
|
* *****************************************************************************
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.datavec.api.records.reader.impl;
|
package org.datavec.api.records.reader.impl;
|
||||||
|
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
|
@ -26,33 +25,37 @@ import org.datavec.api.records.reader.impl.csv.CSVMultiSequenceRecordReader;
|
||||||
import org.datavec.api.split.FileSplit;
|
import org.datavec.api.split.FileSplit;
|
||||||
import org.datavec.api.writable.Text;
|
import org.datavec.api.writable.Text;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.junit.rules.TemporaryFolder;
|
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
|
||||||
|
|
||||||
|
import org.junit.jupiter.api.Disabled;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.nio.charset.StandardCharsets;
|
import java.nio.charset.StandardCharsets;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||||
|
import org.junit.jupiter.api.DisplayName;
|
||||||
|
import java.nio.file.Path;
|
||||||
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
@DisplayName("Csv Multi Sequence Record Reader Test")
|
||||||
import static org.junit.Assert.assertFalse;
|
class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
|
||||||
|
|
||||||
public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
|
@TempDir
|
||||||
|
public Path testDir;
|
||||||
@Rule
|
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testConcatMode() throws Exception {
|
@DisplayName("Test Concat Mode")
|
||||||
for( int i=0; i<3; i++ ) {
|
@Disabled
|
||||||
|
void testConcatMode() throws Exception {
|
||||||
|
for (int i = 0; i < 3; i++) {
|
||||||
String seqSep;
|
String seqSep;
|
||||||
String seqSepRegex;
|
String seqSepRegex;
|
||||||
switch (i){
|
switch(i) {
|
||||||
case 0:
|
case 0:
|
||||||
seqSep = "";
|
seqSep = "";
|
||||||
seqSepRegex = "^$";
|
seqSepRegex = "^$";
|
||||||
|
@ -68,31 +71,23 @@ public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
|
||||||
default:
|
default:
|
||||||
throw new RuntimeException();
|
throw new RuntimeException();
|
||||||
}
|
}
|
||||||
|
|
||||||
String str = "a,b,c\n1,2,3,4\nx,y\n" + seqSep + "\nA,B,C";
|
String str = "a,b,c\n1,2,3,4\nx,y\n" + seqSep + "\nA,B,C";
|
||||||
File f = testDir.newFile();
|
File f = testDir.toFile();
|
||||||
FileUtils.writeStringToFile(f, str, StandardCharsets.UTF_8);
|
FileUtils.writeStringToFile(f, str, StandardCharsets.UTF_8);
|
||||||
|
|
||||||
SequenceRecordReader seqRR = new CSVMultiSequenceRecordReader(seqSepRegex, CSVMultiSequenceRecordReader.Mode.CONCAT);
|
SequenceRecordReader seqRR = new CSVMultiSequenceRecordReader(seqSepRegex, CSVMultiSequenceRecordReader.Mode.CONCAT);
|
||||||
seqRR.initialize(new FileSplit(f));
|
seqRR.initialize(new FileSplit(f));
|
||||||
|
|
||||||
|
|
||||||
List<List<Writable>> exp0 = new ArrayList<>();
|
List<List<Writable>> exp0 = new ArrayList<>();
|
||||||
for (String s : "a,b,c,1,2,3,4,x,y".split(",")) {
|
for (String s : "a,b,c,1,2,3,4,x,y".split(",")) {
|
||||||
exp0.add(Collections.<Writable>singletonList(new Text(s)));
|
exp0.add(Collections.<Writable>singletonList(new Text(s)));
|
||||||
}
|
}
|
||||||
|
|
||||||
List<List<Writable>> exp1 = new ArrayList<>();
|
List<List<Writable>> exp1 = new ArrayList<>();
|
||||||
for (String s : "A,B,C".split(",")) {
|
for (String s : "A,B,C".split(",")) {
|
||||||
exp1.add(Collections.<Writable>singletonList(new Text(s)));
|
exp1.add(Collections.<Writable>singletonList(new Text(s)));
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(exp0, seqRR.sequenceRecord());
|
assertEquals(exp0, seqRR.sequenceRecord());
|
||||||
assertEquals(exp1, seqRR.sequenceRecord());
|
assertEquals(exp1, seqRR.sequenceRecord());
|
||||||
assertFalse(seqRR.hasNext());
|
assertFalse(seqRR.hasNext());
|
||||||
|
|
||||||
seqRR.reset();
|
seqRR.reset();
|
||||||
|
|
||||||
assertEquals(exp0, seqRR.sequenceRecord());
|
assertEquals(exp0, seqRR.sequenceRecord());
|
||||||
assertEquals(exp1, seqRR.sequenceRecord());
|
assertEquals(exp1, seqRR.sequenceRecord());
|
||||||
assertFalse(seqRR.hasNext());
|
assertFalse(seqRR.hasNext());
|
||||||
|
@ -100,13 +95,13 @@ public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testEqualLength() throws Exception {
|
@DisplayName("Test Equal Length")
|
||||||
|
@Disabled
|
||||||
for( int i=0; i<3; i++ ) {
|
void testEqualLength() throws Exception {
|
||||||
|
for (int i = 0; i < 3; i++) {
|
||||||
String seqSep;
|
String seqSep;
|
||||||
String seqSepRegex;
|
String seqSepRegex;
|
||||||
switch (i) {
|
switch(i) {
|
||||||
case 0:
|
case 0:
|
||||||
seqSep = "";
|
seqSep = "";
|
||||||
seqSepRegex = "^$";
|
seqSepRegex = "^$";
|
||||||
|
@ -122,27 +117,17 @@ public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
|
||||||
default:
|
default:
|
||||||
throw new RuntimeException();
|
throw new RuntimeException();
|
||||||
}
|
}
|
||||||
|
|
||||||
String str = "a,b\n1,2\nx,y\n" + seqSep + "\nA\nB\nC";
|
String str = "a,b\n1,2\nx,y\n" + seqSep + "\nA\nB\nC";
|
||||||
File f = testDir.newFile();
|
File f = testDir.toFile();
|
||||||
FileUtils.writeStringToFile(f, str, StandardCharsets.UTF_8);
|
FileUtils.writeStringToFile(f, str, StandardCharsets.UTF_8);
|
||||||
|
|
||||||
SequenceRecordReader seqRR = new CSVMultiSequenceRecordReader(seqSepRegex, CSVMultiSequenceRecordReader.Mode.EQUAL_LENGTH);
|
SequenceRecordReader seqRR = new CSVMultiSequenceRecordReader(seqSepRegex, CSVMultiSequenceRecordReader.Mode.EQUAL_LENGTH);
|
||||||
seqRR.initialize(new FileSplit(f));
|
seqRR.initialize(new FileSplit(f));
|
||||||
|
List<List<Writable>> exp0 = Arrays.asList(Arrays.<Writable>asList(new Text("a"), new Text("1"), new Text("x")), Arrays.<Writable>asList(new Text("b"), new Text("2"), new Text("y")));
|
||||||
|
|
||||||
List<List<Writable>> exp0 = Arrays.asList(
|
|
||||||
Arrays.<Writable>asList(new Text("a"), new Text("1"), new Text("x")),
|
|
||||||
Arrays.<Writable>asList(new Text("b"), new Text("2"), new Text("y")));
|
|
||||||
|
|
||||||
List<List<Writable>> exp1 = Collections.singletonList(Arrays.<Writable>asList(new Text("A"), new Text("B"), new Text("C")));
|
List<List<Writable>> exp1 = Collections.singletonList(Arrays.<Writable>asList(new Text("A"), new Text("B"), new Text("C")));
|
||||||
|
|
||||||
assertEquals(exp0, seqRR.sequenceRecord());
|
assertEquals(exp0, seqRR.sequenceRecord());
|
||||||
assertEquals(exp1, seqRR.sequenceRecord());
|
assertEquals(exp1, seqRR.sequenceRecord());
|
||||||
assertFalse(seqRR.hasNext());
|
assertFalse(seqRR.hasNext());
|
||||||
|
|
||||||
seqRR.reset();
|
seqRR.reset();
|
||||||
|
|
||||||
assertEquals(exp0, seqRR.sequenceRecord());
|
assertEquals(exp0, seqRR.sequenceRecord());
|
||||||
assertEquals(exp1, seqRR.sequenceRecord());
|
assertEquals(exp1, seqRR.sequenceRecord());
|
||||||
assertFalse(seqRR.hasNext());
|
assertFalse(seqRR.hasNext());
|
||||||
|
@ -150,13 +135,13 @@ public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testPadding() throws Exception {
|
@DisplayName("Test Padding")
|
||||||
|
@Disabled
|
||||||
for( int i=0; i<3; i++ ) {
|
void testPadding() throws Exception {
|
||||||
|
for (int i = 0; i < 3; i++) {
|
||||||
String seqSep;
|
String seqSep;
|
||||||
String seqSepRegex;
|
String seqSepRegex;
|
||||||
switch (i) {
|
switch(i) {
|
||||||
case 0:
|
case 0:
|
||||||
seqSep = "";
|
seqSep = "";
|
||||||
seqSepRegex = "^$";
|
seqSepRegex = "^$";
|
||||||
|
@ -172,27 +157,17 @@ public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
|
||||||
default:
|
default:
|
||||||
throw new RuntimeException();
|
throw new RuntimeException();
|
||||||
}
|
}
|
||||||
|
|
||||||
String str = "a,b\n1\nx\n" + seqSep + "\nA\nB\nC";
|
String str = "a,b\n1\nx\n" + seqSep + "\nA\nB\nC";
|
||||||
File f = testDir.newFile();
|
File f = testDir.toFile();
|
||||||
FileUtils.writeStringToFile(f, str, StandardCharsets.UTF_8);
|
FileUtils.writeStringToFile(f, str, StandardCharsets.UTF_8);
|
||||||
|
|
||||||
SequenceRecordReader seqRR = new CSVMultiSequenceRecordReader(seqSepRegex, CSVMultiSequenceRecordReader.Mode.PAD, new Text("PAD"));
|
SequenceRecordReader seqRR = new CSVMultiSequenceRecordReader(seqSepRegex, CSVMultiSequenceRecordReader.Mode.PAD, new Text("PAD"));
|
||||||
seqRR.initialize(new FileSplit(f));
|
seqRR.initialize(new FileSplit(f));
|
||||||
|
List<List<Writable>> exp0 = Arrays.asList(Arrays.<Writable>asList(new Text("a"), new Text("1"), new Text("x")), Arrays.<Writable>asList(new Text("b"), new Text("PAD"), new Text("PAD")));
|
||||||
|
|
||||||
List<List<Writable>> exp0 = Arrays.asList(
|
|
||||||
Arrays.<Writable>asList(new Text("a"), new Text("1"), new Text("x")),
|
|
||||||
Arrays.<Writable>asList(new Text("b"), new Text("PAD"), new Text("PAD")));
|
|
||||||
|
|
||||||
List<List<Writable>> exp1 = Collections.singletonList(Arrays.<Writable>asList(new Text("A"), new Text("B"), new Text("C")));
|
List<List<Writable>> exp1 = Collections.singletonList(Arrays.<Writable>asList(new Text("A"), new Text("B"), new Text("C")));
|
||||||
|
|
||||||
assertEquals(exp0, seqRR.sequenceRecord());
|
assertEquals(exp0, seqRR.sequenceRecord());
|
||||||
assertEquals(exp1, seqRR.sequenceRecord());
|
assertEquals(exp1, seqRR.sequenceRecord());
|
||||||
assertFalse(seqRR.hasNext());
|
assertFalse(seqRR.hasNext());
|
||||||
|
|
||||||
seqRR.reset();
|
seqRR.reset();
|
||||||
|
|
||||||
assertEquals(exp0, seqRR.sequenceRecord());
|
assertEquals(exp0, seqRR.sequenceRecord());
|
||||||
assertEquals(exp1, seqRR.sequenceRecord());
|
assertEquals(exp1, seqRR.sequenceRecord());
|
||||||
assertFalse(seqRR.hasNext());
|
assertFalse(seqRR.hasNext());
|
||||||
|
|
|
@ -17,7 +17,6 @@
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
* *****************************************************************************
|
* *****************************************************************************
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.datavec.api.records.reader.impl;
|
package org.datavec.api.records.reader.impl;
|
||||||
|
|
||||||
import org.datavec.api.records.SequenceRecord;
|
import org.datavec.api.records.SequenceRecord;
|
||||||
|
@ -27,61 +26,53 @@ import org.datavec.api.records.reader.impl.csv.CSVNLinesSequenceRecordReader;
|
||||||
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
|
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
|
||||||
import org.datavec.api.split.FileSplit;
|
import org.datavec.api.split.FileSplit;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
import org.nd4j.common.io.ClassPathResource;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import org.junit.jupiter.api.DisplayName;
|
||||||
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
@DisplayName("Csvn Lines Sequence Record Reader Test")
|
||||||
|
class CSVNLinesSequenceRecordReaderTest extends BaseND4JTest {
|
||||||
public class CSVNLinesSequenceRecordReaderTest extends BaseND4JTest {
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCSVNLinesSequenceRecordReader() throws Exception {
|
@DisplayName("Test CSVN Lines Sequence Record Reader")
|
||||||
|
void testCSVNLinesSequenceRecordReader() throws Exception {
|
||||||
int nLinesPerSequence = 10;
|
int nLinesPerSequence = 10;
|
||||||
|
|
||||||
SequenceRecordReader seqRR = new CSVNLinesSequenceRecordReader(nLinesPerSequence);
|
SequenceRecordReader seqRR = new CSVNLinesSequenceRecordReader(nLinesPerSequence);
|
||||||
seqRR.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
|
seqRR.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
|
||||||
|
|
||||||
CSVRecordReader rr = new CSVRecordReader();
|
CSVRecordReader rr = new CSVRecordReader();
|
||||||
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
|
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
|
||||||
|
|
||||||
int count = 0;
|
int count = 0;
|
||||||
while (seqRR.hasNext()) {
|
while (seqRR.hasNext()) {
|
||||||
List<List<Writable>> next = seqRR.sequenceRecord();
|
List<List<Writable>> next = seqRR.sequenceRecord();
|
||||||
|
|
||||||
List<List<Writable>> expected = new ArrayList<>();
|
List<List<Writable>> expected = new ArrayList<>();
|
||||||
for (int i = 0; i < nLinesPerSequence; i++) {
|
for (int i = 0; i < nLinesPerSequence; i++) {
|
||||||
expected.add(rr.next());
|
expected.add(rr.next());
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(10, next.size());
|
assertEquals(10, next.size());
|
||||||
assertEquals(expected, next);
|
assertEquals(expected, next);
|
||||||
|
|
||||||
count++;
|
count++;
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(150 / nLinesPerSequence, count);
|
assertEquals(150 / nLinesPerSequence, count);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCSVNlinesSequenceRecordReaderMetaData() throws Exception {
|
@DisplayName("Test CSV Nlines Sequence Record Reader Meta Data")
|
||||||
|
void testCSVNlinesSequenceRecordReaderMetaData() throws Exception {
|
||||||
int nLinesPerSequence = 10;
|
int nLinesPerSequence = 10;
|
||||||
|
|
||||||
SequenceRecordReader seqRR = new CSVNLinesSequenceRecordReader(nLinesPerSequence);
|
SequenceRecordReader seqRR = new CSVNLinesSequenceRecordReader(nLinesPerSequence);
|
||||||
seqRR.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
|
seqRR.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
|
||||||
|
|
||||||
CSVRecordReader rr = new CSVRecordReader();
|
CSVRecordReader rr = new CSVRecordReader();
|
||||||
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
|
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
|
||||||
|
|
||||||
List<List<List<Writable>>> out = new ArrayList<>();
|
List<List<List<Writable>>> out = new ArrayList<>();
|
||||||
while (seqRR.hasNext()) {
|
while (seqRR.hasNext()) {
|
||||||
List<List<Writable>> next = seqRR.sequenceRecord();
|
List<List<Writable>> next = seqRR.sequenceRecord();
|
||||||
out.add(next);
|
out.add(next);
|
||||||
}
|
}
|
||||||
|
|
||||||
seqRR.reset();
|
seqRR.reset();
|
||||||
List<List<List<Writable>>> out2 = new ArrayList<>();
|
List<List<List<Writable>>> out2 = new ArrayList<>();
|
||||||
List<SequenceRecord> out3 = new ArrayList<>();
|
List<SequenceRecord> out3 = new ArrayList<>();
|
||||||
|
@ -92,11 +83,8 @@ public class CSVNLinesSequenceRecordReaderTest extends BaseND4JTest {
|
||||||
meta.add(seq.getMetaData());
|
meta.add(seq.getMetaData());
|
||||||
out3.add(seq);
|
out3.add(seq);
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(out, out2);
|
assertEquals(out, out2);
|
||||||
|
|
||||||
List<SequenceRecord> out4 = seqRR.loadSequenceFromMetaData(meta);
|
List<SequenceRecord> out4 = seqRR.loadSequenceFromMetaData(meta);
|
||||||
assertEquals(out3, out4);
|
assertEquals(out3, out4);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,7 +17,6 @@
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
* *****************************************************************************
|
* *****************************************************************************
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.datavec.api.records.reader.impl;
|
package org.datavec.api.records.reader.impl;
|
||||||
|
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
|
@ -34,10 +33,10 @@ import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
|
||||||
import org.datavec.api.writable.IntWritable;
|
import org.datavec.api.writable.IntWritable;
|
||||||
import org.datavec.api.writable.Text;
|
import org.datavec.api.writable.Text;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.DisplayName;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
import org.nd4j.common.io.ClassPathResource;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.nio.file.Files;
|
import java.nio.file.Files;
|
||||||
|
@ -47,41 +46,44 @@ import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.NoSuchElementException;
|
import java.util.NoSuchElementException;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
|
|
||||||
|
@DisplayName("Csv Record Reader Test")
|
||||||
|
class CSVRecordReaderTest extends BaseND4JTest {
|
||||||
|
|
||||||
public class CSVRecordReaderTest extends BaseND4JTest {
|
|
||||||
@Test
|
@Test
|
||||||
public void testNext() throws Exception {
|
@DisplayName("Test Next")
|
||||||
|
void testNext() throws Exception {
|
||||||
CSVRecordReader reader = new CSVRecordReader();
|
CSVRecordReader reader = new CSVRecordReader();
|
||||||
reader.initialize(new StringSplit("1,1,8.0,,,,14.0,,,,15.0,,,,,,,,,,,,1"));
|
reader.initialize(new StringSplit("1,1,8.0,,,,14.0,,,,15.0,,,,,,,,,,,,1"));
|
||||||
while (reader.hasNext()) {
|
while (reader.hasNext()) {
|
||||||
List<Writable> vals = reader.next();
|
List<Writable> vals = reader.next();
|
||||||
List<Writable> arr = new ArrayList<>(vals);
|
List<Writable> arr = new ArrayList<>(vals);
|
||||||
|
assertEquals(23, vals.size(), "Entry count");
|
||||||
assertEquals("Entry count", 23, vals.size());
|
|
||||||
Text lastEntry = (Text) arr.get(arr.size() - 1);
|
Text lastEntry = (Text) arr.get(arr.size() - 1);
|
||||||
assertEquals("Last entry garbage", 1, lastEntry.getLength());
|
assertEquals(1, lastEntry.getLength(), "Last entry garbage");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testEmptyEntries() throws Exception {
|
@DisplayName("Test Empty Entries")
|
||||||
|
void testEmptyEntries() throws Exception {
|
||||||
CSVRecordReader reader = new CSVRecordReader();
|
CSVRecordReader reader = new CSVRecordReader();
|
||||||
reader.initialize(new StringSplit("1,1,8.0,,,,14.0,,,,15.0,,,,,,,,,,,,"));
|
reader.initialize(new StringSplit("1,1,8.0,,,,14.0,,,,15.0,,,,,,,,,,,,"));
|
||||||
while (reader.hasNext()) {
|
while (reader.hasNext()) {
|
||||||
List<Writable> vals = reader.next();
|
List<Writable> vals = reader.next();
|
||||||
assertEquals("Entry count", 23, vals.size());
|
assertEquals(23, vals.size(), "Entry count");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testReset() throws Exception {
|
@DisplayName("Test Reset")
|
||||||
|
void testReset() throws Exception {
|
||||||
CSVRecordReader rr = new CSVRecordReader(0, ',');
|
CSVRecordReader rr = new CSVRecordReader(0, ',');
|
||||||
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
|
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
|
||||||
|
|
||||||
int nResets = 5;
|
int nResets = 5;
|
||||||
for (int i = 0; i < nResets; i++) {
|
for (int i = 0; i < nResets; i++) {
|
||||||
|
|
||||||
int lineCount = 0;
|
int lineCount = 0;
|
||||||
while (rr.hasNext()) {
|
while (rr.hasNext()) {
|
||||||
List<Writable> line = rr.next();
|
List<Writable> line = rr.next();
|
||||||
|
@ -95,7 +97,8 @@ public class CSVRecordReaderTest extends BaseND4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testResetWithSkipLines() throws Exception {
|
@DisplayName("Test Reset With Skip Lines")
|
||||||
|
void testResetWithSkipLines() throws Exception {
|
||||||
CSVRecordReader rr = new CSVRecordReader(10, ',');
|
CSVRecordReader rr = new CSVRecordReader(10, ',');
|
||||||
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
|
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
|
||||||
int lineCount = 0;
|
int lineCount = 0;
|
||||||
|
@ -114,7 +117,8 @@ public class CSVRecordReaderTest extends BaseND4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testWrite() throws Exception {
|
@DisplayName("Test Write")
|
||||||
|
void testWrite() throws Exception {
|
||||||
List<List<Writable>> list = new ArrayList<>();
|
List<List<Writable>> list = new ArrayList<>();
|
||||||
StringBuilder sb = new StringBuilder();
|
StringBuilder sb = new StringBuilder();
|
||||||
for (int i = 0; i < 10; i++) {
|
for (int i = 0; i < 10; i++) {
|
||||||
|
@ -130,81 +134,72 @@ public class CSVRecordReaderTest extends BaseND4JTest {
|
||||||
}
|
}
|
||||||
list.add(temp);
|
list.add(temp);
|
||||||
}
|
}
|
||||||
|
|
||||||
String expected = sb.toString();
|
String expected = sb.toString();
|
||||||
|
|
||||||
Path p = Files.createTempFile("csvwritetest", "csv");
|
Path p = Files.createTempFile("csvwritetest", "csv");
|
||||||
p.toFile().deleteOnExit();
|
p.toFile().deleteOnExit();
|
||||||
|
|
||||||
FileRecordWriter writer = new CSVRecordWriter();
|
FileRecordWriter writer = new CSVRecordWriter();
|
||||||
FileSplit fileSplit = new FileSplit(p.toFile());
|
FileSplit fileSplit = new FileSplit(p.toFile());
|
||||||
writer.initialize(fileSplit,new NumberOfRecordsPartitioner());
|
writer.initialize(fileSplit, new NumberOfRecordsPartitioner());
|
||||||
for (List<Writable> c : list) {
|
for (List<Writable> c : list) {
|
||||||
writer.write(c);
|
writer.write(c);
|
||||||
}
|
}
|
||||||
writer.close();
|
writer.close();
|
||||||
|
// Read file back in; compare
|
||||||
//Read file back in; compare
|
|
||||||
String fileContents = FileUtils.readFileToString(p.toFile(), FileRecordWriter.DEFAULT_CHARSET.name());
|
String fileContents = FileUtils.readFileToString(p.toFile(), FileRecordWriter.DEFAULT_CHARSET.name());
|
||||||
|
// System.out.println(expected);
|
||||||
// System.out.println(expected);
|
// System.out.println("----------");
|
||||||
// System.out.println("----------");
|
// System.out.println(fileContents);
|
||||||
// System.out.println(fileContents);
|
|
||||||
|
|
||||||
assertEquals(expected, fileContents);
|
assertEquals(expected, fileContents);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testTabsAsSplit1() throws Exception {
|
@DisplayName("Test Tabs As Split 1")
|
||||||
|
void testTabsAsSplit1() throws Exception {
|
||||||
CSVRecordReader reader = new CSVRecordReader(0, '\t');
|
CSVRecordReader reader = new CSVRecordReader(0, '\t');
|
||||||
reader.initialize(new FileSplit(new ClassPathResource("datavec-api/tabbed.txt").getFile()));
|
reader.initialize(new FileSplit(new ClassPathResource("datavec-api/tabbed.txt").getFile()));
|
||||||
while (reader.hasNext()) {
|
while (reader.hasNext()) {
|
||||||
List<Writable> list = new ArrayList<>(reader.next());
|
List<Writable> list = new ArrayList<>(reader.next());
|
||||||
|
|
||||||
assertEquals(2, list.size());
|
assertEquals(2, list.size());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testPipesAsSplit() throws Exception {
|
@DisplayName("Test Pipes As Split")
|
||||||
|
void testPipesAsSplit() throws Exception {
|
||||||
CSVRecordReader reader = new CSVRecordReader(0, '|');
|
CSVRecordReader reader = new CSVRecordReader(0, '|');
|
||||||
reader.initialize(new FileSplit(new ClassPathResource("datavec-api/issue414.csv").getFile()));
|
reader.initialize(new FileSplit(new ClassPathResource("datavec-api/issue414.csv").getFile()));
|
||||||
int lineidx = 0;
|
int lineidx = 0;
|
||||||
List<Integer> sixthColumn = Arrays.asList(13, 95, 15, 25);
|
List<Integer> sixthColumn = Arrays.asList(13, 95, 15, 25);
|
||||||
while (reader.hasNext()) {
|
while (reader.hasNext()) {
|
||||||
List<Writable> list = new ArrayList<>(reader.next());
|
List<Writable> list = new ArrayList<>(reader.next());
|
||||||
|
|
||||||
assertEquals(10, list.size());
|
assertEquals(10, list.size());
|
||||||
assertEquals((long)sixthColumn.get(lineidx), list.get(5).toInt());
|
assertEquals((long) sixthColumn.get(lineidx), list.get(5).toInt());
|
||||||
lineidx++;
|
lineidx++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testWithQuotes() throws Exception {
|
@DisplayName("Test With Quotes")
|
||||||
|
void testWithQuotes() throws Exception {
|
||||||
CSVRecordReader reader = new CSVRecordReader(0, ',', '\"');
|
CSVRecordReader reader = new CSVRecordReader(0, ',', '\"');
|
||||||
reader.initialize(new StringSplit("1,0,3,\"Braund, Mr. Owen Harris\",male,\"\"\"\""));
|
reader.initialize(new StringSplit("1,0,3,\"Braund, Mr. Owen Harris\",male,\"\"\"\""));
|
||||||
while (reader.hasNext()) {
|
while (reader.hasNext()) {
|
||||||
List<Writable> vals = reader.next();
|
List<Writable> vals = reader.next();
|
||||||
assertEquals("Entry count", 6, vals.size());
|
assertEquals(6, vals.size(), "Entry count");
|
||||||
assertEquals("1", vals.get(0).toString());
|
assertEquals(vals.get(0).toString(), "1");
|
||||||
assertEquals("0", vals.get(1).toString());
|
assertEquals(vals.get(1).toString(), "0");
|
||||||
assertEquals("3", vals.get(2).toString());
|
assertEquals(vals.get(2).toString(), "3");
|
||||||
assertEquals("Braund, Mr. Owen Harris", vals.get(3).toString());
|
assertEquals(vals.get(3).toString(), "Braund, Mr. Owen Harris");
|
||||||
assertEquals("male", vals.get(4).toString());
|
assertEquals(vals.get(4).toString(), "male");
|
||||||
assertEquals("\"", vals.get(5).toString());
|
assertEquals(vals.get(5).toString(), "\"");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testMeta() throws Exception {
|
@DisplayName("Test Meta")
|
||||||
|
void testMeta() throws Exception {
|
||||||
CSVRecordReader rr = new CSVRecordReader(0, ',');
|
CSVRecordReader rr = new CSVRecordReader(0, ',');
|
||||||
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
|
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
|
||||||
|
|
||||||
int lineCount = 0;
|
int lineCount = 0;
|
||||||
List<RecordMetaData> metaList = new ArrayList<>();
|
List<RecordMetaData> metaList = new ArrayList<>();
|
||||||
List<List<Writable>> writables = new ArrayList<>();
|
List<List<Writable>> writables = new ArrayList<>();
|
||||||
|
@ -213,30 +208,25 @@ public class CSVRecordReaderTest extends BaseND4JTest {
|
||||||
assertEquals(5, r.getRecord().size());
|
assertEquals(5, r.getRecord().size());
|
||||||
lineCount++;
|
lineCount++;
|
||||||
RecordMetaData meta = r.getMetaData();
|
RecordMetaData meta = r.getMetaData();
|
||||||
// System.out.println(r.getRecord() + "\t" + meta.getLocation() + "\t" + meta.getURI());
|
// System.out.println(r.getRecord() + "\t" + meta.getLocation() + "\t" + meta.getURI());
|
||||||
|
|
||||||
metaList.add(meta);
|
metaList.add(meta);
|
||||||
writables.add(r.getRecord());
|
writables.add(r.getRecord());
|
||||||
}
|
}
|
||||||
assertFalse(rr.hasNext());
|
assertFalse(rr.hasNext());
|
||||||
assertEquals(150, lineCount);
|
assertEquals(150, lineCount);
|
||||||
rr.reset();
|
rr.reset();
|
||||||
|
|
||||||
|
|
||||||
System.out.println("\n\n\n--------------------------------");
|
System.out.println("\n\n\n--------------------------------");
|
||||||
List<Record> contents = rr.loadFromMetaData(metaList);
|
List<Record> contents = rr.loadFromMetaData(metaList);
|
||||||
assertEquals(150, contents.size());
|
assertEquals(150, contents.size());
|
||||||
// for(Record r : contents ){
|
// for(Record r : contents ){
|
||||||
// System.out.println(r);
|
// System.out.println(r);
|
||||||
// }
|
// }
|
||||||
|
|
||||||
List<RecordMetaData> meta2 = new ArrayList<>();
|
List<RecordMetaData> meta2 = new ArrayList<>();
|
||||||
meta2.add(metaList.get(100));
|
meta2.add(metaList.get(100));
|
||||||
meta2.add(metaList.get(90));
|
meta2.add(metaList.get(90));
|
||||||
meta2.add(metaList.get(80));
|
meta2.add(metaList.get(80));
|
||||||
meta2.add(metaList.get(70));
|
meta2.add(metaList.get(70));
|
||||||
meta2.add(metaList.get(60));
|
meta2.add(metaList.get(60));
|
||||||
|
|
||||||
List<Record> contents2 = rr.loadFromMetaData(meta2);
|
List<Record> contents2 = rr.loadFromMetaData(meta2);
|
||||||
assertEquals(writables.get(100), contents2.get(0).getRecord());
|
assertEquals(writables.get(100), contents2.get(0).getRecord());
|
||||||
assertEquals(writables.get(90), contents2.get(1).getRecord());
|
assertEquals(writables.get(90), contents2.get(1).getRecord());
|
||||||
|
@ -246,50 +236,49 @@ public class CSVRecordReaderTest extends BaseND4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRegex() throws Exception {
|
@DisplayName("Test Regex")
|
||||||
CSVRecordReader reader = new CSVRegexRecordReader(0, ",", null, new String[] {null, "(.+) (.+) (.+)"});
|
void testRegex() throws Exception {
|
||||||
|
CSVRecordReader reader = new CSVRegexRecordReader(0, ",", null, new String[] { null, "(.+) (.+) (.+)" });
|
||||||
reader.initialize(new StringSplit("normal,1.2.3.4 space separator"));
|
reader.initialize(new StringSplit("normal,1.2.3.4 space separator"));
|
||||||
while (reader.hasNext()) {
|
while (reader.hasNext()) {
|
||||||
List<Writable> vals = reader.next();
|
List<Writable> vals = reader.next();
|
||||||
assertEquals("Entry count", 4, vals.size());
|
assertEquals(4, vals.size(), "Entry count");
|
||||||
assertEquals("normal", vals.get(0).toString());
|
assertEquals(vals.get(0).toString(), "normal");
|
||||||
assertEquals("1.2.3.4", vals.get(1).toString());
|
assertEquals(vals.get(1).toString(), "1.2.3.4");
|
||||||
assertEquals("space", vals.get(2).toString());
|
assertEquals(vals.get(2).toString(), "space");
|
||||||
assertEquals("separator", vals.get(3).toString());
|
assertEquals(vals.get(3).toString(), "separator");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected = NoSuchElementException.class)
|
@Test
|
||||||
public void testCsvSkipAllLines() throws IOException, InterruptedException {
|
@DisplayName("Test Csv Skip All Lines")
|
||||||
final int numLines = 4;
|
void testCsvSkipAllLines() {
|
||||||
final List<Writable> lineList = Arrays.asList((Writable) new IntWritable(numLines - 1),
|
assertThrows(NoSuchElementException.class, () -> {
|
||||||
(Writable) new Text("one"), (Writable) new Text("two"), (Writable) new Text("three"));
|
final int numLines = 4;
|
||||||
String header = ",one,two,three";
|
final List<Writable> lineList = Arrays.asList((Writable) new IntWritable(numLines - 1), (Writable) new Text("one"), (Writable) new Text("two"), (Writable) new Text("three"));
|
||||||
List<String> lines = new ArrayList<>();
|
String header = ",one,two,three";
|
||||||
for (int i = 0; i < numLines; i++)
|
List<String> lines = new ArrayList<>();
|
||||||
lines.add(Integer.toString(i) + header);
|
for (int i = 0; i < numLines; i++) lines.add(Integer.toString(i) + header);
|
||||||
File tempFile = File.createTempFile("csvSkipLines", ".csv");
|
File tempFile = File.createTempFile("csvSkipLines", ".csv");
|
||||||
FileUtils.writeLines(tempFile, lines);
|
FileUtils.writeLines(tempFile, lines);
|
||||||
|
CSVRecordReader rr = new CSVRecordReader(numLines, ',');
|
||||||
CSVRecordReader rr = new CSVRecordReader(numLines, ',');
|
rr.initialize(new FileSplit(tempFile));
|
||||||
rr.initialize(new FileSplit(tempFile));
|
rr.reset();
|
||||||
rr.reset();
|
assertTrue(!rr.hasNext());
|
||||||
assertTrue(!rr.hasNext());
|
rr.next();
|
||||||
rr.next();
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCsvSkipAllButOneLine() throws IOException, InterruptedException {
|
@DisplayName("Test Csv Skip All But One Line")
|
||||||
|
void testCsvSkipAllButOneLine() throws IOException, InterruptedException {
|
||||||
final int numLines = 4;
|
final int numLines = 4;
|
||||||
final List<Writable> lineList = Arrays.<Writable>asList(new Text(Integer.toString(numLines - 1)),
|
final List<Writable> lineList = Arrays.<Writable>asList(new Text(Integer.toString(numLines - 1)), new Text("one"), new Text("two"), new Text("three"));
|
||||||
new Text("one"), new Text("two"), new Text("three"));
|
|
||||||
String header = ",one,two,three";
|
String header = ",one,two,three";
|
||||||
List<String> lines = new ArrayList<>();
|
List<String> lines = new ArrayList<>();
|
||||||
for (int i = 0; i < numLines; i++)
|
for (int i = 0; i < numLines; i++) lines.add(Integer.toString(i) + header);
|
||||||
lines.add(Integer.toString(i) + header);
|
|
||||||
File tempFile = File.createTempFile("csvSkipLines", ".csv");
|
File tempFile = File.createTempFile("csvSkipLines", ".csv");
|
||||||
FileUtils.writeLines(tempFile, lines);
|
FileUtils.writeLines(tempFile, lines);
|
||||||
|
|
||||||
CSVRecordReader rr = new CSVRecordReader(numLines - 1, ',');
|
CSVRecordReader rr = new CSVRecordReader(numLines - 1, ',');
|
||||||
rr.initialize(new FileSplit(tempFile));
|
rr.initialize(new FileSplit(tempFile));
|
||||||
rr.reset();
|
rr.reset();
|
||||||
|
@ -297,50 +286,45 @@ public class CSVRecordReaderTest extends BaseND4JTest {
|
||||||
assertEquals(rr.next(), lineList);
|
assertEquals(rr.next(), lineList);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testStreamReset() throws Exception {
|
@DisplayName("Test Stream Reset")
|
||||||
|
void testStreamReset() throws Exception {
|
||||||
CSVRecordReader rr = new CSVRecordReader(0, ',');
|
CSVRecordReader rr = new CSVRecordReader(0, ',');
|
||||||
rr.initialize(new InputStreamInputSplit(new ClassPathResource("datavec-api/iris.dat").getInputStream()));
|
rr.initialize(new InputStreamInputSplit(new ClassPathResource("datavec-api/iris.dat").getInputStream()));
|
||||||
|
|
||||||
int count = 0;
|
int count = 0;
|
||||||
while(rr.hasNext()){
|
while (rr.hasNext()) {
|
||||||
assertNotNull(rr.next());
|
assertNotNull(rr.next());
|
||||||
count++;
|
count++;
|
||||||
}
|
}
|
||||||
assertEquals(150, count);
|
assertEquals(150, count);
|
||||||
|
|
||||||
assertFalse(rr.resetSupported());
|
assertFalse(rr.resetSupported());
|
||||||
|
try {
|
||||||
try{
|
|
||||||
rr.reset();
|
rr.reset();
|
||||||
fail("Expected exception");
|
fail("Expected exception");
|
||||||
} catch (Exception e){
|
} catch (Exception e) {
|
||||||
String msg = e.getMessage();
|
String msg = e.getMessage();
|
||||||
String msg2 = e.getCause().getMessage();
|
String msg2 = e.getCause().getMessage();
|
||||||
assertTrue(msg, msg.contains("Error during LineRecordReader reset"));
|
assertTrue(msg.contains("Error during LineRecordReader reset"),msg);
|
||||||
assertTrue(msg2, msg2.contains("Reset not supported from streams"));
|
assertTrue(msg2.contains("Reset not supported from streams"),msg2);
|
||||||
// e.printStackTrace();
|
// e.printStackTrace();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testUsefulExceptionNoInit(){
|
@DisplayName("Test Useful Exception No Init")
|
||||||
|
void testUsefulExceptionNoInit() {
|
||||||
CSVRecordReader rr = new CSVRecordReader(0, ',');
|
CSVRecordReader rr = new CSVRecordReader(0, ',');
|
||||||
|
try {
|
||||||
try{
|
|
||||||
rr.hasNext();
|
rr.hasNext();
|
||||||
fail("Expected exception");
|
fail("Expected exception");
|
||||||
} catch (Exception e){
|
} catch (Exception e) {
|
||||||
assertTrue(e.getMessage(), e.getMessage().contains("initialized"));
|
assertTrue( e.getMessage().contains("initialized"),e.getMessage());
|
||||||
}
|
}
|
||||||
|
try {
|
||||||
try{
|
|
||||||
rr.next();
|
rr.next();
|
||||||
fail("Expected exception");
|
fail("Expected exception");
|
||||||
} catch (Exception e){
|
} catch (Exception e) {
|
||||||
assertTrue(e.getMessage(), e.getMessage().contains("initialized"));
|
assertTrue(e.getMessage().contains("initialized"),e.getMessage());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,7 +17,6 @@
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
* *****************************************************************************
|
* *****************************************************************************
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.datavec.api.records.reader.impl;
|
package org.datavec.api.records.reader.impl;
|
||||||
|
|
||||||
import org.datavec.api.records.SequenceRecord;
|
import org.datavec.api.records.SequenceRecord;
|
||||||
|
@ -27,12 +26,11 @@ import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
|
||||||
import org.datavec.api.split.InputSplit;
|
import org.datavec.api.split.InputSplit;
|
||||||
import org.datavec.api.split.NumberedFileInputSplit;
|
import org.datavec.api.split.NumberedFileInputSplit;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.rules.TemporaryFolder;
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
import org.nd4j.common.io.ClassPathResource;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.io.InputStream;
|
import java.io.InputStream;
|
||||||
import java.io.OutputStream;
|
import java.io.OutputStream;
|
||||||
|
@ -41,25 +39,27 @@ import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Iterator;
|
import java.util.Iterator;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import org.junit.jupiter.api.DisplayName;
|
||||||
|
import java.nio.file.Path;
|
||||||
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
@DisplayName("Csv Sequence Record Reader Test")
|
||||||
|
class CSVSequenceRecordReaderTest extends BaseND4JTest {
|
||||||
|
|
||||||
public class CSVSequenceRecordReaderTest extends BaseND4JTest {
|
@TempDir
|
||||||
|
public Path tempDir;
|
||||||
@Rule
|
|
||||||
public TemporaryFolder tempDir = new TemporaryFolder();
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void test() throws Exception {
|
@DisplayName("Test")
|
||||||
|
void test() throws Exception {
|
||||||
CSVSequenceRecordReader seqReader = new CSVSequenceRecordReader(1, ",");
|
CSVSequenceRecordReader seqReader = new CSVSequenceRecordReader(1, ",");
|
||||||
seqReader.initialize(new TestInputSplit());
|
seqReader.initialize(new TestInputSplit());
|
||||||
|
|
||||||
int sequenceCount = 0;
|
int sequenceCount = 0;
|
||||||
while (seqReader.hasNext()) {
|
while (seqReader.hasNext()) {
|
||||||
List<List<Writable>> sequence = seqReader.sequenceRecord();
|
List<List<Writable>> sequence = seqReader.sequenceRecord();
|
||||||
assertEquals(4, sequence.size()); //4 lines, plus 1 header line
|
// 4 lines, plus 1 header line
|
||||||
|
assertEquals(4, sequence.size());
|
||||||
Iterator<List<Writable>> timeStepIter = sequence.iterator();
|
Iterator<List<Writable>> timeStepIter = sequence.iterator();
|
||||||
int lineCount = 0;
|
int lineCount = 0;
|
||||||
while (timeStepIter.hasNext()) {
|
while (timeStepIter.hasNext()) {
|
||||||
|
@ -80,19 +80,18 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testReset() throws Exception {
|
@DisplayName("Test Reset")
|
||||||
|
void testReset() throws Exception {
|
||||||
CSVSequenceRecordReader seqReader = new CSVSequenceRecordReader(1, ",");
|
CSVSequenceRecordReader seqReader = new CSVSequenceRecordReader(1, ",");
|
||||||
seqReader.initialize(new TestInputSplit());
|
seqReader.initialize(new TestInputSplit());
|
||||||
|
|
||||||
int nTests = 5;
|
int nTests = 5;
|
||||||
for (int i = 0; i < nTests; i++) {
|
for (int i = 0; i < nTests; i++) {
|
||||||
seqReader.reset();
|
seqReader.reset();
|
||||||
|
|
||||||
int sequenceCount = 0;
|
int sequenceCount = 0;
|
||||||
while (seqReader.hasNext()) {
|
while (seqReader.hasNext()) {
|
||||||
List<List<Writable>> sequence = seqReader.sequenceRecord();
|
List<List<Writable>> sequence = seqReader.sequenceRecord();
|
||||||
assertEquals(4, sequence.size()); //4 lines, plus 1 header line
|
// 4 lines, plus 1 header line
|
||||||
|
assertEquals(4, sequence.size());
|
||||||
Iterator<List<Writable>> timeStepIter = sequence.iterator();
|
Iterator<List<Writable>> timeStepIter = sequence.iterator();
|
||||||
int lineCount = 0;
|
int lineCount = 0;
|
||||||
while (timeStepIter.hasNext()) {
|
while (timeStepIter.hasNext()) {
|
||||||
|
@ -107,15 +106,15 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testMetaData() throws Exception {
|
@DisplayName("Test Meta Data")
|
||||||
|
void testMetaData() throws Exception {
|
||||||
CSVSequenceRecordReader seqReader = new CSVSequenceRecordReader(1, ",");
|
CSVSequenceRecordReader seqReader = new CSVSequenceRecordReader(1, ",");
|
||||||
seqReader.initialize(new TestInputSplit());
|
seqReader.initialize(new TestInputSplit());
|
||||||
|
|
||||||
List<List<List<Writable>>> l = new ArrayList<>();
|
List<List<List<Writable>>> l = new ArrayList<>();
|
||||||
while (seqReader.hasNext()) {
|
while (seqReader.hasNext()) {
|
||||||
List<List<Writable>> sequence = seqReader.sequenceRecord();
|
List<List<Writable>> sequence = seqReader.sequenceRecord();
|
||||||
assertEquals(4, sequence.size()); //4 lines, plus 1 header line
|
// 4 lines, plus 1 header line
|
||||||
|
assertEquals(4, sequence.size());
|
||||||
Iterator<List<Writable>> timeStepIter = sequence.iterator();
|
Iterator<List<Writable>> timeStepIter = sequence.iterator();
|
||||||
int lineCount = 0;
|
int lineCount = 0;
|
||||||
while (timeStepIter.hasNext()) {
|
while (timeStepIter.hasNext()) {
|
||||||
|
@ -123,10 +122,8 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
|
||||||
lineCount++;
|
lineCount++;
|
||||||
}
|
}
|
||||||
assertEquals(4, lineCount);
|
assertEquals(4, lineCount);
|
||||||
|
|
||||||
l.add(sequence);
|
l.add(sequence);
|
||||||
}
|
}
|
||||||
|
|
||||||
List<SequenceRecord> l2 = new ArrayList<>();
|
List<SequenceRecord> l2 = new ArrayList<>();
|
||||||
List<RecordMetaData> meta = new ArrayList<>();
|
List<RecordMetaData> meta = new ArrayList<>();
|
||||||
seqReader.reset();
|
seqReader.reset();
|
||||||
|
@ -136,7 +133,6 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
|
||||||
meta.add(sr.getMetaData());
|
meta.add(sr.getMetaData());
|
||||||
}
|
}
|
||||||
assertEquals(3, l2.size());
|
assertEquals(3, l2.size());
|
||||||
|
|
||||||
List<SequenceRecord> fromMeta = seqReader.loadSequenceFromMetaData(meta);
|
List<SequenceRecord> fromMeta = seqReader.loadSequenceFromMetaData(meta);
|
||||||
for (int i = 0; i < 3; i++) {
|
for (int i = 0; i < 3; i++) {
|
||||||
assertEquals(l.get(i), l2.get(i).getSequenceRecord());
|
assertEquals(l.get(i), l2.get(i).getSequenceRecord());
|
||||||
|
@ -144,8 +140,8 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private static class
|
@DisplayName("Test Input Split")
|
||||||
TestInputSplit implements InputSplit {
|
private static class TestInputSplit implements InputSplit {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean canWriteToLocation(URI location) {
|
public boolean canWriteToLocation(URI location) {
|
||||||
|
@ -164,7 +160,6 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void updateSplitLocations(boolean reset) {
|
public void updateSplitLocations(boolean reset) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -174,7 +169,6 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void bootStrapForWrite() {
|
public void bootStrapForWrite() {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -222,38 +216,30 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void reset() {
|
public void reset() {
|
||||||
//No op
|
// No op
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean resetSupported() {
|
public boolean resetSupported() {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCsvSeqAndNumberedFileSplit() throws Exception {
|
@DisplayName("Test Csv Seq And Numbered File Split")
|
||||||
File baseDir = tempDir.newFolder();
|
void testCsvSeqAndNumberedFileSplit(@TempDir Path tempDir) throws Exception {
|
||||||
//Simple sanity check unit test
|
File baseDir = tempDir.toFile();
|
||||||
|
// Simple sanity check unit test
|
||||||
for (int i = 0; i < 3; i++) {
|
for (int i = 0; i < 3; i++) {
|
||||||
new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(baseDir);
|
new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(baseDir);
|
||||||
}
|
}
|
||||||
|
// Load time series from CSV sequence files; compare to SequenceRecordReaderDataSetIterator
|
||||||
//Load time series from CSV sequence files; compare to SequenceRecordReaderDataSetIterator
|
|
||||||
ClassPathResource resource = new ClassPathResource("csvsequence_0.txt");
|
ClassPathResource resource = new ClassPathResource("csvsequence_0.txt");
|
||||||
String featuresPath = new File(baseDir, "csvsequence_%d.txt").getAbsolutePath();
|
String featuresPath = new File(baseDir, "csvsequence_%d.txt").getAbsolutePath();
|
||||||
|
|
||||||
SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
|
SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
|
||||||
featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
|
featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
|
||||||
|
while (featureReader.hasNext()) {
|
||||||
while(featureReader.hasNext()){
|
|
||||||
featureReader.nextSequence();
|
featureReader.nextSequence();
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,7 +17,6 @@
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
* *****************************************************************************
|
* *****************************************************************************
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.datavec.api.records.reader.impl;
|
package org.datavec.api.records.reader.impl;
|
||||||
|
|
||||||
import org.datavec.api.records.reader.SequenceRecordReader;
|
import org.datavec.api.records.reader.SequenceRecordReader;
|
||||||
|
@ -25,94 +24,87 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
|
||||||
import org.datavec.api.records.reader.impl.csv.CSVVariableSlidingWindowRecordReader;
|
import org.datavec.api.records.reader.impl.csv.CSVVariableSlidingWindowRecordReader;
|
||||||
import org.datavec.api.split.FileSplit;
|
import org.datavec.api.split.FileSplit;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
import org.nd4j.common.io.ClassPathResource;
|
||||||
|
|
||||||
import java.util.LinkedList;
|
import java.util.LinkedList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import org.junit.jupiter.api.DisplayName;
|
||||||
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
@DisplayName("Csv Variable Sliding Window Record Reader Test")
|
||||||
|
class CSVVariableSlidingWindowRecordReaderTest extends BaseND4JTest {
|
||||||
public class CSVVariableSlidingWindowRecordReaderTest extends BaseND4JTest {
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCSVVariableSlidingWindowRecordReader() throws Exception {
|
@DisplayName("Test CSV Variable Sliding Window Record Reader")
|
||||||
|
void testCSVVariableSlidingWindowRecordReader() throws Exception {
|
||||||
int maxLinesPerSequence = 3;
|
int maxLinesPerSequence = 3;
|
||||||
|
|
||||||
SequenceRecordReader seqRR = new CSVVariableSlidingWindowRecordReader(maxLinesPerSequence);
|
SequenceRecordReader seqRR = new CSVVariableSlidingWindowRecordReader(maxLinesPerSequence);
|
||||||
seqRR.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
|
seqRR.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
|
||||||
|
|
||||||
CSVRecordReader rr = new CSVRecordReader();
|
CSVRecordReader rr = new CSVRecordReader();
|
||||||
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
|
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
|
||||||
|
|
||||||
int count = 0;
|
int count = 0;
|
||||||
while (seqRR.hasNext()) {
|
while (seqRR.hasNext()) {
|
||||||
List<List<Writable>> next = seqRR.sequenceRecord();
|
List<List<Writable>> next = seqRR.sequenceRecord();
|
||||||
|
if (count == maxLinesPerSequence - 1) {
|
||||||
if(count==maxLinesPerSequence-1) {
|
|
||||||
LinkedList<List<Writable>> expected = new LinkedList<>();
|
LinkedList<List<Writable>> expected = new LinkedList<>();
|
||||||
for (int i = 0; i < maxLinesPerSequence; i++) {
|
for (int i = 0; i < maxLinesPerSequence; i++) {
|
||||||
expected.addFirst(rr.next());
|
expected.addFirst(rr.next());
|
||||||
}
|
}
|
||||||
assertEquals(expected, next);
|
assertEquals(expected, next);
|
||||||
|
|
||||||
}
|
}
|
||||||
if(count==maxLinesPerSequence) {
|
if (count == maxLinesPerSequence) {
|
||||||
assertEquals(maxLinesPerSequence, next.size());
|
assertEquals(maxLinesPerSequence, next.size());
|
||||||
}
|
}
|
||||||
if(count==0) { // first seq should be length 1
|
if (count == 0) {
|
||||||
|
// first seq should be length 1
|
||||||
assertEquals(1, next.size());
|
assertEquals(1, next.size());
|
||||||
}
|
}
|
||||||
if(count>151) { // last seq should be length 1
|
if (count > 151) {
|
||||||
|
// last seq should be length 1
|
||||||
assertEquals(1, next.size());
|
assertEquals(1, next.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
count++;
|
count++;
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(152, count);
|
assertEquals(152, count);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCSVVariableSlidingWindowRecordReaderStride() throws Exception {
|
@DisplayName("Test CSV Variable Sliding Window Record Reader Stride")
|
||||||
|
void testCSVVariableSlidingWindowRecordReaderStride() throws Exception {
|
||||||
int maxLinesPerSequence = 3;
|
int maxLinesPerSequence = 3;
|
||||||
int stride = 2;
|
int stride = 2;
|
||||||
|
|
||||||
SequenceRecordReader seqRR = new CSVVariableSlidingWindowRecordReader(maxLinesPerSequence, stride);
|
SequenceRecordReader seqRR = new CSVVariableSlidingWindowRecordReader(maxLinesPerSequence, stride);
|
||||||
seqRR.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
|
seqRR.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
|
||||||
|
|
||||||
CSVRecordReader rr = new CSVRecordReader();
|
CSVRecordReader rr = new CSVRecordReader();
|
||||||
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
|
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
|
||||||
|
|
||||||
int count = 0;
|
int count = 0;
|
||||||
while (seqRR.hasNext()) {
|
while (seqRR.hasNext()) {
|
||||||
List<List<Writable>> next = seqRR.sequenceRecord();
|
List<List<Writable>> next = seqRR.sequenceRecord();
|
||||||
|
if (count == maxLinesPerSequence - 1) {
|
||||||
if(count==maxLinesPerSequence-1) {
|
|
||||||
LinkedList<List<Writable>> expected = new LinkedList<>();
|
LinkedList<List<Writable>> expected = new LinkedList<>();
|
||||||
for(int s = 0; s < stride; s++) {
|
for (int s = 0; s < stride; s++) {
|
||||||
expected = new LinkedList<>();
|
expected = new LinkedList<>();
|
||||||
for (int i = 0; i < maxLinesPerSequence; i++) {
|
for (int i = 0; i < maxLinesPerSequence; i++) {
|
||||||
expected.addFirst(rr.next());
|
expected.addFirst(rr.next());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
assertEquals(expected, next);
|
assertEquals(expected, next);
|
||||||
|
|
||||||
}
|
}
|
||||||
if(count==maxLinesPerSequence) {
|
if (count == maxLinesPerSequence) {
|
||||||
assertEquals(maxLinesPerSequence, next.size());
|
assertEquals(maxLinesPerSequence, next.size());
|
||||||
}
|
}
|
||||||
if(count==0) { // first seq should be length 2
|
if (count == 0) {
|
||||||
|
// first seq should be length 2
|
||||||
assertEquals(2, next.size());
|
assertEquals(2, next.size());
|
||||||
}
|
}
|
||||||
if(count>151) { // last seq should be length 1
|
if (count > 151) {
|
||||||
|
// last seq should be length 1
|
||||||
assertEquals(1, next.size());
|
assertEquals(1, next.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
count++;
|
count++;
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(76, count);
|
assertEquals(76, count);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,7 +17,6 @@
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
* *****************************************************************************
|
* *****************************************************************************
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.datavec.api.records.reader.impl;
|
package org.datavec.api.records.reader.impl;
|
||||||
|
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
|
@ -28,45 +27,44 @@ import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
|
||||||
import org.datavec.api.records.reader.impl.filebatch.FileBatchRecordReader;
|
import org.datavec.api.records.reader.impl.filebatch.FileBatchRecordReader;
|
||||||
import org.datavec.api.records.reader.impl.filebatch.FileBatchSequenceRecordReader;
|
import org.datavec.api.records.reader.impl.filebatch.FileBatchSequenceRecordReader;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.rules.TemporaryFolder;
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
import org.nd4j.common.loader.FileBatch;
|
import org.nd4j.common.loader.FileBatch;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.nio.charset.StandardCharsets;
|
import java.nio.charset.StandardCharsets;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
import org.junit.jupiter.api.DisplayName;
|
||||||
|
import java.nio.file.Path;
|
||||||
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
@DisplayName("File Batch Record Reader Test")
|
||||||
|
public class FileBatchRecordReaderTest extends BaseND4JTest {
|
||||||
public class FileBatchRecordReaderTest extends BaseND4JTest {
|
@TempDir Path testDir;
|
||||||
|
|
||||||
@Rule
|
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testCsv() throws Exception {
|
|
||||||
|
|
||||||
//This is an unrealistic use case - one line/record per CSV
|
|
||||||
File baseDir = testDir.newFolder();
|
|
||||||
|
|
||||||
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
|
@DisplayName("Test Csv")
|
||||||
|
void testCsv(Nd4jBackend backend) throws Exception {
|
||||||
|
// This is an unrealistic use case - one line/record per CSV
|
||||||
|
File baseDir = testDir.toFile();
|
||||||
List<File> fileList = new ArrayList<>();
|
List<File> fileList = new ArrayList<>();
|
||||||
for( int i=0; i<10; i++ ){
|
for (int i = 0; i < 10; i++) {
|
||||||
String s = "file_" + i + "," + i + "," + i;
|
String s = "file_" + i + "," + i + "," + i;
|
||||||
File f = new File(baseDir, "origFile" + i + ".csv");
|
File f = new File(baseDir, "origFile" + i + ".csv");
|
||||||
FileUtils.writeStringToFile(f, s, StandardCharsets.UTF_8);
|
FileUtils.writeStringToFile(f, s, StandardCharsets.UTF_8);
|
||||||
fileList.add(f);
|
fileList.add(f);
|
||||||
}
|
}
|
||||||
|
|
||||||
FileBatch fb = FileBatch.forFiles(fileList);
|
FileBatch fb = FileBatch.forFiles(fileList);
|
||||||
|
|
||||||
RecordReader rr = new CSVRecordReader();
|
RecordReader rr = new CSVRecordReader();
|
||||||
FileBatchRecordReader fbrr = new FileBatchRecordReader(rr, fb);
|
FileBatchRecordReader fbrr = new FileBatchRecordReader(rr, fb);
|
||||||
|
for (int test = 0; test < 3; test++) {
|
||||||
|
|
||||||
for( int test=0; test<3; test++) {
|
|
||||||
for (int i = 0; i < 10; i++) {
|
for (int i = 0; i < 10; i++) {
|
||||||
assertTrue(fbrr.hasNext());
|
assertTrue(fbrr.hasNext());
|
||||||
List<Writable> next = fbrr.next();
|
List<Writable> next = fbrr.next();
|
||||||
|
@ -82,16 +80,17 @@ public class FileBatchRecordReaderTest extends BaseND4JTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
public void testCsvSequence() throws Exception {
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
//CSV sequence - 3 lines per file, 10 files
|
@DisplayName("Test Csv Sequence")
|
||||||
File baseDir = testDir.newFolder();
|
void testCsvSequence(Nd4jBackend backend) throws Exception {
|
||||||
|
// CSV sequence - 3 lines per file, 10 files
|
||||||
|
File baseDir = testDir.toFile();
|
||||||
List<File> fileList = new ArrayList<>();
|
List<File> fileList = new ArrayList<>();
|
||||||
for( int i=0; i<10; i++ ){
|
for (int i = 0; i < 10; i++) {
|
||||||
StringBuilder sb = new StringBuilder();
|
StringBuilder sb = new StringBuilder();
|
||||||
for( int j=0; j<3; j++ ){
|
for (int j = 0; j < 3; j++) {
|
||||||
if(j > 0)
|
if (j > 0)
|
||||||
sb.append("\n");
|
sb.append("\n");
|
||||||
sb.append("file_" + i + "," + i + "," + j);
|
sb.append("file_" + i + "," + i + "," + j);
|
||||||
}
|
}
|
||||||
|
@ -99,19 +98,16 @@ public class FileBatchRecordReaderTest extends BaseND4JTest {
|
||||||
FileUtils.writeStringToFile(f, sb.toString(), StandardCharsets.UTF_8);
|
FileUtils.writeStringToFile(f, sb.toString(), StandardCharsets.UTF_8);
|
||||||
fileList.add(f);
|
fileList.add(f);
|
||||||
}
|
}
|
||||||
|
|
||||||
FileBatch fb = FileBatch.forFiles(fileList);
|
FileBatch fb = FileBatch.forFiles(fileList);
|
||||||
SequenceRecordReader rr = new CSVSequenceRecordReader();
|
SequenceRecordReader rr = new CSVSequenceRecordReader();
|
||||||
FileBatchSequenceRecordReader fbrr = new FileBatchSequenceRecordReader(rr, fb);
|
FileBatchSequenceRecordReader fbrr = new FileBatchSequenceRecordReader(rr, fb);
|
||||||
|
for (int test = 0; test < 3; test++) {
|
||||||
|
|
||||||
for( int test=0; test<3; test++) {
|
|
||||||
for (int i = 0; i < 10; i++) {
|
for (int i = 0; i < 10; i++) {
|
||||||
assertTrue(fbrr.hasNext());
|
assertTrue(fbrr.hasNext());
|
||||||
List<List<Writable>> next = fbrr.sequenceRecord();
|
List<List<Writable>> next = fbrr.sequenceRecord();
|
||||||
assertEquals(3, next.size());
|
assertEquals(3, next.size());
|
||||||
int count = 0;
|
int count = 0;
|
||||||
for(List<Writable> step : next ){
|
for (List<Writable> step : next) {
|
||||||
String s1 = "file_" + i;
|
String s1 = "file_" + i;
|
||||||
assertEquals(s1, step.get(0).toString());
|
assertEquals(s1, step.get(0).toString());
|
||||||
assertEquals(String.valueOf(i), step.get(1).toString());
|
assertEquals(String.valueOf(i), step.get(1).toString());
|
||||||
|
@ -123,5 +119,4 @@ public class FileBatchRecordReaderTest extends BaseND4JTest {
|
||||||
fbrr.reset();
|
fbrr.reset();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,7 +17,6 @@
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
* *****************************************************************************
|
* *****************************************************************************
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.datavec.api.records.reader.impl;
|
package org.datavec.api.records.reader.impl;
|
||||||
|
|
||||||
import org.datavec.api.records.Record;
|
import org.datavec.api.records.Record;
|
||||||
|
@ -26,28 +25,28 @@ import org.datavec.api.split.CollectionInputSplit;
|
||||||
import org.datavec.api.split.FileSplit;
|
import org.datavec.api.split.FileSplit;
|
||||||
import org.datavec.api.split.InputSplit;
|
import org.datavec.api.split.InputSplit;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
import org.nd4j.common.io.ClassPathResource;
|
||||||
|
|
||||||
import java.net.URI;
|
import java.net.URI;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||||
|
import org.junit.jupiter.api.DisplayName;
|
||||||
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
@DisplayName("File Record Reader Test")
|
||||||
import static org.junit.Assert.assertFalse;
|
class FileRecordReaderTest extends BaseND4JTest {
|
||||||
|
|
||||||
public class FileRecordReaderTest extends BaseND4JTest {
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testReset() throws Exception {
|
@DisplayName("Test Reset")
|
||||||
|
void testReset() throws Exception {
|
||||||
FileRecordReader rr = new FileRecordReader();
|
FileRecordReader rr = new FileRecordReader();
|
||||||
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
|
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
|
||||||
|
|
||||||
int nResets = 5;
|
int nResets = 5;
|
||||||
for (int i = 0; i < nResets; i++) {
|
for (int i = 0; i < nResets; i++) {
|
||||||
|
|
||||||
int lineCount = 0;
|
int lineCount = 0;
|
||||||
while (rr.hasNext()) {
|
while (rr.hasNext()) {
|
||||||
List<Writable> line = rr.next();
|
List<Writable> line = rr.next();
|
||||||
|
@ -61,25 +60,20 @@ public class FileRecordReaderTest extends BaseND4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testMeta() throws Exception {
|
@DisplayName("Test Meta")
|
||||||
|
void testMeta() throws Exception {
|
||||||
FileRecordReader rr = new FileRecordReader();
|
FileRecordReader rr = new FileRecordReader();
|
||||||
|
|
||||||
|
|
||||||
URI[] arr = new URI[3];
|
URI[] arr = new URI[3];
|
||||||
arr[0] = new ClassPathResource("datavec-api/csvsequence_0.txt").getFile().toURI();
|
arr[0] = new ClassPathResource("datavec-api/csvsequence_0.txt").getFile().toURI();
|
||||||
arr[1] = new ClassPathResource("datavec-api/csvsequence_1.txt").getFile().toURI();
|
arr[1] = new ClassPathResource("datavec-api/csvsequence_1.txt").getFile().toURI();
|
||||||
arr[2] = new ClassPathResource("datavec-api/csvsequence_2.txt").getFile().toURI();
|
arr[2] = new ClassPathResource("datavec-api/csvsequence_2.txt").getFile().toURI();
|
||||||
|
|
||||||
InputSplit is = new CollectionInputSplit(Arrays.asList(arr));
|
InputSplit is = new CollectionInputSplit(Arrays.asList(arr));
|
||||||
rr.initialize(is);
|
rr.initialize(is);
|
||||||
|
|
||||||
List<List<Writable>> out = new ArrayList<>();
|
List<List<Writable>> out = new ArrayList<>();
|
||||||
while (rr.hasNext()) {
|
while (rr.hasNext()) {
|
||||||
out.add(rr.next());
|
out.add(rr.next());
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(3, out.size());
|
assertEquals(3, out.size());
|
||||||
|
|
||||||
rr.reset();
|
rr.reset();
|
||||||
List<List<Writable>> out2 = new ArrayList<>();
|
List<List<Writable>> out2 = new ArrayList<>();
|
||||||
List<Record> out3 = new ArrayList<>();
|
List<Record> out3 = new ArrayList<>();
|
||||||
|
@ -90,13 +84,10 @@ public class FileRecordReaderTest extends BaseND4JTest {
|
||||||
out2.add(r.getRecord());
|
out2.add(r.getRecord());
|
||||||
out3.add(r);
|
out3.add(r);
|
||||||
meta.add(r.getMetaData());
|
meta.add(r.getMetaData());
|
||||||
|
|
||||||
assertEquals(arr[count++], r.getMetaData().getURI());
|
assertEquals(arr[count++], r.getMetaData().getURI());
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(out, out2);
|
assertEquals(out, out2);
|
||||||
List<Record> fromMeta = rr.loadFromMetaData(meta);
|
List<Record> fromMeta = rr.loadFromMetaData(meta);
|
||||||
assertEquals(out3, fromMeta);
|
assertEquals(out3, fromMeta);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,7 +17,6 @@
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
* *****************************************************************************
|
* *****************************************************************************
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.datavec.api.records.reader.impl;
|
package org.datavec.api.records.reader.impl;
|
||||||
|
|
||||||
import org.datavec.api.records.reader.RecordReader;
|
import org.datavec.api.records.reader.RecordReader;
|
||||||
|
@ -28,97 +27,81 @@ import org.datavec.api.split.CollectionInputSplit;
|
||||||
import org.datavec.api.split.FileSplit;
|
import org.datavec.api.split.FileSplit;
|
||||||
import org.datavec.api.writable.Text;
|
import org.datavec.api.writable.Text;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.rules.TemporaryFolder;
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
import org.nd4j.common.io.ClassPathResource;
|
||||||
import org.nd4j.shade.jackson.core.JsonFactory;
|
import org.nd4j.shade.jackson.core.JsonFactory;
|
||||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.net.URI;
|
import java.net.URI;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import org.junit.jupiter.api.DisplayName;
|
||||||
|
import java.nio.file.Path;
|
||||||
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
@DisplayName("Jackson Line Record Reader Test")
|
||||||
|
class JacksonLineRecordReaderTest extends BaseND4JTest {
|
||||||
|
|
||||||
public class JacksonLineRecordReaderTest extends BaseND4JTest {
|
@TempDir
|
||||||
|
public Path testDir;
|
||||||
|
|
||||||
@Rule
|
public JacksonLineRecordReaderTest() {
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
}
|
||||||
|
|
||||||
public JacksonLineRecordReaderTest() {
|
|
||||||
}
|
|
||||||
|
|
||||||
private static FieldSelection getFieldSelection() {
|
private static FieldSelection getFieldSelection() {
|
||||||
return new FieldSelection.Builder().addField("value1").
|
return new FieldSelection.Builder().addField("value1").addField("value2").addField("value3").addField("value4").addField("value5").addField("value6").addField("value7").addField("value8").addField("value9").addField("value10").build();
|
||||||
addField("value2").
|
|
||||||
addField("value3").
|
|
||||||
addField("value4").
|
|
||||||
addField("value5").
|
|
||||||
addField("value6").
|
|
||||||
addField("value7").
|
|
||||||
addField("value8").
|
|
||||||
addField("value9").
|
|
||||||
addField("value10").build();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testReadJSON() throws Exception {
|
@DisplayName("Test Read JSON")
|
||||||
|
void testReadJSON() throws Exception {
|
||||||
RecordReader rr = new JacksonLineRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()));
|
RecordReader rr = new JacksonLineRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()));
|
||||||
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/json/json_test_3.txt").getFile()));
|
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/json/json_test_3.txt").getFile()));
|
||||||
|
|
||||||
testJacksonRecordReader(rr);
|
testJacksonRecordReader(rr);
|
||||||
}
|
|
||||||
|
|
||||||
private static void testJacksonRecordReader(RecordReader rr) {
|
|
||||||
while (rr.hasNext()) {
|
|
||||||
List<Writable> json0 = rr.next();
|
|
||||||
//System.out.println(json0);
|
|
||||||
assert(json0.size() > 0);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static void testJacksonRecordReader(RecordReader rr) {
|
||||||
|
while (rr.hasNext()) {
|
||||||
|
List<Writable> json0 = rr.next();
|
||||||
|
// System.out.println(json0);
|
||||||
|
assert (json0.size() > 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testJacksonLineSequenceRecordReader() throws Exception {
|
@DisplayName("Test Jackson Line Sequence Record Reader")
|
||||||
File dir = testDir.newFolder();
|
void testJacksonLineSequenceRecordReader(@TempDir Path testDir) throws Exception {
|
||||||
new ClassPathResource("datavec-api/JacksonLineSequenceRecordReaderTest/").copyDirectory(dir);
|
File dir = testDir.toFile();
|
||||||
|
new ClassPathResource("datavec-api/JacksonLineSequenceRecordReaderTest/").copyDirectory(dir);
|
||||||
FieldSelection f = new FieldSelection.Builder().addField("a").addField(new Text("MISSING_B"), "b")
|
FieldSelection f = new FieldSelection.Builder().addField("a").addField(new Text("MISSING_B"), "b").addField(new Text("MISSING_CX"), "c", "x").build();
|
||||||
.addField(new Text("MISSING_CX"), "c", "x").build();
|
JacksonLineSequenceRecordReader rr = new JacksonLineSequenceRecordReader(f, new ObjectMapper(new JsonFactory()));
|
||||||
|
File[] files = dir.listFiles();
|
||||||
JacksonLineSequenceRecordReader rr = new JacksonLineSequenceRecordReader(f, new ObjectMapper(new JsonFactory()));
|
Arrays.sort(files);
|
||||||
File[] files = dir.listFiles();
|
URI[] u = new URI[files.length];
|
||||||
Arrays.sort(files);
|
for (int i = 0; i < files.length; i++) {
|
||||||
URI[] u = new URI[files.length];
|
u[i] = files[i].toURI();
|
||||||
for( int i=0; i<files.length; i++ ){
|
}
|
||||||
u[i] = files[i].toURI();
|
rr.initialize(new CollectionInputSplit(u));
|
||||||
}
|
List<List<Writable>> expSeq0 = new ArrayList<>();
|
||||||
rr.initialize(new CollectionInputSplit(u));
|
expSeq0.add(Arrays.asList((Writable) new Text("aValue0"), new Text("bValue0"), new Text("cxValue0")));
|
||||||
|
expSeq0.add(Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1")));
|
||||||
List<List<Writable>> expSeq0 = new ArrayList<>();
|
expSeq0.add(Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX")));
|
||||||
expSeq0.add(Arrays.asList((Writable) new Text("aValue0"), new Text("bValue0"), new Text("cxValue0")));
|
List<List<Writable>> expSeq1 = new ArrayList<>();
|
||||||
expSeq0.add(Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1")));
|
expSeq1.add(Arrays.asList((Writable) new Text("aValue3"), new Text("bValue3"), new Text("cxValue3")));
|
||||||
expSeq0.add(Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX")));
|
int count = 0;
|
||||||
|
while (rr.hasNext()) {
|
||||||
List<List<Writable>> expSeq1 = new ArrayList<>();
|
List<List<Writable>> next = rr.sequenceRecord();
|
||||||
expSeq1.add(Arrays.asList((Writable) new Text("aValue3"), new Text("bValue3"), new Text("cxValue3")));
|
if (count++ == 0) {
|
||||||
|
assertEquals(expSeq0, next);
|
||||||
|
} else {
|
||||||
int count = 0;
|
assertEquals(expSeq1, next);
|
||||||
while(rr.hasNext()){
|
}
|
||||||
List<List<Writable>> next = rr.sequenceRecord();
|
}
|
||||||
if(count++ == 0){
|
assertEquals(2, count);
|
||||||
assertEquals(expSeq0, next);
|
}
|
||||||
} else {
|
|
||||||
assertEquals(expSeq1, next);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
assertEquals(2, count);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,7 +17,6 @@
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
* *****************************************************************************
|
* *****************************************************************************
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.datavec.api.records.reader.impl;
|
package org.datavec.api.records.reader.impl;
|
||||||
|
|
||||||
import org.datavec.api.io.labels.PathLabelGenerator;
|
import org.datavec.api.io.labels.PathLabelGenerator;
|
||||||
|
@ -31,114 +30,95 @@ import org.datavec.api.split.NumberedFileInputSplit;
|
||||||
import org.datavec.api.writable.IntWritable;
|
import org.datavec.api.writable.IntWritable;
|
||||||
import org.datavec.api.writable.Text;
|
import org.datavec.api.writable.Text;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.rules.TemporaryFolder;
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
import org.nd4j.common.io.ClassPathResource;
|
||||||
import org.nd4j.shade.jackson.core.JsonFactory;
|
import org.nd4j.shade.jackson.core.JsonFactory;
|
||||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||||||
import org.nd4j.shade.jackson.dataformat.xml.XmlFactory;
|
import org.nd4j.shade.jackson.dataformat.xml.XmlFactory;
|
||||||
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
|
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.net.URI;
|
import java.net.URI;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||||
|
import org.junit.jupiter.api.DisplayName;
|
||||||
|
import java.nio.file.Path;
|
||||||
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
@DisplayName("Jackson Record Reader Test")
|
||||||
import static org.junit.Assert.assertFalse;
|
class JacksonRecordReaderTest extends BaseND4JTest {
|
||||||
|
|
||||||
public class JacksonRecordReaderTest extends BaseND4JTest {
|
@TempDir
|
||||||
|
public Path testDir;
|
||||||
@Rule
|
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testReadingJson() throws Exception {
|
@DisplayName("Test Reading Json")
|
||||||
//Load 3 values from 3 JSON files
|
void testReadingJson(@TempDir Path testDir) throws Exception {
|
||||||
//stricture: a:value, b:value, c:x:value, c:y:value
|
// Load 3 values from 3 JSON files
|
||||||
//And we want to load only a:value, b:value and c:x:value
|
// stricture: a:value, b:value, c:x:value, c:y:value
|
||||||
//For first JSON file: all values are present
|
// And we want to load only a:value, b:value and c:x:value
|
||||||
//For second JSON file: b:value is missing
|
// For first JSON file: all values are present
|
||||||
//For third JSON file: c:x:value is missing
|
// For second JSON file: b:value is missing
|
||||||
|
// For third JSON file: c:x:value is missing
|
||||||
ClassPathResource cpr = new ClassPathResource("datavec-api/json/");
|
ClassPathResource cpr = new ClassPathResource("datavec-api/json/");
|
||||||
File f = testDir.newFolder();
|
File f = testDir.toFile();
|
||||||
cpr.copyDirectory(f);
|
cpr.copyDirectory(f);
|
||||||
String path = new File(f, "json_test_%d.txt").getAbsolutePath();
|
String path = new File(f, "json_test_%d.txt").getAbsolutePath();
|
||||||
|
|
||||||
InputSplit is = new NumberedFileInputSplit(path, 0, 2);
|
InputSplit is = new NumberedFileInputSplit(path, 0, 2);
|
||||||
|
|
||||||
RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()));
|
RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()));
|
||||||
rr.initialize(is);
|
rr.initialize(is);
|
||||||
|
|
||||||
testJacksonRecordReader(rr);
|
testJacksonRecordReader(rr);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testReadingYaml() throws Exception {
|
@DisplayName("Test Reading Yaml")
|
||||||
//Exact same information as JSON format, but in YAML format
|
void testReadingYaml(@TempDir Path testDir) throws Exception {
|
||||||
|
// Exact same information as JSON format, but in YAML format
|
||||||
ClassPathResource cpr = new ClassPathResource("datavec-api/yaml/");
|
ClassPathResource cpr = new ClassPathResource("datavec-api/yaml/");
|
||||||
File f = testDir.newFolder();
|
File f = testDir.toFile();
|
||||||
cpr.copyDirectory(f);
|
cpr.copyDirectory(f);
|
||||||
String path = new File(f, "yaml_test_%d.txt").getAbsolutePath();
|
String path = new File(f, "yaml_test_%d.txt").getAbsolutePath();
|
||||||
|
|
||||||
|
|
||||||
InputSplit is = new NumberedFileInputSplit(path, 0, 2);
|
InputSplit is = new NumberedFileInputSplit(path, 0, 2);
|
||||||
|
|
||||||
RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new YAMLFactory()));
|
RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new YAMLFactory()));
|
||||||
rr.initialize(is);
|
rr.initialize(is);
|
||||||
|
|
||||||
testJacksonRecordReader(rr);
|
testJacksonRecordReader(rr);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testReadingXml() throws Exception {
|
@DisplayName("Test Reading Xml")
|
||||||
//Exact same information as JSON format, but in XML format
|
void testReadingXml(@TempDir Path testDir) throws Exception {
|
||||||
|
// Exact same information as JSON format, but in XML format
|
||||||
ClassPathResource cpr = new ClassPathResource("datavec-api/xml/");
|
ClassPathResource cpr = new ClassPathResource("datavec-api/xml/");
|
||||||
File f = testDir.newFolder();
|
File f = testDir.toFile();
|
||||||
cpr.copyDirectory(f);
|
cpr.copyDirectory(f);
|
||||||
String path = new File(f, "xml_test_%d.txt").getAbsolutePath();
|
String path = new File(f, "xml_test_%d.txt").getAbsolutePath();
|
||||||
|
|
||||||
InputSplit is = new NumberedFileInputSplit(path, 0, 2);
|
InputSplit is = new NumberedFileInputSplit(path, 0, 2);
|
||||||
|
|
||||||
RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new XmlFactory()));
|
RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new XmlFactory()));
|
||||||
rr.initialize(is);
|
rr.initialize(is);
|
||||||
|
|
||||||
testJacksonRecordReader(rr);
|
testJacksonRecordReader(rr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
private static FieldSelection getFieldSelection() {
|
private static FieldSelection getFieldSelection() {
|
||||||
return new FieldSelection.Builder().addField("a").addField(new Text("MISSING_B"), "b")
|
return new FieldSelection.Builder().addField("a").addField(new Text("MISSING_B"), "b").addField(new Text("MISSING_CX"), "c", "x").build();
|
||||||
.addField(new Text("MISSING_CX"), "c", "x").build();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
private static void testJacksonRecordReader(RecordReader rr) {
|
private static void testJacksonRecordReader(RecordReader rr) {
|
||||||
|
|
||||||
List<Writable> json0 = rr.next();
|
List<Writable> json0 = rr.next();
|
||||||
List<Writable> exp0 = Arrays.asList((Writable) new Text("aValue0"), new Text("bValue0"), new Text("cxValue0"));
|
List<Writable> exp0 = Arrays.asList((Writable) new Text("aValue0"), new Text("bValue0"), new Text("cxValue0"));
|
||||||
assertEquals(exp0, json0);
|
assertEquals(exp0, json0);
|
||||||
|
|
||||||
List<Writable> json1 = rr.next();
|
List<Writable> json1 = rr.next();
|
||||||
List<Writable> exp1 =
|
List<Writable> exp1 = Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1"));
|
||||||
Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1"));
|
|
||||||
assertEquals(exp1, json1);
|
assertEquals(exp1, json1);
|
||||||
|
|
||||||
List<Writable> json2 = rr.next();
|
List<Writable> json2 = rr.next();
|
||||||
List<Writable> exp2 =
|
List<Writable> exp2 = Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX"));
|
||||||
Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX"));
|
|
||||||
assertEquals(exp2, json2);
|
assertEquals(exp2, json2);
|
||||||
|
|
||||||
assertFalse(rr.hasNext());
|
assertFalse(rr.hasNext());
|
||||||
|
// Test reset
|
||||||
//Test reset
|
|
||||||
rr.reset();
|
rr.reset();
|
||||||
assertEquals(exp0, rr.next());
|
assertEquals(exp0, rr.next());
|
||||||
assertEquals(exp1, rr.next());
|
assertEquals(exp1, rr.next());
|
||||||
|
@ -147,72 +127,50 @@ public class JacksonRecordReaderTest extends BaseND4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testAppendingLabels() throws Exception {
|
@DisplayName("Test Appending Labels")
|
||||||
|
void testAppendingLabels(@TempDir Path testDir) throws Exception {
|
||||||
ClassPathResource cpr = new ClassPathResource("datavec-api/json/");
|
ClassPathResource cpr = new ClassPathResource("datavec-api/json/");
|
||||||
File f = testDir.newFolder();
|
File f = testDir.toFile();
|
||||||
cpr.copyDirectory(f);
|
cpr.copyDirectory(f);
|
||||||
String path = new File(f, "json_test_%d.txt").getAbsolutePath();
|
String path = new File(f, "json_test_%d.txt").getAbsolutePath();
|
||||||
|
|
||||||
InputSplit is = new NumberedFileInputSplit(path, 0, 2);
|
InputSplit is = new NumberedFileInputSplit(path, 0, 2);
|
||||||
|
// Insert at the end:
|
||||||
//Insert at the end:
|
RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1, new LabelGen());
|
||||||
RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1,
|
|
||||||
new LabelGen());
|
|
||||||
rr.initialize(is);
|
rr.initialize(is);
|
||||||
|
List<Writable> exp0 = Arrays.asList((Writable) new Text("aValue0"), new Text("bValue0"), new Text("cxValue0"), new IntWritable(0));
|
||||||
List<Writable> exp0 = Arrays.asList((Writable) new Text("aValue0"), new Text("bValue0"), new Text("cxValue0"),
|
|
||||||
new IntWritable(0));
|
|
||||||
assertEquals(exp0, rr.next());
|
assertEquals(exp0, rr.next());
|
||||||
|
List<Writable> exp1 = Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1"), new IntWritable(1));
|
||||||
List<Writable> exp1 = Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1"),
|
|
||||||
new IntWritable(1));
|
|
||||||
assertEquals(exp1, rr.next());
|
assertEquals(exp1, rr.next());
|
||||||
|
List<Writable> exp2 = Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX"), new IntWritable(2));
|
||||||
List<Writable> exp2 = Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX"),
|
|
||||||
new IntWritable(2));
|
|
||||||
assertEquals(exp2, rr.next());
|
assertEquals(exp2, rr.next());
|
||||||
|
// Insert at position 0:
|
||||||
//Insert at position 0:
|
rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1, new LabelGen(), 0);
|
||||||
rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1,
|
|
||||||
new LabelGen(), 0);
|
|
||||||
rr.initialize(is);
|
rr.initialize(is);
|
||||||
|
exp0 = Arrays.asList((Writable) new IntWritable(0), new Text("aValue0"), new Text("bValue0"), new Text("cxValue0"));
|
||||||
exp0 = Arrays.asList((Writable) new IntWritable(0), new Text("aValue0"), new Text("bValue0"),
|
|
||||||
new Text("cxValue0"));
|
|
||||||
assertEquals(exp0, rr.next());
|
assertEquals(exp0, rr.next());
|
||||||
|
exp1 = Arrays.asList((Writable) new IntWritable(1), new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1"));
|
||||||
exp1 = Arrays.asList((Writable) new IntWritable(1), new Text("aValue1"), new Text("MISSING_B"),
|
|
||||||
new Text("cxValue1"));
|
|
||||||
assertEquals(exp1, rr.next());
|
assertEquals(exp1, rr.next());
|
||||||
|
exp2 = Arrays.asList((Writable) new IntWritable(2), new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX"));
|
||||||
exp2 = Arrays.asList((Writable) new IntWritable(2), new Text("aValue2"), new Text("bValue2"),
|
|
||||||
new Text("MISSING_CX"));
|
|
||||||
assertEquals(exp2, rr.next());
|
assertEquals(exp2, rr.next());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testAppendingLabelsMetaData() throws Exception {
|
@DisplayName("Test Appending Labels Meta Data")
|
||||||
|
void testAppendingLabelsMetaData(@TempDir Path testDir) throws Exception {
|
||||||
ClassPathResource cpr = new ClassPathResource("datavec-api/json/");
|
ClassPathResource cpr = new ClassPathResource("datavec-api/json/");
|
||||||
File f = testDir.newFolder();
|
File f = testDir.toFile();
|
||||||
cpr.copyDirectory(f);
|
cpr.copyDirectory(f);
|
||||||
String path = new File(f, "json_test_%d.txt").getAbsolutePath();
|
String path = new File(f, "json_test_%d.txt").getAbsolutePath();
|
||||||
|
|
||||||
InputSplit is = new NumberedFileInputSplit(path, 0, 2);
|
InputSplit is = new NumberedFileInputSplit(path, 0, 2);
|
||||||
|
// Insert at the end:
|
||||||
//Insert at the end:
|
RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1, new LabelGen());
|
||||||
RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1,
|
|
||||||
new LabelGen());
|
|
||||||
rr.initialize(is);
|
rr.initialize(is);
|
||||||
|
|
||||||
List<List<Writable>> out = new ArrayList<>();
|
List<List<Writable>> out = new ArrayList<>();
|
||||||
while (rr.hasNext()) {
|
while (rr.hasNext()) {
|
||||||
out.add(rr.next());
|
out.add(rr.next());
|
||||||
}
|
}
|
||||||
assertEquals(3, out.size());
|
assertEquals(3, out.size());
|
||||||
|
|
||||||
rr.reset();
|
rr.reset();
|
||||||
|
|
||||||
List<List<Writable>> out2 = new ArrayList<>();
|
List<List<Writable>> out2 = new ArrayList<>();
|
||||||
List<Record> outRecord = new ArrayList<>();
|
List<Record> outRecord = new ArrayList<>();
|
||||||
List<RecordMetaData> meta = new ArrayList<>();
|
List<RecordMetaData> meta = new ArrayList<>();
|
||||||
|
@ -222,14 +180,12 @@ public class JacksonRecordReaderTest extends BaseND4JTest {
|
||||||
outRecord.add(r);
|
outRecord.add(r);
|
||||||
meta.add(r.getMetaData());
|
meta.add(r.getMetaData());
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(out, out2);
|
assertEquals(out, out2);
|
||||||
|
|
||||||
List<Record> fromMeta = rr.loadFromMetaData(meta);
|
List<Record> fromMeta = rr.loadFromMetaData(meta);
|
||||||
assertEquals(outRecord, fromMeta);
|
assertEquals(outRecord, fromMeta);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@DisplayName("Label Gen")
|
||||||
private static class LabelGen implements PathLabelGenerator {
|
private static class LabelGen implements PathLabelGenerator {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -252,5 +208,4 @@ public class JacksonRecordReaderTest extends BaseND4JTest {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,7 +17,6 @@
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
* *****************************************************************************
|
* *****************************************************************************
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.datavec.api.records.reader.impl;
|
package org.datavec.api.records.reader.impl;
|
||||||
|
|
||||||
import org.datavec.api.conf.Configuration;
|
import org.datavec.api.conf.Configuration;
|
||||||
|
@ -27,43 +26,30 @@ import org.datavec.api.split.FileSplit;
|
||||||
import org.datavec.api.writable.DoubleWritable;
|
import org.datavec.api.writable.DoubleWritable;
|
||||||
import org.datavec.api.writable.IntWritable;
|
import org.datavec.api.writable.IntWritable;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
import org.nd4j.common.io.ClassPathResource;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
import static org.datavec.api.records.reader.impl.misc.LibSvmRecordReader.*;
|
import static org.datavec.api.records.reader.impl.misc.LibSvmRecordReader.*;
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import org.junit.jupiter.api.DisplayName;
|
||||||
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||||
|
|
||||||
public class LibSvmRecordReaderTest extends BaseND4JTest {
|
@DisplayName("Lib Svm Record Reader Test")
|
||||||
|
class LibSvmRecordReaderTest extends BaseND4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBasicRecord() throws IOException, InterruptedException {
|
@DisplayName("Test Basic Record")
|
||||||
|
void testBasicRecord() throws IOException, InterruptedException {
|
||||||
Map<Integer, List<Writable>> correct = new HashMap<>();
|
Map<Integer, List<Writable>> correct = new HashMap<>();
|
||||||
// 7 2:1 4:2 6:3 8:4 10:5
|
// 7 2:1 4:2 6:3 8:4 10:5
|
||||||
correct.put(0, Arrays.asList(ZERO, ONE,
|
correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), new IntWritable(7)));
|
||||||
ZERO, new DoubleWritable(2),
|
|
||||||
ZERO, new DoubleWritable(3),
|
|
||||||
ZERO, new DoubleWritable(4),
|
|
||||||
ZERO, new DoubleWritable(5),
|
|
||||||
new IntWritable(7)));
|
|
||||||
// 2 qid:42 1:0.1 2:2 6:6.6 8:80
|
// 2 qid:42 1:0.1 2:2 6:6.6 8:80
|
||||||
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2),
|
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, new IntWritable(2)));
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, new DoubleWritable(6.6),
|
|
||||||
ZERO, new DoubleWritable(80),
|
|
||||||
ZERO, ZERO,
|
|
||||||
new IntWritable(2)));
|
|
||||||
// 33
|
// 33
|
||||||
correct.put(2, Arrays.asList(ZERO, ZERO,
|
correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, new IntWritable(33)));
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
new IntWritable(33)));
|
|
||||||
|
|
||||||
LibSvmRecordReader rr = new LibSvmRecordReader();
|
LibSvmRecordReader rr = new LibSvmRecordReader();
|
||||||
Configuration config = new Configuration();
|
Configuration config = new Configuration();
|
||||||
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
|
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
|
||||||
|
@ -80,27 +66,15 @@ public class LibSvmRecordReaderTest extends BaseND4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testNoAppendLabel() throws IOException, InterruptedException {
|
@DisplayName("Test No Append Label")
|
||||||
|
void testNoAppendLabel() throws IOException, InterruptedException {
|
||||||
Map<Integer, List<Writable>> correct = new HashMap<>();
|
Map<Integer, List<Writable>> correct = new HashMap<>();
|
||||||
// 7 2:1 4:2 6:3 8:4 10:5
|
// 7 2:1 4:2 6:3 8:4 10:5
|
||||||
correct.put(0, Arrays.asList(ZERO, ONE,
|
correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5)));
|
||||||
ZERO, new DoubleWritable(2),
|
|
||||||
ZERO, new DoubleWritable(3),
|
|
||||||
ZERO, new DoubleWritable(4),
|
|
||||||
ZERO, new DoubleWritable(5)));
|
|
||||||
// 2 qid:42 1:0.1 2:2 6:6.6 8:80
|
// 2 qid:42 1:0.1 2:2 6:6.6 8:80
|
||||||
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2),
|
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO));
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, new DoubleWritable(6.6),
|
|
||||||
ZERO, new DoubleWritable(80),
|
|
||||||
ZERO, ZERO));
|
|
||||||
// 33
|
// 33
|
||||||
correct.put(2, Arrays.asList(ZERO, ZERO,
|
correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO));
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO));
|
|
||||||
|
|
||||||
SVMLightRecordReader rr = new SVMLightRecordReader();
|
SVMLightRecordReader rr = new SVMLightRecordReader();
|
||||||
Configuration config = new Configuration();
|
Configuration config = new Configuration();
|
||||||
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
|
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
|
||||||
|
@ -117,33 +91,17 @@ public class LibSvmRecordReaderTest extends BaseND4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testNoLabel() throws IOException, InterruptedException {
|
@DisplayName("Test No Label")
|
||||||
|
void testNoLabel() throws IOException, InterruptedException {
|
||||||
Map<Integer, List<Writable>> correct = new HashMap<>();
|
Map<Integer, List<Writable>> correct = new HashMap<>();
|
||||||
// 2:1 4:2 6:3 8:4 10:5
|
// 2:1 4:2 6:3 8:4 10:5
|
||||||
correct.put(0, Arrays.asList(ZERO, ONE,
|
correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5)));
|
||||||
ZERO, new DoubleWritable(2),
|
// qid:42 1:0.1 2:2 6:6.6 8:80
|
||||||
ZERO, new DoubleWritable(3),
|
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO));
|
||||||
ZERO, new DoubleWritable(4),
|
// 1:1.0
|
||||||
ZERO, new DoubleWritable(5)));
|
correct.put(2, Arrays.asList(new DoubleWritable(1.0), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO));
|
||||||
// qid:42 1:0.1 2:2 6:6.6 8:80
|
|
||||||
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2),
|
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, new DoubleWritable(6.6),
|
|
||||||
ZERO, new DoubleWritable(80),
|
|
||||||
ZERO, ZERO));
|
|
||||||
// 1:1.0
|
|
||||||
correct.put(2, Arrays.asList(new DoubleWritable(1.0), ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO));
|
|
||||||
//
|
//
|
||||||
correct.put(3, Arrays.asList(ZERO, ZERO,
|
correct.put(3, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO));
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO));
|
|
||||||
|
|
||||||
SVMLightRecordReader rr = new SVMLightRecordReader();
|
SVMLightRecordReader rr = new SVMLightRecordReader();
|
||||||
Configuration config = new Configuration();
|
Configuration config = new Configuration();
|
||||||
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
|
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
|
||||||
|
@ -160,33 +118,15 @@ public class LibSvmRecordReaderTest extends BaseND4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testMultioutputRecord() throws IOException, InterruptedException {
|
@DisplayName("Test Multioutput Record")
|
||||||
|
void testMultioutputRecord() throws IOException, InterruptedException {
|
||||||
Map<Integer, List<Writable>> correct = new HashMap<>();
|
Map<Integer, List<Writable>> correct = new HashMap<>();
|
||||||
// 7 2.45,9 2:1 4:2 6:3 8:4 10:5
|
// 7 2.45,9 2:1 4:2 6:3 8:4 10:5
|
||||||
correct.put(0, Arrays.asList(ZERO, ONE,
|
correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), new IntWritable(7), new DoubleWritable(2.45), new IntWritable(9)));
|
||||||
ZERO, new DoubleWritable(2),
|
|
||||||
ZERO, new DoubleWritable(3),
|
|
||||||
ZERO, new DoubleWritable(4),
|
|
||||||
ZERO, new DoubleWritable(5),
|
|
||||||
new IntWritable(7), new DoubleWritable(2.45),
|
|
||||||
new IntWritable(9)));
|
|
||||||
// 2,3,4 qid:42 1:0.1 2:2 6:6.6 8:80
|
// 2,3,4 qid:42 1:0.1 2:2 6:6.6 8:80
|
||||||
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2),
|
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, new IntWritable(2), new IntWritable(3), new IntWritable(4)));
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, new DoubleWritable(6.6),
|
|
||||||
ZERO, new DoubleWritable(80),
|
|
||||||
ZERO, ZERO,
|
|
||||||
new IntWritable(2), new IntWritable(3),
|
|
||||||
new IntWritable(4)));
|
|
||||||
// 33,32.0,31.9
|
// 33,32.0,31.9
|
||||||
correct.put(2, Arrays.asList(ZERO, ZERO,
|
correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, new IntWritable(33), new DoubleWritable(32.0), new DoubleWritable(31.9)));
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
new IntWritable(33), new DoubleWritable(32.0),
|
|
||||||
new DoubleWritable(31.9)));
|
|
||||||
|
|
||||||
LibSvmRecordReader rr = new LibSvmRecordReader();
|
LibSvmRecordReader rr = new LibSvmRecordReader();
|
||||||
Configuration config = new Configuration();
|
Configuration config = new Configuration();
|
||||||
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
|
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
|
||||||
|
@ -202,51 +142,20 @@ public class LibSvmRecordReaderTest extends BaseND4JTest {
|
||||||
assertEquals(i, correct.size());
|
assertEquals(i, correct.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testMultilabelRecord() throws IOException, InterruptedException {
|
@DisplayName("Test Multilabel Record")
|
||||||
|
void testMultilabelRecord() throws IOException, InterruptedException {
|
||||||
Map<Integer, List<Writable>> correct = new HashMap<>();
|
Map<Integer, List<Writable>> correct = new HashMap<>();
|
||||||
// 1,3 2:1 4:2 6:3 8:4 10:5
|
// 1,3 2:1 4:2 6:3 8:4 10:5
|
||||||
correct.put(0, Arrays.asList(ZERO, ONE,
|
correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), LABEL_ONE, LABEL_ZERO, LABEL_ONE, LABEL_ZERO));
|
||||||
ZERO, new DoubleWritable(2),
|
|
||||||
ZERO, new DoubleWritable(3),
|
|
||||||
ZERO, new DoubleWritable(4),
|
|
||||||
ZERO, new DoubleWritable(5),
|
|
||||||
LABEL_ONE, LABEL_ZERO,
|
|
||||||
LABEL_ONE, LABEL_ZERO));
|
|
||||||
// 2 qid:42 1:0.1 2:2 6:6.6 8:80
|
// 2 qid:42 1:0.1 2:2 6:6.6 8:80
|
||||||
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2),
|
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, LABEL_ZERO, LABEL_ONE, LABEL_ZERO, LABEL_ZERO));
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, new DoubleWritable(6.6),
|
|
||||||
ZERO, new DoubleWritable(80),
|
|
||||||
ZERO, ZERO,
|
|
||||||
LABEL_ZERO, LABEL_ONE,
|
|
||||||
LABEL_ZERO, LABEL_ZERO));
|
|
||||||
// 1,2,4
|
// 1,2,4
|
||||||
correct.put(2, Arrays.asList(ZERO, ZERO,
|
correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ONE, LABEL_ONE, LABEL_ZERO, LABEL_ONE));
|
||||||
ZERO, ZERO,
|
// 1:1.0
|
||||||
ZERO, ZERO,
|
correct.put(3, Arrays.asList(new DoubleWritable(1.0), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
LABEL_ONE, LABEL_ONE,
|
|
||||||
LABEL_ZERO, LABEL_ONE));
|
|
||||||
// 1:1.0
|
|
||||||
correct.put(3, Arrays.asList(new DoubleWritable(1.0), ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
LABEL_ZERO, LABEL_ZERO,
|
|
||||||
LABEL_ZERO, LABEL_ZERO));
|
|
||||||
//
|
//
|
||||||
correct.put(4, Arrays.asList(ZERO, ZERO,
|
correct.put(4, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
LABEL_ZERO, LABEL_ZERO,
|
|
||||||
LABEL_ZERO, LABEL_ZERO));
|
|
||||||
|
|
||||||
LibSvmRecordReader rr = new LibSvmRecordReader();
|
LibSvmRecordReader rr = new LibSvmRecordReader();
|
||||||
Configuration config = new Configuration();
|
Configuration config = new Configuration();
|
||||||
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
|
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
|
||||||
|
@ -265,63 +174,24 @@ public class LibSvmRecordReaderTest extends BaseND4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testZeroBasedIndexing() throws IOException, InterruptedException {
|
@DisplayName("Test Zero Based Indexing")
|
||||||
|
void testZeroBasedIndexing() throws IOException, InterruptedException {
|
||||||
Map<Integer, List<Writable>> correct = new HashMap<>();
|
Map<Integer, List<Writable>> correct = new HashMap<>();
|
||||||
// 1,3 2:1 4:2 6:3 8:4 10:5
|
// 1,3 2:1 4:2 6:3 8:4 10:5
|
||||||
correct.put(0, Arrays.asList(ZERO,
|
correct.put(0, Arrays.asList(ZERO, ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), LABEL_ZERO, LABEL_ONE, LABEL_ZERO, LABEL_ONE, LABEL_ZERO));
|
||||||
ZERO, ONE,
|
|
||||||
ZERO, new DoubleWritable(2),
|
|
||||||
ZERO, new DoubleWritable(3),
|
|
||||||
ZERO, new DoubleWritable(4),
|
|
||||||
ZERO, new DoubleWritable(5),
|
|
||||||
LABEL_ZERO,
|
|
||||||
LABEL_ONE, LABEL_ZERO,
|
|
||||||
LABEL_ONE, LABEL_ZERO));
|
|
||||||
// 2 qid:42 1:0.1 2:2 6:6.6 8:80
|
// 2 qid:42 1:0.1 2:2 6:6.6 8:80
|
||||||
correct.put(1, Arrays.asList(ZERO,
|
correct.put(1, Arrays.asList(ZERO, new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ONE, LABEL_ZERO, LABEL_ZERO));
|
||||||
new DoubleWritable(0.1), new DoubleWritable(2),
|
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, new DoubleWritable(6.6),
|
|
||||||
ZERO, new DoubleWritable(80),
|
|
||||||
ZERO, ZERO,
|
|
||||||
LABEL_ZERO,
|
|
||||||
LABEL_ZERO, LABEL_ONE,
|
|
||||||
LABEL_ZERO, LABEL_ZERO));
|
|
||||||
// 1,2,4
|
// 1,2,4
|
||||||
correct.put(2, Arrays.asList(ZERO,
|
correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ONE, LABEL_ONE, LABEL_ZERO, LABEL_ONE));
|
||||||
ZERO, ZERO,
|
// 1:1.0
|
||||||
ZERO, ZERO,
|
correct.put(3, Arrays.asList(ZERO, new DoubleWritable(1.0), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
LABEL_ZERO,
|
|
||||||
LABEL_ONE, LABEL_ONE,
|
|
||||||
LABEL_ZERO, LABEL_ONE));
|
|
||||||
// 1:1.0
|
|
||||||
correct.put(3, Arrays.asList(ZERO,
|
|
||||||
new DoubleWritable(1.0), ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
LABEL_ZERO,
|
|
||||||
LABEL_ZERO, LABEL_ZERO,
|
|
||||||
LABEL_ZERO, LABEL_ZERO));
|
|
||||||
//
|
//
|
||||||
correct.put(4, Arrays.asList(ZERO,
|
correct.put(4, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
LABEL_ZERO,
|
|
||||||
LABEL_ZERO, LABEL_ZERO,
|
|
||||||
LABEL_ZERO, LABEL_ZERO));
|
|
||||||
|
|
||||||
LibSvmRecordReader rr = new LibSvmRecordReader();
|
LibSvmRecordReader rr = new LibSvmRecordReader();
|
||||||
Configuration config = new Configuration();
|
Configuration config = new Configuration();
|
||||||
// Zero-based indexing is default
|
// Zero-based indexing is default
|
||||||
config.setBoolean(SVMLightRecordReader.ZERO_BASED_LABEL_INDEXING, true); // NOT STANDARD!
|
// NOT STANDARD!
|
||||||
|
config.setBoolean(SVMLightRecordReader.ZERO_BASED_LABEL_INDEXING, true);
|
||||||
config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true);
|
config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true);
|
||||||
config.setInt(LibSvmRecordReader.NUM_FEATURES, 11);
|
config.setInt(LibSvmRecordReader.NUM_FEATURES, 11);
|
||||||
config.setBoolean(LibSvmRecordReader.MULTILABEL, true);
|
config.setBoolean(LibSvmRecordReader.MULTILABEL, true);
|
||||||
|
@ -336,87 +206,107 @@ public class LibSvmRecordReaderTest extends BaseND4JTest {
|
||||||
assertEquals(i, correct.size());
|
assertEquals(i, correct.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected = NoSuchElementException.class)
|
@Test
|
||||||
public void testNoSuchElementException() throws Exception {
|
@DisplayName("Test No Such Element Exception")
|
||||||
LibSvmRecordReader rr = new LibSvmRecordReader();
|
void testNoSuchElementException() {
|
||||||
Configuration config = new Configuration();
|
assertThrows(NoSuchElementException.class, () -> {
|
||||||
config.setInt(LibSvmRecordReader.NUM_FEATURES, 11);
|
LibSvmRecordReader rr = new LibSvmRecordReader();
|
||||||
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
|
Configuration config = new Configuration();
|
||||||
while (rr.hasNext())
|
config.setInt(LibSvmRecordReader.NUM_FEATURES, 11);
|
||||||
|
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
|
||||||
|
while (rr.hasNext()) rr.next();
|
||||||
rr.next();
|
rr.next();
|
||||||
rr.next();
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected = UnsupportedOperationException.class)
|
@Test
|
||||||
public void failedToSetNumFeaturesException() throws Exception {
|
@DisplayName("Failed To Set Num Features Exception")
|
||||||
LibSvmRecordReader rr = new LibSvmRecordReader();
|
void failedToSetNumFeaturesException() {
|
||||||
Configuration config = new Configuration();
|
assertThrows(UnsupportedOperationException.class, () -> {
|
||||||
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
|
LibSvmRecordReader rr = new LibSvmRecordReader();
|
||||||
while (rr.hasNext())
|
Configuration config = new Configuration();
|
||||||
|
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
|
||||||
|
while (rr.hasNext()) rr.next();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@DisplayName("Test Inconsistent Num Labels Exception")
|
||||||
|
void testInconsistentNumLabelsException() {
|
||||||
|
assertThrows(UnsupportedOperationException.class, () -> {
|
||||||
|
LibSvmRecordReader rr = new LibSvmRecordReader();
|
||||||
|
Configuration config = new Configuration();
|
||||||
|
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
|
||||||
|
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/inconsistentNumLabels.txt").getFile()));
|
||||||
|
while (rr.hasNext()) rr.next();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@DisplayName("Test Inconsistent Num Multiabels Exception")
|
||||||
|
void testInconsistentNumMultiabelsException() {
|
||||||
|
assertThrows(UnsupportedOperationException.class, () -> {
|
||||||
|
LibSvmRecordReader rr = new LibSvmRecordReader();
|
||||||
|
Configuration config = new Configuration();
|
||||||
|
config.setBoolean(LibSvmRecordReader.MULTILABEL, false);
|
||||||
|
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
|
||||||
|
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile()));
|
||||||
|
while (rr.hasNext()) rr.next();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@DisplayName("Test Feature Index Exceeds Num Features")
|
||||||
|
void testFeatureIndexExceedsNumFeatures() {
|
||||||
|
assertThrows(IndexOutOfBoundsException.class, () -> {
|
||||||
|
LibSvmRecordReader rr = new LibSvmRecordReader();
|
||||||
|
Configuration config = new Configuration();
|
||||||
|
config.setInt(LibSvmRecordReader.NUM_FEATURES, 9);
|
||||||
|
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
|
||||||
rr.next();
|
rr.next();
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected = UnsupportedOperationException.class)
|
@Test
|
||||||
public void testInconsistentNumLabelsException() throws Exception {
|
@DisplayName("Test Label Index Exceeds Num Labels")
|
||||||
LibSvmRecordReader rr = new LibSvmRecordReader();
|
void testLabelIndexExceedsNumLabels() {
|
||||||
Configuration config = new Configuration();
|
assertThrows(IndexOutOfBoundsException.class, () -> {
|
||||||
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
|
LibSvmRecordReader rr = new LibSvmRecordReader();
|
||||||
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/inconsistentNumLabels.txt").getFile()));
|
Configuration config = new Configuration();
|
||||||
while (rr.hasNext())
|
config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true);
|
||||||
|
config.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
|
||||||
|
config.setInt(LibSvmRecordReader.NUM_LABELS, 6);
|
||||||
|
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
|
||||||
rr.next();
|
rr.next();
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected = UnsupportedOperationException.class)
|
@Test
|
||||||
public void testInconsistentNumMultiabelsException() throws Exception {
|
@DisplayName("Test Zero Index Feature Without Using Zero Indexing")
|
||||||
LibSvmRecordReader rr = new LibSvmRecordReader();
|
void testZeroIndexFeatureWithoutUsingZeroIndexing() {
|
||||||
Configuration config = new Configuration();
|
assertThrows(IndexOutOfBoundsException.class, () -> {
|
||||||
config.setBoolean(LibSvmRecordReader.MULTILABEL, false);
|
LibSvmRecordReader rr = new LibSvmRecordReader();
|
||||||
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
|
Configuration config = new Configuration();
|
||||||
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile()));
|
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
|
||||||
while (rr.hasNext())
|
config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true);
|
||||||
|
config.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
|
||||||
|
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexFeature.txt").getFile()));
|
||||||
rr.next();
|
rr.next();
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected = IndexOutOfBoundsException.class)
|
@Test
|
||||||
public void testFeatureIndexExceedsNumFeatures() throws Exception {
|
@DisplayName("Test Zero Index Label Without Using Zero Indexing")
|
||||||
LibSvmRecordReader rr = new LibSvmRecordReader();
|
void testZeroIndexLabelWithoutUsingZeroIndexing() {
|
||||||
Configuration config = new Configuration();
|
assertThrows(IndexOutOfBoundsException.class, () -> {
|
||||||
config.setInt(LibSvmRecordReader.NUM_FEATURES, 9);
|
LibSvmRecordReader rr = new LibSvmRecordReader();
|
||||||
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
|
Configuration config = new Configuration();
|
||||||
rr.next();
|
config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true);
|
||||||
}
|
config.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
|
||||||
|
config.setBoolean(LibSvmRecordReader.MULTILABEL, true);
|
||||||
@Test(expected = IndexOutOfBoundsException.class)
|
config.setInt(LibSvmRecordReader.NUM_LABELS, 2);
|
||||||
public void testLabelIndexExceedsNumLabels() throws Exception {
|
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexLabel.txt").getFile()));
|
||||||
LibSvmRecordReader rr = new LibSvmRecordReader();
|
rr.next();
|
||||||
Configuration config = new Configuration();
|
});
|
||||||
config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true);
|
|
||||||
config.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
|
|
||||||
config.setInt(LibSvmRecordReader.NUM_LABELS, 6);
|
|
||||||
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
|
|
||||||
rr.next();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test(expected = IndexOutOfBoundsException.class)
|
|
||||||
public void testZeroIndexFeatureWithoutUsingZeroIndexing() throws Exception {
|
|
||||||
LibSvmRecordReader rr = new LibSvmRecordReader();
|
|
||||||
Configuration config = new Configuration();
|
|
||||||
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
|
|
||||||
config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true);
|
|
||||||
config.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
|
|
||||||
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexFeature.txt").getFile()));
|
|
||||||
rr.next();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test(expected = IndexOutOfBoundsException.class)
|
|
||||||
public void testZeroIndexLabelWithoutUsingZeroIndexing() throws Exception {
|
|
||||||
LibSvmRecordReader rr = new LibSvmRecordReader();
|
|
||||||
Configuration config = new Configuration();
|
|
||||||
config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true);
|
|
||||||
config.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
|
|
||||||
config.setBoolean(LibSvmRecordReader.MULTILABEL, true);
|
|
||||||
config.setInt(LibSvmRecordReader.NUM_LABELS, 2);
|
|
||||||
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexLabel.txt").getFile()));
|
|
||||||
rr.next();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,7 +17,6 @@
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
* *****************************************************************************
|
* *****************************************************************************
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.datavec.api.records.reader.impl;
|
package org.datavec.api.records.reader.impl;
|
||||||
|
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
|
@ -30,11 +29,10 @@ import org.datavec.api.split.FileSplit;
|
||||||
import org.datavec.api.split.InputSplit;
|
import org.datavec.api.split.InputSplit;
|
||||||
import org.datavec.api.split.InputStreamInputSplit;
|
import org.datavec.api.split.InputStreamInputSplit;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.junit.rules.TemporaryFolder;
|
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
|
||||||
|
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.io.FileInputStream;
|
import java.io.FileInputStream;
|
||||||
import java.io.FileOutputStream;
|
import java.io.FileOutputStream;
|
||||||
|
@ -45,34 +43,31 @@ import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.zip.GZIPInputStream;
|
import java.util.zip.GZIPInputStream;
|
||||||
import java.util.zip.GZIPOutputStream;
|
import java.util.zip.GZIPOutputStream;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import org.junit.jupiter.api.DisplayName;
|
||||||
|
import java.nio.file.Path;
|
||||||
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
@DisplayName("Line Reader Test")
|
||||||
|
class LineReaderTest extends BaseND4JTest {
|
||||||
|
|
||||||
public class LineReaderTest extends BaseND4JTest {
|
|
||||||
|
|
||||||
@Rule
|
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testLineReader() throws Exception {
|
@DisplayName("Test Line Reader")
|
||||||
File tmpdir = testDir.newFolder();
|
void testLineReader(@TempDir Path tmpDir) throws Exception {
|
||||||
|
File tmpdir = tmpDir.toFile();
|
||||||
if (tmpdir.exists())
|
if (tmpdir.exists())
|
||||||
tmpdir.delete();
|
tmpdir.delete();
|
||||||
tmpdir.mkdir();
|
tmpdir.mkdir();
|
||||||
|
|
||||||
File tmp1 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp1.txt"));
|
File tmp1 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp1.txt"));
|
||||||
File tmp2 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp2.txt"));
|
File tmp2 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp2.txt"));
|
||||||
File tmp3 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp3.txt"));
|
File tmp3 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp3.txt"));
|
||||||
|
|
||||||
FileUtils.writeLines(tmp1, Arrays.asList("1", "2", "3"));
|
FileUtils.writeLines(tmp1, Arrays.asList("1", "2", "3"));
|
||||||
FileUtils.writeLines(tmp2, Arrays.asList("4", "5", "6"));
|
FileUtils.writeLines(tmp2, Arrays.asList("4", "5", "6"));
|
||||||
FileUtils.writeLines(tmp3, Arrays.asList("7", "8", "9"));
|
FileUtils.writeLines(tmp3, Arrays.asList("7", "8", "9"));
|
||||||
|
|
||||||
InputSplit split = new FileSplit(tmpdir);
|
InputSplit split = new FileSplit(tmpdir);
|
||||||
|
|
||||||
RecordReader reader = new LineRecordReader();
|
RecordReader reader = new LineRecordReader();
|
||||||
reader.initialize(split);
|
reader.initialize(split);
|
||||||
|
|
||||||
int count = 0;
|
int count = 0;
|
||||||
List<List<Writable>> list = new ArrayList<>();
|
List<List<Writable>> list = new ArrayList<>();
|
||||||
while (reader.hasNext()) {
|
while (reader.hasNext()) {
|
||||||
|
@ -81,34 +76,27 @@ public class LineReaderTest extends BaseND4JTest {
|
||||||
list.add(l);
|
list.add(l);
|
||||||
count++;
|
count++;
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(9, count);
|
assertEquals(9, count);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testLineReaderMetaData() throws Exception {
|
@DisplayName("Test Line Reader Meta Data")
|
||||||
File tmpdir = testDir.newFolder();
|
void testLineReaderMetaData(@TempDir Path tmpDir) throws Exception {
|
||||||
|
File tmpdir = tmpDir.toFile();
|
||||||
File tmp1 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp1.txt"));
|
File tmp1 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp1.txt"));
|
||||||
File tmp2 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp2.txt"));
|
File tmp2 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp2.txt"));
|
||||||
File tmp3 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp3.txt"));
|
File tmp3 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp3.txt"));
|
||||||
|
|
||||||
FileUtils.writeLines(tmp1, Arrays.asList("1", "2", "3"));
|
FileUtils.writeLines(tmp1, Arrays.asList("1", "2", "3"));
|
||||||
FileUtils.writeLines(tmp2, Arrays.asList("4", "5", "6"));
|
FileUtils.writeLines(tmp2, Arrays.asList("4", "5", "6"));
|
||||||
FileUtils.writeLines(tmp3, Arrays.asList("7", "8", "9"));
|
FileUtils.writeLines(tmp3, Arrays.asList("7", "8", "9"));
|
||||||
|
|
||||||
InputSplit split = new FileSplit(tmpdir);
|
InputSplit split = new FileSplit(tmpdir);
|
||||||
|
|
||||||
RecordReader reader = new LineRecordReader();
|
RecordReader reader = new LineRecordReader();
|
||||||
reader.initialize(split);
|
reader.initialize(split);
|
||||||
|
|
||||||
List<List<Writable>> list = new ArrayList<>();
|
List<List<Writable>> list = new ArrayList<>();
|
||||||
while (reader.hasNext()) {
|
while (reader.hasNext()) {
|
||||||
list.add(reader.next());
|
list.add(reader.next());
|
||||||
}
|
}
|
||||||
assertEquals(9, list.size());
|
assertEquals(9, list.size());
|
||||||
|
|
||||||
|
|
||||||
List<List<Writable>> out2 = new ArrayList<>();
|
List<List<Writable>> out2 = new ArrayList<>();
|
||||||
List<Record> out3 = new ArrayList<>();
|
List<Record> out3 = new ArrayList<>();
|
||||||
List<RecordMetaData> meta = new ArrayList<>();
|
List<RecordMetaData> meta = new ArrayList<>();
|
||||||
|
@ -124,13 +112,10 @@ public class LineReaderTest extends BaseND4JTest {
|
||||||
assertEquals(uri, split.locations()[fileIdx]);
|
assertEquals(uri, split.locations()[fileIdx]);
|
||||||
count++;
|
count++;
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(list, out2);
|
assertEquals(list, out2);
|
||||||
|
|
||||||
List<Record> fromMeta = reader.loadFromMetaData(meta);
|
List<Record> fromMeta = reader.loadFromMetaData(meta);
|
||||||
assertEquals(out3, fromMeta);
|
assertEquals(out3, fromMeta);
|
||||||
|
// try: second line of second and third files only...
|
||||||
//try: second line of second and third files only...
|
|
||||||
List<RecordMetaData> subsetMeta = new ArrayList<>();
|
List<RecordMetaData> subsetMeta = new ArrayList<>();
|
||||||
subsetMeta.add(meta.get(4));
|
subsetMeta.add(meta.get(4));
|
||||||
subsetMeta.add(meta.get(7));
|
subsetMeta.add(meta.get(7));
|
||||||
|
@ -141,27 +126,22 @@ public class LineReaderTest extends BaseND4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testLineReaderWithInputStreamInputSplit() throws Exception {
|
@DisplayName("Test Line Reader With Input Stream Input Split")
|
||||||
File tmpdir = testDir.newFolder();
|
void testLineReaderWithInputStreamInputSplit(@TempDir Path testDir) throws Exception {
|
||||||
|
File tmpdir = testDir.toFile();
|
||||||
File tmp1 = new File(tmpdir, "tmp1.txt.gz");
|
File tmp1 = new File(tmpdir, "tmp1.txt.gz");
|
||||||
|
|
||||||
OutputStream os = new GZIPOutputStream(new FileOutputStream(tmp1, false));
|
OutputStream os = new GZIPOutputStream(new FileOutputStream(tmp1, false));
|
||||||
IOUtils.writeLines(Arrays.asList("1", "2", "3", "4", "5", "6", "7", "8", "9"), null, os);
|
IOUtils.writeLines(Arrays.asList("1", "2", "3", "4", "5", "6", "7", "8", "9"), null, os);
|
||||||
os.flush();
|
os.flush();
|
||||||
os.close();
|
os.close();
|
||||||
|
|
||||||
InputSplit split = new InputStreamInputSplit(new GZIPInputStream(new FileInputStream(tmp1)));
|
InputSplit split = new InputStreamInputSplit(new GZIPInputStream(new FileInputStream(tmp1)));
|
||||||
|
|
||||||
RecordReader reader = new LineRecordReader();
|
RecordReader reader = new LineRecordReader();
|
||||||
reader.initialize(split);
|
reader.initialize(split);
|
||||||
|
|
||||||
int count = 0;
|
int count = 0;
|
||||||
while (reader.hasNext()) {
|
while (reader.hasNext()) {
|
||||||
assertEquals(1, reader.next().size());
|
assertEquals(1, reader.next().size());
|
||||||
count++;
|
count++;
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(9, count);
|
assertEquals(9, count);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,7 +17,6 @@
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
* *****************************************************************************
|
* *****************************************************************************
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.datavec.api.records.reader.impl;
|
package org.datavec.api.records.reader.impl;
|
||||||
|
|
||||||
import org.datavec.api.records.Record;
|
import org.datavec.api.records.Record;
|
||||||
|
@ -33,44 +32,41 @@ import org.datavec.api.split.InputSplit;
|
||||||
import org.datavec.api.split.NumberedFileInputSplit;
|
import org.datavec.api.split.NumberedFileInputSplit;
|
||||||
import org.datavec.api.writable.Text;
|
import org.datavec.api.writable.Text;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.rules.TemporaryFolder;
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
import org.nd4j.common.io.ClassPathResource;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||||
|
import org.junit.jupiter.api.DisplayName;
|
||||||
|
import java.nio.file.Path;
|
||||||
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
@DisplayName("Regex Record Reader Test")
|
||||||
import static org.junit.Assert.assertFalse;
|
class RegexRecordReaderTest extends BaseND4JTest {
|
||||||
|
|
||||||
public class RegexRecordReaderTest extends BaseND4JTest {
|
@TempDir
|
||||||
|
public Path testDir;
|
||||||
@Rule
|
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRegexLineRecordReader() throws Exception {
|
@DisplayName("Test Regex Line Record Reader")
|
||||||
|
void testRegexLineRecordReader() throws Exception {
|
||||||
String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)";
|
String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)";
|
||||||
|
|
||||||
RecordReader rr = new RegexLineRecordReader(regex, 1);
|
RecordReader rr = new RegexLineRecordReader(regex, 1);
|
||||||
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/logtestdata/logtestfile0.txt").getFile()));
|
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/logtestdata/logtestfile0.txt").getFile()));
|
||||||
|
List<Writable> exp0 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.001"), new Text("1"), new Text("DEBUG"), new Text("First entry message!"));
|
||||||
List<Writable> exp0 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.001"), new Text("1"),
|
List<Writable> exp1 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.002"), new Text("2"), new Text("INFO"), new Text("Second entry message!"));
|
||||||
new Text("DEBUG"), new Text("First entry message!"));
|
List<Writable> exp2 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.003"), new Text("3"), new Text("WARN"), new Text("Third entry message!"));
|
||||||
List<Writable> exp1 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.002"), new Text("2"),
|
|
||||||
new Text("INFO"), new Text("Second entry message!"));
|
|
||||||
List<Writable> exp2 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.003"), new Text("3"),
|
|
||||||
new Text("WARN"), new Text("Third entry message!"));
|
|
||||||
assertEquals(exp0, rr.next());
|
assertEquals(exp0, rr.next());
|
||||||
assertEquals(exp1, rr.next());
|
assertEquals(exp1, rr.next());
|
||||||
assertEquals(exp2, rr.next());
|
assertEquals(exp2, rr.next());
|
||||||
assertFalse(rr.hasNext());
|
assertFalse(rr.hasNext());
|
||||||
|
// Test reset:
|
||||||
//Test reset:
|
|
||||||
rr.reset();
|
rr.reset();
|
||||||
assertEquals(exp0, rr.next());
|
assertEquals(exp0, rr.next());
|
||||||
assertEquals(exp1, rr.next());
|
assertEquals(exp1, rr.next());
|
||||||
|
@ -79,74 +75,57 @@ public class RegexRecordReaderTest extends BaseND4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRegexLineRecordReaderMeta() throws Exception {
|
@DisplayName("Test Regex Line Record Reader Meta")
|
||||||
|
void testRegexLineRecordReaderMeta() throws Exception {
|
||||||
String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)";
|
String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)";
|
||||||
|
|
||||||
RecordReader rr = new RegexLineRecordReader(regex, 1);
|
RecordReader rr = new RegexLineRecordReader(regex, 1);
|
||||||
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/logtestdata/logtestfile0.txt").getFile()));
|
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/logtestdata/logtestfile0.txt").getFile()));
|
||||||
|
|
||||||
List<List<Writable>> list = new ArrayList<>();
|
List<List<Writable>> list = new ArrayList<>();
|
||||||
while (rr.hasNext()) {
|
while (rr.hasNext()) {
|
||||||
list.add(rr.next());
|
list.add(rr.next());
|
||||||
}
|
}
|
||||||
assertEquals(3, list.size());
|
assertEquals(3, list.size());
|
||||||
|
|
||||||
List<Record> list2 = new ArrayList<>();
|
List<Record> list2 = new ArrayList<>();
|
||||||
List<List<Writable>> list3 = new ArrayList<>();
|
List<List<Writable>> list3 = new ArrayList<>();
|
||||||
List<RecordMetaData> meta = new ArrayList<>();
|
List<RecordMetaData> meta = new ArrayList<>();
|
||||||
rr.reset();
|
rr.reset();
|
||||||
int count = 1; //Start by skipping 1 line
|
// Start by skipping 1 line
|
||||||
|
int count = 1;
|
||||||
while (rr.hasNext()) {
|
while (rr.hasNext()) {
|
||||||
Record r = rr.nextRecord();
|
Record r = rr.nextRecord();
|
||||||
list2.add(r);
|
list2.add(r);
|
||||||
list3.add(r.getRecord());
|
list3.add(r.getRecord());
|
||||||
meta.add(r.getMetaData());
|
meta.add(r.getMetaData());
|
||||||
|
|
||||||
assertEquals(count++, ((RecordMetaDataLine) r.getMetaData()).getLineNumber());
|
assertEquals(count++, ((RecordMetaDataLine) r.getMetaData()).getLineNumber());
|
||||||
}
|
}
|
||||||
|
|
||||||
List<Record> fromMeta = rr.loadFromMetaData(meta);
|
List<Record> fromMeta = rr.loadFromMetaData(meta);
|
||||||
|
|
||||||
assertEquals(list, list3);
|
assertEquals(list, list3);
|
||||||
assertEquals(list2, fromMeta);
|
assertEquals(list2, fromMeta);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRegexSequenceRecordReader() throws Exception {
|
@DisplayName("Test Regex Sequence Record Reader")
|
||||||
|
void testRegexSequenceRecordReader(@TempDir Path testDir) throws Exception {
|
||||||
String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)";
|
String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)";
|
||||||
|
|
||||||
ClassPathResource cpr = new ClassPathResource("datavec-api/logtestdata/");
|
ClassPathResource cpr = new ClassPathResource("datavec-api/logtestdata/");
|
||||||
File f = testDir.newFolder();
|
File f = testDir.toFile();
|
||||||
cpr.copyDirectory(f);
|
cpr.copyDirectory(f);
|
||||||
String path = new File(f, "logtestfile%d.txt").getAbsolutePath();
|
String path = new File(f, "logtestfile%d.txt").getAbsolutePath();
|
||||||
|
|
||||||
InputSplit is = new NumberedFileInputSplit(path, 0, 1);
|
InputSplit is = new NumberedFileInputSplit(path, 0, 1);
|
||||||
|
|
||||||
SequenceRecordReader rr = new RegexSequenceRecordReader(regex, 1);
|
SequenceRecordReader rr = new RegexSequenceRecordReader(regex, 1);
|
||||||
rr.initialize(is);
|
rr.initialize(is);
|
||||||
|
|
||||||
List<List<Writable>> exp0 = new ArrayList<>();
|
List<List<Writable>> exp0 = new ArrayList<>();
|
||||||
exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.001"), new Text("1"), new Text("DEBUG"),
|
exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.001"), new Text("1"), new Text("DEBUG"), new Text("First entry message!")));
|
||||||
new Text("First entry message!")));
|
exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.002"), new Text("2"), new Text("INFO"), new Text("Second entry message!")));
|
||||||
exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.002"), new Text("2"), new Text("INFO"),
|
exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.003"), new Text("3"), new Text("WARN"), new Text("Third entry message!")));
|
||||||
new Text("Second entry message!")));
|
|
||||||
exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.003"), new Text("3"), new Text("WARN"),
|
|
||||||
new Text("Third entry message!")));
|
|
||||||
|
|
||||||
|
|
||||||
List<List<Writable>> exp1 = new ArrayList<>();
|
List<List<Writable>> exp1 = new ArrayList<>();
|
||||||
exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.011"), new Text("11"), new Text("DEBUG"),
|
exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.011"), new Text("11"), new Text("DEBUG"), new Text("First entry message!")));
|
||||||
new Text("First entry message!")));
|
exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.012"), new Text("12"), new Text("INFO"), new Text("Second entry message!")));
|
||||||
exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.012"), new Text("12"), new Text("INFO"),
|
exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.013"), new Text("13"), new Text("WARN"), new Text("Third entry message!")));
|
||||||
new Text("Second entry message!")));
|
|
||||||
exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.013"), new Text("13"), new Text("WARN"),
|
|
||||||
new Text("Third entry message!")));
|
|
||||||
|
|
||||||
assertEquals(exp0, rr.sequenceRecord());
|
assertEquals(exp0, rr.sequenceRecord());
|
||||||
assertEquals(exp1, rr.sequenceRecord());
|
assertEquals(exp1, rr.sequenceRecord());
|
||||||
assertFalse(rr.hasNext());
|
assertFalse(rr.hasNext());
|
||||||
|
// Test resetting:
|
||||||
//Test resetting:
|
|
||||||
rr.reset();
|
rr.reset();
|
||||||
assertEquals(exp0, rr.sequenceRecord());
|
assertEquals(exp0, rr.sequenceRecord());
|
||||||
assertEquals(exp1, rr.sequenceRecord());
|
assertEquals(exp1, rr.sequenceRecord());
|
||||||
|
@ -154,24 +133,20 @@ public class RegexRecordReaderTest extends BaseND4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRegexSequenceRecordReaderMeta() throws Exception {
|
@DisplayName("Test Regex Sequence Record Reader Meta")
|
||||||
|
void testRegexSequenceRecordReaderMeta(@TempDir Path testDir) throws Exception {
|
||||||
String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)";
|
String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)";
|
||||||
|
|
||||||
ClassPathResource cpr = new ClassPathResource("datavec-api/logtestdata/");
|
ClassPathResource cpr = new ClassPathResource("datavec-api/logtestdata/");
|
||||||
File f = testDir.newFolder();
|
File f = testDir.toFile();
|
||||||
cpr.copyDirectory(f);
|
cpr.copyDirectory(f);
|
||||||
String path = new File(f, "logtestfile%d.txt").getAbsolutePath();
|
String path = new File(f, "logtestfile%d.txt").getAbsolutePath();
|
||||||
|
|
||||||
InputSplit is = new NumberedFileInputSplit(path, 0, 1);
|
InputSplit is = new NumberedFileInputSplit(path, 0, 1);
|
||||||
|
|
||||||
SequenceRecordReader rr = new RegexSequenceRecordReader(regex, 1);
|
SequenceRecordReader rr = new RegexSequenceRecordReader(regex, 1);
|
||||||
rr.initialize(is);
|
rr.initialize(is);
|
||||||
|
|
||||||
List<List<List<Writable>>> out = new ArrayList<>();
|
List<List<List<Writable>>> out = new ArrayList<>();
|
||||||
while (rr.hasNext()) {
|
while (rr.hasNext()) {
|
||||||
out.add(rr.sequenceRecord());
|
out.add(rr.sequenceRecord());
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(2, out.size());
|
assertEquals(2, out.size());
|
||||||
List<List<List<Writable>>> out2 = new ArrayList<>();
|
List<List<List<Writable>>> out2 = new ArrayList<>();
|
||||||
List<SequenceRecord> out3 = new ArrayList<>();
|
List<SequenceRecord> out3 = new ArrayList<>();
|
||||||
|
@ -183,11 +158,8 @@ public class RegexRecordReaderTest extends BaseND4JTest {
|
||||||
out3.add(seqr);
|
out3.add(seqr);
|
||||||
meta.add(seqr.getMetaData());
|
meta.add(seqr.getMetaData());
|
||||||
}
|
}
|
||||||
|
|
||||||
List<SequenceRecord> fromMeta = rr.loadSequenceFromMetaData(meta);
|
List<SequenceRecord> fromMeta = rr.loadSequenceFromMetaData(meta);
|
||||||
|
|
||||||
assertEquals(out, out2);
|
assertEquals(out, out2);
|
||||||
assertEquals(out3, fromMeta);
|
assertEquals(out3, fromMeta);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,7 +17,6 @@
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
* *****************************************************************************
|
* *****************************************************************************
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.datavec.api.records.reader.impl;
|
package org.datavec.api.records.reader.impl;
|
||||||
|
|
||||||
import org.datavec.api.conf.Configuration;
|
import org.datavec.api.conf.Configuration;
|
||||||
|
@ -27,43 +26,30 @@ import org.datavec.api.split.FileSplit;
|
||||||
import org.datavec.api.writable.DoubleWritable;
|
import org.datavec.api.writable.DoubleWritable;
|
||||||
import org.datavec.api.writable.IntWritable;
|
import org.datavec.api.writable.IntWritable;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
import org.nd4j.common.io.ClassPathResource;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
import static org.datavec.api.records.reader.impl.misc.SVMLightRecordReader.*;
|
import static org.datavec.api.records.reader.impl.misc.SVMLightRecordReader.*;
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import org.junit.jupiter.api.DisplayName;
|
||||||
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||||
|
|
||||||
public class SVMLightRecordReaderTest extends BaseND4JTest {
|
@DisplayName("Svm Light Record Reader Test")
|
||||||
|
class SVMLightRecordReaderTest extends BaseND4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBasicRecord() throws IOException, InterruptedException {
|
@DisplayName("Test Basic Record")
|
||||||
|
void testBasicRecord() throws IOException, InterruptedException {
|
||||||
Map<Integer, List<Writable>> correct = new HashMap<>();
|
Map<Integer, List<Writable>> correct = new HashMap<>();
|
||||||
// 7 2:1 4:2 6:3 8:4 10:5
|
// 7 2:1 4:2 6:3 8:4 10:5
|
||||||
correct.put(0, Arrays.asList(ZERO, ONE,
|
correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), new IntWritable(7)));
|
||||||
ZERO, new DoubleWritable(2),
|
|
||||||
ZERO, new DoubleWritable(3),
|
|
||||||
ZERO, new DoubleWritable(4),
|
|
||||||
ZERO, new DoubleWritable(5),
|
|
||||||
new IntWritable(7)));
|
|
||||||
// 2 qid:42 1:0.1 2:2 6:6.6 8:80
|
// 2 qid:42 1:0.1 2:2 6:6.6 8:80
|
||||||
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2),
|
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, new IntWritable(2)));
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, new DoubleWritable(6.6),
|
|
||||||
ZERO, new DoubleWritable(80),
|
|
||||||
ZERO, ZERO,
|
|
||||||
new IntWritable(2)));
|
|
||||||
// 33
|
// 33
|
||||||
correct.put(2, Arrays.asList(ZERO, ZERO,
|
correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, new IntWritable(33)));
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
new IntWritable(33)));
|
|
||||||
|
|
||||||
SVMLightRecordReader rr = new SVMLightRecordReader();
|
SVMLightRecordReader rr = new SVMLightRecordReader();
|
||||||
Configuration config = new Configuration();
|
Configuration config = new Configuration();
|
||||||
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
|
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
|
||||||
|
@ -79,27 +65,15 @@ public class SVMLightRecordReaderTest extends BaseND4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testNoAppendLabel() throws IOException, InterruptedException {
|
@DisplayName("Test No Append Label")
|
||||||
|
void testNoAppendLabel() throws IOException, InterruptedException {
|
||||||
Map<Integer, List<Writable>> correct = new HashMap<>();
|
Map<Integer, List<Writable>> correct = new HashMap<>();
|
||||||
// 7 2:1 4:2 6:3 8:4 10:5
|
// 7 2:1 4:2 6:3 8:4 10:5
|
||||||
correct.put(0, Arrays.asList(ZERO, ONE,
|
correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5)));
|
||||||
ZERO, new DoubleWritable(2),
|
|
||||||
ZERO, new DoubleWritable(3),
|
|
||||||
ZERO, new DoubleWritable(4),
|
|
||||||
ZERO, new DoubleWritable(5)));
|
|
||||||
// 2 qid:42 1:0.1 2:2 6:6.6 8:80
|
// 2 qid:42 1:0.1 2:2 6:6.6 8:80
|
||||||
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2),
|
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO));
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, new DoubleWritable(6.6),
|
|
||||||
ZERO, new DoubleWritable(80),
|
|
||||||
ZERO, ZERO));
|
|
||||||
// 33
|
// 33
|
||||||
correct.put(2, Arrays.asList(ZERO, ZERO,
|
correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO));
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO));
|
|
||||||
|
|
||||||
SVMLightRecordReader rr = new SVMLightRecordReader();
|
SVMLightRecordReader rr = new SVMLightRecordReader();
|
||||||
Configuration config = new Configuration();
|
Configuration config = new Configuration();
|
||||||
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
|
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
|
||||||
|
@ -116,33 +90,17 @@ public class SVMLightRecordReaderTest extends BaseND4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testNoLabel() throws IOException, InterruptedException {
|
@DisplayName("Test No Label")
|
||||||
|
void testNoLabel() throws IOException, InterruptedException {
|
||||||
Map<Integer, List<Writable>> correct = new HashMap<>();
|
Map<Integer, List<Writable>> correct = new HashMap<>();
|
||||||
// 2:1 4:2 6:3 8:4 10:5
|
// 2:1 4:2 6:3 8:4 10:5
|
||||||
correct.put(0, Arrays.asList(ZERO, ONE,
|
correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5)));
|
||||||
ZERO, new DoubleWritable(2),
|
// qid:42 1:0.1 2:2 6:6.6 8:80
|
||||||
ZERO, new DoubleWritable(3),
|
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO));
|
||||||
ZERO, new DoubleWritable(4),
|
// 1:1.0
|
||||||
ZERO, new DoubleWritable(5)));
|
correct.put(2, Arrays.asList(new DoubleWritable(1.0), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO));
|
||||||
// qid:42 1:0.1 2:2 6:6.6 8:80
|
|
||||||
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2),
|
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, new DoubleWritable(6.6),
|
|
||||||
ZERO, new DoubleWritable(80),
|
|
||||||
ZERO, ZERO));
|
|
||||||
// 1:1.0
|
|
||||||
correct.put(2, Arrays.asList(new DoubleWritable(1.0), ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO));
|
|
||||||
//
|
//
|
||||||
correct.put(3, Arrays.asList(ZERO, ZERO,
|
correct.put(3, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO));
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO));
|
|
||||||
|
|
||||||
SVMLightRecordReader rr = new SVMLightRecordReader();
|
SVMLightRecordReader rr = new SVMLightRecordReader();
|
||||||
Configuration config = new Configuration();
|
Configuration config = new Configuration();
|
||||||
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
|
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
|
||||||
|
@ -159,33 +117,15 @@ public class SVMLightRecordReaderTest extends BaseND4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testMultioutputRecord() throws IOException, InterruptedException {
|
@DisplayName("Test Multioutput Record")
|
||||||
|
void testMultioutputRecord() throws IOException, InterruptedException {
|
||||||
Map<Integer, List<Writable>> correct = new HashMap<>();
|
Map<Integer, List<Writable>> correct = new HashMap<>();
|
||||||
// 7 2.45,9 2:1 4:2 6:3 8:4 10:5
|
// 7 2.45,9 2:1 4:2 6:3 8:4 10:5
|
||||||
correct.put(0, Arrays.asList(ZERO, ONE,
|
correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), new IntWritable(7), new DoubleWritable(2.45), new IntWritable(9)));
|
||||||
ZERO, new DoubleWritable(2),
|
|
||||||
ZERO, new DoubleWritable(3),
|
|
||||||
ZERO, new DoubleWritable(4),
|
|
||||||
ZERO, new DoubleWritable(5),
|
|
||||||
new IntWritable(7), new DoubleWritable(2.45),
|
|
||||||
new IntWritable(9)));
|
|
||||||
// 2,3,4 qid:42 1:0.1 2:2 6:6.6 8:80
|
// 2,3,4 qid:42 1:0.1 2:2 6:6.6 8:80
|
||||||
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2),
|
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, new IntWritable(2), new IntWritable(3), new IntWritable(4)));
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, new DoubleWritable(6.6),
|
|
||||||
ZERO, new DoubleWritable(80),
|
|
||||||
ZERO, ZERO,
|
|
||||||
new IntWritable(2), new IntWritable(3),
|
|
||||||
new IntWritable(4)));
|
|
||||||
// 33,32.0,31.9
|
// 33,32.0,31.9
|
||||||
correct.put(2, Arrays.asList(ZERO, ZERO,
|
correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, new IntWritable(33), new DoubleWritable(32.0), new DoubleWritable(31.9)));
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
new IntWritable(33), new DoubleWritable(32.0),
|
|
||||||
new DoubleWritable(31.9)));
|
|
||||||
|
|
||||||
SVMLightRecordReader rr = new SVMLightRecordReader();
|
SVMLightRecordReader rr = new SVMLightRecordReader();
|
||||||
Configuration config = new Configuration();
|
Configuration config = new Configuration();
|
||||||
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
|
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
|
||||||
|
@ -200,51 +140,20 @@ public class SVMLightRecordReaderTest extends BaseND4JTest {
|
||||||
assertEquals(i, correct.size());
|
assertEquals(i, correct.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testMultilabelRecord() throws IOException, InterruptedException {
|
@DisplayName("Test Multilabel Record")
|
||||||
|
void testMultilabelRecord() throws IOException, InterruptedException {
|
||||||
Map<Integer, List<Writable>> correct = new HashMap<>();
|
Map<Integer, List<Writable>> correct = new HashMap<>();
|
||||||
// 1,3 2:1 4:2 6:3 8:4 10:5
|
// 1,3 2:1 4:2 6:3 8:4 10:5
|
||||||
correct.put(0, Arrays.asList(ZERO, ONE,
|
correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), LABEL_ONE, LABEL_ZERO, LABEL_ONE, LABEL_ZERO));
|
||||||
ZERO, new DoubleWritable(2),
|
|
||||||
ZERO, new DoubleWritable(3),
|
|
||||||
ZERO, new DoubleWritable(4),
|
|
||||||
ZERO, new DoubleWritable(5),
|
|
||||||
LABEL_ONE, LABEL_ZERO,
|
|
||||||
LABEL_ONE, LABEL_ZERO));
|
|
||||||
// 2 qid:42 1:0.1 2:2 6:6.6 8:80
|
// 2 qid:42 1:0.1 2:2 6:6.6 8:80
|
||||||
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2),
|
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, LABEL_ZERO, LABEL_ONE, LABEL_ZERO, LABEL_ZERO));
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, new DoubleWritable(6.6),
|
|
||||||
ZERO, new DoubleWritable(80),
|
|
||||||
ZERO, ZERO,
|
|
||||||
LABEL_ZERO, LABEL_ONE,
|
|
||||||
LABEL_ZERO, LABEL_ZERO));
|
|
||||||
// 1,2,4
|
// 1,2,4
|
||||||
correct.put(2, Arrays.asList(ZERO, ZERO,
|
correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ONE, LABEL_ONE, LABEL_ZERO, LABEL_ONE));
|
||||||
ZERO, ZERO,
|
// 1:1.0
|
||||||
ZERO, ZERO,
|
correct.put(3, Arrays.asList(new DoubleWritable(1.0), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
LABEL_ONE, LABEL_ONE,
|
|
||||||
LABEL_ZERO, LABEL_ONE));
|
|
||||||
// 1:1.0
|
|
||||||
correct.put(3, Arrays.asList(new DoubleWritable(1.0), ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
LABEL_ZERO, LABEL_ZERO,
|
|
||||||
LABEL_ZERO, LABEL_ZERO));
|
|
||||||
//
|
//
|
||||||
correct.put(4, Arrays.asList(ZERO, ZERO,
|
correct.put(4, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
LABEL_ZERO, LABEL_ZERO,
|
|
||||||
LABEL_ZERO, LABEL_ZERO));
|
|
||||||
|
|
||||||
SVMLightRecordReader rr = new SVMLightRecordReader();
|
SVMLightRecordReader rr = new SVMLightRecordReader();
|
||||||
Configuration config = new Configuration();
|
Configuration config = new Configuration();
|
||||||
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
|
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
|
||||||
|
@ -262,63 +171,24 @@ public class SVMLightRecordReaderTest extends BaseND4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testZeroBasedIndexing() throws IOException, InterruptedException {
|
@DisplayName("Test Zero Based Indexing")
|
||||||
|
void testZeroBasedIndexing() throws IOException, InterruptedException {
|
||||||
Map<Integer, List<Writable>> correct = new HashMap<>();
|
Map<Integer, List<Writable>> correct = new HashMap<>();
|
||||||
// 1,3 2:1 4:2 6:3 8:4 10:5
|
// 1,3 2:1 4:2 6:3 8:4 10:5
|
||||||
correct.put(0, Arrays.asList(ZERO,
|
correct.put(0, Arrays.asList(ZERO, ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), LABEL_ZERO, LABEL_ONE, LABEL_ZERO, LABEL_ONE, LABEL_ZERO));
|
||||||
ZERO, ONE,
|
|
||||||
ZERO, new DoubleWritable(2),
|
|
||||||
ZERO, new DoubleWritable(3),
|
|
||||||
ZERO, new DoubleWritable(4),
|
|
||||||
ZERO, new DoubleWritable(5),
|
|
||||||
LABEL_ZERO,
|
|
||||||
LABEL_ONE, LABEL_ZERO,
|
|
||||||
LABEL_ONE, LABEL_ZERO));
|
|
||||||
// 2 qid:42 1:0.1 2:2 6:6.6 8:80
|
// 2 qid:42 1:0.1 2:2 6:6.6 8:80
|
||||||
correct.put(1, Arrays.asList(ZERO,
|
correct.put(1, Arrays.asList(ZERO, new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ONE, LABEL_ZERO, LABEL_ZERO));
|
||||||
new DoubleWritable(0.1), new DoubleWritable(2),
|
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, new DoubleWritable(6.6),
|
|
||||||
ZERO, new DoubleWritable(80),
|
|
||||||
ZERO, ZERO,
|
|
||||||
LABEL_ZERO,
|
|
||||||
LABEL_ZERO, LABEL_ONE,
|
|
||||||
LABEL_ZERO, LABEL_ZERO));
|
|
||||||
// 1,2,4
|
// 1,2,4
|
||||||
correct.put(2, Arrays.asList(ZERO,
|
correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ONE, LABEL_ONE, LABEL_ZERO, LABEL_ONE));
|
||||||
ZERO, ZERO,
|
// 1:1.0
|
||||||
ZERO, ZERO,
|
correct.put(3, Arrays.asList(ZERO, new DoubleWritable(1.0), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
LABEL_ZERO,
|
|
||||||
LABEL_ONE, LABEL_ONE,
|
|
||||||
LABEL_ZERO, LABEL_ONE));
|
|
||||||
// 1:1.0
|
|
||||||
correct.put(3, Arrays.asList(ZERO,
|
|
||||||
new DoubleWritable(1.0), ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
LABEL_ZERO,
|
|
||||||
LABEL_ZERO, LABEL_ZERO,
|
|
||||||
LABEL_ZERO, LABEL_ZERO));
|
|
||||||
//
|
//
|
||||||
correct.put(4, Arrays.asList(ZERO,
|
correct.put(4, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
ZERO, ZERO,
|
|
||||||
LABEL_ZERO,
|
|
||||||
LABEL_ZERO, LABEL_ZERO,
|
|
||||||
LABEL_ZERO, LABEL_ZERO));
|
|
||||||
|
|
||||||
SVMLightRecordReader rr = new SVMLightRecordReader();
|
SVMLightRecordReader rr = new SVMLightRecordReader();
|
||||||
Configuration config = new Configuration();
|
Configuration config = new Configuration();
|
||||||
// Zero-based indexing is default
|
// Zero-based indexing is default
|
||||||
config.setBoolean(SVMLightRecordReader.ZERO_BASED_LABEL_INDEXING, true); // NOT STANDARD!
|
// NOT STANDARD!
|
||||||
|
config.setBoolean(SVMLightRecordReader.ZERO_BASED_LABEL_INDEXING, true);
|
||||||
config.setInt(SVMLightRecordReader.NUM_FEATURES, 11);
|
config.setInt(SVMLightRecordReader.NUM_FEATURES, 11);
|
||||||
config.setBoolean(SVMLightRecordReader.MULTILABEL, true);
|
config.setBoolean(SVMLightRecordReader.MULTILABEL, true);
|
||||||
config.setInt(SVMLightRecordReader.NUM_LABELS, 5);
|
config.setInt(SVMLightRecordReader.NUM_LABELS, 5);
|
||||||
|
@ -333,20 +203,19 @@ public class SVMLightRecordReaderTest extends BaseND4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testNextRecord() throws IOException, InterruptedException {
|
@DisplayName("Test Next Record")
|
||||||
|
void testNextRecord() throws IOException, InterruptedException {
|
||||||
SVMLightRecordReader rr = new SVMLightRecordReader();
|
SVMLightRecordReader rr = new SVMLightRecordReader();
|
||||||
Configuration config = new Configuration();
|
Configuration config = new Configuration();
|
||||||
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
|
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
|
||||||
config.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
|
config.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
|
||||||
config.setBoolean(SVMLightRecordReader.APPEND_LABEL, false);
|
config.setBoolean(SVMLightRecordReader.APPEND_LABEL, false);
|
||||||
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
|
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
|
||||||
|
|
||||||
Record record = rr.nextRecord();
|
Record record = rr.nextRecord();
|
||||||
List<Writable> recordList = record.getRecord();
|
List<Writable> recordList = record.getRecord();
|
||||||
assertEquals(new DoubleWritable(1.0), recordList.get(1));
|
assertEquals(new DoubleWritable(1.0), recordList.get(1));
|
||||||
assertEquals(new DoubleWritable(3.0), recordList.get(5));
|
assertEquals(new DoubleWritable(3.0), recordList.get(5));
|
||||||
assertEquals(new DoubleWritable(4.0), recordList.get(7));
|
assertEquals(new DoubleWritable(4.0), recordList.get(7));
|
||||||
|
|
||||||
record = rr.nextRecord();
|
record = rr.nextRecord();
|
||||||
recordList = record.getRecord();
|
recordList = record.getRecord();
|
||||||
assertEquals(new DoubleWritable(0.1), recordList.get(0));
|
assertEquals(new DoubleWritable(0.1), recordList.get(0));
|
||||||
|
@ -354,82 +223,102 @@ public class SVMLightRecordReaderTest extends BaseND4JTest {
|
||||||
assertEquals(new DoubleWritable(80.0), recordList.get(7));
|
assertEquals(new DoubleWritable(80.0), recordList.get(7));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected = NoSuchElementException.class)
|
@Test
|
||||||
public void testNoSuchElementException() throws Exception {
|
@DisplayName("Test No Such Element Exception")
|
||||||
SVMLightRecordReader rr = new SVMLightRecordReader();
|
void testNoSuchElementException() {
|
||||||
Configuration config = new Configuration();
|
assertThrows(NoSuchElementException.class, () -> {
|
||||||
config.setInt(SVMLightRecordReader.NUM_FEATURES, 11);
|
SVMLightRecordReader rr = new SVMLightRecordReader();
|
||||||
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
|
Configuration config = new Configuration();
|
||||||
while (rr.hasNext())
|
config.setInt(SVMLightRecordReader.NUM_FEATURES, 11);
|
||||||
|
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
|
||||||
|
while (rr.hasNext()) rr.next();
|
||||||
rr.next();
|
rr.next();
|
||||||
rr.next();
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected = UnsupportedOperationException.class)
|
@Test
|
||||||
public void failedToSetNumFeaturesException() throws Exception {
|
@DisplayName("Failed To Set Num Features Exception")
|
||||||
SVMLightRecordReader rr = new SVMLightRecordReader();
|
void failedToSetNumFeaturesException() {
|
||||||
Configuration config = new Configuration();
|
assertThrows(UnsupportedOperationException.class, () -> {
|
||||||
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
|
SVMLightRecordReader rr = new SVMLightRecordReader();
|
||||||
while (rr.hasNext())
|
Configuration config = new Configuration();
|
||||||
|
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
|
||||||
|
while (rr.hasNext()) rr.next();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@DisplayName("Test Inconsistent Num Labels Exception")
|
||||||
|
void testInconsistentNumLabelsException() {
|
||||||
|
assertThrows(UnsupportedOperationException.class, () -> {
|
||||||
|
SVMLightRecordReader rr = new SVMLightRecordReader();
|
||||||
|
Configuration config = new Configuration();
|
||||||
|
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
|
||||||
|
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/inconsistentNumLabels.txt").getFile()));
|
||||||
|
while (rr.hasNext()) rr.next();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@DisplayName("Failed To Set Num Multiabels Exception")
|
||||||
|
void failedToSetNumMultiabelsException() {
|
||||||
|
assertThrows(UnsupportedOperationException.class, () -> {
|
||||||
|
SVMLightRecordReader rr = new SVMLightRecordReader();
|
||||||
|
Configuration config = new Configuration();
|
||||||
|
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile()));
|
||||||
|
while (rr.hasNext()) rr.next();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@DisplayName("Test Feature Index Exceeds Num Features")
|
||||||
|
void testFeatureIndexExceedsNumFeatures() {
|
||||||
|
assertThrows(IndexOutOfBoundsException.class, () -> {
|
||||||
|
SVMLightRecordReader rr = new SVMLightRecordReader();
|
||||||
|
Configuration config = new Configuration();
|
||||||
|
config.setInt(SVMLightRecordReader.NUM_FEATURES, 9);
|
||||||
|
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
|
||||||
rr.next();
|
rr.next();
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected = UnsupportedOperationException.class)
|
@Test
|
||||||
public void testInconsistentNumLabelsException() throws Exception {
|
@DisplayName("Test Label Index Exceeds Num Labels")
|
||||||
SVMLightRecordReader rr = new SVMLightRecordReader();
|
void testLabelIndexExceedsNumLabels() {
|
||||||
Configuration config = new Configuration();
|
assertThrows(IndexOutOfBoundsException.class, () -> {
|
||||||
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
|
SVMLightRecordReader rr = new SVMLightRecordReader();
|
||||||
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/inconsistentNumLabels.txt").getFile()));
|
Configuration config = new Configuration();
|
||||||
while (rr.hasNext())
|
config.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
|
||||||
|
config.setInt(SVMLightRecordReader.NUM_LABELS, 6);
|
||||||
|
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
|
||||||
rr.next();
|
rr.next();
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected = UnsupportedOperationException.class)
|
@Test
|
||||||
public void failedToSetNumMultiabelsException() throws Exception {
|
@DisplayName("Test Zero Index Feature Without Using Zero Indexing")
|
||||||
SVMLightRecordReader rr = new SVMLightRecordReader();
|
void testZeroIndexFeatureWithoutUsingZeroIndexing() {
|
||||||
Configuration config = new Configuration();
|
assertThrows(IndexOutOfBoundsException.class, () -> {
|
||||||
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile()));
|
SVMLightRecordReader rr = new SVMLightRecordReader();
|
||||||
while (rr.hasNext())
|
Configuration config = new Configuration();
|
||||||
|
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
|
||||||
|
config.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
|
||||||
|
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexFeature.txt").getFile()));
|
||||||
rr.next();
|
rr.next();
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected = IndexOutOfBoundsException.class)
|
@Test
|
||||||
public void testFeatureIndexExceedsNumFeatures() throws Exception {
|
@DisplayName("Test Zero Index Label Without Using Zero Indexing")
|
||||||
SVMLightRecordReader rr = new SVMLightRecordReader();
|
void testZeroIndexLabelWithoutUsingZeroIndexing() {
|
||||||
Configuration config = new Configuration();
|
assertThrows(IndexOutOfBoundsException.class, () -> {
|
||||||
config.setInt(SVMLightRecordReader.NUM_FEATURES, 9);
|
SVMLightRecordReader rr = new SVMLightRecordReader();
|
||||||
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
|
Configuration config = new Configuration();
|
||||||
rr.next();
|
config.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
|
||||||
}
|
config.setBoolean(SVMLightRecordReader.MULTILABEL, true);
|
||||||
|
config.setInt(SVMLightRecordReader.NUM_LABELS, 2);
|
||||||
@Test(expected = IndexOutOfBoundsException.class)
|
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexLabel.txt").getFile()));
|
||||||
public void testLabelIndexExceedsNumLabels() throws Exception {
|
rr.next();
|
||||||
SVMLightRecordReader rr = new SVMLightRecordReader();
|
});
|
||||||
Configuration config = new Configuration();
|
|
||||||
config.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
|
|
||||||
config.setInt(SVMLightRecordReader.NUM_LABELS, 6);
|
|
||||||
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
|
|
||||||
rr.next();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test(expected = IndexOutOfBoundsException.class)
|
|
||||||
public void testZeroIndexFeatureWithoutUsingZeroIndexing() throws Exception {
|
|
||||||
SVMLightRecordReader rr = new SVMLightRecordReader();
|
|
||||||
Configuration config = new Configuration();
|
|
||||||
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
|
|
||||||
config.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
|
|
||||||
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexFeature.txt").getFile()));
|
|
||||||
rr.next();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test(expected = IndexOutOfBoundsException.class)
|
|
||||||
public void testZeroIndexLabelWithoutUsingZeroIndexing() throws Exception {
|
|
||||||
SVMLightRecordReader rr = new SVMLightRecordReader();
|
|
||||||
Configuration config = new Configuration();
|
|
||||||
config.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
|
|
||||||
config.setBoolean(SVMLightRecordReader.MULTILABEL, true);
|
|
||||||
config.setInt(SVMLightRecordReader.NUM_LABELS, 2);
|
|
||||||
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexLabel.txt").getFile()));
|
|
||||||
rr.next();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,14 +26,14 @@ import org.datavec.api.records.reader.SequenceRecordReader;
|
||||||
import org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader;
|
import org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader;
|
||||||
import org.datavec.api.writable.IntWritable;
|
import org.datavec.api.writable.IntWritable;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
public class TestCollectionRecordReaders extends BaseND4JTest {
|
public class TestCollectionRecordReaders extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
|
@ -23,11 +23,11 @@ package org.datavec.api.records.reader.impl;
|
||||||
import org.datavec.api.records.reader.RecordReader;
|
import org.datavec.api.records.reader.RecordReader;
|
||||||
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
|
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
|
||||||
import org.datavec.api.split.FileSplit;
|
import org.datavec.api.split.FileSplit;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
import org.nd4j.common.io.ClassPathResource;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestConcatenatingRecordReader extends BaseND4JTest {
|
public class TestConcatenatingRecordReader extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
|
@ -37,7 +37,7 @@ import org.datavec.api.transform.TransformProcess;
|
||||||
import org.datavec.api.transform.schema.Schema;
|
import org.datavec.api.transform.schema.Schema;
|
||||||
import org.datavec.api.writable.Text;
|
import org.datavec.api.writable.Text;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
import org.nd4j.common.io.ClassPathResource;
|
||||||
import org.nd4j.shade.jackson.core.JsonFactory;
|
import org.nd4j.shade.jackson.core.JsonFactory;
|
||||||
|
@ -47,7 +47,7 @@ import java.io.*;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestSerialization extends BaseND4JTest {
|
public class TestSerialization extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
|
@ -30,7 +30,7 @@ import org.datavec.api.writable.IntWritable;
|
||||||
import org.datavec.api.writable.LongWritable;
|
import org.datavec.api.writable.LongWritable;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.joda.time.DateTimeZone;
|
import org.joda.time.DateTimeZone;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
import org.nd4j.common.io.ClassPathResource;
|
||||||
|
|
||||||
|
@ -38,8 +38,8 @@ import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
public class TransformProcessRecordReaderTests extends BaseND4JTest {
|
public class TransformProcessRecordReaderTests extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,6 @@
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
* *****************************************************************************
|
* *****************************************************************************
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.datavec.api.records.writer.impl;
|
package org.datavec.api.records.writer.impl;
|
||||||
|
|
||||||
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
|
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
|
||||||
|
@ -26,44 +25,42 @@ import org.datavec.api.split.FileSplit;
|
||||||
import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
|
import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
|
||||||
import org.datavec.api.writable.Text;
|
import org.datavec.api.writable.Text;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.Before;
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import org.junit.jupiter.api.DisplayName;
|
||||||
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
@DisplayName("Csv Record Writer Test")
|
||||||
|
class CSVRecordWriterTest extends BaseND4JTest {
|
||||||
public class CSVRecordWriterTest extends BaseND4JTest {
|
|
||||||
|
|
||||||
@Before
|
|
||||||
public void setUp() throws Exception {
|
|
||||||
|
|
||||||
|
@BeforeEach
|
||||||
|
void setUp() throws Exception {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testWrite() throws Exception {
|
@DisplayName("Test Write")
|
||||||
|
void testWrite() throws Exception {
|
||||||
File tempFile = File.createTempFile("datavec", "writer");
|
File tempFile = File.createTempFile("datavec", "writer");
|
||||||
tempFile.deleteOnExit();
|
tempFile.deleteOnExit();
|
||||||
FileSplit fileSplit = new FileSplit(tempFile);
|
FileSplit fileSplit = new FileSplit(tempFile);
|
||||||
CSVRecordWriter writer = new CSVRecordWriter();
|
CSVRecordWriter writer = new CSVRecordWriter();
|
||||||
writer.initialize(fileSplit,new NumberOfRecordsPartitioner());
|
writer.initialize(fileSplit, new NumberOfRecordsPartitioner());
|
||||||
List<Writable> collection = new ArrayList<>();
|
List<Writable> collection = new ArrayList<>();
|
||||||
collection.add(new Text("12"));
|
collection.add(new Text("12"));
|
||||||
collection.add(new Text("13"));
|
collection.add(new Text("13"));
|
||||||
collection.add(new Text("14"));
|
collection.add(new Text("14"));
|
||||||
|
|
||||||
writer.write(collection);
|
writer.write(collection);
|
||||||
|
|
||||||
CSVRecordReader reader = new CSVRecordReader(0);
|
CSVRecordReader reader = new CSVRecordReader(0);
|
||||||
reader.initialize(new FileSplit(tempFile));
|
reader.initialize(new FileSplit(tempFile));
|
||||||
int cnt = 0;
|
int cnt = 0;
|
||||||
while (reader.hasNext()) {
|
while (reader.hasNext()) {
|
||||||
List<Writable> line = new ArrayList<>(reader.next());
|
List<Writable> line = new ArrayList<>(reader.next());
|
||||||
assertEquals(3, line.size());
|
assertEquals(3, line.size());
|
||||||
|
|
||||||
assertEquals(12, line.get(0).toInt());
|
assertEquals(12, line.get(0).toInt());
|
||||||
assertEquals(13, line.get(1).toInt());
|
assertEquals(13, line.get(1).toInt());
|
||||||
assertEquals(14, line.get(2).toInt());
|
assertEquals(14, line.get(2).toInt());
|
||||||
|
|
|
@ -17,7 +17,6 @@
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
* *****************************************************************************
|
* *****************************************************************************
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.datavec.api.records.writer.impl;
|
package org.datavec.api.records.writer.impl;
|
||||||
|
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
|
@ -30,93 +29,90 @@ import org.datavec.api.writable.DoubleWritable;
|
||||||
import org.datavec.api.writable.IntWritable;
|
import org.datavec.api.writable.IntWritable;
|
||||||
import org.datavec.api.writable.NDArrayWritable;
|
import org.datavec.api.writable.NDArrayWritable;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
import org.nd4j.common.io.ClassPathResource;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.regex.Matcher;
|
import java.util.regex.Matcher;
|
||||||
import java.util.regex.Pattern;
|
import java.util.regex.Pattern;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import org.junit.jupiter.api.DisplayName;
|
||||||
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
@DisplayName("Lib Svm Record Writer Test")
|
||||||
|
class LibSvmRecordWriterTest extends BaseND4JTest {
|
||||||
public class LibSvmRecordWriterTest extends BaseND4JTest {
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBasic() throws Exception {
|
@DisplayName("Test Basic")
|
||||||
|
void testBasic() throws Exception {
|
||||||
Configuration configWriter = new Configuration();
|
Configuration configWriter = new Configuration();
|
||||||
|
|
||||||
Configuration configReader = new Configuration();
|
Configuration configReader = new Configuration();
|
||||||
configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
|
configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
|
||||||
configReader.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
|
configReader.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
|
||||||
|
|
||||||
File inputFile = new ClassPathResource("datavec-api/svmlight/basic.txt").getFile();
|
File inputFile = new ClassPathResource("datavec-api/svmlight/basic.txt").getFile();
|
||||||
executeTest(configWriter, configReader, inputFile);
|
executeTest(configWriter, configReader, inputFile);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testNoLabel() throws Exception {
|
@DisplayName("Test No Label")
|
||||||
|
void testNoLabel() throws Exception {
|
||||||
Configuration configWriter = new Configuration();
|
Configuration configWriter = new Configuration();
|
||||||
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
||||||
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 9);
|
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 9);
|
||||||
|
|
||||||
Configuration configReader = new Configuration();
|
Configuration configReader = new Configuration();
|
||||||
configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
|
configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
|
||||||
configReader.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
|
configReader.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
|
||||||
|
|
||||||
File inputFile = new ClassPathResource("datavec-api/svmlight/basic.txt").getFile();
|
File inputFile = new ClassPathResource("datavec-api/svmlight/basic.txt").getFile();
|
||||||
executeTest(configWriter, configReader, inputFile);
|
executeTest(configWriter, configReader, inputFile);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testMultioutputRecord() throws Exception {
|
@DisplayName("Test Multioutput Record")
|
||||||
|
void testMultioutputRecord() throws Exception {
|
||||||
Configuration configWriter = new Configuration();
|
Configuration configWriter = new Configuration();
|
||||||
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
||||||
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 9);
|
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 9);
|
||||||
|
|
||||||
Configuration configReader = new Configuration();
|
Configuration configReader = new Configuration();
|
||||||
configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
|
configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
|
||||||
configReader.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
|
configReader.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
|
||||||
|
|
||||||
File inputFile = new ClassPathResource("datavec-api/svmlight/multioutput.txt").getFile();
|
File inputFile = new ClassPathResource("datavec-api/svmlight/multioutput.txt").getFile();
|
||||||
executeTest(configWriter, configReader, inputFile);
|
executeTest(configWriter, configReader, inputFile);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testMultilabelRecord() throws Exception {
|
@DisplayName("Test Multilabel Record")
|
||||||
|
void testMultilabelRecord() throws Exception {
|
||||||
Configuration configWriter = new Configuration();
|
Configuration configWriter = new Configuration();
|
||||||
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
||||||
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 9);
|
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 9);
|
||||||
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
|
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
|
||||||
|
|
||||||
Configuration configReader = new Configuration();
|
Configuration configReader = new Configuration();
|
||||||
configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
|
configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
|
||||||
configReader.setBoolean(LibSvmRecordReader.MULTILABEL, true);
|
configReader.setBoolean(LibSvmRecordReader.MULTILABEL, true);
|
||||||
configReader.setInt(LibSvmRecordReader.NUM_LABELS, 4);
|
configReader.setInt(LibSvmRecordReader.NUM_LABELS, 4);
|
||||||
configReader.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
|
configReader.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
|
||||||
|
|
||||||
File inputFile = new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile();
|
File inputFile = new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile();
|
||||||
executeTest(configWriter, configReader, inputFile);
|
executeTest(configWriter, configReader, inputFile);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testZeroBasedIndexing() throws Exception {
|
@DisplayName("Test Zero Based Indexing")
|
||||||
|
void testZeroBasedIndexing() throws Exception {
|
||||||
Configuration configWriter = new Configuration();
|
Configuration configWriter = new Configuration();
|
||||||
configWriter.setBoolean(LibSvmRecordWriter.ZERO_BASED_INDEXING, true);
|
configWriter.setBoolean(LibSvmRecordWriter.ZERO_BASED_INDEXING, true);
|
||||||
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
||||||
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 10);
|
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 10);
|
||||||
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
|
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
|
||||||
|
|
||||||
Configuration configReader = new Configuration();
|
Configuration configReader = new Configuration();
|
||||||
configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 11);
|
configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 11);
|
||||||
configReader.setBoolean(LibSvmRecordReader.MULTILABEL, true);
|
configReader.setBoolean(LibSvmRecordReader.MULTILABEL, true);
|
||||||
configReader.setInt(LibSvmRecordReader.NUM_LABELS, 5);
|
configReader.setInt(LibSvmRecordReader.NUM_LABELS, 5);
|
||||||
|
|
||||||
File inputFile = new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile();
|
File inputFile = new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile();
|
||||||
executeTest(configWriter, configReader, inputFile);
|
executeTest(configWriter, configReader, inputFile);
|
||||||
}
|
}
|
||||||
|
@ -127,10 +123,9 @@ public class LibSvmRecordWriterTest extends BaseND4JTest {
|
||||||
tempFile.deleteOnExit();
|
tempFile.deleteOnExit();
|
||||||
if (tempFile.exists())
|
if (tempFile.exists())
|
||||||
tempFile.delete();
|
tempFile.delete();
|
||||||
|
|
||||||
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
|
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
|
||||||
FileSplit outputSplit = new FileSplit(tempFile);
|
FileSplit outputSplit = new FileSplit(tempFile);
|
||||||
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
|
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
|
||||||
LibSvmRecordReader rr = new LibSvmRecordReader();
|
LibSvmRecordReader rr = new LibSvmRecordReader();
|
||||||
rr.initialize(configReader, new FileSplit(inputFile));
|
rr.initialize(configReader, new FileSplit(inputFile));
|
||||||
while (rr.hasNext()) {
|
while (rr.hasNext()) {
|
||||||
|
@ -138,7 +133,6 @@ public class LibSvmRecordWriterTest extends BaseND4JTest {
|
||||||
writer.write(record);
|
writer.write(record);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Pattern p = Pattern.compile(String.format("%s:\\d+ ", LibSvmRecordReader.QID_PREFIX));
|
Pattern p = Pattern.compile(String.format("%s:\\d+ ", LibSvmRecordReader.QID_PREFIX));
|
||||||
List<String> linesOriginal = new ArrayList<>();
|
List<String> linesOriginal = new ArrayList<>();
|
||||||
for (String line : FileUtils.readLines(inputFile)) {
|
for (String line : FileUtils.readLines(inputFile)) {
|
||||||
|
@ -159,7 +153,8 @@ public class LibSvmRecordWriterTest extends BaseND4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testNDArrayWritables() throws Exception {
|
@DisplayName("Test ND Array Writables")
|
||||||
|
void testNDArrayWritables() throws Exception {
|
||||||
INDArray arr2 = Nd4j.zeros(2);
|
INDArray arr2 = Nd4j.zeros(2);
|
||||||
arr2.putScalar(0, 11);
|
arr2.putScalar(0, 11);
|
||||||
arr2.putScalar(1, 12);
|
arr2.putScalar(1, 12);
|
||||||
|
@ -167,35 +162,28 @@ public class LibSvmRecordWriterTest extends BaseND4JTest {
|
||||||
arr3.putScalar(0, 13);
|
arr3.putScalar(0, 13);
|
||||||
arr3.putScalar(1, 14);
|
arr3.putScalar(1, 14);
|
||||||
arr3.putScalar(2, 15);
|
arr3.putScalar(2, 15);
|
||||||
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1),
|
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), new NDArrayWritable(arr3), new IntWritable(4));
|
||||||
new NDArrayWritable(arr2),
|
|
||||||
new IntWritable(2),
|
|
||||||
new DoubleWritable(3),
|
|
||||||
new NDArrayWritable(arr3),
|
|
||||||
new IntWritable(4));
|
|
||||||
File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
|
File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
|
||||||
tempFile.setWritable(true);
|
tempFile.setWritable(true);
|
||||||
tempFile.deleteOnExit();
|
tempFile.deleteOnExit();
|
||||||
if (tempFile.exists())
|
if (tempFile.exists())
|
||||||
tempFile.delete();
|
tempFile.delete();
|
||||||
|
|
||||||
String lineOriginal = "13.0,14.0,15.0,4 1:1.0 2:11.0 3:12.0 4:2.0 5:3.0";
|
String lineOriginal = "13.0,14.0,15.0,4 1:1.0 2:11.0 3:12.0 4:2.0 5:3.0";
|
||||||
|
|
||||||
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
|
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
|
||||||
Configuration configWriter = new Configuration();
|
Configuration configWriter = new Configuration();
|
||||||
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
||||||
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 3);
|
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 3);
|
||||||
FileSplit outputSplit = new FileSplit(tempFile);
|
FileSplit outputSplit = new FileSplit(tempFile);
|
||||||
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
|
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
|
||||||
writer.write(record);
|
writer.write(record);
|
||||||
}
|
}
|
||||||
|
|
||||||
String lineNew = FileUtils.readFileToString(tempFile).trim();
|
String lineNew = FileUtils.readFileToString(tempFile).trim();
|
||||||
assertEquals(lineOriginal, lineNew);
|
assertEquals(lineOriginal, lineNew);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testNDArrayWritablesMultilabel() throws Exception {
|
@DisplayName("Test ND Array Writables Multilabel")
|
||||||
|
void testNDArrayWritablesMultilabel() throws Exception {
|
||||||
INDArray arr2 = Nd4j.zeros(2);
|
INDArray arr2 = Nd4j.zeros(2);
|
||||||
arr2.putScalar(0, 11);
|
arr2.putScalar(0, 11);
|
||||||
arr2.putScalar(1, 12);
|
arr2.putScalar(1, 12);
|
||||||
|
@ -203,36 +191,29 @@ public class LibSvmRecordWriterTest extends BaseND4JTest {
|
||||||
arr3.putScalar(0, 0);
|
arr3.putScalar(0, 0);
|
||||||
arr3.putScalar(1, 1);
|
arr3.putScalar(1, 1);
|
||||||
arr3.putScalar(2, 0);
|
arr3.putScalar(2, 0);
|
||||||
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1),
|
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), new NDArrayWritable(arr3), new DoubleWritable(1));
|
||||||
new NDArrayWritable(arr2),
|
|
||||||
new IntWritable(2),
|
|
||||||
new DoubleWritable(3),
|
|
||||||
new NDArrayWritable(arr3),
|
|
||||||
new DoubleWritable(1));
|
|
||||||
File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
|
File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
|
||||||
tempFile.setWritable(true);
|
tempFile.setWritable(true);
|
||||||
tempFile.deleteOnExit();
|
tempFile.deleteOnExit();
|
||||||
if (tempFile.exists())
|
if (tempFile.exists())
|
||||||
tempFile.delete();
|
tempFile.delete();
|
||||||
|
|
||||||
String lineOriginal = "2,4 1:1.0 2:11.0 3:12.0 4:2.0 5:3.0";
|
String lineOriginal = "2,4 1:1.0 2:11.0 3:12.0 4:2.0 5:3.0";
|
||||||
|
|
||||||
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
|
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
|
||||||
Configuration configWriter = new Configuration();
|
Configuration configWriter = new Configuration();
|
||||||
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
|
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
|
||||||
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
||||||
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 3);
|
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 3);
|
||||||
FileSplit outputSplit = new FileSplit(tempFile);
|
FileSplit outputSplit = new FileSplit(tempFile);
|
||||||
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
|
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
|
||||||
writer.write(record);
|
writer.write(record);
|
||||||
}
|
}
|
||||||
|
|
||||||
String lineNew = FileUtils.readFileToString(tempFile).trim();
|
String lineNew = FileUtils.readFileToString(tempFile).trim();
|
||||||
assertEquals(lineOriginal, lineNew);
|
assertEquals(lineOriginal, lineNew);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testNDArrayWritablesZeroIndex() throws Exception {
|
@DisplayName("Test ND Array Writables Zero Index")
|
||||||
|
void testNDArrayWritablesZeroIndex() throws Exception {
|
||||||
INDArray arr2 = Nd4j.zeros(2);
|
INDArray arr2 = Nd4j.zeros(2);
|
||||||
arr2.putScalar(0, 11);
|
arr2.putScalar(0, 11);
|
||||||
arr2.putScalar(1, 12);
|
arr2.putScalar(1, 12);
|
||||||
|
@ -240,99 +221,91 @@ public class LibSvmRecordWriterTest extends BaseND4JTest {
|
||||||
arr3.putScalar(0, 0);
|
arr3.putScalar(0, 0);
|
||||||
arr3.putScalar(1, 1);
|
arr3.putScalar(1, 1);
|
||||||
arr3.putScalar(2, 0);
|
arr3.putScalar(2, 0);
|
||||||
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1),
|
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), new NDArrayWritable(arr3), new DoubleWritable(1));
|
||||||
new NDArrayWritable(arr2),
|
|
||||||
new IntWritable(2),
|
|
||||||
new DoubleWritable(3),
|
|
||||||
new NDArrayWritable(arr3),
|
|
||||||
new DoubleWritable(1));
|
|
||||||
File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
|
File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
|
||||||
tempFile.setWritable(true);
|
tempFile.setWritable(true);
|
||||||
tempFile.deleteOnExit();
|
tempFile.deleteOnExit();
|
||||||
if (tempFile.exists())
|
if (tempFile.exists())
|
||||||
tempFile.delete();
|
tempFile.delete();
|
||||||
|
|
||||||
String lineOriginal = "1,3 0:1.0 1:11.0 2:12.0 3:2.0 4:3.0";
|
String lineOriginal = "1,3 0:1.0 1:11.0 2:12.0 3:2.0 4:3.0";
|
||||||
|
|
||||||
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
|
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
|
||||||
Configuration configWriter = new Configuration();
|
Configuration configWriter = new Configuration();
|
||||||
configWriter.setBoolean(LibSvmRecordWriter.ZERO_BASED_INDEXING, true); // NOT STANDARD!
|
// NOT STANDARD!
|
||||||
configWriter.setBoolean(LibSvmRecordWriter.ZERO_BASED_LABEL_INDEXING, true); // NOT STANDARD!
|
configWriter.setBoolean(LibSvmRecordWriter.ZERO_BASED_INDEXING, true);
|
||||||
|
// NOT STANDARD!
|
||||||
|
configWriter.setBoolean(LibSvmRecordWriter.ZERO_BASED_LABEL_INDEXING, true);
|
||||||
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
|
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
|
||||||
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
||||||
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 3);
|
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 3);
|
||||||
FileSplit outputSplit = new FileSplit(tempFile);
|
FileSplit outputSplit = new FileSplit(tempFile);
|
||||||
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
|
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
|
||||||
writer.write(record);
|
writer.write(record);
|
||||||
}
|
}
|
||||||
|
|
||||||
String lineNew = FileUtils.readFileToString(tempFile).trim();
|
String lineNew = FileUtils.readFileToString(tempFile).trim();
|
||||||
assertEquals(lineOriginal, lineNew);
|
assertEquals(lineOriginal, lineNew);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testNonIntegerButValidMultilabel() throws Exception {
|
@DisplayName("Test Non Integer But Valid Multilabel")
|
||||||
List<Writable> record = Arrays.asList((Writable) new IntWritable(3),
|
void testNonIntegerButValidMultilabel() throws Exception {
|
||||||
new IntWritable(2),
|
List<Writable> record = Arrays.asList((Writable) new IntWritable(3), new IntWritable(2), new DoubleWritable(1.0));
|
||||||
new DoubleWritable(1.0));
|
|
||||||
File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
|
File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
|
||||||
tempFile.setWritable(true);
|
tempFile.setWritable(true);
|
||||||
tempFile.deleteOnExit();
|
tempFile.deleteOnExit();
|
||||||
if (tempFile.exists())
|
if (tempFile.exists())
|
||||||
tempFile.delete();
|
tempFile.delete();
|
||||||
|
|
||||||
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
|
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
|
||||||
Configuration configWriter = new Configuration();
|
Configuration configWriter = new Configuration();
|
||||||
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
||||||
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 1);
|
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 1);
|
||||||
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
|
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
|
||||||
FileSplit outputSplit = new FileSplit(tempFile);
|
FileSplit outputSplit = new FileSplit(tempFile);
|
||||||
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
|
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
|
||||||
writer.write(record);
|
writer.write(record);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected = NumberFormatException.class)
|
@Test
|
||||||
public void nonIntegerMultilabel() throws Exception {
|
@DisplayName("Non Integer Multilabel")
|
||||||
List<Writable> record = Arrays.asList((Writable) new IntWritable(3),
|
void nonIntegerMultilabel() {
|
||||||
new IntWritable(2),
|
assertThrows(NumberFormatException.class, () -> {
|
||||||
new DoubleWritable(1.2));
|
List<Writable> record = Arrays.asList((Writable) new IntWritable(3), new IntWritable(2), new DoubleWritable(1.2));
|
||||||
File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
|
File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
|
||||||
tempFile.setWritable(true);
|
tempFile.setWritable(true);
|
||||||
tempFile.deleteOnExit();
|
tempFile.deleteOnExit();
|
||||||
if (tempFile.exists())
|
if (tempFile.exists())
|
||||||
tempFile.delete();
|
tempFile.delete();
|
||||||
|
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
|
||||||
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
|
Configuration configWriter = new Configuration();
|
||||||
Configuration configWriter = new Configuration();
|
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
||||||
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 1);
|
||||||
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 1);
|
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
|
||||||
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
|
FileSplit outputSplit = new FileSplit(tempFile);
|
||||||
FileSplit outputSplit = new FileSplit(tempFile);
|
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
|
||||||
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
|
writer.write(record);
|
||||||
writer.write(record);
|
}
|
||||||
}
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected = NumberFormatException.class)
|
@Test
|
||||||
public void nonBinaryMultilabel() throws Exception {
|
@DisplayName("Non Binary Multilabel")
|
||||||
List<Writable> record = Arrays.asList((Writable) new IntWritable(0),
|
void nonBinaryMultilabel() {
|
||||||
new IntWritable(1),
|
assertThrows(NumberFormatException.class, () -> {
|
||||||
new IntWritable(2));
|
List<Writable> record = Arrays.asList((Writable) new IntWritable(0), new IntWritable(1), new IntWritable(2));
|
||||||
File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
|
File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
|
||||||
tempFile.setWritable(true);
|
tempFile.setWritable(true);
|
||||||
tempFile.deleteOnExit();
|
tempFile.deleteOnExit();
|
||||||
if (tempFile.exists())
|
if (tempFile.exists())
|
||||||
tempFile.delete();
|
tempFile.delete();
|
||||||
|
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
|
||||||
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
|
Configuration configWriter = new Configuration();
|
||||||
Configuration configWriter = new Configuration();
|
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
||||||
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN,0);
|
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 1);
|
||||||
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN,1);
|
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
|
||||||
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL,true);
|
FileSplit outputSplit = new FileSplit(tempFile);
|
||||||
FileSplit outputSplit = new FileSplit(tempFile);
|
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
|
||||||
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
|
writer.write(record);
|
||||||
writer.write(record);
|
}
|
||||||
}
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,7 +17,6 @@
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
* *****************************************************************************
|
* *****************************************************************************
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.datavec.api.records.writer.impl;
|
package org.datavec.api.records.writer.impl;
|
||||||
|
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
|
@ -27,93 +26,90 @@ import org.datavec.api.records.writer.impl.misc.SVMLightRecordWriter;
|
||||||
import org.datavec.api.split.FileSplit;
|
import org.datavec.api.split.FileSplit;
|
||||||
import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
|
import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
|
||||||
import org.datavec.api.writable.*;
|
import org.datavec.api.writable.*;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
import org.nd4j.common.io.ClassPathResource;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.regex.Matcher;
|
import java.util.regex.Matcher;
|
||||||
import java.util.regex.Pattern;
|
import java.util.regex.Pattern;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import org.junit.jupiter.api.DisplayName;
|
||||||
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
@DisplayName("Svm Light Record Writer Test")
|
||||||
|
class SVMLightRecordWriterTest extends BaseND4JTest {
|
||||||
public class SVMLightRecordWriterTest extends BaseND4JTest {
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBasic() throws Exception {
|
@DisplayName("Test Basic")
|
||||||
|
void testBasic() throws Exception {
|
||||||
Configuration configWriter = new Configuration();
|
Configuration configWriter = new Configuration();
|
||||||
|
|
||||||
Configuration configReader = new Configuration();
|
Configuration configReader = new Configuration();
|
||||||
configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
|
configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
|
||||||
configReader.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
|
configReader.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
|
||||||
|
|
||||||
File inputFile = new ClassPathResource("datavec-api/svmlight/basic.txt").getFile();
|
File inputFile = new ClassPathResource("datavec-api/svmlight/basic.txt").getFile();
|
||||||
executeTest(configWriter, configReader, inputFile);
|
executeTest(configWriter, configReader, inputFile);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testNoLabel() throws Exception {
|
@DisplayName("Test No Label")
|
||||||
|
void testNoLabel() throws Exception {
|
||||||
Configuration configWriter = new Configuration();
|
Configuration configWriter = new Configuration();
|
||||||
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
||||||
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 9);
|
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 9);
|
||||||
|
|
||||||
Configuration configReader = new Configuration();
|
Configuration configReader = new Configuration();
|
||||||
configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
|
configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
|
||||||
configReader.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
|
configReader.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
|
||||||
|
|
||||||
File inputFile = new ClassPathResource("datavec-api/svmlight/noLabels.txt").getFile();
|
File inputFile = new ClassPathResource("datavec-api/svmlight/noLabels.txt").getFile();
|
||||||
executeTest(configWriter, configReader, inputFile);
|
executeTest(configWriter, configReader, inputFile);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testMultioutputRecord() throws Exception {
|
@DisplayName("Test Multioutput Record")
|
||||||
|
void testMultioutputRecord() throws Exception {
|
||||||
Configuration configWriter = new Configuration();
|
Configuration configWriter = new Configuration();
|
||||||
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
||||||
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 9);
|
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 9);
|
||||||
|
|
||||||
Configuration configReader = new Configuration();
|
Configuration configReader = new Configuration();
|
||||||
configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
|
configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
|
||||||
configReader.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
|
configReader.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
|
||||||
|
|
||||||
File inputFile = new ClassPathResource("datavec-api/svmlight/multioutput.txt").getFile();
|
File inputFile = new ClassPathResource("datavec-api/svmlight/multioutput.txt").getFile();
|
||||||
executeTest(configWriter, configReader, inputFile);
|
executeTest(configWriter, configReader, inputFile);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testMultilabelRecord() throws Exception {
|
@DisplayName("Test Multilabel Record")
|
||||||
|
void testMultilabelRecord() throws Exception {
|
||||||
Configuration configWriter = new Configuration();
|
Configuration configWriter = new Configuration();
|
||||||
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
||||||
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 9);
|
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 9);
|
||||||
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
|
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
|
||||||
|
|
||||||
Configuration configReader = new Configuration();
|
Configuration configReader = new Configuration();
|
||||||
configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
|
configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
|
||||||
configReader.setBoolean(SVMLightRecordReader.MULTILABEL, true);
|
configReader.setBoolean(SVMLightRecordReader.MULTILABEL, true);
|
||||||
configReader.setInt(SVMLightRecordReader.NUM_LABELS, 4);
|
configReader.setInt(SVMLightRecordReader.NUM_LABELS, 4);
|
||||||
configReader.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
|
configReader.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
|
||||||
|
|
||||||
File inputFile = new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile();
|
File inputFile = new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile();
|
||||||
executeTest(configWriter, configReader, inputFile);
|
executeTest(configWriter, configReader, inputFile);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testZeroBasedIndexing() throws Exception {
|
@DisplayName("Test Zero Based Indexing")
|
||||||
|
void testZeroBasedIndexing() throws Exception {
|
||||||
Configuration configWriter = new Configuration();
|
Configuration configWriter = new Configuration();
|
||||||
configWriter.setBoolean(SVMLightRecordWriter.ZERO_BASED_INDEXING, true);
|
configWriter.setBoolean(SVMLightRecordWriter.ZERO_BASED_INDEXING, true);
|
||||||
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
||||||
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 10);
|
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 10);
|
||||||
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
|
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
|
||||||
|
|
||||||
Configuration configReader = new Configuration();
|
Configuration configReader = new Configuration();
|
||||||
configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 11);
|
configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 11);
|
||||||
configReader.setBoolean(SVMLightRecordReader.MULTILABEL, true);
|
configReader.setBoolean(SVMLightRecordReader.MULTILABEL, true);
|
||||||
configReader.setInt(SVMLightRecordReader.NUM_LABELS, 5);
|
configReader.setInt(SVMLightRecordReader.NUM_LABELS, 5);
|
||||||
|
|
||||||
File inputFile = new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile();
|
File inputFile = new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile();
|
||||||
executeTest(configWriter, configReader, inputFile);
|
executeTest(configWriter, configReader, inputFile);
|
||||||
}
|
}
|
||||||
|
@ -124,10 +120,9 @@ public class SVMLightRecordWriterTest extends BaseND4JTest {
|
||||||
tempFile.deleteOnExit();
|
tempFile.deleteOnExit();
|
||||||
if (tempFile.exists())
|
if (tempFile.exists())
|
||||||
tempFile.delete();
|
tempFile.delete();
|
||||||
|
|
||||||
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
|
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
|
||||||
FileSplit outputSplit = new FileSplit(tempFile);
|
FileSplit outputSplit = new FileSplit(tempFile);
|
||||||
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
|
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
|
||||||
SVMLightRecordReader rr = new SVMLightRecordReader();
|
SVMLightRecordReader rr = new SVMLightRecordReader();
|
||||||
rr.initialize(configReader, new FileSplit(inputFile));
|
rr.initialize(configReader, new FileSplit(inputFile));
|
||||||
while (rr.hasNext()) {
|
while (rr.hasNext()) {
|
||||||
|
@ -135,7 +130,6 @@ public class SVMLightRecordWriterTest extends BaseND4JTest {
|
||||||
writer.write(record);
|
writer.write(record);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Pattern p = Pattern.compile(String.format("%s:\\d+ ", SVMLightRecordReader.QID_PREFIX));
|
Pattern p = Pattern.compile(String.format("%s:\\d+ ", SVMLightRecordReader.QID_PREFIX));
|
||||||
List<String> linesOriginal = new ArrayList<>();
|
List<String> linesOriginal = new ArrayList<>();
|
||||||
for (String line : FileUtils.readLines(inputFile)) {
|
for (String line : FileUtils.readLines(inputFile)) {
|
||||||
|
@ -156,7 +150,8 @@ public class SVMLightRecordWriterTest extends BaseND4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testNDArrayWritables() throws Exception {
|
@DisplayName("Test ND Array Writables")
|
||||||
|
void testNDArrayWritables() throws Exception {
|
||||||
INDArray arr2 = Nd4j.zeros(2);
|
INDArray arr2 = Nd4j.zeros(2);
|
||||||
arr2.putScalar(0, 11);
|
arr2.putScalar(0, 11);
|
||||||
arr2.putScalar(1, 12);
|
arr2.putScalar(1, 12);
|
||||||
|
@ -164,35 +159,28 @@ public class SVMLightRecordWriterTest extends BaseND4JTest {
|
||||||
arr3.putScalar(0, 13);
|
arr3.putScalar(0, 13);
|
||||||
arr3.putScalar(1, 14);
|
arr3.putScalar(1, 14);
|
||||||
arr3.putScalar(2, 15);
|
arr3.putScalar(2, 15);
|
||||||
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1),
|
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), new NDArrayWritable(arr3), new IntWritable(4));
|
||||||
new NDArrayWritable(arr2),
|
|
||||||
new IntWritable(2),
|
|
||||||
new DoubleWritable(3),
|
|
||||||
new NDArrayWritable(arr3),
|
|
||||||
new IntWritable(4));
|
|
||||||
File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
|
File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
|
||||||
tempFile.setWritable(true);
|
tempFile.setWritable(true);
|
||||||
tempFile.deleteOnExit();
|
tempFile.deleteOnExit();
|
||||||
if (tempFile.exists())
|
if (tempFile.exists())
|
||||||
tempFile.delete();
|
tempFile.delete();
|
||||||
|
|
||||||
String lineOriginal = "13.0,14.0,15.0,4 1:1.0 2:11.0 3:12.0 4:2.0 5:3.0";
|
String lineOriginal = "13.0,14.0,15.0,4 1:1.0 2:11.0 3:12.0 4:2.0 5:3.0";
|
||||||
|
|
||||||
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
|
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
|
||||||
Configuration configWriter = new Configuration();
|
Configuration configWriter = new Configuration();
|
||||||
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
||||||
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 3);
|
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 3);
|
||||||
FileSplit outputSplit = new FileSplit(tempFile);
|
FileSplit outputSplit = new FileSplit(tempFile);
|
||||||
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
|
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
|
||||||
writer.write(record);
|
writer.write(record);
|
||||||
}
|
}
|
||||||
|
|
||||||
String lineNew = FileUtils.readFileToString(tempFile).trim();
|
String lineNew = FileUtils.readFileToString(tempFile).trim();
|
||||||
assertEquals(lineOriginal, lineNew);
|
assertEquals(lineOriginal, lineNew);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testNDArrayWritablesMultilabel() throws Exception {
|
@DisplayName("Test ND Array Writables Multilabel")
|
||||||
|
void testNDArrayWritablesMultilabel() throws Exception {
|
||||||
INDArray arr2 = Nd4j.zeros(2);
|
INDArray arr2 = Nd4j.zeros(2);
|
||||||
arr2.putScalar(0, 11);
|
arr2.putScalar(0, 11);
|
||||||
arr2.putScalar(1, 12);
|
arr2.putScalar(1, 12);
|
||||||
|
@ -200,36 +188,29 @@ public class SVMLightRecordWriterTest extends BaseND4JTest {
|
||||||
arr3.putScalar(0, 0);
|
arr3.putScalar(0, 0);
|
||||||
arr3.putScalar(1, 1);
|
arr3.putScalar(1, 1);
|
||||||
arr3.putScalar(2, 0);
|
arr3.putScalar(2, 0);
|
||||||
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1),
|
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), new NDArrayWritable(arr3), new DoubleWritable(1));
|
||||||
new NDArrayWritable(arr2),
|
|
||||||
new IntWritable(2),
|
|
||||||
new DoubleWritable(3),
|
|
||||||
new NDArrayWritable(arr3),
|
|
||||||
new DoubleWritable(1));
|
|
||||||
File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
|
File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
|
||||||
tempFile.setWritable(true);
|
tempFile.setWritable(true);
|
||||||
tempFile.deleteOnExit();
|
tempFile.deleteOnExit();
|
||||||
if (tempFile.exists())
|
if (tempFile.exists())
|
||||||
tempFile.delete();
|
tempFile.delete();
|
||||||
|
|
||||||
String lineOriginal = "2,4 1:1.0 2:11.0 3:12.0 4:2.0 5:3.0";
|
String lineOriginal = "2,4 1:1.0 2:11.0 3:12.0 4:2.0 5:3.0";
|
||||||
|
|
||||||
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
|
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
|
||||||
Configuration configWriter = new Configuration();
|
Configuration configWriter = new Configuration();
|
||||||
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
|
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
|
||||||
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
||||||
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 3);
|
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 3);
|
||||||
FileSplit outputSplit = new FileSplit(tempFile);
|
FileSplit outputSplit = new FileSplit(tempFile);
|
||||||
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
|
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
|
||||||
writer.write(record);
|
writer.write(record);
|
||||||
}
|
}
|
||||||
|
|
||||||
String lineNew = FileUtils.readFileToString(tempFile).trim();
|
String lineNew = FileUtils.readFileToString(tempFile).trim();
|
||||||
assertEquals(lineOriginal, lineNew);
|
assertEquals(lineOriginal, lineNew);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testNDArrayWritablesZeroIndex() throws Exception {
|
@DisplayName("Test ND Array Writables Zero Index")
|
||||||
|
void testNDArrayWritablesZeroIndex() throws Exception {
|
||||||
INDArray arr2 = Nd4j.zeros(2);
|
INDArray arr2 = Nd4j.zeros(2);
|
||||||
arr2.putScalar(0, 11);
|
arr2.putScalar(0, 11);
|
||||||
arr2.putScalar(1, 12);
|
arr2.putScalar(1, 12);
|
||||||
|
@ -237,99 +218,91 @@ public class SVMLightRecordWriterTest extends BaseND4JTest {
|
||||||
arr3.putScalar(0, 0);
|
arr3.putScalar(0, 0);
|
||||||
arr3.putScalar(1, 1);
|
arr3.putScalar(1, 1);
|
||||||
arr3.putScalar(2, 0);
|
arr3.putScalar(2, 0);
|
||||||
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1),
|
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), new NDArrayWritable(arr3), new DoubleWritable(1));
|
||||||
new NDArrayWritable(arr2),
|
|
||||||
new IntWritable(2),
|
|
||||||
new DoubleWritable(3),
|
|
||||||
new NDArrayWritable(arr3),
|
|
||||||
new DoubleWritable(1));
|
|
||||||
File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
|
File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
|
||||||
tempFile.setWritable(true);
|
tempFile.setWritable(true);
|
||||||
tempFile.deleteOnExit();
|
tempFile.deleteOnExit();
|
||||||
if (tempFile.exists())
|
if (tempFile.exists())
|
||||||
tempFile.delete();
|
tempFile.delete();
|
||||||
|
|
||||||
String lineOriginal = "1,3 0:1.0 1:11.0 2:12.0 3:2.0 4:3.0";
|
String lineOriginal = "1,3 0:1.0 1:11.0 2:12.0 3:2.0 4:3.0";
|
||||||
|
|
||||||
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
|
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
|
||||||
Configuration configWriter = new Configuration();
|
Configuration configWriter = new Configuration();
|
||||||
configWriter.setBoolean(SVMLightRecordWriter.ZERO_BASED_INDEXING, true); // NOT STANDARD!
|
// NOT STANDARD!
|
||||||
configWriter.setBoolean(SVMLightRecordWriter.ZERO_BASED_LABEL_INDEXING, true); // NOT STANDARD!
|
configWriter.setBoolean(SVMLightRecordWriter.ZERO_BASED_INDEXING, true);
|
||||||
|
// NOT STANDARD!
|
||||||
|
configWriter.setBoolean(SVMLightRecordWriter.ZERO_BASED_LABEL_INDEXING, true);
|
||||||
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
|
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
|
||||||
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
||||||
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 3);
|
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 3);
|
||||||
FileSplit outputSplit = new FileSplit(tempFile);
|
FileSplit outputSplit = new FileSplit(tempFile);
|
||||||
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
|
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
|
||||||
writer.write(record);
|
writer.write(record);
|
||||||
}
|
}
|
||||||
|
|
||||||
String lineNew = FileUtils.readFileToString(tempFile).trim();
|
String lineNew = FileUtils.readFileToString(tempFile).trim();
|
||||||
assertEquals(lineOriginal, lineNew);
|
assertEquals(lineOriginal, lineNew);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testNonIntegerButValidMultilabel() throws Exception {
|
@DisplayName("Test Non Integer But Valid Multilabel")
|
||||||
List<Writable> record = Arrays.asList((Writable) new IntWritable(3),
|
void testNonIntegerButValidMultilabel() throws Exception {
|
||||||
new IntWritable(2),
|
List<Writable> record = Arrays.asList((Writable) new IntWritable(3), new IntWritable(2), new DoubleWritable(1.0));
|
||||||
new DoubleWritable(1.0));
|
|
||||||
File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
|
File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
|
||||||
tempFile.setWritable(true);
|
tempFile.setWritable(true);
|
||||||
tempFile.deleteOnExit();
|
tempFile.deleteOnExit();
|
||||||
if (tempFile.exists())
|
if (tempFile.exists())
|
||||||
tempFile.delete();
|
tempFile.delete();
|
||||||
|
|
||||||
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
|
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
|
||||||
Configuration configWriter = new Configuration();
|
Configuration configWriter = new Configuration();
|
||||||
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
||||||
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 1);
|
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 1);
|
||||||
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
|
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
|
||||||
FileSplit outputSplit = new FileSplit(tempFile);
|
FileSplit outputSplit = new FileSplit(tempFile);
|
||||||
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
|
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
|
||||||
writer.write(record);
|
writer.write(record);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected = NumberFormatException.class)
|
@Test
|
||||||
public void nonIntegerMultilabel() throws Exception {
|
@DisplayName("Non Integer Multilabel")
|
||||||
List<Writable> record = Arrays.asList((Writable) new IntWritable(3),
|
void nonIntegerMultilabel() {
|
||||||
new IntWritable(2),
|
assertThrows(NumberFormatException.class, () -> {
|
||||||
new DoubleWritable(1.2));
|
List<Writable> record = Arrays.asList((Writable) new IntWritable(3), new IntWritable(2), new DoubleWritable(1.2));
|
||||||
File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
|
File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
|
||||||
tempFile.setWritable(true);
|
tempFile.setWritable(true);
|
||||||
tempFile.deleteOnExit();
|
tempFile.deleteOnExit();
|
||||||
if (tempFile.exists())
|
if (tempFile.exists())
|
||||||
tempFile.delete();
|
tempFile.delete();
|
||||||
|
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
|
||||||
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
|
Configuration configWriter = new Configuration();
|
||||||
Configuration configWriter = new Configuration();
|
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
||||||
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 1);
|
||||||
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 1);
|
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
|
||||||
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
|
FileSplit outputSplit = new FileSplit(tempFile);
|
||||||
FileSplit outputSplit = new FileSplit(tempFile);
|
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
|
||||||
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
|
writer.write(record);
|
||||||
writer.write(record);
|
}
|
||||||
}
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected = NumberFormatException.class)
|
@Test
|
||||||
public void nonBinaryMultilabel() throws Exception {
|
@DisplayName("Non Binary Multilabel")
|
||||||
List<Writable> record = Arrays.asList((Writable) new IntWritable(0),
|
void nonBinaryMultilabel() {
|
||||||
new IntWritable(1),
|
assertThrows(NumberFormatException.class, () -> {
|
||||||
new IntWritable(2));
|
List<Writable> record = Arrays.asList((Writable) new IntWritable(0), new IntWritable(1), new IntWritable(2));
|
||||||
File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
|
File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
|
||||||
tempFile.setWritable(true);
|
tempFile.setWritable(true);
|
||||||
tempFile.deleteOnExit();
|
tempFile.deleteOnExit();
|
||||||
if (tempFile.exists())
|
if (tempFile.exists())
|
||||||
tempFile.delete();
|
tempFile.delete();
|
||||||
|
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
|
||||||
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
|
Configuration configWriter = new Configuration();
|
||||||
Configuration configWriter = new Configuration();
|
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
||||||
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
|
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 1);
|
||||||
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 1);
|
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
|
||||||
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
|
FileSplit outputSplit = new FileSplit(tempFile);
|
||||||
FileSplit outputSplit = new FileSplit(tempFile);
|
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
|
||||||
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
|
writer.write(record);
|
||||||
writer.write(record);
|
}
|
||||||
}
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,7 +26,7 @@ import org.datavec.api.io.filters.BalancedPathFilter;
|
||||||
import org.datavec.api.io.filters.RandomPathFilter;
|
import org.datavec.api.io.filters.RandomPathFilter;
|
||||||
import org.datavec.api.io.labels.ParentPathLabelGenerator;
|
import org.datavec.api.io.labels.ParentPathLabelGenerator;
|
||||||
import org.datavec.api.io.labels.PatternPathLabelGenerator;
|
import org.datavec.api.io.labels.PatternPathLabelGenerator;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
import java.io.*;
|
import java.io.*;
|
||||||
import java.net.URI;
|
import java.net.URI;
|
||||||
|
@ -34,8 +34,9 @@ import java.net.URISyntaxException;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
|
||||||
import static junit.framework.TestCase.assertTrue;
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
|
|
|
@ -20,13 +20,12 @@
|
||||||
|
|
||||||
package org.datavec.api.split;
|
package org.datavec.api.split;
|
||||||
|
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
||||||
import java.net.URI;
|
import java.net.URI;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
import static org.junit.Assert.assertTrue;
|
|
||||||
|
|
||||||
public class NumberedFileInputSplitTests extends BaseND4JTest {
|
public class NumberedFileInputSplitTests extends BaseND4JTest {
|
||||||
@Test
|
@Test
|
||||||
|
@ -69,60 +68,81 @@ public class NumberedFileInputSplitTests extends BaseND4JTest {
|
||||||
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
|
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected = IllegalArgumentException.class)
|
@Test()
|
||||||
public void testNumberedFileInputSplitWithLeadingSpaces() {
|
public void testNumberedFileInputSplitWithLeadingSpaces() {
|
||||||
String baseString = "/path/to/files/prefix-%5d.suffix";
|
assertThrows(IllegalArgumentException.class,() -> {
|
||||||
int minIdx = 0;
|
String baseString = "/path/to/files/prefix-%5d.suffix";
|
||||||
int maxIdx = 10;
|
int minIdx = 0;
|
||||||
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
|
int maxIdx = 10;
|
||||||
|
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
|
||||||
|
});
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected = IllegalArgumentException.class)
|
@Test()
|
||||||
public void testNumberedFileInputSplitWithNoLeadingZeroInPadding() {
|
public void testNumberedFileInputSplitWithNoLeadingZeroInPadding() {
|
||||||
String baseString = "/path/to/files/prefix%5d.suffix";
|
assertThrows(IllegalArgumentException.class, () -> {
|
||||||
int minIdx = 0;
|
String baseString = "/path/to/files/prefix%5d.suffix";
|
||||||
int maxIdx = 10;
|
int minIdx = 0;
|
||||||
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
|
int maxIdx = 10;
|
||||||
|
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
|
||||||
|
});
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected = IllegalArgumentException.class)
|
@Test()
|
||||||
public void testNumberedFileInputSplitWithLeadingPlusInPadding() {
|
public void testNumberedFileInputSplitWithLeadingPlusInPadding() {
|
||||||
String baseString = "/path/to/files/prefix%+5d.suffix";
|
assertThrows(IllegalArgumentException.class,() -> {
|
||||||
int minIdx = 0;
|
String baseString = "/path/to/files/prefix%+5d.suffix";
|
||||||
int maxIdx = 10;
|
int minIdx = 0;
|
||||||
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
|
int maxIdx = 10;
|
||||||
|
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
|
||||||
|
});
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected = IllegalArgumentException.class)
|
@Test()
|
||||||
public void testNumberedFileInputSplitWithLeadingMinusInPadding() {
|
public void testNumberedFileInputSplitWithLeadingMinusInPadding() {
|
||||||
String baseString = "/path/to/files/prefix%-5d.suffix";
|
assertThrows(IllegalArgumentException.class,() -> {
|
||||||
int minIdx = 0;
|
String baseString = "/path/to/files/prefix%-5d.suffix";
|
||||||
int maxIdx = 10;
|
int minIdx = 0;
|
||||||
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
|
int maxIdx = 10;
|
||||||
|
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
|
||||||
|
});
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected = IllegalArgumentException.class)
|
@Test()
|
||||||
public void testNumberedFileInputSplitWithTwoDigitsInPadding() {
|
public void testNumberedFileInputSplitWithTwoDigitsInPadding() {
|
||||||
String baseString = "/path/to/files/prefix%011d.suffix";
|
assertThrows(IllegalArgumentException.class,() -> {
|
||||||
int minIdx = 0;
|
String baseString = "/path/to/files/prefix%011d.suffix";
|
||||||
int maxIdx = 10;
|
int minIdx = 0;
|
||||||
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
|
int maxIdx = 10;
|
||||||
|
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
|
||||||
|
});
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected = IllegalArgumentException.class)
|
@Test()
|
||||||
public void testNumberedFileInputSplitWithInnerZerosInPadding() {
|
public void testNumberedFileInputSplitWithInnerZerosInPadding() {
|
||||||
String baseString = "/path/to/files/prefix%101d.suffix";
|
assertThrows(IllegalArgumentException.class,() -> {
|
||||||
int minIdx = 0;
|
String baseString = "/path/to/files/prefix%101d.suffix";
|
||||||
int maxIdx = 10;
|
int minIdx = 0;
|
||||||
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
|
int maxIdx = 10;
|
||||||
|
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
|
||||||
|
});
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected = IllegalArgumentException.class)
|
@Test()
|
||||||
public void testNumberedFileInputSplitWithRepeatInnerZerosInPadding() {
|
public void testNumberedFileInputSplitWithRepeatInnerZerosInPadding() {
|
||||||
String baseString = "/path/to/files/prefix%0505d.suffix";
|
assertThrows(IllegalArgumentException.class,() -> {
|
||||||
int minIdx = 0;
|
String baseString = "/path/to/files/prefix%0505d.suffix";
|
||||||
int maxIdx = 10;
|
int minIdx = 0;
|
||||||
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
|
int maxIdx = 10;
|
||||||
|
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
|
||||||
|
});
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -135,7 +155,7 @@ public class NumberedFileInputSplitTests extends BaseND4JTest {
|
||||||
String path = locs[j++].getPath();
|
String path = locs[j++].getPath();
|
||||||
String exp = String.format(baseString, i);
|
String exp = String.format(baseString, i);
|
||||||
String msg = exp + " vs " + path;
|
String msg = exp + " vs " + path;
|
||||||
assertTrue(msg, path.endsWith(exp)); //Note: on Windows, Java can prepend drive to path - "/C:/"
|
assertTrue(path.endsWith(exp),msg); //Note: on Windows, Java can prepend drive to path - "/C:/"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,9 +25,10 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
|
||||||
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
|
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
|
||||||
import org.datavec.api.writable.Text;
|
import org.datavec.api.writable.Text;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.rules.TemporaryFolder;
|
|
||||||
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
import org.nd4j.common.function.Function;
|
import org.nd4j.common.function.Function;
|
||||||
|
|
||||||
|
@ -37,22 +38,22 @@ import java.io.IOException;
|
||||||
import java.io.InputStream;
|
import java.io.InputStream;
|
||||||
import java.net.URI;
|
import java.net.URI;
|
||||||
import java.nio.charset.StandardCharsets;
|
import java.nio.charset.StandardCharsets;
|
||||||
|
import java.nio.file.Path;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.junit.Assert.assertNotEquals;
|
import static org.junit.jupiter.api.Assertions.assertNotEquals;
|
||||||
|
|
||||||
public class TestStreamInputSplit extends BaseND4JTest {
|
public class TestStreamInputSplit extends BaseND4JTest {
|
||||||
|
|
||||||
@Rule
|
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCsvSimple() throws Exception {
|
public void testCsvSimple(@TempDir Path testDir) throws Exception {
|
||||||
File dir = testDir.newFolder();
|
File dir = testDir.toFile();
|
||||||
File f1 = new File(dir, "file1.txt");
|
File f1 = new File(dir, "file1.txt");
|
||||||
File f2 = new File(dir, "file2.txt");
|
File f2 = new File(dir, "file2.txt");
|
||||||
|
|
||||||
|
@ -93,9 +94,9 @@ public class TestStreamInputSplit extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCsvSequenceSimple() throws Exception {
|
public void testCsvSequenceSimple(@TempDir Path testDir) throws Exception {
|
||||||
|
|
||||||
File dir = testDir.newFolder();
|
File dir = testDir.toFile();
|
||||||
File f1 = new File(dir, "file1.txt");
|
File f1 = new File(dir, "file1.txt");
|
||||||
File f2 = new File(dir, "file2.txt");
|
File f2 = new File(dir, "file2.txt");
|
||||||
|
|
||||||
|
@ -137,8 +138,8 @@ public class TestStreamInputSplit extends BaseND4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testShuffle() throws Exception {
|
public void testShuffle(@TempDir Path testDir) throws Exception {
|
||||||
File dir = testDir.newFolder();
|
File dir = testDir.toFile();
|
||||||
File f1 = new File(dir, "file1.txt");
|
File f1 = new File(dir, "file1.txt");
|
||||||
File f2 = new File(dir, "file2.txt");
|
File f2 = new File(dir, "file2.txt");
|
||||||
File f3 = new File(dir, "file3.txt");
|
File f3 = new File(dir, "file3.txt");
|
||||||
|
|
|
@ -17,44 +17,43 @@
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
* *****************************************************************************
|
* *****************************************************************************
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.datavec.api.split;
|
package org.datavec.api.split;
|
||||||
|
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
||||||
import java.net.URI;
|
import java.net.URI;
|
||||||
import java.net.URISyntaxException;
|
import java.net.URISyntaxException;
|
||||||
import java.util.Collection;
|
import java.util.Collection;
|
||||||
|
|
||||||
import static java.util.Arrays.asList;
|
import static java.util.Arrays.asList;
|
||||||
import static org.junit.Assert.assertArrayEquals;
|
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
|
||||||
|
import org.junit.jupiter.api.DisplayName;
|
||||||
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author Ede Meijer
|
* @author Ede Meijer
|
||||||
*/
|
*/
|
||||||
public class TransformSplitTest extends BaseND4JTest {
|
@DisplayName("Transform Split Test")
|
||||||
@Test
|
class TransformSplitTest extends BaseND4JTest {
|
||||||
public void testTransform() throws URISyntaxException {
|
|
||||||
Collection<URI> inputFiles = asList(new URI("file:///foo/bar/../0.csv"), new URI("file:///foo/1.csv"));
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@DisplayName("Test Transform")
|
||||||
|
void testTransform() throws URISyntaxException {
|
||||||
|
Collection<URI> inputFiles = asList(new URI("file:///foo/bar/../0.csv"), new URI("file:///foo/1.csv"));
|
||||||
InputSplit SUT = new TransformSplit(new CollectionInputSplit(inputFiles), new TransformSplit.URITransform() {
|
InputSplit SUT = new TransformSplit(new CollectionInputSplit(inputFiles), new TransformSplit.URITransform() {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public URI apply(URI uri) throws URISyntaxException {
|
public URI apply(URI uri) throws URISyntaxException {
|
||||||
return uri.normalize();
|
return uri.normalize();
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
assertArrayEquals(new URI[] { new URI("file:///foo/0.csv"), new URI("file:///foo/1.csv") }, SUT.locations());
|
||||||
assertArrayEquals(new URI[] {new URI("file:///foo/0.csv"), new URI("file:///foo/1.csv")}, SUT.locations());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testSearchReplace() throws URISyntaxException {
|
@DisplayName("Test Search Replace")
|
||||||
|
void testSearchReplace() throws URISyntaxException {
|
||||||
Collection<URI> inputFiles = asList(new URI("file:///foo/1-in.csv"), new URI("file:///foo/2-in.csv"));
|
Collection<URI> inputFiles = asList(new URI("file:///foo/1-in.csv"), new URI("file:///foo/2-in.csv"));
|
||||||
|
|
||||||
InputSplit SUT = TransformSplit.ofSearchReplace(new CollectionInputSplit(inputFiles), "-in.csv", "-out.csv");
|
InputSplit SUT = TransformSplit.ofSearchReplace(new CollectionInputSplit(inputFiles), "-in.csv", "-out.csv");
|
||||||
|
assertArrayEquals(new URI[] { new URI("file:///foo/1-out.csv"), new URI("file:///foo/2-out.csv") }, SUT.locations());
|
||||||
assertArrayEquals(new URI[] {new URI("file:///foo/1-out.csv"), new URI("file:///foo/2-out.csv")},
|
|
||||||
SUT.locations());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,14 +27,12 @@ import org.datavec.api.split.FileSplit;
|
||||||
import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
|
import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
|
||||||
import org.datavec.api.split.partition.PartitionMetaData;
|
import org.datavec.api.split.partition.PartitionMetaData;
|
||||||
import org.datavec.api.split.partition.Partitioner;
|
import org.datavec.api.split.partition.Partitioner;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.io.OutputStream;
|
import java.io.OutputStream;
|
||||||
|
|
||||||
import static junit.framework.TestCase.assertTrue;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
import static org.junit.Assert.assertEquals;
|
|
||||||
import static org.junit.Assert.assertNotNull;
|
|
||||||
|
|
||||||
public class PartitionerTests extends BaseND4JTest {
|
public class PartitionerTests extends BaseND4JTest {
|
||||||
@Test
|
@Test
|
||||||
|
|
|
@ -29,12 +29,12 @@ import org.datavec.api.writable.DoubleWritable;
|
||||||
import org.datavec.api.writable.IntWritable;
|
import org.datavec.api.writable.IntWritable;
|
||||||
import org.datavec.api.writable.Text;
|
import org.datavec.api.writable.Text;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestTransformProcess extends BaseND4JTest {
|
public class TestTransformProcess extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
|
@ -27,13 +27,13 @@ import org.datavec.api.transform.condition.string.StringRegexColumnCondition;
|
||||||
import org.datavec.api.transform.schema.Schema;
|
import org.datavec.api.transform.schema.Schema;
|
||||||
import org.datavec.api.transform.transform.TestTransforms;
|
import org.datavec.api.transform.transform.TestTransforms;
|
||||||
import org.datavec.api.writable.*;
|
import org.datavec.api.writable.*;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
import static org.junit.Assert.assertFalse;
|
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
public class TestConditions extends BaseND4JTest {
|
public class TestConditions extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
|
@ -27,7 +27,7 @@ import org.datavec.api.transform.schema.Schema;
|
||||||
import org.datavec.api.writable.DoubleWritable;
|
import org.datavec.api.writable.DoubleWritable;
|
||||||
import org.datavec.api.writable.IntWritable;
|
import org.datavec.api.writable.IntWritable;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
@ -36,8 +36,8 @@ import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static java.util.Arrays.asList;
|
import static java.util.Arrays.asList;
|
||||||
import static org.junit.Assert.assertFalse;
|
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
public class TestFilters extends BaseND4JTest {
|
public class TestFilters extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
|
@ -26,19 +26,22 @@ import org.datavec.api.writable.IntWritable;
|
||||||
import org.datavec.api.writable.NullWritable;
|
import org.datavec.api.writable.NullWritable;
|
||||||
import org.datavec.api.writable.Text;
|
import org.datavec.api.writable.Text;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
||||||
|
import java.nio.file.Path;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||||
|
|
||||||
public class TestJoin extends BaseND4JTest {
|
public class TestJoin extends BaseND4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testJoin() {
|
public void testJoin(@TempDir Path testDir) {
|
||||||
|
|
||||||
Schema firstSchema =
|
Schema firstSchema =
|
||||||
new Schema.Builder().addColumnString("keyColumn").addColumnsInteger("first0", "first1").build();
|
new Schema.Builder().addColumnString("keyColumn").addColumnsInteger("first0", "first1").build();
|
||||||
|
@ -46,20 +49,20 @@ public class TestJoin extends BaseND4JTest {
|
||||||
Schema secondSchema = new Schema.Builder().addColumnString("keyColumn").addColumnsInteger("second0").build();
|
Schema secondSchema = new Schema.Builder().addColumnString("keyColumn").addColumnsInteger("second0").build();
|
||||||
|
|
||||||
List<List<Writable>> first = new ArrayList<>();
|
List<List<Writable>> first = new ArrayList<>();
|
||||||
first.add(Arrays.asList((Writable) new Text("key0"), new IntWritable(0), new IntWritable(1)));
|
first.add(Arrays.asList(new Text("key0"), new IntWritable(0), new IntWritable(1)));
|
||||||
first.add(Arrays.asList((Writable) new Text("key1"), new IntWritable(10), new IntWritable(11)));
|
first.add(Arrays.asList(new Text("key1"), new IntWritable(10), new IntWritable(11)));
|
||||||
|
|
||||||
List<List<Writable>> second = new ArrayList<>();
|
List<List<Writable>> second = new ArrayList<>();
|
||||||
second.add(Arrays.asList((Writable) new Text("key0"), new IntWritable(100)));
|
second.add(Arrays.asList(new Text("key0"), new IntWritable(100)));
|
||||||
second.add(Arrays.asList((Writable) new Text("key1"), new IntWritable(110)));
|
second.add(Arrays.asList(new Text("key1"), new IntWritable(110)));
|
||||||
|
|
||||||
Join join = new Join.Builder(Join.JoinType.Inner).setJoinColumns("keyColumn")
|
Join join = new Join.Builder(Join.JoinType.Inner).setJoinColumns("keyColumn")
|
||||||
.setSchemas(firstSchema, secondSchema).build();
|
.setSchemas(firstSchema, secondSchema).build();
|
||||||
|
|
||||||
List<List<Writable>> expected = new ArrayList<>();
|
List<List<Writable>> expected = new ArrayList<>();
|
||||||
expected.add(Arrays.asList((Writable) new Text("key0"), new IntWritable(0), new IntWritable(1),
|
expected.add(Arrays.asList(new Text("key0"), new IntWritable(0), new IntWritable(1),
|
||||||
new IntWritable(100)));
|
new IntWritable(100)));
|
||||||
expected.add(Arrays.asList((Writable) new Text("key1"), new IntWritable(10), new IntWritable(11),
|
expected.add(Arrays.asList(new Text("key1"), new IntWritable(10), new IntWritable(11),
|
||||||
new IntWritable(110)));
|
new IntWritable(110)));
|
||||||
|
|
||||||
|
|
||||||
|
@ -94,27 +97,31 @@ public class TestJoin extends BaseND4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test(expected = IllegalArgumentException.class)
|
@Test()
|
||||||
public void testJoinValidation() {
|
public void testJoinValidation() {
|
||||||
|
assertThrows(IllegalArgumentException.class,() -> {
|
||||||
|
Schema firstSchema = new Schema.Builder().addColumnString("keyColumn1").addColumnsInteger("first0", "first1")
|
||||||
|
.build();
|
||||||
|
|
||||||
Schema firstSchema = new Schema.Builder().addColumnString("keyColumn1").addColumnsInteger("first0", "first1")
|
Schema secondSchema = new Schema.Builder().addColumnString("keyColumn2").addColumnsInteger("second0").build();
|
||||||
.build();
|
|
||||||
|
|
||||||
Schema secondSchema = new Schema.Builder().addColumnString("keyColumn2").addColumnsInteger("second0").build();
|
new Join.Builder(Join.JoinType.Inner).setJoinColumns("keyColumn1", "thisDoesntExist")
|
||||||
|
.setSchemas(firstSchema, secondSchema).build();
|
||||||
|
});
|
||||||
|
|
||||||
new Join.Builder(Join.JoinType.Inner).setJoinColumns("keyColumn1", "thisDoesntExist")
|
|
||||||
.setSchemas(firstSchema, secondSchema).build();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected = IllegalArgumentException.class)
|
@Test()
|
||||||
public void testJoinValidation2() {
|
public void testJoinValidation2() {
|
||||||
|
assertThrows(IllegalArgumentException.class,() -> {
|
||||||
|
Schema firstSchema = new Schema.Builder().addColumnString("keyColumn1").addColumnsInteger("first0", "first1")
|
||||||
|
.build();
|
||||||
|
|
||||||
Schema firstSchema = new Schema.Builder().addColumnString("keyColumn1").addColumnsInteger("first0", "first1")
|
Schema secondSchema = new Schema.Builder().addColumnString("keyColumn2").addColumnsInteger("second0").build();
|
||||||
.build();
|
|
||||||
|
|
||||||
Schema secondSchema = new Schema.Builder().addColumnString("keyColumn2").addColumnsInteger("second0").build();
|
new Join.Builder(Join.JoinType.Inner).setJoinColumns("keyColumn1").setSchemas(firstSchema, secondSchema)
|
||||||
|
.build();
|
||||||
|
});
|
||||||
|
|
||||||
new Join.Builder(Join.JoinType.Inner).setJoinColumns("keyColumn1").setSchemas(firstSchema, secondSchema)
|
|
||||||
.build();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,32 +17,25 @@
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
* *****************************************************************************
|
* *****************************************************************************
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.datavec.api.transform.ops;
|
package org.datavec.api.transform.ops;
|
||||||
|
|
||||||
import com.tngtech.archunit.core.importer.ImportOption;
|
import com.tngtech.archunit.core.importer.ImportOption;
|
||||||
import com.tngtech.archunit.junit.AnalyzeClasses;
|
import com.tngtech.archunit.junit.AnalyzeClasses;
|
||||||
import com.tngtech.archunit.junit.ArchTest;
|
import com.tngtech.archunit.junit.ArchTest;
|
||||||
import com.tngtech.archunit.junit.ArchUnitRunner;
|
|
||||||
import com.tngtech.archunit.lang.ArchRule;
|
import com.tngtech.archunit.lang.ArchRule;
|
||||||
|
import com.tngtech.archunit.lang.extension.ArchUnitExtension;
|
||||||
|
import com.tngtech.archunit.lang.extension.ArchUnitExtensions;
|
||||||
import org.junit.runner.RunWith;
|
import org.junit.runner.RunWith;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
|
|
||||||
import static com.tngtech.archunit.lang.syntax.ArchRuleDefinition.classes;
|
import static com.tngtech.archunit.lang.syntax.ArchRuleDefinition.classes;
|
||||||
|
import org.junit.jupiter.api.DisplayName;
|
||||||
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
|
||||||
@RunWith(ArchUnitRunner.class)
|
@AnalyzeClasses(packages = "org.datavec.api.transform.ops", importOptions = { ImportOption.DoNotIncludeTests.class })
|
||||||
@AnalyzeClasses(packages = "org.datavec.api.transform.ops", importOptions = {ImportOption.DoNotIncludeTests.class})
|
@DisplayName("Aggregable Multi Op Arch Test")
|
||||||
public class AggregableMultiOpArchTest extends BaseND4JTest {
|
class AggregableMultiOpArchTest extends BaseND4JTest {
|
||||||
|
|
||||||
@ArchTest
|
@ArchTest
|
||||||
public static final ArchRule ALL_AGGREGATE_OPS_MUST_BE_SERIALIZABLE = classes()
|
public static final ArchRule ALL_AGGREGATE_OPS_MUST_BE_SERIALIZABLE = classes().that().resideInAPackage("org.datavec.api.transform.ops").and().doNotHaveSimpleName("AggregatorImpls").and().doNotHaveSimpleName("IAggregableReduceOp").and().doNotHaveSimpleName("StringAggregatorImpls").and().doNotHaveFullyQualifiedName("org.datavec.api.transform.ops.StringAggregatorImpls$1").should().implement(Serializable.class).because("All aggregate ops must be serializable.");
|
||||||
.that().resideInAPackage("org.datavec.api.transform.ops")
|
|
||||||
.and().doNotHaveSimpleName("AggregatorImpls")
|
|
||||||
.and().doNotHaveSimpleName("IAggregableReduceOp")
|
|
||||||
.and().doNotHaveSimpleName("StringAggregatorImpls")
|
|
||||||
.and().doNotHaveFullyQualifiedName("org.datavec.api.transform.ops.StringAggregatorImpls$1")
|
|
||||||
.should().implement(Serializable.class)
|
|
||||||
.because("All aggregate ops must be serializable.");
|
|
||||||
}
|
}
|
|
@ -17,52 +17,46 @@
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
* *****************************************************************************
|
* *****************************************************************************
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.datavec.api.transform.ops;
|
package org.datavec.api.transform.ops;
|
||||||
|
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
import org.junit.jupiter.api.DisplayName;
|
||||||
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
|
||||||
import static org.junit.Assert.assertTrue;
|
@DisplayName("Aggregable Multi Op Test")
|
||||||
|
class AggregableMultiOpTest extends BaseND4JTest {
|
||||||
public class AggregableMultiOpTest extends BaseND4JTest {
|
|
||||||
|
|
||||||
private List<Integer> intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
|
private List<Integer> intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testMulti() throws Exception {
|
@DisplayName("Test Multi")
|
||||||
|
void testMulti() throws Exception {
|
||||||
AggregatorImpls.AggregableFirst<Integer> af = new AggregatorImpls.AggregableFirst<>();
|
AggregatorImpls.AggregableFirst<Integer> af = new AggregatorImpls.AggregableFirst<>();
|
||||||
AggregatorImpls.AggregableSum<Integer> as = new AggregatorImpls.AggregableSum<>();
|
AggregatorImpls.AggregableSum<Integer> as = new AggregatorImpls.AggregableSum<>();
|
||||||
AggregableMultiOp<Integer> multi = new AggregableMultiOp<>(Arrays.asList(af, as));
|
AggregableMultiOp<Integer> multi = new AggregableMultiOp<>(Arrays.asList(af, as));
|
||||||
|
|
||||||
assertTrue(multi.getOperations().size() == 2);
|
assertTrue(multi.getOperations().size() == 2);
|
||||||
for (int i = 0; i < intList.size(); i++) {
|
for (int i = 0; i < intList.size(); i++) {
|
||||||
multi.accept(intList.get(i));
|
multi.accept(intList.get(i));
|
||||||
}
|
}
|
||||||
|
|
||||||
// mutablility
|
// mutablility
|
||||||
assertTrue(as.get().toDouble() == 45D);
|
assertTrue(as.get().toDouble() == 45D);
|
||||||
assertTrue(af.get().toInt() == 1);
|
assertTrue(af.get().toInt() == 1);
|
||||||
|
|
||||||
List<Writable> res = multi.get();
|
List<Writable> res = multi.get();
|
||||||
assertTrue(res.get(1).toDouble() == 45D);
|
assertTrue(res.get(1).toDouble() == 45D);
|
||||||
assertTrue(res.get(0).toInt() == 1);
|
assertTrue(res.get(0).toInt() == 1);
|
||||||
|
|
||||||
AggregatorImpls.AggregableFirst<Integer> rf = new AggregatorImpls.AggregableFirst<>();
|
AggregatorImpls.AggregableFirst<Integer> rf = new AggregatorImpls.AggregableFirst<>();
|
||||||
AggregatorImpls.AggregableSum<Integer> rs = new AggregatorImpls.AggregableSum<>();
|
AggregatorImpls.AggregableSum<Integer> rs = new AggregatorImpls.AggregableSum<>();
|
||||||
AggregableMultiOp<Integer> reverse = new AggregableMultiOp<>(Arrays.asList(rf, rs));
|
AggregableMultiOp<Integer> reverse = new AggregableMultiOp<>(Arrays.asList(rf, rs));
|
||||||
|
|
||||||
for (int i = 0; i < intList.size(); i++) {
|
for (int i = 0; i < intList.size(); i++) {
|
||||||
reverse.accept(intList.get(intList.size() - i - 1));
|
reverse.accept(intList.get(intList.size() - i - 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
List<Writable> revRes = reverse.get();
|
List<Writable> revRes = reverse.get();
|
||||||
assertTrue(revRes.get(1).toDouble() == 45D);
|
assertTrue(revRes.get(1).toDouble() == 45D);
|
||||||
assertTrue(revRes.get(0).toInt() == 9);
|
assertTrue(revRes.get(0).toInt() == 9);
|
||||||
|
|
||||||
multi.combine(reverse);
|
multi.combine(reverse);
|
||||||
List<Writable> combinedRes = multi.get();
|
List<Writable> combinedRes = multi.get();
|
||||||
assertTrue(combinedRes.get(1).toDouble() == 90D);
|
assertTrue(combinedRes.get(1).toDouble() == 90D);
|
||||||
|
|
|
@ -17,41 +17,39 @@
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
* *****************************************************************************
|
* *****************************************************************************
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.datavec.api.transform.ops;
|
package org.datavec.api.transform.ops;
|
||||||
|
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.junit.rules.ExpectedException;
|
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
|
||||||
|
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import org.junit.jupiter.api.DisplayName;
|
||||||
import static org.junit.Assert.assertTrue;
|
|
||||||
|
|
||||||
public class AggregatorImplsTest extends BaseND4JTest {
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
|
@DisplayName("Aggregator Impls Test")
|
||||||
|
class AggregatorImplsTest extends BaseND4JTest {
|
||||||
|
|
||||||
private List<Integer> intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
|
private List<Integer> intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
|
||||||
|
|
||||||
private List<String> stringList = new ArrayList<>(Arrays.asList("arakoa", "abracadabra", "blast", "acceptance"));
|
private List<String> stringList = new ArrayList<>(Arrays.asList("arakoa", "abracadabra", "blast", "acceptance"));
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void aggregableFirstTest() {
|
@DisplayName("Aggregable First Test")
|
||||||
|
void aggregableFirstTest() {
|
||||||
AggregatorImpls.AggregableFirst<Integer> first = new AggregatorImpls.AggregableFirst<>();
|
AggregatorImpls.AggregableFirst<Integer> first = new AggregatorImpls.AggregableFirst<>();
|
||||||
for (int i = 0; i < intList.size(); i++) {
|
for (int i = 0; i < intList.size(); i++) {
|
||||||
first.accept(intList.get(i));
|
first.accept(intList.get(i));
|
||||||
}
|
}
|
||||||
assertEquals(1, first.get().toInt());
|
assertEquals(1, first.get().toInt());
|
||||||
|
|
||||||
AggregatorImpls.AggregableFirst<String> firstS = new AggregatorImpls.AggregableFirst<>();
|
AggregatorImpls.AggregableFirst<String> firstS = new AggregatorImpls.AggregableFirst<>();
|
||||||
for (int i = 0; i < stringList.size(); i++) {
|
for (int i = 0; i < stringList.size(); i++) {
|
||||||
firstS.accept(stringList.get(i));
|
firstS.accept(stringList.get(i));
|
||||||
}
|
}
|
||||||
assertTrue(firstS.get().toString().equals("arakoa"));
|
assertTrue(firstS.get().toString().equals("arakoa"));
|
||||||
|
|
||||||
|
|
||||||
AggregatorImpls.AggregableFirst<Integer> reverse = new AggregatorImpls.AggregableFirst<>();
|
AggregatorImpls.AggregableFirst<Integer> reverse = new AggregatorImpls.AggregableFirst<>();
|
||||||
for (int i = 0; i < intList.size(); i++) {
|
for (int i = 0; i < intList.size(); i++) {
|
||||||
reverse.accept(intList.get(intList.size() - i - 1));
|
reverse.accept(intList.get(intList.size() - i - 1));
|
||||||
|
@ -60,22 +58,19 @@ public class AggregatorImplsTest extends BaseND4JTest {
|
||||||
assertEquals(1, first.get().toInt());
|
assertEquals(1, first.get().toInt());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void aggregableLastTest() {
|
@DisplayName("Aggregable Last Test")
|
||||||
|
void aggregableLastTest() {
|
||||||
AggregatorImpls.AggregableLast<Integer> last = new AggregatorImpls.AggregableLast<>();
|
AggregatorImpls.AggregableLast<Integer> last = new AggregatorImpls.AggregableLast<>();
|
||||||
for (int i = 0; i < intList.size(); i++) {
|
for (int i = 0; i < intList.size(); i++) {
|
||||||
last.accept(intList.get(i));
|
last.accept(intList.get(i));
|
||||||
}
|
}
|
||||||
assertEquals(9, last.get().toInt());
|
assertEquals(9, last.get().toInt());
|
||||||
|
|
||||||
AggregatorImpls.AggregableLast<String> lastS = new AggregatorImpls.AggregableLast<>();
|
AggregatorImpls.AggregableLast<String> lastS = new AggregatorImpls.AggregableLast<>();
|
||||||
for (int i = 0; i < stringList.size(); i++) {
|
for (int i = 0; i < stringList.size(); i++) {
|
||||||
lastS.accept(stringList.get(i));
|
lastS.accept(stringList.get(i));
|
||||||
}
|
}
|
||||||
assertTrue(lastS.get().toString().equals("acceptance"));
|
assertTrue(lastS.get().toString().equals("acceptance"));
|
||||||
|
|
||||||
|
|
||||||
AggregatorImpls.AggregableLast<Integer> reverse = new AggregatorImpls.AggregableLast<>();
|
AggregatorImpls.AggregableLast<Integer> reverse = new AggregatorImpls.AggregableLast<>();
|
||||||
for (int i = 0; i < intList.size(); i++) {
|
for (int i = 0; i < intList.size(); i++) {
|
||||||
reverse.accept(intList.get(intList.size() - i - 1));
|
reverse.accept(intList.get(intList.size() - i - 1));
|
||||||
|
@ -85,20 +80,18 @@ public class AggregatorImplsTest extends BaseND4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void aggregableCountTest() {
|
@DisplayName("Aggregable Count Test")
|
||||||
|
void aggregableCountTest() {
|
||||||
AggregatorImpls.AggregableCount<Integer> cnt = new AggregatorImpls.AggregableCount<>();
|
AggregatorImpls.AggregableCount<Integer> cnt = new AggregatorImpls.AggregableCount<>();
|
||||||
for (int i = 0; i < intList.size(); i++) {
|
for (int i = 0; i < intList.size(); i++) {
|
||||||
cnt.accept(intList.get(i));
|
cnt.accept(intList.get(i));
|
||||||
}
|
}
|
||||||
assertEquals(9, cnt.get().toInt());
|
assertEquals(9, cnt.get().toInt());
|
||||||
|
|
||||||
AggregatorImpls.AggregableCount<String> lastS = new AggregatorImpls.AggregableCount<>();
|
AggregatorImpls.AggregableCount<String> lastS = new AggregatorImpls.AggregableCount<>();
|
||||||
for (int i = 0; i < stringList.size(); i++) {
|
for (int i = 0; i < stringList.size(); i++) {
|
||||||
lastS.accept(stringList.get(i));
|
lastS.accept(stringList.get(i));
|
||||||
}
|
}
|
||||||
assertEquals(4, lastS.get().toInt());
|
assertEquals(4, lastS.get().toInt());
|
||||||
|
|
||||||
|
|
||||||
AggregatorImpls.AggregableCount<Integer> reverse = new AggregatorImpls.AggregableCount<>();
|
AggregatorImpls.AggregableCount<Integer> reverse = new AggregatorImpls.AggregableCount<>();
|
||||||
for (int i = 0; i < intList.size(); i++) {
|
for (int i = 0; i < intList.size(); i++) {
|
||||||
reverse.accept(intList.get(intList.size() - i - 1));
|
reverse.accept(intList.get(intList.size() - i - 1));
|
||||||
|
@ -108,14 +101,13 @@ public class AggregatorImplsTest extends BaseND4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void aggregableMaxTest() {
|
@DisplayName("Aggregable Max Test")
|
||||||
|
void aggregableMaxTest() {
|
||||||
AggregatorImpls.AggregableMax<Integer> mx = new AggregatorImpls.AggregableMax<>();
|
AggregatorImpls.AggregableMax<Integer> mx = new AggregatorImpls.AggregableMax<>();
|
||||||
for (int i = 0; i < intList.size(); i++) {
|
for (int i = 0; i < intList.size(); i++) {
|
||||||
mx.accept(intList.get(i));
|
mx.accept(intList.get(i));
|
||||||
}
|
}
|
||||||
assertEquals(9, mx.get().toInt());
|
assertEquals(9, mx.get().toInt());
|
||||||
|
|
||||||
|
|
||||||
AggregatorImpls.AggregableMax<Integer> reverse = new AggregatorImpls.AggregableMax<>();
|
AggregatorImpls.AggregableMax<Integer> reverse = new AggregatorImpls.AggregableMax<>();
|
||||||
for (int i = 0; i < intList.size(); i++) {
|
for (int i = 0; i < intList.size(); i++) {
|
||||||
reverse.accept(intList.get(intList.size() - i - 1));
|
reverse.accept(intList.get(intList.size() - i - 1));
|
||||||
|
@ -124,16 +116,14 @@ public class AggregatorImplsTest extends BaseND4JTest {
|
||||||
assertEquals(9, mx.get().toInt());
|
assertEquals(9, mx.get().toInt());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void aggregableRangeTest() {
|
@DisplayName("Aggregable Range Test")
|
||||||
|
void aggregableRangeTest() {
|
||||||
AggregatorImpls.AggregableRange<Integer> mx = new AggregatorImpls.AggregableRange<>();
|
AggregatorImpls.AggregableRange<Integer> mx = new AggregatorImpls.AggregableRange<>();
|
||||||
for (int i = 0; i < intList.size(); i++) {
|
for (int i = 0; i < intList.size(); i++) {
|
||||||
mx.accept(intList.get(i));
|
mx.accept(intList.get(i));
|
||||||
}
|
}
|
||||||
assertEquals(8, mx.get().toInt());
|
assertEquals(8, mx.get().toInt());
|
||||||
|
|
||||||
|
|
||||||
AggregatorImpls.AggregableRange<Integer> reverse = new AggregatorImpls.AggregableRange<>();
|
AggregatorImpls.AggregableRange<Integer> reverse = new AggregatorImpls.AggregableRange<>();
|
||||||
for (int i = 0; i < intList.size(); i++) {
|
for (int i = 0; i < intList.size(); i++) {
|
||||||
reverse.accept(intList.get(intList.size() - i - 1) + 9);
|
reverse.accept(intList.get(intList.size() - i - 1) + 9);
|
||||||
|
@ -143,14 +133,13 @@ public class AggregatorImplsTest extends BaseND4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void aggregableMinTest() {
|
@DisplayName("Aggregable Min Test")
|
||||||
|
void aggregableMinTest() {
|
||||||
AggregatorImpls.AggregableMin<Integer> mn = new AggregatorImpls.AggregableMin<>();
|
AggregatorImpls.AggregableMin<Integer> mn = new AggregatorImpls.AggregableMin<>();
|
||||||
for (int i = 0; i < intList.size(); i++) {
|
for (int i = 0; i < intList.size(); i++) {
|
||||||
mn.accept(intList.get(i));
|
mn.accept(intList.get(i));
|
||||||
}
|
}
|
||||||
assertEquals(1, mn.get().toInt());
|
assertEquals(1, mn.get().toInt());
|
||||||
|
|
||||||
|
|
||||||
AggregatorImpls.AggregableMin<Integer> reverse = new AggregatorImpls.AggregableMin<>();
|
AggregatorImpls.AggregableMin<Integer> reverse = new AggregatorImpls.AggregableMin<>();
|
||||||
for (int i = 0; i < intList.size(); i++) {
|
for (int i = 0; i < intList.size(); i++) {
|
||||||
reverse.accept(intList.get(intList.size() - i - 1));
|
reverse.accept(intList.get(intList.size() - i - 1));
|
||||||
|
@ -160,14 +149,13 @@ public class AggregatorImplsTest extends BaseND4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void aggregableSumTest() {
|
@DisplayName("Aggregable Sum Test")
|
||||||
|
void aggregableSumTest() {
|
||||||
AggregatorImpls.AggregableSum<Integer> sm = new AggregatorImpls.AggregableSum<>();
|
AggregatorImpls.AggregableSum<Integer> sm = new AggregatorImpls.AggregableSum<>();
|
||||||
for (int i = 0; i < intList.size(); i++) {
|
for (int i = 0; i < intList.size(); i++) {
|
||||||
sm.accept(intList.get(i));
|
sm.accept(intList.get(i));
|
||||||
}
|
}
|
||||||
assertEquals(45, sm.get().toInt());
|
assertEquals(45, sm.get().toInt());
|
||||||
|
|
||||||
|
|
||||||
AggregatorImpls.AggregableSum<Integer> reverse = new AggregatorImpls.AggregableSum<>();
|
AggregatorImpls.AggregableSum<Integer> reverse = new AggregatorImpls.AggregableSum<>();
|
||||||
for (int i = 0; i < intList.size(); i++) {
|
for (int i = 0; i < intList.size(); i++) {
|
||||||
reverse.accept(intList.get(intList.size() - i - 1));
|
reverse.accept(intList.get(intList.size() - i - 1));
|
||||||
|
@ -176,17 +164,15 @@ public class AggregatorImplsTest extends BaseND4JTest {
|
||||||
assertEquals(90, sm.get().toInt());
|
assertEquals(90, sm.get().toInt());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void aggregableMeanTest() {
|
@DisplayName("Aggregable Mean Test")
|
||||||
|
void aggregableMeanTest() {
|
||||||
AggregatorImpls.AggregableMean<Integer> mn = new AggregatorImpls.AggregableMean<>();
|
AggregatorImpls.AggregableMean<Integer> mn = new AggregatorImpls.AggregableMean<>();
|
||||||
for (int i = 0; i < intList.size(); i++) {
|
for (int i = 0; i < intList.size(); i++) {
|
||||||
mn.accept(intList.get(i));
|
mn.accept(intList.get(i));
|
||||||
}
|
}
|
||||||
assertEquals(9l, (long) mn.getCount());
|
assertEquals(9l, (long) mn.getCount());
|
||||||
assertEquals(5D, mn.get().toDouble(), 0.001);
|
assertEquals(5D, mn.get().toDouble(), 0.001);
|
||||||
|
|
||||||
|
|
||||||
AggregatorImpls.AggregableMean<Integer> reverse = new AggregatorImpls.AggregableMean<>();
|
AggregatorImpls.AggregableMean<Integer> reverse = new AggregatorImpls.AggregableMean<>();
|
||||||
for (int i = 0; i < intList.size(); i++) {
|
for (int i = 0; i < intList.size(); i++) {
|
||||||
reverse.accept(intList.get(intList.size() - i - 1));
|
reverse.accept(intList.get(intList.size() - i - 1));
|
||||||
|
@ -197,80 +183,73 @@ public class AggregatorImplsTest extends BaseND4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void aggregableStdDevTest() {
|
@DisplayName("Aggregable Std Dev Test")
|
||||||
|
void aggregableStdDevTest() {
|
||||||
AggregatorImpls.AggregableStdDev<Integer> sd = new AggregatorImpls.AggregableStdDev<>();
|
AggregatorImpls.AggregableStdDev<Integer> sd = new AggregatorImpls.AggregableStdDev<>();
|
||||||
for (int i = 0; i < intList.size(); i++) {
|
for (int i = 0; i < intList.size(); i++) {
|
||||||
sd.accept(intList.get(i));
|
sd.accept(intList.get(i));
|
||||||
}
|
}
|
||||||
assertTrue(Math.abs(sd.get().toDouble() - 2.7386) < 0.0001);
|
assertTrue(Math.abs(sd.get().toDouble() - 2.7386) < 0.0001);
|
||||||
|
|
||||||
|
|
||||||
AggregatorImpls.AggregableStdDev<Integer> reverse = new AggregatorImpls.AggregableStdDev<>();
|
AggregatorImpls.AggregableStdDev<Integer> reverse = new AggregatorImpls.AggregableStdDev<>();
|
||||||
for (int i = 0; i < intList.size(); i++) {
|
for (int i = 0; i < intList.size(); i++) {
|
||||||
reverse.accept(intList.get(intList.size() - i - 1));
|
reverse.accept(intList.get(intList.size() - i - 1));
|
||||||
}
|
}
|
||||||
sd.combine(reverse);
|
sd.combine(reverse);
|
||||||
assertTrue("" + sd.get().toDouble(), Math.abs(sd.get().toDouble() - 1.8787) < 0.0001);
|
assertTrue(Math.abs(sd.get().toDouble() - 1.8787) < 0.0001,"" + sd.get().toDouble());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void aggregableVariance() {
|
@DisplayName("Aggregable Variance")
|
||||||
|
void aggregableVariance() {
|
||||||
AggregatorImpls.AggregableVariance<Integer> sd = new AggregatorImpls.AggregableVariance<>();
|
AggregatorImpls.AggregableVariance<Integer> sd = new AggregatorImpls.AggregableVariance<>();
|
||||||
for (int i = 0; i < intList.size(); i++) {
|
for (int i = 0; i < intList.size(); i++) {
|
||||||
sd.accept(intList.get(i));
|
sd.accept(intList.get(i));
|
||||||
}
|
}
|
||||||
assertTrue(Math.abs(sd.get().toDouble() - 60D / 8) < 0.0001);
|
assertTrue(Math.abs(sd.get().toDouble() - 60D / 8) < 0.0001);
|
||||||
|
|
||||||
|
|
||||||
AggregatorImpls.AggregableVariance<Integer> reverse = new AggregatorImpls.AggregableVariance<>();
|
AggregatorImpls.AggregableVariance<Integer> reverse = new AggregatorImpls.AggregableVariance<>();
|
||||||
for (int i = 0; i < intList.size(); i++) {
|
for (int i = 0; i < intList.size(); i++) {
|
||||||
reverse.accept(intList.get(intList.size() - i - 1));
|
reverse.accept(intList.get(intList.size() - i - 1));
|
||||||
}
|
}
|
||||||
sd.combine(reverse);
|
sd.combine(reverse);
|
||||||
assertTrue("" + sd.get().toDouble(), Math.abs(sd.get().toDouble() - 3.5294) < 0.0001);
|
assertTrue(Math.abs(sd.get().toDouble() - 3.5294) < 0.0001,"" + sd.get().toDouble());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void aggregableUncorrectedStdDevTest() {
|
@DisplayName("Aggregable Uncorrected Std Dev Test")
|
||||||
|
void aggregableUncorrectedStdDevTest() {
|
||||||
AggregatorImpls.AggregableUncorrectedStdDev<Integer> sd = new AggregatorImpls.AggregableUncorrectedStdDev<>();
|
AggregatorImpls.AggregableUncorrectedStdDev<Integer> sd = new AggregatorImpls.AggregableUncorrectedStdDev<>();
|
||||||
for (int i = 0; i < intList.size(); i++) {
|
for (int i = 0; i < intList.size(); i++) {
|
||||||
sd.accept(intList.get(i));
|
sd.accept(intList.get(i));
|
||||||
}
|
}
|
||||||
assertTrue(Math.abs(sd.get().toDouble() - 2.582) < 0.0001);
|
assertTrue(Math.abs(sd.get().toDouble() - 2.582) < 0.0001);
|
||||||
|
AggregatorImpls.AggregableUncorrectedStdDev<Integer> reverse = new AggregatorImpls.AggregableUncorrectedStdDev<>();
|
||||||
|
|
||||||
AggregatorImpls.AggregableUncorrectedStdDev<Integer> reverse =
|
|
||||||
new AggregatorImpls.AggregableUncorrectedStdDev<>();
|
|
||||||
for (int i = 0; i < intList.size(); i++) {
|
for (int i = 0; i < intList.size(); i++) {
|
||||||
reverse.accept(intList.get(intList.size() - i - 1));
|
reverse.accept(intList.get(intList.size() - i - 1));
|
||||||
}
|
}
|
||||||
sd.combine(reverse);
|
sd.combine(reverse);
|
||||||
assertTrue("" + sd.get().toDouble(), Math.abs(sd.get().toDouble() - 1.8257) < 0.0001);
|
assertTrue(Math.abs(sd.get().toDouble() - 1.8257) < 0.0001,"" + sd.get().toDouble());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void aggregablePopulationVariance() {
|
@DisplayName("Aggregable Population Variance")
|
||||||
|
void aggregablePopulationVariance() {
|
||||||
AggregatorImpls.AggregablePopulationVariance<Integer> sd = new AggregatorImpls.AggregablePopulationVariance<>();
|
AggregatorImpls.AggregablePopulationVariance<Integer> sd = new AggregatorImpls.AggregablePopulationVariance<>();
|
||||||
for (int i = 0; i < intList.size(); i++) {
|
for (int i = 0; i < intList.size(); i++) {
|
||||||
sd.accept(intList.get(i));
|
sd.accept(intList.get(i));
|
||||||
}
|
}
|
||||||
assertTrue(Math.abs(sd.get().toDouble() - 60D / 9) < 0.0001);
|
assertTrue(Math.abs(sd.get().toDouble() - 60D / 9) < 0.0001);
|
||||||
|
AggregatorImpls.AggregablePopulationVariance<Integer> reverse = new AggregatorImpls.AggregablePopulationVariance<>();
|
||||||
|
|
||||||
AggregatorImpls.AggregablePopulationVariance<Integer> reverse =
|
|
||||||
new AggregatorImpls.AggregablePopulationVariance<>();
|
|
||||||
for (int i = 0; i < intList.size(); i++) {
|
for (int i = 0; i < intList.size(); i++) {
|
||||||
reverse.accept(intList.get(intList.size() - i - 1));
|
reverse.accept(intList.get(intList.size() - i - 1));
|
||||||
}
|
}
|
||||||
sd.combine(reverse);
|
sd.combine(reverse);
|
||||||
assertTrue("" + sd.get().toDouble(), Math.abs(sd.get().toDouble() - 30D / 9) < 0.0001);
|
assertTrue(Math.abs(sd.get().toDouble() - 30D / 9) < 0.0001,"" + sd.get().toDouble());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void aggregableCountUniqueTest() {
|
@DisplayName("Aggregable Count Unique Test")
|
||||||
|
void aggregableCountUniqueTest() {
|
||||||
// at this low range, it's linear counting
|
// at this low range, it's linear counting
|
||||||
|
|
||||||
AggregatorImpls.AggregableCountUnique<Integer> cu = new AggregatorImpls.AggregableCountUnique<>();
|
AggregatorImpls.AggregableCountUnique<Integer> cu = new AggregatorImpls.AggregableCountUnique<>();
|
||||||
for (int i = 0; i < intList.size(); i++) {
|
for (int i = 0; i < intList.size(); i++) {
|
||||||
cu.accept(intList.get(i));
|
cu.accept(intList.get(i));
|
||||||
|
@ -278,7 +257,6 @@ public class AggregatorImplsTest extends BaseND4JTest {
|
||||||
assertEquals(9, cu.get().toInt());
|
assertEquals(9, cu.get().toInt());
|
||||||
cu.accept(1);
|
cu.accept(1);
|
||||||
assertEquals(9, cu.get().toInt());
|
assertEquals(9, cu.get().toInt());
|
||||||
|
|
||||||
AggregatorImpls.AggregableCountUnique<Integer> reverse = new AggregatorImpls.AggregableCountUnique<>();
|
AggregatorImpls.AggregableCountUnique<Integer> reverse = new AggregatorImpls.AggregableCountUnique<>();
|
||||||
for (int i = 0; i < intList.size(); i++) {
|
for (int i = 0; i < intList.size(); i++) {
|
||||||
reverse.accept(intList.get(intList.size() - i - 1));
|
reverse.accept(intList.get(intList.size() - i - 1));
|
||||||
|
@ -287,26 +265,25 @@ public class AggregatorImplsTest extends BaseND4JTest {
|
||||||
assertEquals(9, cu.get().toInt());
|
assertEquals(9, cu.get().toInt());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Rule
|
|
||||||
public final ExpectedException exception = ExpectedException.none();
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void incompatibleAggregatorTest() {
|
@DisplayName("Incompatible Aggregator Test")
|
||||||
AggregatorImpls.AggregableSum<Integer> sm = new AggregatorImpls.AggregableSum<>();
|
void incompatibleAggregatorTest() {
|
||||||
for (int i = 0; i < intList.size(); i++) {
|
assertThrows(UnsupportedOperationException.class,() -> {
|
||||||
sm.accept(intList.get(i));
|
AggregatorImpls.AggregableSum<Integer> sm = new AggregatorImpls.AggregableSum<>();
|
||||||
}
|
for (int i = 0; i < intList.size(); i++) {
|
||||||
assertEquals(45, sm.get().toInt());
|
sm.accept(intList.get(i));
|
||||||
|
}
|
||||||
|
assertEquals(45, sm.get().toInt());
|
||||||
|
AggregatorImpls.AggregableMean<Integer> reverse = new AggregatorImpls.AggregableMean<>();
|
||||||
|
for (int i = 0; i < intList.size(); i++) {
|
||||||
|
reverse.accept(intList.get(intList.size() - i - 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
sm.combine(reverse);
|
||||||
|
assertEquals(45, sm.get().toInt());
|
||||||
|
});
|
||||||
|
|
||||||
AggregatorImpls.AggregableMean<Integer> reverse = new AggregatorImpls.AggregableMean<>();
|
|
||||||
for (int i = 0; i < intList.size(); i++) {
|
|
||||||
reverse.accept(intList.get(intList.size() - i - 1));
|
|
||||||
}
|
|
||||||
exception.expect(UnsupportedOperationException.class);
|
|
||||||
sm.combine(reverse);
|
|
||||||
assertEquals(45, sm.get().toInt());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,77 +17,65 @@
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
* *****************************************************************************
|
* *****************************************************************************
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.datavec.api.transform.ops;
|
package org.datavec.api.transform.ops;
|
||||||
|
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
import org.junit.jupiter.api.DisplayName;
|
||||||
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
|
||||||
import static org.junit.Assert.assertTrue;
|
@DisplayName("Dispatch Op Test")
|
||||||
|
class DispatchOpTest extends BaseND4JTest {
|
||||||
public class DispatchOpTest extends BaseND4JTest {
|
|
||||||
|
|
||||||
private List<Integer> intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
|
private List<Integer> intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
|
||||||
|
|
||||||
private List<String> stringList = new ArrayList<>(Arrays.asList("arakoa", "abracadabra", "blast", "acceptance"));
|
private List<String> stringList = new ArrayList<>(Arrays.asList("arakoa", "abracadabra", "blast", "acceptance"));
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testDispatchSimple() {
|
@DisplayName("Test Dispatch Simple")
|
||||||
|
void testDispatchSimple() {
|
||||||
AggregatorImpls.AggregableFirst<Integer> af = new AggregatorImpls.AggregableFirst<>();
|
AggregatorImpls.AggregableFirst<Integer> af = new AggregatorImpls.AggregableFirst<>();
|
||||||
AggregatorImpls.AggregableSum<Integer> as = new AggregatorImpls.AggregableSum<>();
|
AggregatorImpls.AggregableSum<Integer> as = new AggregatorImpls.AggregableSum<>();
|
||||||
AggregableMultiOp<Integer> multiaf =
|
AggregableMultiOp<Integer> multiaf = new AggregableMultiOp<>(Collections.<IAggregableReduceOp<Integer, Writable>>singletonList(af));
|
||||||
new AggregableMultiOp<>(Collections.<IAggregableReduceOp<Integer, Writable>>singletonList(af));
|
AggregableMultiOp<Integer> multias = new AggregableMultiOp<>(Collections.<IAggregableReduceOp<Integer, Writable>>singletonList(as));
|
||||||
AggregableMultiOp<Integer> multias =
|
DispatchOp<Integer, Writable> parallel = new DispatchOp<>(Arrays.<IAggregableReduceOp<Integer, List<Writable>>>asList(multiaf, multias));
|
||||||
new AggregableMultiOp<>(Collections.<IAggregableReduceOp<Integer, Writable>>singletonList(as));
|
|
||||||
|
|
||||||
DispatchOp<Integer, Writable> parallel =
|
|
||||||
new DispatchOp<>(Arrays.<IAggregableReduceOp<Integer, List<Writable>>>asList(multiaf, multias));
|
|
||||||
|
|
||||||
assertTrue(multiaf.getOperations().size() == 1);
|
assertTrue(multiaf.getOperations().size() == 1);
|
||||||
assertTrue(multias.getOperations().size() == 1);
|
assertTrue(multias.getOperations().size() == 1);
|
||||||
assertTrue(parallel.getOperations().size() == 2);
|
assertTrue(parallel.getOperations().size() == 2);
|
||||||
for (int i = 0; i < intList.size(); i++) {
|
for (int i = 0; i < intList.size(); i++) {
|
||||||
parallel.accept(Arrays.asList(intList.get(i), intList.get(i)));
|
parallel.accept(Arrays.asList(intList.get(i), intList.get(i)));
|
||||||
}
|
}
|
||||||
|
|
||||||
List<Writable> res = parallel.get();
|
List<Writable> res = parallel.get();
|
||||||
assertTrue(res.get(1).toDouble() == 45D);
|
assertTrue(res.get(1).toDouble() == 45D);
|
||||||
assertTrue(res.get(0).toInt() == 1);
|
assertTrue(res.get(0).toInt() == 1);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testDispatchFlatMap() {
|
@DisplayName("Test Dispatch Flat Map")
|
||||||
|
void testDispatchFlatMap() {
|
||||||
AggregatorImpls.AggregableFirst<Integer> af = new AggregatorImpls.AggregableFirst<>();
|
AggregatorImpls.AggregableFirst<Integer> af = new AggregatorImpls.AggregableFirst<>();
|
||||||
AggregatorImpls.AggregableSum<Integer> as = new AggregatorImpls.AggregableSum<>();
|
AggregatorImpls.AggregableSum<Integer> as = new AggregatorImpls.AggregableSum<>();
|
||||||
AggregableMultiOp<Integer> multi = new AggregableMultiOp<>(Arrays.asList(af, as));
|
AggregableMultiOp<Integer> multi = new AggregableMultiOp<>(Arrays.asList(af, as));
|
||||||
|
|
||||||
AggregatorImpls.AggregableLast<Integer> al = new AggregatorImpls.AggregableLast<>();
|
AggregatorImpls.AggregableLast<Integer> al = new AggregatorImpls.AggregableLast<>();
|
||||||
AggregatorImpls.AggregableMax<Integer> amax = new AggregatorImpls.AggregableMax<>();
|
AggregatorImpls.AggregableMax<Integer> amax = new AggregatorImpls.AggregableMax<>();
|
||||||
AggregableMultiOp<Integer> otherMulti = new AggregableMultiOp<>(Arrays.asList(al, amax));
|
AggregableMultiOp<Integer> otherMulti = new AggregableMultiOp<>(Arrays.asList(al, amax));
|
||||||
|
DispatchOp<Integer, Writable> parallel = new DispatchOp<>(Arrays.<IAggregableReduceOp<Integer, List<Writable>>>asList(multi, otherMulti));
|
||||||
|
|
||||||
DispatchOp<Integer, Writable> parallel = new DispatchOp<>(
|
|
||||||
Arrays.<IAggregableReduceOp<Integer, List<Writable>>>asList(multi, otherMulti));
|
|
||||||
|
|
||||||
assertTrue(multi.getOperations().size() == 2);
|
assertTrue(multi.getOperations().size() == 2);
|
||||||
assertTrue(otherMulti.getOperations().size() == 2);
|
assertTrue(otherMulti.getOperations().size() == 2);
|
||||||
assertTrue(parallel.getOperations().size() == 2);
|
assertTrue(parallel.getOperations().size() == 2);
|
||||||
for (int i = 0; i < intList.size(); i++) {
|
for (int i = 0; i < intList.size(); i++) {
|
||||||
parallel.accept(Arrays.asList(intList.get(i), intList.get(i)));
|
parallel.accept(Arrays.asList(intList.get(i), intList.get(i)));
|
||||||
}
|
}
|
||||||
|
|
||||||
List<Writable> res = parallel.get();
|
List<Writable> res = parallel.get();
|
||||||
assertTrue(res.get(1).toDouble() == 45D);
|
assertTrue(res.get(1).toDouble() == 45D);
|
||||||
assertTrue(res.get(0).toInt() == 1);
|
assertTrue(res.get(0).toInt() == 1);
|
||||||
assertTrue(res.get(3).toDouble() == 9);
|
assertTrue(res.get(3).toDouble() == 9);
|
||||||
assertTrue(res.get(2).toInt() == 9);
|
assertTrue(res.get(2).toInt() == 9);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,13 +32,14 @@ import org.datavec.api.transform.ops.AggregableMultiOp;
|
||||||
import org.datavec.api.transform.ops.IAggregableReduceOp;
|
import org.datavec.api.transform.ops.IAggregableReduceOp;
|
||||||
import org.datavec.api.transform.schema.Schema;
|
import org.datavec.api.transform.schema.Schema;
|
||||||
import org.datavec.api.writable.*;
|
import org.datavec.api.writable.*;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Disabled;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.junit.Assert.fail;
|
import static org.junit.jupiter.api.Assertions.fail;
|
||||||
|
|
||||||
public class TestMultiOpReduce extends BaseND4JTest {
|
public class TestMultiOpReduce extends BaseND4JTest {
|
||||||
|
|
||||||
|
@ -46,10 +47,10 @@ public class TestMultiOpReduce extends BaseND4JTest {
|
||||||
public void testMultiOpReducerDouble() {
|
public void testMultiOpReducerDouble() {
|
||||||
|
|
||||||
List<List<Writable>> inputs = new ArrayList<>();
|
List<List<Writable>> inputs = new ArrayList<>();
|
||||||
inputs.add(Arrays.asList((Writable) new Text("someKey"), new DoubleWritable(0)));
|
inputs.add(Arrays.asList(new Text("someKey"), new DoubleWritable(0)));
|
||||||
inputs.add(Arrays.asList((Writable) new Text("someKey"), new DoubleWritable(1)));
|
inputs.add(Arrays.asList(new Text("someKey"), new DoubleWritable(1)));
|
||||||
inputs.add(Arrays.asList((Writable) new Text("someKey"), new DoubleWritable(2)));
|
inputs.add(Arrays.asList(new Text("someKey"), new DoubleWritable(2)));
|
||||||
inputs.add(Arrays.asList((Writable) new Text("someKey"), new DoubleWritable(2)));
|
inputs.add(Arrays.asList(new Text("someKey"), new DoubleWritable(2)));
|
||||||
|
|
||||||
Map<ReduceOp, Double> exp = new LinkedHashMap<>();
|
Map<ReduceOp, Double> exp = new LinkedHashMap<>();
|
||||||
exp.put(ReduceOp.Min, 0.0);
|
exp.put(ReduceOp.Min, 0.0);
|
||||||
|
@ -82,7 +83,7 @@ public class TestMultiOpReduce extends BaseND4JTest {
|
||||||
assertEquals(out.get(0), new Text("someKey"));
|
assertEquals(out.get(0), new Text("someKey"));
|
||||||
|
|
||||||
String msg = op.toString();
|
String msg = op.toString();
|
||||||
assertEquals(msg, exp.get(op), out.get(1).toDouble(), 1e-5);
|
assertEquals(exp.get(op), out.get(1).toDouble(), 1e-5,msg);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -126,19 +127,20 @@ public class TestMultiOpReduce extends BaseND4JTest {
|
||||||
assertEquals(out.get(0), new Text("someKey"));
|
assertEquals(out.get(0), new Text("someKey"));
|
||||||
|
|
||||||
String msg = op.toString();
|
String msg = op.toString();
|
||||||
assertEquals(msg, exp.get(op), out.get(1).toDouble(), 1e-5);
|
assertEquals(exp.get(op), out.get(1).toDouble(), 1e-5,msg);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@Disabled
|
||||||
public void testReduceString() {
|
public void testReduceString() {
|
||||||
|
|
||||||
List<List<Writable>> inputs = new ArrayList<>();
|
List<List<Writable>> inputs = new ArrayList<>();
|
||||||
inputs.add(Arrays.asList((Writable) new Text("someKey"), new Text("1")));
|
inputs.add(Arrays.asList(new Text("someKey"), new Text("1")));
|
||||||
inputs.add(Arrays.asList((Writable) new Text("someKey"), new Text("2")));
|
inputs.add(Arrays.asList(new Text("someKey"), new Text("2")));
|
||||||
inputs.add(Arrays.asList((Writable) new Text("someKey"), new Text("3")));
|
inputs.add(Arrays.asList(new Text("someKey"), new Text("3")));
|
||||||
inputs.add(Arrays.asList((Writable) new Text("someKey"), new Text("4")));
|
inputs.add(Arrays.asList(new Text("someKey"), new Text("4")));
|
||||||
|
|
||||||
Map<ReduceOp, String> exp = new LinkedHashMap<>();
|
Map<ReduceOp, String> exp = new LinkedHashMap<>();
|
||||||
exp.put(ReduceOp.Append, "1234");
|
exp.put(ReduceOp.Append, "1234");
|
||||||
|
@ -210,7 +212,7 @@ public class TestMultiOpReduce extends BaseND4JTest {
|
||||||
assertEquals(out.get(0), new Text("someKey"));
|
assertEquals(out.get(0), new Text("someKey"));
|
||||||
|
|
||||||
String msg = op.toString();
|
String msg = op.toString();
|
||||||
assertEquals(msg, exp.get(op), out.get(1).toDouble(), 1e-5);
|
assertEquals(exp.get(op), out.get(1).toDouble(), 1e-5,msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (ReduceOp op : Arrays.asList(ReduceOp.Min, ReduceOp.Max, ReduceOp.Range, ReduceOp.Sum, ReduceOp.Mean,
|
for (ReduceOp op : Arrays.asList(ReduceOp.Min, ReduceOp.Max, ReduceOp.Range, ReduceOp.Sum, ReduceOp.Mean,
|
||||||
|
|
|
@ -24,13 +24,13 @@ import org.datavec.api.transform.ops.IAggregableReduceOp;
|
||||||
import org.datavec.api.transform.reduce.impl.GeographicMidpointReduction;
|
import org.datavec.api.transform.reduce.impl.GeographicMidpointReduction;
|
||||||
import org.datavec.api.writable.Text;
|
import org.datavec.api.writable.Text;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestReductions extends BaseND4JTest {
|
public class TestReductions extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
|
@ -22,10 +22,10 @@ package org.datavec.api.transform.schema;
|
||||||
|
|
||||||
import org.datavec.api.transform.metadata.ColumnMetaData;
|
import org.datavec.api.transform.metadata.ColumnMetaData;
|
||||||
import org.joda.time.DateTimeZone;
|
import org.joda.time.DateTimeZone;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestJsonYaml extends BaseND4JTest {
|
public class TestJsonYaml extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
|
@ -21,10 +21,10 @@
|
||||||
package org.datavec.api.transform.schema;
|
package org.datavec.api.transform.schema;
|
||||||
|
|
||||||
import org.datavec.api.transform.ColumnType;
|
import org.datavec.api.transform.ColumnType;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestSchemaMethods extends BaseND4JTest {
|
public class TestSchemaMethods extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
|
@ -33,7 +33,7 @@ import org.datavec.api.writable.LongWritable;
|
||||||
import org.datavec.api.writable.NullWritable;
|
import org.datavec.api.writable.NullWritable;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.joda.time.DateTimeZone;
|
import org.joda.time.DateTimeZone;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
@ -41,7 +41,7 @@ import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestReduceSequenceByWindowFunction extends BaseND4JTest {
|
public class TestReduceSequenceByWindowFunction extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
|
@ -27,7 +27,7 @@ import org.datavec.api.writable.LongWritable;
|
||||||
import org.datavec.api.writable.Text;
|
import org.datavec.api.writable.Text;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.joda.time.DateTimeZone;
|
import org.joda.time.DateTimeZone;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
@ -35,7 +35,7 @@ import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestSequenceSplit extends BaseND4JTest {
|
public class TestSequenceSplit extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
|
@ -29,7 +29,7 @@ import org.datavec.api.writable.IntWritable;
|
||||||
import org.datavec.api.writable.LongWritable;
|
import org.datavec.api.writable.LongWritable;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.joda.time.DateTimeZone;
|
import org.joda.time.DateTimeZone;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
@ -37,7 +37,7 @@ import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestWindowFunctions extends BaseND4JTest {
|
public class TestWindowFunctions extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
|
@ -26,10 +26,10 @@ import org.datavec.api.transform.schema.Schema;
|
||||||
import org.datavec.api.transform.serde.testClasses.CustomCondition;
|
import org.datavec.api.transform.serde.testClasses.CustomCondition;
|
||||||
import org.datavec.api.transform.serde.testClasses.CustomFilter;
|
import org.datavec.api.transform.serde.testClasses.CustomFilter;
|
||||||
import org.datavec.api.transform.serde.testClasses.CustomTransform;
|
import org.datavec.api.transform.serde.testClasses.CustomTransform;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestCustomTransformJsonYaml extends BaseND4JTest {
|
public class TestCustomTransformJsonYaml extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
|
@ -64,13 +64,13 @@ import org.datavec.api.transform.transform.time.TimeMathOpTransform;
|
||||||
import org.datavec.api.writable.comparator.DoubleWritableComparator;
|
import org.datavec.api.writable.comparator.DoubleWritableComparator;
|
||||||
import org.joda.time.DateTimeFieldType;
|
import org.joda.time.DateTimeFieldType;
|
||||||
import org.joda.time.DateTimeZone;
|
import org.joda.time.DateTimeZone;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestYamlJsonSerde extends BaseND4JTest {
|
public class TestYamlJsonSerde extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
|
@ -24,12 +24,12 @@ import org.datavec.api.transform.StringReduceOp;
|
||||||
import org.datavec.api.transform.schema.Schema;
|
import org.datavec.api.transform.schema.Schema;
|
||||||
import org.datavec.api.writable.Text;
|
import org.datavec.api.writable.Text;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestReduce extends BaseND4JTest {
|
public class TestReduce extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
|
@ -50,7 +50,7 @@ import org.datavec.api.writable.Text;
|
||||||
import org.datavec.api.writable.comparator.LongWritableComparator;
|
import org.datavec.api.writable.comparator.LongWritableComparator;
|
||||||
import org.joda.time.DateTimeFieldType;
|
import org.joda.time.DateTimeFieldType;
|
||||||
import org.joda.time.DateTimeZone;
|
import org.joda.time.DateTimeZone;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
import org.nd4j.common.io.ClassPathResource;
|
||||||
|
|
||||||
|
@ -61,7 +61,7 @@ import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class RegressionTestJson extends BaseND4JTest {
|
public class RegressionTestJson extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
|
@ -50,13 +50,13 @@ import org.datavec.api.writable.Text;
|
||||||
import org.datavec.api.writable.comparator.LongWritableComparator;
|
import org.datavec.api.writable.comparator.LongWritableComparator;
|
||||||
import org.joda.time.DateTimeFieldType;
|
import org.joda.time.DateTimeFieldType;
|
||||||
import org.joda.time.DateTimeZone;
|
import org.joda.time.DateTimeZone;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestJsonYaml extends BaseND4JTest {
|
public class TestJsonYaml extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
|
@ -58,8 +58,8 @@ import org.datavec.api.transform.transform.time.TimeMathOpTransform;
|
||||||
import org.datavec.api.writable.*;
|
import org.datavec.api.writable.*;
|
||||||
import org.joda.time.DateTimeFieldType;
|
import org.joda.time.DateTimeFieldType;
|
||||||
import org.joda.time.DateTimeZone;
|
import org.joda.time.DateTimeZone;
|
||||||
import org.junit.Assert;
|
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -71,8 +71,8 @@ import java.io.ObjectOutputStream;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
import static junit.framework.TestCase.assertEquals;
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
public class TestTransforms extends BaseND4JTest {
|
public class TestTransforms extends BaseND4JTest {
|
||||||
|
|
||||||
|
@ -277,22 +277,22 @@ public class TestTransforms extends BaseND4JTest {
|
||||||
List<String> outputColumns = new ArrayList<>(ALL_COLUMNS);
|
List<String> outputColumns = new ArrayList<>(ALL_COLUMNS);
|
||||||
outputColumns.add(NEW_COLUMN);
|
outputColumns.add(NEW_COLUMN);
|
||||||
Schema newSchema = transform.transform(schema);
|
Schema newSchema = transform.transform(schema);
|
||||||
Assert.assertEquals(outputColumns, newSchema.getColumnNames());
|
assertEquals(outputColumns, newSchema.getColumnNames());
|
||||||
|
|
||||||
List<Writable> input = new ArrayList<>();
|
List<Writable> input = new ArrayList<>();
|
||||||
input.addAll(COLUMN_VALUES);
|
input.addAll(COLUMN_VALUES);
|
||||||
|
|
||||||
transform.setInputSchema(schema);
|
transform.setInputSchema(schema);
|
||||||
List<Writable> transformed = transform.map(input);
|
List<Writable> transformed = transform.map(input);
|
||||||
Assert.assertEquals(NEW_COLUMN_VALUE, transformed.get(transformed.size() - 1).toString());
|
assertEquals(NEW_COLUMN_VALUE, transformed.get(transformed.size() - 1).toString());
|
||||||
|
|
||||||
List<Text> outputColumnValues = new ArrayList<>(COLUMN_VALUES);
|
List<Text> outputColumnValues = new ArrayList<>(COLUMN_VALUES);
|
||||||
outputColumnValues.add(new Text(NEW_COLUMN_VALUE));
|
outputColumnValues.add(new Text(NEW_COLUMN_VALUE));
|
||||||
Assert.assertEquals(outputColumnValues, transformed);
|
assertEquals(outputColumnValues, transformed);
|
||||||
|
|
||||||
String s = JsonMappers.getMapper().writeValueAsString(transform);
|
String s = JsonMappers.getMapper().writeValueAsString(transform);
|
||||||
Transform transform2 = JsonMappers.getMapper().readValue(s, ConcatenateStringColumns.class);
|
Transform transform2 = JsonMappers.getMapper().readValue(s, ConcatenateStringColumns.class);
|
||||||
Assert.assertEquals(transform, transform2);
|
assertEquals(transform, transform2);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -309,7 +309,7 @@ public class TestTransforms extends BaseND4JTest {
|
||||||
transform.setInputSchema(schema);
|
transform.setInputSchema(schema);
|
||||||
Schema newSchema = transform.transform(schema);
|
Schema newSchema = transform.transform(schema);
|
||||||
List<String> outputColumns = new ArrayList<>(ALL_COLUMNS);
|
List<String> outputColumns = new ArrayList<>(ALL_COLUMNS);
|
||||||
Assert.assertEquals(outputColumns, newSchema.getColumnNames());
|
assertEquals(outputColumns, newSchema.getColumnNames());
|
||||||
|
|
||||||
transform = new ChangeCaseStringTransform(STRING_COLUMN, ChangeCaseStringTransform.CaseType.LOWER);
|
transform = new ChangeCaseStringTransform(STRING_COLUMN, ChangeCaseStringTransform.CaseType.LOWER);
|
||||||
transform.setInputSchema(schema);
|
transform.setInputSchema(schema);
|
||||||
|
@ -320,8 +320,8 @@ public class TestTransforms extends BaseND4JTest {
|
||||||
output.add(new Text(TEXT_LOWER_CASE));
|
output.add(new Text(TEXT_LOWER_CASE));
|
||||||
output.add(new Text(TEXT_MIXED_CASE));
|
output.add(new Text(TEXT_MIXED_CASE));
|
||||||
List<Writable> transformed = transform.map(input);
|
List<Writable> transformed = transform.map(input);
|
||||||
Assert.assertEquals(transformed.get(0).toString(), TEXT_LOWER_CASE);
|
assertEquals(transformed.get(0).toString(), TEXT_LOWER_CASE);
|
||||||
Assert.assertEquals(transformed, output);
|
assertEquals(transformed, output);
|
||||||
|
|
||||||
transform = new ChangeCaseStringTransform(STRING_COLUMN, ChangeCaseStringTransform.CaseType.UPPER);
|
transform = new ChangeCaseStringTransform(STRING_COLUMN, ChangeCaseStringTransform.CaseType.UPPER);
|
||||||
transform.setInputSchema(schema);
|
transform.setInputSchema(schema);
|
||||||
|
@ -329,12 +329,12 @@ public class TestTransforms extends BaseND4JTest {
|
||||||
output.add(new Text(TEXT_UPPER_CASE));
|
output.add(new Text(TEXT_UPPER_CASE));
|
||||||
output.add(new Text(TEXT_MIXED_CASE));
|
output.add(new Text(TEXT_MIXED_CASE));
|
||||||
transformed = transform.map(input);
|
transformed = transform.map(input);
|
||||||
Assert.assertEquals(transformed.get(0).toString(), TEXT_UPPER_CASE);
|
assertEquals(transformed.get(0).toString(), TEXT_UPPER_CASE);
|
||||||
Assert.assertEquals(transformed, output);
|
assertEquals(transformed, output);
|
||||||
|
|
||||||
String s = JsonMappers.getMapper().writeValueAsString(transform);
|
String s = JsonMappers.getMapper().writeValueAsString(transform);
|
||||||
Transform transform2 = JsonMappers.getMapper().readValue(s, ChangeCaseStringTransform.class);
|
Transform transform2 = JsonMappers.getMapper().readValue(s, ChangeCaseStringTransform.class);
|
||||||
Assert.assertEquals(transform, transform2);
|
assertEquals(transform, transform2);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -1530,7 +1530,7 @@ public class TestTransforms extends BaseND4JTest {
|
||||||
|
|
||||||
String json = JsonMappers.getMapper().writeValueAsString(t);
|
String json = JsonMappers.getMapper().writeValueAsString(t);
|
||||||
Transform transform2 = JsonMappers.getMapper().readValue(json, StringListToCountsNDArrayTransform.class);
|
Transform transform2 = JsonMappers.getMapper().readValue(json, StringListToCountsNDArrayTransform.class);
|
||||||
Assert.assertEquals(t, transform2);
|
assertEquals(t, transform2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1551,7 +1551,7 @@ public class TestTransforms extends BaseND4JTest {
|
||||||
|
|
||||||
String json = JsonMappers.getMapper().writeValueAsString(t);
|
String json = JsonMappers.getMapper().writeValueAsString(t);
|
||||||
Transform transform2 = JsonMappers.getMapper().readValue(json, StringListToIndicesNDArrayTransform.class);
|
Transform transform2 = JsonMappers.getMapper().readValue(json, StringListToIndicesNDArrayTransform.class);
|
||||||
Assert.assertEquals(t, transform2);
|
assertEquals(t, transform2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -29,7 +29,7 @@ import org.datavec.api.writable.DoubleWritable;
|
||||||
import org.datavec.api.writable.NDArrayWritable;
|
import org.datavec.api.writable.NDArrayWritable;
|
||||||
import org.datavec.api.writable.Text;
|
import org.datavec.api.writable.Text;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -39,7 +39,7 @@ import org.nd4j.linalg.ops.transforms.Transforms;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestNDArrayWritableTransforms extends BaseND4JTest {
|
public class TestNDArrayWritableTransforms extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
|
@ -30,13 +30,13 @@ import org.datavec.api.transform.ndarray.NDArrayScalarOpTransform;
|
||||||
import org.datavec.api.transform.schema.Schema;
|
import org.datavec.api.transform.schema.Schema;
|
||||||
import org.datavec.api.transform.serde.JsonSerializer;
|
import org.datavec.api.transform.serde.JsonSerializer;
|
||||||
import org.datavec.api.transform.serde.YamlSerializer;
|
import org.datavec.api.transform.serde.YamlSerializer;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestYamlJsonSerde extends BaseND4JTest {
|
public class TestYamlJsonSerde extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
|
@ -17,29 +17,29 @@
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
* *****************************************************************************
|
* *****************************************************************************
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.datavec.api.transform.transform.parse;
|
package org.datavec.api.transform.transform.parse;
|
||||||
|
|
||||||
import org.datavec.api.writable.DoubleWritable;
|
import org.datavec.api.writable.DoubleWritable;
|
||||||
import org.datavec.api.writable.Text;
|
import org.datavec.api.writable.Text;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import org.junit.jupiter.api.DisplayName;
|
||||||
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
@DisplayName("Parse Double Transform Test")
|
||||||
|
class ParseDoubleTransformTest extends BaseND4JTest {
|
||||||
|
|
||||||
public class ParseDoubleTransformTest extends BaseND4JTest {
|
|
||||||
@Test
|
@Test
|
||||||
public void testDoubleTransform() {
|
@DisplayName("Test Double Transform")
|
||||||
|
void testDoubleTransform() {
|
||||||
List<Writable> record = new ArrayList<>();
|
List<Writable> record = new ArrayList<>();
|
||||||
record.add(new Text("0.0"));
|
record.add(new Text("0.0"));
|
||||||
List<Writable> transformed = Arrays.<Writable>asList(new DoubleWritable(0.0));
|
List<Writable> transformed = Arrays.<Writable>asList(new DoubleWritable(0.0));
|
||||||
assertEquals(transformed, new ParseDoubleTransform().map(record));
|
assertEquals(transformed, new ParseDoubleTransform().map(record));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,26 +35,26 @@ import org.datavec.api.writable.DoubleWritable;
|
||||||
import org.datavec.api.writable.Text;
|
import org.datavec.api.writable.Text;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.joda.time.DateTimeZone;
|
import org.joda.time.DateTimeZone;
|
||||||
import org.junit.Ignore;
|
import org.junit.jupiter.api.Disabled;
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.rules.TemporaryFolder;
|
|
||||||
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
|
import java.nio.file.Path;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestUI extends BaseND4JTest {
|
public class TestUI extends BaseND4JTest {
|
||||||
|
|
||||||
@Rule
|
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testUI() throws Exception {
|
public void testUI(@TempDir Path testDir) throws Exception {
|
||||||
Schema schema = new Schema.Builder().addColumnString("StringColumn").addColumnInteger("IntColumn")
|
Schema schema = new Schema.Builder().addColumnString("StringColumn").addColumnInteger("IntColumn")
|
||||||
.addColumnInteger("IntColumn2").addColumnInteger("IntColumn3")
|
.addColumnInteger("IntColumn2").addColumnInteger("IntColumn3")
|
||||||
.addColumnTime("TimeColumn", DateTimeZone.UTC).build();
|
.addColumnTime("TimeColumn", DateTimeZone.UTC).build();
|
||||||
|
@ -92,7 +92,7 @@ public class TestUI extends BaseND4JTest {
|
||||||
|
|
||||||
DataAnalysis da = new DataAnalysis(schema, list);
|
DataAnalysis da = new DataAnalysis(schema, list);
|
||||||
|
|
||||||
File fDir = testDir.newFolder();
|
File fDir = testDir.toFile();
|
||||||
String tempDir = fDir.getAbsolutePath();
|
String tempDir = fDir.getAbsolutePath();
|
||||||
String outPath = FilenameUtils.concat(tempDir, "datavec_transform_UITest.html");
|
String outPath = FilenameUtils.concat(tempDir, "datavec_transform_UITest.html");
|
||||||
System.out.println(outPath);
|
System.out.println(outPath);
|
||||||
|
@ -143,7 +143,7 @@ public class TestUI extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Ignore
|
@Disabled
|
||||||
public void testSequencePlot() throws Exception {
|
public void testSequencePlot() throws Exception {
|
||||||
|
|
||||||
Schema schema = new SequenceSchema.Builder().addColumnDouble("sinx")
|
Schema schema = new SequenceSchema.Builder().addColumnDouble("sinx")
|
||||||
|
|
|
@ -17,30 +17,31 @@
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
* *****************************************************************************
|
* *****************************************************************************
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.datavec.api.util;
|
package org.datavec.api.util;
|
||||||
|
|
||||||
import org.junit.Before;
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
||||||
import java.io.BufferedReader;
|
import java.io.BufferedReader;
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.io.InputStream;
|
import java.io.InputStream;
|
||||||
import java.io.InputStreamReader;
|
import java.io.InputStreamReader;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
import static org.junit.Assert.assertTrue;
|
|
||||||
import static org.hamcrest.MatcherAssert.assertThat;
|
import static org.hamcrest.MatcherAssert.assertThat;
|
||||||
import static org.hamcrest.core.AnyOf.anyOf;
|
import static org.hamcrest.core.AnyOf.anyOf;
|
||||||
import static org.hamcrest.core.IsEqual.equalTo;
|
import static org.hamcrest.core.IsEqual.equalTo;
|
||||||
|
import org.junit.jupiter.api.DisplayName;
|
||||||
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
|
||||||
public class ClassPathResourceTest extends BaseND4JTest {
|
@DisplayName("Class Path Resource Test")
|
||||||
|
class ClassPathResourceTest extends BaseND4JTest {
|
||||||
|
|
||||||
private boolean isWindows = false; //File sizes are reported slightly different on Linux vs. Windows
|
// File sizes are reported slightly different on Linux vs. Windows
|
||||||
|
private boolean isWindows = false;
|
||||||
|
|
||||||
@Before
|
@BeforeEach
|
||||||
public void setUp() throws Exception {
|
void setUp() throws Exception {
|
||||||
String osname = System.getProperty("os.name");
|
String osname = System.getProperty("os.name");
|
||||||
if (osname != null && osname.toLowerCase().contains("win")) {
|
if (osname != null && osname.toLowerCase().contains("win")) {
|
||||||
isWindows = true;
|
isWindows = true;
|
||||||
|
@ -48,9 +49,9 @@ public class ClassPathResourceTest extends BaseND4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testGetFile1() throws Exception {
|
@DisplayName("Test Get File 1")
|
||||||
|
void testGetFile1() throws Exception {
|
||||||
File intFile = new ClassPathResource("datavec-api/iris.dat").getFile();
|
File intFile = new ClassPathResource("datavec-api/iris.dat").getFile();
|
||||||
|
|
||||||
assertTrue(intFile.exists());
|
assertTrue(intFile.exists());
|
||||||
if (isWindows) {
|
if (isWindows) {
|
||||||
assertThat(intFile.length(), anyOf(equalTo(2700L), equalTo(2850L)));
|
assertThat(intFile.length(), anyOf(equalTo(2700L), equalTo(2850L)));
|
||||||
|
@ -60,9 +61,9 @@ public class ClassPathResourceTest extends BaseND4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testGetFileSlash1() throws Exception {
|
@DisplayName("Test Get File Slash 1")
|
||||||
|
void testGetFileSlash1() throws Exception {
|
||||||
File intFile = new ClassPathResource("datavec-api/iris.dat").getFile();
|
File intFile = new ClassPathResource("datavec-api/iris.dat").getFile();
|
||||||
|
|
||||||
assertTrue(intFile.exists());
|
assertTrue(intFile.exists());
|
||||||
if (isWindows) {
|
if (isWindows) {
|
||||||
assertThat(intFile.length(), anyOf(equalTo(2700L), equalTo(2850L)));
|
assertThat(intFile.length(), anyOf(equalTo(2700L), equalTo(2850L)));
|
||||||
|
@ -72,11 +73,10 @@ public class ClassPathResourceTest extends BaseND4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testGetFileWithSpace1() throws Exception {
|
@DisplayName("Test Get File With Space 1")
|
||||||
|
void testGetFileWithSpace1() throws Exception {
|
||||||
File intFile = new ClassPathResource("datavec-api/csvsequence test.txt").getFile();
|
File intFile = new ClassPathResource("datavec-api/csvsequence test.txt").getFile();
|
||||||
|
|
||||||
assertTrue(intFile.exists());
|
assertTrue(intFile.exists());
|
||||||
|
|
||||||
if (isWindows) {
|
if (isWindows) {
|
||||||
assertThat(intFile.length(), anyOf(equalTo(60L), equalTo(64L)));
|
assertThat(intFile.length(), anyOf(equalTo(60L), equalTo(64L)));
|
||||||
} else {
|
} else {
|
||||||
|
@ -85,16 +85,15 @@ public class ClassPathResourceTest extends BaseND4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testInputStream() throws Exception {
|
@DisplayName("Test Input Stream")
|
||||||
|
void testInputStream() throws Exception {
|
||||||
ClassPathResource resource = new ClassPathResource("datavec-api/csvsequence_1.txt");
|
ClassPathResource resource = new ClassPathResource("datavec-api/csvsequence_1.txt");
|
||||||
File intFile = resource.getFile();
|
File intFile = resource.getFile();
|
||||||
|
|
||||||
if (isWindows) {
|
if (isWindows) {
|
||||||
assertThat(intFile.length(), anyOf(equalTo(60L), equalTo(64L)));
|
assertThat(intFile.length(), anyOf(equalTo(60L), equalTo(64L)));
|
||||||
} else {
|
} else {
|
||||||
assertEquals(60, intFile.length());
|
assertEquals(60, intFile.length());
|
||||||
}
|
}
|
||||||
|
|
||||||
InputStream stream = resource.getInputStream();
|
InputStream stream = resource.getInputStream();
|
||||||
BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
|
BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
|
||||||
String line = "";
|
String line = "";
|
||||||
|
@ -102,21 +101,19 @@ public class ClassPathResourceTest extends BaseND4JTest {
|
||||||
while ((line = reader.readLine()) != null) {
|
while ((line = reader.readLine()) != null) {
|
||||||
cnt++;
|
cnt++;
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(5, cnt);
|
assertEquals(5, cnt);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testInputStreamSlash() throws Exception {
|
@DisplayName("Test Input Stream Slash")
|
||||||
|
void testInputStreamSlash() throws Exception {
|
||||||
ClassPathResource resource = new ClassPathResource("datavec-api/csvsequence_1.txt");
|
ClassPathResource resource = new ClassPathResource("datavec-api/csvsequence_1.txt");
|
||||||
File intFile = resource.getFile();
|
File intFile = resource.getFile();
|
||||||
|
|
||||||
if (isWindows) {
|
if (isWindows) {
|
||||||
assertThat(intFile.length(), anyOf(equalTo(60L), equalTo(64L)));
|
assertThat(intFile.length(), anyOf(equalTo(60L), equalTo(64L)));
|
||||||
} else {
|
} else {
|
||||||
assertEquals(60, intFile.length());
|
assertEquals(60, intFile.length());
|
||||||
}
|
}
|
||||||
|
|
||||||
InputStream stream = resource.getInputStream();
|
InputStream stream = resource.getInputStream();
|
||||||
BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
|
BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
|
||||||
String line = "";
|
String line = "";
|
||||||
|
@ -124,7 +121,6 @@ public class ClassPathResourceTest extends BaseND4JTest {
|
||||||
while ((line = reader.readLine()) != null) {
|
while ((line = reader.readLine()) != null) {
|
||||||
cnt++;
|
cnt++;
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(5, cnt);
|
assertEquals(5, cnt);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,44 +17,41 @@
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
* *****************************************************************************
|
* *****************************************************************************
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.datavec.api.util;
|
package org.datavec.api.util;
|
||||||
|
|
||||||
import org.datavec.api.timeseries.util.TimeSeriesWritableUtils;
|
import org.datavec.api.timeseries.util.TimeSeriesWritableUtils;
|
||||||
import org.datavec.api.writable.DoubleWritable;
|
import org.datavec.api.writable.DoubleWritable;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
|
||||||
|
import org.junit.jupiter.api.DisplayName;
|
||||||
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
|
||||||
import static org.junit.Assert.assertArrayEquals;
|
@DisplayName("Time Series Utils Test")
|
||||||
|
class TimeSeriesUtilsTest extends BaseND4JTest {
|
||||||
public class TimeSeriesUtilsTest extends BaseND4JTest {
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testTimeSeriesCreation() {
|
@DisplayName("Test Time Series Creation")
|
||||||
|
void testTimeSeriesCreation() {
|
||||||
List<List<List<Writable>>> test = new ArrayList<>();
|
List<List<List<Writable>>> test = new ArrayList<>();
|
||||||
List<List<Writable>> timeStep = new ArrayList<>();
|
List<List<Writable>> timeStep = new ArrayList<>();
|
||||||
for(int i = 0; i < 5; i++) {
|
for (int i = 0; i < 5; i++) {
|
||||||
timeStep.add(getRecord(5));
|
timeStep.add(getRecord(5));
|
||||||
}
|
}
|
||||||
|
|
||||||
test.add(timeStep);
|
test.add(timeStep);
|
||||||
|
|
||||||
INDArray arr = TimeSeriesWritableUtils.convertWritablesSequence(test).getFirst();
|
INDArray arr = TimeSeriesWritableUtils.convertWritablesSequence(test).getFirst();
|
||||||
assertArrayEquals(new long[]{1,5,5},arr.shape());
|
assertArrayEquals(new long[] { 1, 5, 5 }, arr.shape());
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<Writable> getRecord(int length) {
|
private List<Writable> getRecord(int length) {
|
||||||
List<Writable> ret = new ArrayList<>();
|
List<Writable> ret = new ArrayList<>();
|
||||||
for(int i = 0; i < length; i++) {
|
for (int i = 0; i < length; i++) {
|
||||||
ret.add(new DoubleWritable(1.0));
|
ret.add(new DoubleWritable(1.0));
|
||||||
}
|
}
|
||||||
|
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,52 +17,50 @@
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
* *****************************************************************************
|
* *****************************************************************************
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.datavec.api.writable;
|
package org.datavec.api.writable;
|
||||||
|
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
import org.nd4j.shade.guava.collect.Lists;
|
import org.nd4j.shade.guava.collect.Lists;
|
||||||
import org.datavec.api.transform.schema.Schema;
|
import org.datavec.api.transform.schema.Schema;
|
||||||
import org.datavec.api.util.ndarray.RecordConverter;
|
import org.datavec.api.util.ndarray.RecordConverter;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.TimeZone;
|
import java.util.TimeZone;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import org.junit.jupiter.api.DisplayName;
|
||||||
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
@DisplayName("Record Converter Test")
|
||||||
|
class RecordConverterTest extends BaseND4JTest {
|
||||||
|
|
||||||
public class RecordConverterTest extends BaseND4JTest {
|
|
||||||
@Test
|
@Test
|
||||||
public void toRecords_PassInClassificationDataSet_ExpectNDArrayAndIntWritables() {
|
@DisplayName("To Records _ Pass In Classification Data Set _ Expect ND Array And Int Writables")
|
||||||
INDArray feature1 = Nd4j.create(new double[]{4, -5.7, 10, -0.1}, new long[]{1, 4}, DataType.FLOAT);
|
void toRecords_PassInClassificationDataSet_ExpectNDArrayAndIntWritables() {
|
||||||
INDArray feature2 = Nd4j.create(new double[]{11, .7, -1.3, 4}, new long[]{1, 4}, DataType.FLOAT);
|
INDArray feature1 = Nd4j.create(new double[] { 4, -5.7, 10, -0.1 }, new long[] { 1, 4 }, DataType.FLOAT);
|
||||||
INDArray label1 = Nd4j.create(new double[]{0, 0, 1, 0}, new long[]{1, 4}, DataType.FLOAT);
|
INDArray feature2 = Nd4j.create(new double[] { 11, .7, -1.3, 4 }, new long[] { 1, 4 }, DataType.FLOAT);
|
||||||
INDArray label2 = Nd4j.create(new double[]{0, 1, 0, 0}, new long[]{1, 4}, DataType.FLOAT);
|
INDArray label1 = Nd4j.create(new double[] { 0, 0, 1, 0 }, new long[] { 1, 4 }, DataType.FLOAT);
|
||||||
DataSet dataSet = new DataSet(Nd4j.vstack(Lists.newArrayList(feature1, feature2)),
|
INDArray label2 = Nd4j.create(new double[] { 0, 1, 0, 0 }, new long[] { 1, 4 }, DataType.FLOAT);
|
||||||
Nd4j.vstack(Lists.newArrayList(label1, label2)));
|
DataSet dataSet = new DataSet(Nd4j.vstack(Lists.newArrayList(feature1, feature2)), Nd4j.vstack(Lists.newArrayList(label1, label2)));
|
||||||
|
|
||||||
List<List<Writable>> writableList = RecordConverter.toRecords(dataSet);
|
List<List<Writable>> writableList = RecordConverter.toRecords(dataSet);
|
||||||
|
|
||||||
assertEquals(2, writableList.size());
|
assertEquals(2, writableList.size());
|
||||||
testClassificationWritables(feature1, 2, writableList.get(0));
|
testClassificationWritables(feature1, 2, writableList.get(0));
|
||||||
testClassificationWritables(feature2, 1, writableList.get(1));
|
testClassificationWritables(feature2, 1, writableList.get(1));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void toRecords_PassInRegressionDataSet_ExpectNDArrayAndDoubleWritables() {
|
@DisplayName("To Records _ Pass In Regression Data Set _ Expect ND Array And Double Writables")
|
||||||
INDArray feature = Nd4j.create(new double[]{4, -5.7, 10, -0.1}, new long[]{1, 4}, DataType.FLOAT);
|
void toRecords_PassInRegressionDataSet_ExpectNDArrayAndDoubleWritables() {
|
||||||
INDArray label = Nd4j.create(new double[]{.5, 2, 3, .5}, new long[]{1, 4}, DataType.FLOAT);
|
INDArray feature = Nd4j.create(new double[] { 4, -5.7, 10, -0.1 }, new long[] { 1, 4 }, DataType.FLOAT);
|
||||||
|
INDArray label = Nd4j.create(new double[] { .5, 2, 3, .5 }, new long[] { 1, 4 }, DataType.FLOAT);
|
||||||
DataSet dataSet = new DataSet(feature, label);
|
DataSet dataSet = new DataSet(feature, label);
|
||||||
|
|
||||||
List<List<Writable>> writableList = RecordConverter.toRecords(dataSet);
|
List<List<Writable>> writableList = RecordConverter.toRecords(dataSet);
|
||||||
List<Writable> results = writableList.get(0);
|
List<Writable> results = writableList.get(0);
|
||||||
NDArrayWritable ndArrayWritable = (NDArrayWritable) results.get(0);
|
NDArrayWritable ndArrayWritable = (NDArrayWritable) results.get(0);
|
||||||
|
|
||||||
assertEquals(1, writableList.size());
|
assertEquals(1, writableList.size());
|
||||||
assertEquals(5, results.size());
|
assertEquals(5, results.size());
|
||||||
assertEquals(feature, ndArrayWritable.get());
|
assertEquals(feature, ndArrayWritable.get());
|
||||||
|
@ -72,62 +70,39 @@ public class RecordConverterTest extends BaseND4JTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void testClassificationWritables(INDArray expectedFeatureVector, int expectLabelIndex,
|
private void testClassificationWritables(INDArray expectedFeatureVector, int expectLabelIndex, List<Writable> writables) {
|
||||||
List<Writable> writables) {
|
|
||||||
NDArrayWritable ndArrayWritable = (NDArrayWritable) writables.get(0);
|
NDArrayWritable ndArrayWritable = (NDArrayWritable) writables.get(0);
|
||||||
IntWritable intWritable = (IntWritable) writables.get(1);
|
IntWritable intWritable = (IntWritable) writables.get(1);
|
||||||
|
|
||||||
assertEquals(2, writables.size());
|
assertEquals(2, writables.size());
|
||||||
assertEquals(expectedFeatureVector, ndArrayWritable.get());
|
assertEquals(expectedFeatureVector, ndArrayWritable.get());
|
||||||
assertEquals(expectLabelIndex, intWritable.get());
|
assertEquals(expectLabelIndex, intWritable.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testNDArrayWritableConcat() {
|
@DisplayName("Test ND Array Writable Concat")
|
||||||
List<Writable> l = Arrays.<Writable>asList(new DoubleWritable(1),
|
void testNDArrayWritableConcat() {
|
||||||
new NDArrayWritable(Nd4j.create(new double[]{2, 3, 4}, new long[]{1, 3}, DataType.FLOAT)), new DoubleWritable(5),
|
List<Writable> l = Arrays.<Writable>asList(new DoubleWritable(1), new NDArrayWritable(Nd4j.create(new double[] { 2, 3, 4 }, new long[] { 1, 3 }, DataType.FLOAT)), new DoubleWritable(5), new NDArrayWritable(Nd4j.create(new double[] { 6, 7, 8 }, new long[] { 1, 3 }, DataType.FLOAT)), new IntWritable(9), new IntWritable(1));
|
||||||
new NDArrayWritable(Nd4j.create(new double[]{6, 7, 8}, new long[]{1, 3}, DataType.FLOAT)), new IntWritable(9),
|
INDArray exp = Nd4j.create(new double[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 1 }, new long[] { 1, 10 }, DataType.FLOAT);
|
||||||
new IntWritable(1));
|
|
||||||
|
|
||||||
INDArray exp = Nd4j.create(new double[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 1}, new long[]{1, 10}, DataType.FLOAT);
|
|
||||||
INDArray act = RecordConverter.toArray(DataType.FLOAT, l);
|
INDArray act = RecordConverter.toArray(DataType.FLOAT, l);
|
||||||
|
|
||||||
assertEquals(exp, act);
|
assertEquals(exp, act);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testNDArrayWritableConcatToMatrix(){
|
@DisplayName("Test ND Array Writable Concat To Matrix")
|
||||||
|
void testNDArrayWritableConcatToMatrix() {
|
||||||
List<Writable> l1 = Arrays.<Writable>asList(new DoubleWritable(1), new NDArrayWritable(Nd4j.create(new double[]{2, 3, 4}, new long[]{1,3}, DataType.FLOAT)), new DoubleWritable(5));
|
List<Writable> l1 = Arrays.<Writable>asList(new DoubleWritable(1), new NDArrayWritable(Nd4j.create(new double[] { 2, 3, 4 }, new long[] { 1, 3 }, DataType.FLOAT)), new DoubleWritable(5));
|
||||||
List<Writable> l2 = Arrays.<Writable>asList(new DoubleWritable(6), new NDArrayWritable(Nd4j.create(new double[]{7, 8, 9}, new long[]{1,3}, DataType.FLOAT)), new DoubleWritable(10));
|
List<Writable> l2 = Arrays.<Writable>asList(new DoubleWritable(6), new NDArrayWritable(Nd4j.create(new double[] { 7, 8, 9 }, new long[] { 1, 3 }, DataType.FLOAT)), new DoubleWritable(10));
|
||||||
|
INDArray exp = Nd4j.create(new double[][] { { 1, 2, 3, 4, 5 }, { 6, 7, 8, 9, 10 } }).castTo(DataType.FLOAT);
|
||||||
INDArray exp = Nd4j.create(new double[][]{
|
INDArray act = RecordConverter.toMatrix(DataType.FLOAT, Arrays.asList(l1, l2));
|
||||||
{1,2,3,4,5},
|
|
||||||
{6,7,8,9,10}}).castTo(DataType.FLOAT);
|
|
||||||
|
|
||||||
INDArray act = RecordConverter.toMatrix(DataType.FLOAT, Arrays.asList(l1,l2));
|
|
||||||
|
|
||||||
assertEquals(exp, act);
|
assertEquals(exp, act);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testToRecordWithListOfObject(){
|
@DisplayName("Test To Record With List Of Object")
|
||||||
final List<Object> list = Arrays.asList((Object)3, 7.0f, "Foo", "Bar", 1.0, 3f, 3L, 7, 0L);
|
void testToRecordWithListOfObject() {
|
||||||
final Schema schema = new Schema.Builder()
|
final List<Object> list = Arrays.asList((Object) 3, 7.0f, "Foo", "Bar", 1.0, 3f, 3L, 7, 0L);
|
||||||
.addColumnInteger("a")
|
final Schema schema = new Schema.Builder().addColumnInteger("a").addColumnFloat("b").addColumnString("c").addColumnCategorical("d", "Bar", "Baz").addColumnDouble("e").addColumnFloat("f").addColumnLong("g").addColumnInteger("h").addColumnTime("i", TimeZone.getDefault()).build();
|
||||||
.addColumnFloat("b")
|
|
||||||
.addColumnString("c")
|
|
||||||
.addColumnCategorical("d", "Bar", "Baz")
|
|
||||||
.addColumnDouble("e")
|
|
||||||
.addColumnFloat("f")
|
|
||||||
.addColumnLong("g")
|
|
||||||
.addColumnInteger("h")
|
|
||||||
.addColumnTime("i", TimeZone.getDefault())
|
|
||||||
.build();
|
|
||||||
|
|
||||||
final List<Writable> record = RecordConverter.toRecord(schema, list);
|
final List<Writable> record = RecordConverter.toRecord(schema, list);
|
||||||
|
|
||||||
assertEquals(record.get(0).toInt(), 3);
|
assertEquals(record.get(0).toInt(), 3);
|
||||||
assertEquals(record.get(1).toFloat(), 7f, 1e-6);
|
assertEquals(record.get(1).toFloat(), 7f, 1e-6);
|
||||||
assertEquals(record.get(2).toString(), "Foo");
|
assertEquals(record.get(2).toString(), "Foo");
|
||||||
|
@ -137,7 +112,5 @@ public class RecordConverterTest extends BaseND4JTest {
|
||||||
assertEquals(record.get(6).toLong(), 3L);
|
assertEquals(record.get(6).toLong(), 3L);
|
||||||
assertEquals(record.get(7).toInt(), 7);
|
assertEquals(record.get(7).toInt(), 7);
|
||||||
assertEquals(record.get(8).toLong(), 0);
|
assertEquals(record.get(8).toLong(), 0);
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,14 +21,14 @@
|
||||||
package org.datavec.api.writable;
|
package org.datavec.api.writable;
|
||||||
|
|
||||||
import org.datavec.api.transform.metadata.NDArrayMetaData;
|
import org.datavec.api.transform.metadata.NDArrayMetaData;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import java.io.*;
|
import java.io.*;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
public class TestNDArrayWritableAndSerialization extends BaseND4JTest {
|
public class TestNDArrayWritableAndSerialization extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
|
@ -17,38 +17,38 @@
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
* *****************************************************************************
|
* *****************************************************************************
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.datavec.api.writable;
|
package org.datavec.api.writable;
|
||||||
|
|
||||||
import org.datavec.api.writable.batch.NDArrayRecordBatch;
|
import org.datavec.api.writable.batch.NDArrayRecordBatch;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import java.nio.Buffer;
|
import java.nio.Buffer;
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Iterator;
|
import java.util.Iterator;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import org.junit.jupiter.api.DisplayName;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
public class WritableTest extends BaseND4JTest {
|
@DisplayName("Writable Test")
|
||||||
|
class WritableTest extends BaseND4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testWritableEqualityReflexive() {
|
@DisplayName("Test Writable Equality Reflexive")
|
||||||
|
void testWritableEqualityReflexive() {
|
||||||
assertEquals(new IntWritable(1), new IntWritable(1));
|
assertEquals(new IntWritable(1), new IntWritable(1));
|
||||||
assertEquals(new LongWritable(1), new LongWritable(1));
|
assertEquals(new LongWritable(1), new LongWritable(1));
|
||||||
assertEquals(new DoubleWritable(1), new DoubleWritable(1));
|
assertEquals(new DoubleWritable(1), new DoubleWritable(1));
|
||||||
assertEquals(new FloatWritable(1), new FloatWritable(1));
|
assertEquals(new FloatWritable(1), new FloatWritable(1));
|
||||||
assertEquals(new Text("Hello"), new Text("Hello"));
|
assertEquals(new Text("Hello"), new Text("Hello"));
|
||||||
assertEquals(new BytesWritable("Hello".getBytes()),new BytesWritable("Hello".getBytes()));
|
assertEquals(new BytesWritable("Hello".getBytes()), new BytesWritable("Hello".getBytes()));
|
||||||
INDArray ndArray = Nd4j.rand(new int[]{1, 100});
|
INDArray ndArray = Nd4j.rand(new int[] { 1, 100 });
|
||||||
|
|
||||||
assertEquals(new NDArrayWritable(ndArray), new NDArrayWritable(ndArray));
|
assertEquals(new NDArrayWritable(ndArray), new NDArrayWritable(ndArray));
|
||||||
assertEquals(new NullWritable(), new NullWritable());
|
assertEquals(new NullWritable(), new NullWritable());
|
||||||
assertEquals(new BooleanWritable(true), new BooleanWritable(true));
|
assertEquals(new BooleanWritable(true), new BooleanWritable(true));
|
||||||
|
@ -56,9 +56,9 @@ public class WritableTest extends BaseND4JTest {
|
||||||
assertEquals(new ByteWritable(b), new ByteWritable(b));
|
assertEquals(new ByteWritable(b), new ByteWritable(b));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBytesWritableIndexing() {
|
@DisplayName("Test Bytes Writable Indexing")
|
||||||
|
void testBytesWritableIndexing() {
|
||||||
byte[] doubleWrite = new byte[16];
|
byte[] doubleWrite = new byte[16];
|
||||||
ByteBuffer wrapped = ByteBuffer.wrap(doubleWrite);
|
ByteBuffer wrapped = ByteBuffer.wrap(doubleWrite);
|
||||||
Buffer buffer = (Buffer) wrapped;
|
Buffer buffer = (Buffer) wrapped;
|
||||||
|
@ -66,53 +66,51 @@ public class WritableTest extends BaseND4JTest {
|
||||||
wrapped.putDouble(2.0);
|
wrapped.putDouble(2.0);
|
||||||
buffer.rewind();
|
buffer.rewind();
|
||||||
BytesWritable byteWritable = new BytesWritable(doubleWrite);
|
BytesWritable byteWritable = new BytesWritable(doubleWrite);
|
||||||
assertEquals(2,byteWritable.getDouble(1),1e-1);
|
assertEquals(2, byteWritable.getDouble(1), 1e-1);
|
||||||
DataBuffer dataBuffer = Nd4j.createBuffer(new double[] {1,2});
|
DataBuffer dataBuffer = Nd4j.createBuffer(new double[] { 1, 2 });
|
||||||
double[] d1 = dataBuffer.asDouble();
|
double[] d1 = dataBuffer.asDouble();
|
||||||
double[] d2 = byteWritable.asNd4jBuffer(DataType.DOUBLE,8).asDouble();
|
double[] d2 = byteWritable.asNd4jBuffer(DataType.DOUBLE, 8).asDouble();
|
||||||
assertArrayEquals(d1, d2, 0.0);
|
assertArrayEquals(d1, d2, 0.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testByteWritable() {
|
@DisplayName("Test Byte Writable")
|
||||||
|
void testByteWritable() {
|
||||||
byte b = 0xfffffffe;
|
byte b = 0xfffffffe;
|
||||||
assertEquals(new IntWritable(-2), new ByteWritable(b));
|
assertEquals(new IntWritable(-2), new ByteWritable(b));
|
||||||
assertEquals(new LongWritable(-2), new ByteWritable(b));
|
assertEquals(new LongWritable(-2), new ByteWritable(b));
|
||||||
assertEquals(new ByteWritable(b), new IntWritable(-2));
|
assertEquals(new ByteWritable(b), new IntWritable(-2));
|
||||||
assertEquals(new ByteWritable(b), new LongWritable(-2));
|
assertEquals(new ByteWritable(b), new LongWritable(-2));
|
||||||
|
|
||||||
// those would cast to the same Int
|
// those would cast to the same Int
|
||||||
byte minus126 = 0xffffff82;
|
byte minus126 = 0xffffff82;
|
||||||
assertNotEquals(new ByteWritable(minus126), new IntWritable(130));
|
assertNotEquals(new ByteWritable(minus126), new IntWritable(130));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testIntLongWritable() {
|
@DisplayName("Test Int Long Writable")
|
||||||
|
void testIntLongWritable() {
|
||||||
assertEquals(new IntWritable(1), new LongWritable(1l));
|
assertEquals(new IntWritable(1), new LongWritable(1l));
|
||||||
assertEquals(new LongWritable(2l), new IntWritable(2));
|
assertEquals(new LongWritable(2l), new IntWritable(2));
|
||||||
|
|
||||||
long l = 1L << 34;
|
long l = 1L << 34;
|
||||||
// those would cast to the same Int
|
// those would cast to the same Int
|
||||||
assertNotEquals(new LongWritable(l), new IntWritable(4));
|
assertNotEquals(new LongWritable(l), new IntWritable(4));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testDoubleFloatWritable() {
|
@DisplayName("Test Double Float Writable")
|
||||||
|
void testDoubleFloatWritable() {
|
||||||
assertEquals(new DoubleWritable(1d), new FloatWritable(1f));
|
assertEquals(new DoubleWritable(1d), new FloatWritable(1f));
|
||||||
assertEquals(new FloatWritable(2f), new DoubleWritable(2d));
|
assertEquals(new FloatWritable(2f), new DoubleWritable(2d));
|
||||||
|
|
||||||
// we defer to Java equality for Floats
|
// we defer to Java equality for Floats
|
||||||
assertNotEquals(new DoubleWritable(1.1d), new FloatWritable(1.1f));
|
assertNotEquals(new DoubleWritable(1.1d), new FloatWritable(1.1f));
|
||||||
// same idea as above
|
// same idea as above
|
||||||
assertNotEquals(new DoubleWritable(1.1d), new FloatWritable((float)1.1d));
|
assertNotEquals(new DoubleWritable(1.1d), new FloatWritable((float) 1.1d));
|
||||||
|
assertNotEquals(new DoubleWritable((double) Float.MAX_VALUE + 1), new FloatWritable(Float.POSITIVE_INFINITY));
|
||||||
assertNotEquals(new DoubleWritable((double)Float.MAX_VALUE + 1), new FloatWritable(Float.POSITIVE_INFINITY));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testFuzzies() {
|
@DisplayName("Test Fuzzies")
|
||||||
|
void testFuzzies() {
|
||||||
assertTrue(new DoubleWritable(1.1d).fuzzyEquals(new FloatWritable(1.1f), 1e-6d));
|
assertTrue(new DoubleWritable(1.1d).fuzzyEquals(new FloatWritable(1.1f), 1e-6d));
|
||||||
assertTrue(new FloatWritable(1.1f).fuzzyEquals(new DoubleWritable(1.1d), 1e-6d));
|
assertTrue(new FloatWritable(1.1f).fuzzyEquals(new DoubleWritable(1.1d), 1e-6d));
|
||||||
byte b = 0xfffffffe;
|
byte b = 0xfffffffe;
|
||||||
|
@ -122,62 +120,57 @@ public class WritableTest extends BaseND4JTest {
|
||||||
assertTrue(new LongWritable(1).fuzzyEquals(new DoubleWritable(1.05f), 1e-1d));
|
assertTrue(new LongWritable(1).fuzzyEquals(new DoubleWritable(1.05f), 1e-1d));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testNDArrayRecordBatch(){
|
@DisplayName("Test ND Array Record Batch")
|
||||||
|
void testNDArrayRecordBatch() {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
// Outer list over writables/columns, inner list over examples
|
||||||
List<List<INDArray>> orig = new ArrayList<>(); //Outer list over writables/columns, inner list over examples
|
List<List<INDArray>> orig = new ArrayList<>();
|
||||||
for( int i=0; i<3; i++ ){
|
for (int i = 0; i < 3; i++) {
|
||||||
orig.add(new ArrayList<INDArray>());
|
orig.add(new ArrayList<INDArray>());
|
||||||
}
|
}
|
||||||
|
for (int i = 0; i < 5; i++) {
|
||||||
for( int i=0; i<5; i++ ){
|
orig.get(0).add(Nd4j.rand(1, 10));
|
||||||
orig.get(0).add(Nd4j.rand(1,10));
|
orig.get(1).add(Nd4j.rand(new int[] { 1, 5, 6 }));
|
||||||
orig.get(1).add(Nd4j.rand(new int[]{1,5,6}));
|
orig.get(2).add(Nd4j.rand(new int[] { 1, 3, 4, 5 }));
|
||||||
orig.get(2).add(Nd4j.rand(new int[]{1,3,4,5}));
|
|
||||||
}
|
}
|
||||||
|
// Outer list over examples, inner list over writables
|
||||||
List<List<INDArray>> origByExample = new ArrayList<>(); //Outer list over examples, inner list over writables
|
List<List<INDArray>> origByExample = new ArrayList<>();
|
||||||
for( int i=0; i<5; i++ ){
|
for (int i = 0; i < 5; i++) {
|
||||||
origByExample.add(Arrays.asList(orig.get(0).get(i), orig.get(1).get(i), orig.get(2).get(i)));
|
origByExample.add(Arrays.asList(orig.get(0).get(i), orig.get(1).get(i), orig.get(2).get(i)));
|
||||||
}
|
}
|
||||||
|
|
||||||
List<INDArray> batched = new ArrayList<>();
|
List<INDArray> batched = new ArrayList<>();
|
||||||
for(List<INDArray> l : orig){
|
for (List<INDArray> l : orig) {
|
||||||
batched.add(Nd4j.concat(0, l.toArray(new INDArray[5])));
|
batched.add(Nd4j.concat(0, l.toArray(new INDArray[5])));
|
||||||
}
|
}
|
||||||
|
|
||||||
NDArrayRecordBatch batch = new NDArrayRecordBatch(batched);
|
NDArrayRecordBatch batch = new NDArrayRecordBatch(batched);
|
||||||
assertEquals(5, batch.size());
|
assertEquals(5, batch.size());
|
||||||
for( int i=0; i<5; i++ ){
|
for (int i = 0; i < 5; i++) {
|
||||||
List<Writable> act = batch.get(i);
|
List<Writable> act = batch.get(i);
|
||||||
List<INDArray> unboxed = new ArrayList<>();
|
List<INDArray> unboxed = new ArrayList<>();
|
||||||
for(Writable w : act){
|
for (Writable w : act) {
|
||||||
unboxed.add(((NDArrayWritable)w).get());
|
unboxed.add(((NDArrayWritable) w).get());
|
||||||
}
|
}
|
||||||
List<INDArray> exp = origByExample.get(i);
|
List<INDArray> exp = origByExample.get(i);
|
||||||
assertEquals(exp.size(), unboxed.size());
|
assertEquals(exp.size(), unboxed.size());
|
||||||
for( int j=0; j<exp.size(); j++ ){
|
for (int j = 0; j < exp.size(); j++) {
|
||||||
assertEquals(exp.get(j), unboxed.get(j));
|
assertEquals(exp.get(j), unboxed.get(j));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Iterator<List<Writable>> iter = batch.iterator();
|
Iterator<List<Writable>> iter = batch.iterator();
|
||||||
int count = 0;
|
int count = 0;
|
||||||
while(iter.hasNext()){
|
while (iter.hasNext()) {
|
||||||
List<Writable> next = iter.next();
|
List<Writable> next = iter.next();
|
||||||
List<INDArray> unboxed = new ArrayList<>();
|
List<INDArray> unboxed = new ArrayList<>();
|
||||||
for(Writable w : next){
|
for (Writable w : next) {
|
||||||
unboxed.add(((NDArrayWritable)w).get());
|
unboxed.add(((NDArrayWritable) w).get());
|
||||||
}
|
}
|
||||||
List<INDArray> exp = origByExample.get(count++);
|
List<INDArray> exp = origByExample.get(count++);
|
||||||
assertEquals(exp.size(), unboxed.size());
|
assertEquals(exp.size(), unboxed.size());
|
||||||
for( int j=0; j<exp.size(); j++ ){
|
for (int j = 0; j < exp.size(); j++) {
|
||||||
assertEquals(exp.get(j), unboxed.get(j));
|
assertEquals(exp.get(j), unboxed.get(j));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(5, count);
|
assertEquals(5, count);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -60,10 +60,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -17,7 +17,6 @@
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
* *****************************************************************************
|
* *****************************************************************************
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.datavec.arrow;
|
package org.datavec.arrow;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
@ -42,461 +41,397 @@ import org.datavec.api.transform.schema.Schema;
|
||||||
import org.datavec.api.writable.*;
|
import org.datavec.api.writable.*;
|
||||||
import org.datavec.arrow.recordreader.ArrowRecordReader;
|
import org.datavec.arrow.recordreader.ArrowRecordReader;
|
||||||
import org.datavec.arrow.recordreader.ArrowWritableRecordBatch;
|
import org.datavec.arrow.recordreader.ArrowWritableRecordBatch;
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.rules.TemporaryFolder;
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.common.primitives.Pair;
|
import org.nd4j.common.primitives.Pair;
|
||||||
|
|
||||||
import java.io.ByteArrayOutputStream;
|
import java.io.ByteArrayOutputStream;
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.io.FileOutputStream;
|
import java.io.FileOutputStream;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
import static java.nio.channels.Channels.newChannel;
|
import static java.nio.channels.Channels.newChannel;
|
||||||
import static junit.framework.TestCase.assertTrue;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
import static org.junit.Assert.assertArrayEquals;
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import org.junit.jupiter.api.DisplayName;
|
||||||
import static org.junit.Assert.assertFalse;
|
import java.nio.file.Path;
|
||||||
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class ArrowConverterTest extends BaseND4JTest {
|
@DisplayName("Arrow Converter Test")
|
||||||
|
class ArrowConverterTest extends BaseND4JTest {
|
||||||
|
|
||||||
private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE);
|
private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE);
|
||||||
|
|
||||||
@Rule
|
@TempDir
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
public Path testDir;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testToArrayFromINDArray() {
|
@DisplayName("Test To Array From IND Array")
|
||||||
|
void testToArrayFromINDArray() {
|
||||||
Schema.Builder schemaBuilder = new Schema.Builder();
|
Schema.Builder schemaBuilder = new Schema.Builder();
|
||||||
schemaBuilder.addColumnNDArray("outputArray",new long[]{1,4});
|
schemaBuilder.addColumnNDArray("outputArray", new long[] { 1, 4 });
|
||||||
Schema schema = schemaBuilder.build();
|
Schema schema = schemaBuilder.build();
|
||||||
int numRows = 4;
|
int numRows = 4;
|
||||||
List<List<Writable>> ret = new ArrayList<>(numRows);
|
List<List<Writable>> ret = new ArrayList<>(numRows);
|
||||||
for(int i = 0; i < numRows; i++) {
|
for (int i = 0; i < numRows; i++) {
|
||||||
ret.add(Arrays.<Writable>asList(new NDArrayWritable(Nd4j.linspace(1,4,4).reshape(1, 4))));
|
ret.add(Arrays.<Writable>asList(new NDArrayWritable(Nd4j.linspace(1, 4, 4).reshape(1, 4))));
|
||||||
}
|
}
|
||||||
|
|
||||||
List<FieldVector> fieldVectors = ArrowConverter.toArrowColumns(bufferAllocator, schema, ret);
|
List<FieldVector> fieldVectors = ArrowConverter.toArrowColumns(bufferAllocator, schema, ret);
|
||||||
ArrowWritableRecordBatch arrowWritableRecordBatch = new ArrowWritableRecordBatch(fieldVectors,schema);
|
ArrowWritableRecordBatch arrowWritableRecordBatch = new ArrowWritableRecordBatch(fieldVectors, schema);
|
||||||
INDArray array = ArrowConverter.toArray(arrowWritableRecordBatch);
|
INDArray array = ArrowConverter.toArray(arrowWritableRecordBatch);
|
||||||
assertArrayEquals(new long[]{4,4},array.shape());
|
assertArrayEquals(new long[] { 4, 4 }, array.shape());
|
||||||
|
INDArray assertion = Nd4j.repeat(Nd4j.linspace(1, 4, 4), 4).reshape(4, 4);
|
||||||
INDArray assertion = Nd4j.repeat(Nd4j.linspace(1,4,4),4).reshape(4,4);
|
assertEquals(assertion, array);
|
||||||
assertEquals(assertion,array);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testArrowColumnINDArray() {
|
@DisplayName("Test Arrow Column IND Array")
|
||||||
|
void testArrowColumnINDArray() {
|
||||||
Schema.Builder schema = new Schema.Builder();
|
Schema.Builder schema = new Schema.Builder();
|
||||||
List<String> single = new ArrayList<>();
|
List<String> single = new ArrayList<>();
|
||||||
int numCols = 2;
|
int numCols = 2;
|
||||||
INDArray arr = Nd4j.linspace(1,4,4);
|
INDArray arr = Nd4j.linspace(1, 4, 4);
|
||||||
for(int i = 0; i < numCols; i++) {
|
for (int i = 0; i < numCols; i++) {
|
||||||
schema.addColumnNDArray(String.valueOf(i),new long[]{1,4});
|
schema.addColumnNDArray(String.valueOf(i), new long[] { 1, 4 });
|
||||||
single.add(String.valueOf(i));
|
single.add(String.valueOf(i));
|
||||||
}
|
}
|
||||||
|
|
||||||
Schema buildSchema = schema.build();
|
Schema buildSchema = schema.build();
|
||||||
List<List<Writable>> list = new ArrayList<>();
|
List<List<Writable>> list = new ArrayList<>();
|
||||||
List<Writable> firstRow = new ArrayList<>();
|
List<Writable> firstRow = new ArrayList<>();
|
||||||
for(int i = 0 ; i < numCols; i++) {
|
for (int i = 0; i < numCols; i++) {
|
||||||
firstRow.add(new NDArrayWritable(arr));
|
firstRow.add(new NDArrayWritable(arr));
|
||||||
}
|
}
|
||||||
|
|
||||||
list.add(firstRow);
|
list.add(firstRow);
|
||||||
|
|
||||||
List<FieldVector> fieldVectors = ArrowConverter.toArrowColumns(bufferAllocator, buildSchema, list);
|
List<FieldVector> fieldVectors = ArrowConverter.toArrowColumns(bufferAllocator, buildSchema, list);
|
||||||
assertEquals(numCols,fieldVectors.size());
|
assertEquals(numCols, fieldVectors.size());
|
||||||
assertEquals(1,fieldVectors.get(0).getValueCount());
|
assertEquals(1, fieldVectors.get(0).getValueCount());
|
||||||
assertFalse(fieldVectors.get(0).isNull(0));
|
assertFalse(fieldVectors.get(0).isNull(0));
|
||||||
|
|
||||||
ArrowWritableRecordBatch arrowWritableRecordBatch = ArrowConverter.toArrowWritables(fieldVectors, buildSchema);
|
ArrowWritableRecordBatch arrowWritableRecordBatch = ArrowConverter.toArrowWritables(fieldVectors, buildSchema);
|
||||||
assertEquals(1,arrowWritableRecordBatch.size());
|
assertEquals(1, arrowWritableRecordBatch.size());
|
||||||
|
|
||||||
Writable writable = arrowWritableRecordBatch.get(0).get(0);
|
Writable writable = arrowWritableRecordBatch.get(0).get(0);
|
||||||
assertTrue(writable instanceof NDArrayWritable);
|
assertTrue(writable instanceof NDArrayWritable);
|
||||||
NDArrayWritable ndArrayWritable = (NDArrayWritable) writable;
|
NDArrayWritable ndArrayWritable = (NDArrayWritable) writable;
|
||||||
assertEquals(arr,ndArrayWritable.get());
|
assertEquals(arr, ndArrayWritable.get());
|
||||||
|
|
||||||
Writable writable1 = ArrowConverter.fromEntry(0, fieldVectors.get(0), ColumnType.NDArray);
|
Writable writable1 = ArrowConverter.fromEntry(0, fieldVectors.get(0), ColumnType.NDArray);
|
||||||
NDArrayWritable ndArrayWritablewritable1 = (NDArrayWritable) writable1;
|
NDArrayWritable ndArrayWritablewritable1 = (NDArrayWritable) writable1;
|
||||||
System.out.println(ndArrayWritablewritable1.get());
|
System.out.println(ndArrayWritablewritable1.get());
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testArrowColumnString() {
|
@DisplayName("Test Arrow Column String")
|
||||||
|
void testArrowColumnString() {
|
||||||
Schema.Builder schema = new Schema.Builder();
|
Schema.Builder schema = new Schema.Builder();
|
||||||
List<String> single = new ArrayList<>();
|
List<String> single = new ArrayList<>();
|
||||||
for(int i = 0; i < 2; i++) {
|
for (int i = 0; i < 2; i++) {
|
||||||
schema.addColumnInteger(String.valueOf(i));
|
schema.addColumnInteger(String.valueOf(i));
|
||||||
single.add(String.valueOf(i));
|
single.add(String.valueOf(i));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
List<FieldVector> fieldVectors = ArrowConverter.toArrowColumnsStringSingle(bufferAllocator, schema.build(), single);
|
List<FieldVector> fieldVectors = ArrowConverter.toArrowColumnsStringSingle(bufferAllocator, schema.build(), single);
|
||||||
List<List<Writable>> records = ArrowConverter.toArrowWritables(fieldVectors, schema.build());
|
List<List<Writable>> records = ArrowConverter.toArrowWritables(fieldVectors, schema.build());
|
||||||
List<List<Writable>> assertion = new ArrayList<>();
|
List<List<Writable>> assertion = new ArrayList<>();
|
||||||
assertion.add(Arrays.<Writable>asList(new IntWritable(0),new IntWritable(1)));
|
assertion.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(1)));
|
||||||
assertEquals(assertion,records);
|
assertEquals(assertion, records);
|
||||||
|
|
||||||
List<List<String>> batch = new ArrayList<>();
|
List<List<String>> batch = new ArrayList<>();
|
||||||
for(int i = 0; i < 2; i++) {
|
for (int i = 0; i < 2; i++) {
|
||||||
batch.add(Arrays.asList(String.valueOf(i),String.valueOf(i)));
|
batch.add(Arrays.asList(String.valueOf(i), String.valueOf(i)));
|
||||||
}
|
}
|
||||||
|
|
||||||
List<FieldVector> fieldVectorsBatch = ArrowConverter.toArrowColumnsString(bufferAllocator, schema.build(), batch);
|
List<FieldVector> fieldVectorsBatch = ArrowConverter.toArrowColumnsString(bufferAllocator, schema.build(), batch);
|
||||||
List<List<Writable>> batchRecords = ArrowConverter.toArrowWritables(fieldVectorsBatch, schema.build());
|
List<List<Writable>> batchRecords = ArrowConverter.toArrowWritables(fieldVectorsBatch, schema.build());
|
||||||
|
|
||||||
List<List<Writable>> assertionBatch = new ArrayList<>();
|
List<List<Writable>> assertionBatch = new ArrayList<>();
|
||||||
assertionBatch.add(Arrays.<Writable>asList(new IntWritable(0),new IntWritable(0)));
|
assertionBatch.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(0)));
|
||||||
assertionBatch.add(Arrays.<Writable>asList(new IntWritable(1),new IntWritable(1)));
|
assertionBatch.add(Arrays.<Writable>asList(new IntWritable(1), new IntWritable(1)));
|
||||||
assertEquals(assertionBatch,batchRecords);
|
assertEquals(assertionBatch, batchRecords);
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testArrowBatchSetTime() {
|
@DisplayName("Test Arrow Batch Set Time")
|
||||||
|
void testArrowBatchSetTime() {
|
||||||
Schema.Builder schema = new Schema.Builder();
|
Schema.Builder schema = new Schema.Builder();
|
||||||
List<String> single = new ArrayList<>();
|
List<String> single = new ArrayList<>();
|
||||||
for(int i = 0; i < 2; i++) {
|
for (int i = 0; i < 2; i++) {
|
||||||
schema.addColumnTime(String.valueOf(i),TimeZone.getDefault());
|
schema.addColumnTime(String.valueOf(i), TimeZone.getDefault());
|
||||||
single.add(String.valueOf(i));
|
single.add(String.valueOf(i));
|
||||||
}
|
}
|
||||||
|
List<List<Writable>> input = Arrays.asList(Arrays.<Writable>asList(new LongWritable(0), new LongWritable(1)), Arrays.<Writable>asList(new LongWritable(2), new LongWritable(3)));
|
||||||
List<List<Writable>> input = Arrays.asList(
|
List<FieldVector> fieldVector = ArrowConverter.toArrowColumns(bufferAllocator, schema.build(), input);
|
||||||
Arrays.<Writable>asList(new LongWritable(0),new LongWritable(1)),
|
ArrowWritableRecordBatch writableRecordBatch = new ArrowWritableRecordBatch(fieldVector, schema.build());
|
||||||
Arrays.<Writable>asList(new LongWritable(2),new LongWritable(3))
|
|
||||||
);
|
|
||||||
|
|
||||||
List<FieldVector> fieldVector = ArrowConverter.toArrowColumns(bufferAllocator,schema.build(),input);
|
|
||||||
ArrowWritableRecordBatch writableRecordBatch = new ArrowWritableRecordBatch(fieldVector,schema.build());
|
|
||||||
List<Writable> assertion = Arrays.<Writable>asList(new LongWritable(4), new LongWritable(5));
|
List<Writable> assertion = Arrays.<Writable>asList(new LongWritable(4), new LongWritable(5));
|
||||||
writableRecordBatch.set(1, Arrays.<Writable>asList(new LongWritable(4),new LongWritable(5)));
|
writableRecordBatch.set(1, Arrays.<Writable>asList(new LongWritable(4), new LongWritable(5)));
|
||||||
List<Writable> recordTest = writableRecordBatch.get(1);
|
List<Writable> recordTest = writableRecordBatch.get(1);
|
||||||
assertEquals(assertion,recordTest);
|
assertEquals(assertion, recordTest);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testArrowBatchSet() {
|
@DisplayName("Test Arrow Batch Set")
|
||||||
|
void testArrowBatchSet() {
|
||||||
Schema.Builder schema = new Schema.Builder();
|
Schema.Builder schema = new Schema.Builder();
|
||||||
List<String> single = new ArrayList<>();
|
List<String> single = new ArrayList<>();
|
||||||
for(int i = 0; i < 2; i++) {
|
for (int i = 0; i < 2; i++) {
|
||||||
schema.addColumnInteger(String.valueOf(i));
|
schema.addColumnInteger(String.valueOf(i));
|
||||||
single.add(String.valueOf(i));
|
single.add(String.valueOf(i));
|
||||||
}
|
}
|
||||||
|
List<List<Writable>> input = Arrays.asList(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(1)), Arrays.<Writable>asList(new IntWritable(2), new IntWritable(3)));
|
||||||
List<List<Writable>> input = Arrays.asList(
|
List<FieldVector> fieldVector = ArrowConverter.toArrowColumns(bufferAllocator, schema.build(), input);
|
||||||
Arrays.<Writable>asList(new IntWritable(0),new IntWritable(1)),
|
ArrowWritableRecordBatch writableRecordBatch = new ArrowWritableRecordBatch(fieldVector, schema.build());
|
||||||
Arrays.<Writable>asList(new IntWritable(2),new IntWritable(3))
|
|
||||||
);
|
|
||||||
|
|
||||||
List<FieldVector> fieldVector = ArrowConverter.toArrowColumns(bufferAllocator,schema.build(),input);
|
|
||||||
ArrowWritableRecordBatch writableRecordBatch = new ArrowWritableRecordBatch(fieldVector,schema.build());
|
|
||||||
List<Writable> assertion = Arrays.<Writable>asList(new IntWritable(4), new IntWritable(5));
|
List<Writable> assertion = Arrays.<Writable>asList(new IntWritable(4), new IntWritable(5));
|
||||||
writableRecordBatch.set(1, Arrays.<Writable>asList(new IntWritable(4),new IntWritable(5)));
|
writableRecordBatch.set(1, Arrays.<Writable>asList(new IntWritable(4), new IntWritable(5)));
|
||||||
List<Writable> recordTest = writableRecordBatch.get(1);
|
List<Writable> recordTest = writableRecordBatch.get(1);
|
||||||
assertEquals(assertion,recordTest);
|
assertEquals(assertion, recordTest);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testArrowColumnsStringTimeSeries() {
|
@DisplayName("Test Arrow Columns String Time Series")
|
||||||
|
void testArrowColumnsStringTimeSeries() {
|
||||||
Schema.Builder schema = new Schema.Builder();
|
Schema.Builder schema = new Schema.Builder();
|
||||||
List<List<List<String>>> entries = new ArrayList<>();
|
List<List<List<String>>> entries = new ArrayList<>();
|
||||||
for(int i = 0; i < 3; i++) {
|
for (int i = 0; i < 3; i++) {
|
||||||
schema.addColumnInteger(String.valueOf(i));
|
schema.addColumnInteger(String.valueOf(i));
|
||||||
}
|
}
|
||||||
|
for (int i = 0; i < 5; i++) {
|
||||||
for(int i = 0; i < 5; i++) {
|
|
||||||
List<List<String>> arr = Arrays.asList(Arrays.asList(String.valueOf(i), String.valueOf(i), String.valueOf(i)));
|
List<List<String>> arr = Arrays.asList(Arrays.asList(String.valueOf(i), String.valueOf(i), String.valueOf(i)));
|
||||||
entries.add(arr);
|
entries.add(arr);
|
||||||
}
|
}
|
||||||
|
|
||||||
List<FieldVector> fieldVectors = ArrowConverter.toArrowColumnsStringTimeSeries(bufferAllocator, schema.build(), entries);
|
List<FieldVector> fieldVectors = ArrowConverter.toArrowColumnsStringTimeSeries(bufferAllocator, schema.build(), entries);
|
||||||
assertEquals(3,fieldVectors.size());
|
assertEquals(3, fieldVectors.size());
|
||||||
assertEquals(5,fieldVectors.get(0).getValueCount());
|
assertEquals(5, fieldVectors.get(0).getValueCount());
|
||||||
|
|
||||||
|
|
||||||
INDArray exp = Nd4j.create(5, 3);
|
INDArray exp = Nd4j.create(5, 3);
|
||||||
for( int i = 0; i < 5; i++) {
|
for (int i = 0; i < 5; i++) {
|
||||||
exp.getRow(i).assign(i);
|
exp.getRow(i).assign(i);
|
||||||
}
|
}
|
||||||
//Convert to ArrowWritableRecordBatch - note we can't do this in general with time series...
|
// Convert to ArrowWritableRecordBatch - note we can't do this in general with time series...
|
||||||
ArrowWritableRecordBatch wri = ArrowConverter.toArrowWritables(fieldVectors, schema.build());
|
ArrowWritableRecordBatch wri = ArrowConverter.toArrowWritables(fieldVectors, schema.build());
|
||||||
INDArray arr = ArrowConverter.toArray(wri);
|
INDArray arr = ArrowConverter.toArray(wri);
|
||||||
assertArrayEquals(new long[] {5,3}, arr.shape());
|
assertArrayEquals(new long[] { 5, 3 }, arr.shape());
|
||||||
|
|
||||||
|
|
||||||
assertEquals(exp, arr);
|
assertEquals(exp, arr);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testConvertVector() {
|
@DisplayName("Test Convert Vector")
|
||||||
|
void testConvertVector() {
|
||||||
Schema.Builder schema = new Schema.Builder();
|
Schema.Builder schema = new Schema.Builder();
|
||||||
List<List<List<String>>> entries = new ArrayList<>();
|
List<List<List<String>>> entries = new ArrayList<>();
|
||||||
for(int i = 0; i < 3; i++) {
|
for (int i = 0; i < 3; i++) {
|
||||||
schema.addColumnInteger(String.valueOf(i));
|
schema.addColumnInteger(String.valueOf(i));
|
||||||
}
|
}
|
||||||
|
for (int i = 0; i < 5; i++) {
|
||||||
for(int i = 0; i < 5; i++) {
|
|
||||||
List<List<String>> arr = Arrays.asList(Arrays.asList(String.valueOf(i), String.valueOf(i), String.valueOf(i)));
|
List<List<String>> arr = Arrays.asList(Arrays.asList(String.valueOf(i), String.valueOf(i), String.valueOf(i)));
|
||||||
entries.add(arr);
|
entries.add(arr);
|
||||||
}
|
}
|
||||||
|
|
||||||
List<FieldVector> fieldVectors = ArrowConverter.toArrowColumnsStringTimeSeries(bufferAllocator, schema.build(), entries);
|
List<FieldVector> fieldVectors = ArrowConverter.toArrowColumnsStringTimeSeries(bufferAllocator, schema.build(), entries);
|
||||||
INDArray arr = ArrowConverter.convertArrowVector(fieldVectors.get(0),schema.build().getType(0));
|
INDArray arr = ArrowConverter.convertArrowVector(fieldVectors.get(0), schema.build().getType(0));
|
||||||
assertEquals(5,arr.length());
|
assertEquals(5, arr.length());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCreateNDArray() throws Exception {
|
@DisplayName("Test Create ND Array")
|
||||||
|
void testCreateNDArray() throws Exception {
|
||||||
val recordsToWrite = recordToWrite();
|
val recordsToWrite = recordToWrite();
|
||||||
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
|
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
|
||||||
ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(),recordsToWrite.getFirst(),byteArrayOutputStream);
|
ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(), recordsToWrite.getFirst(), byteArrayOutputStream);
|
||||||
|
File f = testDir.toFile();
|
||||||
File f = testDir.newFolder();
|
|
||||||
|
|
||||||
File tmpFile = new File(f, "tmp-arrow-file-" + UUID.randomUUID().toString() + ".arrorw");
|
File tmpFile = new File(f, "tmp-arrow-file-" + UUID.randomUUID().toString() + ".arrorw");
|
||||||
FileOutputStream outputStream = new FileOutputStream(tmpFile);
|
FileOutputStream outputStream = new FileOutputStream(tmpFile);
|
||||||
tmpFile.deleteOnExit();
|
tmpFile.deleteOnExit();
|
||||||
ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(),recordsToWrite.getFirst(),outputStream);
|
ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(), recordsToWrite.getFirst(), outputStream);
|
||||||
outputStream.flush();
|
outputStream.flush();
|
||||||
outputStream.close();
|
outputStream.close();
|
||||||
|
|
||||||
Pair<Schema, ArrowWritableRecordBatch> schemaArrowWritableRecordBatchPair = ArrowConverter.readFromFile(tmpFile);
|
Pair<Schema, ArrowWritableRecordBatch> schemaArrowWritableRecordBatchPair = ArrowConverter.readFromFile(tmpFile);
|
||||||
assertEquals(recordsToWrite.getFirst(),schemaArrowWritableRecordBatchPair.getFirst());
|
assertEquals(recordsToWrite.getFirst(), schemaArrowWritableRecordBatchPair.getFirst());
|
||||||
assertEquals(recordsToWrite.getRight(),schemaArrowWritableRecordBatchPair.getRight().toArrayList());
|
assertEquals(recordsToWrite.getRight(), schemaArrowWritableRecordBatchPair.getRight().toArrayList());
|
||||||
|
|
||||||
byte[] arr = byteArrayOutputStream.toByteArray();
|
byte[] arr = byteArrayOutputStream.toByteArray();
|
||||||
val read = ArrowConverter.readFromBytes(arr);
|
val read = ArrowConverter.readFromBytes(arr);
|
||||||
assertEquals(recordsToWrite,read);
|
assertEquals(recordsToWrite, read);
|
||||||
|
// send file
|
||||||
//send file
|
File tmp = tmpDataFile(recordsToWrite);
|
||||||
File tmp = tmpDataFile(recordsToWrite);
|
|
||||||
ArrowRecordReader recordReader = new ArrowRecordReader();
|
ArrowRecordReader recordReader = new ArrowRecordReader();
|
||||||
|
|
||||||
recordReader.initialize(new FileSplit(tmp));
|
recordReader.initialize(new FileSplit(tmp));
|
||||||
|
|
||||||
recordReader.next();
|
recordReader.next();
|
||||||
ArrowWritableRecordBatch currentBatch = recordReader.getCurrentBatch();
|
ArrowWritableRecordBatch currentBatch = recordReader.getCurrentBatch();
|
||||||
INDArray arr2 = ArrowConverter.toArray(currentBatch);
|
INDArray arr2 = ArrowConverter.toArray(currentBatch);
|
||||||
assertEquals(2,arr2.rows());
|
assertEquals(2, arr2.rows());
|
||||||
assertEquals(2,arr2.columns());
|
assertEquals(2, arr2.columns());
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testConvertToArrowVectors() {
|
|
||||||
INDArray matrix = Nd4j.linspace(1,4,4).reshape(2,2);
|
|
||||||
val vectors = ArrowConverter.convertToArrowVector(matrix,Arrays.asList("test","test2"), ColumnType.Double,bufferAllocator);
|
|
||||||
assertEquals(matrix.rows(),vectors.size());
|
|
||||||
|
|
||||||
INDArray vector = Nd4j.linspace(1,4,4);
|
|
||||||
val vectors2 = ArrowConverter.convertToArrowVector(vector,Arrays.asList("test"), ColumnType.Double,bufferAllocator);
|
|
||||||
assertEquals(1,vectors2.size());
|
|
||||||
assertEquals(matrix.length(),vectors2.get(0).getValueCount());
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testSchemaConversionBasic() {
|
@DisplayName("Test Convert To Arrow Vectors")
|
||||||
|
void testConvertToArrowVectors() {
|
||||||
|
INDArray matrix = Nd4j.linspace(1, 4, 4).reshape(2, 2);
|
||||||
|
val vectors = ArrowConverter.convertToArrowVector(matrix, Arrays.asList("test", "test2"), ColumnType.Double, bufferAllocator);
|
||||||
|
assertEquals(matrix.rows(), vectors.size());
|
||||||
|
INDArray vector = Nd4j.linspace(1, 4, 4);
|
||||||
|
val vectors2 = ArrowConverter.convertToArrowVector(vector, Arrays.asList("test"), ColumnType.Double, bufferAllocator);
|
||||||
|
assertEquals(1, vectors2.size());
|
||||||
|
assertEquals(matrix.length(), vectors2.get(0).getValueCount());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@DisplayName("Test Schema Conversion Basic")
|
||||||
|
void testSchemaConversionBasic() {
|
||||||
Schema.Builder schemaBuilder = new Schema.Builder();
|
Schema.Builder schemaBuilder = new Schema.Builder();
|
||||||
for(int i = 0; i < 2; i++) {
|
for (int i = 0; i < 2; i++) {
|
||||||
schemaBuilder.addColumnDouble("test-" + i);
|
schemaBuilder.addColumnDouble("test-" + i);
|
||||||
schemaBuilder.addColumnInteger("testi-" + i);
|
schemaBuilder.addColumnInteger("testi-" + i);
|
||||||
schemaBuilder.addColumnLong("testl-" + i);
|
schemaBuilder.addColumnLong("testl-" + i);
|
||||||
schemaBuilder.addColumnFloat("testf-" + i);
|
schemaBuilder.addColumnFloat("testf-" + i);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
Schema schema = schemaBuilder.build();
|
Schema schema = schemaBuilder.build();
|
||||||
val schema2 = ArrowConverter.toArrowSchema(schema);
|
val schema2 = ArrowConverter.toArrowSchema(schema);
|
||||||
assertEquals(8,schema2.getFields().size());
|
assertEquals(8, schema2.getFields().size());
|
||||||
val convertedSchema = ArrowConverter.toDatavecSchema(schema2);
|
val convertedSchema = ArrowConverter.toDatavecSchema(schema2);
|
||||||
assertEquals(schema,convertedSchema);
|
assertEquals(schema, convertedSchema);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testReadSchemaAndRecordsFromByteArray() throws Exception {
|
@DisplayName("Test Read Schema And Records From Byte Array")
|
||||||
|
void testReadSchemaAndRecordsFromByteArray() throws Exception {
|
||||||
BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
|
BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
|
||||||
|
|
||||||
int valueCount = 3;
|
int valueCount = 3;
|
||||||
List<Field> fields = new ArrayList<>();
|
List<Field> fields = new ArrayList<>();
|
||||||
fields.add(ArrowConverter.field("field1",new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)));
|
fields.add(ArrowConverter.field("field1", new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)));
|
||||||
fields.add(ArrowConverter.intField("field2"));
|
fields.add(ArrowConverter.intField("field2"));
|
||||||
|
|
||||||
List<FieldVector> fieldVectors = new ArrayList<>();
|
List<FieldVector> fieldVectors = new ArrayList<>();
|
||||||
fieldVectors.add(ArrowConverter.vectorFor(allocator,"field1",new float[] {1,2,3}));
|
fieldVectors.add(ArrowConverter.vectorFor(allocator, "field1", new float[] { 1, 2, 3 }));
|
||||||
fieldVectors.add(ArrowConverter.vectorFor(allocator,"field2",new int[] {1,2,3}));
|
fieldVectors.add(ArrowConverter.vectorFor(allocator, "field2", new int[] { 1, 2, 3 }));
|
||||||
|
|
||||||
|
|
||||||
org.apache.arrow.vector.types.pojo.Schema schema = new org.apache.arrow.vector.types.pojo.Schema(fields);
|
org.apache.arrow.vector.types.pojo.Schema schema = new org.apache.arrow.vector.types.pojo.Schema(fields);
|
||||||
|
|
||||||
VectorSchemaRoot schemaRoot1 = new VectorSchemaRoot(schema, fieldVectors, valueCount);
|
VectorSchemaRoot schemaRoot1 = new VectorSchemaRoot(schema, fieldVectors, valueCount);
|
||||||
VectorUnloader vectorUnloader = new VectorUnloader(schemaRoot1);
|
VectorUnloader vectorUnloader = new VectorUnloader(schemaRoot1);
|
||||||
vectorUnloader.getRecordBatch();
|
vectorUnloader.getRecordBatch();
|
||||||
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
|
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
|
||||||
try(ArrowFileWriter arrowFileWriter = new ArrowFileWriter(schemaRoot1,null,newChannel(byteArrayOutputStream))) {
|
try (ArrowFileWriter arrowFileWriter = new ArrowFileWriter(schemaRoot1, null, newChannel(byteArrayOutputStream))) {
|
||||||
arrowFileWriter.writeBatch();
|
arrowFileWriter.writeBatch();
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
log.error("",e);
|
log.error("", e);
|
||||||
}
|
}
|
||||||
|
|
||||||
byte[] arr = byteArrayOutputStream.toByteArray();
|
byte[] arr = byteArrayOutputStream.toByteArray();
|
||||||
val arr2 = ArrowConverter.readFromBytes(arr);
|
val arr2 = ArrowConverter.readFromBytes(arr);
|
||||||
assertEquals(2,arr2.getFirst().numColumns());
|
assertEquals(2, arr2.getFirst().numColumns());
|
||||||
assertEquals(3,arr2.getRight().size());
|
assertEquals(3, arr2.getRight().size());
|
||||||
|
val arrowCols = ArrowConverter.toArrowColumns(allocator, arr2.getFirst(), arr2.getRight());
|
||||||
val arrowCols = ArrowConverter.toArrowColumns(allocator,arr2.getFirst(),arr2.getRight());
|
assertEquals(2, arrowCols.size());
|
||||||
assertEquals(2,arrowCols.size());
|
assertEquals(valueCount, arrowCols.get(0).getValueCount());
|
||||||
assertEquals(valueCount,arrowCols.get(0).getValueCount());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testVectorForEdgeCases() {
|
@DisplayName("Test Vector For Edge Cases")
|
||||||
|
void testVectorForEdgeCases() {
|
||||||
BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
|
BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
|
||||||
val vector = ArrowConverter.vectorFor(allocator,"field1",new float[]{Float.MIN_VALUE,Float.MAX_VALUE});
|
val vector = ArrowConverter.vectorFor(allocator, "field1", new float[] { Float.MIN_VALUE, Float.MAX_VALUE });
|
||||||
assertEquals(Float.MIN_VALUE,vector.get(0),1e-2);
|
assertEquals(Float.MIN_VALUE, vector.get(0), 1e-2);
|
||||||
assertEquals(Float.MAX_VALUE,vector.get(1),1e-2);
|
assertEquals(Float.MAX_VALUE, vector.get(1), 1e-2);
|
||||||
|
val vectorInt = ArrowConverter.vectorFor(allocator, "field1", new int[] { Integer.MIN_VALUE, Integer.MAX_VALUE });
|
||||||
val vectorInt = ArrowConverter.vectorFor(allocator,"field1",new int[]{Integer.MIN_VALUE,Integer.MAX_VALUE});
|
assertEquals(Integer.MIN_VALUE, vectorInt.get(0), 1e-2);
|
||||||
assertEquals(Integer.MIN_VALUE,vectorInt.get(0),1e-2);
|
assertEquals(Integer.MAX_VALUE, vectorInt.get(1), 1e-2);
|
||||||
assertEquals(Integer.MAX_VALUE,vectorInt.get(1),1e-2);
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testVectorFor() {
|
@DisplayName("Test Vector For")
|
||||||
|
void testVectorFor() {
|
||||||
BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
|
BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
|
||||||
|
val vector = ArrowConverter.vectorFor(allocator, "field1", new float[] { 1, 2, 3 });
|
||||||
val vector = ArrowConverter.vectorFor(allocator,"field1",new float[]{1,2,3});
|
assertEquals(3, vector.getValueCount());
|
||||||
assertEquals(3,vector.getValueCount());
|
assertEquals(1, vector.get(0), 1e-2);
|
||||||
assertEquals(1,vector.get(0),1e-2);
|
assertEquals(2, vector.get(1), 1e-2);
|
||||||
assertEquals(2,vector.get(1),1e-2);
|
assertEquals(3, vector.get(2), 1e-2);
|
||||||
assertEquals(3,vector.get(2),1e-2);
|
val vectorLong = ArrowConverter.vectorFor(allocator, "field1", new long[] { 1, 2, 3 });
|
||||||
|
assertEquals(3, vectorLong.getValueCount());
|
||||||
val vectorLong = ArrowConverter.vectorFor(allocator,"field1",new long[]{1,2,3});
|
assertEquals(1, vectorLong.get(0), 1e-2);
|
||||||
assertEquals(3,vectorLong.getValueCount());
|
assertEquals(2, vectorLong.get(1), 1e-2);
|
||||||
assertEquals(1,vectorLong.get(0),1e-2);
|
assertEquals(3, vectorLong.get(2), 1e-2);
|
||||||
assertEquals(2,vectorLong.get(1),1e-2);
|
val vectorInt = ArrowConverter.vectorFor(allocator, "field1", new int[] { 1, 2, 3 });
|
||||||
assertEquals(3,vectorLong.get(2),1e-2);
|
assertEquals(3, vectorInt.getValueCount());
|
||||||
|
assertEquals(1, vectorInt.get(0), 1e-2);
|
||||||
|
assertEquals(2, vectorInt.get(1), 1e-2);
|
||||||
val vectorInt = ArrowConverter.vectorFor(allocator,"field1",new int[]{1,2,3});
|
assertEquals(3, vectorInt.get(2), 1e-2);
|
||||||
assertEquals(3,vectorInt.getValueCount());
|
val vectorDouble = ArrowConverter.vectorFor(allocator, "field1", new double[] { 1, 2, 3 });
|
||||||
assertEquals(1,vectorInt.get(0),1e-2);
|
assertEquals(3, vectorDouble.getValueCount());
|
||||||
assertEquals(2,vectorInt.get(1),1e-2);
|
assertEquals(1, vectorDouble.get(0), 1e-2);
|
||||||
assertEquals(3,vectorInt.get(2),1e-2);
|
assertEquals(2, vectorDouble.get(1), 1e-2);
|
||||||
|
assertEquals(3, vectorDouble.get(2), 1e-2);
|
||||||
val vectorDouble = ArrowConverter.vectorFor(allocator,"field1",new double[]{1,2,3});
|
val vectorBool = ArrowConverter.vectorFor(allocator, "field1", new boolean[] { true, true, false });
|
||||||
assertEquals(3,vectorDouble.getValueCount());
|
assertEquals(3, vectorBool.getValueCount());
|
||||||
assertEquals(1,vectorDouble.get(0),1e-2);
|
assertEquals(1, vectorBool.get(0), 1e-2);
|
||||||
assertEquals(2,vectorDouble.get(1),1e-2);
|
assertEquals(1, vectorBool.get(1), 1e-2);
|
||||||
assertEquals(3,vectorDouble.get(2),1e-2);
|
assertEquals(0, vectorBool.get(2), 1e-2);
|
||||||
|
|
||||||
|
|
||||||
val vectorBool = ArrowConverter.vectorFor(allocator,"field1",new boolean[]{true,true,false});
|
|
||||||
assertEquals(3,vectorBool.getValueCount());
|
|
||||||
assertEquals(1,vectorBool.get(0),1e-2);
|
|
||||||
assertEquals(1,vectorBool.get(1),1e-2);
|
|
||||||
assertEquals(0,vectorBool.get(2),1e-2);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRecordReaderAndWriteFile() throws Exception {
|
@DisplayName("Test Record Reader And Write File")
|
||||||
|
void testRecordReaderAndWriteFile() throws Exception {
|
||||||
val recordsToWrite = recordToWrite();
|
val recordsToWrite = recordToWrite();
|
||||||
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
|
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
|
||||||
ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(),recordsToWrite.getFirst(),byteArrayOutputStream);
|
ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(), recordsToWrite.getFirst(), byteArrayOutputStream);
|
||||||
byte[] arr = byteArrayOutputStream.toByteArray();
|
byte[] arr = byteArrayOutputStream.toByteArray();
|
||||||
val read = ArrowConverter.readFromBytes(arr);
|
val read = ArrowConverter.readFromBytes(arr);
|
||||||
assertEquals(recordsToWrite,read);
|
assertEquals(recordsToWrite, read);
|
||||||
|
// send file
|
||||||
//send file
|
File tmp = tmpDataFile(recordsToWrite);
|
||||||
File tmp = tmpDataFile(recordsToWrite);
|
|
||||||
RecordReader recordReader = new ArrowRecordReader();
|
RecordReader recordReader = new ArrowRecordReader();
|
||||||
|
|
||||||
recordReader.initialize(new FileSplit(tmp));
|
recordReader.initialize(new FileSplit(tmp));
|
||||||
|
|
||||||
List<Writable> record = recordReader.next();
|
List<Writable> record = recordReader.next();
|
||||||
assertEquals(2,record.size());
|
assertEquals(2, record.size());
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRecordReaderMetaDataList() throws Exception {
|
@DisplayName("Test Record Reader Meta Data List")
|
||||||
|
void testRecordReaderMetaDataList() throws Exception {
|
||||||
val recordsToWrite = recordToWrite();
|
val recordsToWrite = recordToWrite();
|
||||||
//send file
|
// send file
|
||||||
File tmp = tmpDataFile(recordsToWrite);
|
File tmp = tmpDataFile(recordsToWrite);
|
||||||
RecordReader recordReader = new ArrowRecordReader();
|
RecordReader recordReader = new ArrowRecordReader();
|
||||||
RecordMetaDataIndex recordMetaDataIndex = new RecordMetaDataIndex(0,tmp.toURI(),ArrowRecordReader.class);
|
RecordMetaDataIndex recordMetaDataIndex = new RecordMetaDataIndex(0, tmp.toURI(), ArrowRecordReader.class);
|
||||||
recordReader.loadFromMetaData(Arrays.<RecordMetaData>asList(recordMetaDataIndex));
|
recordReader.loadFromMetaData(Arrays.<RecordMetaData>asList(recordMetaDataIndex));
|
||||||
|
|
||||||
Record record = recordReader.nextRecord();
|
Record record = recordReader.nextRecord();
|
||||||
assertEquals(2,record.getRecord().size());
|
assertEquals(2, record.getRecord().size());
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testDates() {
|
@DisplayName("Test Dates")
|
||||||
|
void testDates() {
|
||||||
Date now = new Date();
|
Date now = new Date();
|
||||||
BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE);
|
BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE);
|
||||||
TimeStampMilliVector timeStampMilliVector = ArrowConverter.vectorFor(bufferAllocator, "col1", new Date[]{now});
|
TimeStampMilliVector timeStampMilliVector = ArrowConverter.vectorFor(bufferAllocator, "col1", new Date[] { now });
|
||||||
assertEquals(now.getTime(),timeStampMilliVector.get(0));
|
assertEquals(now.getTime(), timeStampMilliVector.get(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRecordReaderMetaData() throws Exception {
|
@DisplayName("Test Record Reader Meta Data")
|
||||||
|
void testRecordReaderMetaData() throws Exception {
|
||||||
val recordsToWrite = recordToWrite();
|
val recordsToWrite = recordToWrite();
|
||||||
//send file
|
// send file
|
||||||
File tmp = tmpDataFile(recordsToWrite);
|
File tmp = tmpDataFile(recordsToWrite);
|
||||||
RecordReader recordReader = new ArrowRecordReader();
|
RecordReader recordReader = new ArrowRecordReader();
|
||||||
RecordMetaDataIndex recordMetaDataIndex = new RecordMetaDataIndex(0,tmp.toURI(),ArrowRecordReader.class);
|
RecordMetaDataIndex recordMetaDataIndex = new RecordMetaDataIndex(0, tmp.toURI(), ArrowRecordReader.class);
|
||||||
recordReader.loadFromMetaData(recordMetaDataIndex);
|
recordReader.loadFromMetaData(recordMetaDataIndex);
|
||||||
|
|
||||||
Record record = recordReader.nextRecord();
|
Record record = recordReader.nextRecord();
|
||||||
assertEquals(2,record.getRecord().size());
|
assertEquals(2, record.getRecord().size());
|
||||||
}
|
}
|
||||||
|
|
||||||
private File tmpDataFile(Pair<Schema,List<List<Writable>>> recordsToWrite) throws IOException {
|
private File tmpDataFile(Pair<Schema, List<List<Writable>>> recordsToWrite) throws IOException {
|
||||||
|
File f = testDir.toFile();
|
||||||
File f = testDir.newFolder();
|
// send file
|
||||||
|
File tmp = new File(f, "tmp-file-" + UUID.randomUUID().toString());
|
||||||
//send file
|
|
||||||
File tmp = new File(f,"tmp-file-" + UUID.randomUUID().toString());
|
|
||||||
tmp.mkdirs();
|
tmp.mkdirs();
|
||||||
File tmpFile = new File(tmp,"data.arrow");
|
File tmpFile = new File(tmp, "data.arrow");
|
||||||
tmpFile.deleteOnExit();
|
tmpFile.deleteOnExit();
|
||||||
FileOutputStream bufferedOutputStream = new FileOutputStream(tmpFile);
|
FileOutputStream bufferedOutputStream = new FileOutputStream(tmpFile);
|
||||||
ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(),recordsToWrite.getFirst(),bufferedOutputStream);
|
ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(), recordsToWrite.getFirst(), bufferedOutputStream);
|
||||||
bufferedOutputStream.flush();
|
bufferedOutputStream.flush();
|
||||||
bufferedOutputStream.close();
|
bufferedOutputStream.close();
|
||||||
return tmp;
|
return tmp;
|
||||||
}
|
}
|
||||||
|
|
||||||
private Pair<Schema,List<List<Writable>>> recordToWrite() {
|
private Pair<Schema, List<List<Writable>>> recordToWrite() {
|
||||||
List<List<Writable>> records = new ArrayList<>();
|
List<List<Writable>> records = new ArrayList<>();
|
||||||
records.add(Arrays.<Writable>asList(new DoubleWritable(0.0),new DoubleWritable(0.0)));
|
records.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new DoubleWritable(0.0)));
|
||||||
records.add(Arrays.<Writable>asList(new DoubleWritable(0.0),new DoubleWritable(0.0)));
|
records.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new DoubleWritable(0.0)));
|
||||||
Schema.Builder schemaBuilder = new Schema.Builder();
|
Schema.Builder schemaBuilder = new Schema.Builder();
|
||||||
for(int i = 0; i < 2; i++) {
|
for (int i = 0; i < 2; i++) {
|
||||||
schemaBuilder.addColumnFloat("col-" + i);
|
schemaBuilder.addColumnFloat("col-" + i);
|
||||||
}
|
}
|
||||||
|
return Pair.of(schemaBuilder.build(), records);
|
||||||
return Pair.of(schemaBuilder.build(),records);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,7 +17,6 @@
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
* *****************************************************************************
|
* *****************************************************************************
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.datavec.arrow;
|
package org.datavec.arrow;
|
||||||
|
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
|
@ -34,132 +33,98 @@ import org.datavec.api.writable.IntWritable;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.datavec.arrow.recordreader.ArrowRecordReader;
|
import org.datavec.arrow.recordreader.ArrowRecordReader;
|
||||||
import org.datavec.arrow.recordreader.ArrowRecordWriter;
|
import org.datavec.arrow.recordreader.ArrowRecordWriter;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
import org.nd4j.common.primitives.Triple;
|
import org.nd4j.common.primitives.Triple;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.nio.file.Files;
|
import java.nio.file.Files;
|
||||||
import java.nio.file.Path;
|
import java.nio.file.Path;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import org.junit.jupiter.api.DisplayName;
|
||||||
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
@DisplayName("Record Mapper Test")
|
||||||
|
class RecordMapperTest extends BaseND4JTest {
|
||||||
public class RecordMapperTest extends BaseND4JTest {
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testMultiWrite() throws Exception {
|
@DisplayName("Test Multi Write")
|
||||||
|
void testMultiWrite() throws Exception {
|
||||||
val recordsPair = records();
|
val recordsPair = records();
|
||||||
|
|
||||||
Path p = Files.createTempFile("arrowwritetest", ".arrow");
|
Path p = Files.createTempFile("arrowwritetest", ".arrow");
|
||||||
FileUtils.write(p.toFile(),recordsPair.getFirst());
|
FileUtils.write(p.toFile(), recordsPair.getFirst());
|
||||||
p.toFile().deleteOnExit();
|
p.toFile().deleteOnExit();
|
||||||
|
|
||||||
int numReaders = 2;
|
int numReaders = 2;
|
||||||
RecordReader[] readers = new RecordReader[numReaders];
|
RecordReader[] readers = new RecordReader[numReaders];
|
||||||
InputSplit[] splits = new InputSplit[numReaders];
|
InputSplit[] splits = new InputSplit[numReaders];
|
||||||
for(int i = 0; i < readers.length; i++) {
|
for (int i = 0; i < readers.length; i++) {
|
||||||
FileSplit split = new FileSplit(p.toFile());
|
FileSplit split = new FileSplit(p.toFile());
|
||||||
ArrowRecordReader arrowRecordReader = new ArrowRecordReader();
|
ArrowRecordReader arrowRecordReader = new ArrowRecordReader();
|
||||||
readers[i] = arrowRecordReader;
|
readers[i] = arrowRecordReader;
|
||||||
splits[i] = split;
|
splits[i] = split;
|
||||||
}
|
}
|
||||||
|
|
||||||
ArrowRecordWriter arrowRecordWriter = new ArrowRecordWriter(recordsPair.getMiddle());
|
ArrowRecordWriter arrowRecordWriter = new ArrowRecordWriter(recordsPair.getMiddle());
|
||||||
FileSplit split = new FileSplit(p.toFile());
|
FileSplit split = new FileSplit(p.toFile());
|
||||||
arrowRecordWriter.initialize(split,new NumberOfRecordsPartitioner());
|
arrowRecordWriter.initialize(split, new NumberOfRecordsPartitioner());
|
||||||
arrowRecordWriter.writeBatch(recordsPair.getRight());
|
arrowRecordWriter.writeBatch(recordsPair.getRight());
|
||||||
|
|
||||||
|
|
||||||
CSVRecordWriter csvRecordWriter = new CSVRecordWriter();
|
CSVRecordWriter csvRecordWriter = new CSVRecordWriter();
|
||||||
Path p2 = Files.createTempFile("arrowwritetest", ".csv");
|
Path p2 = Files.createTempFile("arrowwritetest", ".csv");
|
||||||
FileUtils.write(p2.toFile(),recordsPair.getFirst());
|
FileUtils.write(p2.toFile(), recordsPair.getFirst());
|
||||||
p.toFile().deleteOnExit();
|
p.toFile().deleteOnExit();
|
||||||
FileSplit outputCsv = new FileSplit(p2.toFile());
|
FileSplit outputCsv = new FileSplit(p2.toFile());
|
||||||
|
RecordMapper mapper = RecordMapper.builder().batchSize(10).inputUrl(split).outputUrl(outputCsv).partitioner(new NumberOfRecordsPartitioner()).readersToConcat(readers).splitPerReader(splits).recordWriter(csvRecordWriter).build();
|
||||||
RecordMapper mapper = RecordMapper.builder().batchSize(10).inputUrl(split)
|
|
||||||
.outputUrl(outputCsv)
|
|
||||||
.partitioner(new NumberOfRecordsPartitioner()).readersToConcat(readers)
|
|
||||||
.splitPerReader(splits)
|
|
||||||
.recordWriter(csvRecordWriter)
|
|
||||||
.build();
|
|
||||||
mapper.copy();
|
mapper.copy();
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCopyFromArrowToCsv() throws Exception {
|
@DisplayName("Test Copy From Arrow To Csv")
|
||||||
|
void testCopyFromArrowToCsv() throws Exception {
|
||||||
val recordsPair = records();
|
val recordsPair = records();
|
||||||
|
|
||||||
Path p = Files.createTempFile("arrowwritetest", ".arrow");
|
Path p = Files.createTempFile("arrowwritetest", ".arrow");
|
||||||
FileUtils.write(p.toFile(),recordsPair.getFirst());
|
FileUtils.write(p.toFile(), recordsPair.getFirst());
|
||||||
p.toFile().deleteOnExit();
|
p.toFile().deleteOnExit();
|
||||||
|
|
||||||
ArrowRecordWriter arrowRecordWriter = new ArrowRecordWriter(recordsPair.getMiddle());
|
ArrowRecordWriter arrowRecordWriter = new ArrowRecordWriter(recordsPair.getMiddle());
|
||||||
FileSplit split = new FileSplit(p.toFile());
|
FileSplit split = new FileSplit(p.toFile());
|
||||||
arrowRecordWriter.initialize(split,new NumberOfRecordsPartitioner());
|
arrowRecordWriter.initialize(split, new NumberOfRecordsPartitioner());
|
||||||
arrowRecordWriter.writeBatch(recordsPair.getRight());
|
arrowRecordWriter.writeBatch(recordsPair.getRight());
|
||||||
|
|
||||||
|
|
||||||
ArrowRecordReader arrowRecordReader = new ArrowRecordReader();
|
ArrowRecordReader arrowRecordReader = new ArrowRecordReader();
|
||||||
arrowRecordReader.initialize(split);
|
arrowRecordReader.initialize(split);
|
||||||
|
|
||||||
|
|
||||||
CSVRecordWriter csvRecordWriter = new CSVRecordWriter();
|
CSVRecordWriter csvRecordWriter = new CSVRecordWriter();
|
||||||
Path p2 = Files.createTempFile("arrowwritetest", ".csv");
|
Path p2 = Files.createTempFile("arrowwritetest", ".csv");
|
||||||
FileUtils.write(p2.toFile(),recordsPair.getFirst());
|
FileUtils.write(p2.toFile(), recordsPair.getFirst());
|
||||||
p.toFile().deleteOnExit();
|
p.toFile().deleteOnExit();
|
||||||
FileSplit outputCsv = new FileSplit(p2.toFile());
|
FileSplit outputCsv = new FileSplit(p2.toFile());
|
||||||
|
RecordMapper mapper = RecordMapper.builder().batchSize(10).inputUrl(split).outputUrl(outputCsv).partitioner(new NumberOfRecordsPartitioner()).recordReader(arrowRecordReader).recordWriter(csvRecordWriter).build();
|
||||||
RecordMapper mapper = RecordMapper.builder().batchSize(10).inputUrl(split)
|
|
||||||
.outputUrl(outputCsv)
|
|
||||||
.partitioner(new NumberOfRecordsPartitioner())
|
|
||||||
.recordReader(arrowRecordReader).recordWriter(csvRecordWriter)
|
|
||||||
.build();
|
|
||||||
mapper.copy();
|
mapper.copy();
|
||||||
|
|
||||||
CSVRecordReader recordReader = new CSVRecordReader();
|
CSVRecordReader recordReader = new CSVRecordReader();
|
||||||
recordReader.initialize(outputCsv);
|
recordReader.initialize(outputCsv);
|
||||||
|
|
||||||
|
|
||||||
List<List<Writable>> loadedCSvRecords = recordReader.next(10);
|
List<List<Writable>> loadedCSvRecords = recordReader.next(10);
|
||||||
assertEquals(10,loadedCSvRecords.size());
|
assertEquals(10, loadedCSvRecords.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCopyFromCsvToArrow() throws Exception {
|
@DisplayName("Test Copy From Csv To Arrow")
|
||||||
|
void testCopyFromCsvToArrow() throws Exception {
|
||||||
val recordsPair = records();
|
val recordsPair = records();
|
||||||
|
|
||||||
Path p = Files.createTempFile("csvwritetest", ".csv");
|
Path p = Files.createTempFile("csvwritetest", ".csv");
|
||||||
FileUtils.write(p.toFile(),recordsPair.getFirst());
|
FileUtils.write(p.toFile(), recordsPair.getFirst());
|
||||||
p.toFile().deleteOnExit();
|
p.toFile().deleteOnExit();
|
||||||
|
|
||||||
|
|
||||||
CSVRecordReader recordReader = new CSVRecordReader();
|
CSVRecordReader recordReader = new CSVRecordReader();
|
||||||
FileSplit fileSplit = new FileSplit(p.toFile());
|
FileSplit fileSplit = new FileSplit(p.toFile());
|
||||||
|
|
||||||
ArrowRecordWriter arrowRecordWriter = new ArrowRecordWriter(recordsPair.getMiddle());
|
ArrowRecordWriter arrowRecordWriter = new ArrowRecordWriter(recordsPair.getMiddle());
|
||||||
File outputFile = Files.createTempFile("outputarrow","arrow").toFile();
|
File outputFile = Files.createTempFile("outputarrow", "arrow").toFile();
|
||||||
FileSplit outputFileSplit = new FileSplit(outputFile);
|
FileSplit outputFileSplit = new FileSplit(outputFile);
|
||||||
RecordMapper mapper = RecordMapper.builder().batchSize(10).inputUrl(fileSplit)
|
RecordMapper mapper = RecordMapper.builder().batchSize(10).inputUrl(fileSplit).outputUrl(outputFileSplit).partitioner(new NumberOfRecordsPartitioner()).recordReader(recordReader).recordWriter(arrowRecordWriter).build();
|
||||||
.outputUrl(outputFileSplit).partitioner(new NumberOfRecordsPartitioner())
|
|
||||||
.recordReader(recordReader).recordWriter(arrowRecordWriter)
|
|
||||||
.build();
|
|
||||||
mapper.copy();
|
mapper.copy();
|
||||||
|
|
||||||
ArrowRecordReader arrowRecordReader = new ArrowRecordReader();
|
ArrowRecordReader arrowRecordReader = new ArrowRecordReader();
|
||||||
arrowRecordReader.initialize(outputFileSplit);
|
arrowRecordReader.initialize(outputFileSplit);
|
||||||
List<List<Writable>> next = arrowRecordReader.next(10);
|
List<List<Writable>> next = arrowRecordReader.next(10);
|
||||||
System.out.println(next);
|
System.out.println(next);
|
||||||
assertEquals(10,next.size());
|
assertEquals(10, next.size());
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private Triple<String,Schema,List<List<Writable>>> records() {
|
private Triple<String, Schema, List<List<Writable>>> records() {
|
||||||
List<List<Writable>> list = new ArrayList<>();
|
List<List<Writable>> list = new ArrayList<>();
|
||||||
StringBuilder sb = new StringBuilder();
|
StringBuilder sb = new StringBuilder();
|
||||||
int numColumns = 3;
|
int numColumns = 3;
|
||||||
|
@ -176,15 +141,10 @@ public class RecordMapperTest extends BaseND4JTest {
|
||||||
}
|
}
|
||||||
list.add(temp);
|
list.add(temp);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
Schema.Builder schemaBuilder = new Schema.Builder();
|
Schema.Builder schemaBuilder = new Schema.Builder();
|
||||||
for(int i = 0; i < numColumns; i++) {
|
for (int i = 0; i < numColumns; i++) {
|
||||||
schemaBuilder.addColumnInteger(String.valueOf(i));
|
schemaBuilder.addColumnInteger(String.valueOf(i));
|
||||||
}
|
}
|
||||||
|
return Triple.of(sb.toString(), schemaBuilder.build(), list);
|
||||||
return Triple.of(sb.toString(),schemaBuilder.build(),list);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,16 +29,16 @@ import org.datavec.api.writable.IntWritable;
|
||||||
import org.datavec.api.writable.Text;
|
import org.datavec.api.writable.Text;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.datavec.arrow.ArrowConverter;
|
import org.datavec.arrow.ArrowConverter;
|
||||||
import org.junit.Ignore;
|
import org.junit.jupiter.api.Disabled;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.junit.Assert.assertFalse;
|
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||||
|
|
||||||
public class ArrowWritableRecordTimeSeriesBatchTests extends BaseND4JTest {
|
public class ArrowWritableRecordTimeSeriesBatchTests extends BaseND4JTest {
|
||||||
|
|
||||||
|
@ -46,6 +46,7 @@ public class ArrowWritableRecordTimeSeriesBatchTests extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@Disabled
|
||||||
public void testBasicIndexing() {
|
public void testBasicIndexing() {
|
||||||
Schema.Builder schema = new Schema.Builder();
|
Schema.Builder schema = new Schema.Builder();
|
||||||
for(int i = 0; i < 3; i++) {
|
for(int i = 0; i < 3; i++) {
|
||||||
|
@ -54,9 +55,9 @@ public class ArrowWritableRecordTimeSeriesBatchTests extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
||||||
List<List<Writable>> timeStep = Arrays.asList(
|
List<List<Writable>> timeStep = Arrays.asList(
|
||||||
Arrays.<Writable>asList(new IntWritable(0),new IntWritable(1),new IntWritable(2)),
|
Arrays.asList(new IntWritable(0),new IntWritable(1),new IntWritable(2)),
|
||||||
Arrays.<Writable>asList(new IntWritable(1),new IntWritable(2),new IntWritable(3)),
|
Arrays.asList(new IntWritable(1),new IntWritable(2),new IntWritable(3)),
|
||||||
Arrays.<Writable>asList(new IntWritable(4),new IntWritable(5),new IntWritable(6))
|
Arrays.asList(new IntWritable(4),new IntWritable(5),new IntWritable(6))
|
||||||
);
|
);
|
||||||
|
|
||||||
int numTimeSteps = 5;
|
int numTimeSteps = 5;
|
||||||
|
@ -69,7 +70,7 @@ public class ArrowWritableRecordTimeSeriesBatchTests extends BaseND4JTest {
|
||||||
assertEquals(3,fieldVectors.size());
|
assertEquals(3,fieldVectors.size());
|
||||||
for(FieldVector fieldVector : fieldVectors) {
|
for(FieldVector fieldVector : fieldVectors) {
|
||||||
for(int i = 0; i < fieldVector.getValueCount(); i++) {
|
for(int i = 0; i < fieldVector.getValueCount(); i++) {
|
||||||
assertFalse("Index " + i + " was null for field vector " + fieldVector, fieldVector.isNull(i));
|
assertFalse( fieldVector.isNull(i),"Index " + i + " was null for field vector " + fieldVector);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -79,7 +80,7 @@ public class ArrowWritableRecordTimeSeriesBatchTests extends BaseND4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
//not worried about this till after next release
|
//not worried about this till after next release
|
||||||
@Ignore
|
@Disabled
|
||||||
public void testVariableLengthTS() {
|
public void testVariableLengthTS() {
|
||||||
Schema.Builder schema = new Schema.Builder()
|
Schema.Builder schema = new Schema.Builder()
|
||||||
.addColumnString("str")
|
.addColumnString("str")
|
||||||
|
|
|
@ -119,10 +119,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -17,41 +17,39 @@
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
* *****************************************************************************
|
* *****************************************************************************
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.datavec.image;
|
package org.datavec.image;
|
||||||
|
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
import org.datavec.api.io.labels.ParentPathLabelGenerator;
|
import org.datavec.api.io.labels.ParentPathLabelGenerator;
|
||||||
import org.datavec.api.split.FileSplit;
|
import org.datavec.api.split.FileSplit;
|
||||||
import org.datavec.image.recordreader.ImageRecordReader;
|
import org.datavec.image.recordreader.ImageRecordReader;
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.junit.rules.TemporaryFolder;
|
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
|
||||||
|
|
||||||
|
import org.junit.jupiter.api.Disabled;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
|
import org.nd4j.common.io.ClassPathResource;
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
import org.junit.jupiter.api.DisplayName;
|
||||||
|
import java.nio.file.Path;
|
||||||
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
@DisplayName("Label Generator Test")
|
||||||
import static org.junit.Assert.assertTrue;
|
class LabelGeneratorTest {
|
||||||
|
|
||||||
public class LabelGeneratorTest {
|
|
||||||
|
|
||||||
@Rule
|
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testParentPathLabelGenerator() throws Exception {
|
@DisplayName("Test Parent Path Label Generator")
|
||||||
//https://github.com/deeplearning4j/DataVec/issues/273
|
@Disabled
|
||||||
|
void testParentPathLabelGenerator(@TempDir Path testDir) throws Exception {
|
||||||
File orig = new ClassPathResource("datavec-data-image/testimages/class0/0.jpg").getFile();
|
File orig = new ClassPathResource("datavec-data-image/testimages/class0/0.jpg").getFile();
|
||||||
|
for (String dirPrefix : new String[] { "m.", "m" }) {
|
||||||
for(String dirPrefix : new String[]{"m.", "m"}) {
|
File f = testDir.toFile();
|
||||||
File f = testDir.newFolder();
|
|
||||||
|
|
||||||
int numDirs = 3;
|
int numDirs = 3;
|
||||||
int filesPerDir = 4;
|
int filesPerDir = 4;
|
||||||
|
|
||||||
for (int i = 0; i < numDirs; i++) {
|
for (int i = 0; i < numDirs; i++) {
|
||||||
File currentLabelDir = new File(f, dirPrefix + i);
|
File currentLabelDir = new File(f, dirPrefix + i);
|
||||||
currentLabelDir.mkdirs();
|
currentLabelDir.mkdirs();
|
||||||
|
@ -61,14 +59,11 @@ public class LabelGeneratorTest {
|
||||||
assertTrue(f3.exists());
|
assertTrue(f3.exists());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ImageRecordReader rr = new ImageRecordReader(28, 28, 1, new ParentPathLabelGenerator());
|
ImageRecordReader rr = new ImageRecordReader(28, 28, 1, new ParentPathLabelGenerator());
|
||||||
rr.initialize(new FileSplit(f));
|
rr.initialize(new FileSplit(f));
|
||||||
|
|
||||||
List<String> labelsAct = rr.getLabels();
|
List<String> labelsAct = rr.getLabels();
|
||||||
List<String> labelsExp = Arrays.asList(dirPrefix + "0", dirPrefix + "1", dirPrefix + "2");
|
List<String> labelsExp = Arrays.asList(dirPrefix + "0", dirPrefix + "1", dirPrefix + "2");
|
||||||
assertEquals(labelsExp, labelsAct);
|
assertEquals(labelsExp, labelsAct);
|
||||||
|
|
||||||
int expCount = numDirs * filesPerDir;
|
int expCount = numDirs * filesPerDir;
|
||||||
int actCount = 0;
|
int actCount = 0;
|
||||||
while (rr.hasNext()) {
|
while (rr.hasNext()) {
|
||||||
|
|
|
@ -22,8 +22,8 @@ package org.datavec.image.loader;
|
||||||
|
|
||||||
import org.apache.commons.io.FilenameUtils;
|
import org.apache.commons.io.FilenameUtils;
|
||||||
import org.datavec.api.records.reader.RecordReader;
|
import org.datavec.api.records.reader.RecordReader;
|
||||||
import org.junit.Ignore;
|
import org.junit.jupiter.api.Disabled;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
|
@ -32,9 +32,9 @@ import java.io.InputStream;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.junit.Assert.assertNotNull;
|
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
|
@ -182,7 +182,7 @@ public class LoaderTests {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Ignore // Use when confirming data is getting stored
|
@Disabled // Use when confirming data is getting stored
|
||||||
@Test
|
@Test
|
||||||
public void testProcessCifar() {
|
public void testProcessCifar() {
|
||||||
int row = 32;
|
int row = 32;
|
||||||
|
@ -208,15 +208,15 @@ public class LoaderTests {
|
||||||
int minibatch = 100;
|
int minibatch = 100;
|
||||||
int nMinibatches = 50000 / minibatch;
|
int nMinibatches = 50000 / minibatch;
|
||||||
|
|
||||||
for( int i=0; i<nMinibatches; i++ ){
|
for( int i=0; i < nMinibatches; i++) {
|
||||||
DataSet ds = loader.next(minibatch);
|
DataSet ds = loader.next(minibatch);
|
||||||
String s = String.valueOf(i);
|
String s = String.valueOf(i);
|
||||||
assertNotNull(s, ds.getFeatures());
|
assertNotNull(ds.getFeatures(),s);
|
||||||
assertNotNull(s, ds.getLabels());
|
assertNotNull(ds.getLabels(),s);
|
||||||
|
|
||||||
assertEquals(s, minibatch, ds.getFeatures().size(0));
|
assertEquals(minibatch, ds.getFeatures().size(0),s);
|
||||||
assertEquals(s, minibatch, ds.getLabels().size(0));
|
assertEquals(minibatch, ds.getLabels().size(0),s);
|
||||||
assertEquals(s, 10, ds.getLabels().size(1));
|
assertEquals(10, ds.getLabels().size(1),s);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,7 +21,7 @@
|
||||||
package org.datavec.image.loader;
|
package org.datavec.image.loader;
|
||||||
|
|
||||||
import org.datavec.image.data.Image;
|
import org.datavec.image.data.Image;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.resources.Resources;
|
import org.nd4j.common.resources.Resources;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
|
@ -32,7 +32,7 @@ import java.io.FileInputStream;
|
||||||
import java.io.InputStream;
|
import java.io.InputStream;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
|
|
||||||
public class TestImageLoader {
|
public class TestImageLoader {
|
||||||
|
|
|
@ -30,9 +30,10 @@ import org.bytedeco.javacv.Java2DFrameConverter;
|
||||||
import org.bytedeco.javacv.OpenCVFrameConverter;
|
import org.bytedeco.javacv.OpenCVFrameConverter;
|
||||||
import org.datavec.image.data.Image;
|
import org.datavec.image.data.Image;
|
||||||
import org.datavec.image.data.ImageWritable;
|
import org.datavec.image.data.ImageWritable;
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.rules.TemporaryFolder;
|
|
||||||
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import org.nd4j.common.resources.Resources;
|
import org.nd4j.common.resources.Resources;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -42,16 +43,17 @@ import org.nd4j.common.io.ClassPathResource;
|
||||||
import java.awt.image.BufferedImage;
|
import java.awt.image.BufferedImage;
|
||||||
import java.io.*;
|
import java.io.*;
|
||||||
import java.lang.reflect.Field;
|
import java.lang.reflect.Field;
|
||||||
|
import java.nio.file.Path;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
|
||||||
import org.bytedeco.leptonica.*;
|
import org.bytedeco.leptonica.*;
|
||||||
import org.bytedeco.opencv.opencv_core.*;
|
import org.bytedeco.opencv.opencv_core.*;
|
||||||
import static org.bytedeco.leptonica.global.lept.*;
|
import static org.bytedeco.leptonica.global.lept.*;
|
||||||
import static org.bytedeco.opencv.global.opencv_core.*;
|
import static org.bytedeco.opencv.global.opencv_core.*;
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.junit.Assert.assertNotEquals;
|
import static org.junit.jupiter.api.Assertions.assertNotEquals;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
import static org.junit.Assert.fail;
|
import static org.junit.jupiter.api.Assertions.fail;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
|
@ -62,8 +64,6 @@ public class TestNativeImageLoader {
|
||||||
static final long seed = 10;
|
static final long seed = 10;
|
||||||
static final Random rng = new Random(seed);
|
static final Random rng = new Random(seed);
|
||||||
|
|
||||||
@Rule
|
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testConvertPix() throws Exception {
|
public void testConvertPix() throws Exception {
|
||||||
|
@ -566,8 +566,8 @@ public class TestNativeImageLoader {
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testNativeImageLoaderEmptyStreams() throws Exception {
|
public void testNativeImageLoaderEmptyStreams(@TempDir Path testDir) throws Exception {
|
||||||
File dir = testDir.newFolder();
|
File dir = testDir.toFile();
|
||||||
File f = new File(dir, "myFile.jpg");
|
File f = new File(dir, "myFile.jpg");
|
||||||
f.createNewFile();
|
f.createNewFile();
|
||||||
|
|
||||||
|
@ -578,7 +578,7 @@ public class TestNativeImageLoader {
|
||||||
fail("Expected exception");
|
fail("Expected exception");
|
||||||
} catch (IOException e){
|
} catch (IOException e){
|
||||||
String msg = e.getMessage();
|
String msg = e.getMessage();
|
||||||
assertTrue(msg, msg.contains("decode image"));
|
assertTrue(msg.contains("decode image"),msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
try(InputStream is = new FileInputStream(f)){
|
try(InputStream is = new FileInputStream(f)){
|
||||||
|
@ -586,7 +586,7 @@ public class TestNativeImageLoader {
|
||||||
fail("Expected exception");
|
fail("Expected exception");
|
||||||
} catch (IOException e){
|
} catch (IOException e){
|
||||||
String msg = e.getMessage();
|
String msg = e.getMessage();
|
||||||
assertTrue(msg, msg.contains("decode image"));
|
assertTrue(msg.contains("decode image"),msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
try(InputStream is = new FileInputStream(f)){
|
try(InputStream is = new FileInputStream(f)){
|
||||||
|
@ -594,7 +594,7 @@ public class TestNativeImageLoader {
|
||||||
fail("Expected exception");
|
fail("Expected exception");
|
||||||
} catch (IOException e){
|
} catch (IOException e){
|
||||||
String msg = e.getMessage();
|
String msg = e.getMessage();
|
||||||
assertTrue(msg, msg.contains("decode image"));
|
assertTrue(msg.contains("decode image"),msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
try(InputStream is = new FileInputStream(f)){
|
try(InputStream is = new FileInputStream(f)){
|
||||||
|
@ -603,7 +603,7 @@ public class TestNativeImageLoader {
|
||||||
fail("Expected exception");
|
fail("Expected exception");
|
||||||
} catch (IOException e){
|
} catch (IOException e){
|
||||||
String msg = e.getMessage();
|
String msg = e.getMessage();
|
||||||
assertTrue(msg, msg.contains("decode image"));
|
assertTrue( msg.contains("decode image"),msg);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,6 @@
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
* *****************************************************************************
|
* *****************************************************************************
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.datavec.image.recordreader;
|
package org.datavec.image.recordreader;
|
||||||
|
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
|
@ -28,61 +27,56 @@ import org.datavec.api.writable.IntWritable;
|
||||||
import org.datavec.api.writable.NDArrayWritable;
|
import org.datavec.api.writable.NDArrayWritable;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.datavec.image.loader.NativeImageLoader;
|
import org.datavec.image.loader.NativeImageLoader;
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.rules.TemporaryFolder;
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import org.nd4j.common.loader.FileBatch;
|
import org.nd4j.common.loader.FileBatch;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
import org.nd4j.common.io.ClassPathResource;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
import org.junit.jupiter.api.DisplayName;
|
||||||
|
import java.nio.file.Path;
|
||||||
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
@DisplayName("File Batch Record Reader Test")
|
||||||
|
class FileBatchRecordReaderTest {
|
||||||
|
|
||||||
public class FileBatchRecordReaderTest {
|
@TempDir
|
||||||
|
public Path testDir;
|
||||||
@Rule
|
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCsv() throws Exception {
|
@DisplayName("Test Csv")
|
||||||
File extractedSourceDir = testDir.newFolder();
|
void testCsv(@TempDir Path testDir,@TempDir Path baseDirPath) throws Exception {
|
||||||
|
File extractedSourceDir = testDir.toFile();
|
||||||
new ClassPathResource("datavec-data-image/testimages").copyDirectory(extractedSourceDir);
|
new ClassPathResource("datavec-data-image/testimages").copyDirectory(extractedSourceDir);
|
||||||
File baseDir = testDir.newFolder();
|
File baseDir = baseDirPath.toFile();
|
||||||
|
|
||||||
|
|
||||||
List<File> c = new ArrayList<>(FileUtils.listFiles(extractedSourceDir, null, true));
|
List<File> c = new ArrayList<>(FileUtils.listFiles(extractedSourceDir, null, true));
|
||||||
assertEquals(6, c.size());
|
assertEquals(6, c.size());
|
||||||
|
|
||||||
Collections.sort(c, new Comparator<File>() {
|
Collections.sort(c, new Comparator<File>() {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int compare(File o1, File o2) {
|
public int compare(File o1, File o2) {
|
||||||
return o1.getPath().compareTo(o2.getPath());
|
return o1.getPath().compareTo(o2.getPath());
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
||||||
FileBatch fb = FileBatch.forFiles(c);
|
FileBatch fb = FileBatch.forFiles(c);
|
||||||
File saveFile = new File(baseDir, "saved.zip");
|
File saveFile = new File(baseDir, "saved.zip");
|
||||||
fb.writeAsZip(saveFile);
|
fb.writeAsZip(saveFile);
|
||||||
fb = FileBatch.readFromZip(saveFile);
|
fb = FileBatch.readFromZip(saveFile);
|
||||||
|
|
||||||
PathLabelGenerator labelMaker = new ParentPathLabelGenerator();
|
PathLabelGenerator labelMaker = new ParentPathLabelGenerator();
|
||||||
ImageRecordReader rr = new ImageRecordReader(32, 32, 1, labelMaker);
|
ImageRecordReader rr = new ImageRecordReader(32, 32, 1, labelMaker);
|
||||||
rr.setLabels(Arrays.asList("class0", "class1"));
|
rr.setLabels(Arrays.asList("class0", "class1"));
|
||||||
FileBatchRecordReader fbrr = new FileBatchRecordReader(rr, fb);
|
FileBatchRecordReader fbrr = new FileBatchRecordReader(rr, fb);
|
||||||
|
|
||||||
|
|
||||||
NativeImageLoader il = new NativeImageLoader(32, 32, 1);
|
NativeImageLoader il = new NativeImageLoader(32, 32, 1);
|
||||||
for( int test=0; test<3; test++) {
|
for (int test = 0; test < 3; test++) {
|
||||||
for (int i = 0; i < 6; i++) {
|
for (int i = 0; i < 6; i++) {
|
||||||
assertTrue(fbrr.hasNext());
|
assertTrue(fbrr.hasNext());
|
||||||
List<Writable> next = fbrr.next();
|
List<Writable> next = fbrr.next();
|
||||||
assertEquals(2, next.size());
|
assertEquals(2, next.size());
|
||||||
|
|
||||||
INDArray exp;
|
INDArray exp;
|
||||||
switch (i){
|
switch(i) {
|
||||||
case 0:
|
case 0:
|
||||||
exp = il.asMatrix(new File(extractedSourceDir, "class0/0.jpg"));
|
exp = il.asMatrix(new File(extractedSourceDir, "class0/0.jpg"));
|
||||||
break;
|
break;
|
||||||
|
@ -105,8 +99,7 @@ public class FileBatchRecordReaderTest {
|
||||||
throw new RuntimeException();
|
throw new RuntimeException();
|
||||||
}
|
}
|
||||||
Writable expLabel = (i < 3 ? new IntWritable(0) : new IntWritable(1));
|
Writable expLabel = (i < 3 ? new IntWritable(0) : new IntWritable(1));
|
||||||
|
assertEquals(((NDArrayWritable) next.get(0)).get(), exp);
|
||||||
assertEquals(((NDArrayWritable)next.get(0)).get(), exp);
|
|
||||||
assertEquals(expLabel, next.get(1));
|
assertEquals(expLabel, next.get(1));
|
||||||
}
|
}
|
||||||
assertFalse(fbrr.hasNext());
|
assertFalse(fbrr.hasNext());
|
||||||
|
@ -114,5 +107,4 @@ public class FileBatchRecordReaderTest {
|
||||||
fbrr.reset();
|
fbrr.reset();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -36,9 +36,10 @@ import org.datavec.api.writable.DoubleWritable;
|
||||||
import org.datavec.api.writable.NDArrayWritable;
|
import org.datavec.api.writable.NDArrayWritable;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.datavec.api.writable.batch.NDArrayRecordBatch;
|
import org.datavec.api.writable.batch.NDArrayRecordBatch;
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.rules.TemporaryFolder;
|
|
||||||
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -46,28 +47,30 @@ import org.nd4j.common.io.ClassPathResource;
|
||||||
|
|
||||||
import java.io.*;
|
import java.io.*;
|
||||||
import java.net.URI;
|
import java.net.URI;
|
||||||
|
import java.nio.file.Path;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
public class TestImageRecordReader {
|
public class TestImageRecordReader {
|
||||||
|
|
||||||
@Rule
|
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
|
||||||
|
|
||||||
@Test(expected = IllegalArgumentException.class)
|
@Test()
|
||||||
public void testEmptySplit() throws IOException {
|
public void testEmptySplit() throws IOException {
|
||||||
InputSplit data = new CollectionInputSplit(new ArrayList<URI>());
|
assertThrows(IllegalArgumentException.class,() -> {
|
||||||
new ImageRecordReader().initialize(data, null);
|
InputSplit data = new CollectionInputSplit(new ArrayList<>());
|
||||||
|
new ImageRecordReader().initialize(data, null);
|
||||||
|
});
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testMetaData() throws IOException {
|
public void testMetaData(@TempDir Path testDir) throws IOException {
|
||||||
|
|
||||||
File parentDir = testDir.newFolder();
|
File parentDir = testDir.toFile();
|
||||||
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(parentDir);
|
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(parentDir);
|
||||||
// System.out.println(f.getAbsolutePath());
|
// System.out.println(f.getAbsolutePath());
|
||||||
// System.out.println(f.getParentFile().getParentFile().getAbsolutePath());
|
// System.out.println(f.getParentFile().getParentFile().getAbsolutePath());
|
||||||
|
@ -104,11 +107,11 @@ public class TestImageRecordReader {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testImageRecordReaderLabelsOrder() throws Exception {
|
public void testImageRecordReaderLabelsOrder(@TempDir Path testDir) throws Exception {
|
||||||
//Labels order should be consistent, regardless of file iteration order
|
//Labels order should be consistent, regardless of file iteration order
|
||||||
|
|
||||||
//Idea: labels order should be consistent regardless of input file order
|
//Idea: labels order should be consistent regardless of input file order
|
||||||
File f = testDir.newFolder();
|
File f = testDir.toFile();
|
||||||
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f);
|
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f);
|
||||||
File f0 = new File(f, "/class0/0.jpg");
|
File f0 = new File(f, "/class0/0.jpg");
|
||||||
File f1 = new File(f, "/class1/A.jpg");
|
File f1 = new File(f, "/class1/A.jpg");
|
||||||
|
@ -135,11 +138,11 @@ public class TestImageRecordReader {
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testImageRecordReaderRandomization() throws Exception {
|
public void testImageRecordReaderRandomization(@TempDir Path testDir) throws Exception {
|
||||||
//Order of FileSplit+ImageRecordReader should be different after reset
|
//Order of FileSplit+ImageRecordReader should be different after reset
|
||||||
|
|
||||||
//Idea: labels order should be consistent regardless of input file order
|
//Idea: labels order should be consistent regardless of input file order
|
||||||
File f0 = testDir.newFolder();
|
File f0 = testDir.toFile();
|
||||||
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f0);
|
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f0);
|
||||||
|
|
||||||
FileSplit fs = new FileSplit(f0, new Random(12345));
|
FileSplit fs = new FileSplit(f0, new Random(12345));
|
||||||
|
@ -189,13 +192,13 @@ public class TestImageRecordReader {
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testImageRecordReaderRegression() throws Exception {
|
public void testImageRecordReaderRegression(@TempDir Path testDir) throws Exception {
|
||||||
|
|
||||||
PathLabelGenerator regressionLabelGen = new TestRegressionLabelGen();
|
PathLabelGenerator regressionLabelGen = new TestRegressionLabelGen();
|
||||||
|
|
||||||
ImageRecordReader rr = new ImageRecordReader(28, 28, 3, regressionLabelGen);
|
ImageRecordReader rr = new ImageRecordReader(28, 28, 3, regressionLabelGen);
|
||||||
|
|
||||||
File rootDir = testDir.newFolder();
|
File rootDir = testDir.toFile();
|
||||||
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(rootDir);
|
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(rootDir);
|
||||||
FileSplit fs = new FileSplit(rootDir);
|
FileSplit fs = new FileSplit(rootDir);
|
||||||
rr.initialize(fs);
|
rr.initialize(fs);
|
||||||
|
@ -244,10 +247,10 @@ public class TestImageRecordReader {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testListenerInvocationBatch() throws IOException {
|
public void testListenerInvocationBatch(@TempDir Path testDir) throws IOException {
|
||||||
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
|
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
|
||||||
ImageRecordReader rr = new ImageRecordReader(32, 32, 3, labelMaker);
|
ImageRecordReader rr = new ImageRecordReader(32, 32, 3, labelMaker);
|
||||||
File f = testDir.newFolder();
|
File f = testDir.toFile();
|
||||||
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f);
|
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f);
|
||||||
|
|
||||||
File parent = f;
|
File parent = f;
|
||||||
|
@ -260,10 +263,10 @@ public class TestImageRecordReader {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testListenerInvocationSingle() throws IOException {
|
public void testListenerInvocationSingle(@TempDir Path testDir) throws IOException {
|
||||||
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
|
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
|
||||||
ImageRecordReader rr = new ImageRecordReader(32, 32, 3, labelMaker);
|
ImageRecordReader rr = new ImageRecordReader(32, 32, 3, labelMaker);
|
||||||
File parent = testDir.newFolder();
|
File parent = testDir.toFile();
|
||||||
new ClassPathResource("datavec-data-image/testimages/class0/").copyDirectory(parent);
|
new ClassPathResource("datavec-data-image/testimages/class0/").copyDirectory(parent);
|
||||||
int numFiles = parent.list().length;
|
int numFiles = parent.list().length;
|
||||||
rr.initialize(new FileSplit(parent));
|
rr.initialize(new FileSplit(parent));
|
||||||
|
@ -315,7 +318,7 @@ public class TestImageRecordReader {
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testImageRecordReaderPathMultiLabelGenerator() throws Exception {
|
public void testImageRecordReaderPathMultiLabelGenerator(@TempDir Path testDir) throws Exception {
|
||||||
Nd4j.setDataType(DataType.FLOAT);
|
Nd4j.setDataType(DataType.FLOAT);
|
||||||
//Assumption: 2 multi-class (one hot) classification labels: 2 and 3 classes respectively
|
//Assumption: 2 multi-class (one hot) classification labels: 2 and 3 classes respectively
|
||||||
// PLUS single value (Writable) regression label
|
// PLUS single value (Writable) regression label
|
||||||
|
@ -324,7 +327,7 @@ public class TestImageRecordReader {
|
||||||
|
|
||||||
ImageRecordReader rr = new ImageRecordReader(28, 28, 3, multiLabelGen);
|
ImageRecordReader rr = new ImageRecordReader(28, 28, 3, multiLabelGen);
|
||||||
|
|
||||||
File rootDir = testDir.newFolder();
|
File rootDir = testDir.toFile();
|
||||||
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(rootDir);
|
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(rootDir);
|
||||||
FileSplit fs = new FileSplit(rootDir);
|
FileSplit fs = new FileSplit(rootDir);
|
||||||
rr.initialize(fs);
|
rr.initialize(fs);
|
||||||
|
@ -471,9 +474,9 @@ public class TestImageRecordReader {
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testNCHW_NCHW() throws Exception {
|
public void testNCHW_NCHW(@TempDir Path testDir) throws Exception {
|
||||||
//Idea: labels order should be consistent regardless of input file order
|
//Idea: labels order should be consistent regardless of input file order
|
||||||
File f0 = testDir.newFolder();
|
File f0 = testDir.toFile();
|
||||||
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f0);
|
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f0);
|
||||||
|
|
||||||
FileSplit fs0 = new FileSplit(f0, new Random(12345));
|
FileSplit fs0 = new FileSplit(f0, new Random(12345));
|
||||||
|
|
|
@ -35,9 +35,10 @@ import org.datavec.image.transform.FlipImageTransform;
|
||||||
import org.datavec.image.transform.ImageTransform;
|
import org.datavec.image.transform.ImageTransform;
|
||||||
import org.datavec.image.transform.PipelineImageTransform;
|
import org.datavec.image.transform.PipelineImageTransform;
|
||||||
import org.datavec.image.transform.ResizeImageTransform;
|
import org.datavec.image.transform.ResizeImageTransform;
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.rules.TemporaryFolder;
|
|
||||||
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.indexing.BooleanIndexing;
|
import org.nd4j.linalg.indexing.BooleanIndexing;
|
||||||
|
@ -46,24 +47,24 @@ import org.nd4j.common.io.ClassPathResource;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.net.URI;
|
import java.net.URI;
|
||||||
|
import java.nio.file.Path;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
public class TestObjectDetectionRecordReader {
|
public class TestObjectDetectionRecordReader {
|
||||||
|
|
||||||
@Rule
|
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void test() throws Exception {
|
public void test(@TempDir Path testDir) throws Exception {
|
||||||
for(boolean nchw : new boolean[]{true, false}) {
|
for(boolean nchw : new boolean[]{true, false}) {
|
||||||
ImageObjectLabelProvider lp = new TestImageObjectDetectionLabelProvider();
|
ImageObjectLabelProvider lp = new TestImageObjectDetectionLabelProvider();
|
||||||
|
|
||||||
File f = testDir.newFolder();
|
File f = testDir.toFile();
|
||||||
new ClassPathResource("datavec-data-image/objdetect/").copyDirectory(f);
|
new ClassPathResource("datavec-data-image/objdetect/").copyDirectory(f);
|
||||||
|
|
||||||
String path = new File(f, "000012.jpg").getParent();
|
String path = new File(f, "000012.jpg").getParent();
|
||||||
|
|
|
@ -21,27 +21,27 @@
|
||||||
package org.datavec.image.recordreader.objdetect;
|
package org.datavec.image.recordreader.objdetect;
|
||||||
|
|
||||||
import org.datavec.image.recordreader.objdetect.impl.VocLabelProvider;
|
import org.datavec.image.recordreader.objdetect.impl.VocLabelProvider;
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.rules.TemporaryFolder;
|
|
||||||
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
import org.nd4j.common.io.ClassPathResource;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
|
import java.nio.file.Path;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestVocLabelProvider {
|
public class TestVocLabelProvider {
|
||||||
|
|
||||||
@Rule
|
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testVocLabelProvider() throws Exception {
|
public void testVocLabelProvider(@TempDir Path testDir) throws Exception {
|
||||||
|
|
||||||
File f = testDir.newFolder();
|
File f = testDir.toFile();
|
||||||
new ClassPathResource("datavec-data-image/voc/2007/").copyDirectory(f);
|
new ClassPathResource("datavec-data-image/voc/2007/").copyDirectory(f);
|
||||||
|
|
||||||
String path = f.getAbsolutePath(); //new ClassPathResource("voc/2007/JPEGImages/000005.jpg").getFile().getParentFile().getParent();
|
String path = f.getAbsolutePath(); //new ClassPathResource("voc/2007/JPEGImages/000005.jpg").getFile().getParentFile().getParent();
|
||||||
|
|
|
@ -17,106 +17,70 @@
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
* *****************************************************************************
|
* *****************************************************************************
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.datavec.image.transform;
|
package org.datavec.image.transform;
|
||||||
|
|
||||||
import org.datavec.image.data.ImageWritable;
|
import org.datavec.image.data.ImageWritable;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
import org.junit.jupiter.api.DisplayName;
|
||||||
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
@DisplayName("Json Yaml Test")
|
||||||
import static org.junit.Assert.assertTrue;
|
class JsonYamlTest {
|
||||||
|
|
||||||
public class JsonYamlTest {
|
|
||||||
@Test
|
@Test
|
||||||
public void testJsonYamlImageTransformProcess() throws IOException {
|
@DisplayName("Test Json Yaml Image Transform Process")
|
||||||
|
void testJsonYamlImageTransformProcess() throws IOException {
|
||||||
int seed = 12345;
|
int seed = 12345;
|
||||||
Random random = new Random(seed);
|
Random random = new Random(seed);
|
||||||
|
// from org.bytedeco.javacpp.opencv_imgproc
|
||||||
//from org.bytedeco.javacpp.opencv_imgproc
|
|
||||||
int COLOR_BGR2Luv = 50;
|
int COLOR_BGR2Luv = 50;
|
||||||
int CV_BGR2GRAY = 6;
|
int CV_BGR2GRAY = 6;
|
||||||
|
ImageTransformProcess itp = new ImageTransformProcess.Builder().colorConversionTransform(COLOR_BGR2Luv).cropImageTransform(10).equalizeHistTransform(CV_BGR2GRAY).flipImageTransform(0).resizeImageTransform(300, 300).rotateImageTransform(30).scaleImageTransform(3).warpImageTransform((float) 0.5).build();
|
||||||
|
|
||||||
ImageTransformProcess itp = new ImageTransformProcess.Builder().colorConversionTransform(COLOR_BGR2Luv)
|
|
||||||
.cropImageTransform(10).equalizeHistTransform(CV_BGR2GRAY).flipImageTransform(0)
|
|
||||||
.resizeImageTransform(300, 300).rotateImageTransform(30).scaleImageTransform(3)
|
|
||||||
.warpImageTransform((float) 0.5)
|
|
||||||
|
|
||||||
// Note : since randomCropTransform use random value
|
|
||||||
// the results from each case(json, yaml, ImageTransformProcess)
|
|
||||||
// can be different
|
|
||||||
// don't use the below line
|
|
||||||
// if you uncomment it, you will get fail from below assertions
|
|
||||||
// .randomCropTransform(seed, 50, 50)
|
|
||||||
|
|
||||||
// Note : you will get "java.lang.NoClassDefFoundError: Could not initialize class org.bytedeco.javacpp.avutil"
|
|
||||||
// it needs to add the below dependency
|
|
||||||
// <dependency>
|
|
||||||
// <groupId>org.bytedeco</groupId>
|
|
||||||
// <artifactId>ffmpeg-platform</artifactId>
|
|
||||||
// </dependency>
|
|
||||||
// FFmpeg has license issues, be careful to use it
|
|
||||||
//.filterImageTransform("noise=alls=20:allf=t+u,format=rgba", 100, 100, 4)
|
|
||||||
|
|
||||||
.build();
|
|
||||||
|
|
||||||
String asJson = itp.toJson();
|
String asJson = itp.toJson();
|
||||||
String asYaml = itp.toYaml();
|
String asYaml = itp.toYaml();
|
||||||
|
// System.out.println(asJson);
|
||||||
// System.out.println(asJson);
|
// System.out.println("\n\n\n");
|
||||||
// System.out.println("\n\n\n");
|
// System.out.println(asYaml);
|
||||||
// System.out.println(asYaml);
|
|
||||||
|
|
||||||
ImageWritable img = TestImageTransform.makeRandomImage(0, 0, 3);
|
ImageWritable img = TestImageTransform.makeRandomImage(0, 0, 3);
|
||||||
ImageWritable imgJson = new ImageWritable(img.getFrame().clone());
|
ImageWritable imgJson = new ImageWritable(img.getFrame().clone());
|
||||||
ImageWritable imgYaml = new ImageWritable(img.getFrame().clone());
|
ImageWritable imgYaml = new ImageWritable(img.getFrame().clone());
|
||||||
ImageWritable imgAll = new ImageWritable(img.getFrame().clone());
|
ImageWritable imgAll = new ImageWritable(img.getFrame().clone());
|
||||||
|
|
||||||
ImageTransformProcess itpFromJson = ImageTransformProcess.fromJson(asJson);
|
ImageTransformProcess itpFromJson = ImageTransformProcess.fromJson(asJson);
|
||||||
ImageTransformProcess itpFromYaml = ImageTransformProcess.fromYaml(asYaml);
|
ImageTransformProcess itpFromYaml = ImageTransformProcess.fromYaml(asYaml);
|
||||||
|
|
||||||
List<ImageTransform> transformList = itp.getTransformList();
|
List<ImageTransform> transformList = itp.getTransformList();
|
||||||
List<ImageTransform> transformListJson = itpFromJson.getTransformList();
|
List<ImageTransform> transformListJson = itpFromJson.getTransformList();
|
||||||
List<ImageTransform> transformListYaml = itpFromYaml.getTransformList();
|
List<ImageTransform> transformListYaml = itpFromYaml.getTransformList();
|
||||||
|
|
||||||
for (int i = 0; i < transformList.size(); i++) {
|
for (int i = 0; i < transformList.size(); i++) {
|
||||||
ImageTransform it = transformList.get(i);
|
ImageTransform it = transformList.get(i);
|
||||||
ImageTransform itJson = transformListJson.get(i);
|
ImageTransform itJson = transformListJson.get(i);
|
||||||
ImageTransform itYaml = transformListYaml.get(i);
|
ImageTransform itYaml = transformListYaml.get(i);
|
||||||
|
|
||||||
System.out.println(i + "\t" + it);
|
System.out.println(i + "\t" + it);
|
||||||
|
|
||||||
img = it.transform(img);
|
img = it.transform(img);
|
||||||
imgJson = itJson.transform(imgJson);
|
imgJson = itJson.transform(imgJson);
|
||||||
imgYaml = itYaml.transform(imgYaml);
|
imgYaml = itYaml.transform(imgYaml);
|
||||||
|
|
||||||
if (it instanceof RandomCropTransform) {
|
if (it instanceof RandomCropTransform) {
|
||||||
assertTrue(img.getFrame().imageHeight == imgJson.getFrame().imageHeight);
|
assertTrue(img.getFrame().imageHeight == imgJson.getFrame().imageHeight);
|
||||||
assertTrue(img.getFrame().imageWidth == imgJson.getFrame().imageWidth);
|
assertTrue(img.getFrame().imageWidth == imgJson.getFrame().imageWidth);
|
||||||
|
|
||||||
assertTrue(img.getFrame().imageHeight == imgYaml.getFrame().imageHeight);
|
assertTrue(img.getFrame().imageHeight == imgYaml.getFrame().imageHeight);
|
||||||
assertTrue(img.getFrame().imageWidth == imgYaml.getFrame().imageWidth);
|
assertTrue(img.getFrame().imageWidth == imgYaml.getFrame().imageWidth);
|
||||||
} else if (it instanceof FilterImageTransform) {
|
} else if (it instanceof FilterImageTransform) {
|
||||||
assertEquals(img.getFrame().imageHeight, imgJson.getFrame().imageHeight);
|
assertEquals(img.getFrame().imageHeight, imgJson.getFrame().imageHeight);
|
||||||
assertEquals(img.getFrame().imageWidth, imgJson.getFrame().imageWidth);
|
assertEquals(img.getFrame().imageWidth, imgJson.getFrame().imageWidth);
|
||||||
assertEquals(img.getFrame().imageChannels, imgJson.getFrame().imageChannels);
|
assertEquals(img.getFrame().imageChannels, imgJson.getFrame().imageChannels);
|
||||||
|
|
||||||
assertEquals(img.getFrame().imageHeight, imgYaml.getFrame().imageHeight);
|
assertEquals(img.getFrame().imageHeight, imgYaml.getFrame().imageHeight);
|
||||||
assertEquals(img.getFrame().imageWidth, imgYaml.getFrame().imageWidth);
|
assertEquals(img.getFrame().imageWidth, imgYaml.getFrame().imageWidth);
|
||||||
assertEquals(img.getFrame().imageChannels, imgYaml.getFrame().imageChannels);
|
assertEquals(img.getFrame().imageChannels, imgYaml.getFrame().imageChannels);
|
||||||
} else {
|
} else {
|
||||||
assertEquals(img, imgJson);
|
assertEquals(img, imgJson);
|
||||||
|
|
||||||
assertEquals(img, imgYaml);
|
assertEquals(img, imgYaml);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
imgAll = itp.execute(imgAll);
|
imgAll = itp.execute(imgAll);
|
||||||
|
|
||||||
assertEquals(imgAll, img);
|
assertEquals(imgAll, img);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,56 +17,50 @@
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
* *****************************************************************************
|
* *****************************************************************************
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.datavec.image.transform;
|
package org.datavec.image.transform;
|
||||||
|
|
||||||
import org.bytedeco.javacv.Frame;
|
import org.bytedeco.javacv.Frame;
|
||||||
import org.datavec.image.data.ImageWritable;
|
import org.datavec.image.data.ImageWritable;
|
||||||
import org.junit.Before;
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import org.junit.jupiter.api.DisplayName;
|
||||||
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
@DisplayName("Resize Image Transform Test")
|
||||||
|
class ResizeImageTransformTest {
|
||||||
public class ResizeImageTransformTest {
|
|
||||||
@Before
|
|
||||||
public void setUp() throws Exception {
|
|
||||||
|
|
||||||
|
@BeforeEach
|
||||||
|
void setUp() throws Exception {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testResizeUpscale1() throws Exception {
|
@DisplayName("Test Resize Upscale 1")
|
||||||
|
void testResizeUpscale1() throws Exception {
|
||||||
ImageWritable srcImg = TestImageTransform.makeRandomImage(32, 32, 3);
|
ImageWritable srcImg = TestImageTransform.makeRandomImage(32, 32, 3);
|
||||||
|
|
||||||
ResizeImageTransform transform = new ResizeImageTransform(200, 200);
|
ResizeImageTransform transform = new ResizeImageTransform(200, 200);
|
||||||
|
|
||||||
ImageWritable dstImg = transform.transform(srcImg);
|
ImageWritable dstImg = transform.transform(srcImg);
|
||||||
|
|
||||||
Frame f = dstImg.getFrame();
|
Frame f = dstImg.getFrame();
|
||||||
assertEquals(f.imageWidth, 200);
|
assertEquals(f.imageWidth, 200);
|
||||||
assertEquals(f.imageHeight, 200);
|
assertEquals(f.imageHeight, 200);
|
||||||
|
float[] coordinates = { 100, 200 };
|
||||||
float[] coordinates = {100, 200};
|
|
||||||
float[] transformed = transform.query(coordinates);
|
float[] transformed = transform.query(coordinates);
|
||||||
assertEquals(200f * 100 / 32, transformed[0], 0);
|
assertEquals(200f * 100 / 32, transformed[0], 0);
|
||||||
assertEquals(200f * 200 / 32, transformed[1], 0);
|
assertEquals(200f * 200 / 32, transformed[1], 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testResizeDownscale() throws Exception {
|
@DisplayName("Test Resize Downscale")
|
||||||
|
void testResizeDownscale() throws Exception {
|
||||||
ImageWritable srcImg = TestImageTransform.makeRandomImage(571, 443, 3);
|
ImageWritable srcImg = TestImageTransform.makeRandomImage(571, 443, 3);
|
||||||
|
|
||||||
ResizeImageTransform transform = new ResizeImageTransform(200, 200);
|
ResizeImageTransform transform = new ResizeImageTransform(200, 200);
|
||||||
|
|
||||||
ImageWritable dstImg = transform.transform(srcImg);
|
ImageWritable dstImg = transform.transform(srcImg);
|
||||||
|
|
||||||
Frame f = dstImg.getFrame();
|
Frame f = dstImg.getFrame();
|
||||||
assertEquals(f.imageWidth, 200);
|
assertEquals(f.imageWidth, 200);
|
||||||
assertEquals(f.imageHeight, 200);
|
assertEquals(f.imageHeight, 200);
|
||||||
|
float[] coordinates = { 300, 400 };
|
||||||
float[] coordinates = {300, 400};
|
|
||||||
float[] transformed = transform.query(coordinates);
|
float[] transformed = transform.query(coordinates);
|
||||||
assertEquals(200f * 300 / 443, transformed[0], 0);
|
assertEquals(200f * 300 / 443, transformed[0], 0);
|
||||||
assertEquals(200f * 400 / 571, transformed[1], 0);
|
assertEquals(200f * 400 / 571, transformed[1], 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,8 +28,8 @@ import org.nd4j.common.io.ClassPathResource;
|
||||||
import org.nd4j.common.primitives.Pair;
|
import org.nd4j.common.primitives.Pair;
|
||||||
import org.datavec.image.data.ImageWritable;
|
import org.datavec.image.data.ImageWritable;
|
||||||
import org.datavec.image.loader.NativeImageLoader;
|
import org.datavec.image.loader.NativeImageLoader;
|
||||||
import org.junit.Ignore;
|
import org.junit.jupiter.api.Disabled;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
import java.awt.*;
|
import java.awt.*;
|
||||||
import java.util.LinkedList;
|
import java.util.LinkedList;
|
||||||
|
@ -40,7 +40,7 @@ import org.bytedeco.opencv.opencv_core.*;
|
||||||
|
|
||||||
import static org.bytedeco.opencv.global.opencv_core.*;
|
import static org.bytedeco.opencv.global.opencv_core.*;
|
||||||
import static org.bytedeco.opencv.global.opencv_imgproc.*;
|
import static org.bytedeco.opencv.global.opencv_imgproc.*;
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
|
@ -255,7 +255,7 @@ public class TestImageTransform {
|
||||||
assertEquals(22, transformed[1], 0);
|
assertEquals(22, transformed[1], 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Ignore
|
@Disabled
|
||||||
@Test
|
@Test
|
||||||
public void testFilterImageTransform() throws Exception {
|
public void testFilterImageTransform() throws Exception {
|
||||||
ImageWritable writable = makeRandomImage(0, 0, 4);
|
ImageWritable writable = makeRandomImage(0, 0, 4);
|
||||||
|
|
|
@ -59,10 +59,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -57,10 +57,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -17,37 +17,34 @@
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
* *****************************************************************************
|
* *****************************************************************************
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.datavec.poi.excel;
|
package org.datavec.poi.excel;
|
||||||
|
|
||||||
import org.datavec.api.records.reader.RecordReader;
|
import org.datavec.api.records.reader.RecordReader;
|
||||||
import org.datavec.api.split.FileSplit;
|
import org.datavec.api.split.FileSplit;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
import org.nd4j.common.io.ClassPathResource;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
import org.junit.jupiter.api.DisplayName;
|
||||||
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
@DisplayName("Excel Record Reader Test")
|
||||||
import static org.junit.Assert.assertTrue;
|
class ExcelRecordReaderTest {
|
||||||
|
|
||||||
public class ExcelRecordReaderTest {
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testSimple() throws Exception {
|
@DisplayName("Test Simple")
|
||||||
|
void testSimple() throws Exception {
|
||||||
RecordReader excel = new ExcelRecordReader();
|
RecordReader excel = new ExcelRecordReader();
|
||||||
excel.initialize(new FileSplit(new ClassPathResource("datavec-excel/testsheet.xlsx").getFile()));
|
excel.initialize(new FileSplit(new ClassPathResource("datavec-excel/testsheet.xlsx").getFile()));
|
||||||
assertTrue(excel.hasNext());
|
assertTrue(excel.hasNext());
|
||||||
List<Writable> next = excel.next();
|
List<Writable> next = excel.next();
|
||||||
assertEquals(3,next.size());
|
assertEquals(3, next.size());
|
||||||
|
|
||||||
RecordReader headerReader = new ExcelRecordReader(1);
|
RecordReader headerReader = new ExcelRecordReader(1);
|
||||||
headerReader.initialize(new FileSplit(new ClassPathResource("datavec-excel/testsheetheader.xlsx").getFile()));
|
headerReader.initialize(new FileSplit(new ClassPathResource("datavec-excel/testsheetheader.xlsx").getFile()));
|
||||||
assertTrue(excel.hasNext());
|
assertTrue(excel.hasNext());
|
||||||
List<Writable> next2 = excel.next();
|
List<Writable> next2 = excel.next();
|
||||||
assertEquals(3,next2.size());
|
assertEquals(3, next2.size());
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,7 +17,6 @@
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
* *****************************************************************************
|
* *****************************************************************************
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.datavec.poi.excel;
|
package org.datavec.poi.excel;
|
||||||
|
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
|
@ -26,44 +25,45 @@ import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
|
||||||
import org.datavec.api.transform.schema.Schema;
|
import org.datavec.api.transform.schema.Schema;
|
||||||
import org.datavec.api.writable.IntWritable;
|
import org.datavec.api.writable.IntWritable;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.junit.rules.TemporaryFolder;
|
|
||||||
import org.nd4j.common.primitives.Triple;
|
|
||||||
|
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
|
import org.nd4j.common.primitives.Triple;
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import org.junit.jupiter.api.DisplayName;
|
||||||
|
import java.nio.file.Path;
|
||||||
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
@DisplayName("Excel Record Writer Test")
|
||||||
|
class ExcelRecordWriterTest {
|
||||||
|
|
||||||
public class ExcelRecordWriterTest {
|
@TempDir
|
||||||
|
public Path testDir;
|
||||||
@Rule
|
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testWriter() throws Exception {
|
@DisplayName("Test Writer")
|
||||||
|
void testWriter() throws Exception {
|
||||||
ExcelRecordWriter excelRecordWriter = new ExcelRecordWriter();
|
ExcelRecordWriter excelRecordWriter = new ExcelRecordWriter();
|
||||||
val records = records();
|
val records = records();
|
||||||
File tmpDir = testDir.newFolder();
|
File tmpDir = testDir.toFile();
|
||||||
File outputFile = new File(tmpDir,"testexcel.xlsx");
|
File outputFile = new File(tmpDir, "testexcel.xlsx");
|
||||||
outputFile.deleteOnExit();
|
outputFile.deleteOnExit();
|
||||||
FileSplit fileSplit = new FileSplit(outputFile);
|
FileSplit fileSplit = new FileSplit(outputFile);
|
||||||
excelRecordWriter.initialize(fileSplit,new NumberOfRecordsPartitioner());
|
excelRecordWriter.initialize(fileSplit, new NumberOfRecordsPartitioner());
|
||||||
excelRecordWriter.writeBatch(records.getRight());
|
excelRecordWriter.writeBatch(records.getRight());
|
||||||
excelRecordWriter.close();
|
excelRecordWriter.close();
|
||||||
File parentFile = outputFile.getParentFile();
|
File parentFile = outputFile.getParentFile();
|
||||||
assertEquals(1,parentFile.list().length);
|
assertEquals(1, parentFile.list().length);
|
||||||
|
|
||||||
ExcelRecordReader excelRecordReader = new ExcelRecordReader();
|
ExcelRecordReader excelRecordReader = new ExcelRecordReader();
|
||||||
excelRecordReader.initialize(fileSplit);
|
excelRecordReader.initialize(fileSplit);
|
||||||
List<List<Writable>> next = excelRecordReader.next(10);
|
List<List<Writable>> next = excelRecordReader.next(10);
|
||||||
assertEquals(10,next.size());
|
assertEquals(10, next.size());
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private Triple<String,Schema,List<List<Writable>>> records() {
|
private Triple<String, Schema, List<List<Writable>>> records() {
|
||||||
List<List<Writable>> list = new ArrayList<>();
|
List<List<Writable>> list = new ArrayList<>();
|
||||||
StringBuilder sb = new StringBuilder();
|
StringBuilder sb = new StringBuilder();
|
||||||
int numColumns = 3;
|
int numColumns = 3;
|
||||||
|
@ -80,13 +80,10 @@ public class ExcelRecordWriterTest {
|
||||||
}
|
}
|
||||||
list.add(temp);
|
list.add(temp);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
Schema.Builder schemaBuilder = new Schema.Builder();
|
Schema.Builder schemaBuilder = new Schema.Builder();
|
||||||
for(int i = 0; i < numColumns; i++) {
|
for (int i = 0; i < numColumns; i++) {
|
||||||
schemaBuilder.addColumnInteger(String.valueOf(i));
|
schemaBuilder.addColumnInteger(String.valueOf(i));
|
||||||
}
|
}
|
||||||
|
return Triple.of(sb.toString(), schemaBuilder.build(), list);
|
||||||
return Triple.of(sb.toString(),schemaBuilder.build(),list);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -65,10 +65,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -17,14 +17,12 @@
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
* *****************************************************************************
|
* *****************************************************************************
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.datavec.api.records.reader.impl;
|
package org.datavec.api.records.reader.impl;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.junit.Assert.assertFalse;
|
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||||
import static org.junit.Assert.assertNotNull;
|
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.net.URI;
|
import java.net.URI;
|
||||||
import java.sql.Connection;
|
import java.sql.Connection;
|
||||||
|
@ -49,53 +47,57 @@ import org.datavec.api.writable.IntWritable;
|
||||||
import org.datavec.api.writable.LongWritable;
|
import org.datavec.api.writable.LongWritable;
|
||||||
import org.datavec.api.writable.Text;
|
import org.datavec.api.writable.Text;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.After;
|
import org.junit.jupiter.api.AfterEach;
|
||||||
import org.junit.Before;
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.junit.rules.TemporaryFolder;
|
|
||||||
|
|
||||||
public class JDBCRecordReaderTest {
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
|
import org.junit.jupiter.api.DisplayName;
|
||||||
|
import java.nio.file.Path;
|
||||||
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||||
|
|
||||||
@Rule
|
@DisplayName("Jdbc Record Reader Test")
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
class JDBCRecordReaderTest {
|
||||||
|
|
||||||
|
@TempDir
|
||||||
|
public Path testDir;
|
||||||
|
|
||||||
Connection conn;
|
Connection conn;
|
||||||
|
|
||||||
EmbeddedDataSource dataSource;
|
EmbeddedDataSource dataSource;
|
||||||
|
|
||||||
private final String dbName = "datavecTests";
|
private final String dbName = "datavecTests";
|
||||||
|
|
||||||
private final String driverClassName = "org.apache.derby.jdbc.EmbeddedDriver";
|
private final String driverClassName = "org.apache.derby.jdbc.EmbeddedDriver";
|
||||||
|
|
||||||
@Before
|
@BeforeEach
|
||||||
public void setUp() throws Exception {
|
void setUp() throws Exception {
|
||||||
File f = testDir.newFolder();
|
File f = testDir.toFile();
|
||||||
System.setProperty("derby.system.home", f.getAbsolutePath());
|
System.setProperty("derby.system.home", f.getAbsolutePath());
|
||||||
|
|
||||||
dataSource = new EmbeddedDataSource();
|
dataSource = new EmbeddedDataSource();
|
||||||
dataSource.setDatabaseName(dbName);
|
dataSource.setDatabaseName(dbName);
|
||||||
dataSource.setCreateDatabase("create");
|
dataSource.setCreateDatabase("create");
|
||||||
conn = dataSource.getConnection();
|
conn = dataSource.getConnection();
|
||||||
|
|
||||||
TestDb.dropTables(conn);
|
TestDb.dropTables(conn);
|
||||||
TestDb.buildCoffeeTable(conn);
|
TestDb.buildCoffeeTable(conn);
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@AfterEach
|
||||||
public void tearDown() throws Exception {
|
void tearDown() throws Exception {
|
||||||
DbUtils.closeQuietly(conn);
|
DbUtils.closeQuietly(conn);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testSimpleIter() throws Exception {
|
@DisplayName("Test Simple Iter")
|
||||||
|
void testSimpleIter() throws Exception {
|
||||||
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
|
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
|
||||||
List<List<Writable>> records = new ArrayList<>();
|
List<List<Writable>> records = new ArrayList<>();
|
||||||
while (reader.hasNext()) {
|
while (reader.hasNext()) {
|
||||||
List<Writable> values = reader.next();
|
List<Writable> values = reader.next();
|
||||||
records.add(values);
|
records.add(values);
|
||||||
}
|
}
|
||||||
|
|
||||||
assertFalse(records.isEmpty());
|
assertFalse(records.isEmpty());
|
||||||
|
|
||||||
List<Writable> first = records.get(0);
|
List<Writable> first = records.get(0);
|
||||||
assertEquals(new Text("Bolivian Dark"), first.get(0));
|
assertEquals(new Text("Bolivian Dark"), first.get(0));
|
||||||
assertEquals(new Text("14-001"), first.get(1));
|
assertEquals(new Text("14-001"), first.get(1));
|
||||||
|
@ -104,39 +106,43 @@ public class JDBCRecordReaderTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testSimpleWithListener() throws Exception {
|
@DisplayName("Test Simple With Listener")
|
||||||
|
void testSimpleWithListener() throws Exception {
|
||||||
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
|
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
|
||||||
RecordListener recordListener = new LogRecordListener();
|
RecordListener recordListener = new LogRecordListener();
|
||||||
reader.setListeners(recordListener);
|
reader.setListeners(recordListener);
|
||||||
reader.next();
|
reader.next();
|
||||||
|
|
||||||
assertTrue(recordListener.invoked());
|
assertTrue(recordListener.invoked());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testReset() throws Exception {
|
@DisplayName("Test Reset")
|
||||||
|
void testReset() throws Exception {
|
||||||
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
|
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
|
||||||
List<List<Writable>> records = new ArrayList<>();
|
List<List<Writable>> records = new ArrayList<>();
|
||||||
records.add(reader.next());
|
records.add(reader.next());
|
||||||
reader.reset();
|
reader.reset();
|
||||||
records.add(reader.next());
|
records.add(reader.next());
|
||||||
|
|
||||||
assertEquals(2, records.size());
|
assertEquals(2, records.size());
|
||||||
assertEquals(new Text("Bolivian Dark"), records.get(0).get(0));
|
assertEquals(new Text("Bolivian Dark"), records.get(0).get(0));
|
||||||
assertEquals(new Text("Bolivian Dark"), records.get(1).get(0));
|
assertEquals(new Text("Bolivian Dark"), records.get(1).get(0));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected = IllegalStateException.class)
|
@Test
|
||||||
public void testLackingDataSourceShouldFail() throws Exception {
|
@DisplayName("Test Lacking Data Source Should Fail")
|
||||||
try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM Coffee")) {
|
void testLackingDataSourceShouldFail() {
|
||||||
reader.initialize(null);
|
assertThrows(IllegalStateException.class, () -> {
|
||||||
}
|
try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM Coffee")) {
|
||||||
|
reader.initialize(null);
|
||||||
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testConfigurationDataSourceInitialization() throws Exception {
|
@DisplayName("Test Configuration Data Source Initialization")
|
||||||
|
void testConfigurationDataSourceInitialization() throws Exception {
|
||||||
try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM Coffee")) {
|
try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM Coffee")) {
|
||||||
Configuration conf = new Configuration();
|
Configuration conf = new Configuration();
|
||||||
conf.set(JDBCRecordReader.JDBC_URL, "jdbc:derby:" + dbName + ";create=true");
|
conf.set(JDBCRecordReader.JDBC_URL, "jdbc:derby:" + dbName + ";create=true");
|
||||||
|
@ -146,28 +152,33 @@ public class JDBCRecordReaderTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected = IllegalArgumentException.class)
|
@Test
|
||||||
public void testInitConfigurationMissingParametersShouldFail() throws Exception {
|
@DisplayName("Test Init Configuration Missing Parameters Should Fail")
|
||||||
try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM Coffee")) {
|
void testInitConfigurationMissingParametersShouldFail() {
|
||||||
Configuration conf = new Configuration();
|
assertThrows(IllegalArgumentException.class, () -> {
|
||||||
conf.set(JDBCRecordReader.JDBC_URL, "should fail anyway");
|
try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM Coffee")) {
|
||||||
reader.initialize(conf, null);
|
Configuration conf = new Configuration();
|
||||||
}
|
conf.set(JDBCRecordReader.JDBC_URL, "should fail anyway");
|
||||||
}
|
reader.initialize(conf, null);
|
||||||
|
}
|
||||||
@Test(expected = UnsupportedOperationException.class)
|
});
|
||||||
public void testRecordDataInputStreamShouldFail() throws Exception {
|
|
||||||
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
|
|
||||||
reader.record(null, null);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testLoadFromMetaData() throws Exception {
|
@DisplayName("Test Record Data Input Stream Should Fail")
|
||||||
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
|
void testRecordDataInputStreamShouldFail() {
|
||||||
RecordMetaDataJdbc rmd = new RecordMetaDataJdbc(new URI(conn.getMetaData().getURL()),
|
assertThrows(UnsupportedOperationException.class, () -> {
|
||||||
"SELECT * FROM Coffee WHERE ProdNum = ?", Collections.singletonList("14-001"), reader.getClass());
|
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
|
||||||
|
reader.record(null, null);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@DisplayName("Test Load From Meta Data")
|
||||||
|
void testLoadFromMetaData() throws Exception {
|
||||||
|
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
|
||||||
|
RecordMetaDataJdbc rmd = new RecordMetaDataJdbc(new URI(conn.getMetaData().getURL()), "SELECT * FROM Coffee WHERE ProdNum = ?", Collections.singletonList("14-001"), reader.getClass());
|
||||||
Record res = reader.loadFromMetaData(rmd);
|
Record res = reader.loadFromMetaData(rmd);
|
||||||
assertNotNull(res);
|
assertNotNull(res);
|
||||||
assertEquals(new Text("Bolivian Dark"), res.getRecord().get(0));
|
assertEquals(new Text("Bolivian Dark"), res.getRecord().get(0));
|
||||||
|
@ -177,7 +188,8 @@ public class JDBCRecordReaderTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testNextRecord() throws Exception {
|
@DisplayName("Test Next Record")
|
||||||
|
void testNextRecord() throws Exception {
|
||||||
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
|
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
|
||||||
Record r = reader.nextRecord();
|
Record r = reader.nextRecord();
|
||||||
List<Writable> fields = r.getRecord();
|
List<Writable> fields = r.getRecord();
|
||||||
|
@ -193,7 +205,8 @@ public class JDBCRecordReaderTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testNextRecordAndRecover() throws Exception {
|
@DisplayName("Test Next Record And Recover")
|
||||||
|
void testNextRecordAndRecover() throws Exception {
|
||||||
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
|
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
|
||||||
Record r = reader.nextRecord();
|
Record r = reader.nextRecord();
|
||||||
List<Writable> fields = r.getRecord();
|
List<Writable> fields = r.getRecord();
|
||||||
|
@ -208,67 +221,89 @@ public class JDBCRecordReaderTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resetting the record reader when initialized as forward only should fail
|
// Resetting the record reader when initialized as forward only should fail
|
||||||
@Test(expected = RuntimeException.class)
|
@Test
|
||||||
public void testResetForwardOnlyShouldFail() throws Exception {
|
@DisplayName("Test Reset Forward Only Should Fail")
|
||||||
try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM Coffee", dataSource)) {
|
void testResetForwardOnlyShouldFail() {
|
||||||
Configuration conf = new Configuration();
|
assertThrows(RuntimeException.class, () -> {
|
||||||
conf.setInt(JDBCRecordReader.JDBC_RESULTSET_TYPE, ResultSet.TYPE_FORWARD_ONLY);
|
try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM Coffee", dataSource)) {
|
||||||
reader.initialize(conf, null);
|
Configuration conf = new Configuration();
|
||||||
reader.next();
|
conf.setInt(JDBCRecordReader.JDBC_RESULTSET_TYPE, ResultSet.TYPE_FORWARD_ONLY);
|
||||||
reader.reset();
|
reader.initialize(conf, null);
|
||||||
}
|
reader.next();
|
||||||
|
reader.reset();
|
||||||
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testReadAllTypes() throws Exception {
|
@DisplayName("Test Read All Types")
|
||||||
|
void testReadAllTypes() throws Exception {
|
||||||
TestDb.buildAllTypesTable(conn);
|
TestDb.buildAllTypesTable(conn);
|
||||||
try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM AllTypes", dataSource)) {
|
try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM AllTypes", dataSource)) {
|
||||||
reader.initialize(null);
|
reader.initialize(null);
|
||||||
List<Writable> item = reader.next();
|
List<Writable> item = reader.next();
|
||||||
|
|
||||||
assertEquals(item.size(), 15);
|
assertEquals(item.size(), 15);
|
||||||
assertEquals(BooleanWritable.class, item.get(0).getClass()); // boolean to boolean
|
// boolean to boolean
|
||||||
assertEquals(Text.class, item.get(1).getClass()); // date to text
|
assertEquals(BooleanWritable.class, item.get(0).getClass());
|
||||||
assertEquals(Text.class, item.get(2).getClass()); // time to text
|
// date to text
|
||||||
assertEquals(Text.class, item.get(3).getClass()); // timestamp to text
|
assertEquals(Text.class, item.get(1).getClass());
|
||||||
assertEquals(Text.class, item.get(4).getClass()); // char to text
|
// time to text
|
||||||
assertEquals(Text.class, item.get(5).getClass()); // long varchar to text
|
assertEquals(Text.class, item.get(2).getClass());
|
||||||
assertEquals(Text.class, item.get(6).getClass()); // varchar to text
|
// timestamp to text
|
||||||
assertEquals(DoubleWritable.class,
|
assertEquals(Text.class, item.get(3).getClass());
|
||||||
item.get(7).getClass()); // float to double (derby's float is an alias of double by default)
|
// char to text
|
||||||
assertEquals(FloatWritable.class, item.get(8).getClass()); // real to float
|
assertEquals(Text.class, item.get(4).getClass());
|
||||||
assertEquals(DoubleWritable.class, item.get(9).getClass()); // decimal to double
|
// long varchar to text
|
||||||
assertEquals(DoubleWritable.class, item.get(10).getClass()); // numeric to double
|
assertEquals(Text.class, item.get(5).getClass());
|
||||||
assertEquals(DoubleWritable.class, item.get(11).getClass()); // double to double
|
// varchar to text
|
||||||
assertEquals(IntWritable.class, item.get(12).getClass()); // integer to integer
|
assertEquals(Text.class, item.get(6).getClass());
|
||||||
assertEquals(IntWritable.class, item.get(13).getClass()); // small int to integer
|
assertEquals(DoubleWritable.class, // float to double (derby's float is an alias of double by default)
|
||||||
assertEquals(LongWritable.class, item.get(14).getClass()); // bigint to long
|
item.get(7).getClass());
|
||||||
|
// real to float
|
||||||
|
assertEquals(FloatWritable.class, item.get(8).getClass());
|
||||||
|
// decimal to double
|
||||||
|
assertEquals(DoubleWritable.class, item.get(9).getClass());
|
||||||
|
// numeric to double
|
||||||
|
assertEquals(DoubleWritable.class, item.get(10).getClass());
|
||||||
|
// double to double
|
||||||
|
assertEquals(DoubleWritable.class, item.get(11).getClass());
|
||||||
|
// integer to integer
|
||||||
|
assertEquals(IntWritable.class, item.get(12).getClass());
|
||||||
|
// small int to integer
|
||||||
|
assertEquals(IntWritable.class, item.get(13).getClass());
|
||||||
|
// bigint to long
|
||||||
|
assertEquals(LongWritable.class, item.get(14).getClass());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected = RuntimeException.class)
|
@Test
|
||||||
public void testNextNoMoreShouldFail() throws Exception {
|
@DisplayName("Test Next No More Should Fail")
|
||||||
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
|
void testNextNoMoreShouldFail() {
|
||||||
while (reader.hasNext()) {
|
assertThrows(RuntimeException.class, () -> {
|
||||||
|
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
|
||||||
|
while (reader.hasNext()) {
|
||||||
|
reader.next();
|
||||||
|
}
|
||||||
reader.next();
|
reader.next();
|
||||||
}
|
}
|
||||||
reader.next();
|
});
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected = IllegalArgumentException.class)
|
@Test
|
||||||
public void testInvalidMetadataShouldFail() throws Exception {
|
@DisplayName("Test Invalid Metadata Should Fail")
|
||||||
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
|
void testInvalidMetadataShouldFail() {
|
||||||
RecordMetaDataLine md = new RecordMetaDataLine(1, new URI("file://test"), JDBCRecordReader.class);
|
assertThrows(IllegalArgumentException.class, () -> {
|
||||||
reader.loadFromMetaData(md);
|
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
|
||||||
}
|
RecordMetaDataLine md = new RecordMetaDataLine(1, new URI("file://test"), JDBCRecordReader.class);
|
||||||
|
reader.loadFromMetaData(md);
|
||||||
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
private JDBCRecordReader getInitializedReader(String query) throws Exception {
|
private JDBCRecordReader getInitializedReader(String query) throws Exception {
|
||||||
int[] indices = {1}; // ProdNum column
|
// ProdNum column
|
||||||
JDBCRecordReader reader = new JDBCRecordReader(query, dataSource, "SELECT * FROM Coffee WHERE ProdNum = ?",
|
int[] indices = { 1 };
|
||||||
indices);
|
JDBCRecordReader reader = new JDBCRecordReader(query, dataSource, "SELECT * FROM Coffee WHERE ProdNum = ?", indices);
|
||||||
reader.setTrimStrings(true);
|
reader.setTrimStrings(true);
|
||||||
reader.initialize(null);
|
reader.initialize(null);
|
||||||
return reader;
|
return reader;
|
||||||
|
|
|
@ -61,25 +61,18 @@
|
||||||
<artifactId>nd4j-common</artifactId>
|
<artifactId>nd4j-common</artifactId>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.datavec</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
<artifactId>datavec-geo</artifactId>
|
<artifactId>python4j-numpy</artifactId>
|
||||||
<version>${project.version}</version>
|
<version>${project.version}</version>
|
||||||
<scope>test</scope>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.datavec</groupId>
|
|
||||||
<artifactId>datavec-python</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
<scope>test</scope>
|
|
||||||
</dependency>
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -36,14 +36,14 @@ import org.datavec.api.writable.LongWritable;
|
||||||
import org.datavec.api.writable.Text;
|
import org.datavec.api.writable.Text;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.joda.time.DateTimeZone;
|
import org.joda.time.DateTimeZone;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
import org.nd4j.common.io.ClassPathResource;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class LocalTransformProcessRecordReaderTests {
|
public class LocalTransformProcessRecordReaderTests {
|
||||||
|
|
||||||
|
|
|
@ -29,9 +29,9 @@ import org.datavec.api.transform.schema.Schema;
|
||||||
import org.datavec.api.util.ndarray.RecordConverter;
|
import org.datavec.api.util.ndarray.RecordConverter;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.datavec.local.transforms.AnalyzeLocal;
|
import org.datavec.local.transforms.AnalyzeLocal;
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.rules.TemporaryFolder;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
import org.nd4j.common.io.ClassPathResource;
|
||||||
|
@ -39,12 +39,11 @@ import org.nd4j.common.io.ClassPathResource;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestAnalyzeLocal {
|
public class TestAnalyzeLocal {
|
||||||
|
|
||||||
@Rule
|
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testAnalysisBasic() throws Exception {
|
public void testAnalysisBasic() throws Exception {
|
||||||
|
@ -72,7 +71,7 @@ public class TestAnalyzeLocal {
|
||||||
INDArray mean = arr.mean(0);
|
INDArray mean = arr.mean(0);
|
||||||
INDArray std = arr.std(0);
|
INDArray std = arr.std(0);
|
||||||
|
|
||||||
for( int i=0; i<5; i++ ){
|
for( int i = 0; i < 5; i++) {
|
||||||
double m = ((NumericalColumnAnalysis)da.getColumnAnalysis().get(i)).getMean();
|
double m = ((NumericalColumnAnalysis)da.getColumnAnalysis().get(i)).getMean();
|
||||||
double stddev = ((NumericalColumnAnalysis)da.getColumnAnalysis().get(i)).getSampleStdev();
|
double stddev = ((NumericalColumnAnalysis)da.getColumnAnalysis().get(i)).getSampleStdev();
|
||||||
assertEquals(mean.getDouble(i), m, 1e-3);
|
assertEquals(mean.getDouble(i), m, 1e-3);
|
||||||
|
|
|
@ -27,7 +27,7 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
|
||||||
import org.datavec.api.split.FileSplit;
|
import org.datavec.api.split.FileSplit;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
|
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
import org.nd4j.common.io.ClassPathResource;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
|
@ -36,8 +36,8 @@ import java.util.List;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
public class TestLineRecordReaderFunction {
|
public class TestLineRecordReaderFunction {
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,7 @@ import org.datavec.api.writable.NDArrayWritable;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
|
|
||||||
import org.datavec.local.transforms.misc.NDArrayToWritablesFunction;
|
import org.datavec.local.transforms.misc.NDArrayToWritablesFunction;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
@ -33,7 +33,7 @@ import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestNDArrayToWritablesFunction {
|
public class TestNDArrayToWritablesFunction {
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,7 @@ import org.datavec.api.writable.NDArrayWritable;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
|
|
||||||
import org.datavec.local.transforms.misc.WritablesToNDArrayFunction;
|
import org.datavec.local.transforms.misc.WritablesToNDArrayFunction;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -33,7 +33,7 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestWritablesToNDArrayFunction {
|
public class TestWritablesToNDArrayFunction {
|
||||||
|
|
||||||
|
|
|
@ -30,12 +30,12 @@ import org.datavec.api.writable.Writable;
|
||||||
|
|
||||||
import org.datavec.local.transforms.misc.SequenceWritablesToStringFunction;
|
import org.datavec.local.transforms.misc.SequenceWritablesToStringFunction;
|
||||||
import org.datavec.local.transforms.misc.WritablesToStringFunction;
|
import org.datavec.local.transforms.misc.WritablesToStringFunction;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestWritablesToStringFunctions {
|
public class TestWritablesToStringFunctions {
|
||||||
|
|
||||||
|
|
|
@ -17,10 +17,8 @@
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
* *****************************************************************************
|
* *****************************************************************************
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.datavec.local.transforms.transform;
|
package org.datavec.local.transforms.transform;
|
||||||
|
|
||||||
|
|
||||||
import org.datavec.api.transform.MathFunction;
|
import org.datavec.api.transform.MathFunction;
|
||||||
import org.datavec.api.transform.MathOp;
|
import org.datavec.api.transform.MathOp;
|
||||||
import org.datavec.api.transform.ReduceOp;
|
import org.datavec.api.transform.ReduceOp;
|
||||||
|
@ -31,108 +29,85 @@ import org.datavec.api.transform.reduce.Reducer;
|
||||||
import org.datavec.api.transform.schema.Schema;
|
import org.datavec.api.transform.schema.Schema;
|
||||||
import org.datavec.api.transform.schema.SequenceSchema;
|
import org.datavec.api.transform.schema.SequenceSchema;
|
||||||
import org.datavec.api.writable.*;
|
import org.datavec.api.writable.*;
|
||||||
import org.datavec.python.PythonTransform;
|
|
||||||
|
|
||||||
import org.datavec.local.transforms.LocalTransformExecutor;
|
import org.datavec.local.transforms.LocalTransformExecutor;
|
||||||
import org.junit.Ignore;
|
import org.junit.jupiter.api.Disabled;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import org.junit.jupiter.api.DisplayName;
|
||||||
|
import static java.time.Duration.ofMillis;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertTimeout;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
@DisplayName("Execution Test")
|
||||||
|
class ExecutionTest {
|
||||||
public class ExecutionTest {
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testExecutionNdarray() {
|
@DisplayName("Test Execution Ndarray")
|
||||||
Schema schema = new Schema.Builder()
|
void testExecutionNdarray() {
|
||||||
.addColumnNDArray("first",new long[]{1,32577})
|
Schema schema = new Schema.Builder().addColumnNDArray("first", new long[] { 1, 32577 }).addColumnNDArray("second", new long[] { 1, 32577 }).build();
|
||||||
.addColumnNDArray("second",new long[]{1,32577}).build();
|
TransformProcess transformProcess = new TransformProcess.Builder(schema).ndArrayMathFunctionTransform("first", MathFunction.SIN).ndArrayMathFunctionTransform("second", MathFunction.COS).build();
|
||||||
|
|
||||||
TransformProcess transformProcess = new TransformProcess.Builder(schema)
|
|
||||||
.ndArrayMathFunctionTransform("first", MathFunction.SIN)
|
|
||||||
.ndArrayMathFunctionTransform("second",MathFunction.COS)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
List<List<Writable>> functions = new ArrayList<>();
|
List<List<Writable>> functions = new ArrayList<>();
|
||||||
List<Writable> firstRow = new ArrayList<>();
|
List<Writable> firstRow = new ArrayList<>();
|
||||||
INDArray firstArr = Nd4j.linspace(1,4,4);
|
INDArray firstArr = Nd4j.linspace(1, 4, 4);
|
||||||
INDArray secondArr = Nd4j.linspace(1,4,4);
|
INDArray secondArr = Nd4j.linspace(1, 4, 4);
|
||||||
firstRow.add(new NDArrayWritable(firstArr));
|
firstRow.add(new NDArrayWritable(firstArr));
|
||||||
firstRow.add(new NDArrayWritable(secondArr));
|
firstRow.add(new NDArrayWritable(secondArr));
|
||||||
functions.add(firstRow);
|
functions.add(firstRow);
|
||||||
|
|
||||||
List<List<Writable>> execute = LocalTransformExecutor.execute(functions, transformProcess);
|
List<List<Writable>> execute = LocalTransformExecutor.execute(functions, transformProcess);
|
||||||
INDArray firstResult = ((NDArrayWritable) execute.get(0).get(0)).get();
|
INDArray firstResult = ((NDArrayWritable) execute.get(0).get(0)).get();
|
||||||
INDArray secondResult = ((NDArrayWritable) execute.get(0).get(1)).get();
|
INDArray secondResult = ((NDArrayWritable) execute.get(0).get(1)).get();
|
||||||
|
|
||||||
INDArray expected = Transforms.sin(firstArr);
|
INDArray expected = Transforms.sin(firstArr);
|
||||||
INDArray secondExpected = Transforms.cos(secondArr);
|
INDArray secondExpected = Transforms.cos(secondArr);
|
||||||
assertEquals(expected,firstResult);
|
assertEquals(expected, firstResult);
|
||||||
assertEquals(secondExpected,secondResult);
|
assertEquals(secondExpected, secondResult);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testExecutionSimple() {
|
@DisplayName("Test Execution Simple")
|
||||||
Schema schema = new Schema.Builder().addColumnInteger("col0")
|
void testExecutionSimple() {
|
||||||
.addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").
|
Schema schema = new Schema.Builder().addColumnInteger("col0").addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").addColumnFloat("col3").build();
|
||||||
addColumnFloat("col3").build();
|
TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1").doubleMathOp("col2", MathOp.Add, 10.0).floatMathOp("col3", MathOp.Add, 5f).build();
|
||||||
|
|
||||||
TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1")
|
|
||||||
.doubleMathOp("col2", MathOp.Add, 10.0).floatMathOp("col3", MathOp.Add, 5f).build();
|
|
||||||
|
|
||||||
List<List<Writable>> inputData = new ArrayList<>();
|
List<List<Writable>> inputData = new ArrayList<>();
|
||||||
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1), new FloatWritable(0.3f)));
|
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1), new FloatWritable(0.3f)));
|
||||||
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1), new FloatWritable(1.7f)));
|
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1), new FloatWritable(1.7f)));
|
||||||
inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1), new FloatWritable(3.6f)));
|
inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1), new FloatWritable(3.6f)));
|
||||||
|
|
||||||
List<List<Writable>> rdd = (inputData);
|
List<List<Writable>> rdd = (inputData);
|
||||||
|
|
||||||
List<List<Writable>> out = new ArrayList<>(LocalTransformExecutor.execute(rdd, tp));
|
List<List<Writable>> out = new ArrayList<>(LocalTransformExecutor.execute(rdd, tp));
|
||||||
|
|
||||||
Collections.sort(out, new Comparator<List<Writable>>() {
|
Collections.sort(out, new Comparator<List<Writable>>() {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int compare(List<Writable> o1, List<Writable> o2) {
|
public int compare(List<Writable> o1, List<Writable> o2) {
|
||||||
return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
|
return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
List<List<Writable>> expected = new ArrayList<>();
|
List<List<Writable>> expected = new ArrayList<>();
|
||||||
expected.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1), new FloatWritable(5.3f)));
|
expected.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1), new FloatWritable(5.3f)));
|
||||||
expected.add(Arrays.<Writable>asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1), new FloatWritable(6.7f)));
|
expected.add(Arrays.<Writable>asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1), new FloatWritable(6.7f)));
|
||||||
expected.add(Arrays.<Writable>asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1), new FloatWritable(8.6f)));
|
expected.add(Arrays.<Writable>asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1), new FloatWritable(8.6f)));
|
||||||
|
|
||||||
assertEquals(expected, out);
|
assertEquals(expected, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testFilter() {
|
@DisplayName("Test Filter")
|
||||||
Schema filterSchema = new Schema.Builder()
|
void testFilter() {
|
||||||
.addColumnDouble("col1").addColumnDouble("col2")
|
Schema filterSchema = new Schema.Builder().addColumnDouble("col1").addColumnDouble("col2").addColumnDouble("col3").build();
|
||||||
.addColumnDouble("col3").build();
|
|
||||||
List<List<Writable>> inputData = new ArrayList<>();
|
List<List<Writable>> inputData = new ArrayList<>();
|
||||||
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new DoubleWritable(1), new DoubleWritable(0.1)));
|
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new DoubleWritable(1), new DoubleWritable(0.1)));
|
||||||
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new DoubleWritable(3), new DoubleWritable(1.1)));
|
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new DoubleWritable(3), new DoubleWritable(1.1)));
|
||||||
inputData.add(Arrays.<Writable>asList(new IntWritable(2), new DoubleWritable(3), new DoubleWritable(2.1)));
|
inputData.add(Arrays.<Writable>asList(new IntWritable(2), new DoubleWritable(3), new DoubleWritable(2.1)));
|
||||||
TransformProcess transformProcess = new TransformProcess.Builder(filterSchema)
|
TransformProcess transformProcess = new TransformProcess.Builder(filterSchema).filter(new DoubleColumnCondition("col1", ConditionOp.LessThan, 1)).build();
|
||||||
.filter(new DoubleColumnCondition("col1",ConditionOp.LessThan,1)).build();
|
|
||||||
List<List<Writable>> execute = LocalTransformExecutor.execute(inputData, transformProcess);
|
List<List<Writable>> execute = LocalTransformExecutor.execute(inputData, transformProcess);
|
||||||
assertEquals(2,execute.size());
|
assertEquals(2, execute.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testExecutionSequence() {
|
@DisplayName("Test Execution Sequence")
|
||||||
|
void testExecutionSequence() {
|
||||||
Schema schema = new SequenceSchema.Builder().addColumnInteger("col0")
|
Schema schema = new SequenceSchema.Builder().addColumnInteger("col0").addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").build();
|
||||||
.addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").build();
|
TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1").doubleMathOp("col2", MathOp.Add, 10.0).build();
|
||||||
|
|
||||||
TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1")
|
|
||||||
.doubleMathOp("col2", MathOp.Add, 10.0).build();
|
|
||||||
|
|
||||||
List<List<List<Writable>>> inputSequences = new ArrayList<>();
|
List<List<List<Writable>>> inputSequences = new ArrayList<>();
|
||||||
List<List<Writable>> seq1 = new ArrayList<>();
|
List<List<Writable>> seq1 = new ArrayList<>();
|
||||||
seq1.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
|
seq1.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
|
||||||
|
@ -141,21 +116,17 @@ public class ExecutionTest {
|
||||||
List<List<Writable>> seq2 = new ArrayList<>();
|
List<List<Writable>> seq2 = new ArrayList<>();
|
||||||
seq2.add(Arrays.<Writable>asList(new IntWritable(3), new Text("state0"), new DoubleWritable(3.1)));
|
seq2.add(Arrays.<Writable>asList(new IntWritable(3), new Text("state0"), new DoubleWritable(3.1)));
|
||||||
seq2.add(Arrays.<Writable>asList(new IntWritable(4), new Text("state1"), new DoubleWritable(4.1)));
|
seq2.add(Arrays.<Writable>asList(new IntWritable(4), new Text("state1"), new DoubleWritable(4.1)));
|
||||||
|
|
||||||
inputSequences.add(seq1);
|
inputSequences.add(seq1);
|
||||||
inputSequences.add(seq2);
|
inputSequences.add(seq2);
|
||||||
|
List<List<List<Writable>>> rdd = (inputSequences);
|
||||||
List<List<List<Writable>>> rdd = (inputSequences);
|
|
||||||
|
|
||||||
List<List<List<Writable>>> out = LocalTransformExecutor.executeSequenceToSequence(rdd, tp);
|
List<List<List<Writable>>> out = LocalTransformExecutor.executeSequenceToSequence(rdd, tp);
|
||||||
|
|
||||||
Collections.sort(out, new Comparator<List<List<Writable>>>() {
|
Collections.sort(out, new Comparator<List<List<Writable>>>() {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int compare(List<List<Writable>> o1, List<List<Writable>> o2) {
|
public int compare(List<List<Writable>> o1, List<List<Writable>> o2) {
|
||||||
return -Integer.compare(o1.size(), o2.size());
|
return -Integer.compare(o1.size(), o2.size());
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
List<List<List<Writable>>> expectedSequence = new ArrayList<>();
|
List<List<List<Writable>>> expectedSequence = new ArrayList<>();
|
||||||
List<List<Writable>> seq1e = new ArrayList<>();
|
List<List<Writable>> seq1e = new ArrayList<>();
|
||||||
seq1e.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1)));
|
seq1e.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1)));
|
||||||
|
@ -164,121 +135,37 @@ public class ExecutionTest {
|
||||||
List<List<Writable>> seq2e = new ArrayList<>();
|
List<List<Writable>> seq2e = new ArrayList<>();
|
||||||
seq2e.add(Arrays.<Writable>asList(new IntWritable(3), new IntWritable(0), new DoubleWritable(13.1)));
|
seq2e.add(Arrays.<Writable>asList(new IntWritable(3), new IntWritable(0), new DoubleWritable(13.1)));
|
||||||
seq2e.add(Arrays.<Writable>asList(new IntWritable(4), new IntWritable(1), new DoubleWritable(14.1)));
|
seq2e.add(Arrays.<Writable>asList(new IntWritable(4), new IntWritable(1), new DoubleWritable(14.1)));
|
||||||
|
|
||||||
expectedSequence.add(seq1e);
|
expectedSequence.add(seq1e);
|
||||||
expectedSequence.add(seq2e);
|
expectedSequence.add(seq2e);
|
||||||
|
|
||||||
assertEquals(expectedSequence, out);
|
assertEquals(expectedSequence, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testReductionGlobal() {
|
@DisplayName("Test Reduction Global")
|
||||||
|
void testReductionGlobal() {
|
||||||
List<List<Writable>> in = Arrays.asList(
|
List<List<Writable>> in = Arrays.asList(Arrays.<Writable>asList(new Text("first"), new DoubleWritable(3.0)), Arrays.<Writable>asList(new Text("second"), new DoubleWritable(5.0)));
|
||||||
Arrays.<Writable>asList(new Text("first"), new DoubleWritable(3.0)),
|
|
||||||
Arrays.<Writable>asList(new Text("second"), new DoubleWritable(5.0))
|
|
||||||
);
|
|
||||||
|
|
||||||
List<List<Writable>> inData = in;
|
List<List<Writable>> inData = in;
|
||||||
|
Schema s = new Schema.Builder().addColumnString("textCol").addColumnDouble("doubleCol").build();
|
||||||
Schema s = new Schema.Builder()
|
TransformProcess tp = new TransformProcess.Builder(s).reduce(new Reducer.Builder(ReduceOp.TakeFirst).takeFirstColumns("textCol").meanColumns("doubleCol").build()).build();
|
||||||
.addColumnString("textCol")
|
|
||||||
.addColumnDouble("doubleCol")
|
|
||||||
.build();
|
|
||||||
|
|
||||||
TransformProcess tp = new TransformProcess.Builder(s)
|
|
||||||
.reduce(new Reducer.Builder(ReduceOp.TakeFirst)
|
|
||||||
.takeFirstColumns("textCol")
|
|
||||||
.meanColumns("doubleCol").build())
|
|
||||||
.build();
|
|
||||||
|
|
||||||
List<List<Writable>> outRdd = LocalTransformExecutor.execute(inData, tp);
|
List<List<Writable>> outRdd = LocalTransformExecutor.execute(inData, tp);
|
||||||
|
|
||||||
List<List<Writable>> out = outRdd;
|
List<List<Writable>> out = outRdd;
|
||||||
|
|
||||||
List<List<Writable>> expOut = Collections.singletonList(Arrays.<Writable>asList(new Text("first"), new DoubleWritable(4.0)));
|
List<List<Writable>> expOut = Collections.singletonList(Arrays.<Writable>asList(new Text("first"), new DoubleWritable(4.0)));
|
||||||
|
|
||||||
assertEquals(expOut, out);
|
assertEquals(expOut, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testReductionByKey(){
|
@DisplayName("Test Reduction By Key")
|
||||||
|
void testReductionByKey() {
|
||||||
List<List<Writable>> in = Arrays.asList(
|
List<List<Writable>> in = Arrays.asList(Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(3.0)), Arrays.<Writable>asList(new IntWritable(0), new Text("second"), new DoubleWritable(5.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(30.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("s"), new DoubleWritable(50.0)));
|
||||||
Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(3.0)),
|
|
||||||
Arrays.<Writable>asList(new IntWritable(0), new Text("second"), new DoubleWritable(5.0)),
|
|
||||||
Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(30.0)),
|
|
||||||
Arrays.<Writable>asList(new IntWritable(1), new Text("s"), new DoubleWritable(50.0))
|
|
||||||
);
|
|
||||||
|
|
||||||
List<List<Writable>> inData = in;
|
List<List<Writable>> inData = in;
|
||||||
|
Schema s = new Schema.Builder().addColumnInteger("intCol").addColumnString("textCol").addColumnDouble("doubleCol").build();
|
||||||
Schema s = new Schema.Builder()
|
TransformProcess tp = new TransformProcess.Builder(s).reduce(new Reducer.Builder(ReduceOp.TakeFirst).keyColumns("intCol").takeFirstColumns("textCol").meanColumns("doubleCol").build()).build();
|
||||||
.addColumnInteger("intCol")
|
|
||||||
.addColumnString("textCol")
|
|
||||||
.addColumnDouble("doubleCol")
|
|
||||||
.build();
|
|
||||||
|
|
||||||
TransformProcess tp = new TransformProcess.Builder(s)
|
|
||||||
.reduce(new Reducer.Builder(ReduceOp.TakeFirst)
|
|
||||||
.keyColumns("intCol")
|
|
||||||
.takeFirstColumns("textCol")
|
|
||||||
.meanColumns("doubleCol").build())
|
|
||||||
.build();
|
|
||||||
|
|
||||||
List<List<Writable>> outRdd = LocalTransformExecutor.execute(inData, tp);
|
List<List<Writable>> outRdd = LocalTransformExecutor.execute(inData, tp);
|
||||||
|
|
||||||
List<List<Writable>> out = outRdd;
|
List<List<Writable>> out = outRdd;
|
||||||
|
List<List<Writable>> expOut = Arrays.asList(Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(4.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(40.0)));
|
||||||
List<List<Writable>> expOut = Arrays.asList(
|
|
||||||
Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(4.0)),
|
|
||||||
Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(40.0)));
|
|
||||||
|
|
||||||
out = new ArrayList<>(out);
|
out = new ArrayList<>(out);
|
||||||
Collections.sort(
|
Collections.sort(out, Comparator.comparingInt(o -> o.get(0).toInt()));
|
||||||
out, new Comparator<List<Writable>>() {
|
|
||||||
@Override
|
|
||||||
public int compare(List<Writable> o1, List<Writable> o2) {
|
|
||||||
return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
assertEquals(expOut, out);
|
assertEquals(expOut, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(timeout = 60000L)
|
|
||||||
@Ignore("AB 2019/05/21 - Fine locally, timeouts on CI - Issue #7657 and #7771")
|
|
||||||
public void testPythonExecutionNdarray()throws Exception{
|
|
||||||
Schema schema = new Schema.Builder()
|
|
||||||
.addColumnNDArray("first",new long[]{1,32577})
|
|
||||||
.addColumnNDArray("second",new long[]{1,32577}).build();
|
|
||||||
|
|
||||||
TransformProcess transformProcess = new TransformProcess.Builder(schema)
|
|
||||||
.transform(
|
|
||||||
PythonTransform.builder().code(
|
|
||||||
"first = np.sin(first)\nsecond = np.cos(second)")
|
|
||||||
.outputSchema(schema).build())
|
|
||||||
.build();
|
|
||||||
|
|
||||||
List<List<Writable>> functions = new ArrayList<>();
|
|
||||||
List<Writable> firstRow = new ArrayList<>();
|
|
||||||
INDArray firstArr = Nd4j.linspace(1,4,4);
|
|
||||||
INDArray secondArr = Nd4j.linspace(1,4,4);
|
|
||||||
firstRow.add(new NDArrayWritable(firstArr));
|
|
||||||
firstRow.add(new NDArrayWritable(secondArr));
|
|
||||||
functions.add(firstRow);
|
|
||||||
|
|
||||||
List<List<Writable>> execute = LocalTransformExecutor.execute(functions, transformProcess);
|
|
||||||
INDArray firstResult = ((NDArrayWritable) execute.get(0).get(0)).get();
|
|
||||||
INDArray secondResult = ((NDArrayWritable) execute.get(0).get(1)).get();
|
|
||||||
|
|
||||||
INDArray expected = Transforms.sin(firstArr);
|
|
||||||
INDArray secondExpected = Transforms.cos(secondArr);
|
|
||||||
assertEquals(expected,firstResult);
|
|
||||||
assertEquals(secondExpected,secondResult);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,154 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.local.transforms.transform;
|
|
||||||
|
|
||||||
import org.datavec.api.transform.ColumnType;
|
|
||||||
import org.datavec.api.transform.Transform;
|
|
||||||
import org.datavec.api.transform.geo.LocationType;
|
|
||||||
import org.datavec.api.transform.schema.Schema;
|
|
||||||
import org.datavec.api.transform.transform.geo.CoordinatesDistanceTransform;
|
|
||||||
import org.datavec.api.transform.transform.geo.IPAddressToCoordinatesTransform;
|
|
||||||
import org.datavec.api.transform.transform.geo.IPAddressToLocationTransform;
|
|
||||||
import org.datavec.api.writable.DoubleWritable;
|
|
||||||
import org.datavec.api.writable.Text;
|
|
||||||
import org.datavec.api.writable.Writable;
|
|
||||||
import org.junit.AfterClass;
|
|
||||||
import org.junit.BeforeClass;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
|
||||||
|
|
||||||
import java.io.*;
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author saudet
|
|
||||||
*/
|
|
||||||
public class TestGeoTransforms {
|
|
||||||
|
|
||||||
@BeforeClass
|
|
||||||
public static void beforeClass() throws Exception {
|
|
||||||
//Use test resources version to avoid tests suddenly failing due to IP/Location DB content changing
|
|
||||||
File f = new ClassPathResource("datavec-geo/GeoIP2-City-Test.mmdb").getFile();
|
|
||||||
System.setProperty(IPAddressToLocationTransform.GEOIP_FILE_PROPERTY, f.getPath());
|
|
||||||
}
|
|
||||||
|
|
||||||
@AfterClass
|
|
||||||
public static void afterClass(){
|
|
||||||
System.setProperty(IPAddressToLocationTransform.GEOIP_FILE_PROPERTY, "");
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testCoordinatesDistanceTransform() throws Exception {
|
|
||||||
Schema schema = new Schema.Builder().addColumnString("point").addColumnString("mean").addColumnString("stddev")
|
|
||||||
.build();
|
|
||||||
|
|
||||||
Transform transform = new CoordinatesDistanceTransform("dist", "point", "mean", "stddev", "\\|");
|
|
||||||
transform.setInputSchema(schema);
|
|
||||||
|
|
||||||
Schema out = transform.transform(schema);
|
|
||||||
assertEquals(4, out.numColumns());
|
|
||||||
assertEquals(Arrays.asList("point", "mean", "stddev", "dist"), out.getColumnNames());
|
|
||||||
assertEquals(Arrays.asList(ColumnType.String, ColumnType.String, ColumnType.String, ColumnType.Double),
|
|
||||||
out.getColumnTypes());
|
|
||||||
|
|
||||||
assertEquals(Arrays.asList((Writable) new Text("-30"), new Text("20"), new Text("10"), new DoubleWritable(5.0)),
|
|
||||||
transform.map(Arrays.asList((Writable) new Text("-30"), new Text("20"), new Text("10"))));
|
|
||||||
assertEquals(Arrays.asList((Writable) new Text("50|40"), new Text("10|-20"), new Text("10|5"),
|
|
||||||
new DoubleWritable(Math.sqrt(160))),
|
|
||||||
transform.map(Arrays.asList((Writable) new Text("50|40"), new Text("10|-20"),
|
|
||||||
new Text("10|5"))));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testIPAddressToCoordinatesTransform() throws Exception {
|
|
||||||
Schema schema = new Schema.Builder().addColumnString("column").build();
|
|
||||||
|
|
||||||
Transform transform = new IPAddressToCoordinatesTransform("column", "CUSTOM_DELIMITER");
|
|
||||||
transform.setInputSchema(schema);
|
|
||||||
|
|
||||||
Schema out = transform.transform(schema);
|
|
||||||
|
|
||||||
assertEquals(1, out.getColumnMetaData().size());
|
|
||||||
assertEquals(ColumnType.String, out.getMetaData(0).getColumnType());
|
|
||||||
|
|
||||||
String in = "81.2.69.160";
|
|
||||||
double latitude = 51.5142;
|
|
||||||
double longitude = -0.0931;
|
|
||||||
|
|
||||||
List<Writable> writables = transform.map(Collections.singletonList((Writable) new Text(in)));
|
|
||||||
assertEquals(1, writables.size());
|
|
||||||
String[] coordinates = writables.get(0).toString().split("CUSTOM_DELIMITER");
|
|
||||||
assertEquals(2, coordinates.length);
|
|
||||||
assertEquals(latitude, Double.parseDouble(coordinates[0]), 0.1);
|
|
||||||
assertEquals(longitude, Double.parseDouble(coordinates[1]), 0.1);
|
|
||||||
|
|
||||||
//Check serialization: things like DatabaseReader etc aren't serializable, hence we need custom serialization :/
|
|
||||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
|
||||||
ObjectOutputStream oos = new ObjectOutputStream(baos);
|
|
||||||
oos.writeObject(transform);
|
|
||||||
|
|
||||||
byte[] bytes = baos.toByteArray();
|
|
||||||
|
|
||||||
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
|
|
||||||
ObjectInputStream ois = new ObjectInputStream(bais);
|
|
||||||
|
|
||||||
Transform deserialized = (Transform) ois.readObject();
|
|
||||||
writables = deserialized.map(Collections.singletonList((Writable) new Text(in)));
|
|
||||||
assertEquals(1, writables.size());
|
|
||||||
coordinates = writables.get(0).toString().split("CUSTOM_DELIMITER");
|
|
||||||
//System.out.println(Arrays.toString(coordinates));
|
|
||||||
assertEquals(2, coordinates.length);
|
|
||||||
assertEquals(latitude, Double.parseDouble(coordinates[0]), 0.1);
|
|
||||||
assertEquals(longitude, Double.parseDouble(coordinates[1]), 0.1);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testIPAddressToLocationTransform() throws Exception {
|
|
||||||
Schema schema = new Schema.Builder().addColumnString("column").build();
|
|
||||||
LocationType[] locationTypes = LocationType.values();
|
|
||||||
String in = "81.2.69.160";
|
|
||||||
String[] locations = {"London", "2643743", "Europe", "6255148", "United Kingdom", "2635167",
|
|
||||||
"51.5142:-0.0931", "", "England", "6269131"}; //Note: no postcode in this test DB for this record
|
|
||||||
|
|
||||||
for (int i = 0; i < locationTypes.length; i++) {
|
|
||||||
LocationType locationType = locationTypes[i];
|
|
||||||
String location = locations[i];
|
|
||||||
|
|
||||||
Transform transform = new IPAddressToLocationTransform("column", locationType);
|
|
||||||
transform.setInputSchema(schema);
|
|
||||||
|
|
||||||
Schema out = transform.transform(schema);
|
|
||||||
|
|
||||||
assertEquals(1, out.getColumnMetaData().size());
|
|
||||||
assertEquals(ColumnType.String, out.getMetaData(0).getColumnType());
|
|
||||||
|
|
||||||
List<Writable> writables = transform.map(Collections.singletonList((Writable) new Text(in)));
|
|
||||||
assertEquals(1, writables.size());
|
|
||||||
assertEquals(location, writables.get(0).toString());
|
|
||||||
//System.out.println(location);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,379 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.local.transforms.transform;
|
|
||||||
|
|
||||||
import org.datavec.api.transform.TransformProcess;
|
|
||||||
import org.datavec.api.transform.condition.Condition;
|
|
||||||
import org.datavec.api.transform.filter.ConditionFilter;
|
|
||||||
import org.datavec.api.transform.filter.Filter;
|
|
||||||
import org.datavec.api.transform.schema.Schema;
|
|
||||||
import org.datavec.local.transforms.LocalTransformExecutor;
|
|
||||||
|
|
||||||
import org.datavec.api.writable.*;
|
|
||||||
import org.datavec.python.PythonCondition;
|
|
||||||
import org.datavec.python.PythonTransform;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
|
|
||||||
import javax.annotation.concurrent.NotThreadSafe;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
import static junit.framework.TestCase.assertTrue;
|
|
||||||
import static org.datavec.api.transform.schema.Schema.Builder;
|
|
||||||
import static org.junit.Assert.*;
|
|
||||||
|
|
||||||
@NotThreadSafe
|
|
||||||
public class TestPythonTransformProcess {
|
|
||||||
|
|
||||||
|
|
||||||
@Test()
|
|
||||||
public void testStringConcat() throws Exception{
|
|
||||||
Builder schemaBuilder = new Builder();
|
|
||||||
schemaBuilder
|
|
||||||
.addColumnString("col1")
|
|
||||||
.addColumnString("col2");
|
|
||||||
|
|
||||||
Schema initialSchema = schemaBuilder.build();
|
|
||||||
schemaBuilder.addColumnString("col3");
|
|
||||||
Schema finalSchema = schemaBuilder.build();
|
|
||||||
|
|
||||||
String pythonCode = "col3 = col1 + col2";
|
|
||||||
|
|
||||||
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
|
|
||||||
PythonTransform.builder().code(pythonCode)
|
|
||||||
.outputSchema(finalSchema)
|
|
||||||
.build()
|
|
||||||
).build();
|
|
||||||
|
|
||||||
List<Writable> inputs = Arrays.asList((Writable)new Text("Hello "), new Text("World!"));
|
|
||||||
|
|
||||||
List<Writable> outputs = tp.execute(inputs);
|
|
||||||
assertEquals((outputs.get(0)).toString(), "Hello ");
|
|
||||||
assertEquals((outputs.get(1)).toString(), "World!");
|
|
||||||
assertEquals((outputs.get(2)).toString(), "Hello World!");
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test(timeout = 60000L)
|
|
||||||
public void testMixedTypes() throws Exception{
|
|
||||||
Builder schemaBuilder = new Builder();
|
|
||||||
schemaBuilder
|
|
||||||
.addColumnInteger("col1")
|
|
||||||
.addColumnFloat("col2")
|
|
||||||
.addColumnString("col3")
|
|
||||||
.addColumnDouble("col4");
|
|
||||||
|
|
||||||
|
|
||||||
Schema initialSchema = schemaBuilder.build();
|
|
||||||
schemaBuilder.addColumnInteger("col5");
|
|
||||||
Schema finalSchema = schemaBuilder.build();
|
|
||||||
|
|
||||||
String pythonCode = "col5 = (int(col3) + col1 + int(col2)) * int(col4)";
|
|
||||||
|
|
||||||
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
|
|
||||||
PythonTransform.builder().code(pythonCode)
|
|
||||||
.outputSchema(finalSchema)
|
|
||||||
.inputSchema(initialSchema)
|
|
||||||
.build() ).build();
|
|
||||||
|
|
||||||
List<Writable> inputs = Arrays.asList((Writable)new IntWritable(10),
|
|
||||||
new FloatWritable(3.5f),
|
|
||||||
new Text("5"),
|
|
||||||
new DoubleWritable(2.0)
|
|
||||||
);
|
|
||||||
|
|
||||||
List<Writable> outputs = tp.execute(inputs);
|
|
||||||
assertEquals(((LongWritable)outputs.get(4)).get(), 36);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test(timeout = 60000L)
|
|
||||||
public void testNDArray() throws Exception{
|
|
||||||
long[] shape = new long[]{3, 2};
|
|
||||||
INDArray arr1 = Nd4j.rand(shape);
|
|
||||||
INDArray arr2 = Nd4j.rand(shape);
|
|
||||||
|
|
||||||
INDArray expectedOutput = arr1.add(arr2);
|
|
||||||
|
|
||||||
Builder schemaBuilder = new Builder();
|
|
||||||
schemaBuilder
|
|
||||||
.addColumnNDArray("col1", shape)
|
|
||||||
.addColumnNDArray("col2", shape);
|
|
||||||
|
|
||||||
Schema initialSchema = schemaBuilder.build();
|
|
||||||
schemaBuilder.addColumnNDArray("col3", shape);
|
|
||||||
Schema finalSchema = schemaBuilder.build();
|
|
||||||
|
|
||||||
String pythonCode = "col3 = col1 + col2";
|
|
||||||
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
|
|
||||||
PythonTransform.builder().code(pythonCode)
|
|
||||||
.outputSchema(finalSchema)
|
|
||||||
.build() ).build();
|
|
||||||
|
|
||||||
List<Writable> inputs = Arrays.asList(
|
|
||||||
(Writable)
|
|
||||||
new NDArrayWritable(arr1),
|
|
||||||
new NDArrayWritable(arr2)
|
|
||||||
);
|
|
||||||
|
|
||||||
List<Writable> outputs = tp.execute(inputs);
|
|
||||||
assertEquals(arr1, ((NDArrayWritable)outputs.get(0)).get());
|
|
||||||
assertEquals(arr2, ((NDArrayWritable)outputs.get(1)).get());
|
|
||||||
assertEquals(expectedOutput,((NDArrayWritable)outputs.get(2)).get());
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test(timeout = 60000L)
|
|
||||||
public void testNDArray2() throws Exception{
|
|
||||||
long[] shape = new long[]{3, 2};
|
|
||||||
INDArray arr1 = Nd4j.rand(shape);
|
|
||||||
INDArray arr2 = Nd4j.rand(shape);
|
|
||||||
|
|
||||||
INDArray expectedOutput = arr1.add(arr2);
|
|
||||||
|
|
||||||
Builder schemaBuilder = new Builder();
|
|
||||||
schemaBuilder
|
|
||||||
.addColumnNDArray("col1", shape)
|
|
||||||
.addColumnNDArray("col2", shape);
|
|
||||||
|
|
||||||
Schema initialSchema = schemaBuilder.build();
|
|
||||||
schemaBuilder.addColumnNDArray("col3", shape);
|
|
||||||
Schema finalSchema = schemaBuilder.build();
|
|
||||||
|
|
||||||
String pythonCode = "col3 = col1 + col2";
|
|
||||||
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
|
|
||||||
PythonTransform.builder().code(pythonCode)
|
|
||||||
.outputSchema(finalSchema)
|
|
||||||
.build() ).build();
|
|
||||||
|
|
||||||
List<Writable> inputs = Arrays.asList(
|
|
||||||
(Writable)
|
|
||||||
new NDArrayWritable(arr1),
|
|
||||||
new NDArrayWritable(arr2)
|
|
||||||
);
|
|
||||||
|
|
||||||
List<Writable> outputs = tp.execute(inputs);
|
|
||||||
assertEquals(arr1, ((NDArrayWritable)outputs.get(0)).get());
|
|
||||||
assertEquals(arr2, ((NDArrayWritable)outputs.get(1)).get());
|
|
||||||
assertEquals(expectedOutput,((NDArrayWritable)outputs.get(2)).get());
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test(timeout = 60000L)
|
|
||||||
public void testNDArrayMixed() throws Exception{
|
|
||||||
long[] shape = new long[]{3, 2};
|
|
||||||
INDArray arr1 = Nd4j.rand(DataType.DOUBLE, shape);
|
|
||||||
INDArray arr2 = Nd4j.rand(DataType.DOUBLE, shape);
|
|
||||||
INDArray expectedOutput = arr1.add(arr2.castTo(DataType.DOUBLE));
|
|
||||||
|
|
||||||
Builder schemaBuilder = new Builder();
|
|
||||||
schemaBuilder
|
|
||||||
.addColumnNDArray("col1", shape)
|
|
||||||
.addColumnNDArray("col2", shape);
|
|
||||||
|
|
||||||
Schema initialSchema = schemaBuilder.build();
|
|
||||||
schemaBuilder.addColumnNDArray("col3", shape);
|
|
||||||
Schema finalSchema = schemaBuilder.build();
|
|
||||||
|
|
||||||
String pythonCode = "col3 = col1 + col2";
|
|
||||||
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
|
|
||||||
PythonTransform.builder().code(pythonCode)
|
|
||||||
.outputSchema(finalSchema)
|
|
||||||
.build()
|
|
||||||
).build();
|
|
||||||
|
|
||||||
List<Writable> inputs = Arrays.asList(
|
|
||||||
(Writable)
|
|
||||||
new NDArrayWritable(arr1),
|
|
||||||
new NDArrayWritable(arr2)
|
|
||||||
);
|
|
||||||
|
|
||||||
List<Writable> outputs = tp.execute(inputs);
|
|
||||||
assertEquals(arr1, ((NDArrayWritable)outputs.get(0)).get());
|
|
||||||
assertEquals(arr2, ((NDArrayWritable)outputs.get(1)).get());
|
|
||||||
assertEquals(expectedOutput,((NDArrayWritable)outputs.get(2)).get());
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test(timeout = 60000L)
|
|
||||||
public void testPythonFilter() {
|
|
||||||
Schema schema = new Builder().addColumnInteger("column").build();
|
|
||||||
|
|
||||||
Condition condition = new PythonCondition(
|
|
||||||
"f = lambda: column < 0"
|
|
||||||
);
|
|
||||||
|
|
||||||
condition.setInputSchema(schema);
|
|
||||||
|
|
||||||
Filter filter = new ConditionFilter(condition);
|
|
||||||
|
|
||||||
assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(10))));
|
|
||||||
assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(1))));
|
|
||||||
assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(0))));
|
|
||||||
assertTrue(filter.removeExample(Collections.singletonList(new IntWritable(-1))));
|
|
||||||
assertTrue(filter.removeExample(Collections.singletonList(new IntWritable(-10))));
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test(timeout = 60000L)
|
|
||||||
public void testPythonFilterAndTransform() throws Exception{
|
|
||||||
Builder schemaBuilder = new Builder();
|
|
||||||
schemaBuilder
|
|
||||||
.addColumnInteger("col1")
|
|
||||||
.addColumnFloat("col2")
|
|
||||||
.addColumnString("col3")
|
|
||||||
.addColumnDouble("col4");
|
|
||||||
|
|
||||||
Schema initialSchema = schemaBuilder.build();
|
|
||||||
schemaBuilder.addColumnString("col6");
|
|
||||||
Schema finalSchema = schemaBuilder.build();
|
|
||||||
|
|
||||||
Condition condition = new PythonCondition(
|
|
||||||
"f = lambda: col1 < 0 and col2 > 10.0"
|
|
||||||
);
|
|
||||||
|
|
||||||
condition.setInputSchema(initialSchema);
|
|
||||||
|
|
||||||
Filter filter = new ConditionFilter(condition);
|
|
||||||
|
|
||||||
String pythonCode = "col6 = str(col1 + col2)";
|
|
||||||
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
|
|
||||||
PythonTransform.builder().code(pythonCode)
|
|
||||||
.outputSchema(finalSchema)
|
|
||||||
.build()
|
|
||||||
).filter(
|
|
||||||
filter
|
|
||||||
).build();
|
|
||||||
|
|
||||||
List<List<Writable>> inputs = new ArrayList<>();
|
|
||||||
inputs.add(
|
|
||||||
Arrays.asList(
|
|
||||||
(Writable)
|
|
||||||
new IntWritable(5),
|
|
||||||
new FloatWritable(3.0f),
|
|
||||||
new Text("abcd"),
|
|
||||||
new DoubleWritable(2.1))
|
|
||||||
);
|
|
||||||
inputs.add(
|
|
||||||
Arrays.asList(
|
|
||||||
(Writable)
|
|
||||||
new IntWritable(-3),
|
|
||||||
new FloatWritable(3.0f),
|
|
||||||
new Text("abcd"),
|
|
||||||
new DoubleWritable(2.1))
|
|
||||||
);
|
|
||||||
inputs.add(
|
|
||||||
Arrays.asList(
|
|
||||||
(Writable)
|
|
||||||
new IntWritable(5),
|
|
||||||
new FloatWritable(11.2f),
|
|
||||||
new Text("abcd"),
|
|
||||||
new DoubleWritable(2.1))
|
|
||||||
);
|
|
||||||
|
|
||||||
LocalTransformExecutor.execute(inputs,tp);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testPythonTransformNoOutputSpecified() throws Exception {
|
|
||||||
PythonTransform pythonTransform = PythonTransform.builder()
|
|
||||||
.code("a += 2; b = 'hello world'")
|
|
||||||
.returnAllInputs(true)
|
|
||||||
.build();
|
|
||||||
List<List<Writable>> inputs = new ArrayList<>();
|
|
||||||
inputs.add(Arrays.asList((Writable)new IntWritable(1)));
|
|
||||||
Schema inputSchema = new Builder()
|
|
||||||
.addColumnInteger("a")
|
|
||||||
.build();
|
|
||||||
|
|
||||||
TransformProcess tp = new TransformProcess.Builder(inputSchema)
|
|
||||||
.transform(pythonTransform)
|
|
||||||
.build();
|
|
||||||
List<List<Writable>> execute = LocalTransformExecutor.execute(inputs, tp);
|
|
||||||
assertEquals(3,execute.get(0).get(0).toInt());
|
|
||||||
assertEquals("hello world",execute.get(0).get(1).toString());
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testNumpyTransform() {
|
|
||||||
PythonTransform pythonTransform = PythonTransform.builder()
|
|
||||||
.code("a += 2; b = 'hello world'")
|
|
||||||
.returnAllInputs(true)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
List<List<Writable>> inputs = new ArrayList<>();
|
|
||||||
inputs.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.scalar(1).reshape(1,1))));
|
|
||||||
Schema inputSchema = new Builder()
|
|
||||||
.addColumnNDArray("a",new long[]{1,1})
|
|
||||||
.build();
|
|
||||||
|
|
||||||
TransformProcess tp = new TransformProcess.Builder(inputSchema)
|
|
||||||
.transform(pythonTransform)
|
|
||||||
.build();
|
|
||||||
List<List<Writable>> execute = LocalTransformExecutor.execute(inputs, tp);
|
|
||||||
assertFalse(execute.isEmpty());
|
|
||||||
assertNotNull(execute.get(0));
|
|
||||||
assertNotNull(execute.get(0).get(0));
|
|
||||||
assertNotNull(execute.get(0).get(1));
|
|
||||||
assertEquals(Nd4j.scalar(3).reshape(1, 1),((NDArrayWritable)execute.get(0).get(0)).get());
|
|
||||||
assertEquals("hello world",execute.get(0).get(1).toString());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testWithSetupRun() throws Exception {
|
|
||||||
|
|
||||||
PythonTransform pythonTransform = PythonTransform.builder()
|
|
||||||
.code("five=None\n" +
|
|
||||||
"def setup():\n" +
|
|
||||||
" global five\n"+
|
|
||||||
" five = 5\n\n" +
|
|
||||||
"def run(a, b):\n" +
|
|
||||||
" c = a + b + five\n"+
|
|
||||||
" return {'c':c}\n\n")
|
|
||||||
.returnAllInputs(true)
|
|
||||||
.setupAndRun(true)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
List<List<Writable>> inputs = new ArrayList<>();
|
|
||||||
inputs.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.scalar(1).reshape(1,1)),
|
|
||||||
new NDArrayWritable(Nd4j.scalar(2).reshape(1,1))));
|
|
||||||
Schema inputSchema = new Builder()
|
|
||||||
.addColumnNDArray("a",new long[]{1,1})
|
|
||||||
.addColumnNDArray("b", new long[]{1, 1})
|
|
||||||
.build();
|
|
||||||
|
|
||||||
TransformProcess tp = new TransformProcess.Builder(inputSchema)
|
|
||||||
.transform(pythonTransform)
|
|
||||||
.build();
|
|
||||||
List<List<Writable>> execute = LocalTransformExecutor.execute(inputs, tp);
|
|
||||||
assertFalse(execute.isEmpty());
|
|
||||||
assertNotNull(execute.get(0));
|
|
||||||
assertNotNull(execute.get(0).get(0));
|
|
||||||
assertEquals(Nd4j.scalar(8).reshape(1, 1),((NDArrayWritable)execute.get(0).get(3)).get());
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -28,11 +28,11 @@ import org.datavec.api.writable.*;
|
||||||
|
|
||||||
|
|
||||||
import org.datavec.local.transforms.LocalTransformExecutor;
|
import org.datavec.local.transforms.LocalTransformExecutor;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestJoin {
|
public class TestJoin {
|
||||||
|
|
||||||
|
|
|
@ -31,13 +31,13 @@ import org.datavec.api.writable.comparator.DoubleWritableComparator;
|
||||||
|
|
||||||
|
|
||||||
import org.datavec.local.transforms.LocalTransformExecutor;
|
import org.datavec.local.transforms.LocalTransformExecutor;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestCalculateSortedRank {
|
public class TestCalculateSortedRank {
|
||||||
|
|
||||||
|
|
|
@ -31,14 +31,14 @@ import org.datavec.api.writable.Writable;
|
||||||
|
|
||||||
import org.datavec.arrow.recordreader.ArrowWritableRecordTimeSeriesBatch;
|
import org.datavec.arrow.recordreader.ArrowWritableRecordTimeSeriesBatch;
|
||||||
import org.datavec.local.transforms.LocalTransformExecutor;
|
import org.datavec.local.transforms.LocalTransformExecutor;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
public class TestConvertToSequence {
|
public class TestConvertToSequence {
|
||||||
|
|
||||||
|
|
|
@ -41,6 +41,12 @@
|
||||||
</properties>
|
</properties>
|
||||||
|
|
||||||
<dependencies>
|
<dependencies>
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.tdunning</groupId>
|
||||||
|
<artifactId>t-digest</artifactId>
|
||||||
|
<version>3.2</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.scala-lang</groupId>
|
<groupId>org.scala-lang</groupId>
|
||||||
<artifactId>scala-library</artifactId>
|
<artifactId>scala-library</artifactId>
|
||||||
|
@ -122,10 +128,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue