Upgrade dl4j to junit 5

master
agibsonccc 2021-03-15 13:02:01 +09:00
parent 552c5d0b72
commit fa1a31c877
259 changed files with 10586 additions and 21460 deletions

View File

@ -15,7 +15,7 @@
<commons.dbutils.version>1.7</commons.dbutils.version> <commons.dbutils.version>1.7</commons.dbutils.version>
<lombok.version>1.18.8</lombok.version> <lombok.version>1.18.8</lombok.version>
<logback.version>1.1.7</logback.version> <logback.version>1.1.7</logback.version>
<junit.version>4.12</junit.version> <junit.version>5.8.0-M1</junit.version>
<junit-jupiter.version>5.4.2</junit-jupiter.version> <junit-jupiter.version>5.4.2</junit-jupiter.version>
<java.version>1.8</java.version> <java.version>1.8</java.version>
<maven-shade-plugin.version>3.1.1</maven-shade-plugin.version> <maven-shade-plugin.version>3.1.1</maven-shade-plugin.version>

View File

@ -17,13 +17,14 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.nd4j.codegen.ir; package org.nd4j.codegen.ir;
public class SerializationTest { import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
public static void main(String...args) { @DisplayName("Serialization Test")
class SerializationTest {
public static void main(String... args) {
} }
} }

View File

@ -17,29 +17,23 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.nd4j.codegen.dsl; package org.nd4j.codegen.dsl;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.codegen.impl.java.DocsGenerator; import org.nd4j.codegen.impl.java.DocsGenerator;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
public class DocsGeneratorTest { @DisplayName("Docs Generator Test")
class DocsGeneratorTest {
@Test @Test
public void testJDtoMDAdapter() { @DisplayName("Test J Dto MD Adapter")
String original = "{@code %INPUT_TYPE% eye = eye(3,2)\n" + void testJDtoMDAdapter() {
" eye:\n" + String original = "{@code %INPUT_TYPE% eye = eye(3,2)\n" + " eye:\n" + " [ 1, 0]\n" + " [ 0, 1]\n" + " [ 0, 0]}";
" [ 1, 0]\n" + String expected = "{ INDArray eye = eye(3,2)\n" + " eye:\n" + " [ 1, 0]\n" + " [ 0, 1]\n" + " [ 0, 0]}";
" [ 0, 1]\n" +
" [ 0, 0]}";
String expected = "{ INDArray eye = eye(3,2)\n" +
" eye:\n" +
" [ 1, 0]\n" +
" [ 0, 1]\n" +
" [ 0, 0]}";
DocsGenerator.JavaDocToMDAdapter adapter = new DocsGenerator.JavaDocToMDAdapter(original); DocsGenerator.JavaDocToMDAdapter adapter = new DocsGenerator.JavaDocToMDAdapter(original);
String out = adapter.filter("@code", StringUtils.EMPTY).filter("%INPUT_TYPE%", "INDArray").toString(); String out = adapter.filter("@code", StringUtils.EMPTY).filter("%INPUT_TYPE%", "INDArray").toString();
assertEquals(out, expected); assertEquals(out, expected);

View File

@ -34,6 +34,14 @@
<artifactId>datavec-api</artifactId> <artifactId>datavec-api</artifactId>
<dependencies> <dependencies>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-api</artifactId>
</dependency>
<dependency>
<groupId>org.junit.vintage</groupId>
<artifactId>junit-vintage-engine</artifactId>
</dependency>
<dependency> <dependency>
<groupId>org.apache.commons</groupId> <groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId> <artifactId>commons-lang3</artifactId>

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.records.reader.impl; package org.datavec.api.records.reader.impl;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
@ -27,46 +26,37 @@ import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.TemporaryFolder; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.io.File; import java.io.File;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals; @DisplayName("Csv Line Sequence Record Reader Test")
class CSVLineSequenceRecordReaderTest extends BaseND4JTest {
public class CSVLineSequenceRecordReaderTest extends BaseND4JTest { @TempDir
public Path testDir;
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test @Test
public void test() throws Exception { @DisplayName("Test")
void test(@TempDir Path testDir) throws Exception {
File f = testDir.newFolder(); File f = testDir.toFile();
File source = new File(f, "temp.csv"); File source = new File(f, "temp.csv");
String str = "a,b,c\n1,2,3,4"; String str = "a,b,c\n1,2,3,4";
FileUtils.writeStringToFile(source, str, StandardCharsets.UTF_8); FileUtils.writeStringToFile(source, str, StandardCharsets.UTF_8);
SequenceRecordReader rr = new CSVLineSequenceRecordReader(); SequenceRecordReader rr = new CSVLineSequenceRecordReader();
rr.initialize(new FileSplit(source)); rr.initialize(new FileSplit(source));
List<List<Writable>> exp0 = Arrays.asList(Collections.<Writable>singletonList(new Text("a")), Collections.<Writable>singletonList(new Text("b")), Collections.<Writable>singletonList(new Text("c")));
List<List<Writable>> exp0 = Arrays.asList( List<List<Writable>> exp1 = Arrays.asList(Collections.<Writable>singletonList(new Text("1")), Collections.<Writable>singletonList(new Text("2")), Collections.<Writable>singletonList(new Text("3")), Collections.<Writable>singletonList(new Text("4")));
Collections.<Writable>singletonList(new Text("a")), for (int i = 0; i < 3; i++) {
Collections.<Writable>singletonList(new Text("b")),
Collections.<Writable>singletonList(new Text("c")));
List<List<Writable>> exp1 = Arrays.asList(
Collections.<Writable>singletonList(new Text("1")),
Collections.<Writable>singletonList(new Text("2")),
Collections.<Writable>singletonList(new Text("3")),
Collections.<Writable>singletonList(new Text("4")));
for( int i=0; i<3; i++ ) {
int count = 0; int count = 0;
while (rr.hasNext()) { while (rr.hasNext()) {
List<List<Writable>> next = rr.sequenceRecord(); List<List<Writable>> next = rr.sequenceRecord();
@ -76,9 +66,7 @@ public class CSVLineSequenceRecordReaderTest extends BaseND4JTest {
assertEquals(exp1, next); assertEquals(exp1, next);
} }
} }
assertEquals(2, count); assertEquals(2, count);
rr.reset(); rr.reset();
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.records.reader.impl; package org.datavec.api.records.reader.impl;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
@ -27,32 +26,34 @@ import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.TemporaryFolder; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.io.File; import java.io.File;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals; @DisplayName("Csv Multi Sequence Record Reader Test")
import static org.junit.Assert.assertFalse; class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest { @TempDir
public Path testDir;
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test @Test
public void testConcatMode() throws Exception { @DisplayName("Test Concat Mode")
for( int i=0; i<3; i++ ) { void testConcatMode() throws Exception {
for (int i = 0; i < 3; i++) {
String seqSep; String seqSep;
String seqSepRegex; String seqSepRegex;
switch (i){ switch(i) {
case 0: case 0:
seqSep = ""; seqSep = "";
seqSepRegex = "^$"; seqSepRegex = "^$";
@ -68,31 +69,23 @@ public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
default: default:
throw new RuntimeException(); throw new RuntimeException();
} }
String str = "a,b,c\n1,2,3,4\nx,y\n" + seqSep + "\nA,B,C"; String str = "a,b,c\n1,2,3,4\nx,y\n" + seqSep + "\nA,B,C";
File f = testDir.newFile(); File f = testDir.toFile();
FileUtils.writeStringToFile(f, str, StandardCharsets.UTF_8); FileUtils.writeStringToFile(f, str, StandardCharsets.UTF_8);
SequenceRecordReader seqRR = new CSVMultiSequenceRecordReader(seqSepRegex, CSVMultiSequenceRecordReader.Mode.CONCAT); SequenceRecordReader seqRR = new CSVMultiSequenceRecordReader(seqSepRegex, CSVMultiSequenceRecordReader.Mode.CONCAT);
seqRR.initialize(new FileSplit(f)); seqRR.initialize(new FileSplit(f));
List<List<Writable>> exp0 = new ArrayList<>(); List<List<Writable>> exp0 = new ArrayList<>();
for (String s : "a,b,c,1,2,3,4,x,y".split(",")) { for (String s : "a,b,c,1,2,3,4,x,y".split(",")) {
exp0.add(Collections.<Writable>singletonList(new Text(s))); exp0.add(Collections.<Writable>singletonList(new Text(s)));
} }
List<List<Writable>> exp1 = new ArrayList<>(); List<List<Writable>> exp1 = new ArrayList<>();
for (String s : "A,B,C".split(",")) { for (String s : "A,B,C".split(",")) {
exp1.add(Collections.<Writable>singletonList(new Text(s))); exp1.add(Collections.<Writable>singletonList(new Text(s)));
} }
assertEquals(exp0, seqRR.sequenceRecord()); assertEquals(exp0, seqRR.sequenceRecord());
assertEquals(exp1, seqRR.sequenceRecord()); assertEquals(exp1, seqRR.sequenceRecord());
assertFalse(seqRR.hasNext()); assertFalse(seqRR.hasNext());
seqRR.reset(); seqRR.reset();
assertEquals(exp0, seqRR.sequenceRecord()); assertEquals(exp0, seqRR.sequenceRecord());
assertEquals(exp1, seqRR.sequenceRecord()); assertEquals(exp1, seqRR.sequenceRecord());
assertFalse(seqRR.hasNext()); assertFalse(seqRR.hasNext());
@ -100,13 +93,12 @@ public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
} }
@Test @Test
public void testEqualLength() throws Exception { @DisplayName("Test Equal Length")
void testEqualLength() throws Exception {
for( int i=0; i<3; i++ ) { for (int i = 0; i < 3; i++) {
String seqSep; String seqSep;
String seqSepRegex; String seqSepRegex;
switch (i) { switch(i) {
case 0: case 0:
seqSep = ""; seqSep = "";
seqSepRegex = "^$"; seqSepRegex = "^$";
@ -122,27 +114,17 @@ public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
default: default:
throw new RuntimeException(); throw new RuntimeException();
} }
String str = "a,b\n1,2\nx,y\n" + seqSep + "\nA\nB\nC"; String str = "a,b\n1,2\nx,y\n" + seqSep + "\nA\nB\nC";
File f = testDir.newFile(); File f = testDir.toFile();
FileUtils.writeStringToFile(f, str, StandardCharsets.UTF_8); FileUtils.writeStringToFile(f, str, StandardCharsets.UTF_8);
SequenceRecordReader seqRR = new CSVMultiSequenceRecordReader(seqSepRegex, CSVMultiSequenceRecordReader.Mode.EQUAL_LENGTH); SequenceRecordReader seqRR = new CSVMultiSequenceRecordReader(seqSepRegex, CSVMultiSequenceRecordReader.Mode.EQUAL_LENGTH);
seqRR.initialize(new FileSplit(f)); seqRR.initialize(new FileSplit(f));
List<List<Writable>> exp0 = Arrays.asList(Arrays.<Writable>asList(new Text("a"), new Text("1"), new Text("x")), Arrays.<Writable>asList(new Text("b"), new Text("2"), new Text("y")));
List<List<Writable>> exp0 = Arrays.asList(
Arrays.<Writable>asList(new Text("a"), new Text("1"), new Text("x")),
Arrays.<Writable>asList(new Text("b"), new Text("2"), new Text("y")));
List<List<Writable>> exp1 = Collections.singletonList(Arrays.<Writable>asList(new Text("A"), new Text("B"), new Text("C"))); List<List<Writable>> exp1 = Collections.singletonList(Arrays.<Writable>asList(new Text("A"), new Text("B"), new Text("C")));
assertEquals(exp0, seqRR.sequenceRecord()); assertEquals(exp0, seqRR.sequenceRecord());
assertEquals(exp1, seqRR.sequenceRecord()); assertEquals(exp1, seqRR.sequenceRecord());
assertFalse(seqRR.hasNext()); assertFalse(seqRR.hasNext());
seqRR.reset(); seqRR.reset();
assertEquals(exp0, seqRR.sequenceRecord()); assertEquals(exp0, seqRR.sequenceRecord());
assertEquals(exp1, seqRR.sequenceRecord()); assertEquals(exp1, seqRR.sequenceRecord());
assertFalse(seqRR.hasNext()); assertFalse(seqRR.hasNext());
@ -150,13 +132,12 @@ public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
} }
@Test @Test
public void testPadding() throws Exception { @DisplayName("Test Padding")
void testPadding() throws Exception {
for( int i=0; i<3; i++ ) { for (int i = 0; i < 3; i++) {
String seqSep; String seqSep;
String seqSepRegex; String seqSepRegex;
switch (i) { switch(i) {
case 0: case 0:
seqSep = ""; seqSep = "";
seqSepRegex = "^$"; seqSepRegex = "^$";
@ -172,27 +153,17 @@ public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
default: default:
throw new RuntimeException(); throw new RuntimeException();
} }
String str = "a,b\n1\nx\n" + seqSep + "\nA\nB\nC"; String str = "a,b\n1\nx\n" + seqSep + "\nA\nB\nC";
File f = testDir.newFile(); File f = testDir.toFile();
FileUtils.writeStringToFile(f, str, StandardCharsets.UTF_8); FileUtils.writeStringToFile(f, str, StandardCharsets.UTF_8);
SequenceRecordReader seqRR = new CSVMultiSequenceRecordReader(seqSepRegex, CSVMultiSequenceRecordReader.Mode.PAD, new Text("PAD")); SequenceRecordReader seqRR = new CSVMultiSequenceRecordReader(seqSepRegex, CSVMultiSequenceRecordReader.Mode.PAD, new Text("PAD"));
seqRR.initialize(new FileSplit(f)); seqRR.initialize(new FileSplit(f));
List<List<Writable>> exp0 = Arrays.asList(Arrays.<Writable>asList(new Text("a"), new Text("1"), new Text("x")), Arrays.<Writable>asList(new Text("b"), new Text("PAD"), new Text("PAD")));
List<List<Writable>> exp0 = Arrays.asList(
Arrays.<Writable>asList(new Text("a"), new Text("1"), new Text("x")),
Arrays.<Writable>asList(new Text("b"), new Text("PAD"), new Text("PAD")));
List<List<Writable>> exp1 = Collections.singletonList(Arrays.<Writable>asList(new Text("A"), new Text("B"), new Text("C"))); List<List<Writable>> exp1 = Collections.singletonList(Arrays.<Writable>asList(new Text("A"), new Text("B"), new Text("C")));
assertEquals(exp0, seqRR.sequenceRecord()); assertEquals(exp0, seqRR.sequenceRecord());
assertEquals(exp1, seqRR.sequenceRecord()); assertEquals(exp1, seqRR.sequenceRecord());
assertFalse(seqRR.hasNext()); assertFalse(seqRR.hasNext());
seqRR.reset(); seqRR.reset();
assertEquals(exp0, seqRR.sequenceRecord()); assertEquals(exp0, seqRR.sequenceRecord());
assertEquals(exp1, seqRR.sequenceRecord()); assertEquals(exp1, seqRR.sequenceRecord());
assertFalse(seqRR.hasNext()); assertFalse(seqRR.hasNext());

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.records.reader.impl; package org.datavec.api.records.reader.impl;
import org.datavec.api.records.SequenceRecord; import org.datavec.api.records.SequenceRecord;
@ -27,61 +26,53 @@ import org.datavec.api.records.reader.impl.csv.CSVNLinesSequenceRecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit; import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals; @DisplayName("Csvn Lines Sequence Record Reader Test")
class CSVNLinesSequenceRecordReaderTest extends BaseND4JTest {
public class CSVNLinesSequenceRecordReaderTest extends BaseND4JTest {
@Test @Test
public void testCSVNLinesSequenceRecordReader() throws Exception { @DisplayName("Test CSVN Lines Sequence Record Reader")
void testCSVNLinesSequenceRecordReader() throws Exception {
int nLinesPerSequence = 10; int nLinesPerSequence = 10;
SequenceRecordReader seqRR = new CSVNLinesSequenceRecordReader(nLinesPerSequence); SequenceRecordReader seqRR = new CSVNLinesSequenceRecordReader(nLinesPerSequence);
seqRR.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile())); seqRR.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
CSVRecordReader rr = new CSVRecordReader(); CSVRecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile())); rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
int count = 0; int count = 0;
while (seqRR.hasNext()) { while (seqRR.hasNext()) {
List<List<Writable>> next = seqRR.sequenceRecord(); List<List<Writable>> next = seqRR.sequenceRecord();
List<List<Writable>> expected = new ArrayList<>(); List<List<Writable>> expected = new ArrayList<>();
for (int i = 0; i < nLinesPerSequence; i++) { for (int i = 0; i < nLinesPerSequence; i++) {
expected.add(rr.next()); expected.add(rr.next());
} }
assertEquals(10, next.size()); assertEquals(10, next.size());
assertEquals(expected, next); assertEquals(expected, next);
count++; count++;
} }
assertEquals(150 / nLinesPerSequence, count); assertEquals(150 / nLinesPerSequence, count);
} }
@Test @Test
public void testCSVNlinesSequenceRecordReaderMetaData() throws Exception { @DisplayName("Test CSV Nlines Sequence Record Reader Meta Data")
void testCSVNlinesSequenceRecordReaderMetaData() throws Exception {
int nLinesPerSequence = 10; int nLinesPerSequence = 10;
SequenceRecordReader seqRR = new CSVNLinesSequenceRecordReader(nLinesPerSequence); SequenceRecordReader seqRR = new CSVNLinesSequenceRecordReader(nLinesPerSequence);
seqRR.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile())); seqRR.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
CSVRecordReader rr = new CSVRecordReader(); CSVRecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile())); rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
List<List<List<Writable>>> out = new ArrayList<>(); List<List<List<Writable>>> out = new ArrayList<>();
while (seqRR.hasNext()) { while (seqRR.hasNext()) {
List<List<Writable>> next = seqRR.sequenceRecord(); List<List<Writable>> next = seqRR.sequenceRecord();
out.add(next); out.add(next);
} }
seqRR.reset(); seqRR.reset();
List<List<List<Writable>>> out2 = new ArrayList<>(); List<List<List<Writable>>> out2 = new ArrayList<>();
List<SequenceRecord> out3 = new ArrayList<>(); List<SequenceRecord> out3 = new ArrayList<>();
@ -92,11 +83,8 @@ public class CSVNLinesSequenceRecordReaderTest extends BaseND4JTest {
meta.add(seq.getMetaData()); meta.add(seq.getMetaData());
out3.add(seq); out3.add(seq);
} }
assertEquals(out, out2); assertEquals(out, out2);
List<SequenceRecord> out4 = seqRR.loadSequenceFromMetaData(meta); List<SequenceRecord> out4 = seqRR.loadSequenceFromMetaData(meta);
assertEquals(out3, out4); assertEquals(out3, out4);
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.records.reader.impl; package org.datavec.api.records.reader.impl;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
@ -34,10 +33,10 @@ import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.nio.file.Files; import java.nio.file.Files;
@ -47,41 +46,44 @@ import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.NoSuchElementException; import java.util.NoSuchElementException;
import static org.junit.Assert.*; import static org.junit.jupiter.api.Assertions.*;
@DisplayName("Csv Record Reader Test")
class CSVRecordReaderTest extends BaseND4JTest {
public class CSVRecordReaderTest extends BaseND4JTest {
@Test @Test
public void testNext() throws Exception { @DisplayName("Test Next")
void testNext() throws Exception {
CSVRecordReader reader = new CSVRecordReader(); CSVRecordReader reader = new CSVRecordReader();
reader.initialize(new StringSplit("1,1,8.0,,,,14.0,,,,15.0,,,,,,,,,,,,1")); reader.initialize(new StringSplit("1,1,8.0,,,,14.0,,,,15.0,,,,,,,,,,,,1"));
while (reader.hasNext()) { while (reader.hasNext()) {
List<Writable> vals = reader.next(); List<Writable> vals = reader.next();
List<Writable> arr = new ArrayList<>(vals); List<Writable> arr = new ArrayList<>(vals);
assertEquals(23, vals.size(), "Entry count");
assertEquals("Entry count", 23, vals.size());
Text lastEntry = (Text) arr.get(arr.size() - 1); Text lastEntry = (Text) arr.get(arr.size() - 1);
assertEquals("Last entry garbage", 1, lastEntry.getLength()); assertEquals(1, lastEntry.getLength(), "Last entry garbage");
} }
} }
@Test @Test
public void testEmptyEntries() throws Exception { @DisplayName("Test Empty Entries")
void testEmptyEntries() throws Exception {
CSVRecordReader reader = new CSVRecordReader(); CSVRecordReader reader = new CSVRecordReader();
reader.initialize(new StringSplit("1,1,8.0,,,,14.0,,,,15.0,,,,,,,,,,,,")); reader.initialize(new StringSplit("1,1,8.0,,,,14.0,,,,15.0,,,,,,,,,,,,"));
while (reader.hasNext()) { while (reader.hasNext()) {
List<Writable> vals = reader.next(); List<Writable> vals = reader.next();
assertEquals("Entry count", 23, vals.size()); assertEquals(23, vals.size(), "Entry count");
} }
} }
@Test @Test
public void testReset() throws Exception { @DisplayName("Test Reset")
void testReset() throws Exception {
CSVRecordReader rr = new CSVRecordReader(0, ','); CSVRecordReader rr = new CSVRecordReader(0, ',');
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile())); rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
int nResets = 5; int nResets = 5;
for (int i = 0; i < nResets; i++) { for (int i = 0; i < nResets; i++) {
int lineCount = 0; int lineCount = 0;
while (rr.hasNext()) { while (rr.hasNext()) {
List<Writable> line = rr.next(); List<Writable> line = rr.next();
@ -95,7 +97,8 @@ public class CSVRecordReaderTest extends BaseND4JTest {
} }
@Test @Test
public void testResetWithSkipLines() throws Exception { @DisplayName("Test Reset With Skip Lines")
void testResetWithSkipLines() throws Exception {
CSVRecordReader rr = new CSVRecordReader(10, ','); CSVRecordReader rr = new CSVRecordReader(10, ',');
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile())); rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
int lineCount = 0; int lineCount = 0;
@ -114,7 +117,8 @@ public class CSVRecordReaderTest extends BaseND4JTest {
} }
@Test @Test
public void testWrite() throws Exception { @DisplayName("Test Write")
void testWrite() throws Exception {
List<List<Writable>> list = new ArrayList<>(); List<List<Writable>> list = new ArrayList<>();
StringBuilder sb = new StringBuilder(); StringBuilder sb = new StringBuilder();
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
@ -130,81 +134,72 @@ public class CSVRecordReaderTest extends BaseND4JTest {
} }
list.add(temp); list.add(temp);
} }
String expected = sb.toString(); String expected = sb.toString();
Path p = Files.createTempFile("csvwritetest", "csv"); Path p = Files.createTempFile("csvwritetest", "csv");
p.toFile().deleteOnExit(); p.toFile().deleteOnExit();
FileRecordWriter writer = new CSVRecordWriter(); FileRecordWriter writer = new CSVRecordWriter();
FileSplit fileSplit = new FileSplit(p.toFile()); FileSplit fileSplit = new FileSplit(p.toFile());
writer.initialize(fileSplit,new NumberOfRecordsPartitioner()); writer.initialize(fileSplit, new NumberOfRecordsPartitioner());
for (List<Writable> c : list) { for (List<Writable> c : list) {
writer.write(c); writer.write(c);
} }
writer.close(); writer.close();
// Read file back in; compare
//Read file back in; compare
String fileContents = FileUtils.readFileToString(p.toFile(), FileRecordWriter.DEFAULT_CHARSET.name()); String fileContents = FileUtils.readFileToString(p.toFile(), FileRecordWriter.DEFAULT_CHARSET.name());
// System.out.println(expected);
// System.out.println(expected); // System.out.println("----------");
// System.out.println("----------"); // System.out.println(fileContents);
// System.out.println(fileContents);
assertEquals(expected, fileContents); assertEquals(expected, fileContents);
} }
@Test @Test
public void testTabsAsSplit1() throws Exception { @DisplayName("Test Tabs As Split 1")
void testTabsAsSplit1() throws Exception {
CSVRecordReader reader = new CSVRecordReader(0, '\t'); CSVRecordReader reader = new CSVRecordReader(0, '\t');
reader.initialize(new FileSplit(new ClassPathResource("datavec-api/tabbed.txt").getFile())); reader.initialize(new FileSplit(new ClassPathResource("datavec-api/tabbed.txt").getFile()));
while (reader.hasNext()) { while (reader.hasNext()) {
List<Writable> list = new ArrayList<>(reader.next()); List<Writable> list = new ArrayList<>(reader.next());
assertEquals(2, list.size()); assertEquals(2, list.size());
} }
} }
@Test @Test
public void testPipesAsSplit() throws Exception { @DisplayName("Test Pipes As Split")
void testPipesAsSplit() throws Exception {
CSVRecordReader reader = new CSVRecordReader(0, '|'); CSVRecordReader reader = new CSVRecordReader(0, '|');
reader.initialize(new FileSplit(new ClassPathResource("datavec-api/issue414.csv").getFile())); reader.initialize(new FileSplit(new ClassPathResource("datavec-api/issue414.csv").getFile()));
int lineidx = 0; int lineidx = 0;
List<Integer> sixthColumn = Arrays.asList(13, 95, 15, 25); List<Integer> sixthColumn = Arrays.asList(13, 95, 15, 25);
while (reader.hasNext()) { while (reader.hasNext()) {
List<Writable> list = new ArrayList<>(reader.next()); List<Writable> list = new ArrayList<>(reader.next());
assertEquals(10, list.size()); assertEquals(10, list.size());
assertEquals((long)sixthColumn.get(lineidx), list.get(5).toInt()); assertEquals((long) sixthColumn.get(lineidx), list.get(5).toInt());
lineidx++; lineidx++;
} }
} }
@Test @Test
public void testWithQuotes() throws Exception { @DisplayName("Test With Quotes")
void testWithQuotes() throws Exception {
CSVRecordReader reader = new CSVRecordReader(0, ',', '\"'); CSVRecordReader reader = new CSVRecordReader(0, ',', '\"');
reader.initialize(new StringSplit("1,0,3,\"Braund, Mr. Owen Harris\",male,\"\"\"\"")); reader.initialize(new StringSplit("1,0,3,\"Braund, Mr. Owen Harris\",male,\"\"\"\""));
while (reader.hasNext()) { while (reader.hasNext()) {
List<Writable> vals = reader.next(); List<Writable> vals = reader.next();
assertEquals("Entry count", 6, vals.size()); assertEquals(6, vals.size(), "Entry count");
assertEquals("1", vals.get(0).toString()); assertEquals(vals.get(0).toString(), "1");
assertEquals("0", vals.get(1).toString()); assertEquals(vals.get(1).toString(), "0");
assertEquals("3", vals.get(2).toString()); assertEquals(vals.get(2).toString(), "3");
assertEquals("Braund, Mr. Owen Harris", vals.get(3).toString()); assertEquals(vals.get(3).toString(), "Braund, Mr. Owen Harris");
assertEquals("male", vals.get(4).toString()); assertEquals(vals.get(4).toString(), "male");
assertEquals("\"", vals.get(5).toString()); assertEquals(vals.get(5).toString(), "\"");
} }
} }
@Test @Test
public void testMeta() throws Exception { @DisplayName("Test Meta")
void testMeta() throws Exception {
CSVRecordReader rr = new CSVRecordReader(0, ','); CSVRecordReader rr = new CSVRecordReader(0, ',');
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile())); rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
int lineCount = 0; int lineCount = 0;
List<RecordMetaData> metaList = new ArrayList<>(); List<RecordMetaData> metaList = new ArrayList<>();
List<List<Writable>> writables = new ArrayList<>(); List<List<Writable>> writables = new ArrayList<>();
@ -213,30 +208,25 @@ public class CSVRecordReaderTest extends BaseND4JTest {
assertEquals(5, r.getRecord().size()); assertEquals(5, r.getRecord().size());
lineCount++; lineCount++;
RecordMetaData meta = r.getMetaData(); RecordMetaData meta = r.getMetaData();
// System.out.println(r.getRecord() + "\t" + meta.getLocation() + "\t" + meta.getURI()); // System.out.println(r.getRecord() + "\t" + meta.getLocation() + "\t" + meta.getURI());
metaList.add(meta); metaList.add(meta);
writables.add(r.getRecord()); writables.add(r.getRecord());
} }
assertFalse(rr.hasNext()); assertFalse(rr.hasNext());
assertEquals(150, lineCount); assertEquals(150, lineCount);
rr.reset(); rr.reset();
System.out.println("\n\n\n--------------------------------"); System.out.println("\n\n\n--------------------------------");
List<Record> contents = rr.loadFromMetaData(metaList); List<Record> contents = rr.loadFromMetaData(metaList);
assertEquals(150, contents.size()); assertEquals(150, contents.size());
// for(Record r : contents ){ // for(Record r : contents ){
// System.out.println(r); // System.out.println(r);
// } // }
List<RecordMetaData> meta2 = new ArrayList<>(); List<RecordMetaData> meta2 = new ArrayList<>();
meta2.add(metaList.get(100)); meta2.add(metaList.get(100));
meta2.add(metaList.get(90)); meta2.add(metaList.get(90));
meta2.add(metaList.get(80)); meta2.add(metaList.get(80));
meta2.add(metaList.get(70)); meta2.add(metaList.get(70));
meta2.add(metaList.get(60)); meta2.add(metaList.get(60));
List<Record> contents2 = rr.loadFromMetaData(meta2); List<Record> contents2 = rr.loadFromMetaData(meta2);
assertEquals(writables.get(100), contents2.get(0).getRecord()); assertEquals(writables.get(100), contents2.get(0).getRecord());
assertEquals(writables.get(90), contents2.get(1).getRecord()); assertEquals(writables.get(90), contents2.get(1).getRecord());
@ -246,50 +236,49 @@ public class CSVRecordReaderTest extends BaseND4JTest {
} }
@Test @Test
public void testRegex() throws Exception { @DisplayName("Test Regex")
CSVRecordReader reader = new CSVRegexRecordReader(0, ",", null, new String[] {null, "(.+) (.+) (.+)"}); void testRegex() throws Exception {
CSVRecordReader reader = new CSVRegexRecordReader(0, ",", null, new String[] { null, "(.+) (.+) (.+)" });
reader.initialize(new StringSplit("normal,1.2.3.4 space separator")); reader.initialize(new StringSplit("normal,1.2.3.4 space separator"));
while (reader.hasNext()) { while (reader.hasNext()) {
List<Writable> vals = reader.next(); List<Writable> vals = reader.next();
assertEquals("Entry count", 4, vals.size()); assertEquals(4, vals.size(), "Entry count");
assertEquals("normal", vals.get(0).toString()); assertEquals(vals.get(0).toString(), "normal");
assertEquals("1.2.3.4", vals.get(1).toString()); assertEquals(vals.get(1).toString(), "1.2.3.4");
assertEquals("space", vals.get(2).toString()); assertEquals(vals.get(2).toString(), "space");
assertEquals("separator", vals.get(3).toString()); assertEquals(vals.get(3).toString(), "separator");
} }
} }
@Test(expected = NoSuchElementException.class) @Test
public void testCsvSkipAllLines() throws IOException, InterruptedException { @DisplayName("Test Csv Skip All Lines")
final int numLines = 4; void testCsvSkipAllLines() {
final List<Writable> lineList = Arrays.asList((Writable) new IntWritable(numLines - 1), assertThrows(NoSuchElementException.class, () -> {
(Writable) new Text("one"), (Writable) new Text("two"), (Writable) new Text("three")); final int numLines = 4;
String header = ",one,two,three"; final List<Writable> lineList = Arrays.asList((Writable) new IntWritable(numLines - 1), (Writable) new Text("one"), (Writable) new Text("two"), (Writable) new Text("three"));
List<String> lines = new ArrayList<>(); String header = ",one,two,three";
for (int i = 0; i < numLines; i++) List<String> lines = new ArrayList<>();
lines.add(Integer.toString(i) + header); for (int i = 0; i < numLines; i++) lines.add(Integer.toString(i) + header);
File tempFile = File.createTempFile("csvSkipLines", ".csv"); File tempFile = File.createTempFile("csvSkipLines", ".csv");
FileUtils.writeLines(tempFile, lines); FileUtils.writeLines(tempFile, lines);
CSVRecordReader rr = new CSVRecordReader(numLines, ',');
CSVRecordReader rr = new CSVRecordReader(numLines, ','); rr.initialize(new FileSplit(tempFile));
rr.initialize(new FileSplit(tempFile)); rr.reset();
rr.reset(); assertTrue(!rr.hasNext());
assertTrue(!rr.hasNext()); rr.next();
rr.next(); });
} }
@Test @Test
public void testCsvSkipAllButOneLine() throws IOException, InterruptedException { @DisplayName("Test Csv Skip All But One Line")
void testCsvSkipAllButOneLine() throws IOException, InterruptedException {
final int numLines = 4; final int numLines = 4;
final List<Writable> lineList = Arrays.<Writable>asList(new Text(Integer.toString(numLines - 1)), final List<Writable> lineList = Arrays.<Writable>asList(new Text(Integer.toString(numLines - 1)), new Text("one"), new Text("two"), new Text("three"));
new Text("one"), new Text("two"), new Text("three"));
String header = ",one,two,three"; String header = ",one,two,three";
List<String> lines = new ArrayList<>(); List<String> lines = new ArrayList<>();
for (int i = 0; i < numLines; i++) for (int i = 0; i < numLines; i++) lines.add(Integer.toString(i) + header);
lines.add(Integer.toString(i) + header);
File tempFile = File.createTempFile("csvSkipLines", ".csv"); File tempFile = File.createTempFile("csvSkipLines", ".csv");
FileUtils.writeLines(tempFile, lines); FileUtils.writeLines(tempFile, lines);
CSVRecordReader rr = new CSVRecordReader(numLines - 1, ','); CSVRecordReader rr = new CSVRecordReader(numLines - 1, ',');
rr.initialize(new FileSplit(tempFile)); rr.initialize(new FileSplit(tempFile));
rr.reset(); rr.reset();
@ -297,50 +286,45 @@ public class CSVRecordReaderTest extends BaseND4JTest {
assertEquals(rr.next(), lineList); assertEquals(rr.next(), lineList);
} }
@Test @Test
public void testStreamReset() throws Exception { @DisplayName("Test Stream Reset")
void testStreamReset() throws Exception {
CSVRecordReader rr = new CSVRecordReader(0, ','); CSVRecordReader rr = new CSVRecordReader(0, ',');
rr.initialize(new InputStreamInputSplit(new ClassPathResource("datavec-api/iris.dat").getInputStream())); rr.initialize(new InputStreamInputSplit(new ClassPathResource("datavec-api/iris.dat").getInputStream()));
int count = 0; int count = 0;
while(rr.hasNext()){ while (rr.hasNext()) {
assertNotNull(rr.next()); assertNotNull(rr.next());
count++; count++;
} }
assertEquals(150, count); assertEquals(150, count);
assertFalse(rr.resetSupported()); assertFalse(rr.resetSupported());
try {
try{
rr.reset(); rr.reset();
fail("Expected exception"); fail("Expected exception");
} catch (Exception e){ } catch (Exception e) {
String msg = e.getMessage(); String msg = e.getMessage();
String msg2 = e.getCause().getMessage(); String msg2 = e.getCause().getMessage();
assertTrue(msg, msg.contains("Error during LineRecordReader reset")); assertTrue(msg.contains("Error during LineRecordReader reset"),msg);
assertTrue(msg2, msg2.contains("Reset not supported from streams")); assertTrue(msg2.contains("Reset not supported from streams"),msg2);
// e.printStackTrace(); // e.printStackTrace();
} }
} }
@Test @Test
public void testUsefulExceptionNoInit(){ @DisplayName("Test Useful Exception No Init")
void testUsefulExceptionNoInit() {
CSVRecordReader rr = new CSVRecordReader(0, ','); CSVRecordReader rr = new CSVRecordReader(0, ',');
try {
try{
rr.hasNext(); rr.hasNext();
fail("Expected exception"); fail("Expected exception");
} catch (Exception e){ } catch (Exception e) {
assertTrue(e.getMessage(), e.getMessage().contains("initialized")); assertTrue( e.getMessage().contains("initialized"),e.getMessage());
} }
try {
try{
rr.next(); rr.next();
fail("Expected exception"); fail("Expected exception");
} catch (Exception e){ } catch (Exception e) {
assertTrue(e.getMessage(), e.getMessage().contains("initialized")); assertTrue(e.getMessage().contains("initialized"),e.getMessage());
} }
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.records.reader.impl; package org.datavec.api.records.reader.impl;
import org.datavec.api.records.SequenceRecord; import org.datavec.api.records.SequenceRecord;
@ -28,11 +27,10 @@ import org.datavec.api.split.InputSplit;
import org.datavec.api.split.NumberedFileInputSplit; import org.datavec.api.split.NumberedFileInputSplit;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.TemporaryFolder; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import java.io.File; import java.io.File;
import java.io.InputStream; import java.io.InputStream;
import java.io.OutputStream; import java.io.OutputStream;
@ -41,25 +39,27 @@ import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Iterator; import java.util.Iterator;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals; @DisplayName("Csv Sequence Record Reader Test")
class CSVSequenceRecordReaderTest extends BaseND4JTest {
public class CSVSequenceRecordReaderTest extends BaseND4JTest { @TempDir
public Path tempDir;
@Rule
public TemporaryFolder tempDir = new TemporaryFolder();
@Test @Test
public void test() throws Exception { @DisplayName("Test")
void test() throws Exception {
CSVSequenceRecordReader seqReader = new CSVSequenceRecordReader(1, ","); CSVSequenceRecordReader seqReader = new CSVSequenceRecordReader(1, ",");
seqReader.initialize(new TestInputSplit()); seqReader.initialize(new TestInputSplit());
int sequenceCount = 0; int sequenceCount = 0;
while (seqReader.hasNext()) { while (seqReader.hasNext()) {
List<List<Writable>> sequence = seqReader.sequenceRecord(); List<List<Writable>> sequence = seqReader.sequenceRecord();
assertEquals(4, sequence.size()); //4 lines, plus 1 header line // 4 lines, plus 1 header line
assertEquals(4, sequence.size());
Iterator<List<Writable>> timeStepIter = sequence.iterator(); Iterator<List<Writable>> timeStepIter = sequence.iterator();
int lineCount = 0; int lineCount = 0;
while (timeStepIter.hasNext()) { while (timeStepIter.hasNext()) {
@ -80,19 +80,18 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
} }
@Test @Test
public void testReset() throws Exception { @DisplayName("Test Reset")
void testReset() throws Exception {
CSVSequenceRecordReader seqReader = new CSVSequenceRecordReader(1, ","); CSVSequenceRecordReader seqReader = new CSVSequenceRecordReader(1, ",");
seqReader.initialize(new TestInputSplit()); seqReader.initialize(new TestInputSplit());
int nTests = 5; int nTests = 5;
for (int i = 0; i < nTests; i++) { for (int i = 0; i < nTests; i++) {
seqReader.reset(); seqReader.reset();
int sequenceCount = 0; int sequenceCount = 0;
while (seqReader.hasNext()) { while (seqReader.hasNext()) {
List<List<Writable>> sequence = seqReader.sequenceRecord(); List<List<Writable>> sequence = seqReader.sequenceRecord();
assertEquals(4, sequence.size()); //4 lines, plus 1 header line // 4 lines, plus 1 header line
assertEquals(4, sequence.size());
Iterator<List<Writable>> timeStepIter = sequence.iterator(); Iterator<List<Writable>> timeStepIter = sequence.iterator();
int lineCount = 0; int lineCount = 0;
while (timeStepIter.hasNext()) { while (timeStepIter.hasNext()) {
@ -107,15 +106,15 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
} }
@Test @Test
public void testMetaData() throws Exception { @DisplayName("Test Meta Data")
void testMetaData() throws Exception {
CSVSequenceRecordReader seqReader = new CSVSequenceRecordReader(1, ","); CSVSequenceRecordReader seqReader = new CSVSequenceRecordReader(1, ",");
seqReader.initialize(new TestInputSplit()); seqReader.initialize(new TestInputSplit());
List<List<List<Writable>>> l = new ArrayList<>(); List<List<List<Writable>>> l = new ArrayList<>();
while (seqReader.hasNext()) { while (seqReader.hasNext()) {
List<List<Writable>> sequence = seqReader.sequenceRecord(); List<List<Writable>> sequence = seqReader.sequenceRecord();
assertEquals(4, sequence.size()); //4 lines, plus 1 header line // 4 lines, plus 1 header line
assertEquals(4, sequence.size());
Iterator<List<Writable>> timeStepIter = sequence.iterator(); Iterator<List<Writable>> timeStepIter = sequence.iterator();
int lineCount = 0; int lineCount = 0;
while (timeStepIter.hasNext()) { while (timeStepIter.hasNext()) {
@ -123,10 +122,8 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
lineCount++; lineCount++;
} }
assertEquals(4, lineCount); assertEquals(4, lineCount);
l.add(sequence); l.add(sequence);
} }
List<SequenceRecord> l2 = new ArrayList<>(); List<SequenceRecord> l2 = new ArrayList<>();
List<RecordMetaData> meta = new ArrayList<>(); List<RecordMetaData> meta = new ArrayList<>();
seqReader.reset(); seqReader.reset();
@ -136,7 +133,6 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
meta.add(sr.getMetaData()); meta.add(sr.getMetaData());
} }
assertEquals(3, l2.size()); assertEquals(3, l2.size());
List<SequenceRecord> fromMeta = seqReader.loadSequenceFromMetaData(meta); List<SequenceRecord> fromMeta = seqReader.loadSequenceFromMetaData(meta);
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
assertEquals(l.get(i), l2.get(i).getSequenceRecord()); assertEquals(l.get(i), l2.get(i).getSequenceRecord());
@ -144,8 +140,8 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
} }
} }
private static class @DisplayName("Test Input Split")
TestInputSplit implements InputSplit { private static class TestInputSplit implements InputSplit {
@Override @Override
public boolean canWriteToLocation(URI location) { public boolean canWriteToLocation(URI location) {
@ -164,7 +160,6 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
@Override @Override
public void updateSplitLocations(boolean reset) { public void updateSplitLocations(boolean reset) {
} }
@Override @Override
@ -174,7 +169,6 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
@Override @Override
public void bootStrapForWrite() { public void bootStrapForWrite() {
} }
@Override @Override
@ -222,38 +216,30 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
@Override @Override
public void reset() { public void reset() {
//No op // No op
} }
@Override @Override
public boolean resetSupported() { public boolean resetSupported() {
return true; return true;
} }
} }
@Test @Test
public void testCsvSeqAndNumberedFileSplit() throws Exception { @DisplayName("Test Csv Seq And Numbered File Split")
File baseDir = tempDir.newFolder(); void testCsvSeqAndNumberedFileSplit(@TempDir Path tempDir) throws Exception {
//Simple sanity check unit test File baseDir = tempDir.toFile();
// Simple sanity check unit test
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(baseDir); new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(baseDir);
} }
// Load time series from CSV sequence files; compare to SequenceRecordReaderDataSetIterator
//Load time series from CSV sequence files; compare to SequenceRecordReaderDataSetIterator
ClassPathResource resource = new ClassPathResource("csvsequence_0.txt"); ClassPathResource resource = new ClassPathResource("csvsequence_0.txt");
String featuresPath = new File(baseDir, "csvsequence_%d.txt").getAbsolutePath(); String featuresPath = new File(baseDir, "csvsequence_%d.txt").getAbsolutePath();
SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ","); SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
while (featureReader.hasNext()) {
while(featureReader.hasNext()){
featureReader.nextSequence(); featureReader.nextSequence();
} }
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.records.reader.impl; package org.datavec.api.records.reader.impl;
import org.datavec.api.records.reader.SequenceRecordReader; import org.datavec.api.records.reader.SequenceRecordReader;
@ -25,94 +24,87 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.records.reader.impl.csv.CSVVariableSlidingWindowRecordReader; import org.datavec.api.records.reader.impl.csv.CSVVariableSlidingWindowRecordReader;
import org.datavec.api.split.FileSplit; import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals; @DisplayName("Csv Variable Sliding Window Record Reader Test")
class CSVVariableSlidingWindowRecordReaderTest extends BaseND4JTest {
public class CSVVariableSlidingWindowRecordReaderTest extends BaseND4JTest {
@Test @Test
public void testCSVVariableSlidingWindowRecordReader() throws Exception { @DisplayName("Test CSV Variable Sliding Window Record Reader")
void testCSVVariableSlidingWindowRecordReader() throws Exception {
int maxLinesPerSequence = 3; int maxLinesPerSequence = 3;
SequenceRecordReader seqRR = new CSVVariableSlidingWindowRecordReader(maxLinesPerSequence); SequenceRecordReader seqRR = new CSVVariableSlidingWindowRecordReader(maxLinesPerSequence);
seqRR.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile())); seqRR.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
CSVRecordReader rr = new CSVRecordReader(); CSVRecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile())); rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
int count = 0; int count = 0;
while (seqRR.hasNext()) { while (seqRR.hasNext()) {
List<List<Writable>> next = seqRR.sequenceRecord(); List<List<Writable>> next = seqRR.sequenceRecord();
if (count == maxLinesPerSequence - 1) {
if(count==maxLinesPerSequence-1) {
LinkedList<List<Writable>> expected = new LinkedList<>(); LinkedList<List<Writable>> expected = new LinkedList<>();
for (int i = 0; i < maxLinesPerSequence; i++) { for (int i = 0; i < maxLinesPerSequence; i++) {
expected.addFirst(rr.next()); expected.addFirst(rr.next());
} }
assertEquals(expected, next); assertEquals(expected, next);
} }
if(count==maxLinesPerSequence) { if (count == maxLinesPerSequence) {
assertEquals(maxLinesPerSequence, next.size()); assertEquals(maxLinesPerSequence, next.size());
} }
if(count==0) { // first seq should be length 1 if (count == 0) {
// first seq should be length 1
assertEquals(1, next.size()); assertEquals(1, next.size());
} }
if(count>151) { // last seq should be length 1 if (count > 151) {
// last seq should be length 1
assertEquals(1, next.size()); assertEquals(1, next.size());
} }
count++; count++;
} }
assertEquals(152, count); assertEquals(152, count);
} }
@Test @Test
public void testCSVVariableSlidingWindowRecordReaderStride() throws Exception { @DisplayName("Test CSV Variable Sliding Window Record Reader Stride")
void testCSVVariableSlidingWindowRecordReaderStride() throws Exception {
int maxLinesPerSequence = 3; int maxLinesPerSequence = 3;
int stride = 2; int stride = 2;
SequenceRecordReader seqRR = new CSVVariableSlidingWindowRecordReader(maxLinesPerSequence, stride); SequenceRecordReader seqRR = new CSVVariableSlidingWindowRecordReader(maxLinesPerSequence, stride);
seqRR.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile())); seqRR.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
CSVRecordReader rr = new CSVRecordReader(); CSVRecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile())); rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
int count = 0; int count = 0;
while (seqRR.hasNext()) { while (seqRR.hasNext()) {
List<List<Writable>> next = seqRR.sequenceRecord(); List<List<Writable>> next = seqRR.sequenceRecord();
if (count == maxLinesPerSequence - 1) {
if(count==maxLinesPerSequence-1) {
LinkedList<List<Writable>> expected = new LinkedList<>(); LinkedList<List<Writable>> expected = new LinkedList<>();
for(int s = 0; s < stride; s++) { for (int s = 0; s < stride; s++) {
expected = new LinkedList<>(); expected = new LinkedList<>();
for (int i = 0; i < maxLinesPerSequence; i++) { for (int i = 0; i < maxLinesPerSequence; i++) {
expected.addFirst(rr.next()); expected.addFirst(rr.next());
} }
} }
assertEquals(expected, next); assertEquals(expected, next);
} }
if(count==maxLinesPerSequence) { if (count == maxLinesPerSequence) {
assertEquals(maxLinesPerSequence, next.size()); assertEquals(maxLinesPerSequence, next.size());
} }
if(count==0) { // first seq should be length 2 if (count == 0) {
// first seq should be length 2
assertEquals(2, next.size()); assertEquals(2, next.size());
} }
if(count>151) { // last seq should be length 1 if (count > 151) {
// last seq should be length 1
assertEquals(1, next.size()); assertEquals(1, next.size());
} }
count++; count++;
} }
assertEquals(76, count); assertEquals(76, count);
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.records.reader.impl; package org.datavec.api.records.reader.impl;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
@ -29,44 +28,38 @@ import org.datavec.api.records.reader.impl.filebatch.FileBatchRecordReader;
import org.datavec.api.records.reader.impl.filebatch.FileBatchSequenceRecordReader; import org.datavec.api.records.reader.impl.filebatch.FileBatchSequenceRecordReader;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.TemporaryFolder; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.loader.FileBatch; import org.nd4j.common.loader.FileBatch;
import java.io.File; import java.io.File;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.*; @DisplayName("File Batch Record Reader Test")
class FileBatchRecordReaderTest extends BaseND4JTest {
public class FileBatchRecordReaderTest extends BaseND4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test @Test
public void testCsv() throws Exception { @DisplayName("Test Csv")
void testCsv(@TempDir Path testDir) throws Exception {
//This is an unrealistic use case - one line/record per CSV // This is an unrealistic use case - one line/record per CSV
File baseDir = testDir.newFolder(); File baseDir = testDir.toFile();
List<File> fileList = new ArrayList<>(); List<File> fileList = new ArrayList<>();
for( int i=0; i<10; i++ ){ for (int i = 0; i < 10; i++) {
String s = "file_" + i + "," + i + "," + i; String s = "file_" + i + "," + i + "," + i;
File f = new File(baseDir, "origFile" + i + ".csv"); File f = new File(baseDir, "origFile" + i + ".csv");
FileUtils.writeStringToFile(f, s, StandardCharsets.UTF_8); FileUtils.writeStringToFile(f, s, StandardCharsets.UTF_8);
fileList.add(f); fileList.add(f);
} }
FileBatch fb = FileBatch.forFiles(fileList); FileBatch fb = FileBatch.forFiles(fileList);
RecordReader rr = new CSVRecordReader(); RecordReader rr = new CSVRecordReader();
FileBatchRecordReader fbrr = new FileBatchRecordReader(rr, fb); FileBatchRecordReader fbrr = new FileBatchRecordReader(rr, fb);
for (int test = 0; test < 3; test++) {
for( int test=0; test<3; test++) {
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
assertTrue(fbrr.hasNext()); assertTrue(fbrr.hasNext());
List<Writable> next = fbrr.next(); List<Writable> next = fbrr.next();
@ -83,15 +76,15 @@ public class FileBatchRecordReaderTest extends BaseND4JTest {
} }
@Test @Test
public void testCsvSequence() throws Exception { @DisplayName("Test Csv Sequence")
//CSV sequence - 3 lines per file, 10 files void testCsvSequence(@TempDir Path testDir) throws Exception {
File baseDir = testDir.newFolder(); // CSV sequence - 3 lines per file, 10 files
File baseDir = testDir.toFile();
List<File> fileList = new ArrayList<>(); List<File> fileList = new ArrayList<>();
for( int i=0; i<10; i++ ){ for (int i = 0; i < 10; i++) {
StringBuilder sb = new StringBuilder(); StringBuilder sb = new StringBuilder();
for( int j=0; j<3; j++ ){ for (int j = 0; j < 3; j++) {
if(j > 0) if (j > 0)
sb.append("\n"); sb.append("\n");
sb.append("file_" + i + "," + i + "," + j); sb.append("file_" + i + "," + i + "," + j);
} }
@ -99,19 +92,16 @@ public class FileBatchRecordReaderTest extends BaseND4JTest {
FileUtils.writeStringToFile(f, sb.toString(), StandardCharsets.UTF_8); FileUtils.writeStringToFile(f, sb.toString(), StandardCharsets.UTF_8);
fileList.add(f); fileList.add(f);
} }
FileBatch fb = FileBatch.forFiles(fileList); FileBatch fb = FileBatch.forFiles(fileList);
SequenceRecordReader rr = new CSVSequenceRecordReader(); SequenceRecordReader rr = new CSVSequenceRecordReader();
FileBatchSequenceRecordReader fbrr = new FileBatchSequenceRecordReader(rr, fb); FileBatchSequenceRecordReader fbrr = new FileBatchSequenceRecordReader(rr, fb);
for (int test = 0; test < 3; test++) {
for( int test=0; test<3; test++) {
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
assertTrue(fbrr.hasNext()); assertTrue(fbrr.hasNext());
List<List<Writable>> next = fbrr.sequenceRecord(); List<List<Writable>> next = fbrr.sequenceRecord();
assertEquals(3, next.size()); assertEquals(3, next.size());
int count = 0; int count = 0;
for(List<Writable> step : next ){ for (List<Writable> step : next) {
String s1 = "file_" + i; String s1 = "file_" + i;
assertEquals(s1, step.get(0).toString()); assertEquals(s1, step.get(0).toString());
assertEquals(String.valueOf(i), step.get(1).toString()); assertEquals(String.valueOf(i), step.get(1).toString());
@ -123,5 +113,4 @@ public class FileBatchRecordReaderTest extends BaseND4JTest {
fbrr.reset(); fbrr.reset();
} }
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.records.reader.impl; package org.datavec.api.records.reader.impl;
import org.datavec.api.records.Record; import org.datavec.api.records.Record;
@ -26,28 +25,28 @@ import org.datavec.api.split.CollectionInputSplit;
import org.datavec.api.split.FileSplit; import org.datavec.api.split.FileSplit;
import org.datavec.api.split.InputSplit; import org.datavec.api.split.InputSplit;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import java.net.URI; import java.net.URI;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals; @DisplayName("File Record Reader Test")
import static org.junit.Assert.assertFalse; class FileRecordReaderTest extends BaseND4JTest {
public class FileRecordReaderTest extends BaseND4JTest {
@Test @Test
public void testReset() throws Exception { @DisplayName("Test Reset")
void testReset() throws Exception {
FileRecordReader rr = new FileRecordReader(); FileRecordReader rr = new FileRecordReader();
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile())); rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
int nResets = 5; int nResets = 5;
for (int i = 0; i < nResets; i++) { for (int i = 0; i < nResets; i++) {
int lineCount = 0; int lineCount = 0;
while (rr.hasNext()) { while (rr.hasNext()) {
List<Writable> line = rr.next(); List<Writable> line = rr.next();
@ -61,25 +60,20 @@ public class FileRecordReaderTest extends BaseND4JTest {
} }
@Test @Test
public void testMeta() throws Exception { @DisplayName("Test Meta")
void testMeta() throws Exception {
FileRecordReader rr = new FileRecordReader(); FileRecordReader rr = new FileRecordReader();
URI[] arr = new URI[3]; URI[] arr = new URI[3];
arr[0] = new ClassPathResource("datavec-api/csvsequence_0.txt").getFile().toURI(); arr[0] = new ClassPathResource("datavec-api/csvsequence_0.txt").getFile().toURI();
arr[1] = new ClassPathResource("datavec-api/csvsequence_1.txt").getFile().toURI(); arr[1] = new ClassPathResource("datavec-api/csvsequence_1.txt").getFile().toURI();
arr[2] = new ClassPathResource("datavec-api/csvsequence_2.txt").getFile().toURI(); arr[2] = new ClassPathResource("datavec-api/csvsequence_2.txt").getFile().toURI();
InputSplit is = new CollectionInputSplit(Arrays.asList(arr)); InputSplit is = new CollectionInputSplit(Arrays.asList(arr));
rr.initialize(is); rr.initialize(is);
List<List<Writable>> out = new ArrayList<>(); List<List<Writable>> out = new ArrayList<>();
while (rr.hasNext()) { while (rr.hasNext()) {
out.add(rr.next()); out.add(rr.next());
} }
assertEquals(3, out.size()); assertEquals(3, out.size());
rr.reset(); rr.reset();
List<List<Writable>> out2 = new ArrayList<>(); List<List<Writable>> out2 = new ArrayList<>();
List<Record> out3 = new ArrayList<>(); List<Record> out3 = new ArrayList<>();
@ -90,13 +84,10 @@ public class FileRecordReaderTest extends BaseND4JTest {
out2.add(r.getRecord()); out2.add(r.getRecord());
out3.add(r); out3.add(r);
meta.add(r.getMetaData()); meta.add(r.getMetaData());
assertEquals(arr[count++], r.getMetaData().getURI()); assertEquals(arr[count++], r.getMetaData().getURI());
} }
assertEquals(out, out2); assertEquals(out, out2);
List<Record> fromMeta = rr.loadFromMetaData(meta); List<Record> fromMeta = rr.loadFromMetaData(meta);
assertEquals(out3, fromMeta); assertEquals(out3, fromMeta);
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.records.reader.impl; package org.datavec.api.records.reader.impl;
import org.datavec.api.records.reader.RecordReader; import org.datavec.api.records.reader.RecordReader;
@ -29,96 +28,80 @@ import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.TemporaryFolder; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import org.nd4j.shade.jackson.core.JsonFactory; import org.nd4j.shade.jackson.core.JsonFactory;
import org.nd4j.shade.jackson.databind.ObjectMapper; import org.nd4j.shade.jackson.databind.ObjectMapper;
import java.io.File; import java.io.File;
import java.net.URI; import java.net.URI;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals; @DisplayName("Jackson Line Record Reader Test")
class JacksonLineRecordReaderTest extends BaseND4JTest {
public class JacksonLineRecordReaderTest extends BaseND4JTest { @TempDir
public Path testDir;
@Rule public JacksonLineRecordReaderTest() {
public TemporaryFolder testDir = new TemporaryFolder(); }
public JacksonLineRecordReaderTest() {
}
private static FieldSelection getFieldSelection() { private static FieldSelection getFieldSelection() {
return new FieldSelection.Builder().addField("value1"). return new FieldSelection.Builder().addField("value1").addField("value2").addField("value3").addField("value4").addField("value5").addField("value6").addField("value7").addField("value8").addField("value9").addField("value10").build();
addField("value2").
addField("value3").
addField("value4").
addField("value5").
addField("value6").
addField("value7").
addField("value8").
addField("value9").
addField("value10").build();
} }
@Test @Test
public void testReadJSON() throws Exception { @DisplayName("Test Read JSON")
void testReadJSON() throws Exception {
RecordReader rr = new JacksonLineRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory())); RecordReader rr = new JacksonLineRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()));
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/json/json_test_3.txt").getFile())); rr.initialize(new FileSplit(new ClassPathResource("datavec-api/json/json_test_3.txt").getFile()));
testJacksonRecordReader(rr); testJacksonRecordReader(rr);
}
private static void testJacksonRecordReader(RecordReader rr) {
while (rr.hasNext()) {
List<Writable> json0 = rr.next();
//System.out.println(json0);
assert(json0.size() > 0);
}
} }
private static void testJacksonRecordReader(RecordReader rr) {
while (rr.hasNext()) {
List<Writable> json0 = rr.next();
// System.out.println(json0);
assert (json0.size() > 0);
}
}
@Test @Test
public void testJacksonLineSequenceRecordReader() throws Exception { @DisplayName("Test Jackson Line Sequence Record Reader")
File dir = testDir.newFolder(); void testJacksonLineSequenceRecordReader(@TempDir Path testDir) throws Exception {
new ClassPathResource("datavec-api/JacksonLineSequenceRecordReaderTest/").copyDirectory(dir); File dir = testDir.toFile();
new ClassPathResource("datavec-api/JacksonLineSequenceRecordReaderTest/").copyDirectory(dir);
FieldSelection f = new FieldSelection.Builder().addField("a").addField(new Text("MISSING_B"), "b") FieldSelection f = new FieldSelection.Builder().addField("a").addField(new Text("MISSING_B"), "b").addField(new Text("MISSING_CX"), "c", "x").build();
.addField(new Text("MISSING_CX"), "c", "x").build(); JacksonLineSequenceRecordReader rr = new JacksonLineSequenceRecordReader(f, new ObjectMapper(new JsonFactory()));
File[] files = dir.listFiles();
JacksonLineSequenceRecordReader rr = new JacksonLineSequenceRecordReader(f, new ObjectMapper(new JsonFactory())); Arrays.sort(files);
File[] files = dir.listFiles(); URI[] u = new URI[files.length];
Arrays.sort(files); for (int i = 0; i < files.length; i++) {
URI[] u = new URI[files.length]; u[i] = files[i].toURI();
for( int i=0; i<files.length; i++ ){ }
u[i] = files[i].toURI(); rr.initialize(new CollectionInputSplit(u));
} List<List<Writable>> expSeq0 = new ArrayList<>();
rr.initialize(new CollectionInputSplit(u)); expSeq0.add(Arrays.asList((Writable) new Text("aValue0"), new Text("bValue0"), new Text("cxValue0")));
expSeq0.add(Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1")));
List<List<Writable>> expSeq0 = new ArrayList<>(); expSeq0.add(Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX")));
expSeq0.add(Arrays.asList((Writable) new Text("aValue0"), new Text("bValue0"), new Text("cxValue0"))); List<List<Writable>> expSeq1 = new ArrayList<>();
expSeq0.add(Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1"))); expSeq1.add(Arrays.asList((Writable) new Text("aValue3"), new Text("bValue3"), new Text("cxValue3")));
expSeq0.add(Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX"))); int count = 0;
while (rr.hasNext()) {
List<List<Writable>> expSeq1 = new ArrayList<>(); List<List<Writable>> next = rr.sequenceRecord();
expSeq1.add(Arrays.asList((Writable) new Text("aValue3"), new Text("bValue3"), new Text("cxValue3"))); if (count++ == 0) {
assertEquals(expSeq0, next);
} else {
int count = 0; assertEquals(expSeq1, next);
while(rr.hasNext()){ }
List<List<Writable>> next = rr.sequenceRecord(); }
if(count++ == 0){ assertEquals(2, count);
assertEquals(expSeq0, next); }
} else {
assertEquals(expSeq1, next);
}
}
assertEquals(2, count);
}
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.records.reader.impl; package org.datavec.api.records.reader.impl;
import org.datavec.api.io.labels.PathLabelGenerator; import org.datavec.api.io.labels.PathLabelGenerator;
@ -32,113 +31,94 @@ import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.TemporaryFolder; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import org.nd4j.shade.jackson.core.JsonFactory; import org.nd4j.shade.jackson.core.JsonFactory;
import org.nd4j.shade.jackson.databind.ObjectMapper; import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.dataformat.xml.XmlFactory; import org.nd4j.shade.jackson.dataformat.xml.XmlFactory;
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory; import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
import java.io.File; import java.io.File;
import java.net.URI; import java.net.URI;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals; @DisplayName("Jackson Record Reader Test")
import static org.junit.Assert.assertFalse; class JacksonRecordReaderTest extends BaseND4JTest {
public class JacksonRecordReaderTest extends BaseND4JTest { @TempDir
public Path testDir;
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test @Test
public void testReadingJson() throws Exception { @DisplayName("Test Reading Json")
//Load 3 values from 3 JSON files void testReadingJson(@TempDir Path testDir) throws Exception {
//stricture: a:value, b:value, c:x:value, c:y:value // Load 3 values from 3 JSON files
//And we want to load only a:value, b:value and c:x:value // stricture: a:value, b:value, c:x:value, c:y:value
//For first JSON file: all values are present // And we want to load only a:value, b:value and c:x:value
//For second JSON file: b:value is missing // For first JSON file: all values are present
//For third JSON file: c:x:value is missing // For second JSON file: b:value is missing
// For third JSON file: c:x:value is missing
ClassPathResource cpr = new ClassPathResource("datavec-api/json/"); ClassPathResource cpr = new ClassPathResource("datavec-api/json/");
File f = testDir.newFolder(); File f = testDir.toFile();
cpr.copyDirectory(f); cpr.copyDirectory(f);
String path = new File(f, "json_test_%d.txt").getAbsolutePath(); String path = new File(f, "json_test_%d.txt").getAbsolutePath();
InputSplit is = new NumberedFileInputSplit(path, 0, 2); InputSplit is = new NumberedFileInputSplit(path, 0, 2);
RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory())); RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()));
rr.initialize(is); rr.initialize(is);
testJacksonRecordReader(rr); testJacksonRecordReader(rr);
} }
@Test @Test
public void testReadingYaml() throws Exception { @DisplayName("Test Reading Yaml")
//Exact same information as JSON format, but in YAML format void testReadingYaml(@TempDir Path testDir) throws Exception {
// Exact same information as JSON format, but in YAML format
ClassPathResource cpr = new ClassPathResource("datavec-api/yaml/"); ClassPathResource cpr = new ClassPathResource("datavec-api/yaml/");
File f = testDir.newFolder(); File f = testDir.toFile();
cpr.copyDirectory(f); cpr.copyDirectory(f);
String path = new File(f, "yaml_test_%d.txt").getAbsolutePath(); String path = new File(f, "yaml_test_%d.txt").getAbsolutePath();
InputSplit is = new NumberedFileInputSplit(path, 0, 2); InputSplit is = new NumberedFileInputSplit(path, 0, 2);
RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new YAMLFactory())); RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new YAMLFactory()));
rr.initialize(is); rr.initialize(is);
testJacksonRecordReader(rr); testJacksonRecordReader(rr);
} }
@Test @Test
public void testReadingXml() throws Exception { @DisplayName("Test Reading Xml")
//Exact same information as JSON format, but in XML format void testReadingXml(@TempDir Path testDir) throws Exception {
// Exact same information as JSON format, but in XML format
ClassPathResource cpr = new ClassPathResource("datavec-api/xml/"); ClassPathResource cpr = new ClassPathResource("datavec-api/xml/");
File f = testDir.newFolder(); File f = testDir.toFile();
cpr.copyDirectory(f); cpr.copyDirectory(f);
String path = new File(f, "xml_test_%d.txt").getAbsolutePath(); String path = new File(f, "xml_test_%d.txt").getAbsolutePath();
InputSplit is = new NumberedFileInputSplit(path, 0, 2); InputSplit is = new NumberedFileInputSplit(path, 0, 2);
RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new XmlFactory())); RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new XmlFactory()));
rr.initialize(is); rr.initialize(is);
testJacksonRecordReader(rr); testJacksonRecordReader(rr);
} }
private static FieldSelection getFieldSelection() { private static FieldSelection getFieldSelection() {
return new FieldSelection.Builder().addField("a").addField(new Text("MISSING_B"), "b") return new FieldSelection.Builder().addField("a").addField(new Text("MISSING_B"), "b").addField(new Text("MISSING_CX"), "c", "x").build();
.addField(new Text("MISSING_CX"), "c", "x").build();
} }
private static void testJacksonRecordReader(RecordReader rr) { private static void testJacksonRecordReader(RecordReader rr) {
List<Writable> json0 = rr.next(); List<Writable> json0 = rr.next();
List<Writable> exp0 = Arrays.asList((Writable) new Text("aValue0"), new Text("bValue0"), new Text("cxValue0")); List<Writable> exp0 = Arrays.asList((Writable) new Text("aValue0"), new Text("bValue0"), new Text("cxValue0"));
assertEquals(exp0, json0); assertEquals(exp0, json0);
List<Writable> json1 = rr.next(); List<Writable> json1 = rr.next();
List<Writable> exp1 = List<Writable> exp1 = Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1"));
Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1"));
assertEquals(exp1, json1); assertEquals(exp1, json1);
List<Writable> json2 = rr.next(); List<Writable> json2 = rr.next();
List<Writable> exp2 = List<Writable> exp2 = Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX"));
Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX"));
assertEquals(exp2, json2); assertEquals(exp2, json2);
assertFalse(rr.hasNext()); assertFalse(rr.hasNext());
// Test reset
//Test reset
rr.reset(); rr.reset();
assertEquals(exp0, rr.next()); assertEquals(exp0, rr.next());
assertEquals(exp1, rr.next()); assertEquals(exp1, rr.next());
@ -147,72 +127,50 @@ public class JacksonRecordReaderTest extends BaseND4JTest {
} }
@Test @Test
public void testAppendingLabels() throws Exception { @DisplayName("Test Appending Labels")
void testAppendingLabels(@TempDir Path testDir) throws Exception {
ClassPathResource cpr = new ClassPathResource("datavec-api/json/"); ClassPathResource cpr = new ClassPathResource("datavec-api/json/");
File f = testDir.newFolder(); File f = testDir.toFile();
cpr.copyDirectory(f); cpr.copyDirectory(f);
String path = new File(f, "json_test_%d.txt").getAbsolutePath(); String path = new File(f, "json_test_%d.txt").getAbsolutePath();
InputSplit is = new NumberedFileInputSplit(path, 0, 2); InputSplit is = new NumberedFileInputSplit(path, 0, 2);
// Insert at the end:
//Insert at the end: RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1, new LabelGen());
RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1,
new LabelGen());
rr.initialize(is); rr.initialize(is);
List<Writable> exp0 = Arrays.asList((Writable) new Text("aValue0"), new Text("bValue0"), new Text("cxValue0"), new IntWritable(0));
List<Writable> exp0 = Arrays.asList((Writable) new Text("aValue0"), new Text("bValue0"), new Text("cxValue0"),
new IntWritable(0));
assertEquals(exp0, rr.next()); assertEquals(exp0, rr.next());
List<Writable> exp1 = Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1"), new IntWritable(1));
List<Writable> exp1 = Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1"),
new IntWritable(1));
assertEquals(exp1, rr.next()); assertEquals(exp1, rr.next());
List<Writable> exp2 = Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX"), new IntWritable(2));
List<Writable> exp2 = Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX"),
new IntWritable(2));
assertEquals(exp2, rr.next()); assertEquals(exp2, rr.next());
// Insert at position 0:
//Insert at position 0: rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1, new LabelGen(), 0);
rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1,
new LabelGen(), 0);
rr.initialize(is); rr.initialize(is);
exp0 = Arrays.asList((Writable) new IntWritable(0), new Text("aValue0"), new Text("bValue0"), new Text("cxValue0"));
exp0 = Arrays.asList((Writable) new IntWritable(0), new Text("aValue0"), new Text("bValue0"),
new Text("cxValue0"));
assertEquals(exp0, rr.next()); assertEquals(exp0, rr.next());
exp1 = Arrays.asList((Writable) new IntWritable(1), new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1"));
exp1 = Arrays.asList((Writable) new IntWritable(1), new Text("aValue1"), new Text("MISSING_B"),
new Text("cxValue1"));
assertEquals(exp1, rr.next()); assertEquals(exp1, rr.next());
exp2 = Arrays.asList((Writable) new IntWritable(2), new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX"));
exp2 = Arrays.asList((Writable) new IntWritable(2), new Text("aValue2"), new Text("bValue2"),
new Text("MISSING_CX"));
assertEquals(exp2, rr.next()); assertEquals(exp2, rr.next());
} }
@Test @Test
public void testAppendingLabelsMetaData() throws Exception { @DisplayName("Test Appending Labels Meta Data")
void testAppendingLabelsMetaData(@TempDir Path testDir) throws Exception {
ClassPathResource cpr = new ClassPathResource("datavec-api/json/"); ClassPathResource cpr = new ClassPathResource("datavec-api/json/");
File f = testDir.newFolder(); File f = testDir.toFile();
cpr.copyDirectory(f); cpr.copyDirectory(f);
String path = new File(f, "json_test_%d.txt").getAbsolutePath(); String path = new File(f, "json_test_%d.txt").getAbsolutePath();
InputSplit is = new NumberedFileInputSplit(path, 0, 2); InputSplit is = new NumberedFileInputSplit(path, 0, 2);
// Insert at the end:
//Insert at the end: RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1, new LabelGen());
RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1,
new LabelGen());
rr.initialize(is); rr.initialize(is);
List<List<Writable>> out = new ArrayList<>(); List<List<Writable>> out = new ArrayList<>();
while (rr.hasNext()) { while (rr.hasNext()) {
out.add(rr.next()); out.add(rr.next());
} }
assertEquals(3, out.size()); assertEquals(3, out.size());
rr.reset(); rr.reset();
List<List<Writable>> out2 = new ArrayList<>(); List<List<Writable>> out2 = new ArrayList<>();
List<Record> outRecord = new ArrayList<>(); List<Record> outRecord = new ArrayList<>();
List<RecordMetaData> meta = new ArrayList<>(); List<RecordMetaData> meta = new ArrayList<>();
@ -222,14 +180,12 @@ public class JacksonRecordReaderTest extends BaseND4JTest {
outRecord.add(r); outRecord.add(r);
meta.add(r.getMetaData()); meta.add(r.getMetaData());
} }
assertEquals(out, out2); assertEquals(out, out2);
List<Record> fromMeta = rr.loadFromMetaData(meta); List<Record> fromMeta = rr.loadFromMetaData(meta);
assertEquals(outRecord, fromMeta); assertEquals(outRecord, fromMeta);
} }
@DisplayName("Label Gen")
private static class LabelGen implements PathLabelGenerator { private static class LabelGen implements PathLabelGenerator {
@Override @Override
@ -252,5 +208,4 @@ public class JacksonRecordReaderTest extends BaseND4JTest {
return true; return true;
} }
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.records.reader.impl; package org.datavec.api.records.reader.impl;
import org.datavec.api.conf.Configuration; import org.datavec.api.conf.Configuration;
@ -27,43 +26,30 @@ import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import java.io.IOException; import java.io.IOException;
import java.util.*; import java.util.*;
import static org.datavec.api.records.reader.impl.misc.LibSvmRecordReader.*; import static org.datavec.api.records.reader.impl.misc.LibSvmRecordReader.*;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.jupiter.api.Assertions.assertThrows;
public class LibSvmRecordReaderTest extends BaseND4JTest { @DisplayName("Lib Svm Record Reader Test")
class LibSvmRecordReaderTest extends BaseND4JTest {
@Test @Test
public void testBasicRecord() throws IOException, InterruptedException { @DisplayName("Test Basic Record")
void testBasicRecord() throws IOException, InterruptedException {
Map<Integer, List<Writable>> correct = new HashMap<>(); Map<Integer, List<Writable>> correct = new HashMap<>();
// 7 2:1 4:2 6:3 8:4 10:5 // 7 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO, ONE, correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), new IntWritable(7)));
ZERO, new DoubleWritable(2),
ZERO, new DoubleWritable(3),
ZERO, new DoubleWritable(4),
ZERO, new DoubleWritable(5),
new IntWritable(7)));
// 2 qid:42 1:0.1 2:2 6:6.6 8:80 // 2 qid:42 1:0.1 2:2 6:6.6 8:80
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, new IntWritable(2)));
ZERO, ZERO,
ZERO, new DoubleWritable(6.6),
ZERO, new DoubleWritable(80),
ZERO, ZERO,
new IntWritable(2)));
// 33 // 33
correct.put(2, Arrays.asList(ZERO, ZERO, correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, new IntWritable(33)));
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
new IntWritable(33)));
LibSvmRecordReader rr = new LibSvmRecordReader(); LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration(); Configuration config = new Configuration();
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false); config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
@ -80,27 +66,15 @@ public class LibSvmRecordReaderTest extends BaseND4JTest {
} }
@Test @Test
public void testNoAppendLabel() throws IOException, InterruptedException { @DisplayName("Test No Append Label")
void testNoAppendLabel() throws IOException, InterruptedException {
Map<Integer, List<Writable>> correct = new HashMap<>(); Map<Integer, List<Writable>> correct = new HashMap<>();
// 7 2:1 4:2 6:3 8:4 10:5 // 7 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO, ONE, correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5)));
ZERO, new DoubleWritable(2),
ZERO, new DoubleWritable(3),
ZERO, new DoubleWritable(4),
ZERO, new DoubleWritable(5)));
// 2 qid:42 1:0.1 2:2 6:6.6 8:80 // 2 qid:42 1:0.1 2:2 6:6.6 8:80
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO));
ZERO, ZERO,
ZERO, new DoubleWritable(6.6),
ZERO, new DoubleWritable(80),
ZERO, ZERO));
// 33 // 33
correct.put(2, Arrays.asList(ZERO, ZERO, correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO));
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO));
SVMLightRecordReader rr = new SVMLightRecordReader(); SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration(); Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false); config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
@ -117,33 +91,17 @@ public class LibSvmRecordReaderTest extends BaseND4JTest {
} }
@Test @Test
public void testNoLabel() throws IOException, InterruptedException { @DisplayName("Test No Label")
void testNoLabel() throws IOException, InterruptedException {
Map<Integer, List<Writable>> correct = new HashMap<>(); Map<Integer, List<Writable>> correct = new HashMap<>();
// 2:1 4:2 6:3 8:4 10:5 // 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO, ONE, correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5)));
ZERO, new DoubleWritable(2), // qid:42 1:0.1 2:2 6:6.6 8:80
ZERO, new DoubleWritable(3), correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO));
ZERO, new DoubleWritable(4), // 1:1.0
ZERO, new DoubleWritable(5))); correct.put(2, Arrays.asList(new DoubleWritable(1.0), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO));
// qid:42 1:0.1 2:2 6:6.6 8:80 //
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), correct.put(3, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO));
ZERO, ZERO,
ZERO, new DoubleWritable(6.6),
ZERO, new DoubleWritable(80),
ZERO, ZERO));
// 1:1.0
correct.put(2, Arrays.asList(new DoubleWritable(1.0), ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO));
//
correct.put(3, Arrays.asList(ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO));
SVMLightRecordReader rr = new SVMLightRecordReader(); SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration(); Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false); config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
@ -160,33 +118,15 @@ public class LibSvmRecordReaderTest extends BaseND4JTest {
} }
@Test @Test
public void testMultioutputRecord() throws IOException, InterruptedException { @DisplayName("Test Multioutput Record")
void testMultioutputRecord() throws IOException, InterruptedException {
Map<Integer, List<Writable>> correct = new HashMap<>(); Map<Integer, List<Writable>> correct = new HashMap<>();
// 7 2.45,9 2:1 4:2 6:3 8:4 10:5 // 7 2.45,9 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO, ONE, correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), new IntWritable(7), new DoubleWritable(2.45), new IntWritable(9)));
ZERO, new DoubleWritable(2),
ZERO, new DoubleWritable(3),
ZERO, new DoubleWritable(4),
ZERO, new DoubleWritable(5),
new IntWritable(7), new DoubleWritable(2.45),
new IntWritable(9)));
// 2,3,4 qid:42 1:0.1 2:2 6:6.6 8:80 // 2,3,4 qid:42 1:0.1 2:2 6:6.6 8:80
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, new IntWritable(2), new IntWritable(3), new IntWritable(4)));
ZERO, ZERO,
ZERO, new DoubleWritable(6.6),
ZERO, new DoubleWritable(80),
ZERO, ZERO,
new IntWritable(2), new IntWritable(3),
new IntWritable(4)));
// 33,32.0,31.9 // 33,32.0,31.9
correct.put(2, Arrays.asList(ZERO, ZERO, correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, new IntWritable(33), new DoubleWritable(32.0), new DoubleWritable(31.9)));
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
new IntWritable(33), new DoubleWritable(32.0),
new DoubleWritable(31.9)));
LibSvmRecordReader rr = new LibSvmRecordReader(); LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration(); Configuration config = new Configuration();
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false); config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
@ -202,51 +142,20 @@ public class LibSvmRecordReaderTest extends BaseND4JTest {
assertEquals(i, correct.size()); assertEquals(i, correct.size());
} }
@Test @Test
public void testMultilabelRecord() throws IOException, InterruptedException { @DisplayName("Test Multilabel Record")
void testMultilabelRecord() throws IOException, InterruptedException {
Map<Integer, List<Writable>> correct = new HashMap<>(); Map<Integer, List<Writable>> correct = new HashMap<>();
// 1,3 2:1 4:2 6:3 8:4 10:5 // 1,3 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO, ONE, correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), LABEL_ONE, LABEL_ZERO, LABEL_ONE, LABEL_ZERO));
ZERO, new DoubleWritable(2),
ZERO, new DoubleWritable(3),
ZERO, new DoubleWritable(4),
ZERO, new DoubleWritable(5),
LABEL_ONE, LABEL_ZERO,
LABEL_ONE, LABEL_ZERO));
// 2 qid:42 1:0.1 2:2 6:6.6 8:80 // 2 qid:42 1:0.1 2:2 6:6.6 8:80
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, LABEL_ZERO, LABEL_ONE, LABEL_ZERO, LABEL_ZERO));
ZERO, ZERO,
ZERO, new DoubleWritable(6.6),
ZERO, new DoubleWritable(80),
ZERO, ZERO,
LABEL_ZERO, LABEL_ONE,
LABEL_ZERO, LABEL_ZERO));
// 1,2,4 // 1,2,4
correct.put(2, Arrays.asList(ZERO, ZERO, correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ONE, LABEL_ONE, LABEL_ZERO, LABEL_ONE));
ZERO, ZERO, // 1:1.0
ZERO, ZERO, correct.put(3, Arrays.asList(new DoubleWritable(1.0), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
ZERO, ZERO, //
ZERO, ZERO, correct.put(4, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
LABEL_ONE, LABEL_ONE,
LABEL_ZERO, LABEL_ONE));
// 1:1.0
correct.put(3, Arrays.asList(new DoubleWritable(1.0), ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
LABEL_ZERO, LABEL_ZERO,
LABEL_ZERO, LABEL_ZERO));
//
correct.put(4, Arrays.asList(ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
LABEL_ZERO, LABEL_ZERO,
LABEL_ZERO, LABEL_ZERO));
LibSvmRecordReader rr = new LibSvmRecordReader(); LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration(); Configuration config = new Configuration();
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false); config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
@ -265,63 +174,24 @@ public class LibSvmRecordReaderTest extends BaseND4JTest {
} }
@Test @Test
public void testZeroBasedIndexing() throws IOException, InterruptedException { @DisplayName("Test Zero Based Indexing")
void testZeroBasedIndexing() throws IOException, InterruptedException {
Map<Integer, List<Writable>> correct = new HashMap<>(); Map<Integer, List<Writable>> correct = new HashMap<>();
// 1,3 2:1 4:2 6:3 8:4 10:5 // 1,3 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO, correct.put(0, Arrays.asList(ZERO, ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), LABEL_ZERO, LABEL_ONE, LABEL_ZERO, LABEL_ONE, LABEL_ZERO));
ZERO, ONE,
ZERO, new DoubleWritable(2),
ZERO, new DoubleWritable(3),
ZERO, new DoubleWritable(4),
ZERO, new DoubleWritable(5),
LABEL_ZERO,
LABEL_ONE, LABEL_ZERO,
LABEL_ONE, LABEL_ZERO));
// 2 qid:42 1:0.1 2:2 6:6.6 8:80 // 2 qid:42 1:0.1 2:2 6:6.6 8:80
correct.put(1, Arrays.asList(ZERO, correct.put(1, Arrays.asList(ZERO, new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ONE, LABEL_ZERO, LABEL_ZERO));
new DoubleWritable(0.1), new DoubleWritable(2),
ZERO, ZERO,
ZERO, new DoubleWritable(6.6),
ZERO, new DoubleWritable(80),
ZERO, ZERO,
LABEL_ZERO,
LABEL_ZERO, LABEL_ONE,
LABEL_ZERO, LABEL_ZERO));
// 1,2,4 // 1,2,4
correct.put(2, Arrays.asList(ZERO, correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ONE, LABEL_ONE, LABEL_ZERO, LABEL_ONE));
ZERO, ZERO, // 1:1.0
ZERO, ZERO, correct.put(3, Arrays.asList(ZERO, new DoubleWritable(1.0), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
ZERO, ZERO, //
ZERO, ZERO, correct.put(4, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
ZERO, ZERO,
LABEL_ZERO,
LABEL_ONE, LABEL_ONE,
LABEL_ZERO, LABEL_ONE));
// 1:1.0
correct.put(3, Arrays.asList(ZERO,
new DoubleWritable(1.0), ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
LABEL_ZERO,
LABEL_ZERO, LABEL_ZERO,
LABEL_ZERO, LABEL_ZERO));
//
correct.put(4, Arrays.asList(ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
LABEL_ZERO,
LABEL_ZERO, LABEL_ZERO,
LABEL_ZERO, LABEL_ZERO));
LibSvmRecordReader rr = new LibSvmRecordReader(); LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration(); Configuration config = new Configuration();
// Zero-based indexing is default // Zero-based indexing is default
config.setBoolean(SVMLightRecordReader.ZERO_BASED_LABEL_INDEXING, true); // NOT STANDARD! // NOT STANDARD!
config.setBoolean(SVMLightRecordReader.ZERO_BASED_LABEL_INDEXING, true);
config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true); config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true);
config.setInt(LibSvmRecordReader.NUM_FEATURES, 11); config.setInt(LibSvmRecordReader.NUM_FEATURES, 11);
config.setBoolean(LibSvmRecordReader.MULTILABEL, true); config.setBoolean(LibSvmRecordReader.MULTILABEL, true);
@ -336,87 +206,107 @@ public class LibSvmRecordReaderTest extends BaseND4JTest {
assertEquals(i, correct.size()); assertEquals(i, correct.size());
} }
@Test(expected = NoSuchElementException.class) @Test
public void testNoSuchElementException() throws Exception { @DisplayName("Test No Such Element Exception")
LibSvmRecordReader rr = new LibSvmRecordReader(); void testNoSuchElementException() {
Configuration config = new Configuration(); assertThrows(NoSuchElementException.class, () -> {
config.setInt(LibSvmRecordReader.NUM_FEATURES, 11); LibSvmRecordReader rr = new LibSvmRecordReader();
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile())); Configuration config = new Configuration();
while (rr.hasNext()) config.setInt(LibSvmRecordReader.NUM_FEATURES, 11);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
while (rr.hasNext()) rr.next();
rr.next(); rr.next();
rr.next(); });
} }
@Test(expected = UnsupportedOperationException.class) @Test
public void failedToSetNumFeaturesException() throws Exception { @DisplayName("Failed To Set Num Features Exception")
LibSvmRecordReader rr = new LibSvmRecordReader(); void failedToSetNumFeaturesException() {
Configuration config = new Configuration(); assertThrows(UnsupportedOperationException.class, () -> {
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile())); LibSvmRecordReader rr = new LibSvmRecordReader();
while (rr.hasNext()) Configuration config = new Configuration();
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
while (rr.hasNext()) rr.next();
});
}
@Test
@DisplayName("Test Inconsistent Num Labels Exception")
void testInconsistentNumLabelsException() {
assertThrows(UnsupportedOperationException.class, () -> {
LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration();
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/inconsistentNumLabels.txt").getFile()));
while (rr.hasNext()) rr.next();
});
}
@Test
@DisplayName("Test Inconsistent Num Multiabels Exception")
void testInconsistentNumMultiabelsException() {
assertThrows(UnsupportedOperationException.class, () -> {
LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration();
config.setBoolean(LibSvmRecordReader.MULTILABEL, false);
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile()));
while (rr.hasNext()) rr.next();
});
}
@Test
@DisplayName("Test Feature Index Exceeds Num Features")
void testFeatureIndexExceedsNumFeatures() {
assertThrows(IndexOutOfBoundsException.class, () -> {
LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration();
config.setInt(LibSvmRecordReader.NUM_FEATURES, 9);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
rr.next(); rr.next();
});
} }
@Test(expected = UnsupportedOperationException.class) @Test
public void testInconsistentNumLabelsException() throws Exception { @DisplayName("Test Label Index Exceeds Num Labels")
LibSvmRecordReader rr = new LibSvmRecordReader(); void testLabelIndexExceedsNumLabels() {
Configuration config = new Configuration(); assertThrows(IndexOutOfBoundsException.class, () -> {
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false); LibSvmRecordReader rr = new LibSvmRecordReader();
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/inconsistentNumLabels.txt").getFile())); Configuration config = new Configuration();
while (rr.hasNext()) config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true);
config.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
config.setInt(LibSvmRecordReader.NUM_LABELS, 6);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
rr.next(); rr.next();
});
} }
@Test(expected = UnsupportedOperationException.class) @Test
public void testInconsistentNumMultiabelsException() throws Exception { @DisplayName("Test Zero Index Feature Without Using Zero Indexing")
LibSvmRecordReader rr = new LibSvmRecordReader(); void testZeroIndexFeatureWithoutUsingZeroIndexing() {
Configuration config = new Configuration(); assertThrows(IndexOutOfBoundsException.class, () -> {
config.setBoolean(LibSvmRecordReader.MULTILABEL, false); LibSvmRecordReader rr = new LibSvmRecordReader();
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false); Configuration config = new Configuration();
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile())); config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
while (rr.hasNext()) config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true);
config.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexFeature.txt").getFile()));
rr.next(); rr.next();
});
} }
@Test(expected = IndexOutOfBoundsException.class) @Test
public void testFeatureIndexExceedsNumFeatures() throws Exception { @DisplayName("Test Zero Index Label Without Using Zero Indexing")
LibSvmRecordReader rr = new LibSvmRecordReader(); void testZeroIndexLabelWithoutUsingZeroIndexing() {
Configuration config = new Configuration(); assertThrows(IndexOutOfBoundsException.class, () -> {
config.setInt(LibSvmRecordReader.NUM_FEATURES, 9); LibSvmRecordReader rr = new LibSvmRecordReader();
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile())); Configuration config = new Configuration();
rr.next(); config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true);
} config.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
config.setBoolean(LibSvmRecordReader.MULTILABEL, true);
@Test(expected = IndexOutOfBoundsException.class) config.setInt(LibSvmRecordReader.NUM_LABELS, 2);
public void testLabelIndexExceedsNumLabels() throws Exception { rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexLabel.txt").getFile()));
LibSvmRecordReader rr = new LibSvmRecordReader(); rr.next();
Configuration config = new Configuration(); });
config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true);
config.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
config.setInt(LibSvmRecordReader.NUM_LABELS, 6);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
rr.next();
}
@Test(expected = IndexOutOfBoundsException.class)
public void testZeroIndexFeatureWithoutUsingZeroIndexing() throws Exception {
LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration();
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true);
config.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexFeature.txt").getFile()));
rr.next();
}
@Test(expected = IndexOutOfBoundsException.class)
public void testZeroIndexLabelWithoutUsingZeroIndexing() throws Exception {
LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration();
config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true);
config.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
config.setBoolean(LibSvmRecordReader.MULTILABEL, true);
config.setInt(LibSvmRecordReader.NUM_LABELS, 2);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexLabel.txt").getFile()));
rr.next();
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.records.reader.impl; package org.datavec.api.records.reader.impl;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
@ -31,10 +30,9 @@ import org.datavec.api.split.InputSplit;
import org.datavec.api.split.InputStreamInputSplit; import org.datavec.api.split.InputStreamInputSplit;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.TemporaryFolder; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.io.File; import java.io.File;
import java.io.FileInputStream; import java.io.FileInputStream;
import java.io.FileOutputStream; import java.io.FileOutputStream;
@ -45,34 +43,31 @@ import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.zip.GZIPInputStream; import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream; import java.util.zip.GZIPOutputStream;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals; @DisplayName("Line Reader Test")
class LineReaderTest extends BaseND4JTest {
public class LineReaderTest extends BaseND4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test @Test
public void testLineReader() throws Exception { @DisplayName("Test Line Reader")
File tmpdir = testDir.newFolder(); void testLineReader(@TempDir Path tmpDir) throws Exception {
File tmpdir = tmpDir.toFile();
if (tmpdir.exists()) if (tmpdir.exists())
tmpdir.delete(); tmpdir.delete();
tmpdir.mkdir(); tmpdir.mkdir();
File tmp1 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp1.txt")); File tmp1 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp1.txt"));
File tmp2 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp2.txt")); File tmp2 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp2.txt"));
File tmp3 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp3.txt")); File tmp3 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp3.txt"));
FileUtils.writeLines(tmp1, Arrays.asList("1", "2", "3")); FileUtils.writeLines(tmp1, Arrays.asList("1", "2", "3"));
FileUtils.writeLines(tmp2, Arrays.asList("4", "5", "6")); FileUtils.writeLines(tmp2, Arrays.asList("4", "5", "6"));
FileUtils.writeLines(tmp3, Arrays.asList("7", "8", "9")); FileUtils.writeLines(tmp3, Arrays.asList("7", "8", "9"));
InputSplit split = new FileSplit(tmpdir); InputSplit split = new FileSplit(tmpdir);
RecordReader reader = new LineRecordReader(); RecordReader reader = new LineRecordReader();
reader.initialize(split); reader.initialize(split);
int count = 0; int count = 0;
List<List<Writable>> list = new ArrayList<>(); List<List<Writable>> list = new ArrayList<>();
while (reader.hasNext()) { while (reader.hasNext()) {
@ -81,34 +76,27 @@ public class LineReaderTest extends BaseND4JTest {
list.add(l); list.add(l);
count++; count++;
} }
assertEquals(9, count); assertEquals(9, count);
} }
@Test @Test
public void testLineReaderMetaData() throws Exception { @DisplayName("Test Line Reader Meta Data")
File tmpdir = testDir.newFolder(); void testLineReaderMetaData(@TempDir Path tmpDir) throws Exception {
File tmpdir = tmpDir.toFile();
File tmp1 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp1.txt")); File tmp1 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp1.txt"));
File tmp2 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp2.txt")); File tmp2 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp2.txt"));
File tmp3 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp3.txt")); File tmp3 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp3.txt"));
FileUtils.writeLines(tmp1, Arrays.asList("1", "2", "3")); FileUtils.writeLines(tmp1, Arrays.asList("1", "2", "3"));
FileUtils.writeLines(tmp2, Arrays.asList("4", "5", "6")); FileUtils.writeLines(tmp2, Arrays.asList("4", "5", "6"));
FileUtils.writeLines(tmp3, Arrays.asList("7", "8", "9")); FileUtils.writeLines(tmp3, Arrays.asList("7", "8", "9"));
InputSplit split = new FileSplit(tmpdir); InputSplit split = new FileSplit(tmpdir);
RecordReader reader = new LineRecordReader(); RecordReader reader = new LineRecordReader();
reader.initialize(split); reader.initialize(split);
List<List<Writable>> list = new ArrayList<>(); List<List<Writable>> list = new ArrayList<>();
while (reader.hasNext()) { while (reader.hasNext()) {
list.add(reader.next()); list.add(reader.next());
} }
assertEquals(9, list.size()); assertEquals(9, list.size());
List<List<Writable>> out2 = new ArrayList<>(); List<List<Writable>> out2 = new ArrayList<>();
List<Record> out3 = new ArrayList<>(); List<Record> out3 = new ArrayList<>();
List<RecordMetaData> meta = new ArrayList<>(); List<RecordMetaData> meta = new ArrayList<>();
@ -124,13 +112,10 @@ public class LineReaderTest extends BaseND4JTest {
assertEquals(uri, split.locations()[fileIdx]); assertEquals(uri, split.locations()[fileIdx]);
count++; count++;
} }
assertEquals(list, out2); assertEquals(list, out2);
List<Record> fromMeta = reader.loadFromMetaData(meta); List<Record> fromMeta = reader.loadFromMetaData(meta);
assertEquals(out3, fromMeta); assertEquals(out3, fromMeta);
// try: second line of second and third files only...
//try: second line of second and third files only...
List<RecordMetaData> subsetMeta = new ArrayList<>(); List<RecordMetaData> subsetMeta = new ArrayList<>();
subsetMeta.add(meta.get(4)); subsetMeta.add(meta.get(4));
subsetMeta.add(meta.get(7)); subsetMeta.add(meta.get(7));
@ -141,27 +126,22 @@ public class LineReaderTest extends BaseND4JTest {
} }
@Test @Test
public void testLineReaderWithInputStreamInputSplit() throws Exception { @DisplayName("Test Line Reader With Input Stream Input Split")
File tmpdir = testDir.newFolder(); void testLineReaderWithInputStreamInputSplit(@TempDir Path testDir) throws Exception {
File tmpdir = testDir.toFile();
File tmp1 = new File(tmpdir, "tmp1.txt.gz"); File tmp1 = new File(tmpdir, "tmp1.txt.gz");
OutputStream os = new GZIPOutputStream(new FileOutputStream(tmp1, false)); OutputStream os = new GZIPOutputStream(new FileOutputStream(tmp1, false));
IOUtils.writeLines(Arrays.asList("1", "2", "3", "4", "5", "6", "7", "8", "9"), null, os); IOUtils.writeLines(Arrays.asList("1", "2", "3", "4", "5", "6", "7", "8", "9"), null, os);
os.flush(); os.flush();
os.close(); os.close();
InputSplit split = new InputStreamInputSplit(new GZIPInputStream(new FileInputStream(tmp1))); InputSplit split = new InputStreamInputSplit(new GZIPInputStream(new FileInputStream(tmp1)));
RecordReader reader = new LineRecordReader(); RecordReader reader = new LineRecordReader();
reader.initialize(split); reader.initialize(split);
int count = 0; int count = 0;
while (reader.hasNext()) { while (reader.hasNext()) {
assertEquals(1, reader.next().size()); assertEquals(1, reader.next().size());
count++; count++;
} }
assertEquals(9, count); assertEquals(9, count);
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.records.reader.impl; package org.datavec.api.records.reader.impl;
import org.datavec.api.records.Record; import org.datavec.api.records.Record;
@ -34,43 +33,40 @@ import org.datavec.api.split.NumberedFileInputSplit;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.TemporaryFolder; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import java.io.File; import java.io.File;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals; @DisplayName("Regex Record Reader Test")
import static org.junit.Assert.assertFalse; class RegexRecordReaderTest extends BaseND4JTest {
public class RegexRecordReaderTest extends BaseND4JTest { @TempDir
public Path testDir;
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test @Test
public void testRegexLineRecordReader() throws Exception { @DisplayName("Test Regex Line Record Reader")
void testRegexLineRecordReader() throws Exception {
String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)"; String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)";
RecordReader rr = new RegexLineRecordReader(regex, 1); RecordReader rr = new RegexLineRecordReader(regex, 1);
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/logtestdata/logtestfile0.txt").getFile())); rr.initialize(new FileSplit(new ClassPathResource("datavec-api/logtestdata/logtestfile0.txt").getFile()));
List<Writable> exp0 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.001"), new Text("1"), new Text("DEBUG"), new Text("First entry message!"));
List<Writable> exp0 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.001"), new Text("1"), List<Writable> exp1 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.002"), new Text("2"), new Text("INFO"), new Text("Second entry message!"));
new Text("DEBUG"), new Text("First entry message!")); List<Writable> exp2 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.003"), new Text("3"), new Text("WARN"), new Text("Third entry message!"));
List<Writable> exp1 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.002"), new Text("2"),
new Text("INFO"), new Text("Second entry message!"));
List<Writable> exp2 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.003"), new Text("3"),
new Text("WARN"), new Text("Third entry message!"));
assertEquals(exp0, rr.next()); assertEquals(exp0, rr.next());
assertEquals(exp1, rr.next()); assertEquals(exp1, rr.next());
assertEquals(exp2, rr.next()); assertEquals(exp2, rr.next());
assertFalse(rr.hasNext()); assertFalse(rr.hasNext());
// Test reset:
//Test reset:
rr.reset(); rr.reset();
assertEquals(exp0, rr.next()); assertEquals(exp0, rr.next());
assertEquals(exp1, rr.next()); assertEquals(exp1, rr.next());
@ -79,74 +75,57 @@ public class RegexRecordReaderTest extends BaseND4JTest {
} }
@Test @Test
public void testRegexLineRecordReaderMeta() throws Exception { @DisplayName("Test Regex Line Record Reader Meta")
void testRegexLineRecordReaderMeta() throws Exception {
String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)"; String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)";
RecordReader rr = new RegexLineRecordReader(regex, 1); RecordReader rr = new RegexLineRecordReader(regex, 1);
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/logtestdata/logtestfile0.txt").getFile())); rr.initialize(new FileSplit(new ClassPathResource("datavec-api/logtestdata/logtestfile0.txt").getFile()));
List<List<Writable>> list = new ArrayList<>(); List<List<Writable>> list = new ArrayList<>();
while (rr.hasNext()) { while (rr.hasNext()) {
list.add(rr.next()); list.add(rr.next());
} }
assertEquals(3, list.size()); assertEquals(3, list.size());
List<Record> list2 = new ArrayList<>(); List<Record> list2 = new ArrayList<>();
List<List<Writable>> list3 = new ArrayList<>(); List<List<Writable>> list3 = new ArrayList<>();
List<RecordMetaData> meta = new ArrayList<>(); List<RecordMetaData> meta = new ArrayList<>();
rr.reset(); rr.reset();
int count = 1; //Start by skipping 1 line // Start by skipping 1 line
int count = 1;
while (rr.hasNext()) { while (rr.hasNext()) {
Record r = rr.nextRecord(); Record r = rr.nextRecord();
list2.add(r); list2.add(r);
list3.add(r.getRecord()); list3.add(r.getRecord());
meta.add(r.getMetaData()); meta.add(r.getMetaData());
assertEquals(count++, ((RecordMetaDataLine) r.getMetaData()).getLineNumber()); assertEquals(count++, ((RecordMetaDataLine) r.getMetaData()).getLineNumber());
} }
List<Record> fromMeta = rr.loadFromMetaData(meta); List<Record> fromMeta = rr.loadFromMetaData(meta);
assertEquals(list, list3); assertEquals(list, list3);
assertEquals(list2, fromMeta); assertEquals(list2, fromMeta);
} }
@Test @Test
public void testRegexSequenceRecordReader() throws Exception { @DisplayName("Test Regex Sequence Record Reader")
void testRegexSequenceRecordReader(@TempDir Path testDir) throws Exception {
String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)"; String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)";
ClassPathResource cpr = new ClassPathResource("datavec-api/logtestdata/"); ClassPathResource cpr = new ClassPathResource("datavec-api/logtestdata/");
File f = testDir.newFolder(); File f = testDir.toFile();
cpr.copyDirectory(f); cpr.copyDirectory(f);
String path = new File(f, "logtestfile%d.txt").getAbsolutePath(); String path = new File(f, "logtestfile%d.txt").getAbsolutePath();
InputSplit is = new NumberedFileInputSplit(path, 0, 1); InputSplit is = new NumberedFileInputSplit(path, 0, 1);
SequenceRecordReader rr = new RegexSequenceRecordReader(regex, 1); SequenceRecordReader rr = new RegexSequenceRecordReader(regex, 1);
rr.initialize(is); rr.initialize(is);
List<List<Writable>> exp0 = new ArrayList<>(); List<List<Writable>> exp0 = new ArrayList<>();
exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.001"), new Text("1"), new Text("DEBUG"), exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.001"), new Text("1"), new Text("DEBUG"), new Text("First entry message!")));
new Text("First entry message!"))); exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.002"), new Text("2"), new Text("INFO"), new Text("Second entry message!")));
exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.002"), new Text("2"), new Text("INFO"), exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.003"), new Text("3"), new Text("WARN"), new Text("Third entry message!")));
new Text("Second entry message!")));
exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.003"), new Text("3"), new Text("WARN"),
new Text("Third entry message!")));
List<List<Writable>> exp1 = new ArrayList<>(); List<List<Writable>> exp1 = new ArrayList<>();
exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.011"), new Text("11"), new Text("DEBUG"), exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.011"), new Text("11"), new Text("DEBUG"), new Text("First entry message!")));
new Text("First entry message!"))); exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.012"), new Text("12"), new Text("INFO"), new Text("Second entry message!")));
exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.012"), new Text("12"), new Text("INFO"), exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.013"), new Text("13"), new Text("WARN"), new Text("Third entry message!")));
new Text("Second entry message!")));
exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.013"), new Text("13"), new Text("WARN"),
new Text("Third entry message!")));
assertEquals(exp0, rr.sequenceRecord()); assertEquals(exp0, rr.sequenceRecord());
assertEquals(exp1, rr.sequenceRecord()); assertEquals(exp1, rr.sequenceRecord());
assertFalse(rr.hasNext()); assertFalse(rr.hasNext());
// Test resetting:
//Test resetting:
rr.reset(); rr.reset();
assertEquals(exp0, rr.sequenceRecord()); assertEquals(exp0, rr.sequenceRecord());
assertEquals(exp1, rr.sequenceRecord()); assertEquals(exp1, rr.sequenceRecord());
@ -154,24 +133,20 @@ public class RegexRecordReaderTest extends BaseND4JTest {
} }
@Test @Test
public void testRegexSequenceRecordReaderMeta() throws Exception { @DisplayName("Test Regex Sequence Record Reader Meta")
void testRegexSequenceRecordReaderMeta(@TempDir Path testDir) throws Exception {
String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)"; String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)";
ClassPathResource cpr = new ClassPathResource("datavec-api/logtestdata/"); ClassPathResource cpr = new ClassPathResource("datavec-api/logtestdata/");
File f = testDir.newFolder(); File f = testDir.toFile();
cpr.copyDirectory(f); cpr.copyDirectory(f);
String path = new File(f, "logtestfile%d.txt").getAbsolutePath(); String path = new File(f, "logtestfile%d.txt").getAbsolutePath();
InputSplit is = new NumberedFileInputSplit(path, 0, 1); InputSplit is = new NumberedFileInputSplit(path, 0, 1);
SequenceRecordReader rr = new RegexSequenceRecordReader(regex, 1); SequenceRecordReader rr = new RegexSequenceRecordReader(regex, 1);
rr.initialize(is); rr.initialize(is);
List<List<List<Writable>>> out = new ArrayList<>(); List<List<List<Writable>>> out = new ArrayList<>();
while (rr.hasNext()) { while (rr.hasNext()) {
out.add(rr.sequenceRecord()); out.add(rr.sequenceRecord());
} }
assertEquals(2, out.size()); assertEquals(2, out.size());
List<List<List<Writable>>> out2 = new ArrayList<>(); List<List<List<Writable>>> out2 = new ArrayList<>();
List<SequenceRecord> out3 = new ArrayList<>(); List<SequenceRecord> out3 = new ArrayList<>();
@ -183,11 +158,8 @@ public class RegexRecordReaderTest extends BaseND4JTest {
out3.add(seqr); out3.add(seqr);
meta.add(seqr.getMetaData()); meta.add(seqr.getMetaData());
} }
List<SequenceRecord> fromMeta = rr.loadSequenceFromMetaData(meta); List<SequenceRecord> fromMeta = rr.loadSequenceFromMetaData(meta);
assertEquals(out, out2); assertEquals(out, out2);
assertEquals(out3, fromMeta); assertEquals(out3, fromMeta);
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.records.reader.impl; package org.datavec.api.records.reader.impl;
import org.datavec.api.conf.Configuration; import org.datavec.api.conf.Configuration;
@ -27,43 +26,30 @@ import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import java.io.IOException; import java.io.IOException;
import java.util.*; import java.util.*;
import static org.datavec.api.records.reader.impl.misc.SVMLightRecordReader.*; import static org.datavec.api.records.reader.impl.misc.SVMLightRecordReader.*;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.jupiter.api.Assertions.assertThrows;
public class SVMLightRecordReaderTest extends BaseND4JTest { @DisplayName("Svm Light Record Reader Test")
class SVMLightRecordReaderTest extends BaseND4JTest {
@Test @Test
public void testBasicRecord() throws IOException, InterruptedException { @DisplayName("Test Basic Record")
void testBasicRecord() throws IOException, InterruptedException {
Map<Integer, List<Writable>> correct = new HashMap<>(); Map<Integer, List<Writable>> correct = new HashMap<>();
// 7 2:1 4:2 6:3 8:4 10:5 // 7 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO, ONE, correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), new IntWritable(7)));
ZERO, new DoubleWritable(2),
ZERO, new DoubleWritable(3),
ZERO, new DoubleWritable(4),
ZERO, new DoubleWritable(5),
new IntWritable(7)));
// 2 qid:42 1:0.1 2:2 6:6.6 8:80 // 2 qid:42 1:0.1 2:2 6:6.6 8:80
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, new IntWritable(2)));
ZERO, ZERO,
ZERO, new DoubleWritable(6.6),
ZERO, new DoubleWritable(80),
ZERO, ZERO,
new IntWritable(2)));
// 33 // 33
correct.put(2, Arrays.asList(ZERO, ZERO, correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, new IntWritable(33)));
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
new IntWritable(33)));
SVMLightRecordReader rr = new SVMLightRecordReader(); SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration(); Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false); config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
@ -79,27 +65,15 @@ public class SVMLightRecordReaderTest extends BaseND4JTest {
} }
@Test @Test
public void testNoAppendLabel() throws IOException, InterruptedException { @DisplayName("Test No Append Label")
void testNoAppendLabel() throws IOException, InterruptedException {
Map<Integer, List<Writable>> correct = new HashMap<>(); Map<Integer, List<Writable>> correct = new HashMap<>();
// 7 2:1 4:2 6:3 8:4 10:5 // 7 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO, ONE, correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5)));
ZERO, new DoubleWritable(2),
ZERO, new DoubleWritable(3),
ZERO, new DoubleWritable(4),
ZERO, new DoubleWritable(5)));
// 2 qid:42 1:0.1 2:2 6:6.6 8:80 // 2 qid:42 1:0.1 2:2 6:6.6 8:80
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO));
ZERO, ZERO,
ZERO, new DoubleWritable(6.6),
ZERO, new DoubleWritable(80),
ZERO, ZERO));
// 33 // 33
correct.put(2, Arrays.asList(ZERO, ZERO, correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO));
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO));
SVMLightRecordReader rr = new SVMLightRecordReader(); SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration(); Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false); config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
@ -116,33 +90,17 @@ public class SVMLightRecordReaderTest extends BaseND4JTest {
} }
@Test @Test
public void testNoLabel() throws IOException, InterruptedException { @DisplayName("Test No Label")
void testNoLabel() throws IOException, InterruptedException {
Map<Integer, List<Writable>> correct = new HashMap<>(); Map<Integer, List<Writable>> correct = new HashMap<>();
// 2:1 4:2 6:3 8:4 10:5 // 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO, ONE, correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5)));
ZERO, new DoubleWritable(2), // qid:42 1:0.1 2:2 6:6.6 8:80
ZERO, new DoubleWritable(3), correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO));
ZERO, new DoubleWritable(4), // 1:1.0
ZERO, new DoubleWritable(5))); correct.put(2, Arrays.asList(new DoubleWritable(1.0), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO));
// qid:42 1:0.1 2:2 6:6.6 8:80 //
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), correct.put(3, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO));
ZERO, ZERO,
ZERO, new DoubleWritable(6.6),
ZERO, new DoubleWritable(80),
ZERO, ZERO));
// 1:1.0
correct.put(2, Arrays.asList(new DoubleWritable(1.0), ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO));
//
correct.put(3, Arrays.asList(ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO));
SVMLightRecordReader rr = new SVMLightRecordReader(); SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration(); Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false); config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
@ -159,33 +117,15 @@ public class SVMLightRecordReaderTest extends BaseND4JTest {
} }
@Test @Test
public void testMultioutputRecord() throws IOException, InterruptedException { @DisplayName("Test Multioutput Record")
void testMultioutputRecord() throws IOException, InterruptedException {
Map<Integer, List<Writable>> correct = new HashMap<>(); Map<Integer, List<Writable>> correct = new HashMap<>();
// 7 2.45,9 2:1 4:2 6:3 8:4 10:5 // 7 2.45,9 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO, ONE, correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), new IntWritable(7), new DoubleWritable(2.45), new IntWritable(9)));
ZERO, new DoubleWritable(2),
ZERO, new DoubleWritable(3),
ZERO, new DoubleWritable(4),
ZERO, new DoubleWritable(5),
new IntWritable(7), new DoubleWritable(2.45),
new IntWritable(9)));
// 2,3,4 qid:42 1:0.1 2:2 6:6.6 8:80 // 2,3,4 qid:42 1:0.1 2:2 6:6.6 8:80
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, new IntWritable(2), new IntWritable(3), new IntWritable(4)));
ZERO, ZERO,
ZERO, new DoubleWritable(6.6),
ZERO, new DoubleWritable(80),
ZERO, ZERO,
new IntWritable(2), new IntWritable(3),
new IntWritable(4)));
// 33,32.0,31.9 // 33,32.0,31.9
correct.put(2, Arrays.asList(ZERO, ZERO, correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, new IntWritable(33), new DoubleWritable(32.0), new DoubleWritable(31.9)));
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
new IntWritable(33), new DoubleWritable(32.0),
new DoubleWritable(31.9)));
SVMLightRecordReader rr = new SVMLightRecordReader(); SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration(); Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false); config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
@ -200,51 +140,20 @@ public class SVMLightRecordReaderTest extends BaseND4JTest {
assertEquals(i, correct.size()); assertEquals(i, correct.size());
} }
@Test @Test
public void testMultilabelRecord() throws IOException, InterruptedException { @DisplayName("Test Multilabel Record")
void testMultilabelRecord() throws IOException, InterruptedException {
Map<Integer, List<Writable>> correct = new HashMap<>(); Map<Integer, List<Writable>> correct = new HashMap<>();
// 1,3 2:1 4:2 6:3 8:4 10:5 // 1,3 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO, ONE, correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), LABEL_ONE, LABEL_ZERO, LABEL_ONE, LABEL_ZERO));
ZERO, new DoubleWritable(2),
ZERO, new DoubleWritable(3),
ZERO, new DoubleWritable(4),
ZERO, new DoubleWritable(5),
LABEL_ONE, LABEL_ZERO,
LABEL_ONE, LABEL_ZERO));
// 2 qid:42 1:0.1 2:2 6:6.6 8:80 // 2 qid:42 1:0.1 2:2 6:6.6 8:80
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, LABEL_ZERO, LABEL_ONE, LABEL_ZERO, LABEL_ZERO));
ZERO, ZERO,
ZERO, new DoubleWritable(6.6),
ZERO, new DoubleWritable(80),
ZERO, ZERO,
LABEL_ZERO, LABEL_ONE,
LABEL_ZERO, LABEL_ZERO));
// 1,2,4 // 1,2,4
correct.put(2, Arrays.asList(ZERO, ZERO, correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ONE, LABEL_ONE, LABEL_ZERO, LABEL_ONE));
ZERO, ZERO, // 1:1.0
ZERO, ZERO, correct.put(3, Arrays.asList(new DoubleWritable(1.0), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
ZERO, ZERO, //
ZERO, ZERO, correct.put(4, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
LABEL_ONE, LABEL_ONE,
LABEL_ZERO, LABEL_ONE));
// 1:1.0
correct.put(3, Arrays.asList(new DoubleWritable(1.0), ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
LABEL_ZERO, LABEL_ZERO,
LABEL_ZERO, LABEL_ZERO));
//
correct.put(4, Arrays.asList(ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
LABEL_ZERO, LABEL_ZERO,
LABEL_ZERO, LABEL_ZERO));
SVMLightRecordReader rr = new SVMLightRecordReader(); SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration(); Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false); config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
@ -262,63 +171,24 @@ public class SVMLightRecordReaderTest extends BaseND4JTest {
} }
@Test @Test
public void testZeroBasedIndexing() throws IOException, InterruptedException { @DisplayName("Test Zero Based Indexing")
void testZeroBasedIndexing() throws IOException, InterruptedException {
Map<Integer, List<Writable>> correct = new HashMap<>(); Map<Integer, List<Writable>> correct = new HashMap<>();
// 1,3 2:1 4:2 6:3 8:4 10:5 // 1,3 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO, correct.put(0, Arrays.asList(ZERO, ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), LABEL_ZERO, LABEL_ONE, LABEL_ZERO, LABEL_ONE, LABEL_ZERO));
ZERO, ONE,
ZERO, new DoubleWritable(2),
ZERO, new DoubleWritable(3),
ZERO, new DoubleWritable(4),
ZERO, new DoubleWritable(5),
LABEL_ZERO,
LABEL_ONE, LABEL_ZERO,
LABEL_ONE, LABEL_ZERO));
// 2 qid:42 1:0.1 2:2 6:6.6 8:80 // 2 qid:42 1:0.1 2:2 6:6.6 8:80
correct.put(1, Arrays.asList(ZERO, correct.put(1, Arrays.asList(ZERO, new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ONE, LABEL_ZERO, LABEL_ZERO));
new DoubleWritable(0.1), new DoubleWritable(2),
ZERO, ZERO,
ZERO, new DoubleWritable(6.6),
ZERO, new DoubleWritable(80),
ZERO, ZERO,
LABEL_ZERO,
LABEL_ZERO, LABEL_ONE,
LABEL_ZERO, LABEL_ZERO));
// 1,2,4 // 1,2,4
correct.put(2, Arrays.asList(ZERO, correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ONE, LABEL_ONE, LABEL_ZERO, LABEL_ONE));
ZERO, ZERO, // 1:1.0
ZERO, ZERO, correct.put(3, Arrays.asList(ZERO, new DoubleWritable(1.0), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
ZERO, ZERO, //
ZERO, ZERO, correct.put(4, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
ZERO, ZERO,
LABEL_ZERO,
LABEL_ONE, LABEL_ONE,
LABEL_ZERO, LABEL_ONE));
// 1:1.0
correct.put(3, Arrays.asList(ZERO,
new DoubleWritable(1.0), ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
LABEL_ZERO,
LABEL_ZERO, LABEL_ZERO,
LABEL_ZERO, LABEL_ZERO));
//
correct.put(4, Arrays.asList(ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
LABEL_ZERO,
LABEL_ZERO, LABEL_ZERO,
LABEL_ZERO, LABEL_ZERO));
SVMLightRecordReader rr = new SVMLightRecordReader(); SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration(); Configuration config = new Configuration();
// Zero-based indexing is default // Zero-based indexing is default
config.setBoolean(SVMLightRecordReader.ZERO_BASED_LABEL_INDEXING, true); // NOT STANDARD! // NOT STANDARD!
config.setBoolean(SVMLightRecordReader.ZERO_BASED_LABEL_INDEXING, true);
config.setInt(SVMLightRecordReader.NUM_FEATURES, 11); config.setInt(SVMLightRecordReader.NUM_FEATURES, 11);
config.setBoolean(SVMLightRecordReader.MULTILABEL, true); config.setBoolean(SVMLightRecordReader.MULTILABEL, true);
config.setInt(SVMLightRecordReader.NUM_LABELS, 5); config.setInt(SVMLightRecordReader.NUM_LABELS, 5);
@ -333,20 +203,19 @@ public class SVMLightRecordReaderTest extends BaseND4JTest {
} }
@Test @Test
public void testNextRecord() throws IOException, InterruptedException { @DisplayName("Test Next Record")
void testNextRecord() throws IOException, InterruptedException {
SVMLightRecordReader rr = new SVMLightRecordReader(); SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration(); Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false); config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
config.setInt(SVMLightRecordReader.NUM_FEATURES, 10); config.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
config.setBoolean(SVMLightRecordReader.APPEND_LABEL, false); config.setBoolean(SVMLightRecordReader.APPEND_LABEL, false);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile())); rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
Record record = rr.nextRecord(); Record record = rr.nextRecord();
List<Writable> recordList = record.getRecord(); List<Writable> recordList = record.getRecord();
assertEquals(new DoubleWritable(1.0), recordList.get(1)); assertEquals(new DoubleWritable(1.0), recordList.get(1));
assertEquals(new DoubleWritable(3.0), recordList.get(5)); assertEquals(new DoubleWritable(3.0), recordList.get(5));
assertEquals(new DoubleWritable(4.0), recordList.get(7)); assertEquals(new DoubleWritable(4.0), recordList.get(7));
record = rr.nextRecord(); record = rr.nextRecord();
recordList = record.getRecord(); recordList = record.getRecord();
assertEquals(new DoubleWritable(0.1), recordList.get(0)); assertEquals(new DoubleWritable(0.1), recordList.get(0));
@ -354,82 +223,102 @@ public class SVMLightRecordReaderTest extends BaseND4JTest {
assertEquals(new DoubleWritable(80.0), recordList.get(7)); assertEquals(new DoubleWritable(80.0), recordList.get(7));
} }
@Test(expected = NoSuchElementException.class) @Test
public void testNoSuchElementException() throws Exception { @DisplayName("Test No Such Element Exception")
SVMLightRecordReader rr = new SVMLightRecordReader(); void testNoSuchElementException() {
Configuration config = new Configuration(); assertThrows(NoSuchElementException.class, () -> {
config.setInt(SVMLightRecordReader.NUM_FEATURES, 11); SVMLightRecordReader rr = new SVMLightRecordReader();
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile())); Configuration config = new Configuration();
while (rr.hasNext()) config.setInt(SVMLightRecordReader.NUM_FEATURES, 11);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
while (rr.hasNext()) rr.next();
rr.next(); rr.next();
rr.next(); });
} }
@Test(expected = UnsupportedOperationException.class) @Test
public void failedToSetNumFeaturesException() throws Exception { @DisplayName("Failed To Set Num Features Exception")
SVMLightRecordReader rr = new SVMLightRecordReader(); void failedToSetNumFeaturesException() {
Configuration config = new Configuration(); assertThrows(UnsupportedOperationException.class, () -> {
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile())); SVMLightRecordReader rr = new SVMLightRecordReader();
while (rr.hasNext()) Configuration config = new Configuration();
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
while (rr.hasNext()) rr.next();
});
}
@Test
@DisplayName("Test Inconsistent Num Labels Exception")
void testInconsistentNumLabelsException() {
assertThrows(UnsupportedOperationException.class, () -> {
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/inconsistentNumLabels.txt").getFile()));
while (rr.hasNext()) rr.next();
});
}
@Test
@DisplayName("Failed To Set Num Multiabels Exception")
void failedToSetNumMultiabelsException() {
assertThrows(UnsupportedOperationException.class, () -> {
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile()));
while (rr.hasNext()) rr.next();
});
}
@Test
@DisplayName("Test Feature Index Exceeds Num Features")
void testFeatureIndexExceedsNumFeatures() {
assertThrows(IndexOutOfBoundsException.class, () -> {
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setInt(SVMLightRecordReader.NUM_FEATURES, 9);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
rr.next(); rr.next();
});
} }
@Test(expected = UnsupportedOperationException.class) @Test
public void testInconsistentNumLabelsException() throws Exception { @DisplayName("Test Label Index Exceeds Num Labels")
SVMLightRecordReader rr = new SVMLightRecordReader(); void testLabelIndexExceedsNumLabels() {
Configuration config = new Configuration(); assertThrows(IndexOutOfBoundsException.class, () -> {
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false); SVMLightRecordReader rr = new SVMLightRecordReader();
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/inconsistentNumLabels.txt").getFile())); Configuration config = new Configuration();
while (rr.hasNext()) config.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
config.setInt(SVMLightRecordReader.NUM_LABELS, 6);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
rr.next(); rr.next();
});
} }
@Test(expected = UnsupportedOperationException.class) @Test
public void failedToSetNumMultiabelsException() throws Exception { @DisplayName("Test Zero Index Feature Without Using Zero Indexing")
SVMLightRecordReader rr = new SVMLightRecordReader(); void testZeroIndexFeatureWithoutUsingZeroIndexing() {
Configuration config = new Configuration(); assertThrows(IndexOutOfBoundsException.class, () -> {
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile())); SVMLightRecordReader rr = new SVMLightRecordReader();
while (rr.hasNext()) Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
config.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexFeature.txt").getFile()));
rr.next(); rr.next();
});
} }
@Test(expected = IndexOutOfBoundsException.class) @Test
public void testFeatureIndexExceedsNumFeatures() throws Exception { @DisplayName("Test Zero Index Label Without Using Zero Indexing")
SVMLightRecordReader rr = new SVMLightRecordReader(); void testZeroIndexLabelWithoutUsingZeroIndexing() {
Configuration config = new Configuration(); assertThrows(IndexOutOfBoundsException.class, () -> {
config.setInt(SVMLightRecordReader.NUM_FEATURES, 9); SVMLightRecordReader rr = new SVMLightRecordReader();
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile())); Configuration config = new Configuration();
rr.next(); config.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
} config.setBoolean(SVMLightRecordReader.MULTILABEL, true);
config.setInt(SVMLightRecordReader.NUM_LABELS, 2);
@Test(expected = IndexOutOfBoundsException.class) rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexLabel.txt").getFile()));
public void testLabelIndexExceedsNumLabels() throws Exception { rr.next();
SVMLightRecordReader rr = new SVMLightRecordReader(); });
Configuration config = new Configuration();
config.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
config.setInt(SVMLightRecordReader.NUM_LABELS, 6);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
rr.next();
}
@Test(expected = IndexOutOfBoundsException.class)
public void testZeroIndexFeatureWithoutUsingZeroIndexing() throws Exception {
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
config.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexFeature.txt").getFile()));
rr.next();
}
@Test(expected = IndexOutOfBoundsException.class)
public void testZeroIndexLabelWithoutUsingZeroIndexing() throws Exception {
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
config.setBoolean(SVMLightRecordReader.MULTILABEL, true);
config.setInt(SVMLightRecordReader.NUM_LABELS, 2);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexLabel.txt").getFile()));
rr.next();
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.records.writer.impl; package org.datavec.api.records.writer.impl;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
@ -26,44 +25,42 @@ import org.datavec.api.split.FileSplit;
import org.datavec.api.split.partition.NumberOfRecordsPartitioner; import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Before; import org.junit.jupiter.api.BeforeEach;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.io.File; import java.io.File;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals; @DisplayName("Csv Record Writer Test")
class CSVRecordWriterTest extends BaseND4JTest {
public class CSVRecordWriterTest extends BaseND4JTest {
@Before
public void setUp() throws Exception {
@BeforeEach
void setUp() throws Exception {
} }
@Test @Test
public void testWrite() throws Exception { @DisplayName("Test Write")
void testWrite() throws Exception {
File tempFile = File.createTempFile("datavec", "writer"); File tempFile = File.createTempFile("datavec", "writer");
tempFile.deleteOnExit(); tempFile.deleteOnExit();
FileSplit fileSplit = new FileSplit(tempFile); FileSplit fileSplit = new FileSplit(tempFile);
CSVRecordWriter writer = new CSVRecordWriter(); CSVRecordWriter writer = new CSVRecordWriter();
writer.initialize(fileSplit,new NumberOfRecordsPartitioner()); writer.initialize(fileSplit, new NumberOfRecordsPartitioner());
List<Writable> collection = new ArrayList<>(); List<Writable> collection = new ArrayList<>();
collection.add(new Text("12")); collection.add(new Text("12"));
collection.add(new Text("13")); collection.add(new Text("13"));
collection.add(new Text("14")); collection.add(new Text("14"));
writer.write(collection); writer.write(collection);
CSVRecordReader reader = new CSVRecordReader(0); CSVRecordReader reader = new CSVRecordReader(0);
reader.initialize(new FileSplit(tempFile)); reader.initialize(new FileSplit(tempFile));
int cnt = 0; int cnt = 0;
while (reader.hasNext()) { while (reader.hasNext()) {
List<Writable> line = new ArrayList<>(reader.next()); List<Writable> line = new ArrayList<>(reader.next());
assertEquals(3, line.size()); assertEquals(3, line.size());
assertEquals(12, line.get(0).toInt()); assertEquals(12, line.get(0).toInt());
assertEquals(13, line.get(1).toInt()); assertEquals(13, line.get(1).toInt());
assertEquals(14, line.get(2).toInt()); assertEquals(14, line.get(2).toInt());

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.records.writer.impl; package org.datavec.api.records.writer.impl;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
@ -30,93 +29,90 @@ import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.NDArrayWritable; import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import java.io.File; import java.io.File;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.regex.Matcher; import java.util.regex.Matcher;
import java.util.regex.Pattern; import java.util.regex.Pattern;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.Assert.assertEquals; @DisplayName("Lib Svm Record Writer Test")
class LibSvmRecordWriterTest extends BaseND4JTest {
public class LibSvmRecordWriterTest extends BaseND4JTest {
@Test @Test
public void testBasic() throws Exception { @DisplayName("Test Basic")
void testBasic() throws Exception {
Configuration configWriter = new Configuration(); Configuration configWriter = new Configuration();
Configuration configReader = new Configuration(); Configuration configReader = new Configuration();
configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 10); configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
configReader.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false); configReader.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
File inputFile = new ClassPathResource("datavec-api/svmlight/basic.txt").getFile(); File inputFile = new ClassPathResource("datavec-api/svmlight/basic.txt").getFile();
executeTest(configWriter, configReader, inputFile); executeTest(configWriter, configReader, inputFile);
} }
@Test @Test
public void testNoLabel() throws Exception { @DisplayName("Test No Label")
void testNoLabel() throws Exception {
Configuration configWriter = new Configuration(); Configuration configWriter = new Configuration();
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 9); configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 9);
Configuration configReader = new Configuration(); Configuration configReader = new Configuration();
configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 10); configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
configReader.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false); configReader.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
File inputFile = new ClassPathResource("datavec-api/svmlight/basic.txt").getFile(); File inputFile = new ClassPathResource("datavec-api/svmlight/basic.txt").getFile();
executeTest(configWriter, configReader, inputFile); executeTest(configWriter, configReader, inputFile);
} }
@Test @Test
public void testMultioutputRecord() throws Exception { @DisplayName("Test Multioutput Record")
void testMultioutputRecord() throws Exception {
Configuration configWriter = new Configuration(); Configuration configWriter = new Configuration();
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 9); configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 9);
Configuration configReader = new Configuration(); Configuration configReader = new Configuration();
configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 10); configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
configReader.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false); configReader.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
File inputFile = new ClassPathResource("datavec-api/svmlight/multioutput.txt").getFile(); File inputFile = new ClassPathResource("datavec-api/svmlight/multioutput.txt").getFile();
executeTest(configWriter, configReader, inputFile); executeTest(configWriter, configReader, inputFile);
} }
@Test @Test
public void testMultilabelRecord() throws Exception { @DisplayName("Test Multilabel Record")
void testMultilabelRecord() throws Exception {
Configuration configWriter = new Configuration(); Configuration configWriter = new Configuration();
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 9); configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 9);
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true); configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
Configuration configReader = new Configuration(); Configuration configReader = new Configuration();
configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 10); configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
configReader.setBoolean(LibSvmRecordReader.MULTILABEL, true); configReader.setBoolean(LibSvmRecordReader.MULTILABEL, true);
configReader.setInt(LibSvmRecordReader.NUM_LABELS, 4); configReader.setInt(LibSvmRecordReader.NUM_LABELS, 4);
configReader.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false); configReader.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
File inputFile = new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile(); File inputFile = new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile();
executeTest(configWriter, configReader, inputFile); executeTest(configWriter, configReader, inputFile);
} }
@Test @Test
public void testZeroBasedIndexing() throws Exception { @DisplayName("Test Zero Based Indexing")
void testZeroBasedIndexing() throws Exception {
Configuration configWriter = new Configuration(); Configuration configWriter = new Configuration();
configWriter.setBoolean(LibSvmRecordWriter.ZERO_BASED_INDEXING, true); configWriter.setBoolean(LibSvmRecordWriter.ZERO_BASED_INDEXING, true);
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 10); configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 10);
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true); configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
Configuration configReader = new Configuration(); Configuration configReader = new Configuration();
configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 11); configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 11);
configReader.setBoolean(LibSvmRecordReader.MULTILABEL, true); configReader.setBoolean(LibSvmRecordReader.MULTILABEL, true);
configReader.setInt(LibSvmRecordReader.NUM_LABELS, 5); configReader.setInt(LibSvmRecordReader.NUM_LABELS, 5);
File inputFile = new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile(); File inputFile = new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile();
executeTest(configWriter, configReader, inputFile); executeTest(configWriter, configReader, inputFile);
} }
@ -127,10 +123,9 @@ public class LibSvmRecordWriterTest extends BaseND4JTest {
tempFile.deleteOnExit(); tempFile.deleteOnExit();
if (tempFile.exists()) if (tempFile.exists())
tempFile.delete(); tempFile.delete();
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) { try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
FileSplit outputSplit = new FileSplit(tempFile); FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner()); writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
LibSvmRecordReader rr = new LibSvmRecordReader(); LibSvmRecordReader rr = new LibSvmRecordReader();
rr.initialize(configReader, new FileSplit(inputFile)); rr.initialize(configReader, new FileSplit(inputFile));
while (rr.hasNext()) { while (rr.hasNext()) {
@ -138,7 +133,6 @@ public class LibSvmRecordWriterTest extends BaseND4JTest {
writer.write(record); writer.write(record);
} }
} }
Pattern p = Pattern.compile(String.format("%s:\\d+ ", LibSvmRecordReader.QID_PREFIX)); Pattern p = Pattern.compile(String.format("%s:\\d+ ", LibSvmRecordReader.QID_PREFIX));
List<String> linesOriginal = new ArrayList<>(); List<String> linesOriginal = new ArrayList<>();
for (String line : FileUtils.readLines(inputFile)) { for (String line : FileUtils.readLines(inputFile)) {
@ -159,7 +153,8 @@ public class LibSvmRecordWriterTest extends BaseND4JTest {
} }
@Test @Test
public void testNDArrayWritables() throws Exception { @DisplayName("Test ND Array Writables")
void testNDArrayWritables() throws Exception {
INDArray arr2 = Nd4j.zeros(2); INDArray arr2 = Nd4j.zeros(2);
arr2.putScalar(0, 11); arr2.putScalar(0, 11);
arr2.putScalar(1, 12); arr2.putScalar(1, 12);
@ -167,35 +162,28 @@ public class LibSvmRecordWriterTest extends BaseND4JTest {
arr3.putScalar(0, 13); arr3.putScalar(0, 13);
arr3.putScalar(1, 14); arr3.putScalar(1, 14);
arr3.putScalar(2, 15); arr3.putScalar(2, 15);
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1), List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), new NDArrayWritable(arr3), new IntWritable(4));
new NDArrayWritable(arr2),
new IntWritable(2),
new DoubleWritable(3),
new NDArrayWritable(arr3),
new IntWritable(4));
File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt"); File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
tempFile.setWritable(true); tempFile.setWritable(true);
tempFile.deleteOnExit(); tempFile.deleteOnExit();
if (tempFile.exists()) if (tempFile.exists())
tempFile.delete(); tempFile.delete();
String lineOriginal = "13.0,14.0,15.0,4 1:1.0 2:11.0 3:12.0 4:2.0 5:3.0"; String lineOriginal = "13.0,14.0,15.0,4 1:1.0 2:11.0 3:12.0 4:2.0 5:3.0";
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) { try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
Configuration configWriter = new Configuration(); Configuration configWriter = new Configuration();
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 3); configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 3);
FileSplit outputSplit = new FileSplit(tempFile); FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner()); writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record); writer.write(record);
} }
String lineNew = FileUtils.readFileToString(tempFile).trim(); String lineNew = FileUtils.readFileToString(tempFile).trim();
assertEquals(lineOriginal, lineNew); assertEquals(lineOriginal, lineNew);
} }
@Test @Test
public void testNDArrayWritablesMultilabel() throws Exception { @DisplayName("Test ND Array Writables Multilabel")
void testNDArrayWritablesMultilabel() throws Exception {
INDArray arr2 = Nd4j.zeros(2); INDArray arr2 = Nd4j.zeros(2);
arr2.putScalar(0, 11); arr2.putScalar(0, 11);
arr2.putScalar(1, 12); arr2.putScalar(1, 12);
@ -203,36 +191,29 @@ public class LibSvmRecordWriterTest extends BaseND4JTest {
arr3.putScalar(0, 0); arr3.putScalar(0, 0);
arr3.putScalar(1, 1); arr3.putScalar(1, 1);
arr3.putScalar(2, 0); arr3.putScalar(2, 0);
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1), List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), new NDArrayWritable(arr3), new DoubleWritable(1));
new NDArrayWritable(arr2),
new IntWritable(2),
new DoubleWritable(3),
new NDArrayWritable(arr3),
new DoubleWritable(1));
File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt"); File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
tempFile.setWritable(true); tempFile.setWritable(true);
tempFile.deleteOnExit(); tempFile.deleteOnExit();
if (tempFile.exists()) if (tempFile.exists())
tempFile.delete(); tempFile.delete();
String lineOriginal = "2,4 1:1.0 2:11.0 3:12.0 4:2.0 5:3.0"; String lineOriginal = "2,4 1:1.0 2:11.0 3:12.0 4:2.0 5:3.0";
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) { try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
Configuration configWriter = new Configuration(); Configuration configWriter = new Configuration();
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true); configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 3); configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 3);
FileSplit outputSplit = new FileSplit(tempFile); FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner()); writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record); writer.write(record);
} }
String lineNew = FileUtils.readFileToString(tempFile).trim(); String lineNew = FileUtils.readFileToString(tempFile).trim();
assertEquals(lineOriginal, lineNew); assertEquals(lineOriginal, lineNew);
} }
@Test @Test
public void testNDArrayWritablesZeroIndex() throws Exception { @DisplayName("Test ND Array Writables Zero Index")
void testNDArrayWritablesZeroIndex() throws Exception {
INDArray arr2 = Nd4j.zeros(2); INDArray arr2 = Nd4j.zeros(2);
arr2.putScalar(0, 11); arr2.putScalar(0, 11);
arr2.putScalar(1, 12); arr2.putScalar(1, 12);
@ -240,99 +221,91 @@ public class LibSvmRecordWriterTest extends BaseND4JTest {
arr3.putScalar(0, 0); arr3.putScalar(0, 0);
arr3.putScalar(1, 1); arr3.putScalar(1, 1);
arr3.putScalar(2, 0); arr3.putScalar(2, 0);
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1), List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), new NDArrayWritable(arr3), new DoubleWritable(1));
new NDArrayWritable(arr2),
new IntWritable(2),
new DoubleWritable(3),
new NDArrayWritable(arr3),
new DoubleWritable(1));
File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt"); File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
tempFile.setWritable(true); tempFile.setWritable(true);
tempFile.deleteOnExit(); tempFile.deleteOnExit();
if (tempFile.exists()) if (tempFile.exists())
tempFile.delete(); tempFile.delete();
String lineOriginal = "1,3 0:1.0 1:11.0 2:12.0 3:2.0 4:3.0"; String lineOriginal = "1,3 0:1.0 1:11.0 2:12.0 3:2.0 4:3.0";
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) { try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
Configuration configWriter = new Configuration(); Configuration configWriter = new Configuration();
configWriter.setBoolean(LibSvmRecordWriter.ZERO_BASED_INDEXING, true); // NOT STANDARD! // NOT STANDARD!
configWriter.setBoolean(LibSvmRecordWriter.ZERO_BASED_LABEL_INDEXING, true); // NOT STANDARD! configWriter.setBoolean(LibSvmRecordWriter.ZERO_BASED_INDEXING, true);
// NOT STANDARD!
configWriter.setBoolean(LibSvmRecordWriter.ZERO_BASED_LABEL_INDEXING, true);
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true); configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 3); configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 3);
FileSplit outputSplit = new FileSplit(tempFile); FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner()); writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record); writer.write(record);
} }
String lineNew = FileUtils.readFileToString(tempFile).trim(); String lineNew = FileUtils.readFileToString(tempFile).trim();
assertEquals(lineOriginal, lineNew); assertEquals(lineOriginal, lineNew);
} }
@Test @Test
public void testNonIntegerButValidMultilabel() throws Exception { @DisplayName("Test Non Integer But Valid Multilabel")
List<Writable> record = Arrays.asList((Writable) new IntWritable(3), void testNonIntegerButValidMultilabel() throws Exception {
new IntWritable(2), List<Writable> record = Arrays.asList((Writable) new IntWritable(3), new IntWritable(2), new DoubleWritable(1.0));
new DoubleWritable(1.0));
File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt"); File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
tempFile.setWritable(true); tempFile.setWritable(true);
tempFile.deleteOnExit(); tempFile.deleteOnExit();
if (tempFile.exists()) if (tempFile.exists())
tempFile.delete(); tempFile.delete();
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) { try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
Configuration configWriter = new Configuration(); Configuration configWriter = new Configuration();
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 1); configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 1);
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true); configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
FileSplit outputSplit = new FileSplit(tempFile); FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner()); writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record); writer.write(record);
} }
} }
@Test(expected = NumberFormatException.class) @Test
public void nonIntegerMultilabel() throws Exception { @DisplayName("Non Integer Multilabel")
List<Writable> record = Arrays.asList((Writable) new IntWritable(3), void nonIntegerMultilabel() {
new IntWritable(2), assertThrows(NumberFormatException.class, () -> {
new DoubleWritable(1.2)); List<Writable> record = Arrays.asList((Writable) new IntWritable(3), new IntWritable(2), new DoubleWritable(1.2));
File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt"); File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
tempFile.setWritable(true); tempFile.setWritable(true);
tempFile.deleteOnExit(); tempFile.deleteOnExit();
if (tempFile.exists()) if (tempFile.exists())
tempFile.delete(); tempFile.delete();
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) { Configuration configWriter = new Configuration();
Configuration configWriter = new Configuration(); configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 1);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 1); configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true); FileSplit outputSplit = new FileSplit(tempFile);
FileSplit outputSplit = new FileSplit(tempFile); writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner()); writer.write(record);
writer.write(record); }
} });
} }
@Test(expected = NumberFormatException.class) @Test
public void nonBinaryMultilabel() throws Exception { @DisplayName("Non Binary Multilabel")
List<Writable> record = Arrays.asList((Writable) new IntWritable(0), void nonBinaryMultilabel() {
new IntWritable(1), assertThrows(NumberFormatException.class, () -> {
new IntWritable(2)); List<Writable> record = Arrays.asList((Writable) new IntWritable(0), new IntWritable(1), new IntWritable(2));
File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt"); File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
tempFile.setWritable(true); tempFile.setWritable(true);
tempFile.deleteOnExit(); tempFile.deleteOnExit();
if (tempFile.exists()) if (tempFile.exists())
tempFile.delete(); tempFile.delete();
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) { Configuration configWriter = new Configuration();
Configuration configWriter = new Configuration(); configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN,0); configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 1);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN,1); configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL,true); FileSplit outputSplit = new FileSplit(tempFile);
FileSplit outputSplit = new FileSplit(tempFile); writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner()); writer.write(record);
writer.write(record); }
} });
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.records.writer.impl; package org.datavec.api.records.writer.impl;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
@ -27,93 +26,90 @@ import org.datavec.api.records.writer.impl.misc.SVMLightRecordWriter;
import org.datavec.api.split.FileSplit; import org.datavec.api.split.FileSplit;
import org.datavec.api.split.partition.NumberOfRecordsPartitioner; import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
import org.datavec.api.writable.*; import org.datavec.api.writable.*;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import java.io.File; import java.io.File;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.regex.Matcher; import java.util.regex.Matcher;
import java.util.regex.Pattern; import java.util.regex.Pattern;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.Assert.assertEquals; @DisplayName("Svm Light Record Writer Test")
class SVMLightRecordWriterTest extends BaseND4JTest {
public class SVMLightRecordWriterTest extends BaseND4JTest {
@Test @Test
public void testBasic() throws Exception { @DisplayName("Test Basic")
void testBasic() throws Exception {
Configuration configWriter = new Configuration(); Configuration configWriter = new Configuration();
Configuration configReader = new Configuration(); Configuration configReader = new Configuration();
configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 10); configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
configReader.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false); configReader.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
File inputFile = new ClassPathResource("datavec-api/svmlight/basic.txt").getFile(); File inputFile = new ClassPathResource("datavec-api/svmlight/basic.txt").getFile();
executeTest(configWriter, configReader, inputFile); executeTest(configWriter, configReader, inputFile);
} }
@Test @Test
public void testNoLabel() throws Exception { @DisplayName("Test No Label")
void testNoLabel() throws Exception {
Configuration configWriter = new Configuration(); Configuration configWriter = new Configuration();
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 9); configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 9);
Configuration configReader = new Configuration(); Configuration configReader = new Configuration();
configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 10); configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
configReader.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false); configReader.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
File inputFile = new ClassPathResource("datavec-api/svmlight/noLabels.txt").getFile(); File inputFile = new ClassPathResource("datavec-api/svmlight/noLabels.txt").getFile();
executeTest(configWriter, configReader, inputFile); executeTest(configWriter, configReader, inputFile);
} }
@Test @Test
public void testMultioutputRecord() throws Exception { @DisplayName("Test Multioutput Record")
void testMultioutputRecord() throws Exception {
Configuration configWriter = new Configuration(); Configuration configWriter = new Configuration();
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 9); configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 9);
Configuration configReader = new Configuration(); Configuration configReader = new Configuration();
configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 10); configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
configReader.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false); configReader.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
File inputFile = new ClassPathResource("datavec-api/svmlight/multioutput.txt").getFile(); File inputFile = new ClassPathResource("datavec-api/svmlight/multioutput.txt").getFile();
executeTest(configWriter, configReader, inputFile); executeTest(configWriter, configReader, inputFile);
} }
@Test @Test
public void testMultilabelRecord() throws Exception { @DisplayName("Test Multilabel Record")
void testMultilabelRecord() throws Exception {
Configuration configWriter = new Configuration(); Configuration configWriter = new Configuration();
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 9); configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 9);
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true); configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
Configuration configReader = new Configuration(); Configuration configReader = new Configuration();
configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 10); configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
configReader.setBoolean(SVMLightRecordReader.MULTILABEL, true); configReader.setBoolean(SVMLightRecordReader.MULTILABEL, true);
configReader.setInt(SVMLightRecordReader.NUM_LABELS, 4); configReader.setInt(SVMLightRecordReader.NUM_LABELS, 4);
configReader.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false); configReader.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
File inputFile = new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile(); File inputFile = new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile();
executeTest(configWriter, configReader, inputFile); executeTest(configWriter, configReader, inputFile);
} }
@Test @Test
public void testZeroBasedIndexing() throws Exception { @DisplayName("Test Zero Based Indexing")
void testZeroBasedIndexing() throws Exception {
Configuration configWriter = new Configuration(); Configuration configWriter = new Configuration();
configWriter.setBoolean(SVMLightRecordWriter.ZERO_BASED_INDEXING, true); configWriter.setBoolean(SVMLightRecordWriter.ZERO_BASED_INDEXING, true);
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 10); configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 10);
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true); configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
Configuration configReader = new Configuration(); Configuration configReader = new Configuration();
configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 11); configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 11);
configReader.setBoolean(SVMLightRecordReader.MULTILABEL, true); configReader.setBoolean(SVMLightRecordReader.MULTILABEL, true);
configReader.setInt(SVMLightRecordReader.NUM_LABELS, 5); configReader.setInt(SVMLightRecordReader.NUM_LABELS, 5);
File inputFile = new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile(); File inputFile = new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile();
executeTest(configWriter, configReader, inputFile); executeTest(configWriter, configReader, inputFile);
} }
@ -124,10 +120,9 @@ public class SVMLightRecordWriterTest extends BaseND4JTest {
tempFile.deleteOnExit(); tempFile.deleteOnExit();
if (tempFile.exists()) if (tempFile.exists())
tempFile.delete(); tempFile.delete();
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) { try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
FileSplit outputSplit = new FileSplit(tempFile); FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner()); writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
SVMLightRecordReader rr = new SVMLightRecordReader(); SVMLightRecordReader rr = new SVMLightRecordReader();
rr.initialize(configReader, new FileSplit(inputFile)); rr.initialize(configReader, new FileSplit(inputFile));
while (rr.hasNext()) { while (rr.hasNext()) {
@ -135,7 +130,6 @@ public class SVMLightRecordWriterTest extends BaseND4JTest {
writer.write(record); writer.write(record);
} }
} }
Pattern p = Pattern.compile(String.format("%s:\\d+ ", SVMLightRecordReader.QID_PREFIX)); Pattern p = Pattern.compile(String.format("%s:\\d+ ", SVMLightRecordReader.QID_PREFIX));
List<String> linesOriginal = new ArrayList<>(); List<String> linesOriginal = new ArrayList<>();
for (String line : FileUtils.readLines(inputFile)) { for (String line : FileUtils.readLines(inputFile)) {
@ -156,7 +150,8 @@ public class SVMLightRecordWriterTest extends BaseND4JTest {
} }
@Test @Test
public void testNDArrayWritables() throws Exception { @DisplayName("Test ND Array Writables")
void testNDArrayWritables() throws Exception {
INDArray arr2 = Nd4j.zeros(2); INDArray arr2 = Nd4j.zeros(2);
arr2.putScalar(0, 11); arr2.putScalar(0, 11);
arr2.putScalar(1, 12); arr2.putScalar(1, 12);
@ -164,35 +159,28 @@ public class SVMLightRecordWriterTest extends BaseND4JTest {
arr3.putScalar(0, 13); arr3.putScalar(0, 13);
arr3.putScalar(1, 14); arr3.putScalar(1, 14);
arr3.putScalar(2, 15); arr3.putScalar(2, 15);
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1), List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), new NDArrayWritable(arr3), new IntWritable(4));
new NDArrayWritable(arr2),
new IntWritable(2),
new DoubleWritable(3),
new NDArrayWritable(arr3),
new IntWritable(4));
File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt"); File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
tempFile.setWritable(true); tempFile.setWritable(true);
tempFile.deleteOnExit(); tempFile.deleteOnExit();
if (tempFile.exists()) if (tempFile.exists())
tempFile.delete(); tempFile.delete();
String lineOriginal = "13.0,14.0,15.0,4 1:1.0 2:11.0 3:12.0 4:2.0 5:3.0"; String lineOriginal = "13.0,14.0,15.0,4 1:1.0 2:11.0 3:12.0 4:2.0 5:3.0";
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) { try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
Configuration configWriter = new Configuration(); Configuration configWriter = new Configuration();
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 3); configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 3);
FileSplit outputSplit = new FileSplit(tempFile); FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner()); writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record); writer.write(record);
} }
String lineNew = FileUtils.readFileToString(tempFile).trim(); String lineNew = FileUtils.readFileToString(tempFile).trim();
assertEquals(lineOriginal, lineNew); assertEquals(lineOriginal, lineNew);
} }
@Test @Test
public void testNDArrayWritablesMultilabel() throws Exception { @DisplayName("Test ND Array Writables Multilabel")
void testNDArrayWritablesMultilabel() throws Exception {
INDArray arr2 = Nd4j.zeros(2); INDArray arr2 = Nd4j.zeros(2);
arr2.putScalar(0, 11); arr2.putScalar(0, 11);
arr2.putScalar(1, 12); arr2.putScalar(1, 12);
@ -200,36 +188,29 @@ public class SVMLightRecordWriterTest extends BaseND4JTest {
arr3.putScalar(0, 0); arr3.putScalar(0, 0);
arr3.putScalar(1, 1); arr3.putScalar(1, 1);
arr3.putScalar(2, 0); arr3.putScalar(2, 0);
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1), List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), new NDArrayWritable(arr3), new DoubleWritable(1));
new NDArrayWritable(arr2),
new IntWritable(2),
new DoubleWritable(3),
new NDArrayWritable(arr3),
new DoubleWritable(1));
File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt"); File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
tempFile.setWritable(true); tempFile.setWritable(true);
tempFile.deleteOnExit(); tempFile.deleteOnExit();
if (tempFile.exists()) if (tempFile.exists())
tempFile.delete(); tempFile.delete();
String lineOriginal = "2,4 1:1.0 2:11.0 3:12.0 4:2.0 5:3.0"; String lineOriginal = "2,4 1:1.0 2:11.0 3:12.0 4:2.0 5:3.0";
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) { try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
Configuration configWriter = new Configuration(); Configuration configWriter = new Configuration();
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true); configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 3); configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 3);
FileSplit outputSplit = new FileSplit(tempFile); FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner()); writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record); writer.write(record);
} }
String lineNew = FileUtils.readFileToString(tempFile).trim(); String lineNew = FileUtils.readFileToString(tempFile).trim();
assertEquals(lineOriginal, lineNew); assertEquals(lineOriginal, lineNew);
} }
@Test @Test
public void testNDArrayWritablesZeroIndex() throws Exception { @DisplayName("Test ND Array Writables Zero Index")
void testNDArrayWritablesZeroIndex() throws Exception {
INDArray arr2 = Nd4j.zeros(2); INDArray arr2 = Nd4j.zeros(2);
arr2.putScalar(0, 11); arr2.putScalar(0, 11);
arr2.putScalar(1, 12); arr2.putScalar(1, 12);
@ -237,99 +218,91 @@ public class SVMLightRecordWriterTest extends BaseND4JTest {
arr3.putScalar(0, 0); arr3.putScalar(0, 0);
arr3.putScalar(1, 1); arr3.putScalar(1, 1);
arr3.putScalar(2, 0); arr3.putScalar(2, 0);
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1), List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), new NDArrayWritable(arr3), new DoubleWritable(1));
new NDArrayWritable(arr2),
new IntWritable(2),
new DoubleWritable(3),
new NDArrayWritable(arr3),
new DoubleWritable(1));
File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt"); File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
tempFile.setWritable(true); tempFile.setWritable(true);
tempFile.deleteOnExit(); tempFile.deleteOnExit();
if (tempFile.exists()) if (tempFile.exists())
tempFile.delete(); tempFile.delete();
String lineOriginal = "1,3 0:1.0 1:11.0 2:12.0 3:2.0 4:3.0"; String lineOriginal = "1,3 0:1.0 1:11.0 2:12.0 3:2.0 4:3.0";
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) { try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
Configuration configWriter = new Configuration(); Configuration configWriter = new Configuration();
configWriter.setBoolean(SVMLightRecordWriter.ZERO_BASED_INDEXING, true); // NOT STANDARD! // NOT STANDARD!
configWriter.setBoolean(SVMLightRecordWriter.ZERO_BASED_LABEL_INDEXING, true); // NOT STANDARD! configWriter.setBoolean(SVMLightRecordWriter.ZERO_BASED_INDEXING, true);
// NOT STANDARD!
configWriter.setBoolean(SVMLightRecordWriter.ZERO_BASED_LABEL_INDEXING, true);
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true); configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 3); configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 3);
FileSplit outputSplit = new FileSplit(tempFile); FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner()); writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record); writer.write(record);
} }
String lineNew = FileUtils.readFileToString(tempFile).trim(); String lineNew = FileUtils.readFileToString(tempFile).trim();
assertEquals(lineOriginal, lineNew); assertEquals(lineOriginal, lineNew);
} }
@Test @Test
public void testNonIntegerButValidMultilabel() throws Exception { @DisplayName("Test Non Integer But Valid Multilabel")
List<Writable> record = Arrays.asList((Writable) new IntWritable(3), void testNonIntegerButValidMultilabel() throws Exception {
new IntWritable(2), List<Writable> record = Arrays.asList((Writable) new IntWritable(3), new IntWritable(2), new DoubleWritable(1.0));
new DoubleWritable(1.0));
File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt"); File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
tempFile.setWritable(true); tempFile.setWritable(true);
tempFile.deleteOnExit(); tempFile.deleteOnExit();
if (tempFile.exists()) if (tempFile.exists())
tempFile.delete(); tempFile.delete();
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) { try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
Configuration configWriter = new Configuration(); Configuration configWriter = new Configuration();
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 1); configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 1);
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true); configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
FileSplit outputSplit = new FileSplit(tempFile); FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner()); writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record); writer.write(record);
} }
} }
@Test(expected = NumberFormatException.class) @Test
public void nonIntegerMultilabel() throws Exception { @DisplayName("Non Integer Multilabel")
List<Writable> record = Arrays.asList((Writable) new IntWritable(3), void nonIntegerMultilabel() {
new IntWritable(2), assertThrows(NumberFormatException.class, () -> {
new DoubleWritable(1.2)); List<Writable> record = Arrays.asList((Writable) new IntWritable(3), new IntWritable(2), new DoubleWritable(1.2));
File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt"); File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
tempFile.setWritable(true); tempFile.setWritable(true);
tempFile.deleteOnExit(); tempFile.deleteOnExit();
if (tempFile.exists()) if (tempFile.exists())
tempFile.delete(); tempFile.delete();
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) { Configuration configWriter = new Configuration();
Configuration configWriter = new Configuration(); configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 1);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 1); configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true); FileSplit outputSplit = new FileSplit(tempFile);
FileSplit outputSplit = new FileSplit(tempFile); writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner()); writer.write(record);
writer.write(record); }
} });
} }
@Test(expected = NumberFormatException.class) @Test
public void nonBinaryMultilabel() throws Exception { @DisplayName("Non Binary Multilabel")
List<Writable> record = Arrays.asList((Writable) new IntWritable(0), void nonBinaryMultilabel() {
new IntWritable(1), assertThrows(NumberFormatException.class, () -> {
new IntWritable(2)); List<Writable> record = Arrays.asList((Writable) new IntWritable(0), new IntWritable(1), new IntWritable(2));
File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt"); File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
tempFile.setWritable(true); tempFile.setWritable(true);
tempFile.deleteOnExit(); tempFile.deleteOnExit();
if (tempFile.exists()) if (tempFile.exists())
tempFile.delete(); tempFile.delete();
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) { Configuration configWriter = new Configuration();
Configuration configWriter = new Configuration(); configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 1);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 1); configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true); FileSplit outputSplit = new FileSplit(tempFile);
FileSplit outputSplit = new FileSplit(tempFile); writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner()); writer.write(record);
writer.write(record); }
} });
} }
} }

View File

@ -17,44 +17,43 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.split; package org.datavec.api.split;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.net.URI; import java.net.URI;
import java.net.URISyntaxException; import java.net.URISyntaxException;
import java.util.Collection; import java.util.Collection;
import static java.util.Arrays.asList; import static java.util.Arrays.asList;
import static org.junit.Assert.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
/** /**
* @author Ede Meijer * @author Ede Meijer
*/ */
public class TransformSplitTest extends BaseND4JTest { @DisplayName("Transform Split Test")
@Test class TransformSplitTest extends BaseND4JTest {
public void testTransform() throws URISyntaxException {
Collection<URI> inputFiles = asList(new URI("file:///foo/bar/../0.csv"), new URI("file:///foo/1.csv"));
@Test
@DisplayName("Test Transform")
void testTransform() throws URISyntaxException {
Collection<URI> inputFiles = asList(new URI("file:///foo/bar/../0.csv"), new URI("file:///foo/1.csv"));
InputSplit SUT = new TransformSplit(new CollectionInputSplit(inputFiles), new TransformSplit.URITransform() { InputSplit SUT = new TransformSplit(new CollectionInputSplit(inputFiles), new TransformSplit.URITransform() {
@Override @Override
public URI apply(URI uri) throws URISyntaxException { public URI apply(URI uri) throws URISyntaxException {
return uri.normalize(); return uri.normalize();
} }
}); });
assertArrayEquals(new URI[] { new URI("file:///foo/0.csv"), new URI("file:///foo/1.csv") }, SUT.locations());
assertArrayEquals(new URI[] {new URI("file:///foo/0.csv"), new URI("file:///foo/1.csv")}, SUT.locations());
} }
@Test @Test
public void testSearchReplace() throws URISyntaxException { @DisplayName("Test Search Replace")
void testSearchReplace() throws URISyntaxException {
Collection<URI> inputFiles = asList(new URI("file:///foo/1-in.csv"), new URI("file:///foo/2-in.csv")); Collection<URI> inputFiles = asList(new URI("file:///foo/1-in.csv"), new URI("file:///foo/2-in.csv"));
InputSplit SUT = TransformSplit.ofSearchReplace(new CollectionInputSplit(inputFiles), "-in.csv", "-out.csv"); InputSplit SUT = TransformSplit.ofSearchReplace(new CollectionInputSplit(inputFiles), "-in.csv", "-out.csv");
assertArrayEquals(new URI[] { new URI("file:///foo/1-out.csv"), new URI("file:///foo/2-out.csv") }, SUT.locations());
assertArrayEquals(new URI[] {new URI("file:///foo/1-out.csv"), new URI("file:///foo/2-out.csv")},
SUT.locations());
} }
} }

View File

@ -17,32 +17,25 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.transform.ops; package org.datavec.api.transform.ops;
import com.tngtech.archunit.core.importer.ImportOption; import com.tngtech.archunit.core.importer.ImportOption;
import com.tngtech.archunit.junit.AnalyzeClasses; import com.tngtech.archunit.junit.AnalyzeClasses;
import com.tngtech.archunit.junit.ArchTest; import com.tngtech.archunit.junit.ArchTest;
import com.tngtech.archunit.junit.ArchUnitRunner;
import com.tngtech.archunit.lang.ArchRule; import com.tngtech.archunit.lang.ArchRule;
import com.tngtech.archunit.lang.extension.ArchUnitExtension;
import com.tngtech.archunit.lang.extension.ArchUnitExtensions;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.io.Serializable; import java.io.Serializable;
import static com.tngtech.archunit.lang.syntax.ArchRuleDefinition.classes; import static com.tngtech.archunit.lang.syntax.ArchRuleDefinition.classes;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@RunWith(ArchUnitRunner.class) @AnalyzeClasses(packages = "org.datavec.api.transform.ops", importOptions = { ImportOption.DoNotIncludeTests.class })
@AnalyzeClasses(packages = "org.datavec.api.transform.ops", importOptions = {ImportOption.DoNotIncludeTests.class}) @DisplayName("Aggregable Multi Op Arch Test")
public class AggregableMultiOpArchTest extends BaseND4JTest { class AggregableMultiOpArchTest extends BaseND4JTest {
@ArchTest @ArchTest
public static final ArchRule ALL_AGGREGATE_OPS_MUST_BE_SERIALIZABLE = classes() public static final ArchRule ALL_AGGREGATE_OPS_MUST_BE_SERIALIZABLE = classes().that().resideInAPackage("org.datavec.api.transform.ops").and().doNotHaveSimpleName("AggregatorImpls").and().doNotHaveSimpleName("IAggregableReduceOp").and().doNotHaveSimpleName("StringAggregatorImpls").and().doNotHaveFullyQualifiedName("org.datavec.api.transform.ops.StringAggregatorImpls$1").should().implement(Serializable.class).because("All aggregate ops must be serializable.");
.that().resideInAPackage("org.datavec.api.transform.ops") }
.and().doNotHaveSimpleName("AggregatorImpls")
.and().doNotHaveSimpleName("IAggregableReduceOp")
.and().doNotHaveSimpleName("StringAggregatorImpls")
.and().doNotHaveFullyQualifiedName("org.datavec.api.transform.ops.StringAggregatorImpls$1")
.should().implement(Serializable.class)
.because("All aggregate ops must be serializable.");
}

View File

@ -17,52 +17,46 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.transform.ops; package org.datavec.api.transform.ops;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.util.*; import java.util.*;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertTrue; @DisplayName("Aggregable Multi Op Test")
class AggregableMultiOpTest extends BaseND4JTest {
public class AggregableMultiOpTest extends BaseND4JTest {
private List<Integer> intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); private List<Integer> intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
@Test @Test
public void testMulti() throws Exception { @DisplayName("Test Multi")
void testMulti() throws Exception {
AggregatorImpls.AggregableFirst<Integer> af = new AggregatorImpls.AggregableFirst<>(); AggregatorImpls.AggregableFirst<Integer> af = new AggregatorImpls.AggregableFirst<>();
AggregatorImpls.AggregableSum<Integer> as = new AggregatorImpls.AggregableSum<>(); AggregatorImpls.AggregableSum<Integer> as = new AggregatorImpls.AggregableSum<>();
AggregableMultiOp<Integer> multi = new AggregableMultiOp<>(Arrays.asList(af, as)); AggregableMultiOp<Integer> multi = new AggregableMultiOp<>(Arrays.asList(af, as));
assertTrue(multi.getOperations().size() == 2); assertTrue(multi.getOperations().size() == 2);
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
multi.accept(intList.get(i)); multi.accept(intList.get(i));
} }
// mutablility // mutablility
assertTrue(as.get().toDouble() == 45D); assertTrue(as.get().toDouble() == 45D);
assertTrue(af.get().toInt() == 1); assertTrue(af.get().toInt() == 1);
List<Writable> res = multi.get(); List<Writable> res = multi.get();
assertTrue(res.get(1).toDouble() == 45D); assertTrue(res.get(1).toDouble() == 45D);
assertTrue(res.get(0).toInt() == 1); assertTrue(res.get(0).toInt() == 1);
AggregatorImpls.AggregableFirst<Integer> rf = new AggregatorImpls.AggregableFirst<>(); AggregatorImpls.AggregableFirst<Integer> rf = new AggregatorImpls.AggregableFirst<>();
AggregatorImpls.AggregableSum<Integer> rs = new AggregatorImpls.AggregableSum<>(); AggregatorImpls.AggregableSum<Integer> rs = new AggregatorImpls.AggregableSum<>();
AggregableMultiOp<Integer> reverse = new AggregableMultiOp<>(Arrays.asList(rf, rs)); AggregableMultiOp<Integer> reverse = new AggregableMultiOp<>(Arrays.asList(rf, rs));
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1)); reverse.accept(intList.get(intList.size() - i - 1));
} }
List<Writable> revRes = reverse.get(); List<Writable> revRes = reverse.get();
assertTrue(revRes.get(1).toDouble() == 45D); assertTrue(revRes.get(1).toDouble() == 45D);
assertTrue(revRes.get(0).toInt() == 9); assertTrue(revRes.get(0).toInt() == 9);
multi.combine(reverse); multi.combine(reverse);
List<Writable> combinedRes = multi.get(); List<Writable> combinedRes = multi.get();
assertTrue(combinedRes.get(1).toDouble() == 90D); assertTrue(combinedRes.get(1).toDouble() == 90D);

View File

@ -17,41 +17,39 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.transform.ops; package org.datavec.api.transform.ops;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.ExpectedException; import org.junit.rules.ExpectedException;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import static org.junit.Assert.assertEquals; @DisplayName("Aggregator Impls Test")
import static org.junit.Assert.assertTrue; class AggregatorImplsTest extends BaseND4JTest {
public class AggregatorImplsTest extends BaseND4JTest {
private List<Integer> intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); private List<Integer> intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
private List<String> stringList = new ArrayList<>(Arrays.asList("arakoa", "abracadabra", "blast", "acceptance")); private List<String> stringList = new ArrayList<>(Arrays.asList("arakoa", "abracadabra", "blast", "acceptance"));
@Test @Test
public void aggregableFirstTest() { @DisplayName("Aggregable First Test")
void aggregableFirstTest() {
AggregatorImpls.AggregableFirst<Integer> first = new AggregatorImpls.AggregableFirst<>(); AggregatorImpls.AggregableFirst<Integer> first = new AggregatorImpls.AggregableFirst<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
first.accept(intList.get(i)); first.accept(intList.get(i));
} }
assertEquals(1, first.get().toInt()); assertEquals(1, first.get().toInt());
AggregatorImpls.AggregableFirst<String> firstS = new AggregatorImpls.AggregableFirst<>(); AggregatorImpls.AggregableFirst<String> firstS = new AggregatorImpls.AggregableFirst<>();
for (int i = 0; i < stringList.size(); i++) { for (int i = 0; i < stringList.size(); i++) {
firstS.accept(stringList.get(i)); firstS.accept(stringList.get(i));
} }
assertTrue(firstS.get().toString().equals("arakoa")); assertTrue(firstS.get().toString().equals("arakoa"));
AggregatorImpls.AggregableFirst<Integer> reverse = new AggregatorImpls.AggregableFirst<>(); AggregatorImpls.AggregableFirst<Integer> reverse = new AggregatorImpls.AggregableFirst<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1)); reverse.accept(intList.get(intList.size() - i - 1));
@ -60,22 +58,19 @@ public class AggregatorImplsTest extends BaseND4JTest {
assertEquals(1, first.get().toInt()); assertEquals(1, first.get().toInt());
} }
@Test @Test
public void aggregableLastTest() { @DisplayName("Aggregable Last Test")
void aggregableLastTest() {
AggregatorImpls.AggregableLast<Integer> last = new AggregatorImpls.AggregableLast<>(); AggregatorImpls.AggregableLast<Integer> last = new AggregatorImpls.AggregableLast<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
last.accept(intList.get(i)); last.accept(intList.get(i));
} }
assertEquals(9, last.get().toInt()); assertEquals(9, last.get().toInt());
AggregatorImpls.AggregableLast<String> lastS = new AggregatorImpls.AggregableLast<>(); AggregatorImpls.AggregableLast<String> lastS = new AggregatorImpls.AggregableLast<>();
for (int i = 0; i < stringList.size(); i++) { for (int i = 0; i < stringList.size(); i++) {
lastS.accept(stringList.get(i)); lastS.accept(stringList.get(i));
} }
assertTrue(lastS.get().toString().equals("acceptance")); assertTrue(lastS.get().toString().equals("acceptance"));
AggregatorImpls.AggregableLast<Integer> reverse = new AggregatorImpls.AggregableLast<>(); AggregatorImpls.AggregableLast<Integer> reverse = new AggregatorImpls.AggregableLast<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1)); reverse.accept(intList.get(intList.size() - i - 1));
@ -85,20 +80,18 @@ public class AggregatorImplsTest extends BaseND4JTest {
} }
@Test @Test
public void aggregableCountTest() { @DisplayName("Aggregable Count Test")
void aggregableCountTest() {
AggregatorImpls.AggregableCount<Integer> cnt = new AggregatorImpls.AggregableCount<>(); AggregatorImpls.AggregableCount<Integer> cnt = new AggregatorImpls.AggregableCount<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
cnt.accept(intList.get(i)); cnt.accept(intList.get(i));
} }
assertEquals(9, cnt.get().toInt()); assertEquals(9, cnt.get().toInt());
AggregatorImpls.AggregableCount<String> lastS = new AggregatorImpls.AggregableCount<>(); AggregatorImpls.AggregableCount<String> lastS = new AggregatorImpls.AggregableCount<>();
for (int i = 0; i < stringList.size(); i++) { for (int i = 0; i < stringList.size(); i++) {
lastS.accept(stringList.get(i)); lastS.accept(stringList.get(i));
} }
assertEquals(4, lastS.get().toInt()); assertEquals(4, lastS.get().toInt());
AggregatorImpls.AggregableCount<Integer> reverse = new AggregatorImpls.AggregableCount<>(); AggregatorImpls.AggregableCount<Integer> reverse = new AggregatorImpls.AggregableCount<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1)); reverse.accept(intList.get(intList.size() - i - 1));
@ -108,14 +101,13 @@ public class AggregatorImplsTest extends BaseND4JTest {
} }
@Test @Test
public void aggregableMaxTest() { @DisplayName("Aggregable Max Test")
void aggregableMaxTest() {
AggregatorImpls.AggregableMax<Integer> mx = new AggregatorImpls.AggregableMax<>(); AggregatorImpls.AggregableMax<Integer> mx = new AggregatorImpls.AggregableMax<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
mx.accept(intList.get(i)); mx.accept(intList.get(i));
} }
assertEquals(9, mx.get().toInt()); assertEquals(9, mx.get().toInt());
AggregatorImpls.AggregableMax<Integer> reverse = new AggregatorImpls.AggregableMax<>(); AggregatorImpls.AggregableMax<Integer> reverse = new AggregatorImpls.AggregableMax<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1)); reverse.accept(intList.get(intList.size() - i - 1));
@ -124,16 +116,14 @@ public class AggregatorImplsTest extends BaseND4JTest {
assertEquals(9, mx.get().toInt()); assertEquals(9, mx.get().toInt());
} }
@Test @Test
public void aggregableRangeTest() { @DisplayName("Aggregable Range Test")
void aggregableRangeTest() {
AggregatorImpls.AggregableRange<Integer> mx = new AggregatorImpls.AggregableRange<>(); AggregatorImpls.AggregableRange<Integer> mx = new AggregatorImpls.AggregableRange<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
mx.accept(intList.get(i)); mx.accept(intList.get(i));
} }
assertEquals(8, mx.get().toInt()); assertEquals(8, mx.get().toInt());
AggregatorImpls.AggregableRange<Integer> reverse = new AggregatorImpls.AggregableRange<>(); AggregatorImpls.AggregableRange<Integer> reverse = new AggregatorImpls.AggregableRange<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1) + 9); reverse.accept(intList.get(intList.size() - i - 1) + 9);
@ -143,14 +133,13 @@ public class AggregatorImplsTest extends BaseND4JTest {
} }
@Test @Test
public void aggregableMinTest() { @DisplayName("Aggregable Min Test")
void aggregableMinTest() {
AggregatorImpls.AggregableMin<Integer> mn = new AggregatorImpls.AggregableMin<>(); AggregatorImpls.AggregableMin<Integer> mn = new AggregatorImpls.AggregableMin<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
mn.accept(intList.get(i)); mn.accept(intList.get(i));
} }
assertEquals(1, mn.get().toInt()); assertEquals(1, mn.get().toInt());
AggregatorImpls.AggregableMin<Integer> reverse = new AggregatorImpls.AggregableMin<>(); AggregatorImpls.AggregableMin<Integer> reverse = new AggregatorImpls.AggregableMin<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1)); reverse.accept(intList.get(intList.size() - i - 1));
@ -160,14 +149,13 @@ public class AggregatorImplsTest extends BaseND4JTest {
} }
@Test @Test
public void aggregableSumTest() { @DisplayName("Aggregable Sum Test")
void aggregableSumTest() {
AggregatorImpls.AggregableSum<Integer> sm = new AggregatorImpls.AggregableSum<>(); AggregatorImpls.AggregableSum<Integer> sm = new AggregatorImpls.AggregableSum<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
sm.accept(intList.get(i)); sm.accept(intList.get(i));
} }
assertEquals(45, sm.get().toInt()); assertEquals(45, sm.get().toInt());
AggregatorImpls.AggregableSum<Integer> reverse = new AggregatorImpls.AggregableSum<>(); AggregatorImpls.AggregableSum<Integer> reverse = new AggregatorImpls.AggregableSum<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1)); reverse.accept(intList.get(intList.size() - i - 1));
@ -176,17 +164,15 @@ public class AggregatorImplsTest extends BaseND4JTest {
assertEquals(90, sm.get().toInt()); assertEquals(90, sm.get().toInt());
} }
@Test @Test
public void aggregableMeanTest() { @DisplayName("Aggregable Mean Test")
void aggregableMeanTest() {
AggregatorImpls.AggregableMean<Integer> mn = new AggregatorImpls.AggregableMean<>(); AggregatorImpls.AggregableMean<Integer> mn = new AggregatorImpls.AggregableMean<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
mn.accept(intList.get(i)); mn.accept(intList.get(i));
} }
assertEquals(9l, (long) mn.getCount()); assertEquals(9l, (long) mn.getCount());
assertEquals(5D, mn.get().toDouble(), 0.001); assertEquals(5D, mn.get().toDouble(), 0.001);
AggregatorImpls.AggregableMean<Integer> reverse = new AggregatorImpls.AggregableMean<>(); AggregatorImpls.AggregableMean<Integer> reverse = new AggregatorImpls.AggregableMean<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1)); reverse.accept(intList.get(intList.size() - i - 1));
@ -197,80 +183,73 @@ public class AggregatorImplsTest extends BaseND4JTest {
} }
@Test @Test
public void aggregableStdDevTest() { @DisplayName("Aggregable Std Dev Test")
void aggregableStdDevTest() {
AggregatorImpls.AggregableStdDev<Integer> sd = new AggregatorImpls.AggregableStdDev<>(); AggregatorImpls.AggregableStdDev<Integer> sd = new AggregatorImpls.AggregableStdDev<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
sd.accept(intList.get(i)); sd.accept(intList.get(i));
} }
assertTrue(Math.abs(sd.get().toDouble() - 2.7386) < 0.0001); assertTrue(Math.abs(sd.get().toDouble() - 2.7386) < 0.0001);
AggregatorImpls.AggregableStdDev<Integer> reverse = new AggregatorImpls.AggregableStdDev<>(); AggregatorImpls.AggregableStdDev<Integer> reverse = new AggregatorImpls.AggregableStdDev<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1)); reverse.accept(intList.get(intList.size() - i - 1));
} }
sd.combine(reverse); sd.combine(reverse);
assertTrue("" + sd.get().toDouble(), Math.abs(sd.get().toDouble() - 1.8787) < 0.0001); assertTrue(Math.abs(sd.get().toDouble() - 1.8787) < 0.0001,"" + sd.get().toDouble());
} }
@Test @Test
public void aggregableVariance() { @DisplayName("Aggregable Variance")
void aggregableVariance() {
AggregatorImpls.AggregableVariance<Integer> sd = new AggregatorImpls.AggregableVariance<>(); AggregatorImpls.AggregableVariance<Integer> sd = new AggregatorImpls.AggregableVariance<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
sd.accept(intList.get(i)); sd.accept(intList.get(i));
} }
assertTrue(Math.abs(sd.get().toDouble() - 60D / 8) < 0.0001); assertTrue(Math.abs(sd.get().toDouble() - 60D / 8) < 0.0001);
AggregatorImpls.AggregableVariance<Integer> reverse = new AggregatorImpls.AggregableVariance<>(); AggregatorImpls.AggregableVariance<Integer> reverse = new AggregatorImpls.AggregableVariance<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1)); reverse.accept(intList.get(intList.size() - i - 1));
} }
sd.combine(reverse); sd.combine(reverse);
assertTrue("" + sd.get().toDouble(), Math.abs(sd.get().toDouble() - 3.5294) < 0.0001); assertTrue(Math.abs(sd.get().toDouble() - 3.5294) < 0.0001,"" + sd.get().toDouble());
} }
@Test @Test
public void aggregableUncorrectedStdDevTest() { @DisplayName("Aggregable Uncorrected Std Dev Test")
void aggregableUncorrectedStdDevTest() {
AggregatorImpls.AggregableUncorrectedStdDev<Integer> sd = new AggregatorImpls.AggregableUncorrectedStdDev<>(); AggregatorImpls.AggregableUncorrectedStdDev<Integer> sd = new AggregatorImpls.AggregableUncorrectedStdDev<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
sd.accept(intList.get(i)); sd.accept(intList.get(i));
} }
assertTrue(Math.abs(sd.get().toDouble() - 2.582) < 0.0001); assertTrue(Math.abs(sd.get().toDouble() - 2.582) < 0.0001);
AggregatorImpls.AggregableUncorrectedStdDev<Integer> reverse = new AggregatorImpls.AggregableUncorrectedStdDev<>();
AggregatorImpls.AggregableUncorrectedStdDev<Integer> reverse =
new AggregatorImpls.AggregableUncorrectedStdDev<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1)); reverse.accept(intList.get(intList.size() - i - 1));
} }
sd.combine(reverse); sd.combine(reverse);
assertTrue("" + sd.get().toDouble(), Math.abs(sd.get().toDouble() - 1.8257) < 0.0001); assertTrue(Math.abs(sd.get().toDouble() - 1.8257) < 0.0001,"" + sd.get().toDouble());
} }
@Test @Test
public void aggregablePopulationVariance() { @DisplayName("Aggregable Population Variance")
void aggregablePopulationVariance() {
AggregatorImpls.AggregablePopulationVariance<Integer> sd = new AggregatorImpls.AggregablePopulationVariance<>(); AggregatorImpls.AggregablePopulationVariance<Integer> sd = new AggregatorImpls.AggregablePopulationVariance<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
sd.accept(intList.get(i)); sd.accept(intList.get(i));
} }
assertTrue(Math.abs(sd.get().toDouble() - 60D / 9) < 0.0001); assertTrue(Math.abs(sd.get().toDouble() - 60D / 9) < 0.0001);
AggregatorImpls.AggregablePopulationVariance<Integer> reverse = new AggregatorImpls.AggregablePopulationVariance<>();
AggregatorImpls.AggregablePopulationVariance<Integer> reverse =
new AggregatorImpls.AggregablePopulationVariance<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1)); reverse.accept(intList.get(intList.size() - i - 1));
} }
sd.combine(reverse); sd.combine(reverse);
assertTrue("" + sd.get().toDouble(), Math.abs(sd.get().toDouble() - 30D / 9) < 0.0001); assertTrue(Math.abs(sd.get().toDouble() - 30D / 9) < 0.0001,"" + sd.get().toDouble());
} }
@Test @Test
public void aggregableCountUniqueTest() { @DisplayName("Aggregable Count Unique Test")
void aggregableCountUniqueTest() {
// at this low range, it's linear counting // at this low range, it's linear counting
AggregatorImpls.AggregableCountUnique<Integer> cu = new AggregatorImpls.AggregableCountUnique<>(); AggregatorImpls.AggregableCountUnique<Integer> cu = new AggregatorImpls.AggregableCountUnique<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
cu.accept(intList.get(i)); cu.accept(intList.get(i));
@ -278,7 +257,6 @@ public class AggregatorImplsTest extends BaseND4JTest {
assertEquals(9, cu.get().toInt()); assertEquals(9, cu.get().toInt());
cu.accept(1); cu.accept(1);
assertEquals(9, cu.get().toInt()); assertEquals(9, cu.get().toInt());
AggregatorImpls.AggregableCountUnique<Integer> reverse = new AggregatorImpls.AggregableCountUnique<>(); AggregatorImpls.AggregableCountUnique<Integer> reverse = new AggregatorImpls.AggregableCountUnique<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1)); reverse.accept(intList.get(intList.size() - i - 1));
@ -290,16 +268,14 @@ public class AggregatorImplsTest extends BaseND4JTest {
@Rule @Rule
public final ExpectedException exception = ExpectedException.none(); public final ExpectedException exception = ExpectedException.none();
@Test @Test
public void incompatibleAggregatorTest() { @DisplayName("Incompatible Aggregator Test")
void incompatibleAggregatorTest() {
AggregatorImpls.AggregableSum<Integer> sm = new AggregatorImpls.AggregableSum<>(); AggregatorImpls.AggregableSum<Integer> sm = new AggregatorImpls.AggregableSum<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
sm.accept(intList.get(i)); sm.accept(intList.get(i));
} }
assertEquals(45, sm.get().toInt()); assertEquals(45, sm.get().toInt());
AggregatorImpls.AggregableMean<Integer> reverse = new AggregatorImpls.AggregableMean<>(); AggregatorImpls.AggregableMean<Integer> reverse = new AggregatorImpls.AggregableMean<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1)); reverse.accept(intList.get(intList.size() - i - 1));
@ -308,5 +284,4 @@ public class AggregatorImplsTest extends BaseND4JTest {
sm.combine(reverse); sm.combine(reverse);
assertEquals(45, sm.get().toInt()); assertEquals(45, sm.get().toInt());
} }
} }

View File

@ -17,77 +17,65 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.transform.ops; package org.datavec.api.transform.ops;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertTrue; @DisplayName("Dispatch Op Test")
class DispatchOpTest extends BaseND4JTest {
public class DispatchOpTest extends BaseND4JTest {
private List<Integer> intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); private List<Integer> intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
private List<String> stringList = new ArrayList<>(Arrays.asList("arakoa", "abracadabra", "blast", "acceptance")); private List<String> stringList = new ArrayList<>(Arrays.asList("arakoa", "abracadabra", "blast", "acceptance"));
@Test @Test
public void testDispatchSimple() { @DisplayName("Test Dispatch Simple")
void testDispatchSimple() {
AggregatorImpls.AggregableFirst<Integer> af = new AggregatorImpls.AggregableFirst<>(); AggregatorImpls.AggregableFirst<Integer> af = new AggregatorImpls.AggregableFirst<>();
AggregatorImpls.AggregableSum<Integer> as = new AggregatorImpls.AggregableSum<>(); AggregatorImpls.AggregableSum<Integer> as = new AggregatorImpls.AggregableSum<>();
AggregableMultiOp<Integer> multiaf = AggregableMultiOp<Integer> multiaf = new AggregableMultiOp<>(Collections.<IAggregableReduceOp<Integer, Writable>>singletonList(af));
new AggregableMultiOp<>(Collections.<IAggregableReduceOp<Integer, Writable>>singletonList(af)); AggregableMultiOp<Integer> multias = new AggregableMultiOp<>(Collections.<IAggregableReduceOp<Integer, Writable>>singletonList(as));
AggregableMultiOp<Integer> multias = DispatchOp<Integer, Writable> parallel = new DispatchOp<>(Arrays.<IAggregableReduceOp<Integer, List<Writable>>>asList(multiaf, multias));
new AggregableMultiOp<>(Collections.<IAggregableReduceOp<Integer, Writable>>singletonList(as));
DispatchOp<Integer, Writable> parallel =
new DispatchOp<>(Arrays.<IAggregableReduceOp<Integer, List<Writable>>>asList(multiaf, multias));
assertTrue(multiaf.getOperations().size() == 1); assertTrue(multiaf.getOperations().size() == 1);
assertTrue(multias.getOperations().size() == 1); assertTrue(multias.getOperations().size() == 1);
assertTrue(parallel.getOperations().size() == 2); assertTrue(parallel.getOperations().size() == 2);
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
parallel.accept(Arrays.asList(intList.get(i), intList.get(i))); parallel.accept(Arrays.asList(intList.get(i), intList.get(i)));
} }
List<Writable> res = parallel.get(); List<Writable> res = parallel.get();
assertTrue(res.get(1).toDouble() == 45D); assertTrue(res.get(1).toDouble() == 45D);
assertTrue(res.get(0).toInt() == 1); assertTrue(res.get(0).toInt() == 1);
} }
@Test @Test
public void testDispatchFlatMap() { @DisplayName("Test Dispatch Flat Map")
void testDispatchFlatMap() {
AggregatorImpls.AggregableFirst<Integer> af = new AggregatorImpls.AggregableFirst<>(); AggregatorImpls.AggregableFirst<Integer> af = new AggregatorImpls.AggregableFirst<>();
AggregatorImpls.AggregableSum<Integer> as = new AggregatorImpls.AggregableSum<>(); AggregatorImpls.AggregableSum<Integer> as = new AggregatorImpls.AggregableSum<>();
AggregableMultiOp<Integer> multi = new AggregableMultiOp<>(Arrays.asList(af, as)); AggregableMultiOp<Integer> multi = new AggregableMultiOp<>(Arrays.asList(af, as));
AggregatorImpls.AggregableLast<Integer> al = new AggregatorImpls.AggregableLast<>(); AggregatorImpls.AggregableLast<Integer> al = new AggregatorImpls.AggregableLast<>();
AggregatorImpls.AggregableMax<Integer> amax = new AggregatorImpls.AggregableMax<>(); AggregatorImpls.AggregableMax<Integer> amax = new AggregatorImpls.AggregableMax<>();
AggregableMultiOp<Integer> otherMulti = new AggregableMultiOp<>(Arrays.asList(al, amax)); AggregableMultiOp<Integer> otherMulti = new AggregableMultiOp<>(Arrays.asList(al, amax));
DispatchOp<Integer, Writable> parallel = new DispatchOp<>(Arrays.<IAggregableReduceOp<Integer, List<Writable>>>asList(multi, otherMulti));
DispatchOp<Integer, Writable> parallel = new DispatchOp<>(
Arrays.<IAggregableReduceOp<Integer, List<Writable>>>asList(multi, otherMulti));
assertTrue(multi.getOperations().size() == 2); assertTrue(multi.getOperations().size() == 2);
assertTrue(otherMulti.getOperations().size() == 2); assertTrue(otherMulti.getOperations().size() == 2);
assertTrue(parallel.getOperations().size() == 2); assertTrue(parallel.getOperations().size() == 2);
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
parallel.accept(Arrays.asList(intList.get(i), intList.get(i))); parallel.accept(Arrays.asList(intList.get(i), intList.get(i)));
} }
List<Writable> res = parallel.get(); List<Writable> res = parallel.get();
assertTrue(res.get(1).toDouble() == 45D); assertTrue(res.get(1).toDouble() == 45D);
assertTrue(res.get(0).toInt() == 1); assertTrue(res.get(0).toInt() == 1);
assertTrue(res.get(3).toDouble() == 9); assertTrue(res.get(3).toDouble() == 9);
assertTrue(res.get(2).toInt() == 9); assertTrue(res.get(2).toInt() == 9);
} }
} }

View File

@ -17,29 +17,29 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.transform.transform.parse; package org.datavec.api.transform.transform.parse;
import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals; @DisplayName("Parse Double Transform Test")
class ParseDoubleTransformTest extends BaseND4JTest {
public class ParseDoubleTransformTest extends BaseND4JTest {
@Test @Test
public void testDoubleTransform() { @DisplayName("Test Double Transform")
void testDoubleTransform() {
List<Writable> record = new ArrayList<>(); List<Writable> record = new ArrayList<>();
record.add(new Text("0.0")); record.add(new Text("0.0"));
List<Writable> transformed = Arrays.<Writable>asList(new DoubleWritable(0.0)); List<Writable> transformed = Arrays.<Writable>asList(new DoubleWritable(0.0));
assertEquals(transformed, new ParseDoubleTransform().map(record)); assertEquals(transformed, new ParseDoubleTransform().map(record));
} }
} }

View File

@ -17,30 +17,31 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.util; package org.datavec.api.util;
import org.junit.Before; import org.junit.jupiter.api.BeforeEach;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.io.BufferedReader; import java.io.BufferedReader;
import java.io.File; import java.io.File;
import java.io.InputStream; import java.io.InputStream;
import java.io.InputStreamReader; import java.io.InputStreamReader;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.Assert.assertTrue;
import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.core.AnyOf.anyOf; import static org.hamcrest.core.AnyOf.anyOf;
import static org.hamcrest.core.IsEqual.equalTo; import static org.hamcrest.core.IsEqual.equalTo;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
public class ClassPathResourceTest extends BaseND4JTest { @DisplayName("Class Path Resource Test")
class ClassPathResourceTest extends BaseND4JTest {
private boolean isWindows = false; //File sizes are reported slightly different on Linux vs. Windows // File sizes are reported slightly different on Linux vs. Windows
private boolean isWindows = false;
@Before @BeforeEach
public void setUp() throws Exception { void setUp() throws Exception {
String osname = System.getProperty("os.name"); String osname = System.getProperty("os.name");
if (osname != null && osname.toLowerCase().contains("win")) { if (osname != null && osname.toLowerCase().contains("win")) {
isWindows = true; isWindows = true;
@ -48,9 +49,9 @@ public class ClassPathResourceTest extends BaseND4JTest {
} }
@Test @Test
public void testGetFile1() throws Exception { @DisplayName("Test Get File 1")
void testGetFile1() throws Exception {
File intFile = new ClassPathResource("datavec-api/iris.dat").getFile(); File intFile = new ClassPathResource("datavec-api/iris.dat").getFile();
assertTrue(intFile.exists()); assertTrue(intFile.exists());
if (isWindows) { if (isWindows) {
assertThat(intFile.length(), anyOf(equalTo(2700L), equalTo(2850L))); assertThat(intFile.length(), anyOf(equalTo(2700L), equalTo(2850L)));
@ -60,9 +61,9 @@ public class ClassPathResourceTest extends BaseND4JTest {
} }
@Test @Test
public void testGetFileSlash1() throws Exception { @DisplayName("Test Get File Slash 1")
void testGetFileSlash1() throws Exception {
File intFile = new ClassPathResource("datavec-api/iris.dat").getFile(); File intFile = new ClassPathResource("datavec-api/iris.dat").getFile();
assertTrue(intFile.exists()); assertTrue(intFile.exists());
if (isWindows) { if (isWindows) {
assertThat(intFile.length(), anyOf(equalTo(2700L), equalTo(2850L))); assertThat(intFile.length(), anyOf(equalTo(2700L), equalTo(2850L)));
@ -72,11 +73,10 @@ public class ClassPathResourceTest extends BaseND4JTest {
} }
@Test @Test
public void testGetFileWithSpace1() throws Exception { @DisplayName("Test Get File With Space 1")
void testGetFileWithSpace1() throws Exception {
File intFile = new ClassPathResource("datavec-api/csvsequence test.txt").getFile(); File intFile = new ClassPathResource("datavec-api/csvsequence test.txt").getFile();
assertTrue(intFile.exists()); assertTrue(intFile.exists());
if (isWindows) { if (isWindows) {
assertThat(intFile.length(), anyOf(equalTo(60L), equalTo(64L))); assertThat(intFile.length(), anyOf(equalTo(60L), equalTo(64L)));
} else { } else {
@ -85,16 +85,15 @@ public class ClassPathResourceTest extends BaseND4JTest {
} }
@Test @Test
public void testInputStream() throws Exception { @DisplayName("Test Input Stream")
void testInputStream() throws Exception {
ClassPathResource resource = new ClassPathResource("datavec-api/csvsequence_1.txt"); ClassPathResource resource = new ClassPathResource("datavec-api/csvsequence_1.txt");
File intFile = resource.getFile(); File intFile = resource.getFile();
if (isWindows) { if (isWindows) {
assertThat(intFile.length(), anyOf(equalTo(60L), equalTo(64L))); assertThat(intFile.length(), anyOf(equalTo(60L), equalTo(64L)));
} else { } else {
assertEquals(60, intFile.length()); assertEquals(60, intFile.length());
} }
InputStream stream = resource.getInputStream(); InputStream stream = resource.getInputStream();
BufferedReader reader = new BufferedReader(new InputStreamReader(stream)); BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
String line = ""; String line = "";
@ -102,21 +101,19 @@ public class ClassPathResourceTest extends BaseND4JTest {
while ((line = reader.readLine()) != null) { while ((line = reader.readLine()) != null) {
cnt++; cnt++;
} }
assertEquals(5, cnt); assertEquals(5, cnt);
} }
@Test @Test
public void testInputStreamSlash() throws Exception { @DisplayName("Test Input Stream Slash")
void testInputStreamSlash() throws Exception {
ClassPathResource resource = new ClassPathResource("datavec-api/csvsequence_1.txt"); ClassPathResource resource = new ClassPathResource("datavec-api/csvsequence_1.txt");
File intFile = resource.getFile(); File intFile = resource.getFile();
if (isWindows) { if (isWindows) {
assertThat(intFile.length(), anyOf(equalTo(60L), equalTo(64L))); assertThat(intFile.length(), anyOf(equalTo(60L), equalTo(64L)));
} else { } else {
assertEquals(60, intFile.length()); assertEquals(60, intFile.length());
} }
InputStream stream = resource.getInputStream(); InputStream stream = resource.getInputStream();
BufferedReader reader = new BufferedReader(new InputStreamReader(stream)); BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
String line = ""; String line = "";
@ -124,7 +121,6 @@ public class ClassPathResourceTest extends BaseND4JTest {
while ((line = reader.readLine()) != null) { while ((line = reader.readLine()) != null) {
cnt++; cnt++;
} }
assertEquals(5, cnt); assertEquals(5, cnt);
} }
} }

View File

@ -17,44 +17,41 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.util; package org.datavec.api.util;
import org.datavec.api.timeseries.util.TimeSeriesWritableUtils; import org.datavec.api.timeseries.util.TimeSeriesWritableUtils;
import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertArrayEquals; @DisplayName("Time Series Utils Test")
class TimeSeriesUtilsTest extends BaseND4JTest {
public class TimeSeriesUtilsTest extends BaseND4JTest {
@Test @Test
public void testTimeSeriesCreation() { @DisplayName("Test Time Series Creation")
void testTimeSeriesCreation() {
List<List<List<Writable>>> test = new ArrayList<>(); List<List<List<Writable>>> test = new ArrayList<>();
List<List<Writable>> timeStep = new ArrayList<>(); List<List<Writable>> timeStep = new ArrayList<>();
for(int i = 0; i < 5; i++) { for (int i = 0; i < 5; i++) {
timeStep.add(getRecord(5)); timeStep.add(getRecord(5));
} }
test.add(timeStep); test.add(timeStep);
INDArray arr = TimeSeriesWritableUtils.convertWritablesSequence(test).getFirst(); INDArray arr = TimeSeriesWritableUtils.convertWritablesSequence(test).getFirst();
assertArrayEquals(new long[]{1,5,5},arr.shape()); assertArrayEquals(new long[] { 1, 5, 5 }, arr.shape());
} }
private List<Writable> getRecord(int length) { private List<Writable> getRecord(int length) {
List<Writable> ret = new ArrayList<>(); List<Writable> ret = new ArrayList<>();
for(int i = 0; i < length; i++) { for (int i = 0; i < length; i++) {
ret.add(new DoubleWritable(1.0)); ret.add(new DoubleWritable(1.0));
} }
return ret; return ret;
} }
} }

View File

@ -17,52 +17,50 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.writable; package org.datavec.api.writable;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.shade.guava.collect.Lists; import org.nd4j.shade.guava.collect.Lists;
import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.Schema;
import org.datavec.api.util.ndarray.RecordConverter; import org.datavec.api.util.ndarray.RecordConverter;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.TimeZone; import java.util.TimeZone;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals; @DisplayName("Record Converter Test")
class RecordConverterTest extends BaseND4JTest {
public class RecordConverterTest extends BaseND4JTest {
@Test @Test
public void toRecords_PassInClassificationDataSet_ExpectNDArrayAndIntWritables() { @DisplayName("To Records _ Pass In Classification Data Set _ Expect ND Array And Int Writables")
INDArray feature1 = Nd4j.create(new double[]{4, -5.7, 10, -0.1}, new long[]{1, 4}, DataType.FLOAT); void toRecords_PassInClassificationDataSet_ExpectNDArrayAndIntWritables() {
INDArray feature2 = Nd4j.create(new double[]{11, .7, -1.3, 4}, new long[]{1, 4}, DataType.FLOAT); INDArray feature1 = Nd4j.create(new double[] { 4, -5.7, 10, -0.1 }, new long[] { 1, 4 }, DataType.FLOAT);
INDArray label1 = Nd4j.create(new double[]{0, 0, 1, 0}, new long[]{1, 4}, DataType.FLOAT); INDArray feature2 = Nd4j.create(new double[] { 11, .7, -1.3, 4 }, new long[] { 1, 4 }, DataType.FLOAT);
INDArray label2 = Nd4j.create(new double[]{0, 1, 0, 0}, new long[]{1, 4}, DataType.FLOAT); INDArray label1 = Nd4j.create(new double[] { 0, 0, 1, 0 }, new long[] { 1, 4 }, DataType.FLOAT);
DataSet dataSet = new DataSet(Nd4j.vstack(Lists.newArrayList(feature1, feature2)), INDArray label2 = Nd4j.create(new double[] { 0, 1, 0, 0 }, new long[] { 1, 4 }, DataType.FLOAT);
Nd4j.vstack(Lists.newArrayList(label1, label2))); DataSet dataSet = new DataSet(Nd4j.vstack(Lists.newArrayList(feature1, feature2)), Nd4j.vstack(Lists.newArrayList(label1, label2)));
List<List<Writable>> writableList = RecordConverter.toRecords(dataSet); List<List<Writable>> writableList = RecordConverter.toRecords(dataSet);
assertEquals(2, writableList.size()); assertEquals(2, writableList.size());
testClassificationWritables(feature1, 2, writableList.get(0)); testClassificationWritables(feature1, 2, writableList.get(0));
testClassificationWritables(feature2, 1, writableList.get(1)); testClassificationWritables(feature2, 1, writableList.get(1));
} }
@Test @Test
public void toRecords_PassInRegressionDataSet_ExpectNDArrayAndDoubleWritables() { @DisplayName("To Records _ Pass In Regression Data Set _ Expect ND Array And Double Writables")
INDArray feature = Nd4j.create(new double[]{4, -5.7, 10, -0.1}, new long[]{1, 4}, DataType.FLOAT); void toRecords_PassInRegressionDataSet_ExpectNDArrayAndDoubleWritables() {
INDArray label = Nd4j.create(new double[]{.5, 2, 3, .5}, new long[]{1, 4}, DataType.FLOAT); INDArray feature = Nd4j.create(new double[] { 4, -5.7, 10, -0.1 }, new long[] { 1, 4 }, DataType.FLOAT);
INDArray label = Nd4j.create(new double[] { .5, 2, 3, .5 }, new long[] { 1, 4 }, DataType.FLOAT);
DataSet dataSet = new DataSet(feature, label); DataSet dataSet = new DataSet(feature, label);
List<List<Writable>> writableList = RecordConverter.toRecords(dataSet); List<List<Writable>> writableList = RecordConverter.toRecords(dataSet);
List<Writable> results = writableList.get(0); List<Writable> results = writableList.get(0);
NDArrayWritable ndArrayWritable = (NDArrayWritable) results.get(0); NDArrayWritable ndArrayWritable = (NDArrayWritable) results.get(0);
assertEquals(1, writableList.size()); assertEquals(1, writableList.size());
assertEquals(5, results.size()); assertEquals(5, results.size());
assertEquals(feature, ndArrayWritable.get()); assertEquals(feature, ndArrayWritable.get());
@ -72,62 +70,39 @@ public class RecordConverterTest extends BaseND4JTest {
} }
} }
private void testClassificationWritables(INDArray expectedFeatureVector, int expectLabelIndex, private void testClassificationWritables(INDArray expectedFeatureVector, int expectLabelIndex, List<Writable> writables) {
List<Writable> writables) {
NDArrayWritable ndArrayWritable = (NDArrayWritable) writables.get(0); NDArrayWritable ndArrayWritable = (NDArrayWritable) writables.get(0);
IntWritable intWritable = (IntWritable) writables.get(1); IntWritable intWritable = (IntWritable) writables.get(1);
assertEquals(2, writables.size()); assertEquals(2, writables.size());
assertEquals(expectedFeatureVector, ndArrayWritable.get()); assertEquals(expectedFeatureVector, ndArrayWritable.get());
assertEquals(expectLabelIndex, intWritable.get()); assertEquals(expectLabelIndex, intWritable.get());
} }
@Test @Test
public void testNDArrayWritableConcat() { @DisplayName("Test ND Array Writable Concat")
List<Writable> l = Arrays.<Writable>asList(new DoubleWritable(1), void testNDArrayWritableConcat() {
new NDArrayWritable(Nd4j.create(new double[]{2, 3, 4}, new long[]{1, 3}, DataType.FLOAT)), new DoubleWritable(5), List<Writable> l = Arrays.<Writable>asList(new DoubleWritable(1), new NDArrayWritable(Nd4j.create(new double[] { 2, 3, 4 }, new long[] { 1, 3 }, DataType.FLOAT)), new DoubleWritable(5), new NDArrayWritable(Nd4j.create(new double[] { 6, 7, 8 }, new long[] { 1, 3 }, DataType.FLOAT)), new IntWritable(9), new IntWritable(1));
new NDArrayWritable(Nd4j.create(new double[]{6, 7, 8}, new long[]{1, 3}, DataType.FLOAT)), new IntWritable(9), INDArray exp = Nd4j.create(new double[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 1 }, new long[] { 1, 10 }, DataType.FLOAT);
new IntWritable(1));
INDArray exp = Nd4j.create(new double[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 1}, new long[]{1, 10}, DataType.FLOAT);
INDArray act = RecordConverter.toArray(DataType.FLOAT, l); INDArray act = RecordConverter.toArray(DataType.FLOAT, l);
assertEquals(exp, act); assertEquals(exp, act);
} }
@Test @Test
public void testNDArrayWritableConcatToMatrix(){ @DisplayName("Test ND Array Writable Concat To Matrix")
void testNDArrayWritableConcatToMatrix() {
List<Writable> l1 = Arrays.<Writable>asList(new DoubleWritable(1), new NDArrayWritable(Nd4j.create(new double[]{2, 3, 4}, new long[]{1,3}, DataType.FLOAT)), new DoubleWritable(5)); List<Writable> l1 = Arrays.<Writable>asList(new DoubleWritable(1), new NDArrayWritable(Nd4j.create(new double[] { 2, 3, 4 }, new long[] { 1, 3 }, DataType.FLOAT)), new DoubleWritable(5));
List<Writable> l2 = Arrays.<Writable>asList(new DoubleWritable(6), new NDArrayWritable(Nd4j.create(new double[]{7, 8, 9}, new long[]{1,3}, DataType.FLOAT)), new DoubleWritable(10)); List<Writable> l2 = Arrays.<Writable>asList(new DoubleWritable(6), new NDArrayWritable(Nd4j.create(new double[] { 7, 8, 9 }, new long[] { 1, 3 }, DataType.FLOAT)), new DoubleWritable(10));
INDArray exp = Nd4j.create(new double[][] { { 1, 2, 3, 4, 5 }, { 6, 7, 8, 9, 10 } }).castTo(DataType.FLOAT);
INDArray exp = Nd4j.create(new double[][]{ INDArray act = RecordConverter.toMatrix(DataType.FLOAT, Arrays.asList(l1, l2));
{1,2,3,4,5},
{6,7,8,9,10}}).castTo(DataType.FLOAT);
INDArray act = RecordConverter.toMatrix(DataType.FLOAT, Arrays.asList(l1,l2));
assertEquals(exp, act); assertEquals(exp, act);
} }
@Test @Test
public void testToRecordWithListOfObject(){ @DisplayName("Test To Record With List Of Object")
final List<Object> list = Arrays.asList((Object)3, 7.0f, "Foo", "Bar", 1.0, 3f, 3L, 7, 0L); void testToRecordWithListOfObject() {
final Schema schema = new Schema.Builder() final List<Object> list = Arrays.asList((Object) 3, 7.0f, "Foo", "Bar", 1.0, 3f, 3L, 7, 0L);
.addColumnInteger("a") final Schema schema = new Schema.Builder().addColumnInteger("a").addColumnFloat("b").addColumnString("c").addColumnCategorical("d", "Bar", "Baz").addColumnDouble("e").addColumnFloat("f").addColumnLong("g").addColumnInteger("h").addColumnTime("i", TimeZone.getDefault()).build();
.addColumnFloat("b")
.addColumnString("c")
.addColumnCategorical("d", "Bar", "Baz")
.addColumnDouble("e")
.addColumnFloat("f")
.addColumnLong("g")
.addColumnInteger("h")
.addColumnTime("i", TimeZone.getDefault())
.build();
final List<Writable> record = RecordConverter.toRecord(schema, list); final List<Writable> record = RecordConverter.toRecord(schema, list);
assertEquals(record.get(0).toInt(), 3); assertEquals(record.get(0).toInt(), 3);
assertEquals(record.get(1).toFloat(), 7f, 1e-6); assertEquals(record.get(1).toFloat(), 7f, 1e-6);
assertEquals(record.get(2).toString(), "Foo"); assertEquals(record.get(2).toString(), "Foo");
@ -137,7 +112,5 @@ public class RecordConverterTest extends BaseND4JTest {
assertEquals(record.get(6).toLong(), 3L); assertEquals(record.get(6).toLong(), 3L);
assertEquals(record.get(7).toInt(), 7); assertEquals(record.get(7).toInt(), 7);
assertEquals(record.get(8).toLong(), 0); assertEquals(record.get(8).toLong(), 0);
} }
} }

View File

@ -17,38 +17,38 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.writable; package org.datavec.api.writable;
import org.datavec.api.writable.batch.NDArrayRecordBatch; import org.datavec.api.writable.batch.NDArrayRecordBatch;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import java.nio.Buffer; import java.nio.Buffer;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Iterator; import java.util.Iterator;
import java.util.List; import java.util.List;
import org.junit.jupiter.api.DisplayName;
import static org.junit.Assert.*; import static org.junit.jupiter.api.Assertions.*;
public class WritableTest extends BaseND4JTest { @DisplayName("Writable Test")
class WritableTest extends BaseND4JTest {
@Test @Test
public void testWritableEqualityReflexive() { @DisplayName("Test Writable Equality Reflexive")
void testWritableEqualityReflexive() {
assertEquals(new IntWritable(1), new IntWritable(1)); assertEquals(new IntWritable(1), new IntWritable(1));
assertEquals(new LongWritable(1), new LongWritable(1)); assertEquals(new LongWritable(1), new LongWritable(1));
assertEquals(new DoubleWritable(1), new DoubleWritable(1)); assertEquals(new DoubleWritable(1), new DoubleWritable(1));
assertEquals(new FloatWritable(1), new FloatWritable(1)); assertEquals(new FloatWritable(1), new FloatWritable(1));
assertEquals(new Text("Hello"), new Text("Hello")); assertEquals(new Text("Hello"), new Text("Hello"));
assertEquals(new BytesWritable("Hello".getBytes()),new BytesWritable("Hello".getBytes())); assertEquals(new BytesWritable("Hello".getBytes()), new BytesWritable("Hello".getBytes()));
INDArray ndArray = Nd4j.rand(new int[]{1, 100}); INDArray ndArray = Nd4j.rand(new int[] { 1, 100 });
assertEquals(new NDArrayWritable(ndArray), new NDArrayWritable(ndArray)); assertEquals(new NDArrayWritable(ndArray), new NDArrayWritable(ndArray));
assertEquals(new NullWritable(), new NullWritable()); assertEquals(new NullWritable(), new NullWritable());
assertEquals(new BooleanWritable(true), new BooleanWritable(true)); assertEquals(new BooleanWritable(true), new BooleanWritable(true));
@ -56,9 +56,9 @@ public class WritableTest extends BaseND4JTest {
assertEquals(new ByteWritable(b), new ByteWritable(b)); assertEquals(new ByteWritable(b), new ByteWritable(b));
} }
@Test @Test
public void testBytesWritableIndexing() { @DisplayName("Test Bytes Writable Indexing")
void testBytesWritableIndexing() {
byte[] doubleWrite = new byte[16]; byte[] doubleWrite = new byte[16];
ByteBuffer wrapped = ByteBuffer.wrap(doubleWrite); ByteBuffer wrapped = ByteBuffer.wrap(doubleWrite);
Buffer buffer = (Buffer) wrapped; Buffer buffer = (Buffer) wrapped;
@ -66,53 +66,51 @@ public class WritableTest extends BaseND4JTest {
wrapped.putDouble(2.0); wrapped.putDouble(2.0);
buffer.rewind(); buffer.rewind();
BytesWritable byteWritable = new BytesWritable(doubleWrite); BytesWritable byteWritable = new BytesWritable(doubleWrite);
assertEquals(2,byteWritable.getDouble(1),1e-1); assertEquals(2, byteWritable.getDouble(1), 1e-1);
DataBuffer dataBuffer = Nd4j.createBuffer(new double[] {1,2}); DataBuffer dataBuffer = Nd4j.createBuffer(new double[] { 1, 2 });
double[] d1 = dataBuffer.asDouble(); double[] d1 = dataBuffer.asDouble();
double[] d2 = byteWritable.asNd4jBuffer(DataType.DOUBLE,8).asDouble(); double[] d2 = byteWritable.asNd4jBuffer(DataType.DOUBLE, 8).asDouble();
assertArrayEquals(d1, d2, 0.0); assertArrayEquals(d1, d2, 0.0);
} }
@Test @Test
public void testByteWritable() { @DisplayName("Test Byte Writable")
void testByteWritable() {
byte b = 0xfffffffe; byte b = 0xfffffffe;
assertEquals(new IntWritable(-2), new ByteWritable(b)); assertEquals(new IntWritable(-2), new ByteWritable(b));
assertEquals(new LongWritable(-2), new ByteWritable(b)); assertEquals(new LongWritable(-2), new ByteWritable(b));
assertEquals(new ByteWritable(b), new IntWritable(-2)); assertEquals(new ByteWritable(b), new IntWritable(-2));
assertEquals(new ByteWritable(b), new LongWritable(-2)); assertEquals(new ByteWritable(b), new LongWritable(-2));
// those would cast to the same Int // those would cast to the same Int
byte minus126 = 0xffffff82; byte minus126 = 0xffffff82;
assertNotEquals(new ByteWritable(minus126), new IntWritable(130)); assertNotEquals(new ByteWritable(minus126), new IntWritable(130));
} }
@Test @Test
public void testIntLongWritable() { @DisplayName("Test Int Long Writable")
void testIntLongWritable() {
assertEquals(new IntWritable(1), new LongWritable(1l)); assertEquals(new IntWritable(1), new LongWritable(1l));
assertEquals(new LongWritable(2l), new IntWritable(2)); assertEquals(new LongWritable(2l), new IntWritable(2));
long l = 1L << 34; long l = 1L << 34;
// those would cast to the same Int // those would cast to the same Int
assertNotEquals(new LongWritable(l), new IntWritable(4)); assertNotEquals(new LongWritable(l), new IntWritable(4));
} }
@Test @Test
public void testDoubleFloatWritable() { @DisplayName("Test Double Float Writable")
void testDoubleFloatWritable() {
assertEquals(new DoubleWritable(1d), new FloatWritable(1f)); assertEquals(new DoubleWritable(1d), new FloatWritable(1f));
assertEquals(new FloatWritable(2f), new DoubleWritable(2d)); assertEquals(new FloatWritable(2f), new DoubleWritable(2d));
// we defer to Java equality for Floats // we defer to Java equality for Floats
assertNotEquals(new DoubleWritable(1.1d), new FloatWritable(1.1f)); assertNotEquals(new DoubleWritable(1.1d), new FloatWritable(1.1f));
// same idea as above // same idea as above
assertNotEquals(new DoubleWritable(1.1d), new FloatWritable((float)1.1d)); assertNotEquals(new DoubleWritable(1.1d), new FloatWritable((float) 1.1d));
assertNotEquals(new DoubleWritable((double) Float.MAX_VALUE + 1), new FloatWritable(Float.POSITIVE_INFINITY));
assertNotEquals(new DoubleWritable((double)Float.MAX_VALUE + 1), new FloatWritable(Float.POSITIVE_INFINITY));
} }
@Test @Test
public void testFuzzies() { @DisplayName("Test Fuzzies")
void testFuzzies() {
assertTrue(new DoubleWritable(1.1d).fuzzyEquals(new FloatWritable(1.1f), 1e-6d)); assertTrue(new DoubleWritable(1.1d).fuzzyEquals(new FloatWritable(1.1f), 1e-6d));
assertTrue(new FloatWritable(1.1f).fuzzyEquals(new DoubleWritable(1.1d), 1e-6d)); assertTrue(new FloatWritable(1.1f).fuzzyEquals(new DoubleWritable(1.1d), 1e-6d));
byte b = 0xfffffffe; byte b = 0xfffffffe;
@ -122,62 +120,57 @@ public class WritableTest extends BaseND4JTest {
assertTrue(new LongWritable(1).fuzzyEquals(new DoubleWritable(1.05f), 1e-1d)); assertTrue(new LongWritable(1).fuzzyEquals(new DoubleWritable(1.05f), 1e-1d));
} }
@Test @Test
public void testNDArrayRecordBatch(){ @DisplayName("Test ND Array Record Batch")
void testNDArrayRecordBatch() {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
// Outer list over writables/columns, inner list over examples
List<List<INDArray>> orig = new ArrayList<>(); //Outer list over writables/columns, inner list over examples List<List<INDArray>> orig = new ArrayList<>();
for( int i=0; i<3; i++ ){ for (int i = 0; i < 3; i++) {
orig.add(new ArrayList<INDArray>()); orig.add(new ArrayList<INDArray>());
} }
for (int i = 0; i < 5; i++) {
for( int i=0; i<5; i++ ){ orig.get(0).add(Nd4j.rand(1, 10));
orig.get(0).add(Nd4j.rand(1,10)); orig.get(1).add(Nd4j.rand(new int[] { 1, 5, 6 }));
orig.get(1).add(Nd4j.rand(new int[]{1,5,6})); orig.get(2).add(Nd4j.rand(new int[] { 1, 3, 4, 5 }));
orig.get(2).add(Nd4j.rand(new int[]{1,3,4,5}));
} }
// Outer list over examples, inner list over writables
List<List<INDArray>> origByExample = new ArrayList<>(); //Outer list over examples, inner list over writables List<List<INDArray>> origByExample = new ArrayList<>();
for( int i=0; i<5; i++ ){ for (int i = 0; i < 5; i++) {
origByExample.add(Arrays.asList(orig.get(0).get(i), orig.get(1).get(i), orig.get(2).get(i))); origByExample.add(Arrays.asList(orig.get(0).get(i), orig.get(1).get(i), orig.get(2).get(i)));
} }
List<INDArray> batched = new ArrayList<>(); List<INDArray> batched = new ArrayList<>();
for(List<INDArray> l : orig){ for (List<INDArray> l : orig) {
batched.add(Nd4j.concat(0, l.toArray(new INDArray[5]))); batched.add(Nd4j.concat(0, l.toArray(new INDArray[5])));
} }
NDArrayRecordBatch batch = new NDArrayRecordBatch(batched); NDArrayRecordBatch batch = new NDArrayRecordBatch(batched);
assertEquals(5, batch.size()); assertEquals(5, batch.size());
for( int i=0; i<5; i++ ){ for (int i = 0; i < 5; i++) {
List<Writable> act = batch.get(i); List<Writable> act = batch.get(i);
List<INDArray> unboxed = new ArrayList<>(); List<INDArray> unboxed = new ArrayList<>();
for(Writable w : act){ for (Writable w : act) {
unboxed.add(((NDArrayWritable)w).get()); unboxed.add(((NDArrayWritable) w).get());
} }
List<INDArray> exp = origByExample.get(i); List<INDArray> exp = origByExample.get(i);
assertEquals(exp.size(), unboxed.size()); assertEquals(exp.size(), unboxed.size());
for( int j=0; j<exp.size(); j++ ){ for (int j = 0; j < exp.size(); j++) {
assertEquals(exp.get(j), unboxed.get(j)); assertEquals(exp.get(j), unboxed.get(j));
} }
} }
Iterator<List<Writable>> iter = batch.iterator(); Iterator<List<Writable>> iter = batch.iterator();
int count = 0; int count = 0;
while(iter.hasNext()){ while (iter.hasNext()) {
List<Writable> next = iter.next(); List<Writable> next = iter.next();
List<INDArray> unboxed = new ArrayList<>(); List<INDArray> unboxed = new ArrayList<>();
for(Writable w : next){ for (Writable w : next) {
unboxed.add(((NDArrayWritable)w).get()); unboxed.add(((NDArrayWritable) w).get());
} }
List<INDArray> exp = origByExample.get(count++); List<INDArray> exp = origByExample.get(count++);
assertEquals(exp.size(), unboxed.size()); assertEquals(exp.size(), unboxed.size());
for( int j=0; j<exp.size(); j++ ){ for (int j = 0; j < exp.size(); j++) {
assertEquals(exp.get(j), unboxed.get(j)); assertEquals(exp.get(j), unboxed.get(j));
} }
} }
assertEquals(5, count); assertEquals(5, count);
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.arrow; package org.datavec.arrow;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@ -43,460 +42,398 @@ import org.datavec.api.writable.*;
import org.datavec.arrow.recordreader.ArrowRecordReader; import org.datavec.arrow.recordreader.ArrowRecordReader;
import org.datavec.arrow.recordreader.ArrowWritableRecordBatch; import org.datavec.arrow.recordreader.ArrowWritableRecordBatch;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.TemporaryFolder; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.primitives.Pair; import org.nd4j.common.primitives.Pair;
import java.io.ByteArrayOutputStream; import java.io.ByteArrayOutputStream;
import java.io.File; import java.io.File;
import java.io.FileOutputStream; import java.io.FileOutputStream;
import java.io.IOException; import java.io.IOException;
import java.util.*; import java.util.*;
import static java.nio.channels.Channels.newChannel; import static java.nio.channels.Channels.newChannel;
import static junit.framework.TestCase.assertTrue; import static junit.framework.TestCase.assertTrue;
import static org.junit.Assert.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
@Slf4j @Slf4j
public class ArrowConverterTest extends BaseND4JTest { @DisplayName("Arrow Converter Test")
class ArrowConverterTest extends BaseND4JTest {
private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE); private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE);
@Rule @TempDir
public TemporaryFolder testDir = new TemporaryFolder(); public Path testDir;
@Test @Test
public void testToArrayFromINDArray() { @DisplayName("Test To Array From IND Array")
void testToArrayFromINDArray() {
Schema.Builder schemaBuilder = new Schema.Builder(); Schema.Builder schemaBuilder = new Schema.Builder();
schemaBuilder.addColumnNDArray("outputArray",new long[]{1,4}); schemaBuilder.addColumnNDArray("outputArray", new long[] { 1, 4 });
Schema schema = schemaBuilder.build(); Schema schema = schemaBuilder.build();
int numRows = 4; int numRows = 4;
List<List<Writable>> ret = new ArrayList<>(numRows); List<List<Writable>> ret = new ArrayList<>(numRows);
for(int i = 0; i < numRows; i++) { for (int i = 0; i < numRows; i++) {
ret.add(Arrays.<Writable>asList(new NDArrayWritable(Nd4j.linspace(1,4,4).reshape(1, 4)))); ret.add(Arrays.<Writable>asList(new NDArrayWritable(Nd4j.linspace(1, 4, 4).reshape(1, 4))));
} }
List<FieldVector> fieldVectors = ArrowConverter.toArrowColumns(bufferAllocator, schema, ret); List<FieldVector> fieldVectors = ArrowConverter.toArrowColumns(bufferAllocator, schema, ret);
ArrowWritableRecordBatch arrowWritableRecordBatch = new ArrowWritableRecordBatch(fieldVectors,schema); ArrowWritableRecordBatch arrowWritableRecordBatch = new ArrowWritableRecordBatch(fieldVectors, schema);
INDArray array = ArrowConverter.toArray(arrowWritableRecordBatch); INDArray array = ArrowConverter.toArray(arrowWritableRecordBatch);
assertArrayEquals(new long[]{4,4},array.shape()); assertArrayEquals(new long[] { 4, 4 }, array.shape());
INDArray assertion = Nd4j.repeat(Nd4j.linspace(1, 4, 4), 4).reshape(4, 4);
INDArray assertion = Nd4j.repeat(Nd4j.linspace(1,4,4),4).reshape(4,4); assertEquals(assertion, array);
assertEquals(assertion,array);
} }
@Test @Test
public void testArrowColumnINDArray() { @DisplayName("Test Arrow Column IND Array")
void testArrowColumnINDArray() {
Schema.Builder schema = new Schema.Builder(); Schema.Builder schema = new Schema.Builder();
List<String> single = new ArrayList<>(); List<String> single = new ArrayList<>();
int numCols = 2; int numCols = 2;
INDArray arr = Nd4j.linspace(1,4,4); INDArray arr = Nd4j.linspace(1, 4, 4);
for(int i = 0; i < numCols; i++) { for (int i = 0; i < numCols; i++) {
schema.addColumnNDArray(String.valueOf(i),new long[]{1,4}); schema.addColumnNDArray(String.valueOf(i), new long[] { 1, 4 });
single.add(String.valueOf(i)); single.add(String.valueOf(i));
} }
Schema buildSchema = schema.build(); Schema buildSchema = schema.build();
List<List<Writable>> list = new ArrayList<>(); List<List<Writable>> list = new ArrayList<>();
List<Writable> firstRow = new ArrayList<>(); List<Writable> firstRow = new ArrayList<>();
for(int i = 0 ; i < numCols; i++) { for (int i = 0; i < numCols; i++) {
firstRow.add(new NDArrayWritable(arr)); firstRow.add(new NDArrayWritable(arr));
} }
list.add(firstRow); list.add(firstRow);
List<FieldVector> fieldVectors = ArrowConverter.toArrowColumns(bufferAllocator, buildSchema, list); List<FieldVector> fieldVectors = ArrowConverter.toArrowColumns(bufferAllocator, buildSchema, list);
assertEquals(numCols,fieldVectors.size()); assertEquals(numCols, fieldVectors.size());
assertEquals(1,fieldVectors.get(0).getValueCount()); assertEquals(1, fieldVectors.get(0).getValueCount());
assertFalse(fieldVectors.get(0).isNull(0)); assertFalse(fieldVectors.get(0).isNull(0));
ArrowWritableRecordBatch arrowWritableRecordBatch = ArrowConverter.toArrowWritables(fieldVectors, buildSchema); ArrowWritableRecordBatch arrowWritableRecordBatch = ArrowConverter.toArrowWritables(fieldVectors, buildSchema);
assertEquals(1,arrowWritableRecordBatch.size()); assertEquals(1, arrowWritableRecordBatch.size());
Writable writable = arrowWritableRecordBatch.get(0).get(0); Writable writable = arrowWritableRecordBatch.get(0).get(0);
assertTrue(writable instanceof NDArrayWritable); assertTrue(writable instanceof NDArrayWritable);
NDArrayWritable ndArrayWritable = (NDArrayWritable) writable; NDArrayWritable ndArrayWritable = (NDArrayWritable) writable;
assertEquals(arr,ndArrayWritable.get()); assertEquals(arr, ndArrayWritable.get());
Writable writable1 = ArrowConverter.fromEntry(0, fieldVectors.get(0), ColumnType.NDArray); Writable writable1 = ArrowConverter.fromEntry(0, fieldVectors.get(0), ColumnType.NDArray);
NDArrayWritable ndArrayWritablewritable1 = (NDArrayWritable) writable1; NDArrayWritable ndArrayWritablewritable1 = (NDArrayWritable) writable1;
System.out.println(ndArrayWritablewritable1.get()); System.out.println(ndArrayWritablewritable1.get());
} }
@Test @Test
public void testArrowColumnString() { @DisplayName("Test Arrow Column String")
void testArrowColumnString() {
Schema.Builder schema = new Schema.Builder(); Schema.Builder schema = new Schema.Builder();
List<String> single = new ArrayList<>(); List<String> single = new ArrayList<>();
for(int i = 0; i < 2; i++) { for (int i = 0; i < 2; i++) {
schema.addColumnInteger(String.valueOf(i)); schema.addColumnInteger(String.valueOf(i));
single.add(String.valueOf(i)); single.add(String.valueOf(i));
} }
List<FieldVector> fieldVectors = ArrowConverter.toArrowColumnsStringSingle(bufferAllocator, schema.build(), single); List<FieldVector> fieldVectors = ArrowConverter.toArrowColumnsStringSingle(bufferAllocator, schema.build(), single);
List<List<Writable>> records = ArrowConverter.toArrowWritables(fieldVectors, schema.build()); List<List<Writable>> records = ArrowConverter.toArrowWritables(fieldVectors, schema.build());
List<List<Writable>> assertion = new ArrayList<>(); List<List<Writable>> assertion = new ArrayList<>();
assertion.add(Arrays.<Writable>asList(new IntWritable(0),new IntWritable(1))); assertion.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(1)));
assertEquals(assertion,records); assertEquals(assertion, records);
List<List<String>> batch = new ArrayList<>(); List<List<String>> batch = new ArrayList<>();
for(int i = 0; i < 2; i++) { for (int i = 0; i < 2; i++) {
batch.add(Arrays.asList(String.valueOf(i),String.valueOf(i))); batch.add(Arrays.asList(String.valueOf(i), String.valueOf(i)));
} }
List<FieldVector> fieldVectorsBatch = ArrowConverter.toArrowColumnsString(bufferAllocator, schema.build(), batch); List<FieldVector> fieldVectorsBatch = ArrowConverter.toArrowColumnsString(bufferAllocator, schema.build(), batch);
List<List<Writable>> batchRecords = ArrowConverter.toArrowWritables(fieldVectorsBatch, schema.build()); List<List<Writable>> batchRecords = ArrowConverter.toArrowWritables(fieldVectorsBatch, schema.build());
List<List<Writable>> assertionBatch = new ArrayList<>(); List<List<Writable>> assertionBatch = new ArrayList<>();
assertionBatch.add(Arrays.<Writable>asList(new IntWritable(0),new IntWritable(0))); assertionBatch.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(0)));
assertionBatch.add(Arrays.<Writable>asList(new IntWritable(1),new IntWritable(1))); assertionBatch.add(Arrays.<Writable>asList(new IntWritable(1), new IntWritable(1)));
assertEquals(assertionBatch,batchRecords); assertEquals(assertionBatch, batchRecords);
} }
@Test @Test
public void testArrowBatchSetTime() { @DisplayName("Test Arrow Batch Set Time")
void testArrowBatchSetTime() {
Schema.Builder schema = new Schema.Builder(); Schema.Builder schema = new Schema.Builder();
List<String> single = new ArrayList<>(); List<String> single = new ArrayList<>();
for(int i = 0; i < 2; i++) { for (int i = 0; i < 2; i++) {
schema.addColumnTime(String.valueOf(i),TimeZone.getDefault()); schema.addColumnTime(String.valueOf(i), TimeZone.getDefault());
single.add(String.valueOf(i)); single.add(String.valueOf(i));
} }
List<List<Writable>> input = Arrays.asList(Arrays.<Writable>asList(new LongWritable(0), new LongWritable(1)), Arrays.<Writable>asList(new LongWritable(2), new LongWritable(3)));
List<List<Writable>> input = Arrays.asList( List<FieldVector> fieldVector = ArrowConverter.toArrowColumns(bufferAllocator, schema.build(), input);
Arrays.<Writable>asList(new LongWritable(0),new LongWritable(1)), ArrowWritableRecordBatch writableRecordBatch = new ArrowWritableRecordBatch(fieldVector, schema.build());
Arrays.<Writable>asList(new LongWritable(2),new LongWritable(3))
);
List<FieldVector> fieldVector = ArrowConverter.toArrowColumns(bufferAllocator,schema.build(),input);
ArrowWritableRecordBatch writableRecordBatch = new ArrowWritableRecordBatch(fieldVector,schema.build());
List<Writable> assertion = Arrays.<Writable>asList(new LongWritable(4), new LongWritable(5)); List<Writable> assertion = Arrays.<Writable>asList(new LongWritable(4), new LongWritable(5));
writableRecordBatch.set(1, Arrays.<Writable>asList(new LongWritable(4),new LongWritable(5))); writableRecordBatch.set(1, Arrays.<Writable>asList(new LongWritable(4), new LongWritable(5)));
List<Writable> recordTest = writableRecordBatch.get(1); List<Writable> recordTest = writableRecordBatch.get(1);
assertEquals(assertion,recordTest); assertEquals(assertion, recordTest);
} }
@Test @Test
public void testArrowBatchSet() { @DisplayName("Test Arrow Batch Set")
void testArrowBatchSet() {
Schema.Builder schema = new Schema.Builder(); Schema.Builder schema = new Schema.Builder();
List<String> single = new ArrayList<>(); List<String> single = new ArrayList<>();
for(int i = 0; i < 2; i++) { for (int i = 0; i < 2; i++) {
schema.addColumnInteger(String.valueOf(i)); schema.addColumnInteger(String.valueOf(i));
single.add(String.valueOf(i)); single.add(String.valueOf(i));
} }
List<List<Writable>> input = Arrays.asList(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(1)), Arrays.<Writable>asList(new IntWritable(2), new IntWritable(3)));
List<List<Writable>> input = Arrays.asList( List<FieldVector> fieldVector = ArrowConverter.toArrowColumns(bufferAllocator, schema.build(), input);
Arrays.<Writable>asList(new IntWritable(0),new IntWritable(1)), ArrowWritableRecordBatch writableRecordBatch = new ArrowWritableRecordBatch(fieldVector, schema.build());
Arrays.<Writable>asList(new IntWritable(2),new IntWritable(3))
);
List<FieldVector> fieldVector = ArrowConverter.toArrowColumns(bufferAllocator,schema.build(),input);
ArrowWritableRecordBatch writableRecordBatch = new ArrowWritableRecordBatch(fieldVector,schema.build());
List<Writable> assertion = Arrays.<Writable>asList(new IntWritable(4), new IntWritable(5)); List<Writable> assertion = Arrays.<Writable>asList(new IntWritable(4), new IntWritable(5));
writableRecordBatch.set(1, Arrays.<Writable>asList(new IntWritable(4),new IntWritable(5))); writableRecordBatch.set(1, Arrays.<Writable>asList(new IntWritable(4), new IntWritable(5)));
List<Writable> recordTest = writableRecordBatch.get(1); List<Writable> recordTest = writableRecordBatch.get(1);
assertEquals(assertion,recordTest); assertEquals(assertion, recordTest);
} }
@Test @Test
public void testArrowColumnsStringTimeSeries() { @DisplayName("Test Arrow Columns String Time Series")
void testArrowColumnsStringTimeSeries() {
Schema.Builder schema = new Schema.Builder(); Schema.Builder schema = new Schema.Builder();
List<List<List<String>>> entries = new ArrayList<>(); List<List<List<String>>> entries = new ArrayList<>();
for(int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
schema.addColumnInteger(String.valueOf(i)); schema.addColumnInteger(String.valueOf(i));
} }
for (int i = 0; i < 5; i++) {
for(int i = 0; i < 5; i++) {
List<List<String>> arr = Arrays.asList(Arrays.asList(String.valueOf(i), String.valueOf(i), String.valueOf(i))); List<List<String>> arr = Arrays.asList(Arrays.asList(String.valueOf(i), String.valueOf(i), String.valueOf(i)));
entries.add(arr); entries.add(arr);
} }
List<FieldVector> fieldVectors = ArrowConverter.toArrowColumnsStringTimeSeries(bufferAllocator, schema.build(), entries); List<FieldVector> fieldVectors = ArrowConverter.toArrowColumnsStringTimeSeries(bufferAllocator, schema.build(), entries);
assertEquals(3,fieldVectors.size()); assertEquals(3, fieldVectors.size());
assertEquals(5,fieldVectors.get(0).getValueCount()); assertEquals(5, fieldVectors.get(0).getValueCount());
INDArray exp = Nd4j.create(5, 3); INDArray exp = Nd4j.create(5, 3);
for( int i = 0; i < 5; i++) { for (int i = 0; i < 5; i++) {
exp.getRow(i).assign(i); exp.getRow(i).assign(i);
} }
//Convert to ArrowWritableRecordBatch - note we can't do this in general with time series... // Convert to ArrowWritableRecordBatch - note we can't do this in general with time series...
ArrowWritableRecordBatch wri = ArrowConverter.toArrowWritables(fieldVectors, schema.build()); ArrowWritableRecordBatch wri = ArrowConverter.toArrowWritables(fieldVectors, schema.build());
INDArray arr = ArrowConverter.toArray(wri); INDArray arr = ArrowConverter.toArray(wri);
assertArrayEquals(new long[] {5,3}, arr.shape()); assertArrayEquals(new long[] { 5, 3 }, arr.shape());
assertEquals(exp, arr); assertEquals(exp, arr);
} }
@Test @Test
public void testConvertVector() { @DisplayName("Test Convert Vector")
void testConvertVector() {
Schema.Builder schema = new Schema.Builder(); Schema.Builder schema = new Schema.Builder();
List<List<List<String>>> entries = new ArrayList<>(); List<List<List<String>>> entries = new ArrayList<>();
for(int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
schema.addColumnInteger(String.valueOf(i)); schema.addColumnInteger(String.valueOf(i));
} }
for (int i = 0; i < 5; i++) {
for(int i = 0; i < 5; i++) {
List<List<String>> arr = Arrays.asList(Arrays.asList(String.valueOf(i), String.valueOf(i), String.valueOf(i))); List<List<String>> arr = Arrays.asList(Arrays.asList(String.valueOf(i), String.valueOf(i), String.valueOf(i)));
entries.add(arr); entries.add(arr);
} }
List<FieldVector> fieldVectors = ArrowConverter.toArrowColumnsStringTimeSeries(bufferAllocator, schema.build(), entries); List<FieldVector> fieldVectors = ArrowConverter.toArrowColumnsStringTimeSeries(bufferAllocator, schema.build(), entries);
INDArray arr = ArrowConverter.convertArrowVector(fieldVectors.get(0),schema.build().getType(0)); INDArray arr = ArrowConverter.convertArrowVector(fieldVectors.get(0), schema.build().getType(0));
assertEquals(5,arr.length()); assertEquals(5, arr.length());
} }
@Test @Test
public void testCreateNDArray() throws Exception { @DisplayName("Test Create ND Array")
void testCreateNDArray() throws Exception {
val recordsToWrite = recordToWrite(); val recordsToWrite = recordToWrite();
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(),recordsToWrite.getFirst(),byteArrayOutputStream); ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(), recordsToWrite.getFirst(), byteArrayOutputStream);
File f = testDir.toFile();
File f = testDir.newFolder();
File tmpFile = new File(f, "tmp-arrow-file-" + UUID.randomUUID().toString() + ".arrorw"); File tmpFile = new File(f, "tmp-arrow-file-" + UUID.randomUUID().toString() + ".arrorw");
FileOutputStream outputStream = new FileOutputStream(tmpFile); FileOutputStream outputStream = new FileOutputStream(tmpFile);
tmpFile.deleteOnExit(); tmpFile.deleteOnExit();
ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(),recordsToWrite.getFirst(),outputStream); ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(), recordsToWrite.getFirst(), outputStream);
outputStream.flush(); outputStream.flush();
outputStream.close(); outputStream.close();
Pair<Schema, ArrowWritableRecordBatch> schemaArrowWritableRecordBatchPair = ArrowConverter.readFromFile(tmpFile); Pair<Schema, ArrowWritableRecordBatch> schemaArrowWritableRecordBatchPair = ArrowConverter.readFromFile(tmpFile);
assertEquals(recordsToWrite.getFirst(),schemaArrowWritableRecordBatchPair.getFirst()); assertEquals(recordsToWrite.getFirst(), schemaArrowWritableRecordBatchPair.getFirst());
assertEquals(recordsToWrite.getRight(),schemaArrowWritableRecordBatchPair.getRight().toArrayList()); assertEquals(recordsToWrite.getRight(), schemaArrowWritableRecordBatchPair.getRight().toArrayList());
byte[] arr = byteArrayOutputStream.toByteArray(); byte[] arr = byteArrayOutputStream.toByteArray();
val read = ArrowConverter.readFromBytes(arr); val read = ArrowConverter.readFromBytes(arr);
assertEquals(recordsToWrite,read); assertEquals(recordsToWrite, read);
// send file
//send file File tmp = tmpDataFile(recordsToWrite);
File tmp = tmpDataFile(recordsToWrite);
ArrowRecordReader recordReader = new ArrowRecordReader(); ArrowRecordReader recordReader = new ArrowRecordReader();
recordReader.initialize(new FileSplit(tmp)); recordReader.initialize(new FileSplit(tmp));
recordReader.next(); recordReader.next();
ArrowWritableRecordBatch currentBatch = recordReader.getCurrentBatch(); ArrowWritableRecordBatch currentBatch = recordReader.getCurrentBatch();
INDArray arr2 = ArrowConverter.toArray(currentBatch); INDArray arr2 = ArrowConverter.toArray(currentBatch);
assertEquals(2,arr2.rows()); assertEquals(2, arr2.rows());
assertEquals(2,arr2.columns()); assertEquals(2, arr2.columns());
}
@Test
public void testConvertToArrowVectors() {
INDArray matrix = Nd4j.linspace(1,4,4).reshape(2,2);
val vectors = ArrowConverter.convertToArrowVector(matrix,Arrays.asList("test","test2"), ColumnType.Double,bufferAllocator);
assertEquals(matrix.rows(),vectors.size());
INDArray vector = Nd4j.linspace(1,4,4);
val vectors2 = ArrowConverter.convertToArrowVector(vector,Arrays.asList("test"), ColumnType.Double,bufferAllocator);
assertEquals(1,vectors2.size());
assertEquals(matrix.length(),vectors2.get(0).getValueCount());
} }
@Test @Test
public void testSchemaConversionBasic() { @DisplayName("Test Convert To Arrow Vectors")
void testConvertToArrowVectors() {
INDArray matrix = Nd4j.linspace(1, 4, 4).reshape(2, 2);
val vectors = ArrowConverter.convertToArrowVector(matrix, Arrays.asList("test", "test2"), ColumnType.Double, bufferAllocator);
assertEquals(matrix.rows(), vectors.size());
INDArray vector = Nd4j.linspace(1, 4, 4);
val vectors2 = ArrowConverter.convertToArrowVector(vector, Arrays.asList("test"), ColumnType.Double, bufferAllocator);
assertEquals(1, vectors2.size());
assertEquals(matrix.length(), vectors2.get(0).getValueCount());
}
@Test
@DisplayName("Test Schema Conversion Basic")
void testSchemaConversionBasic() {
Schema.Builder schemaBuilder = new Schema.Builder(); Schema.Builder schemaBuilder = new Schema.Builder();
for(int i = 0; i < 2; i++) { for (int i = 0; i < 2; i++) {
schemaBuilder.addColumnDouble("test-" + i); schemaBuilder.addColumnDouble("test-" + i);
schemaBuilder.addColumnInteger("testi-" + i); schemaBuilder.addColumnInteger("testi-" + i);
schemaBuilder.addColumnLong("testl-" + i); schemaBuilder.addColumnLong("testl-" + i);
schemaBuilder.addColumnFloat("testf-" + i); schemaBuilder.addColumnFloat("testf-" + i);
} }
Schema schema = schemaBuilder.build(); Schema schema = schemaBuilder.build();
val schema2 = ArrowConverter.toArrowSchema(schema); val schema2 = ArrowConverter.toArrowSchema(schema);
assertEquals(8,schema2.getFields().size()); assertEquals(8, schema2.getFields().size());
val convertedSchema = ArrowConverter.toDatavecSchema(schema2); val convertedSchema = ArrowConverter.toDatavecSchema(schema2);
assertEquals(schema,convertedSchema); assertEquals(schema, convertedSchema);
} }
@Test @Test
public void testReadSchemaAndRecordsFromByteArray() throws Exception { @DisplayName("Test Read Schema And Records From Byte Array")
void testReadSchemaAndRecordsFromByteArray() throws Exception {
BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
int valueCount = 3; int valueCount = 3;
List<Field> fields = new ArrayList<>(); List<Field> fields = new ArrayList<>();
fields.add(ArrowConverter.field("field1",new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE))); fields.add(ArrowConverter.field("field1", new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)));
fields.add(ArrowConverter.intField("field2")); fields.add(ArrowConverter.intField("field2"));
List<FieldVector> fieldVectors = new ArrayList<>(); List<FieldVector> fieldVectors = new ArrayList<>();
fieldVectors.add(ArrowConverter.vectorFor(allocator,"field1",new float[] {1,2,3})); fieldVectors.add(ArrowConverter.vectorFor(allocator, "field1", new float[] { 1, 2, 3 }));
fieldVectors.add(ArrowConverter.vectorFor(allocator,"field2",new int[] {1,2,3})); fieldVectors.add(ArrowConverter.vectorFor(allocator, "field2", new int[] { 1, 2, 3 }));
org.apache.arrow.vector.types.pojo.Schema schema = new org.apache.arrow.vector.types.pojo.Schema(fields); org.apache.arrow.vector.types.pojo.Schema schema = new org.apache.arrow.vector.types.pojo.Schema(fields);
VectorSchemaRoot schemaRoot1 = new VectorSchemaRoot(schema, fieldVectors, valueCount); VectorSchemaRoot schemaRoot1 = new VectorSchemaRoot(schema, fieldVectors, valueCount);
VectorUnloader vectorUnloader = new VectorUnloader(schemaRoot1); VectorUnloader vectorUnloader = new VectorUnloader(schemaRoot1);
vectorUnloader.getRecordBatch(); vectorUnloader.getRecordBatch();
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
try(ArrowFileWriter arrowFileWriter = new ArrowFileWriter(schemaRoot1,null,newChannel(byteArrayOutputStream))) { try (ArrowFileWriter arrowFileWriter = new ArrowFileWriter(schemaRoot1, null, newChannel(byteArrayOutputStream))) {
arrowFileWriter.writeBatch(); arrowFileWriter.writeBatch();
} catch (IOException e) { } catch (IOException e) {
log.error("",e); log.error("", e);
} }
byte[] arr = byteArrayOutputStream.toByteArray(); byte[] arr = byteArrayOutputStream.toByteArray();
val arr2 = ArrowConverter.readFromBytes(arr); val arr2 = ArrowConverter.readFromBytes(arr);
assertEquals(2,arr2.getFirst().numColumns()); assertEquals(2, arr2.getFirst().numColumns());
assertEquals(3,arr2.getRight().size()); assertEquals(3, arr2.getRight().size());
val arrowCols = ArrowConverter.toArrowColumns(allocator, arr2.getFirst(), arr2.getRight());
val arrowCols = ArrowConverter.toArrowColumns(allocator,arr2.getFirst(),arr2.getRight()); assertEquals(2, arrowCols.size());
assertEquals(2,arrowCols.size()); assertEquals(valueCount, arrowCols.get(0).getValueCount());
assertEquals(valueCount,arrowCols.get(0).getValueCount());
} }
@Test @Test
public void testVectorForEdgeCases() { @DisplayName("Test Vector For Edge Cases")
void testVectorForEdgeCases() {
BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
val vector = ArrowConverter.vectorFor(allocator,"field1",new float[]{Float.MIN_VALUE,Float.MAX_VALUE}); val vector = ArrowConverter.vectorFor(allocator, "field1", new float[] { Float.MIN_VALUE, Float.MAX_VALUE });
assertEquals(Float.MIN_VALUE,vector.get(0),1e-2); assertEquals(Float.MIN_VALUE, vector.get(0), 1e-2);
assertEquals(Float.MAX_VALUE,vector.get(1),1e-2); assertEquals(Float.MAX_VALUE, vector.get(1), 1e-2);
val vectorInt = ArrowConverter.vectorFor(allocator, "field1", new int[] { Integer.MIN_VALUE, Integer.MAX_VALUE });
val vectorInt = ArrowConverter.vectorFor(allocator,"field1",new int[]{Integer.MIN_VALUE,Integer.MAX_VALUE}); assertEquals(Integer.MIN_VALUE, vectorInt.get(0), 1e-2);
assertEquals(Integer.MIN_VALUE,vectorInt.get(0),1e-2); assertEquals(Integer.MAX_VALUE, vectorInt.get(1), 1e-2);
assertEquals(Integer.MAX_VALUE,vectorInt.get(1),1e-2);
} }
@Test @Test
public void testVectorFor() { @DisplayName("Test Vector For")
void testVectorFor() {
BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
val vector = ArrowConverter.vectorFor(allocator, "field1", new float[] { 1, 2, 3 });
val vector = ArrowConverter.vectorFor(allocator,"field1",new float[]{1,2,3}); assertEquals(3, vector.getValueCount());
assertEquals(3,vector.getValueCount()); assertEquals(1, vector.get(0), 1e-2);
assertEquals(1,vector.get(0),1e-2); assertEquals(2, vector.get(1), 1e-2);
assertEquals(2,vector.get(1),1e-2); assertEquals(3, vector.get(2), 1e-2);
assertEquals(3,vector.get(2),1e-2); val vectorLong = ArrowConverter.vectorFor(allocator, "field1", new long[] { 1, 2, 3 });
assertEquals(3, vectorLong.getValueCount());
val vectorLong = ArrowConverter.vectorFor(allocator,"field1",new long[]{1,2,3}); assertEquals(1, vectorLong.get(0), 1e-2);
assertEquals(3,vectorLong.getValueCount()); assertEquals(2, vectorLong.get(1), 1e-2);
assertEquals(1,vectorLong.get(0),1e-2); assertEquals(3, vectorLong.get(2), 1e-2);
assertEquals(2,vectorLong.get(1),1e-2); val vectorInt = ArrowConverter.vectorFor(allocator, "field1", new int[] { 1, 2, 3 });
assertEquals(3,vectorLong.get(2),1e-2); assertEquals(3, vectorInt.getValueCount());
assertEquals(1, vectorInt.get(0), 1e-2);
assertEquals(2, vectorInt.get(1), 1e-2);
val vectorInt = ArrowConverter.vectorFor(allocator,"field1",new int[]{1,2,3}); assertEquals(3, vectorInt.get(2), 1e-2);
assertEquals(3,vectorInt.getValueCount()); val vectorDouble = ArrowConverter.vectorFor(allocator, "field1", new double[] { 1, 2, 3 });
assertEquals(1,vectorInt.get(0),1e-2); assertEquals(3, vectorDouble.getValueCount());
assertEquals(2,vectorInt.get(1),1e-2); assertEquals(1, vectorDouble.get(0), 1e-2);
assertEquals(3,vectorInt.get(2),1e-2); assertEquals(2, vectorDouble.get(1), 1e-2);
assertEquals(3, vectorDouble.get(2), 1e-2);
val vectorDouble = ArrowConverter.vectorFor(allocator,"field1",new double[]{1,2,3}); val vectorBool = ArrowConverter.vectorFor(allocator, "field1", new boolean[] { true, true, false });
assertEquals(3,vectorDouble.getValueCount()); assertEquals(3, vectorBool.getValueCount());
assertEquals(1,vectorDouble.get(0),1e-2); assertEquals(1, vectorBool.get(0), 1e-2);
assertEquals(2,vectorDouble.get(1),1e-2); assertEquals(1, vectorBool.get(1), 1e-2);
assertEquals(3,vectorDouble.get(2),1e-2); assertEquals(0, vectorBool.get(2), 1e-2);
val vectorBool = ArrowConverter.vectorFor(allocator,"field1",new boolean[]{true,true,false});
assertEquals(3,vectorBool.getValueCount());
assertEquals(1,vectorBool.get(0),1e-2);
assertEquals(1,vectorBool.get(1),1e-2);
assertEquals(0,vectorBool.get(2),1e-2);
} }
@Test @Test
public void testRecordReaderAndWriteFile() throws Exception { @DisplayName("Test Record Reader And Write File")
void testRecordReaderAndWriteFile() throws Exception {
val recordsToWrite = recordToWrite(); val recordsToWrite = recordToWrite();
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(),recordsToWrite.getFirst(),byteArrayOutputStream); ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(), recordsToWrite.getFirst(), byteArrayOutputStream);
byte[] arr = byteArrayOutputStream.toByteArray(); byte[] arr = byteArrayOutputStream.toByteArray();
val read = ArrowConverter.readFromBytes(arr); val read = ArrowConverter.readFromBytes(arr);
assertEquals(recordsToWrite,read); assertEquals(recordsToWrite, read);
// send file
//send file File tmp = tmpDataFile(recordsToWrite);
File tmp = tmpDataFile(recordsToWrite);
RecordReader recordReader = new ArrowRecordReader(); RecordReader recordReader = new ArrowRecordReader();
recordReader.initialize(new FileSplit(tmp)); recordReader.initialize(new FileSplit(tmp));
List<Writable> record = recordReader.next(); List<Writable> record = recordReader.next();
assertEquals(2,record.size()); assertEquals(2, record.size());
} }
@Test @Test
public void testRecordReaderMetaDataList() throws Exception { @DisplayName("Test Record Reader Meta Data List")
void testRecordReaderMetaDataList() throws Exception {
val recordsToWrite = recordToWrite(); val recordsToWrite = recordToWrite();
//send file // send file
File tmp = tmpDataFile(recordsToWrite); File tmp = tmpDataFile(recordsToWrite);
RecordReader recordReader = new ArrowRecordReader(); RecordReader recordReader = new ArrowRecordReader();
RecordMetaDataIndex recordMetaDataIndex = new RecordMetaDataIndex(0,tmp.toURI(),ArrowRecordReader.class); RecordMetaDataIndex recordMetaDataIndex = new RecordMetaDataIndex(0, tmp.toURI(), ArrowRecordReader.class);
recordReader.loadFromMetaData(Arrays.<RecordMetaData>asList(recordMetaDataIndex)); recordReader.loadFromMetaData(Arrays.<RecordMetaData>asList(recordMetaDataIndex));
Record record = recordReader.nextRecord(); Record record = recordReader.nextRecord();
assertEquals(2,record.getRecord().size()); assertEquals(2, record.getRecord().size());
} }
@Test @Test
public void testDates() { @DisplayName("Test Dates")
void testDates() {
Date now = new Date(); Date now = new Date();
BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE); BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE);
TimeStampMilliVector timeStampMilliVector = ArrowConverter.vectorFor(bufferAllocator, "col1", new Date[]{now}); TimeStampMilliVector timeStampMilliVector = ArrowConverter.vectorFor(bufferAllocator, "col1", new Date[] { now });
assertEquals(now.getTime(),timeStampMilliVector.get(0)); assertEquals(now.getTime(), timeStampMilliVector.get(0));
} }
@Test @Test
public void testRecordReaderMetaData() throws Exception { @DisplayName("Test Record Reader Meta Data")
void testRecordReaderMetaData() throws Exception {
val recordsToWrite = recordToWrite(); val recordsToWrite = recordToWrite();
//send file // send file
File tmp = tmpDataFile(recordsToWrite); File tmp = tmpDataFile(recordsToWrite);
RecordReader recordReader = new ArrowRecordReader(); RecordReader recordReader = new ArrowRecordReader();
RecordMetaDataIndex recordMetaDataIndex = new RecordMetaDataIndex(0,tmp.toURI(),ArrowRecordReader.class); RecordMetaDataIndex recordMetaDataIndex = new RecordMetaDataIndex(0, tmp.toURI(), ArrowRecordReader.class);
recordReader.loadFromMetaData(recordMetaDataIndex); recordReader.loadFromMetaData(recordMetaDataIndex);
Record record = recordReader.nextRecord(); Record record = recordReader.nextRecord();
assertEquals(2,record.getRecord().size()); assertEquals(2, record.getRecord().size());
} }
private File tmpDataFile(Pair<Schema,List<List<Writable>>> recordsToWrite) throws IOException { private File tmpDataFile(Pair<Schema, List<List<Writable>>> recordsToWrite) throws IOException {
File f = testDir.toFile();
File f = testDir.newFolder(); // send file
File tmp = new File(f, "tmp-file-" + UUID.randomUUID().toString());
//send file
File tmp = new File(f,"tmp-file-" + UUID.randomUUID().toString());
tmp.mkdirs(); tmp.mkdirs();
File tmpFile = new File(tmp,"data.arrow"); File tmpFile = new File(tmp, "data.arrow");
tmpFile.deleteOnExit(); tmpFile.deleteOnExit();
FileOutputStream bufferedOutputStream = new FileOutputStream(tmpFile); FileOutputStream bufferedOutputStream = new FileOutputStream(tmpFile);
ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(),recordsToWrite.getFirst(),bufferedOutputStream); ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(), recordsToWrite.getFirst(), bufferedOutputStream);
bufferedOutputStream.flush(); bufferedOutputStream.flush();
bufferedOutputStream.close(); bufferedOutputStream.close();
return tmp; return tmp;
} }
private Pair<Schema,List<List<Writable>>> recordToWrite() { private Pair<Schema, List<List<Writable>>> recordToWrite() {
List<List<Writable>> records = new ArrayList<>(); List<List<Writable>> records = new ArrayList<>();
records.add(Arrays.<Writable>asList(new DoubleWritable(0.0),new DoubleWritable(0.0))); records.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new DoubleWritable(0.0)));
records.add(Arrays.<Writable>asList(new DoubleWritable(0.0),new DoubleWritable(0.0))); records.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new DoubleWritable(0.0)));
Schema.Builder schemaBuilder = new Schema.Builder(); Schema.Builder schemaBuilder = new Schema.Builder();
for(int i = 0; i < 2; i++) { for (int i = 0; i < 2; i++) {
schemaBuilder.addColumnFloat("col-" + i); schemaBuilder.addColumnFloat("col-" + i);
} }
return Pair.of(schemaBuilder.build(), records);
return Pair.of(schemaBuilder.build(),records);
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.arrow; package org.datavec.arrow;
import lombok.val; import lombok.val;
@ -34,132 +33,98 @@ import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.datavec.arrow.recordreader.ArrowRecordReader; import org.datavec.arrow.recordreader.ArrowRecordReader;
import org.datavec.arrow.recordreader.ArrowRecordWriter; import org.datavec.arrow.recordreader.ArrowRecordWriter;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.primitives.Triple; import org.nd4j.common.primitives.Triple;
import java.io.File; import java.io.File;
import java.nio.file.Files; import java.nio.file.Files;
import java.nio.file.Path; import java.nio.file.Path;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals; @DisplayName("Record Mapper Test")
class RecordMapperTest extends BaseND4JTest {
public class RecordMapperTest extends BaseND4JTest {
@Test @Test
public void testMultiWrite() throws Exception { @DisplayName("Test Multi Write")
void testMultiWrite() throws Exception {
val recordsPair = records(); val recordsPair = records();
Path p = Files.createTempFile("arrowwritetest", ".arrow"); Path p = Files.createTempFile("arrowwritetest", ".arrow");
FileUtils.write(p.toFile(),recordsPair.getFirst()); FileUtils.write(p.toFile(), recordsPair.getFirst());
p.toFile().deleteOnExit(); p.toFile().deleteOnExit();
int numReaders = 2; int numReaders = 2;
RecordReader[] readers = new RecordReader[numReaders]; RecordReader[] readers = new RecordReader[numReaders];
InputSplit[] splits = new InputSplit[numReaders]; InputSplit[] splits = new InputSplit[numReaders];
for(int i = 0; i < readers.length; i++) { for (int i = 0; i < readers.length; i++) {
FileSplit split = new FileSplit(p.toFile()); FileSplit split = new FileSplit(p.toFile());
ArrowRecordReader arrowRecordReader = new ArrowRecordReader(); ArrowRecordReader arrowRecordReader = new ArrowRecordReader();
readers[i] = arrowRecordReader; readers[i] = arrowRecordReader;
splits[i] = split; splits[i] = split;
} }
ArrowRecordWriter arrowRecordWriter = new ArrowRecordWriter(recordsPair.getMiddle()); ArrowRecordWriter arrowRecordWriter = new ArrowRecordWriter(recordsPair.getMiddle());
FileSplit split = new FileSplit(p.toFile()); FileSplit split = new FileSplit(p.toFile());
arrowRecordWriter.initialize(split,new NumberOfRecordsPartitioner()); arrowRecordWriter.initialize(split, new NumberOfRecordsPartitioner());
arrowRecordWriter.writeBatch(recordsPair.getRight()); arrowRecordWriter.writeBatch(recordsPair.getRight());
CSVRecordWriter csvRecordWriter = new CSVRecordWriter(); CSVRecordWriter csvRecordWriter = new CSVRecordWriter();
Path p2 = Files.createTempFile("arrowwritetest", ".csv"); Path p2 = Files.createTempFile("arrowwritetest", ".csv");
FileUtils.write(p2.toFile(),recordsPair.getFirst()); FileUtils.write(p2.toFile(), recordsPair.getFirst());
p.toFile().deleteOnExit(); p.toFile().deleteOnExit();
FileSplit outputCsv = new FileSplit(p2.toFile()); FileSplit outputCsv = new FileSplit(p2.toFile());
RecordMapper mapper = RecordMapper.builder().batchSize(10).inputUrl(split).outputUrl(outputCsv).partitioner(new NumberOfRecordsPartitioner()).readersToConcat(readers).splitPerReader(splits).recordWriter(csvRecordWriter).build();
RecordMapper mapper = RecordMapper.builder().batchSize(10).inputUrl(split)
.outputUrl(outputCsv)
.partitioner(new NumberOfRecordsPartitioner()).readersToConcat(readers)
.splitPerReader(splits)
.recordWriter(csvRecordWriter)
.build();
mapper.copy(); mapper.copy();
} }
@Test @Test
public void testCopyFromArrowToCsv() throws Exception { @DisplayName("Test Copy From Arrow To Csv")
void testCopyFromArrowToCsv() throws Exception {
val recordsPair = records(); val recordsPair = records();
Path p = Files.createTempFile("arrowwritetest", ".arrow"); Path p = Files.createTempFile("arrowwritetest", ".arrow");
FileUtils.write(p.toFile(),recordsPair.getFirst()); FileUtils.write(p.toFile(), recordsPair.getFirst());
p.toFile().deleteOnExit(); p.toFile().deleteOnExit();
ArrowRecordWriter arrowRecordWriter = new ArrowRecordWriter(recordsPair.getMiddle()); ArrowRecordWriter arrowRecordWriter = new ArrowRecordWriter(recordsPair.getMiddle());
FileSplit split = new FileSplit(p.toFile()); FileSplit split = new FileSplit(p.toFile());
arrowRecordWriter.initialize(split,new NumberOfRecordsPartitioner()); arrowRecordWriter.initialize(split, new NumberOfRecordsPartitioner());
arrowRecordWriter.writeBatch(recordsPair.getRight()); arrowRecordWriter.writeBatch(recordsPair.getRight());
ArrowRecordReader arrowRecordReader = new ArrowRecordReader(); ArrowRecordReader arrowRecordReader = new ArrowRecordReader();
arrowRecordReader.initialize(split); arrowRecordReader.initialize(split);
CSVRecordWriter csvRecordWriter = new CSVRecordWriter(); CSVRecordWriter csvRecordWriter = new CSVRecordWriter();
Path p2 = Files.createTempFile("arrowwritetest", ".csv"); Path p2 = Files.createTempFile("arrowwritetest", ".csv");
FileUtils.write(p2.toFile(),recordsPair.getFirst()); FileUtils.write(p2.toFile(), recordsPair.getFirst());
p.toFile().deleteOnExit(); p.toFile().deleteOnExit();
FileSplit outputCsv = new FileSplit(p2.toFile()); FileSplit outputCsv = new FileSplit(p2.toFile());
RecordMapper mapper = RecordMapper.builder().batchSize(10).inputUrl(split).outputUrl(outputCsv).partitioner(new NumberOfRecordsPartitioner()).recordReader(arrowRecordReader).recordWriter(csvRecordWriter).build();
RecordMapper mapper = RecordMapper.builder().batchSize(10).inputUrl(split)
.outputUrl(outputCsv)
.partitioner(new NumberOfRecordsPartitioner())
.recordReader(arrowRecordReader).recordWriter(csvRecordWriter)
.build();
mapper.copy(); mapper.copy();
CSVRecordReader recordReader = new CSVRecordReader(); CSVRecordReader recordReader = new CSVRecordReader();
recordReader.initialize(outputCsv); recordReader.initialize(outputCsv);
List<List<Writable>> loadedCSvRecords = recordReader.next(10); List<List<Writable>> loadedCSvRecords = recordReader.next(10);
assertEquals(10,loadedCSvRecords.size()); assertEquals(10, loadedCSvRecords.size());
} }
@Test @Test
public void testCopyFromCsvToArrow() throws Exception { @DisplayName("Test Copy From Csv To Arrow")
void testCopyFromCsvToArrow() throws Exception {
val recordsPair = records(); val recordsPair = records();
Path p = Files.createTempFile("csvwritetest", ".csv"); Path p = Files.createTempFile("csvwritetest", ".csv");
FileUtils.write(p.toFile(),recordsPair.getFirst()); FileUtils.write(p.toFile(), recordsPair.getFirst());
p.toFile().deleteOnExit(); p.toFile().deleteOnExit();
CSVRecordReader recordReader = new CSVRecordReader(); CSVRecordReader recordReader = new CSVRecordReader();
FileSplit fileSplit = new FileSplit(p.toFile()); FileSplit fileSplit = new FileSplit(p.toFile());
ArrowRecordWriter arrowRecordWriter = new ArrowRecordWriter(recordsPair.getMiddle()); ArrowRecordWriter arrowRecordWriter = new ArrowRecordWriter(recordsPair.getMiddle());
File outputFile = Files.createTempFile("outputarrow","arrow").toFile(); File outputFile = Files.createTempFile("outputarrow", "arrow").toFile();
FileSplit outputFileSplit = new FileSplit(outputFile); FileSplit outputFileSplit = new FileSplit(outputFile);
RecordMapper mapper = RecordMapper.builder().batchSize(10).inputUrl(fileSplit) RecordMapper mapper = RecordMapper.builder().batchSize(10).inputUrl(fileSplit).outputUrl(outputFileSplit).partitioner(new NumberOfRecordsPartitioner()).recordReader(recordReader).recordWriter(arrowRecordWriter).build();
.outputUrl(outputFileSplit).partitioner(new NumberOfRecordsPartitioner())
.recordReader(recordReader).recordWriter(arrowRecordWriter)
.build();
mapper.copy(); mapper.copy();
ArrowRecordReader arrowRecordReader = new ArrowRecordReader(); ArrowRecordReader arrowRecordReader = new ArrowRecordReader();
arrowRecordReader.initialize(outputFileSplit); arrowRecordReader.initialize(outputFileSplit);
List<List<Writable>> next = arrowRecordReader.next(10); List<List<Writable>> next = arrowRecordReader.next(10);
System.out.println(next); System.out.println(next);
assertEquals(10,next.size()); assertEquals(10, next.size());
} }
private Triple<String,Schema,List<List<Writable>>> records() { private Triple<String, Schema, List<List<Writable>>> records() {
List<List<Writable>> list = new ArrayList<>(); List<List<Writable>> list = new ArrayList<>();
StringBuilder sb = new StringBuilder(); StringBuilder sb = new StringBuilder();
int numColumns = 3; int numColumns = 3;
@ -176,15 +141,10 @@ public class RecordMapperTest extends BaseND4JTest {
} }
list.add(temp); list.add(temp);
} }
Schema.Builder schemaBuilder = new Schema.Builder(); Schema.Builder schemaBuilder = new Schema.Builder();
for(int i = 0; i < numColumns; i++) { for (int i = 0; i < numColumns; i++) {
schemaBuilder.addColumnInteger(String.valueOf(i)); schemaBuilder.addColumnInteger(String.valueOf(i));
} }
return Triple.of(sb.toString(), schemaBuilder.build(), list);
return Triple.of(sb.toString(),schemaBuilder.build(),list);
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.image; package org.datavec.image;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
@ -25,33 +24,32 @@ import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit; import org.datavec.api.split.FileSplit;
import org.datavec.image.recordreader.ImageRecordReader; import org.datavec.image.recordreader.ImageRecordReader;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.TemporaryFolder; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import java.io.File; import java.io.File;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals; @DisplayName("Label Generator Test")
import static org.junit.Assert.assertTrue; class LabelGeneratorTest {
public class LabelGeneratorTest { @TempDir
public Path testDir;
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test @Test
public void testParentPathLabelGenerator() throws Exception { @DisplayName("Test Parent Path Label Generator")
//https://github.com/deeplearning4j/DataVec/issues/273 void testParentPathLabelGenerator(@TempDir Path testDir) throws Exception {
File orig = new ClassPathResource("datavec-data-image/testimages/class0/0.jpg").getFile(); File orig = new ClassPathResource("datavec-data-image/testimages/class0/0.jpg").getFile();
for (String dirPrefix : new String[] { "m.", "m" }) {
for(String dirPrefix : new String[]{"m.", "m"}) { File f = testDir.toFile();
File f = testDir.newFolder();
int numDirs = 3; int numDirs = 3;
int filesPerDir = 4; int filesPerDir = 4;
for (int i = 0; i < numDirs; i++) { for (int i = 0; i < numDirs; i++) {
File currentLabelDir = new File(f, dirPrefix + i); File currentLabelDir = new File(f, dirPrefix + i);
currentLabelDir.mkdirs(); currentLabelDir.mkdirs();
@ -61,14 +59,11 @@ public class LabelGeneratorTest {
assertTrue(f3.exists()); assertTrue(f3.exists());
} }
} }
ImageRecordReader rr = new ImageRecordReader(28, 28, 1, new ParentPathLabelGenerator()); ImageRecordReader rr = new ImageRecordReader(28, 28, 1, new ParentPathLabelGenerator());
rr.initialize(new FileSplit(f)); rr.initialize(new FileSplit(f));
List<String> labelsAct = rr.getLabels(); List<String> labelsAct = rr.getLabels();
List<String> labelsExp = Arrays.asList(dirPrefix + "0", dirPrefix + "1", dirPrefix + "2"); List<String> labelsExp = Arrays.asList(dirPrefix + "0", dirPrefix + "1", dirPrefix + "2");
assertEquals(labelsExp, labelsAct); assertEquals(labelsExp, labelsAct);
int expCount = numDirs * filesPerDir; int expCount = numDirs * filesPerDir;
int actCount = 0; int actCount = 0;
while (rr.hasNext()) { while (rr.hasNext()) {

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.image.recordreader; package org.datavec.image.recordreader;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
@ -29,60 +28,55 @@ import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.datavec.image.loader.NativeImageLoader; import org.datavec.image.loader.NativeImageLoader;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.TemporaryFolder; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.loader.FileBatch; import org.nd4j.common.loader.FileBatch;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import java.io.File; import java.io.File;
import java.util.*; import java.util.*;
import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.*; @DisplayName("File Batch Record Reader Test")
class FileBatchRecordReaderTest {
public class FileBatchRecordReaderTest { @TempDir
public Path testDir;
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test @Test
public void testCsv() throws Exception { @DisplayName("Test Csv")
File extractedSourceDir = testDir.newFolder(); void testCsv(@TempDir Path testDir,@TempDir Path baseDirPath) throws Exception {
File extractedSourceDir = testDir.toFile();
new ClassPathResource("datavec-data-image/testimages").copyDirectory(extractedSourceDir); new ClassPathResource("datavec-data-image/testimages").copyDirectory(extractedSourceDir);
File baseDir = testDir.newFolder(); File baseDir = baseDirPath.toFile();
List<File> c = new ArrayList<>(FileUtils.listFiles(extractedSourceDir, null, true)); List<File> c = new ArrayList<>(FileUtils.listFiles(extractedSourceDir, null, true));
assertEquals(6, c.size()); assertEquals(6, c.size());
Collections.sort(c, new Comparator<File>() { Collections.sort(c, new Comparator<File>() {
@Override @Override
public int compare(File o1, File o2) { public int compare(File o1, File o2) {
return o1.getPath().compareTo(o2.getPath()); return o1.getPath().compareTo(o2.getPath());
} }
}); });
FileBatch fb = FileBatch.forFiles(c); FileBatch fb = FileBatch.forFiles(c);
File saveFile = new File(baseDir, "saved.zip"); File saveFile = new File(baseDir, "saved.zip");
fb.writeAsZip(saveFile); fb.writeAsZip(saveFile);
fb = FileBatch.readFromZip(saveFile); fb = FileBatch.readFromZip(saveFile);
PathLabelGenerator labelMaker = new ParentPathLabelGenerator(); PathLabelGenerator labelMaker = new ParentPathLabelGenerator();
ImageRecordReader rr = new ImageRecordReader(32, 32, 1, labelMaker); ImageRecordReader rr = new ImageRecordReader(32, 32, 1, labelMaker);
rr.setLabels(Arrays.asList("class0", "class1")); rr.setLabels(Arrays.asList("class0", "class1"));
FileBatchRecordReader fbrr = new FileBatchRecordReader(rr, fb); FileBatchRecordReader fbrr = new FileBatchRecordReader(rr, fb);
NativeImageLoader il = new NativeImageLoader(32, 32, 1); NativeImageLoader il = new NativeImageLoader(32, 32, 1);
for( int test=0; test<3; test++) { for (int test = 0; test < 3; test++) {
for (int i = 0; i < 6; i++) { for (int i = 0; i < 6; i++) {
assertTrue(fbrr.hasNext()); assertTrue(fbrr.hasNext());
List<Writable> next = fbrr.next(); List<Writable> next = fbrr.next();
assertEquals(2, next.size()); assertEquals(2, next.size());
INDArray exp; INDArray exp;
switch (i){ switch(i) {
case 0: case 0:
exp = il.asMatrix(new File(extractedSourceDir, "class0/0.jpg")); exp = il.asMatrix(new File(extractedSourceDir, "class0/0.jpg"));
break; break;
@ -105,8 +99,7 @@ public class FileBatchRecordReaderTest {
throw new RuntimeException(); throw new RuntimeException();
} }
Writable expLabel = (i < 3 ? new IntWritable(0) : new IntWritable(1)); Writable expLabel = (i < 3 ? new IntWritable(0) : new IntWritable(1));
assertEquals(((NDArrayWritable) next.get(0)).get(), exp);
assertEquals(((NDArrayWritable)next.get(0)).get(), exp);
assertEquals(expLabel, next.get(1)); assertEquals(expLabel, next.get(1));
} }
assertFalse(fbrr.hasNext()); assertFalse(fbrr.hasNext());
@ -114,5 +107,4 @@ public class FileBatchRecordReaderTest {
fbrr.reset(); fbrr.reset();
} }
} }
} }

View File

@ -17,106 +17,70 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.image.transform; package org.datavec.image.transform;
import org.datavec.image.data.ImageWritable; import org.datavec.image.data.ImageWritable;
import org.junit.Test; import org.junit.jupiter.api.Test;
import java.io.IOException; import java.io.IOException;
import java.util.List; import java.util.List;
import java.util.Random; import java.util.Random;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals; @DisplayName("Json Yaml Test")
import static org.junit.Assert.assertTrue; class JsonYamlTest {
public class JsonYamlTest {
@Test @Test
public void testJsonYamlImageTransformProcess() throws IOException { @DisplayName("Test Json Yaml Image Transform Process")
void testJsonYamlImageTransformProcess() throws IOException {
int seed = 12345; int seed = 12345;
Random random = new Random(seed); Random random = new Random(seed);
// from org.bytedeco.javacpp.opencv_imgproc
//from org.bytedeco.javacpp.opencv_imgproc
int COLOR_BGR2Luv = 50; int COLOR_BGR2Luv = 50;
int CV_BGR2GRAY = 6; int CV_BGR2GRAY = 6;
ImageTransformProcess itp = new ImageTransformProcess.Builder().colorConversionTransform(COLOR_BGR2Luv).cropImageTransform(10).equalizeHistTransform(CV_BGR2GRAY).flipImageTransform(0).resizeImageTransform(300, 300).rotateImageTransform(30).scaleImageTransform(3).warpImageTransform((float) 0.5).build();
ImageTransformProcess itp = new ImageTransformProcess.Builder().colorConversionTransform(COLOR_BGR2Luv)
.cropImageTransform(10).equalizeHistTransform(CV_BGR2GRAY).flipImageTransform(0)
.resizeImageTransform(300, 300).rotateImageTransform(30).scaleImageTransform(3)
.warpImageTransform((float) 0.5)
// Note : since randomCropTransform use random value
// the results from each case(json, yaml, ImageTransformProcess)
// can be different
// don't use the below line
// if you uncomment it, you will get fail from below assertions
// .randomCropTransform(seed, 50, 50)
// Note : you will get "java.lang.NoClassDefFoundError: Could not initialize class org.bytedeco.javacpp.avutil"
// it needs to add the below dependency
// <dependency>
// <groupId>org.bytedeco</groupId>
// <artifactId>ffmpeg-platform</artifactId>
// </dependency>
// FFmpeg has license issues, be careful to use it
//.filterImageTransform("noise=alls=20:allf=t+u,format=rgba", 100, 100, 4)
.build();
String asJson = itp.toJson(); String asJson = itp.toJson();
String asYaml = itp.toYaml(); String asYaml = itp.toYaml();
// System.out.println(asJson);
// System.out.println(asJson); // System.out.println("\n\n\n");
// System.out.println("\n\n\n"); // System.out.println(asYaml);
// System.out.println(asYaml);
ImageWritable img = TestImageTransform.makeRandomImage(0, 0, 3); ImageWritable img = TestImageTransform.makeRandomImage(0, 0, 3);
ImageWritable imgJson = new ImageWritable(img.getFrame().clone()); ImageWritable imgJson = new ImageWritable(img.getFrame().clone());
ImageWritable imgYaml = new ImageWritable(img.getFrame().clone()); ImageWritable imgYaml = new ImageWritable(img.getFrame().clone());
ImageWritable imgAll = new ImageWritable(img.getFrame().clone()); ImageWritable imgAll = new ImageWritable(img.getFrame().clone());
ImageTransformProcess itpFromJson = ImageTransformProcess.fromJson(asJson); ImageTransformProcess itpFromJson = ImageTransformProcess.fromJson(asJson);
ImageTransformProcess itpFromYaml = ImageTransformProcess.fromYaml(asYaml); ImageTransformProcess itpFromYaml = ImageTransformProcess.fromYaml(asYaml);
List<ImageTransform> transformList = itp.getTransformList(); List<ImageTransform> transformList = itp.getTransformList();
List<ImageTransform> transformListJson = itpFromJson.getTransformList(); List<ImageTransform> transformListJson = itpFromJson.getTransformList();
List<ImageTransform> transformListYaml = itpFromYaml.getTransformList(); List<ImageTransform> transformListYaml = itpFromYaml.getTransformList();
for (int i = 0; i < transformList.size(); i++) { for (int i = 0; i < transformList.size(); i++) {
ImageTransform it = transformList.get(i); ImageTransform it = transformList.get(i);
ImageTransform itJson = transformListJson.get(i); ImageTransform itJson = transformListJson.get(i);
ImageTransform itYaml = transformListYaml.get(i); ImageTransform itYaml = transformListYaml.get(i);
System.out.println(i + "\t" + it); System.out.println(i + "\t" + it);
img = it.transform(img); img = it.transform(img);
imgJson = itJson.transform(imgJson); imgJson = itJson.transform(imgJson);
imgYaml = itYaml.transform(imgYaml); imgYaml = itYaml.transform(imgYaml);
if (it instanceof RandomCropTransform) { if (it instanceof RandomCropTransform) {
assertTrue(img.getFrame().imageHeight == imgJson.getFrame().imageHeight); assertTrue(img.getFrame().imageHeight == imgJson.getFrame().imageHeight);
assertTrue(img.getFrame().imageWidth == imgJson.getFrame().imageWidth); assertTrue(img.getFrame().imageWidth == imgJson.getFrame().imageWidth);
assertTrue(img.getFrame().imageHeight == imgYaml.getFrame().imageHeight); assertTrue(img.getFrame().imageHeight == imgYaml.getFrame().imageHeight);
assertTrue(img.getFrame().imageWidth == imgYaml.getFrame().imageWidth); assertTrue(img.getFrame().imageWidth == imgYaml.getFrame().imageWidth);
} else if (it instanceof FilterImageTransform) { } else if (it instanceof FilterImageTransform) {
assertEquals(img.getFrame().imageHeight, imgJson.getFrame().imageHeight); assertEquals(img.getFrame().imageHeight, imgJson.getFrame().imageHeight);
assertEquals(img.getFrame().imageWidth, imgJson.getFrame().imageWidth); assertEquals(img.getFrame().imageWidth, imgJson.getFrame().imageWidth);
assertEquals(img.getFrame().imageChannels, imgJson.getFrame().imageChannels); assertEquals(img.getFrame().imageChannels, imgJson.getFrame().imageChannels);
assertEquals(img.getFrame().imageHeight, imgYaml.getFrame().imageHeight); assertEquals(img.getFrame().imageHeight, imgYaml.getFrame().imageHeight);
assertEquals(img.getFrame().imageWidth, imgYaml.getFrame().imageWidth); assertEquals(img.getFrame().imageWidth, imgYaml.getFrame().imageWidth);
assertEquals(img.getFrame().imageChannels, imgYaml.getFrame().imageChannels); assertEquals(img.getFrame().imageChannels, imgYaml.getFrame().imageChannels);
} else { } else {
assertEquals(img, imgJson); assertEquals(img, imgJson);
assertEquals(img, imgYaml); assertEquals(img, imgYaml);
} }
} }
imgAll = itp.execute(imgAll); imgAll = itp.execute(imgAll);
assertEquals(imgAll, img); assertEquals(imgAll, img);
} }
} }

View File

@ -17,56 +17,50 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.image.transform; package org.datavec.image.transform;
import org.bytedeco.javacv.Frame; import org.bytedeco.javacv.Frame;
import org.datavec.image.data.ImageWritable; import org.datavec.image.data.ImageWritable;
import org.junit.Before; import org.junit.jupiter.api.BeforeEach;
import org.junit.Test; import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals; @DisplayName("Resize Image Transform Test")
class ResizeImageTransformTest {
public class ResizeImageTransformTest {
@Before
public void setUp() throws Exception {
@BeforeEach
void setUp() throws Exception {
} }
@Test @Test
public void testResizeUpscale1() throws Exception { @DisplayName("Test Resize Upscale 1")
void testResizeUpscale1() throws Exception {
ImageWritable srcImg = TestImageTransform.makeRandomImage(32, 32, 3); ImageWritable srcImg = TestImageTransform.makeRandomImage(32, 32, 3);
ResizeImageTransform transform = new ResizeImageTransform(200, 200); ResizeImageTransform transform = new ResizeImageTransform(200, 200);
ImageWritable dstImg = transform.transform(srcImg); ImageWritable dstImg = transform.transform(srcImg);
Frame f = dstImg.getFrame(); Frame f = dstImg.getFrame();
assertEquals(f.imageWidth, 200); assertEquals(f.imageWidth, 200);
assertEquals(f.imageHeight, 200); assertEquals(f.imageHeight, 200);
float[] coordinates = { 100, 200 };
float[] coordinates = {100, 200};
float[] transformed = transform.query(coordinates); float[] transformed = transform.query(coordinates);
assertEquals(200f * 100 / 32, transformed[0], 0); assertEquals(200f * 100 / 32, transformed[0], 0);
assertEquals(200f * 200 / 32, transformed[1], 0); assertEquals(200f * 200 / 32, transformed[1], 0);
} }
@Test @Test
public void testResizeDownscale() throws Exception { @DisplayName("Test Resize Downscale")
void testResizeDownscale() throws Exception {
ImageWritable srcImg = TestImageTransform.makeRandomImage(571, 443, 3); ImageWritable srcImg = TestImageTransform.makeRandomImage(571, 443, 3);
ResizeImageTransform transform = new ResizeImageTransform(200, 200); ResizeImageTransform transform = new ResizeImageTransform(200, 200);
ImageWritable dstImg = transform.transform(srcImg); ImageWritable dstImg = transform.transform(srcImg);
Frame f = dstImg.getFrame(); Frame f = dstImg.getFrame();
assertEquals(f.imageWidth, 200); assertEquals(f.imageWidth, 200);
assertEquals(f.imageHeight, 200); assertEquals(f.imageHeight, 200);
float[] coordinates = { 300, 400 };
float[] coordinates = {300, 400};
float[] transformed = transform.query(coordinates); float[] transformed = transform.query(coordinates);
assertEquals(200f * 300 / 443, transformed[0], 0); assertEquals(200f * 300 / 443, transformed[0], 0);
assertEquals(200f * 400 / 571, transformed[1], 0); assertEquals(200f * 400 / 571, transformed[1], 0);
} }
} }

View File

@ -17,37 +17,34 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.poi.excel; package org.datavec.poi.excel;
import org.datavec.api.records.reader.RecordReader; import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.split.FileSplit; import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals; @DisplayName("Excel Record Reader Test")
import static org.junit.Assert.assertTrue; class ExcelRecordReaderTest {
public class ExcelRecordReaderTest {
@Test @Test
public void testSimple() throws Exception { @DisplayName("Test Simple")
void testSimple() throws Exception {
RecordReader excel = new ExcelRecordReader(); RecordReader excel = new ExcelRecordReader();
excel.initialize(new FileSplit(new ClassPathResource("datavec-excel/testsheet.xlsx").getFile())); excel.initialize(new FileSplit(new ClassPathResource("datavec-excel/testsheet.xlsx").getFile()));
assertTrue(excel.hasNext()); assertTrue(excel.hasNext());
List<Writable> next = excel.next(); List<Writable> next = excel.next();
assertEquals(3,next.size()); assertEquals(3, next.size());
RecordReader headerReader = new ExcelRecordReader(1); RecordReader headerReader = new ExcelRecordReader(1);
headerReader.initialize(new FileSplit(new ClassPathResource("datavec-excel/testsheetheader.xlsx").getFile())); headerReader.initialize(new FileSplit(new ClassPathResource("datavec-excel/testsheetheader.xlsx").getFile()));
assertTrue(excel.hasNext()); assertTrue(excel.hasNext());
List<Writable> next2 = excel.next(); List<Writable> next2 = excel.next();
assertEquals(3,next2.size()); assertEquals(3, next2.size());
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.poi.excel; package org.datavec.poi.excel;
import lombok.val; import lombok.val;
@ -27,43 +26,44 @@ import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.TemporaryFolder; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.primitives.Triple; import org.nd4j.common.primitives.Triple;
import java.io.File; import java.io.File;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals; @DisplayName("Excel Record Writer Test")
class ExcelRecordWriterTest {
public class ExcelRecordWriterTest { @TempDir
public Path testDir;
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test @Test
public void testWriter() throws Exception { @DisplayName("Test Writer")
void testWriter() throws Exception {
ExcelRecordWriter excelRecordWriter = new ExcelRecordWriter(); ExcelRecordWriter excelRecordWriter = new ExcelRecordWriter();
val records = records(); val records = records();
File tmpDir = testDir.newFolder(); File tmpDir = testDir.toFile();
File outputFile = new File(tmpDir,"testexcel.xlsx"); File outputFile = new File(tmpDir, "testexcel.xlsx");
outputFile.deleteOnExit(); outputFile.deleteOnExit();
FileSplit fileSplit = new FileSplit(outputFile); FileSplit fileSplit = new FileSplit(outputFile);
excelRecordWriter.initialize(fileSplit,new NumberOfRecordsPartitioner()); excelRecordWriter.initialize(fileSplit, new NumberOfRecordsPartitioner());
excelRecordWriter.writeBatch(records.getRight()); excelRecordWriter.writeBatch(records.getRight());
excelRecordWriter.close(); excelRecordWriter.close();
File parentFile = outputFile.getParentFile(); File parentFile = outputFile.getParentFile();
assertEquals(1,parentFile.list().length); assertEquals(1, parentFile.list().length);
ExcelRecordReader excelRecordReader = new ExcelRecordReader(); ExcelRecordReader excelRecordReader = new ExcelRecordReader();
excelRecordReader.initialize(fileSplit); excelRecordReader.initialize(fileSplit);
List<List<Writable>> next = excelRecordReader.next(10); List<List<Writable>> next = excelRecordReader.next(10);
assertEquals(10,next.size()); assertEquals(10, next.size());
} }
private Triple<String,Schema,List<List<Writable>>> records() { private Triple<String, Schema, List<List<Writable>>> records() {
List<List<Writable>> list = new ArrayList<>(); List<List<Writable>> list = new ArrayList<>();
StringBuilder sb = new StringBuilder(); StringBuilder sb = new StringBuilder();
int numColumns = 3; int numColumns = 3;
@ -80,13 +80,10 @@ public class ExcelRecordWriterTest {
} }
list.add(temp); list.add(temp);
} }
Schema.Builder schemaBuilder = new Schema.Builder(); Schema.Builder schemaBuilder = new Schema.Builder();
for(int i = 0; i < numColumns; i++) { for (int i = 0; i < numColumns; i++) {
schemaBuilder.addColumnInteger(String.valueOf(i)); schemaBuilder.addColumnInteger(String.valueOf(i));
} }
return Triple.of(sb.toString(), schemaBuilder.build(), list);
return Triple.of(sb.toString(),schemaBuilder.build(),list);
} }
} }

View File

@ -17,14 +17,12 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.api.records.reader.impl; package org.datavec.api.records.reader.impl;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.Assert.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.Assert.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
import java.io.File; import java.io.File;
import java.net.URI; import java.net.URI;
import java.sql.Connection; import java.sql.Connection;
@ -49,53 +47,57 @@ import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.LongWritable; import org.datavec.api.writable.LongWritable;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.After; import org.junit.jupiter.api.AfterEach;
import org.junit.Before; import org.junit.jupiter.api.BeforeEach;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.TemporaryFolder; import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.jupiter.api.Assertions.assertThrows;
public class JDBCRecordReaderTest { @DisplayName("Jdbc Record Reader Test")
class JDBCRecordReaderTest {
@Rule @TempDir
public TemporaryFolder testDir = new TemporaryFolder(); public Path testDir;
Connection conn; Connection conn;
EmbeddedDataSource dataSource; EmbeddedDataSource dataSource;
private final String dbName = "datavecTests"; private final String dbName = "datavecTests";
private final String driverClassName = "org.apache.derby.jdbc.EmbeddedDriver"; private final String driverClassName = "org.apache.derby.jdbc.EmbeddedDriver";
@Before @BeforeEach
public void setUp() throws Exception { void setUp() throws Exception {
File f = testDir.newFolder(); File f = testDir.toFile();
System.setProperty("derby.system.home", f.getAbsolutePath()); System.setProperty("derby.system.home", f.getAbsolutePath());
dataSource = new EmbeddedDataSource(); dataSource = new EmbeddedDataSource();
dataSource.setDatabaseName(dbName); dataSource.setDatabaseName(dbName);
dataSource.setCreateDatabase("create"); dataSource.setCreateDatabase("create");
conn = dataSource.getConnection(); conn = dataSource.getConnection();
TestDb.dropTables(conn); TestDb.dropTables(conn);
TestDb.buildCoffeeTable(conn); TestDb.buildCoffeeTable(conn);
} }
@After @AfterEach
public void tearDown() throws Exception { void tearDown() throws Exception {
DbUtils.closeQuietly(conn); DbUtils.closeQuietly(conn);
} }
@Test @Test
public void testSimpleIter() throws Exception { @DisplayName("Test Simple Iter")
void testSimpleIter() throws Exception {
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) { try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
List<List<Writable>> records = new ArrayList<>(); List<List<Writable>> records = new ArrayList<>();
while (reader.hasNext()) { while (reader.hasNext()) {
List<Writable> values = reader.next(); List<Writable> values = reader.next();
records.add(values); records.add(values);
} }
assertFalse(records.isEmpty()); assertFalse(records.isEmpty());
List<Writable> first = records.get(0); List<Writable> first = records.get(0);
assertEquals(new Text("Bolivian Dark"), first.get(0)); assertEquals(new Text("Bolivian Dark"), first.get(0));
assertEquals(new Text("14-001"), first.get(1)); assertEquals(new Text("14-001"), first.get(1));
@ -104,39 +106,43 @@ public class JDBCRecordReaderTest {
} }
@Test @Test
public void testSimpleWithListener() throws Exception { @DisplayName("Test Simple With Listener")
void testSimpleWithListener() throws Exception {
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) { try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
RecordListener recordListener = new LogRecordListener(); RecordListener recordListener = new LogRecordListener();
reader.setListeners(recordListener); reader.setListeners(recordListener);
reader.next(); reader.next();
assertTrue(recordListener.invoked()); assertTrue(recordListener.invoked());
} }
} }
@Test @Test
public void testReset() throws Exception { @DisplayName("Test Reset")
void testReset() throws Exception {
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) { try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
List<List<Writable>> records = new ArrayList<>(); List<List<Writable>> records = new ArrayList<>();
records.add(reader.next()); records.add(reader.next());
reader.reset(); reader.reset();
records.add(reader.next()); records.add(reader.next());
assertEquals(2, records.size()); assertEquals(2, records.size());
assertEquals(new Text("Bolivian Dark"), records.get(0).get(0)); assertEquals(new Text("Bolivian Dark"), records.get(0).get(0));
assertEquals(new Text("Bolivian Dark"), records.get(1).get(0)); assertEquals(new Text("Bolivian Dark"), records.get(1).get(0));
} }
} }
@Test(expected = IllegalStateException.class) @Test
public void testLackingDataSourceShouldFail() throws Exception { @DisplayName("Test Lacking Data Source Should Fail")
try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM Coffee")) { void testLackingDataSourceShouldFail() {
reader.initialize(null); assertThrows(IllegalStateException.class, () -> {
} try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM Coffee")) {
reader.initialize(null);
}
});
} }
@Test @Test
public void testConfigurationDataSourceInitialization() throws Exception { @DisplayName("Test Configuration Data Source Initialization")
void testConfigurationDataSourceInitialization() throws Exception {
try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM Coffee")) { try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM Coffee")) {
Configuration conf = new Configuration(); Configuration conf = new Configuration();
conf.set(JDBCRecordReader.JDBC_URL, "jdbc:derby:" + dbName + ";create=true"); conf.set(JDBCRecordReader.JDBC_URL, "jdbc:derby:" + dbName + ";create=true");
@ -146,28 +152,33 @@ public class JDBCRecordReaderTest {
} }
} }
@Test(expected = IllegalArgumentException.class) @Test
public void testInitConfigurationMissingParametersShouldFail() throws Exception { @DisplayName("Test Init Configuration Missing Parameters Should Fail")
try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM Coffee")) { void testInitConfigurationMissingParametersShouldFail() {
Configuration conf = new Configuration(); assertThrows(IllegalArgumentException.class, () -> {
conf.set(JDBCRecordReader.JDBC_URL, "should fail anyway"); try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM Coffee")) {
reader.initialize(conf, null); Configuration conf = new Configuration();
} conf.set(JDBCRecordReader.JDBC_URL, "should fail anyway");
} reader.initialize(conf, null);
}
@Test(expected = UnsupportedOperationException.class) });
public void testRecordDataInputStreamShouldFail() throws Exception {
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
reader.record(null, null);
}
} }
@Test @Test
public void testLoadFromMetaData() throws Exception { @DisplayName("Test Record Data Input Stream Should Fail")
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) { void testRecordDataInputStreamShouldFail() {
RecordMetaDataJdbc rmd = new RecordMetaDataJdbc(new URI(conn.getMetaData().getURL()), assertThrows(UnsupportedOperationException.class, () -> {
"SELECT * FROM Coffee WHERE ProdNum = ?", Collections.singletonList("14-001"), reader.getClass()); try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
reader.record(null, null);
}
});
}
@Test
@DisplayName("Test Load From Meta Data")
void testLoadFromMetaData() throws Exception {
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
RecordMetaDataJdbc rmd = new RecordMetaDataJdbc(new URI(conn.getMetaData().getURL()), "SELECT * FROM Coffee WHERE ProdNum = ?", Collections.singletonList("14-001"), reader.getClass());
Record res = reader.loadFromMetaData(rmd); Record res = reader.loadFromMetaData(rmd);
assertNotNull(res); assertNotNull(res);
assertEquals(new Text("Bolivian Dark"), res.getRecord().get(0)); assertEquals(new Text("Bolivian Dark"), res.getRecord().get(0));
@ -177,7 +188,8 @@ public class JDBCRecordReaderTest {
} }
@Test @Test
public void testNextRecord() throws Exception { @DisplayName("Test Next Record")
void testNextRecord() throws Exception {
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) { try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
Record r = reader.nextRecord(); Record r = reader.nextRecord();
List<Writable> fields = r.getRecord(); List<Writable> fields = r.getRecord();
@ -193,7 +205,8 @@ public class JDBCRecordReaderTest {
} }
@Test @Test
public void testNextRecordAndRecover() throws Exception { @DisplayName("Test Next Record And Recover")
void testNextRecordAndRecover() throws Exception {
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) { try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
Record r = reader.nextRecord(); Record r = reader.nextRecord();
List<Writable> fields = r.getRecord(); List<Writable> fields = r.getRecord();
@ -208,69 +221,91 @@ public class JDBCRecordReaderTest {
} }
// Resetting the record reader when initialized as forward only should fail // Resetting the record reader when initialized as forward only should fail
@Test(expected = RuntimeException.class) @Test
public void testResetForwardOnlyShouldFail() throws Exception { @DisplayName("Test Reset Forward Only Should Fail")
try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM Coffee", dataSource)) { void testResetForwardOnlyShouldFail() {
Configuration conf = new Configuration(); assertThrows(RuntimeException.class, () -> {
conf.setInt(JDBCRecordReader.JDBC_RESULTSET_TYPE, ResultSet.TYPE_FORWARD_ONLY); try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM Coffee", dataSource)) {
reader.initialize(conf, null); Configuration conf = new Configuration();
reader.next(); conf.setInt(JDBCRecordReader.JDBC_RESULTSET_TYPE, ResultSet.TYPE_FORWARD_ONLY);
reader.reset(); reader.initialize(conf, null);
} reader.next();
reader.reset();
}
});
} }
@Test @Test
public void testReadAllTypes() throws Exception { @DisplayName("Test Read All Types")
void testReadAllTypes() throws Exception {
TestDb.buildAllTypesTable(conn); TestDb.buildAllTypesTable(conn);
try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM AllTypes", dataSource)) { try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM AllTypes", dataSource)) {
reader.initialize(null); reader.initialize(null);
List<Writable> item = reader.next(); List<Writable> item = reader.next();
assertEquals(item.size(), 15); assertEquals(item.size(), 15);
assertEquals(BooleanWritable.class, item.get(0).getClass()); // boolean to boolean // boolean to boolean
assertEquals(Text.class, item.get(1).getClass()); // date to text assertEquals(BooleanWritable.class, item.get(0).getClass());
assertEquals(Text.class, item.get(2).getClass()); // time to text // date to text
assertEquals(Text.class, item.get(3).getClass()); // timestamp to text assertEquals(Text.class, item.get(1).getClass());
assertEquals(Text.class, item.get(4).getClass()); // char to text // time to text
assertEquals(Text.class, item.get(5).getClass()); // long varchar to text assertEquals(Text.class, item.get(2).getClass());
assertEquals(Text.class, item.get(6).getClass()); // varchar to text // timestamp to text
assertEquals(DoubleWritable.class, assertEquals(Text.class, item.get(3).getClass());
item.get(7).getClass()); // float to double (derby's float is an alias of double by default) // char to text
assertEquals(FloatWritable.class, item.get(8).getClass()); // real to float assertEquals(Text.class, item.get(4).getClass());
assertEquals(DoubleWritable.class, item.get(9).getClass()); // decimal to double // long varchar to text
assertEquals(DoubleWritable.class, item.get(10).getClass()); // numeric to double assertEquals(Text.class, item.get(5).getClass());
assertEquals(DoubleWritable.class, item.get(11).getClass()); // double to double // varchar to text
assertEquals(IntWritable.class, item.get(12).getClass()); // integer to integer assertEquals(Text.class, item.get(6).getClass());
assertEquals(IntWritable.class, item.get(13).getClass()); // small int to integer assertEquals(DoubleWritable.class, // float to double (derby's float is an alias of double by default)
assertEquals(LongWritable.class, item.get(14).getClass()); // bigint to long item.get(7).getClass());
// real to float
assertEquals(FloatWritable.class, item.get(8).getClass());
// decimal to double
assertEquals(DoubleWritable.class, item.get(9).getClass());
// numeric to double
assertEquals(DoubleWritable.class, item.get(10).getClass());
// double to double
assertEquals(DoubleWritable.class, item.get(11).getClass());
// integer to integer
assertEquals(IntWritable.class, item.get(12).getClass());
// small int to integer
assertEquals(IntWritable.class, item.get(13).getClass());
// bigint to long
assertEquals(LongWritable.class, item.get(14).getClass());
} }
} }
@Test(expected = RuntimeException.class) @Test
public void testNextNoMoreShouldFail() throws Exception { @DisplayName("Test Next No More Should Fail")
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) { void testNextNoMoreShouldFail() {
while (reader.hasNext()) { assertThrows(RuntimeException.class, () -> {
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
while (reader.hasNext()) {
reader.next();
}
reader.next(); reader.next();
} }
reader.next(); });
}
} }
@Test(expected = IllegalArgumentException.class) @Test
public void testInvalidMetadataShouldFail() throws Exception { @DisplayName("Test Invalid Metadata Should Fail")
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) { void testInvalidMetadataShouldFail() {
RecordMetaDataLine md = new RecordMetaDataLine(1, new URI("file://test"), JDBCRecordReader.class); assertThrows(IllegalArgumentException.class, () -> {
reader.loadFromMetaData(md); try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
} RecordMetaDataLine md = new RecordMetaDataLine(1, new URI("file://test"), JDBCRecordReader.class);
reader.loadFromMetaData(md);
}
});
} }
private JDBCRecordReader getInitializedReader(String query) throws Exception { private JDBCRecordReader getInitializedReader(String query) throws Exception {
int[] indices = {1}; // ProdNum column // ProdNum column
JDBCRecordReader reader = new JDBCRecordReader(query, dataSource, "SELECT * FROM Coffee WHERE ProdNum = ?", int[] indices = { 1 };
indices); JDBCRecordReader reader = new JDBCRecordReader(query, dataSource, "SELECT * FROM Coffee WHERE ProdNum = ?", indices);
reader.setTrimStrings(true); reader.setTrimStrings(true);
reader.initialize(null); reader.initialize(null);
return reader; return reader;
} }
} }

View File

@ -17,10 +17,8 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.local.transforms.transform; package org.datavec.local.transforms.transform;
import org.datavec.api.transform.MathFunction; import org.datavec.api.transform.MathFunction;
import org.datavec.api.transform.MathOp; import org.datavec.api.transform.MathOp;
import org.datavec.api.transform.ReduceOp; import org.datavec.api.transform.ReduceOp;
@ -32,107 +30,86 @@ import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.schema.SequenceSchema; import org.datavec.api.transform.schema.SequenceSchema;
import org.datavec.api.writable.*; import org.datavec.api.writable.*;
import org.datavec.python.PythonTransform; import org.datavec.python.PythonTransform;
import org.datavec.local.transforms.LocalTransformExecutor; import org.datavec.local.transforms.LocalTransformExecutor;
import org.junit.Ignore; import org.junit.jupiter.api.Disabled;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms; import org.nd4j.linalg.ops.transforms.Transforms;
import java.util.*; import java.util.*;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static java.time.Duration.ofMillis;
import static org.junit.jupiter.api.Assertions.assertTimeout;
import static org.junit.Assert.assertEquals; @DisplayName("Execution Test")
class ExecutionTest {
public class ExecutionTest {
@Test @Test
public void testExecutionNdarray() { @DisplayName("Test Execution Ndarray")
Schema schema = new Schema.Builder() void testExecutionNdarray() {
.addColumnNDArray("first",new long[]{1,32577}) Schema schema = new Schema.Builder().addColumnNDArray("first", new long[] { 1, 32577 }).addColumnNDArray("second", new long[] { 1, 32577 }).build();
.addColumnNDArray("second",new long[]{1,32577}).build(); TransformProcess transformProcess = new TransformProcess.Builder(schema).ndArrayMathFunctionTransform("first", MathFunction.SIN).ndArrayMathFunctionTransform("second", MathFunction.COS).build();
TransformProcess transformProcess = new TransformProcess.Builder(schema)
.ndArrayMathFunctionTransform("first", MathFunction.SIN)
.ndArrayMathFunctionTransform("second",MathFunction.COS)
.build();
List<List<Writable>> functions = new ArrayList<>(); List<List<Writable>> functions = new ArrayList<>();
List<Writable> firstRow = new ArrayList<>(); List<Writable> firstRow = new ArrayList<>();
INDArray firstArr = Nd4j.linspace(1,4,4); INDArray firstArr = Nd4j.linspace(1, 4, 4);
INDArray secondArr = Nd4j.linspace(1,4,4); INDArray secondArr = Nd4j.linspace(1, 4, 4);
firstRow.add(new NDArrayWritable(firstArr)); firstRow.add(new NDArrayWritable(firstArr));
firstRow.add(new NDArrayWritable(secondArr)); firstRow.add(new NDArrayWritable(secondArr));
functions.add(firstRow); functions.add(firstRow);
List<List<Writable>> execute = LocalTransformExecutor.execute(functions, transformProcess); List<List<Writable>> execute = LocalTransformExecutor.execute(functions, transformProcess);
INDArray firstResult = ((NDArrayWritable) execute.get(0).get(0)).get(); INDArray firstResult = ((NDArrayWritable) execute.get(0).get(0)).get();
INDArray secondResult = ((NDArrayWritable) execute.get(0).get(1)).get(); INDArray secondResult = ((NDArrayWritable) execute.get(0).get(1)).get();
INDArray expected = Transforms.sin(firstArr); INDArray expected = Transforms.sin(firstArr);
INDArray secondExpected = Transforms.cos(secondArr); INDArray secondExpected = Transforms.cos(secondArr);
assertEquals(expected,firstResult); assertEquals(expected, firstResult);
assertEquals(secondExpected,secondResult); assertEquals(secondExpected, secondResult);
} }
@Test @Test
public void testExecutionSimple() { @DisplayName("Test Execution Simple")
Schema schema = new Schema.Builder().addColumnInteger("col0") void testExecutionSimple() {
.addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2"). Schema schema = new Schema.Builder().addColumnInteger("col0").addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").addColumnFloat("col3").build();
addColumnFloat("col3").build(); TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1").doubleMathOp("col2", MathOp.Add, 10.0).floatMathOp("col3", MathOp.Add, 5f).build();
TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1")
.doubleMathOp("col2", MathOp.Add, 10.0).floatMathOp("col3", MathOp.Add, 5f).build();
List<List<Writable>> inputData = new ArrayList<>(); List<List<Writable>> inputData = new ArrayList<>();
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1), new FloatWritable(0.3f))); inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1), new FloatWritable(0.3f)));
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1), new FloatWritable(1.7f))); inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1), new FloatWritable(1.7f)));
inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1), new FloatWritable(3.6f))); inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1), new FloatWritable(3.6f)));
List<List<Writable>> rdd = (inputData); List<List<Writable>> rdd = (inputData);
List<List<Writable>> out = new ArrayList<>(LocalTransformExecutor.execute(rdd, tp)); List<List<Writable>> out = new ArrayList<>(LocalTransformExecutor.execute(rdd, tp));
Collections.sort(out, new Comparator<List<Writable>>() { Collections.sort(out, new Comparator<List<Writable>>() {
@Override @Override
public int compare(List<Writable> o1, List<Writable> o2) { public int compare(List<Writable> o1, List<Writable> o2) {
return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt()); return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
} }
}); });
List<List<Writable>> expected = new ArrayList<>(); List<List<Writable>> expected = new ArrayList<>();
expected.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1), new FloatWritable(5.3f))); expected.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1), new FloatWritable(5.3f)));
expected.add(Arrays.<Writable>asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1), new FloatWritable(6.7f))); expected.add(Arrays.<Writable>asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1), new FloatWritable(6.7f)));
expected.add(Arrays.<Writable>asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1), new FloatWritable(8.6f))); expected.add(Arrays.<Writable>asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1), new FloatWritable(8.6f)));
assertEquals(expected, out); assertEquals(expected, out);
} }
@Test @Test
public void testFilter() { @DisplayName("Test Filter")
Schema filterSchema = new Schema.Builder() void testFilter() {
.addColumnDouble("col1").addColumnDouble("col2") Schema filterSchema = new Schema.Builder().addColumnDouble("col1").addColumnDouble("col2").addColumnDouble("col3").build();
.addColumnDouble("col3").build();
List<List<Writable>> inputData = new ArrayList<>(); List<List<Writable>> inputData = new ArrayList<>();
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new DoubleWritable(1), new DoubleWritable(0.1))); inputData.add(Arrays.<Writable>asList(new IntWritable(0), new DoubleWritable(1), new DoubleWritable(0.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new DoubleWritable(3), new DoubleWritable(1.1))); inputData.add(Arrays.<Writable>asList(new IntWritable(1), new DoubleWritable(3), new DoubleWritable(1.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(2), new DoubleWritable(3), new DoubleWritable(2.1))); inputData.add(Arrays.<Writable>asList(new IntWritable(2), new DoubleWritable(3), new DoubleWritable(2.1)));
TransformProcess transformProcess = new TransformProcess.Builder(filterSchema) TransformProcess transformProcess = new TransformProcess.Builder(filterSchema).filter(new DoubleColumnCondition("col1", ConditionOp.LessThan, 1)).build();
.filter(new DoubleColumnCondition("col1",ConditionOp.LessThan,1)).build();
List<List<Writable>> execute = LocalTransformExecutor.execute(inputData, transformProcess); List<List<Writable>> execute = LocalTransformExecutor.execute(inputData, transformProcess);
assertEquals(2,execute.size()); assertEquals(2, execute.size());
} }
@Test @Test
public void testExecutionSequence() { @DisplayName("Test Execution Sequence")
void testExecutionSequence() {
Schema schema = new SequenceSchema.Builder().addColumnInteger("col0") Schema schema = new SequenceSchema.Builder().addColumnInteger("col0").addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").build();
.addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").build(); TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1").doubleMathOp("col2", MathOp.Add, 10.0).build();
TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1")
.doubleMathOp("col2", MathOp.Add, 10.0).build();
List<List<List<Writable>>> inputSequences = new ArrayList<>(); List<List<List<Writable>>> inputSequences = new ArrayList<>();
List<List<Writable>> seq1 = new ArrayList<>(); List<List<Writable>> seq1 = new ArrayList<>();
seq1.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); seq1.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
@ -141,21 +118,17 @@ public class ExecutionTest {
List<List<Writable>> seq2 = new ArrayList<>(); List<List<Writable>> seq2 = new ArrayList<>();
seq2.add(Arrays.<Writable>asList(new IntWritable(3), new Text("state0"), new DoubleWritable(3.1))); seq2.add(Arrays.<Writable>asList(new IntWritable(3), new Text("state0"), new DoubleWritable(3.1)));
seq2.add(Arrays.<Writable>asList(new IntWritable(4), new Text("state1"), new DoubleWritable(4.1))); seq2.add(Arrays.<Writable>asList(new IntWritable(4), new Text("state1"), new DoubleWritable(4.1)));
inputSequences.add(seq1); inputSequences.add(seq1);
inputSequences.add(seq2); inputSequences.add(seq2);
List<List<List<Writable>>> rdd = (inputSequences);
List<List<List<Writable>>> rdd = (inputSequences);
List<List<List<Writable>>> out = LocalTransformExecutor.executeSequenceToSequence(rdd, tp); List<List<List<Writable>>> out = LocalTransformExecutor.executeSequenceToSequence(rdd, tp);
Collections.sort(out, new Comparator<List<List<Writable>>>() { Collections.sort(out, new Comparator<List<List<Writable>>>() {
@Override @Override
public int compare(List<List<Writable>> o1, List<List<Writable>> o2) { public int compare(List<List<Writable>> o1, List<List<Writable>> o2) {
return -Integer.compare(o1.size(), o2.size()); return -Integer.compare(o1.size(), o2.size());
} }
}); });
List<List<List<Writable>>> expectedSequence = new ArrayList<>(); List<List<List<Writable>>> expectedSequence = new ArrayList<>();
List<List<Writable>> seq1e = new ArrayList<>(); List<List<Writable>> seq1e = new ArrayList<>();
seq1e.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1))); seq1e.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1)));
@ -164,121 +137,66 @@ public class ExecutionTest {
List<List<Writable>> seq2e = new ArrayList<>(); List<List<Writable>> seq2e = new ArrayList<>();
seq2e.add(Arrays.<Writable>asList(new IntWritable(3), new IntWritable(0), new DoubleWritable(13.1))); seq2e.add(Arrays.<Writable>asList(new IntWritable(3), new IntWritable(0), new DoubleWritable(13.1)));
seq2e.add(Arrays.<Writable>asList(new IntWritable(4), new IntWritable(1), new DoubleWritable(14.1))); seq2e.add(Arrays.<Writable>asList(new IntWritable(4), new IntWritable(1), new DoubleWritable(14.1)));
expectedSequence.add(seq1e); expectedSequence.add(seq1e);
expectedSequence.add(seq2e); expectedSequence.add(seq2e);
assertEquals(expectedSequence, out); assertEquals(expectedSequence, out);
} }
@Test @Test
public void testReductionGlobal() { @DisplayName("Test Reduction Global")
void testReductionGlobal() {
List<List<Writable>> in = Arrays.asList( List<List<Writable>> in = Arrays.asList(Arrays.<Writable>asList(new Text("first"), new DoubleWritable(3.0)), Arrays.<Writable>asList(new Text("second"), new DoubleWritable(5.0)));
Arrays.<Writable>asList(new Text("first"), new DoubleWritable(3.0)),
Arrays.<Writable>asList(new Text("second"), new DoubleWritable(5.0))
);
List<List<Writable>> inData = in; List<List<Writable>> inData = in;
Schema s = new Schema.Builder().addColumnString("textCol").addColumnDouble("doubleCol").build();
Schema s = new Schema.Builder() TransformProcess tp = new TransformProcess.Builder(s).reduce(new Reducer.Builder(ReduceOp.TakeFirst).takeFirstColumns("textCol").meanColumns("doubleCol").build()).build();
.addColumnString("textCol")
.addColumnDouble("doubleCol")
.build();
TransformProcess tp = new TransformProcess.Builder(s)
.reduce(new Reducer.Builder(ReduceOp.TakeFirst)
.takeFirstColumns("textCol")
.meanColumns("doubleCol").build())
.build();
List<List<Writable>> outRdd = LocalTransformExecutor.execute(inData, tp); List<List<Writable>> outRdd = LocalTransformExecutor.execute(inData, tp);
List<List<Writable>> out = outRdd; List<List<Writable>> out = outRdd;
List<List<Writable>> expOut = Collections.singletonList(Arrays.<Writable>asList(new Text("first"), new DoubleWritable(4.0))); List<List<Writable>> expOut = Collections.singletonList(Arrays.<Writable>asList(new Text("first"), new DoubleWritable(4.0)));
assertEquals(expOut, out); assertEquals(expOut, out);
} }
@Test @Test
public void testReductionByKey(){ @DisplayName("Test Reduction By Key")
void testReductionByKey() {
List<List<Writable>> in = Arrays.asList( List<List<Writable>> in = Arrays.asList(Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(3.0)), Arrays.<Writable>asList(new IntWritable(0), new Text("second"), new DoubleWritable(5.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(30.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("s"), new DoubleWritable(50.0)));
Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(3.0)),
Arrays.<Writable>asList(new IntWritable(0), new Text("second"), new DoubleWritable(5.0)),
Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(30.0)),
Arrays.<Writable>asList(new IntWritable(1), new Text("s"), new DoubleWritable(50.0))
);
List<List<Writable>> inData = in; List<List<Writable>> inData = in;
Schema s = new Schema.Builder().addColumnInteger("intCol").addColumnString("textCol").addColumnDouble("doubleCol").build();
Schema s = new Schema.Builder() TransformProcess tp = new TransformProcess.Builder(s).reduce(new Reducer.Builder(ReduceOp.TakeFirst).keyColumns("intCol").takeFirstColumns("textCol").meanColumns("doubleCol").build()).build();
.addColumnInteger("intCol")
.addColumnString("textCol")
.addColumnDouble("doubleCol")
.build();
TransformProcess tp = new TransformProcess.Builder(s)
.reduce(new Reducer.Builder(ReduceOp.TakeFirst)
.keyColumns("intCol")
.takeFirstColumns("textCol")
.meanColumns("doubleCol").build())
.build();
List<List<Writable>> outRdd = LocalTransformExecutor.execute(inData, tp); List<List<Writable>> outRdd = LocalTransformExecutor.execute(inData, tp);
List<List<Writable>> out = outRdd; List<List<Writable>> out = outRdd;
List<List<Writable>> expOut = Arrays.asList(Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(4.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(40.0)));
List<List<Writable>> expOut = Arrays.asList(
Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(4.0)),
Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(40.0)));
out = new ArrayList<>(out); out = new ArrayList<>(out);
Collections.sort( Collections.sort(out, new Comparator<List<Writable>>() {
out, new Comparator<List<Writable>>() {
@Override
public int compare(List<Writable> o1, List<Writable> o2) {
return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
}
}
);
@Override
public int compare(List<Writable> o1, List<Writable> o2) {
return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
}
});
assertEquals(expOut, out); assertEquals(expOut, out);
} }
@Test(timeout = 60000L) @Test
@Ignore("AB 2019/05/21 - Fine locally, timeouts on CI - Issue #7657 and #7771") @Disabled("AB 2019/05/21 - Fine locally, timeouts on CI - Issue #7657 and #7771")
public void testPythonExecutionNdarray()throws Exception{ @DisplayName("Test Python Execution Ndarray")
Schema schema = new Schema.Builder() void testPythonExecutionNdarray() {
.addColumnNDArray("first",new long[]{1,32577}) assertTimeout(ofMillis(60000), () -> {
.addColumnNDArray("second",new long[]{1,32577}).build(); Schema schema = new Schema.Builder().addColumnNDArray("first", new long[] { 1, 32577 }).addColumnNDArray("second", new long[] { 1, 32577 }).build();
TransformProcess transformProcess = new TransformProcess.Builder(schema).transform(PythonTransform.builder().code("first = np.sin(first)\nsecond = np.cos(second)").outputSchema(schema).build()).build();
TransformProcess transformProcess = new TransformProcess.Builder(schema) List<List<Writable>> functions = new ArrayList<>();
.transform( List<Writable> firstRow = new ArrayList<>();
PythonTransform.builder().code( INDArray firstArr = Nd4j.linspace(1, 4, 4);
"first = np.sin(first)\nsecond = np.cos(second)") INDArray secondArr = Nd4j.linspace(1, 4, 4);
.outputSchema(schema).build()) firstRow.add(new NDArrayWritable(firstArr));
.build(); firstRow.add(new NDArrayWritable(secondArr));
functions.add(firstRow);
List<List<Writable>> functions = new ArrayList<>(); List<List<Writable>> execute = LocalTransformExecutor.execute(functions, transformProcess);
List<Writable> firstRow = new ArrayList<>(); INDArray firstResult = ((NDArrayWritable) execute.get(0).get(0)).get();
INDArray firstArr = Nd4j.linspace(1,4,4); INDArray secondResult = ((NDArrayWritable) execute.get(0).get(1)).get();
INDArray secondArr = Nd4j.linspace(1,4,4); INDArray expected = Transforms.sin(firstArr);
firstRow.add(new NDArrayWritable(firstArr)); INDArray secondExpected = Transforms.cos(secondArr);
firstRow.add(new NDArrayWritable(secondArr)); assertEquals(expected, firstResult);
functions.add(firstRow); assertEquals(secondExpected, secondResult);
});
List<List<Writable>> execute = LocalTransformExecutor.execute(functions, transformProcess);
INDArray firstResult = ((NDArrayWritable) execute.get(0).get(0)).get();
INDArray secondResult = ((NDArrayWritable) execute.get(0).get(1)).get();
INDArray expected = Transforms.sin(firstArr);
INDArray secondExpected = Transforms.cos(secondArr);
assertEquals(expected,firstResult);
assertEquals(secondExpected,secondResult);
} }
} }

View File

@ -17,36 +17,38 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.spark; package org.datavec.spark;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.spark.SparkConf; import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.JavaSparkContext;
import org.junit.After; import org.junit.jupiter.api.AfterEach;
import org.junit.Before; import org.junit.jupiter.api.BeforeEach;
import java.io.Serializable; import java.io.Serializable;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@Slf4j @Slf4j
@DisplayName("Base Spark Test")
public abstract class BaseSparkTest implements Serializable { public abstract class BaseSparkTest implements Serializable {
protected static JavaSparkContext sc; protected static JavaSparkContext sc;
@Before @BeforeEach
public void before() { void before() {
sc = getContext(); sc = getContext();
} }
@After @AfterEach
public synchronized void after() { synchronized void after() {
sc.close(); sc.close();
//Wait until it's stopped, to avoid race conditions during tests // Wait until it's stopped, to avoid race conditions during tests
for (int i = 0; i < 100; i++) { for (int i = 0; i < 100; i++) {
if (!sc.sc().stopped().get()) { if (!sc.sc().stopped().get()) {
try { try {
Thread.sleep(100L); Thread.sleep(100L);
} catch (InterruptedException e) { } catch (InterruptedException e) {
log.error("",e); log.error("", e);
} }
} else { } else {
break; break;
@ -55,29 +57,21 @@ public abstract class BaseSparkTest implements Serializable {
if (!sc.sc().stopped().get()) { if (!sc.sc().stopped().get()) {
throw new RuntimeException("Spark context is not stopped after 10s"); throw new RuntimeException("Spark context is not stopped after 10s");
} }
sc = null; sc = null;
} }
public synchronized JavaSparkContext getContext() { public synchronized JavaSparkContext getContext() {
if (sc != null) if (sc != null)
return sc; return sc;
SparkConf sparkConf = new SparkConf().setMaster("local[*]").set("spark.driver.host", "localhost").set("spark.driverEnv.SPARK_LOCAL_IP", "127.0.0.1").set("spark.executorEnv.SPARK_LOCAL_IP", "127.0.0.1").setAppName("sparktest");
SparkConf sparkConf = new SparkConf().setMaster("local[*]").set("spark.driver.host", "localhost")
.set("spark.driverEnv.SPARK_LOCAL_IP", "127.0.0.1")
.set("spark.executorEnv.SPARK_LOCAL_IP", "127.0.0.1").setAppName("sparktest");
if (useKryo()) { if (useKryo()) {
sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer"); sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer");
} }
sc = new JavaSparkContext(sparkConf); sc = new JavaSparkContext(sparkConf);
return sc; return sc;
} }
public boolean useKryo(){ public boolean useKryo() {
return false; return false;
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.datavec.spark.transform; package org.datavec.spark.transform;
import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaRDD;
@ -35,59 +34,51 @@ import org.datavec.api.writable.Writable;
import org.datavec.api.writable.NDArrayWritable; import org.datavec.api.writable.NDArrayWritable;
import org.datavec.spark.BaseSparkTest; import org.datavec.spark.BaseSparkTest;
import org.datavec.python.PythonTransform; import org.datavec.python.PythonTransform;
import org.junit.Ignore; import org.junit.jupiter.api.Disabled;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import java.util.*; import java.util.*;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static java.time.Duration.ofMillis;
import static org.junit.jupiter.api.Assertions.assertTimeout;
import static org.junit.Assert.assertEquals; @DisplayName("Execution Test")
import static org.junit.Assert.assertTrue; class ExecutionTest extends BaseSparkTest {
public class ExecutionTest extends BaseSparkTest {
@Test @Test
public void testExecutionSimple() { @DisplayName("Test Execution Simple")
Schema schema = new Schema.Builder().addColumnInteger("col0") void testExecutionSimple() {
.addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").build(); Schema schema = new Schema.Builder().addColumnInteger("col0").addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").build();
TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1").doubleMathOp("col2", MathOp.Add, 10.0).build();
TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1")
.doubleMathOp("col2", MathOp.Add, 10.0).build();
List<List<Writable>> inputData = new ArrayList<>(); List<List<Writable>> inputData = new ArrayList<>();
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1))); inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1))); inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1)));
JavaRDD<List<Writable>> rdd = sc.parallelize(inputData); JavaRDD<List<Writable>> rdd = sc.parallelize(inputData);
List<List<Writable>> out = new ArrayList<>(SparkTransformExecutor.execute(rdd, tp).collect()); List<List<Writable>> out = new ArrayList<>(SparkTransformExecutor.execute(rdd, tp).collect());
Collections.sort(out, new Comparator<List<Writable>>() { Collections.sort(out, new Comparator<List<Writable>>() {
@Override @Override
public int compare(List<Writable> o1, List<Writable> o2) { public int compare(List<Writable> o1, List<Writable> o2) {
return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt()); return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
} }
}); });
List<List<Writable>> expected = new ArrayList<>(); List<List<Writable>> expected = new ArrayList<>();
expected.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1))); expected.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1)));
expected.add(Arrays.<Writable>asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1))); expected.add(Arrays.<Writable>asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1)));
expected.add(Arrays.<Writable>asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1))); expected.add(Arrays.<Writable>asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1)));
assertEquals(expected, out); assertEquals(expected, out);
} }
@Test @Test
public void testExecutionSequence() { @DisplayName("Test Execution Sequence")
void testExecutionSequence() {
Schema schema = new SequenceSchema.Builder().addColumnInteger("col0") Schema schema = new SequenceSchema.Builder().addColumnInteger("col0").addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").build();
.addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").build(); TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1").doubleMathOp("col2", MathOp.Add, 10.0).build();
TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1")
.doubleMathOp("col2", MathOp.Add, 10.0).build();
List<List<List<Writable>>> inputSequences = new ArrayList<>(); List<List<List<Writable>>> inputSequences = new ArrayList<>();
List<List<Writable>> seq1 = new ArrayList<>(); List<List<Writable>> seq1 = new ArrayList<>();
seq1.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); seq1.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
@ -96,22 +87,17 @@ public class ExecutionTest extends BaseSparkTest {
List<List<Writable>> seq2 = new ArrayList<>(); List<List<Writable>> seq2 = new ArrayList<>();
seq2.add(Arrays.<Writable>asList(new IntWritable(3), new Text("state0"), new DoubleWritable(3.1))); seq2.add(Arrays.<Writable>asList(new IntWritable(3), new Text("state0"), new DoubleWritable(3.1)));
seq2.add(Arrays.<Writable>asList(new IntWritable(4), new Text("state1"), new DoubleWritable(4.1))); seq2.add(Arrays.<Writable>asList(new IntWritable(4), new Text("state1"), new DoubleWritable(4.1)));
inputSequences.add(seq1); inputSequences.add(seq1);
inputSequences.add(seq2); inputSequences.add(seq2);
JavaRDD<List<List<Writable>>> rdd = sc.parallelize(inputSequences); JavaRDD<List<List<Writable>>> rdd = sc.parallelize(inputSequences);
List<List<List<Writable>>> out = new ArrayList<>(SparkTransformExecutor.executeSequenceToSequence(rdd, tp).collect());
List<List<List<Writable>>> out =
new ArrayList<>(SparkTransformExecutor.executeSequenceToSequence(rdd, tp).collect());
Collections.sort(out, new Comparator<List<List<Writable>>>() { Collections.sort(out, new Comparator<List<List<Writable>>>() {
@Override @Override
public int compare(List<List<Writable>> o1, List<List<Writable>> o2) { public int compare(List<List<Writable>> o1, List<List<Writable>> o2) {
return -Integer.compare(o1.size(), o2.size()); return -Integer.compare(o1.size(), o2.size());
} }
}); });
List<List<List<Writable>>> expectedSequence = new ArrayList<>(); List<List<List<Writable>>> expectedSequence = new ArrayList<>();
List<List<Writable>> seq1e = new ArrayList<>(); List<List<Writable>> seq1e = new ArrayList<>();
seq1e.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1))); seq1e.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1)));
@ -120,99 +106,49 @@ public class ExecutionTest extends BaseSparkTest {
List<List<Writable>> seq2e = new ArrayList<>(); List<List<Writable>> seq2e = new ArrayList<>();
seq2e.add(Arrays.<Writable>asList(new IntWritable(3), new IntWritable(0), new DoubleWritable(13.1))); seq2e.add(Arrays.<Writable>asList(new IntWritable(3), new IntWritable(0), new DoubleWritable(13.1)));
seq2e.add(Arrays.<Writable>asList(new IntWritable(4), new IntWritable(1), new DoubleWritable(14.1))); seq2e.add(Arrays.<Writable>asList(new IntWritable(4), new IntWritable(1), new DoubleWritable(14.1)));
expectedSequence.add(seq1e); expectedSequence.add(seq1e);
expectedSequence.add(seq2e); expectedSequence.add(seq2e);
assertEquals(expectedSequence, out); assertEquals(expectedSequence, out);
} }
@Test @Test
public void testReductionGlobal() { @DisplayName("Test Reduction Global")
void testReductionGlobal() {
List<List<Writable>> in = Arrays.asList( List<List<Writable>> in = Arrays.asList(Arrays.<Writable>asList(new Text("first"), new DoubleWritable(3.0)), Arrays.<Writable>asList(new Text("second"), new DoubleWritable(5.0)));
Arrays.<Writable>asList(new Text("first"), new DoubleWritable(3.0)),
Arrays.<Writable>asList(new Text("second"), new DoubleWritable(5.0))
);
JavaRDD<List<Writable>> inData = sc.parallelize(in); JavaRDD<List<Writable>> inData = sc.parallelize(in);
Schema s = new Schema.Builder().addColumnString("textCol").addColumnDouble("doubleCol").build();
Schema s = new Schema.Builder() TransformProcess tp = new TransformProcess.Builder(s).reduce(new Reducer.Builder(ReduceOp.TakeFirst).takeFirstColumns("textCol").meanColumns("doubleCol").build()).build();
.addColumnString("textCol")
.addColumnDouble("doubleCol")
.build();
TransformProcess tp = new TransformProcess.Builder(s)
.reduce(new Reducer.Builder(ReduceOp.TakeFirst)
.takeFirstColumns("textCol")
.meanColumns("doubleCol").build())
.build();
JavaRDD<List<Writable>> outRdd = SparkTransformExecutor.execute(inData, tp); JavaRDD<List<Writable>> outRdd = SparkTransformExecutor.execute(inData, tp);
List<List<Writable>> out = outRdd.collect(); List<List<Writable>> out = outRdd.collect();
List<List<Writable>> expOut = Collections.singletonList(Arrays.<Writable>asList(new Text("first"), new DoubleWritable(4.0))); List<List<Writable>> expOut = Collections.singletonList(Arrays.<Writable>asList(new Text("first"), new DoubleWritable(4.0)));
assertEquals(expOut, out); assertEquals(expOut, out);
} }
@Test @Test
public void testReductionByKey(){ @DisplayName("Test Reduction By Key")
void testReductionByKey() {
List<List<Writable>> in = Arrays.asList( List<List<Writable>> in = Arrays.asList(Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(3.0)), Arrays.<Writable>asList(new IntWritable(0), new Text("second"), new DoubleWritable(5.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(30.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("s"), new DoubleWritable(50.0)));
Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(3.0)),
Arrays.<Writable>asList(new IntWritable(0), new Text("second"), new DoubleWritable(5.0)),
Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(30.0)),
Arrays.<Writable>asList(new IntWritable(1), new Text("s"), new DoubleWritable(50.0))
);
JavaRDD<List<Writable>> inData = sc.parallelize(in); JavaRDD<List<Writable>> inData = sc.parallelize(in);
Schema s = new Schema.Builder().addColumnInteger("intCol").addColumnString("textCol").addColumnDouble("doubleCol").build();
Schema s = new Schema.Builder() TransformProcess tp = new TransformProcess.Builder(s).reduce(new Reducer.Builder(ReduceOp.TakeFirst).keyColumns("intCol").takeFirstColumns("textCol").meanColumns("doubleCol").build()).build();
.addColumnInteger("intCol")
.addColumnString("textCol")
.addColumnDouble("doubleCol")
.build();
TransformProcess tp = new TransformProcess.Builder(s)
.reduce(new Reducer.Builder(ReduceOp.TakeFirst)
.keyColumns("intCol")
.takeFirstColumns("textCol")
.meanColumns("doubleCol").build())
.build();
JavaRDD<List<Writable>> outRdd = SparkTransformExecutor.execute(inData, tp); JavaRDD<List<Writable>> outRdd = SparkTransformExecutor.execute(inData, tp);
List<List<Writable>> out = outRdd.collect(); List<List<Writable>> out = outRdd.collect();
List<List<Writable>> expOut = Arrays.asList(Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(4.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(40.0)));
List<List<Writable>> expOut = Arrays.asList(
Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(4.0)),
Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(40.0)));
out = new ArrayList<>(out); out = new ArrayList<>(out);
Collections.sort( Collections.sort(out, new Comparator<List<Writable>>() {
out, new Comparator<List<Writable>>() {
@Override
public int compare(List<Writable> o1, List<Writable> o2) {
return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
}
}
);
@Override
public int compare(List<Writable> o1, List<Writable> o2) {
return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
}
});
assertEquals(expOut, out); assertEquals(expOut, out);
} }
@Test @Test
public void testUniqueMultiCol(){ @DisplayName("Test Unique Multi Col")
void testUniqueMultiCol() {
Schema schema = new Schema.Builder() Schema schema = new Schema.Builder().addColumnInteger("col0").addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").build();
.addColumnInteger("col0")
.addColumnCategorical("col1", "state0", "state1", "state2")
.addColumnDouble("col2").build();
List<List<Writable>> inputData = new ArrayList<>(); List<List<Writable>> inputData = new ArrayList<>();
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1))); inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1)));
@ -223,149 +159,103 @@ public class ExecutionTest extends BaseSparkTest {
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1))); inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1))); inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1)));
JavaRDD<List<Writable>> rdd = sc.parallelize(inputData); JavaRDD<List<Writable>> rdd = sc.parallelize(inputData);
Map<String, List<Writable>> l = AnalyzeSpark.getUnique(Arrays.asList("col0", "col1"), schema, rdd);
Map<String,List<Writable>> l = AnalyzeSpark.getUnique(Arrays.asList("col0", "col1"), schema, rdd);
assertEquals(2, l.size()); assertEquals(2, l.size());
List<Writable> c0 = l.get("col0"); List<Writable> c0 = l.get("col0");
assertEquals(3, c0.size()); assertEquals(3, c0.size());
assertTrue(c0.contains(new IntWritable(0)) && c0.contains(new IntWritable(1)) && c0.contains(new IntWritable(2))); assertTrue(c0.contains(new IntWritable(0)) && c0.contains(new IntWritable(1)) && c0.contains(new IntWritable(2)));
List<Writable> c1 = l.get("col1"); List<Writable> c1 = l.get("col1");
assertEquals(3, c1.size()); assertEquals(3, c1.size());
assertTrue(c1.contains(new Text("state0")) && c1.contains(new Text("state1")) && c1.contains(new Text("state2"))); assertTrue(c1.contains(new Text("state0")) && c1.contains(new Text("state1")) && c1.contains(new Text("state2")));
} }
@Test(timeout = 60000L) @Test
@Ignore("AB 2019/05/21 - Fine locally, timeouts on CI - Issue #7657 and #7771") @Disabled("AB 2019/05/21 - Fine locally, timeouts on CI - Issue #7657 and #7771")
public void testPythonExecution() throws Exception { @DisplayName("Test Python Execution")
Schema schema = new Schema.Builder().addColumnInteger("col0") void testPythonExecution() {
.addColumnString("col1").addColumnDouble("col2").build(); assertTimeout(ofMillis(60000), () -> {
Schema schema = new Schema.Builder().addColumnInteger("col0").addColumnString("col1").addColumnDouble("col2").build();
Schema finalSchema = new Schema.Builder().addColumnInteger("col0").addColumnInteger("col1").addColumnDouble("col2").build();
String pythonCode = "col1 = ['state0', 'state1', 'state2'].index(col1)\ncol2 += 10.0";
TransformProcess tp = new TransformProcess.Builder(schema).transform(PythonTransform.builder().code("first = np.sin(first)\nsecond = np.cos(second)").outputSchema(finalSchema).build()).build();
List<List<Writable>> inputData = new ArrayList<>();
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1)));
JavaRDD<List<Writable>> rdd = sc.parallelize(inputData);
List<List<Writable>> out = new ArrayList<>(SparkTransformExecutor.execute(rdd, tp).collect());
Collections.sort(out, new Comparator<List<Writable>>() {
Schema finalSchema = new Schema.Builder().addColumnInteger("col0") @Override
.addColumnInteger("col1").addColumnDouble("col2").build(); public int compare(List<Writable> o1, List<Writable> o2) {
String pythonCode = "col1 = ['state0', 'state1', 'state2'].index(col1)\ncol2 += 10.0"; return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
TransformProcess tp = new TransformProcess.Builder(schema).transform( }
PythonTransform.builder().code( });
"first = np.sin(first)\nsecond = np.cos(second)") List<List<Writable>> expected = new ArrayList<>();
.outputSchema(finalSchema).build() expected.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1)));
).build(); expected.add(Arrays.<Writable>asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1)));
List<List<Writable>> inputData = new ArrayList<>(); expected.add(Arrays.<Writable>asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); assertEquals(expected, out);
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1)));
JavaRDD<List<Writable>> rdd = sc.parallelize(inputData);
List<List<Writable>> out = new ArrayList<>(SparkTransformExecutor.execute(rdd, tp).collect());
Collections.sort(out, new Comparator<List<Writable>>() {
@Override
public int compare(List<Writable> o1, List<Writable> o2) {
return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
}
}); });
List<List<Writable>> expected = new ArrayList<>();
expected.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1)));
expected.add(Arrays.<Writable>asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1)));
expected.add(Arrays.<Writable>asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1)));
assertEquals(expected, out);
}
@Test(timeout = 60000L)
@Ignore("AB 2019/05/21 - Fine locally, timeouts on CI - Issue #7657 and #7771")
public void testPythonExecutionWithNDArrays() throws Exception {
long[] shape = new long[]{3, 2};
Schema schema = new Schema.Builder().addColumnInteger("id").addColumnNDArray("col1", shape)
.addColumnNDArray("col2", shape).build();
Schema finalSchema = new Schema.Builder().addColumnInteger("id").addColumnNDArray("col1", shape)
.addColumnNDArray("col2", shape).addColumnNDArray("col3", shape).build();
String pythonCode = "col3 = col1 + col2";
TransformProcess tp = new TransformProcess.Builder(schema).transform(
PythonTransform.builder().code(
"first = np.sin(first)\nsecond = np.cos(second)")
.outputSchema(schema).build()
).build();
INDArray zeros = Nd4j.zeros(shape);
INDArray ones = Nd4j.ones(shape);
INDArray twos = ones.add(ones);
List<List<Writable>> inputData = new ArrayList<>();
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new NDArrayWritable(zeros), new NDArrayWritable(zeros)));
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new NDArrayWritable(zeros), new NDArrayWritable(ones)));
inputData.add(Arrays.<Writable>asList(new IntWritable(2), new NDArrayWritable(ones), new NDArrayWritable(ones)));
JavaRDD<List<Writable>> rdd = sc.parallelize(inputData);
List<List<Writable>> out = new ArrayList<>(SparkTransformExecutor.execute(rdd, tp).collect());
Collections.sort(out, new Comparator<List<Writable>>() {
@Override
public int compare(List<Writable> o1, List<Writable> o2) {
return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
}
});
List<List<Writable>> expected = new ArrayList<>();
expected.add(Arrays.<Writable>asList(new IntWritable(0), new NDArrayWritable(zeros), new NDArrayWritable(zeros), new NDArrayWritable(zeros)));
expected.add(Arrays.<Writable>asList(new IntWritable(1), new NDArrayWritable(zeros), new NDArrayWritable(ones), new NDArrayWritable(ones)));
expected.add(Arrays.<Writable>asList(new IntWritable(2), new NDArrayWritable(ones), new NDArrayWritable(ones), new NDArrayWritable(twos)));
} }
@Test @Test
public void testFirstDigitTransformBenfordsLaw(){ @Disabled("AB 2019/05/21 - Fine locally, timeouts on CI - Issue #7657 and #7771")
Schema s = new Schema.Builder() @DisplayName("Test Python Execution With ND Arrays")
.addColumnString("data") void testPythonExecutionWithNDArrays() {
.addColumnDouble("double") assertTimeout(ofMillis(60000), () -> {
.addColumnString("stringNumber") long[] shape = new long[] { 3, 2 };
.build(); Schema schema = new Schema.Builder().addColumnInteger("id").addColumnNDArray("col1", shape).addColumnNDArray("col2", shape).build();
Schema finalSchema = new Schema.Builder().addColumnInteger("id").addColumnNDArray("col1", shape).addColumnNDArray("col2", shape).addColumnNDArray("col3", shape).build();
String pythonCode = "col3 = col1 + col2";
TransformProcess tp = new TransformProcess.Builder(schema).transform(PythonTransform.builder().code("first = np.sin(first)\nsecond = np.cos(second)").outputSchema(schema).build()).build();
INDArray zeros = Nd4j.zeros(shape);
INDArray ones = Nd4j.ones(shape);
INDArray twos = ones.add(ones);
List<List<Writable>> inputData = new ArrayList<>();
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new NDArrayWritable(zeros), new NDArrayWritable(zeros)));
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new NDArrayWritable(zeros), new NDArrayWritable(ones)));
inputData.add(Arrays.<Writable>asList(new IntWritable(2), new NDArrayWritable(ones), new NDArrayWritable(ones)));
JavaRDD<List<Writable>> rdd = sc.parallelize(inputData);
List<List<Writable>> out = new ArrayList<>(SparkTransformExecutor.execute(rdd, tp).collect());
Collections.sort(out, new Comparator<List<Writable>>() {
List<List<Writable>> in = Arrays.asList( @Override
Arrays.<Writable>asList(new Text("a"), new DoubleWritable(3.14159), new Text("8e-4")), public int compare(List<Writable> o1, List<Writable> o2) {
Arrays.<Writable>asList(new Text("a2"), new DoubleWritable(3.14159), new Text("7e-4")), return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
Arrays.<Writable>asList(new Text("b"), new DoubleWritable(2.71828), new Text("7e2")), }
Arrays.<Writable>asList(new Text("c"), new DoubleWritable(1.61803), new Text("6e8")), });
Arrays.<Writable>asList(new Text("c"), new DoubleWritable(1.61803), new Text("2.0")), List<List<Writable>> expected = new ArrayList<>();
Arrays.<Writable>asList(new Text("c"), new DoubleWritable(1.61803), new Text("2.1")), expected.add(Arrays.<Writable>asList(new IntWritable(0), new NDArrayWritable(zeros), new NDArrayWritable(zeros), new NDArrayWritable(zeros)));
Arrays.<Writable>asList(new Text("c"), new DoubleWritable(1.61803), new Text("2.2")), expected.add(Arrays.<Writable>asList(new IntWritable(1), new NDArrayWritable(zeros), new NDArrayWritable(ones), new NDArrayWritable(ones)));
Arrays.<Writable>asList(new Text("c"), new DoubleWritable(-2), new Text("non numerical"))); expected.add(Arrays.<Writable>asList(new IntWritable(2), new NDArrayWritable(ones), new NDArrayWritable(ones), new NDArrayWritable(twos)));
});
//Test Benfords law use case:
TransformProcess tp = new TransformProcess.Builder(s)
.firstDigitTransform("double", "fdDouble", FirstDigitTransform.Mode.EXCEPTION_ON_INVALID)
.firstDigitTransform("stringNumber", "stringNumber", FirstDigitTransform.Mode.INCLUDE_OTHER_CATEGORY)
.removeAllColumnsExceptFor("stringNumber")
.categoricalToOneHot("stringNumber")
.reduce(new Reducer.Builder(ReduceOp.Sum).build())
.build();
JavaRDD<List<Writable>> rdd = sc.parallelize(in);
List<List<Writable>> out = SparkTransformExecutor.execute(rdd, tp).collect();
assertEquals(1, out.size());
List<Writable> l = out.get(0);
List<Writable> exp = Arrays.<Writable>asList(
new IntWritable(0), //0
new IntWritable(0), //1
new IntWritable(3), //2
new IntWritable(0), //3
new IntWritable(0), //4
new IntWritable(0), //5
new IntWritable(1), //6
new IntWritable(2), //7
new IntWritable(1), //8
new IntWritable(0), //9
new IntWritable(1)); //Other
assertEquals(exp, l);
} }
@Test
@DisplayName("Test First Digit Transform Benfords Law")
void testFirstDigitTransformBenfordsLaw() {
Schema s = new Schema.Builder().addColumnString("data").addColumnDouble("double").addColumnString("stringNumber").build();
List<List<Writable>> in = Arrays.asList(Arrays.<Writable>asList(new Text("a"), new DoubleWritable(3.14159), new Text("8e-4")), Arrays.<Writable>asList(new Text("a2"), new DoubleWritable(3.14159), new Text("7e-4")), Arrays.<Writable>asList(new Text("b"), new DoubleWritable(2.71828), new Text("7e2")), Arrays.<Writable>asList(new Text("c"), new DoubleWritable(1.61803), new Text("6e8")), Arrays.<Writable>asList(new Text("c"), new DoubleWritable(1.61803), new Text("2.0")), Arrays.<Writable>asList(new Text("c"), new DoubleWritable(1.61803), new Text("2.1")), Arrays.<Writable>asList(new Text("c"), new DoubleWritable(1.61803), new Text("2.2")), Arrays.<Writable>asList(new Text("c"), new DoubleWritable(-2), new Text("non numerical")));
// Test Benfords law use case:
TransformProcess tp = new TransformProcess.Builder(s).firstDigitTransform("double", "fdDouble", FirstDigitTransform.Mode.EXCEPTION_ON_INVALID).firstDigitTransform("stringNumber", "stringNumber", FirstDigitTransform.Mode.INCLUDE_OTHER_CATEGORY).removeAllColumnsExceptFor("stringNumber").categoricalToOneHot("stringNumber").reduce(new Reducer.Builder(ReduceOp.Sum).build()).build();
JavaRDD<List<Writable>> rdd = sc.parallelize(in);
List<List<Writable>> out = SparkTransformExecutor.execute(rdd, tp).collect();
assertEquals(1, out.size());
List<Writable> l = out.get(0);
List<Writable> exp = Arrays.<Writable>asList(// 0
new IntWritable(0), // 1
new IntWritable(0), // 2
new IntWritable(3), // 3
new IntWritable(0), // 4
new IntWritable(0), // 5
new IntWritable(0), // 6
new IntWritable(1), // 7
new IntWritable(2), // 8
new IntWritable(1), // 9
new IntWritable(0), // Other
new IntWritable(1));
assertEquals(exp, l);
}
} }

View File

@ -89,14 +89,22 @@
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>junit</groupId> <groupId>org.junit.jupiter</groupId>
<artifactId>junit</artifactId> <artifactId>junit-jupiter-api</artifactId>
<version>${junit.version}</version> </dependency>
<dependency>
<groupId>org.junit.vintage</groupId>
<artifactId>junit-vintage-engine</artifactId>
</dependency>
<dependency>
<groupId>com.tngtech.archunit</groupId>
<artifactId>archunit-junit5-engine</artifactId>
<version>${archunit.version}</version>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency> <dependency>
<groupId>com.tngtech.archunit</groupId> <groupId>com.tngtech.archunit</groupId>
<artifactId>archunit-junit4</artifactId> <artifactId>archunit-junit5-api</artifactId>
<version>${archunit.version}</version> <version>${archunit.version}</version>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>

View File

@ -34,10 +34,18 @@
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>junit</groupId> <groupId>org.junit.jupiter</groupId>
<artifactId>junit</artifactId> <artifactId>junit-jupiter-api</artifactId>
<version>${junit.version}</version>
<scope>provided</scope> <scope>provided</scope>
</dependency> </dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-engine</artifactId>
<version>${junit.version}</version>
<scope>provided</scope>
</dependency>
<dependency> <dependency>
<groupId>org.nd4j</groupId> <groupId>org.nd4j</groupId>
<artifactId>nd4j-api</artifactId> <artifactId>nd4j-api</artifactId>

View File

@ -17,17 +17,13 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.deeplearning4j; package org.deeplearning4j;
import ch.qos.logback.classic.LoggerContext; import ch.qos.logback.classic.LoggerContext;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.Pointer;
import org.junit.After; import org.junit.jupiter.api.*;
import org.junit.Before;
import org.junit.Rule;
import org.junit.rules.TestName;
import org.junit.rules.Timeout;
import org.nd4j.common.base.Preconditions; import org.nd4j.common.base.Preconditions;
import org.nd4j.common.config.ND4JSystemProperties; import org.nd4j.common.config.ND4JSystemProperties;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
@ -37,23 +33,22 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.profiler.ProfilerConfig; import org.nd4j.linalg.profiler.ProfilerConfig;
import org.slf4j.ILoggerFactory; import org.slf4j.ILoggerFactory;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import java.lang.management.ManagementFactory; import java.lang.management.ManagementFactory;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Properties; import java.util.Properties;
import static org.junit.jupiter.api.Assumptions.assumeTrue;
import static org.junit.Assume.assumeTrue; import org.junit.jupiter.api.extension.ExtendWith;
@Slf4j @Slf4j
@DisplayName("Base DL 4 J Test")
public abstract class BaseDL4JTest { public abstract class BaseDL4JTest {
@Rule
public TestName name = new TestName();
@Rule
public Timeout timeout = Timeout.millis(getTimeoutMilliseconds());
protected long startTime; protected long startTime;
protected int threadCountBefore; protected int threadCountBefore;
private final int DEFAULT_THREADS = Runtime.getRuntime().availableProcessors(); private final int DEFAULT_THREADS = Runtime.getRuntime().availableProcessors();
@ -63,32 +58,32 @@ public abstract class BaseDL4JTest {
* {@link org.nd4j.linalg.factory.Environment#setMaxMasterThreads(int)} * {@link org.nd4j.linalg.factory.Environment#setMaxMasterThreads(int)}
* @return Number of threads to use for C++ op execution * @return Number of threads to use for C++ op execution
*/ */
public int numThreads(){ public int numThreads() {
return DEFAULT_THREADS; return DEFAULT_THREADS;
} }
/** /**
* Override this method to set the default timeout for methods in the test class * Override this method to set the default timeout for methods in the test class
*/ */
public long getTimeoutMilliseconds(){ public long getTimeoutMilliseconds() {
return 90_000; return 90_000;
} }
/** /**
* Override this to set the profiling mode for the tests defined in the child class * Override this to set the profiling mode for the tests defined in the child class
*/ */
public OpExecutioner.ProfilingMode getProfilingMode(){ public OpExecutioner.ProfilingMode getProfilingMode() {
return OpExecutioner.ProfilingMode.SCOPE_PANIC; return OpExecutioner.ProfilingMode.SCOPE_PANIC;
} }
/** /**
* Override this to set the datatype of the tests defined in the child class * Override this to set the datatype of the tests defined in the child class
*/ */
public DataType getDataType(){ public DataType getDataType() {
return DataType.DOUBLE; return DataType.DOUBLE;
} }
public DataType getDefaultFPDataType(){ public DataType getDefaultFPDataType() {
return getDataType(); return getDataType();
} }
@ -97,8 +92,8 @@ public abstract class BaseDL4JTest {
/** /**
* @return True if integration tests maven profile is enabled, false otherwise. * @return True if integration tests maven profile is enabled, false otherwise.
*/ */
public static boolean isIntegrationTests(){ public static boolean isIntegrationTests() {
if(integrationTest == null){ if (integrationTest == null) {
String prop = System.getenv("DL4J_INTEGRATION_TESTS"); String prop = System.getenv("DL4J_INTEGRATION_TESTS");
integrationTest = Boolean.parseBoolean(prop); integrationTest = Boolean.parseBoolean(prop);
} }
@ -110,14 +105,15 @@ public abstract class BaseDL4JTest {
* This can be used to dynamically skip integration tests when the integration test profile is not enabled. * This can be used to dynamically skip integration tests when the integration test profile is not enabled.
* Note that the integration test profile is not enabled by default - "integration-tests" profile * Note that the integration test profile is not enabled by default - "integration-tests" profile
*/ */
public static void skipUnlessIntegrationTests(){ public static void skipUnlessIntegrationTests() {
assumeTrue("Skipping integration test - integration profile is not enabled", isIntegrationTests()); assumeTrue(isIntegrationTests(), "Skipping integration test - integration profile is not enabled");
} }
@Before @BeforeEach
public void beforeTest(){ @Timeout(90000L)
log.info("{}.{}", getClass().getSimpleName(), name.getMethodName()); void beforeTest(TestInfo testInfo) {
//Suppress ND4J initialization - don't need this logged for every test... log.info("{}.{}", getClass().getSimpleName(), testInfo.getTestMethod().get().getName());
// Suppress ND4J initialization - don't need this logged for every test...
System.setProperty(ND4JSystemProperties.LOG_INITIALIZATION, "false"); System.setProperty(ND4JSystemProperties.LOG_INITIALIZATION, "false");
System.setProperty(ND4JSystemProperties.ND4J_IGNORE_AVX, "true"); System.setProperty(ND4JSystemProperties.ND4J_IGNORE_AVX, "true");
Nd4j.getExecutioner().setProfilingMode(getProfilingMode()); Nd4j.getExecutioner().setProfilingMode(getProfilingMode());
@ -128,83 +124,71 @@ public abstract class BaseDL4JTest {
Nd4j.getExecutioner().enableVerboseMode(false); Nd4j.getExecutioner().enableVerboseMode(false);
int numThreads = numThreads(); int numThreads = numThreads();
Preconditions.checkState(numThreads > 0, "Number of threads must be > 0"); Preconditions.checkState(numThreads > 0, "Number of threads must be > 0");
if(numThreads != Nd4j.getEnvironment().maxMasterThreads()) { if (numThreads != Nd4j.getEnvironment().maxMasterThreads()) {
Nd4j.getEnvironment().setMaxMasterThreads(numThreads); Nd4j.getEnvironment().setMaxMasterThreads(numThreads);
} }
startTime = System.currentTimeMillis(); startTime = System.currentTimeMillis();
threadCountBefore = ManagementFactory.getThreadMXBean().getThreadCount(); threadCountBefore = ManagementFactory.getThreadMXBean().getThreadCount();
} }
@After @AfterEach
public void afterTest(){ void afterTest(TestInfo testInfo) {
//Attempt to keep workspaces isolated between tests // Attempt to keep workspaces isolated between tests
Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread();
MemoryWorkspace currWS = Nd4j.getMemoryManager().getCurrentWorkspace(); MemoryWorkspace currWS = Nd4j.getMemoryManager().getCurrentWorkspace();
Nd4j.getMemoryManager().setCurrentWorkspace(null); Nd4j.getMemoryManager().setCurrentWorkspace(null);
if(currWS != null){ if (currWS != null) {
//Not really safe to continue testing under this situation... other tests will likely fail with obscure // Not really safe to continue testing under this situation... other tests will likely fail with obscure
// errors that are hard to track back to this // errors that are hard to track back to this
log.error("Open workspace leaked from test! Exiting - {}, isOpen = {} - {}", currWS.getId(), currWS.isScopeActive(), currWS); log.error("Open workspace leaked from test! Exiting - {}, isOpen = {} - {}", currWS.getId(), currWS.isScopeActive(), currWS);
System.out.println("Open workspace leaked from test! Exiting - " + currWS.getId() + ", isOpen = " + currWS.isScopeActive() + " - " + currWS); System.out.println("Open workspace leaked from test! Exiting - " + currWS.getId() + ", isOpen = " + currWS.isScopeActive() + " - " + currWS);
System.out.flush(); System.out.flush();
//Try to flush logs also: // Try to flush logs also:
try{ Thread.sleep(1000); } catch (InterruptedException e){ } try {
ILoggerFactory lf = LoggerFactory.getILoggerFactory(); Thread.sleep(1000);
if( lf instanceof LoggerContext){ } catch (InterruptedException e) {
((LoggerContext)lf).stop(); }
ILoggerFactory lf = LoggerFactory.getILoggerFactory();
if (lf instanceof LoggerContext) {
((LoggerContext) lf).stop();
}
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
} }
try{ Thread.sleep(1000); } catch (InterruptedException e){ }
System.exit(1); System.exit(1);
} }
StringBuilder sb = new StringBuilder(); StringBuilder sb = new StringBuilder();
long maxPhys = Pointer.maxPhysicalBytes(); long maxPhys = Pointer.maxPhysicalBytes();
long maxBytes = Pointer.maxBytes(); long maxBytes = Pointer.maxBytes();
long currPhys = Pointer.physicalBytes(); long currPhys = Pointer.physicalBytes();
long currBytes = Pointer.totalBytes(); long currBytes = Pointer.totalBytes();
long jvmTotal = Runtime.getRuntime().totalMemory(); long jvmTotal = Runtime.getRuntime().totalMemory();
long jvmMax = Runtime.getRuntime().maxMemory(); long jvmMax = Runtime.getRuntime().maxMemory();
int threadsAfter = ManagementFactory.getThreadMXBean().getThreadCount(); int threadsAfter = ManagementFactory.getThreadMXBean().getThreadCount();
long duration = System.currentTimeMillis() - startTime; long duration = System.currentTimeMillis() - startTime;
sb.append(getClass().getSimpleName()).append(".").append(name.getMethodName()) sb.append(getClass().getSimpleName()).append(".").append(testInfo.getTestMethod().get().getName()).append(": ").append(duration).append(" ms").append(", threadCount: (").append(threadCountBefore).append("->").append(threadsAfter).append(")").append(", jvmTotal=").append(jvmTotal).append(", jvmMax=").append(jvmMax).append(", totalBytes=").append(currBytes).append(", maxBytes=").append(maxBytes).append(", currPhys=").append(currPhys).append(", maxPhys=").append(maxPhys);
.append(": ").append(duration).append(" ms")
.append(", threadCount: (").append(threadCountBefore).append("->").append(threadsAfter).append(")")
.append(", jvmTotal=").append(jvmTotal)
.append(", jvmMax=").append(jvmMax)
.append(", totalBytes=").append(currBytes).append(", maxBytes=").append(maxBytes)
.append(", currPhys=").append(currPhys).append(", maxPhys=").append(maxPhys);
List<MemoryWorkspace> ws = Nd4j.getWorkspaceManager().getAllWorkspacesForCurrentThread(); List<MemoryWorkspace> ws = Nd4j.getWorkspaceManager().getAllWorkspacesForCurrentThread();
if(ws != null && ws.size() > 0){ if (ws != null && ws.size() > 0) {
long currSize = 0; long currSize = 0;
for(MemoryWorkspace w : ws){ for (MemoryWorkspace w : ws) {
currSize += w.getCurrentSize(); currSize += w.getCurrentSize();
} }
if(currSize > 0){ if (currSize > 0) {
sb.append(", threadWSSize=").append(currSize) sb.append(", threadWSSize=").append(currSize).append(" (").append(ws.size()).append(" WSs)");
.append(" (").append(ws.size()).append(" WSs)");
} }
} }
Properties p = Nd4j.getExecutioner().getEnvironmentInformation(); Properties p = Nd4j.getExecutioner().getEnvironmentInformation();
Object o = p.get("cuda.devicesInformation"); Object o = p.get("cuda.devicesInformation");
if(o instanceof List){ if (o instanceof List) {
List<Map<String,Object>> l = (List<Map<String, Object>>) o; List<Map<String, Object>> l = (List<Map<String, Object>>) o;
if(l.size() > 0) { if (l.size() > 0) {
sb.append(" [").append(l.size()).append(" GPUs: ");
sb.append(" [").append(l.size())
.append(" GPUs: ");
for (int i = 0; i < l.size(); i++) { for (int i = 0; i < l.size(); i++) {
Map<String,Object> m = l.get(i); Map<String, Object> m = l.get(i);
if(i > 0) if (i > 0)
sb.append(","); sb.append(",");
sb.append("(").append(m.get("cuda.freeMemory")).append(" free, ") sb.append("(").append(m.get("cuda.freeMemory")).append(" free, ").append(m.get("cuda.totalMemory")).append(" total)");
.append(m.get("cuda.totalMemory")).append(" total)");
} }
sb.append("]"); sb.append("]");
} }

View File

@ -41,8 +41,15 @@
</dependency> </dependency>
<dependency> <dependency>
<groupId>junit</groupId> <groupId>org.junit.jupiter</groupId>
<artifactId>junit</artifactId> <artifactId>junit-jupiter-api</artifactId>
<version>${junit.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-engine</artifactId>
<version>${junit.version}</version>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
</dependencies> </dependencies>

View File

@ -17,70 +17,56 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.deeplearning4j.common.config; package org.deeplearning4j.common.config;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNotNull;
import org.deeplearning4j.common.config.dummies.TestAbstract; import org.deeplearning4j.common.config.dummies.TestAbstract;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@DisplayName("Dl 4 J Class Loading Test")
class DL4JClassLoadingTest {
public class DL4JClassLoadingTest {
private static final String PACKAGE_PREFIX = "org.deeplearning4j.common.config.dummies."; private static final String PACKAGE_PREFIX = "org.deeplearning4j.common.config.dummies.";
@Test @Test
public void testCreateNewInstance_constructorWithoutArguments() { @DisplayName("Test Create New Instance _ constructor Without Arguments")
void testCreateNewInstance_constructorWithoutArguments() {
/* Given */ /* Given */
String className = PACKAGE_PREFIX + "TestDummy"; String className = PACKAGE_PREFIX + "TestDummy";
/* When */ /* When */
Object instance = DL4JClassLoading.createNewInstance(className); Object instance = DL4JClassLoading.createNewInstance(className);
/* Then */ /* Then */
assertNotNull(instance); assertNotNull(instance);
assertEquals(className, instance.getClass().getName()); assertEquals(className, instance.getClass().getName());
} }
@Test @Test
public void testCreateNewInstance_constructorWithArgument_implicitArgumentTypes() { @DisplayName("Test Create New Instance _ constructor With Argument _ implicit Argument Types")
void testCreateNewInstance_constructorWithArgument_implicitArgumentTypes() {
/* Given */ /* Given */
String className = PACKAGE_PREFIX + "TestColor"; String className = PACKAGE_PREFIX + "TestColor";
/* When */ /* When */
TestAbstract instance = DL4JClassLoading.createNewInstance(className, TestAbstract.class, "white"); TestAbstract instance = DL4JClassLoading.createNewInstance(className, TestAbstract.class, "white");
/* Then */ /* Then */
assertNotNull(instance); assertNotNull(instance);
assertEquals(className, instance.getClass().getName()); assertEquals(className, instance.getClass().getName());
} }
@Test @Test
public void testCreateNewInstance_constructorWithArgument_explicitArgumentTypes() { @DisplayName("Test Create New Instance _ constructor With Argument _ explicit Argument Types")
void testCreateNewInstance_constructorWithArgument_explicitArgumentTypes() {
/* Given */ /* Given */
String colorClassName = PACKAGE_PREFIX + "TestColor"; String colorClassName = PACKAGE_PREFIX + "TestColor";
String rectangleClassName = PACKAGE_PREFIX + "TestRectangle"; String rectangleClassName = PACKAGE_PREFIX + "TestRectangle";
/* When */ /* When */
TestAbstract color = DL4JClassLoading.createNewInstance( TestAbstract color = DL4JClassLoading.createNewInstance(colorClassName, Object.class, new Class<?>[] { int.class, int.class, int.class }, 45, 175, 200);
colorClassName, TestAbstract rectangle = DL4JClassLoading.createNewInstance(rectangleClassName, Object.class, new Class<?>[] { int.class, int.class, TestAbstract.class }, 10, 15, color);
Object.class,
new Class<?>[]{ int.class, int.class, int.class },
45, 175, 200);
TestAbstract rectangle = DL4JClassLoading.createNewInstance(
rectangleClassName,
Object.class,
new Class<?>[]{ int.class, int.class, TestAbstract.class },
10, 15, color);
/* Then */ /* Then */
assertNotNull(color); assertNotNull(color);
assertEquals(colorClassName, color.getClass().getName()); assertEquals(colorClassName, color.getClass().getName());
assertNotNull(rectangle); assertNotNull(rectangle);
assertEquals(rectangleClassName, rectangle.getClass().getName()); assertEquals(rectangleClassName, rectangle.getClass().getName());
} }

View File

@ -49,11 +49,6 @@
</dependencyManagement> </dependencyManagement>
<dependencies> <dependencies>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-tsne</artifactId>
<version>${project.version}</version>
</dependency>
<dependency> <dependency>
<groupId>org.deeplearning4j</groupId> <groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-datasets</artifactId> <artifactId>deeplearning4j-datasets</artifactId>
@ -99,8 +94,12 @@
<version>${commons-compress.version}</version> <version>${commons-compress.version}</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>junit</groupId> <groupId>org.junit.jupiter</groupId>
<artifactId>junit</artifactId> <artifactId>junit-jupiter-api</artifactId>
</dependency>
<dependency>
<groupId>org.junit.vintage</groupId>
<artifactId>junit-vintage-engine</artifactId>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.deeplearning4j</groupId> <groupId>org.deeplearning4j</groupId>

View File

@ -17,15 +17,16 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.deeplearning4j.datasets; package org.deeplearning4j.datasets;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.base.MnistFetcher; import org.deeplearning4j.datasets.base.MnistFetcher;
import org.deeplearning4j.common.resources.DL4JResources; import org.deeplearning4j.common.resources.DL4JResources;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.junit.*; import org.junit.jupiter.api.AfterAll;
import org.junit.rules.TemporaryFolder; import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.junit.rules.Timeout; import org.junit.rules.Timeout;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition; import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
@ -33,69 +34,67 @@ import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.conditions.Conditions; import org.nd4j.linalg.indexing.conditions.Conditions;
import java.io.File; import java.io.File;
import java.nio.file.Path;
import java.util.HashSet; import java.util.HashSet;
import java.util.Set; import java.util.Set;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals; @DisplayName("Mnist Fetcher Test")
import static org.junit.Assert.assertFalse; class MnistFetcherTest extends BaseDL4JTest {
import static org.junit.Assert.assertTrue;
public class MnistFetcherTest extends BaseDL4JTest {
@ClassRule
public static TemporaryFolder testDir = new TemporaryFolder();
@Rule
public Timeout timeout = Timeout.seconds(300);
@BeforeClass @BeforeAll
public static void setup() throws Exception { static void setup(@TempDir Path tempPath) throws Exception {
DL4JResources.setBaseDirectory(testDir.newFolder()); DL4JResources.setBaseDirectory(tempPath.toFile());
} }
@AfterClass @AfterAll
public static void after() { static void after() {
DL4JResources.resetBaseDirectoryLocation(); DL4JResources.resetBaseDirectoryLocation();
} }
@Test @Test
public void testMnist() throws Exception { @DisplayName("Test Mnist")
void testMnist() throws Exception {
DataSetIterator iter = new MnistDataSetIterator(32, 60000, false, true, false, -1); DataSetIterator iter = new MnistDataSetIterator(32, 60000, false, true, false, -1);
int count = 0; int count = 0;
while(iter.hasNext()){ while (iter.hasNext()) {
DataSet ds = iter.next(); DataSet ds = iter.next();
INDArray arr = ds.getFeatures().sum(1); INDArray arr = ds.getFeatures().sum(1);
int countMatch = Nd4j.getExecutioner().execAndReturn(new MatchCondition(arr, Conditions.equals(0))).z().getInt(0); int countMatch = Nd4j.getExecutioner().execAndReturn(new MatchCondition(arr, Conditions.equals(0))).z().getInt(0);
assertEquals(0, countMatch); assertEquals(0, countMatch);
count++; count++;
} }
assertEquals(60000/32, count); assertEquals(60000 / 32, count);
count = 0; count = 0;
iter = new MnistDataSetIterator(32, false, 12345); iter = new MnistDataSetIterator(32, false, 12345);
while(iter.hasNext()){ while (iter.hasNext()) {
DataSet ds = iter.next(); DataSet ds = iter.next();
INDArray arr = ds.getFeatures().sum(1); INDArray arr = ds.getFeatures().sum(1);
int countMatch = Nd4j.getExecutioner().execAndReturn(new MatchCondition(arr, Conditions.equals(0))).z().getInt(0); int countMatch = Nd4j.getExecutioner().execAndReturn(new MatchCondition(arr, Conditions.equals(0))).z().getInt(0);
assertEquals(0, countMatch); assertEquals(0, countMatch);
count++; count++;
} }
assertEquals((int)Math.ceil(10000/32.0), count); assertEquals((int) Math.ceil(10000 / 32.0), count);
} }
@Test @Test
public void testMnistDataFetcher() throws Exception { @DisplayName("Test Mnist Data Fetcher")
void testMnistDataFetcher() throws Exception {
MnistFetcher mnistFetcher = new MnistFetcher(); MnistFetcher mnistFetcher = new MnistFetcher();
File mnistDir = mnistFetcher.downloadAndUntar(); File mnistDir = mnistFetcher.downloadAndUntar();
assertTrue(mnistDir.isDirectory()); assertTrue(mnistDir.isDirectory());
} }
// @Test // @Test
public void testMnistSubset() throws Exception { public void testMnistSubset() throws Exception {
final int numExamples = 100; final int numExamples = 100;
MnistDataSetIterator iter1 = new MnistDataSetIterator(10, numExamples, false, true, true, 123); MnistDataSetIterator iter1 = new MnistDataSetIterator(10, numExamples, false, true, true, 123);
int examples1 = 0; int examples1 = 0;
int itCount1 = 0; int itCount1 = 0;
@ -105,7 +104,6 @@ public class MnistFetcherTest extends BaseDL4JTest {
} }
assertEquals(10, itCount1); assertEquals(10, itCount1);
assertEquals(100, examples1); assertEquals(100, examples1);
MnistDataSetIterator iter2 = new MnistDataSetIterator(10, numExamples, false, true, true, 123); MnistDataSetIterator iter2 = new MnistDataSetIterator(10, numExamples, false, true, true, 123);
int examples2 = 0; int examples2 = 0;
int itCount2 = 0; int itCount2 = 0;
@ -116,7 +114,6 @@ public class MnistFetcherTest extends BaseDL4JTest {
assertFalse(iter2.hasNext()); assertFalse(iter2.hasNext());
assertEquals(10, itCount2); assertEquals(10, itCount2);
assertEquals(100, examples2); assertEquals(100, examples2);
MnistDataSetIterator iter3 = new MnistDataSetIterator(19, numExamples, false, true, true, 123); MnistDataSetIterator iter3 = new MnistDataSetIterator(19, numExamples, false, true, true, 123);
int examples3 = 0; int examples3 = 0;
int itCount3 = 0; int itCount3 = 0;
@ -125,51 +122,45 @@ public class MnistFetcherTest extends BaseDL4JTest {
examples3 += iter3.next().numExamples(); examples3 += iter3.next().numExamples();
} }
assertEquals(100, examples3); assertEquals(100, examples3);
assertEquals((int)Math.ceil(100/19.0), itCount3); assertEquals((int) Math.ceil(100 / 19.0), itCount3);
MnistDataSetIterator iter4 = new MnistDataSetIterator(32, true, 12345); MnistDataSetIterator iter4 = new MnistDataSetIterator(32, true, 12345);
int count4 = 0; int count4 = 0;
while(iter4.hasNext()){ while (iter4.hasNext()) {
count4 += iter4.next().numExamples(); count4 += iter4.next().numExamples();
} }
assertEquals(60000, count4); assertEquals(60000, count4);
} }
@Test @Test
public void testSubsetRepeatability() throws Exception { @DisplayName("Test Subset Repeatability")
void testSubsetRepeatability() throws Exception {
DataSetIterator it = new MnistDataSetIterator(1, 1, false, false, true, 0); DataSetIterator it = new MnistDataSetIterator(1, 1, false, false, true, 0);
DataSet d1 = it.next(); DataSet d1 = it.next();
for( int i=0; i<10; i++ ) { for (int i = 0; i < 10; i++) {
it.reset(); it.reset();
DataSet d2 = it.next(); DataSet d2 = it.next();
assertEquals(d1.get(0).getFeatures(), d2.get(0).getFeatures()); assertEquals(d1.get(0).getFeatures(), d2.get(0).getFeatures());
} }
// Check larger number:
//Check larger number:
it = new MnistDataSetIterator(8, 32, false, false, true, 12345); it = new MnistDataSetIterator(8, 32, false, false, true, 12345);
Set<String> featureLabelSet = new HashSet<>(); Set<String> featureLabelSet = new HashSet<>();
while(it.hasNext()){ while (it.hasNext()) {
DataSet ds = it.next(); DataSet ds = it.next();
INDArray f = ds.getFeatures(); INDArray f = ds.getFeatures();
INDArray l = ds.getLabels(); INDArray l = ds.getLabels();
for (int i = 0; i < f.size(0); i++) {
for( int i=0; i<f.size(0); i++ ){
featureLabelSet.add(f.getRow(i).toString() + "\t" + l.getRow(i).toString()); featureLabelSet.add(f.getRow(i).toString() + "\t" + l.getRow(i).toString());
} }
} }
assertEquals(32, featureLabelSet.size()); assertEquals(32, featureLabelSet.size());
for (int i = 0; i < 3; i++) {
for( int i=0; i<3; i++ ){
it.reset(); it.reset();
Set<String> flSet2 = new HashSet<>(); Set<String> flSet2 = new HashSet<>();
while(it.hasNext()){ while (it.hasNext()) {
DataSet ds = it.next(); DataSet ds = it.next();
INDArray f = ds.getFeatures(); INDArray f = ds.getFeatures();
INDArray l = ds.getLabels(); INDArray l = ds.getLabels();
for (int j = 0; j < f.size(0); j++) {
for( int j=0; j<f.size(0); j++ ){
flSet2.add(f.getRow(j).toString() + "\t" + l.getRow(j).toString()); flSet2.add(f.getRow(j).toString() + "\t" + l.getRow(j).toString());
} }
} }

View File

@ -17,10 +17,8 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.deeplearning4j.datasets.datavec; package org.deeplearning4j.datasets.datavec;
import org.junit.rules.Timeout; import org.junit.rules.Timeout;
import org.nd4j.shade.guava.io.Files; import org.nd4j.shade.guava.io.Files;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
@ -47,8 +45,8 @@ import org.datavec.image.recordreader.ImageRecordReader;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils; import org.deeplearning4j.TestUtils;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.TemporaryFolder; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.MultiDataSet;
@ -58,42 +56,40 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import org.nd4j.common.resources.Resources; import org.nd4j.common.resources.Resources;
import java.io.*; import java.io.*;
import java.net.URI; import java.net.URI;
import java.util.*; import java.util.*;
import static org.junit.jupiter.api.Assertions.*;
import static org.junit.Assert.*;
import static org.nd4j.linalg.indexing.NDArrayIndex.all; import static org.nd4j.linalg.indexing.NDArrayIndex.all;
import static org.nd4j.linalg.indexing.NDArrayIndex.interval; import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
import static org.nd4j.linalg.indexing.NDArrayIndex.point; import static org.nd4j.linalg.indexing.NDArrayIndex.point;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest { @DisplayName("Record Reader Multi Data Set Iterator Test")
class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest {
@Rule @TempDir
public TemporaryFolder temporaryFolder = new TemporaryFolder(); public Path temporaryFolder;
@Rule @Rule
public Timeout timeout = Timeout.seconds(300); public Timeout timeout = Timeout.seconds(300);
@Test @Test
public void testsBasic() throws Exception { @DisplayName("Tests Basic")
//Load details from CSV files; single input/output -> compare to RecordReaderDataSetIterator void testsBasic() throws Exception {
// Load details from CSV files; single input/output -> compare to RecordReaderDataSetIterator
RecordReader rr = new CSVRecordReader(0, ','); RecordReader rr = new CSVRecordReader(0, ',');
rr.initialize(new FileSplit(Resources.asFile("iris.txt"))); rr.initialize(new FileSplit(Resources.asFile("iris.txt")));
RecordReaderDataSetIterator rrdsi = new RecordReaderDataSetIterator(rr, 10, 4, 3); RecordReaderDataSetIterator rrdsi = new RecordReaderDataSetIterator(rr, 10, 4, 3);
RecordReader rr2 = new CSVRecordReader(0, ','); RecordReader rr2 = new CSVRecordReader(0, ',');
rr2.initialize(new FileSplit(Resources.asFile("iris.txt"))); rr2.initialize(new FileSplit(Resources.asFile("iris.txt")));
MultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10).addReader("reader", rr2).addInput("reader", 0, 3).addOutputOneHot("reader", 4, 3).build();
MultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10).addReader("reader", rr2)
.addInput("reader", 0, 3).addOutputOneHot("reader", 4, 3).build();
while (rrdsi.hasNext()) { while (rrdsi.hasNext()) {
DataSet ds = rrdsi.next(); DataSet ds = rrdsi.next();
INDArray fds = ds.getFeatures(); INDArray fds = ds.getFeatures();
INDArray lds = ds.getLabels(); INDArray lds = ds.getLabels();
MultiDataSet mds = rrmdsi.next(); MultiDataSet mds = rrmdsi.next();
assertEquals(1, mds.getFeatures().length); assertEquals(1, mds.getFeatures().length);
assertEquals(1, mds.getLabels().length); assertEquals(1, mds.getLabels().length);
@ -101,49 +97,36 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest {
assertNull(mds.getLabelsMaskArrays()); assertNull(mds.getLabelsMaskArrays());
INDArray fmds = mds.getFeatures(0); INDArray fmds = mds.getFeatures(0);
INDArray lmds = mds.getLabels(0); INDArray lmds = mds.getLabels(0);
assertNotNull(fmds); assertNotNull(fmds);
assertNotNull(lmds); assertNotNull(lmds);
assertEquals(fds, fmds); assertEquals(fds, fmds);
assertEquals(lds, lmds); assertEquals(lds, lmds);
} }
assertFalse(rrmdsi.hasNext()); assertFalse(rrmdsi.hasNext());
// need to manually extract
//need to manually extract File rootDir = temporaryFolder.toFile();
File rootDir = temporaryFolder.newFolder();
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(rootDir); new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(rootDir);
new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive(rootDir); new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive(rootDir);
new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive(rootDir); new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive(rootDir);
} }
// Load time series from CSV sequence files; compare to SequenceRecordReaderDataSetIterator
//Load time series from CSV sequence files; compare to SequenceRecordReaderDataSetIterator
String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt"); String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt");
String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabels_%d.txt"); String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabels_%d.txt");
SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ","); SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false);
SequenceRecordReaderDataSetIterator iter =
new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false);
SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ","); SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ",");
featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
MultiDataSetIterator srrmdsi = new RecordReaderMultiDataSetIterator.Builder(1).addSequenceReader("in", featureReader2).addSequenceReader("out", labelReader2).addInput("in").addOutputOneHot("out", 0, 4).build();
MultiDataSetIterator srrmdsi = new RecordReaderMultiDataSetIterator.Builder(1)
.addSequenceReader("in", featureReader2).addSequenceReader("out", labelReader2).addInput("in")
.addOutputOneHot("out", 0, 4).build();
while (iter.hasNext()) { while (iter.hasNext()) {
DataSet ds = iter.next(); DataSet ds = iter.next();
INDArray fds = ds.getFeatures(); INDArray fds = ds.getFeatures();
INDArray lds = ds.getLabels(); INDArray lds = ds.getLabels();
MultiDataSet mds = srrmdsi.next(); MultiDataSet mds = srrmdsi.next();
assertEquals(1, mds.getFeatures().length); assertEquals(1, mds.getFeatures().length);
assertEquals(1, mds.getLabels().length); assertEquals(1, mds.getLabels().length);
@ -151,10 +134,8 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest {
assertNull(mds.getLabelsMaskArrays()); assertNull(mds.getLabelsMaskArrays());
INDArray fmds = mds.getFeatures(0); INDArray fmds = mds.getFeatures(0);
INDArray lmds = mds.getLabels(0); INDArray lmds = mds.getLabels(0);
assertNotNull(fmds); assertNotNull(fmds);
assertNotNull(lmds); assertNotNull(lmds);
assertEquals(fds, fmds); assertEquals(fds, fmds);
assertEquals(lds, lmds); assertEquals(lds, lmds);
} }
@ -162,16 +143,13 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest {
} }
@Test @Test
public void testsBasicMeta() throws Exception { @DisplayName("Tests Basic Meta")
//As per testBasic - but also loading metadata void testsBasicMeta() throws Exception {
// As per testBasic - but also loading metadata
RecordReader rr2 = new CSVRecordReader(0, ','); RecordReader rr2 = new CSVRecordReader(0, ',');
rr2.initialize(new FileSplit(Resources.asFile("iris.txt"))); rr2.initialize(new FileSplit(Resources.asFile("iris.txt")));
RecordReaderMultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10).addReader("reader", rr2).addInput("reader", 0, 3).addOutputOneHot("reader", 4, 3).build();
RecordReaderMultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10)
.addReader("reader", rr2).addInput("reader", 0, 3).addOutputOneHot("reader", 4, 3).build();
rrmdsi.setCollectMetaData(true); rrmdsi.setCollectMetaData(true);
int count = 0; int count = 0;
while (rrmdsi.hasNext()) { while (rrmdsi.hasNext()) {
MultiDataSet mds = rrmdsi.next(); MultiDataSet mds = rrmdsi.next();
@ -183,27 +161,22 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest {
} }
@Test @Test
public void testSplittingCSV() throws Exception { @DisplayName("Test Splitting CSV")
//Here's the idea: take Iris, and split it up into 2 inputs and 2 output arrays void testSplittingCSV() throws Exception {
//Inputs: columns 0 and 1-2 // Here's the idea: take Iris, and split it up into 2 inputs and 2 output arrays
//Outputs: columns 3, and 4->OneHot // Inputs: columns 0 and 1-2
//need to manually extract // Outputs: columns 3, and 4->OneHot
// need to manually extract
RecordReader rr = new CSVRecordReader(0, ','); RecordReader rr = new CSVRecordReader(0, ',');
rr.initialize(new FileSplit(Resources.asFile("iris.txt"))); rr.initialize(new FileSplit(Resources.asFile("iris.txt")));
RecordReaderDataSetIterator rrdsi = new RecordReaderDataSetIterator(rr, 10, 4, 3); RecordReaderDataSetIterator rrdsi = new RecordReaderDataSetIterator(rr, 10, 4, 3);
RecordReader rr2 = new CSVRecordReader(0, ','); RecordReader rr2 = new CSVRecordReader(0, ',');
rr2.initialize(new FileSplit(Resources.asFile("iris.txt"))); rr2.initialize(new FileSplit(Resources.asFile("iris.txt")));
MultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10).addReader("reader", rr2).addInput("reader", 0, 0).addInput("reader", 1, 2).addOutput("reader", 3, 3).addOutputOneHot("reader", 4, 3).build();
MultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10).addReader("reader", rr2)
.addInput("reader", 0, 0).addInput("reader", 1, 2).addOutput("reader", 3, 3)
.addOutputOneHot("reader", 4, 3).build();
while (rrdsi.hasNext()) { while (rrdsi.hasNext()) {
DataSet ds = rrdsi.next(); DataSet ds = rrdsi.next();
INDArray fds = ds.getFeatures(); INDArray fds = ds.getFeatures();
INDArray lds = ds.getLabels(); INDArray lds = ds.getLabels();
MultiDataSet mds = rrmdsi.next(); MultiDataSet mds = rrmdsi.next();
assertEquals(2, mds.getFeatures().length); assertEquals(2, mds.getFeatures().length);
assertEquals(2, mds.getLabels().length); assertEquals(2, mds.getLabels().length);
@ -211,20 +184,15 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest {
assertNull(mds.getLabelsMaskArrays()); assertNull(mds.getLabelsMaskArrays());
INDArray[] fmds = mds.getFeatures(); INDArray[] fmds = mds.getFeatures();
INDArray[] lmds = mds.getLabels(); INDArray[] lmds = mds.getLabels();
assertNotNull(fmds); assertNotNull(fmds);
assertNotNull(lmds); assertNotNull(lmds);
for (int i = 0; i < fmds.length; i++) for (int i = 0; i < fmds.length; i++) assertNotNull(fmds[i]);
assertNotNull(fmds[i]); for (int i = 0; i < lmds.length; i++) assertNotNull(lmds[i]);
for (int i = 0; i < lmds.length; i++) // Get the subsets of the original iris data
assertNotNull(lmds[i]); INDArray expIn1 = fds.get(all(), interval(0, 0, true));
//Get the subsets of the original iris data
INDArray expIn1 = fds.get(all(), interval(0,0,true));
INDArray expIn2 = fds.get(all(), interval(1, 2, true)); INDArray expIn2 = fds.get(all(), interval(1, 2, true));
INDArray expOut1 = fds.get(all(), interval(3,3,true)); INDArray expOut1 = fds.get(all(), interval(3, 3, true));
INDArray expOut2 = lds; INDArray expOut2 = lds;
assertEquals(expIn1, fmds[0]); assertEquals(expIn1, fmds[0]);
assertEquals(expIn2, fmds[1]); assertEquals(expIn2, fmds[1]);
assertEquals(expOut1, lmds[0]); assertEquals(expOut1, lmds[0]);
@ -234,18 +202,15 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest {
} }
@Test @Test
public void testSplittingCSVMeta() throws Exception { @DisplayName("Test Splitting CSV Meta")
//Here's the idea: take Iris, and split it up into 2 inputs and 2 output arrays void testSplittingCSVMeta() throws Exception {
//Inputs: columns 0 and 1-2 // Here's the idea: take Iris, and split it up into 2 inputs and 2 output arrays
//Outputs: columns 3, and 4->OneHot // Inputs: columns 0 and 1-2
// Outputs: columns 3, and 4->OneHot
RecordReader rr2 = new CSVRecordReader(0, ','); RecordReader rr2 = new CSVRecordReader(0, ',');
rr2.initialize(new FileSplit(Resources.asFile("iris.txt"))); rr2.initialize(new FileSplit(Resources.asFile("iris.txt")));
RecordReaderMultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10).addReader("reader", rr2).addInput("reader", 0, 0).addInput("reader", 1, 2).addOutput("reader", 3, 3).addOutputOneHot("reader", 4, 3).build();
RecordReaderMultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10)
.addReader("reader", rr2).addInput("reader", 0, 0).addInput("reader", 1, 2)
.addOutput("reader", 3, 3).addOutputOneHot("reader", 4, 3).build();
rrmdsi.setCollectMetaData(true); rrmdsi.setCollectMetaData(true);
int count = 0; int count = 0;
while (rrmdsi.hasNext()) { while (rrmdsi.hasNext()) {
MultiDataSet mds = rrmdsi.next(); MultiDataSet mds = rrmdsi.next();
@ -257,42 +222,33 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest {
} }
@Test @Test
public void testSplittingCSVSequence() throws Exception { @DisplayName("Test Splitting CSV Sequence")
//Idea: take CSV sequences, and split "csvsequence_i.txt" into two separate inputs; keep "csvSequencelables_i.txt" void testSplittingCSVSequence() throws Exception {
// Idea: take CSV sequences, and split "csvsequence_i.txt" into two separate inputs; keep "csvSequencelables_i.txt"
// as standard one-hot output // as standard one-hot output
//need to manually extract // need to manually extract
File rootDir = temporaryFolder.newFolder(); File rootDir = temporaryFolder.toFile();
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(rootDir); new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(rootDir);
new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive(rootDir); new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive(rootDir);
new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive(rootDir); new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive(rootDir);
} }
String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt"); String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt");
String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabels_%d.txt"); String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabels_%d.txt");
SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ","); SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false);
SequenceRecordReaderDataSetIterator iter =
new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false);
SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ","); SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ",");
featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
MultiDataSetIterator srrmdsi = new RecordReaderMultiDataSetIterator.Builder(1).addSequenceReader("seq1", featureReader2).addSequenceReader("seq2", labelReader2).addInput("seq1", 0, 1).addInput("seq1", 2, 2).addOutputOneHot("seq2", 0, 4).build();
MultiDataSetIterator srrmdsi = new RecordReaderMultiDataSetIterator.Builder(1)
.addSequenceReader("seq1", featureReader2).addSequenceReader("seq2", labelReader2)
.addInput("seq1", 0, 1).addInput("seq1", 2, 2).addOutputOneHot("seq2", 0, 4).build();
while (iter.hasNext()) { while (iter.hasNext()) {
DataSet ds = iter.next(); DataSet ds = iter.next();
INDArray fds = ds.getFeatures(); INDArray fds = ds.getFeatures();
INDArray lds = ds.getLabels(); INDArray lds = ds.getLabels();
MultiDataSet mds = srrmdsi.next(); MultiDataSet mds = srrmdsi.next();
assertEquals(2, mds.getFeatures().length); assertEquals(2, mds.getFeatures().length);
assertEquals(1, mds.getLabels().length); assertEquals(1, mds.getLabels().length);
@ -300,17 +256,12 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest {
assertNull(mds.getLabelsMaskArrays()); assertNull(mds.getLabelsMaskArrays());
INDArray[] fmds = mds.getFeatures(); INDArray[] fmds = mds.getFeatures();
INDArray[] lmds = mds.getLabels(); INDArray[] lmds = mds.getLabels();
assertNotNull(fmds); assertNotNull(fmds);
assertNotNull(lmds); assertNotNull(lmds);
for (int i = 0; i < fmds.length; i++) for (int i = 0; i < fmds.length; i++) assertNotNull(fmds[i]);
assertNotNull(fmds[i]); for (int i = 0; i < lmds.length; i++) assertNotNull(lmds[i]);
for (int i = 0; i < lmds.length; i++)
assertNotNull(lmds[i]);
INDArray expIn1 = fds.get(all(), NDArrayIndex.interval(0, 1, true), all()); INDArray expIn1 = fds.get(all(), NDArrayIndex.interval(0, 1, true), all());
INDArray expIn2 = fds.get(all(), NDArrayIndex.interval(2, 2, true), all()); INDArray expIn2 = fds.get(all(), NDArrayIndex.interval(2, 2, true), all());
assertEquals(expIn1, fmds[0]); assertEquals(expIn1, fmds[0]);
assertEquals(expIn2, fmds[1]); assertEquals(expIn2, fmds[1]);
assertEquals(lds, lmds[0]); assertEquals(lds, lmds[0]);
@ -319,36 +270,29 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest {
} }
@Test @Test
public void testSplittingCSVSequenceMeta() throws Exception { @DisplayName("Test Splitting CSV Sequence Meta")
//Idea: take CSV sequences, and split "csvsequence_i.txt" into two separate inputs; keep "csvSequencelables_i.txt" void testSplittingCSVSequenceMeta() throws Exception {
// Idea: take CSV sequences, and split "csvsequence_i.txt" into two separate inputs; keep "csvSequencelables_i.txt"
// as standard one-hot output // as standard one-hot output
//need to manually extract // need to manually extract
File rootDir = temporaryFolder.newFolder(); File rootDir = temporaryFolder.toFile();
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(rootDir); new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(rootDir);
new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive(rootDir); new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive(rootDir);
new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive(rootDir); new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive(rootDir);
} }
String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt"); String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt");
String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabels_%d.txt"); String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabels_%d.txt");
SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ","); SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ","); SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ",");
featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
RecordReaderMultiDataSetIterator srrmdsi = new RecordReaderMultiDataSetIterator.Builder(1).addSequenceReader("seq1", featureReader2).addSequenceReader("seq2", labelReader2).addInput("seq1", 0, 1).addInput("seq1", 2, 2).addOutputOneHot("seq2", 0, 4).build();
RecordReaderMultiDataSetIterator srrmdsi = new RecordReaderMultiDataSetIterator.Builder(1)
.addSequenceReader("seq1", featureReader2).addSequenceReader("seq2", labelReader2)
.addInput("seq1", 0, 1).addInput("seq1", 2, 2).addOutputOneHot("seq2", 0, 4).build();
srrmdsi.setCollectMetaData(true); srrmdsi.setCollectMetaData(true);
int count = 0; int count = 0;
while (srrmdsi.hasNext()) { while (srrmdsi.hasNext()) {
MultiDataSet mds = srrmdsi.next(); MultiDataSet mds = srrmdsi.next();
@ -359,34 +303,27 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest {
assertEquals(3, count); assertEquals(3, count);
} }
@Test @Test
public void testInputValidation() { @DisplayName("Test Input Validation")
void testInputValidation() {
//Test: no readers // Test: no readers
try { try {
MultiDataSetIterator r = new RecordReaderMultiDataSetIterator.Builder(1).addInput("something") MultiDataSetIterator r = new RecordReaderMultiDataSetIterator.Builder(1).addInput("something").addOutput("something").build();
.addOutput("something").build();
fail("Should have thrown exception"); fail("Should have thrown exception");
} catch (Exception e) { } catch (Exception e) {
} }
// Test: reference to reader that doesn't exist
//Test: reference to reader that doesn't exist
try { try {
RecordReader rr = new CSVRecordReader(0, ','); RecordReader rr = new CSVRecordReader(0, ',');
rr.initialize(new FileSplit(Resources.asFile("iris.txt"))); rr.initialize(new FileSplit(Resources.asFile("iris.txt")));
MultiDataSetIterator r = new RecordReaderMultiDataSetIterator.Builder(1).addReader("iris", rr).addInput("thisDoesntExist", 0, 3).addOutputOneHot("iris", 4, 3).build();
MultiDataSetIterator r = new RecordReaderMultiDataSetIterator.Builder(1).addReader("iris", rr)
.addInput("thisDoesntExist", 0, 3).addOutputOneHot("iris", 4, 3).build();
fail("Should have thrown exception"); fail("Should have thrown exception");
} catch (Exception e) { } catch (Exception e) {
} }
// Test: no inputs or outputs
//Test: no inputs or outputs
try { try {
RecordReader rr = new CSVRecordReader(0, ','); RecordReader rr = new CSVRecordReader(0, ',');
rr.initialize(new FileSplit(Resources.asFile("iris.txt"))); rr.initialize(new FileSplit(Resources.asFile("iris.txt")));
MultiDataSetIterator r = new RecordReaderMultiDataSetIterator.Builder(1).addReader("iris", rr).build(); MultiDataSetIterator r = new RecordReaderMultiDataSetIterator.Builder(1).addReader("iris", rr).build();
fail("Should have thrown exception"); fail("Should have thrown exception");
} catch (Exception e) { } catch (Exception e) {
@ -394,81 +331,55 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest {
} }
@Test @Test
public void testVariableLengthTS() throws Exception { @DisplayName("Test Variable Length TS")
//need to manually extract void testVariableLengthTS() throws Exception {
File rootDir = temporaryFolder.newFolder(); // need to manually extract
File rootDir = temporaryFolder.toFile();
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(rootDir); new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(rootDir);
new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive(rootDir); new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive(rootDir);
new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive(rootDir); new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive(rootDir);
} }
String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt"); String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt");
String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabelsShort_%d.txt"); String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabelsShort_%d.txt");
// Set up SequenceRecordReaderDataSetIterators for comparison
//Set up SequenceRecordReaderDataSetIterators for comparison
SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ","); SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ","); SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ",");
featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
SequenceRecordReaderDataSetIterator iterAlignStart = new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_START);
SequenceRecordReaderDataSetIterator iterAlignStart = new SequenceRecordReaderDataSetIterator(featureReader, SequenceRecordReaderDataSetIterator iterAlignEnd = new SequenceRecordReaderDataSetIterator(featureReader2, labelReader2, 1, 4, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
labelReader, 1, 4, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_START); // Set up
SequenceRecordReaderDataSetIterator iterAlignEnd = new SequenceRecordReaderDataSetIterator(featureReader2,
labelReader2, 1, 4, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
//Set up
SequenceRecordReader featureReader3 = new CSVSequenceRecordReader(1, ","); SequenceRecordReader featureReader3 = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader3 = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader3 = new CSVSequenceRecordReader(1, ",");
featureReader3.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); featureReader3.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader3.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); labelReader3.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
SequenceRecordReader featureReader4 = new CSVSequenceRecordReader(1, ","); SequenceRecordReader featureReader4 = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader4 = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader4 = new CSVSequenceRecordReader(1, ",");
featureReader4.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); featureReader4.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader4.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); labelReader4.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
RecordReaderMultiDataSetIterator rrmdsiStart = new RecordReaderMultiDataSetIterator.Builder(1).addSequenceReader("in", featureReader3).addSequenceReader("out", labelReader3).addInput("in").addOutputOneHot("out", 0, 4).sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_START).build();
RecordReaderMultiDataSetIterator rrmdsiStart = new RecordReaderMultiDataSetIterator.Builder(1) RecordReaderMultiDataSetIterator rrmdsiEnd = new RecordReaderMultiDataSetIterator.Builder(1).addSequenceReader("in", featureReader4).addSequenceReader("out", labelReader4).addInput("in").addOutputOneHot("out", 0, 4).sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_END).build();
.addSequenceReader("in", featureReader3).addSequenceReader("out", labelReader3).addInput("in")
.addOutputOneHot("out", 0, 4)
.sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_START).build();
RecordReaderMultiDataSetIterator rrmdsiEnd = new RecordReaderMultiDataSetIterator.Builder(1)
.addSequenceReader("in", featureReader4).addSequenceReader("out", labelReader4).addInput("in")
.addOutputOneHot("out", 0, 4)
.sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_END).build();
while (iterAlignStart.hasNext()) { while (iterAlignStart.hasNext()) {
DataSet dsStart = iterAlignStart.next(); DataSet dsStart = iterAlignStart.next();
DataSet dsEnd = iterAlignEnd.next(); DataSet dsEnd = iterAlignEnd.next();
MultiDataSet mdsStart = rrmdsiStart.next(); MultiDataSet mdsStart = rrmdsiStart.next();
MultiDataSet mdsEnd = rrmdsiEnd.next(); MultiDataSet mdsEnd = rrmdsiEnd.next();
assertEquals(1, mdsStart.getFeatures().length); assertEquals(1, mdsStart.getFeatures().length);
assertEquals(1, mdsStart.getLabels().length); assertEquals(1, mdsStart.getLabels().length);
//assertEquals(1, mdsStart.getFeaturesMaskArrays().length); //Features data is always longer -> don't need mask arrays for it // assertEquals(1, mdsStart.getFeaturesMaskArrays().length); //Features data is always longer -> don't need mask arrays for it
assertEquals(1, mdsStart.getLabelsMaskArrays().length); assertEquals(1, mdsStart.getLabelsMaskArrays().length);
assertEquals(1, mdsEnd.getFeatures().length); assertEquals(1, mdsEnd.getFeatures().length);
assertEquals(1, mdsEnd.getLabels().length); assertEquals(1, mdsEnd.getLabels().length);
//assertEquals(1, mdsEnd.getFeaturesMaskArrays().length); // assertEquals(1, mdsEnd.getFeaturesMaskArrays().length);
assertEquals(1, mdsEnd.getLabelsMaskArrays().length); assertEquals(1, mdsEnd.getLabelsMaskArrays().length);
assertEquals(dsStart.getFeatures(), mdsStart.getFeatures(0)); assertEquals(dsStart.getFeatures(), mdsStart.getFeatures(0));
assertEquals(dsStart.getLabels(), mdsStart.getLabels(0)); assertEquals(dsStart.getLabels(), mdsStart.getLabels(0));
assertEquals(dsStart.getLabelsMaskArray(), mdsStart.getLabelsMaskArray(0)); assertEquals(dsStart.getLabelsMaskArray(), mdsStart.getLabelsMaskArray(0));
assertEquals(dsEnd.getFeatures(), mdsEnd.getFeatures(0)); assertEquals(dsEnd.getFeatures(), mdsEnd.getFeatures(0));
assertEquals(dsEnd.getLabels(), mdsEnd.getLabels(0)); assertEquals(dsEnd.getLabels(), mdsEnd.getLabels(0));
assertEquals(dsEnd.getLabelsMaskArray(), mdsEnd.getLabelsMaskArray(0)); assertEquals(dsEnd.getLabelsMaskArray(), mdsEnd.getLabelsMaskArray(0));
@ -477,57 +388,40 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest {
assertFalse(rrmdsiEnd.hasNext()); assertFalse(rrmdsiEnd.hasNext());
} }
@Test @Test
public void testVariableLengthTSMeta() throws Exception { @DisplayName("Test Variable Length TS Meta")
//need to manually extract void testVariableLengthTSMeta() throws Exception {
File rootDir = temporaryFolder.newFolder(); // need to manually extract
File rootDir = temporaryFolder.toFile();
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(rootDir); new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(rootDir);
new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive(rootDir); new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive(rootDir);
new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive(rootDir); new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive(rootDir);
} }
//Set up SequenceRecordReaderDataSetIterators for comparison // Set up SequenceRecordReaderDataSetIterators for comparison
String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt"); String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt");
String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabelsShort_%d.txt"); String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabelsShort_%d.txt");
// Set up
//Set up
SequenceRecordReader featureReader3 = new CSVSequenceRecordReader(1, ","); SequenceRecordReader featureReader3 = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader3 = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader3 = new CSVSequenceRecordReader(1, ",");
featureReader3.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); featureReader3.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader3.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); labelReader3.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
SequenceRecordReader featureReader4 = new CSVSequenceRecordReader(1, ","); SequenceRecordReader featureReader4 = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader4 = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader4 = new CSVSequenceRecordReader(1, ",");
featureReader4.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); featureReader4.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader4.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); labelReader4.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
RecordReaderMultiDataSetIterator rrmdsiStart = new RecordReaderMultiDataSetIterator.Builder(1).addSequenceReader("in", featureReader3).addSequenceReader("out", labelReader3).addInput("in").addOutputOneHot("out", 0, 4).sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_START).build();
RecordReaderMultiDataSetIterator rrmdsiStart = new RecordReaderMultiDataSetIterator.Builder(1) RecordReaderMultiDataSetIterator rrmdsiEnd = new RecordReaderMultiDataSetIterator.Builder(1).addSequenceReader("in", featureReader4).addSequenceReader("out", labelReader4).addInput("in").addOutputOneHot("out", 0, 4).sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_END).build();
.addSequenceReader("in", featureReader3).addSequenceReader("out", labelReader3).addInput("in")
.addOutputOneHot("out", 0, 4)
.sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_START).build();
RecordReaderMultiDataSetIterator rrmdsiEnd = new RecordReaderMultiDataSetIterator.Builder(1)
.addSequenceReader("in", featureReader4).addSequenceReader("out", labelReader4).addInput("in")
.addOutputOneHot("out", 0, 4)
.sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_END).build();
rrmdsiStart.setCollectMetaData(true); rrmdsiStart.setCollectMetaData(true);
rrmdsiEnd.setCollectMetaData(true); rrmdsiEnd.setCollectMetaData(true);
int count = 0; int count = 0;
while (rrmdsiStart.hasNext()) { while (rrmdsiStart.hasNext()) {
MultiDataSet mdsStart = rrmdsiStart.next(); MultiDataSet mdsStart = rrmdsiStart.next();
MultiDataSet mdsEnd = rrmdsiEnd.next(); MultiDataSet mdsEnd = rrmdsiEnd.next();
MultiDataSet mdsStartFromMeta = rrmdsiStart.loadFromMetaData(mdsStart.getExampleMetaData(RecordMetaData.class));
MultiDataSet mdsStartFromMeta =
rrmdsiStart.loadFromMetaData(mdsStart.getExampleMetaData(RecordMetaData.class));
MultiDataSet mdsEndFromMeta = rrmdsiEnd.loadFromMetaData(mdsEnd.getExampleMetaData(RecordMetaData.class)); MultiDataSet mdsEndFromMeta = rrmdsiEnd.loadFromMetaData(mdsEnd.getExampleMetaData(RecordMetaData.class));
assertEquals(mdsStart, mdsStartFromMeta); assertEquals(mdsStart, mdsStartFromMeta);
assertEquals(mdsEnd, mdsEndFromMeta); assertEquals(mdsEnd, mdsEndFromMeta);
count++; count++;
} }
assertFalse(rrmdsiStart.hasNext()); assertFalse(rrmdsiStart.hasNext());
@ -536,53 +430,37 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest {
} }
@Test @Test
public void testImagesRRDMSI() throws Exception { @DisplayName("Test Images RRDMSI")
File parentDir = temporaryFolder.newFolder(); void testImagesRRDMSI() throws Exception {
File parentDir = temporaryFolder.toFile();
parentDir.deleteOnExit(); parentDir.deleteOnExit();
String str1 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Zico/"); String str1 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Zico/");
String str2 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Ziwang_Xu/"); String str2 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Ziwang_Xu/");
File f1 = new File(str1); File f1 = new File(str1);
File f2 = new File(str2); File f2 = new File(str2);
f1.mkdirs(); f1.mkdirs();
f2.mkdirs(); f2.mkdirs();
TestUtils.writeStreamToFile(new File(FilenameUtils.concat(f1.getPath(), "Zico_0001.jpg")), new ClassPathResource("lfwtest/Zico/Zico_0001.jpg").getInputStream());
TestUtils.writeStreamToFile(new File(FilenameUtils.concat(f1.getPath(), "Zico_0001.jpg")), TestUtils.writeStreamToFile(new File(FilenameUtils.concat(f2.getPath(), "Ziwang_Xu_0001.jpg")), new ClassPathResource("lfwtest/Ziwang_Xu/Ziwang_Xu_0001.jpg").getInputStream());
new ClassPathResource("lfwtest/Zico/Zico_0001.jpg").getInputStream());
TestUtils.writeStreamToFile(new File(FilenameUtils.concat(f2.getPath(), "Ziwang_Xu_0001.jpg")),
new ClassPathResource("lfwtest/Ziwang_Xu/Ziwang_Xu_0001.jpg").getInputStream());
int outputNum = 2; int outputNum = 2;
Random r = new Random(12345); Random r = new Random(12345);
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator(); ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
ImageRecordReader rr1 = new ImageRecordReader(10, 10, 1, labelMaker); ImageRecordReader rr1 = new ImageRecordReader(10, 10, 1, labelMaker);
ImageRecordReader rr1s = new ImageRecordReader(5, 5, 1, labelMaker); ImageRecordReader rr1s = new ImageRecordReader(5, 5, 1, labelMaker);
rr1.initialize(new FileSplit(parentDir)); rr1.initialize(new FileSplit(parentDir));
rr1s.initialize(new FileSplit(parentDir)); rr1s.initialize(new FileSplit(parentDir));
MultiDataSetIterator trainDataIterator = new RecordReaderMultiDataSetIterator.Builder(1).addReader("rr1", rr1).addReader("rr1s", rr1s).addInput("rr1", 0, 0).addInput("rr1s", 0, 0).addOutputOneHot("rr1s", 1, outputNum).build();
// Now, do the same thing with ImageRecordReader, and check we get the same results:
MultiDataSetIterator trainDataIterator = new RecordReaderMultiDataSetIterator.Builder(1).addReader("rr1", rr1)
.addReader("rr1s", rr1s).addInput("rr1", 0, 0).addInput("rr1s", 0, 0)
.addOutputOneHot("rr1s", 1, outputNum).build();
//Now, do the same thing with ImageRecordReader, and check we get the same results:
ImageRecordReader rr1_b = new ImageRecordReader(10, 10, 1, labelMaker); ImageRecordReader rr1_b = new ImageRecordReader(10, 10, 1, labelMaker);
ImageRecordReader rr1s_b = new ImageRecordReader(5, 5, 1, labelMaker); ImageRecordReader rr1s_b = new ImageRecordReader(5, 5, 1, labelMaker);
rr1_b.initialize(new FileSplit(parentDir)); rr1_b.initialize(new FileSplit(parentDir));
rr1s_b.initialize(new FileSplit(parentDir)); rr1s_b.initialize(new FileSplit(parentDir));
DataSetIterator dsi1 = new RecordReaderDataSetIterator(rr1_b, 1, 1, 2); DataSetIterator dsi1 = new RecordReaderDataSetIterator(rr1_b, 1, 1, 2);
DataSetIterator dsi2 = new RecordReaderDataSetIterator(rr1s_b, 1, 1, 2); DataSetIterator dsi2 = new RecordReaderDataSetIterator(rr1s_b, 1, 1, 2);
for (int i = 0; i < 2; i++) { for (int i = 0; i < 2; i++) {
MultiDataSet mds = trainDataIterator.next(); MultiDataSet mds = trainDataIterator.next();
DataSet d1 = dsi1.next(); DataSet d1 = dsi1.next();
DataSet d2 = dsi2.next(); DataSet d2 = dsi2.next();
assertEquals(d1.getFeatures(), mds.getFeatures(0)); assertEquals(d1.getFeatures(), mds.getFeatures(0));
assertEquals(d2.getFeatures(), mds.getFeatures(1)); assertEquals(d2.getFeatures(), mds.getFeatures(1));
assertEquals(d1.getLabels(), mds.getLabels(0)); assertEquals(d1.getLabels(), mds.getLabels(0));
@ -590,261 +468,180 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest {
} }
@Test @Test
public void testImagesRRDMSI_Batched() throws Exception { @DisplayName("Test Images RRDMSI _ Batched")
File parentDir = temporaryFolder.newFolder(); void testImagesRRDMSI_Batched() throws Exception {
File parentDir = temporaryFolder.toFile();
parentDir.deleteOnExit(); parentDir.deleteOnExit();
String str1 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Zico/"); String str1 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Zico/");
String str2 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Ziwang_Xu/"); String str2 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Ziwang_Xu/");
File f1 = new File(str1); File f1 = new File(str1);
File f2 = new File(str2); File f2 = new File(str2);
f1.mkdirs(); f1.mkdirs();
f2.mkdirs(); f2.mkdirs();
TestUtils.writeStreamToFile(new File(FilenameUtils.concat(f1.getPath(), "Zico_0001.jpg")), new ClassPathResource("lfwtest/Zico/Zico_0001.jpg").getInputStream());
TestUtils.writeStreamToFile(new File(FilenameUtils.concat(f1.getPath(), "Zico_0001.jpg")), TestUtils.writeStreamToFile(new File(FilenameUtils.concat(f2.getPath(), "Ziwang_Xu_0001.jpg")), new ClassPathResource("lfwtest/Ziwang_Xu/Ziwang_Xu_0001.jpg").getInputStream());
new ClassPathResource("lfwtest/Zico/Zico_0001.jpg").getInputStream());
TestUtils.writeStreamToFile(new File(FilenameUtils.concat(f2.getPath(), "Ziwang_Xu_0001.jpg")),
new ClassPathResource("lfwtest/Ziwang_Xu/Ziwang_Xu_0001.jpg").getInputStream());
int outputNum = 2; int outputNum = 2;
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator(); ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
ImageRecordReader rr1 = new ImageRecordReader(10, 10, 1, labelMaker); ImageRecordReader rr1 = new ImageRecordReader(10, 10, 1, labelMaker);
ImageRecordReader rr1s = new ImageRecordReader(5, 5, 1, labelMaker); ImageRecordReader rr1s = new ImageRecordReader(5, 5, 1, labelMaker);
URI[] uris = new FileSplit(parentDir).locations(); URI[] uris = new FileSplit(parentDir).locations();
rr1.initialize(new CollectionInputSplit(uris)); rr1.initialize(new CollectionInputSplit(uris));
rr1s.initialize(new CollectionInputSplit(uris)); rr1s.initialize(new CollectionInputSplit(uris));
MultiDataSetIterator trainDataIterator = new RecordReaderMultiDataSetIterator.Builder(2).addReader("rr1", rr1).addReader("rr1s", rr1s).addInput("rr1", 0, 0).addInput("rr1s", 0, 0).addOutputOneHot("rr1s", 1, outputNum).build();
MultiDataSetIterator trainDataIterator = new RecordReaderMultiDataSetIterator.Builder(2).addReader("rr1", rr1) // Now, do the same thing with ImageRecordReader, and check we get the same results:
.addReader("rr1s", rr1s).addInput("rr1", 0, 0).addInput("rr1s", 0, 0)
.addOutputOneHot("rr1s", 1, outputNum).build();
//Now, do the same thing with ImageRecordReader, and check we get the same results:
ImageRecordReader rr1_b = new ImageRecordReader(10, 10, 1, labelMaker); ImageRecordReader rr1_b = new ImageRecordReader(10, 10, 1, labelMaker);
ImageRecordReader rr1s_b = new ImageRecordReader(5, 5, 1, labelMaker); ImageRecordReader rr1s_b = new ImageRecordReader(5, 5, 1, labelMaker);
rr1_b.initialize(new FileSplit(parentDir)); rr1_b.initialize(new FileSplit(parentDir));
rr1s_b.initialize(new FileSplit(parentDir)); rr1s_b.initialize(new FileSplit(parentDir));
DataSetIterator dsi1 = new RecordReaderDataSetIterator(rr1_b, 2, 1, 2); DataSetIterator dsi1 = new RecordReaderDataSetIterator(rr1_b, 2, 1, 2);
DataSetIterator dsi2 = new RecordReaderDataSetIterator(rr1s_b, 2, 1, 2); DataSetIterator dsi2 = new RecordReaderDataSetIterator(rr1s_b, 2, 1, 2);
MultiDataSet mds = trainDataIterator.next(); MultiDataSet mds = trainDataIterator.next();
DataSet d1 = dsi1.next(); DataSet d1 = dsi1.next();
DataSet d2 = dsi2.next(); DataSet d2 = dsi2.next();
assertEquals(d1.getFeatures(), mds.getFeatures(0)); assertEquals(d1.getFeatures(), mds.getFeatures(0));
assertEquals(d2.getFeatures(), mds.getFeatures(1)); assertEquals(d2.getFeatures(), mds.getFeatures(1));
assertEquals(d1.getLabels(), mds.getLabels(0)); assertEquals(d1.getLabels(), mds.getLabels(0));
// Check label assignment:
//Check label assignment:
File currentFile = rr1_b.getCurrentFile(); File currentFile = rr1_b.getCurrentFile();
INDArray expLabels; INDArray expLabels;
if(currentFile.getAbsolutePath().contains("Zico")){ if (currentFile.getAbsolutePath().contains("Zico")) {
expLabels = Nd4j.create(new double[][] {{0, 1}, {1, 0}}); expLabels = Nd4j.create(new double[][] { { 0, 1 }, { 1, 0 } });
} else { } else {
expLabels = Nd4j.create(new double[][] {{1, 0}, {0, 1}}); expLabels = Nd4j.create(new double[][] { { 1, 0 }, { 0, 1 } });
} }
assertEquals(expLabels, d1.getLabels()); assertEquals(expLabels, d1.getLabels());
assertEquals(expLabels, d2.getLabels()); assertEquals(expLabels, d2.getLabels());
} }
@Test @Test
public void testTimeSeriesRandomOffset() { @DisplayName("Test Time Series Random Offset")
//2 in, 2 out, 3 total sequences of length [1,3,5] void testTimeSeriesRandomOffset() {
// 2 in, 2 out, 3 total sequences of length [1,3,5]
List<List<Writable>> seq1 = List<List<Writable>> seq1 = Arrays.asList(Arrays.<Writable>asList(new DoubleWritable(1.0), new DoubleWritable(2.0)));
Arrays.asList(Arrays.<Writable>asList(new DoubleWritable(1.0), new DoubleWritable(2.0))); List<List<Writable>> seq2 = Arrays.asList(Arrays.<Writable>asList(new DoubleWritable(10.0), new DoubleWritable(11.0)), Arrays.<Writable>asList(new DoubleWritable(20.0), new DoubleWritable(21.0)), Arrays.<Writable>asList(new DoubleWritable(30.0), new DoubleWritable(31.0)));
List<List<Writable>> seq2 = List<List<Writable>> seq3 = Arrays.asList(Arrays.<Writable>asList(new DoubleWritable(100.0), new DoubleWritable(101.0)), Arrays.<Writable>asList(new DoubleWritable(200.0), new DoubleWritable(201.0)), Arrays.<Writable>asList(new DoubleWritable(300.0), new DoubleWritable(301.0)), Arrays.<Writable>asList(new DoubleWritable(400.0), new DoubleWritable(401.0)), Arrays.<Writable>asList(new DoubleWritable(500.0), new DoubleWritable(501.0)));
Arrays.asList(Arrays.<Writable>asList(new DoubleWritable(10.0), new DoubleWritable(11.0)),
Arrays.<Writable>asList(new DoubleWritable(20.0), new DoubleWritable(21.0)),
Arrays.<Writable>asList(new DoubleWritable(30.0), new DoubleWritable(31.0)));
List<List<Writable>> seq3 =
Arrays.asList(Arrays.<Writable>asList(new DoubleWritable(100.0), new DoubleWritable(101.0)),
Arrays.<Writable>asList(new DoubleWritable(200.0), new DoubleWritable(201.0)),
Arrays.<Writable>asList(new DoubleWritable(300.0), new DoubleWritable(301.0)),
Arrays.<Writable>asList(new DoubleWritable(400.0), new DoubleWritable(401.0)),
Arrays.<Writable>asList(new DoubleWritable(500.0), new DoubleWritable(501.0)));
Collection<List<List<Writable>>> seqs = Arrays.asList(seq1, seq2, seq3); Collection<List<List<Writable>>> seqs = Arrays.asList(seq1, seq2, seq3);
SequenceRecordReader rr = new CollectionSequenceRecordReader(seqs); SequenceRecordReader rr = new CollectionSequenceRecordReader(seqs);
RecordReaderMultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(3).addSequenceReader("rr", rr).addInput("rr", 0, 0).addOutput("rr", 1, 1).timeSeriesRandomOffset(true, 1234L).build();
RecordReaderMultiDataSetIterator rrmdsi = // Provides seed for each minibatch
new RecordReaderMultiDataSetIterator.Builder(3).addSequenceReader("rr", rr).addInput("rr", 0, 0) Random r = new Random(1234);
.addOutput("rr", 1, 1).timeSeriesRandomOffset(true, 1234L).build();
Random r = new Random(1234); //Provides seed for each minibatch
long seed = r.nextLong(); long seed = r.nextLong();
Random r2 = new Random(seed); //Use same RNG seed in new RNG for each minibatch // Use same RNG seed in new RNG for each minibatch
int expOffsetSeq1 = r2.nextInt(5 - 1 + 1); //0 to 4 inclusive Random r2 = new Random(seed);
// 0 to 4 inclusive
int expOffsetSeq1 = r2.nextInt(5 - 1 + 1);
int expOffsetSeq2 = r2.nextInt(5 - 3 + 1); int expOffsetSeq2 = r2.nextInt(5 - 3 + 1);
int expOffsetSeq3 = 0; //Longest TS, always 0 // Longest TS, always 0
//With current seed: 3, 1, 0 int expOffsetSeq3 = 0;
// System.out.println(expOffsetSeq1 + "\t" + expOffsetSeq2 + "\t" + expOffsetSeq3); // With current seed: 3, 1, 0
// System.out.println(expOffsetSeq1 + "\t" + expOffsetSeq2 + "\t" + expOffsetSeq3);
MultiDataSet mds = rrmdsi.next(); MultiDataSet mds = rrmdsi.next();
INDArray expMask = Nd4j.create(new double[][] { { 0, 0, 0, 1, 0 }, { 0, 1, 1, 1, 0 }, { 1, 1, 1, 1, 1 } });
INDArray expMask = Nd4j.create(new double[][] {{0, 0, 0, 1, 0}, {0, 1, 1, 1, 0}, {1, 1, 1, 1, 1}});
assertEquals(expMask, mds.getFeaturesMaskArray(0)); assertEquals(expMask, mds.getFeaturesMaskArray(0));
assertEquals(expMask, mds.getLabelsMaskArray(0)); assertEquals(expMask, mds.getLabelsMaskArray(0));
INDArray f = mds.getFeatures(0); INDArray f = mds.getFeatures(0);
INDArray l = mds.getLabels(0); INDArray l = mds.getLabels(0);
INDArray expF1 = Nd4j.create(new double[] { 1.0 }, new int[] { 1, 1 });
INDArray expF1 = Nd4j.create(new double[] {1.0}, new int[]{1,1}); INDArray expL1 = Nd4j.create(new double[] { 2.0 }, new int[] { 1, 1 });
INDArray expL1 = Nd4j.create(new double[] {2.0}, new int[]{1,1}); INDArray expF2 = Nd4j.create(new double[] { 10, 20, 30 }, new int[] { 1, 3 });
INDArray expL2 = Nd4j.create(new double[] { 11, 21, 31 }, new int[] { 1, 3 });
INDArray expF2 = Nd4j.create(new double[] {10, 20, 30}, new int[]{1,3}); INDArray expF3 = Nd4j.create(new double[] { 100, 200, 300, 400, 500 }, new int[] { 1, 5 });
INDArray expL2 = Nd4j.create(new double[] {11, 21, 31}, new int[]{1,3}); INDArray expL3 = Nd4j.create(new double[] { 101, 201, 301, 401, 501 }, new int[] { 1, 5 });
assertEquals(expF1, f.get(point(0), all(), NDArrayIndex.interval(expOffsetSeq1, expOffsetSeq1 + 1)));
INDArray expF3 = Nd4j.create(new double[] {100, 200, 300, 400, 500}, new int[]{1,5}); assertEquals(expL1, l.get(point(0), all(), NDArrayIndex.interval(expOffsetSeq1, expOffsetSeq1 + 1)));
INDArray expL3 = Nd4j.create(new double[] {101, 201, 301, 401, 501}, new int[]{1,5}); assertEquals(expF2, f.get(point(1), all(), NDArrayIndex.interval(expOffsetSeq2, expOffsetSeq2 + 3)));
assertEquals(expL2, l.get(point(1), all(), NDArrayIndex.interval(expOffsetSeq2, expOffsetSeq2 + 3)));
assertEquals(expF1, f.get(point(0), all(), assertEquals(expF3, f.get(point(2), all(), NDArrayIndex.interval(expOffsetSeq3, expOffsetSeq3 + 5)));
NDArrayIndex.interval(expOffsetSeq1, expOffsetSeq1 + 1))); assertEquals(expL3, l.get(point(2), all(), NDArrayIndex.interval(expOffsetSeq3, expOffsetSeq3 + 5)));
assertEquals(expL1, l.get(point(0), all(),
NDArrayIndex.interval(expOffsetSeq1, expOffsetSeq1 + 1)));
assertEquals(expF2, f.get(point(1), all(),
NDArrayIndex.interval(expOffsetSeq2, expOffsetSeq2 + 3)));
assertEquals(expL2, l.get(point(1), all(),
NDArrayIndex.interval(expOffsetSeq2, expOffsetSeq2 + 3)));
assertEquals(expF3, f.get(point(2), all(),
NDArrayIndex.interval(expOffsetSeq3, expOffsetSeq3 + 5)));
assertEquals(expL3, l.get(point(2), all(),
NDArrayIndex.interval(expOffsetSeq3, expOffsetSeq3 + 5)));
} }
@Test @Test
public void testSeqRRDSIMasking(){ @DisplayName("Test Seq RRDSI Masking")
//This also tests RecordReaderMultiDataSetIterator, by virtue of void testSeqRRDSIMasking() {
// This also tests RecordReaderMultiDataSetIterator, by virtue of
List<List<List<Writable>>> features = new ArrayList<>(); List<List<List<Writable>>> features = new ArrayList<>();
List<List<List<Writable>>> labels = new ArrayList<>(); List<List<List<Writable>>> labels = new ArrayList<>();
features.add(Arrays.asList(l(new DoubleWritable(1)), l(new DoubleWritable(2)), l(new DoubleWritable(3)))); features.add(Arrays.asList(l(new DoubleWritable(1)), l(new DoubleWritable(2)), l(new DoubleWritable(3))));
features.add(Arrays.asList(l(new DoubleWritable(4)), l(new DoubleWritable(5)))); features.add(Arrays.asList(l(new DoubleWritable(4)), l(new DoubleWritable(5))));
labels.add(Arrays.asList(l(new IntWritable(0)))); labels.add(Arrays.asList(l(new IntWritable(0))));
labels.add(Arrays.asList(l(new IntWritable(1)))); labels.add(Arrays.asList(l(new IntWritable(1))));
CollectionSequenceRecordReader fR = new CollectionSequenceRecordReader(features); CollectionSequenceRecordReader fR = new CollectionSequenceRecordReader(features);
CollectionSequenceRecordReader lR = new CollectionSequenceRecordReader(labels); CollectionSequenceRecordReader lR = new CollectionSequenceRecordReader(labels);
SequenceRecordReaderDataSetIterator seqRRDSI = new SequenceRecordReaderDataSetIterator(fR, lR, 2, 2, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
SequenceRecordReaderDataSetIterator seqRRDSI = new SequenceRecordReaderDataSetIterator(
fR, lR, 2, 2, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
DataSet ds = seqRRDSI.next(); DataSet ds = seqRRDSI.next();
INDArray fMask = Nd4j.create(new double[][] { { 1, 1, 1 }, { 1, 1, 0 } });
INDArray fMask = Nd4j.create(new double[][]{ INDArray lMask = Nd4j.create(new double[][] { { 0, 0, 1 }, { 0, 1, 0 } });
{1,1,1},
{1,1,0}});
INDArray lMask = Nd4j.create(new double[][]{
{0,0,1},
{0,1,0}});
assertEquals(fMask, ds.getFeaturesMaskArray()); assertEquals(fMask, ds.getFeaturesMaskArray());
assertEquals(lMask, ds.getLabelsMaskArray()); assertEquals(lMask, ds.getLabelsMaskArray());
INDArray f = Nd4j.create(new double[][] { { 1, 2, 3 }, { 4, 5, 0 } });
INDArray f = Nd4j.create(new double[][]{ INDArray l = Nd4j.create(2, 2, 3);
{1,2,3}, l.putScalar(0, 0, 2, 1.0);
{4,5,0}}); l.putScalar(1, 1, 1, 1.0);
INDArray l = Nd4j.create(2,2,3);
l.putScalar(0,0,2, 1.0);
l.putScalar(1,1,1, 1.0);
assertEquals(f, ds.getFeatures().get(all(), point(0), all())); assertEquals(f, ds.getFeatures().get(all(), point(0), all()));
assertEquals(l, ds.getLabels()); assertEquals(l, ds.getLabels());
} }
private static List<Writable> l(Writable... in){ private static List<Writable> l(Writable... in) {
return Arrays.asList(in); return Arrays.asList(in);
} }
@Test @Test
public void testExcludeStringColCSV() throws Exception { @DisplayName("Test Exclude String Col CSV")
File csvFile = temporaryFolder.newFile(); void testExcludeStringColCSV() throws Exception {
File csvFile = temporaryFolder.toFile();
StringBuilder sb = new StringBuilder(); StringBuilder sb = new StringBuilder();
for(int i=1; i<=10; i++ ){ for (int i = 1; i <= 10; i++) {
if(i > 1){ if (i > 1) {
sb.append("\n"); sb.append("\n");
} }
sb.append("skip_").append(i).append(",").append(i).append(",").append(i + 0.5); sb.append("skip_").append(i).append(",").append(i).append(",").append(i + 0.5);
} }
FileUtils.writeStringToFile(csvFile, sb.toString()); FileUtils.writeStringToFile(csvFile, sb.toString());
RecordReader rr = new CSVRecordReader(); RecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(csvFile)); rr.initialize(new FileSplit(csvFile));
RecordReaderMultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10).addReader("rr", rr).addInput("rr", 1, 1).addOutput("rr", 2, 2).build();
RecordReaderMultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10) INDArray expFeatures = Nd4j.linspace(1, 10, 10).reshape(1, 10).transpose();
.addReader("rr", rr) INDArray expLabels = Nd4j.linspace(1, 10, 10).addi(0.5).reshape(1, 10).transpose();
.addInput("rr", 1, 1)
.addOutput("rr", 2, 2)
.build();
INDArray expFeatures = Nd4j.linspace(1,10,10).reshape(1,10).transpose();
INDArray expLabels = Nd4j.linspace(1,10,10).addi(0.5).reshape(1,10).transpose();
MultiDataSet mds = rrmdsi.next(); MultiDataSet mds = rrmdsi.next();
assertFalse(rrmdsi.hasNext()); assertFalse(rrmdsi.hasNext());
assertEquals(expFeatures, mds.getFeatures(0).castTo(expFeatures.dataType())); assertEquals(expFeatures, mds.getFeatures(0).castTo(expFeatures.dataType()));
assertEquals(expLabels, mds.getLabels(0).castTo(expLabels.dataType())); assertEquals(expLabels, mds.getLabels(0).castTo(expLabels.dataType()));
} }
private static final int nX = 32; private static final int nX = 32;
private static final int nY = 32; private static final int nY = 32;
private static final int nZ = 28; private static final int nZ = 28;
@Test @Test
public void testRRMDSI5D() { @DisplayName("Test RRMDSI 5 D")
void testRRMDSI5D() {
int batchSize = 5; int batchSize = 5;
CustomRecordReader recordReader = new CustomRecordReader(); CustomRecordReader recordReader = new CustomRecordReader();
DataSetIterator dataIter = new RecordReaderDataSetIterator(recordReader, batchSize, DataSetIterator dataIter = new RecordReaderDataSetIterator(recordReader, batchSize, 1, /* Index of label in records */
1, /* Index of label in records */ 2);
2 /* number of different labels */);
int count = 0; int count = 0;
while(dataIter.hasNext()){ while (dataIter.hasNext()) {
DataSet ds = dataIter.next(); DataSet ds = dataIter.next();
int offset = 5 * count;
int offset = 5*count; for (int i = 0; i < 5; i++) {
for( int i=0; i<5; i++ ){ INDArray act = ds.getFeatures().get(interval(i, i, true), all(), all(), all(), all());
INDArray act = ds.getFeatures().get(interval(i,i,true), all(), all(), all(), all()); INDArray exp = Nd4j.valueArrayOf(new int[] { 1, 1, nZ, nX, nY }, i + offset);
INDArray exp = Nd4j.valueArrayOf(new int[]{1, 1, nZ, nX, nY}, i + offset );
assertEquals(exp, act); assertEquals(exp, act);
} }
count++; count++;
} }
assertEquals(2, count); assertEquals(2, count);
} }
@DisplayName("Custom Record Reader")
static class CustomRecordReader extends BaseRecordReader { static class CustomRecordReader extends BaseRecordReader {
int n = 0; int n = 0;
CustomRecordReader() { } CustomRecordReader() {
}
@Override @Override
public boolean batchesSupported() { public boolean batchesSupported() {
@ -858,8 +655,8 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest {
@Override @Override
public List<Writable> next() { public List<Writable> next() {
INDArray nd = Nd4j.create(new float[nZ*nY*nX], new int[] {1, 1, nZ, nY, nX }, 'c').assign(n); INDArray nd = Nd4j.create(new float[nZ * nY * nX], new int[] { 1, 1, nZ, nY, nX }, 'c').assign(n);
final List<Writable>res = RecordConverter.toRecord(nd); final List<Writable> res = RecordConverter.toRecord(nd);
res.add(new IntWritable(0)); res.add(new IntWritable(0));
n++; n++;
return res; return res;
@ -867,14 +664,16 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest {
@Override @Override
public boolean hasNext() { public boolean hasNext() {
return n<10; return n < 10;
} }
final static ArrayList<String> labels = new ArrayList<>(2); final static ArrayList<String> labels = new ArrayList<>(2);
static { static {
labels.add("lbl0"); labels.add("lbl0");
labels.add("lbl1"); labels.add("lbl1");
} }
@Override @Override
public List<String> getLabels() { public List<String> getLabels() {
return labels; return labels;
@ -928,6 +727,7 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest {
public void initialize(InputSplit split) { public void initialize(InputSplit split) {
n = 0; n = 0;
} }
@Override @Override
public void initialize(Configuration conf, InputSplit split) { public void initialize(Configuration conf, InputSplit split) {
n = 0; n = 0;

View File

@ -17,38 +17,39 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.deeplearning4j.datasets.fetchers; package org.deeplearning4j.datasets.fetchers;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.Timeout; import org.junit.rules.Timeout;
import java.io.File; import java.io.File;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.Assert.assertTrue; import static org.junit.jupiter.api.Assumptions.assumeTrue;
import static org.junit.Assume.assumeTrue; import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
/** /**
* @author saudet * @author saudet
*/ */
public class SvhnDataFetcherTest extends BaseDL4JTest { @DisplayName("Svhn Data Fetcher Test")
class SvhnDataFetcherTest extends BaseDL4JTest {
@Override @Override
public long getTimeoutMilliseconds() { public long getTimeoutMilliseconds() {
return 480_000_000L; //Shouldn't take this long but slow download or drive access on CI machines may need extra time. // Shouldn't take this long but slow download or drive access on CI machines may need extra time.
return 480_000_000L;
} }
@Test @Test
public void testSvhnDataFetcher() throws Exception { @DisplayName("Test Svhn Data Fetcher")
assumeTrue(isIntegrationTests()); //Ignore unless integration tests - CI can get caught up on slow disk access void testSvhnDataFetcher() throws Exception {
// Ignore unless integration tests - CI can get caught up on slow disk access
assumeTrue(isIntegrationTests());
SvhnDataFetcher fetch = new SvhnDataFetcher(); SvhnDataFetcher fetch = new SvhnDataFetcher();
File path = fetch.getDataSetPath(DataSetType.TRAIN); File path = fetch.getDataSetPath(DataSetType.TRAIN);
File path2 = fetch.getDataSetPath(DataSetType.TEST); File path2 = fetch.getDataSetPath(DataSetType.TEST);
File path3 = fetch.getDataSetPath(DataSetType.VALIDATION); File path3 = fetch.getDataSetPath(DataSetType.VALIDATION);
assertTrue(path.isDirectory()); assertTrue(path.isDirectory());
assertTrue(path2.isDirectory()); assertTrue(path2.isDirectory());
assertTrue(path3.isDirectory()); assertTrue(path3.isDirectory());

View File

@ -17,52 +17,50 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.deeplearning4j.datasets.iterator; package org.deeplearning4j.datasets.iterator;
import org.apache.commons.lang3.RandomUtils; import org.apache.commons.lang3.RandomUtils;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.common.primitives.Pair; import org.nd4j.common.primitives.Pair;
import java.util.Iterator; import java.util.Iterator;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals; @DisplayName("Abstract Data Set Iterator Test")
import static org.junit.Assert.assertTrue; class AbstractDataSetIteratorTest extends BaseDL4JTest {
public class AbstractDataSetIteratorTest extends BaseDL4JTest {
@Test @Test
public void next() throws Exception { @DisplayName("Next")
void next() throws Exception {
int numFeatures = 128; int numFeatures = 128;
int batchSize = 10; int batchSize = 10;
int numRows = 1000; int numRows = 1000;
AtomicInteger cnt = new AtomicInteger(0); AtomicInteger cnt = new AtomicInteger(0);
FloatsDataSetIterator iterator = new FloatsDataSetIterator(floatIterable(numRows, numFeatures), batchSize); FloatsDataSetIterator iterator = new FloatsDataSetIterator(floatIterable(numRows, numFeatures), batchSize);
assertTrue(iterator.hasNext()); assertTrue(iterator.hasNext());
while (iterator.hasNext()) { while (iterator.hasNext()) {
DataSet dataSet = iterator.next(); DataSet dataSet = iterator.next();
INDArray features = dataSet.getFeatures(); INDArray features = dataSet.getFeatures();
assertEquals(batchSize, features.rows()); assertEquals(batchSize, features.rows());
assertEquals(numFeatures, features.columns()); assertEquals(numFeatures, features.columns());
cnt.incrementAndGet(); cnt.incrementAndGet();
} }
assertEquals(numRows / batchSize, cnt.get()); assertEquals(numRows / batchSize, cnt.get());
} }
protected static Iterable<Pair<float[], float[]>> floatIterable(final int totalRows, final int numColumns) { protected static Iterable<Pair<float[], float[]>> floatIterable(final int totalRows, final int numColumns) {
return new Iterable<Pair<float[], float[]>>() { return new Iterable<Pair<float[], float[]>>() {
@Override @Override
public Iterator<Pair<float[], float[]>> iterator() { public Iterator<Pair<float[], float[]>> iterator() {
return new Iterator<Pair<float[], float[]>>() { return new Iterator<Pair<float[], float[]>>() {
private AtomicInteger cnt = new AtomicInteger(0); private AtomicInteger cnt = new AtomicInteger(0);
@Override @Override
@ -72,8 +70,8 @@ public class AbstractDataSetIteratorTest extends BaseDL4JTest {
@Override @Override
public Pair<float[], float[]> next() { public Pair<float[], float[]> next() {
float features[] = new float[numColumns]; float[] features = new float[numColumns];
float labels[] = new float[numColumns]; float[] labels = new float[numColumns];
for (int i = 0; i < numColumns; i++) { for (int i = 0; i < numColumns; i++) {
features[i] = (float) i; features[i] = (float) i;
labels[i] = RandomUtils.nextFloat(0, 5); labels[i] = RandomUtils.nextFloat(0, 5);

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.deeplearning4j.datasets.iterator; package org.deeplearning4j.datasets.iterator;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@ -25,117 +24,118 @@ import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.callbacks.InterleavedDataSetCallback; import org.deeplearning4j.datasets.iterator.callbacks.InterleavedDataSetCallback;
import org.deeplearning4j.datasets.iterator.tools.VariableTimeseriesGenerator; import org.deeplearning4j.datasets.iterator.tools.VariableTimeseriesGenerator;
import org.deeplearning4j.nn.util.TestDataSetConsumer; import org.deeplearning4j.nn.util.TestDataSetConsumer;
import org.junit.Before; import org.junit.jupiter.api.BeforeEach;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Iterator; import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicLong;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.Assert.assertNotEquals; import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.jupiter.api.Assertions.assertThrows;
@Slf4j @Slf4j
public class AsyncDataSetIteratorTest extends BaseDL4JTest { @DisplayName("Async Data Set Iterator Test")
class AsyncDataSetIteratorTest extends BaseDL4JTest {
private ExistingDataSetIterator backIterator; private ExistingDataSetIterator backIterator;
private static final int TEST_SIZE = 100; private static final int TEST_SIZE = 100;
private static final int ITERATIONS = 10; private static final int ITERATIONS = 10;
// time spent in consumer thread, milliseconds // time spent in consumer thread, milliseconds
private static final long EXECUTION_TIME = 5; private static final long EXECUTION_TIME = 5;
private static final long EXECUTION_SMALL = 1; private static final long EXECUTION_SMALL = 1;
@Before @BeforeEach
public void setUp() throws Exception { void setUp() throws Exception {
List<DataSet> iterable = new ArrayList<>(); List<DataSet> iterable = new ArrayList<>();
for (int i = 0; i < TEST_SIZE; i++) { for (int i = 0; i < TEST_SIZE; i++) {
iterable.add(new DataSet(Nd4j.create(new float[100]), Nd4j.create(new float[10]))); iterable.add(new DataSet(Nd4j.create(new float[100]), Nd4j.create(new float[10])));
} }
backIterator = new ExistingDataSetIterator(iterable); backIterator = new ExistingDataSetIterator(iterable);
} }
@Test @Test
public void hasNext1() throws Exception { @DisplayName("Has Next 1")
void hasNext1() throws Exception {
for (int iter = 0; iter < ITERATIONS; iter++) { for (int iter = 0; iter < ITERATIONS; iter++) {
for (int prefetchSize = 2; prefetchSize <= 8; prefetchSize++) { for (int prefetchSize = 2; prefetchSize <= 8; prefetchSize++) {
AsyncDataSetIterator iterator = new AsyncDataSetIterator(backIterator, prefetchSize); AsyncDataSetIterator iterator = new AsyncDataSetIterator(backIterator, prefetchSize);
int cnt = 0; int cnt = 0;
while (iterator.hasNext()) { while (iterator.hasNext()) {
DataSet ds = iterator.next(); DataSet ds = iterator.next();
assertNotEquals(null, ds); assertNotEquals(null, ds);
cnt++; cnt++;
} }
assertEquals( TEST_SIZE, cnt,"Failed on iteration: " + iter + ", prefetchSize: " + prefetchSize);
assertEquals("Failed on iteration: " + iter + ", prefetchSize: " + prefetchSize, TEST_SIZE, cnt);
iterator.shutdown(); iterator.shutdown();
} }
} }
} }
@Test @Test
public void hasNextWithResetAndLoad() throws Exception { @DisplayName("Has Next With Reset And Load")
void hasNextWithResetAndLoad() throws Exception {
int[] prefetchSizes; int[] prefetchSizes;
if(isIntegrationTests()){ if (isIntegrationTests()) {
prefetchSizes = new int[]{2, 3, 4, 5, 6, 7, 8}; prefetchSizes = new int[] { 2, 3, 4, 5, 6, 7, 8 };
} else { } else {
prefetchSizes = new int[]{2, 3, 8}; prefetchSizes = new int[] { 2, 3, 8 };
} }
for (int iter = 0; iter < ITERATIONS; iter++) { for (int iter = 0; iter < ITERATIONS; iter++) {
for(int prefetchSize : prefetchSizes){ for (int prefetchSize : prefetchSizes) {
AsyncDataSetIterator iterator = new AsyncDataSetIterator(backIterator, prefetchSize); AsyncDataSetIterator iterator = new AsyncDataSetIterator(backIterator, prefetchSize);
TestDataSetConsumer consumer = new TestDataSetConsumer(EXECUTION_SMALL); TestDataSetConsumer consumer = new TestDataSetConsumer(EXECUTION_SMALL);
int cnt = 0; int cnt = 0;
while (iterator.hasNext()) { while (iterator.hasNext()) {
DataSet ds = iterator.next(); DataSet ds = iterator.next();
consumer.consumeOnce(ds, false); consumer.consumeOnce(ds, false);
cnt++; cnt++;
if (cnt == TEST_SIZE / 2) if (cnt == TEST_SIZE / 2)
iterator.reset(); iterator.reset();
} }
assertEquals(TEST_SIZE + (TEST_SIZE / 2), cnt); assertEquals(TEST_SIZE + (TEST_SIZE / 2), cnt);
iterator.shutdown(); iterator.shutdown();
} }
} }
} }
@Test @Test
public void testWithLoad() { @DisplayName("Test With Load")
void testWithLoad() {
for (int iter = 0; iter < ITERATIONS; iter++) { for (int iter = 0; iter < ITERATIONS; iter++) {
AsyncDataSetIterator iterator = new AsyncDataSetIterator(backIterator, 8); AsyncDataSetIterator iterator = new AsyncDataSetIterator(backIterator, 8);
TestDataSetConsumer consumer = new TestDataSetConsumer(iterator, EXECUTION_TIME); TestDataSetConsumer consumer = new TestDataSetConsumer(iterator, EXECUTION_TIME);
consumer.consumeWhileHasNext(true); consumer.consumeWhileHasNext(true);
assertEquals(TEST_SIZE, consumer.getCount()); assertEquals(TEST_SIZE, consumer.getCount());
iterator.shutdown(); iterator.shutdown();
} }
} }
@Test(expected = ArrayIndexOutOfBoundsException.class) @Test
public void testWithException() { @DisplayName("Test With Exception")
ExistingDataSetIterator crashingIterator = new ExistingDataSetIterator(new IterableWithException(100)); void testWithException() {
AsyncDataSetIterator iterator = new AsyncDataSetIterator(crashingIterator, 8); assertThrows(ArrayIndexOutOfBoundsException.class, () -> {
ExistingDataSetIterator crashingIterator = new ExistingDataSetIterator(new IterableWithException(100));
TestDataSetConsumer consumer = new TestDataSetConsumer(iterator, EXECUTION_SMALL); AsyncDataSetIterator iterator = new AsyncDataSetIterator(crashingIterator, 8);
consumer.consumeWhileHasNext(true); TestDataSetConsumer consumer = new TestDataSetConsumer(iterator, EXECUTION_SMALL);
iterator.shutdown(); consumer.consumeWhileHasNext(true);
iterator.shutdown();
});
} }
@DisplayName("Iterable With Exception")
private class IterableWithException implements Iterable<DataSet> { private class IterableWithException implements Iterable<DataSet> {
private final AtomicLong counter = new AtomicLong(0); private final AtomicLong counter = new AtomicLong(0);
private final int crashIteration; private final int crashIteration;
public IterableWithException(int iteration) { public IterableWithException(int iteration) {
@ -146,6 +146,7 @@ public class AsyncDataSetIteratorTest extends BaseDL4JTest {
public Iterator<DataSet> iterator() { public Iterator<DataSet> iterator() {
counter.set(0); counter.set(0);
return new Iterator<DataSet>() { return new Iterator<DataSet>() {
@Override @Override
public boolean hasNext() { public boolean hasNext() {
return true; return true;
@ -155,82 +156,59 @@ public class AsyncDataSetIteratorTest extends BaseDL4JTest {
public DataSet next() { public DataSet next() {
if (counter.incrementAndGet() >= crashIteration) if (counter.incrementAndGet() >= crashIteration)
throw new ArrayIndexOutOfBoundsException("Thrown as expected"); throw new ArrayIndexOutOfBoundsException("Thrown as expected");
return new DataSet(Nd4j.create(10), Nd4j.create(10)); return new DataSet(Nd4j.create(10), Nd4j.create(10));
} }
@Override @Override
public void remove() { public void remove() {
} }
}; };
} }
} }
@Test @Test
public void testVariableTimeSeries1() throws Exception { @DisplayName("Test Variable Time Series 1")
void testVariableTimeSeries1() throws Exception {
int numBatches = isIntegrationTests() ? 1000 : 100; int numBatches = isIntegrationTests() ? 1000 : 100;
int batchSize = isIntegrationTests() ? 32 : 8; int batchSize = isIntegrationTests() ? 32 : 8;
int timeStepsMin = 10; int timeStepsMin = 10;
int timeStepsMax = isIntegrationTests() ? 500 : 100; int timeStepsMax = isIntegrationTests() ? 500 : 100;
int valuesPerTimestep = isIntegrationTests() ? 128 : 16; int valuesPerTimestep = isIntegrationTests() ? 128 : 16;
AsyncDataSetIterator adsi = new AsyncDataSetIterator(new VariableTimeseriesGenerator(1192, numBatches, batchSize, valuesPerTimestep, timeStepsMin, timeStepsMax, 10), 2, true);
AsyncDataSetIterator adsi = new AsyncDataSetIterator(
new VariableTimeseriesGenerator(1192, numBatches, batchSize, valuesPerTimestep, timeStepsMin, timeStepsMax, 10), 2, true);
for (int e = 0; e < 10; e++) { for (int e = 0; e < 10; e++) {
int cnt = 0; int cnt = 0;
while (adsi.hasNext()) { while (adsi.hasNext()) {
DataSet ds = adsi.next(); DataSet ds = adsi.next();
// log.info("Features ptr: {}", AtomicAllocator.getInstance().getPointer(mds.getFeatures()[0].data()).address());
//log.info("Features ptr: {}", AtomicAllocator.getInstance().getPointer(mds.getFeatures()[0].data()).address()); assertEquals( (double) cnt, ds.getFeatures().meanNumber().doubleValue(), 1e-10,"Failed on epoch " + e + "; iteration: " + cnt + ";");
assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt, assertEquals( (double) cnt + 0.25, ds.getLabels().meanNumber().doubleValue(), 1e-10,"Failed on epoch " + e + "; iteration: " + cnt + ";");
ds.getFeatures().meanNumber().doubleValue(), 1e-10); assertEquals( (double) cnt + 0.5, ds.getFeaturesMaskArray().meanNumber().doubleValue(), 1e-10,"Failed on epoch " + e + "; iteration: " + cnt + ";");
assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt + 0.25, assertEquals( (double) cnt + 0.75, ds.getLabelsMaskArray().meanNumber().doubleValue(), 1e-10,"Failed on epoch " + e + "; iteration: " + cnt + ";");
ds.getLabels().meanNumber().doubleValue(), 1e-10);
assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt + 0.5,
ds.getFeaturesMaskArray().meanNumber().doubleValue(), 1e-10);
assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt + 0.75,
ds.getLabelsMaskArray().meanNumber().doubleValue(), 1e-10);
cnt++; cnt++;
} }
adsi.reset(); adsi.reset();
// log.info("Epoch {} finished...", e); // log.info("Epoch {} finished...", e);
} }
} }
@Test @Test
public void testVariableTimeSeries2() throws Exception { @DisplayName("Test Variable Time Series 2")
AsyncDataSetIterator adsi = void testVariableTimeSeries2() throws Exception {
new AsyncDataSetIterator(new VariableTimeseriesGenerator(1192, 100, 32, 128, 100, 100, 100), 2, AsyncDataSetIterator adsi = new AsyncDataSetIterator(new VariableTimeseriesGenerator(1192, 100, 32, 128, 100, 100, 100), 2, true, new InterleavedDataSetCallback(2 * 2));
true, new InterleavedDataSetCallback(2 * 2));
for (int e = 0; e < 5; e++) { for (int e = 0; e < 5; e++) {
int cnt = 0; int cnt = 0;
while (adsi.hasNext()) { while (adsi.hasNext()) {
DataSet ds = adsi.next(); DataSet ds = adsi.next();
ds.detach(); ds.detach();
// log.info("Features ptr: {}", AtomicAllocator.getInstance().getPointer(mds.getFeatures()[0].data()).address());
//log.info("Features ptr: {}", AtomicAllocator.getInstance().getPointer(mds.getFeatures()[0].data()).address()); assertEquals((double) cnt, ds.getFeatures().meanNumber().doubleValue(), 1e-10,"Failed on epoch " + e + "; iteration: " + cnt + ";");
assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt, assertEquals((double) cnt + 0.25, ds.getLabels().meanNumber().doubleValue(), 1e-10,"Failed on epoch " + e + "; iteration: " + cnt + ";");
ds.getFeatures().meanNumber().doubleValue(), 1e-10); assertEquals( (double) cnt + 0.5, ds.getFeaturesMaskArray().meanNumber().doubleValue(), 1e-10,"Failed on epoch " + e + "; iteration: " + cnt + ";");
assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt + 0.25, assertEquals((double) cnt + 0.75, ds.getLabelsMaskArray().meanNumber().doubleValue(), 1e-10,"Failed on epoch " + e + "; iteration: " + cnt + ";");
ds.getLabels().meanNumber().doubleValue(), 1e-10);
assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt + 0.5,
ds.getFeaturesMaskArray().meanNumber().doubleValue(), 1e-10);
assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt + 0.75,
ds.getLabelsMaskArray().meanNumber().doubleValue(), 1e-10);
cnt++; cnt++;
} }
adsi.reset(); adsi.reset();
// log.info("Epoch {} finished...", e); // log.info("Epoch {} finished...", e);
} }
} }
} }

View File

@ -17,98 +17,19 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.deeplearning4j.datasets.iterator; package org.deeplearning4j.datasets.iterator;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.tools.VariableMultiTimeseriesGenerator; import org.deeplearning4j.datasets.iterator.tools.VariableMultiTimeseriesGenerator;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.MultiDataSet;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertEquals; import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@Slf4j @Slf4j
public class AsyncMultiDataSetIteratorTest extends BaseDL4JTest {
/**
* THIS TEST SHOULD BE ALWAYS RUN WITH DOUBLE PRECISION, WITHOUT ANY EXCLUSIONS
*
* @throws Exception
*/
@Test
public void testVariableTimeSeries1() throws Exception {
int numBatches = isIntegrationTests() ? 1000 : 100;
int batchSize = isIntegrationTests() ? 32 : 8;
int timeStepsMin = 10;
int timeStepsMax = isIntegrationTests() ? 500 : 100;
int valuesPerTimestep = isIntegrationTests() ? 128 : 16;
val iterator = new VariableMultiTimeseriesGenerator(1192, numBatches, batchSize, valuesPerTimestep, timeStepsMin, timeStepsMax, 10);
iterator.reset();
iterator.hasNext();
val amdsi = new AsyncMultiDataSetIterator(iterator, 2, true);
for (int e = 0; e < 10; e++) {
int cnt = 0;
while (amdsi.hasNext()) {
MultiDataSet mds = amdsi.next();
//log.info("Features ptr: {}", AtomicAllocator.getInstance().getPointer(mds.getFeatures()[0].data()).address());
assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt,
mds.getFeatures()[0].meanNumber().doubleValue(), 1e-10);
assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt + 0.25,
mds.getLabels()[0].meanNumber().doubleValue(), 1e-10);
assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt + 0.5,
mds.getFeaturesMaskArrays()[0].meanNumber().doubleValue(), 1e-10);
assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt + 0.75,
mds.getLabelsMaskArrays()[0].meanNumber().doubleValue(), 1e-10);
cnt++;
}
amdsi.reset();
log.info("Epoch {} finished...", e);
}
}
@Test
public void testVariableTimeSeries2() throws Exception {
int numBatches = isIntegrationTests() ? 1000 : 100;
int batchSize = isIntegrationTests() ? 32 : 8;
int timeStepsMin = 10;
int timeStepsMax = isIntegrationTests() ? 500 : 100;
int valuesPerTimestep = isIntegrationTests() ? 128 : 16;
val iterator = new VariableMultiTimeseriesGenerator(1192, numBatches, batchSize, valuesPerTimestep, timeStepsMin, timeStepsMax, 10);
for (int e = 0; e < 10; e++) {
iterator.reset();
iterator.hasNext();
val amdsi = new AsyncMultiDataSetIterator(iterator, 2, true);
int cnt = 0;
while (amdsi.hasNext()) {
MultiDataSet mds = amdsi.next();
//log.info("Features ptr: {}", AtomicAllocator.getInstance().getPointer(mds.getFeatures()[0].data()).address());
assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt,
mds.getFeatures()[0].meanNumber().doubleValue(), 1e-10);
assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt + 0.25,
mds.getLabels()[0].meanNumber().doubleValue(), 1e-10);
assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt + 0.5,
mds.getFeaturesMaskArrays()[0].meanNumber().doubleValue(), 1e-10);
assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt + 0.75,
mds.getLabelsMaskArrays()[0].meanNumber().doubleValue(), 1e-10);
cnt++;
}
}
}
/* /*
@Test @Test
public void testResetBug() throws Exception { public void testResetBug() throws Exception {
@ -134,6 +55,120 @@ public class AsyncMultiDataSetIteratorTest extends BaseDL4JTest {
trainData.reset(); trainData.reset();
SequenceRecordReader testFeatures = new CSVSequenceRecordReader();
testFeatures.initialize(new NumberedFileInputSplit("/home/raver119/develop/dl4j-examples/src/main/resources/uci/test/features" + "/%d.csv", 0, 149));
RecordReader testLabels = new CSVRecordReader();
testLabels.initialize(new NumberedFileInputSplit("/home/raver119/develop/dl4j-examples/src/main/resources/uci/test/labels" + "/%d.csv", 0, 149));
MultiDataSetIterator testData = new RecordReaderMultiDataSetIterator.Builder(miniBatchSize)
.addSequenceReader("features", testFeatures)
.addReader("labels", testLabels)
.addInput("features")
.addOutputOneHot("labels", 0, numLabelClasses)
.build();
System.out.println("-------------- HASH 1----------------");
testData.reset();
while(testData.hasNext()){
System.out.println(Arrays.hashCode(testData.next().getFeatures(0).data().asFloat()));
}
System.out.println("-------------- HASH 2 ----------------");
testData.reset();
testData.hasNext(); //***** Remove this (or move to after async creation), and we get expected results *****
val adsi = new AsyncMultiDataSetIterator(testData, 4, true); //OR remove this (keeping hasNext) and we get expected results
//val adsi = new AsyncShieldMultiDataSetIterator(testData);
while(adsi.hasNext()){
System.out.println(Arrays.hashCode(adsi.next().getFeatures(0).data().asFloat()));
}
}
*/
@DisplayName("Async Multi Data Set Iterator Test")
class AsyncMultiDataSetIteratorTest extends BaseDL4JTest {
/**
* THIS TEST SHOULD BE ALWAYS RUN WITH DOUBLE PRECISION, WITHOUT ANY EXCLUSIONS
*
* @throws Exception
*/
@Test
@DisplayName("Test Variable Time Series 1")
void testVariableTimeSeries1() throws Exception {
int numBatches = isIntegrationTests() ? 1000 : 100;
int batchSize = isIntegrationTests() ? 32 : 8;
int timeStepsMin = 10;
int timeStepsMax = isIntegrationTests() ? 500 : 100;
int valuesPerTimestep = isIntegrationTests() ? 128 : 16;
val iterator = new VariableMultiTimeseriesGenerator(1192, numBatches, batchSize, valuesPerTimestep, timeStepsMin, timeStepsMax, 10);
iterator.reset();
iterator.hasNext();
val amdsi = new AsyncMultiDataSetIterator(iterator, 2, true);
for (int e = 0; e < 10; e++) {
int cnt = 0;
while (amdsi.hasNext()) {
MultiDataSet mds = amdsi.next();
// log.info("Features ptr: {}", AtomicAllocator.getInstance().getPointer(mds.getFeatures()[0].data()).address());
assertEquals( (double) cnt, mds.getFeatures()[0].meanNumber().doubleValue(), 1e-10,"Failed on epoch " + e + "; iteration: " + cnt + ";");
assertEquals( (double) cnt + 0.25, mds.getLabels()[0].meanNumber().doubleValue(), 1e-10,"Failed on epoch " + e + "; iteration: " + cnt + ";");
assertEquals((double) cnt + 0.5, mds.getFeaturesMaskArrays()[0].meanNumber().doubleValue(), 1e-10,"Failed on epoch " + e + "; iteration: " + cnt + ";");
assertEquals((double) cnt + 0.75, mds.getLabelsMaskArrays()[0].meanNumber().doubleValue(), 1e-10,"Failed on epoch " + e + "; iteration: " + cnt + ";");
cnt++;
}
amdsi.reset();
log.info("Epoch {} finished...", e);
}
}
@Test
@DisplayName("Test Variable Time Series 2")
void testVariableTimeSeries2() throws Exception {
int numBatches = isIntegrationTests() ? 1000 : 100;
int batchSize = isIntegrationTests() ? 32 : 8;
int timeStepsMin = 10;
int timeStepsMax = isIntegrationTests() ? 500 : 100;
int valuesPerTimestep = isIntegrationTests() ? 128 : 16;
val iterator = new VariableMultiTimeseriesGenerator(1192, numBatches, batchSize, valuesPerTimestep, timeStepsMin, timeStepsMax, 10);
for (int e = 0; e < 10; e++) {
iterator.reset();
iterator.hasNext();
val amdsi = new AsyncMultiDataSetIterator(iterator, 2, true);
int cnt = 0;
while (amdsi.hasNext()) {
MultiDataSet mds = amdsi.next();
// log.info("Features ptr: {}", AtomicAllocator.getInstance().getPointer(mds.getFeatures()[0].data()).address());
assertEquals( (double) cnt, mds.getFeatures()[0].meanNumber().doubleValue(), 1e-10,"Failed on epoch " + e + "; iteration: " + cnt + ";");
assertEquals((double) cnt + 0.25, mds.getLabels()[0].meanNumber().doubleValue(), 1e-10,"Failed on epoch " + e + "; iteration: " + cnt + ";");
assertEquals( (double) cnt + 0.5, mds.getFeaturesMaskArrays()[0].meanNumber().doubleValue(), 1e-10,"Failed on epoch " + e + "; iteration: " + cnt + ";");
assertEquals( (double) cnt + 0.75, mds.getLabelsMaskArrays()[0].meanNumber().doubleValue(), 1e-10,"Failed on epoch " + e + "; iteration: " + cnt + ";");
cnt++;
}
}
}
/*
@Test
public void testResetBug() throws Exception {
// /home/raver119/develop/dl4j-examples/src/main/resources/uci/train/features
SequenceRecordReader trainFeatures = new CSVSequenceRecordReader();
trainFeatures.initialize(new NumberedFileInputSplit("/home/raver119/develop/dl4j-examples/src/main/resources/uci/train/features" + "/%d.csv", 0, 449));
RecordReader trainLabels = new CSVRecordReader();
trainLabels.initialize(new NumberedFileInputSplit("/home/raver119/develop/dl4j-examples/src/main/resources/uci/train/labels" + "/%d.csv", 0, 449));
int miniBatchSize = 10;
int numLabelClasses = 6;
MultiDataSetIterator trainData = new RecordReaderMultiDataSetIterator.Builder(miniBatchSize)
.addSequenceReader("features", trainFeatures)
.addReader("labels", trainLabels)
.addInput("features")
.addOutputOneHot("labels", 0, numLabelClasses)
.build();
//Normalize the training data
MultiDataNormalization normalizer = new MultiNormalizerStandardize();
normalizer.fit(trainData); //Collect training data statistics
trainData.reset();
SequenceRecordReader testFeatures = new CSVSequenceRecordReader(); SequenceRecordReader testFeatures = new CSVSequenceRecordReader();
testFeatures.initialize(new NumberedFileInputSplit("/home/raver119/develop/dl4j-examples/src/main/resources/uci/test/features" + "/%d.csv", 0, 149)); testFeatures.initialize(new NumberedFileInputSplit("/home/raver119/develop/dl4j-examples/src/main/resources/uci/test/features" + "/%d.csv", 0, 149));
RecordReader testLabels = new CSVRecordReader(); RecordReader testLabels = new CSVRecordReader();

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.deeplearning4j.datasets.iterator; package org.deeplearning4j.datasets.iterator;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
@ -41,8 +40,8 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.CollectScoresIterationListener; import org.deeplearning4j.optimize.listeners.CollectScoresIterationListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.junit.Ignore; import org.junit.jupiter.api.Disabled;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
@ -50,26 +49,28 @@ import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Random; import java.util.Random;
import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.*; @DisplayName("Data Set Iterator Test")
class DataSetIteratorTest extends BaseDL4JTest {
public class DataSetIteratorTest extends BaseDL4JTest {
@Override @Override
public long getTimeoutMilliseconds() { public long getTimeoutMilliseconds() {
return 360000; //Should run quickly; increased to large timeout due to occasonal slow CI downloads // Should run quickly; increased to large timeout due to occasonal slow CI downloads
return 360000;
} }
@Test @Test
public void testBatchSizeOfOneIris() throws Exception { @DisplayName("Test Batch Size Of One Iris")
//Test for (a) iterators returning correct number of examples, and void testBatchSizeOfOneIris() throws Exception {
//(b) Labels are a proper one-hot vector (i.e., sum is 1.0) // Test for (a) iterators returning correct number of examples, and
// (b) Labels are a proper one-hot vector (i.e., sum is 1.0)
//Iris: // Iris:
DataSetIterator iris = new IrisDataSetIterator(1, 5); DataSetIterator iris = new IrisDataSetIterator(1, 5);
int irisC = 0; int irisC = 0;
while (iris.hasNext()) { while (iris.hasNext()) {
@ -81,9 +82,9 @@ public class DataSetIteratorTest extends BaseDL4JTest {
} }
@Test @Test
public void testBatchSizeOfOneMnist() throws Exception { @DisplayName("Test Batch Size Of One Mnist")
void testBatchSizeOfOneMnist() throws Exception {
//MNIST: // MNIST:
DataSetIterator mnist = new MnistDataSetIterator(1, 5); DataSetIterator mnist = new MnistDataSetIterator(1, 5);
int mnistC = 0; int mnistC = 0;
while (mnist.hasNext()) { while (mnist.hasNext()) {
@ -95,25 +96,21 @@ public class DataSetIteratorTest extends BaseDL4JTest {
} }
@Test @Test
public void testMnist() throws Exception { @DisplayName("Test Mnist")
void testMnist() throws Exception {
ClassPathResource cpr = new ClassPathResource("mnist_first_200.txt"); ClassPathResource cpr = new ClassPathResource("mnist_first_200.txt");
CSVRecordReader rr = new CSVRecordReader(0, ','); CSVRecordReader rr = new CSVRecordReader(0, ',');
rr.initialize(new FileSplit(cpr.getTempFileFromArchive())); rr.initialize(new FileSplit(cpr.getTempFileFromArchive()));
RecordReaderDataSetIterator dsi = new RecordReaderDataSetIterator(rr, 10, 0, 10); RecordReaderDataSetIterator dsi = new RecordReaderDataSetIterator(rr, 10, 0, 10);
MnistDataSetIterator iter = new MnistDataSetIterator(10, 200, false, true, false, 0); MnistDataSetIterator iter = new MnistDataSetIterator(10, 200, false, true, false, 0);
while (dsi.hasNext()) { while (dsi.hasNext()) {
DataSet dsExp = dsi.next(); DataSet dsExp = dsi.next();
DataSet dsAct = iter.next(); DataSet dsAct = iter.next();
INDArray fExp = dsExp.getFeatures(); INDArray fExp = dsExp.getFeatures();
fExp.divi(255); fExp.divi(255);
INDArray lExp = dsExp.getLabels(); INDArray lExp = dsExp.getLabels();
INDArray fAct = dsAct.getFeatures(); INDArray fAct = dsAct.getFeatures();
INDArray lAct = dsAct.getLabels(); INDArray lAct = dsAct.getLabels();
assertEquals(fExp, fAct.castTo(fExp.dataType())); assertEquals(fExp, fAct.castTo(fExp.dataType()));
assertEquals(lExp, lAct.castTo(lExp.dataType())); assertEquals(lExp, lAct.castTo(lExp.dataType()));
} }
@ -121,12 +118,13 @@ public class DataSetIteratorTest extends BaseDL4JTest {
} }
@Test @Test
public void testLfwIterator() throws Exception { @DisplayName("Test Lfw Iterator")
void testLfwIterator() throws Exception {
int numExamples = 1; int numExamples = 1;
int row = 28; int row = 28;
int col = 28; int col = 28;
int channels = 1; int channels = 1;
LFWDataSetIterator iter = new LFWDataSetIterator(numExamples, new int[] {row, col, channels}, true); LFWDataSetIterator iter = new LFWDataSetIterator(numExamples, new int[] { row, col, channels }, true);
assertTrue(iter.hasNext()); assertTrue(iter.hasNext());
DataSet data = iter.next(); DataSet data = iter.next();
assertEquals(numExamples, data.getLabels().size(0)); assertEquals(numExamples, data.getLabels().size(0));
@ -134,7 +132,8 @@ public class DataSetIteratorTest extends BaseDL4JTest {
} }
@Test @Test
public void testTinyImageNetIterator() throws Exception { @DisplayName("Test Tiny Image Net Iterator")
void testTinyImageNetIterator() throws Exception {
int numClasses = 200; int numClasses = 200;
int row = 64; int row = 64;
int col = 64; int col = 64;
@ -143,24 +142,26 @@ public class DataSetIteratorTest extends BaseDL4JTest {
assertTrue(iter.hasNext()); assertTrue(iter.hasNext());
DataSet data = iter.next(); DataSet data = iter.next();
assertEquals(numClasses, data.getLabels().size(1)); assertEquals(numClasses, data.getLabels().size(1));
assertArrayEquals(new long[]{1, channels, row, col}, data.getFeatures().shape()); assertArrayEquals(new long[] { 1, channels, row, col }, data.getFeatures().shape());
} }
@Test @Test
public void testTinyImageNetIterator2() throws Exception { @DisplayName("Test Tiny Image Net Iterator 2")
void testTinyImageNetIterator2() throws Exception {
int numClasses = 200; int numClasses = 200;
int row = 224; int row = 224;
int col = 224; int col = 224;
int channels = 3; int channels = 3;
TinyImageNetDataSetIterator iter = new TinyImageNetDataSetIterator(1, new int[]{row, col}, DataSetType.TEST); TinyImageNetDataSetIterator iter = new TinyImageNetDataSetIterator(1, new int[] { row, col }, DataSetType.TEST);
assertTrue(iter.hasNext()); assertTrue(iter.hasNext());
DataSet data = iter.next(); DataSet data = iter.next();
assertEquals(numClasses, data.getLabels().size(1)); assertEquals(numClasses, data.getLabels().size(1));
assertArrayEquals(new long[]{1, channels, row, col}, data.getFeatures().shape()); assertArrayEquals(new long[] { 1, channels, row, col }, data.getFeatures().shape());
} }
@Test @Test
public void testLfwModel() throws Exception { @DisplayName("Test Lfw Model")
void testLfwModel() throws Exception {
final int numRows = 28; final int numRows = 28;
final int numColumns = 28; final int numColumns = 28;
int numChannels = 3; int numChannels = 3;
@ -169,39 +170,22 @@ public class DataSetIteratorTest extends BaseDL4JTest {
int batchSize = 2; int batchSize = 2;
int seed = 123; int seed = 123;
int listenerFreq = 1; int listenerFreq = 1;
LFWDataSetIterator lfw = new LFWDataSetIterator(batchSize, numSamples, new int[] { numRows, numColumns, numChannels }, outputNum, false, true, 1.0, new Random(seed));
LFWDataSetIterator lfw = new LFWDataSetIterator(batchSize, numSamples, MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed).gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list().layer(0, new ConvolutionLayer.Builder(5, 5).nIn(numChannels).nOut(6).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build()).layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[] { 2, 2 }).stride(1, 1).build()).layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(outputNum).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build()).setInputType(InputType.convolutionalFlat(numRows, numColumns, numChannels));
new int[] {numRows, numColumns, numChannels}, outputNum, false, true, 1.0, new Random(seed));
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed)
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list()
.layer(0, new ConvolutionLayer.Builder(5, 5).nIn(numChannels).nOut(6)
.weightInit(WeightInit.XAVIER).activation(Activation.RELU).build())
.layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[] {2, 2})
.stride(1, 1).build())
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nOut(outputNum).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX)
.build())
.setInputType(InputType.convolutionalFlat(numRows, numColumns, numChannels))
;
MultiLayerNetwork model = new MultiLayerNetwork(builder.build()); MultiLayerNetwork model = new MultiLayerNetwork(builder.build());
model.init(); model.init();
model.setListeners(new ScoreIterationListener(listenerFreq)); model.setListeners(new ScoreIterationListener(listenerFreq));
model.fit(lfw.next()); model.fit(lfw.next());
DataSet dataTest = lfw.next(); DataSet dataTest = lfw.next();
INDArray output = model.output(dataTest.getFeatures()); INDArray output = model.output(dataTest.getFeatures());
Evaluation eval = new Evaluation(outputNum); Evaluation eval = new Evaluation(outputNum);
eval.eval(dataTest.getLabels(), output); eval.eval(dataTest.getLabels(), output);
// System.out.println(eval.stats()); // System.out.println(eval.stats());
} }
@Test @Test
public void testCifar10Iterator() throws Exception { @DisplayName("Test Cifar 10 Iterator")
void testCifar10Iterator() throws Exception {
int numExamples = 1; int numExamples = 1;
int row = 32; int row = 32;
int col = 32; int col = 32;
@ -213,12 +197,13 @@ public class DataSetIteratorTest extends BaseDL4JTest {
assertEquals(channels * row * col, data.getFeatures().ravel().length()); assertEquals(channels * row * col, data.getFeatures().ravel().length());
} }
// Ignored for now - CIFAR iterator needs work - https://github.com/eclipse/deeplearning4j/issues/4673
@Test @Ignore //Ignored for now - CIFAR iterator needs work - https://github.com/eclipse/deeplearning4j/issues/4673 @Test
public void testCifarModel() throws Exception { @Disabled
@DisplayName("Test Cifar Model")
void testCifarModel() throws Exception {
// Streaming // Streaming
runCifar(false); runCifar(false);
// Preprocess // Preprocess
runCifar(true); runCifar(true);
} }
@ -231,32 +216,14 @@ public class DataSetIteratorTest extends BaseDL4JTest {
int batchSize = 5; int batchSize = 5;
int seed = 123; int seed = 123;
int listenerFreq = 1; int listenerFreq = 1;
Cifar10DataSetIterator cifar = new Cifar10DataSetIterator(batchSize); Cifar10DataSetIterator cifar = new Cifar10DataSetIterator(batchSize);
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed).gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list().layer(0, new ConvolutionLayer.Builder(5, 5).nIn(channels).nOut(6).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build()).layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[] { 2, 2 }).build()).layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(outputNum).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build()).setInputType(InputType.convolutionalFlat(height, width, channels));
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed)
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list()
.layer(0, new ConvolutionLayer.Builder(5, 5).nIn(channels).nOut(6).weightInit(WeightInit.XAVIER)
.activation(Activation.RELU).build())
.layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[] {2, 2})
.build())
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nOut(outputNum).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX)
.build())
.setInputType(InputType.convolutionalFlat(height, width, channels));
MultiLayerNetwork model = new MultiLayerNetwork(builder.build()); MultiLayerNetwork model = new MultiLayerNetwork(builder.build());
model.init(); model.init();
// model.setListeners(Arrays.asList((TrainingListener) new ScoreIterationListener(listenerFreq)));
//model.setListeners(Arrays.asList((TrainingListener) new ScoreIterationListener(listenerFreq)));
CollectScoresIterationListener listener = new CollectScoresIterationListener(listenerFreq); CollectScoresIterationListener listener = new CollectScoresIterationListener(listenerFreq);
model.setListeners(listener); model.setListeners(listener);
model.fit(cifar); model.fit(cifar);
cifar = new Cifar10DataSetIterator(batchSize); cifar = new Cifar10DataSetIterator(batchSize);
Evaluation eval = new Evaluation(cifar.getLabels()); Evaluation eval = new Evaluation(cifar.getLabels());
while (cifar.hasNext()) { while (cifar.hasNext()) {
@ -264,37 +231,31 @@ public class DataSetIteratorTest extends BaseDL4JTest {
INDArray output = model.output(testDS.getFeatures()); INDArray output = model.output(testDS.getFeatures());
eval.eval(testDS.getLabels(), output); eval.eval(testDS.getLabels(), output);
} }
// System.out.println(eval.stats(true)); // System.out.println(eval.stats(true));
listener.exportScores(System.out); listener.exportScores(System.out);
} }
@Test @Test
public void testIteratorDataSetIteratorCombining() { @DisplayName("Test Iterator Data Set Iterator Combining")
//Test combining of a bunch of small (size 1) data sets together void testIteratorDataSetIteratorCombining() {
// Test combining of a bunch of small (size 1) data sets together
int batchSize = 3; int batchSize = 3;
int numBatches = 4; int numBatches = 4;
int featureSize = 5; int featureSize = 5;
int labelSize = 6; int labelSize = 6;
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
List<DataSet> orig = new ArrayList<>(); List<DataSet> orig = new ArrayList<>();
for (int i = 0; i < batchSize * numBatches; i++) { for (int i = 0; i < batchSize * numBatches; i++) {
INDArray features = Nd4j.rand(1, featureSize); INDArray features = Nd4j.rand(1, featureSize);
INDArray labels = Nd4j.rand(1, labelSize); INDArray labels = Nd4j.rand(1, labelSize);
orig.add(new DataSet(features, labels)); orig.add(new DataSet(features, labels));
} }
DataSetIterator iter = new IteratorDataSetIterator(orig.iterator(), batchSize); DataSetIterator iter = new IteratorDataSetIterator(orig.iterator(), batchSize);
int count = 0; int count = 0;
while (iter.hasNext()) { while (iter.hasNext()) {
DataSet ds = iter.next(); DataSet ds = iter.next();
assertArrayEquals(new long[] {batchSize, featureSize}, ds.getFeatures().shape()); assertArrayEquals(new long[] { batchSize, featureSize }, ds.getFeatures().shape());
assertArrayEquals(new long[] {batchSize, labelSize}, ds.getLabels().shape()); assertArrayEquals(new long[] { batchSize, labelSize }, ds.getLabels().shape());
List<INDArray> fList = new ArrayList<>(); List<INDArray> fList = new ArrayList<>();
List<INDArray> lList = new ArrayList<>(); List<INDArray> lList = new ArrayList<>();
for (int i = 0; i < batchSize; i++) { for (int i = 0; i < batchSize; i++) {
@ -302,66 +263,44 @@ public class DataSetIteratorTest extends BaseDL4JTest {
fList.add(dsOrig.getFeatures()); fList.add(dsOrig.getFeatures());
lList.add(dsOrig.getLabels()); lList.add(dsOrig.getLabels());
} }
INDArray fExp = Nd4j.vstack(fList); INDArray fExp = Nd4j.vstack(fList);
INDArray lExp = Nd4j.vstack(lList); INDArray lExp = Nd4j.vstack(lList);
assertEquals(fExp, ds.getFeatures()); assertEquals(fExp, ds.getFeatures());
assertEquals(lExp, ds.getLabels()); assertEquals(lExp, ds.getLabels());
count++; count++;
} }
assertEquals(count, numBatches); assertEquals(count, numBatches);
} }
@Test @Test
public void testIteratorDataSetIteratorSplitting() { @DisplayName("Test Iterator Data Set Iterator Splitting")
//Test splitting large data sets into smaller ones void testIteratorDataSetIteratorSplitting() {
// Test splitting large data sets into smaller ones
int origBatchSize = 4; int origBatchSize = 4;
int origNumDSs = 3; int origNumDSs = 3;
int batchSize = 3; int batchSize = 3;
int numBatches = 4; int numBatches = 4;
int featureSize = 5; int featureSize = 5;
int labelSize = 6; int labelSize = 6;
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
List<DataSet> orig = new ArrayList<>(); List<DataSet> orig = new ArrayList<>();
for (int i = 0; i < origNumDSs; i++) { for (int i = 0; i < origNumDSs; i++) {
INDArray features = Nd4j.rand(origBatchSize, featureSize); INDArray features = Nd4j.rand(origBatchSize, featureSize);
INDArray labels = Nd4j.rand(origBatchSize, labelSize); INDArray labels = Nd4j.rand(origBatchSize, labelSize);
orig.add(new DataSet(features, labels)); orig.add(new DataSet(features, labels));
} }
List<DataSet> expected = new ArrayList<>(); List<DataSet> expected = new ArrayList<>();
expected.add(new DataSet(orig.get(0).getFeatures().getRows(0, 1, 2), expected.add(new DataSet(orig.get(0).getFeatures().getRows(0, 1, 2), orig.get(0).getLabels().getRows(0, 1, 2)));
orig.get(0).getLabels().getRows(0, 1, 2))); expected.add(new DataSet(Nd4j.vstack(orig.get(0).getFeatures().getRows(3), orig.get(1).getFeatures().getRows(0, 1)), Nd4j.vstack(orig.get(0).getLabels().getRows(3), orig.get(1).getLabels().getRows(0, 1))));
expected.add(new DataSet( expected.add(new DataSet(Nd4j.vstack(orig.get(1).getFeatures().getRows(2, 3), orig.get(2).getFeatures().getRows(0)), Nd4j.vstack(orig.get(1).getLabels().getRows(2, 3), orig.get(2).getLabels().getRows(0))));
Nd4j.vstack(orig.get(0).getFeatures().getRows(3), expected.add(new DataSet(orig.get(2).getFeatures().getRows(1, 2, 3), orig.get(2).getLabels().getRows(1, 2, 3)));
orig.get(1).getFeatures().getRows(0, 1)),
Nd4j.vstack(orig.get(0).getLabels().getRows(3), orig.get(1).getLabels().getRows(0, 1))));
expected.add(new DataSet(
Nd4j.vstack(orig.get(1).getFeatures().getRows(2, 3),
orig.get(2).getFeatures().getRows(0)),
Nd4j.vstack(orig.get(1).getLabels().getRows(2, 3), orig.get(2).getLabels().getRows(0))));
expected.add(new DataSet(orig.get(2).getFeatures().getRows(1, 2, 3),
orig.get(2).getLabels().getRows(1, 2, 3)));
DataSetIterator iter = new IteratorDataSetIterator(orig.iterator(), batchSize); DataSetIterator iter = new IteratorDataSetIterator(orig.iterator(), batchSize);
int count = 0; int count = 0;
while (iter.hasNext()) { while (iter.hasNext()) {
DataSet ds = iter.next(); DataSet ds = iter.next();
assertEquals(expected.get(count), ds); assertEquals(expected.get(count), ds);
count++; count++;
} }
assertEquals(count, numBatches); assertEquals(count, numBatches);
} }
} }

View File

@ -17,13 +17,12 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.deeplearning4j.datasets.iterator; package org.deeplearning4j.datasets.iterator;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.ExpectedException; import org.junit.rules.ExpectedException;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
@ -32,23 +31,27 @@ import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import static org.junit.Assert.*; import static org.junit.jupiter.api.Assertions.*;
public class EarlyTerminationDataSetIteratorTest extends BaseDL4JTest { import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@DisplayName("Early Termination Data Set Iterator Test")
class EarlyTerminationDataSetIteratorTest extends BaseDL4JTest {
int minibatchSize = 10; int minibatchSize = 10;
int numExamples = 105; int numExamples = 105;
@Rule @Rule
public final ExpectedException exception = ExpectedException.none(); public final ExpectedException exception = ExpectedException.none();
@Test @Test
public void testNextAndReset() throws Exception { @DisplayName("Test Next And Reset")
void testNextAndReset() throws Exception {
int terminateAfter = 2; int terminateAfter = 2;
DataSetIterator iter = new MnistDataSetIterator(minibatchSize, numExamples); DataSetIterator iter = new MnistDataSetIterator(minibatchSize, numExamples);
EarlyTerminationDataSetIterator earlyEndIter = new EarlyTerminationDataSetIterator(iter, terminateAfter); EarlyTerminationDataSetIterator earlyEndIter = new EarlyTerminationDataSetIterator(iter, terminateAfter);
assertTrue(earlyEndIter.hasNext()); assertTrue(earlyEndIter.hasNext());
int batchesSeen = 0; int batchesSeen = 0;
List<DataSet> seenData = new ArrayList<>(); List<DataSet> seenData = new ArrayList<>();
@ -59,8 +62,7 @@ public class EarlyTerminationDataSetIteratorTest extends BaseDL4JTest {
batchesSeen++; batchesSeen++;
} }
assertEquals(batchesSeen, terminateAfter); assertEquals(batchesSeen, terminateAfter);
// check data is repeated after reset
//check data is repeated after reset
earlyEndIter.reset(); earlyEndIter.reset();
batchesSeen = 0; batchesSeen = 0;
while (earlyEndIter.hasNext()) { while (earlyEndIter.hasNext()) {
@ -72,27 +74,23 @@ public class EarlyTerminationDataSetIteratorTest extends BaseDL4JTest {
} }
@Test @Test
public void testNextNum() throws IOException { @DisplayName("Test Next Num")
void testNextNum() throws IOException {
int terminateAfter = 1; int terminateAfter = 1;
DataSetIterator iter = new MnistDataSetIterator(minibatchSize, numExamples); DataSetIterator iter = new MnistDataSetIterator(minibatchSize, numExamples);
EarlyTerminationDataSetIterator earlyEndIter = new EarlyTerminationDataSetIterator(iter, terminateAfter); EarlyTerminationDataSetIterator earlyEndIter = new EarlyTerminationDataSetIterator(iter, terminateAfter);
earlyEndIter.next(10); earlyEndIter.next(10);
assertEquals(false, earlyEndIter.hasNext()); assertEquals(false, earlyEndIter.hasNext());
earlyEndIter.reset(); earlyEndIter.reset();
assertEquals(true, earlyEndIter.hasNext()); assertEquals(true, earlyEndIter.hasNext());
} }
@Test @Test
public void testCallstoNextNotAllowed() throws IOException { @DisplayName("Test Callsto Next Not Allowed")
void testCallstoNextNotAllowed() throws IOException {
int terminateAfter = 1; int terminateAfter = 1;
DataSetIterator iter = new MnistDataSetIterator(minibatchSize, numExamples); DataSetIterator iter = new MnistDataSetIterator(minibatchSize, numExamples);
EarlyTerminationDataSetIterator earlyEndIter = new EarlyTerminationDataSetIterator(iter, terminateAfter); EarlyTerminationDataSetIterator earlyEndIter = new EarlyTerminationDataSetIterator(iter, terminateAfter);
earlyEndIter.next(10); earlyEndIter.next(10);
iter.reset(); iter.reset();
exception.expect(RuntimeException.class); exception.expect(RuntimeException.class);

View File

@ -17,40 +17,39 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.deeplearning4j.datasets.iterator; package org.deeplearning4j.datasets.iterator;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.ExpectedException; import org.junit.rules.ExpectedException;
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter; import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals; @DisplayName("Early Termination Multi Data Set Iterator Test")
import static org.junit.Assert.assertTrue; class EarlyTerminationMultiDataSetIteratorTest extends BaseDL4JTest {
public class EarlyTerminationMultiDataSetIteratorTest extends BaseDL4JTest {
int minibatchSize = 5; int minibatchSize = 5;
int numExamples = 105; int numExamples = 105;
@Rule @Rule
public final ExpectedException exception = ExpectedException.none(); public final ExpectedException exception = ExpectedException.none();
@Test @Test
public void testNextAndReset() throws Exception { @DisplayName("Test Next And Reset")
void testNextAndReset() throws Exception {
int terminateAfter = 2; int terminateAfter = 2;
MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(new MnistDataSetIterator(minibatchSize, numExamples));
MultiDataSetIterator iter =
new MultiDataSetIteratorAdapter(new MnistDataSetIterator(minibatchSize, numExamples));
int count = 0; int count = 0;
List<MultiDataSet> seenMDS = new ArrayList<>(); List<MultiDataSet> seenMDS = new ArrayList<>();
while (count < terminateAfter) { while (count < terminateAfter) {
@ -58,10 +57,7 @@ public class EarlyTerminationMultiDataSetIteratorTest extends BaseDL4JTest {
count++; count++;
} }
iter.reset(); iter.reset();
EarlyTerminationMultiDataSetIterator earlyEndIter = new EarlyTerminationMultiDataSetIterator(iter, terminateAfter);
EarlyTerminationMultiDataSetIterator earlyEndIter =
new EarlyTerminationMultiDataSetIterator(iter, terminateAfter);
assertTrue(earlyEndIter.hasNext()); assertTrue(earlyEndIter.hasNext());
count = 0; count = 0;
while (earlyEndIter.hasNext()) { while (earlyEndIter.hasNext()) {
@ -71,8 +67,7 @@ public class EarlyTerminationMultiDataSetIteratorTest extends BaseDL4JTest {
count++; count++;
} }
assertEquals(count, terminateAfter); assertEquals(count, terminateAfter);
// check data is repeated
//check data is repeated
earlyEndIter.reset(); earlyEndIter.reset();
count = 0; count = 0;
while (earlyEndIter.hasNext()) { while (earlyEndIter.hasNext()) {
@ -84,34 +79,26 @@ public class EarlyTerminationMultiDataSetIteratorTest extends BaseDL4JTest {
} }
@Test @Test
public void testNextNum() throws IOException { @DisplayName("Test Next Num")
void testNextNum() throws IOException {
int terminateAfter = 1; int terminateAfter = 1;
MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(new MnistDataSetIterator(minibatchSize, numExamples));
MultiDataSetIterator iter = EarlyTerminationMultiDataSetIterator earlyEndIter = new EarlyTerminationMultiDataSetIterator(iter, terminateAfter);
new MultiDataSetIteratorAdapter(new MnistDataSetIterator(minibatchSize, numExamples));
EarlyTerminationMultiDataSetIterator earlyEndIter =
new EarlyTerminationMultiDataSetIterator(iter, terminateAfter);
earlyEndIter.next(10); earlyEndIter.next(10);
assertEquals(false, earlyEndIter.hasNext()); assertEquals(false, earlyEndIter.hasNext());
earlyEndIter.reset(); earlyEndIter.reset();
assertEquals(true, earlyEndIter.hasNext()); assertEquals(true, earlyEndIter.hasNext());
} }
@Test @Test
public void testCallstoNextNotAllowed() throws IOException { @DisplayName("Test Callsto Next Not Allowed")
void testCallstoNextNotAllowed() throws IOException {
int terminateAfter = 1; int terminateAfter = 1;
MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(new MnistDataSetIterator(minibatchSize, numExamples));
MultiDataSetIterator iter = EarlyTerminationMultiDataSetIterator earlyEndIter = new EarlyTerminationMultiDataSetIterator(iter, terminateAfter);
new MultiDataSetIteratorAdapter(new MnistDataSetIterator(minibatchSize, numExamples));
EarlyTerminationMultiDataSetIterator earlyEndIter =
new EarlyTerminationMultiDataSetIterator(iter, terminateAfter);
earlyEndIter.next(10); earlyEndIter.next(10);
iter.reset(); iter.reset();
exception.expect(RuntimeException.class); exception.expect(RuntimeException.class);
earlyEndIter.next(10); earlyEndIter.next(10);
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.deeplearning4j.datasets.iterator; package org.deeplearning4j.datasets.iterator;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@ -25,90 +24,75 @@ import lombok.val;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.parallel.JointParallelDataSetIterator; import org.deeplearning4j.datasets.iterator.parallel.JointParallelDataSetIterator;
import org.deeplearning4j.datasets.iterator.tools.SimpleVariableGenerator; import org.deeplearning4j.datasets.iterator.tools.SimpleVariableGenerator;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.enums.InequalityHandling; import org.nd4j.linalg.dataset.api.iterator.enums.InequalityHandling;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.Assert.assertNotNull; import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@Slf4j @Slf4j
public class JointParallelDataSetIteratorTest extends BaseDL4JTest { @DisplayName("Joint Parallel Data Set Iterator Test")
class JointParallelDataSetIteratorTest extends BaseDL4JTest {
/** /**
* Simple test, checking datasets alignment. They all should have the same data for the same cycle * Simple test, checking datasets alignment. They all should have the same data for the same cycle
* *
*
* @throws Exception * @throws Exception
*/ */
@Test @Test
public void testJointIterator1() throws Exception { @DisplayName("Test Joint Iterator 1")
void testJointIterator1() throws Exception {
DataSetIterator iteratorA = new SimpleVariableGenerator(119, 100, 32, 100, 10); DataSetIterator iteratorA = new SimpleVariableGenerator(119, 100, 32, 100, 10);
DataSetIterator iteratorB = new SimpleVariableGenerator(119, 100, 32, 100, 10); DataSetIterator iteratorB = new SimpleVariableGenerator(119, 100, 32, 100, 10);
JointParallelDataSetIterator jpdsi = new JointParallelDataSetIterator.Builder(InequalityHandling.STOP_EVERYONE).addSourceIterator(iteratorA).addSourceIterator(iteratorB).build();
JointParallelDataSetIterator jpdsi = new JointParallelDataSetIterator.Builder(InequalityHandling.STOP_EVERYONE)
.addSourceIterator(iteratorA).addSourceIterator(iteratorB).build();
int cnt = 0; int cnt = 0;
int example = 0; int example = 0;
while (jpdsi.hasNext()) { while (jpdsi.hasNext()) {
DataSet ds = jpdsi.next(); DataSet ds = jpdsi.next();
assertNotNull("Failed on iteration " + cnt, ds); assertNotNull(ds,"Failed on iteration " + cnt);
// ds.detach();
// ds.detach(); // ds.migrate();
//ds.migrate(); assertEquals( (double) example, ds.getFeatures().meanNumber().doubleValue(), 0.001,"Failed on iteration " + cnt);
assertEquals( (double) example + 0.5, ds.getLabels().meanNumber().doubleValue(), 0.001,"Failed on iteration " + cnt);
assertEquals("Failed on iteration " + cnt, (double) example, ds.getFeatures().meanNumber().doubleValue(), 0.001);
assertEquals("Failed on iteration " + cnt, (double) example + 0.5, ds.getLabels().meanNumber().doubleValue(), 0.001);
cnt++; cnt++;
if (cnt % 2 == 0) if (cnt % 2 == 0)
example++; example++;
} }
assertEquals(100, example); assertEquals(100, example);
assertEquals(200, cnt); assertEquals(200, cnt);
} }
/** /**
* This test checks for pass_null scenario, so in total we should have 300 real datasets + 100 nulls * This test checks for pass_null scenario, so in total we should have 300 real datasets + 100 nulls
* @throws Exception * @throws Exception
*/ */
@Test @Test
public void testJointIterator2() throws Exception { @DisplayName("Test Joint Iterator 2")
void testJointIterator2() throws Exception {
DataSetIterator iteratorA = new SimpleVariableGenerator(119, 200, 32, 100, 10); DataSetIterator iteratorA = new SimpleVariableGenerator(119, 200, 32, 100, 10);
DataSetIterator iteratorB = new SimpleVariableGenerator(119, 100, 32, 100, 10); DataSetIterator iteratorB = new SimpleVariableGenerator(119, 100, 32, 100, 10);
JointParallelDataSetIterator jpdsi = new JointParallelDataSetIterator.Builder(InequalityHandling.PASS_NULL).addSourceIterator(iteratorA).addSourceIterator(iteratorB).build();
JointParallelDataSetIterator jpdsi = new JointParallelDataSetIterator.Builder(InequalityHandling.PASS_NULL)
.addSourceIterator(iteratorA).addSourceIterator(iteratorB).build();
int cnt = 0; int cnt = 0;
int example = 0; int example = 0;
int nulls = 0; int nulls = 0;
while (jpdsi.hasNext()) { while (jpdsi.hasNext()) {
DataSet ds = jpdsi.next(); DataSet ds = jpdsi.next();
if (cnt < 200) if (cnt < 200)
assertNotNull("Failed on iteration " + cnt, ds); assertNotNull(ds,"Failed on iteration " + cnt);
if (ds == null) if (ds == null)
nulls++; nulls++;
if (cnt % 2 == 2) { if (cnt % 2 == 2) {
assertEquals("Failed on iteration " + cnt, (double) example, assertEquals((double) example, ds.getFeatures().meanNumber().doubleValue(), 0.001,"Failed on iteration " + cnt);
ds.getFeatures().meanNumber().doubleValue(), 0.001); assertEquals((double) example + 0.5, ds.getLabels().meanNumber().doubleValue(), 0.001,"Failed on iteration " + cnt);
assertEquals("Failed on iteration " + cnt, (double) example + 0.5,
ds.getLabels().meanNumber().doubleValue(), 0.001);
} }
cnt++; cnt++;
if (cnt % 2 == 0) if (cnt % 2 == 0)
example++; example++;
} }
assertEquals(100, nulls); assertEquals(100, nulls);
assertEquals(200, example); assertEquals(200, example);
assertEquals(400, cnt); assertEquals(400, cnt);
@ -120,25 +104,18 @@ public class JointParallelDataSetIteratorTest extends BaseDL4JTest {
* @throws Exception * @throws Exception
*/ */
@Test @Test
public void testJointIterator3() throws Exception { @DisplayName("Test Joint Iterator 3")
void testJointIterator3() throws Exception {
DataSetIterator iteratorA = new SimpleVariableGenerator(119, 200, 32, 100, 10); DataSetIterator iteratorA = new SimpleVariableGenerator(119, 200, 32, 100, 10);
DataSetIterator iteratorB = new SimpleVariableGenerator(119, 100, 32, 100, 10); DataSetIterator iteratorB = new SimpleVariableGenerator(119, 100, 32, 100, 10);
JointParallelDataSetIterator jpdsi = new JointParallelDataSetIterator.Builder(InequalityHandling.RELOCATE).addSourceIterator(iteratorA).addSourceIterator(iteratorB).build();
JointParallelDataSetIterator jpdsi = new JointParallelDataSetIterator.Builder(InequalityHandling.RELOCATE)
.addSourceIterator(iteratorA).addSourceIterator(iteratorB).build();
int cnt = 0; int cnt = 0;
int example = 0; int example = 0;
while (jpdsi.hasNext()) { while (jpdsi.hasNext()) {
DataSet ds = jpdsi.next(); DataSet ds = jpdsi.next();
assertNotNull("Failed on iteration " + cnt, ds); assertNotNull(ds,"Failed on iteration " + cnt);
assertEquals((double) example, ds.getFeatures().meanNumber().doubleValue(), 0.001,"Failed on iteration " + cnt);
assertEquals("Failed on iteration " + cnt, (double) example, ds.getFeatures().meanNumber().doubleValue(), assertEquals( (double) example + 0.5, ds.getLabels().meanNumber().doubleValue(), 0.001,"Failed on iteration " + cnt);
0.001);
assertEquals("Failed on iteration " + cnt, (double) example + 0.5,
ds.getLabels().meanNumber().doubleValue(), 0.001);
cnt++; cnt++;
if (cnt < 200) { if (cnt < 200) {
if (cnt % 2 == 0) if (cnt % 2 == 0)
@ -146,8 +123,6 @@ public class JointParallelDataSetIteratorTest extends BaseDL4JTest {
} else } else
example++; example++;
} }
assertEquals(300, cnt); assertEquals(300, cnt);
assertEquals(200, example); assertEquals(200, example);
} }
@ -158,52 +133,38 @@ public class JointParallelDataSetIteratorTest extends BaseDL4JTest {
* @throws Exception * @throws Exception
*/ */
@Test @Test
public void testJointIterator4() throws Exception { @DisplayName("Test Joint Iterator 4")
void testJointIterator4() throws Exception {
DataSetIterator iteratorA = new SimpleVariableGenerator(119, 200, 32, 100, 10); DataSetIterator iteratorA = new SimpleVariableGenerator(119, 200, 32, 100, 10);
DataSetIterator iteratorB = new SimpleVariableGenerator(119, 100, 32, 100, 10); DataSetIterator iteratorB = new SimpleVariableGenerator(119, 100, 32, 100, 10);
JointParallelDataSetIterator jpdsi = new JointParallelDataSetIterator.Builder(InequalityHandling.RESET).addSourceIterator(iteratorA).addSourceIterator(iteratorB).build();
JointParallelDataSetIterator jpdsi = new JointParallelDataSetIterator.Builder(InequalityHandling.RESET)
.addSourceIterator(iteratorA).addSourceIterator(iteratorB).build();
int cnt = 0; int cnt = 0;
int cnt_sec = 0; int cnt_sec = 0;
int example_sec = 0; int example_sec = 0;
int example = 0; int example = 0;
while (jpdsi.hasNext()) { while (jpdsi.hasNext()) {
DataSet ds = jpdsi.next(); DataSet ds = jpdsi.next();
assertNotNull("Failed on iteration " + cnt, ds); assertNotNull(ds,"Failed on iteration " + cnt);
if (cnt % 2 == 0) { if (cnt % 2 == 0) {
assertEquals("Failed on iteration " + cnt, (double) example, assertEquals( (double) example, ds.getFeatures().meanNumber().doubleValue(), 0.001,"Failed on iteration " + cnt);
ds.getFeatures().meanNumber().doubleValue(), 0.001); assertEquals((double) example + 0.5, ds.getLabels().meanNumber().doubleValue(), 0.001,"Failed on iteration " + cnt);
assertEquals("Failed on iteration " + cnt, (double) example + 0.5,
ds.getLabels().meanNumber().doubleValue(), 0.001);
} else { } else {
if (cnt <= 200) { if (cnt <= 200) {
assertEquals("Failed on iteration " + cnt, (double) example, assertEquals((double) example, ds.getFeatures().meanNumber().doubleValue(), 0.001,"Failed on iteration " + cnt);
ds.getFeatures().meanNumber().doubleValue(), 0.001); assertEquals( (double) example + 0.5, ds.getLabels().meanNumber().doubleValue(), 0.001,"Failed on iteration " + cnt);
assertEquals("Failed on iteration " + cnt, (double) example + 0.5,
ds.getLabels().meanNumber().doubleValue(), 0.001);
} else { } else {
assertEquals("Failed on iteration " + cnt + ", second iteration " + cnt_sec, (double) example_sec, assertEquals((double) example_sec, ds.getFeatures().meanNumber().doubleValue(), 0.001,"Failed on iteration " + cnt + ", second iteration " + cnt_sec);
ds.getFeatures().meanNumber().doubleValue(), 0.001); assertEquals((double) example_sec + 0.5, ds.getLabels().meanNumber().doubleValue(), 0.001,"Failed on iteration " + cnt + ", second iteration " + cnt_sec);
assertEquals("Failed on iteration " + cnt + ", second iteration " + cnt_sec,
(double) example_sec + 0.5, ds.getLabels().meanNumber().doubleValue(), 0.001);
} }
} }
cnt++; cnt++;
if (cnt % 2 == 0) if (cnt % 2 == 0)
example++; example++;
if (cnt > 201 && cnt % 2 == 1) { if (cnt > 201 && cnt % 2 == 1) {
cnt_sec++; cnt_sec++;
example_sec++; example_sec++;
} }
} }
assertEquals(400, cnt); assertEquals(400, cnt);
assertEquals(200, example); assertEquals(200, example);
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.deeplearning4j.datasets.iterator; package org.deeplearning4j.datasets.iterator;
import org.datavec.api.records.reader.RecordReader; import org.datavec.api.records.reader.RecordReader;
@ -27,34 +26,33 @@ import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.util.TestDataSetConsumer; import org.deeplearning4j.nn.util.TestDataSetConsumer;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.Timeout; import org.junit.rules.Timeout;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import org.nd4j.common.resources.Resources; import org.nd4j.common.resources.Resources;
import java.util.Iterator; import java.util.Iterator;
import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicLong;
import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.*; @DisplayName("Multiple Epochs Iterator Test")
class MultipleEpochsIteratorTest extends BaseDL4JTest {
public class MultipleEpochsIteratorTest extends BaseDL4JTest {
@Rule @Rule
public Timeout timeout = Timeout.seconds(300); public Timeout timeout = Timeout.seconds(300);
@Test @Test
public void testNextAndReset() throws Exception { @DisplayName("Test Next And Reset")
void testNextAndReset() throws Exception {
int epochs = 3; int epochs = 3;
RecordReader rr = new CSVRecordReader(); RecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(Resources.asFile("iris.txt"))); rr.initialize(new FileSplit(Resources.asFile("iris.txt")));
DataSetIterator iter = new RecordReaderDataSetIterator(rr, 150); DataSetIterator iter = new RecordReaderDataSetIterator(rr, 150);
MultipleEpochsIterator multiIter = new MultipleEpochsIterator(epochs, iter); MultipleEpochsIterator multiIter = new MultipleEpochsIterator(epochs, iter);
assertTrue(multiIter.hasNext()); assertTrue(multiIter.hasNext());
while (multiIter.hasNext()) { while (multiIter.hasNext()) {
DataSet path = multiIter.next(); DataSet path = multiIter.next();
@ -64,18 +62,15 @@ public class MultipleEpochsIteratorTest extends BaseDL4JTest {
} }
@Test @Test
public void testLoadFullDataSet() throws Exception { @DisplayName("Test Load Full Data Set")
void testLoadFullDataSet() throws Exception {
int epochs = 3; int epochs = 3;
RecordReader rr = new CSVRecordReader(); RecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(Resources.asFile("iris.txt"))); rr.initialize(new FileSplit(Resources.asFile("iris.txt")));
DataSetIterator iter = new RecordReaderDataSetIterator(rr, 150); DataSetIterator iter = new RecordReaderDataSetIterator(rr, 150);
DataSet ds = iter.next(50); DataSet ds = iter.next(50);
assertEquals(50, ds.getFeatures().size(0)); assertEquals(50, ds.getFeatures().size(0));
MultipleEpochsIterator multiIter = new MultipleEpochsIterator(epochs, ds); MultipleEpochsIterator multiIter = new MultipleEpochsIterator(epochs, ds);
assertTrue(multiIter.hasNext()); assertTrue(multiIter.hasNext());
int count = 0; int count = 0;
while (multiIter.hasNext()) { while (multiIter.hasNext()) {
@ -89,28 +84,26 @@ public class MultipleEpochsIteratorTest extends BaseDL4JTest {
} }
@Test @Test
public void testLoadBatchDataSet() throws Exception { @DisplayName("Test Load Batch Data Set")
void testLoadBatchDataSet() throws Exception {
int epochs = 2; int epochs = 2;
RecordReader rr = new CSVRecordReader(); RecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile())); rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile()));
DataSetIterator iter = new RecordReaderDataSetIterator(rr, 150, 4, 3); DataSetIterator iter = new RecordReaderDataSetIterator(rr, 150, 4, 3);
DataSet ds = iter.next(20); DataSet ds = iter.next(20);
assertEquals(20, ds.getFeatures().size(0)); assertEquals(20, ds.getFeatures().size(0));
MultipleEpochsIterator multiIter = new MultipleEpochsIterator(epochs, ds); MultipleEpochsIterator multiIter = new MultipleEpochsIterator(epochs, ds);
while (multiIter.hasNext()) { while (multiIter.hasNext()) {
DataSet path = multiIter.next(10); DataSet path = multiIter.next(10);
assertNotNull(path); assertNotNull(path);
assertEquals(10, path.numExamples(), 0.0); assertEquals(10, path.numExamples(), 0.0);
} }
assertEquals(epochs, multiIter.epochs); assertEquals(epochs, multiIter.epochs);
} }
@Test @Test
public void testMEDIWithLoad1() throws Exception { @DisplayName("Test MEDI With Load 1")
void testMEDIWithLoad1() throws Exception {
ExistingDataSetIterator iter = new ExistingDataSetIterator(new IterableWithoutException(100)); ExistingDataSetIterator iter = new ExistingDataSetIterator(new IterableWithoutException(100));
MultipleEpochsIterator iterator = new MultipleEpochsIterator(10, iter, 24); MultipleEpochsIterator iterator = new MultipleEpochsIterator(10, iter, 24);
TestDataSetConsumer consumer = new TestDataSetConsumer(iterator, 1); TestDataSetConsumer consumer = new TestDataSetConsumer(iterator, 1);
@ -119,38 +112,39 @@ public class MultipleEpochsIteratorTest extends BaseDL4JTest {
} }
@Test @Test
public void testMEDIWithLoad2() throws Exception { @DisplayName("Test MEDI With Load 2")
void testMEDIWithLoad2() throws Exception {
ExistingDataSetIterator iter = new ExistingDataSetIterator(new IterableWithoutException(100)); ExistingDataSetIterator iter = new ExistingDataSetIterator(new IterableWithoutException(100));
MultipleEpochsIterator iterator = new MultipleEpochsIterator(10, iter, 24); MultipleEpochsIterator iterator = new MultipleEpochsIterator(10, iter, 24);
TestDataSetConsumer consumer = new TestDataSetConsumer(iterator, 2); TestDataSetConsumer consumer = new TestDataSetConsumer(iterator, 2);
long num1 = 0; long num1 = 0;
for (; num1 < 150; num1++) { for (; num1 < 150; num1++) {
consumer.consumeOnce(iterator.next(), true); consumer.consumeOnce(iterator.next(), true);
} }
iterator.reset(); iterator.reset();
long num2 = consumer.consumeWhileHasNext(true); long num2 = consumer.consumeWhileHasNext(true);
assertEquals((10 * 100) + 150, num1 + num2); assertEquals((10 * 100) + 150, num1 + num2);
} }
@Test @Test
public void testMEDIWithLoad3() throws Exception { @DisplayName("Test MEDI With Load 3")
void testMEDIWithLoad3() throws Exception {
ExistingDataSetIterator iter = new ExistingDataSetIterator(new IterableWithoutException(10000)); ExistingDataSetIterator iter = new ExistingDataSetIterator(new IterableWithoutException(10000));
MultipleEpochsIterator iterator = new MultipleEpochsIterator(iter, 24, 136); MultipleEpochsIterator iterator = new MultipleEpochsIterator(iter, 24, 136);
TestDataSetConsumer consumer = new TestDataSetConsumer(iterator, 2); TestDataSetConsumer consumer = new TestDataSetConsumer(iterator, 2);
long num1 = 0; long num1 = 0;
while (iterator.hasNext()) { while (iterator.hasNext()) {
consumer.consumeOnce(iterator.next(), true); consumer.consumeOnce(iterator.next(), true);
num1++; num1++;
} }
assertEquals(136, num1); assertEquals(136, num1);
} }
@DisplayName("Iterable Without Exception")
private class IterableWithoutException implements Iterable<DataSet> { private class IterableWithoutException implements Iterable<DataSet> {
private final AtomicLong counter = new AtomicLong(0); private final AtomicLong counter = new AtomicLong(0);
private final int datasets; private final int datasets;
public IterableWithoutException(int datasets) { public IterableWithoutException(int datasets) {
@ -161,6 +155,7 @@ public class MultipleEpochsIteratorTest extends BaseDL4JTest {
public Iterator<DataSet> iterator() { public Iterator<DataSet> iterator() {
counter.set(0); counter.set(0);
return new Iterator<DataSet>() { return new Iterator<DataSet>() {
@Override @Override
public boolean hasNext() { public boolean hasNext() {
return counter.get() < datasets; return counter.get() < datasets;
@ -174,7 +169,6 @@ public class MultipleEpochsIteratorTest extends BaseDL4JTest {
@Override @Override
public void remove() { public void remove() {
} }
}; };
} }

View File

@ -17,36 +17,34 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.deeplearning4j.datasets.iterator; package org.deeplearning4j.datasets.iterator;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertArrayEquals; @DisplayName("Random Data Set Iterator Test")
import static org.junit.Assert.assertEquals; class RandomDataSetIteratorTest extends BaseDL4JTest {
import static org.junit.Assert.assertTrue;
public class RandomDataSetIteratorTest extends BaseDL4JTest {
@Test @Test
public void testDSI(){ @DisplayName("Test DSI")
DataSetIterator iter = new RandomDataSetIterator(5, new long[]{3,4}, new long[]{3,5}, RandomDataSetIterator.Values.RANDOM_UNIFORM, void testDSI() {
RandomDataSetIterator.Values.ONE_HOT); DataSetIterator iter = new RandomDataSetIterator(5, new long[] { 3, 4 }, new long[] { 3, 5 }, RandomDataSetIterator.Values.RANDOM_UNIFORM, RandomDataSetIterator.Values.ONE_HOT);
int count = 0; int count = 0;
while(iter.hasNext()){ while (iter.hasNext()) {
count++; count++;
DataSet ds = iter.next(); DataSet ds = iter.next();
assertArrayEquals(new long[] { 3, 4 }, ds.getFeatures().shape());
assertArrayEquals(new long[]{3,4}, ds.getFeatures().shape()); assertArrayEquals(new long[] { 3, 5 }, ds.getLabels().shape());
assertArrayEquals(new long[]{3,5}, ds.getLabels().shape());
assertTrue(ds.getFeatures().minNumber().doubleValue() >= 0.0 && ds.getFeatures().maxNumber().doubleValue() <= 1.0); assertTrue(ds.getFeatures().minNumber().doubleValue() >= 0.0 && ds.getFeatures().maxNumber().doubleValue() <= 1.0);
assertEquals(Nd4j.ones(3), ds.getLabels().sum(1)); assertEquals(Nd4j.ones(3), ds.getLabels().sum(1));
} }
@ -54,31 +52,23 @@ public class RandomDataSetIteratorTest extends BaseDL4JTest {
} }
@Test @Test
public void testMDSI(){ @DisplayName("Test MDSI")
void testMDSI() {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
MultiDataSetIterator iter = new RandomMultiDataSetIterator.Builder(5) MultiDataSetIterator iter = new RandomMultiDataSetIterator.Builder(5).addFeatures(new long[] { 3, 4 }, RandomMultiDataSetIterator.Values.INTEGER_0_100).addFeatures(new long[] { 3, 5 }, RandomMultiDataSetIterator.Values.BINARY).addLabels(new long[] { 3, 6 }, RandomMultiDataSetIterator.Values.ZEROS).build();
.addFeatures(new long[]{3,4}, RandomMultiDataSetIterator.Values.INTEGER_0_100)
.addFeatures(new long[]{3,5}, RandomMultiDataSetIterator.Values.BINARY)
.addLabels(new long[]{3,6}, RandomMultiDataSetIterator.Values.ZEROS)
.build();
int count = 0; int count = 0;
while(iter.hasNext()){ while (iter.hasNext()) {
count++; count++;
MultiDataSet mds = iter.next(); MultiDataSet mds = iter.next();
assertEquals(2, mds.numFeatureArrays()); assertEquals(2, mds.numFeatureArrays());
assertEquals(1, mds.numLabelsArrays()); assertEquals(1, mds.numLabelsArrays());
assertArrayEquals(new long[]{3,4}, mds.getFeatures(0).shape()); assertArrayEquals(new long[] { 3, 4 }, mds.getFeatures(0).shape());
assertArrayEquals(new long[]{3,5}, mds.getFeatures(1).shape()); assertArrayEquals(new long[] { 3, 5 }, mds.getFeatures(1).shape());
assertArrayEquals(new long[]{3,6}, mds.getLabels(0).shape()); assertArrayEquals(new long[] { 3, 6 }, mds.getLabels(0).shape());
assertTrue(mds.getFeatures(0).minNumber().doubleValue() >= 0 && mds.getFeatures(0).maxNumber().doubleValue() <= 100.0 && mds.getFeatures(0).maxNumber().doubleValue() > 2.0);
assertTrue(mds.getFeatures(0).minNumber().doubleValue() >= 0 && mds.getFeatures(0).maxNumber().doubleValue() <= 100.0
&& mds.getFeatures(0).maxNumber().doubleValue() > 2.0);
assertTrue(mds.getFeatures(1).minNumber().doubleValue() == 0.0 && mds.getFeatures(1).maxNumber().doubleValue() == 1.0); assertTrue(mds.getFeatures(1).minNumber().doubleValue() == 0.0 && mds.getFeatures(1).maxNumber().doubleValue() == 1.0);
assertEquals(0.0, mds.getLabels(0).sumNumber().doubleValue(), 0.0); assertEquals(0.0, mds.getLabels(0).sumNumber().doubleValue(), 0.0);
} }
assertEquals(5, count); assertEquals(5, count);
} }
} }

View File

@ -17,27 +17,28 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.deeplearning4j.datasets.iterator; package org.deeplearning4j.datasets.iterator;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertEquals; import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
/** /**
* @author Adam Gibson * @author Adam Gibson
*/ */
public class SamplingTest extends BaseDL4JTest { @DisplayName("Sampling Test")
class SamplingTest extends BaseDL4JTest {
@Test @Test
public void testSample() throws Exception { @DisplayName("Test Sample")
void testSample() throws Exception {
DataSetIterator iter = new MnistDataSetIterator(10, 10); DataSetIterator iter = new MnistDataSetIterator(10, 10);
//batch size and total // batch size and total
DataSetIterator sampling = new SamplingDataSetIterator(iter.next(), 10, 10); DataSetIterator sampling = new SamplingDataSetIterator(iter.next(), 10, 10);
assertEquals(10, sampling.next().numExamples()); assertEquals(10, sampling.next().numExamples());
} }
} }

View File

@ -17,50 +17,46 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.deeplearning4j.eval; package org.deeplearning4j.eval;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.evaluation.curves.Histogram; import org.nd4j.evaluation.curves.Histogram;
import org.nd4j.evaluation.curves.PrecisionRecallCurve; import org.nd4j.evaluation.curves.PrecisionRecallCurve;
import org.nd4j.evaluation.curves.RocCurve; import org.nd4j.evaluation.curves.RocCurve;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import static junit.framework.TestCase.assertNull; import static junit.framework.TestCase.assertNull;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@DisplayName("Eval Json Test")
public class EvalJsonTest extends BaseDL4JTest { class EvalJsonTest extends BaseDL4JTest {
@Test @Test
public void testSerdeEmpty() { @DisplayName("Test Serde Empty")
void testSerdeEmpty() {
boolean print = false; boolean print = false;
org.nd4j.evaluation.IEvaluation[] arr = new org.nd4j.evaluation.IEvaluation[] { new Evaluation(), new EvaluationBinary(), new ROCBinary(10), new ROCMultiClass(10), new RegressionEvaluation(3), new RegressionEvaluation(), new EvaluationCalibration() };
org.nd4j.evaluation.IEvaluation[] arr = new org.nd4j.evaluation.IEvaluation[] {new Evaluation(), new EvaluationBinary(), new ROCBinary(10),
new ROCMultiClass(10), new RegressionEvaluation(3), new RegressionEvaluation(),
new EvaluationCalibration()};
for (org.nd4j.evaluation.IEvaluation e : arr) { for (org.nd4j.evaluation.IEvaluation e : arr) {
String json = e.toJson(); String json = e.toJson();
String stats = e.stats(); String stats = e.stats();
if (print) { if (print) {
System.out.println(e.getClass() + "\n" + json + "\n\n"); System.out.println(e.getClass() + "\n" + json + "\n\n");
} }
IEvaluation fromJson = (IEvaluation) org.nd4j.evaluation.BaseEvaluation.fromJson(json, org.nd4j.evaluation.BaseEvaluation.class); IEvaluation fromJson = (IEvaluation) org.nd4j.evaluation.BaseEvaluation.fromJson(json, org.nd4j.evaluation.BaseEvaluation.class);
assertEquals(e.toJson(), fromJson.toJson()); assertEquals(e.toJson(), fromJson.toJson());
} }
} }
@Test @Test
public void testSerde() { @DisplayName("Test Serde")
void testSerde() {
boolean print = false; boolean print = false;
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
Evaluation evaluation = new Evaluation(); Evaluation evaluation = new Evaluation();
EvaluationBinary evaluationBinary = new EvaluationBinary(); EvaluationBinary evaluationBinary = new EvaluationBinary();
ROC roc = new ROC(2); ROC roc = new ROC(2);
@ -68,56 +64,43 @@ public class EvalJsonTest extends BaseDL4JTest {
ROCMultiClass roc3 = new ROCMultiClass(2); ROCMultiClass roc3 = new ROCMultiClass(2);
RegressionEvaluation regressionEvaluation = new RegressionEvaluation(); RegressionEvaluation regressionEvaluation = new RegressionEvaluation();
EvaluationCalibration ec = new EvaluationCalibration(); EvaluationCalibration ec = new EvaluationCalibration();
org.nd4j.evaluation.IEvaluation[] arr = new org.nd4j.evaluation.IEvaluation[] { evaluation, evaluationBinary, roc, roc2, roc3, regressionEvaluation, ec };
org.nd4j.evaluation.IEvaluation[] arr = new org.nd4j.evaluation.IEvaluation[] {evaluation, evaluationBinary, roc, roc2, roc3, regressionEvaluation, ec};
INDArray evalLabel = Nd4j.create(10, 3); INDArray evalLabel = Nd4j.create(10, 3);
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
evalLabel.putScalar(i, i % 3, 1.0); evalLabel.putScalar(i, i % 3, 1.0);
} }
INDArray evalProb = Nd4j.rand(10, 3); INDArray evalProb = Nd4j.rand(10, 3);
evalProb.diviColumnVector(evalProb.sum(true,1)); evalProb.diviColumnVector(evalProb.sum(true, 1));
evaluation.eval(evalLabel, evalProb); evaluation.eval(evalLabel, evalProb);
roc3.eval(evalLabel, evalProb); roc3.eval(evalLabel, evalProb);
ec.eval(evalLabel, evalProb); ec.eval(evalLabel, evalProb);
evalLabel = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(10, 3), 0.5)); evalLabel = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(10, 3), 0.5));
evalProb = Nd4j.rand(10, 3); evalProb = Nd4j.rand(10, 3);
evaluationBinary.eval(evalLabel, evalProb); evaluationBinary.eval(evalLabel, evalProb);
roc2.eval(evalLabel, evalProb); roc2.eval(evalLabel, evalProb);
evalLabel = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(10, 1), 0.5)); evalLabel = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(10, 1), 0.5));
evalProb = Nd4j.rand(10, 1); evalProb = Nd4j.rand(10, 1);
roc.eval(evalLabel, evalProb); roc.eval(evalLabel, evalProb);
regressionEvaluation.eval(Nd4j.rand(10, 3), Nd4j.rand(10, 3)); regressionEvaluation.eval(Nd4j.rand(10, 3), Nd4j.rand(10, 3));
for (org.nd4j.evaluation.IEvaluation e : arr) { for (org.nd4j.evaluation.IEvaluation e : arr) {
String json = e.toJson(); String json = e.toJson();
if (print) { if (print) {
System.out.println(e.getClass() + "\n" + json + "\n\n"); System.out.println(e.getClass() + "\n" + json + "\n\n");
} }
IEvaluation fromJson = (IEvaluation) BaseEvaluation.fromJson(json, org.nd4j.evaluation.BaseEvaluation.class); IEvaluation fromJson = (IEvaluation) BaseEvaluation.fromJson(json, org.nd4j.evaluation.BaseEvaluation.class);
assertEquals(e.toJson(), fromJson.toJson()); assertEquals(e.toJson(), fromJson.toJson());
} }
} }
@Test @Test
public void testSerdeExactRoc() { @DisplayName("Test Serde Exact Roc")
void testSerdeExactRoc() {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
boolean print = false; boolean print = false;
ROC roc = new ROC(0); ROC roc = new ROC(0);
ROCBinary roc2 = new ROCBinary(0); ROCBinary roc2 = new ROCBinary(0);
ROCMultiClass roc3 = new ROCMultiClass(0); ROCMultiClass roc3 = new ROCMultiClass(0);
org.nd4j.evaluation.IEvaluation[] arr = new org.nd4j.evaluation.IEvaluation[] { roc, roc2, roc3 };
org.nd4j.evaluation.IEvaluation[] arr = new org.nd4j.evaluation.IEvaluation[] {roc, roc2, roc3};
INDArray evalLabel = Nd4j.create(100, 3); INDArray evalLabel = Nd4j.create(100, 3);
for (int i = 0; i < 100; i++) { for (int i = 0; i < 100; i++) {
evalLabel.putScalar(i, i % 3, 1.0); evalLabel.putScalar(i, i % 3, 1.0);
@ -125,15 +108,12 @@ public class EvalJsonTest extends BaseDL4JTest {
INDArray evalProb = Nd4j.rand(100, 3); INDArray evalProb = Nd4j.rand(100, 3);
evalProb.diviColumnVector(evalProb.sum(1)); evalProb.diviColumnVector(evalProb.sum(1));
roc3.eval(evalLabel, evalProb); roc3.eval(evalLabel, evalProb);
evalLabel = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(100, 3), 0.5)); evalLabel = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(100, 3), 0.5));
evalProb = Nd4j.rand(100, 3); evalProb = Nd4j.rand(100, 3);
roc2.eval(evalLabel, evalProb); roc2.eval(evalLabel, evalProb);
evalLabel = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(100, 1), 0.5)); evalLabel = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(100, 1), 0.5));
evalProb = Nd4j.rand(100, 1); evalProb = Nd4j.rand(100, 1);
roc.eval(evalLabel, evalProb); roc.eval(evalLabel, evalProb);
for (org.nd4j.evaluation.IEvaluation e : arr) { for (org.nd4j.evaluation.IEvaluation e : arr) {
System.out.println(e.getClass()); System.out.println(e.getClass());
String json = e.toJson(); String json = e.toJson();
@ -143,37 +123,34 @@ public class EvalJsonTest extends BaseDL4JTest {
} }
org.nd4j.evaluation.IEvaluation fromJson = BaseEvaluation.fromJson(json, org.nd4j.evaluation.BaseEvaluation.class); org.nd4j.evaluation.IEvaluation fromJson = BaseEvaluation.fromJson(json, org.nd4j.evaluation.BaseEvaluation.class);
assertEquals(e, fromJson); assertEquals(e, fromJson);
if (fromJson instanceof ROC) { if (fromJson instanceof ROC) {
//Shouldn't have probAndLabel, but should have stored AUC and AUPRC // Shouldn't have probAndLabel, but should have stored AUC and AUPRC
assertNull(((ROC) fromJson).getProbAndLabel()); assertNull(((ROC) fromJson).getProbAndLabel());
assertTrue(((ROC) fromJson).calculateAUC() > 0.0); assertTrue(((ROC) fromJson).calculateAUC() > 0.0);
assertTrue(((ROC) fromJson).calculateAUCPR() > 0.0); assertTrue(((ROC) fromJson).calculateAUCPR() > 0.0);
assertEquals(((ROC) e).getRocCurve(), ((ROC) fromJson).getRocCurve()); assertEquals(((ROC) e).getRocCurve(), ((ROC) fromJson).getRocCurve());
assertEquals(((ROC) e).getPrecisionRecallCurve(), ((ROC) fromJson).getPrecisionRecallCurve()); assertEquals(((ROC) e).getPrecisionRecallCurve(), ((ROC) fromJson).getPrecisionRecallCurve());
} else if (e instanceof ROCBinary) { } else if (e instanceof ROCBinary) {
org.nd4j.evaluation.classification.ROC[] rocs = ((ROCBinary) fromJson).getUnderlying(); org.nd4j.evaluation.classification.ROC[] rocs = ((ROCBinary) fromJson).getUnderlying();
org.nd4j.evaluation.classification.ROC[] origRocs = ((ROCBinary) e).getUnderlying(); org.nd4j.evaluation.classification.ROC[] origRocs = ((ROCBinary) e).getUnderlying();
// for(ROC r : rocs ){ // for(ROC r : rocs ){
for (int i = 0; i < origRocs.length; i++) { for (int i = 0; i < origRocs.length; i++) {
org.nd4j.evaluation.classification.ROC r = rocs[i]; org.nd4j.evaluation.classification.ROC r = rocs[i];
org.nd4j.evaluation.classification.ROC origR = origRocs[i]; org.nd4j.evaluation.classification.ROC origR = origRocs[i];
//Shouldn't have probAndLabel, but should have stored AUC and AUPRC, AND stored curves // Shouldn't have probAndLabel, but should have stored AUC and AUPRC, AND stored curves
assertNull(r.getProbAndLabel()); assertNull(r.getProbAndLabel());
assertEquals(origR.calculateAUC(), origR.calculateAUC(), 1e-6); assertEquals(origR.calculateAUC(), origR.calculateAUC(), 1e-6);
assertEquals(origR.calculateAUCPR(), origR.calculateAUCPR(), 1e-6); assertEquals(origR.calculateAUCPR(), origR.calculateAUCPR(), 1e-6);
assertEquals(origR.getRocCurve(), origR.getRocCurve()); assertEquals(origR.getRocCurve(), origR.getRocCurve());
assertEquals(origR.getPrecisionRecallCurve(), origR.getPrecisionRecallCurve()); assertEquals(origR.getPrecisionRecallCurve(), origR.getPrecisionRecallCurve());
} }
} else if (e instanceof ROCMultiClass) { } else if (e instanceof ROCMultiClass) {
org.nd4j.evaluation.classification.ROC[] rocs = ((ROCMultiClass) fromJson).getUnderlying(); org.nd4j.evaluation.classification.ROC[] rocs = ((ROCMultiClass) fromJson).getUnderlying();
org.nd4j.evaluation.classification.ROC[] origRocs = ((ROCMultiClass) e).getUnderlying(); org.nd4j.evaluation.classification.ROC[] origRocs = ((ROCMultiClass) e).getUnderlying();
for (int i = 0; i < origRocs.length; i++) { for (int i = 0; i < origRocs.length; i++) {
org.nd4j.evaluation.classification.ROC r = rocs[i]; org.nd4j.evaluation.classification.ROC r = rocs[i];
org.nd4j.evaluation.classification.ROC origR = origRocs[i]; org.nd4j.evaluation.classification.ROC origR = origRocs[i];
//Shouldn't have probAndLabel, but should have stored AUC and AUPRC, AND stored curves // Shouldn't have probAndLabel, but should have stored AUC and AUPRC, AND stored curves
assertNull(r.getProbAndLabel()); assertNull(r.getProbAndLabel());
assertEquals(origR.calculateAUC(), origR.calculateAUC(), 1e-6); assertEquals(origR.calculateAUC(), origR.calculateAUC(), 1e-6);
assertEquals(origR.calculateAUCPR(), origR.calculateAUCPR(), 1e-6); assertEquals(origR.calculateAUCPR(), origR.calculateAUCPR(), 1e-6);
@ -185,32 +162,23 @@ public class EvalJsonTest extends BaseDL4JTest {
} }
@Test @Test
public void testJsonYamlCurves() { @DisplayName("Test Json Yaml Curves")
void testJsonYamlCurves() {
ROC roc = new ROC(0); ROC roc = new ROC(0);
INDArray evalLabel = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(100, 1), 0.5));
INDArray evalLabel =
Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(100, 1), 0.5));
INDArray evalProb = Nd4j.rand(100, 1); INDArray evalProb = Nd4j.rand(100, 1);
roc.eval(evalLabel, evalProb); roc.eval(evalLabel, evalProb);
RocCurve c = roc.getRocCurve(); RocCurve c = roc.getRocCurve();
PrecisionRecallCurve prc = roc.getPrecisionRecallCurve(); PrecisionRecallCurve prc = roc.getPrecisionRecallCurve();
String json1 = c.toJson(); String json1 = c.toJson();
String json2 = prc.toJson(); String json2 = prc.toJson();
RocCurve c2 = RocCurve.fromJson(json1); RocCurve c2 = RocCurve.fromJson(json1);
PrecisionRecallCurve prc2 = PrecisionRecallCurve.fromJson(json2); PrecisionRecallCurve prc2 = PrecisionRecallCurve.fromJson(json2);
assertEquals(c, c2); assertEquals(c, c2);
assertEquals(prc, prc2); assertEquals(prc, prc2);
// System.out.println(json1);
// System.out.println(json1); // Also test: histograms
//Also test: histograms
EvaluationCalibration ec = new EvaluationCalibration(); EvaluationCalibration ec = new EvaluationCalibration();
evalLabel = Nd4j.create(10, 3); evalLabel = Nd4j.create(10, 3);
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
evalLabel.putScalar(i, i % 3, 1.0); evalLabel.putScalar(i, i % 3, 1.0);
@ -218,67 +186,45 @@ public class EvalJsonTest extends BaseDL4JTest {
evalProb = Nd4j.rand(10, 3); evalProb = Nd4j.rand(10, 3);
evalProb.diviColumnVector(evalProb.sum(1)); evalProb.diviColumnVector(evalProb.sum(1));
ec.eval(evalLabel, evalProb); ec.eval(evalLabel, evalProb);
Histogram[] histograms = new Histogram[] { ec.getResidualPlotAllClasses(), ec.getResidualPlot(0), ec.getResidualPlot(1), ec.getProbabilityHistogramAllClasses(), ec.getProbabilityHistogram(0), ec.getProbabilityHistogram(1) };
Histogram[] histograms = new Histogram[] {ec.getResidualPlotAllClasses(), ec.getResidualPlot(0),
ec.getResidualPlot(1), ec.getProbabilityHistogramAllClasses(), ec.getProbabilityHistogram(0),
ec.getProbabilityHistogram(1)};
for (Histogram h : histograms) { for (Histogram h : histograms) {
String json = h.toJson(); String json = h.toJson();
String yaml = h.toYaml(); String yaml = h.toYaml();
Histogram h2 = Histogram.fromJson(json); Histogram h2 = Histogram.fromJson(json);
Histogram h3 = Histogram.fromYaml(yaml); Histogram h3 = Histogram.fromYaml(yaml);
assertEquals(h, h2); assertEquals(h, h2);
assertEquals(h2, h3); assertEquals(h2, h3);
} }
} }
@Test @Test
public void testJsonWithCustomThreshold() { @DisplayName("Test Json With Custom Threshold")
void testJsonWithCustomThreshold() {
//Evaluation - binary threshold // Evaluation - binary threshold
Evaluation e = new Evaluation(0.25); Evaluation e = new Evaluation(0.25);
String json = e.toJson(); String json = e.toJson();
String yaml = e.toYaml(); String yaml = e.toYaml();
Evaluation eFromJson = Evaluation.fromJson(json); Evaluation eFromJson = Evaluation.fromJson(json);
Evaluation eFromYaml = Evaluation.fromYaml(yaml); Evaluation eFromYaml = Evaluation.fromYaml(yaml);
assertEquals(0.25, eFromJson.getBinaryDecisionThreshold(), 1e-6); assertEquals(0.25, eFromJson.getBinaryDecisionThreshold(), 1e-6);
assertEquals(0.25, eFromYaml.getBinaryDecisionThreshold(), 1e-6); assertEquals(0.25, eFromYaml.getBinaryDecisionThreshold(), 1e-6);
// Evaluation: custom cost array
INDArray costArray = Nd4j.create(new double[] { 1.0, 2.0, 3.0 });
//Evaluation: custom cost array
INDArray costArray = Nd4j.create(new double[] {1.0, 2.0, 3.0});
Evaluation e2 = new Evaluation(costArray); Evaluation e2 = new Evaluation(costArray);
json = e2.toJson(); json = e2.toJson();
yaml = e2.toYaml(); yaml = e2.toYaml();
eFromJson = Evaluation.fromJson(json); eFromJson = Evaluation.fromJson(json);
eFromYaml = Evaluation.fromYaml(yaml); eFromYaml = Evaluation.fromYaml(yaml);
assertEquals(e2.getCostArray(), eFromJson.getCostArray()); assertEquals(e2.getCostArray(), eFromJson.getCostArray());
assertEquals(e2.getCostArray(), eFromYaml.getCostArray()); assertEquals(e2.getCostArray(), eFromYaml.getCostArray());
// EvaluationBinary - per-output binary threshold
INDArray threshold = Nd4j.create(new double[] { 1.0, 0.5, 0.25 });
//EvaluationBinary - per-output binary threshold
INDArray threshold = Nd4j.create(new double[] {1.0, 0.5, 0.25});
EvaluationBinary eb = new EvaluationBinary(threshold); EvaluationBinary eb = new EvaluationBinary(threshold);
json = eb.toJson(); json = eb.toJson();
yaml = eb.toYaml(); yaml = eb.toYaml();
EvaluationBinary ebFromJson = EvaluationBinary.fromJson(json); EvaluationBinary ebFromJson = EvaluationBinary.fromJson(json);
EvaluationBinary ebFromYaml = EvaluationBinary.fromYaml(yaml); EvaluationBinary ebFromYaml = EvaluationBinary.fromYaml(yaml);
assertEquals(threshold, ebFromJson.getDecisionThreshold()); assertEquals(threshold, ebFromJson.getDecisionThreshold());
assertEquals(threshold, ebFromYaml.getDecisionThreshold()); assertEquals(threshold, ebFromYaml.getDecisionThreshold());
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.deeplearning4j.eval; package org.deeplearning4j.eval;
import org.datavec.api.records.metadata.RecordMetaData; import org.datavec.api.records.metadata.RecordMetaData;
@ -45,7 +44,7 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.EvaluativeListener; import org.deeplearning4j.optimize.listeners.EvaluativeListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
@ -58,78 +57,60 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.common.resources.Resources; import org.nd4j.common.resources.Resources;
import java.util.*; import java.util.*;
import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.*; @DisplayName("Eval Test")
class EvalTest extends BaseDL4JTest {
public class EvalTest extends BaseDL4JTest {
@Test @Test
public void testIris() { @DisplayName("Test Iris")
void testIris() {
// Network config // Network config
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).seed(42).updater(new Sgd(1e-6)).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(2).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()).layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build()).build();
.optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).seed(42)
.updater(new Sgd(1e-6)).list()
.layer(0, new DenseLayer.Builder().nIn(4).nOut(2).activation(Activation.TANH)
.weightInit(WeightInit.XAVIER).build())
.layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3).weightInit(WeightInit.XAVIER)
.activation(Activation.SOFTMAX).build())
.build();
// Instantiate model // Instantiate model
MultiLayerNetwork model = new MultiLayerNetwork(conf); MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init(); model.init();
model.addListeners(new ScoreIterationListener(1)); model.addListeners(new ScoreIterationListener(1));
// Train-test split // Train-test split
DataSetIterator iter = new IrisDataSetIterator(150, 150); DataSetIterator iter = new IrisDataSetIterator(150, 150);
DataSet next = iter.next(); DataSet next = iter.next();
next.shuffle(); next.shuffle();
SplitTestAndTrain trainTest = next.splitTestAndTrain(5, new Random(42)); SplitTestAndTrain trainTest = next.splitTestAndTrain(5, new Random(42));
// Train // Train
DataSet train = trainTest.getTrain(); DataSet train = trainTest.getTrain();
train.normalizeZeroMeanZeroUnitVariance(); train.normalizeZeroMeanZeroUnitVariance();
// Test // Test
DataSet test = trainTest.getTest(); DataSet test = trainTest.getTest();
test.normalizeZeroMeanZeroUnitVariance(); test.normalizeZeroMeanZeroUnitVariance();
INDArray testFeature = test.getFeatures(); INDArray testFeature = test.getFeatures();
INDArray testLabel = test.getLabels(); INDArray testLabel = test.getLabels();
// Fitting model // Fitting model
model.fit(train); model.fit(train);
// Get predictions from test feature // Get predictions from test feature
INDArray testPredictedLabel = model.output(testFeature); INDArray testPredictedLabel = model.output(testFeature);
// Eval with class number // Eval with class number
org.nd4j.evaluation.classification.Evaluation eval = new org.nd4j.evaluation.classification.Evaluation(3); //// Specify class num here // // Specify class num here
org.nd4j.evaluation.classification.Evaluation eval = new org.nd4j.evaluation.classification.Evaluation(3);
eval.eval(testLabel, testPredictedLabel); eval.eval(testLabel, testPredictedLabel);
double eval1F1 = eval.f1(); double eval1F1 = eval.f1();
double eval1Acc = eval.accuracy(); double eval1Acc = eval.accuracy();
// Eval without class number // Eval without class number
org.nd4j.evaluation.classification.Evaluation eval2 = new org.nd4j.evaluation.classification.Evaluation(); //// No class num // // No class num
org.nd4j.evaluation.classification.Evaluation eval2 = new org.nd4j.evaluation.classification.Evaluation();
eval2.eval(testLabel, testPredictedLabel); eval2.eval(testLabel, testPredictedLabel);
double eval2F1 = eval2.f1(); double eval2F1 = eval2.f1();
double eval2Acc = eval2.accuracy(); double eval2Acc = eval2.accuracy();
// Assert the two implementations give same f1 and accuracy (since one batch)
//Assert the two implementations give same f1 and accuracy (since one batch)
assertTrue(eval1F1 == eval2F1 && eval1Acc == eval2Acc); assertTrue(eval1F1 == eval2F1 && eval1Acc == eval2Acc);
org.nd4j.evaluation.classification.Evaluation evalViaMethod = model.evaluate(new ListDataSetIterator<>(Collections.singletonList(test))); org.nd4j.evaluation.classification.Evaluation evalViaMethod = model.evaluate(new ListDataSetIterator<>(Collections.singletonList(test)));
checkEvaluationEquality(eval, evalViaMethod); checkEvaluationEquality(eval, evalViaMethod);
// System.out.println(eval.getConfusionMatrix().toString());
// System.out.println(eval.getConfusionMatrix().toString()); // System.out.println(eval.getConfusionMatrix().toCSV());
// System.out.println(eval.getConfusionMatrix().toCSV()); // System.out.println(eval.getConfusionMatrix().toHTML());
// System.out.println(eval.getConfusionMatrix().toHTML()); // System.out.println(eval.confusionToString());
// System.out.println(eval.confusionToString());
eval.getConfusionMatrix().toString(); eval.getConfusionMatrix().toString();
eval.getConfusionMatrix().toCSV(); eval.getConfusionMatrix().toCSV();
eval.getConfusionMatrix().toHTML(); eval.getConfusionMatrix().toHTML();
@ -160,99 +141,79 @@ public class EvalTest extends BaseDL4JTest {
} }
@Test @Test
public void testEvaluationWithMetaData() throws Exception { @DisplayName("Test Evaluation With Meta Data")
void testEvaluationWithMetaData() throws Exception {
RecordReader csv = new CSVRecordReader(); RecordReader csv = new CSVRecordReader();
csv.initialize(new FileSplit(Resources.asFile("iris.txt"))); csv.initialize(new FileSplit(Resources.asFile("iris.txt")));
int batchSize = 10; int batchSize = 10;
int labelIdx = 4; int labelIdx = 4;
int numClasses = 3; int numClasses = 3;
RecordReaderDataSetIterator rrdsi = new RecordReaderDataSetIterator(csv, batchSize, labelIdx, numClasses); RecordReaderDataSetIterator rrdsi = new RecordReaderDataSetIterator(csv, batchSize, labelIdx, numClasses);
NormalizerStandardize ns = new NormalizerStandardize(); NormalizerStandardize ns = new NormalizerStandardize();
ns.fit(rrdsi); ns.fit(rrdsi);
rrdsi.setPreProcessor(ns); rrdsi.setPreProcessor(ns);
rrdsi.reset(); rrdsi.reset();
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Sgd(0.1)).list().layer(0, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(4).nOut(3).build()).build();
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Sgd(0.1))
.list()
.layer(0, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(4).nOut(3).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
net.fit(rrdsi); net.fit(rrdsi);
rrdsi.reset(); rrdsi.reset();
} }
org.nd4j.evaluation.classification.Evaluation e = new org.nd4j.evaluation.classification.Evaluation(); org.nd4j.evaluation.classification.Evaluation e = new org.nd4j.evaluation.classification.Evaluation();
rrdsi.setCollectMetaData(true); //*** New: Enable collection of metadata (stored in the DataSets) *** // *** New: Enable collection of metadata (stored in the DataSets) ***
rrdsi.setCollectMetaData(true);
while (rrdsi.hasNext()) { while (rrdsi.hasNext()) {
DataSet ds = rrdsi.next(); DataSet ds = rrdsi.next();
List<RecordMetaData> meta = ds.getExampleMetaData(RecordMetaData.class); //*** New - cross dependencies here make types difficult, usid Object internally in DataSet for this*** // *** New - cross dependencies here make types difficult, usid Object internally in DataSet for this***
List<RecordMetaData> meta = ds.getExampleMetaData(RecordMetaData.class);
INDArray out = net.output(ds.getFeatures()); INDArray out = net.output(ds.getFeatures());
e.eval(ds.getLabels(), out, meta); //*** New - evaluate and also store metadata *** // *** New - evaluate and also store metadata ***
e.eval(ds.getLabels(), out, meta);
} }
// System.out.println(e.stats());
// System.out.println(e.stats());
e.stats(); e.stats();
// System.out.println("\n\n*** Prediction Errors: ***");
// System.out.println("\n\n*** Prediction Errors: ***"); // *** New - get list of prediction errors from evaluation ***
List<org.nd4j.evaluation.meta.Prediction> errors = e.getPredictionErrors();
List<org.nd4j.evaluation.meta.Prediction> errors = e.getPredictionErrors(); //*** New - get list of prediction errors from evaluation ***
List<RecordMetaData> metaForErrors = new ArrayList<>(); List<RecordMetaData> metaForErrors = new ArrayList<>();
for (org.nd4j.evaluation.meta.Prediction p : errors) { for (org.nd4j.evaluation.meta.Prediction p : errors) {
metaForErrors.add((RecordMetaData) p.getRecordMetaData()); metaForErrors.add((RecordMetaData) p.getRecordMetaData());
} }
DataSet ds = rrdsi.loadFromMetaData(metaForErrors); //*** New - dynamically load a subset of the data, just for prediction errors *** // *** New - dynamically load a subset of the data, just for prediction errors ***
DataSet ds = rrdsi.loadFromMetaData(metaForErrors);
INDArray output = net.output(ds.getFeatures()); INDArray output = net.output(ds.getFeatures());
int count = 0; int count = 0;
for (org.nd4j.evaluation.meta.Prediction t : errors) { for (org.nd4j.evaluation.meta.Prediction t : errors) {
String s = t + "\t\tRaw Data: " String s = t + "\t\tRaw Data: " + // *** New - load subset of data from MetaData object (usually batched for efficiency) ***
+ csv.loadFromMetaData((RecordMetaData) t.getRecordMetaData()).getRecord() //*** New - load subset of data from MetaData object (usually batched for efficiency) *** csv.loadFromMetaData((RecordMetaData) t.getRecordMetaData()).getRecord() + "\tNormalized: " + ds.getFeatures().getRow(count) + "\tLabels: " + ds.getLabels().getRow(count) + "\tNetwork predictions: " + output.getRow(count);
+ "\tNormalized: " + ds.getFeatures().getRow(count) + "\tLabels: " // System.out.println(s);
+ ds.getLabels().getRow(count) + "\tNetwork predictions: " + output.getRow(count);
// System.out.println(s);
count++; count++;
} }
int errorCount = errors.size(); int errorCount = errors.size();
double expAcc = 1.0 - errorCount / 150.0; double expAcc = 1.0 - errorCount / 150.0;
assertEquals(expAcc, e.accuracy(), 1e-5); assertEquals(expAcc, e.accuracy(), 1e-5);
org.nd4j.evaluation.classification.ConfusionMatrix<Integer> confusion = e.getConfusionMatrix(); org.nd4j.evaluation.classification.ConfusionMatrix<Integer> confusion = e.getConfusionMatrix();
int[] actualCounts = new int[3]; int[] actualCounts = new int[3];
int[] predictedCounts = new int[3]; int[] predictedCounts = new int[3];
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
for (int j = 0; j < 3; j++) { for (int j = 0; j < 3; j++) {
int entry = confusion.getCount(i, j); //(actual,predicted) // (actual,predicted)
int entry = confusion.getCount(i, j);
List<org.nd4j.evaluation.meta.Prediction> list = e.getPredictions(i, j); List<org.nd4j.evaluation.meta.Prediction> list = e.getPredictions(i, j);
assertEquals(entry, list.size()); assertEquals(entry, list.size());
actualCounts[i] += entry; actualCounts[i] += entry;
predictedCounts[j] += entry; predictedCounts[j] += entry;
} }
} }
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
List<org.nd4j.evaluation.meta.Prediction> actualClassI = e.getPredictionsByActualClass(i); List<org.nd4j.evaluation.meta.Prediction> actualClassI = e.getPredictionsByActualClass(i);
List<org.nd4j.evaluation.meta.Prediction> predictedClassI = e.getPredictionByPredictedClass(i); List<org.nd4j.evaluation.meta.Prediction> predictedClassI = e.getPredictionByPredictedClass(i);
assertEquals(actualCounts[i], actualClassI.size()); assertEquals(actualCounts[i], actualClassI.size());
assertEquals(predictedCounts[i], predictedClassI.size()); assertEquals(predictedCounts[i], predictedClassI.size());
} }
// Finally: test doEvaluation methods
//Finally: test doEvaluation methods
rrdsi.reset(); rrdsi.reset();
org.nd4j.evaluation.classification.Evaluation e2 = new org.nd4j.evaluation.classification.Evaluation(); org.nd4j.evaluation.classification.Evaluation e2 = new org.nd4j.evaluation.classification.Evaluation();
net.doEvaluation(rrdsi, e2); net.doEvaluation(rrdsi, e2);
@ -262,7 +223,6 @@ public class EvalTest extends BaseDL4JTest {
assertEquals(actualCounts[i], actualClassI.size()); assertEquals(actualCounts[i], actualClassI.size());
assertEquals(predictedCounts[i], predictedClassI.size()); assertEquals(predictedCounts[i], predictedClassI.size());
} }
ComputationGraph cg = net.toComputationGraph(); ComputationGraph cg = net.toComputationGraph();
rrdsi.reset(); rrdsi.reset();
e2 = new org.nd4j.evaluation.classification.Evaluation(); e2 = new org.nd4j.evaluation.classification.Evaluation();
@ -273,7 +233,6 @@ public class EvalTest extends BaseDL4JTest {
assertEquals(actualCounts[i], actualClassI.size()); assertEquals(actualCounts[i], actualClassI.size());
assertEquals(predictedCounts[i], predictedClassI.size()); assertEquals(predictedCounts[i], predictedClassI.size());
} }
} }
private static void apply(org.nd4j.evaluation.classification.Evaluation e, int nTimes, INDArray predicted, INDArray actual) { private static void apply(org.nd4j.evaluation.classification.Evaluation e, int nTimes, INDArray predicted, INDArray actual) {
@ -283,138 +242,28 @@ public class EvalTest extends BaseDL4JTest {
} }
@Test @Test
public void testEvalSplitting(){ @DisplayName("Test Eval Splitting")
//Test for "tbptt-like" functionality void testEvalSplitting() {
// Test for "tbptt-like" functionality
for(WorkspaceMode ws : WorkspaceMode.values()) { for (WorkspaceMode ws : WorkspaceMode.values()) {
System.out.println("Starting test for workspace mode: " + ws); System.out.println("Starting test for workspace mode: " + ws);
int nIn = 4; int nIn = 4;
int layerSize = 5; int layerSize = 5;
int nOut = 6; int nOut = 6;
int tbpttLength = 10; int tbpttLength = 10;
int tsLength = 5 * tbpttLength + tbpttLength / 2; int tsLength = 5 * tbpttLength + tbpttLength / 2;
MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(12345).trainingWorkspaceMode(ws).inferenceWorkspaceMode(ws).list().layer(new LSTM.Builder().nIn(nIn).nOut(layerSize).build()).layer(new RnnOutputLayer.Builder().nIn(layerSize).nOut(nOut).activation(Activation.SOFTMAX).build()).build();
MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder() MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345).trainingWorkspaceMode(ws).inferenceWorkspaceMode(ws).list().layer(new LSTM.Builder().nIn(nIn).nOut(layerSize).build()).layer(new RnnOutputLayer.Builder().nIn(layerSize).nOut(nOut).activation(Activation.SOFTMAX).build()).tBPTTLength(10).backpropType(BackpropType.TruncatedBPTT).build();
.seed(12345)
.trainingWorkspaceMode(ws)
.inferenceWorkspaceMode(ws)
.list()
.layer(new LSTM.Builder().nIn(nIn).nOut(layerSize).build())
.layer(new RnnOutputLayer.Builder().nIn(layerSize).nOut(nOut)
.activation(Activation.SOFTMAX)
.build())
.build();
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder()
.seed(12345)
.trainingWorkspaceMode(ws)
.inferenceWorkspaceMode(ws)
.list()
.layer(new LSTM.Builder().nIn(nIn).nOut(layerSize).build())
.layer(new RnnOutputLayer.Builder().nIn(layerSize).nOut(nOut)
.activation(Activation.SOFTMAX).build())
.tBPTTLength(10)
.backpropType(BackpropType.TruncatedBPTT)
.build();
MultiLayerNetwork net1 = new MultiLayerNetwork(conf1); MultiLayerNetwork net1 = new MultiLayerNetwork(conf1);
net1.init(); net1.init();
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
net2.init(); net2.init();
net2.setParams(net1.params()); net2.setParams(net1.params());
for (boolean useMask : new boolean[] { false, true }) {
for(boolean useMask : new boolean[]{false, true}) { INDArray in1 = Nd4j.rand(new int[] { 3, nIn, tsLength });
INDArray in1 = Nd4j.rand(new int[]{3, nIn, tsLength});
INDArray out1 = TestUtils.randomOneHotTimeSeries(3, nOut, tsLength); INDArray out1 = TestUtils.randomOneHotTimeSeries(3, nOut, tsLength);
INDArray in2 = Nd4j.rand(new int[] { 5, nIn, tsLength });
INDArray in2 = Nd4j.rand(new int[]{5, nIn, tsLength});
INDArray out2 = TestUtils.randomOneHotTimeSeries(5, nOut, tsLength); INDArray out2 = TestUtils.randomOneHotTimeSeries(5, nOut, tsLength);
INDArray lMask1 = null;
INDArray lMask2 = null;
if(useMask){
lMask1 = Nd4j.create(3, tsLength);
lMask2 = Nd4j.create(5, tsLength);
Nd4j.getExecutioner().exec(new BernoulliDistribution(lMask1, 0.5));
Nd4j.getExecutioner().exec(new BernoulliDistribution(lMask2, 0.5));
}
List<DataSet> l = Arrays.asList(new DataSet(in1, out1, null, lMask1), new DataSet(in2, out2, null, lMask2));
DataSetIterator iter = new ExistingDataSetIterator(l);
// System.out.println("Net 1 eval");
org.nd4j.evaluation.IEvaluation[] e1 = net1.doEvaluation(iter, new org.nd4j.evaluation.classification.Evaluation(), new org.nd4j.evaluation.classification.ROCMultiClass(), new org.nd4j.evaluation.regression.RegressionEvaluation());
// System.out.println("Net 2 eval");
org.nd4j.evaluation.IEvaluation[] e2 = net2.doEvaluation(iter, new org.nd4j.evaluation.classification.Evaluation(), new org.nd4j.evaluation.classification.ROCMultiClass(), new org.nd4j.evaluation.regression.RegressionEvaluation());
assertEquals(e1[0], e2[0]);
assertEquals(e1[1], e2[1]);
assertEquals(e1[2], e2[2]);
}
}
}
@Test
public void testEvalSplittingCompGraph(){
//Test for "tbptt-like" functionality
for(WorkspaceMode ws : WorkspaceMode.values()) {
System.out.println("Starting test for workspace mode: " + ws);
int nIn = 4;
int layerSize = 5;
int nOut = 6;
int tbpttLength = 10;
int tsLength = 5 * tbpttLength + tbpttLength / 2;
ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder()
.seed(12345)
.trainingWorkspaceMode(ws)
.inferenceWorkspaceMode(ws)
.graphBuilder()
.addInputs("in")
.addLayer("0", new LSTM.Builder().nIn(nIn).nOut(layerSize).build(), "in")
.addLayer("1", new RnnOutputLayer.Builder().nIn(layerSize).nOut(nOut)
.activation(Activation.SOFTMAX)
.build(), "0")
.setOutputs("1")
.build();
ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder()
.seed(12345)
.trainingWorkspaceMode(ws)
.inferenceWorkspaceMode(ws)
.graphBuilder()
.addInputs("in")
.addLayer("0", new LSTM.Builder().nIn(nIn).nOut(layerSize).build(), "in")
.addLayer("1", new RnnOutputLayer.Builder().nIn(layerSize).nOut(nOut)
.activation(Activation.SOFTMAX)
.build(), "0")
.setOutputs("1")
.tBPTTLength(10)
.backpropType(BackpropType.TruncatedBPTT)
.build();
ComputationGraph net1 = new ComputationGraph(conf1);
net1.init();
ComputationGraph net2 = new ComputationGraph(conf2);
net2.init();
net2.setParams(net1.params());
for (boolean useMask : new boolean[]{false, true}) {
INDArray in1 = Nd4j.rand(new int[]{3, nIn, tsLength});
INDArray out1 = TestUtils.randomOneHotTimeSeries(3, nOut, tsLength);
INDArray in2 = Nd4j.rand(new int[]{5, nIn, tsLength});
INDArray out2 = TestUtils.randomOneHotTimeSeries(5, nOut, tsLength);
INDArray lMask1 = null; INDArray lMask1 = null;
INDArray lMask2 = null; INDArray lMask2 = null;
if (useMask) { if (useMask) {
@ -423,15 +272,12 @@ public class EvalTest extends BaseDL4JTest {
Nd4j.getExecutioner().exec(new BernoulliDistribution(lMask1, 0.5)); Nd4j.getExecutioner().exec(new BernoulliDistribution(lMask1, 0.5));
Nd4j.getExecutioner().exec(new BernoulliDistribution(lMask2, 0.5)); Nd4j.getExecutioner().exec(new BernoulliDistribution(lMask2, 0.5));
} }
List<DataSet> l = Arrays.asList(new DataSet(in1, out1, null, lMask1), new DataSet(in2, out2, null, lMask2));
List<DataSet> l = Arrays.asList(new DataSet(in1, out1), new DataSet(in2, out2));
DataSetIterator iter = new ExistingDataSetIterator(l); DataSetIterator iter = new ExistingDataSetIterator(l);
// System.out.println("Net 1 eval");
// System.out.println("Eval net 1");
org.nd4j.evaluation.IEvaluation[] e1 = net1.doEvaluation(iter, new org.nd4j.evaluation.classification.Evaluation(), new org.nd4j.evaluation.classification.ROCMultiClass(), new org.nd4j.evaluation.regression.RegressionEvaluation()); org.nd4j.evaluation.IEvaluation[] e1 = net1.doEvaluation(iter, new org.nd4j.evaluation.classification.Evaluation(), new org.nd4j.evaluation.classification.ROCMultiClass(), new org.nd4j.evaluation.regression.RegressionEvaluation());
// System.out.println("Eval net 2"); // System.out.println("Net 2 eval");
org.nd4j.evaluation.IEvaluation[] e2 = net2.doEvaluation(iter, new org.nd4j.evaluation.classification.Evaluation(), new org.nd4j.evaluation.classification.ROCMultiClass(), new org.nd4j.evaluation.regression.RegressionEvaluation()); org.nd4j.evaluation.IEvaluation[] e2 = net2.doEvaluation(iter, new org.nd4j.evaluation.classification.Evaluation(), new org.nd4j.evaluation.classification.ROCMultiClass(), new org.nd4j.evaluation.regression.RegressionEvaluation());
assertEquals(e1[0], e2[0]); assertEquals(e1[0], e2[0]);
assertEquals(e1[1], e2[1]); assertEquals(e1[1], e2[1]);
assertEquals(e1[2], e2[2]); assertEquals(e1[2], e2[2]);
@ -440,192 +286,170 @@ public class EvalTest extends BaseDL4JTest {
} }
@Test @Test
public void testEvalSplitting2(){ @DisplayName("Test Eval Splitting Comp Graph")
void testEvalSplittingCompGraph() {
// Test for "tbptt-like" functionality
for (WorkspaceMode ws : WorkspaceMode.values()) {
System.out.println("Starting test for workspace mode: " + ws);
int nIn = 4;
int layerSize = 5;
int nOut = 6;
int tbpttLength = 10;
int tsLength = 5 * tbpttLength + tbpttLength / 2;
ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(12345).trainingWorkspaceMode(ws).inferenceWorkspaceMode(ws).graphBuilder().addInputs("in").addLayer("0", new LSTM.Builder().nIn(nIn).nOut(layerSize).build(), "in").addLayer("1", new RnnOutputLayer.Builder().nIn(layerSize).nOut(nOut).activation(Activation.SOFTMAX).build(), "0").setOutputs("1").build();
ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345).trainingWorkspaceMode(ws).inferenceWorkspaceMode(ws).graphBuilder().addInputs("in").addLayer("0", new LSTM.Builder().nIn(nIn).nOut(layerSize).build(), "in").addLayer("1", new RnnOutputLayer.Builder().nIn(layerSize).nOut(nOut).activation(Activation.SOFTMAX).build(), "0").setOutputs("1").tBPTTLength(10).backpropType(BackpropType.TruncatedBPTT).build();
ComputationGraph net1 = new ComputationGraph(conf1);
net1.init();
ComputationGraph net2 = new ComputationGraph(conf2);
net2.init();
net2.setParams(net1.params());
for (boolean useMask : new boolean[] { false, true }) {
INDArray in1 = Nd4j.rand(new int[] { 3, nIn, tsLength });
INDArray out1 = TestUtils.randomOneHotTimeSeries(3, nOut, tsLength);
INDArray in2 = Nd4j.rand(new int[] { 5, nIn, tsLength });
INDArray out2 = TestUtils.randomOneHotTimeSeries(5, nOut, tsLength);
INDArray lMask1 = null;
INDArray lMask2 = null;
if (useMask) {
lMask1 = Nd4j.create(3, tsLength);
lMask2 = Nd4j.create(5, tsLength);
Nd4j.getExecutioner().exec(new BernoulliDistribution(lMask1, 0.5));
Nd4j.getExecutioner().exec(new BernoulliDistribution(lMask2, 0.5));
}
List<DataSet> l = Arrays.asList(new DataSet(in1, out1), new DataSet(in2, out2));
DataSetIterator iter = new ExistingDataSetIterator(l);
// System.out.println("Eval net 1");
org.nd4j.evaluation.IEvaluation[] e1 = net1.doEvaluation(iter, new org.nd4j.evaluation.classification.Evaluation(), new org.nd4j.evaluation.classification.ROCMultiClass(), new org.nd4j.evaluation.regression.RegressionEvaluation());
// System.out.println("Eval net 2");
org.nd4j.evaluation.IEvaluation[] e2 = net2.doEvaluation(iter, new org.nd4j.evaluation.classification.Evaluation(), new org.nd4j.evaluation.classification.ROCMultiClass(), new org.nd4j.evaluation.regression.RegressionEvaluation());
assertEquals(e1[0], e2[0]);
assertEquals(e1[1], e2[1]);
assertEquals(e1[2], e2[2]);
}
}
}
@Test
@DisplayName("Test Eval Splitting 2")
void testEvalSplitting2() {
List<List<Writable>> seqFeatures = new ArrayList<>(); List<List<Writable>> seqFeatures = new ArrayList<>();
List<Writable> step = Arrays.<Writable>asList(new FloatWritable(0), new FloatWritable(0), new FloatWritable(0)); List<Writable> step = Arrays.<Writable>asList(new FloatWritable(0), new FloatWritable(0), new FloatWritable(0));
for( int i=0; i<30; i++ ){ for (int i = 0; i < 30; i++) {
seqFeatures.add(step); seqFeatures.add(step);
} }
List<List<Writable>> seqLabels = Collections.singletonList(Collections.<Writable>singletonList(new FloatWritable(0))); List<List<Writable>> seqLabels = Collections.singletonList(Collections.<Writable>singletonList(new FloatWritable(0)));
SequenceRecordReader fsr = new CollectionSequenceRecordReader(Collections.singletonList(seqFeatures)); SequenceRecordReader fsr = new CollectionSequenceRecordReader(Collections.singletonList(seqFeatures));
SequenceRecordReader lsr = new CollectionSequenceRecordReader(Collections.singletonList(seqLabels)); SequenceRecordReader lsr = new CollectionSequenceRecordReader(Collections.singletonList(seqLabels));
DataSetIterator testData = new SequenceRecordReaderDataSetIterator(fsr, lsr, 1, -1, true, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(123).list().layer(0, new LSTM.Builder().activation(Activation.TANH).nIn(3).nOut(3).build()).layer(1, new RnnOutputLayer.Builder().activation(Activation.SIGMOID).lossFunction(LossFunctions.LossFunction.XENT).nIn(3).nOut(1).build()).backpropType(BackpropType.TruncatedBPTT).tBPTTForwardLength(10).tBPTTBackwardLength(10).build();
DataSetIterator testData = new SequenceRecordReaderDataSetIterator(fsr, lsr, 1, -1, true,
SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(123)
.list()
.layer(0, new LSTM.Builder().activation(Activation.TANH).nIn(3).nOut(3).build())
.layer(1, new RnnOutputLayer.Builder().activation(Activation.SIGMOID).lossFunction(LossFunctions.LossFunction.XENT)
.nIn(3).nOut(1).build())
.backpropType(BackpropType.TruncatedBPTT).tBPTTForwardLength(10).tBPTTBackwardLength(10)
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
net.evaluate(testData); net.evaluate(testData);
} }
@Test @Test
public void testEvaluativeListenerSimple(){ @DisplayName("Test Evaluative Listener Simple")
//Sanity check: https://github.com/eclipse/deeplearning4j/issues/5351 void testEvaluativeListenerSimple() {
// Sanity check: https://github.com/eclipse/deeplearning4j/issues/5351
// Network config // Network config
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).seed(42).updater(new Sgd(1e-6)).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(2).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()).layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build()).build();
.optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).seed(42)
.updater(new Sgd(1e-6)).list()
.layer(0, new DenseLayer.Builder().nIn(4).nOut(2).activation(Activation.TANH)
.weightInit(WeightInit.XAVIER).build())
.layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3).weightInit(WeightInit.XAVIER)
.activation(Activation.SOFTMAX).build())
.build();
// Instantiate model // Instantiate model
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
// Train-test split // Train-test split
DataSetIterator iter = new IrisDataSetIterator(30, 150); DataSetIterator iter = new IrisDataSetIterator(30, 150);
DataSetIterator iterTest = new IrisDataSetIterator(30, 150); DataSetIterator iterTest = new IrisDataSetIterator(30, 150);
net.setListeners(new EvaluativeListener(iterTest, 3)); net.setListeners(new EvaluativeListener(iterTest, 3));
for (int i = 0; i < 3; i++) {
for( int i=0; i<3; i++ ){
net.fit(iter); net.fit(iter);
} }
} }
@Test @Test
public void testMultiOutputEvalSimple(){ @DisplayName("Test Multi Output Eval Simple")
void testMultiOutputEvalSimple() {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).graphBuilder().addInputs("in").addLayer("out1", new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX).build(), "in").addLayer("out2", new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX).build(), "in").setOutputs("out1", "out2").build();
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(12345)
.graphBuilder()
.addInputs("in")
.addLayer("out1", new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX).build(), "in")
.addLayer("out2", new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX).build(), "in")
.setOutputs("out1", "out2")
.build();
ComputationGraph cg = new ComputationGraph(conf); ComputationGraph cg = new ComputationGraph(conf);
cg.init(); cg.init();
List<MultiDataSet> list = new ArrayList<>(); List<MultiDataSet> list = new ArrayList<>();
DataSetIterator iter = new IrisDataSetIterator(30, 150); DataSetIterator iter = new IrisDataSetIterator(30, 150);
while(iter.hasNext()){ while (iter.hasNext()) {
DataSet ds = iter.next(); DataSet ds = iter.next();
list.add(new org.nd4j.linalg.dataset.MultiDataSet(new INDArray[]{ds.getFeatures()}, new INDArray[]{ds.getLabels(), ds.getLabels()})); list.add(new org.nd4j.linalg.dataset.MultiDataSet(new INDArray[] { ds.getFeatures() }, new INDArray[] { ds.getLabels(), ds.getLabels() }));
} }
org.nd4j.evaluation.classification.Evaluation e = new org.nd4j.evaluation.classification.Evaluation(); org.nd4j.evaluation.classification.Evaluation e = new org.nd4j.evaluation.classification.Evaluation();
org.nd4j.evaluation.regression.RegressionEvaluation e2 = new org.nd4j.evaluation.regression.RegressionEvaluation(); org.nd4j.evaluation.regression.RegressionEvaluation e2 = new org.nd4j.evaluation.regression.RegressionEvaluation();
Map<Integer,org.nd4j.evaluation.IEvaluation[]> evals = new HashMap<>(); Map<Integer, org.nd4j.evaluation.IEvaluation[]> evals = new HashMap<>();
evals.put(0, new org.nd4j.evaluation.IEvaluation[]{e}); evals.put(0, new org.nd4j.evaluation.IEvaluation[] { e });
evals.put(1, new org.nd4j.evaluation.IEvaluation[]{e2}); evals.put(1, new org.nd4j.evaluation.IEvaluation[] { e2 });
cg.evaluate(new IteratorMultiDataSetIterator(list.iterator(), 30), evals); cg.evaluate(new IteratorMultiDataSetIterator(list.iterator(), 30), evals);
assertEquals(150, e.getNumRowCounter()); assertEquals(150, e.getNumRowCounter());
assertEquals(150, e2.getExampleCountPerColumn().getInt(0)); assertEquals(150, e2.getExampleCountPerColumn().getInt(0));
} }
@Test @Test
public void testMultiOutputEvalCG(){ @DisplayName("Test Multi Output Eval CG")
//Simple sanity check on evaluation void testMultiOutputEvalCG() {
// Simple sanity check on evaluation
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in").layer("0", new EmbeddingSequenceLayer.Builder().nIn(10).nOut(10).build(), "in").layer("1", new LSTM.Builder().nIn(10).nOut(10).build(), "0").layer("2", new LSTM.Builder().nIn(10).nOut(10).build(), "0").layer("out1", new RnnOutputLayer.Builder().nIn(10).nOut(10).activation(Activation.SOFTMAX).build(), "1").layer("out2", new RnnOutputLayer.Builder().nIn(10).nOut(20).activation(Activation.SOFTMAX).build(), "2").setOutputs("out1", "out2").build();
.graphBuilder()
.addInputs("in")
.layer("0", new EmbeddingSequenceLayer.Builder().nIn(10).nOut(10).build(), "in")
.layer("1", new LSTM.Builder().nIn(10).nOut(10).build(), "0")
.layer("2", new LSTM.Builder().nIn(10).nOut(10).build(), "0")
.layer("out1", new RnnOutputLayer.Builder().nIn(10).nOut(10).activation(Activation.SOFTMAX).build(), "1")
.layer("out2", new RnnOutputLayer.Builder().nIn(10).nOut(20).activation(Activation.SOFTMAX).build(), "2")
.setOutputs("out1", "out2")
.build();
ComputationGraph cg = new ComputationGraph(conf); ComputationGraph cg = new ComputationGraph(conf);
cg.init(); cg.init();
org.nd4j.linalg.dataset.MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(new INDArray[] { Nd4j.create(10, 1, 10) }, new INDArray[] { Nd4j.create(10, 10, 10), Nd4j.create(10, 20, 10) });
org.nd4j.linalg.dataset.MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet( Map<Integer, org.nd4j.evaluation.IEvaluation[]> m = new HashMap<>();
new INDArray[]{Nd4j.create(10, 1, 10)}, m.put(0, new org.nd4j.evaluation.IEvaluation[] { new org.nd4j.evaluation.classification.Evaluation() });
new INDArray[]{Nd4j.create(10, 10, 10), Nd4j.create(10, 20, 10)}); m.put(1, new org.nd4j.evaluation.IEvaluation[] { new org.nd4j.evaluation.classification.Evaluation() });
Map<Integer,org.nd4j.evaluation.IEvaluation[]> m = new HashMap<>();
m.put(0, new org.nd4j.evaluation.IEvaluation[]{new org.nd4j.evaluation.classification.Evaluation()});
m.put(1, new org.nd4j.evaluation.IEvaluation[]{new org.nd4j.evaluation.classification.Evaluation()});
cg.evaluate(new SingletonMultiDataSetIterator(mds), m); cg.evaluate(new SingletonMultiDataSetIterator(mds), m);
} }
@Test @Test
public void testInvalidEvaluation(){ @DisplayName("Test Invalid Evaluation")
void testInvalidEvaluation() {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(new DenseLayer.Builder().nIn(4).nOut(10).build()).layer(new OutputLayer.Builder().nIn(10).nOut(3).lossFunction(LossFunctions.LossFunction.MSE).activation(Activation.RELU).build()).build();
.list()
.layer(new DenseLayer.Builder().nIn(4).nOut(10).build())
.layer(new OutputLayer.Builder().nIn(10).nOut(3).lossFunction(LossFunctions.LossFunction.MSE).activation(Activation.RELU).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
DataSetIterator iter = new IrisDataSetIterator(150, 150); DataSetIterator iter = new IrisDataSetIterator(150, 150);
try { try {
net.evaluate(iter); net.evaluate(iter);
fail("Expected exception"); fail("Expected exception");
} catch (IllegalStateException e){ } catch (IllegalStateException e) {
assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("Evaluation")); assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("Evaluation"));
} }
try { try {
net.evaluateROC(iter, 0); net.evaluateROC(iter, 0);
fail("Expected exception"); fail("Expected exception");
} catch (IllegalStateException e){ } catch (IllegalStateException e) {
assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("ROC")); assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("ROC"));
} }
try { try {
net.evaluateROCMultiClass(iter, 0); net.evaluateROCMultiClass(iter, 0);
fail("Expected exception"); fail("Expected exception");
} catch (IllegalStateException e){ } catch (IllegalStateException e) {
assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("ROCMultiClass")); assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("ROCMultiClass"));
} }
ComputationGraph cg = net.toComputationGraph(); ComputationGraph cg = net.toComputationGraph();
try { try {
cg.evaluate(iter); cg.evaluate(iter);
fail("Expected exception"); fail("Expected exception");
} catch (IllegalStateException e){ } catch (IllegalStateException e) {
assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("Evaluation")); assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("Evaluation"));
} }
try { try {
cg.evaluateROC(iter, 0); cg.evaluateROC(iter, 0);
fail("Expected exception"); fail("Expected exception");
} catch (IllegalStateException e){ } catch (IllegalStateException e) {
assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("ROC")); assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("ROC"));
} }
try { try {
cg.evaluateROCMultiClass(iter, 0); cg.evaluateROCMultiClass(iter, 0);
fail("Expected exception"); fail("Expected exception");
} catch (IllegalStateException e){ } catch (IllegalStateException e) {
assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("ROCMultiClass")); assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("ROCMultiClass"));
} }
// Disable validation, and check same thing:
//Disable validation, and check same thing:
net.getLayerWiseConfigurations().setValidateOutputLayerConfig(false); net.getLayerWiseConfigurations().setValidateOutputLayerConfig(false);
net.evaluate(iter); net.evaluate(iter);
net.evaluateROCMultiClass(iter, 0); net.evaluateROCMultiClass(iter, 0);
cg.getConfiguration().setValidateOutputLayerConfig(false); cg.getConfiguration().setValidateOutputLayerConfig(false);
cg.evaluate(iter); cg.evaluate(iter);
cg.evaluateROCMultiClass(iter, 0); cg.evaluateROCMultiClass(iter, 0);

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.deeplearning4j.eval; package org.deeplearning4j.eval;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
@ -28,48 +27,53 @@ import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.Test; import org.junit.jupiter.api.DisplayName;
import org.nd4j.evaluation.curves.PrecisionRecallCurve; import org.junit.jupiter.api.Test;
import org.nd4j.evaluation.curves.RocCurve; import org.nd4j.evaluation.curves.RocCurve;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.*; import java.util.HashMap;
import java.util.Map;
import static org.junit.Assert.*; import static org.junit.jupiter.api.Assertions.assertEquals;
public class ROCTest extends BaseDL4JTest { @DisplayName("Roc Test")
class ROCTest extends BaseDL4JTest {
private static Map<Double, Double> expTPR; private static Map<Double, Double> expTPR;
private static Map<Double, Double> expFPR; private static Map<Double, Double> expFPR;
static { static {
expTPR = new HashMap<>(); expTPR = new HashMap<>();
double totalPositives = 5.0; double totalPositives = 5.0;
expTPR.put(0 / 10.0, 5.0 / totalPositives); //All 10 predicted as class 1, of which 5 of 5 are correct // All 10 predicted as class 1, of which 5 of 5 are correct
expTPR.put(0 / 10.0, 5.0 / totalPositives);
expTPR.put(1 / 10.0, 5.0 / totalPositives); expTPR.put(1 / 10.0, 5.0 / totalPositives);
expTPR.put(2 / 10.0, 5.0 / totalPositives); expTPR.put(2 / 10.0, 5.0 / totalPositives);
expTPR.put(3 / 10.0, 5.0 / totalPositives); expTPR.put(3 / 10.0, 5.0 / totalPositives);
expTPR.put(4 / 10.0, 5.0 / totalPositives); expTPR.put(4 / 10.0, 5.0 / totalPositives);
expTPR.put(5 / 10.0, 5.0 / totalPositives); expTPR.put(5 / 10.0, 5.0 / totalPositives);
expTPR.put(6 / 10.0, 4.0 / totalPositives); //Threshold: 0.4 -> last 4 predicted; last 5 actual // Threshold: 0.4 -> last 4 predicted; last 5 actual
expTPR.put(6 / 10.0, 4.0 / totalPositives);
expTPR.put(7 / 10.0, 3.0 / totalPositives); expTPR.put(7 / 10.0, 3.0 / totalPositives);
expTPR.put(8 / 10.0, 2.0 / totalPositives); expTPR.put(8 / 10.0, 2.0 / totalPositives);
expTPR.put(9 / 10.0, 1.0 / totalPositives); expTPR.put(9 / 10.0, 1.0 / totalPositives);
expTPR.put(10 / 10.0, 0.0 / totalPositives); expTPR.put(10 / 10.0, 0.0 / totalPositives);
expFPR = new HashMap<>(); expFPR = new HashMap<>();
double totalNegatives = 5.0; double totalNegatives = 5.0;
expFPR.put(0 / 10.0, 5.0 / totalNegatives); //All 10 predicted as class 1, but all 5 true negatives are predicted positive // All 10 predicted as class 1, but all 5 true negatives are predicted positive
expFPR.put(1 / 10.0, 4.0 / totalNegatives); //1 true negative is predicted as negative; 4 false positives expFPR.put(0 / 10.0, 5.0 / totalNegatives);
expFPR.put(2 / 10.0, 3.0 / totalNegatives); //2 true negatives are predicted as negative; 3 false positives // 1 true negative is predicted as negative; 4 false positives
expFPR.put(1 / 10.0, 4.0 / totalNegatives);
// 2 true negatives are predicted as negative; 3 false positives
expFPR.put(2 / 10.0, 3.0 / totalNegatives);
expFPR.put(3 / 10.0, 2.0 / totalNegatives); expFPR.put(3 / 10.0, 2.0 / totalNegatives);
expFPR.put(4 / 10.0, 1.0 / totalNegatives); expFPR.put(4 / 10.0, 1.0 / totalNegatives);
expFPR.put(5 / 10.0, 0.0 / totalNegatives); expFPR.put(5 / 10.0, 0.0 / totalNegatives);
@ -81,56 +85,41 @@ public class ROCTest extends BaseDL4JTest {
} }
@Test @Test
public void RocEvalSanityCheck() { @DisplayName("Roc Eval Sanity Check")
void RocEvalSanityCheck() {
DataSetIterator iter = new IrisDataSetIterator(150, 150); DataSetIterator iter = new IrisDataSetIterator(150, 150);
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).seed(12345) MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).seed(12345).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(4).activation(Activation.TANH).build()).layer(1, new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()).build();
.list()
.layer(0, new DenseLayer.Builder().nIn(4).nOut(4).activation(Activation.TANH).build()).layer(1,
new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
NormalizerStandardize ns = new NormalizerStandardize(); NormalizerStandardize ns = new NormalizerStandardize();
DataSet ds = iter.next(); DataSet ds = iter.next();
ns.fit(ds); ns.fit(ds);
ns.transform(ds); ns.transform(ds);
iter.setPreProcessor(ns); iter.setPreProcessor(ns);
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
net.fit(ds); net.fit(ds);
} }
for (int steps : new int[] { 32, 0 }) {
for (int steps : new int[] {32, 0}) { //Steps = 0: exact // Steps = 0: exact
System.out.println("steps: " + steps); System.out.println("steps: " + steps);
iter.reset(); iter.reset();
ds = iter.next(); ds = iter.next();
INDArray f = ds.getFeatures(); INDArray f = ds.getFeatures();
INDArray l = ds.getLabels(); INDArray l = ds.getLabels();
INDArray out = net.output(f); INDArray out = net.output(f);
// System.out.println(f); // System.out.println(f);
// System.out.println(out); // System.out.println(out);
ROCMultiClass manual = new ROCMultiClass(steps); ROCMultiClass manual = new ROCMultiClass(steps);
manual.eval(l, out); manual.eval(l, out);
iter.reset(); iter.reset();
ROCMultiClass roc = net.evaluateROCMultiClass(iter, steps); ROCMultiClass roc = net.evaluateROCMultiClass(iter, steps);
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
double rocExp = manual.calculateAUC(i); double rocExp = manual.calculateAUC(i);
double rocAct = roc.calculateAUC(i); double rocAct = roc.calculateAUC(i);
assertEquals(rocExp, rocAct, 1e-6); assertEquals(rocExp, rocAct, 1e-6);
RocCurve rc = roc.getRocCurve(i); RocCurve rc = roc.getRocCurve(i);
RocCurve rm = manual.getRocCurve(i); RocCurve rm = manual.getRocCurve(i);
assertEquals(rc, rm); assertEquals(rc, rm);
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.deeplearning4j.eval; package org.deeplearning4j.eval;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
@ -29,59 +28,43 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.Collections; import java.util.Collections;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertEquals;
import static org.nd4j.linalg.indexing.NDArrayIndex.all; import static org.nd4j.linalg.indexing.NDArrayIndex.all;
import static org.nd4j.linalg.indexing.NDArrayIndex.interval; import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
public class RegressionEvalTest extends BaseDL4JTest { @DisplayName("Regression Eval Test")
class RegressionEvalTest extends BaseDL4JTest {
@Test @Test
public void testRegressionEvalMethods() { @DisplayName("Test Regression Eval Methods")
void testRegressionEvalMethods() {
//Basic sanity check // Basic sanity check
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.ZERO).list() MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.ZERO).list().layer(0, new OutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(5).build()).build();
.layer(0, new OutputLayer.Builder().activation(Activation.TANH)
.lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(5).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
INDArray f = Nd4j.zeros(4, 10); INDArray f = Nd4j.zeros(4, 10);
INDArray l = Nd4j.ones(4, 5); INDArray l = Nd4j.ones(4, 5);
DataSet ds = new DataSet(f, l); DataSet ds = new DataSet(f, l);
DataSetIterator iter = new ExistingDataSetIterator(Collections.singletonList(ds)); DataSetIterator iter = new ExistingDataSetIterator(Collections.singletonList(ds));
org.nd4j.evaluation.regression.RegressionEvaluation re = net.evaluateRegression(iter); org.nd4j.evaluation.regression.RegressionEvaluation re = net.evaluateRegression(iter);
for (int i = 0; i < 5; i++) { for (int i = 0; i < 5; i++) {
assertEquals(1.0, re.meanSquaredError(i), 1e-6); assertEquals(1.0, re.meanSquaredError(i), 1e-6);
assertEquals(1.0, re.meanAbsoluteError(i), 1e-6); assertEquals(1.0, re.meanAbsoluteError(i), 1e-6);
} }
ComputationGraphConfiguration graphConf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.ZERO).graphBuilder().addInputs("in").addLayer("0", new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(10).nOut(5).build(), "in").setOutputs("0").build();
ComputationGraphConfiguration graphConf =
new NeuralNetConfiguration.Builder().weightInit(WeightInit.ZERO).graphBuilder()
.addInputs("in").addLayer("0", new OutputLayer.Builder()
.lossFunction(LossFunctions.LossFunction.MSE)
.activation(Activation.TANH).nIn(10).nOut(5).build(), "in")
.setOutputs("0").build();
ComputationGraph cg = new ComputationGraph(graphConf); ComputationGraph cg = new ComputationGraph(graphConf);
cg.init(); cg.init();
RegressionEvaluation re2 = cg.evaluateRegression(iter); RegressionEvaluation re2 = cg.evaluateRegression(iter);
for (int i = 0; i < 5; i++) { for (int i = 0; i < 5; i++) {
assertEquals(1.0, re2.meanSquaredError(i), 1e-6); assertEquals(1.0, re2.meanSquaredError(i), 1e-6);
assertEquals(1.0, re2.meanAbsoluteError(i), 1e-6); assertEquals(1.0, re2.meanAbsoluteError(i), 1e-6);
@ -89,25 +72,16 @@ public class RegressionEvalTest extends BaseDL4JTest {
} }
@Test @Test
public void testRegressionEvalPerOutputMasking() { @DisplayName("Test Regression Eval Per Output Masking")
void testRegressionEvalPerOutputMasking() {
INDArray l = Nd4j.create(new double[][] {{1, 2, 3}, {10, 20, 30}, {-5, -10, -20}}); INDArray l = Nd4j.create(new double[][] { { 1, 2, 3 }, { 10, 20, 30 }, { -5, -10, -20 } });
INDArray predictions = Nd4j.zeros(l.shape()); INDArray predictions = Nd4j.zeros(l.shape());
INDArray mask = Nd4j.create(new double[][] { { 0, 1, 1 }, { 1, 1, 0 }, { 0, 1, 0 } });
INDArray mask = Nd4j.create(new double[][] {{0, 1, 1}, {1, 1, 0}, {0, 1, 0}});
RegressionEvaluation re = new RegressionEvaluation(); RegressionEvaluation re = new RegressionEvaluation();
re.eval(l, predictions, mask); re.eval(l, predictions, mask);
double[] mse = new double[] { (10 * 10) / 1.0, (2 * 2 + 20 * 20 + 10 * 10) / 3, (3 * 3) / 1.0 };
double[] mse = new double[] {(10 * 10) / 1.0, (2 * 2 + 20 * 20 + 10 * 10) / 3, (3 * 3) / 1.0}; double[] mae = new double[] { 10.0, (2 + 20 + 10) / 3.0, 3.0 };
double[] rmse = new double[] { 10.0, Math.sqrt((2 * 2 + 20 * 20 + 10 * 10) / 3.0), 3.0 };
double[] mae = new double[] {10.0, (2 + 20 + 10) / 3.0, 3.0};
double[] rmse = new double[] {10.0, Math.sqrt((2 * 2 + 20 * 20 + 10 * 10) / 3.0), 3.0};
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
assertEquals(mse[i], re.meanSquaredError(i), 1e-6); assertEquals(mse[i], re.meanSquaredError(i), 1e-6);
assertEquals(mae[i], re.meanAbsoluteError(i), 1e-6); assertEquals(mae[i], re.meanAbsoluteError(i), 1e-6);
@ -116,24 +90,19 @@ public class RegressionEvalTest extends BaseDL4JTest {
} }
@Test @Test
public void testRegressionEvalTimeSeriesSplit(){ @DisplayName("Test Regression Eval Time Series Split")
void testRegressionEvalTimeSeriesSplit() {
INDArray out1 = Nd4j.rand(new int[]{3, 5, 20}); INDArray out1 = Nd4j.rand(new int[] { 3, 5, 20 });
INDArray outSub1 = out1.get(all(), all(), interval(0,10)); INDArray outSub1 = out1.get(all(), all(), interval(0, 10));
INDArray outSub2 = out1.get(all(), all(), interval(10, 20)); INDArray outSub2 = out1.get(all(), all(), interval(10, 20));
INDArray label1 = Nd4j.rand(new int[] { 3, 5, 20 });
INDArray label1 = Nd4j.rand(new int[]{3, 5, 20}); INDArray labelSub1 = label1.get(all(), all(), interval(0, 10));
INDArray labelSub1 = label1.get(all(), all(), interval(0,10));
INDArray labelSub2 = label1.get(all(), all(), interval(10, 20)); INDArray labelSub2 = label1.get(all(), all(), interval(10, 20));
RegressionEvaluation e1 = new RegressionEvaluation(); RegressionEvaluation e1 = new RegressionEvaluation();
RegressionEvaluation e2 = new RegressionEvaluation(); RegressionEvaluation e2 = new RegressionEvaluation();
e1.eval(label1, out1); e1.eval(label1, out1);
e2.eval(labelSub1, outSub1); e2.eval(labelSub1, outSub1);
e2.eval(labelSub2, outSub2); e2.eval(labelSub2, outSub2);
assertEquals(e1, e2); assertEquals(e1, e2);
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.deeplearning4j.gradientcheck; package org.deeplearning4j.gradientcheck;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
@ -32,9 +31,9 @@ import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.Ignore; import org.junit.jupiter.api.Disabled;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.ExpectedException; import org.junit.rules.ExpectedException;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
@ -42,13 +41,15 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.Random; import java.util.Random;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertTrue; @Disabled
@DisplayName("Attention Layer Test")
class AttentionLayerTest extends BaseDL4JTest {
@Ignore
public class AttentionLayerTest extends BaseDL4JTest {
@Rule @Rule
public ExpectedException exceptionRule = ExpectedException.none(); public ExpectedException exceptionRule = ExpectedException.none();
@ -58,19 +59,18 @@ public class AttentionLayerTest extends BaseDL4JTest {
} }
@Test @Test
public void testSelfAttentionLayer() { @DisplayName("Test Self Attention Layer")
void testSelfAttentionLayer() {
int nIn = 3; int nIn = 3;
int nOut = 2; int nOut = 2;
int tsLength = 4; int tsLength = 4;
int layerSize = 4; int layerSize = 4;
for (int mb : new int[] { 1, 3 }) {
for (int mb : new int[]{1, 3}) { for (boolean inputMask : new boolean[] { false, true }) {
for (boolean inputMask : new boolean[]{false, true}) { for (boolean projectInput : new boolean[] { false, true }) {
for (boolean projectInput : new boolean[]{false, true}) { INDArray in = Nd4j.rand(DataType.DOUBLE, new int[] { mb, nIn, tsLength });
INDArray in = Nd4j.rand(DataType.DOUBLE, new int[]{mb, nIn, tsLength});
INDArray labels = TestUtils.randomOneHot(mb, nOut); INDArray labels = TestUtils.randomOneHot(mb, nOut);
String maskType = (inputMask ? "inputMask" : "none"); String maskType = (inputMask ? "inputMask" : "none");
INDArray inMask = null; INDArray inMask = null;
if (inputMask) { if (inputMask) {
inMask = Nd4j.ones(mb, tsLength); inMask = Nd4j.ones(mb, tsLength);
@ -84,54 +84,32 @@ public class AttentionLayerTest extends BaseDL4JTest {
} }
} }
} }
String name = "testSelfAttentionLayer() - mb=" + mb + ", tsLength = " + tsLength + ", maskType=" + maskType + ", projectInput = " + projectInput; String name = "testSelfAttentionLayer() - mb=" + mb + ", tsLength = " + tsLength + ", maskType=" + maskType + ", projectInput = " + projectInput;
System.out.println("Starting test: " + name); System.out.println("Starting test: " + name);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).activation(Activation.TANH).updater(new NoOp()).weightInit(WeightInit.XAVIER).list().layer(new LSTM.Builder().nOut(layerSize).build()).layer(projectInput ? new SelfAttentionLayer.Builder().nOut(4).nHeads(2).projectInput(true).build() : new SelfAttentionLayer.Builder().nHeads(1).projectInput(false).build()).layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.MAX).build()).layer(new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()).setInputType(InputType.recurrent(nIn)).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.activation(Activation.TANH)
.updater(new NoOp())
.weightInit(WeightInit.XAVIER)
.list()
.layer(new LSTM.Builder().nOut(layerSize).build())
.layer( projectInput ?
new SelfAttentionLayer.Builder().nOut(4).nHeads(2).projectInput(true).build()
: new SelfAttentionLayer.Builder().nHeads(1).projectInput(false).build()
)
.layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.MAX).build())
.layer(new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.setInputType(InputType.recurrent(nIn))
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in).labels(labels).inputMask(inMask).subset(true).maxPerParam(100));
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in) assertTrue(gradOK,name);
.labels(labels).inputMask(inMask).subset(true).maxPerParam(100));
assertTrue(name, gradOK);
} }
} }
} }
} }
@Test @Test
public void testLearnedSelfAttentionLayer() { @DisplayName("Test Learned Self Attention Layer")
void testLearnedSelfAttentionLayer() {
int nIn = 3; int nIn = 3;
int nOut = 2; int nOut = 2;
int tsLength = 4; int tsLength = 4;
int layerSize = 4; int layerSize = 4;
int numQueries = 3; int numQueries = 3;
for (boolean inputMask : new boolean[] { false, true }) {
for (boolean inputMask : new boolean[]{false, true}) { for (int mb : new int[] { 3, 1 }) {
for (int mb : new int[]{3, 1}) { for (boolean projectInput : new boolean[] { false, true }) {
for (boolean projectInput : new boolean[]{false, true}) { INDArray in = Nd4j.rand(DataType.DOUBLE, new int[] { mb, nIn, tsLength });
INDArray in = Nd4j.rand(DataType.DOUBLE, new int[]{mb, nIn, tsLength});
INDArray labels = TestUtils.randomOneHot(mb, nOut); INDArray labels = TestUtils.randomOneHot(mb, nOut);
String maskType = (inputMask ? "inputMask" : "none"); String maskType = (inputMask ? "inputMask" : "none");
INDArray inMask = null; INDArray inMask = null;
if (inputMask) { if (inputMask) {
inMask = Nd4j.ones(mb, tsLength); inMask = Nd4j.ones(mb, tsLength);
@ -145,75 +123,36 @@ public class AttentionLayerTest extends BaseDL4JTest {
} }
} }
} }
String name = "testLearnedSelfAttentionLayer() - mb=" + mb + ", tsLength = " + tsLength + ", maskType=" + maskType + ", projectInput = " + projectInput; String name = "testLearnedSelfAttentionLayer() - mb=" + mb + ", tsLength = " + tsLength + ", maskType=" + maskType + ", projectInput = " + projectInput;
System.out.println("Starting test: " + name); System.out.println("Starting test: " + name);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).activation(Activation.TANH).updater(new NoOp()).weightInit(WeightInit.XAVIER).list().layer(new LSTM.Builder().nOut(layerSize).build()).layer(projectInput ? new LearnedSelfAttentionLayer.Builder().nOut(4).nHeads(2).nQueries(numQueries).projectInput(true).build() : new LearnedSelfAttentionLayer.Builder().nHeads(1).nQueries(numQueries).projectInput(false).build()).layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.MAX).build()).layer(new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()).setInputType(InputType.recurrent(nIn)).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.activation(Activation.TANH)
.updater(new NoOp())
.weightInit(WeightInit.XAVIER)
.list()
.layer(new LSTM.Builder().nOut(layerSize).build())
.layer( projectInput ?
new LearnedSelfAttentionLayer.Builder().nOut(4).nHeads(2).nQueries(numQueries).projectInput(true).build()
: new LearnedSelfAttentionLayer.Builder().nHeads(1).nQueries(numQueries).projectInput(false).build()
)
.layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.MAX).build())
.layer(new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.setInputType(InputType.recurrent(nIn))
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in).labels(labels).inputMask(inMask).subset(true).maxPerParam(100));
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in) assertTrue(gradOK,name);
.labels(labels).inputMask(inMask).subset(true).maxPerParam(100));
assertTrue(name, gradOK);
} }
} }
} }
} }
@Test @Test
public void testLearnedSelfAttentionLayer_differentMiniBatchSizes() { @DisplayName("Test Learned Self Attention Layer _ different Mini Batch Sizes")
void testLearnedSelfAttentionLayer_differentMiniBatchSizes() {
int nIn = 3; int nIn = 3;
int nOut = 2; int nOut = 2;
int tsLength = 4; int tsLength = 4;
int layerSize = 4; int layerSize = 4;
int numQueries = 3; int numQueries = 3;
Random r = new Random(12345); Random r = new Random(12345);
for (boolean inputMask : new boolean[]{false, true}) { for (boolean inputMask : new boolean[] { false, true }) {
for (boolean projectInput : new boolean[]{false, true}) { for (boolean projectInput : new boolean[] { false, true }) {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).activation(Activation.TANH).updater(new NoOp()).weightInit(WeightInit.XAVIER).list().layer(new LSTM.Builder().nOut(layerSize).build()).layer(projectInput ? new LearnedSelfAttentionLayer.Builder().nOut(4).nHeads(2).nQueries(numQueries).projectInput(true).build() : new LearnedSelfAttentionLayer.Builder().nHeads(1).nQueries(numQueries).projectInput(false).build()).layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.MAX).build()).layer(new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()).setInputType(InputType.recurrent(nIn)).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() MultiLayerNetwork net = new MultiLayerNetwork(conf);
.dataType(DataType.DOUBLE) net.init();
.activation(Activation.TANH) for (int mb : new int[] { 3, 1 }) {
.updater(new NoOp()) INDArray in = Nd4j.rand(DataType.DOUBLE, new int[] { mb, nIn, tsLength });
.weightInit(WeightInit.XAVIER)
.list()
.layer(new LSTM.Builder().nOut(layerSize).build())
.layer( projectInput ?
new LearnedSelfAttentionLayer.Builder().nOut(4).nHeads(2).nQueries(numQueries).projectInput(true).build()
: new LearnedSelfAttentionLayer.Builder().nHeads(1).nQueries(numQueries).projectInput(false).build()
)
.layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.MAX).build())
.layer(new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.setInputType(InputType.recurrent(nIn))
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
for (int mb : new int[]{3, 1}) {
INDArray in = Nd4j.rand(DataType.DOUBLE, new int[]{mb, nIn, tsLength});
INDArray labels = TestUtils.randomOneHot(mb, nOut); INDArray labels = TestUtils.randomOneHot(mb, nOut);
String maskType = (inputMask ? "inputMask" : "none"); String maskType = (inputMask ? "inputMask" : "none");
INDArray inMask = null; INDArray inMask = null;
if (inputMask) { if (inputMask) {
inMask = Nd4j.ones(DataType.INT, mb, tsLength); inMask = Nd4j.ones(DataType.INT, mb, tsLength);
@ -227,68 +166,47 @@ public class AttentionLayerTest extends BaseDL4JTest {
} }
} }
} }
String name = "testLearnedSelfAttentionLayer() - mb=" + mb + ", tsLength = " + tsLength + ", maskType=" + maskType + ", projectInput = " + projectInput; String name = "testLearnedSelfAttentionLayer() - mb=" + mb + ", tsLength = " + tsLength + ", maskType=" + maskType + ", projectInput = " + projectInput;
System.out.println("Starting test: " + name); System.out.println("Starting test: " + name);
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in).labels(labels).inputMask(inMask).subset(true).maxPerParam(100));
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in) assertTrue(gradOK,name);
.labels(labels).inputMask(inMask).subset(true).maxPerParam(100));
assertTrue(name, gradOK);
} }
} }
} }
} }
@Test @Test
public void testRecurrentAttentionLayer_differingTimeSteps(){ @DisplayName("Test Recurrent Attention Layer _ differing Time Steps")
void testRecurrentAttentionLayer_differingTimeSteps() {
int nIn = 9; int nIn = 9;
int nOut = 5; int nOut = 5;
int layerSize = 8; int layerSize = 8;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).activation(Activation.IDENTITY).updater(new NoOp()).weightInit(WeightInit.XAVIER).list().layer(new LSTM.Builder().nOut(layerSize).build()).layer(new RecurrentAttentionLayer.Builder().nIn(layerSize).nOut(layerSize).nHeads(1).projectInput(false).hasBias(false).build()).layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.AVG).build()).layer(new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()).setInputType(InputType.recurrent(nIn)).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.activation(Activation.IDENTITY)
.updater(new NoOp())
.weightInit(WeightInit.XAVIER)
.list()
.layer(new LSTM.Builder().nOut(layerSize).build())
.layer(new RecurrentAttentionLayer.Builder().nIn(layerSize).nOut(layerSize).nHeads(1).projectInput(false).hasBias(false).build())
.layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.AVG).build())
.layer(new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.setInputType(InputType.recurrent(nIn))
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
final INDArray initialInput = Nd4j.rand(new int[] { 8, nIn, 7 });
final INDArray initialInput = Nd4j.rand(new int[]{8, nIn, 7}); final INDArray goodNextInput = Nd4j.rand(new int[] { 8, nIn, 7 });
final INDArray goodNextInput = Nd4j.rand(new int[]{8, nIn, 7}); final INDArray badNextInput = Nd4j.rand(new int[] { 8, nIn, 12 });
final INDArray badNextInput = Nd4j.rand(new int[]{8, nIn, 12}); final INDArray labels = Nd4j.rand(new int[] { 8, nOut });
final INDArray labels = Nd4j.rand(new int[]{8, nOut});
net.fit(initialInput, labels); net.fit(initialInput, labels);
net.fit(goodNextInput, labels); net.fit(goodNextInput, labels);
exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("This layer only supports fixed length mini-batches. Expected 7 time steps but got 12."); exceptionRule.expectMessage("This layer only supports fixed length mini-batches. Expected 7 time steps but got 12.");
net.fit(badNextInput, labels); net.fit(badNextInput, labels);
} }
@Test @Test
public void testRecurrentAttentionLayer() { @DisplayName("Test Recurrent Attention Layer")
void testRecurrentAttentionLayer() {
int nIn = 4; int nIn = 4;
int nOut = 2; int nOut = 2;
int tsLength = 3; int tsLength = 3;
int layerSize = 3; int layerSize = 3;
for (int mb : new int[] { 3, 1 }) {
for (int mb : new int[]{3, 1}) { for (boolean inputMask : new boolean[] { true, false }) {
for (boolean inputMask : new boolean[]{true, false}) { INDArray in = Nd4j.rand(DataType.DOUBLE, new int[] { mb, nIn, tsLength });
INDArray in = Nd4j.rand(DataType.DOUBLE, new int[]{mb, nIn, tsLength});
INDArray labels = TestUtils.randomOneHot(mb, nOut); INDArray labels = TestUtils.randomOneHot(mb, nOut);
String maskType = (inputMask ? "inputMask" : "none"); String maskType = (inputMask ? "inputMask" : "none");
INDArray inMask = null; INDArray inMask = null;
if (inputMask) { if (inputMask) {
inMask = Nd4j.ones(mb, tsLength); inMask = Nd4j.ones(mb, tsLength);
@ -302,51 +220,32 @@ public class AttentionLayerTest extends BaseDL4JTest {
} }
} }
} }
String name = "testRecurrentAttentionLayer() - mb=" + mb + ", tsLength = " + tsLength + ", maskType=" + maskType; String name = "testRecurrentAttentionLayer() - mb=" + mb + ", tsLength = " + tsLength + ", maskType=" + maskType;
System.out.println("Starting test: " + name); System.out.println("Starting test: " + name);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).activation(Activation.IDENTITY).updater(new NoOp()).weightInit(WeightInit.XAVIER).list().layer(new LSTM.Builder().nOut(layerSize).build()).layer(new RecurrentAttentionLayer.Builder().nIn(layerSize).nOut(layerSize).nHeads(1).projectInput(false).hasBias(false).build()).layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.AVG).build()).layer(new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()).setInputType(InputType.recurrent(nIn)).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.activation(Activation.IDENTITY)
.updater(new NoOp())
.weightInit(WeightInit.XAVIER)
.list()
.layer(new LSTM.Builder().nOut(layerSize).build())
.layer(new RecurrentAttentionLayer.Builder().nIn(layerSize).nOut(layerSize).nHeads(1).projectInput(false).hasBias(false).build())
.layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.AVG).build())
.layer(new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.setInputType(InputType.recurrent(nIn))
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
// System.out.println("Original");
//System.out.println("Original"); boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in).labels(labels).inputMask(inMask).subset(true).maxPerParam(100));
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in) assertTrue(gradOK,name);
.labels(labels).inputMask(inMask).subset(true).maxPerParam(100));
assertTrue(name, gradOK);
} }
} }
} }
@Test @Test
public void testAttentionVertex() { @DisplayName("Test Attention Vertex")
void testAttentionVertex() {
int nIn = 3; int nIn = 3;
int nOut = 2; int nOut = 2;
int tsLength = 3; int tsLength = 3;
int layerSize = 3; int layerSize = 3;
Random r = new Random(12345); Random r = new Random(12345);
for (boolean inputMask : new boolean[]{false, true}) { for (boolean inputMask : new boolean[] { false, true }) {
for (int mb : new int[]{3, 1}) { for (int mb : new int[] { 3, 1 }) {
for (boolean projectInput : new boolean[]{false, true}) { for (boolean projectInput : new boolean[] { false, true }) {
INDArray in = Nd4j.rand(DataType.DOUBLE, new int[]{mb, nIn, tsLength}); INDArray in = Nd4j.rand(DataType.DOUBLE, new int[] { mb, nIn, tsLength });
INDArray labels = TestUtils.randomOneHot(mb, nOut); INDArray labels = TestUtils.randomOneHot(mb, nOut);
String maskType = (inputMask ? "inputMask" : "none"); String maskType = (inputMask ? "inputMask" : "none");
INDArray inMask = null; INDArray inMask = null;
if (inputMask) { if (inputMask) {
inMask = Nd4j.ones(mb, tsLength); inMask = Nd4j.ones(mb, tsLength);
@ -360,57 +259,32 @@ public class AttentionLayerTest extends BaseDL4JTest {
} }
} }
} }
String name = "testAttentionVertex() - mb=" + mb + ", tsLength = " + tsLength + ", maskType=" + maskType + ", projectInput = " + projectInput; String name = "testAttentionVertex() - mb=" + mb + ", tsLength = " + tsLength + ", maskType=" + maskType + ", projectInput = " + projectInput;
System.out.println("Starting test: " + name); System.out.println("Starting test: " + name);
ComputationGraphConfiguration graph = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).activation(Activation.TANH).updater(new NoOp()).weightInit(WeightInit.XAVIER).graphBuilder().addInputs("input").addLayer("rnnKeys", new SimpleRnn.Builder().nOut(layerSize).build(), "input").addLayer("rnnQueries", new SimpleRnn.Builder().nOut(layerSize).build(), "input").addLayer("rnnValues", new SimpleRnn.Builder().nOut(layerSize).build(), "input").addVertex("attention", projectInput ? new AttentionVertex.Builder().nOut(4).nHeads(2).projectInput(true).nInQueries(layerSize).nInKeys(layerSize).nInValues(layerSize).build() : new AttentionVertex.Builder().nOut(3).nHeads(1).projectInput(false).nInQueries(layerSize).nInKeys(layerSize).nInValues(layerSize).build(), "rnnQueries", "rnnKeys", "rnnValues").addLayer("pooling", new GlobalPoolingLayer.Builder().poolingType(PoolingType.MAX).build(), "attention").addLayer("output", new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build(), "pooling").setOutputs("output").setInputTypes(InputType.recurrent(nIn)).build();
ComputationGraphConfiguration graph = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.activation(Activation.TANH)
.updater(new NoOp())
.weightInit(WeightInit.XAVIER)
.graphBuilder()
.addInputs("input")
.addLayer("rnnKeys", new SimpleRnn.Builder().nOut(layerSize).build(), "input")
.addLayer("rnnQueries", new SimpleRnn.Builder().nOut(layerSize).build(), "input")
.addLayer("rnnValues", new SimpleRnn.Builder().nOut(layerSize).build(), "input")
.addVertex("attention",
projectInput ?
new AttentionVertex.Builder().nOut(4).nHeads(2).projectInput(true).nInQueries(layerSize).nInKeys(layerSize).nInValues(layerSize).build()
: new AttentionVertex.Builder().nOut(3).nHeads(1).projectInput(false).nInQueries(layerSize).nInKeys(layerSize).nInValues(layerSize).build(), "rnnQueries", "rnnKeys", "rnnValues")
.addLayer("pooling", new GlobalPoolingLayer.Builder().poolingType(PoolingType.MAX).build(), "attention")
.addLayer("output", new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build(), "pooling")
.setOutputs("output")
.setInputTypes(InputType.recurrent(nIn))
.build();
ComputationGraph net = new ComputationGraph(graph); ComputationGraph net = new ComputationGraph(graph);
net.init(); net.init();
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[] { in }).labels(new INDArray[] { labels }).inputMask(inMask != null ? new INDArray[] { inMask } : null).subset(true).maxPerParam(100));
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[]{in}) assertTrue(gradOK,name);
.labels(new INDArray[]{labels}).inputMask(inMask != null ? new INDArray[]{inMask} : null).subset(true).maxPerParam(100));
assertTrue(name, gradOK);
} }
} }
} }
} }
@Test @Test
public void testAttentionVertexSameInput() { @DisplayName("Test Attention Vertex Same Input")
void testAttentionVertexSameInput() {
int nIn = 3; int nIn = 3;
int nOut = 2; int nOut = 2;
int tsLength = 4; int tsLength = 4;
int layerSize = 4; int layerSize = 4;
Random r = new Random(12345); Random r = new Random(12345);
for (boolean inputMask : new boolean[]{false, true}) { for (boolean inputMask : new boolean[] { false, true }) {
for (int mb : new int[]{3, 1}) { for (int mb : new int[] { 3, 1 }) {
for (boolean projectInput : new boolean[]{false, true}) { for (boolean projectInput : new boolean[] { false, true }) {
INDArray in = Nd4j.rand(new int[]{mb, nIn, tsLength}); INDArray in = Nd4j.rand(new int[] { mb, nIn, tsLength });
INDArray labels = TestUtils.randomOneHot(mb, nOut); INDArray labels = TestUtils.randomOneHot(mb, nOut);
String maskType = (inputMask ? "inputMask" : "none"); String maskType = (inputMask ? "inputMask" : "none");
INDArray inMask = null; INDArray inMask = null;
if (inputMask) { if (inputMask) {
inMask = Nd4j.ones(mb, tsLength); inMask = Nd4j.ones(mb, tsLength);
@ -424,35 +298,13 @@ public class AttentionLayerTest extends BaseDL4JTest {
} }
} }
} }
String name = "testAttentionVertex() - mb=" + mb + ", tsLength = " + tsLength + ", maskType=" + maskType + ", projectInput = " + projectInput; String name = "testAttentionVertex() - mb=" + mb + ", tsLength = " + tsLength + ", maskType=" + maskType + ", projectInput = " + projectInput;
System.out.println("Starting test: " + name); System.out.println("Starting test: " + name);
ComputationGraphConfiguration graph = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).activation(Activation.TANH).updater(new NoOp()).weightInit(WeightInit.XAVIER).graphBuilder().addInputs("input").addLayer("rnn", new SimpleRnn.Builder().activation(Activation.TANH).nOut(layerSize).build(), "input").addVertex("attention", projectInput ? new AttentionVertex.Builder().nOut(4).nHeads(2).projectInput(true).nInQueries(layerSize).nInKeys(layerSize).nInValues(layerSize).build() : new AttentionVertex.Builder().nOut(4).nHeads(1).projectInput(false).nInQueries(layerSize).nInKeys(layerSize).nInValues(layerSize).build(), "rnn", "rnn", "rnn").addLayer("pooling", new GlobalPoolingLayer.Builder().poolingType(PoolingType.MAX).build(), "attention").addLayer("output", new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build(), "pooling").setOutputs("output").setInputTypes(InputType.recurrent(nIn)).build();
ComputationGraphConfiguration graph = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.activation(Activation.TANH)
.updater(new NoOp())
.weightInit(WeightInit.XAVIER)
.graphBuilder()
.addInputs("input")
.addLayer("rnn", new SimpleRnn.Builder().activation(Activation.TANH).nOut(layerSize).build(), "input")
.addVertex("attention",
projectInput ?
new AttentionVertex.Builder().nOut(4).nHeads(2).projectInput(true).nInQueries(layerSize).nInKeys(layerSize).nInValues(layerSize).build()
: new AttentionVertex.Builder().nOut(4).nHeads(1).projectInput(false).nInQueries(layerSize).nInKeys(layerSize).nInValues(layerSize).build(), "rnn", "rnn", "rnn")
.addLayer("pooling", new GlobalPoolingLayer.Builder().poolingType(PoolingType.MAX).build(), "attention")
.addLayer("output", new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build(), "pooling")
.setOutputs("output")
.setInputTypes(InputType.recurrent(nIn))
.build();
ComputationGraph net = new ComputationGraph(graph); ComputationGraph net = new ComputationGraph(graph);
net.init(); net.init();
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[] { in }).labels(new INDArray[] { labels }).inputMask(inMask != null ? new INDArray[] { inMask } : null));
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[]{in}) assertTrue(gradOK,name);
.labels(new INDArray[]{labels}).inputMask(inMask != null ? new INDArray[]{inMask} : null));
assertTrue(name, gradOK);
} }
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.deeplearning4j.gradientcheck; package org.deeplearning4j.gradientcheck;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
@ -34,7 +33,7 @@ import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -48,18 +47,18 @@ import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.profiler.OpProfiler; import org.nd4j.linalg.profiler.OpProfiler;
import org.nd4j.linalg.profiler.ProfilerConfig; import org.nd4j.linalg.profiler.ProfilerConfig;
import java.util.Arrays; import java.util.Arrays;
import java.util.HashSet; import java.util.HashSet;
import java.util.Random; import java.util.Random;
import java.util.Set; import java.util.Set;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.Assert.assertTrue; import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
/** /**
*
*/ */
public class BNGradientCheckTest extends BaseDL4JTest { @DisplayName("Bn Gradient Check Test")
class BNGradientCheckTest extends BaseDL4JTest {
static { static {
Nd4j.setDataType(DataType.DOUBLE); Nd4j.setDataType(DataType.DOUBLE);
@ -71,7 +70,8 @@ public class BNGradientCheckTest extends BaseDL4JTest {
} }
@Test @Test
public void testGradient2dSimple() { @DisplayName("Test Gradient 2 d Simple")
void testGradient2dSimple() {
DataNormalization scaler = new NormalizerMinMaxScaler(); DataNormalization scaler = new NormalizerMinMaxScaler();
DataSetIterator iter = new IrisDataSetIterator(150, 150); DataSetIterator iter = new IrisDataSetIterator(150, 150);
scaler.fit(iter); scaler.fit(iter);
@ -79,181 +79,117 @@ public class BNGradientCheckTest extends BaseDL4JTest {
DataSet ds = iter.next(); DataSet ds = iter.next();
INDArray input = ds.getFeatures(); INDArray input = ds.getFeatures();
INDArray labels = ds.getLabels(); INDArray labels = ds.getLabels();
for (boolean useLogStd : new boolean[] { true, false }) {
for (boolean useLogStd : new boolean[]{true, false}) { MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().updater(new NoOp()).dataType(DataType.DOUBLE).seed(12345L).dist(new NormalDistribution(0, 1)).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(3).activation(Activation.IDENTITY).build()).layer(1, new BatchNormalization.Builder().useLogStd(useLogStd).nOut(3).build()).layer(2, new ActivationLayer.Builder().activation(Activation.TANH).build()).layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3).build());
MultiLayerConfiguration.Builder builder =
new NeuralNetConfiguration.Builder().updater(new NoOp())
.dataType(DataType.DOUBLE)
.seed(12345L)
.dist(new NormalDistribution(0, 1)).list()
.layer(0, new DenseLayer.Builder().nIn(4).nOut(3)
.activation(Activation.IDENTITY).build())
.layer(1, new BatchNormalization.Builder().useLogStd(useLogStd).nOut(3).build())
.layer(2, new ActivationLayer.Builder().activation(Activation.TANH).build())
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(3).nOut(3).build());
MultiLayerNetwork mln = new MultiLayerNetwork(builder.build()); MultiLayerNetwork mln = new MultiLayerNetwork(builder.build());
mln.init(); mln.init();
// for (int j = 0; j < mln.getnLayers(); j++)
// for (int j = 0; j < mln.getnLayers(); j++) // System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); // Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc
// i.e., runningMean = decay * runningMean + (1-decay) * batchMean
//Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc // However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter"
//i.e., runningMean = decay * runningMean + (1-decay) * batchMean
//However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter"
Set<String> excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "1_log10stdev")); Set<String> excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "1_log10stdev"));
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input) boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input).labels(labels).excludeParams(excludeParams));
.labels(labels).excludeParams(excludeParams));
assertTrue(gradOK); assertTrue(gradOK);
TestUtils.testModelSerialization(mln); TestUtils.testModelSerialization(mln);
} }
} }
@Test @Test
public void testGradientCnnSimple() { @DisplayName("Test Gradient Cnn Simple")
void testGradientCnnSimple() {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int minibatch = 10; int minibatch = 10;
int depth = 1; int depth = 1;
int hw = 4; int hw = 4;
int nOut = 4; int nOut = 4;
INDArray input = Nd4j.rand(new int[]{minibatch, depth, hw, hw}); INDArray input = Nd4j.rand(new int[] { minibatch, depth, hw, hw });
INDArray labels = Nd4j.zeros(minibatch, nOut); INDArray labels = Nd4j.zeros(minibatch, nOut);
Random r = new Random(12345); Random r = new Random(12345);
for (int i = 0; i < minibatch; i++) { for (int i = 0; i < minibatch; i++) {
labels.putScalar(i, r.nextInt(nOut), 1.0); labels.putScalar(i, r.nextInt(nOut), 1.0);
} }
for (boolean useLogStd : new boolean[] { true, false }) {
for (boolean useLogStd : new boolean[]{true, false}) { MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).updater(new NoOp()).seed(12345L).dist(new NormalDistribution(0, 2)).list().layer(0, new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).nIn(depth).nOut(2).activation(Activation.IDENTITY).build()).layer(1, new BatchNormalization.Builder().useLogStd(useLogStd).build()).layer(2, new ActivationLayer.Builder().activation(Activation.TANH).build()).layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(nOut).build()).setInputType(InputType.convolutional(hw, hw, depth));
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.updater(new NoOp()).seed(12345L)
.dist(new NormalDistribution(0, 2)).list()
.layer(0, new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).nIn(depth).nOut(2)
.activation(Activation.IDENTITY).build())
.layer(1, new BatchNormalization.Builder().useLogStd(useLogStd).build())
.layer(2, new ActivationLayer.Builder().activation(Activation.TANH).build())
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(nOut).build())
.setInputType(InputType.convolutional(hw, hw, depth));
MultiLayerNetwork mln = new MultiLayerNetwork(builder.build()); MultiLayerNetwork mln = new MultiLayerNetwork(builder.build());
mln.init(); mln.init();
// for (int j = 0; j < mln.getnLayers(); j++)
// for (int j = 0; j < mln.getnLayers(); j++) // System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); // Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc
// i.e., runningMean = decay * runningMean + (1-decay) * batchMean
//Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc // However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter"
//i.e., runningMean = decay * runningMean + (1-decay) * batchMean
//However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter"
Set<String> excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "1_log10stdev")); Set<String> excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "1_log10stdev"));
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input) boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input).labels(labels).excludeParams(excludeParams));
.labels(labels).excludeParams(excludeParams));
assertTrue(gradOK); assertTrue(gradOK);
TestUtils.testModelSerialization(mln); TestUtils.testModelSerialization(mln);
} }
} }
@Test @Test
public void testGradientBNWithCNNandSubsampling() { @DisplayName("Test Gradient BN With CN Nand Subsampling")
//Parameterized test, testing combinations of: void testGradientBNWithCNNandSubsampling() {
// Parameterized test, testing combinations of:
// (a) activation function // (a) activation function
// (b) Whether to test at random initialization, or after some learning (i.e., 'characteristic mode of operation') // (b) Whether to test at random initialization, or after some learning (i.e., 'characteristic mode of operation')
// (c) Loss function (with specified output activations) // (c) Loss function (with specified output activations)
// (d) l1 and l2 values // (d) l1 and l2 values
Activation[] activFns = {Activation.SIGMOID, Activation.TANH, Activation.IDENTITY}; Activation[] activFns = { Activation.SIGMOID, Activation.TANH, Activation.IDENTITY };
boolean[] characteristic = {true}; //If true: run some backprop steps first // If true: run some backprop steps first
boolean[] characteristic = { true };
LossFunctions.LossFunction[] lossFunctions = LossFunctions.LossFunction[] lossFunctions = { LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD, LossFunctions.LossFunction.MSE };
{LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD, LossFunctions.LossFunction.MSE}; // i.e., lossFunctions[i] used with outputActivations[i] here
Activation[] outputActivations = {Activation.SOFTMAX, Activation.TANH}; //i.e., lossFunctions[i] used with outputActivations[i] here Activation[] outputActivations = { Activation.SOFTMAX, Activation.TANH };
double[] l2vals = { 0.0, 0.1, 0.1 };
double[] l2vals = {0.0, 0.1, 0.1}; // i.e., use l2vals[j] with l1vals[j]
double[] l1vals = {0.0, 0.0, 0.2}; //i.e., use l2vals[j] with l1vals[j] double[] l1vals = { 0.0, 0.0, 0.2 };
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int minibatch = 4; int minibatch = 4;
int depth = 2; int depth = 2;
int hw = 5; int hw = 5;
int nOut = 2; int nOut = 2;
INDArray input = Nd4j.rand(new int[]{minibatch, depth, hw, hw}).muli(5).subi(2.5); INDArray input = Nd4j.rand(new int[] { minibatch, depth, hw, hw }).muli(5).subi(2.5);
INDArray labels = TestUtils.randomOneHot(minibatch, nOut); INDArray labels = TestUtils.randomOneHot(minibatch, nOut);
DataSet ds = new DataSet(input, labels); DataSet ds = new DataSet(input, labels);
Random rng = new Random(12345); Random rng = new Random(12345);
for (boolean useLogStd : new boolean[]{true, false}) { for (boolean useLogStd : new boolean[] { true, false }) {
for (Activation afn : activFns) { for (Activation afn : activFns) {
for (boolean doLearningFirst : characteristic) { for (boolean doLearningFirst : characteristic) {
for (int i = 0; i < lossFunctions.length; i++) { for (int i = 0; i < lossFunctions.length; i++) {
for (int j = 0; j < l2vals.length; j++) { for (int j = 0; j < l2vals.length; j++) {
//Skip 2 of every 3 tests: from 24 cases to 8, still with decent coverage // Skip 2 of every 3 tests: from 24 cases to 8, still with decent coverage
if (rng.nextInt(3) != 0) if (rng.nextInt(3) != 0)
continue; continue;
LossFunctions.LossFunction lf = lossFunctions[i]; LossFunctions.LossFunction lf = lossFunctions[i];
Activation outputActivation = outputActivations[i]; Activation outputActivation = outputActivations[i];
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(12345).dataType(DataType.DOUBLE).l2(l2vals[j]).optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).updater(new NoOp()).dist(new UniformDistribution(-2, 2)).seed(12345L).list().layer(0, new ConvolutionLayer.Builder(2, 2).stride(1, 1).nOut(3).activation(afn).build()).layer(1, new BatchNormalization.Builder().useLogStd(useLogStd).build()).layer(2, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(1, 1).build()).layer(3, new BatchNormalization()).layer(4, new ActivationLayer.Builder().activation(afn).build()).layer(5, new OutputLayer.Builder(lf).activation(outputActivation).nOut(nOut).build()).setInputType(InputType.convolutional(hw, hw, depth));
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(12345)
.dataType(DataType.DOUBLE)
.l2(l2vals[j])
.optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT)
.updater(new NoOp())
.dist(new UniformDistribution(-2, 2)).seed(12345L).list()
.layer(0, new ConvolutionLayer.Builder(2, 2).stride(1, 1).nOut(3)
.activation(afn).build())
.layer(1, new BatchNormalization.Builder().useLogStd(useLogStd).build())
.layer(2, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2, 2).stride(1, 1).build())
.layer(3, new BatchNormalization())
.layer(4, new ActivationLayer.Builder().activation(afn).build())
.layer(5, new OutputLayer.Builder(lf).activation(outputActivation).nOut(nOut)
.build())
.setInputType(InputType.convolutional(hw, hw, depth));
MultiLayerConfiguration conf = builder.build(); MultiLayerConfiguration conf = builder.build();
MultiLayerNetwork mln = new MultiLayerNetwork(conf); MultiLayerNetwork mln = new MultiLayerNetwork(conf);
mln.init(); mln.init();
String name = new Object() { String name = new Object() {
}.getClass().getEnclosingMethod().getName(); }.getClass().getEnclosingMethod().getName();
// System.out.println("Num params: " + mln.numParams());
// System.out.println("Num params: " + mln.numParams());
if (doLearningFirst) { if (doLearningFirst) {
//Run a number of iterations of learning // Run a number of iterations of learning
mln.setInput(ds.getFeatures()); mln.setInput(ds.getFeatures());
mln.setLabels(ds.getLabels()); mln.setLabels(ds.getLabels());
mln.computeGradientAndScore(); mln.computeGradientAndScore();
double scoreBefore = mln.score(); double scoreBefore = mln.score();
for (int k = 0; k < 20; k++) for (int k = 0; k < 20; k++) mln.fit(ds);
mln.fit(ds);
mln.computeGradientAndScore(); mln.computeGradientAndScore();
double scoreAfter = mln.score(); double scoreAfter = mln.score();
//Can't test in 'characteristic mode of operation' if not learning // Can't test in 'characteristic mode of operation' if not learning
String msg = name String msg = name + " - score did not (sufficiently) decrease during learning - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", doLearningFirst= " + doLearningFirst + " (before=" + scoreBefore + ", scoreAfter=" + scoreAfter + ")";
+ " - score did not (sufficiently) decrease during learning - activationFn=" assertTrue(scoreAfter < 0.9 * scoreBefore,msg);
+ afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation
+ ", doLearningFirst= " + doLearningFirst + " (before=" + scoreBefore
+ ", scoreAfter=" + scoreAfter + ")";
assertTrue(msg, scoreAfter < 0.9 * scoreBefore);
} }
System.out.println(name + " - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", doLearningFirst=" + doLearningFirst + ", l1=" + l1vals[j] + ", l2=" + l2vals[j]);
System.out.println(name + " - activationFn=" + afn + ", lossFn=" + lf // for (int k = 0; k < mln.getnLayers(); k++)
+ ", outputActivation=" + outputActivation + ", doLearningFirst=" // System.out.println("Layer " + k + " # params: " + mln.getLayer(k).numParams());
+ doLearningFirst + ", l1=" + l1vals[j] + ", l2=" + l2vals[j]); // Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc
// for (int k = 0; k < mln.getnLayers(); k++) // i.e., runningMean = decay * runningMean + (1-decay) * batchMean
// System.out.println("Layer " + k + " # params: " + mln.getLayer(k).numParams()); // However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter"
//Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc
//i.e., runningMean = decay * runningMean + (1-decay) * batchMean
//However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter"
Set<String> excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "3_mean", "3_var", "1_log10stdev", "3_log10stdev")); Set<String> excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "3_mean", "3_var", "1_log10stdev", "3_log10stdev"));
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input) boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input).labels(labels).excludeParams(excludeParams).subset(true).maxPerParam(// Most params are in output layer, only these should be skipped with this threshold
.labels(labels).excludeParams(excludeParams).subset(true).maxPerParam(25)); //Most params are in output layer, only these should be skipped with this threshold 25));
assertTrue(gradOK); assertTrue(gradOK);
TestUtils.testModelSerialization(mln); TestUtils.testModelSerialization(mln);
} }
@ -263,101 +199,68 @@ public class BNGradientCheckTest extends BaseDL4JTest {
} }
} }
@Test @Test
public void testGradientDense() { @DisplayName("Test Gradient Dense")
//Parameterized test, testing combinations of: void testGradientDense() {
// Parameterized test, testing combinations of:
// (a) activation function // (a) activation function
// (b) Whether to test at random initialization, or after some learning (i.e., 'characteristic mode of operation') // (b) Whether to test at random initialization, or after some learning (i.e., 'characteristic mode of operation')
// (c) Loss function (with specified output activations) // (c) Loss function (with specified output activations)
// (d) l1 and l2 values // (d) l1 and l2 values
Activation[] activFns = {Activation.TANH, Activation.IDENTITY}; Activation[] activFns = { Activation.TANH, Activation.IDENTITY };
boolean[] characteristic = {true}; //If true: run some backprop steps first // If true: run some backprop steps first
boolean[] characteristic = { true };
LossFunctions.LossFunction[] lossFunctions = LossFunctions.LossFunction[] lossFunctions = { LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD, LossFunctions.LossFunction.MSE };
{LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD, LossFunctions.LossFunction.MSE}; // i.e., lossFunctions[i] used with outputActivations[i] here
Activation[] outputActivations = {Activation.SOFTMAX, Activation.TANH}; //i.e., lossFunctions[i] used with outputActivations[i] here Activation[] outputActivations = { Activation.SOFTMAX, Activation.TANH };
double[] l2vals = { 0.0, 0.1 };
double[] l2vals = {0.0, 0.1}; // i.e., use l2vals[j] with l1vals[j]
double[] l1vals = {0.0, 0.2}; //i.e., use l2vals[j] with l1vals[j] double[] l1vals = { 0.0, 0.2 };
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int minibatch = 10; int minibatch = 10;
int nIn = 5; int nIn = 5;
int nOut = 3; int nOut = 3;
INDArray input = Nd4j.rand(new int[]{minibatch, nIn}); INDArray input = Nd4j.rand(new int[] { minibatch, nIn });
INDArray labels = Nd4j.zeros(minibatch, nOut); INDArray labels = Nd4j.zeros(minibatch, nOut);
Random r = new Random(12345); Random r = new Random(12345);
for (int i = 0; i < minibatch; i++) { for (int i = 0; i < minibatch; i++) {
labels.putScalar(i, r.nextInt(nOut), 1.0); labels.putScalar(i, r.nextInt(nOut), 1.0);
} }
DataSet ds = new DataSet(input, labels); DataSet ds = new DataSet(input, labels);
for (boolean useLogStd : new boolean[] { true, false }) {
for (boolean useLogStd : new boolean[]{true, false}) {
for (Activation afn : activFns) { for (Activation afn : activFns) {
for (boolean doLearningFirst : characteristic) { for (boolean doLearningFirst : characteristic) {
for (int i = 0; i < lossFunctions.length; i++) { for (int i = 0; i < lossFunctions.length; i++) {
for (int j = 0; j < l2vals.length; j++) { for (int j = 0; j < l2vals.length; j++) {
LossFunctions.LossFunction lf = lossFunctions[i]; LossFunctions.LossFunction lf = lossFunctions[i];
Activation outputActivation = outputActivations[i]; Activation outputActivation = outputActivations[i];
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).l2(l2vals[j]).optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).updater(new NoOp()).dist(new UniformDistribution(-2, 2)).seed(12345L).list().layer(0, new DenseLayer.Builder().nIn(nIn).nOut(4).activation(afn).build()).layer(1, new BatchNormalization.Builder().useLogStd(useLogStd).build()).layer(2, new DenseLayer.Builder().nIn(4).nOut(4).build()).layer(3, new BatchNormalization.Builder().useLogStd(useLogStd).build()).layer(4, new OutputLayer.Builder(lf).activation(outputActivation).nOut(nOut).build());
MultiLayerConfiguration.Builder builder =
new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.l2(l2vals[j])
.optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT)
.updater(new NoOp())
.dist(new UniformDistribution(-2, 2)).seed(12345L).list()
.layer(0, new DenseLayer.Builder().nIn(nIn).nOut(4)
.activation(afn).build())
.layer(1, new BatchNormalization.Builder().useLogStd(useLogStd).build())
.layer(2, new DenseLayer.Builder().nIn(4).nOut(4).build())
.layer(3, new BatchNormalization.Builder().useLogStd(useLogStd).build())
.layer(4, new OutputLayer.Builder(lf)
.activation(outputActivation).nOut(nOut)
.build());
MultiLayerConfiguration conf = builder.build(); MultiLayerConfiguration conf = builder.build();
MultiLayerNetwork mln = new MultiLayerNetwork(conf); MultiLayerNetwork mln = new MultiLayerNetwork(conf);
mln.init(); mln.init();
String name = new Object() { String name = new Object() {
}.getClass().getEnclosingMethod().getName(); }.getClass().getEnclosingMethod().getName();
if (doLearningFirst) { if (doLearningFirst) {
//Run a number of iterations of learning // Run a number of iterations of learning
mln.setInput(ds.getFeatures()); mln.setInput(ds.getFeatures());
mln.setLabels(ds.getLabels()); mln.setLabels(ds.getLabels());
mln.computeGradientAndScore(); mln.computeGradientAndScore();
double scoreBefore = mln.score(); double scoreBefore = mln.score();
for (int k = 0; k < 10; k++) for (int k = 0; k < 10; k++) mln.fit(ds);
mln.fit(ds);
mln.computeGradientAndScore(); mln.computeGradientAndScore();
double scoreAfter = mln.score(); double scoreAfter = mln.score();
//Can't test in 'characteristic mode of operation' if not learning // Can't test in 'characteristic mode of operation' if not learning
String msg = name String msg = name + " - score did not (sufficiently) decrease during learning - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", doLearningFirst= " + doLearningFirst + " (before=" + scoreBefore + ", scoreAfter=" + scoreAfter + ")";
+ " - score did not (sufficiently) decrease during learning - activationFn=" assertTrue(scoreAfter < 0.8 * scoreBefore,msg);
+ afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation
+ ", doLearningFirst= " + doLearningFirst + " (before=" + scoreBefore
+ ", scoreAfter=" + scoreAfter + ")";
assertTrue(msg, scoreAfter < 0.8 * scoreBefore);
} }
System.out.println(name + " - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", doLearningFirst=" + doLearningFirst + ", l1=" + l1vals[j] + ", l2=" + l2vals[j]);
System.out.println(name + " - activationFn=" + afn + ", lossFn=" + lf // for (int k = 0; k < mln.getnLayers(); k++)
+ ", outputActivation=" + outputActivation + ", doLearningFirst=" // System.out.println("Layer " + k + " # params: " + mln.getLayer(k).numParams());
+ doLearningFirst + ", l1=" + l1vals[j] + ", l2=" + l2vals[j]); // Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc
// for (int k = 0; k < mln.getnLayers(); k++) // i.e., runningMean = decay * runningMean + (1-decay) * batchMean
// System.out.println("Layer " + k + " # params: " + mln.getLayer(k).numParams()); // However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter"
//Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc
//i.e., runningMean = decay * runningMean + (1-decay) * batchMean
//However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter"
Set<String> excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "3_mean", "3_var", "1_log10stdev", "3_log10stdev")); Set<String> excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "3_mean", "3_var", "1_log10stdev", "3_log10stdev"));
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input) boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input).labels(labels).excludeParams(excludeParams));
.labels(labels).excludeParams(excludeParams));
assertTrue(gradOK); assertTrue(gradOK);
TestUtils.testModelSerialization(mln); TestUtils.testModelSerialization(mln);
} }
@ -368,7 +271,8 @@ public class BNGradientCheckTest extends BaseDL4JTest {
} }
@Test @Test
public void testGradient2dFixedGammaBeta() { @DisplayName("Test Gradient 2 d Fixed Gamma Beta")
void testGradient2dFixedGammaBeta() {
DataNormalization scaler = new NormalizerMinMaxScaler(); DataNormalization scaler = new NormalizerMinMaxScaler();
DataSetIterator iter = new IrisDataSetIterator(150, 150); DataSetIterator iter = new IrisDataSetIterator(150, 150);
scaler.fit(iter); scaler.fit(iter);
@ -376,219 +280,142 @@ public class BNGradientCheckTest extends BaseDL4JTest {
DataSet ds = iter.next(); DataSet ds = iter.next();
INDArray input = ds.getFeatures(); INDArray input = ds.getFeatures();
INDArray labels = ds.getLabels(); INDArray labels = ds.getLabels();
for (boolean useLogStd : new boolean[] { true, false }) {
for (boolean useLogStd : new boolean[]{true, false}) { MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().updater(new NoOp()).dataType(DataType.DOUBLE).seed(12345L).dist(new NormalDistribution(0, 1)).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(3).activation(Activation.IDENTITY).build()).layer(1, new BatchNormalization.Builder().useLogStd(useLogStd).lockGammaBeta(true).gamma(2.0).beta(0.5).nOut(3).build()).layer(2, new ActivationLayer.Builder().activation(Activation.TANH).build()).layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3).build());
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().updater(new NoOp())
.dataType(DataType.DOUBLE)
.seed(12345L)
.dist(new NormalDistribution(0, 1)).list()
.layer(0, new DenseLayer.Builder().nIn(4).nOut(3).activation(Activation.IDENTITY).build())
.layer(1, new BatchNormalization.Builder().useLogStd(useLogStd).lockGammaBeta(true).gamma(2.0).beta(0.5).nOut(3)
.build())
.layer(2, new ActivationLayer.Builder().activation(Activation.TANH).build())
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(3).nOut(3).build());
MultiLayerNetwork mln = new MultiLayerNetwork(builder.build()); MultiLayerNetwork mln = new MultiLayerNetwork(builder.build());
mln.init(); mln.init();
// for (int j = 0; j < mln.getnLayers(); j++)
// for (int j = 0; j < mln.getnLayers(); j++) // System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); // Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc
// i.e., runningMean = decay * runningMean + (1-decay) * batchMean
//Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc // However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter"
//i.e., runningMean = decay * runningMean + (1-decay) * batchMean
//However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter"
Set<String> excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "1_log10stdev")); Set<String> excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "1_log10stdev"));
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input) boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input).labels(labels).excludeParams(excludeParams));
.labels(labels).excludeParams(excludeParams));
assertTrue(gradOK); assertTrue(gradOK);
TestUtils.testModelSerialization(mln); TestUtils.testModelSerialization(mln);
} }
} }
@Test @Test
public void testGradientCnnFixedGammaBeta() { @DisplayName("Test Gradient Cnn Fixed Gamma Beta")
void testGradientCnnFixedGammaBeta() {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int minibatch = 10; int minibatch = 10;
int depth = 1; int depth = 1;
int hw = 4; int hw = 4;
int nOut = 4; int nOut = 4;
INDArray input = Nd4j.rand(new int[]{minibatch, depth, hw, hw}); INDArray input = Nd4j.rand(new int[] { minibatch, depth, hw, hw });
INDArray labels = Nd4j.zeros(minibatch, nOut); INDArray labels = Nd4j.zeros(minibatch, nOut);
Random r = new Random(12345); Random r = new Random(12345);
for (int i = 0; i < minibatch; i++) { for (int i = 0; i < minibatch; i++) {
labels.putScalar(i, r.nextInt(nOut), 1.0); labels.putScalar(i, r.nextInt(nOut), 1.0);
} }
for (boolean useLogStd : new boolean[] { true, false }) {
for (boolean useLogStd : new boolean[]{true, false}) { MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().updater(new NoOp()).dataType(DataType.DOUBLE).seed(12345L).dist(new NormalDistribution(0, 2)).list().layer(0, new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).nIn(depth).nOut(2).activation(Activation.IDENTITY).build()).layer(1, new BatchNormalization.Builder().useLogStd(useLogStd).lockGammaBeta(true).gamma(2.0).beta(0.5).build()).layer(2, new ActivationLayer.Builder().activation(Activation.TANH).build()).layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(nOut).build()).setInputType(InputType.convolutional(hw, hw, depth));
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().updater(new NoOp())
.dataType(DataType.DOUBLE)
.seed(12345L)
.dist(new NormalDistribution(0, 2)).list()
.layer(0, new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).nIn(depth).nOut(2)
.activation(Activation.IDENTITY).build())
.layer(1, new BatchNormalization.Builder().useLogStd(useLogStd).lockGammaBeta(true).gamma(2.0).beta(0.5).build())
.layer(2, new ActivationLayer.Builder().activation(Activation.TANH).build())
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(nOut).build())
.setInputType(InputType.convolutional(hw, hw, depth));
MultiLayerNetwork mln = new MultiLayerNetwork(builder.build()); MultiLayerNetwork mln = new MultiLayerNetwork(builder.build());
mln.init(); mln.init();
// for (int j = 0; j < mln.getnLayers(); j++)
// for (int j = 0; j < mln.getnLayers(); j++) // System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); // Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc
// i.e., runningMean = decay * runningMean + (1-decay) * batchMean
//Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc // However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter"
//i.e., runningMean = decay * runningMean + (1-decay) * batchMean
//However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter"
Set<String> excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "1_log10stdev")); Set<String> excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "1_log10stdev"));
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input) boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input).labels(labels).excludeParams(excludeParams));
.labels(labels).excludeParams(excludeParams));
assertTrue(gradOK); assertTrue(gradOK);
TestUtils.testModelSerialization(mln); TestUtils.testModelSerialization(mln);
} }
} }
@Test @Test
public void testBatchNormCompGraphSimple() { @DisplayName("Test Batch Norm Comp Graph Simple")
void testBatchNormCompGraphSimple() {
int numClasses = 2; int numClasses = 2;
int height = 3; int height = 3;
int width = 3; int width = 3;
int channels = 1; int channels = 1;
long seed = 123; long seed = 123;
int minibatchSize = 3; int minibatchSize = 3;
for (boolean useLogStd : new boolean[] { true, false }) {
for (boolean useLogStd : new boolean[]{true, false}) { ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed).updater(new NoOp()).dataType(DataType.DOUBLE).weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in").setInputTypes(InputType.convolutional(height, width, channels)).addLayer("bn", new BatchNormalization.Builder().useLogStd(useLogStd).build(), "in").addLayer("out", new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(numClasses).build(), "bn").setOutputs("out").build();
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed).updater(new NoOp())
.dataType(DataType.DOUBLE)
.weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in")
.setInputTypes(InputType.convolutional(height, width, channels))
.addLayer("bn", new BatchNormalization.Builder().useLogStd(useLogStd).build(), "in")
.addLayer("out", new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(numClasses).build(), "bn")
.setOutputs("out").build();
ComputationGraph net = new ComputationGraph(conf); ComputationGraph net = new ComputationGraph(conf);
net.init(); net.init();
Random r = new Random(12345); Random r = new Random(12345);
INDArray input = Nd4j.rand(new int[]{minibatchSize, channels, height, width}); //Order: examples, channels, height, width // Order: examples, channels, height, width
INDArray input = Nd4j.rand(new int[] { minibatchSize, channels, height, width });
INDArray labels = Nd4j.zeros(minibatchSize, numClasses); INDArray labels = Nd4j.zeros(minibatchSize, numClasses);
for (int i = 0; i < minibatchSize; i++) { for (int i = 0; i < minibatchSize; i++) {
labels.putScalar(new int[]{i, r.nextInt(numClasses)}, 1.0); labels.putScalar(new int[] { i, r.nextInt(numClasses) }, 1.0);
} }
// Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc
//Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc // i.e., runningMean = decay * runningMean + (1-decay) * batchMean
//i.e., runningMean = decay * runningMean + (1-decay) * batchMean // However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter"
//However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter"
Set<String> excludeParams = new HashSet<>(Arrays.asList("bn_mean", "bn_var")); Set<String> excludeParams = new HashSet<>(Arrays.asList("bn_mean", "bn_var"));
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[]{input}) boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[] { input }).labels(new INDArray[] { labels }).excludeParams(excludeParams));
.labels(new INDArray[]{labels}).excludeParams(excludeParams));
assertTrue(gradOK); assertTrue(gradOK);
TestUtils.testModelSerialization(net); TestUtils.testModelSerialization(net);
} }
} }
@Test @Test
public void testGradientBNWithCNNandSubsamplingCompGraph() { @DisplayName("Test Gradient BN With CN Nand Subsampling Comp Graph")
//Parameterized test, testing combinations of: void testGradientBNWithCNNandSubsamplingCompGraph() {
// Parameterized test, testing combinations of:
// (a) activation function // (a) activation function
// (b) Whether to test at random initialization, or after some learning (i.e., 'characteristic mode of operation') // (b) Whether to test at random initialization, or after some learning (i.e., 'characteristic mode of operation')
// (c) Loss function (with specified output activations) // (c) Loss function (with specified output activations)
// (d) l1 and l2 values // (d) l1 and l2 values
Activation[] activFns = {Activation.TANH, Activation.IDENTITY}; Activation[] activFns = { Activation.TANH, Activation.IDENTITY };
boolean doLearningFirst = true; boolean doLearningFirst = true;
LossFunctions.LossFunction[] lossFunctions = { LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD };
LossFunctions.LossFunction[] lossFunctions = {LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD}; // i.e., lossFunctions[i] used with outputActivations[i] here
Activation[] outputActivations = {Activation.SOFTMAX}; //i.e., lossFunctions[i] used with outputActivations[i] here Activation[] outputActivations = { Activation.SOFTMAX };
double[] l2vals = { 0.0, 0.1 };
double[] l2vals = {0.0, 0.1}; // i.e., use l2vals[j] with l1vals[j]
double[] l1vals = {0.0, 0.2}; //i.e., use l2vals[j] with l1vals[j] double[] l1vals = { 0.0, 0.2 };
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int minibatch = 10; int minibatch = 10;
int depth = 2; int depth = 2;
int hw = 5; int hw = 5;
int nOut = 3; int nOut = 3;
INDArray input = Nd4j.rand(new int[]{minibatch, depth, hw, hw}); INDArray input = Nd4j.rand(new int[] { minibatch, depth, hw, hw });
INDArray labels = Nd4j.zeros(minibatch, nOut); INDArray labels = Nd4j.zeros(minibatch, nOut);
Random r = new Random(12345); Random r = new Random(12345);
for (int i = 0; i < minibatch; i++) { for (int i = 0; i < minibatch; i++) {
labels.putScalar(i, r.nextInt(nOut), 1.0); labels.putScalar(i, r.nextInt(nOut), 1.0);
} }
DataSet ds = new DataSet(input, labels); DataSet ds = new DataSet(input, labels);
for (boolean useLogStd : new boolean[] { true, false }) {
for (boolean useLogStd : new boolean[]{true, false}) {
for (Activation afn : activFns) { for (Activation afn : activFns) {
for (int i = 0; i < lossFunctions.length; i++) { for (int i = 0; i < lossFunctions.length; i++) {
for (int j = 0; j < l2vals.length; j++) { for (int j = 0; j < l2vals.length; j++) {
LossFunctions.LossFunction lf = lossFunctions[i]; LossFunctions.LossFunction lf = lossFunctions[i];
Activation outputActivation = outputActivations[i]; Activation outputActivation = outputActivations[i];
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).dataType(DataType.DOUBLE).optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).updater(new NoOp()).dist(new UniformDistribution(-2, 2)).seed(12345L).graphBuilder().addInputs("in").addLayer("0", new ConvolutionLayer.Builder(2, 2).stride(1, 1).nOut(3).activation(afn).build(), "in").addLayer("1", new BatchNormalization.Builder().useLogStd(useLogStd).build(), "0").addLayer("2", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(1, 1).build(), "1").addLayer("3", new BatchNormalization.Builder().useLogStd(useLogStd).build(), "2").addLayer("4", new ActivationLayer.Builder().activation(afn).build(), "3").addLayer("5", new OutputLayer.Builder(lf).activation(outputActivation).nOut(nOut).build(), "4").setOutputs("5").setInputTypes(InputType.convolutional(hw, hw, depth)).build();
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
.dataType(DataType.DOUBLE)
.optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT)
.updater(new NoOp())
.dist(new UniformDistribution(-2, 2)).seed(12345L).graphBuilder()
.addInputs("in")
.addLayer("0", new ConvolutionLayer.Builder(2, 2).stride(1, 1).nOut(3)
.activation(afn).build(), "in")
.addLayer("1", new BatchNormalization.Builder().useLogStd(useLogStd).build(), "0")
.addLayer("2", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2, 2).stride(1, 1).build(), "1")
.addLayer("3", new BatchNormalization.Builder().useLogStd(useLogStd).build(), "2")
.addLayer("4", new ActivationLayer.Builder().activation(afn).build(), "3")
.addLayer("5", new OutputLayer.Builder(lf).activation(outputActivation)
.nOut(nOut).build(), "4")
.setOutputs("5").setInputTypes(InputType.convolutional(hw, hw, depth))
.build();
ComputationGraph net = new ComputationGraph(conf); ComputationGraph net = new ComputationGraph(conf);
net.init(); net.init();
String name = new Object() { String name = new Object() {
}.getClass().getEnclosingMethod().getName(); }.getClass().getEnclosingMethod().getName();
if (doLearningFirst) { if (doLearningFirst) {
//Run a number of iterations of learning // Run a number of iterations of learning
net.setInput(0, ds.getFeatures()); net.setInput(0, ds.getFeatures());
net.setLabels(ds.getLabels()); net.setLabels(ds.getLabels());
net.computeGradientAndScore(); net.computeGradientAndScore();
double scoreBefore = net.score(); double scoreBefore = net.score();
for (int k = 0; k < 20; k++) for (int k = 0; k < 20; k++) net.fit(ds);
net.fit(ds);
net.computeGradientAndScore(); net.computeGradientAndScore();
double scoreAfter = net.score(); double scoreAfter = net.score();
//Can't test in 'characteristic mode of operation' if not learning // Can't test in 'characteristic mode of operation' if not learning
String msg = name String msg = name + " - score did not (sufficiently) decrease during learning - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", doLearningFirst= " + doLearningFirst + " (before=" + scoreBefore + ", scoreAfter=" + scoreAfter + ")";
+ " - score did not (sufficiently) decrease during learning - activationFn=" assertTrue(scoreAfter < 0.9 * scoreBefore,msg);
+ afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation
+ ", doLearningFirst= " + doLearningFirst + " (before=" + scoreBefore
+ ", scoreAfter=" + scoreAfter + ")";
assertTrue(msg, scoreAfter < 0.9 * scoreBefore);
} }
System.out.println(name + " - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", doLearningFirst=" + doLearningFirst + ", l1=" + l1vals[j] + ", l2=" + l2vals[j]);
System.out.println(name + " - activationFn=" + afn + ", lossFn=" + lf // for (int k = 0; k < net.getNumLayers(); k++)
+ ", outputActivation=" + outputActivation + ", doLearningFirst=" // System.out.println("Layer " + k + " # params: " + net.getLayer(k).numParams());
+ doLearningFirst + ", l1=" + l1vals[j] + ", l2=" + l2vals[j]); // Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc
// for (int k = 0; k < net.getNumLayers(); k++) // i.e., runningMean = decay * runningMean + (1-decay) * batchMean
// System.out.println("Layer " + k + " # params: " + net.getLayer(k).numParams()); // However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter"
//Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc
//i.e., runningMean = decay * runningMean + (1-decay) * batchMean
//However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter"
Set<String> excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "3_mean", "3_var", "1_log10stdev", "3_log10stdev")); Set<String> excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "3_mean", "3_var", "1_log10stdev", "3_log10stdev"));
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[]{input}) boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[] { input }).labels(new INDArray[] { labels }).excludeParams(excludeParams));
.labels(new INDArray[]{labels}).excludeParams(excludeParams));
assertTrue(gradOK); assertTrue(gradOK);
TestUtils.testModelSerialization(net); TestUtils.testModelSerialization(net);
} }
@ -596,5 +423,4 @@ public class BNGradientCheckTest extends BaseDL4JTest {
} }
} }
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.deeplearning4j.gradientcheck; package org.deeplearning4j.gradientcheck;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@ -35,7 +34,7 @@ import org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.util.Convolution1DUtils; import org.deeplearning4j.util.Convolution1DUtils;
import org.deeplearning4j.util.ConvolutionUtils; import org.deeplearning4j.util.ConvolutionUtils;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -44,18 +43,24 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.io.File; import java.io.File;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.Assert.assertTrue; import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@Slf4j @Slf4j
public class CNN1DGradientCheckTest extends BaseDL4JTest { @DisplayName("Cnn 1 D Gradient Check Test")
class CNN1DGradientCheckTest extends BaseDL4JTest {
private static final boolean PRINT_RESULTS = true; private static final boolean PRINT_RESULTS = true;
private static final boolean RETURN_ON_FIRST_FAILURE = false; private static final boolean RETURN_ON_FIRST_FAILURE = false;
private static final double DEFAULT_EPS = 1e-6; private static final double DEFAULT_EPS = 1e-6;
private static final double DEFAULT_MAX_REL_ERROR = 1e-3; private static final double DEFAULT_MAX_REL_ERROR = 1e-3;
private static final double DEFAULT_MIN_ABS_ERROR = 1e-8; private static final double DEFAULT_MIN_ABS_ERROR = 1e-8;
static { static {
@ -68,148 +73,91 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
} }
@Test @Test
public void testCnn1DWithLocallyConnected1D() { @DisplayName("Test Cnn 1 D With Locally Connected 1 D")
void testCnn1DWithLocallyConnected1D() {
Nd4j.getRandom().setSeed(1337); Nd4j.getRandom().setSeed(1337);
int[] minibatchSizes = { 2, 3 };
int[] minibatchSizes = {2, 3};
int length = 7; int length = 7;
int convNIn = 2; int convNIn = 2;
int convNOut1 = 3; int convNOut1 = 3;
int convNOut2 = 4; int convNOut2 = 4;
int finalNOut = 4; int finalNOut = 4;
int[] kernels = { 1 };
int[] kernels = {1};
int stride = 1; int stride = 1;
int padding = 0; int padding = 0;
Activation[] activations = { Activation.SIGMOID };
Activation[] activations = {Activation.SIGMOID};
for (Activation afn : activations) { for (Activation afn : activations) {
for (int minibatchSize : minibatchSizes) { for (int minibatchSize : minibatchSizes) {
for (int kernel : kernels) { for (int kernel : kernels) {
INDArray input = Nd4j.rand(new int[]{minibatchSize, convNIn, length}); INDArray input = Nd4j.rand(new int[] { minibatchSize, convNIn, length });
INDArray labels = Nd4j.zeros(minibatchSize, finalNOut, length); INDArray labels = Nd4j.zeros(minibatchSize, finalNOut, length);
for (int i = 0; i < minibatchSize; i++) { for (int i = 0; i < minibatchSize; i++) {
for (int j = 0; j < length; j++) { for (int j = 0; j < length; j++) {
labels.putScalar(new int[]{i, i % finalNOut, j}, 1.0); labels.putScalar(new int[] { i, i % finalNOut, j }, 1.0);
} }
} }
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).updater(new NoOp()).dist(new NormalDistribution(0, 1)).convolutionMode(ConvolutionMode.Same).list().layer(new Convolution1DLayer.Builder().activation(afn).kernelSize(kernel).stride(stride).padding(padding).nIn(convNIn).nOut(convNOut1).rnnDataFormat(RNNFormat.NCW).build()).layer(new LocallyConnected1D.Builder().activation(afn).kernelSize(kernel).stride(stride).padding(padding).nIn(convNOut1).nOut(convNOut2).hasBias(false).build()).layer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(finalNOut).build()).setInputType(InputType.recurrent(convNIn, length)).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.updater(new NoOp())
.dist(new NormalDistribution(0, 1)).convolutionMode(ConvolutionMode.Same).list()
.layer(new Convolution1DLayer.Builder().activation(afn).kernelSize(kernel)
.stride(stride).padding(padding).nIn(convNIn).nOut(convNOut1)
.rnnDataFormat(RNNFormat.NCW)
.build())
.layer(new LocallyConnected1D.Builder().activation(afn).kernelSize(kernel)
.stride(stride).padding(padding).nIn(convNOut1).nOut(convNOut2).hasBias(false)
.build())
.layer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(finalNOut).build())
.setInputType(InputType.recurrent(convNIn, length)).build();
String json = conf.toJson(); String json = conf.toJson();
MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json); MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json);
assertEquals(conf, c2); assertEquals(conf, c2);
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
String msg = "Minibatch=" + minibatchSize + ", activationFn=" + afn + ", kernel = " + kernel;
String msg = "Minibatch=" + minibatchSize + ", activationFn="
+ afn + ", kernel = " + kernel;
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println(msg); System.out.println(msg);
// for (int j = 0; j < net.getnLayers(); j++) // for (int j = 0; j < net.getnLayers(); j++)
// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, assertTrue(gradOK,msg);
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(net); TestUtils.testModelSerialization(net);
} }
} }
} }
} }
@Test @Test
public void testCnn1DWithCropping1D() { @DisplayName("Test Cnn 1 D With Cropping 1 D")
void testCnn1DWithCropping1D() {
Nd4j.getRandom().setSeed(1337); Nd4j.getRandom().setSeed(1337);
int[] minibatchSizes = { 1, 3 };
int[] minibatchSizes = {1, 3};
int length = 7; int length = 7;
int convNIn = 2; int convNIn = 2;
int convNOut1 = 3; int convNOut1 = 3;
int convNOut2 = 4; int convNOut2 = 4;
int finalNOut = 4; int finalNOut = 4;
int[] kernels = { 1, 2, 4 };
int[] kernels = {1, 2, 4};
int stride = 1; int stride = 1;
int padding = 0; int padding = 0;
int cropping = 1; int cropping = 1;
int croppedLength = length - 2 * cropping; int croppedLength = length - 2 * cropping;
Activation[] activations = { Activation.SIGMOID };
Activation[] activations = {Activation.SIGMOID}; SubsamplingLayer.PoolingType[] poolingTypes = new SubsamplingLayer.PoolingType[] { SubsamplingLayer.PoolingType.MAX, SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM };
SubsamplingLayer.PoolingType[] poolingTypes =
new SubsamplingLayer.PoolingType[]{SubsamplingLayer.PoolingType.MAX,
SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM};
for (Activation afn : activations) { for (Activation afn : activations) {
for (SubsamplingLayer.PoolingType poolingType : poolingTypes) { for (SubsamplingLayer.PoolingType poolingType : poolingTypes) {
for (int minibatchSize : minibatchSizes) { for (int minibatchSize : minibatchSizes) {
for (int kernel : kernels) { for (int kernel : kernels) {
INDArray input = Nd4j.rand(new int[]{minibatchSize, convNIn, length}); INDArray input = Nd4j.rand(new int[] { minibatchSize, convNIn, length });
INDArray labels = Nd4j.zeros(minibatchSize, finalNOut, croppedLength); INDArray labels = Nd4j.zeros(minibatchSize, finalNOut, croppedLength);
for (int i = 0; i < minibatchSize; i++) { for (int i = 0; i < minibatchSize; i++) {
for (int j = 0; j < croppedLength; j++) { for (int j = 0; j < croppedLength; j++) {
labels.putScalar(new int[]{i, i % finalNOut, j}, 1.0); labels.putScalar(new int[] { i, i % finalNOut, j }, 1.0);
} }
} }
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).updater(new NoOp()).dist(new NormalDistribution(0, 1)).convolutionMode(ConvolutionMode.Same).list().layer(new Convolution1DLayer.Builder().activation(afn).kernelSize(kernel).stride(stride).padding(padding).nOut(convNOut1).build()).layer(new Cropping1D.Builder(cropping).build()).layer(new Convolution1DLayer.Builder().activation(afn).kernelSize(kernel).stride(stride).padding(padding).nOut(convNOut2).build()).layer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(finalNOut).build()).setInputType(InputType.recurrent(convNIn, length, RNNFormat.NCW)).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.updater(new NoOp())
.dist(new NormalDistribution(0, 1)).convolutionMode(ConvolutionMode.Same).list()
.layer(new Convolution1DLayer.Builder().activation(afn).kernelSize(kernel)
.stride(stride).padding(padding).nOut(convNOut1)
.build())
.layer(new Cropping1D.Builder(cropping).build())
.layer(new Convolution1DLayer.Builder().activation(afn).kernelSize(kernel)
.stride(stride).padding(padding).nOut(convNOut2)
.build())
.layer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(finalNOut).build())
.setInputType(InputType.recurrent(convNIn, length,RNNFormat.NCW)).build();
String json = conf.toJson(); String json = conf.toJson();
MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json); MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json);
assertEquals(conf, c2); assertEquals(conf, c2);
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
String msg = "PoolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn=" + afn + ", kernel = " + kernel;
String msg = "PoolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn="
+ afn + ", kernel = " + kernel;
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println(msg); System.out.println(msg);
// for (int j = 0; j < net.getnLayers(); j++) // for (int j = 0; j < net.getnLayers(); j++)
// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, assertTrue(gradOK,msg);
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(net); TestUtils.testModelSerialization(net);
} }
@ -218,82 +166,50 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
} }
} }
@Test @Test
public void testCnn1DWithZeroPadding1D() { @DisplayName("Test Cnn 1 D With Zero Padding 1 D")
void testCnn1DWithZeroPadding1D() {
Nd4j.getRandom().setSeed(1337); Nd4j.getRandom().setSeed(1337);
int[] minibatchSizes = { 1, 3 };
int[] minibatchSizes = {1, 3};
int length = 7; int length = 7;
int convNIn = 2; int convNIn = 2;
int convNOut1 = 3; int convNOut1 = 3;
int convNOut2 = 4; int convNOut2 = 4;
int finalNOut = 4; int finalNOut = 4;
int[] kernels = { 1, 2, 4 };
int[] kernels = {1, 2, 4};
int stride = 1; int stride = 1;
int pnorm = 2; int pnorm = 2;
int padding = 0; int padding = 0;
int zeroPadding = 2; int zeroPadding = 2;
int paddedLength = length + 2 * zeroPadding; int paddedLength = length + 2 * zeroPadding;
Activation[] activations = { Activation.SIGMOID };
Activation[] activations = {Activation.SIGMOID}; SubsamplingLayer.PoolingType[] poolingTypes = new SubsamplingLayer.PoolingType[] { SubsamplingLayer.PoolingType.MAX, SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM };
SubsamplingLayer.PoolingType[] poolingTypes =
new SubsamplingLayer.PoolingType[]{SubsamplingLayer.PoolingType.MAX,
SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM};
for (Activation afn : activations) { for (Activation afn : activations) {
for (SubsamplingLayer.PoolingType poolingType : poolingTypes) { for (SubsamplingLayer.PoolingType poolingType : poolingTypes) {
for (int minibatchSize : minibatchSizes) { for (int minibatchSize : minibatchSizes) {
for (int kernel : kernels) { for (int kernel : kernels) {
INDArray input = Nd4j.rand(new int[]{minibatchSize, convNIn, length}); INDArray input = Nd4j.rand(new int[] { minibatchSize, convNIn, length });
INDArray labels = Nd4j.zeros(minibatchSize, finalNOut, paddedLength); INDArray labels = Nd4j.zeros(minibatchSize, finalNOut, paddedLength);
for (int i = 0; i < minibatchSize; i++) { for (int i = 0; i < minibatchSize; i++) {
for (int j = 0; j < paddedLength; j++) { for (int j = 0; j < paddedLength; j++) {
labels.putScalar(new int[]{i, i % finalNOut, j}, 1.0); labels.putScalar(new int[] { i, i % finalNOut, j }, 1.0);
} }
} }
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).updater(new NoOp()).dist(new NormalDistribution(0, 1)).convolutionMode(ConvolutionMode.Same).list().layer(new Convolution1DLayer.Builder().activation(afn).kernelSize(kernel).stride(stride).padding(padding).nOut(convNOut1).build()).layer(new ZeroPadding1DLayer.Builder(zeroPadding).build()).layer(new Convolution1DLayer.Builder().activation(afn).kernelSize(kernel).stride(stride).padding(padding).nOut(convNOut2).build()).layer(new ZeroPadding1DLayer.Builder(0).build()).layer(new Subsampling1DLayer.Builder(poolingType).kernelSize(kernel).stride(stride).padding(padding).pnorm(pnorm).build()).layer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(finalNOut).build()).setInputType(InputType.recurrent(convNIn, length, RNNFormat.NCW)).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.updater(new NoOp())
.dist(new NormalDistribution(0, 1)).convolutionMode(ConvolutionMode.Same).list()
.layer(new Convolution1DLayer.Builder().activation(afn).kernelSize(kernel)
.stride(stride).padding(padding).nOut(convNOut1)
.build())
.layer(new ZeroPadding1DLayer.Builder(zeroPadding).build())
.layer(new Convolution1DLayer.Builder().activation(afn).kernelSize(kernel)
.stride(stride).padding(padding).nOut(convNOut2)
.build())
.layer(new ZeroPadding1DLayer.Builder(0).build())
.layer(new Subsampling1DLayer.Builder(poolingType).kernelSize(kernel)
.stride(stride).padding(padding).pnorm(pnorm).build())
.layer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(finalNOut).build())
.setInputType(InputType.recurrent(convNIn, length,RNNFormat.NCW)).build();
String json = conf.toJson(); String json = conf.toJson();
MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json); MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json);
assertEquals(conf, c2); assertEquals(conf, c2);
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
String msg = "PoolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn=" + afn + ", kernel = " + kernel;
String msg = "PoolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn="
+ afn + ", kernel = " + kernel;
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println(msg); System.out.println(msg);
// for (int j = 0; j < net.getnLayers(); j++) // for (int j = 0; j < net.getnLayers(); j++)
// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
assertTrue(gradOK,msg);
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(net); TestUtils.testModelSerialization(net);
} }
} }
@ -301,76 +217,48 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
} }
} }
@Test @Test
public void testCnn1DWithSubsampling1D() { @DisplayName("Test Cnn 1 D With Subsampling 1 D")
void testCnn1DWithSubsampling1D() {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int[] minibatchSizes = { 1, 3 };
int[] minibatchSizes = {1, 3};
int length = 7; int length = 7;
int convNIn = 2; int convNIn = 2;
int convNOut1 = 3; int convNOut1 = 3;
int convNOut2 = 4; int convNOut2 = 4;
int finalNOut = 4; int finalNOut = 4;
int[] kernels = { 1, 2, 4 };
int[] kernels = {1, 2, 4};
int stride = 1; int stride = 1;
int padding = 0; int padding = 0;
int pnorm = 2; int pnorm = 2;
Activation[] activations = { Activation.SIGMOID, Activation.TANH };
Activation[] activations = {Activation.SIGMOID, Activation.TANH}; SubsamplingLayer.PoolingType[] poolingTypes = new SubsamplingLayer.PoolingType[] { SubsamplingLayer.PoolingType.MAX, SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM };
SubsamplingLayer.PoolingType[] poolingTypes =
new SubsamplingLayer.PoolingType[]{SubsamplingLayer.PoolingType.MAX,
SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM};
for (Activation afn : activations) { for (Activation afn : activations) {
for (SubsamplingLayer.PoolingType poolingType : poolingTypes) { for (SubsamplingLayer.PoolingType poolingType : poolingTypes) {
for (int minibatchSize : minibatchSizes) { for (int minibatchSize : minibatchSizes) {
for (int kernel : kernels) { for (int kernel : kernels) {
INDArray input = Nd4j.rand(new int[]{minibatchSize, convNIn, length}); INDArray input = Nd4j.rand(new int[] { minibatchSize, convNIn, length });
INDArray labels = Nd4j.zeros(minibatchSize, finalNOut, length); INDArray labels = Nd4j.zeros(minibatchSize, finalNOut, length);
for (int i = 0; i < minibatchSize; i++) { for (int i = 0; i < minibatchSize; i++) {
for (int j = 0; j < length; j++) { for (int j = 0; j < length; j++) {
labels.putScalar(new int[]{i, i % finalNOut, j}, 1.0); labels.putScalar(new int[] { i, i % finalNOut, j }, 1.0);
} }
} }
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).updater(new NoOp()).dist(new NormalDistribution(0, 1)).convolutionMode(ConvolutionMode.Same).list().layer(0, new Convolution1DLayer.Builder().activation(afn).kernelSize(kernel).stride(stride).padding(padding).nOut(convNOut1).build()).layer(1, new Convolution1DLayer.Builder().activation(afn).kernelSize(kernel).stride(stride).padding(padding).nOut(convNOut2).build()).layer(2, new Subsampling1DLayer.Builder(poolingType).kernelSize(kernel).stride(stride).padding(padding).pnorm(pnorm).build()).layer(3, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(finalNOut).build()).setInputType(InputType.recurrent(convNIn, length, RNNFormat.NCW)).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.updater(new NoOp())
.dist(new NormalDistribution(0, 1)).convolutionMode(ConvolutionMode.Same).list()
.layer(0, new Convolution1DLayer.Builder().activation(afn).kernelSize(kernel)
.stride(stride).padding(padding).nOut(convNOut1)
.build())
.layer(1, new Convolution1DLayer.Builder().activation(afn).kernelSize(kernel)
.stride(stride).padding(padding).nOut(convNOut2)
.build())
.layer(2, new Subsampling1DLayer.Builder(poolingType).kernelSize(kernel)
.stride(stride).padding(padding).pnorm(pnorm).build())
.layer(3, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(finalNOut).build())
.setInputType(InputType.recurrent(convNIn, length,RNNFormat.NCW)).build();
String json = conf.toJson(); String json = conf.toJson();
MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json); MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json);
assertEquals(conf, c2); assertEquals(conf, c2);
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
String msg = "PoolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn=" + afn + ", kernel = " + kernel;
String msg = "PoolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn="
+ afn + ", kernel = " + kernel;
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println(msg); System.out.println(msg);
// for (int j = 0; j < net.getnLayers(); j++) // for (int j = 0; j < net.getnLayers(); j++)
// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
assertTrue(gradOK,msg);
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(net); TestUtils.testModelSerialization(net);
} }
} }
@ -379,66 +267,34 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
} }
@Test @Test
public void testCnn1dWithMasking(){ @DisplayName("Test Cnn 1 d With Masking")
void testCnn1dWithMasking() {
int length = 12; int length = 12;
int convNIn = 2; int convNIn = 2;
int convNOut1 = 3; int convNOut1 = 3;
int convNOut2 = 4; int convNOut2 = 4;
int finalNOut = 3; int finalNOut = 3;
int pnorm = 2; int pnorm = 2;
SubsamplingLayer.PoolingType[] poolingTypes = new SubsamplingLayer.PoolingType[] { SubsamplingLayer.PoolingType.MAX, SubsamplingLayer.PoolingType.AVG };
SubsamplingLayer.PoolingType[] poolingTypes =
new SubsamplingLayer.PoolingType[] {SubsamplingLayer.PoolingType.MAX, SubsamplingLayer.PoolingType.AVG};
for (SubsamplingLayer.PoolingType poolingType : poolingTypes) { for (SubsamplingLayer.PoolingType poolingType : poolingTypes) {
for(ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Same, ConvolutionMode.Truncate}) { for (ConvolutionMode cm : new ConvolutionMode[] { ConvolutionMode.Same, ConvolutionMode.Truncate }) {
for( int stride : new int[]{1, 2}){ for (int stride : new int[] { 1, 2 }) {
String s = cm + ", stride=" + stride + ", pooling=" + poolingType; String s = cm + ", stride=" + stride + ", pooling=" + poolingType;
log.info("Starting test: " + s); log.info("Starting test: " + s);
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).updater(new NoOp()).activation(Activation.TANH).dist(new NormalDistribution(0, 1)).convolutionMode(cm).seed(12345).list().layer(new Convolution1DLayer.Builder().kernelSize(2).rnnDataFormat(RNNFormat.NCW).stride(stride).nIn(convNIn).nOut(convNOut1).build()).layer(new Subsampling1DLayer.Builder(poolingType).kernelSize(2).stride(stride).pnorm(pnorm).build()).layer(new Convolution1DLayer.Builder().kernelSize(2).rnnDataFormat(RNNFormat.NCW).stride(stride).nIn(convNOut1).nOut(convNOut2).build()).layer(new GlobalPoolingLayer(PoolingType.AVG)).layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(finalNOut).build()).setInputType(InputType.recurrent(convNIn, length)).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.updater(new NoOp())
.activation(Activation.TANH)
.dist(new NormalDistribution(0, 1)).convolutionMode(cm)
.seed(12345)
.list()
.layer(new Convolution1DLayer.Builder().kernelSize(2)
.rnnDataFormat(RNNFormat.NCW)
.stride(stride).nIn(convNIn).nOut(convNOut1)
.build())
.layer(new Subsampling1DLayer.Builder(poolingType).kernelSize(2)
.stride(stride).pnorm(pnorm).build())
.layer(new Convolution1DLayer.Builder().kernelSize(2)
.rnnDataFormat(RNNFormat.NCW)
.stride(stride).nIn(convNOut1).nOut(convNOut2)
.build())
.layer(new GlobalPoolingLayer(PoolingType.AVG))
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(finalNOut).build())
.setInputType(InputType.recurrent(convNIn, length)).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
INDArray f = Nd4j.rand(new int[] { 2, convNIn, length });
INDArray f = Nd4j.rand(new int[]{2, convNIn, length});
INDArray fm = Nd4j.create(2, length); INDArray fm = Nd4j.create(2, length);
fm.get(NDArrayIndex.point(0), NDArrayIndex.all()).assign(1); fm.get(NDArrayIndex.point(0), NDArrayIndex.all()).assign(1);
fm.get(NDArrayIndex.point(1), NDArrayIndex.interval(0,6)).assign(1); fm.get(NDArrayIndex.point(1), NDArrayIndex.interval(0, 6)).assign(1);
INDArray label = TestUtils.randomOneHot(2, finalNOut); INDArray label = TestUtils.randomOneHot(2, finalNOut);
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(f).labels(label).inputMask(fm));
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(f) assertTrue(gradOK,s);
.labels(label).inputMask(fm));
assertTrue(s, gradOK);
TestUtils.testModelSerialization(net); TestUtils.testModelSerialization(net);
// TODO also check that masked step values don't impact forward pass, score or gradients
//TODO also check that masked step values don't impact forward pass, score or gradients DataSet ds = new DataSet(f, label, fm, null);
DataSet ds = new DataSet(f,label,fm,null);
double scoreBefore = net.score(ds); double scoreBefore = net.score(ds);
net.setInput(f); net.setInput(f);
net.setLabels(label); net.setLabels(label);
@ -453,7 +309,6 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
net.setLayerMaskArrays(fm, null); net.setLayerMaskArrays(fm, null);
net.computeGradientAndScore(); net.computeGradientAndScore();
INDArray gradAfter = net.getFlattenedGradients().dup(); INDArray gradAfter = net.getFlattenedGradients().dup();
assertEquals(scoreBefore, scoreAfter, 1e-6); assertEquals(scoreBefore, scoreAfter, 1e-6);
assertEquals(gradBefore, gradAfter); assertEquals(gradBefore, gradAfter);
} }
@ -462,18 +317,18 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
} }
@Test @Test
public void testCnn1Causal() throws Exception { @DisplayName("Test Cnn 1 Causal")
void testCnn1Causal() throws Exception {
int convNIn = 2; int convNIn = 2;
int convNOut1 = 3; int convNOut1 = 3;
int convNOut2 = 4; int convNOut2 = 4;
int finalNOut = 3; int finalNOut = 3;
int[] lengths = { 11, 12, 13, 9, 10, 11 };
int[] lengths = {11, 12, 13, 9, 10, 11}; int[] kernels = { 2, 3, 2, 4, 2, 3 };
int[] kernels = {2, 3, 2, 4, 2, 3}; int[] dilations = { 1, 1, 2, 1, 2, 1 };
int[] dilations = {1, 1, 2, 1, 2, 1}; int[] strides = { 1, 2, 1, 2, 1, 1 };
int[] strides = {1, 2, 1, 2, 1, 1}; boolean[] masks = { false, true, false, true, false, true };
boolean[] masks = {false, true, false, true, false, true}; boolean[] hasB = { true, false, true, false, true, true };
boolean[] hasB = {true, false, true, false, true, true};
for (int i = 0; i < lengths.length; i++) { for (int i = 0; i < lengths.length; i++) {
int length = lengths[i]; int length = lengths[i];
int k = kernels[i]; int k = kernels[i];
@ -481,36 +336,13 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
int st = strides[i]; int st = strides[i];
boolean mask = masks[i]; boolean mask = masks[i];
boolean hasBias = hasB[i]; boolean hasBias = hasB[i];
//TODO has bias // TODO has bias
String s = "k=" + k + ", s=" + st + " d=" + d + ", seqLen=" + length; String s = "k=" + k + ", s=" + st + " d=" + d + ", seqLen=" + length;
log.info("Starting test: " + s); log.info("Starting test: " + s);
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).updater(new NoOp()).activation(Activation.TANH).weightInit(new NormalDistribution(0, 1)).seed(12345).list().layer(new Convolution1DLayer.Builder().kernelSize(k).dilation(d).hasBias(hasBias).convolutionMode(ConvolutionMode.Causal).stride(st).nOut(convNOut1).build()).layer(new Convolution1DLayer.Builder().kernelSize(k).dilation(d).convolutionMode(ConvolutionMode.Causal).stride(st).nOut(convNOut2).build()).layer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(finalNOut).build()).setInputType(InputType.recurrent(convNIn, length, RNNFormat.NCW)).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.updater(new NoOp())
.activation(Activation.TANH)
.weightInit(new NormalDistribution(0, 1))
.seed(12345)
.list()
.layer(new Convolution1DLayer.Builder().kernelSize(k)
.dilation(d)
.hasBias(hasBias)
.convolutionMode(ConvolutionMode.Causal)
.stride(st).nOut(convNOut1)
.build())
.layer(new Convolution1DLayer.Builder().kernelSize(k)
.dilation(d)
.convolutionMode(ConvolutionMode.Causal)
.stride(st).nOut(convNOut2)
.build())
.layer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(finalNOut).build())
.setInputType(InputType.recurrent(convNIn, length,RNNFormat.NCW)).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
INDArray f = Nd4j.rand(DataType.DOUBLE, 2, convNIn, length); INDArray f = Nd4j.rand(DataType.DOUBLE, 2, convNIn, length);
INDArray fm = null; INDArray fm = null;
if (mask) { if (mask) {
@ -518,16 +350,11 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
fm.get(NDArrayIndex.point(0), NDArrayIndex.all()).assign(1); fm.get(NDArrayIndex.point(0), NDArrayIndex.all()).assign(1);
fm.get(NDArrayIndex.point(1), NDArrayIndex.interval(0, length - 2)).assign(1); fm.get(NDArrayIndex.point(1), NDArrayIndex.interval(0, length - 2)).assign(1);
} }
long outSize1 = Convolution1DUtils.getOutputSize(length, k, st, 0, ConvolutionMode.Causal, d); long outSize1 = Convolution1DUtils.getOutputSize(length, k, st, 0, ConvolutionMode.Causal, d);
long outSize2 = Convolution1DUtils.getOutputSize(outSize1, k, st, 0, ConvolutionMode.Causal, d); long outSize2 = Convolution1DUtils.getOutputSize(outSize1, k, st, 0, ConvolutionMode.Causal, d);
INDArray label = TestUtils.randomOneHotTimeSeries(2, finalNOut, (int) outSize2);
INDArray label = TestUtils.randomOneHotTimeSeries(2, finalNOut, (int)outSize2); boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(f).labels(label).inputMask(fm));
assertTrue(gradOK,s);
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(f)
.labels(label).inputMask(fm));
assertTrue(s, gradOK);
TestUtils.testModelSerialization(net); TestUtils.testModelSerialization(net);
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.deeplearning4j.gradientcheck; package org.deeplearning4j.gradientcheck;
import lombok.extern.java.Log; import lombok.extern.java.Log;
@ -33,7 +32,7 @@ import org.deeplearning4j.nn.conf.layers.convolutional.Cropping3D;
import org.deeplearning4j.nn.conf.preprocessor.Cnn3DToFeedForwardPreProcessor; import org.deeplearning4j.nn.conf.preprocessor.Cnn3DToFeedForwardPreProcessor;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
@ -41,18 +40,24 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.Arrays; import java.util.Arrays;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.Assert.assertTrue; import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@Log @Log
public class CNN3DGradientCheckTest extends BaseDL4JTest { @DisplayName("Cnn 3 D Gradient Check Test")
class CNN3DGradientCheckTest extends BaseDL4JTest {
private static final boolean PRINT_RESULTS = true; private static final boolean PRINT_RESULTS = true;
private static final boolean RETURN_ON_FIRST_FAILURE = false; private static final boolean RETURN_ON_FIRST_FAILURE = false;
private static final double DEFAULT_EPS = 1e-6; private static final double DEFAULT_EPS = 1e-6;
private static final double DEFAULT_MAX_REL_ERROR = 1e-3; private static final double DEFAULT_MAX_REL_ERROR = 1e-3;
private static final double DEFAULT_MIN_ABS_ERROR = 1e-8; private static final double DEFAULT_MIN_ABS_ERROR = 1e-8;
static { static {
@ -65,30 +70,23 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest {
} }
@Test @Test
public void testCnn3DPlain() { @DisplayName("Test Cnn 3 D Plain")
void testCnn3DPlain() {
Nd4j.getRandom().setSeed(1337); Nd4j.getRandom().setSeed(1337);
// Note: we checked this with a variety of parameters, but it takes a lot of time. // Note: we checked this with a variety of parameters, but it takes a lot of time.
int[] depths = {6}; int[] depths = { 6 };
int[] heights = {6}; int[] heights = { 6 };
int[] widths = {6}; int[] widths = { 6 };
int[] minibatchSizes = { 3 };
int[] minibatchSizes = {3};
int convNIn = 2; int convNIn = 2;
int convNOut1 = 3; int convNOut1 = 3;
int convNOut2 = 4; int convNOut2 = 4;
int denseNOut = 5; int denseNOut = 5;
int finalNOut = 42; int finalNOut = 42;
int[][] kernels = { { 2, 2, 2 } };
int[][] strides = { { 1, 1, 1 } };
int[][] kernels = {{2, 2, 2}}; Activation[] activations = { Activation.SIGMOID };
int[][] strides = {{1, 1, 1}}; ConvolutionMode[] modes = { ConvolutionMode.Truncate, ConvolutionMode.Same };
Activation[] activations = {Activation.SIGMOID};
ConvolutionMode[] modes = {ConvolutionMode.Truncate, ConvolutionMode.Same};
for (Activation afn : activations) { for (Activation afn : activations) {
for (int miniBatchSize : minibatchSizes) { for (int miniBatchSize : minibatchSizes) {
for (int depth : depths) { for (int depth : depths) {
@ -98,71 +96,34 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest {
for (int[] kernel : kernels) { for (int[] kernel : kernels) {
for (int[] stride : strides) { for (int[] stride : strides) {
for (Convolution3D.DataFormat df : Convolution3D.DataFormat.values()) { for (Convolution3D.DataFormat df : Convolution3D.DataFormat.values()) {
int outDepth = mode == ConvolutionMode.Same ? depth / stride[0] : (depth - kernel[0]) / stride[0] + 1;
int outDepth = mode == ConvolutionMode.Same ? int outHeight = mode == ConvolutionMode.Same ? height / stride[1] : (height - kernel[1]) / stride[1] + 1;
depth / stride[0] : (depth - kernel[0]) / stride[0] + 1; int outWidth = mode == ConvolutionMode.Same ? width / stride[2] : (width - kernel[2]) / stride[2] + 1;
int outHeight = mode == ConvolutionMode.Same ?
height / stride[1] : (height - kernel[1]) / stride[1] + 1;
int outWidth = mode == ConvolutionMode.Same ?
width / stride[2] : (width - kernel[2]) / stride[2] + 1;
INDArray input; INDArray input;
if(df == Convolution3D.DataFormat.NDHWC){ if (df == Convolution3D.DataFormat.NDHWC) {
input = Nd4j.rand(new int[]{miniBatchSize, depth, height, width, convNIn}); input = Nd4j.rand(new int[] { miniBatchSize, depth, height, width, convNIn });
} else { } else {
input = Nd4j.rand(new int[]{miniBatchSize, convNIn, depth, height, width}); input = Nd4j.rand(new int[] { miniBatchSize, convNIn, depth, height, width });
} }
INDArray labels = Nd4j.zeros(miniBatchSize, finalNOut); INDArray labels = Nd4j.zeros(miniBatchSize, finalNOut);
for (int i = 0; i < miniBatchSize; i++) { for (int i = 0; i < miniBatchSize; i++) {
labels.putScalar(new int[]{i, i % finalNOut}, 1.0); labels.putScalar(new int[] { i, i % finalNOut }, 1.0);
} }
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).updater(new NoOp()).weightInit(WeightInit.LECUN_NORMAL).dist(new NormalDistribution(0, 1)).list().layer(0, new Convolution3D.Builder().activation(afn).kernelSize(kernel).stride(stride).nIn(convNIn).nOut(convNOut1).hasBias(false).convolutionMode(mode).dataFormat(df).build()).layer(1, new Convolution3D.Builder().activation(afn).kernelSize(1, 1, 1).nIn(convNOut1).nOut(convNOut2).hasBias(false).convolutionMode(mode).dataFormat(df).build()).layer(2, new DenseLayer.Builder().nOut(denseNOut).build()).layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(finalNOut).build()).inputPreProcessor(2, new Cnn3DToFeedForwardPreProcessor(outDepth, outHeight, outWidth, convNOut2, df == Convolution3D.DataFormat.NCDHW)).setInputType(InputType.convolutional3D(df, depth, height, width, convNIn)).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.updater(new NoOp()).weightInit(WeightInit.LECUN_NORMAL)
.dist(new NormalDistribution(0, 1))
.list()
.layer(0, new Convolution3D.Builder().activation(afn).kernelSize(kernel)
.stride(stride).nIn(convNIn).nOut(convNOut1).hasBias(false)
.convolutionMode(mode).dataFormat(df)
.build())
.layer(1, new Convolution3D.Builder().activation(afn).kernelSize(1, 1, 1)
.nIn(convNOut1).nOut(convNOut2).hasBias(false)
.convolutionMode(mode).dataFormat(df)
.build())
.layer(2, new DenseLayer.Builder().nOut(denseNOut).build())
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(finalNOut).build())
.inputPreProcessor(2,
new Cnn3DToFeedForwardPreProcessor(outDepth, outHeight, outWidth,
convNOut2, df == Convolution3D.DataFormat.NCDHW))
.setInputType(InputType.convolutional3D(df, depth, height, width, convNIn)).build();
String json = conf.toJson(); String json = conf.toJson();
MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json); MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json);
assertEquals(conf, c2); assertEquals(conf, c2);
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
String msg = "DataFormat = " + df + ", minibatch size = " + miniBatchSize + ", activationFn=" + afn + ", kernel = " + Arrays.toString(kernel) + ", stride = " + Arrays.toString(stride) + ", mode = " + mode.toString() + ", input depth " + depth + ", input height " + height + ", input width " + width;
String msg = "DataFormat = " + df + ", minibatch size = " + miniBatchSize + ", activationFn=" + afn
+ ", kernel = " + Arrays.toString(kernel) + ", stride = "
+ Arrays.toString(stride) + ", mode = " + mode.toString()
+ ", input depth " + depth + ", input height " + height
+ ", input width " + width;
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
log.info(msg); log.info(msg);
// for (int j = 0; j < net.getnLayers(); j++) { // for (int j = 0; j < net.getnLayers(); j++) {
// log.info("Layer " + j + " # params: " + net.getLayer(j).numParams()); // log.info("Layer " + j + " # params: " + net.getLayer(j).numParams());
// } // }
} }
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input).labels(labels).subset(true).maxPerParam(128));
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input) assertTrue(gradOK,msg);
.labels(labels).subset(true).maxPerParam(128));
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(net); TestUtils.testModelSerialization(net);
} }
} }
@ -176,186 +137,98 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest {
} }
@Test @Test
public void testCnn3DZeroPadding() { @DisplayName("Test Cnn 3 D Zero Padding")
void testCnn3DZeroPadding() {
Nd4j.getRandom().setSeed(42); Nd4j.getRandom().setSeed(42);
int depth = 4; int depth = 4;
int height = 4; int height = 4;
int width = 4; int width = 4;
int[] minibatchSizes = { 3 };
int[] minibatchSizes = {3};
int convNIn = 2; int convNIn = 2;
int convNOut1 = 3; int convNOut1 = 3;
int convNOut2 = 4; int convNOut2 = 4;
int denseNOut = 5; int denseNOut = 5;
int finalNOut = 42; int finalNOut = 42;
int[] kernel = { 2, 2, 2 };
int[] zeroPadding = { 1, 1, 2, 2, 3, 3 };
int[] kernel = {2, 2, 2}; Activation[] activations = { Activation.SIGMOID };
int[] zeroPadding = {1, 1, 2, 2, 3, 3}; ConvolutionMode[] modes = { ConvolutionMode.Truncate, ConvolutionMode.Same };
Activation[] activations = {Activation.SIGMOID};
ConvolutionMode[] modes = {ConvolutionMode.Truncate, ConvolutionMode.Same};
for (Activation afn : activations) { for (Activation afn : activations) {
for (int miniBatchSize : minibatchSizes) { for (int miniBatchSize : minibatchSizes) {
for (ConvolutionMode mode : modes) { for (ConvolutionMode mode : modes) {
int outDepth = mode == ConvolutionMode.Same ? depth : (depth - kernel[0]) + 1;
int outDepth = mode == ConvolutionMode.Same ? int outHeight = mode == ConvolutionMode.Same ? height : (height - kernel[1]) + 1;
depth : (depth - kernel[0]) + 1; int outWidth = mode == ConvolutionMode.Same ? width : (width - kernel[2]) + 1;
int outHeight = mode == ConvolutionMode.Same ?
height : (height - kernel[1]) + 1;
int outWidth = mode == ConvolutionMode.Same ?
width : (width - kernel[2]) + 1;
outDepth += zeroPadding[0] + zeroPadding[1]; outDepth += zeroPadding[0] + zeroPadding[1];
outHeight += zeroPadding[2] + zeroPadding[3]; outHeight += zeroPadding[2] + zeroPadding[3];
outWidth += zeroPadding[4] + zeroPadding[5]; outWidth += zeroPadding[4] + zeroPadding[5];
INDArray input = Nd4j.rand(new int[] { miniBatchSize, convNIn, depth, height, width });
INDArray input = Nd4j.rand(new int[]{miniBatchSize, convNIn, depth, height, width});
INDArray labels = Nd4j.zeros(miniBatchSize, finalNOut); INDArray labels = Nd4j.zeros(miniBatchSize, finalNOut);
for (int i = 0; i < miniBatchSize; i++) { for (int i = 0; i < miniBatchSize; i++) {
labels.putScalar(new int[]{i, i % finalNOut}, 1.0); labels.putScalar(new int[] { i, i % finalNOut }, 1.0);
} }
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).updater(new NoOp()).weightInit(WeightInit.LECUN_NORMAL).dist(new NormalDistribution(0, 1)).list().layer(0, new Convolution3D.Builder().activation(afn).kernelSize(kernel).nIn(convNIn).nOut(convNOut1).hasBias(false).convolutionMode(mode).dataFormat(Convolution3D.DataFormat.NCDHW).build()).layer(1, new Convolution3D.Builder().activation(afn).kernelSize(1, 1, 1).nIn(convNOut1).nOut(convNOut2).hasBias(false).convolutionMode(mode).dataFormat(Convolution3D.DataFormat.NCDHW).build()).layer(2, new ZeroPadding3DLayer.Builder(zeroPadding).build()).layer(3, new DenseLayer.Builder().nOut(denseNOut).build()).layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(finalNOut).build()).inputPreProcessor(3, new Cnn3DToFeedForwardPreProcessor(outDepth, outHeight, outWidth, convNOut2, true)).setInputType(InputType.convolutional3D(depth, height, width, convNIn)).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.updater(new NoOp()).weightInit(WeightInit.LECUN_NORMAL)
.dist(new NormalDistribution(0, 1))
.list()
.layer(0, new Convolution3D.Builder().activation(afn).kernelSize(kernel)
.nIn(convNIn).nOut(convNOut1).hasBias(false)
.convolutionMode(mode).dataFormat(Convolution3D.DataFormat.NCDHW)
.build())
.layer(1, new Convolution3D.Builder().activation(afn).kernelSize(1, 1, 1)
.nIn(convNOut1).nOut(convNOut2).hasBias(false)
.convolutionMode(mode).dataFormat(Convolution3D.DataFormat.NCDHW)
.build())
.layer(2, new ZeroPadding3DLayer.Builder(zeroPadding).build())
.layer(3, new DenseLayer.Builder().nOut(denseNOut).build())
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(finalNOut).build())
.inputPreProcessor(3,
new Cnn3DToFeedForwardPreProcessor(outDepth, outHeight, outWidth,
convNOut2, true))
.setInputType(InputType.convolutional3D(depth, height, width, convNIn)).build();
String json = conf.toJson(); String json = conf.toJson();
MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json); MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json);
assertEquals(conf, c2); assertEquals(conf, c2);
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
String msg = "Minibatch size = " + miniBatchSize + ", activationFn=" + afn + ", kernel = " + Arrays.toString(kernel) + ", mode = " + mode.toString() + ", input depth " + depth + ", input height " + height + ", input width " + width;
String msg = "Minibatch size = " + miniBatchSize + ", activationFn=" + afn
+ ", kernel = " + Arrays.toString(kernel) + ", mode = " + mode.toString()
+ ", input depth " + depth + ", input height " + height
+ ", input width " + width;
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
log.info(msg); log.info(msg);
// for (int j = 0; j < net.getnLayers(); j++) { // for (int j = 0; j < net.getnLayers(); j++) {
// log.info("Layer " + j + " # params: " + net.getLayer(j).numParams()); // log.info("Layer " + j + " # params: " + net.getLayer(j).numParams());
// } // }
} }
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input).labels(labels).subset(true).maxPerParam(512));
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input) assertTrue(gradOK,msg);
.labels(labels).subset(true).maxPerParam(512));
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(net); TestUtils.testModelSerialization(net);
} }
} }
} }
} }
@Test @Test
public void testCnn3DPooling() { @DisplayName("Test Cnn 3 D Pooling")
void testCnn3DPooling() {
Nd4j.getRandom().setSeed(42); Nd4j.getRandom().setSeed(42);
int depth = 4; int depth = 4;
int height = 4; int height = 4;
int width = 4; int width = 4;
int[] minibatchSizes = { 3 };
int[] minibatchSizes = {3};
int convNIn = 2; int convNIn = 2;
int convNOut = 4; int convNOut = 4;
int denseNOut = 5; int denseNOut = 5;
int finalNOut = 42; int finalNOut = 42;
int[] kernel = { 2, 2, 2 };
int[] kernel = {2, 2, 2}; Activation[] activations = { Activation.SIGMOID };
Subsampling3DLayer.PoolingType[] poolModes = { Subsampling3DLayer.PoolingType.AVG };
Activation[] activations = {Activation.SIGMOID}; ConvolutionMode[] modes = { ConvolutionMode.Truncate };
Subsampling3DLayer.PoolingType[] poolModes = {Subsampling3DLayer.PoolingType.AVG};
ConvolutionMode[] modes = {ConvolutionMode.Truncate};
for (Activation afn : activations) { for (Activation afn : activations) {
for (int miniBatchSize : minibatchSizes) { for (int miniBatchSize : minibatchSizes) {
for (Subsampling3DLayer.PoolingType pool : poolModes) { for (Subsampling3DLayer.PoolingType pool : poolModes) {
for (ConvolutionMode mode : modes) { for (ConvolutionMode mode : modes) {
for (Convolution3D.DataFormat df : Convolution3D.DataFormat.values()) { for (Convolution3D.DataFormat df : Convolution3D.DataFormat.values()) {
int outDepth = depth / kernel[0]; int outDepth = depth / kernel[0];
int outHeight = height / kernel[1]; int outHeight = height / kernel[1];
int outWidth = width / kernel[2]; int outWidth = width / kernel[2];
INDArray input = Nd4j.rand(df == Convolution3D.DataFormat.NCDHW ? new int[] { miniBatchSize, convNIn, depth, height, width } : new int[] { miniBatchSize, depth, height, width, convNIn });
INDArray input = Nd4j.rand(
df == Convolution3D.DataFormat.NCDHW ? new int[]{miniBatchSize, convNIn, depth, height, width}
: new int[]{miniBatchSize, depth, height, width, convNIn});
INDArray labels = Nd4j.zeros(miniBatchSize, finalNOut); INDArray labels = Nd4j.zeros(miniBatchSize, finalNOut);
for (int i = 0; i < miniBatchSize; i++) { for (int i = 0; i < miniBatchSize; i++) {
labels.putScalar(new int[]{i, i % finalNOut}, 1.0); labels.putScalar(new int[] { i, i % finalNOut }, 1.0);
} }
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).updater(new NoOp()).weightInit(WeightInit.XAVIER).dist(new NormalDistribution(0, 1)).list().layer(0, new Convolution3D.Builder().activation(afn).kernelSize(1, 1, 1).nIn(convNIn).nOut(convNOut).hasBias(false).convolutionMode(mode).dataFormat(df).build()).layer(1, new Subsampling3DLayer.Builder(kernel).poolingType(pool).convolutionMode(mode).dataFormat(df).build()).layer(2, new DenseLayer.Builder().nOut(denseNOut).build()).layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(finalNOut).build()).inputPreProcessor(2, new Cnn3DToFeedForwardPreProcessor(outDepth, outHeight, outWidth, convNOut, df)).setInputType(InputType.convolutional3D(df, depth, height, width, convNIn)).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.updater(new NoOp())
.weightInit(WeightInit.XAVIER)
.dist(new NormalDistribution(0, 1))
.list()
.layer(0, new Convolution3D.Builder().activation(afn).kernelSize(1, 1, 1)
.nIn(convNIn).nOut(convNOut).hasBias(false)
.convolutionMode(mode).dataFormat(df)
.build())
.layer(1, new Subsampling3DLayer.Builder(kernel)
.poolingType(pool).convolutionMode(mode).dataFormat(df).build())
.layer(2, new DenseLayer.Builder().nOut(denseNOut).build())
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(finalNOut).build())
.inputPreProcessor(2,
new Cnn3DToFeedForwardPreProcessor(outDepth, outHeight, outWidth,convNOut, df))
.setInputType(InputType.convolutional3D(df, depth, height, width, convNIn)).build();
String json = conf.toJson(); String json = conf.toJson();
MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json); MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json);
assertEquals(conf, c2); assertEquals(conf, c2);
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
String msg = "Minibatch size = " + miniBatchSize + ", activationFn=" + afn + ", kernel = " + Arrays.toString(kernel) + ", mode = " + mode.toString() + ", input depth " + depth + ", input height " + height + ", input width " + width + ", dataFormat=" + df;
String msg = "Minibatch size = " + miniBatchSize + ", activationFn=" + afn
+ ", kernel = " + Arrays.toString(kernel) + ", mode = " + mode.toString()
+ ", input depth " + depth + ", input height " + height
+ ", input width " + width + ", dataFormat=" + df;
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
log.info(msg); log.info(msg);
} }
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, assertTrue(gradOK,msg);
DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS,
RETURN_ON_FIRST_FAILURE, input, labels);
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(net); TestUtils.testModelSerialization(net);
} }
} }
@ -365,87 +238,47 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest {
} }
@Test @Test
public void testCnn3DUpsampling() { @DisplayName("Test Cnn 3 D Upsampling")
void testCnn3DUpsampling() {
Nd4j.getRandom().setSeed(42); Nd4j.getRandom().setSeed(42);
int depth = 2; int depth = 2;
int height = 2; int height = 2;
int width = 2; int width = 2;
int[] minibatchSizes = { 3 };
int[] minibatchSizes = {3};
int convNIn = 2; int convNIn = 2;
int convNOut = 4; int convNOut = 4;
int denseNOut = 5; int denseNOut = 5;
int finalNOut = 42; int finalNOut = 42;
int[] upsamplingSize = { 2, 2, 2 };
Activation[] activations = { Activation.SIGMOID };
int[] upsamplingSize = {2, 2, 2}; ConvolutionMode[] modes = { ConvolutionMode.Truncate };
Activation[] activations = {Activation.SIGMOID};
ConvolutionMode[] modes = {ConvolutionMode.Truncate};
for (Activation afn : activations) { for (Activation afn : activations) {
for (int miniBatchSize : minibatchSizes) { for (int miniBatchSize : minibatchSizes) {
for (ConvolutionMode mode : modes) { for (ConvolutionMode mode : modes) {
for(Convolution3D.DataFormat df : Convolution3D.DataFormat.values()) { for (Convolution3D.DataFormat df : Convolution3D.DataFormat.values()) {
int outDepth = depth * upsamplingSize[0]; int outDepth = depth * upsamplingSize[0];
int outHeight = height * upsamplingSize[1]; int outHeight = height * upsamplingSize[1];
int outWidth = width * upsamplingSize[2]; int outWidth = width * upsamplingSize[2];
INDArray input = df == Convolution3D.DataFormat.NCDHW ? Nd4j.rand(miniBatchSize, convNIn, depth, height, width) : Nd4j.rand(miniBatchSize, depth, height, width, convNIn); INDArray input = df == Convolution3D.DataFormat.NCDHW ? Nd4j.rand(miniBatchSize, convNIn, depth, height, width) : Nd4j.rand(miniBatchSize, depth, height, width, convNIn);
INDArray labels = Nd4j.zeros(miniBatchSize, finalNOut); INDArray labels = Nd4j.zeros(miniBatchSize, finalNOut);
for (int i = 0; i < miniBatchSize; i++) { for (int i = 0; i < miniBatchSize; i++) {
labels.putScalar(new int[]{i, i % finalNOut}, 1.0); labels.putScalar(new int[] { i, i % finalNOut }, 1.0);
} }
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).updater(new NoOp()).weightInit(WeightInit.LECUN_NORMAL).dist(new NormalDistribution(0, 1)).seed(12345).list().layer(0, new Convolution3D.Builder().activation(afn).kernelSize(1, 1, 1).nIn(convNIn).nOut(convNOut).hasBias(false).convolutionMode(mode).dataFormat(df).build()).layer(1, new Upsampling3D.Builder(upsamplingSize[0]).dataFormat(df).build()).layer(2, new DenseLayer.Builder().nOut(denseNOut).build()).layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(finalNOut).build()).inputPreProcessor(2, new Cnn3DToFeedForwardPreProcessor(outDepth, outHeight, outWidth, convNOut, true)).setInputType(InputType.convolutional3D(df, depth, height, width, convNIn)).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.updater(new NoOp()).weightInit(WeightInit.LECUN_NORMAL)
.dist(new NormalDistribution(0, 1))
.seed(12345)
.list()
.layer(0, new Convolution3D.Builder().activation(afn).kernelSize(1, 1, 1)
.nIn(convNIn).nOut(convNOut).hasBias(false)
.convolutionMode(mode).dataFormat(df)
.build())
.layer(1, new Upsampling3D.Builder(upsamplingSize[0]).dataFormat(df).build())
.layer(2, new DenseLayer.Builder().nOut(denseNOut).build())
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(finalNOut).build())
.inputPreProcessor(2,
new Cnn3DToFeedForwardPreProcessor(outDepth, outHeight, outWidth,
convNOut, true))
.setInputType(InputType.convolutional3D(df, depth, height, width, convNIn)).build();
String json = conf.toJson(); String json = conf.toJson();
MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json); MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json);
assertEquals(conf, c2); assertEquals(conf, c2);
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
String msg = "Minibatch size = " + miniBatchSize + ", activationFn=" + afn + ", kernel = " + Arrays.toString(upsamplingSize) + ", mode = " + mode.toString() + ", input depth " + depth + ", input height " + height + ", input width " + width;
String msg = "Minibatch size = " + miniBatchSize + ", activationFn=" + afn
+ ", kernel = " + Arrays.toString(upsamplingSize) + ", mode = " + mode.toString()
+ ", input depth " + depth + ", input height " + height
+ ", input width " + width;
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
log.info(msg); log.info(msg);
// for (int j = 0; j < net.getnLayers(); j++) { // for (int j = 0; j < net.getnLayers(); j++) {
// log.info("Layer " + j + " # params: " + net.getLayer(j).numParams()); // log.info("Layer " + j + " # params: " + net.getLayer(j).numParams());
// } // }
} }
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, assertTrue(gradOK,msg);
DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS,
RETURN_ON_FIRST_FAILURE, input, labels);
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(net); TestUtils.testModelSerialization(net);
} }
} }
@ -454,126 +287,74 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest {
} }
@Test @Test
public void testCnn3DCropping() { @DisplayName("Test Cnn 3 D Cropping")
void testCnn3DCropping() {
Nd4j.getRandom().setSeed(42); Nd4j.getRandom().setSeed(42);
int depth = 6; int depth = 6;
int height = 6; int height = 6;
int width = 6; int width = 6;
int[] minibatchSizes = { 3 };
int[] minibatchSizes = {3};
int convNIn = 2; int convNIn = 2;
int convNOut1 = 3; int convNOut1 = 3;
int convNOut2 = 4; int convNOut2 = 4;
int denseNOut = 5; int denseNOut = 5;
int finalNOut = 8; int finalNOut = 8;
int[] kernel = { 1, 1, 1 };
int[] cropping = { 0, 0, 1, 1, 2, 2 };
int[] kernel = {1, 1, 1}; Activation[] activations = { Activation.SIGMOID };
int[] cropping = {0, 0, 1, 1, 2, 2}; ConvolutionMode[] modes = { ConvolutionMode.Same };
Activation[] activations = {Activation.SIGMOID};
ConvolutionMode[] modes = {ConvolutionMode.Same};
for (Activation afn : activations) { for (Activation afn : activations) {
for (int miniBatchSize : minibatchSizes) { for (int miniBatchSize : minibatchSizes) {
for (ConvolutionMode mode : modes) { for (ConvolutionMode mode : modes) {
int outDepth = mode == ConvolutionMode.Same ? depth : (depth - kernel[0]) + 1;
int outDepth = mode == ConvolutionMode.Same ? int outHeight = mode == ConvolutionMode.Same ? height : (height - kernel[1]) + 1;
depth : (depth - kernel[0]) + 1; int outWidth = mode == ConvolutionMode.Same ? width : (width - kernel[2]) + 1;
int outHeight = mode == ConvolutionMode.Same ?
height : (height - kernel[1]) + 1;
int outWidth = mode == ConvolutionMode.Same ?
width : (width - kernel[2]) + 1;
outDepth -= cropping[0] + cropping[1]; outDepth -= cropping[0] + cropping[1];
outHeight -= cropping[2] + cropping[3]; outHeight -= cropping[2] + cropping[3];
outWidth -= cropping[4] + cropping[5]; outWidth -= cropping[4] + cropping[5];
INDArray input = Nd4j.rand(new int[] { miniBatchSize, convNIn, depth, height, width });
INDArray input = Nd4j.rand(new int[]{miniBatchSize, convNIn, depth, height, width});
INDArray labels = Nd4j.zeros(miniBatchSize, finalNOut); INDArray labels = Nd4j.zeros(miniBatchSize, finalNOut);
for (int i = 0; i < miniBatchSize; i++) { for (int i = 0; i < miniBatchSize; i++) {
labels.putScalar(new int[]{i, i % finalNOut}, 1.0); labels.putScalar(new int[] { i, i % finalNOut }, 1.0);
} }
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).updater(new NoOp()).weightInit(WeightInit.LECUN_NORMAL).dist(new NormalDistribution(0, 1)).list().layer(0, new Convolution3D.Builder().activation(afn).kernelSize(kernel).nIn(convNIn).nOut(convNOut1).hasBias(false).convolutionMode(mode).dataFormat(Convolution3D.DataFormat.NCDHW).build()).layer(1, new Convolution3D.Builder().activation(afn).kernelSize(1, 1, 1).nIn(convNOut1).nOut(convNOut2).hasBias(false).convolutionMode(mode).dataFormat(Convolution3D.DataFormat.NCDHW).build()).layer(2, new Cropping3D.Builder(cropping).build()).layer(3, new DenseLayer.Builder().nOut(denseNOut).build()).layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(finalNOut).build()).inputPreProcessor(3, new Cnn3DToFeedForwardPreProcessor(outDepth, outHeight, outWidth, convNOut2, true)).setInputType(InputType.convolutional3D(depth, height, width, convNIn)).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.updater(new NoOp()).weightInit(WeightInit.LECUN_NORMAL)
.dist(new NormalDistribution(0, 1))
.list()
.layer(0, new Convolution3D.Builder().activation(afn).kernelSize(kernel)
.nIn(convNIn).nOut(convNOut1).hasBias(false)
.convolutionMode(mode).dataFormat(Convolution3D.DataFormat.NCDHW)
.build())
.layer(1, new Convolution3D.Builder().activation(afn).kernelSize(1, 1, 1)
.nIn(convNOut1).nOut(convNOut2).hasBias(false)
.convolutionMode(mode).dataFormat(Convolution3D.DataFormat.NCDHW)
.build())
.layer(2, new Cropping3D.Builder(cropping).build())
.layer(3, new DenseLayer.Builder().nOut(denseNOut).build())
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(finalNOut).build())
.inputPreProcessor(3,
new Cnn3DToFeedForwardPreProcessor(outDepth, outHeight, outWidth,
convNOut2, true))
.setInputType(InputType.convolutional3D(depth, height, width, convNIn)).build();
String json = conf.toJson(); String json = conf.toJson();
MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json); MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json);
assertEquals(conf, c2); assertEquals(conf, c2);
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
String msg = "Minibatch size = " + miniBatchSize + ", activationFn=" + afn + ", kernel = " + Arrays.toString(kernel) + ", mode = " + mode.toString() + ", input depth " + depth + ", input height " + height + ", input width " + width;
String msg = "Minibatch size = " + miniBatchSize + ", activationFn=" + afn
+ ", kernel = " + Arrays.toString(kernel) + ", mode = " + mode.toString()
+ ", input depth " + depth + ", input height " + height
+ ", input width " + width;
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
log.info(msg); log.info(msg);
// for (int j = 0; j < net.getnLayers(); j++) { // for (int j = 0; j < net.getnLayers(); j++) {
// log.info("Layer " + j + " # params: " + net.getLayer(j).numParams()); // log.info("Layer " + j + " # params: " + net.getLayer(j).numParams());
// } // }
} }
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, assertTrue(gradOK,msg);
DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS,
RETURN_ON_FIRST_FAILURE, input, labels);
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(net); TestUtils.testModelSerialization(net);
} }
} }
} }
} }
@Test @Test
public void testDeconv3d() { @DisplayName("Test Deconv 3 d")
void testDeconv3d() {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
// Note: we checked this with a variety of parameters, but it takes a lot of time. // Note: we checked this with a variety of parameters, but it takes a lot of time.
int[] depths = {8, 8, 9}; int[] depths = { 8, 8, 9 };
int[] heights = {8, 9, 9}; int[] heights = { 8, 9, 9 };
int[] widths = {8, 8, 9}; int[] widths = { 8, 8, 9 };
int[][] kernels = { { 2, 2, 2 }, { 3, 3, 3 }, { 2, 3, 2 } };
int[][] strides = { { 1, 1, 1 }, { 1, 1, 1 }, { 2, 2, 2 } };
int[][] kernels = {{2, 2, 2}, {3, 3, 3}, {2, 3, 2}}; Activation[] activations = { Activation.SIGMOID, Activation.TANH, Activation.IDENTITY };
int[][] strides = {{1, 1, 1}, {1, 1, 1}, {2, 2, 2}}; ConvolutionMode[] modes = { ConvolutionMode.Truncate, ConvolutionMode.Same, ConvolutionMode.Same };
int[] mbs = { 1, 3, 2 };
Activation[] activations = {Activation.SIGMOID, Activation.TANH, Activation.IDENTITY}; Convolution3D.DataFormat[] dataFormats = new Convolution3D.DataFormat[] { Convolution3D.DataFormat.NCDHW, Convolution3D.DataFormat.NDHWC, Convolution3D.DataFormat.NCDHW };
ConvolutionMode[] modes = {ConvolutionMode.Truncate, ConvolutionMode.Same, ConvolutionMode.Same};
int[] mbs = {1, 3, 2};
Convolution3D.DataFormat[] dataFormats = new Convolution3D.DataFormat[]{Convolution3D.DataFormat.NCDHW, Convolution3D.DataFormat.NDHWC, Convolution3D.DataFormat.NCDHW};
int convNIn = 2; int convNIn = 2;
int finalNOut = 2; int finalNOut = 2;
int[] deconvOut = {2, 3, 4}; int[] deconvOut = { 2, 3, 4 };
for (int i = 0; i < activations.length; i++) { for (int i = 0; i < activations.length; i++) {
Activation afn = activations[i]; Activation afn = activations[i];
int miniBatchSize = mbs[i]; int miniBatchSize = mbs[i];
@ -585,57 +366,28 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest {
int[] stride = strides[i]; int[] stride = strides[i];
Convolution3D.DataFormat df = dataFormats[i]; Convolution3D.DataFormat df = dataFormats[i];
int dOut = deconvOut[i]; int dOut = deconvOut[i];
INDArray input; INDArray input;
if (df == Convolution3D.DataFormat.NDHWC) { if (df == Convolution3D.DataFormat.NDHWC) {
input = Nd4j.rand(new int[]{miniBatchSize, depth, height, width, convNIn}); input = Nd4j.rand(new int[] { miniBatchSize, depth, height, width, convNIn });
} else { } else {
input = Nd4j.rand(new int[]{miniBatchSize, convNIn, depth, height, width}); input = Nd4j.rand(new int[] { miniBatchSize, convNIn, depth, height, width });
} }
INDArray labels = Nd4j.zeros(miniBatchSize, finalNOut); INDArray labels = Nd4j.zeros(miniBatchSize, finalNOut);
for (int j = 0; j < miniBatchSize; j++) { for (int j = 0; j < miniBatchSize; j++) {
labels.putScalar(new int[]{j, j % finalNOut}, 1.0); labels.putScalar(new int[] { j, j % finalNOut }, 1.0);
} }
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).updater(new NoOp()).weightInit(new NormalDistribution(0, 0.1)).list().layer(0, new Convolution3D.Builder().activation(afn).kernelSize(kernel).stride(stride).nIn(convNIn).nOut(dOut).hasBias(false).convolutionMode(mode).dataFormat(df).build()).layer(1, new Deconvolution3D.Builder().activation(afn).kernelSize(kernel).stride(stride).nOut(dOut).hasBias(false).convolutionMode(mode).dataFormat(df).build()).layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(finalNOut).build()).setInputType(InputType.convolutional3D(df, depth, height, width, convNIn)).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.updater(new NoOp())
.weightInit(new NormalDistribution(0, 0.1))
.list()
.layer(0, new Convolution3D.Builder().activation(afn).kernelSize(kernel)
.stride(stride).nIn(convNIn).nOut(dOut).hasBias(false)
.convolutionMode(mode).dataFormat(df)
.build())
.layer(1, new Deconvolution3D.Builder().activation(afn).kernelSize(kernel)
.stride(stride).nOut(dOut).hasBias(false)
.convolutionMode(mode).dataFormat(df)
.build())
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(finalNOut).build())
.setInputType(InputType.convolutional3D(df, depth, height, width, convNIn)).build();
String json = conf.toJson(); String json = conf.toJson();
MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json); MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json);
assertEquals(conf, c2); assertEquals(conf, c2);
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
String msg = "DataFormat = " + df + ", minibatch size = " + miniBatchSize + ", activationFn=" + afn + ", kernel = " + Arrays.toString(kernel) + ", stride = " + Arrays.toString(stride) + ", mode = " + mode.toString() + ", input depth " + depth + ", input height " + height + ", input width " + width;
String msg = "DataFormat = " + df + ", minibatch size = " + miniBatchSize + ", activationFn=" + afn
+ ", kernel = " + Arrays.toString(kernel) + ", stride = "
+ Arrays.toString(stride) + ", mode = " + mode.toString()
+ ", input depth " + depth + ", input height " + height
+ ", input width " + width;
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
log.info(msg); log.info(msg);
} }
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input).labels(labels).subset(true).maxPerParam(64));
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input) assertTrue(gradOK,msg);
.labels(labels).subset(true).maxPerParam(64));
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(net); TestUtils.testModelSerialization(net);
} }
} }

View File

@ -17,11 +17,9 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.deeplearning4j.gradientcheck; package org.deeplearning4j.gradientcheck;
import static org.junit.Assert.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils; import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
@ -35,19 +33,21 @@ import org.deeplearning4j.nn.conf.layers.LossLayer;
import org.deeplearning4j.nn.conf.layers.PrimaryCapsules; import org.deeplearning4j.nn.conf.layers.PrimaryCapsules;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInitDistribution; import org.deeplearning4j.nn.weights.WeightInitDistribution;
import org.junit.Ignore; import org.junit.jupiter.api.Disabled;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.impl.ActivationSoftmax; import org.nd4j.linalg.activations.impl.ActivationSoftmax;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood; import org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood;
import java.util.Random; import java.util.Random;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@Ignore @Disabled
public class CapsnetGradientCheckTest extends BaseDL4JTest { @DisplayName("Capsnet Gradient Check Test")
class CapsnetGradientCheckTest extends BaseDL4JTest {
@Override @Override
public long getTimeoutMilliseconds() { public long getTimeoutMilliseconds() {
@ -55,71 +55,39 @@ public class CapsnetGradientCheckTest extends BaseDL4JTest {
} }
@Test @Test
public void testCapsNet() { @DisplayName("Test Caps Net")
void testCapsNet() {
int[] minibatchSizes = {8, 16}; int[] minibatchSizes = { 8, 16 };
int width = 6; int width = 6;
int height = 6; int height = 6;
int inputDepth = 4; int inputDepth = 4;
int[] primaryCapsDims = { 2, 4 };
int[] primaryCapsDims = {2, 4}; int[] primaryCapsChannels = { 8 };
int[] primaryCapsChannels = {8}; int[] capsules = { 5 };
int[] capsules = {5}; int[] capsuleDims = { 4, 8 };
int[] capsuleDims = {4, 8}; int[] routings = { 1 };
int[] routings = {1};
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
for (int routing : routings) { for (int routing : routings) {
for (int primaryCapsDim : primaryCapsDims) { for (int primaryCapsDim : primaryCapsDims) {
for (int primarpCapsChannel : primaryCapsChannels) { for (int primarpCapsChannel : primaryCapsChannels) {
for (int capsule : capsules) { for (int capsule : capsules) {
for (int capsuleDim : capsuleDims) { for (int capsuleDim : capsuleDims) {
for (int minibatchSize : minibatchSizes) { for (int minibatchSize : minibatchSizes) {
INDArray input = Nd4j.rand(minibatchSize, inputDepth * height * width).mul(10).reshape(-1, inputDepth, height, width);
INDArray input = Nd4j.rand(minibatchSize, inputDepth * height * width).mul(10)
.reshape(-1, inputDepth, height, width);
INDArray labels = Nd4j.zeros(minibatchSize, capsule); INDArray labels = Nd4j.zeros(minibatchSize, capsule);
for (int i = 0; i < minibatchSize; i++) { for (int i = 0; i < minibatchSize; i++) {
labels.putScalar(new int[]{i, i % capsule}, 1.0); labels.putScalar(new int[] { i, i % capsule }, 1.0);
} }
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).seed(123).updater(new NoOp()).weightInit(new WeightInitDistribution(new UniformDistribution(-6, 6))).list().layer(new PrimaryCapsules.Builder(primaryCapsDim, primarpCapsChannel).kernelSize(3, 3).stride(2, 2).build()).layer(new CapsuleLayer.Builder(capsule, capsuleDim, routing).build()).layer(new CapsuleStrengthLayer.Builder().build()).layer(new ActivationLayer.Builder(new ActivationSoftmax()).build()).layer(new LossLayer.Builder(new LossNegativeLogLikelihood()).build()).setInputType(InputType.convolutional(height, width, inputDepth)).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.seed(123)
.updater(new NoOp())
.weightInit(new WeightInitDistribution(new UniformDistribution(-6, 6)))
.list()
.layer(new PrimaryCapsules.Builder(primaryCapsDim, primarpCapsChannel)
.kernelSize(3, 3)
.stride(2, 2)
.build())
.layer(new CapsuleLayer.Builder(capsule, capsuleDim, routing).build())
.layer(new CapsuleStrengthLayer.Builder().build())
.layer(new ActivationLayer.Builder(new ActivationSoftmax()).build())
.layer(new LossLayer.Builder(new LossNegativeLogLikelihood()).build())
.setInputType(InputType.convolutional(height, width, inputDepth))
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
System.out.println("nParams, layer " + i + ": " + net.getLayer(i).numParams()); System.out.println("nParams, layer " + i + ": " + net.getLayer(i).numParams());
} }
String msg = "minibatch=" + minibatchSize + ", PrimaryCaps: " + primarpCapsChannel + " channels, " + primaryCapsDim + " dimensions, Capsules: " + capsule + " capsules with " + capsuleDim + " dimensions and " + routing + " routings";
String msg = "minibatch=" + minibatchSize +
", PrimaryCaps: " + primarpCapsChannel +
" channels, " + primaryCapsDim + " dimensions, Capsules: " + capsule +
" capsules with " + capsuleDim + " dimensions and " + routing + " routings";
System.out.println(msg); System.out.println(msg);
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input).labels(labels).subset(true).maxPerParam(100));
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input) assertTrue(gradOK,msg);
.labels(labels).subset(true).maxPerParam(100));
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(net); TestUtils.testModelSerialization(net);
} }
} }

View File

@ -17,34 +17,34 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.deeplearning4j.nn.adapters; package org.deeplearning4j.nn.adapters;
import lombok.val; import lombok.val;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.*; @DisplayName("Argmax Adapter Test")
class ArgmaxAdapterTest extends BaseDL4JTest {
public class ArgmaxAdapterTest extends BaseDL4JTest {
@Test @Test
public void testSoftmax_2D_1() { @DisplayName("Test Softmax _ 2 D _ 1")
val in = new double[][] {{1, 3, 2}, { 4, 5, 6}}; void testSoftmax_2D_1() {
val in = new double[][] { { 1, 3, 2 }, { 4, 5, 6 } };
val adapter = new ArgmaxAdapter(); val adapter = new ArgmaxAdapter();
val result = adapter.apply(Nd4j.create(in)); val result = adapter.apply(Nd4j.create(in));
assertArrayEquals(new int[] { 1, 2 }, result);
assertArrayEquals(new int[]{1, 2}, result);
} }
@Test @Test
public void testSoftmax_1D_1() { @DisplayName("Test Softmax _ 1 D _ 1")
val in = new double[] {1, 3, 2}; void testSoftmax_1D_1() {
val in = new double[] { 1, 3, 2 };
val adapter = new ArgmaxAdapter(); val adapter = new ArgmaxAdapter();
val result = adapter.apply(Nd4j.create(in)); val result = adapter.apply(Nd4j.create(in));
assertArrayEquals(new int[] { 1 }, result);
assertArrayEquals(new int[]{1}, result);
} }
} }

View File

@ -17,35 +17,37 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.deeplearning4j.nn.adapters; package org.deeplearning4j.nn.adapters;
import lombok.val; import lombok.val;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.util.ArrayUtil; import org.nd4j.common.util.ArrayUtil;
import static org.junit.jupiter.api.Assertions.*;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.Assert.*; import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@DisplayName("Regression 2 d Adapter Test")
class Regression2dAdapterTest extends BaseDL4JTest {
public class Regression2dAdapterTest extends BaseDL4JTest {
@Test @Test
public void testRegressionAdapter_2D_1() throws Exception { @DisplayName("Test Regression Adapter _ 2 D _ 1")
val in = new double[][] {{1, 2, 3}, { 4, 5, 6}}; void testRegressionAdapter_2D_1() throws Exception {
val in = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } };
val adapter = new Regression2dAdapter(); val adapter = new Regression2dAdapter();
val result = adapter.apply(Nd4j.create(in)); val result = adapter.apply(Nd4j.create(in));
assertArrayEquals(ArrayUtil.flatten(in), ArrayUtil.flatten(result), 1e-5); assertArrayEquals(ArrayUtil.flatten(in), ArrayUtil.flatten(result), 1e-5);
} }
@Test @Test
public void testRegressionAdapter_2D_2() throws Exception { @DisplayName("Test Regression Adapter _ 2 D _ 2")
val in = new double[]{1, 2, 3}; void testRegressionAdapter_2D_2() throws Exception {
val in = new double[] { 1, 2, 3 };
val adapter = new Regression2dAdapter(); val adapter = new Regression2dAdapter();
val result = adapter.apply(Nd4j.create(in)); val result = adapter.apply(Nd4j.create(in));
assertArrayEquals(in, ArrayUtil.flatten(result), 1e-5); assertArrayEquals(in, ArrayUtil.flatten(result), 1e-5);
} }
} }

View File

@ -17,10 +17,8 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.deeplearning4j.nn.conf; package org.deeplearning4j.nn.conf;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
@ -43,296 +41,158 @@ import org.deeplearning4j.nn.conf.misc.TestGraphVertex;
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor; import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToCnnPreProcessor; import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToCnnPreProcessor;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import static org.junit.jupiter.api.Assertions.*;
import static org.junit.Assert.*; import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@Slf4j @Slf4j
public class ComputationGraphConfigurationTest extends BaseDL4JTest { @DisplayName("Computation Graph Configuration Test")
class ComputationGraphConfigurationTest extends BaseDL4JTest {
@Test @Test
public void testJSONBasic() { @DisplayName("Test JSON Basic")
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) void testJSONBasic() {
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).dist(new NormalDistribution(0, 1)).updater(new NoOp()).graphBuilder().addInputs("input").appendLayer("firstLayer", new DenseLayer.Builder().nIn(4).nOut(5).activation(Activation.TANH).build()).addLayer("outputLayer", new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(5).nOut(3).build(), "firstLayer").setOutputs("outputLayer").build();
.dist(new NormalDistribution(0, 1)).updater(new NoOp())
.graphBuilder().addInputs("input")
.appendLayer("firstLayer",
new DenseLayer.Builder().nIn(4).nOut(5).activation(Activation.TANH).build())
.addLayer("outputLayer",
new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(5).nOut(3).build(),
"firstLayer")
.setOutputs("outputLayer").build();
String json = conf.toJson(); String json = conf.toJson();
ComputationGraphConfiguration conf2 = ComputationGraphConfiguration.fromJson(json); ComputationGraphConfiguration conf2 = ComputationGraphConfiguration.fromJson(json);
assertEquals(json, conf2.toJson()); assertEquals(json, conf2.toJson());
assertEquals(conf, conf2); assertEquals(conf, conf2);
} }
@Test @Test
public void testJSONBasic2() { @DisplayName("Test JSON Basic 2")
ComputationGraphConfiguration conf = void testJSONBasic2() {
new NeuralNetConfiguration.Builder() ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder().addInputs("input").addLayer("cnn1", new ConvolutionLayer.Builder(2, 2).stride(2, 2).nIn(1).nOut(5).build(), "input").addLayer("cnn2", new ConvolutionLayer.Builder(2, 2).stride(2, 2).nIn(1).nOut(5).build(), "input").addLayer("max1", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).build(), "cnn1", "cnn2").addLayer("dnn1", new DenseLayer.Builder().nOut(7).build(), "max1").addLayer("max2", new SubsamplingLayer.Builder().build(), "max1").addLayer("output", new OutputLayer.Builder().nIn(7).nOut(10).activation(Activation.SOFTMAX).build(), "dnn1", "max2").setOutputs("output").inputPreProcessor("cnn1", new FeedForwardToCnnPreProcessor(32, 32, 3)).inputPreProcessor("cnn2", new FeedForwardToCnnPreProcessor(32, 32, 3)).inputPreProcessor("dnn1", new CnnToFeedForwardPreProcessor(8, 8, 5)).build();
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.graphBuilder().addInputs("input")
.addLayer("cnn1",
new ConvolutionLayer.Builder(2, 2).stride(2, 2).nIn(1).nOut(5)
.build(),
"input")
.addLayer("cnn2",
new ConvolutionLayer.Builder(2, 2).stride(2, 2).nIn(1).nOut(5)
.build(),
"input")
.addLayer("max1",
new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2, 2).build(),
"cnn1", "cnn2")
.addLayer("dnn1", new DenseLayer.Builder().nOut(7).build(), "max1")
.addLayer("max2", new SubsamplingLayer.Builder().build(), "max1")
.addLayer("output", new OutputLayer.Builder().nIn(7).nOut(10).activation(Activation.SOFTMAX).build(), "dnn1",
"max2")
.setOutputs("output")
.inputPreProcessor("cnn1", new FeedForwardToCnnPreProcessor(32, 32, 3))
.inputPreProcessor("cnn2", new FeedForwardToCnnPreProcessor(32, 32, 3))
.inputPreProcessor("dnn1", new CnnToFeedForwardPreProcessor(8, 8, 5))
.build();
String json = conf.toJson(); String json = conf.toJson();
ComputationGraphConfiguration conf2 = ComputationGraphConfiguration.fromJson(json); ComputationGraphConfiguration conf2 = ComputationGraphConfiguration.fromJson(json);
assertEquals(json, conf2.toJson()); assertEquals(json, conf2.toJson());
assertEquals(conf, conf2); assertEquals(conf, conf2);
} }
@Test @Test
public void testJSONWithGraphNodes() { @DisplayName("Test JSON With Graph Nodes")
void testJSONWithGraphNodes() {
ComputationGraphConfiguration conf = ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder().addInputs("input1", "input2").addLayer("cnn1", new ConvolutionLayer.Builder(2, 2).stride(2, 2).nIn(1).nOut(5).build(), "input1").addLayer("cnn2", new ConvolutionLayer.Builder(2, 2).stride(2, 2).nIn(1).nOut(5).build(), "input2").addVertex("merge1", new MergeVertex(), "cnn1", "cnn2").addVertex("subset1", new SubsetVertex(0, 1), "merge1").addLayer("dense1", new DenseLayer.Builder().nIn(20).nOut(5).build(), "subset1").addLayer("dense2", new DenseLayer.Builder().nIn(20).nOut(5).build(), "subset1").addVertex("add", new ElementWiseVertex(ElementWiseVertex.Op.Add), "dense1", "dense2").addLayer("out", new OutputLayer.Builder().nIn(1).nOut(1).activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build(), "add").setOutputs("out").build();
new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.graphBuilder().addInputs("input1", "input2")
.addLayer("cnn1",
new ConvolutionLayer.Builder(2, 2).stride(2, 2).nIn(1).nOut(5)
.build(),
"input1")
.addLayer("cnn2",
new ConvolutionLayer.Builder(2, 2).stride(2, 2).nIn(1).nOut(5)
.build(),
"input2")
.addVertex("merge1", new MergeVertex(), "cnn1", "cnn2")
.addVertex("subset1", new SubsetVertex(0, 1), "merge1")
.addLayer("dense1", new DenseLayer.Builder().nIn(20).nOut(5).build(), "subset1")
.addLayer("dense2", new DenseLayer.Builder().nIn(20).nOut(5).build(), "subset1")
.addVertex("add", new ElementWiseVertex(ElementWiseVertex.Op.Add), "dense1",
"dense2")
.addLayer("out", new OutputLayer.Builder().nIn(1).nOut(1).activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build(), "add")
.setOutputs("out").build();
String json = conf.toJson(); String json = conf.toJson();
// System.out.println(json); // System.out.println(json);
ComputationGraphConfiguration conf2 = ComputationGraphConfiguration.fromJson(json); ComputationGraphConfiguration conf2 = ComputationGraphConfiguration.fromJson(json);
assertEquals(json, conf2.toJson()); assertEquals(json, conf2.toJson());
assertEquals(conf, conf2); assertEquals(conf, conf2);
} }
@Test @Test
public void testInvalidConfigurations() { @DisplayName("Test Invalid Configurations")
void testInvalidConfigurations() {
//Test no inputs for a layer: // Test no inputs for a layer:
try { try {
new NeuralNetConfiguration.Builder().graphBuilder().addInputs("input1") new NeuralNetConfiguration.Builder().graphBuilder().addInputs("input1").addLayer("dense1", new DenseLayer.Builder().nIn(2).nOut(2).build(), "input1").addLayer("out", new OutputLayer.Builder().nIn(2).nOut(2).build()).setOutputs("out").build();
.addLayer("dense1", new DenseLayer.Builder().nIn(2).nOut(2).build(), "input1")
.addLayer("out", new OutputLayer.Builder().nIn(2).nOut(2).build()).setOutputs("out")
.build();
fail("No exception thrown for invalid configuration"); fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) { } catch (IllegalStateException e) {
//OK - exception is good // OK - exception is good
log.info(e.toString()); log.info(e.toString());
} }
// Use appendLayer on first layer // Use appendLayer on first layer
try { try {
new NeuralNetConfiguration.Builder().graphBuilder() new NeuralNetConfiguration.Builder().graphBuilder().appendLayer("dense1", new DenseLayer.Builder().nIn(2).nOut(2).build()).addLayer("out", new OutputLayer.Builder().nIn(2).nOut(2).build()).setOutputs("out").build();
.appendLayer("dense1", new DenseLayer.Builder().nIn(2).nOut(2).build())
.addLayer("out", new OutputLayer.Builder().nIn(2).nOut(2).build()).setOutputs("out")
.build();
fail("No exception thrown for invalid configuration"); fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) { } catch (IllegalStateException e) {
//OK - exception is good // OK - exception is good
log.info(e.toString()); log.info(e.toString());
} }
// Test no network inputs
//Test no network inputs
try { try {
new NeuralNetConfiguration.Builder().graphBuilder() new NeuralNetConfiguration.Builder().graphBuilder().addLayer("dense1", new DenseLayer.Builder().nIn(2).nOut(2).build(), "input1").addLayer("out", new OutputLayer.Builder().nIn(2).nOut(2).build(), "dense1").setOutputs("out").build();
.addLayer("dense1", new DenseLayer.Builder().nIn(2).nOut(2).build(), "input1")
.addLayer("out", new OutputLayer.Builder().nIn(2).nOut(2).build(), "dense1")
.setOutputs("out").build();
fail("No exception thrown for invalid configuration"); fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) { } catch (IllegalStateException e) {
//OK - exception is good // OK - exception is good
log.info(e.toString()); log.info(e.toString());
} }
// Test no network outputs
//Test no network outputs
try { try {
new NeuralNetConfiguration.Builder().graphBuilder().addInputs("input1") new NeuralNetConfiguration.Builder().graphBuilder().addInputs("input1").addLayer("dense1", new DenseLayer.Builder().nIn(2).nOut(2).build(), "input1").addLayer("out", new OutputLayer.Builder().nIn(2).nOut(2).build(), "dense1").build();
.addLayer("dense1", new DenseLayer.Builder().nIn(2).nOut(2).build(), "input1")
.addLayer("out", new OutputLayer.Builder().nIn(2).nOut(2).build(), "dense1").build();
fail("No exception thrown for invalid configuration"); fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) { } catch (IllegalStateException e) {
//OK - exception is good // OK - exception is good
log.info(e.toString()); log.info(e.toString());
} }
// Test: invalid input
//Test: invalid input
try { try {
new NeuralNetConfiguration.Builder().graphBuilder().addInputs("input1") new NeuralNetConfiguration.Builder().graphBuilder().addInputs("input1").addLayer("dense1", new DenseLayer.Builder().nIn(2).nOut(2).build(), "input1").addLayer("out", new OutputLayer.Builder().nIn(2).nOut(2).build(), "thisDoesntExist").setOutputs("out").build();
.addLayer("dense1", new DenseLayer.Builder().nIn(2).nOut(2).build(), "input1")
.addLayer("out", new OutputLayer.Builder().nIn(2).nOut(2).build(), "thisDoesntExist")
.setOutputs("out").build();
fail("No exception thrown for invalid configuration"); fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) { } catch (IllegalStateException e) {
//OK - exception is good // OK - exception is good
log.info(e.toString()); log.info(e.toString());
} }
// Test: graph with cycles
//Test: graph with cycles
try { try {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("input1") ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("input1").addLayer("dense1", new DenseLayer.Builder().nIn(2).nOut(2).build(), "input1", "dense3").addLayer("dense2", new DenseLayer.Builder().nIn(2).nOut(2).build(), "dense1").addLayer("dense3", new DenseLayer.Builder().nIn(2).nOut(2).build(), "dense2").addLayer("out", new OutputLayer.Builder().nIn(2).nOut(2).lossFunction(LossFunctions.LossFunction.MSE).build(), "dense1").setOutputs("out").build();
.addLayer("dense1", new DenseLayer.Builder().nIn(2).nOut(2).build(), "input1", "dense3") // Cycle detection happens in ComputationGraph.init()
.addLayer("dense2", new DenseLayer.Builder().nIn(2).nOut(2).build(), "dense1")
.addLayer("dense3", new DenseLayer.Builder().nIn(2).nOut(2).build(), "dense2")
.addLayer("out", new OutputLayer.Builder().nIn(2).nOut(2).lossFunction(LossFunctions.LossFunction.MSE).build(), "dense1")
.setOutputs("out").build();
//Cycle detection happens in ComputationGraph.init()
ComputationGraph graph = new ComputationGraph(conf); ComputationGraph graph = new ComputationGraph(conf);
graph.init(); graph.init();
fail("No exception thrown for invalid configuration"); fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) { } catch (IllegalStateException e) {
//OK - exception is good // OK - exception is good
log.info(e.toString()); log.info(e.toString());
} }
// Test: input != inputType count mismatch
//Test: input != inputType count mismatch
try { try {
new NeuralNetConfiguration.Builder().graphBuilder().addInputs("input1", "input2") new NeuralNetConfiguration.Builder().graphBuilder().addInputs("input1", "input2").setInputTypes(new InputType.InputTypeRecurrent(10, 12)).addLayer("cnn1", new ConvolutionLayer.Builder(2, 2).stride(2, 2).nIn(1).nOut(5).build(), "input1").addLayer("cnn2", new ConvolutionLayer.Builder(2, 2).stride(2, 2).nIn(1).nOut(5).build(), "input2").addVertex("merge1", new MergeVertex(), "cnn1", "cnn2").addVertex("subset1", new SubsetVertex(0, 1), "merge1").addLayer("dense1", new DenseLayer.Builder().nIn(20).nOut(5).build(), "subset1").addLayer("dense2", new DenseLayer.Builder().nIn(20).nOut(5).build(), "subset1").addVertex("add", new ElementWiseVertex(ElementWiseVertex.Op.Add), "dense1", "dense2").addLayer("out", new OutputLayer.Builder().nIn(1).nOut(1).activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build(), "add").setOutputs("out").build();
.setInputTypes(new InputType.InputTypeRecurrent(10, 12))
.addLayer("cnn1",
new ConvolutionLayer.Builder(2, 2).stride(2, 2).nIn(1).nOut(5)
.build(),
"input1")
.addLayer("cnn2",
new ConvolutionLayer.Builder(2, 2).stride(2, 2).nIn(1).nOut(5)
.build(),
"input2")
.addVertex("merge1", new MergeVertex(), "cnn1", "cnn2")
.addVertex("subset1", new SubsetVertex(0, 1), "merge1")
.addLayer("dense1", new DenseLayer.Builder().nIn(20).nOut(5).build(), "subset1")
.addLayer("dense2", new DenseLayer.Builder().nIn(20).nOut(5).build(), "subset1")
.addVertex("add", new ElementWiseVertex(ElementWiseVertex.Op.Add), "dense1",
"dense2")
.addLayer("out", new OutputLayer.Builder().nIn(1).nOut(1).activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build(), "add")
.setOutputs("out").build();
fail("No exception thrown for invalid configuration"); fail("No exception thrown for invalid configuration");
} catch (IllegalArgumentException e) { } catch (IllegalArgumentException e) {
//OK - exception is good // OK - exception is good
log.info(e.toString()); log.info(e.toString());
} }
} }
@Test @Test
public void testConfigurationWithRuntimeJSONSubtypes() { @DisplayName("Test Configuration With Runtime JSON Subtypes")
//Idea: suppose someone wants to use a ComputationGraph with a custom GraphVertex void testConfigurationWithRuntimeJSONSubtypes() {
// Idea: suppose someone wants to use a ComputationGraph with a custom GraphVertex
// (i.e., one not built into DL4J). Check that this works for JSON serialization // (i.e., one not built into DL4J). Check that this works for JSON serialization
// using runtime/reflection subtype mechanism in ComputationGraphConfiguration.fromJson() // using runtime/reflection subtype mechanism in ComputationGraphConfiguration.fromJson()
//Check a standard GraphVertex implementation, plus a static inner graph vertex // Check a standard GraphVertex implementation, plus a static inner graph vertex
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in").addVertex("test", new TestGraphVertex(3, 7), "in").addVertex("test2", new StaticInnerGraphVertex(4, 5), "in").setOutputs("test", "test2").build();
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in")
.addVertex("test", new TestGraphVertex(3, 7), "in")
.addVertex("test2", new StaticInnerGraphVertex(4, 5), "in").setOutputs("test", "test2").build();
String json = conf.toJson(); String json = conf.toJson();
// System.out.println(json); // System.out.println(json);
ComputationGraphConfiguration conf2 = ComputationGraphConfiguration.fromJson(json); ComputationGraphConfiguration conf2 = ComputationGraphConfiguration.fromJson(json);
assertEquals(conf, conf2); assertEquals(conf, conf2);
assertEquals(json, conf2.toJson()); assertEquals(json, conf2.toJson());
TestGraphVertex tgv = (TestGraphVertex) conf2.getVertices().get("test"); TestGraphVertex tgv = (TestGraphVertex) conf2.getVertices().get("test");
assertEquals(3, tgv.getFirstVal()); assertEquals(3, tgv.getFirstVal());
assertEquals(7, tgv.getSecondVal()); assertEquals(7, tgv.getSecondVal());
StaticInnerGraphVertex sigv = (StaticInnerGraphVertex) conf.getVertices().get("test2"); StaticInnerGraphVertex sigv = (StaticInnerGraphVertex) conf.getVertices().get("test2");
assertEquals(4, sigv.getFirstVal()); assertEquals(4, sigv.getFirstVal());
assertEquals(5, sigv.getSecondVal()); assertEquals(5, sigv.getSecondVal());
} }
@Test @Test
public void testOutputOrderDoesntChangeWhenCloning() { @DisplayName("Test Output Order Doesnt Change When Cloning")
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in") void testOutputOrderDoesntChangeWhenCloning() {
.addLayer("out1", new OutputLayer.Builder().nIn(1).nOut(1).build(), "in") ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in").addLayer("out1", new OutputLayer.Builder().nIn(1).nOut(1).build(), "in").addLayer("out2", new OutputLayer.Builder().nIn(1).nOut(1).build(), "in").addLayer("out3", new OutputLayer.Builder().nIn(1).nOut(1).build(), "in").validateOutputLayerConfig(false).setOutputs("out1", "out2", "out3").build();
.addLayer("out2", new OutputLayer.Builder().nIn(1).nOut(1).build(), "in")
.addLayer("out3", new OutputLayer.Builder().nIn(1).nOut(1).build(), "in")
.validateOutputLayerConfig(false)
.setOutputs("out1", "out2", "out3").build();
ComputationGraphConfiguration cloned = conf.clone(); ComputationGraphConfiguration cloned = conf.clone();
String json = conf.toJson(); String json = conf.toJson();
String jsonCloned = cloned.toJson(); String jsonCloned = cloned.toJson();
assertEquals(json, jsonCloned); assertEquals(json, jsonCloned);
} }
@Test @Test
public void testAllowDisconnectedLayers() { @DisplayName("Test Allow Disconnected Layers")
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in") void testAllowDisconnectedLayers() {
.addLayer("bidirectional", ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in").addLayer("bidirectional", new Bidirectional(new LSTM.Builder().activation(Activation.TANH).nOut(10).build()), "in").addLayer("out", new RnnOutputLayer.Builder().nOut(6).lossFunction(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).build(), "bidirectional").addLayer("disconnected_layer", new Bidirectional(new LSTM.Builder().activation(Activation.TANH).nOut(10).build()), "in").setOutputs("out").setInputTypes(new InputType.InputTypeRecurrent(10, 12)).allowDisconnected(true).build();
new Bidirectional(new LSTM.Builder().activation(Activation.TANH).nOut(10).build()),
"in")
.addLayer("out", new RnnOutputLayer.Builder().nOut(6)
.lossFunction(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX)
.build(), "bidirectional")
.addLayer("disconnected_layer",
new Bidirectional(new LSTM.Builder().activation(Activation.TANH).nOut(10).build()),
"in")
.setOutputs("out")
.setInputTypes(new InputType.InputTypeRecurrent(10, 12))
.allowDisconnected(true)
.build();
ComputationGraph graph = new ComputationGraph(conf); ComputationGraph graph = new ComputationGraph(conf);
graph.init(); graph.init();
} }
@Test @Test
public void testBidirectionalGraphSummary() { @DisplayName("Test Bidirectional Graph Summary")
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in") void testBidirectionalGraphSummary() {
.addLayer("bidirectional", ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in").addLayer("bidirectional", new Bidirectional(new LSTM.Builder().activation(Activation.TANH).nOut(10).build()), "in").addLayer("out", new RnnOutputLayer.Builder().nOut(6).lossFunction(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).build(), "bidirectional").setOutputs("out").setInputTypes(new InputType.InputTypeRecurrent(10, 12)).build();
new Bidirectional(new LSTM.Builder().activation(Activation.TANH).nOut(10).build()),
"in")
.addLayer("out", new RnnOutputLayer.Builder().nOut(6)
.lossFunction(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX)
.build(), "bidirectional")
.setOutputs("out")
.setInputTypes(new InputType.InputTypeRecurrent(10, 12))
.build();
ComputationGraph graph = new ComputationGraph(conf); ComputationGraph graph = new ComputationGraph(conf);
graph.init(); graph.init();
graph.summary(); graph.summary();
@ -342,9 +202,11 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
@NoArgsConstructor @NoArgsConstructor
@Data @Data
@EqualsAndHashCode(callSuper = false) @EqualsAndHashCode(callSuper = false)
public static class StaticInnerGraphVertex extends GraphVertex { @DisplayName("Static Inner Graph Vertex")
static class StaticInnerGraphVertex extends GraphVertex {
private int firstVal; private int firstVal;
private int secondVal; private int secondVal;
@Override @Override
@ -368,8 +230,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
} }
@Override @Override
public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph graph, String name, int idx, public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph graph, String name, int idx, INDArray paramsView, boolean initializeParams, DataType networkDatatype) {
INDArray paramsView, boolean initializeParams, DataType networkDatatype) {
throw new UnsupportedOperationException("Not supported"); throw new UnsupportedOperationException("Not supported");
} }
@ -384,9 +245,9 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
} }
} }
@Test @Test
public void testInvalidOutputLayer(){ @DisplayName("Test Invalid Output Layer")
void testInvalidOutputLayer() {
/* /*
Test case (invalid configs) Test case (invalid configs)
1. nOut=1 + softmax 1. nOut=1 + softmax
@ -395,35 +256,24 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
4. xent + relu 4. xent + relu
5. mcxent + sigmoid 5. mcxent + sigmoid
*/ */
LossFunctions.LossFunction[] lf = new LossFunctions.LossFunction[] { LossFunctions.LossFunction.MCXENT, LossFunctions.LossFunction.MCXENT, LossFunctions.LossFunction.XENT, LossFunctions.LossFunction.XENT, LossFunctions.LossFunction.MCXENT };
LossFunctions.LossFunction[] lf = new LossFunctions.LossFunction[]{ int[] nOut = new int[] { 1, 3, 3, 3, 3 };
LossFunctions.LossFunction.MCXENT, LossFunctions.LossFunction.MCXENT, LossFunctions.LossFunction.XENT, Activation[] activations = new Activation[] { Activation.SOFTMAX, Activation.TANH, Activation.SOFTMAX, Activation.RELU, Activation.SIGMOID };
LossFunctions.LossFunction.XENT, LossFunctions.LossFunction.MCXENT}; for (int i = 0; i < lf.length; i++) {
int[] nOut = new int[]{1, 3, 3, 3, 3}; for (boolean lossLayer : new boolean[] { false, true }) {
Activation[] activations = new Activation[]{Activation.SOFTMAX, Activation.TANH, Activation.SOFTMAX, Activation.RELU, Activation.SIGMOID}; for (boolean validate : new boolean[] { true, false }) {
for( int i=0; i<lf.length; i++ ){
for(boolean lossLayer : new boolean[]{false, true}) {
for (boolean validate : new boolean[]{true, false}) {
String s = "nOut=" + nOut[i] + ",lossFn=" + lf[i] + ",lossLayer=" + lossLayer + ",validate=" + validate; String s = "nOut=" + nOut[i] + ",lossFn=" + lf[i] + ",lossLayer=" + lossLayer + ",validate=" + validate;
if(nOut[i] == 1 && lossLayer) if (nOut[i] == 1 && lossLayer)
continue; //nOuts are not availabel in loss layer, can't expect it to detect this case // nOuts are not availabel in loss layer, can't expect it to detect this case
continue;
try { try {
new NeuralNetConfiguration.Builder() new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in").layer("0", new DenseLayer.Builder().nIn(10).nOut(10).build(), "in").layer("1", !lossLayer ? new OutputLayer.Builder().nIn(10).nOut(nOut[i]).activation(activations[i]).lossFunction(lf[i]).build() : new LossLayer.Builder().activation(activations[i]).lossFunction(lf[i]).build(), "0").setOutputs("1").validateOutputLayerConfig(validate).build();
.graphBuilder()
.addInputs("in")
.layer("0", new DenseLayer.Builder().nIn(10).nOut(10).build(), "in")
.layer("1",
!lossLayer ? new OutputLayer.Builder().nIn(10).nOut(nOut[i]).activation(activations[i]).lossFunction(lf[i]).build()
: new LossLayer.Builder().activation(activations[i]).lossFunction(lf[i]).build(), "0")
.setOutputs("1")
.validateOutputLayerConfig(validate)
.build();
if (validate) { if (validate) {
fail("Expected exception: " + s); fail("Expected exception: " + s);
} }
} catch (DL4JInvalidConfigException e) { } catch (DL4JInvalidConfigException e) {
if (validate) { if (validate) {
assertTrue(s, e.getMessage().toLowerCase().contains("invalid output")); assertTrue(e.getMessage().toLowerCase().contains("invalid output"),s);
} else { } else {
fail("Validation should not be enabled"); fail("Validation should not be enabled");
} }

View File

@ -17,102 +17,86 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.deeplearning4j.nn.conf; package org.deeplearning4j.nn.conf;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.LossLayer; import org.deeplearning4j.nn.conf.layers.LossLayer;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.lossfunctions.ILossFunction; import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.impl.*; import org.nd4j.linalg.lossfunctions.impl.*;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals; @DisplayName("Json Test")
class JsonTest extends BaseDL4JTest {
public class JsonTest extends BaseDL4JTest {
@Test @Test
public void testJsonLossFunctions() { @DisplayName("Test Json Loss Functions")
void testJsonLossFunctions() {
ILossFunction[] lossFunctions = new ILossFunction[] {new LossBinaryXENT(), new LossBinaryXENT(), ILossFunction[] lossFunctions = new ILossFunction[] { new LossBinaryXENT(), new LossBinaryXENT(), new LossCosineProximity(), new LossHinge(), new LossKLD(), new LossKLD(), new LossL1(), new LossL1(), new LossL2(), new LossL2(), new LossMAE(), new LossMAE(), new LossMAPE(), new LossMAPE(), new LossMCXENT(), new LossMSE(), new LossMSE(), new LossMSLE(), new LossMSLE(), new LossNegativeLogLikelihood(), new LossNegativeLogLikelihood(), new LossPoisson(), new LossSquaredHinge(), new LossFMeasure(), new LossFMeasure(2.0) };
new LossCosineProximity(), new LossHinge(), new LossKLD(), new LossKLD(), new LossL1(), Activation[] outputActivationFn = new Activation[] { // xent
new LossL1(), new LossL2(), new LossL2(), new LossMAE(), new LossMAE(), new LossMAPE(), Activation.SIGMOID, // xent
new LossMAPE(), new LossMCXENT(), new LossMSE(), new LossMSE(), new LossMSLE(), new LossMSLE(), Activation.SIGMOID, // cosine
new LossNegativeLogLikelihood(), new LossNegativeLogLikelihood(), new LossPoisson(), Activation.TANH, // hinge -> trying to predict 1 or -1
new LossSquaredHinge(), new LossFMeasure(), new LossFMeasure(2.0)}; Activation.TANH, // kld -> probab so should be between 0 and 1
Activation.SIGMOID, // kld + softmax
Activation[] outputActivationFn = new Activation[] {Activation.SIGMOID, //xent Activation.SOFTMAX, // l1
Activation.SIGMOID, //xent Activation.TANH, // l1 + softmax
Activation.TANH, //cosine Activation.SOFTMAX, // l2
Activation.TANH, //hinge -> trying to predict 1 or -1 Activation.TANH, // l2 + softmax
Activation.SIGMOID, //kld -> probab so should be between 0 and 1 Activation.SOFTMAX, // mae
Activation.SOFTMAX, //kld + softmax Activation.IDENTITY, // mae + softmax
Activation.TANH, //l1 Activation.SOFTMAX, // mape
Activation.SOFTMAX, //l1 + softmax Activation.IDENTITY, // mape + softmax
Activation.TANH, //l2 Activation.SOFTMAX, // mcxent
Activation.SOFTMAX, //l2 + softmax Activation.SOFTMAX, // mse
Activation.IDENTITY, //mae Activation.IDENTITY, // mse + softmax
Activation.SOFTMAX, //mae + softmax Activation.SOFTMAX, // msle - requires positive labels/activations due to log
Activation.IDENTITY, //mape Activation.SIGMOID, // msle + softmax
Activation.SOFTMAX, //mape + softmax Activation.SOFTMAX, // nll
Activation.SOFTMAX, //mcxent Activation.SIGMOID, // nll + softmax
Activation.IDENTITY, //mse Activation.SOFTMAX, // poisson - requires positive predictions due to log... not sure if this is the best option
Activation.SOFTMAX, //mse + softmax Activation.SIGMOID, // squared hinge
Activation.SIGMOID, //msle - requires positive labels/activations due to log Activation.TANH, // f-measure (binary, single sigmoid output)
Activation.SOFTMAX, //msle + softmax Activation.SIGMOID, // f-measure (binary, 2-label softmax output)
Activation.SIGMOID, //nll Activation.SOFTMAX };
Activation.SOFTMAX, //nll + softmax int[] nOut = new int[] { // xent
Activation.SIGMOID, //poisson - requires positive predictions due to log... not sure if this is the best option 1, // xent
Activation.TANH, //squared hinge 3, // cosine
Activation.SIGMOID, //f-measure (binary, single sigmoid output) 5, // hinge
Activation.SOFTMAX //f-measure (binary, 2-label softmax output) 3, // kld
}; 3, // kld + softmax
3, // l1
int[] nOut = new int[] {1, //xent 3, // l1 + softmax
3, //xent 3, // l2
5, //cosine 3, // l2 + softmax
3, //hinge 3, // mae
3, //kld 3, // mae + softmax
3, //kld + softmax 3, // mape
3, //l1 3, // mape + softmax
3, //l1 + softmax 3, // mcxent
3, //l2 3, // mse
3, //l2 + softmax 3, // mse + softmax
3, //mae 3, // msle
3, //mae + softmax 3, // msle + softmax
3, //mape 3, // nll
3, //mape + softmax 3, // nll + softmax
3, //mcxent 3, // poisson
3, //mse 3, // squared hinge
3, //mse + softmax 3, // f-measure (binary, single sigmoid output)
3, //msle 1, // f-measure (binary, 2-label softmax output)
3, //msle + softmax 2 };
3, //nll
3, //nll + softmax
3, //poisson
3, //squared hinge
1, //f-measure (binary, single sigmoid output)
2, //f-measure (binary, 2-label softmax output)
};
for (int i = 0; i < lossFunctions.length; i++) { for (int i = 0; i < lossFunctions.length; i++) {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(Updater.ADAM).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(nOut[i]).activation(Activation.TANH).build()).layer(1, new LossLayer.Builder().lossFunction(lossFunctions[i]).activation(outputActivationFn[i]).build()).validateOutputLayerConfig(false).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(Updater.ADAM).list()
.layer(0, new DenseLayer.Builder().nIn(4).nOut(nOut[i]).activation(Activation.TANH).build())
.layer(1, new LossLayer.Builder().lossFunction(lossFunctions[i])
.activation(outputActivationFn[i]).build())
.validateOutputLayerConfig(false).build();
String json = conf.toJson(); String json = conf.toJson();
String yaml = conf.toYaml(); String yaml = conf.toYaml();
MultiLayerConfiguration fromJson = MultiLayerConfiguration.fromJson(json); MultiLayerConfiguration fromJson = MultiLayerConfiguration.fromJson(json);
MultiLayerConfiguration fromYaml = MultiLayerConfiguration.fromYaml(yaml); MultiLayerConfiguration fromYaml = MultiLayerConfiguration.fromYaml(yaml);
assertEquals(conf, fromJson); assertEquals(conf, fromJson);
assertEquals(conf, fromYaml); assertEquals(conf, fromYaml);
} }
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.deeplearning4j.nn.conf; package org.deeplearning4j.nn.conf;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@ -34,41 +33,40 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.TemporaryFolder; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.io.*; import java.io.*;
import java.util.Arrays; import java.util.Arrays;
import java.util.Properties; import java.util.Properties;
import static org.junit.jupiter.api.Assertions.*;
import static org.junit.Assert.*; import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
@Slf4j @Slf4j
public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest { @DisplayName("Multi Layer Neural Net Configuration Test")
class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
@Rule @TempDir
public TemporaryFolder testDir = new TemporaryFolder(); public Path testDir;
@Test @Test
public void testJson() throws Exception { @DisplayName("Test Json")
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list() void testJson() throws Exception {
.layer(0, new DenseLayer.Builder().dist(new NormalDistribution(1, 1e-1)).build()) MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(0, new DenseLayer.Builder().dist(new NormalDistribution(1, 1e-1)).build()).inputPreProcessor(0, new CnnToFeedForwardPreProcessor()).build();
.inputPreProcessor(0, new CnnToFeedForwardPreProcessor()).build();
String json = conf.toJson(); String json = conf.toJson();
MultiLayerConfiguration from = MultiLayerConfiguration.fromJson(json); MultiLayerConfiguration from = MultiLayerConfiguration.fromJson(json);
assertEquals(conf.getConf(0), from.getConf(0)); assertEquals(conf.getConf(0), from.getConf(0));
Properties props = new Properties(); Properties props = new Properties();
props.put("json", json); props.put("json", json);
String key = props.getProperty("json"); String key = props.getProperty("json");
assertEquals(json, key); assertEquals(json, key);
File f = testDir.newFile("props"); File f = testDir.resolve("props").toFile();
f.deleteOnExit(); f.deleteOnExit();
BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(f)); BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(f));
props.store(bos, ""); props.store(bos, "");
@ -82,36 +80,18 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
String json2 = props2.getProperty("json"); String json2 = props2.getProperty("json");
MultiLayerConfiguration conf3 = MultiLayerConfiguration.fromJson(json2); MultiLayerConfiguration conf3 = MultiLayerConfiguration.fromJson(json2);
assertEquals(conf.getConf(0), conf3.getConf(0)); assertEquals(conf.getConf(0), conf3.getConf(0));
} }
@Test @Test
public void testConvnetJson() { @DisplayName("Test Convnet Json")
void testConvnetJson() {
final int numRows = 76; final int numRows = 76;
final int numColumns = 76; final int numColumns = 76;
int nChannels = 3; int nChannels = 3;
int outputNum = 6; int outputNum = 6;
int seed = 123; int seed = 123;
// setup the network
//setup the network MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed).l1(1e-1).l2(2e-4).weightNoise(new DropConnect(0.5)).miniBatch(true).optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).list().layer(0, new ConvolutionLayer.Builder(5, 5).nOut(5).dropOut(0.5).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build()).layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[] { 2, 2 }).build()).layer(2, new ConvolutionLayer.Builder(3, 3).nOut(10).dropOut(0.5).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build()).layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[] { 2, 2 }).build()).layer(4, new DenseLayer.Builder().nOut(100).activation(Activation.RELU).build()).layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(outputNum).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build()).setInputType(InputType.convolutional(numRows, numColumns, nChannels));
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed)
.l1(1e-1).l2(2e-4).weightNoise(new DropConnect(0.5)).miniBatch(true)
.optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).list()
.layer(0, new ConvolutionLayer.Builder(5, 5).nOut(5).dropOut(0.5).weightInit(WeightInit.XAVIER)
.activation(Activation.RELU).build())
.layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[] {2, 2})
.build())
.layer(2, new ConvolutionLayer.Builder(3, 3).nOut(10).dropOut(0.5).weightInit(WeightInit.XAVIER)
.activation(Activation.RELU).build())
.layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[] {2, 2})
.build())
.layer(4, new DenseLayer.Builder().nOut(100).activation(Activation.RELU).build())
.layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nOut(outputNum).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX)
.build())
.setInputType(InputType.convolutional(numRows, numColumns, nChannels));
MultiLayerConfiguration conf = builder.build(); MultiLayerConfiguration conf = builder.build();
String json = conf.toJson(); String json = conf.toJson();
MultiLayerConfiguration conf2 = MultiLayerConfiguration.fromJson(json); MultiLayerConfiguration conf2 = MultiLayerConfiguration.fromJson(json);
@ -119,30 +99,15 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
} }
@Test @Test
public void testUpsamplingConvnetJson() { @DisplayName("Test Upsampling Convnet Json")
void testUpsamplingConvnetJson() {
final int numRows = 76; final int numRows = 76;
final int numColumns = 76; final int numColumns = 76;
int nChannels = 3; int nChannels = 3;
int outputNum = 6; int outputNum = 6;
int seed = 123; int seed = 123;
// setup the network
//setup the network MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed).l1(1e-1).l2(2e-4).dropOut(0.5).miniBatch(true).optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).list().layer(new ConvolutionLayer.Builder(5, 5).nOut(5).dropOut(0.5).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build()).layer(new Upsampling2D.Builder().size(2).build()).layer(2, new ConvolutionLayer.Builder(3, 3).nOut(10).dropOut(0.5).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build()).layer(new Upsampling2D.Builder().size(2).build()).layer(4, new DenseLayer.Builder().nOut(100).activation(Activation.RELU).build()).layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(outputNum).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build()).setInputType(InputType.convolutional(numRows, numColumns, nChannels));
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed)
.l1(1e-1).l2(2e-4).dropOut(0.5).miniBatch(true)
.optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).list()
.layer(new ConvolutionLayer.Builder(5, 5).nOut(5).dropOut(0.5).weightInit(WeightInit.XAVIER)
.activation(Activation.RELU).build())
.layer(new Upsampling2D.Builder().size(2).build())
.layer(2, new ConvolutionLayer.Builder(3, 3).nOut(10).dropOut(0.5).weightInit(WeightInit.XAVIER)
.activation(Activation.RELU).build())
.layer(new Upsampling2D.Builder().size(2).build())
.layer(4, new DenseLayer.Builder().nOut(100).activation(Activation.RELU).build())
.layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nOut(outputNum).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX)
.build())
.setInputType(InputType.convolutional(numRows, numColumns, nChannels));
MultiLayerConfiguration conf = builder.build(); MultiLayerConfiguration conf = builder.build();
String json = conf.toJson(); String json = conf.toJson();
MultiLayerConfiguration conf2 = MultiLayerConfiguration.fromJson(json); MultiLayerConfiguration conf2 = MultiLayerConfiguration.fromJson(json);
@ -150,36 +115,26 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
} }
@Test @Test
public void testGlobalPoolingJson() { @DisplayName("Test Global Pooling Json")
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new NoOp()) void testGlobalPoolingJson() {
.dist(new NormalDistribution(0, 1.0)).seed(12345L).list() MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new NoOp()).dist(new NormalDistribution(0, 1.0)).seed(12345L).list().layer(0, new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).nOut(5).build()).layer(1, new GlobalPoolingLayer.Builder().poolingType(PoolingType.PNORM).pnorm(3).build()).layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(3).build()).setInputType(InputType.convolutional(32, 32, 1)).build();
.layer(0, new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).nOut(5).build())
.layer(1, new GlobalPoolingLayer.Builder().poolingType(PoolingType.PNORM).pnorm(3).build())
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(3).build())
.setInputType(InputType.convolutional(32, 32, 1)).build();
String str = conf.toJson(); String str = conf.toJson();
MultiLayerConfiguration fromJson = conf.fromJson(str); MultiLayerConfiguration fromJson = conf.fromJson(str);
assertEquals(conf, fromJson); assertEquals(conf, fromJson);
} }
@Test @Test
public void testYaml() throws Exception { @DisplayName("Test Yaml")
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list() void testYaml() throws Exception {
.layer(0, new DenseLayer.Builder().dist(new NormalDistribution(1, 1e-1)).build()) MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(0, new DenseLayer.Builder().dist(new NormalDistribution(1, 1e-1)).build()).inputPreProcessor(0, new CnnToFeedForwardPreProcessor()).build();
.inputPreProcessor(0, new CnnToFeedForwardPreProcessor()).build();
String json = conf.toYaml(); String json = conf.toYaml();
MultiLayerConfiguration from = MultiLayerConfiguration.fromYaml(json); MultiLayerConfiguration from = MultiLayerConfiguration.fromYaml(json);
assertEquals(conf.getConf(0), from.getConf(0)); assertEquals(conf.getConf(0), from.getConf(0));
Properties props = new Properties(); Properties props = new Properties();
props.put("json", json); props.put("json", json);
String key = props.getProperty("json"); String key = props.getProperty("json");
assertEquals(json, key); assertEquals(json, key);
File f = testDir.newFile("props"); File f = testDir.resolve("props").toFile();
f.deleteOnExit(); f.deleteOnExit();
BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(f)); BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(f));
props.store(bos, ""); props.store(bos, "");
@ -193,17 +148,13 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
String yaml = props2.getProperty("json"); String yaml = props2.getProperty("json");
MultiLayerConfiguration conf3 = MultiLayerConfiguration.fromYaml(yaml); MultiLayerConfiguration conf3 = MultiLayerConfiguration.fromYaml(yaml);
assertEquals(conf.getConf(0), conf3.getConf(0)); assertEquals(conf.getConf(0), conf3.getConf(0));
} }
@Test @Test
public void testClone() { @DisplayName("Test Clone")
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(0, new DenseLayer.Builder().build()) void testClone() {
.layer(1, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).build()) MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(0, new DenseLayer.Builder().build()).layer(1, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).build()).inputPreProcessor(1, new CnnToFeedForwardPreProcessor()).build();
.inputPreProcessor(1, new CnnToFeedForwardPreProcessor()).build();
MultiLayerConfiguration conf2 = conf.clone(); MultiLayerConfiguration conf2 = conf.clone();
assertEquals(conf, conf2); assertEquals(conf, conf2);
assertNotSame(conf, conf2); assertNotSame(conf, conf2);
assertNotSame(conf.getConfs(), conf2.getConfs()); assertNotSame(conf.getConfs(), conf2.getConfs());
@ -217,174 +168,125 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
} }
@Test @Test
public void testRandomWeightInit() { @DisplayName("Test Random Weight Init")
void testRandomWeightInit() {
MultiLayerNetwork model1 = new MultiLayerNetwork(getConf()); MultiLayerNetwork model1 = new MultiLayerNetwork(getConf());
model1.init(); model1.init();
Nd4j.getRandom().setSeed(12345L); Nd4j.getRandom().setSeed(12345L);
MultiLayerNetwork model2 = new MultiLayerNetwork(getConf()); MultiLayerNetwork model2 = new MultiLayerNetwork(getConf());
model2.init(); model2.init();
float[] p1 = model1.params().data().asFloat(); float[] p1 = model1.params().data().asFloat();
float[] p2 = model2.params().data().asFloat(); float[] p2 = model2.params().data().asFloat();
System.out.println(Arrays.toString(p1)); System.out.println(Arrays.toString(p1));
System.out.println(Arrays.toString(p2)); System.out.println(Arrays.toString(p2));
assertArrayEquals(p1, p2, 0.0f);
org.junit.Assert.assertArrayEquals(p1, p2, 0.0f);
} }
@Test @Test
public void testTrainingListener() { @DisplayName("Test Training Listener")
void testTrainingListener() {
MultiLayerNetwork model1 = new MultiLayerNetwork(getConf()); MultiLayerNetwork model1 = new MultiLayerNetwork(getConf());
model1.init(); model1.init();
model1.addListeners( new ScoreIterationListener(1)); model1.addListeners(new ScoreIterationListener(1));
MultiLayerNetwork model2 = new MultiLayerNetwork(getConf()); MultiLayerNetwork model2 = new MultiLayerNetwork(getConf());
model2.addListeners( new ScoreIterationListener(1)); model2.addListeners(new ScoreIterationListener(1));
model2.init(); model2.init();
Layer[] l1 = model1.getLayers(); Layer[] l1 = model1.getLayers();
for (int i = 0; i < l1.length; i++) for (int i = 0; i < l1.length; i++) assertTrue(l1[i].getListeners() != null && l1[i].getListeners().size() == 1);
assertTrue(l1[i].getListeners() != null && l1[i].getListeners().size() == 1);
Layer[] l2 = model2.getLayers(); Layer[] l2 = model2.getLayers();
for (int i = 0; i < l2.length; i++) for (int i = 0; i < l2.length; i++) assertTrue(l2[i].getListeners() != null && l2[i].getListeners().size() == 1);
assertTrue(l2[i].getListeners() != null && l2[i].getListeners().size() == 1);
} }
private static MultiLayerConfiguration getConf() { private static MultiLayerConfiguration getConf() {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345l).list() MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345l).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).dist(new NormalDistribution(0, 1)).build()).layer(1, new OutputLayer.Builder().nIn(2).nOut(1).activation(Activation.TANH).dist(new NormalDistribution(0, 1)).lossFunction(LossFunctions.LossFunction.MSE).build()).build();
.layer(0, new DenseLayer.Builder().nIn(2).nOut(2)
.dist(new NormalDistribution(0, 1)).build())
.layer(1, new OutputLayer.Builder().nIn(2).nOut(1)
.activation(Activation.TANH)
.dist(new NormalDistribution(0, 1)).lossFunction(LossFunctions.LossFunction.MSE).build())
.build();
return conf; return conf;
} }
@Test @Test
public void testInvalidConfig() { @DisplayName("Test Invalid Config")
void testInvalidConfig() {
try { try {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).list() MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).list().build();
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
fail("No exception thrown for invalid configuration"); fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) { } catch (IllegalStateException e) {
//OK // OK
log.error("",e); log.error("", e);
} catch (Throwable e) { } catch (Throwable e) {
log.error("",e); log.error("", e);
fail("Unexpected exception thrown for invalid config"); fail("Unexpected exception thrown for invalid config");
} }
try { try {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).list() MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).list().layer(1, new DenseLayer.Builder().nIn(3).nOut(4).build()).layer(2, new OutputLayer.Builder().nIn(4).nOut(5).build()).build();
.layer(1, new DenseLayer.Builder().nIn(3).nOut(4).build())
.layer(2, new OutputLayer.Builder().nIn(4).nOut(5).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
fail("No exception thrown for invalid configuration"); fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) { } catch (IllegalStateException e) {
//OK // OK
log.info(e.toString()); log.info(e.toString());
} catch (Throwable e) { } catch (Throwable e) {
log.error("",e); log.error("", e);
fail("Unexpected exception thrown for invalid config"); fail("Unexpected exception thrown for invalid config");
} }
try { try {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).list() MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).list().layer(0, new DenseLayer.Builder().nIn(3).nOut(4).build()).layer(2, new OutputLayer.Builder().nIn(4).nOut(5).build()).build();
.layer(0, new DenseLayer.Builder().nIn(3).nOut(4).build())
.layer(2, new OutputLayer.Builder().nIn(4).nOut(5).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
fail("No exception thrown for invalid configuration"); fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) { } catch (IllegalStateException e) {
//OK // OK
log.info(e.toString()); log.info(e.toString());
} catch (Throwable e) { } catch (Throwable e) {
log.error("",e); log.error("", e);
fail("Unexpected exception thrown for invalid config"); fail("Unexpected exception thrown for invalid config");
} }
} }
@Test @Test
public void testListOverloads() { @DisplayName("Test List Overloads")
void testListOverloads() {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).list() MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).list().layer(0, new DenseLayer.Builder().nIn(3).nOut(4).build()).layer(1, new OutputLayer.Builder().nIn(4).nOut(5).activation(Activation.SOFTMAX).build()).build();
.layer(0, new DenseLayer.Builder().nIn(3).nOut(4).build())
.layer(1, new OutputLayer.Builder().nIn(4).nOut(5).activation(Activation.SOFTMAX).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
DenseLayer dl = (DenseLayer) conf.getConf(0).getLayer(); DenseLayer dl = (DenseLayer) conf.getConf(0).getLayer();
assertEquals(3, dl.getNIn()); assertEquals(3, dl.getNIn());
assertEquals(4, dl.getNOut()); assertEquals(4, dl.getNOut());
OutputLayer ol = (OutputLayer) conf.getConf(1).getLayer(); OutputLayer ol = (OutputLayer) conf.getConf(1).getLayer();
assertEquals(4, ol.getNIn()); assertEquals(4, ol.getNIn());
assertEquals(5, ol.getNOut()); assertEquals(5, ol.getNOut());
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345).list().layer(0, new DenseLayer.Builder().nIn(3).nOut(4).build()).layer(1, new OutputLayer.Builder().nIn(4).nOut(5).activation(Activation.SOFTMAX).build()).build();
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345).list()
.layer(0, new DenseLayer.Builder().nIn(3).nOut(4).build())
.layer(1, new OutputLayer.Builder().nIn(4).nOut(5).activation(Activation.SOFTMAX).build())
.build();
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
net2.init(); net2.init();
MultiLayerConfiguration conf3 = new NeuralNetConfiguration.Builder().seed(12345).list(new DenseLayer.Builder().nIn(3).nOut(4).build(), new OutputLayer.Builder().nIn(4).nOut(5).activation(Activation.SOFTMAX).build()).build();
MultiLayerConfiguration conf3 = new NeuralNetConfiguration.Builder().seed(12345)
.list(new DenseLayer.Builder().nIn(3).nOut(4).build(),
new OutputLayer.Builder().nIn(4).nOut(5).activation(Activation.SOFTMAX).build())
.build();
MultiLayerNetwork net3 = new MultiLayerNetwork(conf3); MultiLayerNetwork net3 = new MultiLayerNetwork(conf3);
net3.init(); net3.init();
assertEquals(conf, conf2); assertEquals(conf, conf2);
assertEquals(conf, conf3); assertEquals(conf, conf3);
} }
@Test @Test
public void testBiasLr() { @DisplayName("Test Bias Lr")
//setup the network void testBiasLr() {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(new Adam(1e-2)) // setup the network
.biasUpdater(new Adam(0.5)).list() MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(new Adam(1e-2)).biasUpdater(new Adam(0.5)).list().layer(0, new ConvolutionLayer.Builder(5, 5).nOut(5).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build()).layer(1, new DenseLayer.Builder().nOut(100).activation(Activation.RELU).build()).layer(2, new DenseLayer.Builder().nOut(100).activation(Activation.RELU).build()).layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(10).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build()).setInputType(InputType.convolutional(28, 28, 1)).build();
.layer(0, new ConvolutionLayer.Builder(5, 5).nOut(5).weightInit(WeightInit.XAVIER)
.activation(Activation.RELU).build())
.layer(1, new DenseLayer.Builder().nOut(100).activation(Activation.RELU).build())
.layer(2, new DenseLayer.Builder().nOut(100).activation(Activation.RELU).build())
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(10)
.weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build())
.setInputType(InputType.convolutional(28, 28, 1)).build();
org.deeplearning4j.nn.conf.layers.BaseLayer l0 = (BaseLayer) conf.getConf(0).getLayer(); org.deeplearning4j.nn.conf.layers.BaseLayer l0 = (BaseLayer) conf.getConf(0).getLayer();
org.deeplearning4j.nn.conf.layers.BaseLayer l1 = (BaseLayer) conf.getConf(1).getLayer(); org.deeplearning4j.nn.conf.layers.BaseLayer l1 = (BaseLayer) conf.getConf(1).getLayer();
org.deeplearning4j.nn.conf.layers.BaseLayer l2 = (BaseLayer) conf.getConf(2).getLayer(); org.deeplearning4j.nn.conf.layers.BaseLayer l2 = (BaseLayer) conf.getConf(2).getLayer();
org.deeplearning4j.nn.conf.layers.BaseLayer l3 = (BaseLayer) conf.getConf(3).getLayer(); org.deeplearning4j.nn.conf.layers.BaseLayer l3 = (BaseLayer) conf.getConf(3).getLayer();
assertEquals(0.5, ((Adam) l0.getUpdaterByParam("b")).getLearningRate(), 1e-6);
assertEquals(0.5, ((Adam)l0.getUpdaterByParam("b")).getLearningRate(), 1e-6); assertEquals(1e-2, ((Adam) l0.getUpdaterByParam("W")).getLearningRate(), 1e-6);
assertEquals(1e-2, ((Adam)l0.getUpdaterByParam("W")).getLearningRate(), 1e-6); assertEquals(0.5, ((Adam) l1.getUpdaterByParam("b")).getLearningRate(), 1e-6);
assertEquals(1e-2, ((Adam) l1.getUpdaterByParam("W")).getLearningRate(), 1e-6);
assertEquals(0.5, ((Adam)l1.getUpdaterByParam("b")).getLearningRate(), 1e-6); assertEquals(0.5, ((Adam) l2.getUpdaterByParam("b")).getLearningRate(), 1e-6);
assertEquals(1e-2, ((Adam)l1.getUpdaterByParam("W")).getLearningRate(), 1e-6); assertEquals(1e-2, ((Adam) l2.getUpdaterByParam("W")).getLearningRate(), 1e-6);
assertEquals(0.5, ((Adam) l3.getUpdaterByParam("b")).getLearningRate(), 1e-6);
assertEquals(0.5, ((Adam)l2.getUpdaterByParam("b")).getLearningRate(), 1e-6); assertEquals(1e-2, ((Adam) l3.getUpdaterByParam("W")).getLearningRate(), 1e-6);
assertEquals(1e-2, ((Adam)l2.getUpdaterByParam("W")).getLearningRate(), 1e-6);
assertEquals(0.5, ((Adam)l3.getUpdaterByParam("b")).getLearningRate(), 1e-6);
assertEquals(1e-2, ((Adam)l3.getUpdaterByParam("W")).getLearningRate(), 1e-6);
} }
@Test @Test
public void testInvalidOutputLayer(){ @DisplayName("Test Invalid Output Layer")
void testInvalidOutputLayer() {
/* /*
Test case (invalid configs) Test case (invalid configs)
1. nOut=1 + softmax 1. nOut=1 + softmax
@ -393,32 +295,24 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
4. xent + relu 4. xent + relu
5. mcxent + sigmoid 5. mcxent + sigmoid
*/ */
LossFunctions.LossFunction[] lf = new LossFunctions.LossFunction[] { LossFunctions.LossFunction.MCXENT, LossFunctions.LossFunction.MCXENT, LossFunctions.LossFunction.XENT, LossFunctions.LossFunction.XENT, LossFunctions.LossFunction.MCXENT };
LossFunctions.LossFunction[] lf = new LossFunctions.LossFunction[]{ int[] nOut = new int[] { 1, 3, 3, 3, 3 };
LossFunctions.LossFunction.MCXENT, LossFunctions.LossFunction.MCXENT, LossFunctions.LossFunction.XENT, Activation[] activations = new Activation[] { Activation.SOFTMAX, Activation.TANH, Activation.SOFTMAX, Activation.RELU, Activation.SIGMOID };
LossFunctions.LossFunction.XENT, LossFunctions.LossFunction.MCXENT}; for (int i = 0; i < lf.length; i++) {
int[] nOut = new int[]{1, 3, 3, 3, 3}; for (boolean lossLayer : new boolean[] { false, true }) {
Activation[] activations = new Activation[]{Activation.SOFTMAX, Activation.TANH, Activation.SOFTMAX, Activation.RELU, Activation.SIGMOID}; for (boolean validate : new boolean[] { true, false }) {
for( int i=0; i<lf.length; i++ ){
for(boolean lossLayer : new boolean[]{false, true}) {
for (boolean validate : new boolean[]{true, false}) {
String s = "nOut=" + nOut[i] + ",lossFn=" + lf[i] + ",lossLayer=" + lossLayer + ",validate=" + validate; String s = "nOut=" + nOut[i] + ",lossFn=" + lf[i] + ",lossLayer=" + lossLayer + ",validate=" + validate;
if(nOut[i] == 1 && lossLayer) if (nOut[i] == 1 && lossLayer)
continue; //nOuts are not availabel in loss layer, can't expect it to detect this case // nOuts are not availabel in loss layer, can't expect it to detect this case
continue;
try { try {
new NeuralNetConfiguration.Builder() new NeuralNetConfiguration.Builder().list().layer(new DenseLayer.Builder().nIn(10).nOut(10).build()).layer(!lossLayer ? new OutputLayer.Builder().nIn(10).nOut(nOut[i]).activation(activations[i]).lossFunction(lf[i]).build() : new LossLayer.Builder().activation(activations[i]).lossFunction(lf[i]).build()).validateOutputLayerConfig(validate).build();
.list()
.layer(new DenseLayer.Builder().nIn(10).nOut(10).build())
.layer(!lossLayer ? new OutputLayer.Builder().nIn(10).nOut(nOut[i]).activation(activations[i]).lossFunction(lf[i]).build()
: new LossLayer.Builder().activation(activations[i]).lossFunction(lf[i]).build())
.validateOutputLayerConfig(validate)
.build();
if (validate) { if (validate) {
fail("Expected exception: " + s); fail("Expected exception: " + s);
} }
} catch (DL4JInvalidConfigException e) { } catch (DL4JInvalidConfigException e) {
if (validate) { if (validate) {
assertTrue(s, e.getMessage().toLowerCase().contains("invalid output")); assertTrue(e.getMessage().toLowerCase().contains("invalid output"),s);
} else { } else {
fail("Validation should not be enabled"); fail("Validation should not be enabled");
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.deeplearning4j.nn.conf; package org.deeplearning4j.nn.conf;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
@ -29,61 +28,70 @@ import org.deeplearning4j.nn.conf.layers.SubsamplingLayer.PoolingType;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.params.DefaultParamInitializer; import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.convolution.Convolution; import org.nd4j.linalg.convolution.Convolution;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.Assert.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.Assert.assertFalse; import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
/** /**
* @author Jeffrey Tang. * @author Jeffrey Tang.
*/ */
public class MultiNeuralNetConfLayerBuilderTest extends BaseDL4JTest { @DisplayName("Multi Neural Net Conf Layer Builder Test")
class MultiNeuralNetConfLayerBuilderTest extends BaseDL4JTest {
int numIn = 10; int numIn = 10;
int numOut = 5; int numOut = 5;
double drop = 0.3; double drop = 0.3;
Activation act = Activation.SOFTMAX; Activation act = Activation.SOFTMAX;
PoolingType poolType = PoolingType.MAX; PoolingType poolType = PoolingType.MAX;
int[] filterSize = new int[] {2, 2};
int[] filterSize = new int[] { 2, 2 };
int filterDepth = 6; int filterDepth = 6;
int[] stride = new int[] {2, 2};
int[] stride = new int[] { 2, 2 };
int k = 1; int k = 1;
Convolution.Type convType = Convolution.Type.FULL; Convolution.Type convType = Convolution.Type.FULL;
LossFunction loss = LossFunction.MCXENT; LossFunction loss = LossFunction.MCXENT;
WeightInit weight = WeightInit.XAVIER; WeightInit weight = WeightInit.XAVIER;
double corrupt = 0.4; double corrupt = 0.4;
double sparsity = 0.3; double sparsity = 0.3;
@Test @Test
public void testNeuralNetConfigAPI() { @DisplayName("Test Neural Net Config API")
void testNeuralNetConfigAPI() {
LossFunction newLoss = LossFunction.SQUARED_LOSS; LossFunction newLoss = LossFunction.SQUARED_LOSS;
int newNumIn = numIn + 1; int newNumIn = numIn + 1;
int newNumOut = numOut + 1; int newNumOut = numOut + 1;
WeightInit newWeight = WeightInit.UNIFORM; WeightInit newWeight = WeightInit.UNIFORM;
double newDrop = 0.5; double newDrop = 0.5;
int[] newFS = new int[] {3, 3}; int[] newFS = new int[] { 3, 3 };
int newFD = 7; int newFD = 7;
int[] newStride = new int[] {3, 3}; int[] newStride = new int[] { 3, 3 };
Convolution.Type newConvType = Convolution.Type.SAME; Convolution.Type newConvType = Convolution.Type.SAME;
PoolingType newPoolType = PoolingType.AVG; PoolingType newPoolType = PoolingType.AVG;
double newCorrupt = 0.5; double newCorrupt = 0.5;
double newSparsity = 0.5; double newSparsity = 0.5;
MultiLayerConfiguration multiConf1 = new NeuralNetConfiguration.Builder().list().layer(0, new DenseLayer.Builder().nIn(newNumIn).nOut(newNumOut).activation(act).build()).layer(1, new DenseLayer.Builder().nIn(newNumIn + 1).nOut(newNumOut + 1).activation(act).build()).build();
MultiLayerConfiguration multiConf1 =
new NeuralNetConfiguration.Builder().list()
.layer(0, new DenseLayer.Builder().nIn(newNumIn).nOut(newNumOut).activation(act)
.build())
.layer(1, new DenseLayer.Builder().nIn(newNumIn + 1).nOut(newNumOut + 1)
.activation(act).build())
.build();
NeuralNetConfiguration firstLayer = multiConf1.getConf(0); NeuralNetConfiguration firstLayer = multiConf1.getConf(0);
NeuralNetConfiguration secondLayer = multiConf1.getConf(1); NeuralNetConfiguration secondLayer = multiConf1.getConf(1);
assertFalse(firstLayer.equals(secondLayer)); assertFalse(firstLayer.equals(secondLayer));
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.deeplearning4j.nn.conf; package org.deeplearning4j.nn.conf;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
@ -37,7 +36,7 @@ import org.deeplearning4j.nn.weights.*;
import org.deeplearning4j.optimize.api.ConvexOptimizer; import org.deeplearning4j.optimize.api.ConvexOptimizer;
import org.deeplearning4j.optimize.solvers.StochasticGradientDescent; import org.deeplearning4j.optimize.solvers.StochasticGradientDescent;
import org.deeplearning4j.optimize.stepfunctions.NegativeDefaultStepFunction; import org.deeplearning4j.optimize.stepfunctions.NegativeDefaultStepFunction;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU; import org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU;
@ -46,65 +45,61 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.learning.regularization.Regularization; import org.nd4j.linalg.learning.regularization.Regularization;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotSame;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals; @DisplayName("Neural Net Configuration Test")
import static org.junit.Assert.assertNotSame; class NeuralNetConfigurationTest extends BaseDL4JTest {
import static org.junit.Assert.assertTrue;
public class NeuralNetConfigurationTest extends BaseDL4JTest {
final DataSet trainingSet = createData(); final DataSet trainingSet = createData();
public DataSet createData() { public DataSet createData() {
int numFeatures = 40; int numFeatures = 40;
// have to be at least two or else output layer gradient is a scalar and cause exception
INDArray input = Nd4j.create(2, numFeatures); // have to be at least two or else output layer gradient is a scalar and cause exception INDArray input = Nd4j.create(2, numFeatures);
INDArray labels = Nd4j.create(2, 2); INDArray labels = Nd4j.create(2, 2);
INDArray row0 = Nd4j.create(1, numFeatures); INDArray row0 = Nd4j.create(1, numFeatures);
row0.assign(0.1); row0.assign(0.1);
input.putRow(0, row0); input.putRow(0, row0);
labels.put(0, 1, 1); // set the 4th column // set the 4th column
labels.put(0, 1, 1);
INDArray row1 = Nd4j.create(1, numFeatures); INDArray row1 = Nd4j.create(1, numFeatures);
row1.assign(0.2); row1.assign(0.2);
input.putRow(1, row1); input.putRow(1, row1);
labels.put(1, 0, 1); // set the 2nd column // set the 2nd column
labels.put(1, 0, 1);
return new DataSet(input, labels); return new DataSet(input, labels);
} }
@Test @Test
public void testJson() { @DisplayName("Test Json")
void testJson() {
NeuralNetConfiguration conf = getConfig(1, 1, new WeightInitXavier(), true); NeuralNetConfiguration conf = getConfig(1, 1, new WeightInitXavier(), true);
String json = conf.toJson(); String json = conf.toJson();
NeuralNetConfiguration read = NeuralNetConfiguration.fromJson(json); NeuralNetConfiguration read = NeuralNetConfiguration.fromJson(json);
assertEquals(conf, read); assertEquals(conf, read);
} }
@Test @Test
public void testYaml() { @DisplayName("Test Yaml")
void testYaml() {
NeuralNetConfiguration conf = getConfig(1, 1, new WeightInitXavier(), true); NeuralNetConfiguration conf = getConfig(1, 1, new WeightInitXavier(), true);
String json = conf.toYaml(); String json = conf.toYaml();
NeuralNetConfiguration read = NeuralNetConfiguration.fromYaml(json); NeuralNetConfiguration read = NeuralNetConfiguration.fromYaml(json);
assertEquals(conf, read); assertEquals(conf, read);
} }
@Test @Test
public void testClone() { @DisplayName("Test Clone")
void testClone() {
NeuralNetConfiguration conf = getConfig(1, 1, new WeightInitUniform(), true); NeuralNetConfiguration conf = getConfig(1, 1, new WeightInitUniform(), true);
BaseLayer bl = (BaseLayer) conf.getLayer(); BaseLayer bl = (BaseLayer) conf.getLayer();
conf.setStepFunction(new DefaultStepFunction()); conf.setStepFunction(new DefaultStepFunction());
NeuralNetConfiguration conf2 = conf.clone(); NeuralNetConfiguration conf2 = conf.clone();
assertEquals(conf, conf2); assertEquals(conf, conf2);
assertNotSame(conf, conf2); assertNotSame(conf, conf2);
assertNotSame(conf.getLayer(), conf2.getLayer()); assertNotSame(conf.getLayer(), conf2.getLayer());
@ -112,97 +107,74 @@ public class NeuralNetConfigurationTest extends BaseDL4JTest {
} }
@Test @Test
public void testRNG() { @DisplayName("Test RNG")
DenseLayer layer = new DenseLayer.Builder().nIn(trainingSet.numInputs()).nOut(trainingSet.numOutcomes()) void testRNG() {
.weightInit(WeightInit.UNIFORM).activation(Activation.TANH).build(); DenseLayer layer = new DenseLayer.Builder().nIn(trainingSet.numInputs()).nOut(trainingSet.numOutcomes()).weightInit(WeightInit.UNIFORM).activation(Activation.TANH).build();
NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().seed(123).optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).layer(layer).build();
NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().seed(123)
.optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).layer(layer).build();
long numParams = conf.getLayer().initializer().numParams(conf); long numParams = conf.getLayer().initializer().numParams(conf);
INDArray params = Nd4j.create(1, numParams); INDArray params = Nd4j.create(1, numParams);
Layer model = conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); Layer model = conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType());
INDArray modelWeights = model.getParam(DefaultParamInitializer.WEIGHT_KEY); INDArray modelWeights = model.getParam(DefaultParamInitializer.WEIGHT_KEY);
DenseLayer layer2 = new DenseLayer.Builder().nIn(trainingSet.numInputs()).nOut(trainingSet.numOutcomes()).weightInit(WeightInit.UNIFORM).activation(Activation.TANH).build();
NeuralNetConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(123).optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).layer(layer2).build();
DenseLayer layer2 = new DenseLayer.Builder().nIn(trainingSet.numInputs()).nOut(trainingSet.numOutcomes())
.weightInit(WeightInit.UNIFORM).activation(Activation.TANH).build();
NeuralNetConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(123)
.optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).layer(layer2).build();
long numParams2 = conf2.getLayer().initializer().numParams(conf); long numParams2 = conf2.getLayer().initializer().numParams(conf);
INDArray params2 = Nd4j.create(1, numParams); INDArray params2 = Nd4j.create(1, numParams);
Layer model2 = conf2.getLayer().instantiate(conf2, null, 0, params2, true, params.dataType()); Layer model2 = conf2.getLayer().instantiate(conf2, null, 0, params2, true, params.dataType());
INDArray modelWeights2 = model2.getParam(DefaultParamInitializer.WEIGHT_KEY); INDArray modelWeights2 = model2.getParam(DefaultParamInitializer.WEIGHT_KEY);
assertEquals(modelWeights, modelWeights2); assertEquals(modelWeights, modelWeights2);
} }
@Test @Test
public void testSetSeedSize() { @DisplayName("Test Set Seed Size")
void testSetSeedSize() {
Nd4j.getRandom().setSeed(123); Nd4j.getRandom().setSeed(123);
Layer model = getLayer(trainingSet.numInputs(), trainingSet.numOutcomes(), new WeightInitXavier(), true); Layer model = getLayer(trainingSet.numInputs(), trainingSet.numOutcomes(), new WeightInitXavier(), true);
INDArray modelWeights = model.getParam(DefaultParamInitializer.WEIGHT_KEY); INDArray modelWeights = model.getParam(DefaultParamInitializer.WEIGHT_KEY);
Nd4j.getRandom().setSeed(123); Nd4j.getRandom().setSeed(123);
Layer model2 = getLayer(trainingSet.numInputs(), trainingSet.numOutcomes(), new WeightInitXavier(), true);
INDArray modelWeights2 = model2.getParam(DefaultParamInitializer.WEIGHT_KEY);
assertEquals(modelWeights, modelWeights2);
}
@Test
public void testSetSeedNormalized() {
Nd4j.getRandom().setSeed(123);
Layer model = getLayer(trainingSet.numInputs(), trainingSet.numOutcomes(), new WeightInitXavier(), true);
INDArray modelWeights = model.getParam(DefaultParamInitializer.WEIGHT_KEY);
Nd4j.getRandom().setSeed(123);
Layer model2 = getLayer(trainingSet.numInputs(), trainingSet.numOutcomes(), new WeightInitXavier(), true); Layer model2 = getLayer(trainingSet.numInputs(), trainingSet.numOutcomes(), new WeightInitXavier(), true);
INDArray modelWeights2 = model2.getParam(DefaultParamInitializer.WEIGHT_KEY); INDArray modelWeights2 = model2.getParam(DefaultParamInitializer.WEIGHT_KEY);
assertEquals(modelWeights, modelWeights2); assertEquals(modelWeights, modelWeights2);
} }
@Test @Test
public void testSetSeedXavier() { @DisplayName("Test Set Seed Normalized")
void testSetSeedNormalized() {
Nd4j.getRandom().setSeed(123); Nd4j.getRandom().setSeed(123);
Layer model = getLayer(trainingSet.numInputs(), trainingSet.numOutcomes(), new WeightInitXavier(), true);
INDArray modelWeights = model.getParam(DefaultParamInitializer.WEIGHT_KEY);
Nd4j.getRandom().setSeed(123);
Layer model2 = getLayer(trainingSet.numInputs(), trainingSet.numOutcomes(), new WeightInitXavier(), true);
INDArray modelWeights2 = model2.getParam(DefaultParamInitializer.WEIGHT_KEY);
assertEquals(modelWeights, modelWeights2);
}
@Test
@DisplayName("Test Set Seed Xavier")
void testSetSeedXavier() {
Nd4j.getRandom().setSeed(123);
Layer model = getLayer(trainingSet.numInputs(), trainingSet.numOutcomes(), new WeightInitUniform(), true); Layer model = getLayer(trainingSet.numInputs(), trainingSet.numOutcomes(), new WeightInitUniform(), true);
INDArray modelWeights = model.getParam(DefaultParamInitializer.WEIGHT_KEY); INDArray modelWeights = model.getParam(DefaultParamInitializer.WEIGHT_KEY);
Nd4j.getRandom().setSeed(123); Nd4j.getRandom().setSeed(123);
Layer model2 = getLayer(trainingSet.numInputs(), trainingSet.numOutcomes(), new WeightInitUniform(), true); Layer model2 = getLayer(trainingSet.numInputs(), trainingSet.numOutcomes(), new WeightInitUniform(), true);
INDArray modelWeights2 = model2.getParam(DefaultParamInitializer.WEIGHT_KEY); INDArray modelWeights2 = model2.getParam(DefaultParamInitializer.WEIGHT_KEY);
assertEquals(modelWeights, modelWeights2); assertEquals(modelWeights, modelWeights2);
} }
@Test @Test
public void testSetSeedDistribution() { @DisplayName("Test Set Seed Distribution")
void testSetSeedDistribution() {
Nd4j.getRandom().setSeed(123); Nd4j.getRandom().setSeed(123);
Layer model = getLayer(trainingSet.numInputs(), trainingSet.numOutcomes(), new WeightInitDistribution(new NormalDistribution(1, 1)), true);
Layer model = getLayer(trainingSet.numInputs(), trainingSet.numOutcomes(),
new WeightInitDistribution(new NormalDistribution(1, 1)), true);
INDArray modelWeights = model.getParam(DefaultParamInitializer.WEIGHT_KEY); INDArray modelWeights = model.getParam(DefaultParamInitializer.WEIGHT_KEY);
Nd4j.getRandom().setSeed(123); Nd4j.getRandom().setSeed(123);
Layer model2 = getLayer(trainingSet.numInputs(), trainingSet.numOutcomes(), new WeightInitDistribution(new NormalDistribution(1, 1)), true);
Layer model2 = getLayer(trainingSet.numInputs(), trainingSet.numOutcomes(),
new WeightInitDistribution(new NormalDistribution(1, 1)), true);
INDArray modelWeights2 = model2.getParam(DefaultParamInitializer.WEIGHT_KEY); INDArray modelWeights2 = model2.getParam(DefaultParamInitializer.WEIGHT_KEY);
assertEquals(modelWeights, modelWeights2); assertEquals(modelWeights, modelWeights2);
} }
private static NeuralNetConfiguration getConfig(int nIn, int nOut, IWeightInit weightInit, boolean pretrain) { private static NeuralNetConfiguration getConfig(int nIn, int nOut, IWeightInit weightInit, boolean pretrain) {
DenseLayer layer = new DenseLayer.Builder().nIn(nIn).nOut(nOut).weightInit(weightInit) DenseLayer layer = new DenseLayer.Builder().nIn(nIn).nOut(nOut).weightInit(weightInit).activation(Activation.TANH).build();
.activation(Activation.TANH).build(); NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).layer(layer).build();
NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).layer(layer)
.build();
return conf; return conf;
} }
@ -213,94 +185,75 @@ public class NeuralNetConfigurationTest extends BaseDL4JTest {
return conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); return conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType());
} }
@Test @Test
public void testLearningRateByParam() { @DisplayName("Test Learning Rate By Param")
void testLearningRateByParam() {
double lr = 0.01; double lr = 0.01;
double biasLr = 0.02; double biasLr = 0.02;
int[] nIns = {4, 3, 3}; int[] nIns = { 4, 3, 3 };
int[] nOuts = {3, 3, 3}; int[] nOuts = { 3, 3, 3 };
int oldScore = 1; int oldScore = 1;
int newScore = 1; int newScore = 1;
int iteration = 3; int iteration = 3;
INDArray gradientW = Nd4j.ones(nIns[0], nOuts[0]); INDArray gradientW = Nd4j.ones(nIns[0], nOuts[0]);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.3)).list().layer(0, new DenseLayer.Builder().nIn(nIns[0]).nOut(nOuts[0]).updater(new Sgd(lr)).biasUpdater(new Sgd(biasLr)).build()).layer(1, new BatchNormalization.Builder().nIn(nIns[1]).nOut(nOuts[1]).updater(new Sgd(0.7)).build()).layer(2, new OutputLayer.Builder().nIn(nIns[2]).nOut(nOuts[2]).lossFunction(LossFunctions.LossFunction.MSE).build()).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.3)).list()
.layer(0, new DenseLayer.Builder().nIn(nIns[0]).nOut(nOuts[0])
.updater(new Sgd(lr)).biasUpdater(new Sgd(biasLr)).build())
.layer(1, new BatchNormalization.Builder().nIn(nIns[1]).nOut(nOuts[1]).updater(new Sgd(0.7)).build())
.layer(2, new OutputLayer.Builder().nIn(nIns[2]).nOut(nOuts[2]).lossFunction(LossFunctions.LossFunction.MSE).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
ConvexOptimizer opt = new StochasticGradientDescent(net.getDefaultConfiguration(), new NegativeDefaultStepFunction(), null, net);
ConvexOptimizer opt = new StochasticGradientDescent(net.getDefaultConfiguration(), assertEquals(lr, ((Sgd) net.getLayer(0).conf().getLayer().getUpdaterByParam("W")).getLearningRate(), 1e-4);
new NegativeDefaultStepFunction(), null, net); assertEquals(biasLr, ((Sgd) net.getLayer(0).conf().getLayer().getUpdaterByParam("b")).getLearningRate(), 1e-4);
assertEquals(lr, ((Sgd)net.getLayer(0).conf().getLayer().getUpdaterByParam("W")).getLearningRate(), 1e-4); assertEquals(0.7, ((Sgd) net.getLayer(1).conf().getLayer().getUpdaterByParam("gamma")).getLearningRate(), 1e-4);
assertEquals(biasLr, ((Sgd)net.getLayer(0).conf().getLayer().getUpdaterByParam("b")).getLearningRate(), 1e-4); // From global LR
assertEquals(0.7, ((Sgd)net.getLayer(1).conf().getLayer().getUpdaterByParam("gamma")).getLearningRate(), 1e-4); assertEquals(0.3, ((Sgd) net.getLayer(2).conf().getLayer().getUpdaterByParam("W")).getLearningRate(), 1e-4);
assertEquals(0.3, ((Sgd)net.getLayer(2).conf().getLayer().getUpdaterByParam("W")).getLearningRate(), 1e-4); //From global LR // From global LR
assertEquals(0.3, ((Sgd)net.getLayer(2).conf().getLayer().getUpdaterByParam("W")).getLearningRate(), 1e-4); //From global LR assertEquals(0.3, ((Sgd) net.getLayer(2).conf().getLayer().getUpdaterByParam("W")).getLearningRate(), 1e-4);
} }
@Test @Test
public void testLeakyreluAlpha() { @DisplayName("Test Leakyrelu Alpha")
//FIXME: Make more generic to use neuralnetconfs void testLeakyreluAlpha() {
// FIXME: Make more generic to use neuralnetconfs
int sizeX = 4; int sizeX = 4;
int scaleX = 10; int scaleX = 10;
System.out.println("Here is a leaky vector.."); System.out.println("Here is a leaky vector..");
INDArray leakyVector = Nd4j.linspace(-1, 1, sizeX, Nd4j.dataType()); INDArray leakyVector = Nd4j.linspace(-1, 1, sizeX, Nd4j.dataType());
leakyVector = leakyVector.mul(scaleX); leakyVector = leakyVector.mul(scaleX);
System.out.println(leakyVector); System.out.println(leakyVector);
double myAlpha = 0.5; double myAlpha = 0.5;
System.out.println("======================"); System.out.println("======================");
System.out.println("Exec and Return: Leaky Relu transformation with alpha = 0.5 .."); System.out.println("Exec and Return: Leaky Relu transformation with alpha = 0.5 ..");
System.out.println("======================"); System.out.println("======================");
INDArray outDef = Nd4j.getExecutioner().exec(new LeakyReLU(leakyVector.dup(), myAlpha)); INDArray outDef = Nd4j.getExecutioner().exec(new LeakyReLU(leakyVector.dup(), myAlpha));
System.out.println(outDef); System.out.println(outDef);
String confActivation = "leakyrelu"; String confActivation = "leakyrelu";
Object[] confExtra = {myAlpha}; Object[] confExtra = { myAlpha };
INDArray outMine = Nd4j.getExecutioner().exec(new LeakyReLU(leakyVector.dup(), myAlpha)); INDArray outMine = Nd4j.getExecutioner().exec(new LeakyReLU(leakyVector.dup(), myAlpha));
System.out.println("======================"); System.out.println("======================");
System.out.println("Exec and Return: Leaky Relu transformation with a value via getOpFactory"); System.out.println("Exec and Return: Leaky Relu transformation with a value via getOpFactory");
System.out.println("======================"); System.out.println("======================");
System.out.println(outMine); System.out.println(outMine);
// Test equality for ndarray elementwise
//Test equality for ndarray elementwise // assertArrayEquals(..)
//assertArrayEquals(..)
} }
@Test @Test
public void testL1L2ByParam() { @DisplayName("Test L 1 L 2 By Param")
void testL1L2ByParam() {
double l1 = 0.01; double l1 = 0.01;
double l2 = 0.07; double l2 = 0.07;
int[] nIns = {4, 3, 3}; int[] nIns = { 4, 3, 3 };
int[] nOuts = {3, 3, 3}; int[] nOuts = { 3, 3, 3 };
int oldScore = 1; int oldScore = 1;
int newScore = 1; int newScore = 1;
int iteration = 3; int iteration = 3;
INDArray gradientW = Nd4j.ones(nIns[0], nOuts[0]); INDArray gradientW = Nd4j.ones(nIns[0], nOuts[0]);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().l1(l1).l2(l2).list().layer(0, new DenseLayer.Builder().nIn(nIns[0]).nOut(nOuts[0]).build()).layer(1, new BatchNormalization.Builder().nIn(nIns[1]).nOut(nOuts[1]).l2(0.5).build()).layer(2, new OutputLayer.Builder().nIn(nIns[2]).nOut(nOuts[2]).lossFunction(LossFunctions.LossFunction.MSE).build()).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().l1(l1)
.l2(l2).list()
.layer(0, new DenseLayer.Builder().nIn(nIns[0]).nOut(nOuts[0]).build())
.layer(1, new BatchNormalization.Builder().nIn(nIns[1]).nOut(nOuts[1]).l2(0.5).build())
.layer(2, new OutputLayer.Builder().nIn(nIns[2]).nOut(nOuts[2]).lossFunction(LossFunctions.LossFunction.MSE).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
ConvexOptimizer opt = new StochasticGradientDescent(net.getDefaultConfiguration(), new NegativeDefaultStepFunction(), null, net);
ConvexOptimizer opt = new StochasticGradientDescent(net.getDefaultConfiguration(),
new NegativeDefaultStepFunction(), null, net);
assertEquals(l1, TestUtils.getL1(net.getLayer(0).conf().getLayer().getRegularizationByParam("W")), 1e-4); assertEquals(l1, TestUtils.getL1(net.getLayer(0).conf().getLayer().getRegularizationByParam("W")), 1e-4);
List<Regularization> r = net.getLayer(0).conf().getLayer().getRegularizationByParam("b"); List<Regularization> r = net.getLayer(0).conf().getLayer().getRegularizationByParam("b");
assertEquals(0, r.size()); assertEquals(0, r.size());
r = net.getLayer(1).conf().getLayer().getRegularizationByParam("beta"); r = net.getLayer(1).conf().getLayer().getRegularizationByParam("beta");
assertTrue(r == null || r.isEmpty()); assertTrue(r == null || r.isEmpty());
r = net.getLayer(1).conf().getLayer().getRegularizationByParam("gamma"); r = net.getLayer(1).conf().getLayer().getRegularizationByParam("gamma");
@ -315,14 +268,10 @@ public class NeuralNetConfigurationTest extends BaseDL4JTest {
} }
@Test @Test
public void testLayerPretrainConfig() { @DisplayName("Test Layer Pretrain Config")
void testLayerPretrainConfig() {
boolean pretrain = true; boolean pretrain = true;
VariationalAutoencoder layer = new VariationalAutoencoder.Builder().nIn(10).nOut(5).updater(new Sgd(1e-1)).lossFunction(LossFunctions.LossFunction.KL_DIVERGENCE).build();
VariationalAutoencoder layer = new VariationalAutoencoder.Builder()
.nIn(10).nOut(5).updater(new Sgd(1e-1))
.lossFunction(LossFunctions.LossFunction.KL_DIVERGENCE).build();
NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().seed(42).layer(layer).build(); NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().seed(42).layer(layer).build();
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.deeplearning4j.nn.conf.graph; package org.deeplearning4j.nn.conf.graph;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
@ -30,8 +29,8 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.Assert; import org.junit.jupiter.api.Assertions;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.impl.ActivationSigmoid; import org.nd4j.linalg.activations.impl.ActivationSigmoid;
import org.nd4j.linalg.activations.impl.ActivationTanH; import org.nd4j.linalg.activations.impl.ActivationTanH;
@ -43,194 +42,99 @@ import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import org.nd4j.linalg.ops.transforms.Transforms; import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.common.primitives.Pair; import org.nd4j.common.primitives.Pair;
import java.util.Map; import java.util.Map;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertArrayEquals; @DisplayName("Element Wise Vertex Test")
class ElementWiseVertexTest extends BaseDL4JTest {
public class ElementWiseVertexTest extends BaseDL4JTest {
@Test @Test
public void testElementWiseVertexNumParams() { @DisplayName("Test Element Wise Vertex Num Params")
void testElementWiseVertexNumParams() {
/* /*
* https://github.com/eclipse/deeplearning4j/pull/3514#issuecomment-307754386 * https://github.com/eclipse/deeplearning4j/pull/3514#issuecomment-307754386
* from @agibsonccc: check for the basics: like 0 numParams * from @agibsonccc: check for the basics: like 0 numParams
*/ */
ElementWiseVertex.Op[] ops = new ElementWiseVertex.Op[] { ElementWiseVertex.Op.Add, ElementWiseVertex.Op.Subtract, ElementWiseVertex.Op.Product };
ElementWiseVertex.Op ops[] = new ElementWiseVertex.Op[] {ElementWiseVertex.Op.Add,
ElementWiseVertex.Op.Subtract, ElementWiseVertex.Op.Product};
for (ElementWiseVertex.Op op : ops) { for (ElementWiseVertex.Op op : ops) {
ElementWiseVertex ewv = new ElementWiseVertex(op); ElementWiseVertex ewv = new ElementWiseVertex(op);
Assert.assertEquals(0, ewv.numParams(true)); Assertions.assertEquals(0, ewv.numParams(true));
Assert.assertEquals(0, ewv.numParams(false)); Assertions.assertEquals(0, ewv.numParams(false));
} }
} }
@Test @Test
public void testElementWiseVertexForwardAdd() { @DisplayName("Test Element Wise Vertex Forward Add")
void testElementWiseVertexForwardAdd() {
int batchsz = 24; int batchsz = 24;
int featuresz = 17; int featuresz = 17;
ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().graphBuilder() ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("input1", "input2", "input3").addLayer("denselayer", new DenseLayer.Builder().nIn(featuresz).nOut(1).activation(Activation.IDENTITY).build(), "input1").addVertex("elementwiseAdd", new ElementWiseVertex(ElementWiseVertex.Op.Add), "input1", "input2", "input3").addLayer("Add", new ActivationLayer.Builder().activation(Activation.IDENTITY).build(), "elementwiseAdd").setOutputs("Add", "denselayer").build();
.addInputs("input1", "input2", "input3")
.addLayer("denselayer",
new DenseLayer.Builder().nIn(featuresz).nOut(1).activation(Activation.IDENTITY)
.build(),
"input1")
/* denselayer is not actually used, but it seems that you _need_ to have trainable parameters, otherwise, you get
* Invalid shape: Requested INDArray shape [1, 0] contains dimension size values < 1 (all dimensions must be 1 or more)
* at org.nd4j.linalg.factory.Nd4j.checkShapeValues(Nd4j.java:4877)
* at org.nd4j.linalg.factory.Nd4j.create(Nd4j.java:4867)
* at org.nd4j.linalg.factory.Nd4j.create(Nd4j.java:4820)
* at org.nd4j.linalg.factory.Nd4j.create(Nd4j.java:3948)
* at org.deeplearning4j.nn.graph.ComputationGraph.init(ComputationGraph.java:409)
* at org.deeplearning4j.nn.graph.ComputationGraph.init(ComputationGraph.java:341)
*/
.addVertex("elementwiseAdd", new ElementWiseVertex(ElementWiseVertex.Op.Add), "input1",
"input2", "input3")
.addLayer("Add", new ActivationLayer.Builder().activation(Activation.IDENTITY).build(),
"elementwiseAdd")
.setOutputs("Add", "denselayer").build();
ComputationGraph cg = new ComputationGraph(cgc); ComputationGraph cg = new ComputationGraph(cgc);
cg.init(); cg.init();
INDArray input1 = Nd4j.rand(batchsz, featuresz); INDArray input1 = Nd4j.rand(batchsz, featuresz);
INDArray input2 = Nd4j.rand(batchsz, featuresz); INDArray input2 = Nd4j.rand(batchsz, featuresz);
INDArray input3 = Nd4j.rand(batchsz, featuresz); INDArray input3 = Nd4j.rand(batchsz, featuresz);
INDArray target = input1.dup().addi(input2).addi(input3); INDArray target = input1.dup().addi(input2).addi(input3);
INDArray output = cg.output(input1, input2, input3)[0]; INDArray output = cg.output(input1, input2, input3)[0];
INDArray squared = output.sub(target.castTo(output.dataType())); INDArray squared = output.sub(target.castTo(output.dataType()));
double rms = squared.mul(squared).sumNumber().doubleValue(); double rms = squared.mul(squared).sumNumber().doubleValue();
Assert.assertEquals(0.0, rms, this.epsilon); Assertions.assertEquals(0.0, rms, this.epsilon);
} }
@Test @Test
public void testElementWiseVertexForwardProduct() { @DisplayName("Test Element Wise Vertex Forward Product")
void testElementWiseVertexForwardProduct() {
int batchsz = 24; int batchsz = 24;
int featuresz = 17; int featuresz = 17;
ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().graphBuilder() ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("input1", "input2", "input3").addLayer("denselayer", new DenseLayer.Builder().nIn(featuresz).nOut(1).activation(Activation.IDENTITY).build(), "input1").addVertex("elementwiseProduct", new ElementWiseVertex(ElementWiseVertex.Op.Product), "input1", "input2", "input3").addLayer("Product", new ActivationLayer.Builder().activation(Activation.IDENTITY).build(), "elementwiseProduct").setOutputs("Product", "denselayer").build();
.addInputs("input1", "input2", "input3")
.addLayer("denselayer",
new DenseLayer.Builder().nIn(featuresz).nOut(1).activation(Activation.IDENTITY)
.build(),
"input1")
/* denselayer is not actually used, but it seems that you _need_ to have trainable parameters, otherwise, you get
* Invalid shape: Requested INDArray shape [1, 0] contains dimension size values < 1 (all dimensions must be 1 or more)
* at org.nd4j.linalg.factory.Nd4j.checkShapeValues(Nd4j.java:4877)
* at org.nd4j.linalg.factory.Nd4j.create(Nd4j.java:4867)
* at org.nd4j.linalg.factory.Nd4j.create(Nd4j.java:4820)
* at org.nd4j.linalg.factory.Nd4j.create(Nd4j.java:3948)
* at org.deeplearning4j.nn.graph.ComputationGraph.init(ComputationGraph.java:409)
* at org.deeplearning4j.nn.graph.ComputationGraph.init(ComputationGraph.java:341)
*/
.addVertex("elementwiseProduct", new ElementWiseVertex(ElementWiseVertex.Op.Product), "input1",
"input2", "input3")
.addLayer("Product", new ActivationLayer.Builder().activation(Activation.IDENTITY).build(),
"elementwiseProduct")
.setOutputs("Product", "denselayer").build();
ComputationGraph cg = new ComputationGraph(cgc); ComputationGraph cg = new ComputationGraph(cgc);
cg.init(); cg.init();
INDArray input1 = Nd4j.rand(batchsz, featuresz); INDArray input1 = Nd4j.rand(batchsz, featuresz);
INDArray input2 = Nd4j.rand(batchsz, featuresz); INDArray input2 = Nd4j.rand(batchsz, featuresz);
INDArray input3 = Nd4j.rand(batchsz, featuresz); INDArray input3 = Nd4j.rand(batchsz, featuresz);
INDArray target = input1.dup().muli(input2).muli(input3); INDArray target = input1.dup().muli(input2).muli(input3);
INDArray output = cg.output(input1, input2, input3)[0]; INDArray output = cg.output(input1, input2, input3)[0];
INDArray squared = output.sub(target.castTo(output.dataType())); INDArray squared = output.sub(target.castTo(output.dataType()));
double rms = squared.mul(squared).sumNumber().doubleValue(); double rms = squared.mul(squared).sumNumber().doubleValue();
Assert.assertEquals(0.0, rms, this.epsilon); Assertions.assertEquals(0.0, rms, this.epsilon);
} }
@Test @Test
public void testElementWiseVertexForwardSubtract() { @DisplayName("Test Element Wise Vertex Forward Subtract")
void testElementWiseVertexForwardSubtract() {
int batchsz = 24; int batchsz = 24;
int featuresz = 17; int featuresz = 17;
ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().graphBuilder() ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("input1", "input2").addLayer("denselayer", new DenseLayer.Builder().nIn(featuresz).nOut(1).activation(Activation.IDENTITY).build(), "input1").addVertex("elementwiseSubtract", new ElementWiseVertex(ElementWiseVertex.Op.Subtract), "input1", "input2").addLayer("Subtract", new ActivationLayer.Builder().activation(Activation.IDENTITY).build(), "elementwiseSubtract").setOutputs("Subtract", "denselayer").build();
.addInputs("input1", "input2")
.addLayer("denselayer",
new DenseLayer.Builder().nIn(featuresz).nOut(1).activation(Activation.IDENTITY)
.build(),
"input1")
/* denselayer is not actually used, but it seems that you _need_ to have trainable parameters, otherwise, you get
* Invalid shape: Requested INDArray shape [1, 0] contains dimension size values < 1 (all dimensions must be 1 or more)
* at org.nd4j.linalg.factory.Nd4j.checkShapeValues(Nd4j.java:4877)
* at org.nd4j.linalg.factory.Nd4j.create(Nd4j.java:4867)
* at org.nd4j.linalg.factory.Nd4j.create(Nd4j.java:4820)
* at org.nd4j.linalg.factory.Nd4j.create(Nd4j.java:3948)
* at org.deeplearning4j.nn.graph.ComputationGraph.init(ComputationGraph.java:409)
* at org.deeplearning4j.nn.graph.ComputationGraph.init(ComputationGraph.java:341)
*/
.addVertex("elementwiseSubtract", new ElementWiseVertex(ElementWiseVertex.Op.Subtract),
"input1", "input2")
.addLayer("Subtract", new ActivationLayer.Builder().activation(Activation.IDENTITY).build(),
"elementwiseSubtract")
.setOutputs("Subtract", "denselayer").build();
ComputationGraph cg = new ComputationGraph(cgc); ComputationGraph cg = new ComputationGraph(cgc);
cg.init(); cg.init();
INDArray input1 = Nd4j.rand(batchsz, featuresz); INDArray input1 = Nd4j.rand(batchsz, featuresz);
INDArray input2 = Nd4j.rand(batchsz, featuresz); INDArray input2 = Nd4j.rand(batchsz, featuresz);
INDArray target = input1.dup().subi(input2); INDArray target = input1.dup().subi(input2);
INDArray output = cg.output(input1, input2)[0]; INDArray output = cg.output(input1, input2)[0];
INDArray squared = output.sub(target); INDArray squared = output.sub(target);
double rms = Math.sqrt(squared.mul(squared).sumNumber().doubleValue()); double rms = Math.sqrt(squared.mul(squared).sumNumber().doubleValue());
Assert.assertEquals(0.0, rms, this.epsilon); Assertions.assertEquals(0.0, rms, this.epsilon);
} }
@Test @Test
public void testElementWiseVertexFullAdd() { @DisplayName("Test Element Wise Vertex Full Add")
void testElementWiseVertexFullAdd() {
int batchsz = 24; int batchsz = 24;
int featuresz = 17; int featuresz = 17;
int midsz = 13; int midsz = 13;
int outputsz = 11; int outputsz = 11;
ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER) ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).dataType(DataType.DOUBLE).biasInit(0.0).updater(new Sgd()).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder().addInputs("input1", "input2", "input3").addLayer("dense1", new DenseLayer.Builder().nIn(featuresz).nOut(midsz).activation(new ActivationTanH()).build(), "input1").addLayer("dense2", new DenseLayer.Builder().nIn(featuresz).nOut(midsz).activation(new ActivationTanH()).build(), "input2").addLayer("dense3", new DenseLayer.Builder().nIn(featuresz).nOut(midsz).activation(new ActivationTanH()).build(), "input3").addVertex("elementwiseAdd", new ElementWiseVertex(ElementWiseVertex.Op.Add), "dense1", "dense2", "dense3").addLayer("output", new OutputLayer.Builder().nIn(midsz).nOut(outputsz).activation(new ActivationSigmoid()).lossFunction(LossFunction.MSE).build(), "elementwiseAdd").setOutputs("output").build();
.dataType(DataType.DOUBLE)
.biasInit(0.0).updater(new Sgd())
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder()
.addInputs("input1", "input2", "input3")
.addLayer("dense1",
new DenseLayer.Builder().nIn(featuresz).nOut(midsz)
.activation(new ActivationTanH()).build(),
"input1")
.addLayer("dense2",
new DenseLayer.Builder().nIn(featuresz).nOut(midsz)
.activation(new ActivationTanH()).build(),
"input2")
.addLayer("dense3",
new DenseLayer.Builder().nIn(featuresz).nOut(midsz)
.activation(new ActivationTanH()).build(),
"input3")
.addVertex("elementwiseAdd", new ElementWiseVertex(ElementWiseVertex.Op.Add), "dense1",
"dense2", "dense3")
.addLayer("output",
new OutputLayer.Builder().nIn(midsz).nOut(outputsz)
.activation(new ActivationSigmoid())
.lossFunction(LossFunction.MSE).build(),
"elementwiseAdd")
.setOutputs("output").build();
ComputationGraph cg = new ComputationGraph(cgc); ComputationGraph cg = new ComputationGraph(cgc);
cg.init(); cg.init();
INDArray input1 = Nd4j.rand(new int[] {batchsz, featuresz}, new UniformDistribution(-1, 1)); INDArray input1 = Nd4j.rand(new int[] { batchsz, featuresz }, new UniformDistribution(-1, 1));
INDArray input2 = Nd4j.rand(new int[] {batchsz, featuresz}, new UniformDistribution(-1, 1)); INDArray input2 = Nd4j.rand(new int[] { batchsz, featuresz }, new UniformDistribution(-1, 1));
INDArray input3 = Nd4j.rand(new int[] {batchsz, featuresz}, new UniformDistribution(-1, 1)); INDArray input3 = Nd4j.rand(new int[] { batchsz, featuresz }, new UniformDistribution(-1, 1));
INDArray target = nullsafe(Nd4j.rand(new int[] {batchsz, outputsz}, new UniformDistribution(0, 1))); INDArray target = nullsafe(Nd4j.rand(new int[] { batchsz, outputsz }, new UniformDistribution(0, 1)));
cg.setInputs(input1, input2, input3); cg.setInputs(input1, input2, input3);
cg.setLabels(target); cg.setLabels(target);
cg.computeGradientAndScore(); cg.computeGradientAndScore();
// Let's figure out what our params are now. // Let's figure out what our params are now.
Map<String, INDArray> params = cg.paramTable(); Map<String, INDArray> params = cg.paramTable();
INDArray dense1_W = nullsafe(params.get("dense1_W")); INDArray dense1_W = nullsafe(params.get("dense1_W"));
@ -241,35 +145,22 @@ public class ElementWiseVertexTest extends BaseDL4JTest {
INDArray dense3_b = nullsafe(params.get("dense3_b")); INDArray dense3_b = nullsafe(params.get("dense3_b"));
INDArray output_W = nullsafe(params.get("output_W")); INDArray output_W = nullsafe(params.get("output_W"));
INDArray output_b = nullsafe(params.get("output_b")); INDArray output_b = nullsafe(params.get("output_b"));
// Now, let's calculate what we expect the output to be. // Now, let's calculate what we expect the output to be.
INDArray mh = input1.mmul(dense1_W).addi(dense1_b.repmat(batchsz, 1)); INDArray mh = input1.mmul(dense1_W).addi(dense1_b.repmat(batchsz, 1));
INDArray m = (Transforms.tanh(mh)); INDArray m = (Transforms.tanh(mh));
INDArray nh = input2.mmul(dense2_W).addi(dense2_b.repmat(batchsz, 1)); INDArray nh = input2.mmul(dense2_W).addi(dense2_b.repmat(batchsz, 1));
INDArray n = (Transforms.tanh(nh)); INDArray n = (Transforms.tanh(nh));
INDArray oh = input3.mmul(dense3_W).addi(dense3_b.repmat(batchsz, 1)); INDArray oh = input3.mmul(dense3_W).addi(dense3_b.repmat(batchsz, 1));
INDArray o = (Transforms.tanh(oh)); INDArray o = (Transforms.tanh(oh));
INDArray middle = Nd4j.zeros(batchsz, midsz); INDArray middle = Nd4j.zeros(batchsz, midsz);
middle.addi(m).addi(n).addi(o); middle.addi(m).addi(n).addi(o);
INDArray expect = Nd4j.zeros(batchsz, outputsz); INDArray expect = Nd4j.zeros(batchsz, outputsz);
expect.addi(Transforms.sigmoid(middle.mmul(output_W).addi(output_b.repmat(batchsz, 1)))); expect.addi(Transforms.sigmoid(middle.mmul(output_W).addi(output_b.repmat(batchsz, 1))));
INDArray output = nullsafe(cg.output(input1, input2, input3)[0]); INDArray output = nullsafe(cg.output(input1, input2, input3)[0]);
Assertions.assertEquals(0.0, mse(output, expect), this.epsilon);
Assert.assertEquals(0.0, mse(output, expect), this.epsilon);
Pair<Gradient, Double> pgd = cg.gradientAndScore(); Pair<Gradient, Double> pgd = cg.gradientAndScore();
double score = pgd.getSecond(); double score = pgd.getSecond();
Assert.assertEquals(score, mse(output, target), this.epsilon); Assertions.assertEquals(score, mse(output, target), this.epsilon);
Map<String, INDArray> gradients = pgd.getFirst().gradientForVariable(); Map<String, INDArray> gradients = pgd.getFirst().gradientForVariable();
/* /*
* So. Let's say we have inputs a, b, c * So. Let's say we have inputs a, b, c
@ -305,27 +196,23 @@ public class ElementWiseVertexTest extends BaseDL4JTest {
* dmh/db1 = Nd4j.ones(1, batchsz) * dmh/db1 = Nd4j.ones(1, batchsz)
* *
*/ */
INDArray y = output; INDArray y = output;
INDArray s = middle; INDArray s = middle;
INDArray W4 = output_W; INDArray W4 = output_W;
INDArray dEdy = Nd4j.zeros(target.shape()); INDArray dEdy = Nd4j.zeros(target.shape());
dEdy.addi(y).subi(target).muli(2); // This should be of size batchsz x outputsz // This should be of size batchsz x outputsz
dEdy.divi(target.shape()[1]); // Why? Because the LossFunction divides by the _element size_ of the output. dEdy.addi(y).subi(target).muli(2);
// Why? Because the LossFunction divides by the _element size_ of the output.
INDArray dydyh = y.mul(y.mul(-1).add(1)); // This is of size batchsz x outputsz dEdy.divi(target.shape()[1]);
// This is of size batchsz x outputsz
INDArray dydyh = y.mul(y.mul(-1).add(1));
INDArray dEdyh = dydyh.mul(dEdy); INDArray dEdyh = dydyh.mul(dEdy);
INDArray dyhdW4 = s.transpose(); INDArray dyhdW4 = s.transpose();
INDArray dEdW4 = nullsafe(dyhdW4.mmul(dEdyh)); INDArray dEdW4 = nullsafe(dyhdW4.mmul(dEdyh));
INDArray dyhdb4 = Nd4j.ones(1, batchsz); INDArray dyhdb4 = Nd4j.ones(1, batchsz);
INDArray dEdb4 = nullsafe(dyhdb4.mmul(dEdyh)); INDArray dEdb4 = nullsafe(dyhdb4.mmul(dEdyh));
INDArray dyhds = W4.transpose(); INDArray dyhds = W4.transpose();
INDArray dEds = dEdyh.mmul(dyhds); INDArray dEds = dEdyh.mmul(dyhds);
INDArray dsdm = Nd4j.ones(batchsz, midsz); INDArray dsdm = Nd4j.ones(batchsz, midsz);
INDArray dEdm = dsdm.mul(dEds); INDArray dEdm = dsdm.mul(dEds);
INDArray dmdmh = (m.mul(m)).mul(-1).add(1); INDArray dmdmh = (m.mul(m)).mul(-1).add(1);
@ -334,7 +221,6 @@ public class ElementWiseVertexTest extends BaseDL4JTest {
INDArray dEdW1 = nullsafe(dmhdW1.mmul(dEdmh)); INDArray dEdW1 = nullsafe(dmhdW1.mmul(dEdmh));
INDArray dmhdb1 = Nd4j.ones(1, batchsz); INDArray dmhdb1 = Nd4j.ones(1, batchsz);
INDArray dEdb1 = nullsafe(dmhdb1.mmul(dEdmh)); INDArray dEdb1 = nullsafe(dmhdb1.mmul(dEdmh));
INDArray dsdn = Nd4j.ones(batchsz, midsz); INDArray dsdn = Nd4j.ones(batchsz, midsz);
INDArray dEdn = dsdn.mul(dEds); INDArray dEdn = dsdn.mul(dEds);
INDArray dndnh = (n.mul(n)).mul(-1).add(1); INDArray dndnh = (n.mul(n)).mul(-1).add(1);
@ -343,7 +229,6 @@ public class ElementWiseVertexTest extends BaseDL4JTest {
INDArray dEdW2 = nullsafe(dnhdW2.mmul(dEdnh)); INDArray dEdW2 = nullsafe(dnhdW2.mmul(dEdnh));
INDArray dnhdb2 = Nd4j.ones(1, batchsz); INDArray dnhdb2 = Nd4j.ones(1, batchsz);
INDArray dEdb2 = nullsafe(dnhdb2.mmul(dEdnh)); INDArray dEdb2 = nullsafe(dnhdb2.mmul(dEdnh));
INDArray dsdo = Nd4j.ones(batchsz, midsz); INDArray dsdo = Nd4j.ones(batchsz, midsz);
INDArray dEdo = dsdo.mul(dEds); INDArray dEdo = dsdo.mul(dEds);
INDArray dodoh = (o.mul(o)).mul(-1).add(1); INDArray dodoh = (o.mul(o)).mul(-1).add(1);
@ -352,61 +237,33 @@ public class ElementWiseVertexTest extends BaseDL4JTest {
INDArray dEdW3 = nullsafe(dohdW3.mmul(dEdoh)); INDArray dEdW3 = nullsafe(dohdW3.mmul(dEdoh));
INDArray dohdb3 = Nd4j.ones(1, batchsz); INDArray dohdb3 = Nd4j.ones(1, batchsz);
INDArray dEdb3 = nullsafe(dohdb3.mmul(dEdoh)); INDArray dEdb3 = nullsafe(dohdb3.mmul(dEdoh));
Assertions.assertEquals(0, mse(nullsafe(gradients.get("output_W")), dEdW4), this.epsilon);
Assertions.assertEquals(0, mse(nullsafe(gradients.get("output_b")), dEdb4), this.epsilon);
Assert.assertEquals(0, mse(nullsafe(gradients.get("output_W")), dEdW4), this.epsilon); Assertions.assertEquals(0, mse(nullsafe(gradients.get("dense1_W")), dEdW1), this.epsilon);
Assert.assertEquals(0, mse(nullsafe(gradients.get("output_b")), dEdb4), this.epsilon); Assertions.assertEquals(0, mse(nullsafe(gradients.get("dense1_b")), dEdb1), this.epsilon);
Assert.assertEquals(0, mse(nullsafe(gradients.get("dense1_W")), dEdW1), this.epsilon); Assertions.assertEquals(0, mse(nullsafe(gradients.get("dense2_W")), dEdW2), this.epsilon);
Assert.assertEquals(0, mse(nullsafe(gradients.get("dense1_b")), dEdb1), this.epsilon); Assertions.assertEquals(0, mse(nullsafe(gradients.get("dense2_b")), dEdb2), this.epsilon);
Assert.assertEquals(0, mse(nullsafe(gradients.get("dense2_W")), dEdW2), this.epsilon); Assertions.assertEquals(0, mse(nullsafe(gradients.get("dense3_W")), dEdW3), this.epsilon);
Assert.assertEquals(0, mse(nullsafe(gradients.get("dense2_b")), dEdb2), this.epsilon); Assertions.assertEquals(0, mse(nullsafe(gradients.get("dense3_b")), dEdb3), this.epsilon);
Assert.assertEquals(0, mse(nullsafe(gradients.get("dense3_W")), dEdW3), this.epsilon);
Assert.assertEquals(0, mse(nullsafe(gradients.get("dense3_b")), dEdb3), this.epsilon);
} }
@Test @Test
public void testElementWiseVertexFullProduct() { @DisplayName("Test Element Wise Vertex Full Product")
void testElementWiseVertexFullProduct() {
int batchsz = 24; int batchsz = 24;
int featuresz = 17; int featuresz = 17;
int midsz = 13; int midsz = 13;
int outputsz = 11; int outputsz = 11;
ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER) ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).dataType(DataType.DOUBLE).biasInit(0.0).updater(new Sgd()).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder().addInputs("input1", "input2", "input3").addLayer("dense1", new DenseLayer.Builder().nIn(featuresz).nOut(midsz).activation(new ActivationTanH()).build(), "input1").addLayer("dense2", new DenseLayer.Builder().nIn(featuresz).nOut(midsz).activation(new ActivationTanH()).build(), "input2").addLayer("dense3", new DenseLayer.Builder().nIn(featuresz).nOut(midsz).activation(new ActivationTanH()).build(), "input3").addVertex("elementwiseProduct", new ElementWiseVertex(ElementWiseVertex.Op.Product), "dense1", "dense2", "dense3").addLayer("output", new OutputLayer.Builder().nIn(midsz).nOut(outputsz).activation(new ActivationSigmoid()).lossFunction(LossFunction.MSE).build(), "elementwiseProduct").setOutputs("output").build();
.dataType(DataType.DOUBLE)
.biasInit(0.0).updater(new Sgd())
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder()
.addInputs("input1", "input2", "input3")
.addLayer("dense1",
new DenseLayer.Builder().nIn(featuresz).nOut(midsz)
.activation(new ActivationTanH()).build(),
"input1")
.addLayer("dense2",
new DenseLayer.Builder().nIn(featuresz).nOut(midsz)
.activation(new ActivationTanH()).build(),
"input2")
.addLayer("dense3",
new DenseLayer.Builder().nIn(featuresz).nOut(midsz)
.activation(new ActivationTanH()).build(),
"input3")
.addVertex("elementwiseProduct", new ElementWiseVertex(ElementWiseVertex.Op.Product), "dense1",
"dense2", "dense3")
.addLayer("output",
new OutputLayer.Builder().nIn(midsz).nOut(outputsz)
.activation(new ActivationSigmoid())
.lossFunction(LossFunction.MSE).build(),
"elementwiseProduct")
.setOutputs("output").build();
ComputationGraph cg = new ComputationGraph(cgc); ComputationGraph cg = new ComputationGraph(cgc);
cg.init(); cg.init();
INDArray input1 = Nd4j.rand(new int[] {batchsz, featuresz}, new UniformDistribution(-1, 1)); INDArray input1 = Nd4j.rand(new int[] { batchsz, featuresz }, new UniformDistribution(-1, 1));
INDArray input2 = Nd4j.rand(new int[] {batchsz, featuresz}, new UniformDistribution(-1, 1)); INDArray input2 = Nd4j.rand(new int[] { batchsz, featuresz }, new UniformDistribution(-1, 1));
INDArray input3 = Nd4j.rand(new int[] {batchsz, featuresz}, new UniformDistribution(-1, 1)); INDArray input3 = Nd4j.rand(new int[] { batchsz, featuresz }, new UniformDistribution(-1, 1));
INDArray target = nullsafe(Nd4j.rand(new int[] {batchsz, outputsz}, new UniformDistribution(0, 1))); INDArray target = nullsafe(Nd4j.rand(new int[] { batchsz, outputsz }, new UniformDistribution(0, 1)));
cg.setInputs(input1, input2, input3); cg.setInputs(input1, input2, input3);
cg.setLabels(target); cg.setLabels(target);
cg.computeGradientAndScore(); cg.computeGradientAndScore();
// Let's figure out what our params are now. // Let's figure out what our params are now.
Map<String, INDArray> params = cg.paramTable(); Map<String, INDArray> params = cg.paramTable();
INDArray dense1_W = nullsafe(params.get("dense1_W")); INDArray dense1_W = nullsafe(params.get("dense1_W"));
@ -417,35 +274,22 @@ public class ElementWiseVertexTest extends BaseDL4JTest {
INDArray dense3_b = nullsafe(params.get("dense3_b")); INDArray dense3_b = nullsafe(params.get("dense3_b"));
INDArray output_W = nullsafe(params.get("output_W")); INDArray output_W = nullsafe(params.get("output_W"));
INDArray output_b = nullsafe(params.get("output_b")); INDArray output_b = nullsafe(params.get("output_b"));
// Now, let's calculate what we expect the output to be. // Now, let's calculate what we expect the output to be.
INDArray mh = input1.mmul(dense1_W).addi(dense1_b.repmat(batchsz, 1)); INDArray mh = input1.mmul(dense1_W).addi(dense1_b.repmat(batchsz, 1));
INDArray m = (Transforms.tanh(mh)); INDArray m = (Transforms.tanh(mh));
INDArray nh = input2.mmul(dense2_W).addi(dense2_b.repmat(batchsz, 1)); INDArray nh = input2.mmul(dense2_W).addi(dense2_b.repmat(batchsz, 1));
INDArray n = (Transforms.tanh(nh)); INDArray n = (Transforms.tanh(nh));
INDArray oh = input3.mmul(dense3_W).addi(dense3_b.repmat(batchsz, 1)); INDArray oh = input3.mmul(dense3_W).addi(dense3_b.repmat(batchsz, 1));
INDArray o = (Transforms.tanh(oh)); INDArray o = (Transforms.tanh(oh));
INDArray middle = Nd4j.ones(batchsz, midsz); INDArray middle = Nd4j.ones(batchsz, midsz);
middle.muli(m).muli(n).muli(o); middle.muli(m).muli(n).muli(o);
INDArray expect = Nd4j.zeros(batchsz, outputsz); INDArray expect = Nd4j.zeros(batchsz, outputsz);
expect.addi(Transforms.sigmoid(middle.mmul(output_W).addi(output_b.repmat(batchsz, 1)))); expect.addi(Transforms.sigmoid(middle.mmul(output_W).addi(output_b.repmat(batchsz, 1))));
INDArray output = nullsafe(cg.output(input1, input2, input3)[0]); INDArray output = nullsafe(cg.output(input1, input2, input3)[0]);
Assertions.assertEquals(0.0, mse(output, expect), this.epsilon);
Assert.assertEquals(0.0, mse(output, expect), this.epsilon);
Pair<Gradient, Double> pgd = cg.gradientAndScore(); Pair<Gradient, Double> pgd = cg.gradientAndScore();
double score = pgd.getSecond(); double score = pgd.getSecond();
Assert.assertEquals(score, mse(output, target), this.epsilon); Assertions.assertEquals(score, mse(output, target), this.epsilon);
Map<String, INDArray> gradients = pgd.getFirst().gradientForVariable(); Map<String, INDArray> gradients = pgd.getFirst().gradientForVariable();
/* /*
* So. Let's say we have inputs a, b, c * So. Let's say we have inputs a, b, c
@ -481,27 +325,23 @@ public class ElementWiseVertexTest extends BaseDL4JTest {
* dmh/db1 = Nd4j.ones(1, batchsz) * dmh/db1 = Nd4j.ones(1, batchsz)
* *
*/ */
INDArray y = output; INDArray y = output;
INDArray s = middle; INDArray s = middle;
INDArray W4 = output_W; INDArray W4 = output_W;
INDArray dEdy = Nd4j.zeros(target.shape()); INDArray dEdy = Nd4j.zeros(target.shape());
dEdy.addi(y).subi(target).muli(2); // This should be of size batchsz x outputsz // This should be of size batchsz x outputsz
dEdy.divi(target.shape()[1]); // Why? Because the LossFunction divides by the _element size_ of the output. dEdy.addi(y).subi(target).muli(2);
// Why? Because the LossFunction divides by the _element size_ of the output.
INDArray dydyh = y.mul(y.mul(-1).add(1)); // This is of size batchsz x outputsz dEdy.divi(target.shape()[1]);
// This is of size batchsz x outputsz
INDArray dydyh = y.mul(y.mul(-1).add(1));
INDArray dEdyh = dydyh.mul(dEdy); INDArray dEdyh = dydyh.mul(dEdy);
INDArray dyhdW4 = s.transpose(); INDArray dyhdW4 = s.transpose();
INDArray dEdW4 = nullsafe(dyhdW4.mmul(dEdyh)); INDArray dEdW4 = nullsafe(dyhdW4.mmul(dEdyh));
INDArray dyhdb4 = Nd4j.ones(1, batchsz); INDArray dyhdb4 = Nd4j.ones(1, batchsz);
INDArray dEdb4 = nullsafe(dyhdb4.mmul(dEdyh)); INDArray dEdb4 = nullsafe(dyhdb4.mmul(dEdyh));
INDArray dyhds = W4.transpose(); INDArray dyhds = W4.transpose();
INDArray dEds = dEdyh.mmul(dyhds); INDArray dEds = dEdyh.mmul(dyhds);
INDArray dsdm = Nd4j.ones(batchsz, midsz).muli(n).muli(o); INDArray dsdm = Nd4j.ones(batchsz, midsz).muli(n).muli(o);
INDArray dEdm = dsdm.mul(dEds); INDArray dEdm = dsdm.mul(dEds);
INDArray dmdmh = (m.mul(m)).mul(-1).add(1); INDArray dmdmh = (m.mul(m)).mul(-1).add(1);
@ -510,7 +350,6 @@ public class ElementWiseVertexTest extends BaseDL4JTest {
INDArray dEdW1 = nullsafe(dmhdW1.mmul(dEdmh)); INDArray dEdW1 = nullsafe(dmhdW1.mmul(dEdmh));
INDArray dmhdb1 = Nd4j.ones(1, batchsz); INDArray dmhdb1 = Nd4j.ones(1, batchsz);
INDArray dEdb1 = nullsafe(dmhdb1.mmul(dEdmh)); INDArray dEdb1 = nullsafe(dmhdb1.mmul(dEdmh));
INDArray dsdn = Nd4j.ones(batchsz, midsz).muli(m).muli(o); INDArray dsdn = Nd4j.ones(batchsz, midsz).muli(m).muli(o);
INDArray dEdn = dsdn.mul(dEds); INDArray dEdn = dsdn.mul(dEds);
INDArray dndnh = (n.mul(n)).mul(-1).add(1); INDArray dndnh = (n.mul(n)).mul(-1).add(1);
@ -519,7 +358,6 @@ public class ElementWiseVertexTest extends BaseDL4JTest {
INDArray dEdW2 = nullsafe(dnhdW2.mmul(dEdnh)); INDArray dEdW2 = nullsafe(dnhdW2.mmul(dEdnh));
INDArray dnhdb2 = Nd4j.ones(1, batchsz); INDArray dnhdb2 = Nd4j.ones(1, batchsz);
INDArray dEdb2 = nullsafe(dnhdb2.mmul(dEdnh)); INDArray dEdb2 = nullsafe(dnhdb2.mmul(dEdnh));
INDArray dsdo = Nd4j.ones(batchsz, midsz).muli(m).muli(n); INDArray dsdo = Nd4j.ones(batchsz, midsz).muli(m).muli(n);
INDArray dEdo = dsdo.mul(dEds); INDArray dEdo = dsdo.mul(dEds);
INDArray dodoh = (o.mul(o)).mul(-1).add(1); INDArray dodoh = (o.mul(o)).mul(-1).add(1);
@ -528,55 +366,32 @@ public class ElementWiseVertexTest extends BaseDL4JTest {
INDArray dEdW3 = nullsafe(dohdW3.mmul(dEdoh)); INDArray dEdW3 = nullsafe(dohdW3.mmul(dEdoh));
INDArray dohdb3 = Nd4j.ones(1, batchsz); INDArray dohdb3 = Nd4j.ones(1, batchsz);
INDArray dEdb3 = nullsafe(dohdb3.mmul(dEdoh)); INDArray dEdb3 = nullsafe(dohdb3.mmul(dEdoh));
Assertions.assertEquals(0, mse(nullsafe(gradients.get("output_W")), dEdW4), this.epsilon);
Assert.assertEquals(0, mse(nullsafe(gradients.get("output_W")), dEdW4), this.epsilon); Assertions.assertEquals(0, mse(nullsafe(gradients.get("output_b")), dEdb4), this.epsilon);
Assert.assertEquals(0, mse(nullsafe(gradients.get("output_b")), dEdb4), this.epsilon); Assertions.assertEquals(0, mse(nullsafe(gradients.get("dense1_W")), dEdW1), this.epsilon);
Assert.assertEquals(0, mse(nullsafe(gradients.get("dense1_W")), dEdW1), this.epsilon); Assertions.assertEquals(0, mse(nullsafe(gradients.get("dense1_b")), dEdb1), this.epsilon);
Assert.assertEquals(0, mse(nullsafe(gradients.get("dense1_b")), dEdb1), this.epsilon); Assertions.assertEquals(0, mse(nullsafe(gradients.get("dense2_W")), dEdW2), this.epsilon);
Assert.assertEquals(0, mse(nullsafe(gradients.get("dense2_W")), dEdW2), this.epsilon); Assertions.assertEquals(0, mse(nullsafe(gradients.get("dense2_b")), dEdb2), this.epsilon);
Assert.assertEquals(0, mse(nullsafe(gradients.get("dense2_b")), dEdb2), this.epsilon); Assertions.assertEquals(0, mse(nullsafe(gradients.get("dense3_W")), dEdW3), this.epsilon);
Assert.assertEquals(0, mse(nullsafe(gradients.get("dense3_W")), dEdW3), this.epsilon); Assertions.assertEquals(0, mse(nullsafe(gradients.get("dense3_b")), dEdb3), this.epsilon);
Assert.assertEquals(0, mse(nullsafe(gradients.get("dense3_b")), dEdb3), this.epsilon);
} }
@Test @Test
public void testElementWiseVertexFullSubtract() { @DisplayName("Test Element Wise Vertex Full Subtract")
void testElementWiseVertexFullSubtract() {
int batchsz = 24; int batchsz = 24;
int featuresz = 17; int featuresz = 17;
int midsz = 13; int midsz = 13;
int outputsz = 11; int outputsz = 11;
ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER) ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).dataType(DataType.DOUBLE).biasInit(0.0).updater(new Sgd()).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder().addInputs("input1", "input2").addLayer("dense1", new DenseLayer.Builder().nIn(featuresz).nOut(midsz).activation(new ActivationTanH()).build(), "input1").addLayer("dense2", new DenseLayer.Builder().nIn(featuresz).nOut(midsz).activation(new ActivationTanH()).build(), "input2").addVertex("elementwiseSubtract", new ElementWiseVertex(ElementWiseVertex.Op.Subtract), "dense1", "dense2").addLayer("output", new OutputLayer.Builder().nIn(midsz).nOut(outputsz).activation(new ActivationSigmoid()).lossFunction(LossFunction.MSE).build(), "elementwiseSubtract").setOutputs("output").build();
.dataType(DataType.DOUBLE)
.biasInit(0.0).updater(new Sgd())
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder()
.addInputs("input1", "input2")
.addLayer("dense1",
new DenseLayer.Builder().nIn(featuresz).nOut(midsz)
.activation(new ActivationTanH()).build(),
"input1")
.addLayer("dense2",
new DenseLayer.Builder().nIn(featuresz).nOut(midsz)
.activation(new ActivationTanH()).build(),
"input2")
.addVertex("elementwiseSubtract", new ElementWiseVertex(ElementWiseVertex.Op.Subtract),
"dense1", "dense2")
.addLayer("output",
new OutputLayer.Builder().nIn(midsz).nOut(outputsz)
.activation(new ActivationSigmoid())
.lossFunction(LossFunction.MSE).build(),
"elementwiseSubtract")
.setOutputs("output").build();
ComputationGraph cg = new ComputationGraph(cgc); ComputationGraph cg = new ComputationGraph(cgc);
cg.init(); cg.init();
INDArray input1 = Nd4j.rand(new int[] {batchsz, featuresz}, new UniformDistribution(-1, 1)); INDArray input1 = Nd4j.rand(new int[] { batchsz, featuresz }, new UniformDistribution(-1, 1));
INDArray input2 = Nd4j.rand(new int[] {batchsz, featuresz}, new UniformDistribution(-1, 1)); INDArray input2 = Nd4j.rand(new int[] { batchsz, featuresz }, new UniformDistribution(-1, 1));
INDArray target = nullsafe(Nd4j.rand(new int[] {batchsz, outputsz}, new UniformDistribution(0, 1))); INDArray target = nullsafe(Nd4j.rand(new int[] { batchsz, outputsz }, new UniformDistribution(0, 1)));
cg.setInputs(input1, input2); cg.setInputs(input1, input2);
cg.setLabels(target); cg.setLabels(target);
cg.computeGradientAndScore(); cg.computeGradientAndScore();
// Let's figure out what our params are now. // Let's figure out what our params are now.
Map<String, INDArray> params = cg.paramTable(); Map<String, INDArray> params = cg.paramTable();
INDArray dense1_W = nullsafe(params.get("dense1_W")); INDArray dense1_W = nullsafe(params.get("dense1_W"));
@ -585,32 +400,20 @@ public class ElementWiseVertexTest extends BaseDL4JTest {
INDArray dense2_b = nullsafe(params.get("dense2_b")); INDArray dense2_b = nullsafe(params.get("dense2_b"));
INDArray output_W = nullsafe(params.get("output_W")); INDArray output_W = nullsafe(params.get("output_W"));
INDArray output_b = nullsafe(params.get("output_b")); INDArray output_b = nullsafe(params.get("output_b"));
// Now, let's calculate what we expect the output to be. // Now, let's calculate what we expect the output to be.
INDArray mh = input1.mmul(dense1_W).addi(dense1_b.repmat(batchsz, 1)); INDArray mh = input1.mmul(dense1_W).addi(dense1_b.repmat(batchsz, 1));
INDArray m = (Transforms.tanh(mh)); INDArray m = (Transforms.tanh(mh));
INDArray nh = input2.mmul(dense2_W).addi(dense2_b.repmat(batchsz, 1)); INDArray nh = input2.mmul(dense2_W).addi(dense2_b.repmat(batchsz, 1));
INDArray n = (Transforms.tanh(nh)); INDArray n = (Transforms.tanh(nh));
INDArray middle = Nd4j.zeros(batchsz, midsz); INDArray middle = Nd4j.zeros(batchsz, midsz);
middle.addi(m).subi(n); middle.addi(m).subi(n);
INDArray expect = Nd4j.zeros(batchsz, outputsz); INDArray expect = Nd4j.zeros(batchsz, outputsz);
expect.addi(Transforms.sigmoid(middle.mmul(output_W).addi(output_b.repmat(batchsz, 1)))); expect.addi(Transforms.sigmoid(middle.mmul(output_W).addi(output_b.repmat(batchsz, 1))));
INDArray output = nullsafe(cg.output(input1, input2)[0]); INDArray output = nullsafe(cg.output(input1, input2)[0]);
Assertions.assertEquals(0.0, mse(output, expect), this.epsilon);
Assert.assertEquals(0.0, mse(output, expect), this.epsilon);
Pair<Gradient, Double> pgd = cg.gradientAndScore(); Pair<Gradient, Double> pgd = cg.gradientAndScore();
double score = pgd.getSecond(); double score = pgd.getSecond();
Assert.assertEquals(score, mse(output, target), this.epsilon); Assertions.assertEquals(score, mse(output, target), this.epsilon);
Map<String, INDArray> gradients = pgd.getFirst().gradientForVariable(); Map<String, INDArray> gradients = pgd.getFirst().gradientForVariable();
/* /*
* So. Let's say we have inputs a, b, c * So. Let's say we have inputs a, b, c
@ -644,27 +447,23 @@ public class ElementWiseVertexTest extends BaseDL4JTest {
* dmh/db1 = Nd4j.ones(1, batchsz) * dmh/db1 = Nd4j.ones(1, batchsz)
* *
*/ */
INDArray y = output; INDArray y = output;
INDArray s = middle; INDArray s = middle;
INDArray W4 = output_W; INDArray W4 = output_W;
INDArray dEdy = Nd4j.zeros(target.shape()); INDArray dEdy = Nd4j.zeros(target.shape());
dEdy.addi(y).subi(target).muli(2); // This should be of size batchsz x outputsz // This should be of size batchsz x outputsz
dEdy.divi(target.shape()[1]); // Why? Because the LossFunction divides by the _element size_ of the output. dEdy.addi(y).subi(target).muli(2);
// Why? Because the LossFunction divides by the _element size_ of the output.
INDArray dydyh = y.mul(y.mul(-1).add(1)); // This is of size batchsz x outputsz dEdy.divi(target.shape()[1]);
// This is of size batchsz x outputsz
INDArray dydyh = y.mul(y.mul(-1).add(1));
INDArray dEdyh = dydyh.mul(dEdy); INDArray dEdyh = dydyh.mul(dEdy);
INDArray dyhdW4 = s.transpose(); INDArray dyhdW4 = s.transpose();
INDArray dEdW4 = nullsafe(dyhdW4.mmul(dEdyh)); INDArray dEdW4 = nullsafe(dyhdW4.mmul(dEdyh));
INDArray dyhdb4 = Nd4j.ones(1, batchsz); INDArray dyhdb4 = Nd4j.ones(1, batchsz);
INDArray dEdb4 = nullsafe(dyhdb4.mmul(dEdyh)); INDArray dEdb4 = nullsafe(dyhdb4.mmul(dEdyh));
INDArray dyhds = W4.transpose(); INDArray dyhds = W4.transpose();
INDArray dEds = dEdyh.mmul(dyhds); INDArray dEds = dEdyh.mmul(dyhds);
INDArray dsdm = Nd4j.ones(batchsz, midsz); INDArray dsdm = Nd4j.ones(batchsz, midsz);
INDArray dEdm = dsdm.mul(dEds); INDArray dEdm = dsdm.mul(dEds);
INDArray dmdmh = (m.mul(m)).mul(-1).add(1); INDArray dmdmh = (m.mul(m)).mul(-1).add(1);
@ -673,7 +472,6 @@ public class ElementWiseVertexTest extends BaseDL4JTest {
INDArray dEdW1 = nullsafe(dmhdW1.mmul(dEdmh)); INDArray dEdW1 = nullsafe(dmhdW1.mmul(dEdmh));
INDArray dmhdb1 = Nd4j.ones(1, batchsz); INDArray dmhdb1 = Nd4j.ones(1, batchsz);
INDArray dEdb1 = nullsafe(dmhdb1.mmul(dEdmh)); INDArray dEdb1 = nullsafe(dmhdb1.mmul(dEdmh));
INDArray dsdn = Nd4j.ones(batchsz, midsz).muli(-1); INDArray dsdn = Nd4j.ones(batchsz, midsz).muli(-1);
INDArray dEdn = dsdn.mul(dEds); INDArray dEdn = dsdn.mul(dEds);
INDArray dndnh = (n.mul(n)).mul(-1).add(1); INDArray dndnh = (n.mul(n)).mul(-1).add(1);
@ -682,20 +480,16 @@ public class ElementWiseVertexTest extends BaseDL4JTest {
INDArray dEdW2 = nullsafe(dnhdW2.mmul(dEdnh)); INDArray dEdW2 = nullsafe(dnhdW2.mmul(dEdnh));
INDArray dnhdb2 = Nd4j.ones(1, batchsz); INDArray dnhdb2 = Nd4j.ones(1, batchsz);
INDArray dEdb2 = nullsafe(dnhdb2.mmul(dEdnh)); INDArray dEdb2 = nullsafe(dnhdb2.mmul(dEdnh));
Assertions.assertEquals(0, mse(nullsafe(gradients.get("output_W")), dEdW4), this.epsilon);
Assertions.assertEquals(0, mse(nullsafe(gradients.get("output_b")), dEdb4), this.epsilon);
Assert.assertEquals(0, mse(nullsafe(gradients.get("output_W")), dEdW4), this.epsilon); Assertions.assertEquals(0, mse(nullsafe(gradients.get("dense1_W")), dEdW1), this.epsilon);
Assert.assertEquals(0, mse(nullsafe(gradients.get("output_b")), dEdb4), this.epsilon); Assertions.assertEquals(0, mse(nullsafe(gradients.get("dense1_b")), dEdb1), this.epsilon);
Assert.assertEquals(0, mse(nullsafe(gradients.get("dense1_W")), dEdW1), this.epsilon); Assertions.assertEquals(0, mse(nullsafe(gradients.get("dense2_W")), dEdW2), this.epsilon);
Assert.assertEquals(0, mse(nullsafe(gradients.get("dense1_b")), dEdb1), this.epsilon); Assertions.assertEquals(0, mse(nullsafe(gradients.get("dense2_b")), dEdb2), this.epsilon);
Assert.assertEquals(0, mse(nullsafe(gradients.get("dense2_W")), dEdW2), this.epsilon);
Assert.assertEquals(0, mse(nullsafe(gradients.get("dense2_b")), dEdb2), this.epsilon);
} }
private static double mse(INDArray output, INDArray target) { private static double mse(INDArray output, INDArray target) {
double mse_expect = Transforms.pow(output.sub(target), 2.0).sumNumber().doubleValue() double mse_expect = Transforms.pow(output.sub(target), 2.0).sumNumber().doubleValue() / (output.columns() * output.rows());
/ (output.columns() * output.rows());
return mse_expect; return mse_expect;
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.deeplearning4j.nn.conf.graph; package org.deeplearning4j.nn.conf.graph;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
@ -30,8 +29,8 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.Assert; import org.junit.jupiter.api.Assertions;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.activations.impl.ActivationSigmoid; import org.nd4j.linalg.activations.impl.ActivationSigmoid;
@ -42,86 +41,70 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import org.nd4j.common.primitives.Pair; import org.nd4j.common.primitives.Pair;
import java.util.Map; import java.util.Map;
import java.util.TreeMap; import java.util.TreeMap;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@DisplayName("Shift Vertex Test")
class ShiftVertexTest extends BaseDL4JTest {
public class ShiftVertexTest extends BaseDL4JTest {
@Test @Test
public void testShiftVertexNumParamsTrue() { @DisplayName("Test Shift Vertex Num Params True")
void testShiftVertexNumParamsTrue() {
/* /*
* https://github.com/eclipse/deeplearning4j/pull/3514#issuecomment-307754386 * https://github.com/eclipse/deeplearning4j/pull/3514#issuecomment-307754386
* from @agibsonccc: check for the basics: like 0 numParams * from @agibsonccc: check for the basics: like 0 numParams
*/ */
// The 0.7 doesn't really matter.
ShiftVertex sv = new ShiftVertex(0.7); // The 0.7 doesn't really matter.
Assert.assertEquals(0, sv.numParams(true));
}
@Test
public void testShiftVertexNumParamsFalse() {
/*
* https://github.com/eclipse/deeplearning4j/pull/3514#issuecomment-307754386
* from @agibsonccc: check for the basics: like 0 numParams
*/
ShiftVertex sv = new ShiftVertex(0.7); // The 0.7 doesn't really matter.
Assert.assertEquals(0, sv.numParams(false));
}
@Test
public void testGet() {
ShiftVertex sv = new ShiftVertex(0.7); ShiftVertex sv = new ShiftVertex(0.7);
Assert.assertEquals(0.7, sv.getShiftFactor(), this.epsilon); Assertions.assertEquals(0, sv.numParams(true));
} }
@Test @Test
public void testSimple() { @DisplayName("Test Shift Vertex Num Params False")
void testShiftVertexNumParamsFalse() {
/*
* https://github.com/eclipse/deeplearning4j/pull/3514#issuecomment-307754386
* from @agibsonccc: check for the basics: like 0 numParams
*/
// The 0.7 doesn't really matter.
ShiftVertex sv = new ShiftVertex(0.7);
Assertions.assertEquals(0, sv.numParams(false));
}
@Test
@DisplayName("Test Get")
void testGet() {
ShiftVertex sv = new ShiftVertex(0.7);
Assertions.assertEquals(0.7, sv.getShiftFactor(), this.epsilon);
}
@Test
@DisplayName("Test Simple")
void testSimple() {
/* /*
* This function _simply_ tests whether ShiftVertex is _in fact_ adding the shift value to it's inputs. * This function _simply_ tests whether ShiftVertex is _in fact_ adding the shift value to it's inputs.
*/ */
// Just first n primes / 10. // Just first n primes / 10.
INDArray input = Nd4j INDArray input = Nd4j.create(new double[][] { { 0.2, 0.3, 0.5 }, { 0.7, 1.1, 1.3 }, { 1.7, 1.9, 2.3 }, { 2.9, 3.1, 3.7 } });
.create(new double[][] {{0.2, 0.3, 0.5}, {0.7, 1.1, 1.3}, {1.7, 1.9, 2.3}, {2.9, 3.1, 3.7}});
double sf = 4.1; double sf = 4.1;
ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("input") ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("input").addLayer("denselayer", new DenseLayer.Builder().nIn(input.columns()).nOut(1).activation(Activation.IDENTITY).build(), "input").addLayer("identityinputactivation", new ActivationLayer.Builder().activation(Activation.IDENTITY).build(), "input").addVertex("shiftvertex", new ShiftVertex(sf), "identityinputactivation").addLayer("identityshiftvertex", new ActivationLayer.Builder().activation(Activation.IDENTITY).build(), "shiftvertex").setOutputs("identityshiftvertex", "denselayer").build();
.addLayer("denselayer",
new DenseLayer.Builder().nIn(input.columns()).nOut(1)
.activation(Activation.IDENTITY).build(),
"input")
/* denselayer is not actually used, but it seems that you _need_ to have trainable parameters, otherwise, you get
* Invalid shape: Requested INDArray shape [1, 0] contains dimension size values < 1 (all dimensions must be 1 or more)
* at org.nd4j.linalg.factory.Nd4j.checkShapeValues(Nd4j.java:4877)
* at org.nd4j.linalg.factory.Nd4j.create(Nd4j.java:4867)
* at org.nd4j.linalg.factory.Nd4j.create(Nd4j.java:4820)
* at org.nd4j.linalg.factory.Nd4j.create(Nd4j.java:3948)
* at org.deeplearning4j.nn.graph.ComputationGraph.init(ComputationGraph.java:409)
* at org.deeplearning4j.nn.graph.ComputationGraph.init(ComputationGraph.java:341)
*/
.addLayer("identityinputactivation",
new ActivationLayer.Builder().activation(Activation.IDENTITY).build(), "input")
.addVertex("shiftvertex", new ShiftVertex(sf), "identityinputactivation")
.addLayer("identityshiftvertex",
new ActivationLayer.Builder().activation(Activation.IDENTITY).build(),
"shiftvertex")
.setOutputs("identityshiftvertex", "denselayer").build();
ComputationGraph cg = new ComputationGraph(cgc); ComputationGraph cg = new ComputationGraph(cgc);
cg.init(); cg.init();
// We can call outputSingle, because we only have a single output layer. It has nothing to do with minibatches. // We can call outputSingle, because we only have a single output layer. It has nothing to do with minibatches.
INDArray output = cg.output(true, input)[0]; INDArray output = cg.output(true, input)[0];
INDArray target = Nd4j.zeros(input.shape()); INDArray target = Nd4j.zeros(input.shape());
target.addi(input); target.addi(input);
target.addi(sf); target.addi(sf);
INDArray squared = output.sub(target); INDArray squared = output.sub(target);
double rms = squared.mul(squared).sumNumber().doubleValue(); double rms = squared.mul(squared).sumNumber().doubleValue();
Assert.assertEquals(0.0, rms, this.epsilon); Assertions.assertEquals(0.0, rms, this.epsilon);
} }
@Test @Test
public void testComprehensive() { @DisplayName("Test Comprehensive")
void testComprehensive() {
/* /*
* This function tests ShiftVertex more comprehensively. Specifically, it verifies that the lossfunction works as * This function tests ShiftVertex more comprehensively. Specifically, it verifies that the lossfunction works as
* expected on a ComputationGraph _with_ a ShiftVertex and it verifies that the derivatives produced by * expected on a ComputationGraph _with_ a ShiftVertex and it verifies that the derivatives produced by
@ -130,29 +113,12 @@ public class ShiftVertexTest extends BaseDL4JTest {
BaseActivationFunction a1 = new ActivationTanH(); BaseActivationFunction a1 = new ActivationTanH();
BaseActivationFunction a2 = new ActivationSigmoid(); BaseActivationFunction a2 = new ActivationSigmoid();
// Just first n primes / 10. // Just first n primes / 10.
INDArray input = Nd4j INDArray input = Nd4j.create(new double[][] { { 0.2, 0.3, 0.5 }, { 0.7, 1.1, 1.3 }, { 1.7, 1.9, 2.3 }, { 2.9, 3.1, 3.7 } });
.create(new double[][] {{0.2, 0.3, 0.5}, {0.7, 1.1, 1.3}, {1.7, 1.9, 2.3}, {2.9, 3.1, 3.7}});
double sf = 4.1; double sf = 4.1;
// Actually, given that I'm using a sigmoid on the output, // Actually, given that I'm using a sigmoid on the output,
// these should really be between 0 and 1 // these should really be between 0 and 1
INDArray target = Nd4j.create(new double[][] {{0.05, 0.10, 0.15, 0.20, 0.25}, {0.30, 0.35, 0.40, 0.45, 0.50}, INDArray target = Nd4j.create(new double[][] { { 0.05, 0.10, 0.15, 0.20, 0.25 }, { 0.30, 0.35, 0.40, 0.45, 0.50 }, { 0.55, 0.60, 0.65, 0.70, 0.75 }, { 0.80, 0.85, 0.90, 0.95, 0.99 } });
{0.55, 0.60, 0.65, 0.70, 0.75}, {0.80, 0.85, 0.90, 0.95, 0.99}}); ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).dataType(DataType.DOUBLE).updater(new Sgd(0.01)).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder().addInputs("input").addLayer("denselayer", new DenseLayer.Builder().nIn(input.columns()).nOut(input.columns()).activation(a1).build(), "input").addVertex("shiftvertex", new ShiftVertex(sf), "denselayer").addLayer("output", new OutputLayer.Builder().nIn(input.columns()).nOut(target.columns()).activation(a2).lossFunction(LossFunction.MSE).build(), "shiftvertex").setOutputs("output").build();
ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER)
.dataType(DataType.DOUBLE)
.updater(new Sgd(0.01))
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder()
.addInputs("input")
.addLayer("denselayer",
new DenseLayer.Builder().nIn(input.columns()).nOut(input.columns())
.activation(a1).build(),
"input")
.addVertex("shiftvertex", new ShiftVertex(sf), "denselayer")
.addLayer("output",
new OutputLayer.Builder().nIn(input.columns()).nOut(target.columns())
.activation(a2).lossFunction(LossFunction.MSE).build(),
"shiftvertex")
.setOutputs("output").build();
ComputationGraph cg = new ComputationGraph(cgc); ComputationGraph cg = new ComputationGraph(cgc);
cg.init(); cg.init();
cg.setInput(0, input); cg.setInput(0, input);
@ -163,26 +129,23 @@ public class ShiftVertexTest extends BaseDL4JTest {
Gradient g = cg.gradient(); Gradient g = cg.gradient();
Map<String, INDArray> gradients = g.gradientForVariable(); Map<String, INDArray> gradients = g.gradientForVariable();
Map<String, INDArray> manual_gradients = new TreeMap<String, INDArray>(); Map<String, INDArray> manual_gradients = new TreeMap<String, INDArray>();
INDArray W = nullsafe(weights.get("denselayer_W")); INDArray W = nullsafe(weights.get("denselayer_W"));
INDArray b = nullsafe(weights.get("denselayer_b")); INDArray b = nullsafe(weights.get("denselayer_b"));
INDArray V = nullsafe(weights.get("output_W")); INDArray V = nullsafe(weights.get("output_W"));
INDArray c = nullsafe(weights.get("output_b")); INDArray c = nullsafe(weights.get("output_b"));
Map<String, INDArray> manual_weights = new TreeMap<String, INDArray>(); Map<String, INDArray> manual_weights = new TreeMap<String, INDArray>();
manual_weights.put("denselayer_W", W); manual_weights.put("denselayer_W", W);
manual_weights.put("denselayer_b", b); manual_weights.put("denselayer_b", b);
manual_weights.put("output_W", V); manual_weights.put("output_W", V);
manual_weights.put("output_b", c); manual_weights.put("output_b", c);
// First things first, let's calculate the score. // First things first, let's calculate the score.
long batchsz = input.shape()[0]; long batchsz = input.shape()[0];
INDArray z = input.castTo(W.dataType()).mmul(W).add(b.repmat(batchsz, 1)); INDArray z = input.castTo(W.dataType()).mmul(W).add(b.repmat(batchsz, 1));
INDArray a = a1.getActivation(z.dup(), true).add(sf); // activation modifies it's input!! // activation modifies it's input!!
INDArray a = a1.getActivation(z.dup(), true).add(sf);
INDArray q = a.mmul(V).add(c.repmat(batchsz, 1)); INDArray q = a.mmul(V).add(c.repmat(batchsz, 1));
INDArray o = nullsafe(a2.getActivation(q.dup(), true)); INDArray o = nullsafe(a2.getActivation(q.dup(), true));
double score_manual = sum_errors(o, target) / (o.columns() * o.rows()); double score_manual = sum_errors(o, target) / (o.columns() * o.rows());
/* /*
* So. We have * So. We have
* z5 = input1 * W15 + input2 * W25 + input3 * W35 + b5 * z5 = input1 * W15 + input2 * W25 + input3 * W35 + b5
@ -197,12 +160,15 @@ public class ShiftVertexTest extends BaseDL4JTest {
* dq1/dv11 = a1 dq2/dV12 = a1 dq3/dV13 = a1 ... * dq1/dv11 = a1 dq2/dV12 = a1 dq3/dV13 = a1 ...
* dq1/dv21 = a2 dq2... * dq1/dv21 = a2 dq2...
*/ */
INDArray dEdo = target.like(); //Nd4j.zeros(target.shape()); // Nd4j.zeros(target.shape());
dEdo.addi(o.castTo(dEdo.dataType())).subi(target).muli(2); // This should be of size batchsz x outputsz INDArray dEdo = target.like();
dEdo.divi(target.shape()[1]); // Why? Because the LossFunction divides by the _element size_ of the output. // This should be of size batchsz x outputsz
dEdo.addi(o.castTo(dEdo.dataType())).subi(target).muli(2);
// Why? Because the LossFunction divides by the _element size_ of the output.
dEdo.divi(target.shape()[1]);
Pair<INDArray, INDArray> derivs2 = a2.backprop(q, dEdo); Pair<INDArray, INDArray> derivs2 = a2.backprop(q, dEdo);
INDArray dEdq = derivs2.getFirst(); // This should be of size batchsz x outputsz (dE/do * do/dq) this _should_ be o * (1-o) * dE/do for Sigmoid. // This should be of size batchsz x outputsz (dE/do * do/dq) this _should_ be o * (1-o) * dE/do for Sigmoid.
INDArray dEdq = derivs2.getFirst();
// Should be o = q^3 do/dq = 3 q^2 for Cube. // Should be o = q^3 do/dq = 3 q^2 for Cube.
/* /*
INDArray dodq = q.mul(q).mul(3); INDArray dodq = q.mul(q).mul(3);
@ -213,26 +179,23 @@ public class ShiftVertexTest extends BaseDL4JTest {
System.err.println(tbv); System.err.println(tbv);
System.err.println(dEdq); System.err.println(dEdq);
*/ */
INDArray dqdc = Nd4j.ones(1, batchsz); INDArray dqdc = Nd4j.ones(1, batchsz);
INDArray dEdc = dqdc.mmul(dEdq); // This should be of size 1 x outputsz // This should be of size 1 x outputsz
INDArray dEdc = dqdc.mmul(dEdq);
INDArray dEdV = a.transpose().mmul(dEdq); INDArray dEdV = a.transpose().mmul(dEdq);
INDArray dEda = dEdq.mmul(V.transpose()); // This should be dEdo * dodq * dqda // This should be dEdo * dodq * dqda
INDArray dEda = dEdq.mmul(V.transpose());
Pair<INDArray, INDArray> derivs1 = a1.backprop(z, dEda); Pair<INDArray, INDArray> derivs1 = a1.backprop(z, dEda);
INDArray dEdz = derivs1.getFirst(); INDArray dEdz = derivs1.getFirst();
INDArray dzdb = Nd4j.ones(1, batchsz); INDArray dzdb = Nd4j.ones(1, batchsz);
INDArray dEdb = dzdb.mmul(dEdz); INDArray dEdb = dzdb.mmul(dEdz);
INDArray dEdW = input.transpose().mmul(dEdz); INDArray dEdW = input.transpose().mmul(dEdz);
manual_gradients.put("output_b", dEdc); manual_gradients.put("output_b", dEdc);
manual_gradients.put("output_W", dEdV); manual_gradients.put("output_W", dEdV);
manual_gradients.put("denselayer_b", dEdb); manual_gradients.put("denselayer_b", dEdb);
manual_gradients.put("denselayer_W", dEdW); manual_gradients.put("denselayer_W", dEdW);
double summse = Math.pow((score_manual - score_dl4j), 2); double summse = Math.pow((score_manual - score_dl4j), 2);
int denominator = 1; int denominator = 1;
for (Map.Entry<String, INDArray> mesi : gradients.entrySet()) { for (Map.Entry<String, INDArray> mesi : gradients.entrySet()) {
String name = mesi.getKey(); String name = mesi.getKey();
INDArray dl4j_gradient = nullsafe(mesi.getValue()); INDArray dl4j_gradient = nullsafe(mesi.getValue());
@ -241,9 +204,7 @@ public class ShiftVertexTest extends BaseDL4JTest {
summse += se; summse += se;
denominator += dl4j_gradient.columns() * dl4j_gradient.rows(); denominator += dl4j_gradient.columns() * dl4j_gradient.rows();
} }
Assertions.assertEquals(0.0, summse / denominator, this.epsilon);
Assert.assertEquals(0.0, summse / denominator, this.epsilon);
} }
private static double sum_errors(INDArray a, INDArray b) { private static double sum_errors(INDArray a, INDArray b) {

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.deeplearning4j.nn.conf.layers; package org.deeplearning4j.nn.conf.layers;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
@ -25,7 +24,7 @@ import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.dropout.Dropout; import org.deeplearning4j.nn.conf.dropout.Dropout;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationSoftmax; import org.nd4j.linalg.activations.impl.ActivationSoftmax;
@ -34,45 +33,62 @@ import org.nd4j.linalg.convolution.Convolution;
import org.nd4j.linalg.learning.config.AdaGrad; import org.nd4j.linalg.learning.config.AdaGrad;
import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import java.io.*; import java.io.*;
import static org.junit.jupiter.api.Assertions.*;
import static org.junit.Assert.*; import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
/** /**
* @author Jeffrey Tang. * @author Jeffrey Tang.
*/ */
public class LayerBuilderTest extends BaseDL4JTest { @DisplayName("Layer Builder Test")
class LayerBuilderTest extends BaseDL4JTest {
final double DELTA = 1e-15; final double DELTA = 1e-15;
int numIn = 10; int numIn = 10;
int numOut = 5; int numOut = 5;
double drop = 0.3; double drop = 0.3;
IActivation act = new ActivationSoftmax(); IActivation act = new ActivationSoftmax();
PoolingType poolType = PoolingType.MAX; PoolingType poolType = PoolingType.MAX;
int[] kernelSize = new int[] {2, 2};
int[] stride = new int[] {2, 2}; int[] kernelSize = new int[] { 2, 2 };
int[] padding = new int[] {1, 1};
int[] stride = new int[] { 2, 2 };
int[] padding = new int[] { 1, 1 };
int k = 1; int k = 1;
Convolution.Type convType = Convolution.Type.VALID; Convolution.Type convType = Convolution.Type.VALID;
LossFunction loss = LossFunction.MCXENT; LossFunction loss = LossFunction.MCXENT;
WeightInit weight = WeightInit.XAVIER; WeightInit weight = WeightInit.XAVIER;
double corrupt = 0.4; double corrupt = 0.4;
double sparsity = 0.3; double sparsity = 0.3;
double corruptionLevel = 0.5; double corruptionLevel = 0.5;
double dropOut = 0.1; double dropOut = 0.1;
IUpdater updater = new AdaGrad(); IUpdater updater = new AdaGrad();
GradientNormalization gradNorm = GradientNormalization.ClipL2PerParamType; GradientNormalization gradNorm = GradientNormalization.ClipL2PerParamType;
double gradNormThreshold = 8; double gradNormThreshold = 8;
@Test @Test
public void testLayer() throws Exception { @DisplayName("Test Layer")
DenseLayer layer = new DenseLayer.Builder().activation(act).weightInit(weight).dropOut(dropOut) void testLayer() throws Exception {
.updater(updater).gradientNormalization(gradNorm) DenseLayer layer = new DenseLayer.Builder().activation(act).weightInit(weight).dropOut(dropOut).updater(updater).gradientNormalization(gradNorm).gradientNormalizationThreshold(gradNormThreshold).build();
.gradientNormalizationThreshold(gradNormThreshold).build();
checkSerialization(layer); checkSerialization(layer);
assertEquals(act, layer.getActivationFn()); assertEquals(act, layer.getActivationFn());
assertEquals(weight.getWeightInitFunction(), layer.getWeightInitFn()); assertEquals(weight.getWeightInitFunction(), layer.getWeightInitFn());
assertEquals(new Dropout(dropOut), layer.getIDropout()); assertEquals(new Dropout(dropOut), layer.getIDropout());
@ -82,34 +98,30 @@ public class LayerBuilderTest extends BaseDL4JTest {
} }
@Test @Test
public void testFeedForwardLayer() throws Exception { @DisplayName("Test Feed Forward Layer")
void testFeedForwardLayer() throws Exception {
DenseLayer ff = new DenseLayer.Builder().nIn(numIn).nOut(numOut).build(); DenseLayer ff = new DenseLayer.Builder().nIn(numIn).nOut(numOut).build();
checkSerialization(ff); checkSerialization(ff);
assertEquals(numIn, ff.getNIn()); assertEquals(numIn, ff.getNIn());
assertEquals(numOut, ff.getNOut()); assertEquals(numOut, ff.getNOut());
} }
@Test @Test
public void testConvolutionLayer() throws Exception { @DisplayName("Test Convolution Layer")
void testConvolutionLayer() throws Exception {
ConvolutionLayer conv = new ConvolutionLayer.Builder(kernelSize, stride, padding).build(); ConvolutionLayer conv = new ConvolutionLayer.Builder(kernelSize, stride, padding).build();
checkSerialization(conv); checkSerialization(conv);
// assertEquals(convType, conv.getConvolutionType());
// assertEquals(convType, conv.getConvolutionType());
assertArrayEquals(kernelSize, conv.getKernelSize()); assertArrayEquals(kernelSize, conv.getKernelSize());
assertArrayEquals(stride, conv.getStride()); assertArrayEquals(stride, conv.getStride());
assertArrayEquals(padding, conv.getPadding()); assertArrayEquals(padding, conv.getPadding());
} }
@Test @Test
public void testSubsamplingLayer() throws Exception { @DisplayName("Test Subsampling Layer")
SubsamplingLayer sample = void testSubsamplingLayer() throws Exception {
new SubsamplingLayer.Builder(poolType, stride).kernelSize(kernelSize).padding(padding).build(); SubsamplingLayer sample = new SubsamplingLayer.Builder(poolType, stride).kernelSize(kernelSize).padding(padding).build();
checkSerialization(sample); checkSerialization(sample);
assertArrayEquals(padding, sample.getPadding()); assertArrayEquals(padding, sample.getPadding());
assertArrayEquals(kernelSize, sample.getKernelSize()); assertArrayEquals(kernelSize, sample.getKernelSize());
assertEquals(poolType, sample.getPoolingType()); assertEquals(poolType, sample.getPoolingType());
@ -117,36 +129,33 @@ public class LayerBuilderTest extends BaseDL4JTest {
} }
@Test @Test
public void testOutputLayer() throws Exception { @DisplayName("Test Output Layer")
void testOutputLayer() throws Exception {
OutputLayer out = new OutputLayer.Builder(loss).build(); OutputLayer out = new OutputLayer.Builder(loss).build();
checkSerialization(out); checkSerialization(out);
} }
@Test @Test
public void testRnnOutputLayer() throws Exception { @DisplayName("Test Rnn Output Layer")
void testRnnOutputLayer() throws Exception {
RnnOutputLayer out = new RnnOutputLayer.Builder(loss).build(); RnnOutputLayer out = new RnnOutputLayer.Builder(loss).build();
checkSerialization(out); checkSerialization(out);
} }
@Test @Test
public void testAutoEncoder() throws Exception { @DisplayName("Test Auto Encoder")
void testAutoEncoder() throws Exception {
AutoEncoder enc = new AutoEncoder.Builder().corruptionLevel(corruptionLevel).sparsity(sparsity).build(); AutoEncoder enc = new AutoEncoder.Builder().corruptionLevel(corruptionLevel).sparsity(sparsity).build();
checkSerialization(enc); checkSerialization(enc);
assertEquals(corruptionLevel, enc.getCorruptionLevel(), DELTA); assertEquals(corruptionLevel, enc.getCorruptionLevel(), DELTA);
assertEquals(sparsity, enc.getSparsity(), DELTA); assertEquals(sparsity, enc.getSparsity(), DELTA);
} }
@Test @Test
public void testGravesLSTM() throws Exception { @DisplayName("Test Graves LSTM")
GravesLSTM glstm = new GravesLSTM.Builder().forgetGateBiasInit(1.5).activation(Activation.TANH).nIn(numIn) void testGravesLSTM() throws Exception {
.nOut(numOut).build(); GravesLSTM glstm = new GravesLSTM.Builder().forgetGateBiasInit(1.5).activation(Activation.TANH).nIn(numIn).nOut(numOut).build();
checkSerialization(glstm); checkSerialization(glstm);
assertEquals(glstm.getForgetGateBiasInit(), 1.5, 0.0); assertEquals(glstm.getForgetGateBiasInit(), 1.5, 0.0);
assertEquals(glstm.nIn, numIn); assertEquals(glstm.nIn, numIn);
assertEquals(glstm.nOut, numOut); assertEquals(glstm.nOut, numOut);
@ -154,12 +163,10 @@ public class LayerBuilderTest extends BaseDL4JTest {
} }
@Test @Test
public void testGravesBidirectionalLSTM() throws Exception { @DisplayName("Test Graves Bidirectional LSTM")
final GravesBidirectionalLSTM glstm = new GravesBidirectionalLSTM.Builder().forgetGateBiasInit(1.5) void testGravesBidirectionalLSTM() throws Exception {
.activation(Activation.TANH).nIn(numIn).nOut(numOut).build(); final GravesBidirectionalLSTM glstm = new GravesBidirectionalLSTM.Builder().forgetGateBiasInit(1.5).activation(Activation.TANH).nIn(numIn).nOut(numOut).build();
checkSerialization(glstm); checkSerialization(glstm);
assertEquals(1.5, glstm.getForgetGateBiasInit(), 0.0); assertEquals(1.5, glstm.getForgetGateBiasInit(), 0.0);
assertEquals(glstm.nIn, numIn); assertEquals(glstm.nIn, numIn);
assertEquals(glstm.nOut, numOut); assertEquals(glstm.nOut, numOut);
@ -167,21 +174,19 @@ public class LayerBuilderTest extends BaseDL4JTest {
} }
@Test @Test
public void testEmbeddingLayer() throws Exception { @DisplayName("Test Embedding Layer")
void testEmbeddingLayer() throws Exception {
EmbeddingLayer el = new EmbeddingLayer.Builder().nIn(10).nOut(5).build(); EmbeddingLayer el = new EmbeddingLayer.Builder().nIn(10).nOut(5).build();
checkSerialization(el); checkSerialization(el);
assertEquals(10, el.getNIn()); assertEquals(10, el.getNIn());
assertEquals(5, el.getNOut()); assertEquals(5, el.getNOut());
} }
@Test @Test
public void testBatchNormLayer() throws Exception { @DisplayName("Test Batch Norm Layer")
BatchNormalization bN = new BatchNormalization.Builder().nIn(numIn).nOut(numOut).gamma(2).beta(1).decay(0.5) void testBatchNormLayer() throws Exception {
.lockGammaBeta(true).build(); BatchNormalization bN = new BatchNormalization.Builder().nIn(numIn).nOut(numOut).gamma(2).beta(1).decay(0.5).lockGammaBeta(true).build();
checkSerialization(bN); checkSerialization(bN);
assertEquals(numIn, bN.nIn); assertEquals(numIn, bN.nIn);
assertEquals(numOut, bN.nOut); assertEquals(numOut, bN.nOut);
assertEquals(true, bN.isLockGammaBeta()); assertEquals(true, bN.isLockGammaBeta());
@ -191,42 +196,38 @@ public class LayerBuilderTest extends BaseDL4JTest {
} }
@Test @Test
public void testActivationLayer() throws Exception { @DisplayName("Test Activation Layer")
void testActivationLayer() throws Exception {
ActivationLayer activationLayer = new ActivationLayer.Builder().activation(act).build(); ActivationLayer activationLayer = new ActivationLayer.Builder().activation(act).build();
checkSerialization(activationLayer); checkSerialization(activationLayer);
assertEquals(act, activationLayer.activationFn); assertEquals(act, activationLayer.activationFn);
} }
private void checkSerialization(Layer layer) throws Exception { private void checkSerialization(Layer layer) throws Exception {
NeuralNetConfiguration confExpected = new NeuralNetConfiguration.Builder().layer(layer).build(); NeuralNetConfiguration confExpected = new NeuralNetConfiguration.Builder().layer(layer).build();
NeuralNetConfiguration confActual; NeuralNetConfiguration confActual;
// check Java serialization // check Java serialization
byte[] data; byte[] data;
try (ByteArrayOutputStream bos = new ByteArrayOutputStream(); ObjectOutput out = new ObjectOutputStream(bos)) { try (ByteArrayOutputStream bos = new ByteArrayOutputStream();
ObjectOutput out = new ObjectOutputStream(bos)) {
out.writeObject(confExpected); out.writeObject(confExpected);
data = bos.toByteArray(); data = bos.toByteArray();
} }
try (ByteArrayInputStream bis = new ByteArrayInputStream(data); ObjectInput in = new ObjectInputStream(bis)) { try (ByteArrayInputStream bis = new ByteArrayInputStream(data);
ObjectInput in = new ObjectInputStream(bis)) {
confActual = (NeuralNetConfiguration) in.readObject(); confActual = (NeuralNetConfiguration) in.readObject();
} }
assertEquals("unequal Java serialization", confExpected.getLayer(), confActual.getLayer()); assertEquals(confExpected.getLayer(), confActual.getLayer(), "unequal Java serialization");
// check JSON // check JSON
String json = confExpected.toJson(); String json = confExpected.toJson();
confActual = NeuralNetConfiguration.fromJson(json); confActual = NeuralNetConfiguration.fromJson(json);
assertEquals("unequal JSON serialization", confExpected.getLayer(), confActual.getLayer()); assertEquals(confExpected.getLayer(), confActual.getLayer(), "unequal JSON serialization");
// check YAML // check YAML
String yaml = confExpected.toYaml(); String yaml = confExpected.toYaml();
confActual = NeuralNetConfiguration.fromYaml(yaml); confActual = NeuralNetConfiguration.fromYaml(yaml);
assertEquals("unequal YAML serialization", confExpected.getLayer(), confActual.getLayer()); assertEquals(confExpected.getLayer(), confActual.getLayer(), "unequal YAML serialization");
// check the layer's use of callSuper on equals method // check the layer's use of callSuper on equals method
confActual.getLayer().setIDropout(new Dropout(new java.util.Random().nextDouble())); confActual.getLayer().setIDropout(new Dropout(new java.util.Random().nextDouble()));
assertNotEquals("broken equals method (missing callSuper?)", confExpected.getLayer(), confActual.getLayer()); assertNotEquals(confExpected.getLayer(), confActual.getLayer(), "broken equals method (missing callSuper?)");
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.deeplearning4j.nn.conf.layers; package org.deeplearning4j.nn.conf.layers;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
@ -30,7 +29,7 @@ import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
import org.deeplearning4j.nn.conf.dropout.Dropout; import org.deeplearning4j.nn.conf.dropout.Dropout;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInitDistribution; import org.deeplearning4j.nn.weights.WeightInitDistribution;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.learning.config.AdaDelta; import org.nd4j.linalg.learning.config.AdaDelta;
import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.Adam;
@ -38,89 +37,170 @@ import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.learning.config.RmsProp; import org.nd4j.linalg.learning.config.RmsProp;
import org.nd4j.linalg.schedule.MapSchedule; import org.nd4j.linalg.schedule.MapSchedule;
import org.nd4j.linalg.schedule.ScheduleType; import org.nd4j.linalg.schedule.ScheduleType;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals; /*
import static org.junit.Assert.assertTrue; @Test
public void testLearningRatePolicyExponential() {
double lr = 2;
double lrDecayRate = 5;
int iterations = 1;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().learningRate(lr)
.updater(Updater.SGD)
.learningRateDecayPolicy(LearningRatePolicy.Exponential).lrPolicyDecayRate(lrDecayRate).list()
.layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build())
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
public class LayerConfigTest extends BaseDL4JTest { assertEquals(LearningRatePolicy.Exponential, conf.getConf(0).getLearningRatePolicy());
assertEquals(LearningRatePolicy.Exponential, conf.getConf(1).getLearningRatePolicy());
assertEquals(lrDecayRate, conf.getConf(0).getLrPolicyDecayRate(), 0.0);
assertEquals(lrDecayRate, conf.getConf(1).getLrPolicyDecayRate(), 0.0);
}
@Test @Test
public void testLayerName() { public void testLearningRatePolicyInverse() {
double lr = 2;
double lrDecayRate = 5;
double power = 3;
int iterations = 1;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().iterations(iterations).learningRate(lr)
.learningRateDecayPolicy(LearningRatePolicy.Inverse).lrPolicyDecayRate(lrDecayRate)
.lrPolicyPower(power).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build())
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
assertEquals(LearningRatePolicy.Inverse, conf.getConf(0).getLearningRatePolicy());
assertEquals(LearningRatePolicy.Inverse, conf.getConf(1).getLearningRatePolicy());
assertEquals(lrDecayRate, conf.getConf(0).getLrPolicyDecayRate(), 0.0);
assertEquals(lrDecayRate, conf.getConf(1).getLrPolicyDecayRate(), 0.0);
assertEquals(power, conf.getConf(0).getLrPolicyPower(), 0.0);
assertEquals(power, conf.getConf(1).getLrPolicyPower(), 0.0);
}
@Test
public void testLearningRatePolicySteps() {
double lr = 2;
double lrDecayRate = 5;
double steps = 4;
int iterations = 1;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().iterations(iterations).learningRate(lr)
.learningRateDecayPolicy(LearningRatePolicy.Step).lrPolicyDecayRate(lrDecayRate)
.lrPolicySteps(steps).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build())
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
assertEquals(LearningRatePolicy.Step, conf.getConf(0).getLearningRatePolicy());
assertEquals(LearningRatePolicy.Step, conf.getConf(1).getLearningRatePolicy());
assertEquals(lrDecayRate, conf.getConf(0).getLrPolicyDecayRate(), 0.0);
assertEquals(lrDecayRate, conf.getConf(1).getLrPolicyDecayRate(), 0.0);
assertEquals(steps, conf.getConf(0).getLrPolicySteps(), 0.0);
assertEquals(steps, conf.getConf(1).getLrPolicySteps(), 0.0);
}
@Test
public void testLearningRatePolicyPoly() {
double lr = 2;
double lrDecayRate = 5;
double power = 3;
int iterations = 1;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().iterations(iterations).learningRate(lr)
.learningRateDecayPolicy(LearningRatePolicy.Poly).lrPolicyDecayRate(lrDecayRate)
.lrPolicyPower(power).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build())
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
assertEquals(LearningRatePolicy.Poly, conf.getConf(0).getLearningRatePolicy());
assertEquals(LearningRatePolicy.Poly, conf.getConf(1).getLearningRatePolicy());
assertEquals(lrDecayRate, conf.getConf(0).getLrPolicyDecayRate(), 0.0);
assertEquals(lrDecayRate, conf.getConf(1).getLrPolicyDecayRate(), 0.0);
assertEquals(power, conf.getConf(0).getLrPolicyPower(), 0.0);
assertEquals(power, conf.getConf(1).getLrPolicyPower(), 0.0);
}
@Test
public void testLearningRatePolicySigmoid() {
double lr = 2;
double lrDecayRate = 5;
double steps = 4;
int iterations = 1;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().iterations(iterations).learningRate(lr)
.learningRateDecayPolicy(LearningRatePolicy.Sigmoid).lrPolicyDecayRate(lrDecayRate)
.lrPolicySteps(steps).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build())
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
assertEquals(LearningRatePolicy.Sigmoid, conf.getConf(0).getLearningRatePolicy());
assertEquals(LearningRatePolicy.Sigmoid, conf.getConf(1).getLearningRatePolicy());
assertEquals(lrDecayRate, conf.getConf(0).getLrPolicyDecayRate(), 0.0);
assertEquals(lrDecayRate, conf.getConf(1).getLrPolicyDecayRate(), 0.0);
assertEquals(steps, conf.getConf(0).getLrPolicySteps(), 0.0);
assertEquals(steps, conf.getConf(1).getLrPolicySteps(), 0.0);
}
*/
@DisplayName("Layer Config Test")
class LayerConfigTest extends BaseDL4JTest {
@Test
@DisplayName("Test Layer Name")
void testLayerName() {
String name1 = "genisys"; String name1 = "genisys";
String name2 = "bill"; String name2 = "bill";
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).name(name1).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).name(name2).build()).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list()
.layer(0, new DenseLayer.Builder().nIn(2).nOut(2).name(name1).build())
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).name(name2).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
assertEquals(name1, conf.getConf(0).getLayer().getLayerName()); assertEquals(name1, conf.getConf(0).getLayer().getLayerName());
assertEquals(name2, conf.getConf(1).getLayer().getLayerName()); assertEquals(name2, conf.getConf(1).getLayer().getLayerName());
} }
@Test @Test
public void testActivationLayerwiseOverride() { @DisplayName("Test Activation Layerwise Override")
//Without layerwise override: void testActivationLayerwiseOverride() {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().activation(Activation.RELU).list() // Without layerwise override:
.layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()) MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().activation(Activation.RELU).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
assertEquals(((BaseLayer) conf.getConf(0).getLayer()).getActivationFn().toString(), "relu");
assertEquals("relu", ((BaseLayer) conf.getConf(0).getLayer()).getActivationFn().toString()); assertEquals(((BaseLayer) conf.getConf(1).getLayer()).getActivationFn().toString(), "relu");
assertEquals("relu", ((BaseLayer) conf.getConf(1).getLayer()).getActivationFn().toString()); // With
conf = new NeuralNetConfiguration.Builder().activation(Activation.RELU).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).activation(Activation.TANH).build()).build();
//With
conf = new NeuralNetConfiguration.Builder().activation(Activation.RELU).list()
.layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build())
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).activation(Activation.TANH).build()).build();
net = new MultiLayerNetwork(conf); net = new MultiLayerNetwork(conf);
net.init(); net.init();
assertEquals(((BaseLayer) conf.getConf(0).getLayer()).getActivationFn().toString(), "relu");
assertEquals("relu", ((BaseLayer) conf.getConf(0).getLayer()).getActivationFn().toString()); assertEquals(((BaseLayer) conf.getConf(1).getLayer()).getActivationFn().toString(), "tanh");
assertEquals("tanh", ((BaseLayer) conf.getConf(1).getLayer()).getActivationFn().toString());
} }
@Test @Test
public void testWeightBiasInitLayerwiseOverride() { @DisplayName("Test Weight Bias Init Layerwise Override")
//Without layerwise override: void testWeightBiasInitLayerwiseOverride() {
// Without layerwise override:
final Distribution defaultDistribution = new NormalDistribution(0, 1.0); final Distribution defaultDistribution = new NormalDistribution(0, 1.0);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dist(defaultDistribution).biasInit(1).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
.dist(defaultDistribution).biasInit(1).list()
.layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build())
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayer) conf.getConf(0).getLayer()).getWeightInitFn()); assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayer) conf.getConf(0).getLayer()).getWeightInitFn());
assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayer) conf.getConf(1).getLayer()).getWeightInitFn()); assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayer) conf.getConf(1).getLayer()).getWeightInitFn());
assertEquals(1, ((BaseLayer) conf.getConf(0).getLayer()).getBiasInit(), 0.0); assertEquals(1, ((BaseLayer) conf.getConf(0).getLayer()).getBiasInit(), 0.0);
assertEquals(1, ((BaseLayer) conf.getConf(1).getLayer()).getBiasInit(), 0.0); assertEquals(1, ((BaseLayer) conf.getConf(1).getLayer()).getBiasInit(), 0.0);
// With:
//With:
final Distribution overriddenDistribution = new UniformDistribution(0, 1); final Distribution overriddenDistribution = new UniformDistribution(0, 1);
conf = new NeuralNetConfiguration.Builder() conf = new NeuralNetConfiguration.Builder().dist(defaultDistribution).biasInit(1).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).dist(overriddenDistribution).biasInit(0).build()).build();
.dist(defaultDistribution).biasInit(1).list()
.layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1,
new DenseLayer.Builder().nIn(2).nOut(2)
.dist(overriddenDistribution).biasInit(0).build())
.build();
net = new MultiLayerNetwork(conf); net = new MultiLayerNetwork(conf);
net.init(); net.init();
assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayer) conf.getConf(0).getLayer()).getWeightInitFn()); assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayer) conf.getConf(0).getLayer()).getWeightInitFn());
assertEquals(new WeightInitDistribution(overriddenDistribution), ((BaseLayer) conf.getConf(1).getLayer()).getWeightInitFn()); assertEquals(new WeightInitDistribution(overriddenDistribution), ((BaseLayer) conf.getConf(1).getLayer()).getWeightInitFn());
assertEquals(1, ((BaseLayer) conf.getConf(0).getLayer()).getBiasInit(), 0.0); assertEquals(1, ((BaseLayer) conf.getConf(0).getLayer()).getBiasInit(), 0.0);
assertEquals(0, ((BaseLayer) conf.getConf(1).getLayer()).getBiasInit(), 0.0); assertEquals(0, ((BaseLayer) conf.getConf(1).getLayer()).getBiasInit(), 0.0);
} }
@ -176,101 +256,65 @@ public class LayerConfigTest extends BaseDL4JTest {
assertEquals(0.2, ((BaseLayer) conf.getConf(0).getLayer()).getL2(), 0.0); assertEquals(0.2, ((BaseLayer) conf.getConf(0).getLayer()).getL2(), 0.0);
assertEquals(0.8, ((BaseLayer) conf.getConf(1).getLayer()).getL2(), 0.0); assertEquals(0.8, ((BaseLayer) conf.getConf(1).getLayer()).getL2(), 0.0);
}*/ }*/
@Test @Test
public void testDropoutLayerwiseOverride() { @DisplayName("Test Dropout Layerwise Override")
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dropOut(1.0).list() void testDropoutLayerwiseOverride() {
.layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()) MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dropOut(1.0).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
assertEquals(new Dropout(1.0), conf.getConf(0).getLayer().getIDropout()); assertEquals(new Dropout(1.0), conf.getConf(0).getLayer().getIDropout());
assertEquals(new Dropout(1.0), conf.getConf(1).getLayer().getIDropout()); assertEquals(new Dropout(1.0), conf.getConf(1).getLayer().getIDropout());
conf = new NeuralNetConfiguration.Builder().dropOut(1.0).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).dropOut(2.0).build()).build();
conf = new NeuralNetConfiguration.Builder().dropOut(1.0).list()
.layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build())
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).dropOut(2.0).build()).build();
net = new MultiLayerNetwork(conf); net = new MultiLayerNetwork(conf);
net.init(); net.init();
assertEquals(new Dropout(1.0), conf.getConf(0).getLayer().getIDropout()); assertEquals(new Dropout(1.0), conf.getConf(0).getLayer().getIDropout());
assertEquals(new Dropout(2.0), conf.getConf(1).getLayer().getIDropout()); assertEquals(new Dropout(2.0), conf.getConf(1).getLayer().getIDropout());
} }
@Test @Test
public void testMomentumLayerwiseOverride() { @DisplayName("Test Momentum Layerwise Override")
void testMomentumLayerwiseOverride() {
Map<Integer, Double> testMomentumAfter = new HashMap<>(); Map<Integer, Double> testMomentumAfter = new HashMap<>();
testMomentumAfter.put(0, 0.1); testMomentumAfter.put(0, 0.1);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Nesterovs(1.0, new MapSchedule(ScheduleType.ITERATION, testMomentumAfter))).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.updater(new Nesterovs(1.0, new MapSchedule(ScheduleType.ITERATION, testMomentumAfter)))
.list()
.layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build())
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
assertEquals(0.1, ((Nesterovs) ((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0, 0), 0.0);
assertEquals(0.1, ((Nesterovs)((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0); assertEquals(0.1, ((Nesterovs) ((BaseLayer) conf.getConf(1).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0, 0), 0.0);
assertEquals(0.1, ((Nesterovs)((BaseLayer) conf.getConf(1).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0);
Map<Integer, Double> testMomentumAfter2 = new HashMap<>(); Map<Integer, Double> testMomentumAfter2 = new HashMap<>();
testMomentumAfter2.put(0, 0.2); testMomentumAfter2.put(0, 0.2);
conf = new NeuralNetConfiguration.Builder().updater(new Nesterovs(1.0, new MapSchedule(ScheduleType.ITERATION, testMomentumAfter))).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).updater(new Nesterovs(1.0, new MapSchedule(ScheduleType.ITERATION, testMomentumAfter2))).build()).build();
conf = new NeuralNetConfiguration.Builder().updater(new Nesterovs(1.0, new MapSchedule(ScheduleType.ITERATION, testMomentumAfter) ))
.list()
.layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder()
.nIn(2).nOut(2).updater(new Nesterovs(1.0, new MapSchedule(ScheduleType.ITERATION, testMomentumAfter2))).build())
.build();
net = new MultiLayerNetwork(conf); net = new MultiLayerNetwork(conf);
net.init(); net.init();
assertEquals(0.1, ((Nesterovs)((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0); assertEquals(0.1, ((Nesterovs) ((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0, 0), 0.0);
assertEquals(0.2, ((Nesterovs)((BaseLayer) conf.getConf(1).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0); assertEquals(0.2, ((Nesterovs) ((BaseLayer) conf.getConf(1).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0, 0), 0.0);
} }
@Test @Test
public void testUpdaterRhoRmsDecayLayerwiseOverride() { @DisplayName("Test Updater Rho Rms Decay Layerwise Override")
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new AdaDelta(0.5, 0.9)).list() void testUpdaterRhoRmsDecayLayerwiseOverride() {
.layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()) MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new AdaDelta(0.5, 0.9)).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).updater(new AdaDelta(0.01, 0.9)).build()).build();
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).updater(new AdaDelta(0.01,0.9)).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
assertTrue(((BaseLayer) conf.getConf(0).getLayer()).getIUpdater() instanceof AdaDelta); assertTrue(((BaseLayer) conf.getConf(0).getLayer()).getIUpdater() instanceof AdaDelta);
assertTrue(((BaseLayer) conf.getConf(1).getLayer()).getIUpdater() instanceof AdaDelta); assertTrue(((BaseLayer) conf.getConf(1).getLayer()).getIUpdater() instanceof AdaDelta);
assertEquals(0.5, ((AdaDelta)((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getRho(), 0.0); assertEquals(0.5, ((AdaDelta) ((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getRho(), 0.0);
assertEquals(0.01, ((AdaDelta)((BaseLayer) conf.getConf(1).getLayer()).getIUpdater()).getRho(), 0.0); assertEquals(0.01, ((AdaDelta) ((BaseLayer) conf.getConf(1).getLayer()).getIUpdater()).getRho(), 0.0);
conf = new NeuralNetConfiguration.Builder().updater(new RmsProp(1.0, 2.0, RmsProp.DEFAULT_RMSPROP_EPSILON)).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).updater(new RmsProp(1.0, 1.0, RmsProp.DEFAULT_RMSPROP_EPSILON)).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).updater(new AdaDelta(0.5, AdaDelta.DEFAULT_ADADELTA_EPSILON)).build()).build();
conf = new NeuralNetConfiguration.Builder().updater(new RmsProp(1.0, 2.0, RmsProp.DEFAULT_RMSPROP_EPSILON)).list()
.layer(0, new DenseLayer.Builder().nIn(2).nOut(2).updater(new RmsProp(1.0, 1.0, RmsProp.DEFAULT_RMSPROP_EPSILON)).build())
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).updater(new AdaDelta(0.5,AdaDelta.DEFAULT_ADADELTA_EPSILON)).build())
.build();
net = new MultiLayerNetwork(conf); net = new MultiLayerNetwork(conf);
net.init(); net.init();
assertTrue(((BaseLayer) conf.getConf(0).getLayer()).getIUpdater() instanceof RmsProp); assertTrue(((BaseLayer) conf.getConf(0).getLayer()).getIUpdater() instanceof RmsProp);
assertTrue(((BaseLayer) conf.getConf(1).getLayer()).getIUpdater() instanceof AdaDelta); assertTrue(((BaseLayer) conf.getConf(1).getLayer()).getIUpdater() instanceof AdaDelta);
assertEquals(1.0, ((RmsProp) ((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getRmsDecay(), 0.0); assertEquals(1.0, ((RmsProp) ((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getRmsDecay(), 0.0);
assertEquals(0.5, ((AdaDelta) ((BaseLayer) conf.getConf(1).getLayer()).getIUpdater()).getRho(), 0.0); assertEquals(0.5, ((AdaDelta) ((BaseLayer) conf.getConf(1).getLayer()).getIUpdater()).getRho(), 0.0);
} }
@Test @Test
public void testUpdaterAdamParamsLayerwiseOverride() { @DisplayName("Test Updater Adam Params Layerwise Override")
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() void testUpdaterAdamParamsLayerwiseOverride() {
.updater(new Adam(1.0, 0.5, 0.5, 1e-8)) MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Adam(1.0, 0.5, 0.5, 1e-8)).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).updater(new Adam(1.0, 0.6, 0.7, 1e-8)).build()).build();
.list()
.layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build())
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).updater(new Adam(1.0, 0.6, 0.7, 1e-8)).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
assertEquals(0.5, ((Adam) ((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getBeta1(), 0.0); assertEquals(0.5, ((Adam) ((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getBeta1(), 0.0);
assertEquals(0.6, ((Adam) ((BaseLayer) conf.getConf(1).getLayer()).getIUpdater()).getBeta1(), 0.0); assertEquals(0.6, ((Adam) ((BaseLayer) conf.getConf(1).getLayer()).getIUpdater()).getBeta1(), 0.0);
assertEquals(0.5, ((Adam) ((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getBeta2(), 0.0); assertEquals(0.5, ((Adam) ((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getBeta2(), 0.0);
@ -278,45 +322,25 @@ public class LayerConfigTest extends BaseDL4JTest {
} }
@Test @Test
public void testGradientNormalizationLayerwiseOverride() { @DisplayName("Test Gradient Normalization Layerwise Override")
void testGradientNormalizationLayerwiseOverride() {
//Learning rate without layerwise override: // Learning rate without layerwise override:
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue).gradientNormalizationThreshold(10).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
.gradientNormalizationThreshold(10).list()
.layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build())
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, ((BaseLayer) conf.getConf(0).getLayer()).getGradientNormalization());
assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, ((BaseLayer) conf.getConf(1).getLayer()).getGradientNormalization());
((BaseLayer) conf.getConf(0).getLayer()).getGradientNormalization());
assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue,
((BaseLayer) conf.getConf(1).getLayer()).getGradientNormalization());
assertEquals(10, ((BaseLayer) conf.getConf(0).getLayer()).getGradientNormalizationThreshold(), 0.0); assertEquals(10, ((BaseLayer) conf.getConf(0).getLayer()).getGradientNormalizationThreshold(), 0.0);
assertEquals(10, ((BaseLayer) conf.getConf(1).getLayer()).getGradientNormalizationThreshold(), 0.0); assertEquals(10, ((BaseLayer) conf.getConf(1).getLayer()).getGradientNormalizationThreshold(), 0.0);
// With:
//With: conf = new NeuralNetConfiguration.Builder().gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue).gradientNormalizationThreshold(10).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).gradientNormalization(GradientNormalization.None).gradientNormalizationThreshold(2.5).build()).build();
conf = new NeuralNetConfiguration.Builder()
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
.gradientNormalizationThreshold(10).list()
.layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build())
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2)
.gradientNormalization(GradientNormalization.None)
.gradientNormalizationThreshold(2.5).build())
.build();
net = new MultiLayerNetwork(conf); net = new MultiLayerNetwork(conf);
net.init(); net.init();
assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, ((BaseLayer) conf.getConf(0).getLayer()).getGradientNormalization());
assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue,
((BaseLayer) conf.getConf(0).getLayer()).getGradientNormalization());
assertEquals(GradientNormalization.None, ((BaseLayer) conf.getConf(1).getLayer()).getGradientNormalization()); assertEquals(GradientNormalization.None, ((BaseLayer) conf.getConf(1).getLayer()).getGradientNormalization());
assertEquals(10, ((BaseLayer) conf.getConf(0).getLayer()).getGradientNormalizationThreshold(), 0.0); assertEquals(10, ((BaseLayer) conf.getConf(0).getLayer()).getGradientNormalizationThreshold(), 0.0);
assertEquals(2.5, ((BaseLayer) conf.getConf(1).getLayer()).getGradientNormalizationThreshold(), 0.0); assertEquals(2.5, ((BaseLayer) conf.getConf(1).getLayer()).getGradientNormalizationThreshold(), 0.0);
} }
/* /*
@Test @Test
public void testLearningRatePolicyExponential() { public void testLearningRatePolicyExponential() {

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.deeplearning4j.nn.conf.layers; package org.deeplearning4j.nn.conf.layers;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
@ -35,8 +34,8 @@ import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.WeightInitDistribution; import org.deeplearning4j.nn.weights.WeightInitDistribution;
import org.junit.Ignore; import org.junit.jupiter.api.Disabled;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.Nesterovs; import org.nd4j.linalg.learning.config.Nesterovs;
@ -44,107 +43,89 @@ import org.nd4j.linalg.learning.config.RmsProp;
import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.schedule.MapSchedule; import org.nd4j.linalg.schedule.MapSchedule;
import org.nd4j.linalg.schedule.ScheduleType; import org.nd4j.linalg.schedule.ScheduleType;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.Assert.assertEquals; @DisplayName("Layer Config Validation Test")
import static org.junit.Assert.assertNull; class LayerConfigValidationTest extends BaseDL4JTest {
public class LayerConfigValidationTest extends BaseDL4JTest {
@Test @Test
public void testDropConnect() { @DisplayName("Test Drop Connect")
void testDropConnect() {
// Warning thrown only since some layers may not have l1 or l2 // Warning thrown only since some layers may not have l1 or l2
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)).weightNoise(new DropConnect(0.5)) MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)).weightNoise(new DropConnect(0.5)).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
.list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build())
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
} }
@Test @Test
public void testL1L2NotSet() { @DisplayName("Test L 1 L 2 Not Set")
void testL1L2NotSet() {
// Warning thrown only since some layers may not have l1 or l2 // Warning thrown only since some layers may not have l1 or l2
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.3)) MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.3)).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
.list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build())
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
}
@Test(expected = IllegalStateException.class)
@Ignore //Old assumption: throw exception on l1 but no regularization. Current design: warn, not exception
public void testRegNotSetL1Global() {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.3)).l1(0.5).list()
.layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build())
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
}
@Test(expected = IllegalStateException.class)
@Ignore //Old assumption: throw exception on l1 but no regularization. Current design: warn, not exception
public void testRegNotSetL2Local() {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.3)).list()
.layer(0, new DenseLayer.Builder().nIn(2).nOut(2).l2(0.5).build())
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
} }
@Test @Test
public void testWeightInitDistNotSet() { @Disabled
@DisplayName("Test Reg Not Set L 1 Global")
void testRegNotSetL1Global() {
assertThrows(IllegalStateException.class, () -> {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.3)).l1(0.5).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
});
}
@Test
@Disabled
@DisplayName("Test Reg Not Set L 2 Local")
void testRegNotSetL2Local() {
assertThrows(IllegalStateException.class, () -> {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.3)).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).l2(0.5).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
});
}
@Test
@DisplayName("Test Weight Init Dist Not Set")
void testWeightInitDistNotSet() {
// Warning thrown only since global dist can be set with a different weight init locally // Warning thrown only since global dist can be set with a different weight init locally
MultiLayerConfiguration conf = MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.3)).dist(new GaussianDistribution(1e-3, 2)).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
new NeuralNetConfiguration.Builder().updater(new Sgd(0.3)).dist(new GaussianDistribution(1e-3, 2))
.list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build())
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
} }
@Test @Test
public void testNesterovsNotSetGlobal() { @DisplayName("Test Nesterovs Not Set Global")
void testNesterovsNotSetGlobal() {
// Warnings only thrown // Warnings only thrown
Map<Integer, Double> testMomentumAfter = new HashMap<>(); Map<Integer, Double> testMomentumAfter = new HashMap<>();
testMomentumAfter.put(0, 0.1); testMomentumAfter.put(0, 0.1);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Nesterovs(1.0, new MapSchedule(ScheduleType.ITERATION, testMomentumAfter))).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
MultiLayerConfiguration conf =
new NeuralNetConfiguration.Builder().updater(new Nesterovs(1.0, new MapSchedule(ScheduleType.ITERATION, testMomentumAfter))).list()
.layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build())
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
} }
@Test @Test
public void testCompGraphNullLayer() { @DisplayName("Test Comp Graph Null Layer")
ComputationGraphConfiguration.GraphBuilder gb = new NeuralNetConfiguration.Builder() void testCompGraphNullLayer() {
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Sgd(0.01)) ComputationGraphConfiguration.GraphBuilder gb = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Sgd(0.01)).seed(42).miniBatch(false).l1(0.2).l2(0.2).updater(Updater.RMSPROP).graphBuilder().addInputs("in").addLayer("L" + 1, new GravesLSTM.Builder().nIn(20).updater(Updater.RMSPROP).nOut(10).weightInit(WeightInit.XAVIER).dropOut(0.4).l1(0.3).activation(Activation.SIGMOID).build(), "in").addLayer("output", new RnnOutputLayer.Builder().nIn(20).nOut(10).activation(Activation.SOFTMAX).weightInit(WeightInit.RELU_UNIFORM).build(), "L" + 1).setOutputs("output");
.seed(42).miniBatch(false).l1(0.2).l2(0.2)
/* Graph Builder */
.updater(Updater.RMSPROP).graphBuilder().addInputs("in")
.addLayer("L" + 1,
new GravesLSTM.Builder().nIn(20).updater(Updater.RMSPROP).nOut(10)
.weightInit(WeightInit.XAVIER)
.dropOut(0.4).l1(0.3).activation(Activation.SIGMOID).build(),
"in")
.addLayer("output",
new RnnOutputLayer.Builder().nIn(20).nOut(10).activation(Activation.SOFTMAX)
.weightInit(WeightInit.RELU_UNIFORM).build(),
"L" + 1)
.setOutputs("output");
ComputationGraphConfiguration conf = gb.build(); ComputationGraphConfiguration conf = gb.build();
ComputationGraph cg = new ComputationGraph(conf); ComputationGraph cg = new ComputationGraph(conf);
cg.init(); cg.init();
} }
@Test @Test
public void testPredefinedConfigValues() { @DisplayName("Test Predefined Config Values")
void testPredefinedConfigValues() {
double expectedMomentum = 0.9; double expectedMomentum = 0.9;
double expectedAdamMeanDecay = 0.9; double expectedAdamMeanDecay = 0.9;
double expectedAdamVarDecay = 0.999; double expectedAdamVarDecay = 0.999;
@ -152,59 +133,38 @@ public class LayerConfigValidationTest extends BaseDL4JTest {
Distribution expectedDist = new NormalDistribution(0, 1); Distribution expectedDist = new NormalDistribution(0, 1);
double expectedL1 = 0.0; double expectedL1 = 0.0;
double expectedL2 = 0.0; double expectedL2 = 0.0;
// Nesterovs Updater // Nesterovs Updater
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Nesterovs(0.9)) MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Nesterovs(0.9)).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).l2(0.5).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).updater(new Nesterovs(0.3, 0.4)).build()).build();
.list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).l2(0.5).build())
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).updater(new Nesterovs(0.3, 0.4)).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
BaseLayer layerConf = (BaseLayer) net.getLayer(0).conf().getLayer(); BaseLayer layerConf = (BaseLayer) net.getLayer(0).conf().getLayer();
assertEquals(expectedMomentum, ((Nesterovs) layerConf.getIUpdater()).getMomentum(), 1e-3); assertEquals(expectedMomentum, ((Nesterovs) layerConf.getIUpdater()).getMomentum(), 1e-3);
assertNull(TestUtils.getL1Reg(layerConf.getRegularization())); assertNull(TestUtils.getL1Reg(layerConf.getRegularization()));
assertEquals(0.5, TestUtils.getL2(layerConf), 1e-3); assertEquals(0.5, TestUtils.getL2(layerConf), 1e-3);
BaseLayer layerConf1 = (BaseLayer) net.getLayer(1).conf().getLayer(); BaseLayer layerConf1 = (BaseLayer) net.getLayer(1).conf().getLayer();
assertEquals(0.4, ((Nesterovs) layerConf1.getIUpdater()).getMomentum(), 1e-3); assertEquals(0.4, ((Nesterovs) layerConf1.getIUpdater()).getMomentum(), 1e-3);
// Adam Updater // Adam Updater
conf = new NeuralNetConfiguration.Builder().updater(new Adam(0.3)) conf = new NeuralNetConfiguration.Builder().updater(new Adam(0.3)).weightInit(new WeightInitDistribution(expectedDist)).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).l2(0.5).l1(0.3).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
.weightInit(new WeightInitDistribution(expectedDist)).list()
.layer(0, new DenseLayer.Builder().nIn(2).nOut(2).l2(0.5).l1(0.3).build())
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
net = new MultiLayerNetwork(conf); net = new MultiLayerNetwork(conf);
net.init(); net.init();
layerConf = (BaseLayer) net.getLayer(0).conf().getLayer(); layerConf = (BaseLayer) net.getLayer(0).conf().getLayer();
assertEquals(0.3, TestUtils.getL1(layerConf), 1e-3); assertEquals(0.3, TestUtils.getL1(layerConf), 1e-3);
assertEquals(0.5, TestUtils.getL2(layerConf), 1e-3); assertEquals(0.5, TestUtils.getL2(layerConf), 1e-3);
layerConf1 = (BaseLayer) net.getLayer(1).conf().getLayer(); layerConf1 = (BaseLayer) net.getLayer(1).conf().getLayer();
assertEquals(expectedAdamMeanDecay, ((Adam) layerConf1.getIUpdater()).getBeta1(), 1e-3); assertEquals(expectedAdamMeanDecay, ((Adam) layerConf1.getIUpdater()).getBeta1(), 1e-3);
assertEquals(expectedAdamVarDecay, ((Adam) layerConf1.getIUpdater()).getBeta2(), 1e-3); assertEquals(expectedAdamVarDecay, ((Adam) layerConf1.getIUpdater()).getBeta2(), 1e-3);
assertEquals(new WeightInitDistribution(expectedDist), layerConf1.getWeightInitFn()); assertEquals(new WeightInitDistribution(expectedDist), layerConf1.getWeightInitFn());
assertNull(TestUtils.getL1Reg(layerConf1.getRegularization())); assertNull(TestUtils.getL1Reg(layerConf1.getRegularization()));
assertNull(TestUtils.getL2Reg(layerConf1.getRegularization())); assertNull(TestUtils.getL2Reg(layerConf1.getRegularization()));
// RMSProp Updater
//RMSProp Updater conf = new NeuralNetConfiguration.Builder().updater(new RmsProp(0.3)).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).updater(new RmsProp(0.3, 0.4, RmsProp.DEFAULT_RMSPROP_EPSILON)).build()).build();
conf = new NeuralNetConfiguration.Builder().updater(new RmsProp(0.3)).list()
.layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build())
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).updater(new RmsProp(0.3, 0.4, RmsProp.DEFAULT_RMSPROP_EPSILON)).build()).build();
net = new MultiLayerNetwork(conf); net = new MultiLayerNetwork(conf);
net.init(); net.init();
layerConf = (BaseLayer) net.getLayer(0).conf().getLayer(); layerConf = (BaseLayer) net.getLayer(0).conf().getLayer();
assertEquals(expectedRmsDecay, ((RmsProp) layerConf.getIUpdater()).getRmsDecay(), 1e-3); assertEquals(expectedRmsDecay, ((RmsProp) layerConf.getIUpdater()).getRmsDecay(), 1e-3);
assertNull(TestUtils.getL1Reg(layerConf.getRegularization())); assertNull(TestUtils.getL1Reg(layerConf.getRegularization()));
assertNull(TestUtils.getL2Reg(layerConf.getRegularization())); assertNull(TestUtils.getL2Reg(layerConf.getRegularization()));
layerConf1 = (BaseLayer) net.getLayer(1).conf().getLayer(); layerConf1 = (BaseLayer) net.getLayer(1).conf().getLayer();
assertEquals(0.4, ((RmsProp) layerConf1.getIUpdater()).getRmsDecay(), 1e-3); assertEquals(0.4, ((RmsProp) layerConf1.getIUpdater()).getRmsDecay(), 1e-3);
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.deeplearning4j.nn.conf.preprocessor; package org.deeplearning4j.nn.conf.preprocessor;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
@ -28,7 +27,7 @@ import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -36,29 +35,33 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Nesterovs; import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import static org.junit.jupiter.api.Assertions.*;
import static org.junit.Assert.*; import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
/** /**
**/ */
@DisplayName("Cnn Processor Test")
class CNNProcessorTest extends BaseDL4JTest {
public class CNNProcessorTest extends BaseDL4JTest {
private static int rows = 28; private static int rows = 28;
private static int cols = 28; private static int cols = 28;
private static INDArray in2D = Nd4j.create(DataType.FLOAT, 1, 784); private static INDArray in2D = Nd4j.create(DataType.FLOAT, 1, 784);
private static INDArray in3D = Nd4j.create(DataType.FLOAT, 20, 784, 7); private static INDArray in3D = Nd4j.create(DataType.FLOAT, 20, 784, 7);
private static INDArray in4D = Nd4j.create(DataType.FLOAT, 20, 1, 28, 28); private static INDArray in4D = Nd4j.create(DataType.FLOAT, 20, 1, 28, 28);
@Test @Test
public void testFeedForwardToCnnPreProcessor() { @DisplayName("Test Feed Forward To Cnn Pre Processor")
void testFeedForwardToCnnPreProcessor() {
FeedForwardToCnnPreProcessor convProcessor = new FeedForwardToCnnPreProcessor(rows, cols, 1); FeedForwardToCnnPreProcessor convProcessor = new FeedForwardToCnnPreProcessor(rows, cols, 1);
INDArray check2to4 = convProcessor.preProcess(in2D, -1, LayerWorkspaceMgr.noWorkspaces()); INDArray check2to4 = convProcessor.preProcess(in2D, -1, LayerWorkspaceMgr.noWorkspaces());
int val2to4 = check2to4.shape().length; int val2to4 = check2to4.shape().length;
assertTrue(val2to4 == 4); assertTrue(val2to4 == 4);
assertEquals(Nd4j.create(DataType.FLOAT, 1, 1, 28, 28), check2to4); assertEquals(Nd4j.create(DataType.FLOAT, 1, 1, 28, 28), check2to4);
INDArray check4to4 = convProcessor.preProcess(in4D, -1, LayerWorkspaceMgr.noWorkspaces()); INDArray check4to4 = convProcessor.preProcess(in4D, -1, LayerWorkspaceMgr.noWorkspaces());
int val4to4 = check4to4.shape().length; int val4to4 = check4to4.shape().length;
assertTrue(val4to4 == 4); assertTrue(val4to4 == 4);
@ -66,42 +69,41 @@ public class CNNProcessorTest extends BaseDL4JTest {
} }
@Test @Test
public void testFeedForwardToCnnPreProcessor2() { @DisplayName("Test Feed Forward To Cnn Pre Processor 2")
int[] nRows = {1, 5, 20}; void testFeedForwardToCnnPreProcessor2() {
int[] nCols = {1, 5, 20}; int[] nRows = { 1, 5, 20 };
int[] nDepth = {1, 3}; int[] nCols = { 1, 5, 20 };
int[] nMiniBatchSize = {1, 5}; int[] nDepth = { 1, 3 };
int[] nMiniBatchSize = { 1, 5 };
for (int rows : nRows) { for (int rows : nRows) {
for (int cols : nCols) { for (int cols : nCols) {
for (int d : nDepth) { for (int d : nDepth) {
FeedForwardToCnnPreProcessor convProcessor = new FeedForwardToCnnPreProcessor(rows, cols, d); FeedForwardToCnnPreProcessor convProcessor = new FeedForwardToCnnPreProcessor(rows, cols, d);
for (int miniBatch : nMiniBatchSize) { for (int miniBatch : nMiniBatchSize) {
long[] ffShape = new long[] {miniBatch, rows * cols * d}; long[] ffShape = new long[] { miniBatch, rows * cols * d };
INDArray rand = Nd4j.rand(ffShape); INDArray rand = Nd4j.rand(ffShape);
INDArray ffInput_c = Nd4j.create(DataType.FLOAT, ffShape, 'c'); INDArray ffInput_c = Nd4j.create(DataType.FLOAT, ffShape, 'c');
INDArray ffInput_f = Nd4j.create(DataType.FLOAT, ffShape, 'f'); INDArray ffInput_f = Nd4j.create(DataType.FLOAT, ffShape, 'f');
ffInput_c.assign(rand); ffInput_c.assign(rand);
ffInput_f.assign(rand); ffInput_f.assign(rand);
assertEquals(ffInput_c, ffInput_f); assertEquals(ffInput_c, ffInput_f);
// Test forward pass:
//Test forward pass:
INDArray convAct_c = convProcessor.preProcess(ffInput_c, -1, LayerWorkspaceMgr.noWorkspaces()); INDArray convAct_c = convProcessor.preProcess(ffInput_c, -1, LayerWorkspaceMgr.noWorkspaces());
INDArray convAct_f = convProcessor.preProcess(ffInput_f, -1, LayerWorkspaceMgr.noWorkspaces()); INDArray convAct_f = convProcessor.preProcess(ffInput_f, -1, LayerWorkspaceMgr.noWorkspaces());
long[] convShape = {miniBatch, d, rows, cols}; long[] convShape = { miniBatch, d, rows, cols };
assertArrayEquals(convShape, convAct_c.shape()); assertArrayEquals(convShape, convAct_c.shape());
assertArrayEquals(convShape, convAct_f.shape()); assertArrayEquals(convShape, convAct_f.shape());
assertEquals(convAct_c, convAct_f); assertEquals(convAct_c, convAct_f);
// Check values:
//Check values: // CNN reshaping (for each example) takes a 1d vector and converts it to 3d
//CNN reshaping (for each example) takes a 1d vector and converts it to 3d
// (4d total, for minibatch data) // (4d total, for minibatch data)
//1d vector is assumed to be rows from channels 0 concatenated, followed by channels 1, etc // 1d vector is assumed to be rows from channels 0 concatenated, followed by channels 1, etc
for (int ex = 0; ex < miniBatch; ex++) { for (int ex = 0; ex < miniBatch; ex++) {
for (int r = 0; r < rows; r++) { for (int r = 0; r < rows; r++) {
for (int c = 0; c < cols; c++) { for (int c = 0; c < cols; c++) {
for (int depth = 0; depth < d; depth++) { for (int depth = 0; depth < d; depth++) {
int origPosition = depth * (rows * cols) + r * cols + c; //pos in vector // pos in vector
int origPosition = depth * (rows * cols) + r * cols + c;
double vecValue = ffInput_c.getDouble(ex, origPosition); double vecValue = ffInput_c.getDouble(ex, origPosition);
double convValue = convAct_c.getDouble(ex, depth, r, c); double convValue = convAct_c.getDouble(ex, depth, r, c);
assertEquals(vecValue, convValue, 0.0); assertEquals(vecValue, convValue, 0.0);
@ -109,9 +111,8 @@ public class CNNProcessorTest extends BaseDL4JTest {
} }
} }
} }
// Test backward pass:
//Test backward pass: // Idea is that backward pass should do opposite to forward pass
//Idea is that backward pass should do opposite to forward pass
INDArray epsilon4_c = Nd4j.create(DataType.FLOAT, convShape, 'c'); INDArray epsilon4_c = Nd4j.create(DataType.FLOAT, convShape, 'c');
INDArray epsilon4_f = Nd4j.create(DataType.FLOAT, convShape, 'f'); INDArray epsilon4_f = Nd4j.create(DataType.FLOAT, convShape, 'f');
epsilon4_c.assign(convAct_c); epsilon4_c.assign(convAct_c);
@ -126,12 +127,11 @@ public class CNNProcessorTest extends BaseDL4JTest {
} }
} }
@Test @Test
public void testFeedForwardToCnnPreProcessorBackprop() { @DisplayName("Test Feed Forward To Cnn Pre Processor Backprop")
void testFeedForwardToCnnPreProcessorBackprop() {
FeedForwardToCnnPreProcessor convProcessor = new FeedForwardToCnnPreProcessor(rows, cols, 1); FeedForwardToCnnPreProcessor convProcessor = new FeedForwardToCnnPreProcessor(rows, cols, 1);
convProcessor.preProcess(in2D, -1, LayerWorkspaceMgr.noWorkspaces()); convProcessor.preProcess(in2D, -1, LayerWorkspaceMgr.noWorkspaces());
INDArray check2to2 = convProcessor.backprop(in2D, -1, LayerWorkspaceMgr.noWorkspaces()); INDArray check2to2 = convProcessor.backprop(in2D, -1, LayerWorkspaceMgr.noWorkspaces());
int val2to2 = check2to2.shape().length; int val2to2 = check2to2.shape().length;
assertTrue(val2to2 == 2); assertTrue(val2to2 == 2);
@ -139,14 +139,13 @@ public class CNNProcessorTest extends BaseDL4JTest {
} }
@Test @Test
public void testCnnToFeedForwardProcessor() { @DisplayName("Test Cnn To Feed Forward Processor")
void testCnnToFeedForwardProcessor() {
CnnToFeedForwardPreProcessor convProcessor = new CnnToFeedForwardPreProcessor(rows, cols, 1); CnnToFeedForwardPreProcessor convProcessor = new CnnToFeedForwardPreProcessor(rows, cols, 1);
INDArray check2to4 = convProcessor.backprop(in2D, -1, LayerWorkspaceMgr.noWorkspaces()); INDArray check2to4 = convProcessor.backprop(in2D, -1, LayerWorkspaceMgr.noWorkspaces());
int val2to4 = check2to4.shape().length; int val2to4 = check2to4.shape().length;
assertTrue(val2to4 == 4); assertTrue(val2to4 == 4);
assertEquals(Nd4j.create(DataType.FLOAT, 1, 1, 28, 28), check2to4); assertEquals(Nd4j.create(DataType.FLOAT, 1, 1, 28, 28), check2to4);
INDArray check4to4 = convProcessor.backprop(in4D, -1, LayerWorkspaceMgr.noWorkspaces()); INDArray check4to4 = convProcessor.backprop(in4D, -1, LayerWorkspaceMgr.noWorkspaces());
int val4to4 = check4to4.shape().length; int val4to4 = check4to4.shape().length;
assertTrue(val4to4 == 4); assertTrue(val4to4 == 4);
@ -154,15 +153,14 @@ public class CNNProcessorTest extends BaseDL4JTest {
} }
@Test @Test
public void testCnnToFeedForwardPreProcessorBackprop() { @DisplayName("Test Cnn To Feed Forward Pre Processor Backprop")
void testCnnToFeedForwardPreProcessorBackprop() {
CnnToFeedForwardPreProcessor convProcessor = new CnnToFeedForwardPreProcessor(rows, cols, 1); CnnToFeedForwardPreProcessor convProcessor = new CnnToFeedForwardPreProcessor(rows, cols, 1);
convProcessor.preProcess(in4D, -1, LayerWorkspaceMgr.noWorkspaces()); convProcessor.preProcess(in4D, -1, LayerWorkspaceMgr.noWorkspaces());
INDArray check2to2 = convProcessor.preProcess(in2D, -1, LayerWorkspaceMgr.noWorkspaces()); INDArray check2to2 = convProcessor.preProcess(in2D, -1, LayerWorkspaceMgr.noWorkspaces());
int val2to2 = check2to2.shape().length; int val2to2 = check2to2.shape().length;
assertTrue(val2to2 == 2); assertTrue(val2to2 == 2);
assertEquals(Nd4j.create(DataType.FLOAT, 1, 784), check2to2); assertEquals(Nd4j.create(DataType.FLOAT, 1, 784), check2to2);
INDArray check4to2 = convProcessor.preProcess(in4D, -1, LayerWorkspaceMgr.noWorkspaces()); INDArray check4to2 = convProcessor.preProcess(in4D, -1, LayerWorkspaceMgr.noWorkspaces());
int val4to2 = check4to2.shape().length; int val4to2 = check4to2.shape().length;
assertTrue(val4to2 == 2); assertTrue(val4to2 == 2);
@ -170,42 +168,41 @@ public class CNNProcessorTest extends BaseDL4JTest {
} }
@Test @Test
public void testCnnToFeedForwardPreProcessor2() { @DisplayName("Test Cnn To Feed Forward Pre Processor 2")
int[] nRows = {1, 5, 20}; void testCnnToFeedForwardPreProcessor2() {
int[] nCols = {1, 5, 20}; int[] nRows = { 1, 5, 20 };
int[] nDepth = {1, 3}; int[] nCols = { 1, 5, 20 };
int[] nMiniBatchSize = {1, 5}; int[] nDepth = { 1, 3 };
int[] nMiniBatchSize = { 1, 5 };
for (int rows : nRows) { for (int rows : nRows) {
for (int cols : nCols) { for (int cols : nCols) {
for (int d : nDepth) { for (int d : nDepth) {
CnnToFeedForwardPreProcessor convProcessor = new CnnToFeedForwardPreProcessor(rows, cols, d); CnnToFeedForwardPreProcessor convProcessor = new CnnToFeedForwardPreProcessor(rows, cols, d);
for (int miniBatch : nMiniBatchSize) { for (int miniBatch : nMiniBatchSize) {
long[] convActShape = new long[] {miniBatch, d, rows, cols}; long[] convActShape = new long[] { miniBatch, d, rows, cols };
INDArray rand = Nd4j.rand(convActShape); INDArray rand = Nd4j.rand(convActShape);
INDArray convInput_c = Nd4j.create(DataType.FLOAT, convActShape, 'c'); INDArray convInput_c = Nd4j.create(DataType.FLOAT, convActShape, 'c');
INDArray convInput_f = Nd4j.create(DataType.FLOAT, convActShape, 'f'); INDArray convInput_f = Nd4j.create(DataType.FLOAT, convActShape, 'f');
convInput_c.assign(rand); convInput_c.assign(rand);
convInput_f.assign(rand); convInput_f.assign(rand);
assertEquals(convInput_c, convInput_f); assertEquals(convInput_c, convInput_f);
// Test forward pass:
//Test forward pass:
INDArray ffAct_c = convProcessor.preProcess(convInput_c, -1, LayerWorkspaceMgr.noWorkspaces()); INDArray ffAct_c = convProcessor.preProcess(convInput_c, -1, LayerWorkspaceMgr.noWorkspaces());
INDArray ffAct_f = convProcessor.preProcess(convInput_f, -1, LayerWorkspaceMgr.noWorkspaces()); INDArray ffAct_f = convProcessor.preProcess(convInput_f, -1, LayerWorkspaceMgr.noWorkspaces());
long[] ffActShape = {miniBatch, d * rows * cols}; long[] ffActShape = { miniBatch, d * rows * cols };
assertArrayEquals(ffActShape, ffAct_c.shape()); assertArrayEquals(ffActShape, ffAct_c.shape());
assertArrayEquals(ffActShape, ffAct_f.shape()); assertArrayEquals(ffActShape, ffAct_f.shape());
assertEquals(ffAct_c, ffAct_f); assertEquals(ffAct_c, ffAct_f);
// Check values:
//Check values: // CNN reshaping (for each example) takes a 1d vector and converts it to 3d
//CNN reshaping (for each example) takes a 1d vector and converts it to 3d
// (4d total, for minibatch data) // (4d total, for minibatch data)
//1d vector is assumed to be rows from channels 0 concatenated, followed by channels 1, etc // 1d vector is assumed to be rows from channels 0 concatenated, followed by channels 1, etc
for (int ex = 0; ex < miniBatch; ex++) { for (int ex = 0; ex < miniBatch; ex++) {
for (int r = 0; r < rows; r++) { for (int r = 0; r < rows; r++) {
for (int c = 0; c < cols; c++) { for (int c = 0; c < cols; c++) {
for (int depth = 0; depth < d; depth++) { for (int depth = 0; depth < d; depth++) {
int vectorPosition = depth * (rows * cols) + r * cols + c; //pos in vector after reshape // pos in vector after reshape
int vectorPosition = depth * (rows * cols) + r * cols + c;
double vecValue = ffAct_c.getDouble(ex, vectorPosition); double vecValue = ffAct_c.getDouble(ex, vectorPosition);
double convValue = convInput_c.getDouble(ex, depth, r, c); double convValue = convInput_c.getDouble(ex, depth, r, c);
assertEquals(convValue, vecValue, 0.0); assertEquals(convValue, vecValue, 0.0);
@ -213,9 +210,8 @@ public class CNNProcessorTest extends BaseDL4JTest {
} }
} }
} }
// Test backward pass:
//Test backward pass: // Idea is that backward pass should do opposite to forward pass
//Idea is that backward pass should do opposite to forward pass
INDArray epsilon2_c = Nd4j.create(DataType.FLOAT, ffActShape, 'c'); INDArray epsilon2_c = Nd4j.create(DataType.FLOAT, ffActShape, 'c');
INDArray epsilon2_f = Nd4j.create(DataType.FLOAT, ffActShape, 'f'); INDArray epsilon2_f = Nd4j.create(DataType.FLOAT, ffActShape, 'f');
epsilon2_c.assign(ffAct_c); epsilon2_c.assign(ffAct_c);
@ -231,79 +227,32 @@ public class CNNProcessorTest extends BaseDL4JTest {
} }
@Test @Test
public void testInvalidInputShape(){ @DisplayName("Test Invalid Input Shape")
void testInvalidInputShape() {
NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder() NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(123).miniBatch(true).cacheMode(CacheMode.DEVICE).updater(new Nesterovs(0.9)).gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT);
.seed(123) int[] kernelArray = new int[] { 3, 3 };
.miniBatch(true) int[] strideArray = new int[] { 1, 1 };
.cacheMode(CacheMode.DEVICE) int[] zeroPaddingArray = new int[] { 0, 0 };
.updater(new Nesterovs(0.9))
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT);
int[] kernelArray = new int[]{3,3};
int[] strideArray = new int[]{1,1};
int[] zeroPaddingArray = new int[]{0,0};
int processWidth = 4; int processWidth = 4;
// Building the DL4J network
NeuralNetConfiguration.ListBuilder listBuilder = builder.list(); // Building the DL4J network NeuralNetConfiguration.ListBuilder listBuilder = builder.list();
listBuilder = listBuilder.layer(0, new ConvolutionLayer.Builder(kernelArray, strideArray, zeroPaddingArray).name("cnn1").convolutionMode(ConvolutionMode.Strict).nIn(// 2 input channels
listBuilder = listBuilder.layer(0, new ConvolutionLayer.Builder(kernelArray, strideArray, zeroPaddingArray) 2).nOut(processWidth).weightInit(WeightInit.XAVIER_UNIFORM).activation(Activation.RELU).biasInit(1e-2).build());
.name("cnn1") listBuilder = listBuilder.layer(1, new ConvolutionLayer.Builder(kernelArray, strideArray, zeroPaddingArray).name("cnn2").convolutionMode(ConvolutionMode.Strict).nOut(processWidth).weightInit(WeightInit.XAVIER_UNIFORM).activation(Activation.RELU).biasInit(1e-2).build());
.convolutionMode(ConvolutionMode.Strict) listBuilder = listBuilder.layer(2, new ConvolutionLayer.Builder(kernelArray, strideArray, zeroPaddingArray).name("cnn3").convolutionMode(ConvolutionMode.Strict).nOut(processWidth).weightInit(WeightInit.XAVIER_UNIFORM).activation(Activation.RELU).build());
.nIn(2) // 2 input channels listBuilder = listBuilder.layer(3, new ConvolutionLayer.Builder(kernelArray, strideArray, zeroPaddingArray).name("cnn4").convolutionMode(ConvolutionMode.Strict).nOut(processWidth).weightInit(WeightInit.XAVIER_UNIFORM).activation(Activation.RELU).build());
.nOut(processWidth) listBuilder = listBuilder.layer(4, new OutputLayer.Builder(LossFunctions.LossFunction.MSE).name("output").nOut(1).activation(Activation.TANH).build());
.weightInit(WeightInit.XAVIER_UNIFORM) MultiLayerConfiguration conf = listBuilder.setInputType(InputType.convolutional(20, 10, 2)).build();
.activation(Activation.RELU)
.biasInit(1e-2).build());
listBuilder = listBuilder.layer(1, new ConvolutionLayer.Builder(kernelArray, strideArray, zeroPaddingArray)
.name("cnn2")
.convolutionMode(ConvolutionMode.Strict)
.nOut(processWidth)
.weightInit(WeightInit.XAVIER_UNIFORM)
.activation(Activation.RELU)
.biasInit(1e-2)
.build());
listBuilder = listBuilder.layer(2, new ConvolutionLayer.Builder(kernelArray, strideArray, zeroPaddingArray)
.name("cnn3")
.convolutionMode(ConvolutionMode.Strict)
.nOut(processWidth)
.weightInit(WeightInit.XAVIER_UNIFORM)
.activation(Activation.RELU).build());
listBuilder = listBuilder.layer(3, new ConvolutionLayer.Builder(kernelArray, strideArray, zeroPaddingArray)
.name("cnn4")
.convolutionMode(ConvolutionMode.Strict)
.nOut(processWidth)
.weightInit(WeightInit.XAVIER_UNIFORM)
.activation(Activation.RELU).build());
listBuilder = listBuilder
.layer(4, new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
.name("output")
.nOut(1)
.activation(Activation.TANH)
.build());
MultiLayerConfiguration conf = listBuilder
.setInputType(InputType.convolutional(20, 10, 2))
.build();
// For some reason, this model works // For some reason, this model works
MultiLayerNetwork niceModel = new MultiLayerNetwork(conf); MultiLayerNetwork niceModel = new MultiLayerNetwork(conf);
niceModel.init(); niceModel.init();
// Valid
niceModel.output(Nd4j.create(DataType.FLOAT, 1, 2, 20, 10)); //Valid niceModel.output(Nd4j.create(DataType.FLOAT, 1, 2, 20, 10));
try { try {
niceModel.output(Nd4j.create(DataType.FLOAT, 1, 2, 10, 20)); niceModel.output(Nd4j.create(DataType.FLOAT, 1, 2, 10, 20));
fail("Expected exception"); fail("Expected exception");
} catch (IllegalStateException e){ } catch (IllegalStateException e) {
//OK // OK
} }
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.deeplearning4j.nn.conf.preprocessor; package org.deeplearning4j.nn.conf.preprocessor;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
@ -27,44 +26,33 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.preprocessor.custom.MyCustomPreprocessor; import org.deeplearning4j.nn.conf.preprocessor.custom.MyCustomPreprocessor;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.shade.jackson.databind.ObjectMapper; import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.databind.introspect.AnnotatedClass; import org.nd4j.shade.jackson.databind.introspect.AnnotatedClass;
import org.nd4j.shade.jackson.databind.jsontype.NamedType; import org.nd4j.shade.jackson.databind.jsontype.NamedType;
import java.util.Collection; import java.util.Collection;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals; @DisplayName("Custom Preprocessor Test")
import static org.junit.Assert.assertTrue; class CustomPreprocessorTest extends BaseDL4JTest {
public class CustomPreprocessorTest extends BaseDL4JTest {
@Test @Test
public void testCustomPreprocessor() { @DisplayName("Test Custom Preprocessor")
//Second: let's create a MultiLayerCofiguration with one, and check JSON and YAML config actually works... void testCustomPreprocessor() {
MultiLayerConfiguration conf = // Second: let's create a MultiLayerCofiguration with one, and check JSON and YAML config actually works...
new NeuralNetConfiguration.Builder().list() MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build()).layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(10).activation(Activation.SOFTMAX).nOut(10).build()).inputPreProcessor(0, new MyCustomPreprocessor()).build();
.layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build())
.layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(10)
.activation(Activation.SOFTMAX).nOut(10).build())
.inputPreProcessor(0, new MyCustomPreprocessor())
.build();
String json = conf.toJson(); String json = conf.toJson();
String yaml = conf.toYaml(); String yaml = conf.toYaml();
// System.out.println(json);
// System.out.println(json);
MultiLayerConfiguration confFromJson = MultiLayerConfiguration.fromJson(json); MultiLayerConfiguration confFromJson = MultiLayerConfiguration.fromJson(json);
assertEquals(conf, confFromJson); assertEquals(conf, confFromJson);
MultiLayerConfiguration confFromYaml = MultiLayerConfiguration.fromYaml(yaml); MultiLayerConfiguration confFromYaml = MultiLayerConfiguration.fromYaml(yaml);
assertEquals(conf, confFromYaml); assertEquals(conf, confFromYaml);
assertTrue(confFromJson.getInputPreProcess(0) instanceof MyCustomPreprocessor); assertTrue(confFromJson.getInputPreProcess(0) instanceof MyCustomPreprocessor);
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.deeplearning4j.nn.layers; package org.deeplearning4j.nn.layers;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
@ -35,7 +34,7 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.impl.ActivationELU; import org.nd4j.linalg.activations.impl.ActivationELU;
import org.nd4j.linalg.activations.impl.ActivationRationalTanh; import org.nd4j.linalg.activations.impl.ActivationRationalTanh;
@ -46,31 +45,27 @@ import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.*;
import static org.junit.Assert.*; import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
/** /**
*/ */
@DisplayName("Activation Layer Test")
public class ActivationLayerTest extends BaseDL4JTest { class ActivationLayerTest extends BaseDL4JTest {
@Override @Override
public DataType getDataType(){ public DataType getDataType() {
return DataType.FLOAT; return DataType.FLOAT;
} }
@Test @Test
public void testInputTypes() { @DisplayName("Test Input Types")
org.deeplearning4j.nn.conf.layers.ActivationLayer l = void testInputTypes() {
new org.deeplearning4j.nn.conf.layers.ActivationLayer.Builder().activation(Activation.RELU) org.deeplearning4j.nn.conf.layers.ActivationLayer l = new org.deeplearning4j.nn.conf.layers.ActivationLayer.Builder().activation(Activation.RELU).build();
.build();
InputType in1 = InputType.feedForward(20); InputType in1 = InputType.feedForward(20);
InputType in2 = InputType.convolutional(28, 28, 1); InputType in2 = InputType.convolutional(28, 28, 1);
assertEquals(in1, l.getOutputType(0, in1)); assertEquals(in1, l.getOutputType(0, in1));
assertEquals(in2, l.getOutputType(0, in2)); assertEquals(in2, l.getOutputType(0, in2));
assertNull(l.getPreProcessorForInputType(in1)); assertNull(l.getPreProcessorForInputType(in1));
@ -78,252 +73,132 @@ public class ActivationLayerTest extends BaseDL4JTest {
} }
@Test @Test
public void testDenseActivationLayer() throws Exception { @DisplayName("Test Dense Activation Layer")
void testDenseActivationLayer() throws Exception {
DataSetIterator iter = new MnistDataSetIterator(2, 2); DataSetIterator iter = new MnistDataSetIterator(2, 2);
DataSet next = iter.next(); DataSet next = iter.next();
// Run without separate activation layer // Run without separate activation layer
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123).list().layer(0, new DenseLayer.Builder().nIn(28 * 28 * 1).nOut(10).activation(Activation.RELU).weightInit(WeightInit.XAVIER).build()).layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nIn(10).nOut(10).build()).build();
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123)
.list()
.layer(0, new DenseLayer.Builder().nIn(28 * 28 * 1).nOut(10).activation(Activation.RELU)
.weightInit(WeightInit.XAVIER).build())
.layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER)
.activation(Activation.SOFTMAX).nIn(10).nOut(10).build())
.build();
MultiLayerNetwork network = new MultiLayerNetwork(conf); MultiLayerNetwork network = new MultiLayerNetwork(conf);
network.init(); network.init();
network.fit(next); network.fit(next);
// Run with separate activation layer // Run with separate activation layer
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder() MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123).list().layer(0, new DenseLayer.Builder().nIn(28 * 28 * 1).nOut(10).activation(Activation.IDENTITY).weightInit(WeightInit.XAVIER).build()).layer(1, new org.deeplearning4j.nn.conf.layers.ActivationLayer.Builder().activation(Activation.RELU).build()).layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nIn(10).nOut(10).build()).build();
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123)
.list()
.layer(0, new DenseLayer.Builder().nIn(28 * 28 * 1).nOut(10).activation(Activation.IDENTITY)
.weightInit(WeightInit.XAVIER).build())
.layer(1, new org.deeplearning4j.nn.conf.layers.ActivationLayer.Builder()
.activation(Activation.RELU).build())
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nIn(10).nOut(10)
.build())
.build();
MultiLayerNetwork network2 = new MultiLayerNetwork(conf2); MultiLayerNetwork network2 = new MultiLayerNetwork(conf2);
network2.init(); network2.init();
network2.fit(next); network2.fit(next);
// check parameters // check parameters
assertEquals(network.getLayer(0).getParam("W"), network2.getLayer(0).getParam("W")); assertEquals(network.getLayer(0).getParam("W"), network2.getLayer(0).getParam("W"));
assertEquals(network.getLayer(1).getParam("W"), network2.getLayer(2).getParam("W")); assertEquals(network.getLayer(1).getParam("W"), network2.getLayer(2).getParam("W"));
assertEquals(network.getLayer(0).getParam("b"), network2.getLayer(0).getParam("b")); assertEquals(network.getLayer(0).getParam("b"), network2.getLayer(0).getParam("b"));
assertEquals(network.getLayer(1).getParam("b"), network2.getLayer(2).getParam("b")); assertEquals(network.getLayer(1).getParam("b"), network2.getLayer(2).getParam("b"));
// check activations // check activations
network.init(); network.init();
network.setInput(next.getFeatures()); network.setInput(next.getFeatures());
List<INDArray> activations = network.feedForward(true); List<INDArray> activations = network.feedForward(true);
network2.init(); network2.init();
network2.setInput(next.getFeatures()); network2.setInput(next.getFeatures());
List<INDArray> activations2 = network2.feedForward(true); List<INDArray> activations2 = network2.feedForward(true);
assertEquals(activations.get(1).reshape(activations2.get(2).shape()), activations2.get(2)); assertEquals(activations.get(1).reshape(activations2.get(2).shape()), activations2.get(2));
assertEquals(activations.get(2), activations2.get(3)); assertEquals(activations.get(2), activations2.get(3));
} }
@Test @Test
public void testAutoEncoderActivationLayer() throws Exception { @DisplayName("Test Auto Encoder Activation Layer")
void testAutoEncoderActivationLayer() throws Exception {
int minibatch = 3; int minibatch = 3;
int nIn = 5; int nIn = 5;
int layerSize = 5; int layerSize = 5;
int nOut = 3; int nOut = 3;
INDArray next = Nd4j.rand(new int[] { minibatch, nIn });
INDArray next = Nd4j.rand(new int[] {minibatch, nIn});
INDArray labels = Nd4j.zeros(minibatch, nOut); INDArray labels = Nd4j.zeros(minibatch, nOut);
for (int i = 0; i < minibatch; i++) { for (int i = 0; i < minibatch; i++) {
labels.putScalar(i, i % nOut, 1.0); labels.putScalar(i, i % nOut, 1.0);
} }
// Run without separate activation layer // Run without separate activation layer
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123).list().layer(0, new AutoEncoder.Builder().nIn(nIn).nOut(layerSize).corruptionLevel(0.0).activation(Activation.SIGMOID).build()).layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.RECONSTRUCTION_CROSSENTROPY).activation(Activation.SOFTMAX).nIn(layerSize).nOut(nOut).build()).build();
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123)
.list()
.layer(0, new AutoEncoder.Builder().nIn(nIn).nOut(layerSize).corruptionLevel(0.0)
.activation(Activation.SIGMOID).build())
.layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
LossFunctions.LossFunction.RECONSTRUCTION_CROSSENTROPY)
.activation(Activation.SOFTMAX).nIn(layerSize).nOut(nOut)
.build())
.build();
MultiLayerNetwork network = new MultiLayerNetwork(conf); MultiLayerNetwork network = new MultiLayerNetwork(conf);
network.init(); network.init();
network.fit(next, labels); //Labels are necessary for this test: layer activation function affect pretraining results, otherwise // Labels are necessary for this test: layer activation function affect pretraining results, otherwise
network.fit(next, labels);
// Run with separate activation layer // Run with separate activation layer
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder() MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123).list().layer(0, new AutoEncoder.Builder().nIn(nIn).nOut(layerSize).corruptionLevel(0.0).activation(Activation.IDENTITY).build()).layer(1, new org.deeplearning4j.nn.conf.layers.ActivationLayer.Builder().activation(Activation.SIGMOID).build()).layer(2, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.RECONSTRUCTION_CROSSENTROPY).activation(Activation.SOFTMAX).nIn(layerSize).nOut(nOut).build()).build();
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123)
.list()
.layer(0, new AutoEncoder.Builder().nIn(nIn).nOut(layerSize).corruptionLevel(0.0)
.activation(Activation.IDENTITY).build())
.layer(1, new org.deeplearning4j.nn.conf.layers.ActivationLayer.Builder()
.activation(Activation.SIGMOID).build())
.layer(2, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
LossFunctions.LossFunction.RECONSTRUCTION_CROSSENTROPY)
.activation(Activation.SOFTMAX).nIn(layerSize).nOut(nOut)
.build())
.build();
MultiLayerNetwork network2 = new MultiLayerNetwork(conf2); MultiLayerNetwork network2 = new MultiLayerNetwork(conf2);
network2.init(); network2.init();
network2.fit(next, labels); network2.fit(next, labels);
// check parameters // check parameters
assertEquals(network.getLayer(0).getParam("W"), network2.getLayer(0).getParam("W")); assertEquals(network.getLayer(0).getParam("W"), network2.getLayer(0).getParam("W"));
assertEquals(network.getLayer(1).getParam("W"), network2.getLayer(2).getParam("W")); assertEquals(network.getLayer(1).getParam("W"), network2.getLayer(2).getParam("W"));
assertEquals(network.getLayer(0).getParam("b"), network2.getLayer(0).getParam("b")); assertEquals(network.getLayer(0).getParam("b"), network2.getLayer(0).getParam("b"));
assertEquals(network.getLayer(1).getParam("b"), network2.getLayer(2).getParam("b")); assertEquals(network.getLayer(1).getParam("b"), network2.getLayer(2).getParam("b"));
// check activations // check activations
network.init(); network.init();
network.setInput(next); network.setInput(next);
List<INDArray> activations = network.feedForward(true); List<INDArray> activations = network.feedForward(true);
network2.init(); network2.init();
network2.setInput(next); network2.setInput(next);
List<INDArray> activations2 = network2.feedForward(true); List<INDArray> activations2 = network2.feedForward(true);
assertEquals(activations.get(1).reshape(activations2.get(2).shape()), activations2.get(2)); assertEquals(activations.get(1).reshape(activations2.get(2).shape()), activations2.get(2));
assertEquals(activations.get(2), activations2.get(3)); assertEquals(activations.get(2), activations2.get(3));
} }
@Test @Test
public void testCNNActivationLayer() throws Exception { @DisplayName("Test CNN Activation Layer")
void testCNNActivationLayer() throws Exception {
DataSetIterator iter = new MnistDataSetIterator(2, 2); DataSetIterator iter = new MnistDataSetIterator(2, 2);
DataSet next = iter.next(); DataSet next = iter.next();
// Run without separate activation layer // Run without separate activation layer
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123).list().layer(0, new ConvolutionLayer.Builder(4, 4).stride(2, 2).nIn(1).nOut(20).activation(Activation.RELU).weightInit(WeightInit.XAVIER).build()).layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nOut(10).build()).setInputType(InputType.convolutionalFlat(28, 28, 1)).build();
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123)
.list()
.layer(0, new ConvolutionLayer.Builder(4, 4).stride(2, 2).nIn(1).nOut(20)
.activation(Activation.RELU).weightInit(WeightInit.XAVIER).build())
.layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER)
.activation(Activation.SOFTMAX).nOut(10).build())
.setInputType(InputType.convolutionalFlat(28, 28, 1)).build();
MultiLayerNetwork network = new MultiLayerNetwork(conf); MultiLayerNetwork network = new MultiLayerNetwork(conf);
network.init(); network.init();
network.fit(next); network.fit(next);
// Run with separate activation layer // Run with separate activation layer
MultiLayerConfiguration conf2 = MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123).list().layer(0, new ConvolutionLayer.Builder(4, 4).stride(2, 2).nIn(1).nOut(20).activation(Activation.IDENTITY).weightInit(WeightInit.XAVIER).build()).layer(1, new org.deeplearning4j.nn.conf.layers.ActivationLayer.Builder().activation(Activation.RELU).build()).layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nOut(10).build()).setInputType(InputType.convolutionalFlat(28, 28, 1)).build();
new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.seed(123).list()
.layer(0, new ConvolutionLayer.Builder(4, 4).stride(2, 2).nIn(1).nOut(20)
.activation(Activation.IDENTITY).weightInit(WeightInit.XAVIER)
.build())
.layer(1, new org.deeplearning4j.nn.conf.layers.ActivationLayer.Builder()
.activation(Activation.RELU).build())
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX)
.nOut(10).build())
.setInputType(InputType.convolutionalFlat(28, 28, 1)).build();
MultiLayerNetwork network2 = new MultiLayerNetwork(conf2); MultiLayerNetwork network2 = new MultiLayerNetwork(conf2);
network2.init(); network2.init();
network2.fit(next); network2.fit(next);
// check parameters // check parameters
assertEquals(network.getLayer(0).getParam("W"), network2.getLayer(0).getParam("W")); assertEquals(network.getLayer(0).getParam("W"), network2.getLayer(0).getParam("W"));
assertEquals(network.getLayer(1).getParam("W"), network2.getLayer(2).getParam("W")); assertEquals(network.getLayer(1).getParam("W"), network2.getLayer(2).getParam("W"));
assertEquals(network.getLayer(0).getParam("b"), network2.getLayer(0).getParam("b")); assertEquals(network.getLayer(0).getParam("b"), network2.getLayer(0).getParam("b"));
// check activations // check activations
network.init(); network.init();
network.setInput(next.getFeatures()); network.setInput(next.getFeatures());
List<INDArray> activations = network.feedForward(true); List<INDArray> activations = network.feedForward(true);
network2.init(); network2.init();
network2.setInput(next.getFeatures()); network2.setInput(next.getFeatures());
List<INDArray> activations2 = network2.feedForward(true); List<INDArray> activations2 = network2.feedForward(true);
assertEquals(activations.get(1).reshape(activations2.get(2).shape()), activations2.get(2)); assertEquals(activations.get(1).reshape(activations2.get(2).shape()), activations2.get(2));
assertEquals(activations.get(2), activations2.get(3)); assertEquals(activations.get(2), activations2.get(3));
} }
@Test @Test
public void testActivationInheritance() { @DisplayName("Test Activation Inheritance")
void testActivationInheritance() {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123).weightInit(WeightInit.XAVIER).activation(Activation.RATIONALTANH).list().layer(new DenseLayer.Builder().nIn(10).nOut(10).build()).layer(new ActivationLayer()).layer(new ActivationLayer.Builder().build()).layer(new ActivationLayer.Builder().activation(Activation.ELU).build()).layer(new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10).nOut(10).build()).build();
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123)
.weightInit(WeightInit.XAVIER)
.activation(Activation.RATIONALTANH)
.list()
.layer(new DenseLayer.Builder().nIn(10).nOut(10).build())
.layer(new ActivationLayer())
.layer(new ActivationLayer.Builder().build())
.layer(new ActivationLayer.Builder().activation(Activation.ELU).build())
.layer(new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(10).nOut(10).build())
.build();
MultiLayerNetwork network = new MultiLayerNetwork(conf); MultiLayerNetwork network = new MultiLayerNetwork(conf);
network.init(); network.init();
assertNotNull(((ActivationLayer) network.getLayer(1).conf().getLayer()).getActivationFn());
assertNotNull(((ActivationLayer)network.getLayer(1).conf().getLayer()).getActivationFn()); assertTrue(((DenseLayer) network.getLayer(0).conf().getLayer()).getActivationFn() instanceof ActivationRationalTanh);
assertTrue(((ActivationLayer) network.getLayer(1).conf().getLayer()).getActivationFn() instanceof ActivationRationalTanh);
assertTrue(((DenseLayer)network.getLayer(0).conf().getLayer()).getActivationFn() instanceof ActivationRationalTanh); assertTrue(((ActivationLayer) network.getLayer(2).conf().getLayer()).getActivationFn() instanceof ActivationRationalTanh);
assertTrue(((ActivationLayer)network.getLayer(1).conf().getLayer()).getActivationFn() instanceof ActivationRationalTanh); assertTrue(((ActivationLayer) network.getLayer(3).conf().getLayer()).getActivationFn() instanceof ActivationELU);
assertTrue(((ActivationLayer)network.getLayer(2).conf().getLayer()).getActivationFn() instanceof ActivationRationalTanh); assertTrue(((OutputLayer) network.getLayer(4).conf().getLayer()).getActivationFn() instanceof ActivationSoftmax);
assertTrue(((ActivationLayer)network.getLayer(3).conf().getLayer()).getActivationFn() instanceof ActivationELU);
assertTrue(((OutputLayer)network.getLayer(4).conf().getLayer()).getActivationFn() instanceof ActivationSoftmax);
} }
@Test @Test
public void testActivationInheritanceCG() { @DisplayName("Test Activation Inheritance CG")
void testActivationInheritanceCG() {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123).weightInit(WeightInit.XAVIER).activation(Activation.RATIONALTANH).graphBuilder().addInputs("in").addLayer("0", new DenseLayer.Builder().nIn(10).nOut(10).build(), "in").addLayer("1", new ActivationLayer(), "0").addLayer("2", new ActivationLayer.Builder().build(), "1").addLayer("3", new ActivationLayer.Builder().activation(Activation.ELU).build(), "2").addLayer("4", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10).nOut(10).build(), "3").setOutputs("4").build();
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123)
.weightInit(WeightInit.XAVIER)
.activation(Activation.RATIONALTANH)
.graphBuilder()
.addInputs("in")
.addLayer("0", new DenseLayer.Builder().nIn(10).nOut(10).build(), "in")
.addLayer("1", new ActivationLayer(), "0")
.addLayer("2", new ActivationLayer.Builder().build(), "1")
.addLayer("3", new ActivationLayer.Builder().activation(Activation.ELU).build(), "2")
.addLayer("4", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(10).nOut(10).build(), "3")
.setOutputs("4")
.build();
ComputationGraph network = new ComputationGraph(conf); ComputationGraph network = new ComputationGraph(conf);
network.init(); network.init();
assertNotNull(((ActivationLayer) network.getLayer("1").conf().getLayer()).getActivationFn());
assertNotNull(((ActivationLayer)network.getLayer("1").conf().getLayer()).getActivationFn()); assertTrue(((DenseLayer) network.getLayer("0").conf().getLayer()).getActivationFn() instanceof ActivationRationalTanh);
assertTrue(((ActivationLayer) network.getLayer("1").conf().getLayer()).getActivationFn() instanceof ActivationRationalTanh);
assertTrue(((DenseLayer)network.getLayer("0").conf().getLayer()).getActivationFn() instanceof ActivationRationalTanh); assertTrue(((ActivationLayer) network.getLayer("2").conf().getLayer()).getActivationFn() instanceof ActivationRationalTanh);
assertTrue(((ActivationLayer)network.getLayer("1").conf().getLayer()).getActivationFn() instanceof ActivationRationalTanh); assertTrue(((ActivationLayer) network.getLayer("3").conf().getLayer()).getActivationFn() instanceof ActivationELU);
assertTrue(((ActivationLayer)network.getLayer("2").conf().getLayer()).getActivationFn() instanceof ActivationRationalTanh); assertTrue(((OutputLayer) network.getLayer("4").conf().getLayer()).getActivationFn() instanceof ActivationSoftmax);
assertTrue(((ActivationLayer)network.getLayer("3").conf().getLayer()).getActivationFn() instanceof ActivationELU);
assertTrue(((OutputLayer)network.getLayer("4").conf().getLayer()).getActivationFn() instanceof ActivationSoftmax);
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.deeplearning4j.nn.layers; package org.deeplearning4j.nn.layers;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
@ -31,49 +30,30 @@ import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
public class AutoEncoderTest extends BaseDL4JTest { @DisplayName("Auto Encoder Test")
class AutoEncoderTest extends BaseDL4JTest {
@Test @Test
public void sanityCheckIssue5662(){ @DisplayName("Sanity Check Issue 5662")
void sanityCheckIssue5662() {
int mergeSize = 50; int mergeSize = 50;
int encdecSize = 25; int encdecSize = 25;
int in1Size = 20; int in1Size = 20;
int in2Size = 15; int in2Size = 15;
int hiddenSize = 10; int hiddenSize = 10;
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in1", "in2").addLayer("1", new DenseLayer.Builder().nOut(mergeSize).build(), "in1").addLayer("2", new DenseLayer.Builder().nOut(mergeSize).build(), "in2").addVertex("merge", new MergeVertex(), "1", "2").addLayer("e", new AutoEncoder.Builder().nOut(encdecSize).corruptionLevel(0.2).build(), "merge").addLayer("hidden", new AutoEncoder.Builder().nOut(hiddenSize).build(), "e").addLayer("decoder", new AutoEncoder.Builder().nOut(encdecSize).corruptionLevel(0.2).build(), "hidden").addLayer("L4", new DenseLayer.Builder().nOut(mergeSize).build(), "decoder").addLayer("out1", new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nOut(in1Size).build(), "L4").addLayer("out2", new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nOut(in2Size).build(), "L4").setOutputs("out1", "out2").setInputTypes(InputType.feedForward(in1Size), InputType.feedForward(in2Size)).build();
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
.weightInit(WeightInit.XAVIER)
.graphBuilder()
.addInputs("in1", "in2")
.addLayer("1", new DenseLayer.Builder().nOut(mergeSize).build(), "in1")
.addLayer("2", new DenseLayer.Builder().nOut(mergeSize).build(), "in2")
.addVertex("merge", new MergeVertex(), "1", "2")
.addLayer("e",new AutoEncoder.Builder().nOut(encdecSize).corruptionLevel(0.2).build(),"merge")
.addLayer("hidden",new AutoEncoder.Builder().nOut(hiddenSize).build(),"e")
.addLayer("decoder",new AutoEncoder.Builder().nOut(encdecSize).corruptionLevel(0.2).build(),"hidden")
.addLayer("L4", new DenseLayer.Builder().nOut(mergeSize).build(), "decoder")
.addLayer("out1", new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nOut(in1Size).build(),"L4")
.addLayer("out2",new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nOut(in2Size).build(),"L4")
.setOutputs("out1","out2")
.setInputTypes(InputType.feedForward(in1Size), InputType.feedForward(in2Size))
.build();
ComputationGraph net = new ComputationGraph(conf); ComputationGraph net = new ComputationGraph(conf);
net.init(); net.init();
MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(new INDArray[] { Nd4j.create(1, in1Size), Nd4j.create(1, in2Size) }, new INDArray[] { Nd4j.create(1, in1Size), Nd4j.create(1, in2Size) });
MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(
new INDArray[]{Nd4j.create(1, in1Size), Nd4j.create(1, in2Size)},
new INDArray[]{Nd4j.create(1, in1Size), Nd4j.create(1, in2Size)});
net.summary(InputType.feedForward(in1Size), InputType.feedForward(in2Size)); net.summary(InputType.feedForward(in1Size), InputType.feedForward(in2Size));
net.fit(new SingletonMultiDataSetIterator(mds)); net.fit(new SingletonMultiDataSetIterator(mds));
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.deeplearning4j.nn.layers; package org.deeplearning4j.nn.layers;
import lombok.val; import lombok.val;
@ -29,46 +28,47 @@ import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Before; import org.junit.jupiter.api.BeforeEach;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals; @DisplayName("Base Layer Test")
import static org.junit.Assert.assertNotEquals; class BaseLayerTest extends BaseDL4JTest {
public class BaseLayerTest extends BaseDL4JTest { protected INDArray weight = Nd4j.create(new double[] { 0.10, -0.20, -0.15, 0.05 }, new int[] { 2, 2 });
protected INDArray bias = Nd4j.create(new double[] { 0.5, 0.5 }, new int[] { 1, 2 });
protected INDArray weight = Nd4j.create(new double[] {0.10, -0.20, -0.15, 0.05}, new int[] {2, 2});
protected INDArray bias = Nd4j.create(new double[] {0.5, 0.5}, new int[] {1, 2});
protected Map<String, INDArray> paramTable; protected Map<String, INDArray> paramTable;
@Before @BeforeEach
public void doBefore() { void doBefore() {
paramTable = new HashMap<>(); paramTable = new HashMap<>();
paramTable.put("W", weight); paramTable.put("W", weight);
paramTable.put("b", bias); paramTable.put("b", bias);
} }
@Test @Test
public void testSetExistingParamsConvolutionSingleLayer() { @DisplayName("Test Set Existing Params Convolution Single Layer")
void testSetExistingParamsConvolutionSingleLayer() {
Layer layer = configureSingleLayer(); Layer layer = configureSingleLayer();
assertNotEquals(paramTable, layer.paramTable()); assertNotEquals(paramTable, layer.paramTable());
layer.setParamTable(paramTable); layer.setParamTable(paramTable);
assertEquals(paramTable, layer.paramTable()); assertEquals(paramTable, layer.paramTable());
} }
@Test @Test
public void testSetExistingParamsDenseMultiLayer() { @DisplayName("Test Set Existing Params Dense Multi Layer")
void testSetExistingParamsDenseMultiLayer() {
MultiLayerNetwork net = configureMultiLayer(); MultiLayerNetwork net = configureMultiLayer();
for (Layer layer : net.getLayers()) { for (Layer layer : net.getLayers()) {
assertNotEquals(paramTable, layer.paramTable()); assertNotEquals(paramTable, layer.paramTable());
layer.setParamTable(paramTable); layer.setParamTable(paramTable);
@ -76,31 +76,21 @@ public class BaseLayerTest extends BaseDL4JTest {
} }
} }
public Layer configureSingleLayer() { public Layer configureSingleLayer() {
int nIn = 2; int nIn = 2;
int nOut = 2; int nOut = 2;
NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().layer(new ConvolutionLayer.Builder().nIn(nIn).nOut(nOut).build()).build();
NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder()
.layer(new ConvolutionLayer.Builder().nIn(nIn).nOut(nOut).build()).build();
val numParams = conf.getLayer().initializer().numParams(conf); val numParams = conf.getLayer().initializer().numParams(conf);
INDArray params = Nd4j.create(1, numParams); INDArray params = Nd4j.create(1, numParams);
return conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); return conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType());
} }
public MultiLayerNetwork configureMultiLayer() { public MultiLayerNetwork configureMultiLayer() {
int nIn = 2; int nIn = 2;
int nOut = 2; int nOut = 2;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(0, new DenseLayer.Builder().nIn(nIn).nOut(nOut).build()).layer(1, new OutputLayer.Builder().nIn(nIn).nOut(nOut).activation(Activation.SOFTMAX).build()).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list()
.layer(0, new DenseLayer.Builder().nIn(nIn).nOut(nOut).build())
.layer(1, new OutputLayer.Builder().nIn(nIn).nOut(nOut).activation(Activation.SOFTMAX).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
return net; return net;
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.deeplearning4j.nn.layers; package org.deeplearning4j.nn.layers;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
@ -28,77 +27,58 @@ import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals; @DisplayName("Cache Mode Test")
class CacheModeTest extends BaseDL4JTest {
public class CacheModeTest extends BaseDL4JTest {
@Test @Test
public void testConvCacheModeSimple(){ @DisplayName("Test Conv Cache Mode Simple")
void testConvCacheModeSimple() {
MultiLayerConfiguration conf1 = getConf(CacheMode.NONE); MultiLayerConfiguration conf1 = getConf(CacheMode.NONE);
MultiLayerConfiguration conf2 = getConf(CacheMode.DEVICE); MultiLayerConfiguration conf2 = getConf(CacheMode.DEVICE);
MultiLayerNetwork net1 = new MultiLayerNetwork(conf1); MultiLayerNetwork net1 = new MultiLayerNetwork(conf1);
net1.init(); net1.init();
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
net2.init(); net2.init();
INDArray in = Nd4j.rand(3, 28 * 28);
INDArray in = Nd4j.rand(3, 28*28);
INDArray labels = TestUtils.randomOneHot(3, 10); INDArray labels = TestUtils.randomOneHot(3, 10);
INDArray out1 = net1.output(in); INDArray out1 = net1.output(in);
INDArray out2 = net2.output(in); INDArray out2 = net2.output(in);
assertEquals(out1, out2); assertEquals(out1, out2);
assertEquals(net1.params(), net2.params()); assertEquals(net1.params(), net2.params());
net1.fit(in, labels); net1.fit(in, labels);
net2.fit(in, labels); net2.fit(in, labels);
assertEquals(net1.params(), net2.params()); assertEquals(net1.params(), net2.params());
} }
private static MultiLayerConfiguration getConf(CacheMode cacheMode){ private static MultiLayerConfiguration getConf(CacheMode cacheMode) {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().activation(Activation.TANH).inferenceWorkspaceMode(WorkspaceMode.ENABLED).trainingWorkspaceMode(WorkspaceMode.ENABLED).seed(12345).cacheMode(cacheMode).list().layer(new ConvolutionLayer.Builder().nOut(3).build()).layer(new ConvolutionLayer.Builder().nOut(3).build()).layer(new OutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).build()).setInputType(InputType.convolutionalFlat(28, 28, 1)).build();
.activation(Activation.TANH)
.inferenceWorkspaceMode(WorkspaceMode.ENABLED)
.trainingWorkspaceMode(WorkspaceMode.ENABLED)
.seed(12345)
.cacheMode(cacheMode)
.list()
.layer(new ConvolutionLayer.Builder().nOut(3).build())
.layer(new ConvolutionLayer.Builder().nOut(3).build())
.layer(new OutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).build())
.setInputType(InputType.convolutionalFlat(28, 28, 1))
.build();
return conf; return conf;
} }
@Test @Test
public void testLSTMCacheModeSimple(){ @DisplayName("Test LSTM Cache Mode Simple")
void testLSTMCacheModeSimple() {
for(boolean graves : new boolean[]{true, false}) { for (boolean graves : new boolean[] { true, false }) {
MultiLayerConfiguration conf1 = getConfLSTM(CacheMode.NONE, graves); MultiLayerConfiguration conf1 = getConfLSTM(CacheMode.NONE, graves);
MultiLayerConfiguration conf2 = getConfLSTM(CacheMode.DEVICE, graves); MultiLayerConfiguration conf2 = getConfLSTM(CacheMode.DEVICE, graves);
MultiLayerNetwork net1 = new MultiLayerNetwork(conf1); MultiLayerNetwork net1 = new MultiLayerNetwork(conf1);
net1.init(); net1.init();
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
net2.init(); net2.init();
INDArray in = Nd4j.rand(new int[] { 3, 3, 10 });
INDArray in = Nd4j.rand(new int[]{3, 3, 10});
INDArray labels = TestUtils.randomOneHotTimeSeries(3, 10, 10); INDArray labels = TestUtils.randomOneHotTimeSeries(3, 10, 10);
INDArray out1 = net1.output(in); INDArray out1 = net1.output(in);
INDArray out2 = net2.output(in); INDArray out2 = net2.output(in);
assertEquals(out1, out2); assertEquals(out1, out2);
assertEquals(net1.params(), net2.params()); assertEquals(net1.params(), net2.params());
net1.fit(in, labels); net1.fit(in, labels);
net2.fit(in, labels); net2.fit(in, labels);
@ -106,68 +86,33 @@ public class CacheModeTest extends BaseDL4JTest {
} }
} }
private static MultiLayerConfiguration getConfLSTM(CacheMode cacheMode, boolean graves){ private static MultiLayerConfiguration getConfLSTM(CacheMode cacheMode, boolean graves) {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().activation(Activation.TANH).inferenceWorkspaceMode(WorkspaceMode.ENABLED).trainingWorkspaceMode(WorkspaceMode.ENABLED).seed(12345).cacheMode(cacheMode).list().layer(graves ? new GravesLSTM.Builder().nIn(3).nOut(3).build() : new LSTM.Builder().nIn(3).nOut(3).build()).layer(graves ? new GravesLSTM.Builder().nIn(3).nOut(3).build() : new LSTM.Builder().nIn(3).nOut(3).build()).layer(new RnnOutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).build()).build();
.activation(Activation.TANH)
.inferenceWorkspaceMode(WorkspaceMode.ENABLED)
.trainingWorkspaceMode(WorkspaceMode.ENABLED)
.seed(12345)
.cacheMode(cacheMode)
.list()
.layer(graves ?
new GravesLSTM.Builder().nIn(3).nOut(3).build() :
new LSTM.Builder().nIn(3).nOut(3).build())
.layer(graves ?
new GravesLSTM.Builder().nIn(3).nOut(3).build() :
new LSTM.Builder().nIn(3).nOut(3).build())
.layer(new RnnOutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).build())
.build();
return conf; return conf;
} }
@Test @Test
public void testConvCacheModeSimpleCG(){ @DisplayName("Test Conv Cache Mode Simple CG")
void testConvCacheModeSimpleCG() {
ComputationGraphConfiguration conf1 = getConfCG(CacheMode.NONE); ComputationGraphConfiguration conf1 = getConfCG(CacheMode.NONE);
ComputationGraphConfiguration conf2 = getConfCG(CacheMode.DEVICE); ComputationGraphConfiguration conf2 = getConfCG(CacheMode.DEVICE);
ComputationGraph net1 = new ComputationGraph(conf1); ComputationGraph net1 = new ComputationGraph(conf1);
net1.init(); net1.init();
ComputationGraph net2 = new ComputationGraph(conf2); ComputationGraph net2 = new ComputationGraph(conf2);
net2.init(); net2.init();
INDArray in = Nd4j.rand(3, 28 * 28);
INDArray in = Nd4j.rand(3, 28*28);
INDArray labels = TestUtils.randomOneHot(3, 10); INDArray labels = TestUtils.randomOneHot(3, 10);
INDArray out1 = net1.outputSingle(in); INDArray out1 = net1.outputSingle(in);
INDArray out2 = net2.outputSingle(in); INDArray out2 = net2.outputSingle(in);
assertEquals(out1, out2); assertEquals(out1, out2);
assertEquals(net1.params(), net2.params()); assertEquals(net1.params(), net2.params());
net1.fit(new DataSet(in, labels)); net1.fit(new DataSet(in, labels));
net2.fit(new DataSet(in, labels)); net2.fit(new DataSet(in, labels));
assertEquals(net1.params(), net2.params()); assertEquals(net1.params(), net2.params());
} }
private static ComputationGraphConfiguration getConfCG(CacheMode cacheMode){ private static ComputationGraphConfiguration getConfCG(CacheMode cacheMode) {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().activation(Activation.TANH).inferenceWorkspaceMode(WorkspaceMode.ENABLED).trainingWorkspaceMode(WorkspaceMode.ENABLED).seed(12345).cacheMode(cacheMode).graphBuilder().addInputs("in").layer("0", new ConvolutionLayer.Builder().nOut(3).build(), "in").layer("1", new ConvolutionLayer.Builder().nOut(3).build(), "0").layer("2", new OutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).build(), "1").setOutputs("2").setInputTypes(InputType.convolutionalFlat(28, 28, 1)).build();
.activation(Activation.TANH)
.inferenceWorkspaceMode(WorkspaceMode.ENABLED)
.trainingWorkspaceMode(WorkspaceMode.ENABLED)
.seed(12345)
.cacheMode(cacheMode)
.graphBuilder()
.addInputs("in")
.layer("0", new ConvolutionLayer.Builder().nOut(3).build(), "in")
.layer("1", new ConvolutionLayer.Builder().nOut(3).build(), "0")
.layer("2", new OutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).build(), "1")
.setOutputs("2")
.setInputTypes(InputType.convolutionalFlat(28, 28, 1))
.build();
return conf; return conf;
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.deeplearning4j.nn.layers; package org.deeplearning4j.nn.layers;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
@ -34,8 +33,8 @@ import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.junit.Ignore; import org.junit.jupiter.api.Disabled;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
@ -44,73 +43,40 @@ import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import java.util.Random; import java.util.Random;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertNotEquals; @DisplayName("Center Loss Output Layer Test")
class CenterLossOutputLayerTest extends BaseDL4JTest {
public class CenterLossOutputLayerTest extends BaseDL4JTest {
private ComputationGraph getGraph(int numLabels, double lambda) { private ComputationGraph getGraph(int numLabels, double lambda) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).dist(new NormalDistribution(0, 1)).updater(new NoOp()).graphBuilder().addInputs("input1").addLayer("l1", new DenseLayer.Builder().nIn(4).nOut(5).activation(Activation.RELU).build(), "input1").addLayer("lossLayer", new CenterLossOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT).nIn(5).nOut(numLabels).lambda(lambda).activation(Activation.SOFTMAX).build(), "l1").setOutputs("lossLayer").build();
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.dist(new NormalDistribution(0, 1)).updater(new NoOp())
.graphBuilder().addInputs("input1")
.addLayer("l1", new DenseLayer.Builder().nIn(4).nOut(5).activation(Activation.RELU).build(),
"input1")
.addLayer("lossLayer", new CenterLossOutputLayer.Builder()
.lossFunction(LossFunctions.LossFunction.MCXENT).nIn(5).nOut(numLabels)
.lambda(lambda).activation(Activation.SOFTMAX).build(), "l1")
.setOutputs("lossLayer").build();
ComputationGraph graph = new ComputationGraph(conf); ComputationGraph graph = new ComputationGraph(conf);
graph.init(); graph.init();
return graph; return graph;
} }
public ComputationGraph getCNNMnistConfig() { public ComputationGraph getCNNMnistConfig() {
// Number of input channels
int nChannels = 1; // Number of input channels int nChannels = 1;
int outputNum = 10; // The number of possible outcomes // The number of possible outcomes
int outputNum = 10;
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) // Training iterations as above ComputationGraphConfiguration conf = // Training iterations as above
.l2(0.0005).weightInit(WeightInit.XAVIER) new NeuralNetConfiguration.Builder().seed(12345).l2(0.0005).weightInit(WeightInit.XAVIER).updater(new Nesterovs(0.01, 0.9)).graphBuilder().addInputs("input").setInputTypes(InputType.convolutionalFlat(28, 28, 1)).addLayer("0", new ConvolutionLayer.Builder(5, 5).nIn(nChannels).stride(1, 1).nOut(20).activation(Activation.IDENTITY).build(), "input").addLayer("1", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(2, 2).build(), "0").addLayer("2", new ConvolutionLayer.Builder(5, 5).stride(1, 1).nOut(50).activation(Activation.IDENTITY).build(), "1").addLayer("3", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(2, 2).build(), "2").addLayer("4", new DenseLayer.Builder().activation(Activation.RELU).nOut(500).build(), "3").addLayer("output", new org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer.Builder(LossFunction.MCXENT).nOut(outputNum).activation(Activation.SOFTMAX).build(), "4").setOutputs("output").build();
.updater(new Nesterovs(0.01, 0.9))
.graphBuilder().addInputs("input")
.setInputTypes(InputType.convolutionalFlat(28, 28, 1))
.addLayer("0", new ConvolutionLayer.Builder(5, 5)
//nIn and nOut specify channels. nIn here is the nChannels and nOut is the number of filters to be applied
.nIn(nChannels).stride(1, 1).nOut(20).activation(Activation.IDENTITY).build(),
"input")
.addLayer("1", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2)
.stride(2, 2).build(), "0")
.addLayer("2", new ConvolutionLayer.Builder(5, 5)
//Note that nIn need not be specified in later layers
.stride(1, 1).nOut(50).activation(Activation.IDENTITY).build(), "1")
.addLayer("3", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2)
.stride(2, 2).build(), "2")
.addLayer("4", new DenseLayer.Builder().activation(Activation.RELU).nOut(500).build(), "3")
.addLayer("output",
new org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer.Builder(
LossFunction.MCXENT).nOut(outputNum)
.activation(Activation.SOFTMAX).build(),
"4")
.setOutputs("output").build();
ComputationGraph graph = new ComputationGraph(conf); ComputationGraph graph = new ComputationGraph(conf);
graph.init(); graph.init();
return graph; return graph;
} }
@Test @Test
public void testLambdaConf() { @DisplayName("Test Lambda Conf")
double[] lambdas = new double[] {0.1, 0.01}; void testLambdaConf() {
double[] lambdas = new double[] { 0.1, 0.01 };
double[] results = new double[2]; double[] results = new double[2];
int numClasses = 2; int numClasses = 2;
INDArray input = Nd4j.rand(150, 4); INDArray input = Nd4j.rand(150, 4);
INDArray labels = Nd4j.zeros(150, numClasses); INDArray labels = Nd4j.zeros(150, numClasses);
Random r = new Random(12345); Random r = new Random(12345);
@ -118,7 +84,6 @@ public class CenterLossOutputLayerTest extends BaseDL4JTest {
labels.putScalar(i, r.nextInt(numClasses), 1.0); labels.putScalar(i, r.nextInt(numClasses), 1.0);
} }
ComputationGraph graph; ComputationGraph graph;
for (int i = 0; i < lambdas.length; i++) { for (int i = 0; i < lambdas.length; i++) {
graph = getGraph(numClasses, lambdas[i]); graph = getGraph(numClasses, lambdas[i]);
graph.setInput(0, input); graph.setInput(0, input);
@ -126,27 +91,23 @@ public class CenterLossOutputLayerTest extends BaseDL4JTest {
graph.computeGradientAndScore(); graph.computeGradientAndScore();
results[i] = graph.score(); results[i] = graph.score();
} }
assertNotEquals(results[0], results[1]); assertNotEquals(results[0], results[1]);
} }
@Test @Test
@Ignore //Should be run manually @Disabled
public void testMNISTConfig() throws Exception { @DisplayName("Test MNIST Config")
int batchSize = 64; // Test batch size void testMNISTConfig() throws Exception {
// Test batch size
int batchSize = 64;
DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, 12345); DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, 12345);
ComputationGraph net = getCNNMnistConfig(); ComputationGraph net = getCNNMnistConfig();
net.init(); net.init();
net.setListeners(new ScoreIterationListener(1)); net.setListeners(new ScoreIterationListener(1));
for (int i = 0; i < 50; i++) { for (int i = 0; i < 50; i++) {
net.fit(mnistTrain.next()); net.fit(mnistTrain.next());
Thread.sleep(1000); Thread.sleep(1000);
} }
Thread.sleep(100000); Thread.sleep(100000);
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.deeplearning4j.nn.layers; package org.deeplearning4j.nn.layers;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
@ -36,7 +35,7 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor; import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -44,30 +43,30 @@ import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.Assert.assertNull; import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
/** /**
*/ */
public class DropoutLayerTest extends BaseDL4JTest { @DisplayName("Dropout Layer Test")
class DropoutLayerTest extends BaseDL4JTest {
@Override @Override
public DataType getDataType(){ public DataType getDataType() {
return DataType.FLOAT; return DataType.FLOAT;
} }
@Test @Test
public void testInputTypes() { @DisplayName("Test Input Types")
void testInputTypes() {
DropoutLayer config = new DropoutLayer.Builder(0.5).build(); DropoutLayer config = new DropoutLayer.Builder(0.5).build();
InputType in1 = InputType.feedForward(20); InputType in1 = InputType.feedForward(20);
InputType in2 = InputType.convolutional(28, 28, 1); InputType in2 = InputType.convolutional(28, 28, 1);
assertEquals(in1, config.getOutputType(0, in1)); assertEquals(in1, config.getOutputType(0, in1));
assertEquals(in2, config.getOutputType(0, in2)); assertEquals(in2, config.getOutputType(0, in2));
assertNull(config.getPreProcessorForInputType(in1)); assertNull(config.getPreProcessorForInputType(in1));
@ -75,58 +74,30 @@ public class DropoutLayerTest extends BaseDL4JTest {
} }
@Test @Test
public void testDropoutLayerWithoutTraining() throws Exception { @DisplayName("Test Dropout Layer Without Training")
MultiLayerConfiguration confIntegrated = new NeuralNetConfiguration.Builder().seed(3648) void testDropoutLayerWithoutTraining() throws Exception {
.list().layer(0, MultiLayerConfiguration confIntegrated = new NeuralNetConfiguration.Builder().seed(3648).list().layer(0, new ConvolutionLayer.Builder(1, 1).stride(1, 1).nIn(1).nOut(1).dropOut(0.25).activation(Activation.IDENTITY).weightInit(WeightInit.XAVIER).build()).layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).weightInit(WeightInit.XAVIER).dropOut(0.25).nOut(4).build()).setInputType(InputType.convolutionalFlat(2, 2, 1)).build();
new ConvolutionLayer.Builder(1, 1).stride(1, 1).nIn(1).nOut(1).dropOut(0.25)
.activation(Activation.IDENTITY).weightInit(WeightInit.XAVIER)
.build())
.layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX)
.weightInit(WeightInit.XAVIER).dropOut(0.25)
.nOut(4).build())
.setInputType(InputType.convolutionalFlat(2, 2, 1)).build();
MultiLayerNetwork netIntegrated = new MultiLayerNetwork(confIntegrated); MultiLayerNetwork netIntegrated = new MultiLayerNetwork(confIntegrated);
netIntegrated.init(); netIntegrated.init();
netIntegrated.getLayer(0).setParam("W", Nd4j.eye(1)); netIntegrated.getLayer(0).setParam("W", Nd4j.eye(1));
netIntegrated.getLayer(0).setParam("b", Nd4j.zeros(1, 1)); netIntegrated.getLayer(0).setParam("b", Nd4j.zeros(1, 1));
netIntegrated.getLayer(1).setParam("W", Nd4j.eye(4)); netIntegrated.getLayer(1).setParam("W", Nd4j.eye(4));
netIntegrated.getLayer(1).setParam("b", Nd4j.zeros(4, 1)); netIntegrated.getLayer(1).setParam("b", Nd4j.zeros(4, 1));
MultiLayerConfiguration confSeparate = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(3648).list().layer(0, new DropoutLayer.Builder(0.25).build()).layer(1, new ConvolutionLayer.Builder(1, 1).stride(1, 1).nIn(1).nOut(1).activation(Activation.IDENTITY).weightInit(WeightInit.XAVIER).build()).layer(2, new DropoutLayer.Builder(0.25).build()).layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nOut(4).build()).setInputType(InputType.convolutionalFlat(2, 2, 1)).build();
MultiLayerConfiguration confSeparate =
new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.seed(3648)
.list().layer(0,
new DropoutLayer.Builder(0.25)
.build())
.layer(1, new ConvolutionLayer.Builder(1, 1).stride(1, 1).nIn(1).nOut(1)
.activation(Activation.IDENTITY).weightInit(WeightInit.XAVIER)
.build())
.layer(2, new DropoutLayer.Builder(0.25).build())
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX)
.nOut(4).build())
.setInputType(InputType.convolutionalFlat(2, 2, 1)).build();
MultiLayerNetwork netSeparate = new MultiLayerNetwork(confSeparate); MultiLayerNetwork netSeparate = new MultiLayerNetwork(confSeparate);
netSeparate.init(); netSeparate.init();
netSeparate.getLayer(1).setParam("W", Nd4j.eye(1)); netSeparate.getLayer(1).setParam("W", Nd4j.eye(1));
netSeparate.getLayer(1).setParam("b", Nd4j.zeros(1, 1)); netSeparate.getLayer(1).setParam("b", Nd4j.zeros(1, 1));
netSeparate.getLayer(3).setParam("W", Nd4j.eye(4)); netSeparate.getLayer(3).setParam("W", Nd4j.eye(4));
netSeparate.getLayer(3).setParam("b", Nd4j.zeros(4, 1)); netSeparate.getLayer(3).setParam("b", Nd4j.zeros(4, 1));
// Disable input modification for this test:
//Disable input modification for this test: for (Layer l : netIntegrated.getLayers()) {
for(Layer l : netIntegrated.getLayers()){
l.allowInputModification(false); l.allowInputModification(false);
} }
for(Layer l : netSeparate.getLayers()){ for (Layer l : netSeparate.getLayers()) {
l.allowInputModification(false); l.allowInputModification(false);
} }
INDArray in = Nd4j.arange(1, 5).reshape(1, 4);
INDArray in = Nd4j.arange(1, 5).reshape(1,4);
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
List<INDArray> actTrainIntegrated = netIntegrated.feedForward(in.dup(), true); List<INDArray> actTrainIntegrated = netIntegrated.feedForward(in.dup(), true);
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -135,15 +106,10 @@ public class DropoutLayerTest extends BaseDL4JTest {
List<INDArray> actTestIntegrated = netIntegrated.feedForward(in.dup(), false); List<INDArray> actTestIntegrated = netIntegrated.feedForward(in.dup(), false);
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
List<INDArray> actTestSeparate = netSeparate.feedForward(in.dup(), false); List<INDArray> actTestSeparate = netSeparate.feedForward(in.dup(), false);
// Check masks:
//Check masks: INDArray maskIntegrated = ((Dropout) netIntegrated.getLayer(0).conf().getLayer().getIDropout()).getMask();
INDArray maskIntegrated = ((Dropout)netIntegrated.getLayer(0).conf().getLayer().getIDropout()).getMask(); INDArray maskSeparate = ((Dropout) netSeparate.getLayer(0).conf().getLayer().getIDropout()).getMask();
INDArray maskSeparate = ((Dropout)netSeparate.getLayer(0).conf().getLayer().getIDropout()).getMask();
assertEquals(maskIntegrated, maskSeparate); assertEquals(maskIntegrated, maskSeparate);
assertEquals(actTrainIntegrated.get(1), actTrainSeparate.get(2)); assertEquals(actTrainIntegrated.get(1), actTrainSeparate.get(2));
assertEquals(actTrainIntegrated.get(2), actTrainSeparate.get(4)); assertEquals(actTrainIntegrated.get(2), actTrainSeparate.get(4));
assertEquals(actTestIntegrated.get(1), actTestSeparate.get(2)); assertEquals(actTestIntegrated.get(1), actTestSeparate.get(2));
@ -151,68 +117,41 @@ public class DropoutLayerTest extends BaseDL4JTest {
} }
@Test @Test
public void testDropoutLayerWithDenseMnist() throws Exception { @DisplayName("Test Dropout Layer With Dense Mnist")
void testDropoutLayerWithDenseMnist() throws Exception {
DataSetIterator iter = new MnistDataSetIterator(2, 2); DataSetIterator iter = new MnistDataSetIterator(2, 2);
DataSet next = iter.next(); DataSet next = iter.next();
// Run without separate activation layer // Run without separate activation layer
MultiLayerConfiguration confIntegrated = new NeuralNetConfiguration.Builder() MultiLayerConfiguration confIntegrated = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123).list().layer(0, new DenseLayer.Builder().nIn(28 * 28 * 1).nOut(10).activation(Activation.RELU).weightInit(WeightInit.XAVIER).build()).layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).dropOut(0.25).nIn(10).nOut(10).build()).build();
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123)
.list()
.layer(0, new DenseLayer.Builder().nIn(28 * 28 * 1).nOut(10)
.activation(Activation.RELU).weightInit(
WeightInit.XAVIER)
.build())
.layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).dropOut(0.25)
.nIn(10).nOut(10).build())
.build();
MultiLayerNetwork netIntegrated = new MultiLayerNetwork(confIntegrated); MultiLayerNetwork netIntegrated = new MultiLayerNetwork(confIntegrated);
netIntegrated.init(); netIntegrated.init();
netIntegrated.fit(next); netIntegrated.fit(next);
// Run with separate activation layer // Run with separate activation layer
MultiLayerConfiguration confSeparate = new NeuralNetConfiguration.Builder() MultiLayerConfiguration confSeparate = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123).list().layer(0, new DenseLayer.Builder().nIn(28 * 28 * 1).nOut(10).activation(Activation.RELU).weightInit(WeightInit.XAVIER).build()).layer(1, new DropoutLayer.Builder(0.25).build()).layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nIn(10).nOut(10).build()).build();
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123)
.list()
.layer(0, new DenseLayer.Builder().nIn(28 * 28 * 1).nOut(10).activation(Activation.RELU)
.weightInit(WeightInit.XAVIER).build())
.layer(1, new DropoutLayer.Builder(0.25).build())
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nIn(10).nOut(10)
.build())
.build();
MultiLayerNetwork netSeparate = new MultiLayerNetwork(confSeparate); MultiLayerNetwork netSeparate = new MultiLayerNetwork(confSeparate);
netSeparate.init(); netSeparate.init();
netSeparate.fit(next); netSeparate.fit(next);
// Disable input modification for this test:
//Disable input modification for this test: for (Layer l : netIntegrated.getLayers()) {
for(Layer l : netIntegrated.getLayers()){
l.allowInputModification(false); l.allowInputModification(false);
} }
for(Layer l : netSeparate.getLayers()){ for (Layer l : netSeparate.getLayers()) {
l.allowInputModification(false); l.allowInputModification(false);
} }
// check parameters // check parameters
assertEquals(netIntegrated.getLayer(0).getParam("W"), netSeparate.getLayer(0).getParam("W")); assertEquals(netIntegrated.getLayer(0).getParam("W"), netSeparate.getLayer(0).getParam("W"));
assertEquals(netIntegrated.getLayer(0).getParam("b"), netSeparate.getLayer(0).getParam("b")); assertEquals(netIntegrated.getLayer(0).getParam("b"), netSeparate.getLayer(0).getParam("b"));
assertEquals(netIntegrated.getLayer(1).getParam("W"), netSeparate.getLayer(2).getParam("W")); assertEquals(netIntegrated.getLayer(1).getParam("W"), netSeparate.getLayer(2).getParam("W"));
assertEquals(netIntegrated.getLayer(1).getParam("b"), netSeparate.getLayer(2).getParam("b")); assertEquals(netIntegrated.getLayer(1).getParam("b"), netSeparate.getLayer(2).getParam("b"));
// check activations // check activations
netIntegrated.setInput(next.getFeatures()); netIntegrated.setInput(next.getFeatures());
netSeparate.setInput(next.getFeatures()); netSeparate.setInput(next.getFeatures());
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
List<INDArray> actTrainIntegrated = netIntegrated.feedForward(true); List<INDArray> actTrainIntegrated = netIntegrated.feedForward(true);
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
List<INDArray> actTrainSeparate = netSeparate.feedForward(true); List<INDArray> actTrainSeparate = netSeparate.feedForward(true);
assertEquals(actTrainIntegrated.get(1), actTrainSeparate.get(1)); assertEquals(actTrainIntegrated.get(1), actTrainSeparate.get(1));
assertEquals(actTrainIntegrated.get(2), actTrainSeparate.get(3)); assertEquals(actTrainIntegrated.get(2), actTrainSeparate.get(3));
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
List<INDArray> actTestIntegrated = netIntegrated.feedForward(false); List<INDArray> actTestIntegrated = netIntegrated.feedForward(false);
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -222,77 +161,49 @@ public class DropoutLayerTest extends BaseDL4JTest {
} }
@Test @Test
public void testDropoutLayerWithConvMnist() throws Exception { @DisplayName("Test Dropout Layer With Conv Mnist")
Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE); //Set to double datatype - MKL-DNN not used for CPU (otherwise different strides due to Dl4J impl permutes) void testDropoutLayerWithConvMnist() throws Exception {
// Set to double datatype - MKL-DNN not used for CPU (otherwise different strides due to Dl4J impl permutes)
Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE);
DataSetIterator iter = new MnistDataSetIterator(2, 2); DataSetIterator iter = new MnistDataSetIterator(2, 2);
DataSet next = iter.next(); DataSet next = iter.next();
// Run without separate activation layer // Run without separate activation layer
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
MultiLayerConfiguration confIntegrated = new NeuralNetConfiguration.Builder().seed(123) MultiLayerConfiguration confIntegrated = new NeuralNetConfiguration.Builder().seed(123).list().layer(0, new ConvolutionLayer.Builder(4, 4).stride(2, 2).nIn(1).nOut(20).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()).layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).dropOut(0.5).nOut(10).build()).setInputType(InputType.convolutionalFlat(28, 28, 1)).build();
.list().layer(0,
new ConvolutionLayer.Builder(4, 4).stride(2, 2).nIn(1).nOut(20)
.activation(Activation.TANH).weightInit(WeightInit.XAVIER)
.build())
.layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).dropOut(0.5)
.nOut(10).build())
.setInputType(InputType.convolutionalFlat(28, 28, 1)).build();
// Run with separate activation layer // Run with separate activation layer
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
// Manually configure preprocessors
//Manually configure preprocessors // This is necessary, otherwise CnnToFeedForwardPreprocessor will be in different locatinos
//This is necessary, otherwise CnnToFeedForwardPreprocessor will be in different locatinos // i.e., dropout on 4d activations in latter, and dropout on 2d activations in former
//i.e., dropout on 4d activations in latter, and dropout on 2d activations in former
Map<Integer, InputPreProcessor> preProcessorMap = new HashMap<>(); Map<Integer, InputPreProcessor> preProcessorMap = new HashMap<>();
preProcessorMap.put(1, new CnnToFeedForwardPreProcessor(13, 13, 20)); preProcessorMap.put(1, new CnnToFeedForwardPreProcessor(13, 13, 20));
MultiLayerConfiguration confSeparate = new NeuralNetConfiguration.Builder().seed(123).list().layer(0, new ConvolutionLayer.Builder(4, 4).stride(2, 2).nIn(1).nOut(20).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()).layer(1, new DropoutLayer.Builder(0.5).build()).layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nOut(10).build()).inputPreProcessors(preProcessorMap).setInputType(InputType.convolutionalFlat(28, 28, 1)).build();
MultiLayerConfiguration confSeparate = new NeuralNetConfiguration.Builder().seed(123).list()
.layer(0, new ConvolutionLayer.Builder(4, 4).stride(2, 2).nIn(1).nOut(20)
.activation(Activation.TANH).weightInit(WeightInit.XAVIER).build())
.layer(1, new DropoutLayer.Builder(0.5).build())
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nOut(10).build())
.inputPreProcessors(preProcessorMap)
.setInputType(InputType.convolutionalFlat(28, 28, 1)).build();
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
MultiLayerNetwork netIntegrated = new MultiLayerNetwork(confIntegrated); MultiLayerNetwork netIntegrated = new MultiLayerNetwork(confIntegrated);
netIntegrated.init(); netIntegrated.init();
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
MultiLayerNetwork netSeparate = new MultiLayerNetwork(confSeparate); MultiLayerNetwork netSeparate = new MultiLayerNetwork(confSeparate);
netSeparate.init(); netSeparate.init();
assertEquals(netIntegrated.params(), netSeparate.params()); assertEquals(netIntegrated.params(), netSeparate.params());
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
netIntegrated.fit(next); netIntegrated.fit(next);
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
netSeparate.fit(next); netSeparate.fit(next);
assertEquals(netIntegrated.params(), netSeparate.params()); assertEquals(netIntegrated.params(), netSeparate.params());
// check parameters // check parameters
assertEquals(netIntegrated.getLayer(0).getParam("W"), netSeparate.getLayer(0).getParam("W")); assertEquals(netIntegrated.getLayer(0).getParam("W"), netSeparate.getLayer(0).getParam("W"));
assertEquals(netIntegrated.getLayer(0).getParam("b"), netSeparate.getLayer(0).getParam("b")); assertEquals(netIntegrated.getLayer(0).getParam("b"), netSeparate.getLayer(0).getParam("b"));
assertEquals(netIntegrated.getLayer(1).getParam("W"), netSeparate.getLayer(2).getParam("W")); assertEquals(netIntegrated.getLayer(1).getParam("W"), netSeparate.getLayer(2).getParam("W"));
assertEquals(netIntegrated.getLayer(1).getParam("b"), netSeparate.getLayer(2).getParam("b")); assertEquals(netIntegrated.getLayer(1).getParam("b"), netSeparate.getLayer(2).getParam("b"));
// check activations // check activations
netIntegrated.setInput(next.getFeatures().dup()); netIntegrated.setInput(next.getFeatures().dup());
netSeparate.setInput(next.getFeatures().dup()); netSeparate.setInput(next.getFeatures().dup());
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
List<INDArray> actTrainIntegrated = netIntegrated.feedForward(true); List<INDArray> actTrainIntegrated = netIntegrated.feedForward(true);
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
List<INDArray> actTrainSeparate = netSeparate.feedForward(true); List<INDArray> actTrainSeparate = netSeparate.feedForward(true);
assertEquals(actTrainIntegrated.get(1), actTrainSeparate.get(1)); assertEquals(actTrainIntegrated.get(1), actTrainSeparate.get(1));
assertEquals(actTrainIntegrated.get(2), actTrainSeparate.get(3)); assertEquals(actTrainIntegrated.get(2), actTrainSeparate.get(3));
netIntegrated.setInput(next.getFeatures().dup()); netIntegrated.setInput(next.getFeatures().dup());
netSeparate.setInput(next.getFeatures().dup()); netSeparate.setInput(next.getFeatures().dup());
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.deeplearning4j.nn.layers; package org.deeplearning4j.nn.layers;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@ -31,116 +30,69 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration; import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration;
import org.deeplearning4j.nn.transferlearning.TransferLearning; import org.deeplearning4j.nn.transferlearning.TransferLearning;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertEquals; import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@Slf4j @Slf4j
public class FrozenLayerTest extends BaseDL4JTest { @DisplayName("Frozen Layer Test")
class FrozenLayerTest extends BaseDL4JTest {
/* /*
A model with a few frozen layers == A model with a few frozen layers ==
Model with non frozen layers set with the output of the forward pass of the frozen layers Model with non frozen layers set with the output of the forward pass of the frozen layers
*/ */
@Test @Test
public void testFrozen() { @DisplayName("Test Frozen")
void testFrozen() {
DataSet randomData = new DataSet(Nd4j.rand(10, 4), Nd4j.rand(10, 3)); DataSet randomData = new DataSet(Nd4j.rand(10, 4), Nd4j.rand(10, 3));
NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)).activation(Activation.IDENTITY);
NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1))
.activation(Activation.IDENTITY);
FineTuneConfiguration finetune = new FineTuneConfiguration.Builder().updater(new Sgd(0.1)).build(); FineTuneConfiguration finetune = new FineTuneConfiguration.Builder().updater(new Sgd(0.1)).build();
MultiLayerNetwork modelToFineTune = new MultiLayerNetwork(overallConf.clone().list().layer(0, new DenseLayer.Builder().nIn(4).nOut(3).build()).layer(1, new DenseLayer.Builder().nIn(3).nOut(2).build()).layer(2, new DenseLayer.Builder().nIn(2).nOut(3).build()).layer(3, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3).build()).build());
MultiLayerNetwork modelToFineTune = new MultiLayerNetwork(overallConf.clone().list()
.layer(0, new DenseLayer.Builder().nIn(4).nOut(3).build())
.layer(1, new DenseLayer.Builder().nIn(3).nOut(2).build())
.layer(2, new DenseLayer.Builder().nIn(2).nOut(3).build())
.layer(3, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3)
.build())
.build());
modelToFineTune.init(); modelToFineTune.init();
List<INDArray> ff = modelToFineTune.feedForwardToLayer(2, randomData.getFeatures(), false); List<INDArray> ff = modelToFineTune.feedForwardToLayer(2, randomData.getFeatures(), false);
INDArray asFrozenFeatures = ff.get(2); INDArray asFrozenFeatures = ff.get(2);
MultiLayerNetwork modelNow = new TransferLearning.Builder(modelToFineTune).fineTuneConfiguration(finetune).setFeatureExtractor(1).build();
MultiLayerNetwork modelNow = new TransferLearning.Builder(modelToFineTune).fineTuneConfiguration(finetune) INDArray paramsLastTwoLayers = Nd4j.hstack(modelToFineTune.getLayer(2).params(), modelToFineTune.getLayer(3).params());
.setFeatureExtractor(1).build(); MultiLayerNetwork notFrozen = new MultiLayerNetwork(overallConf.clone().list().layer(0, new DenseLayer.Builder().nIn(2).nOut(3).build()).layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3).build()).build(), paramsLastTwoLayers);
// assertEquals(modelNow.getLayer(2).conf(), notFrozen.getLayer(0).conf()); //Equal, other than names
INDArray paramsLastTwoLayers = // assertEquals(modelNow.getLayer(3).conf(), notFrozen.getLayer(1).conf()); //Equal, other than names
Nd4j.hstack(modelToFineTune.getLayer(2).params(), modelToFineTune.getLayer(3).params()); // Check: forward pass
MultiLayerNetwork notFrozen = new MultiLayerNetwork(overallConf.clone().list()
.layer(0, new DenseLayer.Builder().nIn(2).nOut(3).build())
.layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3)
.build())
.build(), paramsLastTwoLayers);
// assertEquals(modelNow.getLayer(2).conf(), notFrozen.getLayer(0).conf()); //Equal, other than names
// assertEquals(modelNow.getLayer(3).conf(), notFrozen.getLayer(1).conf()); //Equal, other than names
//Check: forward pass
INDArray outNow = modelNow.output(randomData.getFeatures()); INDArray outNow = modelNow.output(randomData.getFeatures());
INDArray outNotFrozen = notFrozen.output(asFrozenFeatures); INDArray outNotFrozen = notFrozen.output(asFrozenFeatures);
assertEquals(outNow, outNotFrozen); assertEquals(outNow, outNotFrozen);
for (int i = 0; i < 5; i++) { for (int i = 0; i < 5; i++) {
notFrozen.fit(new DataSet(asFrozenFeatures, randomData.getLabels())); notFrozen.fit(new DataSet(asFrozenFeatures, randomData.getLabels()));
modelNow.fit(randomData); modelNow.fit(randomData);
} }
INDArray expected = Nd4j.hstack(modelToFineTune.getLayer(0).params(), modelToFineTune.getLayer(1).params(), notFrozen.params());
INDArray expected = Nd4j.hstack(modelToFineTune.getLayer(0).params(), modelToFineTune.getLayer(1).params(),
notFrozen.params());
INDArray act = modelNow.params(); INDArray act = modelNow.params();
assertEquals(expected, act); assertEquals(expected, act);
} }
@Test @Test
public void cloneMLNFrozen() { @DisplayName("Clone MLN Frozen")
void cloneMLNFrozen() {
DataSet randomData = new DataSet(Nd4j.rand(10, 4), Nd4j.rand(10, 3)); DataSet randomData = new DataSet(Nd4j.rand(10, 4), Nd4j.rand(10, 3));
NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)).activation(Activation.IDENTITY);
NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)) MultiLayerNetwork modelToFineTune = new MultiLayerNetwork(overallConf.list().layer(0, new DenseLayer.Builder().nIn(4).nOut(3).build()).layer(1, new DenseLayer.Builder().nIn(3).nOut(2).build()).layer(2, new DenseLayer.Builder().nIn(2).nOut(3).build()).layer(3, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3).build()).build());
.activation(Activation.IDENTITY);
MultiLayerNetwork modelToFineTune = new MultiLayerNetwork(overallConf.list()
.layer(0, new DenseLayer.Builder().nIn(4).nOut(3).build())
.layer(1, new DenseLayer.Builder().nIn(3).nOut(2).build())
.layer(2, new DenseLayer.Builder().nIn(2).nOut(3).build())
.layer(3, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3)
.build())
.build());
modelToFineTune.init(); modelToFineTune.init();
INDArray asFrozenFeatures = modelToFineTune.feedForwardToLayer(2, randomData.getFeatures(), false).get(2); INDArray asFrozenFeatures = modelToFineTune.feedForwardToLayer(2, randomData.getFeatures(), false).get(2);
MultiLayerNetwork modelNow = new TransferLearning.Builder(modelToFineTune).setFeatureExtractor(1).build(); MultiLayerNetwork modelNow = new TransferLearning.Builder(modelToFineTune).setFeatureExtractor(1).build();
MultiLayerNetwork clonedModel = modelNow.clone(); MultiLayerNetwork clonedModel = modelNow.clone();
// Check json
//Check json
assertEquals(modelNow.getLayerWiseConfigurations().toJson(), clonedModel.getLayerWiseConfigurations().toJson()); assertEquals(modelNow.getLayerWiseConfigurations().toJson(), clonedModel.getLayerWiseConfigurations().toJson());
// Check params
//Check params
assertEquals(modelNow.params(), clonedModel.params()); assertEquals(modelNow.params(), clonedModel.params());
MultiLayerNetwork notFrozen = new MultiLayerNetwork(overallConf.list().layer(0, new DenseLayer.Builder().nIn(2).nOut(3).build()).layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3).build()).build(), Nd4j.hstack(modelToFineTune.getLayer(2).params(), modelToFineTune.getLayer(3).params()));
MultiLayerNetwork notFrozen = new MultiLayerNetwork(
overallConf.list().layer(0, new DenseLayer.Builder().nIn(2).nOut(3).build())
.layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(3).nOut(3)
.build())
.build(),
Nd4j.hstack(modelToFineTune.getLayer(2).params(), modelToFineTune.getLayer(3).params()));
int i = 0; int i = 0;
while (i < 5) { while (i < 5) {
notFrozen.fit(new DataSet(asFrozenFeatures, randomData.getLabels())); notFrozen.fit(new DataSet(asFrozenFeatures, randomData.getLabels()));
@ -148,112 +100,49 @@ public class FrozenLayerTest extends BaseDL4JTest {
clonedModel.fit(randomData); clonedModel.fit(randomData);
i++; i++;
} }
INDArray expectedParams = Nd4j.hstack(modelToFineTune.getLayer(0).params(), modelToFineTune.getLayer(1).params(), notFrozen.params());
INDArray expectedParams = Nd4j.hstack(modelToFineTune.getLayer(0).params(),
modelToFineTune.getLayer(1).params(), notFrozen.params());
assertEquals(expectedParams, modelNow.params()); assertEquals(expectedParams, modelNow.params());
assertEquals(expectedParams, clonedModel.params()); assertEquals(expectedParams, clonedModel.params());
} }
@Test @Test
public void testFrozenCompGraph() { @DisplayName("Test Frozen Comp Graph")
void testFrozenCompGraph() {
DataSet randomData = new DataSet(Nd4j.rand(10, 4), Nd4j.rand(10, 3)); DataSet randomData = new DataSet(Nd4j.rand(10, 4), Nd4j.rand(10, 3));
NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)).activation(Activation.IDENTITY);
NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)) ComputationGraph modelToFineTune = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In").addLayer("layer0", new DenseLayer.Builder().nIn(4).nOut(3).build(), "layer0In").addLayer("layer1", new DenseLayer.Builder().nIn(3).nOut(2).build(), "layer0").addLayer("layer2", new DenseLayer.Builder().nIn(2).nOut(3).build(), "layer1").addLayer("layer3", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3).build(), "layer2").setOutputs("layer3").build());
.activation(Activation.IDENTITY);
ComputationGraph modelToFineTune = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In")
.addLayer("layer0", new DenseLayer.Builder().nIn(4).nOut(3).build(), "layer0In")
.addLayer("layer1", new DenseLayer.Builder().nIn(3).nOut(2).build(), "layer0")
.addLayer("layer2", new DenseLayer.Builder().nIn(2).nOut(3).build(), "layer1")
.addLayer("layer3",
new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(3).nOut(3)
.build(),
"layer2")
.setOutputs("layer3").build());
modelToFineTune.init(); modelToFineTune.init();
INDArray asFrozenFeatures = modelToFineTune.feedForward(randomData.getFeatures(), false).get("layer1"); INDArray asFrozenFeatures = modelToFineTune.feedForward(randomData.getFeatures(), false).get("layer1");
ComputationGraph modelNow = new TransferLearning.GraphBuilder(modelToFineTune).setFeatureExtractor("layer1").build();
ComputationGraph modelNow = ComputationGraph notFrozen = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In").addLayer("layer0", new DenseLayer.Builder().nIn(2).nOut(3).build(), "layer0In").addLayer("layer1", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3).build(), "layer0").setOutputs("layer1").build());
new TransferLearning.GraphBuilder(modelToFineTune).setFeatureExtractor("layer1").build();
ComputationGraph notFrozen = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In")
.addLayer("layer0", new DenseLayer.Builder().nIn(2).nOut(3).build(), "layer0In")
.addLayer("layer1",
new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(3).nOut(3)
.build(),
"layer0")
.setOutputs("layer1").build());
notFrozen.init(); notFrozen.init();
notFrozen.setParams(Nd4j.hstack(modelToFineTune.getLayer("layer2").params(), notFrozen.setParams(Nd4j.hstack(modelToFineTune.getLayer("layer2").params(), modelToFineTune.getLayer("layer3").params()));
modelToFineTune.getLayer("layer3").params()));
int i = 0; int i = 0;
while (i < 5) { while (i < 5) {
notFrozen.fit(new DataSet(asFrozenFeatures, randomData.getLabels())); notFrozen.fit(new DataSet(asFrozenFeatures, randomData.getLabels()));
modelNow.fit(randomData); modelNow.fit(randomData);
i++; i++;
} }
assertEquals(Nd4j.hstack(modelToFineTune.getLayer("layer0").params(), modelToFineTune.getLayer("layer1").params(), notFrozen.params()), modelNow.params());
assertEquals(Nd4j.hstack(modelToFineTune.getLayer("layer0").params(),
modelToFineTune.getLayer("layer1").params(), notFrozen.params()), modelNow.params());
} }
@Test @Test
public void cloneCompGraphFrozen() { @DisplayName("Clone Comp Graph Frozen")
void cloneCompGraphFrozen() {
DataSet randomData = new DataSet(Nd4j.rand(10, 4), Nd4j.rand(10, 3)); DataSet randomData = new DataSet(Nd4j.rand(10, 4), Nd4j.rand(10, 3));
NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)).activation(Activation.IDENTITY);
NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)) ComputationGraph modelToFineTune = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In").addLayer("layer0", new DenseLayer.Builder().nIn(4).nOut(3).build(), "layer0In").addLayer("layer1", new DenseLayer.Builder().nIn(3).nOut(2).build(), "layer0").addLayer("layer2", new DenseLayer.Builder().nIn(2).nOut(3).build(), "layer1").addLayer("layer3", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3).build(), "layer2").setOutputs("layer3").build());
.activation(Activation.IDENTITY);
ComputationGraph modelToFineTune = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In")
.addLayer("layer0", new DenseLayer.Builder().nIn(4).nOut(3).build(), "layer0In")
.addLayer("layer1", new DenseLayer.Builder().nIn(3).nOut(2).build(), "layer0")
.addLayer("layer2", new DenseLayer.Builder().nIn(2).nOut(3).build(), "layer1")
.addLayer("layer3",
new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(3).nOut(3)
.build(),
"layer2")
.setOutputs("layer3").build());
modelToFineTune.init(); modelToFineTune.init();
INDArray asFrozenFeatures = modelToFineTune.feedForward(randomData.getFeatures(), false).get("layer1"); INDArray asFrozenFeatures = modelToFineTune.feedForward(randomData.getFeatures(), false).get("layer1");
ComputationGraph modelNow = ComputationGraph modelNow = new TransferLearning.GraphBuilder(modelToFineTune).setFeatureExtractor("layer1").build();
new TransferLearning.GraphBuilder(modelToFineTune).setFeatureExtractor("layer1").build();
ComputationGraph clonedModel = modelNow.clone(); ComputationGraph clonedModel = modelNow.clone();
// Check json
//Check json
assertEquals(clonedModel.getConfiguration().toJson(), modelNow.getConfiguration().toJson()); assertEquals(clonedModel.getConfiguration().toJson(), modelNow.getConfiguration().toJson());
// Check params
//Check params
assertEquals(modelNow.params(), clonedModel.params()); assertEquals(modelNow.params(), clonedModel.params());
ComputationGraph notFrozen = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In").addLayer("layer0", new DenseLayer.Builder().nIn(2).nOut(3).build(), "layer0In").addLayer("layer1", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3).build(), "layer0").setOutputs("layer1").build());
ComputationGraph notFrozen = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In")
.addLayer("layer0", new DenseLayer.Builder().nIn(2).nOut(3).build(), "layer0In")
.addLayer("layer1",
new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(3).nOut(3)
.build(),
"layer0")
.setOutputs("layer1").build());
notFrozen.init(); notFrozen.init();
notFrozen.setParams(Nd4j.hstack(modelToFineTune.getLayer("layer2").params(), notFrozen.setParams(Nd4j.hstack(modelToFineTune.getLayer("layer2").params(), modelToFineTune.getLayer("layer3").params()));
modelToFineTune.getLayer("layer3").params()));
int i = 0; int i = 0;
while (i < 5) { while (i < 5) {
notFrozen.fit(new DataSet(asFrozenFeatures, randomData.getLabels())); notFrozen.fit(new DataSet(asFrozenFeatures, randomData.getLabels()));
@ -261,117 +150,54 @@ public class FrozenLayerTest extends BaseDL4JTest {
clonedModel.fit(randomData); clonedModel.fit(randomData);
i++; i++;
} }
INDArray expectedParams = Nd4j.hstack(modelToFineTune.getLayer("layer0").params(), modelToFineTune.getLayer("layer1").params(), notFrozen.params());
INDArray expectedParams = Nd4j.hstack(modelToFineTune.getLayer("layer0").params(),
modelToFineTune.getLayer("layer1").params(), notFrozen.params());
assertEquals(expectedParams, modelNow.params()); assertEquals(expectedParams, modelNow.params());
assertEquals(expectedParams, clonedModel.params()); assertEquals(expectedParams, clonedModel.params());
} }
@Test @Test
public void testFrozenLayerInstantiation() { @DisplayName("Test Frozen Layer Instantiation")
//We need to be able to instantitate frozen layers from JSON etc, and have them be the same as if void testFrozenLayerInstantiation() {
// We need to be able to instantitate frozen layers from JSON etc, and have them be the same as if
// they were initialized via the builder // they were initialized via the builder
MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(12345).list() MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(12345).list().layer(0, new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()).layer(1, new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()).layer(2, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10).nOut(10).build()).build();
.layer(0, new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH) MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345).list().layer(0, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayer(new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build())).layer(1, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayer(new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build())).layer(2, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10).nOut(10).build()).build();
.weightInit(WeightInit.XAVIER).build())
.layer(1, new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH)
.weightInit(WeightInit.XAVIER).build())
.layer(2, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10)
.nOut(10).build())
.build();
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345).list().layer(0,
new org.deeplearning4j.nn.conf.layers.misc.FrozenLayer(new DenseLayer.Builder().nIn(10).nOut(10)
.activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()))
.layer(1, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayer(
new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH)
.weightInit(WeightInit.XAVIER).build()))
.layer(2, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10)
.nOut(10).build())
.build();
MultiLayerNetwork net1 = new MultiLayerNetwork(conf1); MultiLayerNetwork net1 = new MultiLayerNetwork(conf1);
net1.init(); net1.init();
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
net2.init(); net2.init();
assertEquals(net1.params(), net2.params()); assertEquals(net1.params(), net2.params());
String json = conf2.toJson(); String json = conf2.toJson();
MultiLayerConfiguration fromJson = MultiLayerConfiguration.fromJson(json); MultiLayerConfiguration fromJson = MultiLayerConfiguration.fromJson(json);
assertEquals(conf2, fromJson); assertEquals(conf2, fromJson);
MultiLayerNetwork net3 = new MultiLayerNetwork(fromJson); MultiLayerNetwork net3 = new MultiLayerNetwork(fromJson);
net3.init(); net3.init();
INDArray input = Nd4j.rand(10, 10); INDArray input = Nd4j.rand(10, 10);
INDArray out2 = net2.output(input); INDArray out2 = net2.output(input);
INDArray out3 = net3.output(input); INDArray out3 = net3.output(input);
assertEquals(out2, out3); assertEquals(out2, out3);
} }
@Test @Test
public void testFrozenLayerInstantiationCompGraph() { @DisplayName("Test Frozen Layer Instantiation Comp Graph")
void testFrozenLayerInstantiationCompGraph() {
//We need to be able to instantitate frozen layers from JSON etc, and have them be the same as if // We need to be able to instantitate frozen layers from JSON etc, and have them be the same as if
// they were initialized via the builder // they were initialized via the builder
ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(12345).graphBuilder() ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(12345).graphBuilder().addInputs("in").addLayer("0", new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build(), "in").addLayer("1", new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build(), "0").addLayer("2", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10).nOut(10).build(), "1").setOutputs("2").build();
.addInputs("in") ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345).graphBuilder().addInputs("in").addLayer("0", new org.deeplearning4j.nn.conf.layers.misc.FrozenLayer.Builder().layer(new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()).build(), "in").addLayer("1", new org.deeplearning4j.nn.conf.layers.misc.FrozenLayer.Builder().layer(new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()).build(), "0").addLayer("2", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10).nOut(10).build(), "1").setOutputs("2").build();
.addLayer("0", new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH)
.weightInit(WeightInit.XAVIER).build(), "in")
.addLayer("1", new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH)
.weightInit(WeightInit.XAVIER).build(), "0")
.addLayer("2", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10)
.nOut(10).build(),
"1")
.setOutputs("2").build();
ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345).graphBuilder()
.addInputs("in")
.addLayer("0", new org.deeplearning4j.nn.conf.layers.misc.FrozenLayer.Builder()
.layer(new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH)
.weightInit(WeightInit.XAVIER).build())
.build(), "in")
.addLayer("1", new org.deeplearning4j.nn.conf.layers.misc.FrozenLayer.Builder()
.layer(new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH)
.weightInit(WeightInit.XAVIER).build())
.build(), "0")
.addLayer("2", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10)
.nOut(10).build(),
"1")
.setOutputs("2").build();
ComputationGraph net1 = new ComputationGraph(conf1); ComputationGraph net1 = new ComputationGraph(conf1);
net1.init(); net1.init();
ComputationGraph net2 = new ComputationGraph(conf2); ComputationGraph net2 = new ComputationGraph(conf2);
net2.init(); net2.init();
assertEquals(net1.params(), net2.params()); assertEquals(net1.params(), net2.params());
String json = conf2.toJson(); String json = conf2.toJson();
ComputationGraphConfiguration fromJson = ComputationGraphConfiguration.fromJson(json); ComputationGraphConfiguration fromJson = ComputationGraphConfiguration.fromJson(json);
assertEquals(conf2, fromJson); assertEquals(conf2, fromJson);
ComputationGraph net3 = new ComputationGraph(fromJson); ComputationGraph net3 = new ComputationGraph(fromJson);
net3.init(); net3.init();
INDArray input = Nd4j.rand(10, 10); INDArray input = Nd4j.rand(10, 10);
INDArray out2 = net2.outputSingle(input); INDArray out2 = net2.outputSingle(input);
INDArray out3 = net3.outputSingle(input); INDArray out3 = net3.outputSingle(input);
assertEquals(out2, out3); assertEquals(out2, out3);
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.deeplearning4j.nn.layers; package org.deeplearning4j.nn.layers;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@ -34,363 +33,194 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration; import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration;
import org.deeplearning4j.nn.transferlearning.TransferLearning; import org.deeplearning4j.nn.transferlearning.TransferLearning;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.Assert.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.Assert.assertNotNull; import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@Slf4j @Slf4j
public class FrozenLayerWithBackpropTest extends BaseDL4JTest { @DisplayName("Frozen Layer With Backprop Test")
class FrozenLayerWithBackpropTest extends BaseDL4JTest {
@Test @Test
public void testFrozenWithBackpropLayerInstantiation() { @DisplayName("Test Frozen With Backprop Layer Instantiation")
//We need to be able to instantitate frozen layers from JSON etc, and have them be the same as if void testFrozenWithBackpropLayerInstantiation() {
// We need to be able to instantitate frozen layers from JSON etc, and have them be the same as if
// they were initialized via the builder // they were initialized via the builder
MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(12345).list() MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(12345).list().layer(0, new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()).layer(1, new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()).layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10).nOut(10).build()).build();
.layer(0, new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH) MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345).list().layer(0, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build())).layer(1, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build())).layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10).nOut(10).build()).build();
.weightInit(WeightInit.XAVIER).build())
.layer(1, new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH)
.weightInit(WeightInit.XAVIER).build())
.layer(2, new OutputLayer.Builder(
LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10)
.nOut(10).build())
.build();
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345).list().layer(0,
new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(10).nOut(10)
.activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()))
.layer(1, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH)
.weightInit(WeightInit.XAVIER).build()))
.layer(2, new OutputLayer.Builder(
LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10)
.nOut(10).build())
.build();
MultiLayerNetwork net1 = new MultiLayerNetwork(conf1); MultiLayerNetwork net1 = new MultiLayerNetwork(conf1);
net1.init(); net1.init();
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
net2.init(); net2.init();
assertEquals(net1.params(), net2.params()); assertEquals(net1.params(), net2.params());
String json = conf2.toJson(); String json = conf2.toJson();
MultiLayerConfiguration fromJson = MultiLayerConfiguration.fromJson(json); MultiLayerConfiguration fromJson = MultiLayerConfiguration.fromJson(json);
assertEquals(conf2, fromJson); assertEquals(conf2, fromJson);
MultiLayerNetwork net3 = new MultiLayerNetwork(fromJson); MultiLayerNetwork net3 = new MultiLayerNetwork(fromJson);
net3.init(); net3.init();
INDArray input = Nd4j.rand(10, 10); INDArray input = Nd4j.rand(10, 10);
INDArray out2 = net2.output(input); INDArray out2 = net2.output(input);
INDArray out3 = net3.output(input); INDArray out3 = net3.output(input);
assertEquals(out2, out3); assertEquals(out2, out3);
} }
@Test @Test
public void testFrozenLayerInstantiationCompGraph() { @DisplayName("Test Frozen Layer Instantiation Comp Graph")
void testFrozenLayerInstantiationCompGraph() {
//We need to be able to instantitate frozen layers from JSON etc, and have them be the same as if // We need to be able to instantitate frozen layers from JSON etc, and have them be the same as if
// they were initialized via the builder // they were initialized via the builder
ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(12345).graphBuilder() ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(12345).graphBuilder().addInputs("in").addLayer("0", new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build(), "in").addLayer("1", new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build(), "0").addLayer("2", new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10).nOut(10).build(), "1").setOutputs("2").build();
.addInputs("in") ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345).graphBuilder().addInputs("in").addLayer("0", new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()), "in").addLayer("1", new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()), "0").addLayer("2", new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10).nOut(10).build(), "1").setOutputs("2").build();
.addLayer("0", new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH)
.weightInit(WeightInit.XAVIER).build(), "in")
.addLayer("1", new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH)
.weightInit(WeightInit.XAVIER).build(), "0")
.addLayer("2", new OutputLayer.Builder(
LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10)
.nOut(10).build(),
"1")
.setOutputs("2").build();
ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345).graphBuilder()
.addInputs("in")
.addLayer("0", new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH)
.weightInit(WeightInit.XAVIER).build()), "in")
.addLayer("1", new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH)
.weightInit(WeightInit.XAVIER).build()), "0")
.addLayer("2", new OutputLayer.Builder(
LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10)
.nOut(10).build(),
"1")
.setOutputs("2").build();
ComputationGraph net1 = new ComputationGraph(conf1); ComputationGraph net1 = new ComputationGraph(conf1);
net1.init(); net1.init();
ComputationGraph net2 = new ComputationGraph(conf2); ComputationGraph net2 = new ComputationGraph(conf2);
net2.init(); net2.init();
assertEquals(net1.params(), net2.params()); assertEquals(net1.params(), net2.params());
String json = conf2.toJson(); String json = conf2.toJson();
ComputationGraphConfiguration fromJson = ComputationGraphConfiguration.fromJson(json); ComputationGraphConfiguration fromJson = ComputationGraphConfiguration.fromJson(json);
assertEquals(conf2, fromJson); assertEquals(conf2, fromJson);
ComputationGraph net3 = new ComputationGraph(fromJson); ComputationGraph net3 = new ComputationGraph(fromJson);
net3.init(); net3.init();
INDArray input = Nd4j.rand(10, 10); INDArray input = Nd4j.rand(10, 10);
INDArray out2 = net2.outputSingle(input); INDArray out2 = net2.outputSingle(input);
INDArray out3 = net3.outputSingle(input); INDArray out3 = net3.outputSingle(input);
assertEquals(out2, out3); assertEquals(out2, out3);
} }
@Test @Test
public void testMultiLayerNetworkFrozenLayerParamsAfterBackprop() { @DisplayName("Test Multi Layer Network Frozen Layer Params After Backprop")
void testMultiLayerNetworkFrozenLayerParamsAfterBackprop() {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
DataSet randomData = new DataSet(Nd4j.rand(100, 4), Nd4j.rand(100, 1)); DataSet randomData = new DataSet(Nd4j.rand(100, 4), Nd4j.rand(100, 1));
MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(12345).weightInit(WeightInit.XAVIER).updater(new Sgd(2)).list().layer(new DenseLayer.Builder().nIn(4).nOut(3).build()).layer(new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(3).nOut(4).build())).layer(new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(4).nOut(2).build())).layer(new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(2).nOut(1).build())).build();
MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder()
.seed(12345)
.weightInit(WeightInit.XAVIER)
.updater(new Sgd(2))
.list()
.layer(new DenseLayer.Builder().nIn(4).nOut(3).build())
.layer(new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
new DenseLayer.Builder().nIn(3).nOut(4).build()))
.layer(new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
new DenseLayer.Builder().nIn(4).nOut(2).build()))
.layer(new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(2).nOut(1).build()))
.build();
MultiLayerNetwork network = new MultiLayerNetwork(conf1); MultiLayerNetwork network = new MultiLayerNetwork(conf1);
network.init(); network.init();
INDArray unfrozenLayerParams = network.getLayer(0).params().dup(); INDArray unfrozenLayerParams = network.getLayer(0).params().dup();
INDArray frozenLayerParams1 = network.getLayer(1).params().dup(); INDArray frozenLayerParams1 = network.getLayer(1).params().dup();
INDArray frozenLayerParams2 = network.getLayer(2).params().dup(); INDArray frozenLayerParams2 = network.getLayer(2).params().dup();
INDArray frozenOutputLayerParams = network.getLayer(3).params().dup(); INDArray frozenOutputLayerParams = network.getLayer(3).params().dup();
for (int i = 0; i < 100; i++) { for (int i = 0; i < 100; i++) {
network.fit(randomData); network.fit(randomData);
} }
assertNotEquals(unfrozenLayerParams, network.getLayer(0).params()); assertNotEquals(unfrozenLayerParams, network.getLayer(0).params());
assertEquals(frozenLayerParams1, network.getLayer(1).params()); assertEquals(frozenLayerParams1, network.getLayer(1).params());
assertEquals(frozenLayerParams2, network.getLayer(2).params()); assertEquals(frozenLayerParams2, network.getLayer(2).params());
assertEquals(frozenOutputLayerParams, network.getLayer(3).params()); assertEquals(frozenOutputLayerParams, network.getLayer(3).params());
} }
@Test @Test
public void testComputationGraphFrozenLayerParamsAfterBackprop() { @DisplayName("Test Computation Graph Frozen Layer Params After Backprop")
void testComputationGraphFrozenLayerParamsAfterBackprop() {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
DataSet randomData = new DataSet(Nd4j.rand(100, 4), Nd4j.rand(100, 1)); DataSet randomData = new DataSet(Nd4j.rand(100, 4), Nd4j.rand(100, 1));
String frozenBranchName = "B1-"; String frozenBranchName = "B1-";
String unfrozenBranchName = "B2-"; String unfrozenBranchName = "B2-";
String initialLayer = "initial"; String initialLayer = "initial";
String frozenBranchUnfrozenLayer0 = frozenBranchName + "0"; String frozenBranchUnfrozenLayer0 = frozenBranchName + "0";
String frozenBranchFrozenLayer1 = frozenBranchName + "1"; String frozenBranchFrozenLayer1 = frozenBranchName + "1";
String frozenBranchFrozenLayer2 = frozenBranchName + "2"; String frozenBranchFrozenLayer2 = frozenBranchName + "2";
String frozenBranchOutput = frozenBranchName + "Output"; String frozenBranchOutput = frozenBranchName + "Output";
String unfrozenLayer0 = unfrozenBranchName + "0"; String unfrozenLayer0 = unfrozenBranchName + "0";
String unfrozenLayer1 = unfrozenBranchName + "1"; String unfrozenLayer1 = unfrozenBranchName + "1";
String unfrozenBranch2 = unfrozenBranchName + "Output"; String unfrozenBranch2 = unfrozenBranchName + "Output";
ComputationGraphConfiguration computationGraphConf = new NeuralNetConfiguration.Builder().updater(new Sgd(2.0)).seed(12345).graphBuilder().addInputs("input").addLayer(initialLayer, new DenseLayer.Builder().nIn(4).nOut(4).build(), "input").addLayer(frozenBranchUnfrozenLayer0, new DenseLayer.Builder().nIn(4).nOut(3).build(), initialLayer).addLayer(frozenBranchFrozenLayer1, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(3).nOut(4).build()), frozenBranchUnfrozenLayer0).addLayer(frozenBranchFrozenLayer2, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(4).nOut(2).build()), frozenBranchFrozenLayer1).addLayer(unfrozenLayer0, new DenseLayer.Builder().nIn(4).nOut(4).build(), initialLayer).addLayer(unfrozenLayer1, new DenseLayer.Builder().nIn(4).nOut(2).build(), unfrozenLayer0).addLayer(unfrozenBranch2, new DenseLayer.Builder().nIn(2).nOut(1).build(), unfrozenLayer1).addVertex("merge", new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2).addLayer(frozenBranchOutput, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(3).nOut(1).build()), "merge").setOutputs(frozenBranchOutput).build();
ComputationGraphConfiguration computationGraphConf = new NeuralNetConfiguration.Builder()
.updater(new Sgd(2.0))
.seed(12345)
.graphBuilder()
.addInputs("input")
.addLayer(initialLayer, new DenseLayer.Builder().nIn(4).nOut(4).build(),"input")
.addLayer(frozenBranchUnfrozenLayer0, new DenseLayer.Builder().nIn(4).nOut(3).build(),initialLayer)
.addLayer(frozenBranchFrozenLayer1, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
new DenseLayer.Builder().nIn(3).nOut(4).build()),frozenBranchUnfrozenLayer0)
.addLayer(frozenBranchFrozenLayer2, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
new DenseLayer.Builder().nIn(4).nOut(2).build()),frozenBranchFrozenLayer1)
.addLayer(unfrozenLayer0, new DenseLayer.Builder().nIn(4).nOut(4).build(),initialLayer)
.addLayer(unfrozenLayer1, new DenseLayer.Builder().nIn(4).nOut(2).build(),unfrozenLayer0)
.addLayer(unfrozenBranch2, new DenseLayer.Builder().nIn(2).nOut(1).build(),unfrozenLayer1)
.addVertex("merge", new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2)
.addLayer(frozenBranchOutput,new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(3).nOut(1).build()),"merge")
.setOutputs(frozenBranchOutput)
.build();
ComputationGraph computationGraph = new ComputationGraph(computationGraphConf); ComputationGraph computationGraph = new ComputationGraph(computationGraphConf);
computationGraph.init(); computationGraph.init();
INDArray unfrozenLayerParams = computationGraph.getLayer(frozenBranchUnfrozenLayer0).params().dup(); INDArray unfrozenLayerParams = computationGraph.getLayer(frozenBranchUnfrozenLayer0).params().dup();
INDArray frozenLayerParams1 = computationGraph.getLayer(frozenBranchFrozenLayer1).params().dup(); INDArray frozenLayerParams1 = computationGraph.getLayer(frozenBranchFrozenLayer1).params().dup();
INDArray frozenLayerParams2 = computationGraph.getLayer(frozenBranchFrozenLayer2).params().dup(); INDArray frozenLayerParams2 = computationGraph.getLayer(frozenBranchFrozenLayer2).params().dup();
INDArray frozenOutputLayerParams = computationGraph.getLayer(frozenBranchOutput).params().dup(); INDArray frozenOutputLayerParams = computationGraph.getLayer(frozenBranchOutput).params().dup();
for (int i = 0; i < 100; i++) { for (int i = 0; i < 100; i++) {
computationGraph.fit(randomData); computationGraph.fit(randomData);
} }
assertNotEquals(unfrozenLayerParams, computationGraph.getLayer(frozenBranchUnfrozenLayer0).params()); assertNotEquals(unfrozenLayerParams, computationGraph.getLayer(frozenBranchUnfrozenLayer0).params());
assertEquals(frozenLayerParams1, computationGraph.getLayer(frozenBranchFrozenLayer1).params()); assertEquals(frozenLayerParams1, computationGraph.getLayer(frozenBranchFrozenLayer1).params());
assertEquals(frozenLayerParams2, computationGraph.getLayer(frozenBranchFrozenLayer2).params()); assertEquals(frozenLayerParams2, computationGraph.getLayer(frozenBranchFrozenLayer2).params());
assertEquals(frozenOutputLayerParams, computationGraph.getLayer(frozenBranchOutput).params()); assertEquals(frozenOutputLayerParams, computationGraph.getLayer(frozenBranchOutput).params());
} }
/** /**
* Frozen layer should have same results as a layer with Sgd updater with learning rate set to 0 * Frozen layer should have same results as a layer with Sgd updater with learning rate set to 0
*/ */
@Test @Test
public void testFrozenLayerVsSgd() { @DisplayName("Test Frozen Layer Vs Sgd")
void testFrozenLayerVsSgd() {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
DataSet randomData = new DataSet(Nd4j.rand(100, 4), Nd4j.rand(100, 1)); DataSet randomData = new DataSet(Nd4j.rand(100, 4), Nd4j.rand(100, 1));
MultiLayerConfiguration confSgd = new NeuralNetConfiguration.Builder().seed(12345).weightInit(WeightInit.XAVIER).updater(new Sgd(2)).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(3).build()).layer(1, new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(3).nOut(4).build()).layer(2, new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(4).nOut(2).build()).layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MSE).updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).activation(Activation.TANH).nIn(2).nOut(1).build()).build();
MultiLayerConfiguration confSgd = new NeuralNetConfiguration.Builder() MultiLayerConfiguration confFrozen = new NeuralNetConfiguration.Builder().seed(12345).weightInit(WeightInit.XAVIER).updater(new Sgd(2)).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(3).build()).layer(1, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(3).nOut(4).build())).layer(2, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(4).nOut(2).build())).layer(3, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(2).nOut(1).build())).build();
.seed(12345)
.weightInit(WeightInit.XAVIER)
.updater(new Sgd(2))
.list()
.layer(0,new DenseLayer.Builder().nIn(4).nOut(3).build())
.layer(1,new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(3).nOut(4).build())
.layer(2,new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(4).nOut(2).build())
.layer(3,new OutputLayer.Builder(LossFunctions.LossFunction.MSE).updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).activation(Activation.TANH).nIn(2).nOut(1).build())
.build();
MultiLayerConfiguration confFrozen = new NeuralNetConfiguration.Builder()
.seed(12345)
.weightInit(WeightInit.XAVIER)
.updater(new Sgd(2))
.list()
.layer(0,new DenseLayer.Builder().nIn(4).nOut(3).build())
.layer(1,new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(3).nOut(4).build()))
.layer(2,new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(4).nOut(2).build()))
.layer(3,new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(2).nOut(1).build()))
.build();
MultiLayerNetwork frozenNetwork = new MultiLayerNetwork(confFrozen); MultiLayerNetwork frozenNetwork = new MultiLayerNetwork(confFrozen);
frozenNetwork.init(); frozenNetwork.init();
INDArray unfrozenLayerParams = frozenNetwork.getLayer(0).params().dup(); INDArray unfrozenLayerParams = frozenNetwork.getLayer(0).params().dup();
INDArray frozenLayerParams1 = frozenNetwork.getLayer(1).params().dup(); INDArray frozenLayerParams1 = frozenNetwork.getLayer(1).params().dup();
INDArray frozenLayerParams2 = frozenNetwork.getLayer(2).params().dup(); INDArray frozenLayerParams2 = frozenNetwork.getLayer(2).params().dup();
INDArray frozenOutputLayerParams = frozenNetwork.getLayer(3).params().dup(); INDArray frozenOutputLayerParams = frozenNetwork.getLayer(3).params().dup();
MultiLayerNetwork sgdNetwork = new MultiLayerNetwork(confSgd); MultiLayerNetwork sgdNetwork = new MultiLayerNetwork(confSgd);
sgdNetwork.init(); sgdNetwork.init();
INDArray unfrozenSgdLayerParams = sgdNetwork.getLayer(0).params().dup(); INDArray unfrozenSgdLayerParams = sgdNetwork.getLayer(0).params().dup();
INDArray frozenSgdLayerParams1 = sgdNetwork.getLayer(1).params().dup(); INDArray frozenSgdLayerParams1 = sgdNetwork.getLayer(1).params().dup();
INDArray frozenSgdLayerParams2 = sgdNetwork.getLayer(2).params().dup(); INDArray frozenSgdLayerParams2 = sgdNetwork.getLayer(2).params().dup();
INDArray frozenSgdOutputLayerParams = sgdNetwork.getLayer(3).params().dup(); INDArray frozenSgdOutputLayerParams = sgdNetwork.getLayer(3).params().dup();
for (int i = 0; i < 100; i++) { for (int i = 0; i < 100; i++) {
frozenNetwork.fit(randomData); frozenNetwork.fit(randomData);
} }
for (int i = 0; i < 100; i++) { for (int i = 0; i < 100; i++) {
sgdNetwork.fit(randomData); sgdNetwork.fit(randomData);
} }
assertEquals(frozenNetwork.getLayer(0).params(), sgdNetwork.getLayer(0).params()); assertEquals(frozenNetwork.getLayer(0).params(), sgdNetwork.getLayer(0).params());
assertEquals(frozenNetwork.getLayer(1).params(), sgdNetwork.getLayer(1).params()); assertEquals(frozenNetwork.getLayer(1).params(), sgdNetwork.getLayer(1).params());
assertEquals(frozenNetwork.getLayer(2).params(), sgdNetwork.getLayer(2).params()); assertEquals(frozenNetwork.getLayer(2).params(), sgdNetwork.getLayer(2).params());
assertEquals(frozenNetwork.getLayer(3).params(), sgdNetwork.getLayer(3).params()); assertEquals(frozenNetwork.getLayer(3).params(), sgdNetwork.getLayer(3).params());
} }
@Test @Test
public void testComputationGraphVsSgd() { @DisplayName("Test Computation Graph Vs Sgd")
void testComputationGraphVsSgd() {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
DataSet randomData = new DataSet(Nd4j.rand(100, 4), Nd4j.rand(100, 1)); DataSet randomData = new DataSet(Nd4j.rand(100, 4), Nd4j.rand(100, 1));
String frozenBranchName = "B1-"; String frozenBranchName = "B1-";
String unfrozenBranchName = "B2-"; String unfrozenBranchName = "B2-";
String initialLayer = "initial"; String initialLayer = "initial";
String frozenBranchUnfrozenLayer0 = frozenBranchName + "0"; String frozenBranchUnfrozenLayer0 = frozenBranchName + "0";
String frozenBranchFrozenLayer1 = frozenBranchName + "1"; String frozenBranchFrozenLayer1 = frozenBranchName + "1";
String frozenBranchFrozenLayer2 = frozenBranchName + "2"; String frozenBranchFrozenLayer2 = frozenBranchName + "2";
String frozenBranchOutput = frozenBranchName + "Output"; String frozenBranchOutput = frozenBranchName + "Output";
String unfrozenLayer0 = unfrozenBranchName + "0"; String unfrozenLayer0 = unfrozenBranchName + "0";
String unfrozenLayer1 = unfrozenBranchName + "1"; String unfrozenLayer1 = unfrozenBranchName + "1";
String unfrozenBranch2 = unfrozenBranchName + "Output"; String unfrozenBranch2 = unfrozenBranchName + "Output";
ComputationGraphConfiguration computationGraphConf = new NeuralNetConfiguration.Builder().updater(new Sgd(2.0)).seed(12345).graphBuilder().addInputs("input").addLayer(initialLayer, new DenseLayer.Builder().nIn(4).nOut(4).build(), "input").addLayer(frozenBranchUnfrozenLayer0, new DenseLayer.Builder().nIn(4).nOut(3).build(), initialLayer).addLayer(frozenBranchFrozenLayer1, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(3).nOut(4).build()), frozenBranchUnfrozenLayer0).addLayer(frozenBranchFrozenLayer2, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(4).nOut(2).build()), frozenBranchFrozenLayer1).addLayer(unfrozenLayer0, new DenseLayer.Builder().nIn(4).nOut(4).build(), initialLayer).addLayer(unfrozenLayer1, new DenseLayer.Builder().nIn(4).nOut(2).build(), unfrozenLayer0).addLayer(unfrozenBranch2, new DenseLayer.Builder().nIn(2).nOut(1).build(), unfrozenLayer1).addVertex("merge", new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2).addLayer(frozenBranchOutput, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(3).nOut(1).build()), "merge").setOutputs(frozenBranchOutput).build();
ComputationGraphConfiguration computationGraphConf = new NeuralNetConfiguration.Builder() ComputationGraphConfiguration computationSgdGraphConf = new NeuralNetConfiguration.Builder().updater(new Sgd(2.0)).seed(12345).graphBuilder().addInputs("input").addLayer(initialLayer, new DenseLayer.Builder().nIn(4).nOut(4).build(), "input").addLayer(frozenBranchUnfrozenLayer0, new DenseLayer.Builder().nIn(4).nOut(3).build(), initialLayer).addLayer(frozenBranchFrozenLayer1, new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(3).nOut(4).build(), frozenBranchUnfrozenLayer0).addLayer(frozenBranchFrozenLayer2, new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(4).nOut(2).build(), frozenBranchFrozenLayer1).addLayer(unfrozenLayer0, new DenseLayer.Builder().nIn(4).nOut(4).build(), initialLayer).addLayer(unfrozenLayer1, new DenseLayer.Builder().nIn(4).nOut(2).build(), unfrozenLayer0).addLayer(unfrozenBranch2, new DenseLayer.Builder().nIn(2).nOut(1).build(), unfrozenLayer1).addVertex("merge", new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2).addLayer(frozenBranchOutput, new OutputLayer.Builder(LossFunctions.LossFunction.MSE).updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).activation(Activation.TANH).nIn(3).nOut(1).build(), "merge").setOutputs(frozenBranchOutput).build();
.updater(new Sgd(2.0))
.seed(12345)
.graphBuilder()
.addInputs("input")
.addLayer(initialLayer,new DenseLayer.Builder().nIn(4).nOut(4).build(),"input")
.addLayer(frozenBranchUnfrozenLayer0,new DenseLayer.Builder().nIn(4).nOut(3).build(), initialLayer)
.addLayer(frozenBranchFrozenLayer1,new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
new DenseLayer.Builder().nIn(3).nOut(4).build()),frozenBranchUnfrozenLayer0)
.addLayer(frozenBranchFrozenLayer2,
new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
new DenseLayer.Builder().nIn(4).nOut(2).build()),frozenBranchFrozenLayer1)
.addLayer(unfrozenLayer0,new DenseLayer.Builder().nIn(4).nOut(4).build(),initialLayer)
.addLayer(unfrozenLayer1,new DenseLayer.Builder().nIn(4).nOut(2).build(),unfrozenLayer0)
.addLayer(unfrozenBranch2,new DenseLayer.Builder().nIn(2).nOut(1).build(),unfrozenLayer1)
.addVertex("merge",new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2)
.addLayer(frozenBranchOutput, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(3).nOut(1).build()),"merge")
.setOutputs(frozenBranchOutput)
.build();
ComputationGraphConfiguration computationSgdGraphConf = new NeuralNetConfiguration.Builder()
.updater(new Sgd(2.0))
.seed(12345)
.graphBuilder()
.addInputs("input")
.addLayer(initialLayer, new DenseLayer.Builder().nIn(4).nOut(4).build(),"input")
.addLayer(frozenBranchUnfrozenLayer0,new DenseLayer.Builder().nIn(4).nOut(3).build(),initialLayer)
.addLayer(frozenBranchFrozenLayer1,new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(3).nOut(4).build(),frozenBranchUnfrozenLayer0)
.addLayer(frozenBranchFrozenLayer2,new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(4).nOut(2).build(),frozenBranchFrozenLayer1)
.addLayer(unfrozenLayer0,new DenseLayer.Builder().nIn(4).nOut(4).build(),initialLayer)
.addLayer(unfrozenLayer1,new DenseLayer.Builder().nIn(4).nOut(2).build(),unfrozenLayer0)
.addLayer(unfrozenBranch2,new DenseLayer.Builder().nIn(2).nOut(1).build(),unfrozenLayer1)
.addVertex("merge",new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2)
.addLayer(frozenBranchOutput,new OutputLayer.Builder(LossFunctions.LossFunction.MSE).updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).activation(Activation.TANH).nIn(3).nOut(1).build(),"merge")
.setOutputs(frozenBranchOutput)
.build();
ComputationGraph frozenComputationGraph = new ComputationGraph(computationGraphConf); ComputationGraph frozenComputationGraph = new ComputationGraph(computationGraphConf);
frozenComputationGraph.init(); frozenComputationGraph.init();
INDArray unfrozenLayerParams = frozenComputationGraph.getLayer(frozenBranchUnfrozenLayer0).params().dup(); INDArray unfrozenLayerParams = frozenComputationGraph.getLayer(frozenBranchUnfrozenLayer0).params().dup();
INDArray frozenLayerParams1 = frozenComputationGraph.getLayer(frozenBranchFrozenLayer1).params().dup(); INDArray frozenLayerParams1 = frozenComputationGraph.getLayer(frozenBranchFrozenLayer1).params().dup();
INDArray frozenLayerParams2 = frozenComputationGraph.getLayer(frozenBranchFrozenLayer2).params().dup(); INDArray frozenLayerParams2 = frozenComputationGraph.getLayer(frozenBranchFrozenLayer2).params().dup();
INDArray frozenOutputLayerParams = frozenComputationGraph.getLayer(frozenBranchOutput).params().dup(); INDArray frozenOutputLayerParams = frozenComputationGraph.getLayer(frozenBranchOutput).params().dup();
ComputationGraph sgdComputationGraph = new ComputationGraph(computationSgdGraphConf); ComputationGraph sgdComputationGraph = new ComputationGraph(computationSgdGraphConf);
sgdComputationGraph.init(); sgdComputationGraph.init();
INDArray unfrozenSgdLayerParams = sgdComputationGraph.getLayer(frozenBranchUnfrozenLayer0).params().dup(); INDArray unfrozenSgdLayerParams = sgdComputationGraph.getLayer(frozenBranchUnfrozenLayer0).params().dup();
INDArray frozenSgdLayerParams1 = sgdComputationGraph.getLayer(frozenBranchFrozenLayer1).params().dup(); INDArray frozenSgdLayerParams1 = sgdComputationGraph.getLayer(frozenBranchFrozenLayer1).params().dup();
INDArray frozenSgdLayerParams2 = sgdComputationGraph.getLayer(frozenBranchFrozenLayer2).params().dup(); INDArray frozenSgdLayerParams2 = sgdComputationGraph.getLayer(frozenBranchFrozenLayer2).params().dup();
INDArray frozenSgdOutputLayerParams = sgdComputationGraph.getLayer(frozenBranchOutput).params().dup(); INDArray frozenSgdOutputLayerParams = sgdComputationGraph.getLayer(frozenBranchOutput).params().dup();
for (int i = 0; i < 100; i++) { for (int i = 0; i < 100; i++) {
frozenComputationGraph.fit(randomData); frozenComputationGraph.fit(randomData);
} }
for (int i = 0; i < 100; i++) { for (int i = 0; i < 100; i++) {
sgdComputationGraph.fit(randomData); sgdComputationGraph.fit(randomData);
} }
assertEquals(frozenComputationGraph.getLayer(frozenBranchUnfrozenLayer0).params(), sgdComputationGraph.getLayer(frozenBranchUnfrozenLayer0).params()); assertEquals(frozenComputationGraph.getLayer(frozenBranchUnfrozenLayer0).params(), sgdComputationGraph.getLayer(frozenBranchUnfrozenLayer0).params());
assertEquals(frozenComputationGraph.getLayer(frozenBranchFrozenLayer1).params(), sgdComputationGraph.getLayer(frozenBranchFrozenLayer1).params()); assertEquals(frozenComputationGraph.getLayer(frozenBranchFrozenLayer1).params(), sgdComputationGraph.getLayer(frozenBranchFrozenLayer1).params());
assertEquals(frozenComputationGraph.getLayer(frozenBranchFrozenLayer2).params(), sgdComputationGraph.getLayer(frozenBranchFrozenLayer2).params()); assertEquals(frozenComputationGraph.getLayer(frozenBranchFrozenLayer2).params(), sgdComputationGraph.getLayer(frozenBranchFrozenLayer2).params());
assertEquals(frozenComputationGraph.getLayer(frozenBranchOutput).params(), sgdComputationGraph.getLayer(frozenBranchOutput).params()); assertEquals(frozenComputationGraph.getLayer(frozenBranchOutput).params(), sgdComputationGraph.getLayer(frozenBranchOutput).params());
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.deeplearning4j.nn.layers; package org.deeplearning4j.nn.layers;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@ -36,7 +35,7 @@ import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -46,123 +45,88 @@ import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import java.util.Collections; import java.util.Collections;
import java.util.Random; import java.util.Random;
import static org.junit.jupiter.api.Assertions.*;
import static org.junit.Assert.*; import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@Slf4j @Slf4j
public class OutputLayerTest extends BaseDL4JTest { @DisplayName("Output Layer Test")
class OutputLayerTest extends BaseDL4JTest {
@Test @Test
public void testSetParams() { @DisplayName("Test Set Params")
NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder() void testSetParams() {
.optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT) NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).updater(new Sgd(1e-1)).layer(new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder().nIn(4).nOut(3).weightInit(WeightInit.ZERO).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()).build();
.updater(new Sgd(1e-1))
.layer(new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder().nIn(4).nOut(3)
.weightInit(WeightInit.ZERO).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.build();
long numParams = conf.getLayer().initializer().numParams(conf); long numParams = conf.getLayer().initializer().numParams(conf);
INDArray params = Nd4j.create(1, numParams); INDArray params = Nd4j.create(1, numParams);
OutputLayer l = (OutputLayer) conf.getLayer().instantiate(conf, OutputLayer l = (OutputLayer) conf.getLayer().instantiate(conf, Collections.<TrainingListener>singletonList(new ScoreIterationListener(1)), 0, params, true, params.dataType());
Collections.<TrainingListener>singletonList(new ScoreIterationListener(1)), 0, params, true, params.dataType());
params = l.params(); params = l.params();
l.setParams(params); l.setParams(params);
assertEquals(params, l.params()); assertEquals(params, l.params());
} }
@Test @Test
public void testOutputLayersRnnForwardPass() { @DisplayName("Test Output Layers Rnn Forward Pass")
//Test output layer with RNNs ( void testOutputLayersRnnForwardPass() {
//Expect all outputs etc. to be 2d // Test output layer with RNNs (
// Expect all outputs etc. to be 2d
int nIn = 2; int nIn = 2;
int nOut = 5; int nOut = 5;
int layerSize = 4; int layerSize = 4;
int timeSeriesLength = 6; int timeSeriesLength = 6;
int miniBatchSize = 3; int miniBatchSize = 3;
Random r = new Random(12345L); Random r = new Random(12345L);
INDArray input = Nd4j.zeros(miniBatchSize, nIn, timeSeriesLength); INDArray input = Nd4j.zeros(miniBatchSize, nIn, timeSeriesLength);
for (int i = 0; i < miniBatchSize; i++) { for (int i = 0; i < miniBatchSize; i++) {
for (int j = 0; j < nIn; j++) { for (int j = 0; j < nIn; j++) {
for (int k = 0; k < timeSeriesLength; k++) { for (int k = 0; k < timeSeriesLength; k++) {
input.putScalar(new int[] {i, j, k}, r.nextDouble() - 0.5); input.putScalar(new int[] { i, j, k }, r.nextDouble() - 0.5);
} }
} }
} }
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345L).list().layer(0, new GravesLSTM.Builder().nIn(nIn).nOut(layerSize).dist(new NormalDistribution(0, 1)).activation(Activation.TANH).updater(new NoOp()).build()).layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(layerSize).nOut(nOut).dist(new NormalDistribution(0, 1)).updater(new NoOp()).build()).inputPreProcessor(1, new RnnToFeedForwardPreProcessor()).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345L).list()
.layer(0, new GravesLSTM.Builder().nIn(nIn).nOut(layerSize)
.dist(new NormalDistribution(0, 1)).activation(Activation.TANH)
.updater(new NoOp()).build())
.layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(layerSize).nOut(nOut)
.dist(new NormalDistribution(0, 1))
.updater(new NoOp()).build())
.inputPreProcessor(1, new RnnToFeedForwardPreProcessor()).build();
MultiLayerNetwork mln = new MultiLayerNetwork(conf); MultiLayerNetwork mln = new MultiLayerNetwork(conf);
mln.init(); mln.init();
INDArray out2d = mln.feedForward(input).get(2); INDArray out2d = mln.feedForward(input).get(2);
assertArrayEquals(out2d.shape(), new long[] {miniBatchSize * timeSeriesLength, nOut}); assertArrayEquals(out2d.shape(), new long[] { miniBatchSize * timeSeriesLength, nOut });
INDArray out = mln.output(input); INDArray out = mln.output(input);
assertArrayEquals(out.shape(), new long[] {miniBatchSize * timeSeriesLength, nOut}); assertArrayEquals(out.shape(), new long[] { miniBatchSize * timeSeriesLength, nOut });
INDArray preout = mln.output(input); INDArray preout = mln.output(input);
assertArrayEquals(preout.shape(), new long[] {miniBatchSize * timeSeriesLength, nOut}); assertArrayEquals(preout.shape(), new long[] { miniBatchSize * timeSeriesLength, nOut });
// As above, but for RnnOutputLayer. Expect all activations etc. to be 3d
//As above, but for RnnOutputLayer. Expect all activations etc. to be 3d MultiLayerConfiguration confRnn = new NeuralNetConfiguration.Builder().seed(12345L).list().layer(0, new GravesLSTM.Builder().nIn(nIn).nOut(layerSize).dist(new NormalDistribution(0, 1)).activation(Activation.TANH).updater(new NoOp()).build()).layer(1, new org.deeplearning4j.nn.conf.layers.RnnOutputLayer.Builder(LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(layerSize).nOut(nOut).dist(new NormalDistribution(0, 1)).updater(new NoOp()).build()).build();
MultiLayerConfiguration confRnn = new NeuralNetConfiguration.Builder().seed(12345L).list()
.layer(0, new GravesLSTM.Builder().nIn(nIn).nOut(layerSize)
.dist(new NormalDistribution(0, 1)).activation(Activation.TANH)
.updater(new NoOp()).build())
.layer(1, new org.deeplearning4j.nn.conf.layers.RnnOutputLayer.Builder(LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(layerSize).nOut(nOut)
.dist(new NormalDistribution(0, 1))
.updater(new NoOp()).build())
.build();
MultiLayerNetwork mlnRnn = new MultiLayerNetwork(confRnn); MultiLayerNetwork mlnRnn = new MultiLayerNetwork(confRnn);
mln.init(); mln.init();
INDArray out3d = mlnRnn.feedForward(input).get(2); INDArray out3d = mlnRnn.feedForward(input).get(2);
assertArrayEquals(out3d.shape(), new long[] {miniBatchSize, nOut, timeSeriesLength}); assertArrayEquals(out3d.shape(), new long[] { miniBatchSize, nOut, timeSeriesLength });
INDArray outRnn = mlnRnn.output(input); INDArray outRnn = mlnRnn.output(input);
assertArrayEquals(outRnn.shape(), new long[] {miniBatchSize, nOut, timeSeriesLength}); assertArrayEquals(outRnn.shape(), new long[] { miniBatchSize, nOut, timeSeriesLength });
INDArray preoutRnn = mlnRnn.output(input); INDArray preoutRnn = mlnRnn.output(input);
assertArrayEquals(preoutRnn.shape(), new long[] {miniBatchSize, nOut, timeSeriesLength}); assertArrayEquals(preoutRnn.shape(), new long[] { miniBatchSize, nOut, timeSeriesLength });
} }
@Test @Test
public void testRnnOutputLayerIncEdgeCases() { @DisplayName("Test Rnn Output Layer Inc Edge Cases")
//Basic test + test edge cases: timeSeriesLength==1, miniBatchSize==1, both void testRnnOutputLayerIncEdgeCases() {
int[] tsLength = {5, 1, 5, 1}; // Basic test + test edge cases: timeSeriesLength==1, miniBatchSize==1, both
int[] miniBatch = {7, 7, 1, 1}; int[] tsLength = { 5, 1, 5, 1 };
int[] miniBatch = { 7, 7, 1, 1 };
int nIn = 3; int nIn = 3;
int nOut = 6; int nOut = 6;
int layerSize = 4; int layerSize = 4;
FeedForwardToRnnPreProcessor proc = new FeedForwardToRnnPreProcessor(); FeedForwardToRnnPreProcessor proc = new FeedForwardToRnnPreProcessor();
for (int t = 0; t < tsLength.length; t++) { for (int t = 0; t < tsLength.length; t++) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int timeSeriesLength = tsLength[t]; int timeSeriesLength = tsLength[t];
int miniBatchSize = miniBatch[t]; int miniBatchSize = miniBatch[t];
Random r = new Random(12345L); Random r = new Random(12345L);
INDArray input = Nd4j.zeros(miniBatchSize, nIn, timeSeriesLength); INDArray input = Nd4j.zeros(miniBatchSize, nIn, timeSeriesLength);
for (int i = 0; i < miniBatchSize; i++) { for (int i = 0; i < miniBatchSize; i++) {
for (int j = 0; j < nIn; j++) { for (int j = 0; j < nIn; j++) {
for (int k = 0; k < timeSeriesLength; k++) { for (int k = 0; k < timeSeriesLength; k++) {
input.putScalar(new int[] {i, j, k}, r.nextDouble() - 0.5); input.putScalar(new int[] { i, j, k }, r.nextDouble() - 0.5);
} }
} }
} }
@ -170,406 +134,200 @@ public class OutputLayerTest extends BaseDL4JTest {
for (int i = 0; i < miniBatchSize; i++) { for (int i = 0; i < miniBatchSize; i++) {
for (int j = 0; j < timeSeriesLength; j++) { for (int j = 0; j < timeSeriesLength; j++) {
int idx = r.nextInt(nOut); int idx = r.nextInt(nOut);
labels3d.putScalar(new int[] {i, idx, j}, 1.0f); labels3d.putScalar(new int[] { i, idx, j }, 1.0f);
} }
} }
INDArray labels2d = proc.backprop(labels3d, miniBatchSize, LayerWorkspaceMgr.noWorkspaces()); INDArray labels2d = proc.backprop(labels3d, miniBatchSize, LayerWorkspaceMgr.noWorkspaces());
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345L).list().layer(0, new GravesLSTM.Builder().nIn(nIn).nOut(layerSize).dist(new NormalDistribution(0, 1)).activation(Activation.TANH).updater(new NoOp()).build()).layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(layerSize).nOut(nOut).dist(new NormalDistribution(0, 1)).updater(new NoOp()).build()).inputPreProcessor(1, new RnnToFeedForwardPreProcessor()).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345L).list()
.layer(0, new GravesLSTM.Builder().nIn(nIn).nOut(layerSize)
.dist(new NormalDistribution(0, 1))
.activation(Activation.TANH).updater(new NoOp()).build())
.layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(layerSize).nOut(nOut)
.dist(new NormalDistribution(0, 1))
.updater(new NoOp()).build())
.inputPreProcessor(1, new RnnToFeedForwardPreProcessor())
.build();
MultiLayerNetwork mln = new MultiLayerNetwork(conf); MultiLayerNetwork mln = new MultiLayerNetwork(conf);
mln.init(); mln.init();
INDArray out2d = mln.feedForward(input).get(2); INDArray out2d = mln.feedForward(input).get(2);
INDArray out3d = proc.preProcess(out2d, miniBatchSize, LayerWorkspaceMgr.noWorkspaces()); INDArray out3d = proc.preProcess(out2d, miniBatchSize, LayerWorkspaceMgr.noWorkspaces());
MultiLayerConfiguration confRnn = new NeuralNetConfiguration.Builder().seed(12345L).list().layer(0, new GravesLSTM.Builder().nIn(nIn).nOut(layerSize).dist(new NormalDistribution(0, 1)).activation(Activation.TANH).updater(new NoOp()).build()).layer(1, new org.deeplearning4j.nn.conf.layers.RnnOutputLayer.Builder(LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(layerSize).nOut(nOut).dist(new NormalDistribution(0, 1)).updater(new NoOp()).build()).build();
MultiLayerConfiguration confRnn = new NeuralNetConfiguration.Builder().seed(12345L).list()
.layer(0, new GravesLSTM.Builder().nIn(nIn).nOut(layerSize)
.dist(new NormalDistribution(0, 1))
.activation(Activation.TANH).updater(new NoOp()).build())
.layer(1, new org.deeplearning4j.nn.conf.layers.RnnOutputLayer.Builder(LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(layerSize).nOut(nOut)
.dist(new NormalDistribution(0, 1))
.updater(new NoOp()).build())
.build();
MultiLayerNetwork mlnRnn = new MultiLayerNetwork(confRnn); MultiLayerNetwork mlnRnn = new MultiLayerNetwork(confRnn);
mlnRnn.init(); mlnRnn.init();
INDArray outRnn = mlnRnn.feedForward(input).get(2); INDArray outRnn = mlnRnn.feedForward(input).get(2);
mln.setLabels(labels2d); mln.setLabels(labels2d);
mlnRnn.setLabels(labels3d); mlnRnn.setLabels(labels3d);
mln.computeGradientAndScore(); mln.computeGradientAndScore();
mlnRnn.computeGradientAndScore(); mlnRnn.computeGradientAndScore();
// score is average over all examples.
//score is average over all examples. // However: OutputLayer version has miniBatch*timeSeriesLength "examples" (after reshaping)
//However: OutputLayer version has miniBatch*timeSeriesLength "examples" (after reshaping) // RnnOutputLayer has miniBatch examples
//RnnOutputLayer has miniBatch examples // Hence: expect difference in scores by factor of timeSeriesLength
//Hence: expect difference in scores by factor of timeSeriesLength
double score = mln.score() * timeSeriesLength; double score = mln.score() * timeSeriesLength;
double scoreRNN = mlnRnn.score(); double scoreRNN = mlnRnn.score();
assertTrue(!Double.isNaN(score)); assertTrue(!Double.isNaN(score));
assertTrue(!Double.isNaN(scoreRNN)); assertTrue(!Double.isNaN(scoreRNN));
double relError = Math.abs(score - scoreRNN) / (Math.abs(score) + Math.abs(scoreRNN)); double relError = Math.abs(score - scoreRNN) / (Math.abs(score) + Math.abs(scoreRNN));
System.out.println(relError); System.out.println(relError);
assertTrue(relError < 1e-6); assertTrue(relError < 1e-6);
// Check labels and inputs for output layer:
//Check labels and inputs for output layer:
OutputLayer ol = (OutputLayer) mln.getOutputLayer(); OutputLayer ol = (OutputLayer) mln.getOutputLayer();
assertArrayEquals(ol.getInput().shape(), new long[] {miniBatchSize * timeSeriesLength, layerSize}); assertArrayEquals(ol.getInput().shape(), new long[] { miniBatchSize * timeSeriesLength, layerSize });
assertArrayEquals(ol.getLabels().shape(), new long[] {miniBatchSize * timeSeriesLength, nOut}); assertArrayEquals(ol.getLabels().shape(), new long[] { miniBatchSize * timeSeriesLength, nOut });
RnnOutputLayer rnnol = (RnnOutputLayer) mlnRnn.getOutputLayer(); RnnOutputLayer rnnol = (RnnOutputLayer) mlnRnn.getOutputLayer();
//assertArrayEquals(rnnol.getInput().shape(),new int[]{miniBatchSize,layerSize,timeSeriesLength}); // assertArrayEquals(rnnol.getInput().shape(),new int[]{miniBatchSize,layerSize,timeSeriesLength});
//Input may be set by BaseLayer methods. Thus input may end up as reshaped 2d version instead of original 3d version. // Input may be set by BaseLayer methods. Thus input may end up as reshaped 2d version instead of original 3d version.
//Not ideal, but everything else works. // Not ideal, but everything else works.
assertArrayEquals(rnnol.getLabels().shape(), new long[] {miniBatchSize, nOut, timeSeriesLength}); assertArrayEquals(rnnol.getLabels().shape(), new long[] { miniBatchSize, nOut, timeSeriesLength });
// Check shapes of output for both:
//Check shapes of output for both: assertArrayEquals(out2d.shape(), new long[] { miniBatchSize * timeSeriesLength, nOut });
assertArrayEquals(out2d.shape(), new long[] {miniBatchSize * timeSeriesLength, nOut});
INDArray out = mln.output(input); INDArray out = mln.output(input);
assertArrayEquals(out.shape(), new long[] {miniBatchSize * timeSeriesLength, nOut}); assertArrayEquals(out.shape(), new long[] { miniBatchSize * timeSeriesLength, nOut });
INDArray preout = mln.output(input); INDArray preout = mln.output(input);
assertArrayEquals(preout.shape(), new long[] {miniBatchSize * timeSeriesLength, nOut}); assertArrayEquals(preout.shape(), new long[] { miniBatchSize * timeSeriesLength, nOut });
INDArray outFFRnn = mlnRnn.feedForward(input).get(2); INDArray outFFRnn = mlnRnn.feedForward(input).get(2);
assertArrayEquals(outFFRnn.shape(), new long[] {miniBatchSize, nOut, timeSeriesLength}); assertArrayEquals(outFFRnn.shape(), new long[] { miniBatchSize, nOut, timeSeriesLength });
INDArray outRnn2 = mlnRnn.output(input); INDArray outRnn2 = mlnRnn.output(input);
assertArrayEquals(outRnn2.shape(), new long[] {miniBatchSize, nOut, timeSeriesLength}); assertArrayEquals(outRnn2.shape(), new long[] { miniBatchSize, nOut, timeSeriesLength });
INDArray preoutRnn = mlnRnn.output(input); INDArray preoutRnn = mlnRnn.output(input);
assertArrayEquals(preoutRnn.shape(), new long[] {miniBatchSize, nOut, timeSeriesLength}); assertArrayEquals(preoutRnn.shape(), new long[] { miniBatchSize, nOut, timeSeriesLength });
} }
} }
@Test @Test
public void testCompareRnnOutputRnnLoss(){ @DisplayName("Test Compare Rnn Output Rnn Loss")
void testCompareRnnOutputRnnLoss() {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int timeSeriesLength = 4; int timeSeriesLength = 4;
int nIn = 5; int nIn = 5;
int layerSize = 6; int layerSize = 6;
int nOut = 6; int nOut = 6;
int miniBatchSize = 3; int miniBatchSize = 3;
MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(12345L).updater(new NoOp()).list().layer(new LSTM.Builder().nIn(nIn).nOut(layerSize).activation(Activation.TANH).dist(new NormalDistribution(0, 1.0)).updater(new NoOp()).build()).layer(new DenseLayer.Builder().nIn(layerSize).nOut(nOut).activation(Activation.IDENTITY).build()).layer(new RnnLossLayer.Builder(LossFunction.MCXENT).activation(Activation.SOFTMAX).build()).build();
MultiLayerConfiguration conf1 =
new NeuralNetConfiguration.Builder().seed(12345L)
.updater(new NoOp())
.list()
.layer(new LSTM.Builder().nIn(nIn).nOut(layerSize).activation(Activation.TANH)
.dist(new NormalDistribution(0, 1.0))
.updater(new NoOp()).build())
.layer(new DenseLayer.Builder().nIn(layerSize).nOut(nOut).activation(Activation.IDENTITY).build())
.layer(new RnnLossLayer.Builder(LossFunction.MCXENT)
.activation(Activation.SOFTMAX)
.build())
.build();
MultiLayerNetwork mln = new MultiLayerNetwork(conf1); MultiLayerNetwork mln = new MultiLayerNetwork(conf1);
mln.init(); mln.init();
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345L).updater(new NoOp()).list().layer(new LSTM.Builder().nIn(nIn).nOut(layerSize).activation(Activation.TANH).dist(new NormalDistribution(0, 1.0)).updater(new NoOp()).build()).layer(new org.deeplearning4j.nn.conf.layers.RnnOutputLayer.Builder(LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(layerSize).nOut(nOut).build()).build();
MultiLayerConfiguration conf2 =
new NeuralNetConfiguration.Builder().seed(12345L)
.updater(new NoOp())
.list()
.layer(new LSTM.Builder().nIn(nIn).nOut(layerSize).activation(Activation.TANH)
.dist(new NormalDistribution(0, 1.0))
.updater(new NoOp()).build())
.layer(new org.deeplearning4j.nn.conf.layers.RnnOutputLayer.Builder(LossFunction.MCXENT)
.activation(Activation.SOFTMAX)
.nIn(layerSize).nOut(nOut)
.build())
.build();
MultiLayerNetwork mln2 = new MultiLayerNetwork(conf2); MultiLayerNetwork mln2 = new MultiLayerNetwork(conf2);
mln2.init(); mln2.init();
mln2.setParams(mln.params()); mln2.setParams(mln.params());
INDArray in = Nd4j.rand(new int[] { miniBatchSize, nIn, timeSeriesLength });
INDArray in = Nd4j.rand(new int[]{miniBatchSize, nIn, timeSeriesLength});
INDArray out1 = mln.output(in); INDArray out1 = mln.output(in);
INDArray out2 = mln.output(in); INDArray out2 = mln.output(in);
assertEquals(out1, out2); assertEquals(out1, out2);
Random r = new Random(12345); Random r = new Random(12345);
INDArray labels = Nd4j.create(miniBatchSize, nOut, timeSeriesLength); INDArray labels = Nd4j.create(miniBatchSize, nOut, timeSeriesLength);
for( int i=0; i<miniBatchSize; i++ ){ for (int i = 0; i < miniBatchSize; i++) {
for( int j=0; j<timeSeriesLength; j++ ){ for (int j = 0; j < timeSeriesLength; j++) {
labels.putScalar(i, r.nextInt(nOut), j, 1.0); labels.putScalar(i, r.nextInt(nOut), j, 1.0);
} }
} }
mln.setInput(in); mln.setInput(in);
mln.setLabels(labels); mln.setLabels(labels);
mln2.setInput(in); mln2.setInput(in);
mln2.setLabels(labels); mln2.setLabels(labels);
mln.computeGradientAndScore(); mln.computeGradientAndScore();
mln2.computeGradientAndScore(); mln2.computeGradientAndScore();
assertEquals(mln.gradient().gradient(), mln2.gradient().gradient()); assertEquals(mln.gradient().gradient(), mln2.gradient().gradient());
assertEquals(mln.score(), mln2.score(), 1e-6); assertEquals(mln.score(), mln2.score(), 1e-6);
TestUtils.testModelSerialization(mln); TestUtils.testModelSerialization(mln);
} }
@Test @Test
public void testCnnLossLayer(){ @DisplayName("Test Cnn Loss Layer")
void testCnnLossLayer() {
for(WorkspaceMode ws : WorkspaceMode.values()) { for (WorkspaceMode ws : WorkspaceMode.values()) {
log.info("*** Testing workspace: " + ws); log.info("*** Testing workspace: " + ws);
for (Activation a : new Activation[] { Activation.TANH, Activation.SELU }) {
for (Activation a : new Activation[]{Activation.TANH, Activation.SELU}) { // Check that (A+identity) is equal to (identity+A), for activation A
//Check that (A+identity) is equal to (identity+A), for activation A // i.e., should get same output and weight gradients for both
//i.e., should get same output and weight gradients for both MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(12345L).updater(new NoOp()).convolutionMode(ConvolutionMode.Same).inferenceWorkspaceMode(ws).trainingWorkspaceMode(ws).list().layer(new ConvolutionLayer.Builder().nIn(3).nOut(4).activation(Activation.IDENTITY).kernelSize(2, 2).stride(1, 1).dist(new NormalDistribution(0, 1.0)).updater(new NoOp()).build()).layer(new CnnLossLayer.Builder(LossFunction.MSE).activation(a).build()).build();
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345L).updater(new NoOp()).convolutionMode(ConvolutionMode.Same).inferenceWorkspaceMode(ws).trainingWorkspaceMode(ws).list().layer(new ConvolutionLayer.Builder().nIn(3).nOut(4).activation(a).kernelSize(2, 2).stride(1, 1).dist(new NormalDistribution(0, 1.0)).updater(new NoOp()).build()).layer(new CnnLossLayer.Builder(LossFunction.MSE).activation(Activation.IDENTITY).build()).build();
MultiLayerConfiguration conf1 =
new NeuralNetConfiguration.Builder().seed(12345L)
.updater(new NoOp())
.convolutionMode(ConvolutionMode.Same)
.inferenceWorkspaceMode(ws)
.trainingWorkspaceMode(ws)
.list()
.layer(new ConvolutionLayer.Builder().nIn(3).nOut(4).activation(Activation.IDENTITY)
.kernelSize(2, 2).stride(1, 1)
.dist(new NormalDistribution(0, 1.0))
.updater(new NoOp()).build())
.layer(new CnnLossLayer.Builder(LossFunction.MSE)
.activation(a)
.build())
.build();
MultiLayerConfiguration conf2 =
new NeuralNetConfiguration.Builder().seed(12345L)
.updater(new NoOp())
.convolutionMode(ConvolutionMode.Same)
.inferenceWorkspaceMode(ws)
.trainingWorkspaceMode(ws)
.list()
.layer(new ConvolutionLayer.Builder().nIn(3).nOut(4).activation(a)
.kernelSize(2, 2).stride(1, 1)
.dist(new NormalDistribution(0, 1.0))
.updater(new NoOp()).build())
.layer(new CnnLossLayer.Builder(LossFunction.MSE)
.activation(Activation.IDENTITY)
.build())
.build();
MultiLayerNetwork mln = new MultiLayerNetwork(conf1); MultiLayerNetwork mln = new MultiLayerNetwork(conf1);
mln.init(); mln.init();
MultiLayerNetwork mln2 = new MultiLayerNetwork(conf2); MultiLayerNetwork mln2 = new MultiLayerNetwork(conf2);
mln2.init(); mln2.init();
mln2.setParams(mln.params()); mln2.setParams(mln.params());
INDArray in = Nd4j.rand(new int[] { 3, 3, 5, 5 });
INDArray in = Nd4j.rand(new int[]{3, 3, 5, 5});
INDArray out1 = mln.output(in); INDArray out1 = mln.output(in);
INDArray out2 = mln2.output(in); INDArray out2 = mln2.output(in);
assertEquals(out1, out2); assertEquals(out1, out2);
INDArray labels = Nd4j.rand(out1.shape()); INDArray labels = Nd4j.rand(out1.shape());
mln.setInput(in); mln.setInput(in);
mln.setLabels(labels); mln.setLabels(labels);
mln2.setInput(in); mln2.setInput(in);
mln2.setLabels(labels); mln2.setLabels(labels);
mln.computeGradientAndScore(); mln.computeGradientAndScore();
mln2.computeGradientAndScore(); mln2.computeGradientAndScore();
assertEquals(mln.score(), mln2.score(), 1e-6); assertEquals(mln.score(), mln2.score(), 1e-6);
assertEquals(mln.gradient().gradient(), mln2.gradient().gradient()); assertEquals(mln.gradient().gradient(), mln2.gradient().gradient());
// Also check computeScoreForExamples
//Also check computeScoreForExamples INDArray in2a = Nd4j.rand(new int[] { 1, 3, 5, 5 });
INDArray in2a = Nd4j.rand(new int[]{1, 3, 5, 5}); INDArray labels2a = Nd4j.rand(new int[] { 1, 4, 5, 5 });
INDArray labels2a = Nd4j.rand(new int[]{1, 4, 5, 5});
INDArray in2 = Nd4j.concat(0, in2a, in2a); INDArray in2 = Nd4j.concat(0, in2a, in2a);
INDArray labels2 = Nd4j.concat(0, labels2a, labels2a); INDArray labels2 = Nd4j.concat(0, labels2a, labels2a);
INDArray s = mln.scoreExamples(new DataSet(in2, labels2), false); INDArray s = mln.scoreExamples(new DataSet(in2, labels2), false);
assertArrayEquals(new long[]{2, 1}, s.shape()); assertArrayEquals(new long[] { 2, 1 }, s.shape());
assertEquals(s.getDouble(0), s.getDouble(1), 1e-6); assertEquals(s.getDouble(0), s.getDouble(1), 1e-6);
TestUtils.testModelSerialization(mln); TestUtils.testModelSerialization(mln);
} }
} }
} }
@Test @Test
public void testCnnLossLayerCompGraph(){ @DisplayName("Test Cnn Loss Layer Comp Graph")
void testCnnLossLayerCompGraph() {
for(WorkspaceMode ws : WorkspaceMode.values()) { for (WorkspaceMode ws : WorkspaceMode.values()) {
log.info("*** Testing workspace: " + ws); log.info("*** Testing workspace: " + ws);
for (Activation a : new Activation[] { Activation.TANH, Activation.SELU }) {
for (Activation a : new Activation[]{Activation.TANH, Activation.SELU}) { // Check that (A+identity) is equal to (identity+A), for activation A
//Check that (A+identity) is equal to (identity+A), for activation A // i.e., should get same output and weight gradients for both
//i.e., should get same output and weight gradients for both ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(12345L).updater(new NoOp()).convolutionMode(ConvolutionMode.Same).inferenceWorkspaceMode(ws).trainingWorkspaceMode(ws).graphBuilder().addInputs("in").addLayer("0", new ConvolutionLayer.Builder().nIn(3).nOut(4).activation(Activation.IDENTITY).kernelSize(2, 2).stride(1, 1).dist(new NormalDistribution(0, 1.0)).updater(new NoOp()).build(), "in").addLayer("1", new CnnLossLayer.Builder(LossFunction.MSE).activation(a).build(), "0").setOutputs("1").build();
ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345L).updater(new NoOp()).convolutionMode(ConvolutionMode.Same).inferenceWorkspaceMode(ws).trainingWorkspaceMode(ws).graphBuilder().addInputs("in").addLayer("0", new ConvolutionLayer.Builder().nIn(3).nOut(4).activation(a).kernelSize(2, 2).stride(1, 1).dist(new NormalDistribution(0, 1.0)).updater(new NoOp()).build(), "in").addLayer("1", new CnnLossLayer.Builder(LossFunction.MSE).activation(Activation.IDENTITY).build(), "0").setOutputs("1").build();
ComputationGraphConfiguration conf1 =
new NeuralNetConfiguration.Builder().seed(12345L)
.updater(new NoOp())
.convolutionMode(ConvolutionMode.Same)
.inferenceWorkspaceMode(ws)
.trainingWorkspaceMode(ws)
.graphBuilder()
.addInputs("in")
.addLayer("0", new ConvolutionLayer.Builder().nIn(3).nOut(4).activation(Activation.IDENTITY)
.kernelSize(2, 2).stride(1, 1)
.dist(new NormalDistribution(0, 1.0))
.updater(new NoOp()).build(), "in")
.addLayer("1", new CnnLossLayer.Builder(LossFunction.MSE)
.activation(a)
.build(), "0")
.setOutputs("1")
.build();
ComputationGraphConfiguration conf2 =
new NeuralNetConfiguration.Builder().seed(12345L)
.updater(new NoOp())
.convolutionMode(ConvolutionMode.Same)
.inferenceWorkspaceMode(ws)
.trainingWorkspaceMode(ws)
.graphBuilder()
.addInputs("in")
.addLayer("0", new ConvolutionLayer.Builder().nIn(3).nOut(4).activation(a)
.kernelSize(2, 2).stride(1, 1)
.dist(new NormalDistribution(0, 1.0))
.updater(new NoOp()).build(), "in")
.addLayer("1", new CnnLossLayer.Builder(LossFunction.MSE)
.activation(Activation.IDENTITY)
.build(), "0")
.setOutputs("1")
.build();
ComputationGraph graph = new ComputationGraph(conf1); ComputationGraph graph = new ComputationGraph(conf1);
graph.init(); graph.init();
ComputationGraph graph2 = new ComputationGraph(conf2); ComputationGraph graph2 = new ComputationGraph(conf2);
graph2.init(); graph2.init();
graph2.setParams(graph.params()); graph2.setParams(graph.params());
INDArray in = Nd4j.rand(new int[] { 3, 3, 5, 5 });
INDArray in = Nd4j.rand(new int[]{3, 3, 5, 5});
INDArray out1 = graph.outputSingle(in); INDArray out1 = graph.outputSingle(in);
INDArray out2 = graph2.outputSingle(in); INDArray out2 = graph2.outputSingle(in);
assertEquals(out1, out2); assertEquals(out1, out2);
INDArray labels = Nd4j.rand(out1.shape()); INDArray labels = Nd4j.rand(out1.shape());
graph.setInput(0, in);
graph.setInput(0,in);
graph.setLabels(labels); graph.setLabels(labels);
graph2.setInput(0, in);
graph2.setInput(0,in);
graph2.setLabels(labels); graph2.setLabels(labels);
graph.computeGradientAndScore(); graph.computeGradientAndScore();
graph2.computeGradientAndScore(); graph2.computeGradientAndScore();
assertEquals(graph.score(), graph2.score(), 1e-6); assertEquals(graph.score(), graph2.score(), 1e-6);
assertEquals(graph.gradient().gradient(), graph2.gradient().gradient()); assertEquals(graph.gradient().gradient(), graph2.gradient().gradient());
// Also check computeScoreForExamples
//Also check computeScoreForExamples INDArray in2a = Nd4j.rand(new int[] { 1, 3, 5, 5 });
INDArray in2a = Nd4j.rand(new int[]{1, 3, 5, 5}); INDArray labels2a = Nd4j.rand(new int[] { 1, 4, 5, 5 });
INDArray labels2a = Nd4j.rand(new int[]{1, 4, 5, 5});
INDArray in2 = Nd4j.concat(0, in2a, in2a); INDArray in2 = Nd4j.concat(0, in2a, in2a);
INDArray labels2 = Nd4j.concat(0, labels2a, labels2a); INDArray labels2 = Nd4j.concat(0, labels2a, labels2a);
INDArray s = graph.scoreExamples(new DataSet(in2, labels2), false); INDArray s = graph.scoreExamples(new DataSet(in2, labels2), false);
assertArrayEquals(new long[]{2, 1}, s.shape()); assertArrayEquals(new long[] { 2, 1 }, s.shape());
assertEquals(s.getDouble(0), s.getDouble(1), 1e-6); assertEquals(s.getDouble(0), s.getDouble(1), 1e-6);
TestUtils.testModelSerialization(graph); TestUtils.testModelSerialization(graph);
} }
} }
} }
@Test @Test
public void testCnnOutputLayerSoftmax(){ @DisplayName("Test Cnn Output Layer Softmax")
//Check that softmax is applied channels-wise void testCnnOutputLayerSoftmax() {
// Check that softmax is applied channels-wise
MultiLayerConfiguration conf = MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345L).updater(new NoOp()).convolutionMode(ConvolutionMode.Same).list().layer(new ConvolutionLayer.Builder().nIn(3).nOut(4).activation(Activation.IDENTITY).dist(new NormalDistribution(0, 1.0)).updater(new NoOp()).build()).layer(new CnnLossLayer.Builder(LossFunction.MSE).activation(Activation.SOFTMAX).build()).build();
new NeuralNetConfiguration.Builder().seed(12345L)
.updater(new NoOp())
.convolutionMode(ConvolutionMode.Same)
.list()
.layer(new ConvolutionLayer.Builder().nIn(3).nOut(4).activation(Activation.IDENTITY)
.dist(new NormalDistribution(0, 1.0))
.updater(new NoOp()).build())
.layer(new CnnLossLayer.Builder(LossFunction.MSE)
.activation(Activation.SOFTMAX)
.build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
INDArray in = Nd4j.rand(new int[] { 2, 3, 4, 5 });
INDArray in = Nd4j.rand(new int[]{2,3,4,5});
INDArray out = net.output(in); INDArray out = net.output(in);
double min = out.minNumber().doubleValue(); double min = out.minNumber().doubleValue();
double max = out.maxNumber().doubleValue(); double max = out.maxNumber().doubleValue();
assertTrue(min >= 0 && max <= 1.0); assertTrue(min >= 0 && max <= 1.0);
INDArray sum = out.sum(1); INDArray sum = out.sum(1);
assertEquals(Nd4j.ones(DataType.FLOAT,2,4,5), sum); assertEquals(Nd4j.ones(DataType.FLOAT, 2, 4, 5), sum);
} }
@Test @Test
public void testOutputLayerDefaults(){ @DisplayName("Test Output Layer Defaults")
void testOutputLayerDefaults() {
new NeuralNetConfiguration.Builder().list() new NeuralNetConfiguration.Builder().list().layer(new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder().nIn(10).nOut(10).build()).build();
.layer(new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder().nIn(10).nOut(10).build()) new NeuralNetConfiguration.Builder().list().layer(new org.deeplearning4j.nn.conf.layers.LossLayer.Builder().build()).build();
.build(); new NeuralNetConfiguration.Builder().list().layer(new org.deeplearning4j.nn.conf.layers.CnnLossLayer.Builder().build()).build();
new NeuralNetConfiguration.Builder().list().layer(new org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer.Builder().build()).build();
new NeuralNetConfiguration.Builder().list()
.layer(new org.deeplearning4j.nn.conf.layers.LossLayer.Builder().build())
.build();
new NeuralNetConfiguration.Builder().list()
.layer(new org.deeplearning4j.nn.conf.layers.CnnLossLayer.Builder().build())
.build();
new NeuralNetConfiguration.Builder().list()
.layer(new org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer.Builder().build())
.build();
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.deeplearning4j.nn.layers; package org.deeplearning4j.nn.layers;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
@ -26,47 +25,41 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.misc.RepeatVector; import org.deeplearning4j.nn.conf.layers.misc.RepeatVector;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.primitives.Pair; import org.nd4j.common.primitives.Pair;
import java.util.Arrays; import java.util.Arrays;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals; @DisplayName("Repeat Vector Test")
import static org.junit.Assert.assertTrue; class RepeatVectorTest extends BaseDL4JTest {
public class RepeatVectorTest extends BaseDL4JTest {
private int REPEAT = 4; private int REPEAT = 4;
private Layer getRepeatVectorLayer() { private Layer getRepeatVectorLayer() {
NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().seed(123) NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().seed(123).dataType(DataType.DOUBLE).layer(new RepeatVector.Builder(REPEAT).build()).build();
.dataType(DataType.DOUBLE) return conf.getLayer().instantiate(conf, null, 0, null, false, DataType.DOUBLE);
.layer(new RepeatVector.Builder(REPEAT).build()).build();
return conf.getLayer().instantiate(conf, null, 0,
null, false, DataType.DOUBLE);
} }
@Test @Test
public void testRepeatVector() { @DisplayName("Test Repeat Vector")
void testRepeatVector() {
double[] arr = new double[] {1., 2., 3., 1., 2., 3., 1., 2., 3., 1., 2., 3.}; double[] arr = new double[] { 1., 2., 3., 1., 2., 3., 1., 2., 3., 1., 2., 3. };
INDArray expectedOut = Nd4j.create(arr, new long[] {1, 3, REPEAT}, 'f'); INDArray expectedOut = Nd4j.create(arr, new long[] { 1, 3, REPEAT }, 'f');
INDArray input = Nd4j.create(new double[] {1., 2., 3.}, new long[] {1, 3}); INDArray input = Nd4j.create(new double[] { 1., 2., 3. }, new long[] { 1, 3 });
Layer layer = getRepeatVectorLayer(); Layer layer = getRepeatVectorLayer();
INDArray output = layer.activate(input, false, LayerWorkspaceMgr.noWorkspaces()); INDArray output = layer.activate(input, false, LayerWorkspaceMgr.noWorkspaces());
assertTrue(Arrays.equals(expectedOut.shape(), output.shape())); assertTrue(Arrays.equals(expectedOut.shape(), output.shape()));
assertEquals(expectedOut, output); assertEquals(expectedOut, output);
INDArray epsilon = Nd4j.ones(1, 3, 4);
INDArray epsilon = Nd4j.ones(1,3,4);
Pair<Gradient, INDArray> out = layer.backpropGradient(epsilon, LayerWorkspaceMgr.noWorkspaces()); Pair<Gradient, INDArray> out = layer.backpropGradient(epsilon, LayerWorkspaceMgr.noWorkspaces());
INDArray outEpsilon = out.getSecond(); INDArray outEpsilon = out.getSecond();
INDArray expectedEpsilon = Nd4j.create(new double[] {4., 4., 4.}, new long[] {1, 3}); INDArray expectedEpsilon = Nd4j.create(new double[] { 4., 4., 4. }, new long[] { 1, 3 });
assertEquals(expectedEpsilon, outEpsilon); assertEquals(expectedEpsilon, outEpsilon);
} }
} }

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.deeplearning4j.nn.layers; package org.deeplearning4j.nn.layers;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
@ -25,45 +24,41 @@ import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.AutoEncoder; import org.deeplearning4j.nn.conf.layers.AutoEncoder;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertEquals; import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
/** /**
*/ */
@DisplayName("Seed Test")
public class SeedTest extends BaseDL4JTest { class SeedTest extends BaseDL4JTest {
private DataSetIterator irisIter = new IrisDataSetIterator(50, 50); private DataSetIterator irisIter = new IrisDataSetIterator(50, 50);
private DataSet data = irisIter.next(); private DataSet data = irisIter.next();
@Test @Test
public void testAutoEncoderSeed() { @DisplayName("Test Auto Encoder Seed")
AutoEncoder layerType = new AutoEncoder.Builder().nIn(4).nOut(3).corruptionLevel(0.0) void testAutoEncoderSeed() {
.activation(Activation.SIGMOID).build(); AutoEncoder layerType = new AutoEncoder.Builder().nIn(4).nOut(3).corruptionLevel(0.0).activation(Activation.SIGMOID).build();
NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().layer(layerType).seed(123).build();
NeuralNetConfiguration conf =
new NeuralNetConfiguration.Builder().layer(layerType).seed(123).build();
long numParams = conf.getLayer().initializer().numParams(conf); long numParams = conf.getLayer().initializer().numParams(conf);
INDArray params = Nd4j.create(1, numParams); INDArray params = Nd4j.create(1, numParams);
Layer layer = conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); Layer layer = conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType());
layer.setBackpropGradientsViewArray(Nd4j.create(1, numParams)); layer.setBackpropGradientsViewArray(Nd4j.create(1, numParams));
layer.fit(data.getFeatures(), LayerWorkspaceMgr.noWorkspaces()); layer.fit(data.getFeatures(), LayerWorkspaceMgr.noWorkspaces());
layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces()); layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
double score = layer.score(); double score = layer.score();
INDArray parameters = layer.params(); INDArray parameters = layer.params();
layer.setParams(parameters); layer.setParams(parameters);
layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces()); layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
double score2 = layer.score(); double score2 = layer.score();
assertEquals(parameters, layer.params()); assertEquals(parameters, layer.params());
assertEquals(score, score2, 1e-4); assertEquals(score, score2, 1e-4);

View File

@ -17,11 +17,9 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.deeplearning4j.nn.layers.capsule; package org.deeplearning4j.nn.layers.capsule;
import static org.junit.Assert.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
import java.io.IOException; import java.io.IOException;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
@ -35,64 +33,44 @@ import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.LossLayer; import org.deeplearning4j.nn.conf.layers.LossLayer;
import org.deeplearning4j.nn.conf.layers.PrimaryCapsules; import org.deeplearning4j.nn.conf.layers.PrimaryCapsules;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Ignore; import org.junit.jupiter.api.Disabled;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.impl.ActivationSoftmax; import org.nd4j.linalg.activations.impl.ActivationSoftmax;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood; import org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@Ignore("AB - ignored due to excessive runtime. Keep for manual debugging when required") @Disabled("AB - ignored due to excessive runtime. Keep for manual debugging when required")
public class CapsNetMNISTTest extends BaseDL4JTest { @DisplayName("Caps Net MNIST Test")
class CapsNetMNISTTest extends BaseDL4JTest {
@Override @Override
public DataType getDataType(){ public DataType getDataType() {
return DataType.FLOAT; return DataType.FLOAT;
} }
@Test @Test
public void testCapsNetOnMNIST(){ @DisplayName("Test Caps Net On MNIST")
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() void testCapsNetOnMNIST() {
.seed(123) MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(123).updater(new Adam()).list().layer(new ConvolutionLayer.Builder().nOut(16).kernelSize(9, 9).stride(3, 3).build()).layer(new PrimaryCapsules.Builder(8, 8).kernelSize(7, 7).stride(2, 2).build()).layer(new CapsuleLayer.Builder(10, 16, 3).build()).layer(new CapsuleStrengthLayer.Builder().build()).layer(new ActivationLayer.Builder(new ActivationSoftmax()).build()).layer(new LossLayer.Builder(new LossNegativeLogLikelihood()).build()).setInputType(InputType.convolutionalFlat(28, 28, 1)).build();
.updater(new Adam())
.list()
.layer(new ConvolutionLayer.Builder()
.nOut(16)
.kernelSize(9, 9)
.stride(3, 3)
.build())
.layer(new PrimaryCapsules.Builder(8, 8)
.kernelSize(7, 7)
.stride(2, 2)
.build())
.layer(new CapsuleLayer.Builder(10, 16, 3).build())
.layer(new CapsuleStrengthLayer.Builder().build())
.layer(new ActivationLayer.Builder(new ActivationSoftmax()).build())
.layer(new LossLayer.Builder(new LossNegativeLogLikelihood()).build())
.setInputType(InputType.convolutionalFlat(28, 28, 1))
.build();
MultiLayerNetwork model = new MultiLayerNetwork(conf); MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init(); model.init();
int rngSeed = 12345; int rngSeed = 12345;
try { try {
MnistDataSetIterator mnistTrain = new MnistDataSetIterator(64, true, rngSeed); MnistDataSetIterator mnistTrain = new MnistDataSetIterator(64, true, rngSeed);
MnistDataSetIterator mnistTest = new MnistDataSetIterator(64, false, rngSeed); MnistDataSetIterator mnistTest = new MnistDataSetIterator(64, false, rngSeed);
for (int i = 0; i < 2; i++) { for (int i = 0; i < 2; i++) {
model.fit(mnistTrain); model.fit(mnistTrain);
} }
Evaluation eval = model.evaluate(mnistTest); Evaluation eval = model.evaluate(mnistTest);
assertTrue(eval.accuracy() > 0.95, "Accuracy not over 95%");
assertTrue("Accuracy not over 95%", eval.accuracy() > 0.95); assertTrue(eval.precision() > 0.95, "Precision not over 95%");
assertTrue("Precision not over 95%", eval.precision() > 0.95); assertTrue(eval.recall() > 0.95, "Recall not over 95%");
assertTrue("Recall not over 95%", eval.recall() > 0.95); assertTrue(eval.f1() > 0.95, "F1-score not over 95%");
assertTrue("F1-score not over 95%", eval.f1() > 0.95); } catch (IOException e) {
} catch (IOException e){
System.out.println("Could not load MNIST."); System.out.println("Could not load MNIST.");
} }
} }

View File

@ -17,84 +17,71 @@
* * SPDX-License-Identifier: Apache-2.0 * * SPDX-License-Identifier: Apache-2.0
* ***************************************************************************** * *****************************************************************************
*/ */
package org.deeplearning4j.nn.layers.capsule; package org.deeplearning4j.nn.layers.capsule;
import static org.junit.Assert.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.Assert.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.CapsuleLayer; import org.deeplearning4j.nn.conf.layers.CapsuleLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
public class CapsuleLayerTest extends BaseDL4JTest { @DisplayName("Capsule Layer Test")
class CapsuleLayerTest extends BaseDL4JTest {
@Override @Override
public DataType getDataType(){ public DataType getDataType() {
return DataType.FLOAT; return DataType.FLOAT;
} }
@Test @Test
public void testOutputType(){ @DisplayName("Test Output Type")
void testOutputType() {
CapsuleLayer layer = new CapsuleLayer.Builder(10, 16, 5).build(); CapsuleLayer layer = new CapsuleLayer.Builder(10, 16, 5).build();
InputType in1 = InputType.recurrent(5, 8); InputType in1 = InputType.recurrent(5, 8);
assertEquals(InputType.recurrent(10, 16), layer.getOutputType(0, in1)); assertEquals(InputType.recurrent(10, 16), layer.getOutputType(0, in1));
} }
@Test @Test
public void testInputType(){ @DisplayName("Test Input Type")
void testInputType() {
CapsuleLayer layer = new CapsuleLayer.Builder(10, 16, 5).build(); CapsuleLayer layer = new CapsuleLayer.Builder(10, 16, 5).build();
InputType in1 = InputType.recurrent(5, 8); InputType in1 = InputType.recurrent(5, 8);
layer.setNIn(in1, true); layer.setNIn(in1, true);
assertEquals(5, layer.getInputCapsules()); assertEquals(5, layer.getInputCapsules());
assertEquals(8, layer.getInputCapsuleDimensions()); assertEquals(8, layer.getInputCapsuleDimensions());
} }
@Test @Test
public void testConfig(){ @DisplayName("Test Config")
void testConfig() {
CapsuleLayer layer1 = new CapsuleLayer.Builder(10, 16, 5).build(); CapsuleLayer layer1 = new CapsuleLayer.Builder(10, 16, 5).build();
assertEquals(10, layer1.getCapsules()); assertEquals(10, layer1.getCapsules());
assertEquals(16, layer1.getCapsuleDimensions()); assertEquals(16, layer1.getCapsuleDimensions());
assertEquals(5, layer1.getRoutings()); assertEquals(5, layer1.getRoutings());
assertFalse(layer1.isHasBias()); assertFalse(layer1.isHasBias());
CapsuleLayer layer2 = new CapsuleLayer.Builder(10, 16, 5).hasBias(true).build(); CapsuleLayer layer2 = new CapsuleLayer.Builder(10, 16, 5).hasBias(true).build();
assertTrue(layer2.isHasBias()); assertTrue(layer2.isHasBias());
} }
@Test @Test
public void testLayer(){ @DisplayName("Test Layer")
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() void testLayer() {
.seed(123) MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(123).list().layer(new CapsuleLayer.Builder(10, 16, 3).build()).setInputType(InputType.recurrent(10, 8)).build();
.list()
.layer(new CapsuleLayer.Builder(10, 16, 3).build())
.setInputType(InputType.recurrent(10, 8))
.build();
MultiLayerNetwork model = new MultiLayerNetwork(conf); MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init(); model.init();
INDArray emptyFeatures = Nd4j.zeros(64, 10, 8); INDArray emptyFeatures = Nd4j.zeros(64, 10, 8);
long[] shape = model.output(emptyFeatures).shape(); long[] shape = model.output(emptyFeatures).shape();
assertArrayEquals(new long[] { 64, 10, 16 }, shape);
assertArrayEquals(new long[]{64, 10, 16}, shape);
} }
} }

Some files were not shown because too many files have changed in this diff Show More