diff --git a/contrib/codegen-tools/codegen/pom.xml b/contrib/codegen-tools/codegen/pom.xml index cbd00a825..5f367d8e4 100644 --- a/contrib/codegen-tools/codegen/pom.xml +++ b/contrib/codegen-tools/codegen/pom.xml @@ -15,7 +15,7 @@ 1.7 1.18.8 1.1.7 - 4.12 + 5.8.0-M1 5.4.2 1.8 3.1.1 diff --git a/contrib/codegen-tools/codegen/src/main/java/org/nd4j/codegen/ir/SerializationTest.java b/contrib/codegen-tools/codegen/src/main/java/org/nd4j/codegen/ir/SerializationTest.java index f41bd93a6..cbe4e265c 100644 --- a/contrib/codegen-tools/codegen/src/main/java/org/nd4j/codegen/ir/SerializationTest.java +++ b/contrib/codegen-tools/codegen/src/main/java/org/nd4j/codegen/ir/SerializationTest.java @@ -17,13 +17,14 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.nd4j.codegen.ir; -public class SerializationTest { +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; - public static void main(String...args) { +@DisplayName("Serialization Test") +class SerializationTest { + public static void main(String... args) { } - } diff --git a/contrib/codegen-tools/codegen/src/test/java/org/nd4j/codegen/dsl/DocsGeneratorTest.java b/contrib/codegen-tools/codegen/src/test/java/org/nd4j/codegen/dsl/DocsGeneratorTest.java index 5d8e12885..7eeef5717 100644 --- a/contrib/codegen-tools/codegen/src/test/java/org/nd4j/codegen/dsl/DocsGeneratorTest.java +++ b/contrib/codegen-tools/codegen/src/test/java/org/nd4j/codegen/dsl/DocsGeneratorTest.java @@ -17,29 +17,23 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.nd4j.codegen.dsl; import org.apache.commons.lang3.StringUtils; import org.junit.jupiter.api.Test; import org.nd4j.codegen.impl.java.DocsGenerator; - import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -public class DocsGeneratorTest { +@DisplayName("Docs Generator Test") +class DocsGeneratorTest { @Test - public void testJDtoMDAdapter() { - String original = "{@code %INPUT_TYPE% eye = eye(3,2)\n" + - " eye:\n" + - " [ 1, 0]\n" + - " [ 0, 1]\n" + - " [ 0, 0]}"; - String expected = "{ INDArray eye = eye(3,2)\n" + - " eye:\n" + - " [ 1, 0]\n" + - " [ 0, 1]\n" + - " [ 0, 0]}"; + @DisplayName("Test J Dto MD Adapter") + void testJDtoMDAdapter() { + String original = "{@code %INPUT_TYPE% eye = eye(3,2)\n" + " eye:\n" + " [ 1, 0]\n" + " [ 0, 1]\n" + " [ 0, 0]}"; + String expected = "{ INDArray eye = eye(3,2)\n" + " eye:\n" + " [ 1, 0]\n" + " [ 0, 1]\n" + " [ 0, 0]}"; DocsGenerator.JavaDocToMDAdapter adapter = new DocsGenerator.JavaDocToMDAdapter(original); String out = adapter.filter("@code", StringUtils.EMPTY).filter("%INPUT_TYPE%", "INDArray").toString(); assertEquals(out, expected); diff --git a/datavec/datavec-api/pom.xml b/datavec/datavec-api/pom.xml index d7fcf5a47..fc091c5dd 100644 --- a/datavec/datavec-api/pom.xml +++ b/datavec/datavec-api/pom.xml @@ -34,6 +34,14 @@ datavec-api + + org.junit.jupiter + junit-jupiter-api + + + org.junit.vintage + junit-vintage-engine + org.apache.commons commons-lang3 diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVLineSequenceRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVLineSequenceRecordReaderTest.java index 7aef92158..5ce4cb254 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVLineSequenceRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVLineSequenceRecordReaderTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.datavec.api.records.reader.impl; import org.apache.commons.io.FileUtils; @@ -27,46 +26,37 @@ import org.datavec.api.split.FileSplit; import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.tests.BaseND4JTest; - import java.io.File; import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.Collections; import java.util.List; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import java.nio.file.Path; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.assertEquals; +@DisplayName("Csv Line Sequence Record Reader Test") +class CSVLineSequenceRecordReaderTest extends BaseND4JTest { -public class CSVLineSequenceRecordReaderTest extends BaseND4JTest { - - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); + @TempDir + public Path testDir; @Test - public void test() throws Exception { - - File f = testDir.newFolder(); + @DisplayName("Test") + void test(@TempDir Path testDir) throws Exception { + File f = testDir.toFile(); File source = new File(f, "temp.csv"); String str = "a,b,c\n1,2,3,4"; FileUtils.writeStringToFile(source, str, StandardCharsets.UTF_8); - SequenceRecordReader rr = new CSVLineSequenceRecordReader(); rr.initialize(new FileSplit(source)); - - List> exp0 = Arrays.asList( - Collections.singletonList(new Text("a")), - Collections.singletonList(new Text("b")), - Collections.singletonList(new Text("c"))); - - List> exp1 = Arrays.asList( - Collections.singletonList(new Text("1")), - Collections.singletonList(new Text("2")), - Collections.singletonList(new Text("3")), - Collections.singletonList(new Text("4"))); - - for( int i=0; i<3; i++ ) { + List> exp0 = Arrays.asList(Collections.singletonList(new Text("a")), Collections.singletonList(new Text("b")), Collections.singletonList(new Text("c"))); + List> exp1 = Arrays.asList(Collections.singletonList(new Text("1")), Collections.singletonList(new Text("2")), Collections.singletonList(new Text("3")), Collections.singletonList(new Text("4"))); + for (int i = 0; i < 3; i++) { int count = 0; while (rr.hasNext()) { List> next = rr.sequenceRecord(); @@ -76,9 +66,7 @@ public class CSVLineSequenceRecordReaderTest extends BaseND4JTest { assertEquals(exp1, next); } } - assertEquals(2, count); - rr.reset(); } } diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVMultiSequenceRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVMultiSequenceRecordReaderTest.java index f78676627..f108a4438 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVMultiSequenceRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVMultiSequenceRecordReaderTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.datavec.api.records.reader.impl; import org.apache.commons.io.FileUtils; @@ -27,32 +26,34 @@ import org.datavec.api.split.FileSplit; import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.tests.BaseND4JTest; - import java.io.File; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import org.junit.jupiter.api.DisplayName; +import java.nio.file.Path; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; +@DisplayName("Csv Multi Sequence Record Reader Test") +class CSVMultiSequenceRecordReaderTest extends BaseND4JTest { -public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest { - - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); + @TempDir + public Path testDir; @Test - public void testConcatMode() throws Exception { - for( int i=0; i<3; i++ ) { - + @DisplayName("Test Concat Mode") + void testConcatMode() throws Exception { + for (int i = 0; i < 3; i++) { String seqSep; String seqSepRegex; - switch (i){ + switch(i) { case 0: seqSep = ""; seqSepRegex = "^$"; @@ -68,31 +69,23 @@ public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest { default: throw new RuntimeException(); } - String str = "a,b,c\n1,2,3,4\nx,y\n" + seqSep + "\nA,B,C"; - File f = testDir.newFile(); + File f = testDir.toFile(); FileUtils.writeStringToFile(f, str, StandardCharsets.UTF_8); - SequenceRecordReader seqRR = new CSVMultiSequenceRecordReader(seqSepRegex, CSVMultiSequenceRecordReader.Mode.CONCAT); seqRR.initialize(new FileSplit(f)); - - List> exp0 = new ArrayList<>(); for (String s : "a,b,c,1,2,3,4,x,y".split(",")) { exp0.add(Collections.singletonList(new Text(s))); } - List> exp1 = new ArrayList<>(); for (String s : "A,B,C".split(",")) { exp1.add(Collections.singletonList(new Text(s))); } - assertEquals(exp0, seqRR.sequenceRecord()); assertEquals(exp1, seqRR.sequenceRecord()); assertFalse(seqRR.hasNext()); - seqRR.reset(); - assertEquals(exp0, seqRR.sequenceRecord()); assertEquals(exp1, seqRR.sequenceRecord()); assertFalse(seqRR.hasNext()); @@ -100,13 +93,12 @@ public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest { } @Test - public void testEqualLength() throws Exception { - - for( int i=0; i<3; i++ ) { - + @DisplayName("Test Equal Length") + void testEqualLength() throws Exception { + for (int i = 0; i < 3; i++) { String seqSep; String seqSepRegex; - switch (i) { + switch(i) { case 0: seqSep = ""; seqSepRegex = "^$"; @@ -122,27 +114,17 @@ public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest { default: throw new RuntimeException(); } - String str = "a,b\n1,2\nx,y\n" + seqSep + "\nA\nB\nC"; - File f = testDir.newFile(); + File f = testDir.toFile(); FileUtils.writeStringToFile(f, str, StandardCharsets.UTF_8); - SequenceRecordReader seqRR = new CSVMultiSequenceRecordReader(seqSepRegex, CSVMultiSequenceRecordReader.Mode.EQUAL_LENGTH); seqRR.initialize(new FileSplit(f)); - - - List> exp0 = Arrays.asList( - Arrays.asList(new Text("a"), new Text("1"), new Text("x")), - Arrays.asList(new Text("b"), new Text("2"), new Text("y"))); - + List> exp0 = Arrays.asList(Arrays.asList(new Text("a"), new Text("1"), new Text("x")), Arrays.asList(new Text("b"), new Text("2"), new Text("y"))); List> exp1 = Collections.singletonList(Arrays.asList(new Text("A"), new Text("B"), new Text("C"))); - assertEquals(exp0, seqRR.sequenceRecord()); assertEquals(exp1, seqRR.sequenceRecord()); assertFalse(seqRR.hasNext()); - seqRR.reset(); - assertEquals(exp0, seqRR.sequenceRecord()); assertEquals(exp1, seqRR.sequenceRecord()); assertFalse(seqRR.hasNext()); @@ -150,13 +132,12 @@ public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest { } @Test - public void testPadding() throws Exception { - - for( int i=0; i<3; i++ ) { - + @DisplayName("Test Padding") + void testPadding() throws Exception { + for (int i = 0; i < 3; i++) { String seqSep; String seqSepRegex; - switch (i) { + switch(i) { case 0: seqSep = ""; seqSepRegex = "^$"; @@ -172,27 +153,17 @@ public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest { default: throw new RuntimeException(); } - String str = "a,b\n1\nx\n" + seqSep + "\nA\nB\nC"; - File f = testDir.newFile(); + File f = testDir.toFile(); FileUtils.writeStringToFile(f, str, StandardCharsets.UTF_8); - SequenceRecordReader seqRR = new CSVMultiSequenceRecordReader(seqSepRegex, CSVMultiSequenceRecordReader.Mode.PAD, new Text("PAD")); seqRR.initialize(new FileSplit(f)); - - - List> exp0 = Arrays.asList( - Arrays.asList(new Text("a"), new Text("1"), new Text("x")), - Arrays.asList(new Text("b"), new Text("PAD"), new Text("PAD"))); - + List> exp0 = Arrays.asList(Arrays.asList(new Text("a"), new Text("1"), new Text("x")), Arrays.asList(new Text("b"), new Text("PAD"), new Text("PAD"))); List> exp1 = Collections.singletonList(Arrays.asList(new Text("A"), new Text("B"), new Text("C"))); - assertEquals(exp0, seqRR.sequenceRecord()); assertEquals(exp1, seqRR.sequenceRecord()); assertFalse(seqRR.hasNext()); - seqRR.reset(); - assertEquals(exp0, seqRR.sequenceRecord()); assertEquals(exp1, seqRR.sequenceRecord()); assertFalse(seqRR.hasNext()); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVNLinesSequenceRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVNLinesSequenceRecordReaderTest.java index 80c75c830..184462c8c 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVNLinesSequenceRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVNLinesSequenceRecordReaderTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.datavec.api.records.reader.impl; import org.datavec.api.records.SequenceRecord; @@ -27,61 +26,53 @@ import org.datavec.api.records.reader.impl.csv.CSVNLinesSequenceRecordReader; import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.split.FileSplit; import org.datavec.api.writable.Writable; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.io.ClassPathResource; - import java.util.ArrayList; import java.util.List; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.assertEquals; - -public class CSVNLinesSequenceRecordReaderTest extends BaseND4JTest { +@DisplayName("Csvn Lines Sequence Record Reader Test") +class CSVNLinesSequenceRecordReaderTest extends BaseND4JTest { @Test - public void testCSVNLinesSequenceRecordReader() throws Exception { + @DisplayName("Test CSVN Lines Sequence Record Reader") + void testCSVNLinesSequenceRecordReader() throws Exception { int nLinesPerSequence = 10; - SequenceRecordReader seqRR = new CSVNLinesSequenceRecordReader(nLinesPerSequence); seqRR.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile())); - CSVRecordReader rr = new CSVRecordReader(); rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile())); - int count = 0; while (seqRR.hasNext()) { List> next = seqRR.sequenceRecord(); - List> expected = new ArrayList<>(); for (int i = 0; i < nLinesPerSequence; i++) { expected.add(rr.next()); } - assertEquals(10, next.size()); assertEquals(expected, next); - count++; } - assertEquals(150 / nLinesPerSequence, count); } @Test - public void testCSVNlinesSequenceRecordReaderMetaData() throws Exception { + @DisplayName("Test CSV Nlines Sequence Record Reader Meta Data") + void testCSVNlinesSequenceRecordReaderMetaData() throws Exception { int nLinesPerSequence = 10; - SequenceRecordReader seqRR = new CSVNLinesSequenceRecordReader(nLinesPerSequence); seqRR.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile())); - CSVRecordReader rr = new CSVRecordReader(); rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile())); - List>> out = new ArrayList<>(); while (seqRR.hasNext()) { List> next = seqRR.sequenceRecord(); out.add(next); } - seqRR.reset(); List>> out2 = new ArrayList<>(); List out3 = new ArrayList<>(); @@ -92,11 +83,8 @@ public class CSVNLinesSequenceRecordReaderTest extends BaseND4JTest { meta.add(seq.getMetaData()); out3.add(seq); } - assertEquals(out, out2); - List out4 = seqRR.loadSequenceFromMetaData(meta); assertEquals(out3, out4); } - } diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVRecordReaderTest.java index 85f20f3ad..c7e840c42 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVRecordReaderTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.datavec.api.records.reader.impl; import org.apache.commons.io.FileUtils; @@ -34,10 +33,10 @@ import org.datavec.api.split.partition.NumberOfRecordsPartitioner; import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; -import org.junit.Test; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.io.ClassPathResource; - import java.io.File; import java.io.IOException; import java.nio.file.Files; @@ -47,41 +46,44 @@ import java.util.Arrays; import java.util.List; import java.util.NoSuchElementException; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; + + +@DisplayName("Csv Record Reader Test") +class CSVRecordReaderTest extends BaseND4JTest { -public class CSVRecordReaderTest extends BaseND4JTest { @Test - public void testNext() throws Exception { + @DisplayName("Test Next") + void testNext() throws Exception { CSVRecordReader reader = new CSVRecordReader(); reader.initialize(new StringSplit("1,1,8.0,,,,14.0,,,,15.0,,,,,,,,,,,,1")); while (reader.hasNext()) { List vals = reader.next(); List arr = new ArrayList<>(vals); - - assertEquals("Entry count", 23, vals.size()); + assertEquals(23, vals.size(), "Entry count"); Text lastEntry = (Text) arr.get(arr.size() - 1); - assertEquals("Last entry garbage", 1, lastEntry.getLength()); + assertEquals(1, lastEntry.getLength(), "Last entry garbage"); } } @Test - public void testEmptyEntries() throws Exception { + @DisplayName("Test Empty Entries") + void testEmptyEntries() throws Exception { CSVRecordReader reader = new CSVRecordReader(); reader.initialize(new StringSplit("1,1,8.0,,,,14.0,,,,15.0,,,,,,,,,,,,")); while (reader.hasNext()) { List vals = reader.next(); - assertEquals("Entry count", 23, vals.size()); + assertEquals(23, vals.size(), "Entry count"); } } @Test - public void testReset() throws Exception { + @DisplayName("Test Reset") + void testReset() throws Exception { CSVRecordReader rr = new CSVRecordReader(0, ','); rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile())); - int nResets = 5; for (int i = 0; i < nResets; i++) { - int lineCount = 0; while (rr.hasNext()) { List line = rr.next(); @@ -95,7 +97,8 @@ public class CSVRecordReaderTest extends BaseND4JTest { } @Test - public void testResetWithSkipLines() throws Exception { + @DisplayName("Test Reset With Skip Lines") + void testResetWithSkipLines() throws Exception { CSVRecordReader rr = new CSVRecordReader(10, ','); rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile())); int lineCount = 0; @@ -114,7 +117,8 @@ public class CSVRecordReaderTest extends BaseND4JTest { } @Test - public void testWrite() throws Exception { + @DisplayName("Test Write") + void testWrite() throws Exception { List> list = new ArrayList<>(); StringBuilder sb = new StringBuilder(); for (int i = 0; i < 10; i++) { @@ -130,81 +134,72 @@ public class CSVRecordReaderTest extends BaseND4JTest { } list.add(temp); } - String expected = sb.toString(); - Path p = Files.createTempFile("csvwritetest", "csv"); p.toFile().deleteOnExit(); - FileRecordWriter writer = new CSVRecordWriter(); FileSplit fileSplit = new FileSplit(p.toFile()); - writer.initialize(fileSplit,new NumberOfRecordsPartitioner()); + writer.initialize(fileSplit, new NumberOfRecordsPartitioner()); for (List c : list) { writer.write(c); } writer.close(); - - //Read file back in; compare + // Read file back in; compare String fileContents = FileUtils.readFileToString(p.toFile(), FileRecordWriter.DEFAULT_CHARSET.name()); - - // System.out.println(expected); - // System.out.println("----------"); - // System.out.println(fileContents); - + // System.out.println(expected); + // System.out.println("----------"); + // System.out.println(fileContents); assertEquals(expected, fileContents); } @Test - public void testTabsAsSplit1() throws Exception { - + @DisplayName("Test Tabs As Split 1") + void testTabsAsSplit1() throws Exception { CSVRecordReader reader = new CSVRecordReader(0, '\t'); reader.initialize(new FileSplit(new ClassPathResource("datavec-api/tabbed.txt").getFile())); while (reader.hasNext()) { List list = new ArrayList<>(reader.next()); - assertEquals(2, list.size()); } } @Test - public void testPipesAsSplit() throws Exception { - + @DisplayName("Test Pipes As Split") + void testPipesAsSplit() throws Exception { CSVRecordReader reader = new CSVRecordReader(0, '|'); reader.initialize(new FileSplit(new ClassPathResource("datavec-api/issue414.csv").getFile())); int lineidx = 0; List sixthColumn = Arrays.asList(13, 95, 15, 25); while (reader.hasNext()) { List list = new ArrayList<>(reader.next()); - assertEquals(10, list.size()); - assertEquals((long)sixthColumn.get(lineidx), list.get(5).toInt()); + assertEquals((long) sixthColumn.get(lineidx), list.get(5).toInt()); lineidx++; } } - @Test - public void testWithQuotes() throws Exception { + @DisplayName("Test With Quotes") + void testWithQuotes() throws Exception { CSVRecordReader reader = new CSVRecordReader(0, ',', '\"'); reader.initialize(new StringSplit("1,0,3,\"Braund, Mr. Owen Harris\",male,\"\"\"\"")); while (reader.hasNext()) { List vals = reader.next(); - assertEquals("Entry count", 6, vals.size()); - assertEquals("1", vals.get(0).toString()); - assertEquals("0", vals.get(1).toString()); - assertEquals("3", vals.get(2).toString()); - assertEquals("Braund, Mr. Owen Harris", vals.get(3).toString()); - assertEquals("male", vals.get(4).toString()); - assertEquals("\"", vals.get(5).toString()); + assertEquals(6, vals.size(), "Entry count"); + assertEquals(vals.get(0).toString(), "1"); + assertEquals(vals.get(1).toString(), "0"); + assertEquals(vals.get(2).toString(), "3"); + assertEquals(vals.get(3).toString(), "Braund, Mr. Owen Harris"); + assertEquals(vals.get(4).toString(), "male"); + assertEquals(vals.get(5).toString(), "\""); } } - @Test - public void testMeta() throws Exception { + @DisplayName("Test Meta") + void testMeta() throws Exception { CSVRecordReader rr = new CSVRecordReader(0, ','); rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile())); - int lineCount = 0; List metaList = new ArrayList<>(); List> writables = new ArrayList<>(); @@ -213,30 +208,25 @@ public class CSVRecordReaderTest extends BaseND4JTest { assertEquals(5, r.getRecord().size()); lineCount++; RecordMetaData meta = r.getMetaData(); - // System.out.println(r.getRecord() + "\t" + meta.getLocation() + "\t" + meta.getURI()); - + // System.out.println(r.getRecord() + "\t" + meta.getLocation() + "\t" + meta.getURI()); metaList.add(meta); writables.add(r.getRecord()); } assertFalse(rr.hasNext()); assertEquals(150, lineCount); rr.reset(); - - System.out.println("\n\n\n--------------------------------"); List contents = rr.loadFromMetaData(metaList); assertEquals(150, contents.size()); - // for(Record r : contents ){ - // System.out.println(r); - // } - + // for(Record r : contents ){ + // System.out.println(r); + // } List meta2 = new ArrayList<>(); meta2.add(metaList.get(100)); meta2.add(metaList.get(90)); meta2.add(metaList.get(80)); meta2.add(metaList.get(70)); meta2.add(metaList.get(60)); - List contents2 = rr.loadFromMetaData(meta2); assertEquals(writables.get(100), contents2.get(0).getRecord()); assertEquals(writables.get(90), contents2.get(1).getRecord()); @@ -246,50 +236,49 @@ public class CSVRecordReaderTest extends BaseND4JTest { } @Test - public void testRegex() throws Exception { - CSVRecordReader reader = new CSVRegexRecordReader(0, ",", null, new String[] {null, "(.+) (.+) (.+)"}); + @DisplayName("Test Regex") + void testRegex() throws Exception { + CSVRecordReader reader = new CSVRegexRecordReader(0, ",", null, new String[] { null, "(.+) (.+) (.+)" }); reader.initialize(new StringSplit("normal,1.2.3.4 space separator")); while (reader.hasNext()) { List vals = reader.next(); - assertEquals("Entry count", 4, vals.size()); - assertEquals("normal", vals.get(0).toString()); - assertEquals("1.2.3.4", vals.get(1).toString()); - assertEquals("space", vals.get(2).toString()); - assertEquals("separator", vals.get(3).toString()); + assertEquals(4, vals.size(), "Entry count"); + assertEquals(vals.get(0).toString(), "normal"); + assertEquals(vals.get(1).toString(), "1.2.3.4"); + assertEquals(vals.get(2).toString(), "space"); + assertEquals(vals.get(3).toString(), "separator"); } } - @Test(expected = NoSuchElementException.class) - public void testCsvSkipAllLines() throws IOException, InterruptedException { - final int numLines = 4; - final List lineList = Arrays.asList((Writable) new IntWritable(numLines - 1), - (Writable) new Text("one"), (Writable) new Text("two"), (Writable) new Text("three")); - String header = ",one,two,three"; - List lines = new ArrayList<>(); - for (int i = 0; i < numLines; i++) - lines.add(Integer.toString(i) + header); - File tempFile = File.createTempFile("csvSkipLines", ".csv"); - FileUtils.writeLines(tempFile, lines); - - CSVRecordReader rr = new CSVRecordReader(numLines, ','); - rr.initialize(new FileSplit(tempFile)); - rr.reset(); - assertTrue(!rr.hasNext()); - rr.next(); + @Test + @DisplayName("Test Csv Skip All Lines") + void testCsvSkipAllLines() { + assertThrows(NoSuchElementException.class, () -> { + final int numLines = 4; + final List lineList = Arrays.asList((Writable) new IntWritable(numLines - 1), (Writable) new Text("one"), (Writable) new Text("two"), (Writable) new Text("three")); + String header = ",one,two,three"; + List lines = new ArrayList<>(); + for (int i = 0; i < numLines; i++) lines.add(Integer.toString(i) + header); + File tempFile = File.createTempFile("csvSkipLines", ".csv"); + FileUtils.writeLines(tempFile, lines); + CSVRecordReader rr = new CSVRecordReader(numLines, ','); + rr.initialize(new FileSplit(tempFile)); + rr.reset(); + assertTrue(!rr.hasNext()); + rr.next(); + }); } @Test - public void testCsvSkipAllButOneLine() throws IOException, InterruptedException { + @DisplayName("Test Csv Skip All But One Line") + void testCsvSkipAllButOneLine() throws IOException, InterruptedException { final int numLines = 4; - final List lineList = Arrays.asList(new Text(Integer.toString(numLines - 1)), - new Text("one"), new Text("two"), new Text("three")); + final List lineList = Arrays.asList(new Text(Integer.toString(numLines - 1)), new Text("one"), new Text("two"), new Text("three")); String header = ",one,two,three"; List lines = new ArrayList<>(); - for (int i = 0; i < numLines; i++) - lines.add(Integer.toString(i) + header); + for (int i = 0; i < numLines; i++) lines.add(Integer.toString(i) + header); File tempFile = File.createTempFile("csvSkipLines", ".csv"); FileUtils.writeLines(tempFile, lines); - CSVRecordReader rr = new CSVRecordReader(numLines - 1, ','); rr.initialize(new FileSplit(tempFile)); rr.reset(); @@ -297,50 +286,45 @@ public class CSVRecordReaderTest extends BaseND4JTest { assertEquals(rr.next(), lineList); } - @Test - public void testStreamReset() throws Exception { + @DisplayName("Test Stream Reset") + void testStreamReset() throws Exception { CSVRecordReader rr = new CSVRecordReader(0, ','); rr.initialize(new InputStreamInputSplit(new ClassPathResource("datavec-api/iris.dat").getInputStream())); - int count = 0; - while(rr.hasNext()){ + while (rr.hasNext()) { assertNotNull(rr.next()); count++; } assertEquals(150, count); - assertFalse(rr.resetSupported()); - - try{ + try { rr.reset(); fail("Expected exception"); - } catch (Exception e){ + } catch (Exception e) { String msg = e.getMessage(); String msg2 = e.getCause().getMessage(); - assertTrue(msg, msg.contains("Error during LineRecordReader reset")); - assertTrue(msg2, msg2.contains("Reset not supported from streams")); -// e.printStackTrace(); + assertTrue(msg.contains("Error during LineRecordReader reset"),msg); + assertTrue(msg2.contains("Reset not supported from streams"),msg2); + // e.printStackTrace(); } } @Test - public void testUsefulExceptionNoInit(){ - + @DisplayName("Test Useful Exception No Init") + void testUsefulExceptionNoInit() { CSVRecordReader rr = new CSVRecordReader(0, ','); - - try{ + try { rr.hasNext(); fail("Expected exception"); - } catch (Exception e){ - assertTrue(e.getMessage(), e.getMessage().contains("initialized")); + } catch (Exception e) { + assertTrue( e.getMessage().contains("initialized"),e.getMessage()); } - - try{ + try { rr.next(); fail("Expected exception"); - } catch (Exception e){ - assertTrue(e.getMessage(), e.getMessage().contains("initialized")); + } catch (Exception e) { + assertTrue(e.getMessage().contains("initialized"),e.getMessage()); } } } diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVSequenceRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVSequenceRecordReaderTest.java index 70a774165..e022746e0 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVSequenceRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVSequenceRecordReaderTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.datavec.api.records.reader.impl; import org.datavec.api.records.SequenceRecord; @@ -28,11 +27,10 @@ import org.datavec.api.split.InputSplit; import org.datavec.api.split.NumberedFileInputSplit; import org.datavec.api.writable.Writable; import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.io.ClassPathResource; - import java.io.File; import java.io.InputStream; import java.io.OutputStream; @@ -41,25 +39,27 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Iterator; import java.util.List; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import java.nio.file.Path; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.assertEquals; +@DisplayName("Csv Sequence Record Reader Test") +class CSVSequenceRecordReaderTest extends BaseND4JTest { -public class CSVSequenceRecordReaderTest extends BaseND4JTest { - - @Rule - public TemporaryFolder tempDir = new TemporaryFolder(); + @TempDir + public Path tempDir; @Test - public void test() throws Exception { - + @DisplayName("Test") + void test() throws Exception { CSVSequenceRecordReader seqReader = new CSVSequenceRecordReader(1, ","); seqReader.initialize(new TestInputSplit()); - int sequenceCount = 0; while (seqReader.hasNext()) { List> sequence = seqReader.sequenceRecord(); - assertEquals(4, sequence.size()); //4 lines, plus 1 header line - + // 4 lines, plus 1 header line + assertEquals(4, sequence.size()); Iterator> timeStepIter = sequence.iterator(); int lineCount = 0; while (timeStepIter.hasNext()) { @@ -80,19 +80,18 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest { } @Test - public void testReset() throws Exception { + @DisplayName("Test Reset") + void testReset() throws Exception { CSVSequenceRecordReader seqReader = new CSVSequenceRecordReader(1, ","); seqReader.initialize(new TestInputSplit()); - int nTests = 5; for (int i = 0; i < nTests; i++) { seqReader.reset(); - int sequenceCount = 0; while (seqReader.hasNext()) { List> sequence = seqReader.sequenceRecord(); - assertEquals(4, sequence.size()); //4 lines, plus 1 header line - + // 4 lines, plus 1 header line + assertEquals(4, sequence.size()); Iterator> timeStepIter = sequence.iterator(); int lineCount = 0; while (timeStepIter.hasNext()) { @@ -107,15 +106,15 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest { } @Test - public void testMetaData() throws Exception { + @DisplayName("Test Meta Data") + void testMetaData() throws Exception { CSVSequenceRecordReader seqReader = new CSVSequenceRecordReader(1, ","); seqReader.initialize(new TestInputSplit()); - List>> l = new ArrayList<>(); while (seqReader.hasNext()) { List> sequence = seqReader.sequenceRecord(); - assertEquals(4, sequence.size()); //4 lines, plus 1 header line - + // 4 lines, plus 1 header line + assertEquals(4, sequence.size()); Iterator> timeStepIter = sequence.iterator(); int lineCount = 0; while (timeStepIter.hasNext()) { @@ -123,10 +122,8 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest { lineCount++; } assertEquals(4, lineCount); - l.add(sequence); } - List l2 = new ArrayList<>(); List meta = new ArrayList<>(); seqReader.reset(); @@ -136,7 +133,6 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest { meta.add(sr.getMetaData()); } assertEquals(3, l2.size()); - List fromMeta = seqReader.loadSequenceFromMetaData(meta); for (int i = 0; i < 3; i++) { assertEquals(l.get(i), l2.get(i).getSequenceRecord()); @@ -144,8 +140,8 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest { } } - private static class - TestInputSplit implements InputSplit { + @DisplayName("Test Input Split") + private static class TestInputSplit implements InputSplit { @Override public boolean canWriteToLocation(URI location) { @@ -164,7 +160,6 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest { @Override public void updateSplitLocations(boolean reset) { - } @Override @@ -174,7 +169,6 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest { @Override public void bootStrapForWrite() { - } @Override @@ -222,38 +216,30 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest { @Override public void reset() { - //No op + // No op } @Override public boolean resetSupported() { return true; } - - - - } - @Test - public void testCsvSeqAndNumberedFileSplit() throws Exception { - File baseDir = tempDir.newFolder(); - //Simple sanity check unit test + @DisplayName("Test Csv Seq And Numbered File Split") + void testCsvSeqAndNumberedFileSplit(@TempDir Path tempDir) throws Exception { + File baseDir = tempDir.toFile(); + // Simple sanity check unit test for (int i = 0; i < 3; i++) { new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(baseDir); } - - //Load time series from CSV sequence files; compare to SequenceRecordReaderDataSetIterator + // Load time series from CSV sequence files; compare to SequenceRecordReaderDataSetIterator ClassPathResource resource = new ClassPathResource("csvsequence_0.txt"); String featuresPath = new File(baseDir, "csvsequence_%d.txt").getAbsolutePath(); - SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ","); featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); - - while(featureReader.hasNext()){ + while (featureReader.hasNext()) { featureReader.nextSequence(); } - } } diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVVariableSlidingWindowRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVVariableSlidingWindowRecordReaderTest.java index cab012faf..148f8ff0b 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVVariableSlidingWindowRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVVariableSlidingWindowRecordReaderTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.datavec.api.records.reader.impl; import org.datavec.api.records.reader.SequenceRecordReader; @@ -25,94 +24,87 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.records.reader.impl.csv.CSVVariableSlidingWindowRecordReader; import org.datavec.api.split.FileSplit; import org.datavec.api.writable.Writable; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.io.ClassPathResource; - import java.util.LinkedList; import java.util.List; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.assertEquals; - -public class CSVVariableSlidingWindowRecordReaderTest extends BaseND4JTest { +@DisplayName("Csv Variable Sliding Window Record Reader Test") +class CSVVariableSlidingWindowRecordReaderTest extends BaseND4JTest { @Test - public void testCSVVariableSlidingWindowRecordReader() throws Exception { + @DisplayName("Test CSV Variable Sliding Window Record Reader") + void testCSVVariableSlidingWindowRecordReader() throws Exception { int maxLinesPerSequence = 3; - SequenceRecordReader seqRR = new CSVVariableSlidingWindowRecordReader(maxLinesPerSequence); seqRR.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile())); - CSVRecordReader rr = new CSVRecordReader(); rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile())); - int count = 0; while (seqRR.hasNext()) { List> next = seqRR.sequenceRecord(); - - if(count==maxLinesPerSequence-1) { + if (count == maxLinesPerSequence - 1) { LinkedList> expected = new LinkedList<>(); for (int i = 0; i < maxLinesPerSequence; i++) { expected.addFirst(rr.next()); } assertEquals(expected, next); - } - if(count==maxLinesPerSequence) { + if (count == maxLinesPerSequence) { assertEquals(maxLinesPerSequence, next.size()); } - if(count==0) { // first seq should be length 1 + if (count == 0) { + // first seq should be length 1 assertEquals(1, next.size()); } - if(count>151) { // last seq should be length 1 + if (count > 151) { + // last seq should be length 1 assertEquals(1, next.size()); } - count++; } - assertEquals(152, count); } @Test - public void testCSVVariableSlidingWindowRecordReaderStride() throws Exception { + @DisplayName("Test CSV Variable Sliding Window Record Reader Stride") + void testCSVVariableSlidingWindowRecordReaderStride() throws Exception { int maxLinesPerSequence = 3; int stride = 2; - SequenceRecordReader seqRR = new CSVVariableSlidingWindowRecordReader(maxLinesPerSequence, stride); seqRR.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile())); - CSVRecordReader rr = new CSVRecordReader(); rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile())); - int count = 0; while (seqRR.hasNext()) { List> next = seqRR.sequenceRecord(); - - if(count==maxLinesPerSequence-1) { + if (count == maxLinesPerSequence - 1) { LinkedList> expected = new LinkedList<>(); - for(int s = 0; s < stride; s++) { + for (int s = 0; s < stride; s++) { expected = new LinkedList<>(); for (int i = 0; i < maxLinesPerSequence; i++) { expected.addFirst(rr.next()); } } assertEquals(expected, next); - } - if(count==maxLinesPerSequence) { + if (count == maxLinesPerSequence) { assertEquals(maxLinesPerSequence, next.size()); } - if(count==0) { // first seq should be length 2 + if (count == 0) { + // first seq should be length 2 assertEquals(2, next.size()); } - if(count>151) { // last seq should be length 1 + if (count > 151) { + // last seq should be length 1 assertEquals(1, next.size()); } - count++; } - assertEquals(76, count); } } diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileBatchRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileBatchRecordReaderTest.java index 036e23475..1acbf2fac 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileBatchRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileBatchRecordReaderTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.datavec.api.records.reader.impl; import org.apache.commons.io.FileUtils; @@ -29,44 +28,38 @@ import org.datavec.api.records.reader.impl.filebatch.FileBatchRecordReader; import org.datavec.api.records.reader.impl.filebatch.FileBatchSequenceRecordReader; import org.datavec.api.writable.Writable; import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.loader.FileBatch; - import java.io.File; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.List; +import static org.junit.jupiter.api.Assertions.*; +import org.junit.jupiter.api.DisplayName; +import java.nio.file.Path; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.*; - -public class FileBatchRecordReaderTest extends BaseND4JTest { - - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); +@DisplayName("File Batch Record Reader Test") +class FileBatchRecordReaderTest extends BaseND4JTest { @Test - public void testCsv() throws Exception { - - //This is an unrealistic use case - one line/record per CSV - File baseDir = testDir.newFolder(); - + @DisplayName("Test Csv") + void testCsv(@TempDir Path testDir) throws Exception { + // This is an unrealistic use case - one line/record per CSV + File baseDir = testDir.toFile(); List fileList = new ArrayList<>(); - for( int i=0; i<10; i++ ){ + for (int i = 0; i < 10; i++) { String s = "file_" + i + "," + i + "," + i; File f = new File(baseDir, "origFile" + i + ".csv"); FileUtils.writeStringToFile(f, s, StandardCharsets.UTF_8); fileList.add(f); } - FileBatch fb = FileBatch.forFiles(fileList); - RecordReader rr = new CSVRecordReader(); FileBatchRecordReader fbrr = new FileBatchRecordReader(rr, fb); - - - for( int test=0; test<3; test++) { + for (int test = 0; test < 3; test++) { for (int i = 0; i < 10; i++) { assertTrue(fbrr.hasNext()); List next = fbrr.next(); @@ -83,15 +76,15 @@ public class FileBatchRecordReaderTest extends BaseND4JTest { } @Test - public void testCsvSequence() throws Exception { - //CSV sequence - 3 lines per file, 10 files - File baseDir = testDir.newFolder(); - + @DisplayName("Test Csv Sequence") + void testCsvSequence(@TempDir Path testDir) throws Exception { + // CSV sequence - 3 lines per file, 10 files + File baseDir = testDir.toFile(); List fileList = new ArrayList<>(); - for( int i=0; i<10; i++ ){ + for (int i = 0; i < 10; i++) { StringBuilder sb = new StringBuilder(); - for( int j=0; j<3; j++ ){ - if(j > 0) + for (int j = 0; j < 3; j++) { + if (j > 0) sb.append("\n"); sb.append("file_" + i + "," + i + "," + j); } @@ -99,19 +92,16 @@ public class FileBatchRecordReaderTest extends BaseND4JTest { FileUtils.writeStringToFile(f, sb.toString(), StandardCharsets.UTF_8); fileList.add(f); } - FileBatch fb = FileBatch.forFiles(fileList); SequenceRecordReader rr = new CSVSequenceRecordReader(); FileBatchSequenceRecordReader fbrr = new FileBatchSequenceRecordReader(rr, fb); - - - for( int test=0; test<3; test++) { + for (int test = 0; test < 3; test++) { for (int i = 0; i < 10; i++) { assertTrue(fbrr.hasNext()); List> next = fbrr.sequenceRecord(); assertEquals(3, next.size()); int count = 0; - for(List step : next ){ + for (List step : next) { String s1 = "file_" + i; assertEquals(s1, step.get(0).toString()); assertEquals(String.valueOf(i), step.get(1).toString()); @@ -123,5 +113,4 @@ public class FileBatchRecordReaderTest extends BaseND4JTest { fbrr.reset(); } } - } diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileRecordReaderTest.java index 910fc31b2..d914cd95f 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileRecordReaderTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.datavec.api.records.reader.impl; import org.datavec.api.records.Record; @@ -26,28 +25,28 @@ import org.datavec.api.split.CollectionInputSplit; import org.datavec.api.split.FileSplit; import org.datavec.api.split.InputSplit; import org.datavec.api.writable.Writable; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.io.ClassPathResource; - import java.net.URI; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; - -public class FileRecordReaderTest extends BaseND4JTest { +@DisplayName("File Record Reader Test") +class FileRecordReaderTest extends BaseND4JTest { @Test - public void testReset() throws Exception { + @DisplayName("Test Reset") + void testReset() throws Exception { FileRecordReader rr = new FileRecordReader(); rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile())); - int nResets = 5; for (int i = 0; i < nResets; i++) { - int lineCount = 0; while (rr.hasNext()) { List line = rr.next(); @@ -61,25 +60,20 @@ public class FileRecordReaderTest extends BaseND4JTest { } @Test - public void testMeta() throws Exception { + @DisplayName("Test Meta") + void testMeta() throws Exception { FileRecordReader rr = new FileRecordReader(); - - URI[] arr = new URI[3]; arr[0] = new ClassPathResource("datavec-api/csvsequence_0.txt").getFile().toURI(); arr[1] = new ClassPathResource("datavec-api/csvsequence_1.txt").getFile().toURI(); arr[2] = new ClassPathResource("datavec-api/csvsequence_2.txt").getFile().toURI(); - InputSplit is = new CollectionInputSplit(Arrays.asList(arr)); rr.initialize(is); - List> out = new ArrayList<>(); while (rr.hasNext()) { out.add(rr.next()); } - assertEquals(3, out.size()); - rr.reset(); List> out2 = new ArrayList<>(); List out3 = new ArrayList<>(); @@ -90,13 +84,10 @@ public class FileRecordReaderTest extends BaseND4JTest { out2.add(r.getRecord()); out3.add(r); meta.add(r.getMetaData()); - assertEquals(arr[count++], r.getMetaData().getURI()); } - assertEquals(out, out2); List fromMeta = rr.loadFromMetaData(meta); assertEquals(out3, fromMeta); } - } diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonLineRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonLineRecordReaderTest.java index 9d5b76688..4095d1af7 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonLineRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonLineRecordReaderTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.datavec.api.records.reader.impl; import org.datavec.api.records.reader.RecordReader; @@ -29,96 +28,80 @@ import org.datavec.api.split.FileSplit; import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.io.ClassPathResource; import org.nd4j.shade.jackson.core.JsonFactory; import org.nd4j.shade.jackson.databind.ObjectMapper; - import java.io.File; import java.net.URI; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import java.nio.file.Path; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.assertEquals; +@DisplayName("Jackson Line Record Reader Test") +class JacksonLineRecordReaderTest extends BaseND4JTest { -public class JacksonLineRecordReaderTest extends BaseND4JTest { + @TempDir + public Path testDir; - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); - - public JacksonLineRecordReaderTest() { - } + public JacksonLineRecordReaderTest() { + } private static FieldSelection getFieldSelection() { - return new FieldSelection.Builder().addField("value1"). - addField("value2"). - addField("value3"). - addField("value4"). - addField("value5"). - addField("value6"). - addField("value7"). - addField("value8"). - addField("value9"). - addField("value10").build(); + return new FieldSelection.Builder().addField("value1").addField("value2").addField("value3").addField("value4").addField("value5").addField("value6").addField("value7").addField("value8").addField("value9").addField("value10").build(); } - + @Test - public void testReadJSON() throws Exception { - + @DisplayName("Test Read JSON") + void testReadJSON() throws Exception { RecordReader rr = new JacksonLineRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory())); rr.initialize(new FileSplit(new ClassPathResource("datavec-api/json/json_test_3.txt").getFile())); - testJacksonRecordReader(rr); - } - - private static void testJacksonRecordReader(RecordReader rr) { - while (rr.hasNext()) { - List json0 = rr.next(); - //System.out.println(json0); - assert(json0.size() > 0); - } } + private static void testJacksonRecordReader(RecordReader rr) { + while (rr.hasNext()) { + List json0 = rr.next(); + // System.out.println(json0); + assert (json0.size() > 0); + } + } @Test - public void testJacksonLineSequenceRecordReader() throws Exception { - File dir = testDir.newFolder(); - new ClassPathResource("datavec-api/JacksonLineSequenceRecordReaderTest/").copyDirectory(dir); - - FieldSelection f = new FieldSelection.Builder().addField("a").addField(new Text("MISSING_B"), "b") - .addField(new Text("MISSING_CX"), "c", "x").build(); - - JacksonLineSequenceRecordReader rr = new JacksonLineSequenceRecordReader(f, new ObjectMapper(new JsonFactory())); - File[] files = dir.listFiles(); - Arrays.sort(files); - URI[] u = new URI[files.length]; - for( int i=0; i> expSeq0 = new ArrayList<>(); - expSeq0.add(Arrays.asList((Writable) new Text("aValue0"), new Text("bValue0"), new Text("cxValue0"))); - expSeq0.add(Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1"))); - expSeq0.add(Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX"))); - - List> expSeq1 = new ArrayList<>(); - expSeq1.add(Arrays.asList((Writable) new Text("aValue3"), new Text("bValue3"), new Text("cxValue3"))); - - - int count = 0; - while(rr.hasNext()){ - List> next = rr.sequenceRecord(); - if(count++ == 0){ - assertEquals(expSeq0, next); - } else { - assertEquals(expSeq1, next); - } - } - - assertEquals(2, count); - } + @DisplayName("Test Jackson Line Sequence Record Reader") + void testJacksonLineSequenceRecordReader(@TempDir Path testDir) throws Exception { + File dir = testDir.toFile(); + new ClassPathResource("datavec-api/JacksonLineSequenceRecordReaderTest/").copyDirectory(dir); + FieldSelection f = new FieldSelection.Builder().addField("a").addField(new Text("MISSING_B"), "b").addField(new Text("MISSING_CX"), "c", "x").build(); + JacksonLineSequenceRecordReader rr = new JacksonLineSequenceRecordReader(f, new ObjectMapper(new JsonFactory())); + File[] files = dir.listFiles(); + Arrays.sort(files); + URI[] u = new URI[files.length]; + for (int i = 0; i < files.length; i++) { + u[i] = files[i].toURI(); + } + rr.initialize(new CollectionInputSplit(u)); + List> expSeq0 = new ArrayList<>(); + expSeq0.add(Arrays.asList((Writable) new Text("aValue0"), new Text("bValue0"), new Text("cxValue0"))); + expSeq0.add(Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1"))); + expSeq0.add(Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX"))); + List> expSeq1 = new ArrayList<>(); + expSeq1.add(Arrays.asList((Writable) new Text("aValue3"), new Text("bValue3"), new Text("cxValue3"))); + int count = 0; + while (rr.hasNext()) { + List> next = rr.sequenceRecord(); + if (count++ == 0) { + assertEquals(expSeq0, next); + } else { + assertEquals(expSeq1, next); + } + } + assertEquals(2, count); + } } diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonRecordReaderTest.java index 5b91c4523..2e4a2261b 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonRecordReaderTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.datavec.api.records.reader.impl; import org.datavec.api.io.labels.PathLabelGenerator; @@ -32,113 +31,94 @@ import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.io.ClassPathResource; import org.nd4j.shade.jackson.core.JsonFactory; import org.nd4j.shade.jackson.databind.ObjectMapper; import org.nd4j.shade.jackson.dataformat.xml.XmlFactory; import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory; - import java.io.File; import java.net.URI; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import org.junit.jupiter.api.DisplayName; +import java.nio.file.Path; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; +@DisplayName("Jackson Record Reader Test") +class JacksonRecordReaderTest extends BaseND4JTest { -public class JacksonRecordReaderTest extends BaseND4JTest { - - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); + @TempDir + public Path testDir; @Test - public void testReadingJson() throws Exception { - //Load 3 values from 3 JSON files - //stricture: a:value, b:value, c:x:value, c:y:value - //And we want to load only a:value, b:value and c:x:value - //For first JSON file: all values are present - //For second JSON file: b:value is missing - //For third JSON file: c:x:value is missing - + @DisplayName("Test Reading Json") + void testReadingJson(@TempDir Path testDir) throws Exception { + // Load 3 values from 3 JSON files + // stricture: a:value, b:value, c:x:value, c:y:value + // And we want to load only a:value, b:value and c:x:value + // For first JSON file: all values are present + // For second JSON file: b:value is missing + // For third JSON file: c:x:value is missing ClassPathResource cpr = new ClassPathResource("datavec-api/json/"); - File f = testDir.newFolder(); + File f = testDir.toFile(); cpr.copyDirectory(f); String path = new File(f, "json_test_%d.txt").getAbsolutePath(); - InputSplit is = new NumberedFileInputSplit(path, 0, 2); - RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory())); rr.initialize(is); - testJacksonRecordReader(rr); } @Test - public void testReadingYaml() throws Exception { - //Exact same information as JSON format, but in YAML format - + @DisplayName("Test Reading Yaml") + void testReadingYaml(@TempDir Path testDir) throws Exception { + // Exact same information as JSON format, but in YAML format ClassPathResource cpr = new ClassPathResource("datavec-api/yaml/"); - File f = testDir.newFolder(); + File f = testDir.toFile(); cpr.copyDirectory(f); String path = new File(f, "yaml_test_%d.txt").getAbsolutePath(); - - InputSplit is = new NumberedFileInputSplit(path, 0, 2); - RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new YAMLFactory())); rr.initialize(is); - testJacksonRecordReader(rr); } @Test - public void testReadingXml() throws Exception { - //Exact same information as JSON format, but in XML format - + @DisplayName("Test Reading Xml") + void testReadingXml(@TempDir Path testDir) throws Exception { + // Exact same information as JSON format, but in XML format ClassPathResource cpr = new ClassPathResource("datavec-api/xml/"); - File f = testDir.newFolder(); + File f = testDir.toFile(); cpr.copyDirectory(f); String path = new File(f, "xml_test_%d.txt").getAbsolutePath(); - InputSplit is = new NumberedFileInputSplit(path, 0, 2); - RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new XmlFactory())); rr.initialize(is); - testJacksonRecordReader(rr); } - private static FieldSelection getFieldSelection() { - return new FieldSelection.Builder().addField("a").addField(new Text("MISSING_B"), "b") - .addField(new Text("MISSING_CX"), "c", "x").build(); + return new FieldSelection.Builder().addField("a").addField(new Text("MISSING_B"), "b").addField(new Text("MISSING_CX"), "c", "x").build(); } - - private static void testJacksonRecordReader(RecordReader rr) { - List json0 = rr.next(); List exp0 = Arrays.asList((Writable) new Text("aValue0"), new Text("bValue0"), new Text("cxValue0")); assertEquals(exp0, json0); - List json1 = rr.next(); - List exp1 = - Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1")); + List exp1 = Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1")); assertEquals(exp1, json1); - List json2 = rr.next(); - List exp2 = - Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX")); + List exp2 = Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX")); assertEquals(exp2, json2); - assertFalse(rr.hasNext()); - - //Test reset + // Test reset rr.reset(); assertEquals(exp0, rr.next()); assertEquals(exp1, rr.next()); @@ -147,72 +127,50 @@ public class JacksonRecordReaderTest extends BaseND4JTest { } @Test - public void testAppendingLabels() throws Exception { - + @DisplayName("Test Appending Labels") + void testAppendingLabels(@TempDir Path testDir) throws Exception { ClassPathResource cpr = new ClassPathResource("datavec-api/json/"); - File f = testDir.newFolder(); + File f = testDir.toFile(); cpr.copyDirectory(f); String path = new File(f, "json_test_%d.txt").getAbsolutePath(); - InputSplit is = new NumberedFileInputSplit(path, 0, 2); - - //Insert at the end: - RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1, - new LabelGen()); + // Insert at the end: + RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1, new LabelGen()); rr.initialize(is); - - List exp0 = Arrays.asList((Writable) new Text("aValue0"), new Text("bValue0"), new Text("cxValue0"), - new IntWritable(0)); + List exp0 = Arrays.asList((Writable) new Text("aValue0"), new Text("bValue0"), new Text("cxValue0"), new IntWritable(0)); assertEquals(exp0, rr.next()); - - List exp1 = Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1"), - new IntWritable(1)); + List exp1 = Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1"), new IntWritable(1)); assertEquals(exp1, rr.next()); - - List exp2 = Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX"), - new IntWritable(2)); + List exp2 = Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX"), new IntWritable(2)); assertEquals(exp2, rr.next()); - - //Insert at position 0: - rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1, - new LabelGen(), 0); + // Insert at position 0: + rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1, new LabelGen(), 0); rr.initialize(is); - - exp0 = Arrays.asList((Writable) new IntWritable(0), new Text("aValue0"), new Text("bValue0"), - new Text("cxValue0")); + exp0 = Arrays.asList((Writable) new IntWritable(0), new Text("aValue0"), new Text("bValue0"), new Text("cxValue0")); assertEquals(exp0, rr.next()); - - exp1 = Arrays.asList((Writable) new IntWritable(1), new Text("aValue1"), new Text("MISSING_B"), - new Text("cxValue1")); + exp1 = Arrays.asList((Writable) new IntWritable(1), new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1")); assertEquals(exp1, rr.next()); - - exp2 = Arrays.asList((Writable) new IntWritable(2), new Text("aValue2"), new Text("bValue2"), - new Text("MISSING_CX")); + exp2 = Arrays.asList((Writable) new IntWritable(2), new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX")); assertEquals(exp2, rr.next()); } @Test - public void testAppendingLabelsMetaData() throws Exception { + @DisplayName("Test Appending Labels Meta Data") + void testAppendingLabelsMetaData(@TempDir Path testDir) throws Exception { ClassPathResource cpr = new ClassPathResource("datavec-api/json/"); - File f = testDir.newFolder(); + File f = testDir.toFile(); cpr.copyDirectory(f); String path = new File(f, "json_test_%d.txt").getAbsolutePath(); - InputSplit is = new NumberedFileInputSplit(path, 0, 2); - - //Insert at the end: - RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1, - new LabelGen()); + // Insert at the end: + RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1, new LabelGen()); rr.initialize(is); - List> out = new ArrayList<>(); while (rr.hasNext()) { out.add(rr.next()); } assertEquals(3, out.size()); - rr.reset(); - List> out2 = new ArrayList<>(); List outRecord = new ArrayList<>(); List meta = new ArrayList<>(); @@ -222,14 +180,12 @@ public class JacksonRecordReaderTest extends BaseND4JTest { outRecord.add(r); meta.add(r.getMetaData()); } - assertEquals(out, out2); - List fromMeta = rr.loadFromMetaData(meta); assertEquals(outRecord, fromMeta); } - + @DisplayName("Label Gen") private static class LabelGen implements PathLabelGenerator { @Override @@ -252,5 +208,4 @@ public class JacksonRecordReaderTest extends BaseND4JTest { return true; } } - } diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LibSvmRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LibSvmRecordReaderTest.java index e7fe410c8..9d3ae9663 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LibSvmRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LibSvmRecordReaderTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.datavec.api.records.reader.impl; import org.datavec.api.conf.Configuration; @@ -27,43 +26,30 @@ import org.datavec.api.split.FileSplit; import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.Writable; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.io.ClassPathResource; - import java.io.IOException; import java.util.*; - import static org.datavec.api.records.reader.impl.misc.LibSvmRecordReader.*; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; +import static org.junit.jupiter.api.Assertions.assertThrows; -public class LibSvmRecordReaderTest extends BaseND4JTest { +@DisplayName("Lib Svm Record Reader Test") +class LibSvmRecordReaderTest extends BaseND4JTest { @Test - public void testBasicRecord() throws IOException, InterruptedException { + @DisplayName("Test Basic Record") + void testBasicRecord() throws IOException, InterruptedException { Map> correct = new HashMap<>(); // 7 2:1 4:2 6:3 8:4 10:5 - correct.put(0, Arrays.asList(ZERO, ONE, - ZERO, new DoubleWritable(2), - ZERO, new DoubleWritable(3), - ZERO, new DoubleWritable(4), - ZERO, new DoubleWritable(5), - new IntWritable(7))); + correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), new IntWritable(7))); // 2 qid:42 1:0.1 2:2 6:6.6 8:80 - correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), - ZERO, ZERO, - ZERO, new DoubleWritable(6.6), - ZERO, new DoubleWritable(80), - ZERO, ZERO, - new IntWritable(2))); + correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, new IntWritable(2))); // 33 - correct.put(2, Arrays.asList(ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - new IntWritable(33))); - + correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, new IntWritable(33))); LibSvmRecordReader rr = new LibSvmRecordReader(); Configuration config = new Configuration(); config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false); @@ -80,27 +66,15 @@ public class LibSvmRecordReaderTest extends BaseND4JTest { } @Test - public void testNoAppendLabel() throws IOException, InterruptedException { + @DisplayName("Test No Append Label") + void testNoAppendLabel() throws IOException, InterruptedException { Map> correct = new HashMap<>(); // 7 2:1 4:2 6:3 8:4 10:5 - correct.put(0, Arrays.asList(ZERO, ONE, - ZERO, new DoubleWritable(2), - ZERO, new DoubleWritable(3), - ZERO, new DoubleWritable(4), - ZERO, new DoubleWritable(5))); + correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5))); // 2 qid:42 1:0.1 2:2 6:6.6 8:80 - correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), - ZERO, ZERO, - ZERO, new DoubleWritable(6.6), - ZERO, new DoubleWritable(80), - ZERO, ZERO)); + correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO)); // 33 - correct.put(2, Arrays.asList(ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO)); - + correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO)); SVMLightRecordReader rr = new SVMLightRecordReader(); Configuration config = new Configuration(); config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false); @@ -117,33 +91,17 @@ public class LibSvmRecordReaderTest extends BaseND4JTest { } @Test - public void testNoLabel() throws IOException, InterruptedException { + @DisplayName("Test No Label") + void testNoLabel() throws IOException, InterruptedException { Map> correct = new HashMap<>(); - // 2:1 4:2 6:3 8:4 10:5 - correct.put(0, Arrays.asList(ZERO, ONE, - ZERO, new DoubleWritable(2), - ZERO, new DoubleWritable(3), - ZERO, new DoubleWritable(4), - ZERO, new DoubleWritable(5))); - // qid:42 1:0.1 2:2 6:6.6 8:80 - correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), - ZERO, ZERO, - ZERO, new DoubleWritable(6.6), - ZERO, new DoubleWritable(80), - ZERO, ZERO)); - // 1:1.0 - correct.put(2, Arrays.asList(new DoubleWritable(1.0), ZERO, - ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO)); - // - correct.put(3, Arrays.asList(ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO)); - + // 2:1 4:2 6:3 8:4 10:5 + correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5))); + // qid:42 1:0.1 2:2 6:6.6 8:80 + correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO)); + // 1:1.0 + correct.put(2, Arrays.asList(new DoubleWritable(1.0), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO)); + // + correct.put(3, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO)); SVMLightRecordReader rr = new SVMLightRecordReader(); Configuration config = new Configuration(); config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false); @@ -160,33 +118,15 @@ public class LibSvmRecordReaderTest extends BaseND4JTest { } @Test - public void testMultioutputRecord() throws IOException, InterruptedException { + @DisplayName("Test Multioutput Record") + void testMultioutputRecord() throws IOException, InterruptedException { Map> correct = new HashMap<>(); // 7 2.45,9 2:1 4:2 6:3 8:4 10:5 - correct.put(0, Arrays.asList(ZERO, ONE, - ZERO, new DoubleWritable(2), - ZERO, new DoubleWritable(3), - ZERO, new DoubleWritable(4), - ZERO, new DoubleWritable(5), - new IntWritable(7), new DoubleWritable(2.45), - new IntWritable(9))); + correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), new IntWritable(7), new DoubleWritable(2.45), new IntWritable(9))); // 2,3,4 qid:42 1:0.1 2:2 6:6.6 8:80 - correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), - ZERO, ZERO, - ZERO, new DoubleWritable(6.6), - ZERO, new DoubleWritable(80), - ZERO, ZERO, - new IntWritable(2), new IntWritable(3), - new IntWritable(4))); + correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, new IntWritable(2), new IntWritable(3), new IntWritable(4))); // 33,32.0,31.9 - correct.put(2, Arrays.asList(ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - new IntWritable(33), new DoubleWritable(32.0), - new DoubleWritable(31.9))); - + correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, new IntWritable(33), new DoubleWritable(32.0), new DoubleWritable(31.9))); LibSvmRecordReader rr = new LibSvmRecordReader(); Configuration config = new Configuration(); config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false); @@ -202,51 +142,20 @@ public class LibSvmRecordReaderTest extends BaseND4JTest { assertEquals(i, correct.size()); } - @Test - public void testMultilabelRecord() throws IOException, InterruptedException { + @DisplayName("Test Multilabel Record") + void testMultilabelRecord() throws IOException, InterruptedException { Map> correct = new HashMap<>(); // 1,3 2:1 4:2 6:3 8:4 10:5 - correct.put(0, Arrays.asList(ZERO, ONE, - ZERO, new DoubleWritable(2), - ZERO, new DoubleWritable(3), - ZERO, new DoubleWritable(4), - ZERO, new DoubleWritable(5), - LABEL_ONE, LABEL_ZERO, - LABEL_ONE, LABEL_ZERO)); + correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), LABEL_ONE, LABEL_ZERO, LABEL_ONE, LABEL_ZERO)); // 2 qid:42 1:0.1 2:2 6:6.6 8:80 - correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), - ZERO, ZERO, - ZERO, new DoubleWritable(6.6), - ZERO, new DoubleWritable(80), - ZERO, ZERO, - LABEL_ZERO, LABEL_ONE, - LABEL_ZERO, LABEL_ZERO)); + correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, LABEL_ZERO, LABEL_ONE, LABEL_ZERO, LABEL_ZERO)); // 1,2,4 - correct.put(2, Arrays.asList(ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - LABEL_ONE, LABEL_ONE, - LABEL_ZERO, LABEL_ONE)); - // 1:1.0 - correct.put(3, Arrays.asList(new DoubleWritable(1.0), ZERO, - ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - LABEL_ZERO, LABEL_ZERO, - LABEL_ZERO, LABEL_ZERO)); - // - correct.put(4, Arrays.asList(ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - LABEL_ZERO, LABEL_ZERO, - LABEL_ZERO, LABEL_ZERO)); - + correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ONE, LABEL_ONE, LABEL_ZERO, LABEL_ONE)); + // 1:1.0 + correct.put(3, Arrays.asList(new DoubleWritable(1.0), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO)); + // + correct.put(4, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO)); LibSvmRecordReader rr = new LibSvmRecordReader(); Configuration config = new Configuration(); config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false); @@ -265,63 +174,24 @@ public class LibSvmRecordReaderTest extends BaseND4JTest { } @Test - public void testZeroBasedIndexing() throws IOException, InterruptedException { + @DisplayName("Test Zero Based Indexing") + void testZeroBasedIndexing() throws IOException, InterruptedException { Map> correct = new HashMap<>(); // 1,3 2:1 4:2 6:3 8:4 10:5 - correct.put(0, Arrays.asList(ZERO, - ZERO, ONE, - ZERO, new DoubleWritable(2), - ZERO, new DoubleWritable(3), - ZERO, new DoubleWritable(4), - ZERO, new DoubleWritable(5), - LABEL_ZERO, - LABEL_ONE, LABEL_ZERO, - LABEL_ONE, LABEL_ZERO)); + correct.put(0, Arrays.asList(ZERO, ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), LABEL_ZERO, LABEL_ONE, LABEL_ZERO, LABEL_ONE, LABEL_ZERO)); // 2 qid:42 1:0.1 2:2 6:6.6 8:80 - correct.put(1, Arrays.asList(ZERO, - new DoubleWritable(0.1), new DoubleWritable(2), - ZERO, ZERO, - ZERO, new DoubleWritable(6.6), - ZERO, new DoubleWritable(80), - ZERO, ZERO, - LABEL_ZERO, - LABEL_ZERO, LABEL_ONE, - LABEL_ZERO, LABEL_ZERO)); + correct.put(1, Arrays.asList(ZERO, new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ONE, LABEL_ZERO, LABEL_ZERO)); // 1,2,4 - correct.put(2, Arrays.asList(ZERO, - ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - LABEL_ZERO, - LABEL_ONE, LABEL_ONE, - LABEL_ZERO, LABEL_ONE)); - // 1:1.0 - correct.put(3, Arrays.asList(ZERO, - new DoubleWritable(1.0), ZERO, - ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - LABEL_ZERO, - LABEL_ZERO, LABEL_ZERO, - LABEL_ZERO, LABEL_ZERO)); - // - correct.put(4, Arrays.asList(ZERO, - ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - LABEL_ZERO, - LABEL_ZERO, LABEL_ZERO, - LABEL_ZERO, LABEL_ZERO)); - + correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ONE, LABEL_ONE, LABEL_ZERO, LABEL_ONE)); + // 1:1.0 + correct.put(3, Arrays.asList(ZERO, new DoubleWritable(1.0), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO)); + // + correct.put(4, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO)); LibSvmRecordReader rr = new LibSvmRecordReader(); Configuration config = new Configuration(); // Zero-based indexing is default - config.setBoolean(SVMLightRecordReader.ZERO_BASED_LABEL_INDEXING, true); // NOT STANDARD! + // NOT STANDARD! + config.setBoolean(SVMLightRecordReader.ZERO_BASED_LABEL_INDEXING, true); config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true); config.setInt(LibSvmRecordReader.NUM_FEATURES, 11); config.setBoolean(LibSvmRecordReader.MULTILABEL, true); @@ -336,87 +206,107 @@ public class LibSvmRecordReaderTest extends BaseND4JTest { assertEquals(i, correct.size()); } - @Test(expected = NoSuchElementException.class) - public void testNoSuchElementException() throws Exception { - LibSvmRecordReader rr = new LibSvmRecordReader(); - Configuration config = new Configuration(); - config.setInt(LibSvmRecordReader.NUM_FEATURES, 11); - rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile())); - while (rr.hasNext()) + @Test + @DisplayName("Test No Such Element Exception") + void testNoSuchElementException() { + assertThrows(NoSuchElementException.class, () -> { + LibSvmRecordReader rr = new LibSvmRecordReader(); + Configuration config = new Configuration(); + config.setInt(LibSvmRecordReader.NUM_FEATURES, 11); + rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile())); + while (rr.hasNext()) rr.next(); rr.next(); - rr.next(); + }); } - @Test(expected = UnsupportedOperationException.class) - public void failedToSetNumFeaturesException() throws Exception { - LibSvmRecordReader rr = new LibSvmRecordReader(); - Configuration config = new Configuration(); - rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile())); - while (rr.hasNext()) + @Test + @DisplayName("Failed To Set Num Features Exception") + void failedToSetNumFeaturesException() { + assertThrows(UnsupportedOperationException.class, () -> { + LibSvmRecordReader rr = new LibSvmRecordReader(); + Configuration config = new Configuration(); + rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile())); + while (rr.hasNext()) rr.next(); + }); + } + + @Test + @DisplayName("Test Inconsistent Num Labels Exception") + void testInconsistentNumLabelsException() { + assertThrows(UnsupportedOperationException.class, () -> { + LibSvmRecordReader rr = new LibSvmRecordReader(); + Configuration config = new Configuration(); + config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false); + rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/inconsistentNumLabels.txt").getFile())); + while (rr.hasNext()) rr.next(); + }); + } + + @Test + @DisplayName("Test Inconsistent Num Multiabels Exception") + void testInconsistentNumMultiabelsException() { + assertThrows(UnsupportedOperationException.class, () -> { + LibSvmRecordReader rr = new LibSvmRecordReader(); + Configuration config = new Configuration(); + config.setBoolean(LibSvmRecordReader.MULTILABEL, false); + config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false); + rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile())); + while (rr.hasNext()) rr.next(); + }); + } + + @Test + @DisplayName("Test Feature Index Exceeds Num Features") + void testFeatureIndexExceedsNumFeatures() { + assertThrows(IndexOutOfBoundsException.class, () -> { + LibSvmRecordReader rr = new LibSvmRecordReader(); + Configuration config = new Configuration(); + config.setInt(LibSvmRecordReader.NUM_FEATURES, 9); + rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile())); rr.next(); + }); } - @Test(expected = UnsupportedOperationException.class) - public void testInconsistentNumLabelsException() throws Exception { - LibSvmRecordReader rr = new LibSvmRecordReader(); - Configuration config = new Configuration(); - config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false); - rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/inconsistentNumLabels.txt").getFile())); - while (rr.hasNext()) + @Test + @DisplayName("Test Label Index Exceeds Num Labels") + void testLabelIndexExceedsNumLabels() { + assertThrows(IndexOutOfBoundsException.class, () -> { + LibSvmRecordReader rr = new LibSvmRecordReader(); + Configuration config = new Configuration(); + config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true); + config.setInt(LibSvmRecordReader.NUM_FEATURES, 10); + config.setInt(LibSvmRecordReader.NUM_LABELS, 6); + rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile())); rr.next(); + }); } - @Test(expected = UnsupportedOperationException.class) - public void testInconsistentNumMultiabelsException() throws Exception { - LibSvmRecordReader rr = new LibSvmRecordReader(); - Configuration config = new Configuration(); - config.setBoolean(LibSvmRecordReader.MULTILABEL, false); - config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false); - rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile())); - while (rr.hasNext()) + @Test + @DisplayName("Test Zero Index Feature Without Using Zero Indexing") + void testZeroIndexFeatureWithoutUsingZeroIndexing() { + assertThrows(IndexOutOfBoundsException.class, () -> { + LibSvmRecordReader rr = new LibSvmRecordReader(); + Configuration config = new Configuration(); + config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false); + config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true); + config.setInt(LibSvmRecordReader.NUM_FEATURES, 10); + rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexFeature.txt").getFile())); rr.next(); + }); } - @Test(expected = IndexOutOfBoundsException.class) - public void testFeatureIndexExceedsNumFeatures() throws Exception { - LibSvmRecordReader rr = new LibSvmRecordReader(); - Configuration config = new Configuration(); - config.setInt(LibSvmRecordReader.NUM_FEATURES, 9); - rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile())); - rr.next(); - } - - @Test(expected = IndexOutOfBoundsException.class) - public void testLabelIndexExceedsNumLabels() throws Exception { - LibSvmRecordReader rr = new LibSvmRecordReader(); - Configuration config = new Configuration(); - config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true); - config.setInt(LibSvmRecordReader.NUM_FEATURES, 10); - config.setInt(LibSvmRecordReader.NUM_LABELS, 6); - rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile())); - rr.next(); - } - - @Test(expected = IndexOutOfBoundsException.class) - public void testZeroIndexFeatureWithoutUsingZeroIndexing() throws Exception { - LibSvmRecordReader rr = new LibSvmRecordReader(); - Configuration config = new Configuration(); - config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false); - config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true); - config.setInt(LibSvmRecordReader.NUM_FEATURES, 10); - rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexFeature.txt").getFile())); - rr.next(); - } - - @Test(expected = IndexOutOfBoundsException.class) - public void testZeroIndexLabelWithoutUsingZeroIndexing() throws Exception { - LibSvmRecordReader rr = new LibSvmRecordReader(); - Configuration config = new Configuration(); - config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true); - config.setInt(LibSvmRecordReader.NUM_FEATURES, 10); - config.setBoolean(LibSvmRecordReader.MULTILABEL, true); - config.setInt(LibSvmRecordReader.NUM_LABELS, 2); - rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexLabel.txt").getFile())); - rr.next(); + @Test + @DisplayName("Test Zero Index Label Without Using Zero Indexing") + void testZeroIndexLabelWithoutUsingZeroIndexing() { + assertThrows(IndexOutOfBoundsException.class, () -> { + LibSvmRecordReader rr = new LibSvmRecordReader(); + Configuration config = new Configuration(); + config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true); + config.setInt(LibSvmRecordReader.NUM_FEATURES, 10); + config.setBoolean(LibSvmRecordReader.MULTILABEL, true); + config.setInt(LibSvmRecordReader.NUM_LABELS, 2); + rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexLabel.txt").getFile())); + rr.next(); + }); } } diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LineReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LineReaderTest.java index 18dc8b0fd..dd81758d0 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LineReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LineReaderTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.datavec.api.records.reader.impl; import org.apache.commons.io.FileUtils; @@ -31,10 +30,9 @@ import org.datavec.api.split.InputSplit; import org.datavec.api.split.InputStreamInputSplit; import org.datavec.api.writable.Writable; import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.tests.BaseND4JTest; - import java.io.File; import java.io.FileInputStream; import java.io.FileOutputStream; @@ -45,34 +43,31 @@ import java.util.Arrays; import java.util.List; import java.util.zip.GZIPInputStream; import java.util.zip.GZIPOutputStream; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import java.nio.file.Path; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.assertEquals; +@DisplayName("Line Reader Test") +class LineReaderTest extends BaseND4JTest { -public class LineReaderTest extends BaseND4JTest { - - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); @Test - public void testLineReader() throws Exception { - File tmpdir = testDir.newFolder(); + @DisplayName("Test Line Reader") + void testLineReader(@TempDir Path tmpDir) throws Exception { + File tmpdir = tmpDir.toFile(); if (tmpdir.exists()) tmpdir.delete(); tmpdir.mkdir(); - File tmp1 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp1.txt")); File tmp2 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp2.txt")); File tmp3 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp3.txt")); - FileUtils.writeLines(tmp1, Arrays.asList("1", "2", "3")); FileUtils.writeLines(tmp2, Arrays.asList("4", "5", "6")); FileUtils.writeLines(tmp3, Arrays.asList("7", "8", "9")); - InputSplit split = new FileSplit(tmpdir); - RecordReader reader = new LineRecordReader(); reader.initialize(split); - int count = 0; List> list = new ArrayList<>(); while (reader.hasNext()) { @@ -81,34 +76,27 @@ public class LineReaderTest extends BaseND4JTest { list.add(l); count++; } - assertEquals(9, count); } @Test - public void testLineReaderMetaData() throws Exception { - File tmpdir = testDir.newFolder(); - + @DisplayName("Test Line Reader Meta Data") + void testLineReaderMetaData(@TempDir Path tmpDir) throws Exception { + File tmpdir = tmpDir.toFile(); File tmp1 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp1.txt")); File tmp2 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp2.txt")); File tmp3 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp3.txt")); - FileUtils.writeLines(tmp1, Arrays.asList("1", "2", "3")); FileUtils.writeLines(tmp2, Arrays.asList("4", "5", "6")); FileUtils.writeLines(tmp3, Arrays.asList("7", "8", "9")); - InputSplit split = new FileSplit(tmpdir); - RecordReader reader = new LineRecordReader(); reader.initialize(split); - List> list = new ArrayList<>(); while (reader.hasNext()) { list.add(reader.next()); } assertEquals(9, list.size()); - - List> out2 = new ArrayList<>(); List out3 = new ArrayList<>(); List meta = new ArrayList<>(); @@ -124,13 +112,10 @@ public class LineReaderTest extends BaseND4JTest { assertEquals(uri, split.locations()[fileIdx]); count++; } - assertEquals(list, out2); - List fromMeta = reader.loadFromMetaData(meta); assertEquals(out3, fromMeta); - - //try: second line of second and third files only... + // try: second line of second and third files only... List subsetMeta = new ArrayList<>(); subsetMeta.add(meta.get(4)); subsetMeta.add(meta.get(7)); @@ -141,27 +126,22 @@ public class LineReaderTest extends BaseND4JTest { } @Test - public void testLineReaderWithInputStreamInputSplit() throws Exception { - File tmpdir = testDir.newFolder(); - + @DisplayName("Test Line Reader With Input Stream Input Split") + void testLineReaderWithInputStreamInputSplit(@TempDir Path testDir) throws Exception { + File tmpdir = testDir.toFile(); File tmp1 = new File(tmpdir, "tmp1.txt.gz"); - OutputStream os = new GZIPOutputStream(new FileOutputStream(tmp1, false)); IOUtils.writeLines(Arrays.asList("1", "2", "3", "4", "5", "6", "7", "8", "9"), null, os); os.flush(); os.close(); - InputSplit split = new InputStreamInputSplit(new GZIPInputStream(new FileInputStream(tmp1))); - RecordReader reader = new LineRecordReader(); reader.initialize(split); - int count = 0; while (reader.hasNext()) { assertEquals(1, reader.next().size()); count++; } - assertEquals(9, count); } } diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/RegexRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/RegexRecordReaderTest.java index 97e1a854a..997a6de10 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/RegexRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/RegexRecordReaderTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.datavec.api.records.reader.impl; import org.datavec.api.records.Record; @@ -34,43 +33,40 @@ import org.datavec.api.split.NumberedFileInputSplit; import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.io.ClassPathResource; - import java.io.File; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import org.junit.jupiter.api.DisplayName; +import java.nio.file.Path; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; +@DisplayName("Regex Record Reader Test") +class RegexRecordReaderTest extends BaseND4JTest { -public class RegexRecordReaderTest extends BaseND4JTest { - - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); + @TempDir + public Path testDir; @Test - public void testRegexLineRecordReader() throws Exception { + @DisplayName("Test Regex Line Record Reader") + void testRegexLineRecordReader() throws Exception { String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)"; - RecordReader rr = new RegexLineRecordReader(regex, 1); rr.initialize(new FileSplit(new ClassPathResource("datavec-api/logtestdata/logtestfile0.txt").getFile())); - - List exp0 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.001"), new Text("1"), - new Text("DEBUG"), new Text("First entry message!")); - List exp1 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.002"), new Text("2"), - new Text("INFO"), new Text("Second entry message!")); - List exp2 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.003"), new Text("3"), - new Text("WARN"), new Text("Third entry message!")); + List exp0 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.001"), new Text("1"), new Text("DEBUG"), new Text("First entry message!")); + List exp1 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.002"), new Text("2"), new Text("INFO"), new Text("Second entry message!")); + List exp2 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.003"), new Text("3"), new Text("WARN"), new Text("Third entry message!")); assertEquals(exp0, rr.next()); assertEquals(exp1, rr.next()); assertEquals(exp2, rr.next()); assertFalse(rr.hasNext()); - - //Test reset: + // Test reset: rr.reset(); assertEquals(exp0, rr.next()); assertEquals(exp1, rr.next()); @@ -79,74 +75,57 @@ public class RegexRecordReaderTest extends BaseND4JTest { } @Test - public void testRegexLineRecordReaderMeta() throws Exception { + @DisplayName("Test Regex Line Record Reader Meta") + void testRegexLineRecordReaderMeta() throws Exception { String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)"; - RecordReader rr = new RegexLineRecordReader(regex, 1); rr.initialize(new FileSplit(new ClassPathResource("datavec-api/logtestdata/logtestfile0.txt").getFile())); - List> list = new ArrayList<>(); while (rr.hasNext()) { list.add(rr.next()); } assertEquals(3, list.size()); - List list2 = new ArrayList<>(); List> list3 = new ArrayList<>(); List meta = new ArrayList<>(); rr.reset(); - int count = 1; //Start by skipping 1 line + // Start by skipping 1 line + int count = 1; while (rr.hasNext()) { Record r = rr.nextRecord(); list2.add(r); list3.add(r.getRecord()); meta.add(r.getMetaData()); - assertEquals(count++, ((RecordMetaDataLine) r.getMetaData()).getLineNumber()); } - List fromMeta = rr.loadFromMetaData(meta); - assertEquals(list, list3); assertEquals(list2, fromMeta); } @Test - public void testRegexSequenceRecordReader() throws Exception { + @DisplayName("Test Regex Sequence Record Reader") + void testRegexSequenceRecordReader(@TempDir Path testDir) throws Exception { String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)"; - ClassPathResource cpr = new ClassPathResource("datavec-api/logtestdata/"); - File f = testDir.newFolder(); + File f = testDir.toFile(); cpr.copyDirectory(f); String path = new File(f, "logtestfile%d.txt").getAbsolutePath(); - InputSplit is = new NumberedFileInputSplit(path, 0, 1); - SequenceRecordReader rr = new RegexSequenceRecordReader(regex, 1); rr.initialize(is); - List> exp0 = new ArrayList<>(); - exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.001"), new Text("1"), new Text("DEBUG"), - new Text("First entry message!"))); - exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.002"), new Text("2"), new Text("INFO"), - new Text("Second entry message!"))); - exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.003"), new Text("3"), new Text("WARN"), - new Text("Third entry message!"))); - - + exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.001"), new Text("1"), new Text("DEBUG"), new Text("First entry message!"))); + exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.002"), new Text("2"), new Text("INFO"), new Text("Second entry message!"))); + exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.003"), new Text("3"), new Text("WARN"), new Text("Third entry message!"))); List> exp1 = new ArrayList<>(); - exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.011"), new Text("11"), new Text("DEBUG"), - new Text("First entry message!"))); - exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.012"), new Text("12"), new Text("INFO"), - new Text("Second entry message!"))); - exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.013"), new Text("13"), new Text("WARN"), - new Text("Third entry message!"))); - + exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.011"), new Text("11"), new Text("DEBUG"), new Text("First entry message!"))); + exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.012"), new Text("12"), new Text("INFO"), new Text("Second entry message!"))); + exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.013"), new Text("13"), new Text("WARN"), new Text("Third entry message!"))); assertEquals(exp0, rr.sequenceRecord()); assertEquals(exp1, rr.sequenceRecord()); assertFalse(rr.hasNext()); - - //Test resetting: + // Test resetting: rr.reset(); assertEquals(exp0, rr.sequenceRecord()); assertEquals(exp1, rr.sequenceRecord()); @@ -154,24 +133,20 @@ public class RegexRecordReaderTest extends BaseND4JTest { } @Test - public void testRegexSequenceRecordReaderMeta() throws Exception { + @DisplayName("Test Regex Sequence Record Reader Meta") + void testRegexSequenceRecordReaderMeta(@TempDir Path testDir) throws Exception { String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)"; - ClassPathResource cpr = new ClassPathResource("datavec-api/logtestdata/"); - File f = testDir.newFolder(); + File f = testDir.toFile(); cpr.copyDirectory(f); String path = new File(f, "logtestfile%d.txt").getAbsolutePath(); - InputSplit is = new NumberedFileInputSplit(path, 0, 1); - SequenceRecordReader rr = new RegexSequenceRecordReader(regex, 1); rr.initialize(is); - List>> out = new ArrayList<>(); while (rr.hasNext()) { out.add(rr.sequenceRecord()); } - assertEquals(2, out.size()); List>> out2 = new ArrayList<>(); List out3 = new ArrayList<>(); @@ -183,11 +158,8 @@ public class RegexRecordReaderTest extends BaseND4JTest { out3.add(seqr); meta.add(seqr.getMetaData()); } - List fromMeta = rr.loadSequenceFromMetaData(meta); - assertEquals(out, out2); assertEquals(out3, fromMeta); } - } diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/SVMLightRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/SVMLightRecordReaderTest.java index 35b2d6a46..c072cea97 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/SVMLightRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/SVMLightRecordReaderTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.datavec.api.records.reader.impl; import org.datavec.api.conf.Configuration; @@ -27,43 +26,30 @@ import org.datavec.api.split.FileSplit; import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.Writable; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.io.ClassPathResource; - import java.io.IOException; import java.util.*; - import static org.datavec.api.records.reader.impl.misc.SVMLightRecordReader.*; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; +import static org.junit.jupiter.api.Assertions.assertThrows; -public class SVMLightRecordReaderTest extends BaseND4JTest { +@DisplayName("Svm Light Record Reader Test") +class SVMLightRecordReaderTest extends BaseND4JTest { @Test - public void testBasicRecord() throws IOException, InterruptedException { + @DisplayName("Test Basic Record") + void testBasicRecord() throws IOException, InterruptedException { Map> correct = new HashMap<>(); // 7 2:1 4:2 6:3 8:4 10:5 - correct.put(0, Arrays.asList(ZERO, ONE, - ZERO, new DoubleWritable(2), - ZERO, new DoubleWritable(3), - ZERO, new DoubleWritable(4), - ZERO, new DoubleWritable(5), - new IntWritable(7))); + correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), new IntWritable(7))); // 2 qid:42 1:0.1 2:2 6:6.6 8:80 - correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), - ZERO, ZERO, - ZERO, new DoubleWritable(6.6), - ZERO, new DoubleWritable(80), - ZERO, ZERO, - new IntWritable(2))); + correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, new IntWritable(2))); // 33 - correct.put(2, Arrays.asList(ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - new IntWritable(33))); - + correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, new IntWritable(33))); SVMLightRecordReader rr = new SVMLightRecordReader(); Configuration config = new Configuration(); config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false); @@ -79,27 +65,15 @@ public class SVMLightRecordReaderTest extends BaseND4JTest { } @Test - public void testNoAppendLabel() throws IOException, InterruptedException { + @DisplayName("Test No Append Label") + void testNoAppendLabel() throws IOException, InterruptedException { Map> correct = new HashMap<>(); // 7 2:1 4:2 6:3 8:4 10:5 - correct.put(0, Arrays.asList(ZERO, ONE, - ZERO, new DoubleWritable(2), - ZERO, new DoubleWritable(3), - ZERO, new DoubleWritable(4), - ZERO, new DoubleWritable(5))); + correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5))); // 2 qid:42 1:0.1 2:2 6:6.6 8:80 - correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), - ZERO, ZERO, - ZERO, new DoubleWritable(6.6), - ZERO, new DoubleWritable(80), - ZERO, ZERO)); + correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO)); // 33 - correct.put(2, Arrays.asList(ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO)); - + correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO)); SVMLightRecordReader rr = new SVMLightRecordReader(); Configuration config = new Configuration(); config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false); @@ -116,33 +90,17 @@ public class SVMLightRecordReaderTest extends BaseND4JTest { } @Test - public void testNoLabel() throws IOException, InterruptedException { + @DisplayName("Test No Label") + void testNoLabel() throws IOException, InterruptedException { Map> correct = new HashMap<>(); - // 2:1 4:2 6:3 8:4 10:5 - correct.put(0, Arrays.asList(ZERO, ONE, - ZERO, new DoubleWritable(2), - ZERO, new DoubleWritable(3), - ZERO, new DoubleWritable(4), - ZERO, new DoubleWritable(5))); - // qid:42 1:0.1 2:2 6:6.6 8:80 - correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), - ZERO, ZERO, - ZERO, new DoubleWritable(6.6), - ZERO, new DoubleWritable(80), - ZERO, ZERO)); - // 1:1.0 - correct.put(2, Arrays.asList(new DoubleWritable(1.0), ZERO, - ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO)); - // - correct.put(3, Arrays.asList(ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO)); - + // 2:1 4:2 6:3 8:4 10:5 + correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5))); + // qid:42 1:0.1 2:2 6:6.6 8:80 + correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO)); + // 1:1.0 + correct.put(2, Arrays.asList(new DoubleWritable(1.0), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO)); + // + correct.put(3, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO)); SVMLightRecordReader rr = new SVMLightRecordReader(); Configuration config = new Configuration(); config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false); @@ -159,33 +117,15 @@ public class SVMLightRecordReaderTest extends BaseND4JTest { } @Test - public void testMultioutputRecord() throws IOException, InterruptedException { + @DisplayName("Test Multioutput Record") + void testMultioutputRecord() throws IOException, InterruptedException { Map> correct = new HashMap<>(); // 7 2.45,9 2:1 4:2 6:3 8:4 10:5 - correct.put(0, Arrays.asList(ZERO, ONE, - ZERO, new DoubleWritable(2), - ZERO, new DoubleWritable(3), - ZERO, new DoubleWritable(4), - ZERO, new DoubleWritable(5), - new IntWritable(7), new DoubleWritable(2.45), - new IntWritable(9))); + correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), new IntWritable(7), new DoubleWritable(2.45), new IntWritable(9))); // 2,3,4 qid:42 1:0.1 2:2 6:6.6 8:80 - correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), - ZERO, ZERO, - ZERO, new DoubleWritable(6.6), - ZERO, new DoubleWritable(80), - ZERO, ZERO, - new IntWritable(2), new IntWritable(3), - new IntWritable(4))); + correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, new IntWritable(2), new IntWritable(3), new IntWritable(4))); // 33,32.0,31.9 - correct.put(2, Arrays.asList(ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - new IntWritable(33), new DoubleWritable(32.0), - new DoubleWritable(31.9))); - + correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, new IntWritable(33), new DoubleWritable(32.0), new DoubleWritable(31.9))); SVMLightRecordReader rr = new SVMLightRecordReader(); Configuration config = new Configuration(); config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false); @@ -200,51 +140,20 @@ public class SVMLightRecordReaderTest extends BaseND4JTest { assertEquals(i, correct.size()); } - @Test - public void testMultilabelRecord() throws IOException, InterruptedException { + @DisplayName("Test Multilabel Record") + void testMultilabelRecord() throws IOException, InterruptedException { Map> correct = new HashMap<>(); // 1,3 2:1 4:2 6:3 8:4 10:5 - correct.put(0, Arrays.asList(ZERO, ONE, - ZERO, new DoubleWritable(2), - ZERO, new DoubleWritable(3), - ZERO, new DoubleWritable(4), - ZERO, new DoubleWritable(5), - LABEL_ONE, LABEL_ZERO, - LABEL_ONE, LABEL_ZERO)); + correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), LABEL_ONE, LABEL_ZERO, LABEL_ONE, LABEL_ZERO)); // 2 qid:42 1:0.1 2:2 6:6.6 8:80 - correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), - ZERO, ZERO, - ZERO, new DoubleWritable(6.6), - ZERO, new DoubleWritable(80), - ZERO, ZERO, - LABEL_ZERO, LABEL_ONE, - LABEL_ZERO, LABEL_ZERO)); + correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, LABEL_ZERO, LABEL_ONE, LABEL_ZERO, LABEL_ZERO)); // 1,2,4 - correct.put(2, Arrays.asList(ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - LABEL_ONE, LABEL_ONE, - LABEL_ZERO, LABEL_ONE)); - // 1:1.0 - correct.put(3, Arrays.asList(new DoubleWritable(1.0), ZERO, - ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - LABEL_ZERO, LABEL_ZERO, - LABEL_ZERO, LABEL_ZERO)); - // - correct.put(4, Arrays.asList(ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - LABEL_ZERO, LABEL_ZERO, - LABEL_ZERO, LABEL_ZERO)); - + correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ONE, LABEL_ONE, LABEL_ZERO, LABEL_ONE)); + // 1:1.0 + correct.put(3, Arrays.asList(new DoubleWritable(1.0), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO)); + // + correct.put(4, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO)); SVMLightRecordReader rr = new SVMLightRecordReader(); Configuration config = new Configuration(); config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false); @@ -262,63 +171,24 @@ public class SVMLightRecordReaderTest extends BaseND4JTest { } @Test - public void testZeroBasedIndexing() throws IOException, InterruptedException { + @DisplayName("Test Zero Based Indexing") + void testZeroBasedIndexing() throws IOException, InterruptedException { Map> correct = new HashMap<>(); // 1,3 2:1 4:2 6:3 8:4 10:5 - correct.put(0, Arrays.asList(ZERO, - ZERO, ONE, - ZERO, new DoubleWritable(2), - ZERO, new DoubleWritable(3), - ZERO, new DoubleWritable(4), - ZERO, new DoubleWritable(5), - LABEL_ZERO, - LABEL_ONE, LABEL_ZERO, - LABEL_ONE, LABEL_ZERO)); + correct.put(0, Arrays.asList(ZERO, ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), LABEL_ZERO, LABEL_ONE, LABEL_ZERO, LABEL_ONE, LABEL_ZERO)); // 2 qid:42 1:0.1 2:2 6:6.6 8:80 - correct.put(1, Arrays.asList(ZERO, - new DoubleWritable(0.1), new DoubleWritable(2), - ZERO, ZERO, - ZERO, new DoubleWritable(6.6), - ZERO, new DoubleWritable(80), - ZERO, ZERO, - LABEL_ZERO, - LABEL_ZERO, LABEL_ONE, - LABEL_ZERO, LABEL_ZERO)); + correct.put(1, Arrays.asList(ZERO, new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ONE, LABEL_ZERO, LABEL_ZERO)); // 1,2,4 - correct.put(2, Arrays.asList(ZERO, - ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - LABEL_ZERO, - LABEL_ONE, LABEL_ONE, - LABEL_ZERO, LABEL_ONE)); - // 1:1.0 - correct.put(3, Arrays.asList(ZERO, - new DoubleWritable(1.0), ZERO, - ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - LABEL_ZERO, - LABEL_ZERO, LABEL_ZERO, - LABEL_ZERO, LABEL_ZERO)); - // - correct.put(4, Arrays.asList(ZERO, - ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - ZERO, ZERO, - LABEL_ZERO, - LABEL_ZERO, LABEL_ZERO, - LABEL_ZERO, LABEL_ZERO)); - + correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ONE, LABEL_ONE, LABEL_ZERO, LABEL_ONE)); + // 1:1.0 + correct.put(3, Arrays.asList(ZERO, new DoubleWritable(1.0), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO)); + // + correct.put(4, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO)); SVMLightRecordReader rr = new SVMLightRecordReader(); Configuration config = new Configuration(); // Zero-based indexing is default - config.setBoolean(SVMLightRecordReader.ZERO_BASED_LABEL_INDEXING, true); // NOT STANDARD! + // NOT STANDARD! + config.setBoolean(SVMLightRecordReader.ZERO_BASED_LABEL_INDEXING, true); config.setInt(SVMLightRecordReader.NUM_FEATURES, 11); config.setBoolean(SVMLightRecordReader.MULTILABEL, true); config.setInt(SVMLightRecordReader.NUM_LABELS, 5); @@ -333,20 +203,19 @@ public class SVMLightRecordReaderTest extends BaseND4JTest { } @Test - public void testNextRecord() throws IOException, InterruptedException { + @DisplayName("Test Next Record") + void testNextRecord() throws IOException, InterruptedException { SVMLightRecordReader rr = new SVMLightRecordReader(); Configuration config = new Configuration(); config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false); config.setInt(SVMLightRecordReader.NUM_FEATURES, 10); config.setBoolean(SVMLightRecordReader.APPEND_LABEL, false); rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile())); - Record record = rr.nextRecord(); List recordList = record.getRecord(); assertEquals(new DoubleWritable(1.0), recordList.get(1)); assertEquals(new DoubleWritable(3.0), recordList.get(5)); assertEquals(new DoubleWritable(4.0), recordList.get(7)); - record = rr.nextRecord(); recordList = record.getRecord(); assertEquals(new DoubleWritable(0.1), recordList.get(0)); @@ -354,82 +223,102 @@ public class SVMLightRecordReaderTest extends BaseND4JTest { assertEquals(new DoubleWritable(80.0), recordList.get(7)); } - @Test(expected = NoSuchElementException.class) - public void testNoSuchElementException() throws Exception { - SVMLightRecordReader rr = new SVMLightRecordReader(); - Configuration config = new Configuration(); - config.setInt(SVMLightRecordReader.NUM_FEATURES, 11); - rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile())); - while (rr.hasNext()) + @Test + @DisplayName("Test No Such Element Exception") + void testNoSuchElementException() { + assertThrows(NoSuchElementException.class, () -> { + SVMLightRecordReader rr = new SVMLightRecordReader(); + Configuration config = new Configuration(); + config.setInt(SVMLightRecordReader.NUM_FEATURES, 11); + rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile())); + while (rr.hasNext()) rr.next(); rr.next(); - rr.next(); + }); } - @Test(expected = UnsupportedOperationException.class) - public void failedToSetNumFeaturesException() throws Exception { - SVMLightRecordReader rr = new SVMLightRecordReader(); - Configuration config = new Configuration(); - rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile())); - while (rr.hasNext()) + @Test + @DisplayName("Failed To Set Num Features Exception") + void failedToSetNumFeaturesException() { + assertThrows(UnsupportedOperationException.class, () -> { + SVMLightRecordReader rr = new SVMLightRecordReader(); + Configuration config = new Configuration(); + rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile())); + while (rr.hasNext()) rr.next(); + }); + } + + @Test + @DisplayName("Test Inconsistent Num Labels Exception") + void testInconsistentNumLabelsException() { + assertThrows(UnsupportedOperationException.class, () -> { + SVMLightRecordReader rr = new SVMLightRecordReader(); + Configuration config = new Configuration(); + config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false); + rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/inconsistentNumLabels.txt").getFile())); + while (rr.hasNext()) rr.next(); + }); + } + + @Test + @DisplayName("Failed To Set Num Multiabels Exception") + void failedToSetNumMultiabelsException() { + assertThrows(UnsupportedOperationException.class, () -> { + SVMLightRecordReader rr = new SVMLightRecordReader(); + Configuration config = new Configuration(); + rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile())); + while (rr.hasNext()) rr.next(); + }); + } + + @Test + @DisplayName("Test Feature Index Exceeds Num Features") + void testFeatureIndexExceedsNumFeatures() { + assertThrows(IndexOutOfBoundsException.class, () -> { + SVMLightRecordReader rr = new SVMLightRecordReader(); + Configuration config = new Configuration(); + config.setInt(SVMLightRecordReader.NUM_FEATURES, 9); + rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile())); rr.next(); + }); } - @Test(expected = UnsupportedOperationException.class) - public void testInconsistentNumLabelsException() throws Exception { - SVMLightRecordReader rr = new SVMLightRecordReader(); - Configuration config = new Configuration(); - config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false); - rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/inconsistentNumLabels.txt").getFile())); - while (rr.hasNext()) + @Test + @DisplayName("Test Label Index Exceeds Num Labels") + void testLabelIndexExceedsNumLabels() { + assertThrows(IndexOutOfBoundsException.class, () -> { + SVMLightRecordReader rr = new SVMLightRecordReader(); + Configuration config = new Configuration(); + config.setInt(SVMLightRecordReader.NUM_FEATURES, 10); + config.setInt(SVMLightRecordReader.NUM_LABELS, 6); + rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile())); rr.next(); + }); } - @Test(expected = UnsupportedOperationException.class) - public void failedToSetNumMultiabelsException() throws Exception { - SVMLightRecordReader rr = new SVMLightRecordReader(); - Configuration config = new Configuration(); - rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile())); - while (rr.hasNext()) + @Test + @DisplayName("Test Zero Index Feature Without Using Zero Indexing") + void testZeroIndexFeatureWithoutUsingZeroIndexing() { + assertThrows(IndexOutOfBoundsException.class, () -> { + SVMLightRecordReader rr = new SVMLightRecordReader(); + Configuration config = new Configuration(); + config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false); + config.setInt(SVMLightRecordReader.NUM_FEATURES, 10); + rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexFeature.txt").getFile())); rr.next(); + }); } - @Test(expected = IndexOutOfBoundsException.class) - public void testFeatureIndexExceedsNumFeatures() throws Exception { - SVMLightRecordReader rr = new SVMLightRecordReader(); - Configuration config = new Configuration(); - config.setInt(SVMLightRecordReader.NUM_FEATURES, 9); - rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile())); - rr.next(); - } - - @Test(expected = IndexOutOfBoundsException.class) - public void testLabelIndexExceedsNumLabels() throws Exception { - SVMLightRecordReader rr = new SVMLightRecordReader(); - Configuration config = new Configuration(); - config.setInt(SVMLightRecordReader.NUM_FEATURES, 10); - config.setInt(SVMLightRecordReader.NUM_LABELS, 6); - rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile())); - rr.next(); - } - - @Test(expected = IndexOutOfBoundsException.class) - public void testZeroIndexFeatureWithoutUsingZeroIndexing() throws Exception { - SVMLightRecordReader rr = new SVMLightRecordReader(); - Configuration config = new Configuration(); - config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false); - config.setInt(SVMLightRecordReader.NUM_FEATURES, 10); - rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexFeature.txt").getFile())); - rr.next(); - } - - @Test(expected = IndexOutOfBoundsException.class) - public void testZeroIndexLabelWithoutUsingZeroIndexing() throws Exception { - SVMLightRecordReader rr = new SVMLightRecordReader(); - Configuration config = new Configuration(); - config.setInt(SVMLightRecordReader.NUM_FEATURES, 10); - config.setBoolean(SVMLightRecordReader.MULTILABEL, true); - config.setInt(SVMLightRecordReader.NUM_LABELS, 2); - rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexLabel.txt").getFile())); - rr.next(); + @Test + @DisplayName("Test Zero Index Label Without Using Zero Indexing") + void testZeroIndexLabelWithoutUsingZeroIndexing() { + assertThrows(IndexOutOfBoundsException.class, () -> { + SVMLightRecordReader rr = new SVMLightRecordReader(); + Configuration config = new Configuration(); + config.setInt(SVMLightRecordReader.NUM_FEATURES, 10); + config.setBoolean(SVMLightRecordReader.MULTILABEL, true); + config.setInt(SVMLightRecordReader.NUM_LABELS, 2); + rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexLabel.txt").getFile())); + rr.next(); + }); } } diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/CSVRecordWriterTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/CSVRecordWriterTest.java index 5890722b3..c63240896 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/CSVRecordWriterTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/CSVRecordWriterTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.datavec.api.records.writer.impl; import org.datavec.api.records.reader.impl.csv.CSVRecordReader; @@ -26,44 +25,42 @@ import org.datavec.api.split.FileSplit; import org.datavec.api.split.partition.NumberOfRecordsPartitioner; import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; - import java.io.File; import java.util.ArrayList; import java.util.List; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.assertEquals; - -public class CSVRecordWriterTest extends BaseND4JTest { - - @Before - public void setUp() throws Exception { +@DisplayName("Csv Record Writer Test") +class CSVRecordWriterTest extends BaseND4JTest { + @BeforeEach + void setUp() throws Exception { } @Test - public void testWrite() throws Exception { + @DisplayName("Test Write") + void testWrite() throws Exception { File tempFile = File.createTempFile("datavec", "writer"); tempFile.deleteOnExit(); FileSplit fileSplit = new FileSplit(tempFile); CSVRecordWriter writer = new CSVRecordWriter(); - writer.initialize(fileSplit,new NumberOfRecordsPartitioner()); + writer.initialize(fileSplit, new NumberOfRecordsPartitioner()); List collection = new ArrayList<>(); collection.add(new Text("12")); collection.add(new Text("13")); collection.add(new Text("14")); - writer.write(collection); - CSVRecordReader reader = new CSVRecordReader(0); reader.initialize(new FileSplit(tempFile)); int cnt = 0; while (reader.hasNext()) { List line = new ArrayList<>(reader.next()); assertEquals(3, line.size()); - assertEquals(12, line.get(0).toInt()); assertEquals(13, line.get(1).toInt()); assertEquals(14, line.get(2).toInt()); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/LibSvmRecordWriterTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/LibSvmRecordWriterTest.java index 0e80e10b7..66c9ab3d2 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/LibSvmRecordWriterTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/LibSvmRecordWriterTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.datavec.api.records.writer.impl; import org.apache.commons.io.FileUtils; @@ -30,93 +29,90 @@ import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.NDArrayWritable; import org.datavec.api.writable.Writable; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.common.io.ClassPathResource; - import java.io.File; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.regex.Matcher; import java.util.regex.Pattern; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; +import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.Assert.assertEquals; - -public class LibSvmRecordWriterTest extends BaseND4JTest { +@DisplayName("Lib Svm Record Writer Test") +class LibSvmRecordWriterTest extends BaseND4JTest { @Test - public void testBasic() throws Exception { + @DisplayName("Test Basic") + void testBasic() throws Exception { Configuration configWriter = new Configuration(); - Configuration configReader = new Configuration(); configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 10); configReader.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false); - File inputFile = new ClassPathResource("datavec-api/svmlight/basic.txt").getFile(); executeTest(configWriter, configReader, inputFile); } @Test - public void testNoLabel() throws Exception { + @DisplayName("Test No Label") + void testNoLabel() throws Exception { Configuration configWriter = new Configuration(); configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 9); - Configuration configReader = new Configuration(); configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 10); configReader.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false); - File inputFile = new ClassPathResource("datavec-api/svmlight/basic.txt").getFile(); executeTest(configWriter, configReader, inputFile); } @Test - public void testMultioutputRecord() throws Exception { + @DisplayName("Test Multioutput Record") + void testMultioutputRecord() throws Exception { Configuration configWriter = new Configuration(); configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 9); - Configuration configReader = new Configuration(); configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 10); configReader.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false); - File inputFile = new ClassPathResource("datavec-api/svmlight/multioutput.txt").getFile(); executeTest(configWriter, configReader, inputFile); } @Test - public void testMultilabelRecord() throws Exception { + @DisplayName("Test Multilabel Record") + void testMultilabelRecord() throws Exception { Configuration configWriter = new Configuration(); configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 9); configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true); - Configuration configReader = new Configuration(); configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 10); configReader.setBoolean(LibSvmRecordReader.MULTILABEL, true); configReader.setInt(LibSvmRecordReader.NUM_LABELS, 4); configReader.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false); - File inputFile = new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile(); executeTest(configWriter, configReader, inputFile); } @Test - public void testZeroBasedIndexing() throws Exception { + @DisplayName("Test Zero Based Indexing") + void testZeroBasedIndexing() throws Exception { Configuration configWriter = new Configuration(); configWriter.setBoolean(LibSvmRecordWriter.ZERO_BASED_INDEXING, true); configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 10); configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true); - Configuration configReader = new Configuration(); configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 11); configReader.setBoolean(LibSvmRecordReader.MULTILABEL, true); configReader.setInt(LibSvmRecordReader.NUM_LABELS, 5); - File inputFile = new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile(); executeTest(configWriter, configReader, inputFile); } @@ -127,10 +123,9 @@ public class LibSvmRecordWriterTest extends BaseND4JTest { tempFile.deleteOnExit(); if (tempFile.exists()) tempFile.delete(); - try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) { - FileSplit outputSplit = new FileSplit(tempFile); - writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner()); + FileSplit outputSplit = new FileSplit(tempFile); + writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner()); LibSvmRecordReader rr = new LibSvmRecordReader(); rr.initialize(configReader, new FileSplit(inputFile)); while (rr.hasNext()) { @@ -138,7 +133,6 @@ public class LibSvmRecordWriterTest extends BaseND4JTest { writer.write(record); } } - Pattern p = Pattern.compile(String.format("%s:\\d+ ", LibSvmRecordReader.QID_PREFIX)); List linesOriginal = new ArrayList<>(); for (String line : FileUtils.readLines(inputFile)) { @@ -159,7 +153,8 @@ public class LibSvmRecordWriterTest extends BaseND4JTest { } @Test - public void testNDArrayWritables() throws Exception { + @DisplayName("Test ND Array Writables") + void testNDArrayWritables() throws Exception { INDArray arr2 = Nd4j.zeros(2); arr2.putScalar(0, 11); arr2.putScalar(1, 12); @@ -167,35 +162,28 @@ public class LibSvmRecordWriterTest extends BaseND4JTest { arr3.putScalar(0, 13); arr3.putScalar(1, 14); arr3.putScalar(2, 15); - List record = Arrays.asList((Writable) new DoubleWritable(1), - new NDArrayWritable(arr2), - new IntWritable(2), - new DoubleWritable(3), - new NDArrayWritable(arr3), - new IntWritable(4)); + List record = Arrays.asList((Writable) new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), new NDArrayWritable(arr3), new IntWritable(4)); File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt"); tempFile.setWritable(true); tempFile.deleteOnExit(); if (tempFile.exists()) tempFile.delete(); - String lineOriginal = "13.0,14.0,15.0,4 1:1.0 2:11.0 3:12.0 4:2.0 5:3.0"; - try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) { Configuration configWriter = new Configuration(); configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 3); FileSplit outputSplit = new FileSplit(tempFile); - writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner()); + writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner()); writer.write(record); } - String lineNew = FileUtils.readFileToString(tempFile).trim(); assertEquals(lineOriginal, lineNew); } @Test - public void testNDArrayWritablesMultilabel() throws Exception { + @DisplayName("Test ND Array Writables Multilabel") + void testNDArrayWritablesMultilabel() throws Exception { INDArray arr2 = Nd4j.zeros(2); arr2.putScalar(0, 11); arr2.putScalar(1, 12); @@ -203,36 +191,29 @@ public class LibSvmRecordWriterTest extends BaseND4JTest { arr3.putScalar(0, 0); arr3.putScalar(1, 1); arr3.putScalar(2, 0); - List record = Arrays.asList((Writable) new DoubleWritable(1), - new NDArrayWritable(arr2), - new IntWritable(2), - new DoubleWritable(3), - new NDArrayWritable(arr3), - new DoubleWritable(1)); + List record = Arrays.asList((Writable) new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), new NDArrayWritable(arr3), new DoubleWritable(1)); File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt"); tempFile.setWritable(true); tempFile.deleteOnExit(); if (tempFile.exists()) tempFile.delete(); - String lineOriginal = "2,4 1:1.0 2:11.0 3:12.0 4:2.0 5:3.0"; - try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) { Configuration configWriter = new Configuration(); configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true); configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 3); FileSplit outputSplit = new FileSplit(tempFile); - writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner()); + writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner()); writer.write(record); } - String lineNew = FileUtils.readFileToString(tempFile).trim(); assertEquals(lineOriginal, lineNew); } @Test - public void testNDArrayWritablesZeroIndex() throws Exception { + @DisplayName("Test ND Array Writables Zero Index") + void testNDArrayWritablesZeroIndex() throws Exception { INDArray arr2 = Nd4j.zeros(2); arr2.putScalar(0, 11); arr2.putScalar(1, 12); @@ -240,99 +221,91 @@ public class LibSvmRecordWriterTest extends BaseND4JTest { arr3.putScalar(0, 0); arr3.putScalar(1, 1); arr3.putScalar(2, 0); - List record = Arrays.asList((Writable) new DoubleWritable(1), - new NDArrayWritable(arr2), - new IntWritable(2), - new DoubleWritable(3), - new NDArrayWritable(arr3), - new DoubleWritable(1)); + List record = Arrays.asList((Writable) new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), new NDArrayWritable(arr3), new DoubleWritable(1)); File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt"); tempFile.setWritable(true); tempFile.deleteOnExit(); if (tempFile.exists()) tempFile.delete(); - String lineOriginal = "1,3 0:1.0 1:11.0 2:12.0 3:2.0 4:3.0"; - try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) { Configuration configWriter = new Configuration(); - configWriter.setBoolean(LibSvmRecordWriter.ZERO_BASED_INDEXING, true); // NOT STANDARD! - configWriter.setBoolean(LibSvmRecordWriter.ZERO_BASED_LABEL_INDEXING, true); // NOT STANDARD! + // NOT STANDARD! + configWriter.setBoolean(LibSvmRecordWriter.ZERO_BASED_INDEXING, true); + // NOT STANDARD! + configWriter.setBoolean(LibSvmRecordWriter.ZERO_BASED_LABEL_INDEXING, true); configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true); configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 3); FileSplit outputSplit = new FileSplit(tempFile); - writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner()); + writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner()); writer.write(record); } - String lineNew = FileUtils.readFileToString(tempFile).trim(); assertEquals(lineOriginal, lineNew); } @Test - public void testNonIntegerButValidMultilabel() throws Exception { - List record = Arrays.asList((Writable) new IntWritable(3), - new IntWritable(2), - new DoubleWritable(1.0)); + @DisplayName("Test Non Integer But Valid Multilabel") + void testNonIntegerButValidMultilabel() throws Exception { + List record = Arrays.asList((Writable) new IntWritable(3), new IntWritable(2), new DoubleWritable(1.0)); File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt"); tempFile.setWritable(true); tempFile.deleteOnExit(); if (tempFile.exists()) tempFile.delete(); - try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) { Configuration configWriter = new Configuration(); configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 1); configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true); FileSplit outputSplit = new FileSplit(tempFile); - writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner()); + writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner()); writer.write(record); } } - @Test(expected = NumberFormatException.class) - public void nonIntegerMultilabel() throws Exception { - List record = Arrays.asList((Writable) new IntWritable(3), - new IntWritable(2), - new DoubleWritable(1.2)); - File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt"); - tempFile.setWritable(true); - tempFile.deleteOnExit(); - if (tempFile.exists()) - tempFile.delete(); - - try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) { - Configuration configWriter = new Configuration(); - configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0); - configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 1); - configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true); - FileSplit outputSplit = new FileSplit(tempFile); - writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner()); - writer.write(record); - } + @Test + @DisplayName("Non Integer Multilabel") + void nonIntegerMultilabel() { + assertThrows(NumberFormatException.class, () -> { + List record = Arrays.asList((Writable) new IntWritable(3), new IntWritable(2), new DoubleWritable(1.2)); + File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt"); + tempFile.setWritable(true); + tempFile.deleteOnExit(); + if (tempFile.exists()) + tempFile.delete(); + try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) { + Configuration configWriter = new Configuration(); + configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0); + configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 1); + configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true); + FileSplit outputSplit = new FileSplit(tempFile); + writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner()); + writer.write(record); + } + }); } - @Test(expected = NumberFormatException.class) - public void nonBinaryMultilabel() throws Exception { - List record = Arrays.asList((Writable) new IntWritable(0), - new IntWritable(1), - new IntWritable(2)); - File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt"); - tempFile.setWritable(true); - tempFile.deleteOnExit(); - if (tempFile.exists()) - tempFile.delete(); - - try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) { - Configuration configWriter = new Configuration(); - configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN,0); - configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN,1); - configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL,true); - FileSplit outputSplit = new FileSplit(tempFile); - writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner()); - writer.write(record); - } + @Test + @DisplayName("Non Binary Multilabel") + void nonBinaryMultilabel() { + assertThrows(NumberFormatException.class, () -> { + List record = Arrays.asList((Writable) new IntWritable(0), new IntWritable(1), new IntWritable(2)); + File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt"); + tempFile.setWritable(true); + tempFile.deleteOnExit(); + if (tempFile.exists()) + tempFile.delete(); + try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) { + Configuration configWriter = new Configuration(); + configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0); + configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 1); + configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true); + FileSplit outputSplit = new FileSplit(tempFile); + writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner()); + writer.write(record); + } + }); } } diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/SVMLightRecordWriterTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/SVMLightRecordWriterTest.java index 8efb2a539..56a130465 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/SVMLightRecordWriterTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/SVMLightRecordWriterTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.datavec.api.records.writer.impl; import org.apache.commons.io.FileUtils; @@ -27,93 +26,90 @@ import org.datavec.api.records.writer.impl.misc.SVMLightRecordWriter; import org.datavec.api.split.FileSplit; import org.datavec.api.split.partition.NumberOfRecordsPartitioner; import org.datavec.api.writable.*; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.common.io.ClassPathResource; - import java.io.File; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.regex.Matcher; import java.util.regex.Pattern; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; +import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.Assert.assertEquals; - -public class SVMLightRecordWriterTest extends BaseND4JTest { +@DisplayName("Svm Light Record Writer Test") +class SVMLightRecordWriterTest extends BaseND4JTest { @Test - public void testBasic() throws Exception { + @DisplayName("Test Basic") + void testBasic() throws Exception { Configuration configWriter = new Configuration(); - Configuration configReader = new Configuration(); configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 10); configReader.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false); - File inputFile = new ClassPathResource("datavec-api/svmlight/basic.txt").getFile(); executeTest(configWriter, configReader, inputFile); } @Test - public void testNoLabel() throws Exception { + @DisplayName("Test No Label") + void testNoLabel() throws Exception { Configuration configWriter = new Configuration(); configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 9); - Configuration configReader = new Configuration(); configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 10); configReader.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false); - File inputFile = new ClassPathResource("datavec-api/svmlight/noLabels.txt").getFile(); executeTest(configWriter, configReader, inputFile); } @Test - public void testMultioutputRecord() throws Exception { + @DisplayName("Test Multioutput Record") + void testMultioutputRecord() throws Exception { Configuration configWriter = new Configuration(); configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 9); - Configuration configReader = new Configuration(); configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 10); configReader.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false); - File inputFile = new ClassPathResource("datavec-api/svmlight/multioutput.txt").getFile(); executeTest(configWriter, configReader, inputFile); } @Test - public void testMultilabelRecord() throws Exception { + @DisplayName("Test Multilabel Record") + void testMultilabelRecord() throws Exception { Configuration configWriter = new Configuration(); configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 9); configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true); - Configuration configReader = new Configuration(); configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 10); configReader.setBoolean(SVMLightRecordReader.MULTILABEL, true); configReader.setInt(SVMLightRecordReader.NUM_LABELS, 4); configReader.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false); - File inputFile = new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile(); executeTest(configWriter, configReader, inputFile); } @Test - public void testZeroBasedIndexing() throws Exception { + @DisplayName("Test Zero Based Indexing") + void testZeroBasedIndexing() throws Exception { Configuration configWriter = new Configuration(); configWriter.setBoolean(SVMLightRecordWriter.ZERO_BASED_INDEXING, true); configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 10); configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true); - Configuration configReader = new Configuration(); configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 11); configReader.setBoolean(SVMLightRecordReader.MULTILABEL, true); configReader.setInt(SVMLightRecordReader.NUM_LABELS, 5); - File inputFile = new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile(); executeTest(configWriter, configReader, inputFile); } @@ -124,10 +120,9 @@ public class SVMLightRecordWriterTest extends BaseND4JTest { tempFile.deleteOnExit(); if (tempFile.exists()) tempFile.delete(); - try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) { FileSplit outputSplit = new FileSplit(tempFile); - writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner()); + writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner()); SVMLightRecordReader rr = new SVMLightRecordReader(); rr.initialize(configReader, new FileSplit(inputFile)); while (rr.hasNext()) { @@ -135,7 +130,6 @@ public class SVMLightRecordWriterTest extends BaseND4JTest { writer.write(record); } } - Pattern p = Pattern.compile(String.format("%s:\\d+ ", SVMLightRecordReader.QID_PREFIX)); List linesOriginal = new ArrayList<>(); for (String line : FileUtils.readLines(inputFile)) { @@ -156,7 +150,8 @@ public class SVMLightRecordWriterTest extends BaseND4JTest { } @Test - public void testNDArrayWritables() throws Exception { + @DisplayName("Test ND Array Writables") + void testNDArrayWritables() throws Exception { INDArray arr2 = Nd4j.zeros(2); arr2.putScalar(0, 11); arr2.putScalar(1, 12); @@ -164,35 +159,28 @@ public class SVMLightRecordWriterTest extends BaseND4JTest { arr3.putScalar(0, 13); arr3.putScalar(1, 14); arr3.putScalar(2, 15); - List record = Arrays.asList((Writable) new DoubleWritable(1), - new NDArrayWritable(arr2), - new IntWritable(2), - new DoubleWritable(3), - new NDArrayWritable(arr3), - new IntWritable(4)); + List record = Arrays.asList((Writable) new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), new NDArrayWritable(arr3), new IntWritable(4)); File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt"); tempFile.setWritable(true); tempFile.deleteOnExit(); if (tempFile.exists()) tempFile.delete(); - String lineOriginal = "13.0,14.0,15.0,4 1:1.0 2:11.0 3:12.0 4:2.0 5:3.0"; - try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) { Configuration configWriter = new Configuration(); configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 3); FileSplit outputSplit = new FileSplit(tempFile); - writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner()); + writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner()); writer.write(record); } - String lineNew = FileUtils.readFileToString(tempFile).trim(); assertEquals(lineOriginal, lineNew); } @Test - public void testNDArrayWritablesMultilabel() throws Exception { + @DisplayName("Test ND Array Writables Multilabel") + void testNDArrayWritablesMultilabel() throws Exception { INDArray arr2 = Nd4j.zeros(2); arr2.putScalar(0, 11); arr2.putScalar(1, 12); @@ -200,36 +188,29 @@ public class SVMLightRecordWriterTest extends BaseND4JTest { arr3.putScalar(0, 0); arr3.putScalar(1, 1); arr3.putScalar(2, 0); - List record = Arrays.asList((Writable) new DoubleWritable(1), - new NDArrayWritable(arr2), - new IntWritable(2), - new DoubleWritable(3), - new NDArrayWritable(arr3), - new DoubleWritable(1)); + List record = Arrays.asList((Writable) new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), new NDArrayWritable(arr3), new DoubleWritable(1)); File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt"); tempFile.setWritable(true); tempFile.deleteOnExit(); if (tempFile.exists()) tempFile.delete(); - String lineOriginal = "2,4 1:1.0 2:11.0 3:12.0 4:2.0 5:3.0"; - try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) { Configuration configWriter = new Configuration(); configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true); configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 3); FileSplit outputSplit = new FileSplit(tempFile); - writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner()); + writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner()); writer.write(record); } - String lineNew = FileUtils.readFileToString(tempFile).trim(); assertEquals(lineOriginal, lineNew); } @Test - public void testNDArrayWritablesZeroIndex() throws Exception { + @DisplayName("Test ND Array Writables Zero Index") + void testNDArrayWritablesZeroIndex() throws Exception { INDArray arr2 = Nd4j.zeros(2); arr2.putScalar(0, 11); arr2.putScalar(1, 12); @@ -237,99 +218,91 @@ public class SVMLightRecordWriterTest extends BaseND4JTest { arr3.putScalar(0, 0); arr3.putScalar(1, 1); arr3.putScalar(2, 0); - List record = Arrays.asList((Writable) new DoubleWritable(1), - new NDArrayWritable(arr2), - new IntWritable(2), - new DoubleWritable(3), - new NDArrayWritable(arr3), - new DoubleWritable(1)); + List record = Arrays.asList((Writable) new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), new NDArrayWritable(arr3), new DoubleWritable(1)); File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt"); tempFile.setWritable(true); tempFile.deleteOnExit(); if (tempFile.exists()) tempFile.delete(); - String lineOriginal = "1,3 0:1.0 1:11.0 2:12.0 3:2.0 4:3.0"; - try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) { Configuration configWriter = new Configuration(); - configWriter.setBoolean(SVMLightRecordWriter.ZERO_BASED_INDEXING, true); // NOT STANDARD! - configWriter.setBoolean(SVMLightRecordWriter.ZERO_BASED_LABEL_INDEXING, true); // NOT STANDARD! + // NOT STANDARD! + configWriter.setBoolean(SVMLightRecordWriter.ZERO_BASED_INDEXING, true); + // NOT STANDARD! + configWriter.setBoolean(SVMLightRecordWriter.ZERO_BASED_LABEL_INDEXING, true); configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true); configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 3); FileSplit outputSplit = new FileSplit(tempFile); - writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner()); + writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner()); writer.write(record); } - String lineNew = FileUtils.readFileToString(tempFile).trim(); assertEquals(lineOriginal, lineNew); } @Test - public void testNonIntegerButValidMultilabel() throws Exception { - List record = Arrays.asList((Writable) new IntWritable(3), - new IntWritable(2), - new DoubleWritable(1.0)); + @DisplayName("Test Non Integer But Valid Multilabel") + void testNonIntegerButValidMultilabel() throws Exception { + List record = Arrays.asList((Writable) new IntWritable(3), new IntWritable(2), new DoubleWritable(1.0)); File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt"); tempFile.setWritable(true); tempFile.deleteOnExit(); if (tempFile.exists()) tempFile.delete(); - try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) { Configuration configWriter = new Configuration(); configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 1); configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true); FileSplit outputSplit = new FileSplit(tempFile); - writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner()); + writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner()); writer.write(record); } } - @Test(expected = NumberFormatException.class) - public void nonIntegerMultilabel() throws Exception { - List record = Arrays.asList((Writable) new IntWritable(3), - new IntWritable(2), - new DoubleWritable(1.2)); - File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt"); - tempFile.setWritable(true); - tempFile.deleteOnExit(); - if (tempFile.exists()) - tempFile.delete(); - - try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) { - Configuration configWriter = new Configuration(); - configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0); - configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 1); - configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true); - FileSplit outputSplit = new FileSplit(tempFile); - writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner()); - writer.write(record); - } + @Test + @DisplayName("Non Integer Multilabel") + void nonIntegerMultilabel() { + assertThrows(NumberFormatException.class, () -> { + List record = Arrays.asList((Writable) new IntWritable(3), new IntWritable(2), new DoubleWritable(1.2)); + File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt"); + tempFile.setWritable(true); + tempFile.deleteOnExit(); + if (tempFile.exists()) + tempFile.delete(); + try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) { + Configuration configWriter = new Configuration(); + configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0); + configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 1); + configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true); + FileSplit outputSplit = new FileSplit(tempFile); + writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner()); + writer.write(record); + } + }); } - @Test(expected = NumberFormatException.class) - public void nonBinaryMultilabel() throws Exception { - List record = Arrays.asList((Writable) new IntWritable(0), - new IntWritable(1), - new IntWritable(2)); - File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt"); - tempFile.setWritable(true); - tempFile.deleteOnExit(); - if (tempFile.exists()) - tempFile.delete(); - - try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) { - Configuration configWriter = new Configuration(); - configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0); - configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 1); - configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true); - FileSplit outputSplit = new FileSplit(tempFile); - writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner()); - writer.write(record); - } + @Test + @DisplayName("Non Binary Multilabel") + void nonBinaryMultilabel() { + assertThrows(NumberFormatException.class, () -> { + List record = Arrays.asList((Writable) new IntWritable(0), new IntWritable(1), new IntWritable(2)); + File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt"); + tempFile.setWritable(true); + tempFile.deleteOnExit(); + if (tempFile.exists()) + tempFile.delete(); + try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) { + Configuration configWriter = new Configuration(); + configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0); + configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 1); + configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true); + FileSplit outputSplit = new FileSplit(tempFile); + writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner()); + writer.write(record); + } + }); } } diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/split/TransformSplitTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/split/TransformSplitTest.java index 79c799fd5..253eb98f4 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/split/TransformSplitTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/split/TransformSplitTest.java @@ -17,44 +17,43 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.datavec.api.split; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; - import java.net.URI; import java.net.URISyntaxException; import java.util.Collection; - import static java.util.Arrays.asList; -import static org.junit.Assert.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** * @author Ede Meijer */ -public class TransformSplitTest extends BaseND4JTest { - @Test - public void testTransform() throws URISyntaxException { - Collection inputFiles = asList(new URI("file:///foo/bar/../0.csv"), new URI("file:///foo/1.csv")); +@DisplayName("Transform Split Test") +class TransformSplitTest extends BaseND4JTest { + @Test + @DisplayName("Test Transform") + void testTransform() throws URISyntaxException { + Collection inputFiles = asList(new URI("file:///foo/bar/../0.csv"), new URI("file:///foo/1.csv")); InputSplit SUT = new TransformSplit(new CollectionInputSplit(inputFiles), new TransformSplit.URITransform() { + @Override public URI apply(URI uri) throws URISyntaxException { return uri.normalize(); } }); - - assertArrayEquals(new URI[] {new URI("file:///foo/0.csv"), new URI("file:///foo/1.csv")}, SUT.locations()); + assertArrayEquals(new URI[] { new URI("file:///foo/0.csv"), new URI("file:///foo/1.csv") }, SUT.locations()); } @Test - public void testSearchReplace() throws URISyntaxException { + @DisplayName("Test Search Replace") + void testSearchReplace() throws URISyntaxException { Collection inputFiles = asList(new URI("file:///foo/1-in.csv"), new URI("file:///foo/2-in.csv")); - InputSplit SUT = TransformSplit.ofSearchReplace(new CollectionInputSplit(inputFiles), "-in.csv", "-out.csv"); - - assertArrayEquals(new URI[] {new URI("file:///foo/1-out.csv"), new URI("file:///foo/2-out.csv")}, - SUT.locations()); + assertArrayEquals(new URI[] { new URI("file:///foo/1-out.csv"), new URI("file:///foo/2-out.csv") }, SUT.locations()); } } diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregableMultiOpArchTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregableMultiOpArchTest.java index e67722f78..42351fd9a 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregableMultiOpArchTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregableMultiOpArchTest.java @@ -17,32 +17,25 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.datavec.api.transform.ops; import com.tngtech.archunit.core.importer.ImportOption; import com.tngtech.archunit.junit.AnalyzeClasses; import com.tngtech.archunit.junit.ArchTest; -import com.tngtech.archunit.junit.ArchUnitRunner; import com.tngtech.archunit.lang.ArchRule; +import com.tngtech.archunit.lang.extension.ArchUnitExtension; +import com.tngtech.archunit.lang.extension.ArchUnitExtensions; import org.junit.runner.RunWith; import org.nd4j.common.tests.BaseND4JTest; - import java.io.Serializable; - import static com.tngtech.archunit.lang.syntax.ArchRuleDefinition.classes; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -@RunWith(ArchUnitRunner.class) -@AnalyzeClasses(packages = "org.datavec.api.transform.ops", importOptions = {ImportOption.DoNotIncludeTests.class}) -public class AggregableMultiOpArchTest extends BaseND4JTest { +@AnalyzeClasses(packages = "org.datavec.api.transform.ops", importOptions = { ImportOption.DoNotIncludeTests.class }) +@DisplayName("Aggregable Multi Op Arch Test") +class AggregableMultiOpArchTest extends BaseND4JTest { @ArchTest - public static final ArchRule ALL_AGGREGATE_OPS_MUST_BE_SERIALIZABLE = classes() - .that().resideInAPackage("org.datavec.api.transform.ops") - .and().doNotHaveSimpleName("AggregatorImpls") - .and().doNotHaveSimpleName("IAggregableReduceOp") - .and().doNotHaveSimpleName("StringAggregatorImpls") - .and().doNotHaveFullyQualifiedName("org.datavec.api.transform.ops.StringAggregatorImpls$1") - .should().implement(Serializable.class) - .because("All aggregate ops must be serializable."); -} \ No newline at end of file + public static final ArchRule ALL_AGGREGATE_OPS_MUST_BE_SERIALIZABLE = classes().that().resideInAPackage("org.datavec.api.transform.ops").and().doNotHaveSimpleName("AggregatorImpls").and().doNotHaveSimpleName("IAggregableReduceOp").and().doNotHaveSimpleName("StringAggregatorImpls").and().doNotHaveFullyQualifiedName("org.datavec.api.transform.ops.StringAggregatorImpls$1").should().implement(Serializable.class).because("All aggregate ops must be serializable."); +} diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregableMultiOpTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregableMultiOpTest.java index cb4fdeb04..acd2971ac 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregableMultiOpTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregableMultiOpTest.java @@ -17,52 +17,46 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.datavec.api.transform.ops; import org.datavec.api.writable.Writable; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; - import java.util.*; +import static org.junit.jupiter.api.Assertions.assertTrue; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.assertTrue; - -public class AggregableMultiOpTest extends BaseND4JTest { +@DisplayName("Aggregable Multi Op Test") +class AggregableMultiOpTest extends BaseND4JTest { private List intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); @Test - public void testMulti() throws Exception { + @DisplayName("Test Multi") + void testMulti() throws Exception { AggregatorImpls.AggregableFirst af = new AggregatorImpls.AggregableFirst<>(); AggregatorImpls.AggregableSum as = new AggregatorImpls.AggregableSum<>(); AggregableMultiOp multi = new AggregableMultiOp<>(Arrays.asList(af, as)); - assertTrue(multi.getOperations().size() == 2); for (int i = 0; i < intList.size(); i++) { multi.accept(intList.get(i)); } - // mutablility assertTrue(as.get().toDouble() == 45D); assertTrue(af.get().toInt() == 1); - List res = multi.get(); assertTrue(res.get(1).toDouble() == 45D); assertTrue(res.get(0).toInt() == 1); - AggregatorImpls.AggregableFirst rf = new AggregatorImpls.AggregableFirst<>(); AggregatorImpls.AggregableSum rs = new AggregatorImpls.AggregableSum<>(); AggregableMultiOp reverse = new AggregableMultiOp<>(Arrays.asList(rf, rs)); - for (int i = 0; i < intList.size(); i++) { reverse.accept(intList.get(intList.size() - i - 1)); } - List revRes = reverse.get(); assertTrue(revRes.get(1).toDouble() == 45D); assertTrue(revRes.get(0).toInt() == 9); - multi.combine(reverse); List combinedRes = multi.get(); assertTrue(combinedRes.get(1).toDouble() == 90D); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregatorImplsTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregatorImplsTest.java index 47da27bdc..e7c8de557 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregatorImplsTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregatorImplsTest.java @@ -17,41 +17,39 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.datavec.api.transform.ops; import org.junit.Rule; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.rules.ExpectedException; import org.nd4j.common.tests.BaseND4JTest; - import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import org.junit.jupiter.api.DisplayName; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; - -public class AggregatorImplsTest extends BaseND4JTest { +@DisplayName("Aggregator Impls Test") +class AggregatorImplsTest extends BaseND4JTest { private List intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); + private List stringList = new ArrayList<>(Arrays.asList("arakoa", "abracadabra", "blast", "acceptance")); @Test - public void aggregableFirstTest() { + @DisplayName("Aggregable First Test") + void aggregableFirstTest() { AggregatorImpls.AggregableFirst first = new AggregatorImpls.AggregableFirst<>(); for (int i = 0; i < intList.size(); i++) { first.accept(intList.get(i)); } assertEquals(1, first.get().toInt()); - AggregatorImpls.AggregableFirst firstS = new AggregatorImpls.AggregableFirst<>(); for (int i = 0; i < stringList.size(); i++) { firstS.accept(stringList.get(i)); } assertTrue(firstS.get().toString().equals("arakoa")); - - AggregatorImpls.AggregableFirst reverse = new AggregatorImpls.AggregableFirst<>(); for (int i = 0; i < intList.size(); i++) { reverse.accept(intList.get(intList.size() - i - 1)); @@ -60,22 +58,19 @@ public class AggregatorImplsTest extends BaseND4JTest { assertEquals(1, first.get().toInt()); } - @Test - public void aggregableLastTest() { + @DisplayName("Aggregable Last Test") + void aggregableLastTest() { AggregatorImpls.AggregableLast last = new AggregatorImpls.AggregableLast<>(); for (int i = 0; i < intList.size(); i++) { last.accept(intList.get(i)); } assertEquals(9, last.get().toInt()); - AggregatorImpls.AggregableLast lastS = new AggregatorImpls.AggregableLast<>(); for (int i = 0; i < stringList.size(); i++) { lastS.accept(stringList.get(i)); } assertTrue(lastS.get().toString().equals("acceptance")); - - AggregatorImpls.AggregableLast reverse = new AggregatorImpls.AggregableLast<>(); for (int i = 0; i < intList.size(); i++) { reverse.accept(intList.get(intList.size() - i - 1)); @@ -85,20 +80,18 @@ public class AggregatorImplsTest extends BaseND4JTest { } @Test - public void aggregableCountTest() { + @DisplayName("Aggregable Count Test") + void aggregableCountTest() { AggregatorImpls.AggregableCount cnt = new AggregatorImpls.AggregableCount<>(); for (int i = 0; i < intList.size(); i++) { cnt.accept(intList.get(i)); } assertEquals(9, cnt.get().toInt()); - AggregatorImpls.AggregableCount lastS = new AggregatorImpls.AggregableCount<>(); for (int i = 0; i < stringList.size(); i++) { lastS.accept(stringList.get(i)); } assertEquals(4, lastS.get().toInt()); - - AggregatorImpls.AggregableCount reverse = new AggregatorImpls.AggregableCount<>(); for (int i = 0; i < intList.size(); i++) { reverse.accept(intList.get(intList.size() - i - 1)); @@ -108,14 +101,13 @@ public class AggregatorImplsTest extends BaseND4JTest { } @Test - public void aggregableMaxTest() { + @DisplayName("Aggregable Max Test") + void aggregableMaxTest() { AggregatorImpls.AggregableMax mx = new AggregatorImpls.AggregableMax<>(); for (int i = 0; i < intList.size(); i++) { mx.accept(intList.get(i)); } assertEquals(9, mx.get().toInt()); - - AggregatorImpls.AggregableMax reverse = new AggregatorImpls.AggregableMax<>(); for (int i = 0; i < intList.size(); i++) { reverse.accept(intList.get(intList.size() - i - 1)); @@ -124,16 +116,14 @@ public class AggregatorImplsTest extends BaseND4JTest { assertEquals(9, mx.get().toInt()); } - @Test - public void aggregableRangeTest() { + @DisplayName("Aggregable Range Test") + void aggregableRangeTest() { AggregatorImpls.AggregableRange mx = new AggregatorImpls.AggregableRange<>(); for (int i = 0; i < intList.size(); i++) { mx.accept(intList.get(i)); } assertEquals(8, mx.get().toInt()); - - AggregatorImpls.AggregableRange reverse = new AggregatorImpls.AggregableRange<>(); for (int i = 0; i < intList.size(); i++) { reverse.accept(intList.get(intList.size() - i - 1) + 9); @@ -143,14 +133,13 @@ public class AggregatorImplsTest extends BaseND4JTest { } @Test - public void aggregableMinTest() { + @DisplayName("Aggregable Min Test") + void aggregableMinTest() { AggregatorImpls.AggregableMin mn = new AggregatorImpls.AggregableMin<>(); for (int i = 0; i < intList.size(); i++) { mn.accept(intList.get(i)); } assertEquals(1, mn.get().toInt()); - - AggregatorImpls.AggregableMin reverse = new AggregatorImpls.AggregableMin<>(); for (int i = 0; i < intList.size(); i++) { reverse.accept(intList.get(intList.size() - i - 1)); @@ -160,14 +149,13 @@ public class AggregatorImplsTest extends BaseND4JTest { } @Test - public void aggregableSumTest() { + @DisplayName("Aggregable Sum Test") + void aggregableSumTest() { AggregatorImpls.AggregableSum sm = new AggregatorImpls.AggregableSum<>(); for (int i = 0; i < intList.size(); i++) { sm.accept(intList.get(i)); } assertEquals(45, sm.get().toInt()); - - AggregatorImpls.AggregableSum reverse = new AggregatorImpls.AggregableSum<>(); for (int i = 0; i < intList.size(); i++) { reverse.accept(intList.get(intList.size() - i - 1)); @@ -176,17 +164,15 @@ public class AggregatorImplsTest extends BaseND4JTest { assertEquals(90, sm.get().toInt()); } - @Test - public void aggregableMeanTest() { + @DisplayName("Aggregable Mean Test") + void aggregableMeanTest() { AggregatorImpls.AggregableMean mn = new AggregatorImpls.AggregableMean<>(); for (int i = 0; i < intList.size(); i++) { mn.accept(intList.get(i)); } assertEquals(9l, (long) mn.getCount()); assertEquals(5D, mn.get().toDouble(), 0.001); - - AggregatorImpls.AggregableMean reverse = new AggregatorImpls.AggregableMean<>(); for (int i = 0; i < intList.size(); i++) { reverse.accept(intList.get(intList.size() - i - 1)); @@ -197,80 +183,73 @@ public class AggregatorImplsTest extends BaseND4JTest { } @Test - public void aggregableStdDevTest() { + @DisplayName("Aggregable Std Dev Test") + void aggregableStdDevTest() { AggregatorImpls.AggregableStdDev sd = new AggregatorImpls.AggregableStdDev<>(); for (int i = 0; i < intList.size(); i++) { sd.accept(intList.get(i)); } assertTrue(Math.abs(sd.get().toDouble() - 2.7386) < 0.0001); - - AggregatorImpls.AggregableStdDev reverse = new AggregatorImpls.AggregableStdDev<>(); for (int i = 0; i < intList.size(); i++) { reverse.accept(intList.get(intList.size() - i - 1)); } sd.combine(reverse); - assertTrue("" + sd.get().toDouble(), Math.abs(sd.get().toDouble() - 1.8787) < 0.0001); + assertTrue(Math.abs(sd.get().toDouble() - 1.8787) < 0.0001,"" + sd.get().toDouble()); } @Test - public void aggregableVariance() { + @DisplayName("Aggregable Variance") + void aggregableVariance() { AggregatorImpls.AggregableVariance sd = new AggregatorImpls.AggregableVariance<>(); for (int i = 0; i < intList.size(); i++) { sd.accept(intList.get(i)); } assertTrue(Math.abs(sd.get().toDouble() - 60D / 8) < 0.0001); - - AggregatorImpls.AggregableVariance reverse = new AggregatorImpls.AggregableVariance<>(); for (int i = 0; i < intList.size(); i++) { reverse.accept(intList.get(intList.size() - i - 1)); } sd.combine(reverse); - assertTrue("" + sd.get().toDouble(), Math.abs(sd.get().toDouble() - 3.5294) < 0.0001); + assertTrue(Math.abs(sd.get().toDouble() - 3.5294) < 0.0001,"" + sd.get().toDouble()); } @Test - public void aggregableUncorrectedStdDevTest() { + @DisplayName("Aggregable Uncorrected Std Dev Test") + void aggregableUncorrectedStdDevTest() { AggregatorImpls.AggregableUncorrectedStdDev sd = new AggregatorImpls.AggregableUncorrectedStdDev<>(); for (int i = 0; i < intList.size(); i++) { sd.accept(intList.get(i)); } assertTrue(Math.abs(sd.get().toDouble() - 2.582) < 0.0001); - - - AggregatorImpls.AggregableUncorrectedStdDev reverse = - new AggregatorImpls.AggregableUncorrectedStdDev<>(); + AggregatorImpls.AggregableUncorrectedStdDev reverse = new AggregatorImpls.AggregableUncorrectedStdDev<>(); for (int i = 0; i < intList.size(); i++) { reverse.accept(intList.get(intList.size() - i - 1)); } sd.combine(reverse); - assertTrue("" + sd.get().toDouble(), Math.abs(sd.get().toDouble() - 1.8257) < 0.0001); + assertTrue(Math.abs(sd.get().toDouble() - 1.8257) < 0.0001,"" + sd.get().toDouble()); } - @Test - public void aggregablePopulationVariance() { + @DisplayName("Aggregable Population Variance") + void aggregablePopulationVariance() { AggregatorImpls.AggregablePopulationVariance sd = new AggregatorImpls.AggregablePopulationVariance<>(); for (int i = 0; i < intList.size(); i++) { sd.accept(intList.get(i)); } assertTrue(Math.abs(sd.get().toDouble() - 60D / 9) < 0.0001); - - - AggregatorImpls.AggregablePopulationVariance reverse = - new AggregatorImpls.AggregablePopulationVariance<>(); + AggregatorImpls.AggregablePopulationVariance reverse = new AggregatorImpls.AggregablePopulationVariance<>(); for (int i = 0; i < intList.size(); i++) { reverse.accept(intList.get(intList.size() - i - 1)); } sd.combine(reverse); - assertTrue("" + sd.get().toDouble(), Math.abs(sd.get().toDouble() - 30D / 9) < 0.0001); + assertTrue(Math.abs(sd.get().toDouble() - 30D / 9) < 0.0001,"" + sd.get().toDouble()); } @Test - public void aggregableCountUniqueTest() { + @DisplayName("Aggregable Count Unique Test") + void aggregableCountUniqueTest() { // at this low range, it's linear counting - AggregatorImpls.AggregableCountUnique cu = new AggregatorImpls.AggregableCountUnique<>(); for (int i = 0; i < intList.size(); i++) { cu.accept(intList.get(i)); @@ -278,7 +257,6 @@ public class AggregatorImplsTest extends BaseND4JTest { assertEquals(9, cu.get().toInt()); cu.accept(1); assertEquals(9, cu.get().toInt()); - AggregatorImpls.AggregableCountUnique reverse = new AggregatorImpls.AggregableCountUnique<>(); for (int i = 0; i < intList.size(); i++) { reverse.accept(intList.get(intList.size() - i - 1)); @@ -290,16 +268,14 @@ public class AggregatorImplsTest extends BaseND4JTest { @Rule public final ExpectedException exception = ExpectedException.none(); - @Test - public void incompatibleAggregatorTest() { + @DisplayName("Incompatible Aggregator Test") + void incompatibleAggregatorTest() { AggregatorImpls.AggregableSum sm = new AggregatorImpls.AggregableSum<>(); for (int i = 0; i < intList.size(); i++) { sm.accept(intList.get(i)); } assertEquals(45, sm.get().toInt()); - - AggregatorImpls.AggregableMean reverse = new AggregatorImpls.AggregableMean<>(); for (int i = 0; i < intList.size(); i++) { reverse.accept(intList.get(intList.size() - i - 1)); @@ -308,5 +284,4 @@ public class AggregatorImplsTest extends BaseND4JTest { sm.combine(reverse); assertEquals(45, sm.get().toInt()); } - } diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/DispatchOpTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/DispatchOpTest.java index 098c5635a..6a444923d 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/DispatchOpTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/DispatchOpTest.java @@ -17,77 +17,65 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.datavec.api.transform.ops; import org.datavec.api.writable.Writable; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; - import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; +import static org.junit.jupiter.api.Assertions.assertTrue; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.assertTrue; - -public class DispatchOpTest extends BaseND4JTest { +@DisplayName("Dispatch Op Test") +class DispatchOpTest extends BaseND4JTest { private List intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); + private List stringList = new ArrayList<>(Arrays.asList("arakoa", "abracadabra", "blast", "acceptance")); @Test - public void testDispatchSimple() { + @DisplayName("Test Dispatch Simple") + void testDispatchSimple() { AggregatorImpls.AggregableFirst af = new AggregatorImpls.AggregableFirst<>(); AggregatorImpls.AggregableSum as = new AggregatorImpls.AggregableSum<>(); - AggregableMultiOp multiaf = - new AggregableMultiOp<>(Collections.>singletonList(af)); - AggregableMultiOp multias = - new AggregableMultiOp<>(Collections.>singletonList(as)); - - DispatchOp parallel = - new DispatchOp<>(Arrays.>>asList(multiaf, multias)); - + AggregableMultiOp multiaf = new AggregableMultiOp<>(Collections.>singletonList(af)); + AggregableMultiOp multias = new AggregableMultiOp<>(Collections.>singletonList(as)); + DispatchOp parallel = new DispatchOp<>(Arrays.>>asList(multiaf, multias)); assertTrue(multiaf.getOperations().size() == 1); assertTrue(multias.getOperations().size() == 1); assertTrue(parallel.getOperations().size() == 2); for (int i = 0; i < intList.size(); i++) { parallel.accept(Arrays.asList(intList.get(i), intList.get(i))); } - List res = parallel.get(); assertTrue(res.get(1).toDouble() == 45D); assertTrue(res.get(0).toInt() == 1); - } @Test - public void testDispatchFlatMap() { + @DisplayName("Test Dispatch Flat Map") + void testDispatchFlatMap() { AggregatorImpls.AggregableFirst af = new AggregatorImpls.AggregableFirst<>(); AggregatorImpls.AggregableSum as = new AggregatorImpls.AggregableSum<>(); AggregableMultiOp multi = new AggregableMultiOp<>(Arrays.asList(af, as)); - AggregatorImpls.AggregableLast al = new AggregatorImpls.AggregableLast<>(); AggregatorImpls.AggregableMax amax = new AggregatorImpls.AggregableMax<>(); AggregableMultiOp otherMulti = new AggregableMultiOp<>(Arrays.asList(al, amax)); - - - DispatchOp parallel = new DispatchOp<>( - Arrays.>>asList(multi, otherMulti)); - + DispatchOp parallel = new DispatchOp<>(Arrays.>>asList(multi, otherMulti)); assertTrue(multi.getOperations().size() == 2); assertTrue(otherMulti.getOperations().size() == 2); assertTrue(parallel.getOperations().size() == 2); for (int i = 0; i < intList.size(); i++) { parallel.accept(Arrays.asList(intList.get(i), intList.get(i))); } - List res = parallel.get(); assertTrue(res.get(1).toDouble() == 45D); assertTrue(res.get(0).toInt() == 1); assertTrue(res.get(3).toDouble() == 9); assertTrue(res.get(2).toInt() == 9); - } - } diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/parse/ParseDoubleTransformTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/parse/ParseDoubleTransformTest.java index 14fbf7ca8..a42b273e2 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/parse/ParseDoubleTransformTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/parse/ParseDoubleTransformTest.java @@ -17,29 +17,29 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.datavec.api.transform.transform.parse; import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; - import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.assertEquals; +@DisplayName("Parse Double Transform Test") +class ParseDoubleTransformTest extends BaseND4JTest { -public class ParseDoubleTransformTest extends BaseND4JTest { @Test - public void testDoubleTransform() { + @DisplayName("Test Double Transform") + void testDoubleTransform() { List record = new ArrayList<>(); record.add(new Text("0.0")); List transformed = Arrays.asList(new DoubleWritable(0.0)); assertEquals(transformed, new ParseDoubleTransform().map(record)); } - - } diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/util/ClassPathResourceTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/util/ClassPathResourceTest.java index 48f214cb2..b0a283563 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/util/ClassPathResourceTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/util/ClassPathResourceTest.java @@ -17,30 +17,31 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.datavec.api.util; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; - import java.io.BufferedReader; import java.io.File; import java.io.InputStream; import java.io.InputStreamReader; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.core.AnyOf.anyOf; import static org.hamcrest.core.IsEqual.equalTo; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -public class ClassPathResourceTest extends BaseND4JTest { +@DisplayName("Class Path Resource Test") +class ClassPathResourceTest extends BaseND4JTest { - private boolean isWindows = false; //File sizes are reported slightly different on Linux vs. Windows + // File sizes are reported slightly different on Linux vs. Windows + private boolean isWindows = false; - @Before - public void setUp() throws Exception { + @BeforeEach + void setUp() throws Exception { String osname = System.getProperty("os.name"); if (osname != null && osname.toLowerCase().contains("win")) { isWindows = true; @@ -48,9 +49,9 @@ public class ClassPathResourceTest extends BaseND4JTest { } @Test - public void testGetFile1() throws Exception { + @DisplayName("Test Get File 1") + void testGetFile1() throws Exception { File intFile = new ClassPathResource("datavec-api/iris.dat").getFile(); - assertTrue(intFile.exists()); if (isWindows) { assertThat(intFile.length(), anyOf(equalTo(2700L), equalTo(2850L))); @@ -60,9 +61,9 @@ public class ClassPathResourceTest extends BaseND4JTest { } @Test - public void testGetFileSlash1() throws Exception { + @DisplayName("Test Get File Slash 1") + void testGetFileSlash1() throws Exception { File intFile = new ClassPathResource("datavec-api/iris.dat").getFile(); - assertTrue(intFile.exists()); if (isWindows) { assertThat(intFile.length(), anyOf(equalTo(2700L), equalTo(2850L))); @@ -72,11 +73,10 @@ public class ClassPathResourceTest extends BaseND4JTest { } @Test - public void testGetFileWithSpace1() throws Exception { + @DisplayName("Test Get File With Space 1") + void testGetFileWithSpace1() throws Exception { File intFile = new ClassPathResource("datavec-api/csvsequence test.txt").getFile(); - assertTrue(intFile.exists()); - if (isWindows) { assertThat(intFile.length(), anyOf(equalTo(60L), equalTo(64L))); } else { @@ -85,16 +85,15 @@ public class ClassPathResourceTest extends BaseND4JTest { } @Test - public void testInputStream() throws Exception { + @DisplayName("Test Input Stream") + void testInputStream() throws Exception { ClassPathResource resource = new ClassPathResource("datavec-api/csvsequence_1.txt"); File intFile = resource.getFile(); - if (isWindows) { assertThat(intFile.length(), anyOf(equalTo(60L), equalTo(64L))); } else { assertEquals(60, intFile.length()); } - InputStream stream = resource.getInputStream(); BufferedReader reader = new BufferedReader(new InputStreamReader(stream)); String line = ""; @@ -102,21 +101,19 @@ public class ClassPathResourceTest extends BaseND4JTest { while ((line = reader.readLine()) != null) { cnt++; } - assertEquals(5, cnt); } @Test - public void testInputStreamSlash() throws Exception { + @DisplayName("Test Input Stream Slash") + void testInputStreamSlash() throws Exception { ClassPathResource resource = new ClassPathResource("datavec-api/csvsequence_1.txt"); File intFile = resource.getFile(); - if (isWindows) { assertThat(intFile.length(), anyOf(equalTo(60L), equalTo(64L))); } else { assertEquals(60, intFile.length()); } - InputStream stream = resource.getInputStream(); BufferedReader reader = new BufferedReader(new InputStreamReader(stream)); String line = ""; @@ -124,7 +121,6 @@ public class ClassPathResourceTest extends BaseND4JTest { while ((line = reader.readLine()) != null) { cnt++; } - assertEquals(5, cnt); } } diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/util/TimeSeriesUtilsTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/util/TimeSeriesUtilsTest.java index 48a815a63..53dbbb5f7 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/util/TimeSeriesUtilsTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/util/TimeSeriesUtilsTest.java @@ -17,44 +17,41 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.datavec.api.util; import org.datavec.api.timeseries.util.TimeSeriesWritableUtils; import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.Writable; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; - import java.util.ArrayList; import java.util.List; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.assertArrayEquals; - -public class TimeSeriesUtilsTest extends BaseND4JTest { +@DisplayName("Time Series Utils Test") +class TimeSeriesUtilsTest extends BaseND4JTest { @Test - public void testTimeSeriesCreation() { + @DisplayName("Test Time Series Creation") + void testTimeSeriesCreation() { List>> test = new ArrayList<>(); List> timeStep = new ArrayList<>(); - for(int i = 0; i < 5; i++) { + for (int i = 0; i < 5; i++) { timeStep.add(getRecord(5)); } - test.add(timeStep); - INDArray arr = TimeSeriesWritableUtils.convertWritablesSequence(test).getFirst(); - assertArrayEquals(new long[]{1,5,5},arr.shape()); - } + assertArrayEquals(new long[] { 1, 5, 5 }, arr.shape()); + } - private List getRecord(int length) { + private List getRecord(int length) { List ret = new ArrayList<>(); - for(int i = 0; i < length; i++) { + for (int i = 0; i < length; i++) { ret.add(new DoubleWritable(1.0)); } - return ret; - } - + } } diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/writable/RecordConverterTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/writable/RecordConverterTest.java index bcabc2910..f84229ceb 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/writable/RecordConverterTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/writable/RecordConverterTest.java @@ -17,52 +17,50 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.datavec.api.writable; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.shade.guava.collect.Lists; import org.datavec.api.transform.schema.Schema; import org.datavec.api.util.ndarray.RecordConverter; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; - import java.util.Arrays; import java.util.List; import java.util.TimeZone; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.assertEquals; +@DisplayName("Record Converter Test") +class RecordConverterTest extends BaseND4JTest { -public class RecordConverterTest extends BaseND4JTest { @Test - public void toRecords_PassInClassificationDataSet_ExpectNDArrayAndIntWritables() { - INDArray feature1 = Nd4j.create(new double[]{4, -5.7, 10, -0.1}, new long[]{1, 4}, DataType.FLOAT); - INDArray feature2 = Nd4j.create(new double[]{11, .7, -1.3, 4}, new long[]{1, 4}, DataType.FLOAT); - INDArray label1 = Nd4j.create(new double[]{0, 0, 1, 0}, new long[]{1, 4}, DataType.FLOAT); - INDArray label2 = Nd4j.create(new double[]{0, 1, 0, 0}, new long[]{1, 4}, DataType.FLOAT); - DataSet dataSet = new DataSet(Nd4j.vstack(Lists.newArrayList(feature1, feature2)), - Nd4j.vstack(Lists.newArrayList(label1, label2))); - + @DisplayName("To Records _ Pass In Classification Data Set _ Expect ND Array And Int Writables") + void toRecords_PassInClassificationDataSet_ExpectNDArrayAndIntWritables() { + INDArray feature1 = Nd4j.create(new double[] { 4, -5.7, 10, -0.1 }, new long[] { 1, 4 }, DataType.FLOAT); + INDArray feature2 = Nd4j.create(new double[] { 11, .7, -1.3, 4 }, new long[] { 1, 4 }, DataType.FLOAT); + INDArray label1 = Nd4j.create(new double[] { 0, 0, 1, 0 }, new long[] { 1, 4 }, DataType.FLOAT); + INDArray label2 = Nd4j.create(new double[] { 0, 1, 0, 0 }, new long[] { 1, 4 }, DataType.FLOAT); + DataSet dataSet = new DataSet(Nd4j.vstack(Lists.newArrayList(feature1, feature2)), Nd4j.vstack(Lists.newArrayList(label1, label2))); List> writableList = RecordConverter.toRecords(dataSet); - assertEquals(2, writableList.size()); testClassificationWritables(feature1, 2, writableList.get(0)); testClassificationWritables(feature2, 1, writableList.get(1)); } @Test - public void toRecords_PassInRegressionDataSet_ExpectNDArrayAndDoubleWritables() { - INDArray feature = Nd4j.create(new double[]{4, -5.7, 10, -0.1}, new long[]{1, 4}, DataType.FLOAT); - INDArray label = Nd4j.create(new double[]{.5, 2, 3, .5}, new long[]{1, 4}, DataType.FLOAT); + @DisplayName("To Records _ Pass In Regression Data Set _ Expect ND Array And Double Writables") + void toRecords_PassInRegressionDataSet_ExpectNDArrayAndDoubleWritables() { + INDArray feature = Nd4j.create(new double[] { 4, -5.7, 10, -0.1 }, new long[] { 1, 4 }, DataType.FLOAT); + INDArray label = Nd4j.create(new double[] { .5, 2, 3, .5 }, new long[] { 1, 4 }, DataType.FLOAT); DataSet dataSet = new DataSet(feature, label); - List> writableList = RecordConverter.toRecords(dataSet); List results = writableList.get(0); NDArrayWritable ndArrayWritable = (NDArrayWritable) results.get(0); - assertEquals(1, writableList.size()); assertEquals(5, results.size()); assertEquals(feature, ndArrayWritable.get()); @@ -72,62 +70,39 @@ public class RecordConverterTest extends BaseND4JTest { } } - private void testClassificationWritables(INDArray expectedFeatureVector, int expectLabelIndex, - List writables) { + private void testClassificationWritables(INDArray expectedFeatureVector, int expectLabelIndex, List writables) { NDArrayWritable ndArrayWritable = (NDArrayWritable) writables.get(0); IntWritable intWritable = (IntWritable) writables.get(1); - assertEquals(2, writables.size()); assertEquals(expectedFeatureVector, ndArrayWritable.get()); assertEquals(expectLabelIndex, intWritable.get()); } - @Test - public void testNDArrayWritableConcat() { - List l = Arrays.asList(new DoubleWritable(1), - new NDArrayWritable(Nd4j.create(new double[]{2, 3, 4}, new long[]{1, 3}, DataType.FLOAT)), new DoubleWritable(5), - new NDArrayWritable(Nd4j.create(new double[]{6, 7, 8}, new long[]{1, 3}, DataType.FLOAT)), new IntWritable(9), - new IntWritable(1)); - - INDArray exp = Nd4j.create(new double[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 1}, new long[]{1, 10}, DataType.FLOAT); + @DisplayName("Test ND Array Writable Concat") + void testNDArrayWritableConcat() { + List l = Arrays.asList(new DoubleWritable(1), new NDArrayWritable(Nd4j.create(new double[] { 2, 3, 4 }, new long[] { 1, 3 }, DataType.FLOAT)), new DoubleWritable(5), new NDArrayWritable(Nd4j.create(new double[] { 6, 7, 8 }, new long[] { 1, 3 }, DataType.FLOAT)), new IntWritable(9), new IntWritable(1)); + INDArray exp = Nd4j.create(new double[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 1 }, new long[] { 1, 10 }, DataType.FLOAT); INDArray act = RecordConverter.toArray(DataType.FLOAT, l); - assertEquals(exp, act); } @Test - public void testNDArrayWritableConcatToMatrix(){ - - List l1 = Arrays.asList(new DoubleWritable(1), new NDArrayWritable(Nd4j.create(new double[]{2, 3, 4}, new long[]{1,3}, DataType.FLOAT)), new DoubleWritable(5)); - List l2 = Arrays.asList(new DoubleWritable(6), new NDArrayWritable(Nd4j.create(new double[]{7, 8, 9}, new long[]{1,3}, DataType.FLOAT)), new DoubleWritable(10)); - - INDArray exp = Nd4j.create(new double[][]{ - {1,2,3,4,5}, - {6,7,8,9,10}}).castTo(DataType.FLOAT); - - INDArray act = RecordConverter.toMatrix(DataType.FLOAT, Arrays.asList(l1,l2)); - + @DisplayName("Test ND Array Writable Concat To Matrix") + void testNDArrayWritableConcatToMatrix() { + List l1 = Arrays.asList(new DoubleWritable(1), new NDArrayWritable(Nd4j.create(new double[] { 2, 3, 4 }, new long[] { 1, 3 }, DataType.FLOAT)), new DoubleWritable(5)); + List l2 = Arrays.asList(new DoubleWritable(6), new NDArrayWritable(Nd4j.create(new double[] { 7, 8, 9 }, new long[] { 1, 3 }, DataType.FLOAT)), new DoubleWritable(10)); + INDArray exp = Nd4j.create(new double[][] { { 1, 2, 3, 4, 5 }, { 6, 7, 8, 9, 10 } }).castTo(DataType.FLOAT); + INDArray act = RecordConverter.toMatrix(DataType.FLOAT, Arrays.asList(l1, l2)); assertEquals(exp, act); } @Test - public void testToRecordWithListOfObject(){ - final List list = Arrays.asList((Object)3, 7.0f, "Foo", "Bar", 1.0, 3f, 3L, 7, 0L); - final Schema schema = new Schema.Builder() - .addColumnInteger("a") - .addColumnFloat("b") - .addColumnString("c") - .addColumnCategorical("d", "Bar", "Baz") - .addColumnDouble("e") - .addColumnFloat("f") - .addColumnLong("g") - .addColumnInteger("h") - .addColumnTime("i", TimeZone.getDefault()) - .build(); - + @DisplayName("Test To Record With List Of Object") + void testToRecordWithListOfObject() { + final List list = Arrays.asList((Object) 3, 7.0f, "Foo", "Bar", 1.0, 3f, 3L, 7, 0L); + final Schema schema = new Schema.Builder().addColumnInteger("a").addColumnFloat("b").addColumnString("c").addColumnCategorical("d", "Bar", "Baz").addColumnDouble("e").addColumnFloat("f").addColumnLong("g").addColumnInteger("h").addColumnTime("i", TimeZone.getDefault()).build(); final List record = RecordConverter.toRecord(schema, list); - assertEquals(record.get(0).toInt(), 3); assertEquals(record.get(1).toFloat(), 7f, 1e-6); assertEquals(record.get(2).toString(), "Foo"); @@ -137,7 +112,5 @@ public class RecordConverterTest extends BaseND4JTest { assertEquals(record.get(6).toLong(), 3L); assertEquals(record.get(7).toInt(), 7); assertEquals(record.get(8).toLong(), 0); - - } } diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/writable/WritableTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/writable/WritableTest.java index d9861cc92..f3daccd04 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/writable/WritableTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/writable/WritableTest.java @@ -17,38 +17,38 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.datavec.api.writable; import org.datavec.api.writable.batch.NDArrayRecordBatch; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; - import java.nio.Buffer; import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; import java.util.Iterator; import java.util.List; +import org.junit.jupiter.api.DisplayName; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; -public class WritableTest extends BaseND4JTest { +@DisplayName("Writable Test") +class WritableTest extends BaseND4JTest { @Test - public void testWritableEqualityReflexive() { + @DisplayName("Test Writable Equality Reflexive") + void testWritableEqualityReflexive() { assertEquals(new IntWritable(1), new IntWritable(1)); assertEquals(new LongWritable(1), new LongWritable(1)); assertEquals(new DoubleWritable(1), new DoubleWritable(1)); assertEquals(new FloatWritable(1), new FloatWritable(1)); assertEquals(new Text("Hello"), new Text("Hello")); - assertEquals(new BytesWritable("Hello".getBytes()),new BytesWritable("Hello".getBytes())); - INDArray ndArray = Nd4j.rand(new int[]{1, 100}); - + assertEquals(new BytesWritable("Hello".getBytes()), new BytesWritable("Hello".getBytes())); + INDArray ndArray = Nd4j.rand(new int[] { 1, 100 }); assertEquals(new NDArrayWritable(ndArray), new NDArrayWritable(ndArray)); assertEquals(new NullWritable(), new NullWritable()); assertEquals(new BooleanWritable(true), new BooleanWritable(true)); @@ -56,9 +56,9 @@ public class WritableTest extends BaseND4JTest { assertEquals(new ByteWritable(b), new ByteWritable(b)); } - @Test - public void testBytesWritableIndexing() { + @DisplayName("Test Bytes Writable Indexing") + void testBytesWritableIndexing() { byte[] doubleWrite = new byte[16]; ByteBuffer wrapped = ByteBuffer.wrap(doubleWrite); Buffer buffer = (Buffer) wrapped; @@ -66,53 +66,51 @@ public class WritableTest extends BaseND4JTest { wrapped.putDouble(2.0); buffer.rewind(); BytesWritable byteWritable = new BytesWritable(doubleWrite); - assertEquals(2,byteWritable.getDouble(1),1e-1); - DataBuffer dataBuffer = Nd4j.createBuffer(new double[] {1,2}); + assertEquals(2, byteWritable.getDouble(1), 1e-1); + DataBuffer dataBuffer = Nd4j.createBuffer(new double[] { 1, 2 }); double[] d1 = dataBuffer.asDouble(); - double[] d2 = byteWritable.asNd4jBuffer(DataType.DOUBLE,8).asDouble(); + double[] d2 = byteWritable.asNd4jBuffer(DataType.DOUBLE, 8).asDouble(); assertArrayEquals(d1, d2, 0.0); } @Test - public void testByteWritable() { + @DisplayName("Test Byte Writable") + void testByteWritable() { byte b = 0xfffffffe; assertEquals(new IntWritable(-2), new ByteWritable(b)); assertEquals(new LongWritable(-2), new ByteWritable(b)); assertEquals(new ByteWritable(b), new IntWritable(-2)); assertEquals(new ByteWritable(b), new LongWritable(-2)); - // those would cast to the same Int byte minus126 = 0xffffff82; assertNotEquals(new ByteWritable(minus126), new IntWritable(130)); } @Test - public void testIntLongWritable() { + @DisplayName("Test Int Long Writable") + void testIntLongWritable() { assertEquals(new IntWritable(1), new LongWritable(1l)); assertEquals(new LongWritable(2l), new IntWritable(2)); - long l = 1L << 34; // those would cast to the same Int assertNotEquals(new LongWritable(l), new IntWritable(4)); } - @Test - public void testDoubleFloatWritable() { + @DisplayName("Test Double Float Writable") + void testDoubleFloatWritable() { assertEquals(new DoubleWritable(1d), new FloatWritable(1f)); assertEquals(new FloatWritable(2f), new DoubleWritable(2d)); - // we defer to Java equality for Floats assertNotEquals(new DoubleWritable(1.1d), new FloatWritable(1.1f)); // same idea as above - assertNotEquals(new DoubleWritable(1.1d), new FloatWritable((float)1.1d)); - - assertNotEquals(new DoubleWritable((double)Float.MAX_VALUE + 1), new FloatWritable(Float.POSITIVE_INFINITY)); + assertNotEquals(new DoubleWritable(1.1d), new FloatWritable((float) 1.1d)); + assertNotEquals(new DoubleWritable((double) Float.MAX_VALUE + 1), new FloatWritable(Float.POSITIVE_INFINITY)); } - @Test - public void testFuzzies() { + @DisplayName("Test Fuzzies") + void testFuzzies() { assertTrue(new DoubleWritable(1.1d).fuzzyEquals(new FloatWritable(1.1f), 1e-6d)); assertTrue(new FloatWritable(1.1f).fuzzyEquals(new DoubleWritable(1.1d), 1e-6d)); byte b = 0xfffffffe; @@ -122,62 +120,57 @@ public class WritableTest extends BaseND4JTest { assertTrue(new LongWritable(1).fuzzyEquals(new DoubleWritable(1.05f), 1e-1d)); } - @Test - public void testNDArrayRecordBatch(){ + @DisplayName("Test ND Array Record Batch") + void testNDArrayRecordBatch() { Nd4j.getRandom().setSeed(12345); - - List> orig = new ArrayList<>(); //Outer list over writables/columns, inner list over examples - for( int i=0; i<3; i++ ){ + // Outer list over writables/columns, inner list over examples + List> orig = new ArrayList<>(); + for (int i = 0; i < 3; i++) { orig.add(new ArrayList()); } - - for( int i=0; i<5; i++ ){ - orig.get(0).add(Nd4j.rand(1,10)); - orig.get(1).add(Nd4j.rand(new int[]{1,5,6})); - orig.get(2).add(Nd4j.rand(new int[]{1,3,4,5})); + for (int i = 0; i < 5; i++) { + orig.get(0).add(Nd4j.rand(1, 10)); + orig.get(1).add(Nd4j.rand(new int[] { 1, 5, 6 })); + orig.get(2).add(Nd4j.rand(new int[] { 1, 3, 4, 5 })); } - - List> origByExample = new ArrayList<>(); //Outer list over examples, inner list over writables - for( int i=0; i<5; i++ ){ + // Outer list over examples, inner list over writables + List> origByExample = new ArrayList<>(); + for (int i = 0; i < 5; i++) { origByExample.add(Arrays.asList(orig.get(0).get(i), orig.get(1).get(i), orig.get(2).get(i))); } - List batched = new ArrayList<>(); - for(List l : orig){ + for (List l : orig) { batched.add(Nd4j.concat(0, l.toArray(new INDArray[5]))); } - NDArrayRecordBatch batch = new NDArrayRecordBatch(batched); assertEquals(5, batch.size()); - for( int i=0; i<5; i++ ){ + for (int i = 0; i < 5; i++) { List act = batch.get(i); List unboxed = new ArrayList<>(); - for(Writable w : act){ - unboxed.add(((NDArrayWritable)w).get()); + for (Writable w : act) { + unboxed.add(((NDArrayWritable) w).get()); } List exp = origByExample.get(i); assertEquals(exp.size(), unboxed.size()); - for( int j=0; j> iter = batch.iterator(); int count = 0; - while(iter.hasNext()){ + while (iter.hasNext()) { List next = iter.next(); List unboxed = new ArrayList<>(); - for(Writable w : next){ - unboxed.add(((NDArrayWritable)w).get()); + for (Writable w : next) { + unboxed.add(((NDArrayWritable) w).get()); } List exp = origByExample.get(count++); assertEquals(exp.size(), unboxed.size()); - for( int j=0; j> ret = new ArrayList<>(numRows); - for(int i = 0; i < numRows; i++) { - ret.add(Arrays.asList(new NDArrayWritable(Nd4j.linspace(1,4,4).reshape(1, 4)))); + for (int i = 0; i < numRows; i++) { + ret.add(Arrays.asList(new NDArrayWritable(Nd4j.linspace(1, 4, 4).reshape(1, 4)))); } - List fieldVectors = ArrowConverter.toArrowColumns(bufferAllocator, schema, ret); - ArrowWritableRecordBatch arrowWritableRecordBatch = new ArrowWritableRecordBatch(fieldVectors,schema); + ArrowWritableRecordBatch arrowWritableRecordBatch = new ArrowWritableRecordBatch(fieldVectors, schema); INDArray array = ArrowConverter.toArray(arrowWritableRecordBatch); - assertArrayEquals(new long[]{4,4},array.shape()); - - INDArray assertion = Nd4j.repeat(Nd4j.linspace(1,4,4),4).reshape(4,4); - assertEquals(assertion,array); + assertArrayEquals(new long[] { 4, 4 }, array.shape()); + INDArray assertion = Nd4j.repeat(Nd4j.linspace(1, 4, 4), 4).reshape(4, 4); + assertEquals(assertion, array); } @Test - public void testArrowColumnINDArray() { + @DisplayName("Test Arrow Column IND Array") + void testArrowColumnINDArray() { Schema.Builder schema = new Schema.Builder(); List single = new ArrayList<>(); int numCols = 2; - INDArray arr = Nd4j.linspace(1,4,4); - for(int i = 0; i < numCols; i++) { - schema.addColumnNDArray(String.valueOf(i),new long[]{1,4}); + INDArray arr = Nd4j.linspace(1, 4, 4); + for (int i = 0; i < numCols; i++) { + schema.addColumnNDArray(String.valueOf(i), new long[] { 1, 4 }); single.add(String.valueOf(i)); } - Schema buildSchema = schema.build(); List> list = new ArrayList<>(); List firstRow = new ArrayList<>(); - for(int i = 0 ; i < numCols; i++) { + for (int i = 0; i < numCols; i++) { firstRow.add(new NDArrayWritable(arr)); } - list.add(firstRow); - List fieldVectors = ArrowConverter.toArrowColumns(bufferAllocator, buildSchema, list); - assertEquals(numCols,fieldVectors.size()); - assertEquals(1,fieldVectors.get(0).getValueCount()); + assertEquals(numCols, fieldVectors.size()); + assertEquals(1, fieldVectors.get(0).getValueCount()); assertFalse(fieldVectors.get(0).isNull(0)); - ArrowWritableRecordBatch arrowWritableRecordBatch = ArrowConverter.toArrowWritables(fieldVectors, buildSchema); - assertEquals(1,arrowWritableRecordBatch.size()); - + assertEquals(1, arrowWritableRecordBatch.size()); Writable writable = arrowWritableRecordBatch.get(0).get(0); assertTrue(writable instanceof NDArrayWritable); NDArrayWritable ndArrayWritable = (NDArrayWritable) writable; - assertEquals(arr,ndArrayWritable.get()); - + assertEquals(arr, ndArrayWritable.get()); Writable writable1 = ArrowConverter.fromEntry(0, fieldVectors.get(0), ColumnType.NDArray); NDArrayWritable ndArrayWritablewritable1 = (NDArrayWritable) writable1; System.out.println(ndArrayWritablewritable1.get()); - } @Test - public void testArrowColumnString() { + @DisplayName("Test Arrow Column String") + void testArrowColumnString() { Schema.Builder schema = new Schema.Builder(); List single = new ArrayList<>(); - for(int i = 0; i < 2; i++) { + for (int i = 0; i < 2; i++) { schema.addColumnInteger(String.valueOf(i)); single.add(String.valueOf(i)); } - - List fieldVectors = ArrowConverter.toArrowColumnsStringSingle(bufferAllocator, schema.build(), single); List> records = ArrowConverter.toArrowWritables(fieldVectors, schema.build()); List> assertion = new ArrayList<>(); - assertion.add(Arrays.asList(new IntWritable(0),new IntWritable(1))); - assertEquals(assertion,records); - + assertion.add(Arrays.asList(new IntWritable(0), new IntWritable(1))); + assertEquals(assertion, records); List> batch = new ArrayList<>(); - for(int i = 0; i < 2; i++) { - batch.add(Arrays.asList(String.valueOf(i),String.valueOf(i))); + for (int i = 0; i < 2; i++) { + batch.add(Arrays.asList(String.valueOf(i), String.valueOf(i))); } - List fieldVectorsBatch = ArrowConverter.toArrowColumnsString(bufferAllocator, schema.build(), batch); List> batchRecords = ArrowConverter.toArrowWritables(fieldVectorsBatch, schema.build()); - List> assertionBatch = new ArrayList<>(); - assertionBatch.add(Arrays.asList(new IntWritable(0),new IntWritable(0))); - assertionBatch.add(Arrays.asList(new IntWritable(1),new IntWritable(1))); - assertEquals(assertionBatch,batchRecords); - - + assertionBatch.add(Arrays.asList(new IntWritable(0), new IntWritable(0))); + assertionBatch.add(Arrays.asList(new IntWritable(1), new IntWritable(1))); + assertEquals(assertionBatch, batchRecords); } - - @Test - public void testArrowBatchSetTime() { + @DisplayName("Test Arrow Batch Set Time") + void testArrowBatchSetTime() { Schema.Builder schema = new Schema.Builder(); List single = new ArrayList<>(); - for(int i = 0; i < 2; i++) { - schema.addColumnTime(String.valueOf(i),TimeZone.getDefault()); + for (int i = 0; i < 2; i++) { + schema.addColumnTime(String.valueOf(i), TimeZone.getDefault()); single.add(String.valueOf(i)); } - - List> input = Arrays.asList( - Arrays.asList(new LongWritable(0),new LongWritable(1)), - Arrays.asList(new LongWritable(2),new LongWritable(3)) - ); - - List fieldVector = ArrowConverter.toArrowColumns(bufferAllocator,schema.build(),input); - ArrowWritableRecordBatch writableRecordBatch = new ArrowWritableRecordBatch(fieldVector,schema.build()); + List> input = Arrays.asList(Arrays.asList(new LongWritable(0), new LongWritable(1)), Arrays.asList(new LongWritable(2), new LongWritable(3))); + List fieldVector = ArrowConverter.toArrowColumns(bufferAllocator, schema.build(), input); + ArrowWritableRecordBatch writableRecordBatch = new ArrowWritableRecordBatch(fieldVector, schema.build()); List assertion = Arrays.asList(new LongWritable(4), new LongWritable(5)); - writableRecordBatch.set(1, Arrays.asList(new LongWritable(4),new LongWritable(5))); + writableRecordBatch.set(1, Arrays.asList(new LongWritable(4), new LongWritable(5))); List recordTest = writableRecordBatch.get(1); - assertEquals(assertion,recordTest); + assertEquals(assertion, recordTest); } @Test - public void testArrowBatchSet() { + @DisplayName("Test Arrow Batch Set") + void testArrowBatchSet() { Schema.Builder schema = new Schema.Builder(); List single = new ArrayList<>(); - for(int i = 0; i < 2; i++) { + for (int i = 0; i < 2; i++) { schema.addColumnInteger(String.valueOf(i)); single.add(String.valueOf(i)); } - - List> input = Arrays.asList( - Arrays.asList(new IntWritable(0),new IntWritable(1)), - Arrays.asList(new IntWritable(2),new IntWritable(3)) - ); - - List fieldVector = ArrowConverter.toArrowColumns(bufferAllocator,schema.build(),input); - ArrowWritableRecordBatch writableRecordBatch = new ArrowWritableRecordBatch(fieldVector,schema.build()); + List> input = Arrays.asList(Arrays.asList(new IntWritable(0), new IntWritable(1)), Arrays.asList(new IntWritable(2), new IntWritable(3))); + List fieldVector = ArrowConverter.toArrowColumns(bufferAllocator, schema.build(), input); + ArrowWritableRecordBatch writableRecordBatch = new ArrowWritableRecordBatch(fieldVector, schema.build()); List assertion = Arrays.asList(new IntWritable(4), new IntWritable(5)); - writableRecordBatch.set(1, Arrays.asList(new IntWritable(4),new IntWritable(5))); + writableRecordBatch.set(1, Arrays.asList(new IntWritable(4), new IntWritable(5))); List recordTest = writableRecordBatch.get(1); - assertEquals(assertion,recordTest); + assertEquals(assertion, recordTest); } @Test - public void testArrowColumnsStringTimeSeries() { + @DisplayName("Test Arrow Columns String Time Series") + void testArrowColumnsStringTimeSeries() { Schema.Builder schema = new Schema.Builder(); List>> entries = new ArrayList<>(); - for(int i = 0; i < 3; i++) { + for (int i = 0; i < 3; i++) { schema.addColumnInteger(String.valueOf(i)); } - - for(int i = 0; i < 5; i++) { + for (int i = 0; i < 5; i++) { List> arr = Arrays.asList(Arrays.asList(String.valueOf(i), String.valueOf(i), String.valueOf(i))); entries.add(arr); } - List fieldVectors = ArrowConverter.toArrowColumnsStringTimeSeries(bufferAllocator, schema.build(), entries); - assertEquals(3,fieldVectors.size()); - assertEquals(5,fieldVectors.get(0).getValueCount()); - - + assertEquals(3, fieldVectors.size()); + assertEquals(5, fieldVectors.get(0).getValueCount()); INDArray exp = Nd4j.create(5, 3); - for( int i = 0; i < 5; i++) { + for (int i = 0; i < 5; i++) { exp.getRow(i).assign(i); } - //Convert to ArrowWritableRecordBatch - note we can't do this in general with time series... + // Convert to ArrowWritableRecordBatch - note we can't do this in general with time series... ArrowWritableRecordBatch wri = ArrowConverter.toArrowWritables(fieldVectors, schema.build()); INDArray arr = ArrowConverter.toArray(wri); - assertArrayEquals(new long[] {5,3}, arr.shape()); - - + assertArrayEquals(new long[] { 5, 3 }, arr.shape()); assertEquals(exp, arr); } @Test - public void testConvertVector() { + @DisplayName("Test Convert Vector") + void testConvertVector() { Schema.Builder schema = new Schema.Builder(); List>> entries = new ArrayList<>(); - for(int i = 0; i < 3; i++) { + for (int i = 0; i < 3; i++) { schema.addColumnInteger(String.valueOf(i)); } - - for(int i = 0; i < 5; i++) { + for (int i = 0; i < 5; i++) { List> arr = Arrays.asList(Arrays.asList(String.valueOf(i), String.valueOf(i), String.valueOf(i))); entries.add(arr); } - List fieldVectors = ArrowConverter.toArrowColumnsStringTimeSeries(bufferAllocator, schema.build(), entries); - INDArray arr = ArrowConverter.convertArrowVector(fieldVectors.get(0),schema.build().getType(0)); - assertEquals(5,arr.length()); + INDArray arr = ArrowConverter.convertArrowVector(fieldVectors.get(0), schema.build().getType(0)); + assertEquals(5, arr.length()); } @Test - public void testCreateNDArray() throws Exception { + @DisplayName("Test Create ND Array") + void testCreateNDArray() throws Exception { val recordsToWrite = recordToWrite(); ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); - ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(),recordsToWrite.getFirst(),byteArrayOutputStream); - - File f = testDir.newFolder(); - + ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(), recordsToWrite.getFirst(), byteArrayOutputStream); + File f = testDir.toFile(); File tmpFile = new File(f, "tmp-arrow-file-" + UUID.randomUUID().toString() + ".arrorw"); FileOutputStream outputStream = new FileOutputStream(tmpFile); tmpFile.deleteOnExit(); - ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(),recordsToWrite.getFirst(),outputStream); + ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(), recordsToWrite.getFirst(), outputStream); outputStream.flush(); outputStream.close(); - Pair schemaArrowWritableRecordBatchPair = ArrowConverter.readFromFile(tmpFile); - assertEquals(recordsToWrite.getFirst(),schemaArrowWritableRecordBatchPair.getFirst()); - assertEquals(recordsToWrite.getRight(),schemaArrowWritableRecordBatchPair.getRight().toArrayList()); - + assertEquals(recordsToWrite.getFirst(), schemaArrowWritableRecordBatchPair.getFirst()); + assertEquals(recordsToWrite.getRight(), schemaArrowWritableRecordBatchPair.getRight().toArrayList()); byte[] arr = byteArrayOutputStream.toByteArray(); val read = ArrowConverter.readFromBytes(arr); - assertEquals(recordsToWrite,read); - - //send file - File tmp = tmpDataFile(recordsToWrite); + assertEquals(recordsToWrite, read); + // send file + File tmp = tmpDataFile(recordsToWrite); ArrowRecordReader recordReader = new ArrowRecordReader(); - recordReader.initialize(new FileSplit(tmp)); - recordReader.next(); ArrowWritableRecordBatch currentBatch = recordReader.getCurrentBatch(); INDArray arr2 = ArrowConverter.toArray(currentBatch); - assertEquals(2,arr2.rows()); - assertEquals(2,arr2.columns()); - } - - - @Test - public void testConvertToArrowVectors() { - INDArray matrix = Nd4j.linspace(1,4,4).reshape(2,2); - val vectors = ArrowConverter.convertToArrowVector(matrix,Arrays.asList("test","test2"), ColumnType.Double,bufferAllocator); - assertEquals(matrix.rows(),vectors.size()); - - INDArray vector = Nd4j.linspace(1,4,4); - val vectors2 = ArrowConverter.convertToArrowVector(vector,Arrays.asList("test"), ColumnType.Double,bufferAllocator); - assertEquals(1,vectors2.size()); - assertEquals(matrix.length(),vectors2.get(0).getValueCount()); - + assertEquals(2, arr2.rows()); + assertEquals(2, arr2.columns()); } @Test - public void testSchemaConversionBasic() { + @DisplayName("Test Convert To Arrow Vectors") + void testConvertToArrowVectors() { + INDArray matrix = Nd4j.linspace(1, 4, 4).reshape(2, 2); + val vectors = ArrowConverter.convertToArrowVector(matrix, Arrays.asList("test", "test2"), ColumnType.Double, bufferAllocator); + assertEquals(matrix.rows(), vectors.size()); + INDArray vector = Nd4j.linspace(1, 4, 4); + val vectors2 = ArrowConverter.convertToArrowVector(vector, Arrays.asList("test"), ColumnType.Double, bufferAllocator); + assertEquals(1, vectors2.size()); + assertEquals(matrix.length(), vectors2.get(0).getValueCount()); + } + + @Test + @DisplayName("Test Schema Conversion Basic") + void testSchemaConversionBasic() { Schema.Builder schemaBuilder = new Schema.Builder(); - for(int i = 0; i < 2; i++) { + for (int i = 0; i < 2; i++) { schemaBuilder.addColumnDouble("test-" + i); schemaBuilder.addColumnInteger("testi-" + i); schemaBuilder.addColumnLong("testl-" + i); schemaBuilder.addColumnFloat("testf-" + i); } - - Schema schema = schemaBuilder.build(); val schema2 = ArrowConverter.toArrowSchema(schema); - assertEquals(8,schema2.getFields().size()); + assertEquals(8, schema2.getFields().size()); val convertedSchema = ArrowConverter.toDatavecSchema(schema2); - assertEquals(schema,convertedSchema); + assertEquals(schema, convertedSchema); } @Test - public void testReadSchemaAndRecordsFromByteArray() throws Exception { + @DisplayName("Test Read Schema And Records From Byte Array") + void testReadSchemaAndRecordsFromByteArray() throws Exception { BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); - int valueCount = 3; List fields = new ArrayList<>(); - fields.add(ArrowConverter.field("field1",new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE))); + fields.add(ArrowConverter.field("field1", new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE))); fields.add(ArrowConverter.intField("field2")); - List fieldVectors = new ArrayList<>(); - fieldVectors.add(ArrowConverter.vectorFor(allocator,"field1",new float[] {1,2,3})); - fieldVectors.add(ArrowConverter.vectorFor(allocator,"field2",new int[] {1,2,3})); - - + fieldVectors.add(ArrowConverter.vectorFor(allocator, "field1", new float[] { 1, 2, 3 })); + fieldVectors.add(ArrowConverter.vectorFor(allocator, "field2", new int[] { 1, 2, 3 })); org.apache.arrow.vector.types.pojo.Schema schema = new org.apache.arrow.vector.types.pojo.Schema(fields); - VectorSchemaRoot schemaRoot1 = new VectorSchemaRoot(schema, fieldVectors, valueCount); VectorUnloader vectorUnloader = new VectorUnloader(schemaRoot1); vectorUnloader.getRecordBatch(); ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); - try(ArrowFileWriter arrowFileWriter = new ArrowFileWriter(schemaRoot1,null,newChannel(byteArrayOutputStream))) { + try (ArrowFileWriter arrowFileWriter = new ArrowFileWriter(schemaRoot1, null, newChannel(byteArrayOutputStream))) { arrowFileWriter.writeBatch(); } catch (IOException e) { - log.error("",e); + log.error("", e); } - byte[] arr = byteArrayOutputStream.toByteArray(); val arr2 = ArrowConverter.readFromBytes(arr); - assertEquals(2,arr2.getFirst().numColumns()); - assertEquals(3,arr2.getRight().size()); - - val arrowCols = ArrowConverter.toArrowColumns(allocator,arr2.getFirst(),arr2.getRight()); - assertEquals(2,arrowCols.size()); - assertEquals(valueCount,arrowCols.get(0).getValueCount()); + assertEquals(2, arr2.getFirst().numColumns()); + assertEquals(3, arr2.getRight().size()); + val arrowCols = ArrowConverter.toArrowColumns(allocator, arr2.getFirst(), arr2.getRight()); + assertEquals(2, arrowCols.size()); + assertEquals(valueCount, arrowCols.get(0).getValueCount()); } - @Test - public void testVectorForEdgeCases() { + @DisplayName("Test Vector For Edge Cases") + void testVectorForEdgeCases() { BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); - val vector = ArrowConverter.vectorFor(allocator,"field1",new float[]{Float.MIN_VALUE,Float.MAX_VALUE}); - assertEquals(Float.MIN_VALUE,vector.get(0),1e-2); - assertEquals(Float.MAX_VALUE,vector.get(1),1e-2); - - val vectorInt = ArrowConverter.vectorFor(allocator,"field1",new int[]{Integer.MIN_VALUE,Integer.MAX_VALUE}); - assertEquals(Integer.MIN_VALUE,vectorInt.get(0),1e-2); - assertEquals(Integer.MAX_VALUE,vectorInt.get(1),1e-2); - + val vector = ArrowConverter.vectorFor(allocator, "field1", new float[] { Float.MIN_VALUE, Float.MAX_VALUE }); + assertEquals(Float.MIN_VALUE, vector.get(0), 1e-2); + assertEquals(Float.MAX_VALUE, vector.get(1), 1e-2); + val vectorInt = ArrowConverter.vectorFor(allocator, "field1", new int[] { Integer.MIN_VALUE, Integer.MAX_VALUE }); + assertEquals(Integer.MIN_VALUE, vectorInt.get(0), 1e-2); + assertEquals(Integer.MAX_VALUE, vectorInt.get(1), 1e-2); } @Test - public void testVectorFor() { + @DisplayName("Test Vector For") + void testVectorFor() { BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); - - val vector = ArrowConverter.vectorFor(allocator,"field1",new float[]{1,2,3}); - assertEquals(3,vector.getValueCount()); - assertEquals(1,vector.get(0),1e-2); - assertEquals(2,vector.get(1),1e-2); - assertEquals(3,vector.get(2),1e-2); - - val vectorLong = ArrowConverter.vectorFor(allocator,"field1",new long[]{1,2,3}); - assertEquals(3,vectorLong.getValueCount()); - assertEquals(1,vectorLong.get(0),1e-2); - assertEquals(2,vectorLong.get(1),1e-2); - assertEquals(3,vectorLong.get(2),1e-2); - - - val vectorInt = ArrowConverter.vectorFor(allocator,"field1",new int[]{1,2,3}); - assertEquals(3,vectorInt.getValueCount()); - assertEquals(1,vectorInt.get(0),1e-2); - assertEquals(2,vectorInt.get(1),1e-2); - assertEquals(3,vectorInt.get(2),1e-2); - - val vectorDouble = ArrowConverter.vectorFor(allocator,"field1",new double[]{1,2,3}); - assertEquals(3,vectorDouble.getValueCount()); - assertEquals(1,vectorDouble.get(0),1e-2); - assertEquals(2,vectorDouble.get(1),1e-2); - assertEquals(3,vectorDouble.get(2),1e-2); - - - val vectorBool = ArrowConverter.vectorFor(allocator,"field1",new boolean[]{true,true,false}); - assertEquals(3,vectorBool.getValueCount()); - assertEquals(1,vectorBool.get(0),1e-2); - assertEquals(1,vectorBool.get(1),1e-2); - assertEquals(0,vectorBool.get(2),1e-2); + val vector = ArrowConverter.vectorFor(allocator, "field1", new float[] { 1, 2, 3 }); + assertEquals(3, vector.getValueCount()); + assertEquals(1, vector.get(0), 1e-2); + assertEquals(2, vector.get(1), 1e-2); + assertEquals(3, vector.get(2), 1e-2); + val vectorLong = ArrowConverter.vectorFor(allocator, "field1", new long[] { 1, 2, 3 }); + assertEquals(3, vectorLong.getValueCount()); + assertEquals(1, vectorLong.get(0), 1e-2); + assertEquals(2, vectorLong.get(1), 1e-2); + assertEquals(3, vectorLong.get(2), 1e-2); + val vectorInt = ArrowConverter.vectorFor(allocator, "field1", new int[] { 1, 2, 3 }); + assertEquals(3, vectorInt.getValueCount()); + assertEquals(1, vectorInt.get(0), 1e-2); + assertEquals(2, vectorInt.get(1), 1e-2); + assertEquals(3, vectorInt.get(2), 1e-2); + val vectorDouble = ArrowConverter.vectorFor(allocator, "field1", new double[] { 1, 2, 3 }); + assertEquals(3, vectorDouble.getValueCount()); + assertEquals(1, vectorDouble.get(0), 1e-2); + assertEquals(2, vectorDouble.get(1), 1e-2); + assertEquals(3, vectorDouble.get(2), 1e-2); + val vectorBool = ArrowConverter.vectorFor(allocator, "field1", new boolean[] { true, true, false }); + assertEquals(3, vectorBool.getValueCount()); + assertEquals(1, vectorBool.get(0), 1e-2); + assertEquals(1, vectorBool.get(1), 1e-2); + assertEquals(0, vectorBool.get(2), 1e-2); } @Test - public void testRecordReaderAndWriteFile() throws Exception { + @DisplayName("Test Record Reader And Write File") + void testRecordReaderAndWriteFile() throws Exception { val recordsToWrite = recordToWrite(); ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); - ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(),recordsToWrite.getFirst(),byteArrayOutputStream); + ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(), recordsToWrite.getFirst(), byteArrayOutputStream); byte[] arr = byteArrayOutputStream.toByteArray(); val read = ArrowConverter.readFromBytes(arr); - assertEquals(recordsToWrite,read); - - //send file - File tmp = tmpDataFile(recordsToWrite); + assertEquals(recordsToWrite, read); + // send file + File tmp = tmpDataFile(recordsToWrite); RecordReader recordReader = new ArrowRecordReader(); - recordReader.initialize(new FileSplit(tmp)); - List record = recordReader.next(); - assertEquals(2,record.size()); - + assertEquals(2, record.size()); } @Test - public void testRecordReaderMetaDataList() throws Exception { + @DisplayName("Test Record Reader Meta Data List") + void testRecordReaderMetaDataList() throws Exception { val recordsToWrite = recordToWrite(); - //send file - File tmp = tmpDataFile(recordsToWrite); + // send file + File tmp = tmpDataFile(recordsToWrite); RecordReader recordReader = new ArrowRecordReader(); - RecordMetaDataIndex recordMetaDataIndex = new RecordMetaDataIndex(0,tmp.toURI(),ArrowRecordReader.class); + RecordMetaDataIndex recordMetaDataIndex = new RecordMetaDataIndex(0, tmp.toURI(), ArrowRecordReader.class); recordReader.loadFromMetaData(Arrays.asList(recordMetaDataIndex)); - Record record = recordReader.nextRecord(); - assertEquals(2,record.getRecord().size()); - + assertEquals(2, record.getRecord().size()); } @Test - public void testDates() { + @DisplayName("Test Dates") + void testDates() { Date now = new Date(); BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE); - TimeStampMilliVector timeStampMilliVector = ArrowConverter.vectorFor(bufferAllocator, "col1", new Date[]{now}); - assertEquals(now.getTime(),timeStampMilliVector.get(0)); + TimeStampMilliVector timeStampMilliVector = ArrowConverter.vectorFor(bufferAllocator, "col1", new Date[] { now }); + assertEquals(now.getTime(), timeStampMilliVector.get(0)); } - @Test - public void testRecordReaderMetaData() throws Exception { + @DisplayName("Test Record Reader Meta Data") + void testRecordReaderMetaData() throws Exception { val recordsToWrite = recordToWrite(); - //send file - File tmp = tmpDataFile(recordsToWrite); + // send file + File tmp = tmpDataFile(recordsToWrite); RecordReader recordReader = new ArrowRecordReader(); - RecordMetaDataIndex recordMetaDataIndex = new RecordMetaDataIndex(0,tmp.toURI(),ArrowRecordReader.class); + RecordMetaDataIndex recordMetaDataIndex = new RecordMetaDataIndex(0, tmp.toURI(), ArrowRecordReader.class); recordReader.loadFromMetaData(recordMetaDataIndex); - Record record = recordReader.nextRecord(); - assertEquals(2,record.getRecord().size()); + assertEquals(2, record.getRecord().size()); } - private File tmpDataFile(Pair>> recordsToWrite) throws IOException { - - File f = testDir.newFolder(); - - //send file - File tmp = new File(f,"tmp-file-" + UUID.randomUUID().toString()); + private File tmpDataFile(Pair>> recordsToWrite) throws IOException { + File f = testDir.toFile(); + // send file + File tmp = new File(f, "tmp-file-" + UUID.randomUUID().toString()); tmp.mkdirs(); - File tmpFile = new File(tmp,"data.arrow"); + File tmpFile = new File(tmp, "data.arrow"); tmpFile.deleteOnExit(); FileOutputStream bufferedOutputStream = new FileOutputStream(tmpFile); - ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(),recordsToWrite.getFirst(),bufferedOutputStream); + ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(), recordsToWrite.getFirst(), bufferedOutputStream); bufferedOutputStream.flush(); bufferedOutputStream.close(); return tmp; } - private Pair>> recordToWrite() { + private Pair>> recordToWrite() { List> records = new ArrayList<>(); - records.add(Arrays.asList(new DoubleWritable(0.0),new DoubleWritable(0.0))); - records.add(Arrays.asList(new DoubleWritable(0.0),new DoubleWritable(0.0))); + records.add(Arrays.asList(new DoubleWritable(0.0), new DoubleWritable(0.0))); + records.add(Arrays.asList(new DoubleWritable(0.0), new DoubleWritable(0.0))); Schema.Builder schemaBuilder = new Schema.Builder(); - for(int i = 0; i < 2; i++) { + for (int i = 0; i < 2; i++) { schemaBuilder.addColumnFloat("col-" + i); } - - return Pair.of(schemaBuilder.build(),records); + return Pair.of(schemaBuilder.build(), records); } - - - - } diff --git a/datavec/datavec-arrow/src/test/java/org/datavec/arrow/RecordMapperTest.java b/datavec/datavec-arrow/src/test/java/org/datavec/arrow/RecordMapperTest.java index 42abee0b3..5eec05c93 100644 --- a/datavec/datavec-arrow/src/test/java/org/datavec/arrow/RecordMapperTest.java +++ b/datavec/datavec-arrow/src/test/java/org/datavec/arrow/RecordMapperTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.datavec.arrow; import lombok.val; @@ -34,132 +33,98 @@ import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.Writable; import org.datavec.arrow.recordreader.ArrowRecordReader; import org.datavec.arrow.recordreader.ArrowRecordWriter; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.primitives.Triple; - import java.io.File; import java.nio.file.Files; import java.nio.file.Path; import java.util.ArrayList; import java.util.List; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.assertEquals; - -public class RecordMapperTest extends BaseND4JTest { +@DisplayName("Record Mapper Test") +class RecordMapperTest extends BaseND4JTest { @Test - public void testMultiWrite() throws Exception { + @DisplayName("Test Multi Write") + void testMultiWrite() throws Exception { val recordsPair = records(); - Path p = Files.createTempFile("arrowwritetest", ".arrow"); - FileUtils.write(p.toFile(),recordsPair.getFirst()); + FileUtils.write(p.toFile(), recordsPair.getFirst()); p.toFile().deleteOnExit(); - int numReaders = 2; RecordReader[] readers = new RecordReader[numReaders]; InputSplit[] splits = new InputSplit[numReaders]; - for(int i = 0; i < readers.length; i++) { + for (int i = 0; i < readers.length; i++) { FileSplit split = new FileSplit(p.toFile()); ArrowRecordReader arrowRecordReader = new ArrowRecordReader(); readers[i] = arrowRecordReader; splits[i] = split; } - ArrowRecordWriter arrowRecordWriter = new ArrowRecordWriter(recordsPair.getMiddle()); FileSplit split = new FileSplit(p.toFile()); - arrowRecordWriter.initialize(split,new NumberOfRecordsPartitioner()); + arrowRecordWriter.initialize(split, new NumberOfRecordsPartitioner()); arrowRecordWriter.writeBatch(recordsPair.getRight()); - - CSVRecordWriter csvRecordWriter = new CSVRecordWriter(); Path p2 = Files.createTempFile("arrowwritetest", ".csv"); - FileUtils.write(p2.toFile(),recordsPair.getFirst()); + FileUtils.write(p2.toFile(), recordsPair.getFirst()); p.toFile().deleteOnExit(); FileSplit outputCsv = new FileSplit(p2.toFile()); - - RecordMapper mapper = RecordMapper.builder().batchSize(10).inputUrl(split) - .outputUrl(outputCsv) - .partitioner(new NumberOfRecordsPartitioner()).readersToConcat(readers) - .splitPerReader(splits) - .recordWriter(csvRecordWriter) - .build(); + RecordMapper mapper = RecordMapper.builder().batchSize(10).inputUrl(split).outputUrl(outputCsv).partitioner(new NumberOfRecordsPartitioner()).readersToConcat(readers).splitPerReader(splits).recordWriter(csvRecordWriter).build(); mapper.copy(); - - } - @Test - public void testCopyFromArrowToCsv() throws Exception { + @DisplayName("Test Copy From Arrow To Csv") + void testCopyFromArrowToCsv() throws Exception { val recordsPair = records(); - Path p = Files.createTempFile("arrowwritetest", ".arrow"); - FileUtils.write(p.toFile(),recordsPair.getFirst()); + FileUtils.write(p.toFile(), recordsPair.getFirst()); p.toFile().deleteOnExit(); - ArrowRecordWriter arrowRecordWriter = new ArrowRecordWriter(recordsPair.getMiddle()); FileSplit split = new FileSplit(p.toFile()); - arrowRecordWriter.initialize(split,new NumberOfRecordsPartitioner()); + arrowRecordWriter.initialize(split, new NumberOfRecordsPartitioner()); arrowRecordWriter.writeBatch(recordsPair.getRight()); - - ArrowRecordReader arrowRecordReader = new ArrowRecordReader(); arrowRecordReader.initialize(split); - - CSVRecordWriter csvRecordWriter = new CSVRecordWriter(); Path p2 = Files.createTempFile("arrowwritetest", ".csv"); - FileUtils.write(p2.toFile(),recordsPair.getFirst()); + FileUtils.write(p2.toFile(), recordsPair.getFirst()); p.toFile().deleteOnExit(); FileSplit outputCsv = new FileSplit(p2.toFile()); - - RecordMapper mapper = RecordMapper.builder().batchSize(10).inputUrl(split) - .outputUrl(outputCsv) - .partitioner(new NumberOfRecordsPartitioner()) - .recordReader(arrowRecordReader).recordWriter(csvRecordWriter) - .build(); + RecordMapper mapper = RecordMapper.builder().batchSize(10).inputUrl(split).outputUrl(outputCsv).partitioner(new NumberOfRecordsPartitioner()).recordReader(arrowRecordReader).recordWriter(csvRecordWriter).build(); mapper.copy(); - CSVRecordReader recordReader = new CSVRecordReader(); recordReader.initialize(outputCsv); - - List> loadedCSvRecords = recordReader.next(10); - assertEquals(10,loadedCSvRecords.size()); + assertEquals(10, loadedCSvRecords.size()); } - @Test - public void testCopyFromCsvToArrow() throws Exception { + @DisplayName("Test Copy From Csv To Arrow") + void testCopyFromCsvToArrow() throws Exception { val recordsPair = records(); - Path p = Files.createTempFile("csvwritetest", ".csv"); - FileUtils.write(p.toFile(),recordsPair.getFirst()); + FileUtils.write(p.toFile(), recordsPair.getFirst()); p.toFile().deleteOnExit(); - - CSVRecordReader recordReader = new CSVRecordReader(); FileSplit fileSplit = new FileSplit(p.toFile()); - ArrowRecordWriter arrowRecordWriter = new ArrowRecordWriter(recordsPair.getMiddle()); - File outputFile = Files.createTempFile("outputarrow","arrow").toFile(); + File outputFile = Files.createTempFile("outputarrow", "arrow").toFile(); FileSplit outputFileSplit = new FileSplit(outputFile); - RecordMapper mapper = RecordMapper.builder().batchSize(10).inputUrl(fileSplit) - .outputUrl(outputFileSplit).partitioner(new NumberOfRecordsPartitioner()) - .recordReader(recordReader).recordWriter(arrowRecordWriter) - .build(); + RecordMapper mapper = RecordMapper.builder().batchSize(10).inputUrl(fileSplit).outputUrl(outputFileSplit).partitioner(new NumberOfRecordsPartitioner()).recordReader(recordReader).recordWriter(arrowRecordWriter).build(); mapper.copy(); - ArrowRecordReader arrowRecordReader = new ArrowRecordReader(); arrowRecordReader.initialize(outputFileSplit); List> next = arrowRecordReader.next(10); System.out.println(next); - assertEquals(10,next.size()); - + assertEquals(10, next.size()); } - private Triple>> records() { + private Triple>> records() { List> list = new ArrayList<>(); StringBuilder sb = new StringBuilder(); int numColumns = 3; @@ -176,15 +141,10 @@ public class RecordMapperTest extends BaseND4JTest { } list.add(temp); } - - Schema.Builder schemaBuilder = new Schema.Builder(); - for(int i = 0; i < numColumns; i++) { + for (int i = 0; i < numColumns; i++) { schemaBuilder.addColumnInteger(String.valueOf(i)); } - - return Triple.of(sb.toString(),schemaBuilder.build(),list); + return Triple.of(sb.toString(), schemaBuilder.build(), list); } - - } diff --git a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/LabelGeneratorTest.java b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/LabelGeneratorTest.java index f5e62341c..5cdc2bf40 100644 --- a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/LabelGeneratorTest.java +++ b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/LabelGeneratorTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.datavec.image; import org.apache.commons.io.FileUtils; @@ -25,33 +24,32 @@ import org.datavec.api.io.labels.ParentPathLabelGenerator; import org.datavec.api.split.FileSplit; import org.datavec.image.recordreader.ImageRecordReader; import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.io.ClassPathResource; - import java.io.File; import java.util.Arrays; import java.util.List; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import org.junit.jupiter.api.DisplayName; +import java.nio.file.Path; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +@DisplayName("Label Generator Test") +class LabelGeneratorTest { -public class LabelGeneratorTest { - - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); + @TempDir + public Path testDir; @Test - public void testParentPathLabelGenerator() throws Exception { - //https://github.com/deeplearning4j/DataVec/issues/273 + @DisplayName("Test Parent Path Label Generator") + void testParentPathLabelGenerator(@TempDir Path testDir) throws Exception { File orig = new ClassPathResource("datavec-data-image/testimages/class0/0.jpg").getFile(); - - for(String dirPrefix : new String[]{"m.", "m"}) { - File f = testDir.newFolder(); - + for (String dirPrefix : new String[] { "m.", "m" }) { + File f = testDir.toFile(); int numDirs = 3; int filesPerDir = 4; - for (int i = 0; i < numDirs; i++) { File currentLabelDir = new File(f, dirPrefix + i); currentLabelDir.mkdirs(); @@ -61,14 +59,11 @@ public class LabelGeneratorTest { assertTrue(f3.exists()); } } - ImageRecordReader rr = new ImageRecordReader(28, 28, 1, new ParentPathLabelGenerator()); rr.initialize(new FileSplit(f)); - List labelsAct = rr.getLabels(); List labelsExp = Arrays.asList(dirPrefix + "0", dirPrefix + "1", dirPrefix + "2"); assertEquals(labelsExp, labelsAct); - int expCount = numDirs * filesPerDir; int actCount = 0; while (rr.hasNext()) { diff --git a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/recordreader/FileBatchRecordReaderTest.java b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/recordreader/FileBatchRecordReaderTest.java index d54b32b0e..5676dd020 100644 --- a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/recordreader/FileBatchRecordReaderTest.java +++ b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/recordreader/FileBatchRecordReaderTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.datavec.image.recordreader; import org.apache.commons.io.FileUtils; @@ -29,60 +28,55 @@ import org.datavec.api.writable.NDArrayWritable; import org.datavec.api.writable.Writable; import org.datavec.image.loader.NativeImageLoader; import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.loader.FileBatch; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.common.io.ClassPathResource; - import java.io.File; import java.util.*; +import static org.junit.jupiter.api.Assertions.*; +import org.junit.jupiter.api.DisplayName; +import java.nio.file.Path; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.*; +@DisplayName("File Batch Record Reader Test") +class FileBatchRecordReaderTest { -public class FileBatchRecordReaderTest { - - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); + @TempDir + public Path testDir; @Test - public void testCsv() throws Exception { - File extractedSourceDir = testDir.newFolder(); + @DisplayName("Test Csv") + void testCsv(@TempDir Path testDir,@TempDir Path baseDirPath) throws Exception { + File extractedSourceDir = testDir.toFile(); new ClassPathResource("datavec-data-image/testimages").copyDirectory(extractedSourceDir); - File baseDir = testDir.newFolder(); - - + File baseDir = baseDirPath.toFile(); List c = new ArrayList<>(FileUtils.listFiles(extractedSourceDir, null, true)); assertEquals(6, c.size()); - Collections.sort(c, new Comparator() { + @Override public int compare(File o1, File o2) { return o1.getPath().compareTo(o2.getPath()); } }); - - FileBatch fb = FileBatch.forFiles(c); File saveFile = new File(baseDir, "saved.zip"); fb.writeAsZip(saveFile); fb = FileBatch.readFromZip(saveFile); - PathLabelGenerator labelMaker = new ParentPathLabelGenerator(); ImageRecordReader rr = new ImageRecordReader(32, 32, 1, labelMaker); rr.setLabels(Arrays.asList("class0", "class1")); FileBatchRecordReader fbrr = new FileBatchRecordReader(rr, fb); - - NativeImageLoader il = new NativeImageLoader(32, 32, 1); - for( int test=0; test<3; test++) { + for (int test = 0; test < 3; test++) { for (int i = 0; i < 6; i++) { assertTrue(fbrr.hasNext()); List next = fbrr.next(); assertEquals(2, next.size()); - INDArray exp; - switch (i){ + switch(i) { case 0: exp = il.asMatrix(new File(extractedSourceDir, "class0/0.jpg")); break; @@ -105,8 +99,7 @@ public class FileBatchRecordReaderTest { throw new RuntimeException(); } Writable expLabel = (i < 3 ? new IntWritable(0) : new IntWritable(1)); - - assertEquals(((NDArrayWritable)next.get(0)).get(), exp); + assertEquals(((NDArrayWritable) next.get(0)).get(), exp); assertEquals(expLabel, next.get(1)); } assertFalse(fbrr.hasNext()); @@ -114,5 +107,4 @@ public class FileBatchRecordReaderTest { fbrr.reset(); } } - } diff --git a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/transform/JsonYamlTest.java b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/transform/JsonYamlTest.java index 2d9bab6ea..60d354d9e 100644 --- a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/transform/JsonYamlTest.java +++ b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/transform/JsonYamlTest.java @@ -17,106 +17,70 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.datavec.image.transform; import org.datavec.image.data.ImageWritable; -import org.junit.Test; - +import org.junit.jupiter.api.Test; import java.io.IOException; import java.util.List; import java.util.Random; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +@DisplayName("Json Yaml Test") +class JsonYamlTest { -public class JsonYamlTest { @Test - public void testJsonYamlImageTransformProcess() throws IOException { + @DisplayName("Test Json Yaml Image Transform Process") + void testJsonYamlImageTransformProcess() throws IOException { int seed = 12345; Random random = new Random(seed); - - //from org.bytedeco.javacpp.opencv_imgproc + // from org.bytedeco.javacpp.opencv_imgproc int COLOR_BGR2Luv = 50; int CV_BGR2GRAY = 6; - - - ImageTransformProcess itp = new ImageTransformProcess.Builder().colorConversionTransform(COLOR_BGR2Luv) - .cropImageTransform(10).equalizeHistTransform(CV_BGR2GRAY).flipImageTransform(0) - .resizeImageTransform(300, 300).rotateImageTransform(30).scaleImageTransform(3) - .warpImageTransform((float) 0.5) - - // Note : since randomCropTransform use random value - // the results from each case(json, yaml, ImageTransformProcess) - // can be different - // don't use the below line - // if you uncomment it, you will get fail from below assertions - // .randomCropTransform(seed, 50, 50) - - // Note : you will get "java.lang.NoClassDefFoundError: Could not initialize class org.bytedeco.javacpp.avutil" - // it needs to add the below dependency - // - // org.bytedeco - // ffmpeg-platform - // - // FFmpeg has license issues, be careful to use it - //.filterImageTransform("noise=alls=20:allf=t+u,format=rgba", 100, 100, 4) - - .build(); - + ImageTransformProcess itp = new ImageTransformProcess.Builder().colorConversionTransform(COLOR_BGR2Luv).cropImageTransform(10).equalizeHistTransform(CV_BGR2GRAY).flipImageTransform(0).resizeImageTransform(300, 300).rotateImageTransform(30).scaleImageTransform(3).warpImageTransform((float) 0.5).build(); String asJson = itp.toJson(); String asYaml = itp.toYaml(); - -// System.out.println(asJson); -// System.out.println("\n\n\n"); -// System.out.println(asYaml); - + // System.out.println(asJson); + // System.out.println("\n\n\n"); + // System.out.println(asYaml); ImageWritable img = TestImageTransform.makeRandomImage(0, 0, 3); ImageWritable imgJson = new ImageWritable(img.getFrame().clone()); ImageWritable imgYaml = new ImageWritable(img.getFrame().clone()); ImageWritable imgAll = new ImageWritable(img.getFrame().clone()); - ImageTransformProcess itpFromJson = ImageTransformProcess.fromJson(asJson); ImageTransformProcess itpFromYaml = ImageTransformProcess.fromYaml(asYaml); - List transformList = itp.getTransformList(); List transformListJson = itpFromJson.getTransformList(); List transformListYaml = itpFromYaml.getTransformList(); - for (int i = 0; i < transformList.size(); i++) { ImageTransform it = transformList.get(i); ImageTransform itJson = transformListJson.get(i); ImageTransform itYaml = transformListYaml.get(i); - System.out.println(i + "\t" + it); - img = it.transform(img); imgJson = itJson.transform(imgJson); imgYaml = itYaml.transform(imgYaml); - if (it instanceof RandomCropTransform) { assertTrue(img.getFrame().imageHeight == imgJson.getFrame().imageHeight); assertTrue(img.getFrame().imageWidth == imgJson.getFrame().imageWidth); - assertTrue(img.getFrame().imageHeight == imgYaml.getFrame().imageHeight); assertTrue(img.getFrame().imageWidth == imgYaml.getFrame().imageWidth); } else if (it instanceof FilterImageTransform) { assertEquals(img.getFrame().imageHeight, imgJson.getFrame().imageHeight); assertEquals(img.getFrame().imageWidth, imgJson.getFrame().imageWidth); assertEquals(img.getFrame().imageChannels, imgJson.getFrame().imageChannels); - assertEquals(img.getFrame().imageHeight, imgYaml.getFrame().imageHeight); assertEquals(img.getFrame().imageWidth, imgYaml.getFrame().imageWidth); assertEquals(img.getFrame().imageChannels, imgYaml.getFrame().imageChannels); } else { assertEquals(img, imgJson); - assertEquals(img, imgYaml); } } - imgAll = itp.execute(imgAll); - assertEquals(imgAll, img); } } diff --git a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/transform/ResizeImageTransformTest.java b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/transform/ResizeImageTransformTest.java index 33dae8c19..47ce04ec3 100644 --- a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/transform/ResizeImageTransformTest.java +++ b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/transform/ResizeImageTransformTest.java @@ -17,56 +17,50 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.datavec.image.transform; import org.bytedeco.javacv.Frame; import org.datavec.image.data.ImageWritable; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.assertEquals; - -public class ResizeImageTransformTest { - @Before - public void setUp() throws Exception { +@DisplayName("Resize Image Transform Test") +class ResizeImageTransformTest { + @BeforeEach + void setUp() throws Exception { } @Test - public void testResizeUpscale1() throws Exception { + @DisplayName("Test Resize Upscale 1") + void testResizeUpscale1() throws Exception { ImageWritable srcImg = TestImageTransform.makeRandomImage(32, 32, 3); - ResizeImageTransform transform = new ResizeImageTransform(200, 200); - ImageWritable dstImg = transform.transform(srcImg); - Frame f = dstImg.getFrame(); assertEquals(f.imageWidth, 200); assertEquals(f.imageHeight, 200); - - float[] coordinates = {100, 200}; + float[] coordinates = { 100, 200 }; float[] transformed = transform.query(coordinates); assertEquals(200f * 100 / 32, transformed[0], 0); assertEquals(200f * 200 / 32, transformed[1], 0); } @Test - public void testResizeDownscale() throws Exception { + @DisplayName("Test Resize Downscale") + void testResizeDownscale() throws Exception { ImageWritable srcImg = TestImageTransform.makeRandomImage(571, 443, 3); - ResizeImageTransform transform = new ResizeImageTransform(200, 200); - ImageWritable dstImg = transform.transform(srcImg); - Frame f = dstImg.getFrame(); assertEquals(f.imageWidth, 200); assertEquals(f.imageHeight, 200); - - float[] coordinates = {300, 400}; + float[] coordinates = { 300, 400 }; float[] transformed = transform.query(coordinates); assertEquals(200f * 300 / 443, transformed[0], 0); assertEquals(200f * 400 / 571, transformed[1], 0); } - } diff --git a/datavec/datavec-excel/src/test/java/org/datavec/poi/excel/ExcelRecordReaderTest.java b/datavec/datavec-excel/src/test/java/org/datavec/poi/excel/ExcelRecordReaderTest.java index 12e0b97c8..97de530c9 100644 --- a/datavec/datavec-excel/src/test/java/org/datavec/poi/excel/ExcelRecordReaderTest.java +++ b/datavec/datavec-excel/src/test/java/org/datavec/poi/excel/ExcelRecordReaderTest.java @@ -17,37 +17,34 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.datavec.poi.excel; import org.datavec.api.records.reader.RecordReader; import org.datavec.api.split.FileSplit; import org.datavec.api.writable.Writable; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.io.ClassPathResource; - import java.util.List; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; - -public class ExcelRecordReaderTest { +@DisplayName("Excel Record Reader Test") +class ExcelRecordReaderTest { @Test - public void testSimple() throws Exception { + @DisplayName("Test Simple") + void testSimple() throws Exception { RecordReader excel = new ExcelRecordReader(); excel.initialize(new FileSplit(new ClassPathResource("datavec-excel/testsheet.xlsx").getFile())); assertTrue(excel.hasNext()); List next = excel.next(); - assertEquals(3,next.size()); - + assertEquals(3, next.size()); RecordReader headerReader = new ExcelRecordReader(1); headerReader.initialize(new FileSplit(new ClassPathResource("datavec-excel/testsheetheader.xlsx").getFile())); assertTrue(excel.hasNext()); List next2 = excel.next(); - assertEquals(3,next2.size()); - - + assertEquals(3, next2.size()); } - } diff --git a/datavec/datavec-excel/src/test/java/org/datavec/poi/excel/ExcelRecordWriterTest.java b/datavec/datavec-excel/src/test/java/org/datavec/poi/excel/ExcelRecordWriterTest.java index ae132be87..3d03f764e 100644 --- a/datavec/datavec-excel/src/test/java/org/datavec/poi/excel/ExcelRecordWriterTest.java +++ b/datavec/datavec-excel/src/test/java/org/datavec/poi/excel/ExcelRecordWriterTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.datavec.poi.excel; import lombok.val; @@ -27,43 +26,44 @@ import org.datavec.api.transform.schema.Schema; import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.Writable; import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.primitives.Triple; - import java.io.File; import java.util.ArrayList; import java.util.List; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import java.nio.file.Path; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.assertEquals; +@DisplayName("Excel Record Writer Test") +class ExcelRecordWriterTest { -public class ExcelRecordWriterTest { - - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); + @TempDir + public Path testDir; @Test - public void testWriter() throws Exception { + @DisplayName("Test Writer") + void testWriter() throws Exception { ExcelRecordWriter excelRecordWriter = new ExcelRecordWriter(); val records = records(); - File tmpDir = testDir.newFolder(); - File outputFile = new File(tmpDir,"testexcel.xlsx"); + File tmpDir = testDir.toFile(); + File outputFile = new File(tmpDir, "testexcel.xlsx"); outputFile.deleteOnExit(); FileSplit fileSplit = new FileSplit(outputFile); - excelRecordWriter.initialize(fileSplit,new NumberOfRecordsPartitioner()); + excelRecordWriter.initialize(fileSplit, new NumberOfRecordsPartitioner()); excelRecordWriter.writeBatch(records.getRight()); excelRecordWriter.close(); File parentFile = outputFile.getParentFile(); - assertEquals(1,parentFile.list().length); - + assertEquals(1, parentFile.list().length); ExcelRecordReader excelRecordReader = new ExcelRecordReader(); excelRecordReader.initialize(fileSplit); List> next = excelRecordReader.next(10); - assertEquals(10,next.size()); - + assertEquals(10, next.size()); } - private Triple>> records() { + private Triple>> records() { List> list = new ArrayList<>(); StringBuilder sb = new StringBuilder(); int numColumns = 3; @@ -80,13 +80,10 @@ public class ExcelRecordWriterTest { } list.add(temp); } - - Schema.Builder schemaBuilder = new Schema.Builder(); - for(int i = 0; i < numColumns; i++) { + for (int i = 0; i < numColumns; i++) { schemaBuilder.addColumnInteger(String.valueOf(i)); } - - return Triple.of(sb.toString(),schemaBuilder.build(),list); + return Triple.of(sb.toString(), schemaBuilder.build(), list); } } diff --git a/datavec/datavec-jdbc/src/test/java/org/datavec/api/records/reader/impl/JDBCRecordReaderTest.java b/datavec/datavec-jdbc/src/test/java/org/datavec/api/records/reader/impl/JDBCRecordReaderTest.java index ebd832dbc..fb7daa5e9 100644 --- a/datavec/datavec-jdbc/src/test/java/org/datavec/api/records/reader/impl/JDBCRecordReaderTest.java +++ b/datavec/datavec-jdbc/src/test/java/org/datavec/api/records/reader/impl/JDBCRecordReaderTest.java @@ -17,14 +17,12 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.datavec.api.records.reader.impl; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; - +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; import java.io.File; import java.net.URI; import java.sql.Connection; @@ -49,53 +47,57 @@ import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.LongWritable; import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; -import org.junit.After; -import org.junit.Before; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; +import org.junit.jupiter.api.DisplayName; +import java.nio.file.Path; +import org.junit.jupiter.api.extension.ExtendWith; +import static org.junit.jupiter.api.Assertions.assertThrows; -public class JDBCRecordReaderTest { +@DisplayName("Jdbc Record Reader Test") +class JDBCRecordReaderTest { - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); + @TempDir + public Path testDir; Connection conn; + EmbeddedDataSource dataSource; private final String dbName = "datavecTests"; + private final String driverClassName = "org.apache.derby.jdbc.EmbeddedDriver"; - @Before - public void setUp() throws Exception { - File f = testDir.newFolder(); + @BeforeEach + void setUp() throws Exception { + File f = testDir.toFile(); System.setProperty("derby.system.home", f.getAbsolutePath()); - dataSource = new EmbeddedDataSource(); dataSource.setDatabaseName(dbName); dataSource.setCreateDatabase("create"); conn = dataSource.getConnection(); - TestDb.dropTables(conn); TestDb.buildCoffeeTable(conn); } - @After - public void tearDown() throws Exception { + @AfterEach + void tearDown() throws Exception { DbUtils.closeQuietly(conn); } @Test - public void testSimpleIter() throws Exception { + @DisplayName("Test Simple Iter") + void testSimpleIter() throws Exception { try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) { List> records = new ArrayList<>(); while (reader.hasNext()) { List values = reader.next(); records.add(values); } - assertFalse(records.isEmpty()); - List first = records.get(0); assertEquals(new Text("Bolivian Dark"), first.get(0)); assertEquals(new Text("14-001"), first.get(1)); @@ -104,39 +106,43 @@ public class JDBCRecordReaderTest { } @Test - public void testSimpleWithListener() throws Exception { + @DisplayName("Test Simple With Listener") + void testSimpleWithListener() throws Exception { try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) { RecordListener recordListener = new LogRecordListener(); reader.setListeners(recordListener); reader.next(); - assertTrue(recordListener.invoked()); } } @Test - public void testReset() throws Exception { + @DisplayName("Test Reset") + void testReset() throws Exception { try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) { List> records = new ArrayList<>(); records.add(reader.next()); reader.reset(); records.add(reader.next()); - assertEquals(2, records.size()); assertEquals(new Text("Bolivian Dark"), records.get(0).get(0)); assertEquals(new Text("Bolivian Dark"), records.get(1).get(0)); } } - @Test(expected = IllegalStateException.class) - public void testLackingDataSourceShouldFail() throws Exception { - try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM Coffee")) { - reader.initialize(null); - } + @Test + @DisplayName("Test Lacking Data Source Should Fail") + void testLackingDataSourceShouldFail() { + assertThrows(IllegalStateException.class, () -> { + try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM Coffee")) { + reader.initialize(null); + } + }); } @Test - public void testConfigurationDataSourceInitialization() throws Exception { + @DisplayName("Test Configuration Data Source Initialization") + void testConfigurationDataSourceInitialization() throws Exception { try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM Coffee")) { Configuration conf = new Configuration(); conf.set(JDBCRecordReader.JDBC_URL, "jdbc:derby:" + dbName + ";create=true"); @@ -146,28 +152,33 @@ public class JDBCRecordReaderTest { } } - @Test(expected = IllegalArgumentException.class) - public void testInitConfigurationMissingParametersShouldFail() throws Exception { - try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM Coffee")) { - Configuration conf = new Configuration(); - conf.set(JDBCRecordReader.JDBC_URL, "should fail anyway"); - reader.initialize(conf, null); - } - } - - @Test(expected = UnsupportedOperationException.class) - public void testRecordDataInputStreamShouldFail() throws Exception { - try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) { - reader.record(null, null); - } + @Test + @DisplayName("Test Init Configuration Missing Parameters Should Fail") + void testInitConfigurationMissingParametersShouldFail() { + assertThrows(IllegalArgumentException.class, () -> { + try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM Coffee")) { + Configuration conf = new Configuration(); + conf.set(JDBCRecordReader.JDBC_URL, "should fail anyway"); + reader.initialize(conf, null); + } + }); } @Test - public void testLoadFromMetaData() throws Exception { - try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) { - RecordMetaDataJdbc rmd = new RecordMetaDataJdbc(new URI(conn.getMetaData().getURL()), - "SELECT * FROM Coffee WHERE ProdNum = ?", Collections.singletonList("14-001"), reader.getClass()); + @DisplayName("Test Record Data Input Stream Should Fail") + void testRecordDataInputStreamShouldFail() { + assertThrows(UnsupportedOperationException.class, () -> { + try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) { + reader.record(null, null); + } + }); + } + @Test + @DisplayName("Test Load From Meta Data") + void testLoadFromMetaData() throws Exception { + try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) { + RecordMetaDataJdbc rmd = new RecordMetaDataJdbc(new URI(conn.getMetaData().getURL()), "SELECT * FROM Coffee WHERE ProdNum = ?", Collections.singletonList("14-001"), reader.getClass()); Record res = reader.loadFromMetaData(rmd); assertNotNull(res); assertEquals(new Text("Bolivian Dark"), res.getRecord().get(0)); @@ -177,7 +188,8 @@ public class JDBCRecordReaderTest { } @Test - public void testNextRecord() throws Exception { + @DisplayName("Test Next Record") + void testNextRecord() throws Exception { try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) { Record r = reader.nextRecord(); List fields = r.getRecord(); @@ -193,7 +205,8 @@ public class JDBCRecordReaderTest { } @Test - public void testNextRecordAndRecover() throws Exception { + @DisplayName("Test Next Record And Recover") + void testNextRecordAndRecover() throws Exception { try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) { Record r = reader.nextRecord(); List fields = r.getRecord(); @@ -208,69 +221,91 @@ public class JDBCRecordReaderTest { } // Resetting the record reader when initialized as forward only should fail - @Test(expected = RuntimeException.class) - public void testResetForwardOnlyShouldFail() throws Exception { - try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM Coffee", dataSource)) { - Configuration conf = new Configuration(); - conf.setInt(JDBCRecordReader.JDBC_RESULTSET_TYPE, ResultSet.TYPE_FORWARD_ONLY); - reader.initialize(conf, null); - reader.next(); - reader.reset(); - } + @Test + @DisplayName("Test Reset Forward Only Should Fail") + void testResetForwardOnlyShouldFail() { + assertThrows(RuntimeException.class, () -> { + try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM Coffee", dataSource)) { + Configuration conf = new Configuration(); + conf.setInt(JDBCRecordReader.JDBC_RESULTSET_TYPE, ResultSet.TYPE_FORWARD_ONLY); + reader.initialize(conf, null); + reader.next(); + reader.reset(); + } + }); } @Test - public void testReadAllTypes() throws Exception { + @DisplayName("Test Read All Types") + void testReadAllTypes() throws Exception { TestDb.buildAllTypesTable(conn); try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM AllTypes", dataSource)) { reader.initialize(null); List item = reader.next(); - assertEquals(item.size(), 15); - assertEquals(BooleanWritable.class, item.get(0).getClass()); // boolean to boolean - assertEquals(Text.class, item.get(1).getClass()); // date to text - assertEquals(Text.class, item.get(2).getClass()); // time to text - assertEquals(Text.class, item.get(3).getClass()); // timestamp to text - assertEquals(Text.class, item.get(4).getClass()); // char to text - assertEquals(Text.class, item.get(5).getClass()); // long varchar to text - assertEquals(Text.class, item.get(6).getClass()); // varchar to text - assertEquals(DoubleWritable.class, - item.get(7).getClass()); // float to double (derby's float is an alias of double by default) - assertEquals(FloatWritable.class, item.get(8).getClass()); // real to float - assertEquals(DoubleWritable.class, item.get(9).getClass()); // decimal to double - assertEquals(DoubleWritable.class, item.get(10).getClass()); // numeric to double - assertEquals(DoubleWritable.class, item.get(11).getClass()); // double to double - assertEquals(IntWritable.class, item.get(12).getClass()); // integer to integer - assertEquals(IntWritable.class, item.get(13).getClass()); // small int to integer - assertEquals(LongWritable.class, item.get(14).getClass()); // bigint to long - + // boolean to boolean + assertEquals(BooleanWritable.class, item.get(0).getClass()); + // date to text + assertEquals(Text.class, item.get(1).getClass()); + // time to text + assertEquals(Text.class, item.get(2).getClass()); + // timestamp to text + assertEquals(Text.class, item.get(3).getClass()); + // char to text + assertEquals(Text.class, item.get(4).getClass()); + // long varchar to text + assertEquals(Text.class, item.get(5).getClass()); + // varchar to text + assertEquals(Text.class, item.get(6).getClass()); + assertEquals(DoubleWritable.class, // float to double (derby's float is an alias of double by default) + item.get(7).getClass()); + // real to float + assertEquals(FloatWritable.class, item.get(8).getClass()); + // decimal to double + assertEquals(DoubleWritable.class, item.get(9).getClass()); + // numeric to double + assertEquals(DoubleWritable.class, item.get(10).getClass()); + // double to double + assertEquals(DoubleWritable.class, item.get(11).getClass()); + // integer to integer + assertEquals(IntWritable.class, item.get(12).getClass()); + // small int to integer + assertEquals(IntWritable.class, item.get(13).getClass()); + // bigint to long + assertEquals(LongWritable.class, item.get(14).getClass()); } } - @Test(expected = RuntimeException.class) - public void testNextNoMoreShouldFail() throws Exception { - try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) { - while (reader.hasNext()) { + @Test + @DisplayName("Test Next No More Should Fail") + void testNextNoMoreShouldFail() { + assertThrows(RuntimeException.class, () -> { + try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) { + while (reader.hasNext()) { + reader.next(); + } reader.next(); } - reader.next(); - } + }); } - @Test(expected = IllegalArgumentException.class) - public void testInvalidMetadataShouldFail() throws Exception { - try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) { - RecordMetaDataLine md = new RecordMetaDataLine(1, new URI("file://test"), JDBCRecordReader.class); - reader.loadFromMetaData(md); - } + @Test + @DisplayName("Test Invalid Metadata Should Fail") + void testInvalidMetadataShouldFail() { + assertThrows(IllegalArgumentException.class, () -> { + try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) { + RecordMetaDataLine md = new RecordMetaDataLine(1, new URI("file://test"), JDBCRecordReader.class); + reader.loadFromMetaData(md); + } + }); } private JDBCRecordReader getInitializedReader(String query) throws Exception { - int[] indices = {1}; // ProdNum column - JDBCRecordReader reader = new JDBCRecordReader(query, dataSource, "SELECT * FROM Coffee WHERE ProdNum = ?", - indices); + // ProdNum column + int[] indices = { 1 }; + JDBCRecordReader reader = new JDBCRecordReader(query, dataSource, "SELECT * FROM Coffee WHERE ProdNum = ?", indices); reader.setTrimStrings(true); reader.initialize(null); return reader; } -} \ No newline at end of file +} diff --git a/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/ExecutionTest.java b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/ExecutionTest.java index 67c6ace3d..4a85c255b 100644 --- a/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/ExecutionTest.java +++ b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/ExecutionTest.java @@ -17,10 +17,8 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.datavec.local.transforms.transform; - import org.datavec.api.transform.MathFunction; import org.datavec.api.transform.MathOp; import org.datavec.api.transform.ReduceOp; @@ -32,107 +30,86 @@ import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.SequenceSchema; import org.datavec.api.writable.*; import org.datavec.python.PythonTransform; - import org.datavec.local.transforms.LocalTransformExecutor; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.ops.transforms.Transforms; - import java.util.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; +import static java.time.Duration.ofMillis; +import static org.junit.jupiter.api.Assertions.assertTimeout; -import static org.junit.Assert.assertEquals; - -public class ExecutionTest { +@DisplayName("Execution Test") +class ExecutionTest { @Test - public void testExecutionNdarray() { - Schema schema = new Schema.Builder() - .addColumnNDArray("first",new long[]{1,32577}) - .addColumnNDArray("second",new long[]{1,32577}).build(); - - TransformProcess transformProcess = new TransformProcess.Builder(schema) - .ndArrayMathFunctionTransform("first", MathFunction.SIN) - .ndArrayMathFunctionTransform("second",MathFunction.COS) - .build(); - + @DisplayName("Test Execution Ndarray") + void testExecutionNdarray() { + Schema schema = new Schema.Builder().addColumnNDArray("first", new long[] { 1, 32577 }).addColumnNDArray("second", new long[] { 1, 32577 }).build(); + TransformProcess transformProcess = new TransformProcess.Builder(schema).ndArrayMathFunctionTransform("first", MathFunction.SIN).ndArrayMathFunctionTransform("second", MathFunction.COS).build(); List> functions = new ArrayList<>(); List firstRow = new ArrayList<>(); - INDArray firstArr = Nd4j.linspace(1,4,4); - INDArray secondArr = Nd4j.linspace(1,4,4); + INDArray firstArr = Nd4j.linspace(1, 4, 4); + INDArray secondArr = Nd4j.linspace(1, 4, 4); firstRow.add(new NDArrayWritable(firstArr)); firstRow.add(new NDArrayWritable(secondArr)); functions.add(firstRow); - List> execute = LocalTransformExecutor.execute(functions, transformProcess); INDArray firstResult = ((NDArrayWritable) execute.get(0).get(0)).get(); INDArray secondResult = ((NDArrayWritable) execute.get(0).get(1)).get(); - INDArray expected = Transforms.sin(firstArr); INDArray secondExpected = Transforms.cos(secondArr); - assertEquals(expected,firstResult); - assertEquals(secondExpected,secondResult); - + assertEquals(expected, firstResult); + assertEquals(secondExpected, secondResult); } @Test - public void testExecutionSimple() { - Schema schema = new Schema.Builder().addColumnInteger("col0") - .addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2"). - addColumnFloat("col3").build(); - - TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1") - .doubleMathOp("col2", MathOp.Add, 10.0).floatMathOp("col3", MathOp.Add, 5f).build(); - + @DisplayName("Test Execution Simple") + void testExecutionSimple() { + Schema schema = new Schema.Builder().addColumnInteger("col0").addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").addColumnFloat("col3").build(); + TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1").doubleMathOp("col2", MathOp.Add, 10.0).floatMathOp("col3", MathOp.Add, 5f).build(); List> inputData = new ArrayList<>(); inputData.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1), new FloatWritable(0.3f))); inputData.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1), new FloatWritable(1.7f))); inputData.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1), new FloatWritable(3.6f))); - List> rdd = (inputData); - List> out = new ArrayList<>(LocalTransformExecutor.execute(rdd, tp)); - Collections.sort(out, new Comparator>() { + @Override public int compare(List o1, List o2) { return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt()); } }); - List> expected = new ArrayList<>(); expected.add(Arrays.asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1), new FloatWritable(5.3f))); expected.add(Arrays.asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1), new FloatWritable(6.7f))); expected.add(Arrays.asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1), new FloatWritable(8.6f))); - assertEquals(expected, out); } @Test - public void testFilter() { - Schema filterSchema = new Schema.Builder() - .addColumnDouble("col1").addColumnDouble("col2") - .addColumnDouble("col3").build(); + @DisplayName("Test Filter") + void testFilter() { + Schema filterSchema = new Schema.Builder().addColumnDouble("col1").addColumnDouble("col2").addColumnDouble("col3").build(); List> inputData = new ArrayList<>(); inputData.add(Arrays.asList(new IntWritable(0), new DoubleWritable(1), new DoubleWritable(0.1))); inputData.add(Arrays.asList(new IntWritable(1), new DoubleWritable(3), new DoubleWritable(1.1))); inputData.add(Arrays.asList(new IntWritable(2), new DoubleWritable(3), new DoubleWritable(2.1))); - TransformProcess transformProcess = new TransformProcess.Builder(filterSchema) - .filter(new DoubleColumnCondition("col1",ConditionOp.LessThan,1)).build(); + TransformProcess transformProcess = new TransformProcess.Builder(filterSchema).filter(new DoubleColumnCondition("col1", ConditionOp.LessThan, 1)).build(); List> execute = LocalTransformExecutor.execute(inputData, transformProcess); - assertEquals(2,execute.size()); + assertEquals(2, execute.size()); } @Test - public void testExecutionSequence() { - - Schema schema = new SequenceSchema.Builder().addColumnInteger("col0") - .addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").build(); - - TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1") - .doubleMathOp("col2", MathOp.Add, 10.0).build(); - + @DisplayName("Test Execution Sequence") + void testExecutionSequence() { + Schema schema = new SequenceSchema.Builder().addColumnInteger("col0").addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").build(); + TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1").doubleMathOp("col2", MathOp.Add, 10.0).build(); List>> inputSequences = new ArrayList<>(); List> seq1 = new ArrayList<>(); seq1.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); @@ -141,21 +118,17 @@ public class ExecutionTest { List> seq2 = new ArrayList<>(); seq2.add(Arrays.asList(new IntWritable(3), new Text("state0"), new DoubleWritable(3.1))); seq2.add(Arrays.asList(new IntWritable(4), new Text("state1"), new DoubleWritable(4.1))); - inputSequences.add(seq1); inputSequences.add(seq2); - - List>> rdd = (inputSequences); - + List>> rdd = (inputSequences); List>> out = LocalTransformExecutor.executeSequenceToSequence(rdd, tp); - Collections.sort(out, new Comparator>>() { + @Override public int compare(List> o1, List> o2) { return -Integer.compare(o1.size(), o2.size()); } }); - List>> expectedSequence = new ArrayList<>(); List> seq1e = new ArrayList<>(); seq1e.add(Arrays.asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1))); @@ -164,121 +137,66 @@ public class ExecutionTest { List> seq2e = new ArrayList<>(); seq2e.add(Arrays.asList(new IntWritable(3), new IntWritable(0), new DoubleWritable(13.1))); seq2e.add(Arrays.asList(new IntWritable(4), new IntWritable(1), new DoubleWritable(14.1))); - expectedSequence.add(seq1e); expectedSequence.add(seq2e); - assertEquals(expectedSequence, out); } - @Test - public void testReductionGlobal() { - - List> in = Arrays.asList( - Arrays.asList(new Text("first"), new DoubleWritable(3.0)), - Arrays.asList(new Text("second"), new DoubleWritable(5.0)) - ); - + @DisplayName("Test Reduction Global") + void testReductionGlobal() { + List> in = Arrays.asList(Arrays.asList(new Text("first"), new DoubleWritable(3.0)), Arrays.asList(new Text("second"), new DoubleWritable(5.0))); List> inData = in; - - Schema s = new Schema.Builder() - .addColumnString("textCol") - .addColumnDouble("doubleCol") - .build(); - - TransformProcess tp = new TransformProcess.Builder(s) - .reduce(new Reducer.Builder(ReduceOp.TakeFirst) - .takeFirstColumns("textCol") - .meanColumns("doubleCol").build()) - .build(); - + Schema s = new Schema.Builder().addColumnString("textCol").addColumnDouble("doubleCol").build(); + TransformProcess tp = new TransformProcess.Builder(s).reduce(new Reducer.Builder(ReduceOp.TakeFirst).takeFirstColumns("textCol").meanColumns("doubleCol").build()).build(); List> outRdd = LocalTransformExecutor.execute(inData, tp); - List> out = outRdd; - List> expOut = Collections.singletonList(Arrays.asList(new Text("first"), new DoubleWritable(4.0))); - assertEquals(expOut, out); } @Test - public void testReductionByKey(){ - - List> in = Arrays.asList( - Arrays.asList(new IntWritable(0), new Text("first"), new DoubleWritable(3.0)), - Arrays.asList(new IntWritable(0), new Text("second"), new DoubleWritable(5.0)), - Arrays.asList(new IntWritable(1), new Text("f"), new DoubleWritable(30.0)), - Arrays.asList(new IntWritable(1), new Text("s"), new DoubleWritable(50.0)) - ); - + @DisplayName("Test Reduction By Key") + void testReductionByKey() { + List> in = Arrays.asList(Arrays.asList(new IntWritable(0), new Text("first"), new DoubleWritable(3.0)), Arrays.asList(new IntWritable(0), new Text("second"), new DoubleWritable(5.0)), Arrays.asList(new IntWritable(1), new Text("f"), new DoubleWritable(30.0)), Arrays.asList(new IntWritable(1), new Text("s"), new DoubleWritable(50.0))); List> inData = in; - - Schema s = new Schema.Builder() - .addColumnInteger("intCol") - .addColumnString("textCol") - .addColumnDouble("doubleCol") - .build(); - - TransformProcess tp = new TransformProcess.Builder(s) - .reduce(new Reducer.Builder(ReduceOp.TakeFirst) - .keyColumns("intCol") - .takeFirstColumns("textCol") - .meanColumns("doubleCol").build()) - .build(); - + Schema s = new Schema.Builder().addColumnInteger("intCol").addColumnString("textCol").addColumnDouble("doubleCol").build(); + TransformProcess tp = new TransformProcess.Builder(s).reduce(new Reducer.Builder(ReduceOp.TakeFirst).keyColumns("intCol").takeFirstColumns("textCol").meanColumns("doubleCol").build()).build(); List> outRdd = LocalTransformExecutor.execute(inData, tp); - List> out = outRdd; - - List> expOut = Arrays.asList( - Arrays.asList(new IntWritable(0), new Text("first"), new DoubleWritable(4.0)), - Arrays.asList(new IntWritable(1), new Text("f"), new DoubleWritable(40.0))); - + List> expOut = Arrays.asList(Arrays.asList(new IntWritable(0), new Text("first"), new DoubleWritable(4.0)), Arrays.asList(new IntWritable(1), new Text("f"), new DoubleWritable(40.0))); out = new ArrayList<>(out); - Collections.sort( - out, new Comparator>() { - @Override - public int compare(List o1, List o2) { - return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt()); - } - } - ); + Collections.sort(out, new Comparator>() { + @Override + public int compare(List o1, List o2) { + return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt()); + } + }); assertEquals(expOut, out); } - @Test(timeout = 60000L) - @Ignore("AB 2019/05/21 - Fine locally, timeouts on CI - Issue #7657 and #7771") - public void testPythonExecutionNdarray()throws Exception{ - Schema schema = new Schema.Builder() - .addColumnNDArray("first",new long[]{1,32577}) - .addColumnNDArray("second",new long[]{1,32577}).build(); - - TransformProcess transformProcess = new TransformProcess.Builder(schema) - .transform( - PythonTransform.builder().code( - "first = np.sin(first)\nsecond = np.cos(second)") - .outputSchema(schema).build()) - .build(); - - List> functions = new ArrayList<>(); - List firstRow = new ArrayList<>(); - INDArray firstArr = Nd4j.linspace(1,4,4); - INDArray secondArr = Nd4j.linspace(1,4,4); - firstRow.add(new NDArrayWritable(firstArr)); - firstRow.add(new NDArrayWritable(secondArr)); - functions.add(firstRow); - - List> execute = LocalTransformExecutor.execute(functions, transformProcess); - INDArray firstResult = ((NDArrayWritable) execute.get(0).get(0)).get(); - INDArray secondResult = ((NDArrayWritable) execute.get(0).get(1)).get(); - - INDArray expected = Transforms.sin(firstArr); - INDArray secondExpected = Transforms.cos(secondArr); - assertEquals(expected,firstResult); - assertEquals(secondExpected,secondResult); - + @Test + @Disabled("AB 2019/05/21 - Fine locally, timeouts on CI - Issue #7657 and #7771") + @DisplayName("Test Python Execution Ndarray") + void testPythonExecutionNdarray() { + assertTimeout(ofMillis(60000), () -> { + Schema schema = new Schema.Builder().addColumnNDArray("first", new long[] { 1, 32577 }).addColumnNDArray("second", new long[] { 1, 32577 }).build(); + TransformProcess transformProcess = new TransformProcess.Builder(schema).transform(PythonTransform.builder().code("first = np.sin(first)\nsecond = np.cos(second)").outputSchema(schema).build()).build(); + List> functions = new ArrayList<>(); + List firstRow = new ArrayList<>(); + INDArray firstArr = Nd4j.linspace(1, 4, 4); + INDArray secondArr = Nd4j.linspace(1, 4, 4); + firstRow.add(new NDArrayWritable(firstArr)); + firstRow.add(new NDArrayWritable(secondArr)); + functions.add(firstRow); + List> execute = LocalTransformExecutor.execute(functions, transformProcess); + INDArray firstResult = ((NDArrayWritable) execute.get(0).get(0)).get(); + INDArray secondResult = ((NDArrayWritable) execute.get(0).get(1)).get(); + INDArray expected = Transforms.sin(firstArr); + INDArray secondExpected = Transforms.cos(secondArr); + assertEquals(expected, firstResult); + assertEquals(secondExpected, secondResult); + }); } - } diff --git a/datavec/datavec-spark/src/test/java/org/datavec/spark/BaseSparkTest.java b/datavec/datavec-spark/src/test/java/org/datavec/spark/BaseSparkTest.java index 3dc0e3bff..701ca7b04 100644 --- a/datavec/datavec-spark/src/test/java/org/datavec/spark/BaseSparkTest.java +++ b/datavec/datavec-spark/src/test/java/org/datavec/spark/BaseSparkTest.java @@ -17,36 +17,38 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.datavec.spark; import lombok.extern.slf4j.Slf4j; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaSparkContext; -import org.junit.After; -import org.junit.Before; - +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; import java.io.Serializable; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; @Slf4j +@DisplayName("Base Spark Test") public abstract class BaseSparkTest implements Serializable { + protected static JavaSparkContext sc; - @Before - public void before() { + @BeforeEach + void before() { sc = getContext(); } - @After - public synchronized void after() { + @AfterEach + synchronized void after() { sc.close(); - //Wait until it's stopped, to avoid race conditions during tests + // Wait until it's stopped, to avoid race conditions during tests for (int i = 0; i < 100; i++) { if (!sc.sc().stopped().get()) { try { Thread.sleep(100L); } catch (InterruptedException e) { - log.error("",e); + log.error("", e); } } else { break; @@ -55,29 +57,21 @@ public abstract class BaseSparkTest implements Serializable { if (!sc.sc().stopped().get()) { throw new RuntimeException("Spark context is not stopped after 10s"); } - - sc = null; } public synchronized JavaSparkContext getContext() { if (sc != null) return sc; - - SparkConf sparkConf = new SparkConf().setMaster("local[*]").set("spark.driver.host", "localhost") - .set("spark.driverEnv.SPARK_LOCAL_IP", "127.0.0.1") - .set("spark.executorEnv.SPARK_LOCAL_IP", "127.0.0.1").setAppName("sparktest"); + SparkConf sparkConf = new SparkConf().setMaster("local[*]").set("spark.driver.host", "localhost").set("spark.driverEnv.SPARK_LOCAL_IP", "127.0.0.1").set("spark.executorEnv.SPARK_LOCAL_IP", "127.0.0.1").setAppName("sparktest"); if (useKryo()) { sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer"); } - - sc = new JavaSparkContext(sparkConf); - return sc; } - public boolean useKryo(){ + public boolean useKryo() { return false; } } diff --git a/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/ExecutionTest.java b/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/ExecutionTest.java index 0b93af28a..6a1015197 100644 --- a/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/ExecutionTest.java +++ b/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/ExecutionTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.datavec.spark.transform; import org.apache.spark.api.java.JavaRDD; @@ -35,59 +34,51 @@ import org.datavec.api.writable.Writable; import org.datavec.api.writable.NDArrayWritable; import org.datavec.spark.BaseSparkTest; import org.datavec.python.PythonTransform; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; - import java.util.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; +import static java.time.Duration.ofMillis; +import static org.junit.jupiter.api.Assertions.assertTimeout; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; - -public class ExecutionTest extends BaseSparkTest { +@DisplayName("Execution Test") +class ExecutionTest extends BaseSparkTest { @Test - public void testExecutionSimple() { - Schema schema = new Schema.Builder().addColumnInteger("col0") - .addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").build(); - - TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1") - .doubleMathOp("col2", MathOp.Add, 10.0).build(); - + @DisplayName("Test Execution Simple") + void testExecutionSimple() { + Schema schema = new Schema.Builder().addColumnInteger("col0").addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").build(); + TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1").doubleMathOp("col2", MathOp.Add, 10.0).build(); List> inputData = new ArrayList<>(); inputData.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); inputData.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1))); inputData.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1))); - JavaRDD> rdd = sc.parallelize(inputData); - List> out = new ArrayList<>(SparkTransformExecutor.execute(rdd, tp).collect()); - Collections.sort(out, new Comparator>() { + @Override public int compare(List o1, List o2) { return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt()); } }); - List> expected = new ArrayList<>(); expected.add(Arrays.asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1))); expected.add(Arrays.asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1))); expected.add(Arrays.asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1))); - assertEquals(expected, out); } @Test - public void testExecutionSequence() { - - Schema schema = new SequenceSchema.Builder().addColumnInteger("col0") - .addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").build(); - - TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1") - .doubleMathOp("col2", MathOp.Add, 10.0).build(); - + @DisplayName("Test Execution Sequence") + void testExecutionSequence() { + Schema schema = new SequenceSchema.Builder().addColumnInteger("col0").addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").build(); + TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1").doubleMathOp("col2", MathOp.Add, 10.0).build(); List>> inputSequences = new ArrayList<>(); List> seq1 = new ArrayList<>(); seq1.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); @@ -96,22 +87,17 @@ public class ExecutionTest extends BaseSparkTest { List> seq2 = new ArrayList<>(); seq2.add(Arrays.asList(new IntWritable(3), new Text("state0"), new DoubleWritable(3.1))); seq2.add(Arrays.asList(new IntWritable(4), new Text("state1"), new DoubleWritable(4.1))); - inputSequences.add(seq1); inputSequences.add(seq2); - JavaRDD>> rdd = sc.parallelize(inputSequences); - - List>> out = - new ArrayList<>(SparkTransformExecutor.executeSequenceToSequence(rdd, tp).collect()); - + List>> out = new ArrayList<>(SparkTransformExecutor.executeSequenceToSequence(rdd, tp).collect()); Collections.sort(out, new Comparator>>() { + @Override public int compare(List> o1, List> o2) { return -Integer.compare(o1.size(), o2.size()); } }); - List>> expectedSequence = new ArrayList<>(); List> seq1e = new ArrayList<>(); seq1e.add(Arrays.asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1))); @@ -120,99 +106,49 @@ public class ExecutionTest extends BaseSparkTest { List> seq2e = new ArrayList<>(); seq2e.add(Arrays.asList(new IntWritable(3), new IntWritable(0), new DoubleWritable(13.1))); seq2e.add(Arrays.asList(new IntWritable(4), new IntWritable(1), new DoubleWritable(14.1))); - expectedSequence.add(seq1e); expectedSequence.add(seq2e); - assertEquals(expectedSequence, out); } - @Test - public void testReductionGlobal() { - - List> in = Arrays.asList( - Arrays.asList(new Text("first"), new DoubleWritable(3.0)), - Arrays.asList(new Text("second"), new DoubleWritable(5.0)) - ); - + @DisplayName("Test Reduction Global") + void testReductionGlobal() { + List> in = Arrays.asList(Arrays.asList(new Text("first"), new DoubleWritable(3.0)), Arrays.asList(new Text("second"), new DoubleWritable(5.0))); JavaRDD> inData = sc.parallelize(in); - - Schema s = new Schema.Builder() - .addColumnString("textCol") - .addColumnDouble("doubleCol") - .build(); - - TransformProcess tp = new TransformProcess.Builder(s) - .reduce(new Reducer.Builder(ReduceOp.TakeFirst) - .takeFirstColumns("textCol") - .meanColumns("doubleCol").build()) - .build(); - + Schema s = new Schema.Builder().addColumnString("textCol").addColumnDouble("doubleCol").build(); + TransformProcess tp = new TransformProcess.Builder(s).reduce(new Reducer.Builder(ReduceOp.TakeFirst).takeFirstColumns("textCol").meanColumns("doubleCol").build()).build(); JavaRDD> outRdd = SparkTransformExecutor.execute(inData, tp); - List> out = outRdd.collect(); - List> expOut = Collections.singletonList(Arrays.asList(new Text("first"), new DoubleWritable(4.0))); - assertEquals(expOut, out); } @Test - public void testReductionByKey(){ - - List> in = Arrays.asList( - Arrays.asList(new IntWritable(0), new Text("first"), new DoubleWritable(3.0)), - Arrays.asList(new IntWritable(0), new Text("second"), new DoubleWritable(5.0)), - Arrays.asList(new IntWritable(1), new Text("f"), new DoubleWritable(30.0)), - Arrays.asList(new IntWritable(1), new Text("s"), new DoubleWritable(50.0)) - ); - + @DisplayName("Test Reduction By Key") + void testReductionByKey() { + List> in = Arrays.asList(Arrays.asList(new IntWritable(0), new Text("first"), new DoubleWritable(3.0)), Arrays.asList(new IntWritable(0), new Text("second"), new DoubleWritable(5.0)), Arrays.asList(new IntWritable(1), new Text("f"), new DoubleWritable(30.0)), Arrays.asList(new IntWritable(1), new Text("s"), new DoubleWritable(50.0))); JavaRDD> inData = sc.parallelize(in); - - Schema s = new Schema.Builder() - .addColumnInteger("intCol") - .addColumnString("textCol") - .addColumnDouble("doubleCol") - .build(); - - TransformProcess tp = new TransformProcess.Builder(s) - .reduce(new Reducer.Builder(ReduceOp.TakeFirst) - .keyColumns("intCol") - .takeFirstColumns("textCol") - .meanColumns("doubleCol").build()) - .build(); - + Schema s = new Schema.Builder().addColumnInteger("intCol").addColumnString("textCol").addColumnDouble("doubleCol").build(); + TransformProcess tp = new TransformProcess.Builder(s).reduce(new Reducer.Builder(ReduceOp.TakeFirst).keyColumns("intCol").takeFirstColumns("textCol").meanColumns("doubleCol").build()).build(); JavaRDD> outRdd = SparkTransformExecutor.execute(inData, tp); - List> out = outRdd.collect(); - - List> expOut = Arrays.asList( - Arrays.asList(new IntWritable(0), new Text("first"), new DoubleWritable(4.0)), - Arrays.asList(new IntWritable(1), new Text("f"), new DoubleWritable(40.0))); - + List> expOut = Arrays.asList(Arrays.asList(new IntWritable(0), new Text("first"), new DoubleWritable(4.0)), Arrays.asList(new IntWritable(1), new Text("f"), new DoubleWritable(40.0))); out = new ArrayList<>(out); - Collections.sort( - out, new Comparator>() { - @Override - public int compare(List o1, List o2) { - return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt()); - } - } - ); + Collections.sort(out, new Comparator>() { + @Override + public int compare(List o1, List o2) { + return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt()); + } + }); assertEquals(expOut, out); } - @Test - public void testUniqueMultiCol(){ - - Schema schema = new Schema.Builder() - .addColumnInteger("col0") - .addColumnCategorical("col1", "state0", "state1", "state2") - .addColumnDouble("col2").build(); - + @DisplayName("Test Unique Multi Col") + void testUniqueMultiCol() { + Schema schema = new Schema.Builder().addColumnInteger("col0").addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").build(); List> inputData = new ArrayList<>(); inputData.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); inputData.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1))); @@ -223,149 +159,103 @@ public class ExecutionTest extends BaseSparkTest { inputData.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); inputData.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1))); inputData.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1))); - JavaRDD> rdd = sc.parallelize(inputData); - - Map> l = AnalyzeSpark.getUnique(Arrays.asList("col0", "col1"), schema, rdd); - + Map> l = AnalyzeSpark.getUnique(Arrays.asList("col0", "col1"), schema, rdd); assertEquals(2, l.size()); List c0 = l.get("col0"); assertEquals(3, c0.size()); assertTrue(c0.contains(new IntWritable(0)) && c0.contains(new IntWritable(1)) && c0.contains(new IntWritable(2))); - List c1 = l.get("col1"); assertEquals(3, c1.size()); assertTrue(c1.contains(new Text("state0")) && c1.contains(new Text("state1")) && c1.contains(new Text("state2"))); } - @Test(timeout = 60000L) - @Ignore("AB 2019/05/21 - Fine locally, timeouts on CI - Issue #7657 and #7771") - public void testPythonExecution() throws Exception { - Schema schema = new Schema.Builder().addColumnInteger("col0") - .addColumnString("col1").addColumnDouble("col2").build(); + @Test + @Disabled("AB 2019/05/21 - Fine locally, timeouts on CI - Issue #7657 and #7771") + @DisplayName("Test Python Execution") + void testPythonExecution() { + assertTimeout(ofMillis(60000), () -> { + Schema schema = new Schema.Builder().addColumnInteger("col0").addColumnString("col1").addColumnDouble("col2").build(); + Schema finalSchema = new Schema.Builder().addColumnInteger("col0").addColumnInteger("col1").addColumnDouble("col2").build(); + String pythonCode = "col1 = ['state0', 'state1', 'state2'].index(col1)\ncol2 += 10.0"; + TransformProcess tp = new TransformProcess.Builder(schema).transform(PythonTransform.builder().code("first = np.sin(first)\nsecond = np.cos(second)").outputSchema(finalSchema).build()).build(); + List> inputData = new ArrayList<>(); + inputData.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); + inputData.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1))); + inputData.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1))); + JavaRDD> rdd = sc.parallelize(inputData); + List> out = new ArrayList<>(SparkTransformExecutor.execute(rdd, tp).collect()); + Collections.sort(out, new Comparator>() { - Schema finalSchema = new Schema.Builder().addColumnInteger("col0") - .addColumnInteger("col1").addColumnDouble("col2").build(); - String pythonCode = "col1 = ['state0', 'state1', 'state2'].index(col1)\ncol2 += 10.0"; - TransformProcess tp = new TransformProcess.Builder(schema).transform( - PythonTransform.builder().code( - "first = np.sin(first)\nsecond = np.cos(second)") - .outputSchema(finalSchema).build() - ).build(); - List> inputData = new ArrayList<>(); - inputData.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); - inputData.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1))); - inputData.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1))); - - JavaRDD> rdd = sc.parallelize(inputData); - - List> out = new ArrayList<>(SparkTransformExecutor.execute(rdd, tp).collect()); - - Collections.sort(out, new Comparator>() { - @Override - public int compare(List o1, List o2) { - return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt()); - } + @Override + public int compare(List o1, List o2) { + return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt()); + } + }); + List> expected = new ArrayList<>(); + expected.add(Arrays.asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1))); + expected.add(Arrays.asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1))); + expected.add(Arrays.asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1))); + assertEquals(expected, out); }); - - List> expected = new ArrayList<>(); - expected.add(Arrays.asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1))); - expected.add(Arrays.asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1))); - expected.add(Arrays.asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1))); - - assertEquals(expected, out); - } - - @Test(timeout = 60000L) - @Ignore("AB 2019/05/21 - Fine locally, timeouts on CI - Issue #7657 and #7771") - public void testPythonExecutionWithNDArrays() throws Exception { - long[] shape = new long[]{3, 2}; - Schema schema = new Schema.Builder().addColumnInteger("id").addColumnNDArray("col1", shape) - .addColumnNDArray("col2", shape).build(); - - Schema finalSchema = new Schema.Builder().addColumnInteger("id").addColumnNDArray("col1", shape) - .addColumnNDArray("col2", shape).addColumnNDArray("col3", shape).build(); - - String pythonCode = "col3 = col1 + col2"; - TransformProcess tp = new TransformProcess.Builder(schema).transform( - PythonTransform.builder().code( - "first = np.sin(first)\nsecond = np.cos(second)") - .outputSchema(schema).build() - ).build(); - - INDArray zeros = Nd4j.zeros(shape); - INDArray ones = Nd4j.ones(shape); - INDArray twos = ones.add(ones); - - List> inputData = new ArrayList<>(); - inputData.add(Arrays.asList(new IntWritable(0), new NDArrayWritable(zeros), new NDArrayWritable(zeros))); - inputData.add(Arrays.asList(new IntWritable(1), new NDArrayWritable(zeros), new NDArrayWritable(ones))); - inputData.add(Arrays.asList(new IntWritable(2), new NDArrayWritable(ones), new NDArrayWritable(ones))); - - JavaRDD> rdd = sc.parallelize(inputData); - - List> out = new ArrayList<>(SparkTransformExecutor.execute(rdd, tp).collect()); - - Collections.sort(out, new Comparator>() { - @Override - public int compare(List o1, List o2) { - return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt()); - } - }); - - List> expected = new ArrayList<>(); - expected.add(Arrays.asList(new IntWritable(0), new NDArrayWritable(zeros), new NDArrayWritable(zeros), new NDArrayWritable(zeros))); - expected.add(Arrays.asList(new IntWritable(1), new NDArrayWritable(zeros), new NDArrayWritable(ones), new NDArrayWritable(ones))); - expected.add(Arrays.asList(new IntWritable(2), new NDArrayWritable(ones), new NDArrayWritable(ones), new NDArrayWritable(twos))); } @Test - public void testFirstDigitTransformBenfordsLaw(){ - Schema s = new Schema.Builder() - .addColumnString("data") - .addColumnDouble("double") - .addColumnString("stringNumber") - .build(); + @Disabled("AB 2019/05/21 - Fine locally, timeouts on CI - Issue #7657 and #7771") + @DisplayName("Test Python Execution With ND Arrays") + void testPythonExecutionWithNDArrays() { + assertTimeout(ofMillis(60000), () -> { + long[] shape = new long[] { 3, 2 }; + Schema schema = new Schema.Builder().addColumnInteger("id").addColumnNDArray("col1", shape).addColumnNDArray("col2", shape).build(); + Schema finalSchema = new Schema.Builder().addColumnInteger("id").addColumnNDArray("col1", shape).addColumnNDArray("col2", shape).addColumnNDArray("col3", shape).build(); + String pythonCode = "col3 = col1 + col2"; + TransformProcess tp = new TransformProcess.Builder(schema).transform(PythonTransform.builder().code("first = np.sin(first)\nsecond = np.cos(second)").outputSchema(schema).build()).build(); + INDArray zeros = Nd4j.zeros(shape); + INDArray ones = Nd4j.ones(shape); + INDArray twos = ones.add(ones); + List> inputData = new ArrayList<>(); + inputData.add(Arrays.asList(new IntWritable(0), new NDArrayWritable(zeros), new NDArrayWritable(zeros))); + inputData.add(Arrays.asList(new IntWritable(1), new NDArrayWritable(zeros), new NDArrayWritable(ones))); + inputData.add(Arrays.asList(new IntWritable(2), new NDArrayWritable(ones), new NDArrayWritable(ones))); + JavaRDD> rdd = sc.parallelize(inputData); + List> out = new ArrayList<>(SparkTransformExecutor.execute(rdd, tp).collect()); + Collections.sort(out, new Comparator>() { - List> in = Arrays.asList( - Arrays.asList(new Text("a"), new DoubleWritable(3.14159), new Text("8e-4")), - Arrays.asList(new Text("a2"), new DoubleWritable(3.14159), new Text("7e-4")), - Arrays.asList(new Text("b"), new DoubleWritable(2.71828), new Text("7e2")), - Arrays.asList(new Text("c"), new DoubleWritable(1.61803), new Text("6e8")), - Arrays.asList(new Text("c"), new DoubleWritable(1.61803), new Text("2.0")), - Arrays.asList(new Text("c"), new DoubleWritable(1.61803), new Text("2.1")), - Arrays.asList(new Text("c"), new DoubleWritable(1.61803), new Text("2.2")), - Arrays.asList(new Text("c"), new DoubleWritable(-2), new Text("non numerical"))); - - //Test Benfords law use case: - TransformProcess tp = new TransformProcess.Builder(s) - .firstDigitTransform("double", "fdDouble", FirstDigitTransform.Mode.EXCEPTION_ON_INVALID) - .firstDigitTransform("stringNumber", "stringNumber", FirstDigitTransform.Mode.INCLUDE_OTHER_CATEGORY) - .removeAllColumnsExceptFor("stringNumber") - .categoricalToOneHot("stringNumber") - .reduce(new Reducer.Builder(ReduceOp.Sum).build()) - .build(); - - JavaRDD> rdd = sc.parallelize(in); - - - List> out = SparkTransformExecutor.execute(rdd, tp).collect(); - assertEquals(1, out.size()); - - List l = out.get(0); - List exp = Arrays.asList( - new IntWritable(0), //0 - new IntWritable(0), //1 - new IntWritable(3), //2 - new IntWritable(0), //3 - new IntWritable(0), //4 - new IntWritable(0), //5 - new IntWritable(1), //6 - new IntWritable(2), //7 - new IntWritable(1), //8 - new IntWritable(0), //9 - new IntWritable(1)); //Other - assertEquals(exp, l); + @Override + public int compare(List o1, List o2) { + return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt()); + } + }); + List> expected = new ArrayList<>(); + expected.add(Arrays.asList(new IntWritable(0), new NDArrayWritable(zeros), new NDArrayWritable(zeros), new NDArrayWritable(zeros))); + expected.add(Arrays.asList(new IntWritable(1), new NDArrayWritable(zeros), new NDArrayWritable(ones), new NDArrayWritable(ones))); + expected.add(Arrays.asList(new IntWritable(2), new NDArrayWritable(ones), new NDArrayWritable(ones), new NDArrayWritable(twos))); + }); } + @Test + @DisplayName("Test First Digit Transform Benfords Law") + void testFirstDigitTransformBenfordsLaw() { + Schema s = new Schema.Builder().addColumnString("data").addColumnDouble("double").addColumnString("stringNumber").build(); + List> in = Arrays.asList(Arrays.asList(new Text("a"), new DoubleWritable(3.14159), new Text("8e-4")), Arrays.asList(new Text("a2"), new DoubleWritable(3.14159), new Text("7e-4")), Arrays.asList(new Text("b"), new DoubleWritable(2.71828), new Text("7e2")), Arrays.asList(new Text("c"), new DoubleWritable(1.61803), new Text("6e8")), Arrays.asList(new Text("c"), new DoubleWritable(1.61803), new Text("2.0")), Arrays.asList(new Text("c"), new DoubleWritable(1.61803), new Text("2.1")), Arrays.asList(new Text("c"), new DoubleWritable(1.61803), new Text("2.2")), Arrays.asList(new Text("c"), new DoubleWritable(-2), new Text("non numerical"))); + // Test Benfords law use case: + TransformProcess tp = new TransformProcess.Builder(s).firstDigitTransform("double", "fdDouble", FirstDigitTransform.Mode.EXCEPTION_ON_INVALID).firstDigitTransform("stringNumber", "stringNumber", FirstDigitTransform.Mode.INCLUDE_OTHER_CATEGORY).removeAllColumnsExceptFor("stringNumber").categoricalToOneHot("stringNumber").reduce(new Reducer.Builder(ReduceOp.Sum).build()).build(); + JavaRDD> rdd = sc.parallelize(in); + List> out = SparkTransformExecutor.execute(rdd, tp).collect(); + assertEquals(1, out.size()); + List l = out.get(0); + List exp = Arrays.asList(// 0 + new IntWritable(0), // 1 + new IntWritable(0), // 2 + new IntWritable(3), // 3 + new IntWritable(0), // 4 + new IntWritable(0), // 5 + new IntWritable(0), // 6 + new IntWritable(1), // 7 + new IntWritable(2), // 8 + new IntWritable(1), // 9 + new IntWritable(0), // Other + new IntWritable(1)); + assertEquals(exp, l); + } } diff --git a/datavec/pom.xml b/datavec/pom.xml index 1ec358c4b..65f1afc61 100644 --- a/datavec/pom.xml +++ b/datavec/pom.xml @@ -89,14 +89,22 @@ - junit - junit - ${junit.version} + org.junit.jupiter + junit-jupiter-api + + + org.junit.vintage + junit-vintage-engine + + + com.tngtech.archunit + archunit-junit5-engine + ${archunit.version} test com.tngtech.archunit - archunit-junit4 + archunit-junit5-api ${archunit.version} test diff --git a/deeplearning4j/deeplearning4j-common-tests/pom.xml b/deeplearning4j/deeplearning4j-common-tests/pom.xml index 852471025..cce6ea55d 100644 --- a/deeplearning4j/deeplearning4j-common-tests/pom.xml +++ b/deeplearning4j/deeplearning4j-common-tests/pom.xml @@ -34,10 +34,18 @@ - junit - junit + org.junit.jupiter + junit-jupiter-api + ${junit.version} provided + + org.junit.jupiter + junit-jupiter-engine + ${junit.version} + provided + + org.nd4j nd4j-api diff --git a/deeplearning4j/deeplearning4j-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java b/deeplearning4j/deeplearning4j-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java index e95993a79..98c0e328b 100644 --- a/deeplearning4j/deeplearning4j-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java +++ b/deeplearning4j/deeplearning4j-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java @@ -17,17 +17,13 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j; import ch.qos.logback.classic.LoggerContext; import lombok.extern.slf4j.Slf4j; import org.bytedeco.javacpp.Pointer; -import org.junit.After; -import org.junit.Before; -import org.junit.Rule; -import org.junit.rules.TestName; -import org.junit.rules.Timeout; +import org.junit.jupiter.api.*; + import org.nd4j.common.base.Preconditions; import org.nd4j.common.config.ND4JSystemProperties; import org.nd4j.linalg.api.buffer.DataType; @@ -37,23 +33,22 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.profiler.ProfilerConfig; import org.slf4j.ILoggerFactory; import org.slf4j.LoggerFactory; - import java.lang.management.ManagementFactory; import java.util.List; import java.util.Map; import java.util.Properties; +import static org.junit.jupiter.api.Assumptions.assumeTrue; -import static org.junit.Assume.assumeTrue; +import org.junit.jupiter.api.extension.ExtendWith; @Slf4j +@DisplayName("Base DL 4 J Test") public abstract class BaseDL4JTest { - @Rule - public TestName name = new TestName(); - @Rule - public Timeout timeout = Timeout.millis(getTimeoutMilliseconds()); + protected long startTime; + protected int threadCountBefore; private final int DEFAULT_THREADS = Runtime.getRuntime().availableProcessors(); @@ -63,32 +58,32 @@ public abstract class BaseDL4JTest { * {@link org.nd4j.linalg.factory.Environment#setMaxMasterThreads(int)} * @return Number of threads to use for C++ op execution */ - public int numThreads(){ + public int numThreads() { return DEFAULT_THREADS; } /** * Override this method to set the default timeout for methods in the test class */ - public long getTimeoutMilliseconds(){ + public long getTimeoutMilliseconds() { return 90_000; } /** * Override this to set the profiling mode for the tests defined in the child class */ - public OpExecutioner.ProfilingMode getProfilingMode(){ + public OpExecutioner.ProfilingMode getProfilingMode() { return OpExecutioner.ProfilingMode.SCOPE_PANIC; } /** * Override this to set the datatype of the tests defined in the child class */ - public DataType getDataType(){ + public DataType getDataType() { return DataType.DOUBLE; } - public DataType getDefaultFPDataType(){ + public DataType getDefaultFPDataType() { return getDataType(); } @@ -97,8 +92,8 @@ public abstract class BaseDL4JTest { /** * @return True if integration tests maven profile is enabled, false otherwise. */ - public static boolean isIntegrationTests(){ - if(integrationTest == null){ + public static boolean isIntegrationTests() { + if (integrationTest == null) { String prop = System.getenv("DL4J_INTEGRATION_TESTS"); integrationTest = Boolean.parseBoolean(prop); } @@ -110,14 +105,15 @@ public abstract class BaseDL4JTest { * This can be used to dynamically skip integration tests when the integration test profile is not enabled. * Note that the integration test profile is not enabled by default - "integration-tests" profile */ - public static void skipUnlessIntegrationTests(){ - assumeTrue("Skipping integration test - integration profile is not enabled", isIntegrationTests()); + public static void skipUnlessIntegrationTests() { + assumeTrue(isIntegrationTests(), "Skipping integration test - integration profile is not enabled"); } - @Before - public void beforeTest(){ - log.info("{}.{}", getClass().getSimpleName(), name.getMethodName()); - //Suppress ND4J initialization - don't need this logged for every test... + @BeforeEach + @Timeout(90000L) + void beforeTest(TestInfo testInfo) { + log.info("{}.{}", getClass().getSimpleName(), testInfo.getTestMethod().get().getName()); + // Suppress ND4J initialization - don't need this logged for every test... System.setProperty(ND4JSystemProperties.LOG_INITIALIZATION, "false"); System.setProperty(ND4JSystemProperties.ND4J_IGNORE_AVX, "true"); Nd4j.getExecutioner().setProfilingMode(getProfilingMode()); @@ -128,83 +124,71 @@ public abstract class BaseDL4JTest { Nd4j.getExecutioner().enableVerboseMode(false); int numThreads = numThreads(); Preconditions.checkState(numThreads > 0, "Number of threads must be > 0"); - if(numThreads != Nd4j.getEnvironment().maxMasterThreads()) { + if (numThreads != Nd4j.getEnvironment().maxMasterThreads()) { Nd4j.getEnvironment().setMaxMasterThreads(numThreads); } startTime = System.currentTimeMillis(); threadCountBefore = ManagementFactory.getThreadMXBean().getThreadCount(); } - @After - public void afterTest(){ - //Attempt to keep workspaces isolated between tests + @AfterEach + void afterTest(TestInfo testInfo) { + // Attempt to keep workspaces isolated between tests Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); MemoryWorkspace currWS = Nd4j.getMemoryManager().getCurrentWorkspace(); Nd4j.getMemoryManager().setCurrentWorkspace(null); - if(currWS != null){ - //Not really safe to continue testing under this situation... other tests will likely fail with obscure + if (currWS != null) { + // Not really safe to continue testing under this situation... other tests will likely fail with obscure // errors that are hard to track back to this log.error("Open workspace leaked from test! Exiting - {}, isOpen = {} - {}", currWS.getId(), currWS.isScopeActive(), currWS); System.out.println("Open workspace leaked from test! Exiting - " + currWS.getId() + ", isOpen = " + currWS.isScopeActive() + " - " + currWS); System.out.flush(); - //Try to flush logs also: - try{ Thread.sleep(1000); } catch (InterruptedException e){ } - ILoggerFactory lf = LoggerFactory.getILoggerFactory(); - if( lf instanceof LoggerContext){ - ((LoggerContext)lf).stop(); + // Try to flush logs also: + try { + Thread.sleep(1000); + } catch (InterruptedException e) { + } + ILoggerFactory lf = LoggerFactory.getILoggerFactory(); + if (lf instanceof LoggerContext) { + ((LoggerContext) lf).stop(); + } + try { + Thread.sleep(1000); + } catch (InterruptedException e) { } - try{ Thread.sleep(1000); } catch (InterruptedException e){ } System.exit(1); } - StringBuilder sb = new StringBuilder(); long maxPhys = Pointer.maxPhysicalBytes(); long maxBytes = Pointer.maxBytes(); long currPhys = Pointer.physicalBytes(); long currBytes = Pointer.totalBytes(); - long jvmTotal = Runtime.getRuntime().totalMemory(); long jvmMax = Runtime.getRuntime().maxMemory(); - int threadsAfter = ManagementFactory.getThreadMXBean().getThreadCount(); - long duration = System.currentTimeMillis() - startTime; - sb.append(getClass().getSimpleName()).append(".").append(name.getMethodName()) - .append(": ").append(duration).append(" ms") - .append(", threadCount: (").append(threadCountBefore).append("->").append(threadsAfter).append(")") - .append(", jvmTotal=").append(jvmTotal) - .append(", jvmMax=").append(jvmMax) - .append(", totalBytes=").append(currBytes).append(", maxBytes=").append(maxBytes) - .append(", currPhys=").append(currPhys).append(", maxPhys=").append(maxPhys); - + sb.append(getClass().getSimpleName()).append(".").append(testInfo.getTestMethod().get().getName()).append(": ").append(duration).append(" ms").append(", threadCount: (").append(threadCountBefore).append("->").append(threadsAfter).append(")").append(", jvmTotal=").append(jvmTotal).append(", jvmMax=").append(jvmMax).append(", totalBytes=").append(currBytes).append(", maxBytes=").append(maxBytes).append(", currPhys=").append(currPhys).append(", maxPhys=").append(maxPhys); List ws = Nd4j.getWorkspaceManager().getAllWorkspacesForCurrentThread(); - if(ws != null && ws.size() > 0){ + if (ws != null && ws.size() > 0) { long currSize = 0; - for(MemoryWorkspace w : ws){ + for (MemoryWorkspace w : ws) { currSize += w.getCurrentSize(); } - if(currSize > 0){ - sb.append(", threadWSSize=").append(currSize) - .append(" (").append(ws.size()).append(" WSs)"); + if (currSize > 0) { + sb.append(", threadWSSize=").append(currSize).append(" (").append(ws.size()).append(" WSs)"); } } - - Properties p = Nd4j.getExecutioner().getEnvironmentInformation(); Object o = p.get("cuda.devicesInformation"); - if(o instanceof List){ - List> l = (List>) o; - if(l.size() > 0) { - - sb.append(" [").append(l.size()) - .append(" GPUs: "); - + if (o instanceof List) { + List> l = (List>) o; + if (l.size() > 0) { + sb.append(" [").append(l.size()).append(" GPUs: "); for (int i = 0; i < l.size(); i++) { - Map m = l.get(i); - if(i > 0) + Map m = l.get(i); + if (i > 0) sb.append(","); - sb.append("(").append(m.get("cuda.freeMemory")).append(" free, ") - .append(m.get("cuda.totalMemory")).append(" total)"); + sb.append("(").append(m.get("cuda.freeMemory")).append(" free, ").append(m.get("cuda.totalMemory")).append(" total)"); } sb.append("]"); } diff --git a/deeplearning4j/deeplearning4j-common/pom.xml b/deeplearning4j/deeplearning4j-common/pom.xml index bf250b0af..c63939b27 100644 --- a/deeplearning4j/deeplearning4j-common/pom.xml +++ b/deeplearning4j/deeplearning4j-common/pom.xml @@ -41,8 +41,15 @@ - junit - junit + org.junit.jupiter + junit-jupiter-api + ${junit.version} + test + + + org.junit.jupiter + junit-jupiter-engine + ${junit.version} test diff --git a/deeplearning4j/deeplearning4j-common/src/test/java/org/deeplearning4j/common/config/DL4JClassLoadingTest.java b/deeplearning4j/deeplearning4j-common/src/test/java/org/deeplearning4j/common/config/DL4JClassLoadingTest.java index d3740a8d1..73757e214 100644 --- a/deeplearning4j/deeplearning4j-common/src/test/java/org/deeplearning4j/common/config/DL4JClassLoadingTest.java +++ b/deeplearning4j/deeplearning4j-common/src/test/java/org/deeplearning4j/common/config/DL4JClassLoadingTest.java @@ -17,70 +17,56 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.common.config; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; - +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; import org.deeplearning4j.common.config.dummies.TestAbstract; -import org.junit.Test; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; + +@DisplayName("Dl 4 J Class Loading Test") +class DL4JClassLoadingTest { -public class DL4JClassLoadingTest { private static final String PACKAGE_PREFIX = "org.deeplearning4j.common.config.dummies."; @Test - public void testCreateNewInstance_constructorWithoutArguments() { - + @DisplayName("Test Create New Instance _ constructor Without Arguments") + void testCreateNewInstance_constructorWithoutArguments() { /* Given */ String className = PACKAGE_PREFIX + "TestDummy"; - /* When */ Object instance = DL4JClassLoading.createNewInstance(className); - /* Then */ assertNotNull(instance); assertEquals(className, instance.getClass().getName()); } @Test - public void testCreateNewInstance_constructorWithArgument_implicitArgumentTypes() { - + @DisplayName("Test Create New Instance _ constructor With Argument _ implicit Argument Types") + void testCreateNewInstance_constructorWithArgument_implicitArgumentTypes() { /* Given */ String className = PACKAGE_PREFIX + "TestColor"; - /* When */ TestAbstract instance = DL4JClassLoading.createNewInstance(className, TestAbstract.class, "white"); - /* Then */ assertNotNull(instance); assertEquals(className, instance.getClass().getName()); } @Test - public void testCreateNewInstance_constructorWithArgument_explicitArgumentTypes() { - + @DisplayName("Test Create New Instance _ constructor With Argument _ explicit Argument Types") + void testCreateNewInstance_constructorWithArgument_explicitArgumentTypes() { /* Given */ String colorClassName = PACKAGE_PREFIX + "TestColor"; String rectangleClassName = PACKAGE_PREFIX + "TestRectangle"; - /* When */ - TestAbstract color = DL4JClassLoading.createNewInstance( - colorClassName, - Object.class, - new Class[]{ int.class, int.class, int.class }, - 45, 175, 200); - - TestAbstract rectangle = DL4JClassLoading.createNewInstance( - rectangleClassName, - Object.class, - new Class[]{ int.class, int.class, TestAbstract.class }, - 10, 15, color); - + TestAbstract color = DL4JClassLoading.createNewInstance(colorClassName, Object.class, new Class[] { int.class, int.class, int.class }, 45, 175, 200); + TestAbstract rectangle = DL4JClassLoading.createNewInstance(rectangleClassName, Object.class, new Class[] { int.class, int.class, TestAbstract.class }, 10, 15, color); /* Then */ assertNotNull(color); assertEquals(colorClassName, color.getClass().getName()); - assertNotNull(rectangle); assertEquals(rectangleClassName, rectangle.getClass().getName()); } diff --git a/deeplearning4j/deeplearning4j-core/pom.xml b/deeplearning4j/deeplearning4j-core/pom.xml index 27caa6718..655e60a8a 100644 --- a/deeplearning4j/deeplearning4j-core/pom.xml +++ b/deeplearning4j/deeplearning4j-core/pom.xml @@ -49,11 +49,6 @@ - - org.deeplearning4j - deeplearning4j-tsne - ${project.version} - org.deeplearning4j deeplearning4j-datasets @@ -99,8 +94,12 @@ ${commons-compress.version} - junit - junit + org.junit.jupiter + junit-jupiter-api + + + org.junit.vintage + junit-vintage-engine org.deeplearning4j diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/MnistFetcherTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/MnistFetcherTest.java index 59da9f5d6..29041b5f5 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/MnistFetcherTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/MnistFetcherTest.java @@ -17,15 +17,16 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.datasets; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.base.MnistFetcher; import org.deeplearning4j.common.resources.DL4JResources; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; -import org.junit.*; -import org.junit.rules.TemporaryFolder; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; import org.junit.rules.Timeout; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition; @@ -33,69 +34,67 @@ import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.conditions.Conditions; - import java.io.File; +import java.nio.file.Path; import java.util.HashSet; import java.util.Set; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; +@DisplayName("Mnist Fetcher Test") +class MnistFetcherTest extends BaseDL4JTest { -public class MnistFetcherTest extends BaseDL4JTest { - @ClassRule - public static TemporaryFolder testDir = new TemporaryFolder(); - @Rule - public Timeout timeout = Timeout.seconds(300); - @BeforeClass - public static void setup() throws Exception { - DL4JResources.setBaseDirectory(testDir.newFolder()); + @BeforeAll + static void setup(@TempDir Path tempPath) throws Exception { + DL4JResources.setBaseDirectory(tempPath.toFile()); } - @AfterClass - public static void after() { + @AfterAll + static void after() { DL4JResources.resetBaseDirectoryLocation(); } @Test - public void testMnist() throws Exception { + @DisplayName("Test Mnist") + void testMnist() throws Exception { DataSetIterator iter = new MnistDataSetIterator(32, 60000, false, true, false, -1); int count = 0; - while(iter.hasNext()){ + while (iter.hasNext()) { DataSet ds = iter.next(); INDArray arr = ds.getFeatures().sum(1); int countMatch = Nd4j.getExecutioner().execAndReturn(new MatchCondition(arr, Conditions.equals(0))).z().getInt(0); assertEquals(0, countMatch); count++; } - assertEquals(60000/32, count); - + assertEquals(60000 / 32, count); count = 0; iter = new MnistDataSetIterator(32, false, 12345); - while(iter.hasNext()){ + while (iter.hasNext()) { DataSet ds = iter.next(); INDArray arr = ds.getFeatures().sum(1); int countMatch = Nd4j.getExecutioner().execAndReturn(new MatchCondition(arr, Conditions.equals(0))).z().getInt(0); assertEquals(0, countMatch); count++; } - assertEquals((int)Math.ceil(10000/32.0), count); + assertEquals((int) Math.ceil(10000 / 32.0), count); } @Test - public void testMnistDataFetcher() throws Exception { + @DisplayName("Test Mnist Data Fetcher") + void testMnistDataFetcher() throws Exception { MnistFetcher mnistFetcher = new MnistFetcher(); File mnistDir = mnistFetcher.downloadAndUntar(); - assertTrue(mnistDir.isDirectory()); } -// @Test + // @Test public void testMnistSubset() throws Exception { final int numExamples = 100; - MnistDataSetIterator iter1 = new MnistDataSetIterator(10, numExamples, false, true, true, 123); int examples1 = 0; int itCount1 = 0; @@ -105,7 +104,6 @@ public class MnistFetcherTest extends BaseDL4JTest { } assertEquals(10, itCount1); assertEquals(100, examples1); - MnistDataSetIterator iter2 = new MnistDataSetIterator(10, numExamples, false, true, true, 123); int examples2 = 0; int itCount2 = 0; @@ -116,7 +114,6 @@ public class MnistFetcherTest extends BaseDL4JTest { assertFalse(iter2.hasNext()); assertEquals(10, itCount2); assertEquals(100, examples2); - MnistDataSetIterator iter3 = new MnistDataSetIterator(19, numExamples, false, true, true, 123); int examples3 = 0; int itCount3 = 0; @@ -125,51 +122,45 @@ public class MnistFetcherTest extends BaseDL4JTest { examples3 += iter3.next().numExamples(); } assertEquals(100, examples3); - assertEquals((int)Math.ceil(100/19.0), itCount3); - + assertEquals((int) Math.ceil(100 / 19.0), itCount3); MnistDataSetIterator iter4 = new MnistDataSetIterator(32, true, 12345); int count4 = 0; - while(iter4.hasNext()){ + while (iter4.hasNext()) { count4 += iter4.next().numExamples(); } assertEquals(60000, count4); } @Test - public void testSubsetRepeatability() throws Exception { - + @DisplayName("Test Subset Repeatability") + void testSubsetRepeatability() throws Exception { DataSetIterator it = new MnistDataSetIterator(1, 1, false, false, true, 0); DataSet d1 = it.next(); - for( int i=0; i<10; i++ ) { + for (int i = 0; i < 10; i++) { it.reset(); DataSet d2 = it.next(); assertEquals(d1.get(0).getFeatures(), d2.get(0).getFeatures()); } - - //Check larger number: + // Check larger number: it = new MnistDataSetIterator(8, 32, false, false, true, 12345); Set featureLabelSet = new HashSet<>(); - while(it.hasNext()){ + while (it.hasNext()) { DataSet ds = it.next(); INDArray f = ds.getFeatures(); INDArray l = ds.getLabels(); - - for( int i=0; i flSet2 = new HashSet<>(); - while(it.hasNext()){ + while (it.hasNext()) { DataSet ds = it.next(); INDArray f = ds.getFeatures(); INDArray l = ds.getLabels(); - - for( int j=0; j dsList = new ArrayList<>(); while (iter.hasNext()) { dsList.add(iter.next()); } - - assertEquals(3, dsList.size()); //3 files + // 3 files + assertEquals(3, dsList.size()); for (int i = 0; i < 3; i++) { DataSet ds = dsList.get(i); INDArray features = ds.getFeatures(); INDArray labels = ds.getLabels(); - assertEquals(1, features.size(0)); //1 example in mini-batch + // 1 example in mini-batch + assertEquals(1, features.size(0)); assertEquals(1, labels.size(0)); - assertEquals(3, features.size(1)); //3 values per line/time step - assertEquals(4, labels.size(1)); //1 value per line, but 4 possible values -> one-hot vector - assertEquals(4, features.size(2)); //sequence length = 4 + // 3 values per line/time step + assertEquals(3, features.size(1)); + // 1 value per line, but 4 possible values -> one-hot vector + assertEquals(4, labels.size(1)); + // sequence length = 4 + assertEquals(4, features.size(2)); assertEquals(4, labels.size(2)); } - - //Check features vs. expected: + // Check features vs. expected: INDArray expF0 = Nd4j.create(1, 3, 4); - expF0.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] {0, 1, 2})); - expF0.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] {10, 11, 12})); - expF0.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] {20, 21, 22})); - expF0.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] {30, 31, 32})); + expF0.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] { 0, 1, 2 })); + expF0.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] { 10, 11, 12 })); + expF0.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] { 20, 21, 22 })); + expF0.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] { 30, 31, 32 })); assertEquals(dsList.get(0).getFeatures(), expF0); - INDArray expF1 = Nd4j.create(1, 3, 4); - expF1.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] {100, 101, 102})); - expF1.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] {110, 111, 112})); - expF1.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] {120, 121, 122})); - expF1.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] {130, 131, 132})); + expF1.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] { 100, 101, 102 })); + expF1.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] { 110, 111, 112 })); + expF1.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] { 120, 121, 122 })); + expF1.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] { 130, 131, 132 })); assertEquals(dsList.get(1).getFeatures(), expF1); - INDArray expF2 = Nd4j.create(1, 3, 4); - expF2.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] {200, 201, 202})); - expF2.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] {210, 211, 212})); - expF2.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] {220, 221, 222})); - expF2.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] {230, 231, 232})); + expF2.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] { 200, 201, 202 })); + expF2.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] { 210, 211, 212 })); + expF2.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] { 220, 221, 222 })); + expF2.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] { 230, 231, 232 })); assertEquals(dsList.get(2).getFeatures(), expF2); - - //Check labels vs. expected: + // Check labels vs. expected: INDArray expL0 = Nd4j.create(1, 4, 4); - expL0.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] {1, 0, 0, 0})); - expL0.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] {0, 1, 0, 0})); - expL0.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] {0, 0, 1, 0})); - expL0.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] {0, 0, 0, 1})); + expL0.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] { 1, 0, 0, 0 })); + expL0.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] { 0, 1, 0, 0 })); + expL0.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] { 0, 0, 1, 0 })); + expL0.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] { 0, 0, 0, 1 })); assertEquals(dsList.get(0).getLabels(), expL0); - INDArray expL1 = Nd4j.create(1, 4, 4); - expL1.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] {0, 0, 0, 1})); - expL1.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] {0, 0, 1, 0})); - expL1.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] {0, 1, 0, 0})); - expL1.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] {1, 0, 0, 0})); + expL1.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] { 0, 0, 0, 1 })); + expL1.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] { 0, 0, 1, 0 })); + expL1.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] { 0, 1, 0, 0 })); + expL1.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] { 1, 0, 0, 0 })); assertEquals(dsList.get(1).getLabels(), expL1); - INDArray expL2 = Nd4j.create(1, 4, 4); - expL2.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] {0, 1, 0, 0})); - expL2.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] {1, 0, 0, 0})); - expL2.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] {0, 0, 0, 1})); - expL2.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] {0, 0, 1, 0})); + expL2.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] { 0, 1, 0, 0 })); + expL2.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] { 1, 0, 0, 0 })); + expL2.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] { 0, 0, 0, 1 })); + expL2.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] { 0, 0, 1, 0 })); assertEquals(dsList.get(2).getLabels(), expL2); } @Test - public void testSequenceRecordReaderMeta() throws Exception { - File rootDir = temporaryFolder.newFolder(); - //need to manually extract + @DisplayName("Test Sequence Record Reader Meta") + void testSequenceRecordReaderMeta() throws Exception { + File rootDir = temporaryFolder.toFile(); + // need to manually extract for (int i = 0; i < 3; i++) { FileUtils.copyFile(Resources.asFile(String.format("csvsequence_%d.txt", i)), new File(rootDir, String.format("csvsequence_%d.txt", i))); FileUtils.copyFile(Resources.asFile(String.format("csvsequencelabels_%d.txt", i)), new File(rootDir, String.format("csvsequencelabels_%d.txt", i))); } String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt"); String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabels_%d.txt"); - SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ","); featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); - - SequenceRecordReaderDataSetIterator iter = - new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false); - + SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false); iter.setCollectMetaData(true); - assertEquals(3, iter.inputColumns()); assertEquals(4, iter.totalOutcomes()); - while (iter.hasNext()) { DataSet ds = iter.next(); List meta = ds.getExampleMetaData(RecordMetaData.class); DataSet fromMeta = iter.loadFromMetaData(meta); - assertEquals(ds, fromMeta); } } @Test - public void testSequenceRecordReaderRegression() throws Exception { - //need to manually extract - File rootDir = temporaryFolder.newFolder(); + @DisplayName("Test Sequence Record Reader Regression") + void testSequenceRecordReaderRegression() throws Exception { + // need to manually extract + File rootDir = temporaryFolder.toFile(); for (int i = 0; i < 3; i++) { FileUtils.copyFile(Resources.asFile(String.format("csvsequence_%d.txt", i)), new File(rootDir, String.format("csvsequence_%d.txt", i))); } String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt"); String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt"); - SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ","); featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); - - SequenceRecordReaderDataSetIterator iter = - new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 0, true); - + SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 0, true); assertEquals(3, iter.inputColumns()); assertEquals(3, iter.totalOutcomes()); - List dsList = new ArrayList<>(); while (iter.hasNext()) { dsList.add(iter.next()); } - - assertEquals(3, dsList.size()); //3 files + // 3 files + assertEquals(3, dsList.size()); for (int i = 0; i < 3; i++) { DataSet ds = dsList.get(i); INDArray features = ds.getFeatures(); INDArray labels = ds.getLabels(); - assertArrayEquals(new long[] {1, 3, 4}, features.shape()); //1 examples, 3 values, 4 time steps - assertArrayEquals(new long[] {1, 3, 4}, labels.shape()); - + // 1 examples, 3 values, 4 time steps + assertArrayEquals(new long[] { 1, 3, 4 }, features.shape()); + assertArrayEquals(new long[] { 1, 3, 4 }, labels.shape()); assertEquals(features, labels); } - - //Also test regression + reset from a single reader: + // Also test regression + reset from a single reader: featureReader.reset(); iter = new SequenceRecordReaderDataSetIterator(featureReader, 1, 0, 2, true); int count = 0; @@ -316,8 +290,6 @@ public class RecordReaderDataSetiteratorTest extends BaseDL4JTest { count++; } assertEquals(3, count); - - iter.reset(); count = 0; while (iter.hasNext()) { @@ -328,58 +300,51 @@ public class RecordReaderDataSetiteratorTest extends BaseDL4JTest { } @Test - public void testSequenceRecordReaderMultiRegression() throws Exception { - File rootDir = temporaryFolder.newFolder(); - //need to manually extract + @DisplayName("Test Sequence Record Reader Multi Regression") + void testSequenceRecordReaderMultiRegression() throws Exception { + File rootDir = temporaryFolder.toFile(); + // need to manually extract for (int i = 0; i < 3; i++) { FileUtils.copyFile(Resources.asFile(String.format("csvsequence_%d.txt", i)), new File(rootDir, String.format("csvsequence_%d.txt", i))); } String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt"); - SequenceRecordReader reader = new CSVSequenceRecordReader(1, ","); reader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); - - SequenceRecordReaderDataSetIterator iter = - new SequenceRecordReaderDataSetIterator(reader, 1, 2, 1, true); - + SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(reader, 1, 2, 1, true); assertEquals(1, iter.inputColumns()); assertEquals(2, iter.totalOutcomes()); - List dsList = new ArrayList<>(); while (iter.hasNext()) { dsList.add(iter.next()); } - - assertEquals(3, dsList.size()); //3 files + // 3 files + assertEquals(3, dsList.size()); for (int i = 0; i < 3; i++) { DataSet ds = dsList.get(i); INDArray features = ds.getFeatures(); INDArray labels = ds.getLabels(); - assertArrayEquals(new long[] {1, 1, 4}, features.shape()); //1 examples, 1 values, 4 time steps - assertArrayEquals(new long[] {1, 2, 4}, labels.shape()); - + // 1 examples, 1 values, 4 time steps + assertArrayEquals(new long[] { 1, 1, 4 }, features.shape()); + assertArrayEquals(new long[] { 1, 2, 4 }, labels.shape()); INDArray f2d = features.get(point(0), all(), all()).transpose(); INDArray l2d = labels.get(point(0), all(), all()).transpose(); - - switch (i){ + switch(i) { case 0: - assertEquals(Nd4j.create(new double[]{0,10,20,30}, new int[]{4,1}).castTo(DataType.FLOAT), f2d); - assertEquals(Nd4j.create(new double[][]{{1,2}, {11,12}, {21,22}, {31,32}}).castTo(DataType.FLOAT), l2d); + assertEquals(Nd4j.create(new double[] { 0, 10, 20, 30 }, new int[] { 4, 1 }).castTo(DataType.FLOAT), f2d); + assertEquals(Nd4j.create(new double[][] { { 1, 2 }, { 11, 12 }, { 21, 22 }, { 31, 32 } }).castTo(DataType.FLOAT), l2d); break; case 1: - assertEquals(Nd4j.create(new double[]{100,110,120,130}, new int[]{4,1}).castTo(DataType.FLOAT), f2d); - assertEquals(Nd4j.create(new double[][]{{101,102}, {111,112}, {121,122}, {131,132}}).castTo(DataType.FLOAT), l2d); + assertEquals(Nd4j.create(new double[] { 100, 110, 120, 130 }, new int[] { 4, 1 }).castTo(DataType.FLOAT), f2d); + assertEquals(Nd4j.create(new double[][] { { 101, 102 }, { 111, 112 }, { 121, 122 }, { 131, 132 } }).castTo(DataType.FLOAT), l2d); break; case 2: - assertEquals(Nd4j.create(new double[]{200,210,220,230}, new int[]{4,1}).castTo(DataType.FLOAT), f2d); - assertEquals(Nd4j.create(new double[][]{{201,202}, {211,212}, {221,222}, {231,232}}).castTo(DataType.FLOAT), l2d); + assertEquals(Nd4j.create(new double[] { 200, 210, 220, 230 }, new int[] { 4, 1 }).castTo(DataType.FLOAT), f2d); + assertEquals(Nd4j.create(new double[][] { { 201, 202 }, { 211, 212 }, { 221, 222 }, { 231, 232 } }).castTo(DataType.FLOAT), l2d); break; default: throw new RuntimeException(); } } - - iter.reset(); int count = 0; while (iter.hasNext()) { @@ -389,30 +354,24 @@ public class RecordReaderDataSetiteratorTest extends BaseDL4JTest { assertEquals(3, count); } - - @Test - public void testSequenceRecordReaderReset() throws Exception { - File rootDir = temporaryFolder.newFolder(); - //need to manually extract + @DisplayName("Test Sequence Record Reader Reset") + void testSequenceRecordReaderReset() throws Exception { + File rootDir = temporaryFolder.toFile(); + // need to manually extract for (int i = 0; i < 3; i++) { FileUtils.copyFile(Resources.asFile(String.format("csvsequence_%d.txt", i)), new File(rootDir, String.format("csvsequence_%d.txt", i))); FileUtils.copyFile(Resources.asFile(String.format("csvsequencelabels_%d.txt", i)), new File(rootDir, String.format("csvsequencelabels_%d.txt", i))); } String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt"); String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabels_%d.txt"); - SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ","); featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); - - SequenceRecordReaderDataSetIterator iter = - new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false); - + SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false); assertEquals(3, iter.inputColumns()); assertEquals(4, iter.totalOutcomes()); - int nResets = 5; for (int i = 0; i < nResets; i++) { iter.reset(); @@ -421,44 +380,39 @@ public class RecordReaderDataSetiteratorTest extends BaseDL4JTest { DataSet ds = iter.next(); INDArray features = ds.getFeatures(); INDArray labels = ds.getLabels(); - assertArrayEquals(new long[] {1, 3, 4}, features.shape()); - assertArrayEquals(new long[] {1, 4, 4}, labels.shape()); + assertArrayEquals(new long[] { 1, 3, 4 }, features.shape()); + assertArrayEquals(new long[] { 1, 4, 4 }, labels.shape()); count++; } assertEquals(3, count); } } - - @Test - public void testCSVLoadingRegression() throws Exception { + @DisplayName("Test CSV Loading Regression") + void testCSVLoadingRegression() throws Exception { int nLines = 30; int nFeatures = 5; int miniBatchSize = 10; int labelIdx = 0; - String path = "rr_csv_test_rand.csv"; - Pair p = makeRandomCSV(path, nLines, nFeatures); + Pair p = makeRandomCSV(path, nLines, nFeatures); double[][] data = p.getFirst(); RecordReader testReader = new CSVRecordReader(); testReader.initialize(new FileSplit(p.getSecond())); - DataSetIterator iter = new RecordReaderDataSetIterator(testReader, miniBatchSize, labelIdx, labelIdx, true); int miniBatch = 0; while (iter.hasNext()) { DataSet test = iter.next(); INDArray features = test.getFeatures(); INDArray labels = test.getLabels(); - assertArrayEquals(new long[] {miniBatchSize, nFeatures}, features.shape()); - assertArrayEquals(new long[] {miniBatchSize, 1}, labels.shape()); - + assertArrayEquals(new long[] { miniBatchSize, nFeatures }, features.shape()); + assertArrayEquals(new long[] { miniBatchSize, 1 }, labels.shape()); int startRow = miniBatch * miniBatchSize; for (int i = 0; i < miniBatchSize; i++) { double labelExp = data[startRow + i][labelIdx]; double labelAct = labels.getDouble(i); assertEquals(labelExp, labelAct, 1e-5f); - int featureCount = 0; for (int j = 0; j < nFeatures + 1; j++) { if (j == labelIdx) @@ -468,24 +422,21 @@ public class RecordReaderDataSetiteratorTest extends BaseDL4JTest { assertEquals(featureExp, featureAct, 1e-5f); } } - miniBatch++; } assertEquals(nLines / miniBatchSize, miniBatch); } - - public Pair makeRandomCSV(String tempFile, int nLines, int nFeatures) throws IOException { - File temp = temporaryFolder.newFile(tempFile); + public Pair makeRandomCSV(String tempFile, int nLines, int nFeatures) throws IOException { + File temp = temporaryFolder.resolve(tempFile).toFile(); temp.mkdirs(); temp.deleteOnExit(); Random rand = new Random(12345); - double[][] dArr = new double[nLines][nFeatures + 1]; - try (PrintWriter out = new PrintWriter(new BufferedWriter(new FileWriter(temp)))) { for (int i = 0; i < nLines; i++) { - dArr[i][0] = rand.nextDouble(); //First column: label + // First column: label + dArr[i][0] = rand.nextDouble(); out.print(dArr[i][0]); for (int j = 0; j < nFeatures; j++) { dArr[i][j + 1] = rand.nextDouble(); @@ -494,157 +445,138 @@ public class RecordReaderDataSetiteratorTest extends BaseDL4JTest { out.println(); } } catch (IOException e) { - log.error("",e); + log.error("", e); } - - return new Pair<>(dArr,temp); + return new Pair<>(dArr, temp); } @Test - public void testVariableLengthSequence() throws Exception { - File rootDir = temporaryFolder.newFolder(); - //need to manually extract + @DisplayName("Test Variable Length Sequence") + void testVariableLengthSequence() throws Exception { + File rootDir = temporaryFolder.toFile(); + // need to manually extract for (int i = 0; i < 3; i++) { FileUtils.copyFile(Resources.asFile(String.format("csvsequence_%d.txt", i)), new File(rootDir, String.format("csvsequence_%d.txt", i))); FileUtils.copyFile(Resources.asFile(String.format("csvsequencelabelsShort_%d.txt", i)), new File(rootDir, String.format("csvsequencelabelsShort_%d.txt", i))); } String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt"); String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabelsShort_%d.txt"); - SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ","); featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); - SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ","); featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); - - SequenceRecordReaderDataSetIterator iterAlignStart = new SequenceRecordReaderDataSetIterator(featureReader, - labelReader, 1, 4, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_START); - - SequenceRecordReaderDataSetIterator iterAlignEnd = new SequenceRecordReaderDataSetIterator(featureReader2, - labelReader2, 1, 4, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END); - + SequenceRecordReaderDataSetIterator iterAlignStart = new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_START); + SequenceRecordReaderDataSetIterator iterAlignEnd = new SequenceRecordReaderDataSetIterator(featureReader2, labelReader2, 1, 4, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END); assertEquals(3, iterAlignStart.inputColumns()); assertEquals(4, iterAlignStart.totalOutcomes()); - assertEquals(3, iterAlignEnd.inputColumns()); assertEquals(4, iterAlignEnd.totalOutcomes()); - List dsListAlignStart = new ArrayList<>(); while (iterAlignStart.hasNext()) { dsListAlignStart.add(iterAlignStart.next()); } - List dsListAlignEnd = new ArrayList<>(); while (iterAlignEnd.hasNext()) { dsListAlignEnd.add(iterAlignEnd.next()); } - - assertEquals(3, dsListAlignStart.size()); //3 files - assertEquals(3, dsListAlignEnd.size()); //3 files - + // 3 files + assertEquals(3, dsListAlignStart.size()); + // 3 files + assertEquals(3, dsListAlignEnd.size()); for (int i = 0; i < 3; i++) { DataSet ds = dsListAlignStart.get(i); INDArray features = ds.getFeatures(); INDArray labels = ds.getLabels(); - assertEquals(1, features.size(0)); //1 example in mini-batch + // 1 example in mini-batch + assertEquals(1, features.size(0)); assertEquals(1, labels.size(0)); - assertEquals(3, features.size(1)); //3 values per line/time step - assertEquals(4, labels.size(1)); //1 value per line, but 4 possible values -> one-hot vector - assertEquals(4, features.size(2)); //sequence length = 4 + // 3 values per line/time step + assertEquals(3, features.size(1)); + // 1 value per line, but 4 possible values -> one-hot vector + assertEquals(4, labels.size(1)); + // sequence length = 4 + assertEquals(4, features.size(2)); assertEquals(4, labels.size(2)); - DataSet ds2 = dsListAlignEnd.get(i); features = ds2.getFeatures(); labels = ds2.getLabels(); - assertEquals(1, features.size(0)); //1 example in mini-batch + // 1 example in mini-batch + assertEquals(1, features.size(0)); assertEquals(1, labels.size(0)); - assertEquals(3, features.size(1)); //3 values per line/time step - assertEquals(4, labels.size(1)); //1 value per line, but 4 possible values -> one-hot vector - assertEquals(4, features.size(2)); //sequence length = 4 + // 3 values per line/time step + assertEquals(3, features.size(1)); + // 1 value per line, but 4 possible values -> one-hot vector + assertEquals(4, labels.size(1)); + // sequence length = 4 + assertEquals(4, features.size(2)); assertEquals(4, labels.size(2)); } - - //Check features vs. expected: - //Here: labels always longer than features -> same features for align start and align end + // Check features vs. expected: + // Here: labels always longer than features -> same features for align start and align end INDArray expF0 = Nd4j.create(1, 3, 4); - expF0.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] {0, 1, 2})); - expF0.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] {10, 11, 12})); - expF0.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] {20, 21, 22})); - expF0.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] {30, 31, 32})); + expF0.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] { 0, 1, 2 })); + expF0.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] { 10, 11, 12 })); + expF0.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] { 20, 21, 22 })); + expF0.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] { 30, 31, 32 })); assertEquals(expF0, dsListAlignStart.get(0).getFeatures()); assertEquals(expF0, dsListAlignEnd.get(0).getFeatures()); - INDArray expF1 = Nd4j.create(1, 3, 4); - expF1.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] {100, 101, 102})); - expF1.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] {110, 111, 112})); - expF1.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] {120, 121, 122})); - expF1.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] {130, 131, 132})); + expF1.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] { 100, 101, 102 })); + expF1.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] { 110, 111, 112 })); + expF1.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] { 120, 121, 122 })); + expF1.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] { 130, 131, 132 })); assertEquals(expF1, dsListAlignStart.get(1).getFeatures()); assertEquals(expF1, dsListAlignEnd.get(1).getFeatures()); - INDArray expF2 = Nd4j.create(1, 3, 4); - expF2.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] {200, 201, 202})); - expF2.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] {210, 211, 212})); - expF2.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] {220, 221, 222})); - expF2.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] {230, 231, 232})); + expF2.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] { 200, 201, 202 })); + expF2.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] { 210, 211, 212 })); + expF2.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] { 220, 221, 222 })); + expF2.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] { 230, 231, 232 })); assertEquals(expF2, dsListAlignStart.get(2).getFeatures()); assertEquals(expF2, dsListAlignEnd.get(2).getFeatures()); - - //Check features mask array: - INDArray featuresMaskExpected = null; //null: equivalent to all 1s (i.e., present for all time steps) + // Check features mask array: + // null: equivalent to all 1s (i.e., present for all time steps) + INDArray featuresMaskExpected = null; for (int i = 0; i < 3; i++) { INDArray featuresMaskStart = dsListAlignStart.get(i).getFeaturesMaskArray(); INDArray featuresMaskEnd = dsListAlignEnd.get(i).getFeaturesMaskArray(); assertEquals(featuresMaskExpected, featuresMaskStart); assertEquals(featuresMaskExpected, featuresMaskEnd); } - - - //Check labels vs. expected: - //First: aligning start + // Check labels vs. expected: + // First: aligning start INDArray expL0 = Nd4j.create(1, 4, 4); - expL0.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] {1, 0, 0, 0})); - expL0.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] {0, 1, 0, 0})); + expL0.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] { 1, 0, 0, 0 })); + expL0.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] { 0, 1, 0, 0 })); assertEquals(expL0, dsListAlignStart.get(0).getLabels()); - INDArray expL1 = Nd4j.create(1, 4, 4); - expL1.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] {0, 1, 0, 0})); + expL1.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] { 0, 1, 0, 0 })); assertEquals(expL1, dsListAlignStart.get(1).getLabels()); - INDArray expL2 = Nd4j.create(1, 4, 4); - expL2.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] {0, 0, 0, 1})); - expL2.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] {0, 0, 1, 0})); - expL2.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] {0, 1, 0, 0})); + expL2.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] { 0, 0, 0, 1 })); + expL2.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] { 0, 0, 1, 0 })); + expL2.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] { 0, 1, 0, 0 })); assertEquals(expL2, dsListAlignStart.get(2).getLabels()); - - //Second: align end + // Second: align end INDArray expL0end = Nd4j.create(1, 4, 4); - expL0end.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] {1, 0, 0, 0})); - expL0end.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] {0, 1, 0, 0})); + expL0end.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] { 1, 0, 0, 0 })); + expL0end.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] { 0, 1, 0, 0 })); assertEquals(expL0end, dsListAlignEnd.get(0).getLabels()); - INDArray expL1end = Nd4j.create(1, 4, 4); - expL1end.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] {0, 1, 0, 0})); + expL1end.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] { 0, 1, 0, 0 })); assertEquals(expL1end, dsListAlignEnd.get(1).getLabels()); - INDArray expL2end = Nd4j.create(1, 4, 4); - expL2end.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] {0, 0, 0, 1})); - expL2end.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] {0, 0, 1, 0})); - expL2end.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] {0, 1, 0, 0})); + expL2end.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] { 0, 0, 0, 1 })); + expL2end.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] { 0, 0, 1, 0 })); + expL2end.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] { 0, 1, 0, 0 })); assertEquals(expL2end, dsListAlignEnd.get(2).getLabels()); - - //Check labels mask array - INDArray[] labelsMaskExpectedStart = new INDArray[] {Nd4j.create(new float[] {1, 1, 0, 0}, new int[] {1, 4}), - Nd4j.create(new float[] {1, 0, 0, 0}, new int[] {1, 4}), - Nd4j.create(new float[] {1, 1, 1, 0}, new int[] {1, 4})}; - INDArray[] labelsMaskExpectedEnd = new INDArray[] {Nd4j.create(new float[] {0, 0, 1, 1}, new int[] {1, 4}), - Nd4j.create(new float[] {0, 0, 0, 1}, new int[] {1, 4}), - Nd4j.create(new float[] {0, 1, 1, 1}, new int[] {1, 4})}; - + // Check labels mask array + INDArray[] labelsMaskExpectedStart = new INDArray[] { Nd4j.create(new float[] { 1, 1, 0, 0 }, new int[] { 1, 4 }), Nd4j.create(new float[] { 1, 0, 0, 0 }, new int[] { 1, 4 }), Nd4j.create(new float[] { 1, 1, 1, 0 }, new int[] { 1, 4 }) }; + INDArray[] labelsMaskExpectedEnd = new INDArray[] { Nd4j.create(new float[] { 0, 0, 1, 1 }, new int[] { 1, 4 }), Nd4j.create(new float[] { 0, 0, 0, 1 }, new int[] { 1, 4 }), Nd4j.create(new float[] { 0, 1, 1, 1 }, new int[] { 1, 4 }) }; for (int i = 0; i < 3; i++) { INDArray labelsMaskStart = dsListAlignStart.get(i).getLabelsMaskArray(); INDArray labelsMaskEnd = dsListAlignEnd.get(i).getLabelsMaskArray(); @@ -654,86 +586,71 @@ public class RecordReaderDataSetiteratorTest extends BaseDL4JTest { } @Test - public void testSequenceRecordReaderSingleReader() throws Exception { - File rootDir = temporaryFolder.newFolder(); - //need to manually extract + @DisplayName("Test Sequence Record Reader Single Reader") + void testSequenceRecordReaderSingleReader() throws Exception { + File rootDir = temporaryFolder.toFile(); + // need to manually extract for (int i = 0; i < 3; i++) { FileUtils.copyFile(Resources.asFile(String.format("csvsequenceSingle_%d.txt", i)), new File(rootDir, String.format("csvsequenceSingle_%d.txt", i))); } String path = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequenceSingle_%d.txt"); - SequenceRecordReader reader = new CSVSequenceRecordReader(1, ","); reader.initialize(new NumberedFileInputSplit(path, 0, 2)); - SequenceRecordReaderDataSetIterator iteratorClassification = - new SequenceRecordReaderDataSetIterator(reader, 1, 3, 0, false); - + SequenceRecordReaderDataSetIterator iteratorClassification = new SequenceRecordReaderDataSetIterator(reader, 1, 3, 0, false); assertTrue(iteratorClassification.hasNext()); - SequenceRecordReader reader2 = new CSVSequenceRecordReader(1, ","); reader2.initialize(new NumberedFileInputSplit(path, 0, 2)); - SequenceRecordReaderDataSetIterator iteratorRegression = - new SequenceRecordReaderDataSetIterator(reader2, 1, 1, 0, true); - + SequenceRecordReaderDataSetIterator iteratorRegression = new SequenceRecordReaderDataSetIterator(reader2, 1, 1, 0, true); INDArray expF0 = Nd4j.create(1, 2, 4); - expF0.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] {1, 2})); - expF0.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] {11, 12})); - expF0.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] {21, 22})); - expF0.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] {31, 32})); - + expF0.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] { 1, 2 })); + expF0.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] { 11, 12 })); + expF0.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] { 21, 22 })); + expF0.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] { 31, 32 })); INDArray expF1 = Nd4j.create(1, 2, 4); - expF1.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] {101, 102})); - expF1.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] {111, 112})); - expF1.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] {121, 122})); - expF1.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] {131, 132})); - + expF1.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] { 101, 102 })); + expF1.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] { 111, 112 })); + expF1.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] { 121, 122 })); + expF1.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] { 131, 132 })); INDArray expF2 = Nd4j.create(1, 2, 4); - expF2.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] {201, 202})); - expF2.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] {211, 212})); - expF2.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] {221, 222})); - expF2.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] {231, 232})); - - INDArray[] expF = new INDArray[] {expF0, expF1, expF2}; - - //Expected out for classification: + expF2.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] { 201, 202 })); + expF2.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] { 211, 212 })); + expF2.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] { 221, 222 })); + expF2.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] { 231, 232 })); + INDArray[] expF = new INDArray[] { expF0, expF1, expF2 }; + // Expected out for classification: INDArray expOut0 = Nd4j.create(1, 3, 4); - expOut0.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] {1, 0, 0})); - expOut0.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] {0, 1, 0})); - expOut0.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] {0, 0, 1})); - expOut0.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] {1, 0, 0})); - + expOut0.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] { 1, 0, 0 })); + expOut0.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] { 0, 1, 0 })); + expOut0.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] { 0, 0, 1 })); + expOut0.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] { 1, 0, 0 })); INDArray expOut1 = Nd4j.create(1, 3, 4); - expOut1.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] {0, 1, 0})); - expOut1.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] {0, 0, 1})); - expOut1.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] {1, 0, 0})); - expOut1.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] {0, 0, 1})); - + expOut1.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] { 0, 1, 0 })); + expOut1.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] { 0, 0, 1 })); + expOut1.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] { 1, 0, 0 })); + expOut1.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] { 0, 0, 1 })); INDArray expOut2 = Nd4j.create(1, 3, 4); - expOut2.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] {0, 1, 0})); - expOut2.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] {1, 0, 0})); - expOut2.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] {0, 1, 0})); - expOut2.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] {0, 0, 1})); - INDArray[] expOutClassification = new INDArray[] {expOut0, expOut1, expOut2}; - - //Expected out for regression: + expOut2.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] { 0, 1, 0 })); + expOut2.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] { 1, 0, 0 })); + expOut2.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] { 0, 1, 0 })); + expOut2.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] { 0, 0, 1 })); + INDArray[] expOutClassification = new INDArray[] { expOut0, expOut1, expOut2 }; + // Expected out for regression: INDArray expOutR0 = Nd4j.create(1, 1, 4); - expOutR0.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] {0})); - expOutR0.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] {1})); - expOutR0.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] {2})); - expOutR0.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] {0})); - + expOutR0.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] { 0 })); + expOutR0.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] { 1 })); + expOutR0.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] { 2 })); + expOutR0.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] { 0 })); INDArray expOutR1 = Nd4j.create(1, 1, 4); - expOutR1.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] {1})); - expOutR1.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] {2})); - expOutR1.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] {0})); - expOutR1.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] {2})); - + expOutR1.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] { 1 })); + expOutR1.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] { 2 })); + expOutR1.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] { 0 })); + expOutR1.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] { 2 })); INDArray expOutR2 = Nd4j.create(1, 1, 4); - expOutR2.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] {1})); - expOutR2.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] {0})); - expOutR2.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] {1})); - expOutR2.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] {2})); - INDArray[] expOutRegression = new INDArray[] {expOutR0, expOutR1, expOutR2}; - + expOutR2.tensorAlongDimension(0, 1).assign(Nd4j.create(new double[] { 1 })); + expOutR2.tensorAlongDimension(1, 1).assign(Nd4j.create(new double[] { 0 })); + expOutR2.tensorAlongDimension(2, 1).assign(Nd4j.create(new double[] { 1 })); + expOutR2.tensorAlongDimension(3, 1).assign(Nd4j.create(new double[] { 2 })); + INDArray[] expOutRegression = new INDArray[] { expOutR0, expOutR1, expOutR2 }; int countC = 0; while (iteratorClassification.hasNext()) { DataSet ds = iteratorClassification.next(); @@ -741,16 +658,14 @@ public class RecordReaderDataSetiteratorTest extends BaseDL4JTest { INDArray l = ds.getLabels(); assertNull(ds.getFeaturesMaskArray()); assertNull(ds.getLabelsMaskArray()); - - assertArrayEquals(new long[] {1, 2, 4}, f.shape()); - assertArrayEquals(new long[] {1, 3, 4}, l.shape()); //One-hot representation - + assertArrayEquals(new long[] { 1, 2, 4 }, f.shape()); + // One-hot representation + assertArrayEquals(new long[] { 1, 3, 4 }, l.shape()); assertEquals(expF[countC], f); assertEquals(expOutClassification[countC++], l); } assertEquals(3, countC); assertEquals(3, iteratorClassification.totalOutcomes()); - int countF = 0; while (iteratorRegression.hasNext()) { DataSet ds = iteratorRegression.next(); @@ -758,10 +673,9 @@ public class RecordReaderDataSetiteratorTest extends BaseDL4JTest { INDArray l = ds.getLabels(); assertNull(ds.getFeaturesMaskArray()); assertNull(ds.getLabelsMaskArray()); - - assertArrayEquals(new long[] {1, 2, 4}, f.shape()); - assertArrayEquals(new long[] {1, 1, 4}, l.shape()); //Regression (single output) - + assertArrayEquals(new long[] { 1, 2, 4 }, f.shape()); + // Regression (single output) + assertArrayEquals(new long[] { 1, 1, 4 }, l.shape()); assertEquals(expF[countF], f); assertEquals(expOutRegression[countF++], l); } @@ -769,66 +683,63 @@ public class RecordReaderDataSetiteratorTest extends BaseDL4JTest { assertEquals(1, iteratorRegression.totalOutcomes()); } - @Test(expected = ZeroLengthSequenceException.class) - public void testSequenceRecordReaderSingleReaderWithEmptySequenceThrows() throws Exception { - SequenceRecordReader reader = new CSVSequenceRecordReader(1, ","); - reader.initialize(new FileSplit(Resources.asFile("empty.txt"))); - - new SequenceRecordReaderDataSetIterator(reader, 1, -1, 1, true).next(); - } - - @Test(expected = ZeroLengthSequenceException.class) - public void testSequenceRecordReaderTwoReadersWithEmptyFeatureSequenceThrows() throws Exception { - SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ","); - SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ","); - - featureReader.initialize(new FileSplit(Resources.asFile("empty.txt"))); - labelReader.initialize( - new FileSplit(Resources.asFile("csvsequencelabels_0.txt"))); - - new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, -1, true).next(); - } - - @Test(expected = ZeroLengthSequenceException.class) - public void testSequenceRecordReaderTwoReadersWithEmptyLabelSequenceThrows() throws Exception { - SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ","); - SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ","); - - File f = Resources.asFile("csvsequence_0.txt"); - featureReader.initialize(new FileSplit(f)); - labelReader.initialize(new FileSplit(Resources.asFile("empty.txt"))); - - new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, -1, true).next(); + @Test + @DisplayName("Test Sequence Record Reader Single Reader With Empty Sequence Throws") + void testSequenceRecordReaderSingleReaderWithEmptySequenceThrows() { + assertThrows(ZeroLengthSequenceException.class, () -> { + SequenceRecordReader reader = new CSVSequenceRecordReader(1, ","); + reader.initialize(new FileSplit(Resources.asFile("empty.txt"))); + new SequenceRecordReaderDataSetIterator(reader, 1, -1, 1, true).next(); + }); } @Test - public void testSequenceRecordReaderSingleReaderMetaData() throws Exception { - File rootDir = temporaryFolder.newFolder(); - //need to manually extract + @DisplayName("Test Sequence Record Reader Two Readers With Empty Feature Sequence Throws") + void testSequenceRecordReaderTwoReadersWithEmptyFeatureSequenceThrows() { + assertThrows(ZeroLengthSequenceException.class, () -> { + SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ","); + SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ","); + featureReader.initialize(new FileSplit(Resources.asFile("empty.txt"))); + labelReader.initialize(new FileSplit(Resources.asFile("csvsequencelabels_0.txt"))); + new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, -1, true).next(); + }); + } + + @Test + @DisplayName("Test Sequence Record Reader Two Readers With Empty Label Sequence Throws") + void testSequenceRecordReaderTwoReadersWithEmptyLabelSequenceThrows() { + assertThrows(ZeroLengthSequenceException.class, () -> { + SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ","); + SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ","); + File f = Resources.asFile("csvsequence_0.txt"); + featureReader.initialize(new FileSplit(f)); + labelReader.initialize(new FileSplit(Resources.asFile("empty.txt"))); + new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, -1, true).next(); + }); + } + + @Test + @DisplayName("Test Sequence Record Reader Single Reader Meta Data") + void testSequenceRecordReaderSingleReaderMetaData() throws Exception { + File rootDir = temporaryFolder.toFile(); + // need to manually extract for (int i = 0; i < 3; i++) { FileUtils.copyFile(Resources.asFile(String.format("csvsequenceSingle_%d.txt", i)), new File(rootDir, String.format("csvsequenceSingle_%d.txt", i))); } String path = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequenceSingle_%d.txt"); - SequenceRecordReader reader = new CSVSequenceRecordReader(1, ","); reader.initialize(new NumberedFileInputSplit(path, 0, 2)); - SequenceRecordReaderDataSetIterator iteratorClassification = - new SequenceRecordReaderDataSetIterator(reader, 1, 3, 0, false); - + SequenceRecordReaderDataSetIterator iteratorClassification = new SequenceRecordReaderDataSetIterator(reader, 1, 3, 0, false); SequenceRecordReader reader2 = new CSVSequenceRecordReader(1, ","); reader2.initialize(new NumberedFileInputSplit(path, 0, 2)); - SequenceRecordReaderDataSetIterator iteratorRegression = - new SequenceRecordReaderDataSetIterator(reader2, 1, 1, 0, true); - + SequenceRecordReaderDataSetIterator iteratorRegression = new SequenceRecordReaderDataSetIterator(reader2, 1, 1, 0, true); iteratorClassification.setCollectMetaData(true); iteratorRegression.setCollectMetaData(true); - while (iteratorClassification.hasNext()) { DataSet ds = iteratorClassification.next(); DataSet fromMeta = iteratorClassification.loadFromMetaData(ds.getExampleMetaData(RecordMetaData.class)); assertEquals(ds, fromMeta); } - while (iteratorRegression.hasNext()) { DataSet ds = iteratorRegression.next(); DataSet fromMeta = iteratorRegression.loadFromMetaData(ds.getExampleMetaData(RecordMetaData.class)); @@ -836,170 +747,117 @@ public class RecordReaderDataSetiteratorTest extends BaseDL4JTest { } } - @Test - public void testSeqRRDSIArrayWritableOneReader() { - + @DisplayName("Test Seq RRDSI Array Writable One Reader") + void testSeqRRDSIArrayWritableOneReader() { List> sequence1 = new ArrayList<>(); - sequence1.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] {1, 2, 3}, new long[]{1,3})), - new IntWritable(0))); - sequence1.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] {4, 5, 6}, new long[]{1,3})), - new IntWritable(1))); + sequence1.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] { 1, 2, 3 }, new long[] { 1, 3 })), new IntWritable(0))); + sequence1.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] { 4, 5, 6 }, new long[] { 1, 3 })), new IntWritable(1))); List> sequence2 = new ArrayList<>(); - sequence2.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] {7, 8, 9}, new long[]{1,3})), - new IntWritable(2))); - sequence2.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] {10, 11, 12}, new long[]{1,3})), - new IntWritable(3))); - - + sequence2.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] { 7, 8, 9 }, new long[] { 1, 3 })), new IntWritable(2))); + sequence2.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] { 10, 11, 12 }, new long[] { 1, 3 })), new IntWritable(3))); SequenceRecordReader rr = new CollectionSequenceRecordReader(Arrays.asList(sequence1, sequence2)); - SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(rr, 2, 4, 1, false); - DataSet ds = iter.next(); - - INDArray expFeatures = Nd4j.create(2, 3, 2); //2 examples, 3 values per time step, 2 time steps - expFeatures.tensorAlongDimension(0, 1, 2).assign(Nd4j.create(new double[][] {{1, 4}, {2, 5}, {3, 6}})); - expFeatures.tensorAlongDimension(1, 1, 2).assign(Nd4j.create(new double[][] {{7, 10}, {8, 11}, {9, 12}})); - + // 2 examples, 3 values per time step, 2 time steps + INDArray expFeatures = Nd4j.create(2, 3, 2); + expFeatures.tensorAlongDimension(0, 1, 2).assign(Nd4j.create(new double[][] { { 1, 4 }, { 2, 5 }, { 3, 6 } })); + expFeatures.tensorAlongDimension(1, 1, 2).assign(Nd4j.create(new double[][] { { 7, 10 }, { 8, 11 }, { 9, 12 } })); INDArray expLabels = Nd4j.create(2, 4, 2); - expLabels.tensorAlongDimension(0, 1, 2).assign(Nd4j.create(new double[][] {{1, 0}, {0, 1}, {0, 0}, {0, 0}})); - expLabels.tensorAlongDimension(1, 1, 2).assign(Nd4j.create(new double[][] {{0, 0}, {0, 0}, {1, 0}, {0, 1}})); - + expLabels.tensorAlongDimension(0, 1, 2).assign(Nd4j.create(new double[][] { { 1, 0 }, { 0, 1 }, { 0, 0 }, { 0, 0 } })); + expLabels.tensorAlongDimension(1, 1, 2).assign(Nd4j.create(new double[][] { { 0, 0 }, { 0, 0 }, { 1, 0 }, { 0, 1 } })); assertEquals(expFeatures, ds.getFeatures()); assertEquals(expLabels, ds.getLabels()); } @Test - public void testSeqRRDSIArrayWritableOneReaderRegression() { - //Regression, where the output is an array writable + @DisplayName("Test Seq RRDSI Array Writable One Reader Regression") + void testSeqRRDSIArrayWritableOneReaderRegression() { + // Regression, where the output is an array writable List> sequence1 = new ArrayList<>(); - sequence1.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] {1, 2, 3}, new long[]{1,3})), - new NDArrayWritable(Nd4j.create(new double[] {100, 200, 300}, new long[]{1,3})))); - sequence1.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] {4, 5, 6}, new long[]{1,3})), - new NDArrayWritable(Nd4j.create(new double[] {400, 500, 600}, new long[]{1,3})))); + sequence1.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] { 1, 2, 3 }, new long[] { 1, 3 })), new NDArrayWritable(Nd4j.create(new double[] { 100, 200, 300 }, new long[] { 1, 3 })))); + sequence1.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] { 4, 5, 6 }, new long[] { 1, 3 })), new NDArrayWritable(Nd4j.create(new double[] { 400, 500, 600 }, new long[] { 1, 3 })))); List> sequence2 = new ArrayList<>(); - sequence2.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] {7, 8, 9}, new long[]{1,3})), - new NDArrayWritable(Nd4j.create(new double[] {700, 800, 900}, new long[]{1,3})))); - sequence2.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] {10, 11, 12}, new long[]{1,3})), - new NDArrayWritable(Nd4j.create(new double[] {1000, 1100, 1200}, new long[]{1,3})))); - - + sequence2.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] { 7, 8, 9 }, new long[] { 1, 3 })), new NDArrayWritable(Nd4j.create(new double[] { 700, 800, 900 }, new long[] { 1, 3 })))); + sequence2.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] { 10, 11, 12 }, new long[] { 1, 3 })), new NDArrayWritable(Nd4j.create(new double[] { 1000, 1100, 1200 }, new long[] { 1, 3 })))); SequenceRecordReader rr = new CollectionSequenceRecordReader(Arrays.asList(sequence1, sequence2)); - SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(rr, 2, -1, 1, true); - DataSet ds = iter.next(); - - INDArray expFeatures = Nd4j.create(2, 3, 2); //2 examples, 3 values per time step, 2 time steps - expFeatures.tensorAlongDimension(0, 1, 2).assign(Nd4j.create(new double[][] {{1, 4}, {2, 5}, {3, 6}})); - expFeatures.tensorAlongDimension(1, 1, 2).assign(Nd4j.create(new double[][] {{7, 10}, {8, 11}, {9, 12}})); - + // 2 examples, 3 values per time step, 2 time steps + INDArray expFeatures = Nd4j.create(2, 3, 2); + expFeatures.tensorAlongDimension(0, 1, 2).assign(Nd4j.create(new double[][] { { 1, 4 }, { 2, 5 }, { 3, 6 } })); + expFeatures.tensorAlongDimension(1, 1, 2).assign(Nd4j.create(new double[][] { { 7, 10 }, { 8, 11 }, { 9, 12 } })); INDArray expLabels = Nd4j.create(2, 3, 2); - expLabels.tensorAlongDimension(0, 1, 2) - .assign(Nd4j.create(new double[][] {{100, 400}, {200, 500}, {300, 600}})); - expLabels.tensorAlongDimension(1, 1, 2) - .assign(Nd4j.create(new double[][] {{700, 1000}, {800, 1100}, {900, 1200}})); - + expLabels.tensorAlongDimension(0, 1, 2).assign(Nd4j.create(new double[][] { { 100, 400 }, { 200, 500 }, { 300, 600 } })); + expLabels.tensorAlongDimension(1, 1, 2).assign(Nd4j.create(new double[][] { { 700, 1000 }, { 800, 1100 }, { 900, 1200 } })); assertEquals(expFeatures, ds.getFeatures()); assertEquals(expLabels, ds.getLabels()); } @Test - public void testSeqRRDSIMultipleArrayWritablesOneReader() { - //Input with multiple array writables: - + @DisplayName("Test Seq RRDSI Multiple Array Writables One Reader") + void testSeqRRDSIMultipleArrayWritablesOneReader() { + // Input with multiple array writables: List> sequence1 = new ArrayList<>(); - sequence1.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] {1, 2, 3}, new long[]{1,3})), - new NDArrayWritable(Nd4j.create(new double[] {100, 200, 300}, new long[]{1,3})), new IntWritable(0))); - sequence1.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] {4, 5, 6}, new long[]{1,3})), - new NDArrayWritable(Nd4j.create(new double[] {400, 500, 600}, new long[]{1,3})), new IntWritable(1))); + sequence1.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] { 1, 2, 3 }, new long[] { 1, 3 })), new NDArrayWritable(Nd4j.create(new double[] { 100, 200, 300 }, new long[] { 1, 3 })), new IntWritable(0))); + sequence1.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] { 4, 5, 6 }, new long[] { 1, 3 })), new NDArrayWritable(Nd4j.create(new double[] { 400, 500, 600 }, new long[] { 1, 3 })), new IntWritable(1))); List> sequence2 = new ArrayList<>(); - sequence2.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] {7, 8, 9}, new long[]{1,3})), - new NDArrayWritable(Nd4j.create(new double[] {700, 800, 900}, new long[]{1,3})), new IntWritable(2))); - sequence2.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] {10, 11, 12}, new long[]{1,3})), - new NDArrayWritable(Nd4j.create(new double[] {1000, 1100, 1200}, new long[]{1,3})), new IntWritable(3))); - - + sequence2.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] { 7, 8, 9 }, new long[] { 1, 3 })), new NDArrayWritable(Nd4j.create(new double[] { 700, 800, 900 }, new long[] { 1, 3 })), new IntWritable(2))); + sequence2.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] { 10, 11, 12 }, new long[] { 1, 3 })), new NDArrayWritable(Nd4j.create(new double[] { 1000, 1100, 1200 }, new long[] { 1, 3 })), new IntWritable(3))); SequenceRecordReader rr = new CollectionSequenceRecordReader(Arrays.asList(sequence1, sequence2)); - SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(rr, 2, 4, 2, false); - DataSet ds = iter.next(); - - INDArray expFeatures = Nd4j.create(2, 6, 2); //2 examples, 6 values per time step, 2 time steps - expFeatures.tensorAlongDimension(0, 1, 2).assign( - Nd4j.create(new double[][] {{1, 4}, {2, 5}, {3, 6}, {100, 400}, {200, 500}, {300, 600}})); - expFeatures.tensorAlongDimension(1, 1, 2).assign( - Nd4j.create(new double[][] {{7, 10}, {8, 11}, {9, 12}, {700, 1000}, {800, 1100}, {900, 1200}})); - + // 2 examples, 6 values per time step, 2 time steps + INDArray expFeatures = Nd4j.create(2, 6, 2); + expFeatures.tensorAlongDimension(0, 1, 2).assign(Nd4j.create(new double[][] { { 1, 4 }, { 2, 5 }, { 3, 6 }, { 100, 400 }, { 200, 500 }, { 300, 600 } })); + expFeatures.tensorAlongDimension(1, 1, 2).assign(Nd4j.create(new double[][] { { 7, 10 }, { 8, 11 }, { 9, 12 }, { 700, 1000 }, { 800, 1100 }, { 900, 1200 } })); INDArray expLabels = Nd4j.create(2, 4, 2); - expLabels.tensorAlongDimension(0, 1, 2).assign(Nd4j.create(new double[][] {{1, 0}, {0, 1}, {0, 0}, {0, 0}})); - expLabels.tensorAlongDimension(1, 1, 2).assign(Nd4j.create(new double[][] {{0, 0}, {0, 0}, {1, 0}, {0, 1}})); - + expLabels.tensorAlongDimension(0, 1, 2).assign(Nd4j.create(new double[][] { { 1, 0 }, { 0, 1 }, { 0, 0 }, { 0, 0 } })); + expLabels.tensorAlongDimension(1, 1, 2).assign(Nd4j.create(new double[][] { { 0, 0 }, { 0, 0 }, { 1, 0 }, { 0, 1 } })); assertEquals(expFeatures, ds.getFeatures()); assertEquals(expLabels, ds.getLabels()); } @Test - public void testSeqRRDSIArrayWritableTwoReaders() { + @DisplayName("Test Seq RRDSI Array Writable Two Readers") + void testSeqRRDSIArrayWritableTwoReaders() { List> sequence1 = new ArrayList<>(); - sequence1.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] {1, 2, 3}, new long[]{1,3})), - new IntWritable(100))); - sequence1.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] {4, 5, 6}, new long[]{1,3})), - new IntWritable(200))); + sequence1.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] { 1, 2, 3 }, new long[] { 1, 3 })), new IntWritable(100))); + sequence1.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] { 4, 5, 6 }, new long[] { 1, 3 })), new IntWritable(200))); List> sequence2 = new ArrayList<>(); - sequence2.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] {7, 8, 9}, new long[]{1,3})), - new IntWritable(300))); - sequence2.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] {10, 11, 12}, new long[]{1,3})), - new IntWritable(400))); + sequence2.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] { 7, 8, 9 }, new long[] { 1, 3 })), new IntWritable(300))); + sequence2.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] { 10, 11, 12 }, new long[] { 1, 3 })), new IntWritable(400))); SequenceRecordReader rrFeatures = new CollectionSequenceRecordReader(Arrays.asList(sequence1, sequence2)); - List> sequence1L = new ArrayList<>(); - sequence1L.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] {100, 200, 300}, new long[]{1,3})), - new IntWritable(101))); - sequence1L.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] {400, 500, 600}, new long[]{1,3})), - new IntWritable(201))); + sequence1L.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] { 100, 200, 300 }, new long[] { 1, 3 })), new IntWritable(101))); + sequence1L.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] { 400, 500, 600 }, new long[] { 1, 3 })), new IntWritable(201))); List> sequence2L = new ArrayList<>(); - sequence2L.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] {700, 800, 900}, new long[]{1,3})), - new IntWritable(301))); - sequence2L.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] {1000, 1100, 1200}, new long[]{1,3})), - new IntWritable(401))); + sequence2L.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] { 700, 800, 900 }, new long[] { 1, 3 })), new IntWritable(301))); + sequence2L.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] { 1000, 1100, 1200 }, new long[] { 1, 3 })), new IntWritable(401))); SequenceRecordReader rrLabels = new CollectionSequenceRecordReader(Arrays.asList(sequence1L, sequence2L)); - - SequenceRecordReaderDataSetIterator iter = - new SequenceRecordReaderDataSetIterator(rrFeatures, rrLabels, 2, -1, true); - - INDArray expFeatures = Nd4j.create(2, 4, 2); //2 examples, 4 values per time step, 2 time steps - expFeatures.tensorAlongDimension(0, 1, 2) - .assign(Nd4j.create(new double[][] {{1, 4}, {2, 5}, {3, 6}, {100, 200}})); - expFeatures.tensorAlongDimension(1, 1, 2) - .assign(Nd4j.create(new double[][] {{7, 10}, {8, 11}, {9, 12}, {300, 400}})); - + SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(rrFeatures, rrLabels, 2, -1, true); + // 2 examples, 4 values per time step, 2 time steps + INDArray expFeatures = Nd4j.create(2, 4, 2); + expFeatures.tensorAlongDimension(0, 1, 2).assign(Nd4j.create(new double[][] { { 1, 4 }, { 2, 5 }, { 3, 6 }, { 100, 200 } })); + expFeatures.tensorAlongDimension(1, 1, 2).assign(Nd4j.create(new double[][] { { 7, 10 }, { 8, 11 }, { 9, 12 }, { 300, 400 } })); INDArray expLabels = Nd4j.create(2, 4, 2); - expLabels.tensorAlongDimension(0, 1, 2) - .assign(Nd4j.create(new double[][] {{100, 400}, {200, 500}, {300, 600}, {101, 201}})); - expLabels.tensorAlongDimension(1, 1, 2) - .assign(Nd4j.create(new double[][] {{700, 1000}, {800, 1100}, {900, 1200}, {301, 401}})); - + expLabels.tensorAlongDimension(0, 1, 2).assign(Nd4j.create(new double[][] { { 100, 400 }, { 200, 500 }, { 300, 600 }, { 101, 201 } })); + expLabels.tensorAlongDimension(1, 1, 2).assign(Nd4j.create(new double[][] { { 700, 1000 }, { 800, 1100 }, { 900, 1200 }, { 301, 401 } })); DataSet ds = iter.next(); assertEquals(expFeatures, ds.getFeatures()); assertEquals(expLabels, ds.getLabels()); } @Test - public void testRecordReaderMetaData() throws Exception { - + @DisplayName("Test Record Reader Meta Data") + void testRecordReaderMetaData() throws Exception { RecordReader csv = new CSVRecordReader(); csv.initialize(new FileSplit(Resources.asFile("iris.txt"))); - int batchSize = 10; int labelIdx = 4; int numClasses = 3; - RecordReaderDataSetIterator rrdsi = new RecordReaderDataSetIterator(csv, batchSize, labelIdx, numClasses); rrdsi.setCollectMetaData(true); - while (rrdsi.hasNext()) { DataSet ds = rrdsi.next(); List meta = ds.getExampleMetaData(RecordMetaData.class); @@ -1007,98 +865,75 @@ public class RecordReaderDataSetiteratorTest extends BaseDL4JTest { for (RecordMetaData m : meta) { Record r = csv.loadFromMetaData(m); INDArray row = ds.getFeatures().getRow(i); -// if(i <= 3) { -// System.out.println(m.getLocation() + "\t" + r.getRecord() + "\t" + row); -// } - + // if(i <= 3) { + // System.out.println(m.getLocation() + "\t" + r.getRecord() + "\t" + row); + // } for (int j = 0; j < 4; j++) { double exp = r.getRecord().get(j).toDouble(); double act = row.getDouble(j); - assertEquals("Failed on idx: " + j, exp, act, 1e-6); + assertEquals( exp, act, 1e-6,"Failed on idx: " + j); } i++; } -// System.out.println(); - + // System.out.println(); DataSet fromMeta = rrdsi.loadFromMetaData(meta); assertEquals(ds, fromMeta); } } @Test - public void testRRDSIwithAsync() throws Exception { + @DisplayName("Test RRDS Iwith Async") + void testRRDSIwithAsync() throws Exception { RecordReader csv = new CSVRecordReader(); csv.initialize(new FileSplit(Resources.asFile("iris.txt"))); - int batchSize = 10; int labelIdx = 4; int numClasses = 3; - RecordReaderDataSetIterator rrdsi = new RecordReaderDataSetIterator(csv, batchSize, labelIdx, numClasses); AsyncDataSetIterator adsi = new AsyncDataSetIterator(rrdsi, 8, true); while (adsi.hasNext()) { DataSet ds = adsi.next(); - } - } - - @Test - public void testRecordReaderDataSetIteratorNDArrayWritableLabels() { - + @DisplayName("Test Record Reader Data Set Iterator ND Array Writable Labels") + void testRecordReaderDataSetIteratorNDArrayWritableLabels() { Collection> data = new ArrayList<>(); - - data.add(Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), - new NDArrayWritable(Nd4j.create(new double[] {1.1, 2.1, 3.1}, new long[]{1,3})))); - data.add(Arrays.asList(new DoubleWritable(2), new DoubleWritable(3), - new NDArrayWritable(Nd4j.create(new double[] {4.1, 5.1, 6.1}, new long[]{1,3})))); - data.add(Arrays.asList(new DoubleWritable(4), new DoubleWritable(5), - new NDArrayWritable(Nd4j.create(new double[] {7.1, 8.1, 9.1}, new long[]{1,3})))); - + data.add(Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), new NDArrayWritable(Nd4j.create(new double[] { 1.1, 2.1, 3.1 }, new long[] { 1, 3 })))); + data.add(Arrays.asList(new DoubleWritable(2), new DoubleWritable(3), new NDArrayWritable(Nd4j.create(new double[] { 4.1, 5.1, 6.1 }, new long[] { 1, 3 })))); + data.add(Arrays.asList(new DoubleWritable(4), new DoubleWritable(5), new NDArrayWritable(Nd4j.create(new double[] { 7.1, 8.1, 9.1 }, new long[] { 1, 3 })))); RecordReader rr = new CollectionRecordReader(data); int batchSize = 3; int labelIndexFrom = 2; int labelIndexTo = 2; boolean regression = true; - DataSetIterator rrdsi = - new RecordReaderDataSetIterator(rr, batchSize, labelIndexFrom, labelIndexTo, regression); - + DataSetIterator rrdsi = new RecordReaderDataSetIterator(rr, batchSize, labelIndexFrom, labelIndexTo, regression); DataSet ds = rrdsi.next(); - INDArray expFeatures = Nd4j.create(new float[][] {{0, 1}, {2, 3}, {4, 5}}); - INDArray expLabels = Nd4j.create(new float[][] {{1.1f, 2.1f, 3.1f}, {4.1f, 5.1f, 6.1f}, {7.1f, 8.1f, 9.1f}}); - + INDArray expFeatures = Nd4j.create(new float[][] { { 0, 1 }, { 2, 3 }, { 4, 5 } }); + INDArray expLabels = Nd4j.create(new float[][] { { 1.1f, 2.1f, 3.1f }, { 4.1f, 5.1f, 6.1f }, { 7.1f, 8.1f, 9.1f } }); assertEquals(expFeatures, ds.getFeatures()); assertEquals(expLabels, ds.getLabels()); - - //ALSO: test if we have NDArrayWritables for BOTH the features and the labels + // ALSO: test if we have NDArrayWritables for BOTH the features and the labels data = new ArrayList<>(); - - data.add(Arrays.asList(new NDArrayWritable(Nd4j.create(new double[] {0, 1}, new long[]{1,2})), - new NDArrayWritable(Nd4j.create(new double[] {1.1, 2.1, 3.1}, new long[]{1,3})))); - data.add(Arrays.asList(new NDArrayWritable(Nd4j.create(new double[] {2, 3}, new long[]{1,2})), - new NDArrayWritable(Nd4j.create(new double[] {4.1, 5.1, 6.1}, new long[]{1,3})))); - data.add(Arrays.asList(new NDArrayWritable(Nd4j.create(new double[] {4, 5}, new long[]{1,2})), - new NDArrayWritable(Nd4j.create(new double[] {7.1, 8.1, 9.1}, new long[]{1,3})))); + data.add(Arrays.asList(new NDArrayWritable(Nd4j.create(new double[] { 0, 1 }, new long[] { 1, 2 })), new NDArrayWritable(Nd4j.create(new double[] { 1.1, 2.1, 3.1 }, new long[] { 1, 3 })))); + data.add(Arrays.asList(new NDArrayWritable(Nd4j.create(new double[] { 2, 3 }, new long[] { 1, 2 })), new NDArrayWritable(Nd4j.create(new double[] { 4.1, 5.1, 6.1 }, new long[] { 1, 3 })))); + data.add(Arrays.asList(new NDArrayWritable(Nd4j.create(new double[] { 4, 5 }, new long[] { 1, 2 })), new NDArrayWritable(Nd4j.create(new double[] { 7.1, 8.1, 9.1 }, new long[] { 1, 3 })))); labelIndexFrom = 1; labelIndexTo = 1; - rr = new CollectionRecordReader(data); rrdsi = new RecordReaderDataSetIterator(rr, batchSize, labelIndexFrom, labelIndexTo, regression); - DataSet ds2 = rrdsi.next(); assertEquals(expFeatures, ds2.getFeatures()); assertEquals(expLabels, ds2.getLabels()); } - @Test - @Ignore - public void specialRRTest4() throws Exception { + @Disabled + @DisplayName("Special RR Test 4") + void specialRRTest4() throws Exception { RecordReader rr = new SpecialImageRecordReader(25000, 10, 3, 224, 224); RecordReaderDataSetIterator rrdsi = new RecordReaderDataSetIterator(rr, 128); - int cnt = 0; int examples = 0; while (rrdsi.hasNext()) { @@ -1106,14 +941,12 @@ public class RecordReaderDataSetiteratorTest extends BaseDL4JTest { assertEquals(128, ds.numExamples()); for (int i = 0; i < ds.numExamples(); i++) { INDArray example = ds.getFeatures().tensorAlongDimension(i, 1, 2, 3).dup(); - // assertEquals("Failed on DataSet [" + cnt + "], example [" + i + "]", (double) examples, example.meanNumber().doubleValue(), 0.01); - - // assertEquals("Failed on DataSet [" + cnt + "], example [" + i + "]", (double) examples, ds.getLabels().getRow(i).meanNumber().doubleValue(), 0.01); + // assertEquals("Failed on DataSet [" + cnt + "], example [" + i + "]", (double) examples, example.meanNumber().doubleValue(), 0.01); + // assertEquals("Failed on DataSet [" + cnt + "], example [" + i + "]", (double) examples, ds.getLabels().getRow(i).meanNumber().doubleValue(), 0.01); examples++; } cnt++; } - } /* @@ -1196,82 +1029,61 @@ public class RecordReaderDataSetiteratorTest extends BaseDL4JTest { } */ - - @Test - public void testRecordReaderDataSetIteratorConcat() { - - //[DoubleWritable, DoubleWritable, NDArrayWritable([1,10]), IntWritable] -> concatenate to a [1,13] feature vector automatically. - - List l = Arrays.asList(new DoubleWritable(1), - new NDArrayWritable(Nd4j.create(new double[] {2, 3, 4})), new DoubleWritable(5), - new NDArrayWritable(Nd4j.create(new double[] {6, 7, 8})), new IntWritable(9), - new IntWritable(1)); - + @DisplayName("Test Record Reader Data Set Iterator Concat") + void testRecordReaderDataSetIteratorConcat() { + // [DoubleWritable, DoubleWritable, NDArrayWritable([1,10]), IntWritable] -> concatenate to a [1,13] feature vector automatically. + List l = Arrays.asList(new DoubleWritable(1), new NDArrayWritable(Nd4j.create(new double[] { 2, 3, 4 })), new DoubleWritable(5), new NDArrayWritable(Nd4j.create(new double[] { 6, 7, 8 })), new IntWritable(9), new IntWritable(1)); RecordReader rr = new CollectionRecordReader(Collections.singletonList(l)); - DataSetIterator iter = new RecordReaderDataSetIterator(rr, 1, 5, 3); - DataSet ds = iter.next(); - INDArray expF = Nd4j.create(new float[] {1, 2, 3, 4, 5, 6, 7, 8, 9}, new int[]{1,9}); - INDArray expL = Nd4j.create(new float[] {0, 1, 0}, new int[]{1,3}); - + INDArray expF = Nd4j.create(new float[] { 1, 2, 3, 4, 5, 6, 7, 8, 9 }, new int[] { 1, 9 }); + INDArray expL = Nd4j.create(new float[] { 0, 1, 0 }, new int[] { 1, 3 }); assertEquals(expF, ds.getFeatures()); assertEquals(expL, ds.getLabels()); } @Test - public void testRecordReaderDataSetIteratorConcat2() { + @DisplayName("Test Record Reader Data Set Iterator Concat 2") + void testRecordReaderDataSetIteratorConcat2() { List l = new ArrayList<>(); l.add(new IntWritable(0)); l.add(new NDArrayWritable(Nd4j.arange(1, 9))); l.add(new IntWritable(9)); - RecordReader rr = new CollectionRecordReader(Collections.singletonList(l)); DataSetIterator iter = new RecordReaderDataSetIterator(rr, 1); - DataSet ds = iter.next(); - INDArray expF = Nd4j.create(new float[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, new int[]{1,10}); - + INDArray expF = Nd4j.create(new float[] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 }, new int[] { 1, 10 }); assertEquals(expF, ds.getFeatures()); } @Test - public void testRecordReaderDataSetIteratorDisjointFeatures() { - - //Idea: input vector is like [f,f,f,f,l,l,f,f] or similar - i.e., label writables aren't start/end - - List l = Arrays.asList(new DoubleWritable(1), - new NDArrayWritable(Nd4j.create(new float[] {2, 3, 4}, new long[]{1,3})), new DoubleWritable(5), - new NDArrayWritable(Nd4j.create(new float[] {6, 7, 8}, new long[]{1,3}))); - - INDArray expF = Nd4j.create(new float[] {1, 6, 7, 8}, new long[]{1,4}); - INDArray expL = Nd4j.create(new float[] {2, 3, 4, 5}, new long[]{1,4}); - + @DisplayName("Test Record Reader Data Set Iterator Disjoint Features") + void testRecordReaderDataSetIteratorDisjointFeatures() { + // Idea: input vector is like [f,f,f,f,l,l,f,f] or similar - i.e., label writables aren't start/end + List l = Arrays.asList(new DoubleWritable(1), new NDArrayWritable(Nd4j.create(new float[] { 2, 3, 4 }, new long[] { 1, 3 })), new DoubleWritable(5), new NDArrayWritable(Nd4j.create(new float[] { 6, 7, 8 }, new long[] { 1, 3 }))); + INDArray expF = Nd4j.create(new float[] { 1, 6, 7, 8 }, new long[] { 1, 4 }); + INDArray expL = Nd4j.create(new float[] { 2, 3, 4, 5 }, new long[] { 1, 4 }); RecordReader rr = new CollectionRecordReader(Collections.singletonList(l)); - DataSetIterator iter = new RecordReaderDataSetIterator(rr, 1, 1, 2, true); - DataSet ds = iter.next(); assertEquals(expF, ds.getFeatures()); assertEquals(expL, ds.getLabels()); } @Test - public void testNormalizerPrefetchReset() throws Exception { - //Check NPE fix for: https://github.com/eclipse/deeplearning4j/issues/4214 + @DisplayName("Test Normalizer Prefetch Reset") + void testNormalizerPrefetchReset() throws Exception { + // Check NPE fix for: https://github.com/eclipse/deeplearning4j/issues/4214 RecordReader csv = new CSVRecordReader(); csv.initialize(new FileSplit(Resources.asFile("iris.txt"))); - int batchSize = 3; - DataSetIterator iter = new RecordReaderDataSetIterator(csv, batchSize, 4, 4, true); - DataNormalization normalizer = new NormalizerMinMaxScaler(0, 1); normalizer.fit(iter); iter.setPreProcessor(normalizer); - - iter.inputColumns(); //Prefetch + // Prefetch + iter.inputColumns(); iter.totalOutcomes(); iter.hasNext(); iter.reset(); @@ -1279,94 +1091,71 @@ public class RecordReaderDataSetiteratorTest extends BaseDL4JTest { } @Test - public void testReadingFromStream() throws Exception { - - for(boolean b : new boolean[]{false, true}) { + @DisplayName("Test Reading From Stream") + void testReadingFromStream() throws Exception { + for (boolean b : new boolean[] { false, true }) { int batchSize = 1; int labelIndex = 4; int numClasses = 3; InputStream dataFile = Resources.asStream("iris.txt"); RecordReader recordReader = new CSVRecordReader(0, ','); recordReader.initialize(new InputStreamInputSplit(dataFile)); - assertTrue(recordReader.hasNext()); assertFalse(recordReader.resetSupported()); - DataSetIterator iterator; - if(b){ - iterator = new RecordReaderDataSetIterator.Builder(recordReader, batchSize) - .classification(labelIndex, numClasses) - .build(); + if (b) { + iterator = new RecordReaderDataSetIterator.Builder(recordReader, batchSize).classification(labelIndex, numClasses).build(); } else { iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, numClasses); } assertFalse(iterator.resetSupported()); - int count = 0; while (iterator.hasNext()) { assertNotNull(iterator.next()); count++; } - assertEquals(150, count); - try { iterator.reset(); fail("Expected exception"); } catch (Exception e) { - //expected + // expected } } } - @Test - public void testImagesRRDSI() throws Exception { - File parentDir = temporaryFolder.newFolder(); + @DisplayName("Test Images RRDSI") + void testImagesRRDSI() throws Exception { + File parentDir = temporaryFolder.toFile(); parentDir.deleteOnExit(); String str1 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Zico/"); String str2 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Ziwang_Xu/"); - File f2 = new File(str2); File f1 = new File(str1); f1.mkdirs(); f2.mkdirs(); - - TestUtils.writeStreamToFile(new File(FilenameUtils.concat(f1.getPath(), "Zico_0001.jpg")), - new ClassPathResource("lfwtest/Zico/Zico_0001.jpg").getInputStream()); - TestUtils.writeStreamToFile(new File(FilenameUtils.concat(f2.getPath(), "Ziwang_Xu_0001.jpg")), - new ClassPathResource("lfwtest/Ziwang_Xu/Ziwang_Xu_0001.jpg").getInputStream()); - - + TestUtils.writeStreamToFile(new File(FilenameUtils.concat(f1.getPath(), "Zico_0001.jpg")), new ClassPathResource("lfwtest/Zico/Zico_0001.jpg").getInputStream()); + TestUtils.writeStreamToFile(new File(FilenameUtils.concat(f2.getPath(), "Ziwang_Xu_0001.jpg")), new ClassPathResource("lfwtest/Ziwang_Xu/Ziwang_Xu_0001.jpg").getInputStream()); Random r = new Random(12345); ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator(); - ImageRecordReader rr1 = new ImageRecordReader(28, 28, 3, labelMaker); rr1.initialize(new FileSplit(parentDir)); - - - RecordReaderDataSetIterator rrdsi = new RecordReaderDataSetIterator(rr1,2); + RecordReaderDataSetIterator rrdsi = new RecordReaderDataSetIterator(rr1, 2); DataSet ds = rrdsi.next(); - assertArrayEquals(new long[]{2, 3, 28, 28}, ds.getFeatures().shape()); - assertArrayEquals(new long[]{2, 2}, ds.getLabels().shape()); - - - //Check the same thing via the builder: + assertArrayEquals(new long[] { 2, 3, 28, 28 }, ds.getFeatures().shape()); + assertArrayEquals(new long[] { 2, 2 }, ds.getLabels().shape()); + // Check the same thing via the builder: rr1.reset(); - rrdsi = new RecordReaderDataSetIterator.Builder(rr1, 2) - .classification(1,2) - .build(); - - + rrdsi = new RecordReaderDataSetIterator.Builder(rr1, 2).classification(1, 2).build(); ds = rrdsi.next(); - assertArrayEquals(new long[]{2, 3, 28, 28}, ds.getFeatures().shape()); - assertArrayEquals(new long[]{2, 2}, ds.getLabels().shape()); + assertArrayEquals(new long[] { 2, 3, 28, 28 }, ds.getFeatures().shape()); + assertArrayEquals(new long[] { 2, 2 }, ds.getLabels().shape()); } - - @Test - public void testSeqRRDSINoLabels(){ + @DisplayName("Test Seq RRDSI No Labels") + void testSeqRRDSINoLabels() { List> sequence1 = new ArrayList<>(); sequence1.add(Arrays.asList((Writable) new DoubleWritable(1), new DoubleWritable(2))); sequence1.add(Arrays.asList((Writable) new DoubleWritable(3), new DoubleWritable(4))); @@ -1375,20 +1164,16 @@ public class RecordReaderDataSetiteratorTest extends BaseDL4JTest { sequence2.add(Arrays.asList((Writable) new DoubleWritable(10), new DoubleWritable(20))); sequence2.add(Arrays.asList((Writable) new DoubleWritable(30), new DoubleWritable(40))); SequenceRecordReader rrFeatures = new CollectionSequenceRecordReader(Arrays.asList(sequence1, sequence2)); - SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(rrFeatures, 2, -1, -1); - DataSet ds = iter.next(); assertNotNull(ds.getFeatures()); assertNull(ds.getLabels()); } - @Test - public void testCollectMetaData(){ - RecordReaderDataSetIterator trainIter = new RecordReaderDataSetIterator.Builder(new CollectionRecordReader(Collections.>emptyList()), 1) - .collectMetaData(true) - .build(); + @DisplayName("Test Collect Meta Data") + void testCollectMetaData() { + RecordReaderDataSetIterator trainIter = new RecordReaderDataSetIterator.Builder(new CollectionRecordReader(Collections.>emptyList()), 1).collectMetaData(true).build(); assertTrue(trainIter.isCollectMetaData()); trainIter.setCollectMetaData(false); assertFalse(trainIter.isCollectMetaData()); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIteratorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIteratorTest.java index 7901ba71f..507d80e9e 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIteratorTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIteratorTest.java @@ -17,10 +17,8 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.datasets.datavec; - import org.junit.rules.Timeout; import org.nd4j.shade.guava.io.Files; import org.apache.commons.io.FileUtils; @@ -47,8 +45,8 @@ import org.datavec.image.recordreader.ImageRecordReader; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.TestUtils; import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.MultiDataSet; @@ -58,42 +56,40 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.resources.Resources; - import java.io.*; import java.net.URI; import java.util.*; - -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; import static org.nd4j.linalg.indexing.NDArrayIndex.all; import static org.nd4j.linalg.indexing.NDArrayIndex.interval; import static org.nd4j.linalg.indexing.NDArrayIndex.point; +import org.junit.jupiter.api.DisplayName; +import java.nio.file.Path; +import org.junit.jupiter.api.extension.ExtendWith; -public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest { +@DisplayName("Record Reader Multi Data Set Iterator Test") +class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest { - @Rule - public TemporaryFolder temporaryFolder = new TemporaryFolder(); + @TempDir + public Path temporaryFolder; @Rule public Timeout timeout = Timeout.seconds(300); @Test - public void testsBasic() throws Exception { - //Load details from CSV files; single input/output -> compare to RecordReaderDataSetIterator + @DisplayName("Tests Basic") + void testsBasic() throws Exception { + // Load details from CSV files; single input/output -> compare to RecordReaderDataSetIterator RecordReader rr = new CSVRecordReader(0, ','); rr.initialize(new FileSplit(Resources.asFile("iris.txt"))); RecordReaderDataSetIterator rrdsi = new RecordReaderDataSetIterator(rr, 10, 4, 3); - RecordReader rr2 = new CSVRecordReader(0, ','); rr2.initialize(new FileSplit(Resources.asFile("iris.txt"))); - - MultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10).addReader("reader", rr2) - .addInput("reader", 0, 3).addOutputOneHot("reader", 4, 3).build(); - + MultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10).addReader("reader", rr2).addInput("reader", 0, 3).addOutputOneHot("reader", 4, 3).build(); while (rrdsi.hasNext()) { DataSet ds = rrdsi.next(); INDArray fds = ds.getFeatures(); INDArray lds = ds.getLabels(); - MultiDataSet mds = rrmdsi.next(); assertEquals(1, mds.getFeatures().length); assertEquals(1, mds.getLabels().length); @@ -101,49 +97,36 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest { assertNull(mds.getLabelsMaskArrays()); INDArray fmds = mds.getFeatures(0); INDArray lmds = mds.getLabels(0); - assertNotNull(fmds); assertNotNull(lmds); - assertEquals(fds, fmds); assertEquals(lds, lmds); } assertFalse(rrmdsi.hasNext()); - - //need to manually extract - File rootDir = temporaryFolder.newFolder(); + // need to manually extract + File rootDir = temporaryFolder.toFile(); for (int i = 0; i < 3; i++) { new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(rootDir); new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive(rootDir); new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive(rootDir); } - - //Load time series from CSV sequence files; compare to SequenceRecordReaderDataSetIterator + // Load time series from CSV sequence files; compare to SequenceRecordReaderDataSetIterator String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt"); String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabels_%d.txt"); - SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ","); featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); - - SequenceRecordReaderDataSetIterator iter = - new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false); - + SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false); SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ","); featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); - - MultiDataSetIterator srrmdsi = new RecordReaderMultiDataSetIterator.Builder(1) - .addSequenceReader("in", featureReader2).addSequenceReader("out", labelReader2).addInput("in") - .addOutputOneHot("out", 0, 4).build(); - + MultiDataSetIterator srrmdsi = new RecordReaderMultiDataSetIterator.Builder(1).addSequenceReader("in", featureReader2).addSequenceReader("out", labelReader2).addInput("in").addOutputOneHot("out", 0, 4).build(); while (iter.hasNext()) { DataSet ds = iter.next(); INDArray fds = ds.getFeatures(); INDArray lds = ds.getLabels(); - MultiDataSet mds = srrmdsi.next(); assertEquals(1, mds.getFeatures().length); assertEquals(1, mds.getLabels().length); @@ -151,10 +134,8 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest { assertNull(mds.getLabelsMaskArrays()); INDArray fmds = mds.getFeatures(0); INDArray lmds = mds.getLabels(0); - assertNotNull(fmds); assertNotNull(lmds); - assertEquals(fds, fmds); assertEquals(lds, lmds); } @@ -162,16 +143,13 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest { } @Test - public void testsBasicMeta() throws Exception { - //As per testBasic - but also loading metadata + @DisplayName("Tests Basic Meta") + void testsBasicMeta() throws Exception { + // As per testBasic - but also loading metadata RecordReader rr2 = new CSVRecordReader(0, ','); rr2.initialize(new FileSplit(Resources.asFile("iris.txt"))); - - RecordReaderMultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10) - .addReader("reader", rr2).addInput("reader", 0, 3).addOutputOneHot("reader", 4, 3).build(); - + RecordReaderMultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10).addReader("reader", rr2).addInput("reader", 0, 3).addOutputOneHot("reader", 4, 3).build(); rrmdsi.setCollectMetaData(true); - int count = 0; while (rrmdsi.hasNext()) { MultiDataSet mds = rrmdsi.next(); @@ -183,27 +161,22 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest { } @Test - public void testSplittingCSV() throws Exception { - //Here's the idea: take Iris, and split it up into 2 inputs and 2 output arrays - //Inputs: columns 0 and 1-2 - //Outputs: columns 3, and 4->OneHot - //need to manually extract + @DisplayName("Test Splitting CSV") + void testSplittingCSV() throws Exception { + // Here's the idea: take Iris, and split it up into 2 inputs and 2 output arrays + // Inputs: columns 0 and 1-2 + // Outputs: columns 3, and 4->OneHot + // need to manually extract RecordReader rr = new CSVRecordReader(0, ','); rr.initialize(new FileSplit(Resources.asFile("iris.txt"))); RecordReaderDataSetIterator rrdsi = new RecordReaderDataSetIterator(rr, 10, 4, 3); - RecordReader rr2 = new CSVRecordReader(0, ','); rr2.initialize(new FileSplit(Resources.asFile("iris.txt"))); - - MultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10).addReader("reader", rr2) - .addInput("reader", 0, 0).addInput("reader", 1, 2).addOutput("reader", 3, 3) - .addOutputOneHot("reader", 4, 3).build(); - + MultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10).addReader("reader", rr2).addInput("reader", 0, 0).addInput("reader", 1, 2).addOutput("reader", 3, 3).addOutputOneHot("reader", 4, 3).build(); while (rrdsi.hasNext()) { DataSet ds = rrdsi.next(); INDArray fds = ds.getFeatures(); INDArray lds = ds.getLabels(); - MultiDataSet mds = rrmdsi.next(); assertEquals(2, mds.getFeatures().length); assertEquals(2, mds.getLabels().length); @@ -211,20 +184,15 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest { assertNull(mds.getLabelsMaskArrays()); INDArray[] fmds = mds.getFeatures(); INDArray[] lmds = mds.getLabels(); - assertNotNull(fmds); assertNotNull(lmds); - for (int i = 0; i < fmds.length; i++) - assertNotNull(fmds[i]); - for (int i = 0; i < lmds.length; i++) - assertNotNull(lmds[i]); - - //Get the subsets of the original iris data - INDArray expIn1 = fds.get(all(), interval(0,0,true)); + for (int i = 0; i < fmds.length; i++) assertNotNull(fmds[i]); + for (int i = 0; i < lmds.length; i++) assertNotNull(lmds[i]); + // Get the subsets of the original iris data + INDArray expIn1 = fds.get(all(), interval(0, 0, true)); INDArray expIn2 = fds.get(all(), interval(1, 2, true)); - INDArray expOut1 = fds.get(all(), interval(3,3,true)); + INDArray expOut1 = fds.get(all(), interval(3, 3, true)); INDArray expOut2 = lds; - assertEquals(expIn1, fmds[0]); assertEquals(expIn2, fmds[1]); assertEquals(expOut1, lmds[0]); @@ -234,18 +202,15 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest { } @Test - public void testSplittingCSVMeta() throws Exception { - //Here's the idea: take Iris, and split it up into 2 inputs and 2 output arrays - //Inputs: columns 0 and 1-2 - //Outputs: columns 3, and 4->OneHot + @DisplayName("Test Splitting CSV Meta") + void testSplittingCSVMeta() throws Exception { + // Here's the idea: take Iris, and split it up into 2 inputs and 2 output arrays + // Inputs: columns 0 and 1-2 + // Outputs: columns 3, and 4->OneHot RecordReader rr2 = new CSVRecordReader(0, ','); rr2.initialize(new FileSplit(Resources.asFile("iris.txt"))); - - RecordReaderMultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10) - .addReader("reader", rr2).addInput("reader", 0, 0).addInput("reader", 1, 2) - .addOutput("reader", 3, 3).addOutputOneHot("reader", 4, 3).build(); + RecordReaderMultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10).addReader("reader", rr2).addInput("reader", 0, 0).addInput("reader", 1, 2).addOutput("reader", 3, 3).addOutputOneHot("reader", 4, 3).build(); rrmdsi.setCollectMetaData(true); - int count = 0; while (rrmdsi.hasNext()) { MultiDataSet mds = rrmdsi.next(); @@ -257,42 +222,33 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest { } @Test - public void testSplittingCSVSequence() throws Exception { - //Idea: take CSV sequences, and split "csvsequence_i.txt" into two separate inputs; keep "csvSequencelables_i.txt" + @DisplayName("Test Splitting CSV Sequence") + void testSplittingCSVSequence() throws Exception { + // Idea: take CSV sequences, and split "csvsequence_i.txt" into two separate inputs; keep "csvSequencelables_i.txt" // as standard one-hot output - //need to manually extract - File rootDir = temporaryFolder.newFolder(); + // need to manually extract + File rootDir = temporaryFolder.toFile(); for (int i = 0; i < 3; i++) { new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(rootDir); new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive(rootDir); new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive(rootDir); } - String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt"); String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabels_%d.txt"); - SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ","); featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); - - SequenceRecordReaderDataSetIterator iter = - new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false); - + SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false); SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ","); featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); - - MultiDataSetIterator srrmdsi = new RecordReaderMultiDataSetIterator.Builder(1) - .addSequenceReader("seq1", featureReader2).addSequenceReader("seq2", labelReader2) - .addInput("seq1", 0, 1).addInput("seq1", 2, 2).addOutputOneHot("seq2", 0, 4).build(); - + MultiDataSetIterator srrmdsi = new RecordReaderMultiDataSetIterator.Builder(1).addSequenceReader("seq1", featureReader2).addSequenceReader("seq2", labelReader2).addInput("seq1", 0, 1).addInput("seq1", 2, 2).addOutputOneHot("seq2", 0, 4).build(); while (iter.hasNext()) { DataSet ds = iter.next(); INDArray fds = ds.getFeatures(); INDArray lds = ds.getLabels(); - MultiDataSet mds = srrmdsi.next(); assertEquals(2, mds.getFeatures().length); assertEquals(1, mds.getLabels().length); @@ -300,17 +256,12 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest { assertNull(mds.getLabelsMaskArrays()); INDArray[] fmds = mds.getFeatures(); INDArray[] lmds = mds.getLabels(); - assertNotNull(fmds); assertNotNull(lmds); - for (int i = 0; i < fmds.length; i++) - assertNotNull(fmds[i]); - for (int i = 0; i < lmds.length; i++) - assertNotNull(lmds[i]); - + for (int i = 0; i < fmds.length; i++) assertNotNull(fmds[i]); + for (int i = 0; i < lmds.length; i++) assertNotNull(lmds[i]); INDArray expIn1 = fds.get(all(), NDArrayIndex.interval(0, 1, true), all()); INDArray expIn2 = fds.get(all(), NDArrayIndex.interval(2, 2, true), all()); - assertEquals(expIn1, fmds[0]); assertEquals(expIn2, fmds[1]); assertEquals(lds, lmds[0]); @@ -319,36 +270,29 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest { } @Test - public void testSplittingCSVSequenceMeta() throws Exception { - //Idea: take CSV sequences, and split "csvsequence_i.txt" into two separate inputs; keep "csvSequencelables_i.txt" + @DisplayName("Test Splitting CSV Sequence Meta") + void testSplittingCSVSequenceMeta() throws Exception { + // Idea: take CSV sequences, and split "csvsequence_i.txt" into two separate inputs; keep "csvSequencelables_i.txt" // as standard one-hot output - //need to manually extract - File rootDir = temporaryFolder.newFolder(); + // need to manually extract + File rootDir = temporaryFolder.toFile(); for (int i = 0; i < 3; i++) { new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(rootDir); new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive(rootDir); new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive(rootDir); } - String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt"); String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabels_%d.txt"); - SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ","); featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); - SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ","); featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); - - RecordReaderMultiDataSetIterator srrmdsi = new RecordReaderMultiDataSetIterator.Builder(1) - .addSequenceReader("seq1", featureReader2).addSequenceReader("seq2", labelReader2) - .addInput("seq1", 0, 1).addInput("seq1", 2, 2).addOutputOneHot("seq2", 0, 4).build(); - + RecordReaderMultiDataSetIterator srrmdsi = new RecordReaderMultiDataSetIterator.Builder(1).addSequenceReader("seq1", featureReader2).addSequenceReader("seq2", labelReader2).addInput("seq1", 0, 1).addInput("seq1", 2, 2).addOutputOneHot("seq2", 0, 4).build(); srrmdsi.setCollectMetaData(true); - int count = 0; while (srrmdsi.hasNext()) { MultiDataSet mds = srrmdsi.next(); @@ -359,34 +303,27 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest { assertEquals(3, count); } - @Test - public void testInputValidation() { - - //Test: no readers + @DisplayName("Test Input Validation") + void testInputValidation() { + // Test: no readers try { - MultiDataSetIterator r = new RecordReaderMultiDataSetIterator.Builder(1).addInput("something") - .addOutput("something").build(); + MultiDataSetIterator r = new RecordReaderMultiDataSetIterator.Builder(1).addInput("something").addOutput("something").build(); fail("Should have thrown exception"); } catch (Exception e) { } - - //Test: reference to reader that doesn't exist + // Test: reference to reader that doesn't exist try { RecordReader rr = new CSVRecordReader(0, ','); rr.initialize(new FileSplit(Resources.asFile("iris.txt"))); - - MultiDataSetIterator r = new RecordReaderMultiDataSetIterator.Builder(1).addReader("iris", rr) - .addInput("thisDoesntExist", 0, 3).addOutputOneHot("iris", 4, 3).build(); + MultiDataSetIterator r = new RecordReaderMultiDataSetIterator.Builder(1).addReader("iris", rr).addInput("thisDoesntExist", 0, 3).addOutputOneHot("iris", 4, 3).build(); fail("Should have thrown exception"); } catch (Exception e) { } - - //Test: no inputs or outputs + // Test: no inputs or outputs try { RecordReader rr = new CSVRecordReader(0, ','); rr.initialize(new FileSplit(Resources.asFile("iris.txt"))); - MultiDataSetIterator r = new RecordReaderMultiDataSetIterator.Builder(1).addReader("iris", rr).build(); fail("Should have thrown exception"); } catch (Exception e) { @@ -394,81 +331,55 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest { } @Test - public void testVariableLengthTS() throws Exception { - //need to manually extract - File rootDir = temporaryFolder.newFolder(); + @DisplayName("Test Variable Length TS") + void testVariableLengthTS() throws Exception { + // need to manually extract + File rootDir = temporaryFolder.toFile(); for (int i = 0; i < 3; i++) { new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(rootDir); new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive(rootDir); new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive(rootDir); } - String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt"); String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabelsShort_%d.txt"); - - //Set up SequenceRecordReaderDataSetIterators for comparison - + // Set up SequenceRecordReaderDataSetIterators for comparison SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ","); featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); - SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ","); featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); - - SequenceRecordReaderDataSetIterator iterAlignStart = new SequenceRecordReaderDataSetIterator(featureReader, - labelReader, 1, 4, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_START); - - SequenceRecordReaderDataSetIterator iterAlignEnd = new SequenceRecordReaderDataSetIterator(featureReader2, - labelReader2, 1, 4, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END); - - - //Set up + SequenceRecordReaderDataSetIterator iterAlignStart = new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_START); + SequenceRecordReaderDataSetIterator iterAlignEnd = new SequenceRecordReaderDataSetIterator(featureReader2, labelReader2, 1, 4, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END); + // Set up SequenceRecordReader featureReader3 = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader3 = new CSVSequenceRecordReader(1, ","); featureReader3.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); labelReader3.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); - SequenceRecordReader featureReader4 = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader4 = new CSVSequenceRecordReader(1, ","); featureReader4.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); labelReader4.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); - - RecordReaderMultiDataSetIterator rrmdsiStart = new RecordReaderMultiDataSetIterator.Builder(1) - .addSequenceReader("in", featureReader3).addSequenceReader("out", labelReader3).addInput("in") - .addOutputOneHot("out", 0, 4) - .sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_START).build(); - - RecordReaderMultiDataSetIterator rrmdsiEnd = new RecordReaderMultiDataSetIterator.Builder(1) - .addSequenceReader("in", featureReader4).addSequenceReader("out", labelReader4).addInput("in") - .addOutputOneHot("out", 0, 4) - .sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_END).build(); - - + RecordReaderMultiDataSetIterator rrmdsiStart = new RecordReaderMultiDataSetIterator.Builder(1).addSequenceReader("in", featureReader3).addSequenceReader("out", labelReader3).addInput("in").addOutputOneHot("out", 0, 4).sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_START).build(); + RecordReaderMultiDataSetIterator rrmdsiEnd = new RecordReaderMultiDataSetIterator.Builder(1).addSequenceReader("in", featureReader4).addSequenceReader("out", labelReader4).addInput("in").addOutputOneHot("out", 0, 4).sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_END).build(); while (iterAlignStart.hasNext()) { DataSet dsStart = iterAlignStart.next(); DataSet dsEnd = iterAlignEnd.next(); - MultiDataSet mdsStart = rrmdsiStart.next(); MultiDataSet mdsEnd = rrmdsiEnd.next(); - assertEquals(1, mdsStart.getFeatures().length); assertEquals(1, mdsStart.getLabels().length); - //assertEquals(1, mdsStart.getFeaturesMaskArrays().length); //Features data is always longer -> don't need mask arrays for it + // assertEquals(1, mdsStart.getFeaturesMaskArrays().length); //Features data is always longer -> don't need mask arrays for it assertEquals(1, mdsStart.getLabelsMaskArrays().length); - assertEquals(1, mdsEnd.getFeatures().length); assertEquals(1, mdsEnd.getLabels().length); - //assertEquals(1, mdsEnd.getFeaturesMaskArrays().length); + // assertEquals(1, mdsEnd.getFeaturesMaskArrays().length); assertEquals(1, mdsEnd.getLabelsMaskArrays().length); - - assertEquals(dsStart.getFeatures(), mdsStart.getFeatures(0)); assertEquals(dsStart.getLabels(), mdsStart.getLabels(0)); assertEquals(dsStart.getLabelsMaskArray(), mdsStart.getLabelsMaskArray(0)); - assertEquals(dsEnd.getFeatures(), mdsEnd.getFeatures(0)); assertEquals(dsEnd.getLabels(), mdsEnd.getLabels(0)); assertEquals(dsEnd.getLabelsMaskArray(), mdsEnd.getLabelsMaskArray(0)); @@ -477,57 +388,40 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest { assertFalse(rrmdsiEnd.hasNext()); } - @Test - public void testVariableLengthTSMeta() throws Exception { - //need to manually extract - File rootDir = temporaryFolder.newFolder(); + @DisplayName("Test Variable Length TS Meta") + void testVariableLengthTSMeta() throws Exception { + // need to manually extract + File rootDir = temporaryFolder.toFile(); for (int i = 0; i < 3; i++) { new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(rootDir); new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive(rootDir); new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive(rootDir); } - //Set up SequenceRecordReaderDataSetIterators for comparison - + // Set up SequenceRecordReaderDataSetIterators for comparison String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt"); String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabelsShort_%d.txt"); - - //Set up + // Set up SequenceRecordReader featureReader3 = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader3 = new CSVSequenceRecordReader(1, ","); featureReader3.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); labelReader3.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); - SequenceRecordReader featureReader4 = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader4 = new CSVSequenceRecordReader(1, ","); featureReader4.initialize(new NumberedFileInputSplit(featuresPath, 0, 2)); labelReader4.initialize(new NumberedFileInputSplit(labelsPath, 0, 2)); - - RecordReaderMultiDataSetIterator rrmdsiStart = new RecordReaderMultiDataSetIterator.Builder(1) - .addSequenceReader("in", featureReader3).addSequenceReader("out", labelReader3).addInput("in") - .addOutputOneHot("out", 0, 4) - .sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_START).build(); - - RecordReaderMultiDataSetIterator rrmdsiEnd = new RecordReaderMultiDataSetIterator.Builder(1) - .addSequenceReader("in", featureReader4).addSequenceReader("out", labelReader4).addInput("in") - .addOutputOneHot("out", 0, 4) - .sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_END).build(); - + RecordReaderMultiDataSetIterator rrmdsiStart = new RecordReaderMultiDataSetIterator.Builder(1).addSequenceReader("in", featureReader3).addSequenceReader("out", labelReader3).addInput("in").addOutputOneHot("out", 0, 4).sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_START).build(); + RecordReaderMultiDataSetIterator rrmdsiEnd = new RecordReaderMultiDataSetIterator.Builder(1).addSequenceReader("in", featureReader4).addSequenceReader("out", labelReader4).addInput("in").addOutputOneHot("out", 0, 4).sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_END).build(); rrmdsiStart.setCollectMetaData(true); rrmdsiEnd.setCollectMetaData(true); - int count = 0; while (rrmdsiStart.hasNext()) { MultiDataSet mdsStart = rrmdsiStart.next(); MultiDataSet mdsEnd = rrmdsiEnd.next(); - - MultiDataSet mdsStartFromMeta = - rrmdsiStart.loadFromMetaData(mdsStart.getExampleMetaData(RecordMetaData.class)); + MultiDataSet mdsStartFromMeta = rrmdsiStart.loadFromMetaData(mdsStart.getExampleMetaData(RecordMetaData.class)); MultiDataSet mdsEndFromMeta = rrmdsiEnd.loadFromMetaData(mdsEnd.getExampleMetaData(RecordMetaData.class)); - assertEquals(mdsStart, mdsStartFromMeta); assertEquals(mdsEnd, mdsEndFromMeta); - count++; } assertFalse(rrmdsiStart.hasNext()); @@ -536,53 +430,37 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest { } @Test - public void testImagesRRDMSI() throws Exception { - File parentDir = temporaryFolder.newFolder(); + @DisplayName("Test Images RRDMSI") + void testImagesRRDMSI() throws Exception { + File parentDir = temporaryFolder.toFile(); parentDir.deleteOnExit(); String str1 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Zico/"); String str2 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Ziwang_Xu/"); - File f1 = new File(str1); File f2 = new File(str2); f1.mkdirs(); f2.mkdirs(); - - TestUtils.writeStreamToFile(new File(FilenameUtils.concat(f1.getPath(), "Zico_0001.jpg")), - new ClassPathResource("lfwtest/Zico/Zico_0001.jpg").getInputStream()); - TestUtils.writeStreamToFile(new File(FilenameUtils.concat(f2.getPath(), "Ziwang_Xu_0001.jpg")), - new ClassPathResource("lfwtest/Ziwang_Xu/Ziwang_Xu_0001.jpg").getInputStream()); - - + TestUtils.writeStreamToFile(new File(FilenameUtils.concat(f1.getPath(), "Zico_0001.jpg")), new ClassPathResource("lfwtest/Zico/Zico_0001.jpg").getInputStream()); + TestUtils.writeStreamToFile(new File(FilenameUtils.concat(f2.getPath(), "Ziwang_Xu_0001.jpg")), new ClassPathResource("lfwtest/Ziwang_Xu/Ziwang_Xu_0001.jpg").getInputStream()); int outputNum = 2; Random r = new Random(12345); ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator(); - ImageRecordReader rr1 = new ImageRecordReader(10, 10, 1, labelMaker); ImageRecordReader rr1s = new ImageRecordReader(5, 5, 1, labelMaker); - rr1.initialize(new FileSplit(parentDir)); rr1s.initialize(new FileSplit(parentDir)); - - - MultiDataSetIterator trainDataIterator = new RecordReaderMultiDataSetIterator.Builder(1).addReader("rr1", rr1) - .addReader("rr1s", rr1s).addInput("rr1", 0, 0).addInput("rr1s", 0, 0) - .addOutputOneHot("rr1s", 1, outputNum).build(); - - //Now, do the same thing with ImageRecordReader, and check we get the same results: + MultiDataSetIterator trainDataIterator = new RecordReaderMultiDataSetIterator.Builder(1).addReader("rr1", rr1).addReader("rr1s", rr1s).addInput("rr1", 0, 0).addInput("rr1s", 0, 0).addOutputOneHot("rr1s", 1, outputNum).build(); + // Now, do the same thing with ImageRecordReader, and check we get the same results: ImageRecordReader rr1_b = new ImageRecordReader(10, 10, 1, labelMaker); ImageRecordReader rr1s_b = new ImageRecordReader(5, 5, 1, labelMaker); rr1_b.initialize(new FileSplit(parentDir)); rr1s_b.initialize(new FileSplit(parentDir)); - DataSetIterator dsi1 = new RecordReaderDataSetIterator(rr1_b, 1, 1, 2); DataSetIterator dsi2 = new RecordReaderDataSetIterator(rr1s_b, 1, 1, 2); - for (int i = 0; i < 2; i++) { MultiDataSet mds = trainDataIterator.next(); - DataSet d1 = dsi1.next(); DataSet d2 = dsi2.next(); - assertEquals(d1.getFeatures(), mds.getFeatures(0)); assertEquals(d2.getFeatures(), mds.getFeatures(1)); assertEquals(d1.getLabels(), mds.getLabels(0)); @@ -590,261 +468,180 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest { } @Test - public void testImagesRRDMSI_Batched() throws Exception { - File parentDir = temporaryFolder.newFolder(); + @DisplayName("Test Images RRDMSI _ Batched") + void testImagesRRDMSI_Batched() throws Exception { + File parentDir = temporaryFolder.toFile(); parentDir.deleteOnExit(); String str1 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Zico/"); String str2 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Ziwang_Xu/"); - File f1 = new File(str1); File f2 = new File(str2); f1.mkdirs(); f2.mkdirs(); - - TestUtils.writeStreamToFile(new File(FilenameUtils.concat(f1.getPath(), "Zico_0001.jpg")), - new ClassPathResource("lfwtest/Zico/Zico_0001.jpg").getInputStream()); - TestUtils.writeStreamToFile(new File(FilenameUtils.concat(f2.getPath(), "Ziwang_Xu_0001.jpg")), - new ClassPathResource("lfwtest/Ziwang_Xu/Ziwang_Xu_0001.jpg").getInputStream()); - + TestUtils.writeStreamToFile(new File(FilenameUtils.concat(f1.getPath(), "Zico_0001.jpg")), new ClassPathResource("lfwtest/Zico/Zico_0001.jpg").getInputStream()); + TestUtils.writeStreamToFile(new File(FilenameUtils.concat(f2.getPath(), "Ziwang_Xu_0001.jpg")), new ClassPathResource("lfwtest/Ziwang_Xu/Ziwang_Xu_0001.jpg").getInputStream()); int outputNum = 2; ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator(); - ImageRecordReader rr1 = new ImageRecordReader(10, 10, 1, labelMaker); ImageRecordReader rr1s = new ImageRecordReader(5, 5, 1, labelMaker); - URI[] uris = new FileSplit(parentDir).locations(); - rr1.initialize(new CollectionInputSplit(uris)); rr1s.initialize(new CollectionInputSplit(uris)); - - MultiDataSetIterator trainDataIterator = new RecordReaderMultiDataSetIterator.Builder(2).addReader("rr1", rr1) - .addReader("rr1s", rr1s).addInput("rr1", 0, 0).addInput("rr1s", 0, 0) - .addOutputOneHot("rr1s", 1, outputNum).build(); - - //Now, do the same thing with ImageRecordReader, and check we get the same results: + MultiDataSetIterator trainDataIterator = new RecordReaderMultiDataSetIterator.Builder(2).addReader("rr1", rr1).addReader("rr1s", rr1s).addInput("rr1", 0, 0).addInput("rr1s", 0, 0).addOutputOneHot("rr1s", 1, outputNum).build(); + // Now, do the same thing with ImageRecordReader, and check we get the same results: ImageRecordReader rr1_b = new ImageRecordReader(10, 10, 1, labelMaker); ImageRecordReader rr1s_b = new ImageRecordReader(5, 5, 1, labelMaker); rr1_b.initialize(new FileSplit(parentDir)); rr1s_b.initialize(new FileSplit(parentDir)); - DataSetIterator dsi1 = new RecordReaderDataSetIterator(rr1_b, 2, 1, 2); DataSetIterator dsi2 = new RecordReaderDataSetIterator(rr1s_b, 2, 1, 2); - MultiDataSet mds = trainDataIterator.next(); - DataSet d1 = dsi1.next(); DataSet d2 = dsi2.next(); - assertEquals(d1.getFeatures(), mds.getFeatures(0)); assertEquals(d2.getFeatures(), mds.getFeatures(1)); assertEquals(d1.getLabels(), mds.getLabels(0)); - - //Check label assignment: - + // Check label assignment: File currentFile = rr1_b.getCurrentFile(); INDArray expLabels; - if(currentFile.getAbsolutePath().contains("Zico")){ - expLabels = Nd4j.create(new double[][] {{0, 1}, {1, 0}}); + if (currentFile.getAbsolutePath().contains("Zico")) { + expLabels = Nd4j.create(new double[][] { { 0, 1 }, { 1, 0 } }); } else { - expLabels = Nd4j.create(new double[][] {{1, 0}, {0, 1}}); + expLabels = Nd4j.create(new double[][] { { 1, 0 }, { 0, 1 } }); } - assertEquals(expLabels, d1.getLabels()); assertEquals(expLabels, d2.getLabels()); } - - - @Test - public void testTimeSeriesRandomOffset() { - //2 in, 2 out, 3 total sequences of length [1,3,5] - - List> seq1 = - Arrays.asList(Arrays.asList(new DoubleWritable(1.0), new DoubleWritable(2.0))); - List> seq2 = - Arrays.asList(Arrays.asList(new DoubleWritable(10.0), new DoubleWritable(11.0)), - Arrays.asList(new DoubleWritable(20.0), new DoubleWritable(21.0)), - Arrays.asList(new DoubleWritable(30.0), new DoubleWritable(31.0))); - List> seq3 = - Arrays.asList(Arrays.asList(new DoubleWritable(100.0), new DoubleWritable(101.0)), - Arrays.asList(new DoubleWritable(200.0), new DoubleWritable(201.0)), - Arrays.asList(new DoubleWritable(300.0), new DoubleWritable(301.0)), - Arrays.asList(new DoubleWritable(400.0), new DoubleWritable(401.0)), - Arrays.asList(new DoubleWritable(500.0), new DoubleWritable(501.0))); - + @DisplayName("Test Time Series Random Offset") + void testTimeSeriesRandomOffset() { + // 2 in, 2 out, 3 total sequences of length [1,3,5] + List> seq1 = Arrays.asList(Arrays.asList(new DoubleWritable(1.0), new DoubleWritable(2.0))); + List> seq2 = Arrays.asList(Arrays.asList(new DoubleWritable(10.0), new DoubleWritable(11.0)), Arrays.asList(new DoubleWritable(20.0), new DoubleWritable(21.0)), Arrays.asList(new DoubleWritable(30.0), new DoubleWritable(31.0))); + List> seq3 = Arrays.asList(Arrays.asList(new DoubleWritable(100.0), new DoubleWritable(101.0)), Arrays.asList(new DoubleWritable(200.0), new DoubleWritable(201.0)), Arrays.asList(new DoubleWritable(300.0), new DoubleWritable(301.0)), Arrays.asList(new DoubleWritable(400.0), new DoubleWritable(401.0)), Arrays.asList(new DoubleWritable(500.0), new DoubleWritable(501.0))); Collection>> seqs = Arrays.asList(seq1, seq2, seq3); - SequenceRecordReader rr = new CollectionSequenceRecordReader(seqs); - - RecordReaderMultiDataSetIterator rrmdsi = - new RecordReaderMultiDataSetIterator.Builder(3).addSequenceReader("rr", rr).addInput("rr", 0, 0) - .addOutput("rr", 1, 1).timeSeriesRandomOffset(true, 1234L).build(); - - - Random r = new Random(1234); //Provides seed for each minibatch + RecordReaderMultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(3).addSequenceReader("rr", rr).addInput("rr", 0, 0).addOutput("rr", 1, 1).timeSeriesRandomOffset(true, 1234L).build(); + // Provides seed for each minibatch + Random r = new Random(1234); long seed = r.nextLong(); - Random r2 = new Random(seed); //Use same RNG seed in new RNG for each minibatch - int expOffsetSeq1 = r2.nextInt(5 - 1 + 1); //0 to 4 inclusive + // Use same RNG seed in new RNG for each minibatch + Random r2 = new Random(seed); + // 0 to 4 inclusive + int expOffsetSeq1 = r2.nextInt(5 - 1 + 1); int expOffsetSeq2 = r2.nextInt(5 - 3 + 1); - int expOffsetSeq3 = 0; //Longest TS, always 0 - //With current seed: 3, 1, 0 - // System.out.println(expOffsetSeq1 + "\t" + expOffsetSeq2 + "\t" + expOffsetSeq3); - + // Longest TS, always 0 + int expOffsetSeq3 = 0; + // With current seed: 3, 1, 0 + // System.out.println(expOffsetSeq1 + "\t" + expOffsetSeq2 + "\t" + expOffsetSeq3); MultiDataSet mds = rrmdsi.next(); - - INDArray expMask = Nd4j.create(new double[][] {{0, 0, 0, 1, 0}, {0, 1, 1, 1, 0}, {1, 1, 1, 1, 1}}); - + INDArray expMask = Nd4j.create(new double[][] { { 0, 0, 0, 1, 0 }, { 0, 1, 1, 1, 0 }, { 1, 1, 1, 1, 1 } }); assertEquals(expMask, mds.getFeaturesMaskArray(0)); assertEquals(expMask, mds.getLabelsMaskArray(0)); - INDArray f = mds.getFeatures(0); INDArray l = mds.getLabels(0); - - INDArray expF1 = Nd4j.create(new double[] {1.0}, new int[]{1,1}); - INDArray expL1 = Nd4j.create(new double[] {2.0}, new int[]{1,1}); - - INDArray expF2 = Nd4j.create(new double[] {10, 20, 30}, new int[]{1,3}); - INDArray expL2 = Nd4j.create(new double[] {11, 21, 31}, new int[]{1,3}); - - INDArray expF3 = Nd4j.create(new double[] {100, 200, 300, 400, 500}, new int[]{1,5}); - INDArray expL3 = Nd4j.create(new double[] {101, 201, 301, 401, 501}, new int[]{1,5}); - - assertEquals(expF1, f.get(point(0), all(), - NDArrayIndex.interval(expOffsetSeq1, expOffsetSeq1 + 1))); - assertEquals(expL1, l.get(point(0), all(), - NDArrayIndex.interval(expOffsetSeq1, expOffsetSeq1 + 1))); - - assertEquals(expF2, f.get(point(1), all(), - NDArrayIndex.interval(expOffsetSeq2, expOffsetSeq2 + 3))); - assertEquals(expL2, l.get(point(1), all(), - NDArrayIndex.interval(expOffsetSeq2, expOffsetSeq2 + 3))); - - assertEquals(expF3, f.get(point(2), all(), - NDArrayIndex.interval(expOffsetSeq3, expOffsetSeq3 + 5))); - assertEquals(expL3, l.get(point(2), all(), - NDArrayIndex.interval(expOffsetSeq3, expOffsetSeq3 + 5))); + INDArray expF1 = Nd4j.create(new double[] { 1.0 }, new int[] { 1, 1 }); + INDArray expL1 = Nd4j.create(new double[] { 2.0 }, new int[] { 1, 1 }); + INDArray expF2 = Nd4j.create(new double[] { 10, 20, 30 }, new int[] { 1, 3 }); + INDArray expL2 = Nd4j.create(new double[] { 11, 21, 31 }, new int[] { 1, 3 }); + INDArray expF3 = Nd4j.create(new double[] { 100, 200, 300, 400, 500 }, new int[] { 1, 5 }); + INDArray expL3 = Nd4j.create(new double[] { 101, 201, 301, 401, 501 }, new int[] { 1, 5 }); + assertEquals(expF1, f.get(point(0), all(), NDArrayIndex.interval(expOffsetSeq1, expOffsetSeq1 + 1))); + assertEquals(expL1, l.get(point(0), all(), NDArrayIndex.interval(expOffsetSeq1, expOffsetSeq1 + 1))); + assertEquals(expF2, f.get(point(1), all(), NDArrayIndex.interval(expOffsetSeq2, expOffsetSeq2 + 3))); + assertEquals(expL2, l.get(point(1), all(), NDArrayIndex.interval(expOffsetSeq2, expOffsetSeq2 + 3))); + assertEquals(expF3, f.get(point(2), all(), NDArrayIndex.interval(expOffsetSeq3, expOffsetSeq3 + 5))); + assertEquals(expL3, l.get(point(2), all(), NDArrayIndex.interval(expOffsetSeq3, expOffsetSeq3 + 5))); } - @Test - public void testSeqRRDSIMasking(){ - //This also tests RecordReaderMultiDataSetIterator, by virtue of + @DisplayName("Test Seq RRDSI Masking") + void testSeqRRDSIMasking() { + // This also tests RecordReaderMultiDataSetIterator, by virtue of List>> features = new ArrayList<>(); List>> labels = new ArrayList<>(); - features.add(Arrays.asList(l(new DoubleWritable(1)), l(new DoubleWritable(2)), l(new DoubleWritable(3)))); features.add(Arrays.asList(l(new DoubleWritable(4)), l(new DoubleWritable(5)))); - labels.add(Arrays.asList(l(new IntWritable(0)))); labels.add(Arrays.asList(l(new IntWritable(1)))); - CollectionSequenceRecordReader fR = new CollectionSequenceRecordReader(features); CollectionSequenceRecordReader lR = new CollectionSequenceRecordReader(labels); - - SequenceRecordReaderDataSetIterator seqRRDSI = new SequenceRecordReaderDataSetIterator( - fR, lR, 2, 2, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END); - + SequenceRecordReaderDataSetIterator seqRRDSI = new SequenceRecordReaderDataSetIterator(fR, lR, 2, 2, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END); DataSet ds = seqRRDSI.next(); - - INDArray fMask = Nd4j.create(new double[][]{ - {1,1,1}, - {1,1,0}}); - - INDArray lMask = Nd4j.create(new double[][]{ - {0,0,1}, - {0,1,0}}); - + INDArray fMask = Nd4j.create(new double[][] { { 1, 1, 1 }, { 1, 1, 0 } }); + INDArray lMask = Nd4j.create(new double[][] { { 0, 0, 1 }, { 0, 1, 0 } }); assertEquals(fMask, ds.getFeaturesMaskArray()); assertEquals(lMask, ds.getLabelsMaskArray()); - - INDArray f = Nd4j.create(new double[][]{ - {1,2,3}, - {4,5,0}}); - - INDArray l = Nd4j.create(2,2,3); - l.putScalar(0,0,2, 1.0); - l.putScalar(1,1,1, 1.0); - + INDArray f = Nd4j.create(new double[][] { { 1, 2, 3 }, { 4, 5, 0 } }); + INDArray l = Nd4j.create(2, 2, 3); + l.putScalar(0, 0, 2, 1.0); + l.putScalar(1, 1, 1, 1.0); assertEquals(f, ds.getFeatures().get(all(), point(0), all())); assertEquals(l, ds.getLabels()); } - private static List l(Writable... in){ + private static List l(Writable... in) { return Arrays.asList(in); } - - @Test - public void testExcludeStringColCSV() throws Exception { - File csvFile = temporaryFolder.newFile(); - + @DisplayName("Test Exclude String Col CSV") + void testExcludeStringColCSV() throws Exception { + File csvFile = temporaryFolder.toFile(); StringBuilder sb = new StringBuilder(); - for(int i=1; i<=10; i++ ){ - if(i > 1){ + for (int i = 1; i <= 10; i++) { + if (i > 1) { sb.append("\n"); } sb.append("skip_").append(i).append(",").append(i).append(",").append(i + 0.5); } FileUtils.writeStringToFile(csvFile, sb.toString()); - RecordReader rr = new CSVRecordReader(); rr.initialize(new FileSplit(csvFile)); - - RecordReaderMultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10) - .addReader("rr", rr) - .addInput("rr", 1, 1) - .addOutput("rr", 2, 2) - .build(); - - INDArray expFeatures = Nd4j.linspace(1,10,10).reshape(1,10).transpose(); - INDArray expLabels = Nd4j.linspace(1,10,10).addi(0.5).reshape(1,10).transpose(); - + RecordReaderMultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10).addReader("rr", rr).addInput("rr", 1, 1).addOutput("rr", 2, 2).build(); + INDArray expFeatures = Nd4j.linspace(1, 10, 10).reshape(1, 10).transpose(); + INDArray expLabels = Nd4j.linspace(1, 10, 10).addi(0.5).reshape(1, 10).transpose(); MultiDataSet mds = rrmdsi.next(); assertFalse(rrmdsi.hasNext()); - assertEquals(expFeatures, mds.getFeatures(0).castTo(expFeatures.dataType())); assertEquals(expLabels, mds.getLabels(0).castTo(expLabels.dataType())); } - private static final int nX = 32; + private static final int nY = 32; + private static final int nZ = 28; - @Test - public void testRRMDSI5D() { + @DisplayName("Test RRMDSI 5 D") + void testRRMDSI5D() { int batchSize = 5; - CustomRecordReader recordReader = new CustomRecordReader(); - DataSetIterator dataIter = new RecordReaderDataSetIterator(recordReader, batchSize, - 1, /* Index of label in records */ - 2 /* number of different labels */); - + DataSetIterator dataIter = new RecordReaderDataSetIterator(recordReader, batchSize, 1, /* Index of label in records */ + 2); int count = 0; - while(dataIter.hasNext()){ + while (dataIter.hasNext()) { DataSet ds = dataIter.next(); - - int offset = 5*count; - for( int i=0; i<5; i++ ){ - INDArray act = ds.getFeatures().get(interval(i,i,true), all(), all(), all(), all()); - INDArray exp = Nd4j.valueArrayOf(new int[]{1, 1, nZ, nX, nY}, i + offset ); + int offset = 5 * count; + for (int i = 0; i < 5; i++) { + INDArray act = ds.getFeatures().get(interval(i, i, true), all(), all(), all(), all()); + INDArray exp = Nd4j.valueArrayOf(new int[] { 1, 1, nZ, nX, nY }, i + offset); assertEquals(exp, act); } count++; } - assertEquals(2, count); } - + @DisplayName("Custom Record Reader") static class CustomRecordReader extends BaseRecordReader { int n = 0; - CustomRecordReader() { } + CustomRecordReader() { + } @Override public boolean batchesSupported() { @@ -858,8 +655,8 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest { @Override public List next() { - INDArray nd = Nd4j.create(new float[nZ*nY*nX], new int[] {1, 1, nZ, nY, nX }, 'c').assign(n); - final Listres = RecordConverter.toRecord(nd); + INDArray nd = Nd4j.create(new float[nZ * nY * nX], new int[] { 1, 1, nZ, nY, nX }, 'c').assign(n); + final List res = RecordConverter.toRecord(nd); res.add(new IntWritable(0)); n++; return res; @@ -867,14 +664,16 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest { @Override public boolean hasNext() { - return n<10; + return n < 10; } final static ArrayList labels = new ArrayList<>(2); + static { labels.add("lbl0"); labels.add("lbl1"); } + @Override public List getLabels() { return labels; @@ -928,6 +727,7 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest { public void initialize(InputSplit split) { n = 0; } + @Override public void initialize(Configuration conf, InputSplit split) { n = 0; diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/fetchers/SvhnDataFetcherTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/fetchers/SvhnDataFetcherTest.java index 617e0d1ff..7a59ae012 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/fetchers/SvhnDataFetcherTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/fetchers/SvhnDataFetcherTest.java @@ -17,38 +17,39 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.datasets.fetchers; import org.deeplearning4j.BaseDL4JTest; import org.junit.Rule; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.rules.Timeout; - import java.io.File; - -import static org.junit.Assert.assertTrue; -import static org.junit.Assume.assumeTrue; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assumptions.assumeTrue; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** * @author saudet */ -public class SvhnDataFetcherTest extends BaseDL4JTest { +@DisplayName("Svhn Data Fetcher Test") +class SvhnDataFetcherTest extends BaseDL4JTest { @Override public long getTimeoutMilliseconds() { - return 480_000_000L; //Shouldn't take this long but slow download or drive access on CI machines may need extra time. + // Shouldn't take this long but slow download or drive access on CI machines may need extra time. + return 480_000_000L; } @Test - public void testSvhnDataFetcher() throws Exception { - assumeTrue(isIntegrationTests()); //Ignore unless integration tests - CI can get caught up on slow disk access - + @DisplayName("Test Svhn Data Fetcher") + void testSvhnDataFetcher() throws Exception { + // Ignore unless integration tests - CI can get caught up on slow disk access + assumeTrue(isIntegrationTests()); SvhnDataFetcher fetch = new SvhnDataFetcher(); File path = fetch.getDataSetPath(DataSetType.TRAIN); File path2 = fetch.getDataSetPath(DataSetType.TEST); File path3 = fetch.getDataSetPath(DataSetType.VALIDATION); - assertTrue(path.isDirectory()); assertTrue(path2.isDirectory()); assertTrue(path3.isDirectory()); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/AbstractDataSetIteratorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/AbstractDataSetIteratorTest.java index 4a6eac144..af42f61ea 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/AbstractDataSetIteratorTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/AbstractDataSetIteratorTest.java @@ -17,52 +17,50 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.datasets.iterator; import org.apache.commons.lang3.RandomUtils; import org.deeplearning4j.BaseDL4JTest; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.common.primitives.Pair; - import java.util.Iterator; import java.util.concurrent.atomic.AtomicInteger; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +@DisplayName("Abstract Data Set Iterator Test") +class AbstractDataSetIteratorTest extends BaseDL4JTest { -public class AbstractDataSetIteratorTest extends BaseDL4JTest { @Test - public void next() throws Exception { + @DisplayName("Next") + void next() throws Exception { int numFeatures = 128; int batchSize = 10; int numRows = 1000; AtomicInteger cnt = new AtomicInteger(0); FloatsDataSetIterator iterator = new FloatsDataSetIterator(floatIterable(numRows, numFeatures), batchSize); - assertTrue(iterator.hasNext()); - while (iterator.hasNext()) { DataSet dataSet = iterator.next(); - INDArray features = dataSet.getFeatures(); - assertEquals(batchSize, features.rows()); assertEquals(numFeatures, features.columns()); cnt.incrementAndGet(); } - assertEquals(numRows / batchSize, cnt.get()); } - protected static Iterable> floatIterable(final int totalRows, final int numColumns) { return new Iterable>() { + @Override public Iterator> iterator() { return new Iterator>() { + private AtomicInteger cnt = new AtomicInteger(0); @Override @@ -72,8 +70,8 @@ public class AbstractDataSetIteratorTest extends BaseDL4JTest { @Override public Pair next() { - float features[] = new float[numColumns]; - float labels[] = new float[numColumns]; + float[] features = new float[numColumns]; + float[] labels = new float[numColumns]; for (int i = 0; i < numColumns; i++) { features[i] = (float) i; labels[i] = RandomUtils.nextFloat(0, 5); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/AsyncDataSetIteratorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/AsyncDataSetIteratorTest.java index 5a9c71595..3c29cfe10 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/AsyncDataSetIteratorTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/AsyncDataSetIteratorTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.datasets.iterator; import lombok.extern.slf4j.Slf4j; @@ -25,117 +24,118 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.callbacks.InterleavedDataSetCallback; import org.deeplearning4j.datasets.iterator.tools.VariableTimeseriesGenerator; import org.deeplearning4j.nn.util.TestDataSetConsumer; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; - import java.util.ArrayList; import java.util.Iterator; import java.util.List; import java.util.concurrent.atomic.AtomicLong; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; +import static org.junit.jupiter.api.Assertions.assertThrows; @Slf4j -public class AsyncDataSetIteratorTest extends BaseDL4JTest { +@DisplayName("Async Data Set Iterator Test") +class AsyncDataSetIteratorTest extends BaseDL4JTest { + private ExistingDataSetIterator backIterator; + private static final int TEST_SIZE = 100; + private static final int ITERATIONS = 10; // time spent in consumer thread, milliseconds private static final long EXECUTION_TIME = 5; + private static final long EXECUTION_SMALL = 1; - @Before - public void setUp() throws Exception { + @BeforeEach + void setUp() throws Exception { List iterable = new ArrayList<>(); for (int i = 0; i < TEST_SIZE; i++) { iterable.add(new DataSet(Nd4j.create(new float[100]), Nd4j.create(new float[10]))); } - backIterator = new ExistingDataSetIterator(iterable); } @Test - public void hasNext1() throws Exception { + @DisplayName("Has Next 1") + void hasNext1() throws Exception { for (int iter = 0; iter < ITERATIONS; iter++) { for (int prefetchSize = 2; prefetchSize <= 8; prefetchSize++) { AsyncDataSetIterator iterator = new AsyncDataSetIterator(backIterator, prefetchSize); int cnt = 0; while (iterator.hasNext()) { DataSet ds = iterator.next(); - assertNotEquals(null, ds); cnt++; } - - assertEquals("Failed on iteration: " + iter + ", prefetchSize: " + prefetchSize, TEST_SIZE, cnt); + assertEquals( TEST_SIZE, cnt,"Failed on iteration: " + iter + ", prefetchSize: " + prefetchSize); iterator.shutdown(); } } } @Test - public void hasNextWithResetAndLoad() throws Exception { + @DisplayName("Has Next With Reset And Load") + void hasNextWithResetAndLoad() throws Exception { int[] prefetchSizes; - if(isIntegrationTests()){ - prefetchSizes = new int[]{2, 3, 4, 5, 6, 7, 8}; + if (isIntegrationTests()) { + prefetchSizes = new int[] { 2, 3, 4, 5, 6, 7, 8 }; } else { - prefetchSizes = new int[]{2, 3, 8}; + prefetchSizes = new int[] { 2, 3, 8 }; } - - for (int iter = 0; iter < ITERATIONS; iter++) { - for(int prefetchSize : prefetchSizes){ + for (int prefetchSize : prefetchSizes) { AsyncDataSetIterator iterator = new AsyncDataSetIterator(backIterator, prefetchSize); TestDataSetConsumer consumer = new TestDataSetConsumer(EXECUTION_SMALL); int cnt = 0; while (iterator.hasNext()) { DataSet ds = iterator.next(); consumer.consumeOnce(ds, false); - cnt++; if (cnt == TEST_SIZE / 2) iterator.reset(); } - assertEquals(TEST_SIZE + (TEST_SIZE / 2), cnt); iterator.shutdown(); } } } - @Test - public void testWithLoad() { - + @DisplayName("Test With Load") + void testWithLoad() { for (int iter = 0; iter < ITERATIONS; iter++) { AsyncDataSetIterator iterator = new AsyncDataSetIterator(backIterator, 8); TestDataSetConsumer consumer = new TestDataSetConsumer(iterator, EXECUTION_TIME); - consumer.consumeWhileHasNext(true); - assertEquals(TEST_SIZE, consumer.getCount()); iterator.shutdown(); } } - @Test(expected = ArrayIndexOutOfBoundsException.class) - public void testWithException() { - ExistingDataSetIterator crashingIterator = new ExistingDataSetIterator(new IterableWithException(100)); - AsyncDataSetIterator iterator = new AsyncDataSetIterator(crashingIterator, 8); - - TestDataSetConsumer consumer = new TestDataSetConsumer(iterator, EXECUTION_SMALL); - consumer.consumeWhileHasNext(true); - iterator.shutdown(); + @Test + @DisplayName("Test With Exception") + void testWithException() { + assertThrows(ArrayIndexOutOfBoundsException.class, () -> { + ExistingDataSetIterator crashingIterator = new ExistingDataSetIterator(new IterableWithException(100)); + AsyncDataSetIterator iterator = new AsyncDataSetIterator(crashingIterator, 8); + TestDataSetConsumer consumer = new TestDataSetConsumer(iterator, EXECUTION_SMALL); + consumer.consumeWhileHasNext(true); + iterator.shutdown(); + }); } - - + @DisplayName("Iterable With Exception") private class IterableWithException implements Iterable { + private final AtomicLong counter = new AtomicLong(0); + private final int crashIteration; public IterableWithException(int iteration) { @@ -146,6 +146,7 @@ public class AsyncDataSetIteratorTest extends BaseDL4JTest { public Iterator iterator() { counter.set(0); return new Iterator() { + @Override public boolean hasNext() { return true; @@ -155,82 +156,59 @@ public class AsyncDataSetIteratorTest extends BaseDL4JTest { public DataSet next() { if (counter.incrementAndGet() >= crashIteration) throw new ArrayIndexOutOfBoundsException("Thrown as expected"); - return new DataSet(Nd4j.create(10), Nd4j.create(10)); } @Override public void remove() { - } }; } } - @Test - public void testVariableTimeSeries1() throws Exception { + @DisplayName("Test Variable Time Series 1") + void testVariableTimeSeries1() throws Exception { int numBatches = isIntegrationTests() ? 1000 : 100; int batchSize = isIntegrationTests() ? 32 : 8; int timeStepsMin = 10; int timeStepsMax = isIntegrationTests() ? 500 : 100; int valuesPerTimestep = isIntegrationTests() ? 128 : 16; - - AsyncDataSetIterator adsi = new AsyncDataSetIterator( - new VariableTimeseriesGenerator(1192, numBatches, batchSize, valuesPerTimestep, timeStepsMin, timeStepsMax, 10), 2, true); - + AsyncDataSetIterator adsi = new AsyncDataSetIterator(new VariableTimeseriesGenerator(1192, numBatches, batchSize, valuesPerTimestep, timeStepsMin, timeStepsMax, 10), 2, true); for (int e = 0; e < 10; e++) { int cnt = 0; while (adsi.hasNext()) { DataSet ds = adsi.next(); - - //log.info("Features ptr: {}", AtomicAllocator.getInstance().getPointer(mds.getFeatures()[0].data()).address()); - assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt, - ds.getFeatures().meanNumber().doubleValue(), 1e-10); - assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt + 0.25, - ds.getLabels().meanNumber().doubleValue(), 1e-10); - assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt + 0.5, - ds.getFeaturesMaskArray().meanNumber().doubleValue(), 1e-10); - assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt + 0.75, - ds.getLabelsMaskArray().meanNumber().doubleValue(), 1e-10); - + // log.info("Features ptr: {}", AtomicAllocator.getInstance().getPointer(mds.getFeatures()[0].data()).address()); + assertEquals( (double) cnt, ds.getFeatures().meanNumber().doubleValue(), 1e-10,"Failed on epoch " + e + "; iteration: " + cnt + ";"); + assertEquals( (double) cnt + 0.25, ds.getLabels().meanNumber().doubleValue(), 1e-10,"Failed on epoch " + e + "; iteration: " + cnt + ";"); + assertEquals( (double) cnt + 0.5, ds.getFeaturesMaskArray().meanNumber().doubleValue(), 1e-10,"Failed on epoch " + e + "; iteration: " + cnt + ";"); + assertEquals( (double) cnt + 0.75, ds.getLabelsMaskArray().meanNumber().doubleValue(), 1e-10,"Failed on epoch " + e + "; iteration: " + cnt + ";"); cnt++; } - adsi.reset(); -// log.info("Epoch {} finished...", e); + // log.info("Epoch {} finished...", e); } } @Test - public void testVariableTimeSeries2() throws Exception { - AsyncDataSetIterator adsi = - new AsyncDataSetIterator(new VariableTimeseriesGenerator(1192, 100, 32, 128, 100, 100, 100), 2, - true, new InterleavedDataSetCallback(2 * 2)); - - + @DisplayName("Test Variable Time Series 2") + void testVariableTimeSeries2() throws Exception { + AsyncDataSetIterator adsi = new AsyncDataSetIterator(new VariableTimeseriesGenerator(1192, 100, 32, 128, 100, 100, 100), 2, true, new InterleavedDataSetCallback(2 * 2)); for (int e = 0; e < 5; e++) { int cnt = 0; while (adsi.hasNext()) { - DataSet ds = adsi.next(); ds.detach(); - - //log.info("Features ptr: {}", AtomicAllocator.getInstance().getPointer(mds.getFeatures()[0].data()).address()); - assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt, - ds.getFeatures().meanNumber().doubleValue(), 1e-10); - assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt + 0.25, - ds.getLabels().meanNumber().doubleValue(), 1e-10); - assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt + 0.5, - ds.getFeaturesMaskArray().meanNumber().doubleValue(), 1e-10); - assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt + 0.75, - ds.getLabelsMaskArray().meanNumber().doubleValue(), 1e-10); - + // log.info("Features ptr: {}", AtomicAllocator.getInstance().getPointer(mds.getFeatures()[0].data()).address()); + assertEquals((double) cnt, ds.getFeatures().meanNumber().doubleValue(), 1e-10,"Failed on epoch " + e + "; iteration: " + cnt + ";"); + assertEquals((double) cnt + 0.25, ds.getLabels().meanNumber().doubleValue(), 1e-10,"Failed on epoch " + e + "; iteration: " + cnt + ";"); + assertEquals( (double) cnt + 0.5, ds.getFeaturesMaskArray().meanNumber().doubleValue(), 1e-10,"Failed on epoch " + e + "; iteration: " + cnt + ";"); + assertEquals((double) cnt + 0.75, ds.getLabelsMaskArray().meanNumber().doubleValue(), 1e-10,"Failed on epoch " + e + "; iteration: " + cnt + ";"); cnt++; } - adsi.reset(); -// log.info("Epoch {} finished...", e); + // log.info("Epoch {} finished...", e); } } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/AsyncMultiDataSetIteratorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/AsyncMultiDataSetIteratorTest.java index 4747beed8..523e8fdcd 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/AsyncMultiDataSetIteratorTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/AsyncMultiDataSetIteratorTest.java @@ -17,98 +17,19 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.datasets.iterator; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.tools.VariableMultiTimeseriesGenerator; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.dataset.api.MultiDataSet; - -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; @Slf4j -public class AsyncMultiDataSetIteratorTest extends BaseDL4JTest { - - /** - * THIS TEST SHOULD BE ALWAYS RUN WITH DOUBLE PRECISION, WITHOUT ANY EXCLUSIONS - * - * @throws Exception - */ - @Test - public void testVariableTimeSeries1() throws Exception { - int numBatches = isIntegrationTests() ? 1000 : 100; - int batchSize = isIntegrationTests() ? 32 : 8; - int timeStepsMin = 10; - int timeStepsMax = isIntegrationTests() ? 500 : 100; - int valuesPerTimestep = isIntegrationTests() ? 128 : 16; - - val iterator = new VariableMultiTimeseriesGenerator(1192, numBatches, batchSize, valuesPerTimestep, timeStepsMin, timeStepsMax, 10); - iterator.reset(); - iterator.hasNext(); - val amdsi = new AsyncMultiDataSetIterator(iterator, 2, true); - - for (int e = 0; e < 10; e++) { - int cnt = 0; - while (amdsi.hasNext()) { - MultiDataSet mds = amdsi.next(); - - - //log.info("Features ptr: {}", AtomicAllocator.getInstance().getPointer(mds.getFeatures()[0].data()).address()); - assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt, - mds.getFeatures()[0].meanNumber().doubleValue(), 1e-10); - assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt + 0.25, - mds.getLabels()[0].meanNumber().doubleValue(), 1e-10); - assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt + 0.5, - mds.getFeaturesMaskArrays()[0].meanNumber().doubleValue(), 1e-10); - assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt + 0.75, - mds.getLabelsMaskArrays()[0].meanNumber().doubleValue(), 1e-10); - - cnt++; - } - - amdsi.reset(); - log.info("Epoch {} finished...", e); - } - } - - - @Test - public void testVariableTimeSeries2() throws Exception { - int numBatches = isIntegrationTests() ? 1000 : 100; - int batchSize = isIntegrationTests() ? 32 : 8; - int timeStepsMin = 10; - int timeStepsMax = isIntegrationTests() ? 500 : 100; - int valuesPerTimestep = isIntegrationTests() ? 128 : 16; - - val iterator = new VariableMultiTimeseriesGenerator(1192, numBatches, batchSize, valuesPerTimestep, timeStepsMin, timeStepsMax, 10); - - for (int e = 0; e < 10; e++) { - iterator.reset(); - iterator.hasNext(); - val amdsi = new AsyncMultiDataSetIterator(iterator, 2, true); - - int cnt = 0; - while (amdsi.hasNext()) { - MultiDataSet mds = amdsi.next(); - - - //log.info("Features ptr: {}", AtomicAllocator.getInstance().getPointer(mds.getFeatures()[0].data()).address()); - assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt, - mds.getFeatures()[0].meanNumber().doubleValue(), 1e-10); - assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt + 0.25, - mds.getLabels()[0].meanNumber().doubleValue(), 1e-10); - assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt + 0.5, - mds.getFeaturesMaskArrays()[0].meanNumber().doubleValue(), 1e-10); - assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt + 0.75, - mds.getLabelsMaskArrays()[0].meanNumber().doubleValue(), 1e-10); - - cnt++; - } - } - } /* @Test public void testResetBug() throws Exception { @@ -134,6 +55,120 @@ public class AsyncMultiDataSetIteratorTest extends BaseDL4JTest { trainData.reset(); + SequenceRecordReader testFeatures = new CSVSequenceRecordReader(); + testFeatures.initialize(new NumberedFileInputSplit("/home/raver119/develop/dl4j-examples/src/main/resources/uci/test/features" + "/%d.csv", 0, 149)); + RecordReader testLabels = new CSVRecordReader(); + testLabels.initialize(new NumberedFileInputSplit("/home/raver119/develop/dl4j-examples/src/main/resources/uci/test/labels" + "/%d.csv", 0, 149)); + + MultiDataSetIterator testData = new RecordReaderMultiDataSetIterator.Builder(miniBatchSize) + .addSequenceReader("features", testFeatures) + .addReader("labels", testLabels) + .addInput("features") + .addOutputOneHot("labels", 0, numLabelClasses) + .build(); + + System.out.println("-------------- HASH 1----------------"); + testData.reset(); + while(testData.hasNext()){ + System.out.println(Arrays.hashCode(testData.next().getFeatures(0).data().asFloat())); + } + + System.out.println("-------------- HASH 2 ----------------"); + testData.reset(); + testData.hasNext(); //***** Remove this (or move to after async creation), and we get expected results ***** + val adsi = new AsyncMultiDataSetIterator(testData, 4, true); //OR remove this (keeping hasNext) and we get expected results + //val adsi = new AsyncShieldMultiDataSetIterator(testData); + while(adsi.hasNext()){ + System.out.println(Arrays.hashCode(adsi.next().getFeatures(0).data().asFloat())); + } + } + */ +@DisplayName("Async Multi Data Set Iterator Test") +class AsyncMultiDataSetIteratorTest extends BaseDL4JTest { + + /** + * THIS TEST SHOULD BE ALWAYS RUN WITH DOUBLE PRECISION, WITHOUT ANY EXCLUSIONS + * + * @throws Exception + */ + @Test + @DisplayName("Test Variable Time Series 1") + void testVariableTimeSeries1() throws Exception { + int numBatches = isIntegrationTests() ? 1000 : 100; + int batchSize = isIntegrationTests() ? 32 : 8; + int timeStepsMin = 10; + int timeStepsMax = isIntegrationTests() ? 500 : 100; + int valuesPerTimestep = isIntegrationTests() ? 128 : 16; + val iterator = new VariableMultiTimeseriesGenerator(1192, numBatches, batchSize, valuesPerTimestep, timeStepsMin, timeStepsMax, 10); + iterator.reset(); + iterator.hasNext(); + val amdsi = new AsyncMultiDataSetIterator(iterator, 2, true); + for (int e = 0; e < 10; e++) { + int cnt = 0; + while (amdsi.hasNext()) { + MultiDataSet mds = amdsi.next(); + // log.info("Features ptr: {}", AtomicAllocator.getInstance().getPointer(mds.getFeatures()[0].data()).address()); + assertEquals( (double) cnt, mds.getFeatures()[0].meanNumber().doubleValue(), 1e-10,"Failed on epoch " + e + "; iteration: " + cnt + ";"); + assertEquals( (double) cnt + 0.25, mds.getLabels()[0].meanNumber().doubleValue(), 1e-10,"Failed on epoch " + e + "; iteration: " + cnt + ";"); + assertEquals((double) cnt + 0.5, mds.getFeaturesMaskArrays()[0].meanNumber().doubleValue(), 1e-10,"Failed on epoch " + e + "; iteration: " + cnt + ";"); + assertEquals((double) cnt + 0.75, mds.getLabelsMaskArrays()[0].meanNumber().doubleValue(), 1e-10,"Failed on epoch " + e + "; iteration: " + cnt + ";"); + cnt++; + } + amdsi.reset(); + log.info("Epoch {} finished...", e); + } + } + + @Test + @DisplayName("Test Variable Time Series 2") + void testVariableTimeSeries2() throws Exception { + int numBatches = isIntegrationTests() ? 1000 : 100; + int batchSize = isIntegrationTests() ? 32 : 8; + int timeStepsMin = 10; + int timeStepsMax = isIntegrationTests() ? 500 : 100; + int valuesPerTimestep = isIntegrationTests() ? 128 : 16; + val iterator = new VariableMultiTimeseriesGenerator(1192, numBatches, batchSize, valuesPerTimestep, timeStepsMin, timeStepsMax, 10); + for (int e = 0; e < 10; e++) { + iterator.reset(); + iterator.hasNext(); + val amdsi = new AsyncMultiDataSetIterator(iterator, 2, true); + int cnt = 0; + while (amdsi.hasNext()) { + MultiDataSet mds = amdsi.next(); + // log.info("Features ptr: {}", AtomicAllocator.getInstance().getPointer(mds.getFeatures()[0].data()).address()); + assertEquals( (double) cnt, mds.getFeatures()[0].meanNumber().doubleValue(), 1e-10,"Failed on epoch " + e + "; iteration: " + cnt + ";"); + assertEquals((double) cnt + 0.25, mds.getLabels()[0].meanNumber().doubleValue(), 1e-10,"Failed on epoch " + e + "; iteration: " + cnt + ";"); + assertEquals( (double) cnt + 0.5, mds.getFeaturesMaskArrays()[0].meanNumber().doubleValue(), 1e-10,"Failed on epoch " + e + "; iteration: " + cnt + ";"); + assertEquals( (double) cnt + 0.75, mds.getLabelsMaskArrays()[0].meanNumber().doubleValue(), 1e-10,"Failed on epoch " + e + "; iteration: " + cnt + ";"); + cnt++; + } + } + } + /* + @Test + public void testResetBug() throws Exception { + // /home/raver119/develop/dl4j-examples/src/main/resources/uci/train/features + + SequenceRecordReader trainFeatures = new CSVSequenceRecordReader(); + trainFeatures.initialize(new NumberedFileInputSplit("/home/raver119/develop/dl4j-examples/src/main/resources/uci/train/features" + "/%d.csv", 0, 449)); + RecordReader trainLabels = new CSVRecordReader(); + trainLabels.initialize(new NumberedFileInputSplit("/home/raver119/develop/dl4j-examples/src/main/resources/uci/train/labels" + "/%d.csv", 0, 449)); + + int miniBatchSize = 10; + int numLabelClasses = 6; + MultiDataSetIterator trainData = new RecordReaderMultiDataSetIterator.Builder(miniBatchSize) + .addSequenceReader("features", trainFeatures) + .addReader("labels", trainLabels) + .addInput("features") + .addOutputOneHot("labels", 0, numLabelClasses) + .build(); + + //Normalize the training data + MultiDataNormalization normalizer = new MultiNormalizerStandardize(); + normalizer.fit(trainData); //Collect training data statistics + trainData.reset(); + + SequenceRecordReader testFeatures = new CSVSequenceRecordReader(); testFeatures.initialize(new NumberedFileInputSplit("/home/raver119/develop/dl4j-examples/src/main/resources/uci/test/features" + "/%d.csv", 0, 149)); RecordReader testLabels = new CSVRecordReader(); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/DataSetIteratorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/DataSetIteratorTest.java index 11a151988..9e8114712 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/DataSetIteratorTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/DataSetIteratorTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.datasets.iterator; import org.datavec.api.records.reader.impl.csv.CSVRecordReader; @@ -41,8 +40,8 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.listeners.CollectScoresIterationListener; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; @@ -50,26 +49,28 @@ import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.common.io.ClassPathResource; import org.nd4j.linalg.lossfunctions.LossFunctions; - import java.util.ArrayList; import java.util.List; import java.util.Random; +import static org.junit.jupiter.api.Assertions.*; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.*; - -public class DataSetIteratorTest extends BaseDL4JTest { +@DisplayName("Data Set Iterator Test") +class DataSetIteratorTest extends BaseDL4JTest { @Override public long getTimeoutMilliseconds() { - return 360000; //Should run quickly; increased to large timeout due to occasonal slow CI downloads + // Should run quickly; increased to large timeout due to occasonal slow CI downloads + return 360000; } @Test - public void testBatchSizeOfOneIris() throws Exception { - //Test for (a) iterators returning correct number of examples, and - //(b) Labels are a proper one-hot vector (i.e., sum is 1.0) - - //Iris: + @DisplayName("Test Batch Size Of One Iris") + void testBatchSizeOfOneIris() throws Exception { + // Test for (a) iterators returning correct number of examples, and + // (b) Labels are a proper one-hot vector (i.e., sum is 1.0) + // Iris: DataSetIterator iris = new IrisDataSetIterator(1, 5); int irisC = 0; while (iris.hasNext()) { @@ -81,9 +82,9 @@ public class DataSetIteratorTest extends BaseDL4JTest { } @Test - public void testBatchSizeOfOneMnist() throws Exception { - - //MNIST: + @DisplayName("Test Batch Size Of One Mnist") + void testBatchSizeOfOneMnist() throws Exception { + // MNIST: DataSetIterator mnist = new MnistDataSetIterator(1, 5); int mnistC = 0; while (mnist.hasNext()) { @@ -95,25 +96,21 @@ public class DataSetIteratorTest extends BaseDL4JTest { } @Test - public void testMnist() throws Exception { + @DisplayName("Test Mnist") + void testMnist() throws Exception { ClassPathResource cpr = new ClassPathResource("mnist_first_200.txt"); CSVRecordReader rr = new CSVRecordReader(0, ','); rr.initialize(new FileSplit(cpr.getTempFileFromArchive())); RecordReaderDataSetIterator dsi = new RecordReaderDataSetIterator(rr, 10, 0, 10); - MnistDataSetIterator iter = new MnistDataSetIterator(10, 200, false, true, false, 0); - while (dsi.hasNext()) { DataSet dsExp = dsi.next(); DataSet dsAct = iter.next(); - INDArray fExp = dsExp.getFeatures(); fExp.divi(255); INDArray lExp = dsExp.getLabels(); - INDArray fAct = dsAct.getFeatures(); INDArray lAct = dsAct.getLabels(); - assertEquals(fExp, fAct.castTo(fExp.dataType())); assertEquals(lExp, lAct.castTo(lExp.dataType())); } @@ -121,12 +118,13 @@ public class DataSetIteratorTest extends BaseDL4JTest { } @Test - public void testLfwIterator() throws Exception { + @DisplayName("Test Lfw Iterator") + void testLfwIterator() throws Exception { int numExamples = 1; int row = 28; int col = 28; int channels = 1; - LFWDataSetIterator iter = new LFWDataSetIterator(numExamples, new int[] {row, col, channels}, true); + LFWDataSetIterator iter = new LFWDataSetIterator(numExamples, new int[] { row, col, channels }, true); assertTrue(iter.hasNext()); DataSet data = iter.next(); assertEquals(numExamples, data.getLabels().size(0)); @@ -134,7 +132,8 @@ public class DataSetIteratorTest extends BaseDL4JTest { } @Test - public void testTinyImageNetIterator() throws Exception { + @DisplayName("Test Tiny Image Net Iterator") + void testTinyImageNetIterator() throws Exception { int numClasses = 200; int row = 64; int col = 64; @@ -143,24 +142,26 @@ public class DataSetIteratorTest extends BaseDL4JTest { assertTrue(iter.hasNext()); DataSet data = iter.next(); assertEquals(numClasses, data.getLabels().size(1)); - assertArrayEquals(new long[]{1, channels, row, col}, data.getFeatures().shape()); + assertArrayEquals(new long[] { 1, channels, row, col }, data.getFeatures().shape()); } @Test - public void testTinyImageNetIterator2() throws Exception { + @DisplayName("Test Tiny Image Net Iterator 2") + void testTinyImageNetIterator2() throws Exception { int numClasses = 200; int row = 224; int col = 224; int channels = 3; - TinyImageNetDataSetIterator iter = new TinyImageNetDataSetIterator(1, new int[]{row, col}, DataSetType.TEST); + TinyImageNetDataSetIterator iter = new TinyImageNetDataSetIterator(1, new int[] { row, col }, DataSetType.TEST); assertTrue(iter.hasNext()); DataSet data = iter.next(); assertEquals(numClasses, data.getLabels().size(1)); - assertArrayEquals(new long[]{1, channels, row, col}, data.getFeatures().shape()); + assertArrayEquals(new long[] { 1, channels, row, col }, data.getFeatures().shape()); } @Test - public void testLfwModel() throws Exception { + @DisplayName("Test Lfw Model") + void testLfwModel() throws Exception { final int numRows = 28; final int numColumns = 28; int numChannels = 3; @@ -169,39 +170,22 @@ public class DataSetIteratorTest extends BaseDL4JTest { int batchSize = 2; int seed = 123; int listenerFreq = 1; - - LFWDataSetIterator lfw = new LFWDataSetIterator(batchSize, numSamples, - new int[] {numRows, numColumns, numChannels}, outputNum, false, true, 1.0, new Random(seed)); - - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed) - .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list() - .layer(0, new ConvolutionLayer.Builder(5, 5).nIn(numChannels).nOut(6) - .weightInit(WeightInit.XAVIER).activation(Activation.RELU).build()) - .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[] {2, 2}) - .stride(1, 1).build()) - .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) - .nOut(outputNum).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX) - .build()) - .setInputType(InputType.convolutionalFlat(numRows, numColumns, numChannels)) - ; - + LFWDataSetIterator lfw = new LFWDataSetIterator(batchSize, numSamples, new int[] { numRows, numColumns, numChannels }, outputNum, false, true, 1.0, new Random(seed)); + MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed).gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list().layer(0, new ConvolutionLayer.Builder(5, 5).nIn(numChannels).nOut(6).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build()).layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[] { 2, 2 }).stride(1, 1).build()).layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(outputNum).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build()).setInputType(InputType.convolutionalFlat(numRows, numColumns, numChannels)); MultiLayerNetwork model = new MultiLayerNetwork(builder.build()); model.init(); - model.setListeners(new ScoreIterationListener(listenerFreq)); - model.fit(lfw.next()); - DataSet dataTest = lfw.next(); INDArray output = model.output(dataTest.getFeatures()); Evaluation eval = new Evaluation(outputNum); eval.eval(dataTest.getLabels(), output); -// System.out.println(eval.stats()); + // System.out.println(eval.stats()); } @Test - public void testCifar10Iterator() throws Exception { + @DisplayName("Test Cifar 10 Iterator") + void testCifar10Iterator() throws Exception { int numExamples = 1; int row = 32; int col = 32; @@ -213,12 +197,13 @@ public class DataSetIteratorTest extends BaseDL4JTest { assertEquals(channels * row * col, data.getFeatures().ravel().length()); } - - @Test @Ignore //Ignored for now - CIFAR iterator needs work - https://github.com/eclipse/deeplearning4j/issues/4673 - public void testCifarModel() throws Exception { + // Ignored for now - CIFAR iterator needs work - https://github.com/eclipse/deeplearning4j/issues/4673 + @Test + @Disabled + @DisplayName("Test Cifar Model") + void testCifarModel() throws Exception { // Streaming runCifar(false); - // Preprocess runCifar(true); } @@ -231,32 +216,14 @@ public class DataSetIteratorTest extends BaseDL4JTest { int batchSize = 5; int seed = 123; int listenerFreq = 1; - Cifar10DataSetIterator cifar = new Cifar10DataSetIterator(batchSize); - - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed) - .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list() - .layer(0, new ConvolutionLayer.Builder(5, 5).nIn(channels).nOut(6).weightInit(WeightInit.XAVIER) - .activation(Activation.RELU).build()) - .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[] {2, 2}) - .build()) - .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) - .nOut(outputNum).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX) - .build()) - - .setInputType(InputType.convolutionalFlat(height, width, channels)); - + MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed).gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list().layer(0, new ConvolutionLayer.Builder(5, 5).nIn(channels).nOut(6).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build()).layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[] { 2, 2 }).build()).layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(outputNum).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build()).setInputType(InputType.convolutionalFlat(height, width, channels)); MultiLayerNetwork model = new MultiLayerNetwork(builder.build()); model.init(); - - //model.setListeners(Arrays.asList((TrainingListener) new ScoreIterationListener(listenerFreq))); - + // model.setListeners(Arrays.asList((TrainingListener) new ScoreIterationListener(listenerFreq))); CollectScoresIterationListener listener = new CollectScoresIterationListener(listenerFreq); model.setListeners(listener); - model.fit(cifar); - cifar = new Cifar10DataSetIterator(batchSize); Evaluation eval = new Evaluation(cifar.getLabels()); while (cifar.hasNext()) { @@ -264,37 +231,31 @@ public class DataSetIteratorTest extends BaseDL4JTest { INDArray output = model.output(testDS.getFeatures()); eval.eval(testDS.getLabels(), output); } -// System.out.println(eval.stats(true)); + // System.out.println(eval.stats(true)); listener.exportScores(System.out); } - @Test - public void testIteratorDataSetIteratorCombining() { - //Test combining of a bunch of small (size 1) data sets together - + @DisplayName("Test Iterator Data Set Iterator Combining") + void testIteratorDataSetIteratorCombining() { + // Test combining of a bunch of small (size 1) data sets together int batchSize = 3; int numBatches = 4; - int featureSize = 5; int labelSize = 6; - Nd4j.getRandom().setSeed(12345); - List orig = new ArrayList<>(); for (int i = 0; i < batchSize * numBatches; i++) { INDArray features = Nd4j.rand(1, featureSize); INDArray labels = Nd4j.rand(1, labelSize); orig.add(new DataSet(features, labels)); } - DataSetIterator iter = new IteratorDataSetIterator(orig.iterator(), batchSize); int count = 0; while (iter.hasNext()) { DataSet ds = iter.next(); - assertArrayEquals(new long[] {batchSize, featureSize}, ds.getFeatures().shape()); - assertArrayEquals(new long[] {batchSize, labelSize}, ds.getLabels().shape()); - + assertArrayEquals(new long[] { batchSize, featureSize }, ds.getFeatures().shape()); + assertArrayEquals(new long[] { batchSize, labelSize }, ds.getLabels().shape()); List fList = new ArrayList<>(); List lList = new ArrayList<>(); for (int i = 0; i < batchSize; i++) { @@ -302,66 +263,44 @@ public class DataSetIteratorTest extends BaseDL4JTest { fList.add(dsOrig.getFeatures()); lList.add(dsOrig.getLabels()); } - INDArray fExp = Nd4j.vstack(fList); INDArray lExp = Nd4j.vstack(lList); - assertEquals(fExp, ds.getFeatures()); assertEquals(lExp, ds.getLabels()); - count++; } - assertEquals(count, numBatches); } @Test - public void testIteratorDataSetIteratorSplitting() { - //Test splitting large data sets into smaller ones - + @DisplayName("Test Iterator Data Set Iterator Splitting") + void testIteratorDataSetIteratorSplitting() { + // Test splitting large data sets into smaller ones int origBatchSize = 4; int origNumDSs = 3; - int batchSize = 3; int numBatches = 4; - int featureSize = 5; int labelSize = 6; - Nd4j.getRandom().setSeed(12345); - List orig = new ArrayList<>(); for (int i = 0; i < origNumDSs; i++) { INDArray features = Nd4j.rand(origBatchSize, featureSize); INDArray labels = Nd4j.rand(origBatchSize, labelSize); orig.add(new DataSet(features, labels)); } - - List expected = new ArrayList<>(); - expected.add(new DataSet(orig.get(0).getFeatures().getRows(0, 1, 2), - orig.get(0).getLabels().getRows(0, 1, 2))); - expected.add(new DataSet( - Nd4j.vstack(orig.get(0).getFeatures().getRows(3), - orig.get(1).getFeatures().getRows(0, 1)), - Nd4j.vstack(orig.get(0).getLabels().getRows(3), orig.get(1).getLabels().getRows(0, 1)))); - expected.add(new DataSet( - Nd4j.vstack(orig.get(1).getFeatures().getRows(2, 3), - orig.get(2).getFeatures().getRows(0)), - Nd4j.vstack(orig.get(1).getLabels().getRows(2, 3), orig.get(2).getLabels().getRows(0)))); - expected.add(new DataSet(orig.get(2).getFeatures().getRows(1, 2, 3), - orig.get(2).getLabels().getRows(1, 2, 3))); - - + expected.add(new DataSet(orig.get(0).getFeatures().getRows(0, 1, 2), orig.get(0).getLabels().getRows(0, 1, 2))); + expected.add(new DataSet(Nd4j.vstack(orig.get(0).getFeatures().getRows(3), orig.get(1).getFeatures().getRows(0, 1)), Nd4j.vstack(orig.get(0).getLabels().getRows(3), orig.get(1).getLabels().getRows(0, 1)))); + expected.add(new DataSet(Nd4j.vstack(orig.get(1).getFeatures().getRows(2, 3), orig.get(2).getFeatures().getRows(0)), Nd4j.vstack(orig.get(1).getLabels().getRows(2, 3), orig.get(2).getLabels().getRows(0)))); + expected.add(new DataSet(orig.get(2).getFeatures().getRows(1, 2, 3), orig.get(2).getLabels().getRows(1, 2, 3))); DataSetIterator iter = new IteratorDataSetIterator(orig.iterator(), batchSize); int count = 0; while (iter.hasNext()) { DataSet ds = iter.next(); assertEquals(expected.get(count), ds); - count++; } - assertEquals(count, numBatches); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/EarlyTerminationDataSetIteratorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/EarlyTerminationDataSetIteratorTest.java index 3221386f5..40f2d8abe 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/EarlyTerminationDataSetIteratorTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/EarlyTerminationDataSetIteratorTest.java @@ -17,13 +17,12 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.datasets.iterator; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.junit.Rule; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.rules.ExpectedException; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; @@ -32,23 +31,27 @@ import java.io.IOException; import java.util.ArrayList; import java.util.List; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; -public class EarlyTerminationDataSetIteratorTest extends BaseDL4JTest { +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; + +@DisplayName("Early Termination Data Set Iterator Test") +class EarlyTerminationDataSetIteratorTest extends BaseDL4JTest { int minibatchSize = 10; + int numExamples = 105; + @Rule public final ExpectedException exception = ExpectedException.none(); @Test - public void testNextAndReset() throws Exception { - + @DisplayName("Test Next And Reset") + void testNextAndReset() throws Exception { int terminateAfter = 2; - DataSetIterator iter = new MnistDataSetIterator(minibatchSize, numExamples); EarlyTerminationDataSetIterator earlyEndIter = new EarlyTerminationDataSetIterator(iter, terminateAfter); - assertTrue(earlyEndIter.hasNext()); int batchesSeen = 0; List seenData = new ArrayList<>(); @@ -59,8 +62,7 @@ public class EarlyTerminationDataSetIteratorTest extends BaseDL4JTest { batchesSeen++; } assertEquals(batchesSeen, terminateAfter); - - //check data is repeated after reset + // check data is repeated after reset earlyEndIter.reset(); batchesSeen = 0; while (earlyEndIter.hasNext()) { @@ -72,27 +74,23 @@ public class EarlyTerminationDataSetIteratorTest extends BaseDL4JTest { } @Test - public void testNextNum() throws IOException { + @DisplayName("Test Next Num") + void testNextNum() throws IOException { int terminateAfter = 1; - DataSetIterator iter = new MnistDataSetIterator(minibatchSize, numExamples); EarlyTerminationDataSetIterator earlyEndIter = new EarlyTerminationDataSetIterator(iter, terminateAfter); - earlyEndIter.next(10); assertEquals(false, earlyEndIter.hasNext()); - earlyEndIter.reset(); assertEquals(true, earlyEndIter.hasNext()); - } @Test - public void testCallstoNextNotAllowed() throws IOException { + @DisplayName("Test Callsto Next Not Allowed") + void testCallstoNextNotAllowed() throws IOException { int terminateAfter = 1; - DataSetIterator iter = new MnistDataSetIterator(minibatchSize, numExamples); EarlyTerminationDataSetIterator earlyEndIter = new EarlyTerminationDataSetIterator(iter, terminateAfter); - earlyEndIter.next(10); iter.reset(); exception.expect(RuntimeException.class); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/EarlyTerminationMultiDataSetIteratorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/EarlyTerminationMultiDataSetIteratorTest.java index 51f7cd949..06b55bfcb 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/EarlyTerminationMultiDataSetIteratorTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/EarlyTerminationMultiDataSetIteratorTest.java @@ -17,40 +17,39 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.datasets.iterator; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.junit.Rule; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.rules.ExpectedException; import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter; import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; - import java.io.IOException; import java.util.ArrayList; import java.util.List; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; - -public class EarlyTerminationMultiDataSetIteratorTest extends BaseDL4JTest { +@DisplayName("Early Termination Multi Data Set Iterator Test") +class EarlyTerminationMultiDataSetIteratorTest extends BaseDL4JTest { int minibatchSize = 5; + int numExamples = 105; + @Rule public final ExpectedException exception = ExpectedException.none(); @Test - public void testNextAndReset() throws Exception { - + @DisplayName("Test Next And Reset") + void testNextAndReset() throws Exception { int terminateAfter = 2; - - MultiDataSetIterator iter = - new MultiDataSetIteratorAdapter(new MnistDataSetIterator(minibatchSize, numExamples)); - + MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(new MnistDataSetIterator(minibatchSize, numExamples)); int count = 0; List seenMDS = new ArrayList<>(); while (count < terminateAfter) { @@ -58,10 +57,7 @@ public class EarlyTerminationMultiDataSetIteratorTest extends BaseDL4JTest { count++; } iter.reset(); - - EarlyTerminationMultiDataSetIterator earlyEndIter = - new EarlyTerminationMultiDataSetIterator(iter, terminateAfter); - + EarlyTerminationMultiDataSetIterator earlyEndIter = new EarlyTerminationMultiDataSetIterator(iter, terminateAfter); assertTrue(earlyEndIter.hasNext()); count = 0; while (earlyEndIter.hasNext()) { @@ -71,8 +67,7 @@ public class EarlyTerminationMultiDataSetIteratorTest extends BaseDL4JTest { count++; } assertEquals(count, terminateAfter); - - //check data is repeated + // check data is repeated earlyEndIter.reset(); count = 0; while (earlyEndIter.hasNext()) { @@ -84,34 +79,26 @@ public class EarlyTerminationMultiDataSetIteratorTest extends BaseDL4JTest { } @Test - public void testNextNum() throws IOException { + @DisplayName("Test Next Num") + void testNextNum() throws IOException { int terminateAfter = 1; - - MultiDataSetIterator iter = - new MultiDataSetIteratorAdapter(new MnistDataSetIterator(minibatchSize, numExamples)); - EarlyTerminationMultiDataSetIterator earlyEndIter = - new EarlyTerminationMultiDataSetIterator(iter, terminateAfter); - + MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(new MnistDataSetIterator(minibatchSize, numExamples)); + EarlyTerminationMultiDataSetIterator earlyEndIter = new EarlyTerminationMultiDataSetIterator(iter, terminateAfter); earlyEndIter.next(10); assertEquals(false, earlyEndIter.hasNext()); - earlyEndIter.reset(); assertEquals(true, earlyEndIter.hasNext()); } @Test - public void testCallstoNextNotAllowed() throws IOException { + @DisplayName("Test Callsto Next Not Allowed") + void testCallstoNextNotAllowed() throws IOException { int terminateAfter = 1; - - MultiDataSetIterator iter = - new MultiDataSetIteratorAdapter(new MnistDataSetIterator(minibatchSize, numExamples)); - EarlyTerminationMultiDataSetIterator earlyEndIter = - new EarlyTerminationMultiDataSetIterator(iter, terminateAfter); - + MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(new MnistDataSetIterator(minibatchSize, numExamples)); + EarlyTerminationMultiDataSetIterator earlyEndIter = new EarlyTerminationMultiDataSetIterator(iter, terminateAfter); earlyEndIter.next(10); iter.reset(); exception.expect(RuntimeException.class); earlyEndIter.next(10); } - } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/JointParallelDataSetIteratorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/JointParallelDataSetIteratorTest.java index 23c2da124..de5573c56 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/JointParallelDataSetIteratorTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/JointParallelDataSetIteratorTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.datasets.iterator; import lombok.extern.slf4j.Slf4j; @@ -25,90 +24,75 @@ import lombok.val; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.parallel.JointParallelDataSetIterator; import org.deeplearning4j.datasets.iterator.tools.SimpleVariableGenerator; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.enums.InequalityHandling; import org.nd4j.linalg.factory.Nd4j; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; @Slf4j -public class JointParallelDataSetIteratorTest extends BaseDL4JTest { +@DisplayName("Joint Parallel Data Set Iterator Test") +class JointParallelDataSetIteratorTest extends BaseDL4JTest { /** * Simple test, checking datasets alignment. They all should have the same data for the same cycle * - * * @throws Exception */ @Test - public void testJointIterator1() throws Exception { + @DisplayName("Test Joint Iterator 1") + void testJointIterator1() throws Exception { DataSetIterator iteratorA = new SimpleVariableGenerator(119, 100, 32, 100, 10); DataSetIterator iteratorB = new SimpleVariableGenerator(119, 100, 32, 100, 10); - - JointParallelDataSetIterator jpdsi = new JointParallelDataSetIterator.Builder(InequalityHandling.STOP_EVERYONE) - .addSourceIterator(iteratorA).addSourceIterator(iteratorB).build(); - + JointParallelDataSetIterator jpdsi = new JointParallelDataSetIterator.Builder(InequalityHandling.STOP_EVERYONE).addSourceIterator(iteratorA).addSourceIterator(iteratorB).build(); int cnt = 0; int example = 0; while (jpdsi.hasNext()) { DataSet ds = jpdsi.next(); - assertNotNull("Failed on iteration " + cnt, ds); - -// ds.detach(); - //ds.migrate(); - - assertEquals("Failed on iteration " + cnt, (double) example, ds.getFeatures().meanNumber().doubleValue(), 0.001); - assertEquals("Failed on iteration " + cnt, (double) example + 0.5, ds.getLabels().meanNumber().doubleValue(), 0.001); - + assertNotNull(ds,"Failed on iteration " + cnt); + // ds.detach(); + // ds.migrate(); + assertEquals( (double) example, ds.getFeatures().meanNumber().doubleValue(), 0.001,"Failed on iteration " + cnt); + assertEquals( (double) example + 0.5, ds.getLabels().meanNumber().doubleValue(), 0.001,"Failed on iteration " + cnt); cnt++; if (cnt % 2 == 0) example++; } - assertEquals(100, example); assertEquals(200, cnt); } - /** * This test checks for pass_null scenario, so in total we should have 300 real datasets + 100 nulls * @throws Exception */ @Test - public void testJointIterator2() throws Exception { + @DisplayName("Test Joint Iterator 2") + void testJointIterator2() throws Exception { DataSetIterator iteratorA = new SimpleVariableGenerator(119, 200, 32, 100, 10); DataSetIterator iteratorB = new SimpleVariableGenerator(119, 100, 32, 100, 10); - - JointParallelDataSetIterator jpdsi = new JointParallelDataSetIterator.Builder(InequalityHandling.PASS_NULL) - .addSourceIterator(iteratorA).addSourceIterator(iteratorB).build(); - + JointParallelDataSetIterator jpdsi = new JointParallelDataSetIterator.Builder(InequalityHandling.PASS_NULL).addSourceIterator(iteratorA).addSourceIterator(iteratorB).build(); int cnt = 0; int example = 0; int nulls = 0; while (jpdsi.hasNext()) { DataSet ds = jpdsi.next(); if (cnt < 200) - assertNotNull("Failed on iteration " + cnt, ds); - + assertNotNull(ds,"Failed on iteration " + cnt); if (ds == null) nulls++; - if (cnt % 2 == 2) { - assertEquals("Failed on iteration " + cnt, (double) example, - ds.getFeatures().meanNumber().doubleValue(), 0.001); - assertEquals("Failed on iteration " + cnt, (double) example + 0.5, - ds.getLabels().meanNumber().doubleValue(), 0.001); + assertEquals((double) example, ds.getFeatures().meanNumber().doubleValue(), 0.001,"Failed on iteration " + cnt); + assertEquals((double) example + 0.5, ds.getLabels().meanNumber().doubleValue(), 0.001,"Failed on iteration " + cnt); } - - cnt++; if (cnt % 2 == 0) example++; } - assertEquals(100, nulls); assertEquals(200, example); assertEquals(400, cnt); @@ -120,25 +104,18 @@ public class JointParallelDataSetIteratorTest extends BaseDL4JTest { * @throws Exception */ @Test - public void testJointIterator3() throws Exception { + @DisplayName("Test Joint Iterator 3") + void testJointIterator3() throws Exception { DataSetIterator iteratorA = new SimpleVariableGenerator(119, 200, 32, 100, 10); DataSetIterator iteratorB = new SimpleVariableGenerator(119, 100, 32, 100, 10); - - JointParallelDataSetIterator jpdsi = new JointParallelDataSetIterator.Builder(InequalityHandling.RELOCATE) - .addSourceIterator(iteratorA).addSourceIterator(iteratorB).build(); - + JointParallelDataSetIterator jpdsi = new JointParallelDataSetIterator.Builder(InequalityHandling.RELOCATE).addSourceIterator(iteratorA).addSourceIterator(iteratorB).build(); int cnt = 0; int example = 0; while (jpdsi.hasNext()) { DataSet ds = jpdsi.next(); - assertNotNull("Failed on iteration " + cnt, ds); - - assertEquals("Failed on iteration " + cnt, (double) example, ds.getFeatures().meanNumber().doubleValue(), - 0.001); - assertEquals("Failed on iteration " + cnt, (double) example + 0.5, - ds.getLabels().meanNumber().doubleValue(), 0.001); - - + assertNotNull(ds,"Failed on iteration " + cnt); + assertEquals((double) example, ds.getFeatures().meanNumber().doubleValue(), 0.001,"Failed on iteration " + cnt); + assertEquals( (double) example + 0.5, ds.getLabels().meanNumber().doubleValue(), 0.001,"Failed on iteration " + cnt); cnt++; if (cnt < 200) { if (cnt % 2 == 0) @@ -146,8 +123,6 @@ public class JointParallelDataSetIteratorTest extends BaseDL4JTest { } else example++; } - - assertEquals(300, cnt); assertEquals(200, example); } @@ -158,52 +133,38 @@ public class JointParallelDataSetIteratorTest extends BaseDL4JTest { * @throws Exception */ @Test - public void testJointIterator4() throws Exception { + @DisplayName("Test Joint Iterator 4") + void testJointIterator4() throws Exception { DataSetIterator iteratorA = new SimpleVariableGenerator(119, 200, 32, 100, 10); DataSetIterator iteratorB = new SimpleVariableGenerator(119, 100, 32, 100, 10); - - JointParallelDataSetIterator jpdsi = new JointParallelDataSetIterator.Builder(InequalityHandling.RESET) - .addSourceIterator(iteratorA).addSourceIterator(iteratorB).build(); - + JointParallelDataSetIterator jpdsi = new JointParallelDataSetIterator.Builder(InequalityHandling.RESET).addSourceIterator(iteratorA).addSourceIterator(iteratorB).build(); int cnt = 0; int cnt_sec = 0; int example_sec = 0; int example = 0; while (jpdsi.hasNext()) { DataSet ds = jpdsi.next(); - assertNotNull("Failed on iteration " + cnt, ds); - + assertNotNull(ds,"Failed on iteration " + cnt); if (cnt % 2 == 0) { - assertEquals("Failed on iteration " + cnt, (double) example, - ds.getFeatures().meanNumber().doubleValue(), 0.001); - assertEquals("Failed on iteration " + cnt, (double) example + 0.5, - ds.getLabels().meanNumber().doubleValue(), 0.001); + assertEquals( (double) example, ds.getFeatures().meanNumber().doubleValue(), 0.001,"Failed on iteration " + cnt); + assertEquals((double) example + 0.5, ds.getLabels().meanNumber().doubleValue(), 0.001,"Failed on iteration " + cnt); } else { if (cnt <= 200) { - assertEquals("Failed on iteration " + cnt, (double) example, - ds.getFeatures().meanNumber().doubleValue(), 0.001); - assertEquals("Failed on iteration " + cnt, (double) example + 0.5, - ds.getLabels().meanNumber().doubleValue(), 0.001); + assertEquals((double) example, ds.getFeatures().meanNumber().doubleValue(), 0.001,"Failed on iteration " + cnt); + assertEquals( (double) example + 0.5, ds.getLabels().meanNumber().doubleValue(), 0.001,"Failed on iteration " + cnt); } else { - assertEquals("Failed on iteration " + cnt + ", second iteration " + cnt_sec, (double) example_sec, - ds.getFeatures().meanNumber().doubleValue(), 0.001); - assertEquals("Failed on iteration " + cnt + ", second iteration " + cnt_sec, - (double) example_sec + 0.5, ds.getLabels().meanNumber().doubleValue(), 0.001); + assertEquals((double) example_sec, ds.getFeatures().meanNumber().doubleValue(), 0.001,"Failed on iteration " + cnt + ", second iteration " + cnt_sec); + assertEquals((double) example_sec + 0.5, ds.getLabels().meanNumber().doubleValue(), 0.001,"Failed on iteration " + cnt + ", second iteration " + cnt_sec); } } - cnt++; if (cnt % 2 == 0) example++; - if (cnt > 201 && cnt % 2 == 1) { cnt_sec++; example_sec++; } - } - - assertEquals(400, cnt); assertEquals(200, example); } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/MultipleEpochsIteratorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/MultipleEpochsIteratorTest.java index a013781ac..97a4f491b 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/MultipleEpochsIteratorTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/MultipleEpochsIteratorTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.datasets.iterator; import org.datavec.api.records.reader.RecordReader; @@ -27,34 +26,33 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; import org.deeplearning4j.nn.util.TestDataSetConsumer; import org.junit.Rule; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.rules.Timeout; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.resources.Resources; - import java.util.Iterator; import java.util.concurrent.atomic.AtomicLong; +import static org.junit.jupiter.api.Assertions.*; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.*; - - -public class MultipleEpochsIteratorTest extends BaseDL4JTest { +@DisplayName("Multiple Epochs Iterator Test") +class MultipleEpochsIteratorTest extends BaseDL4JTest { @Rule public Timeout timeout = Timeout.seconds(300); @Test - public void testNextAndReset() throws Exception { + @DisplayName("Test Next And Reset") + void testNextAndReset() throws Exception { int epochs = 3; - RecordReader rr = new CSVRecordReader(); rr.initialize(new FileSplit(Resources.asFile("iris.txt"))); DataSetIterator iter = new RecordReaderDataSetIterator(rr, 150); MultipleEpochsIterator multiIter = new MultipleEpochsIterator(epochs, iter); - assertTrue(multiIter.hasNext()); while (multiIter.hasNext()) { DataSet path = multiIter.next(); @@ -64,18 +62,15 @@ public class MultipleEpochsIteratorTest extends BaseDL4JTest { } @Test - public void testLoadFullDataSet() throws Exception { + @DisplayName("Test Load Full Data Set") + void testLoadFullDataSet() throws Exception { int epochs = 3; - RecordReader rr = new CSVRecordReader(); rr.initialize(new FileSplit(Resources.asFile("iris.txt"))); DataSetIterator iter = new RecordReaderDataSetIterator(rr, 150); DataSet ds = iter.next(50); - assertEquals(50, ds.getFeatures().size(0)); - MultipleEpochsIterator multiIter = new MultipleEpochsIterator(epochs, ds); - assertTrue(multiIter.hasNext()); int count = 0; while (multiIter.hasNext()) { @@ -89,28 +84,26 @@ public class MultipleEpochsIteratorTest extends BaseDL4JTest { } @Test - public void testLoadBatchDataSet() throws Exception { + @DisplayName("Test Load Batch Data Set") + void testLoadBatchDataSet() throws Exception { int epochs = 2; - RecordReader rr = new CSVRecordReader(); rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile())); DataSetIterator iter = new RecordReaderDataSetIterator(rr, 150, 4, 3); DataSet ds = iter.next(20); assertEquals(20, ds.getFeatures().size(0)); MultipleEpochsIterator multiIter = new MultipleEpochsIterator(epochs, ds); - while (multiIter.hasNext()) { DataSet path = multiIter.next(10); assertNotNull(path); assertEquals(10, path.numExamples(), 0.0); } - assertEquals(epochs, multiIter.epochs); } - @Test - public void testMEDIWithLoad1() throws Exception { + @DisplayName("Test MEDI With Load 1") + void testMEDIWithLoad1() throws Exception { ExistingDataSetIterator iter = new ExistingDataSetIterator(new IterableWithoutException(100)); MultipleEpochsIterator iterator = new MultipleEpochsIterator(10, iter, 24); TestDataSetConsumer consumer = new TestDataSetConsumer(iterator, 1); @@ -119,38 +112,39 @@ public class MultipleEpochsIteratorTest extends BaseDL4JTest { } @Test - public void testMEDIWithLoad2() throws Exception { + @DisplayName("Test MEDI With Load 2") + void testMEDIWithLoad2() throws Exception { ExistingDataSetIterator iter = new ExistingDataSetIterator(new IterableWithoutException(100)); MultipleEpochsIterator iterator = new MultipleEpochsIterator(10, iter, 24); TestDataSetConsumer consumer = new TestDataSetConsumer(iterator, 2); long num1 = 0; - for (; num1 < 150; num1++) { consumer.consumeOnce(iterator.next(), true); } iterator.reset(); - long num2 = consumer.consumeWhileHasNext(true); assertEquals((10 * 100) + 150, num1 + num2); } @Test - public void testMEDIWithLoad3() throws Exception { + @DisplayName("Test MEDI With Load 3") + void testMEDIWithLoad3() throws Exception { ExistingDataSetIterator iter = new ExistingDataSetIterator(new IterableWithoutException(10000)); MultipleEpochsIterator iterator = new MultipleEpochsIterator(iter, 24, 136); TestDataSetConsumer consumer = new TestDataSetConsumer(iterator, 2); long num1 = 0; - while (iterator.hasNext()) { consumer.consumeOnce(iterator.next(), true); num1++; } - assertEquals(136, num1); } + @DisplayName("Iterable Without Exception") private class IterableWithoutException implements Iterable { + private final AtomicLong counter = new AtomicLong(0); + private final int datasets; public IterableWithoutException(int datasets) { @@ -161,6 +155,7 @@ public class MultipleEpochsIteratorTest extends BaseDL4JTest { public Iterator iterator() { counter.set(0); return new Iterator() { + @Override public boolean hasNext() { return counter.get() < datasets; @@ -174,7 +169,6 @@ public class MultipleEpochsIteratorTest extends BaseDL4JTest { @Override public void remove() { - } }; } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/RandomDataSetIteratorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/RandomDataSetIteratorTest.java index 47a155f01..3bd2a9770 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/RandomDataSetIteratorTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/RandomDataSetIteratorTest.java @@ -17,36 +17,34 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.datasets.iterator; import org.deeplearning4j.BaseDL4JTest; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; import org.nd4j.linalg.factory.Nd4j; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; - -public class RandomDataSetIteratorTest extends BaseDL4JTest { +@DisplayName("Random Data Set Iterator Test") +class RandomDataSetIteratorTest extends BaseDL4JTest { @Test - public void testDSI(){ - DataSetIterator iter = new RandomDataSetIterator(5, new long[]{3,4}, new long[]{3,5}, RandomDataSetIterator.Values.RANDOM_UNIFORM, - RandomDataSetIterator.Values.ONE_HOT); - + @DisplayName("Test DSI") + void testDSI() { + DataSetIterator iter = new RandomDataSetIterator(5, new long[] { 3, 4 }, new long[] { 3, 5 }, RandomDataSetIterator.Values.RANDOM_UNIFORM, RandomDataSetIterator.Values.ONE_HOT); int count = 0; - while(iter.hasNext()){ + while (iter.hasNext()) { count++; DataSet ds = iter.next(); - - assertArrayEquals(new long[]{3,4}, ds.getFeatures().shape()); - assertArrayEquals(new long[]{3,5}, ds.getLabels().shape()); - + assertArrayEquals(new long[] { 3, 4 }, ds.getFeatures().shape()); + assertArrayEquals(new long[] { 3, 5 }, ds.getLabels().shape()); assertTrue(ds.getFeatures().minNumber().doubleValue() >= 0.0 && ds.getFeatures().maxNumber().doubleValue() <= 1.0); assertEquals(Nd4j.ones(3), ds.getLabels().sum(1)); } @@ -54,31 +52,23 @@ public class RandomDataSetIteratorTest extends BaseDL4JTest { } @Test - public void testMDSI(){ + @DisplayName("Test MDSI") + void testMDSI() { Nd4j.getRandom().setSeed(12345); - MultiDataSetIterator iter = new RandomMultiDataSetIterator.Builder(5) - .addFeatures(new long[]{3,4}, RandomMultiDataSetIterator.Values.INTEGER_0_100) - .addFeatures(new long[]{3,5}, RandomMultiDataSetIterator.Values.BINARY) - .addLabels(new long[]{3,6}, RandomMultiDataSetIterator.Values.ZEROS) - .build(); - + MultiDataSetIterator iter = new RandomMultiDataSetIterator.Builder(5).addFeatures(new long[] { 3, 4 }, RandomMultiDataSetIterator.Values.INTEGER_0_100).addFeatures(new long[] { 3, 5 }, RandomMultiDataSetIterator.Values.BINARY).addLabels(new long[] { 3, 6 }, RandomMultiDataSetIterator.Values.ZEROS).build(); int count = 0; - while(iter.hasNext()){ + while (iter.hasNext()) { count++; MultiDataSet mds = iter.next(); - assertEquals(2, mds.numFeatureArrays()); assertEquals(1, mds.numLabelsArrays()); - assertArrayEquals(new long[]{3,4}, mds.getFeatures(0).shape()); - assertArrayEquals(new long[]{3,5}, mds.getFeatures(1).shape()); - assertArrayEquals(new long[]{3,6}, mds.getLabels(0).shape()); - - assertTrue(mds.getFeatures(0).minNumber().doubleValue() >= 0 && mds.getFeatures(0).maxNumber().doubleValue() <= 100.0 - && mds.getFeatures(0).maxNumber().doubleValue() > 2.0); + assertArrayEquals(new long[] { 3, 4 }, mds.getFeatures(0).shape()); + assertArrayEquals(new long[] { 3, 5 }, mds.getFeatures(1).shape()); + assertArrayEquals(new long[] { 3, 6 }, mds.getLabels(0).shape()); + assertTrue(mds.getFeatures(0).minNumber().doubleValue() >= 0 && mds.getFeatures(0).maxNumber().doubleValue() <= 100.0 && mds.getFeatures(0).maxNumber().doubleValue() > 2.0); assertTrue(mds.getFeatures(1).minNumber().doubleValue() == 0.0 && mds.getFeatures(1).maxNumber().doubleValue() == 1.0); assertEquals(0.0, mds.getLabels(0).sumNumber().doubleValue(), 0.0); } assertEquals(5, count); } - } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/SamplingTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/SamplingTest.java index 81c6e1575..fc256c51c 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/SamplingTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/SamplingTest.java @@ -17,27 +17,28 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.datasets.iterator; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; - -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** * @author Adam Gibson */ -public class SamplingTest extends BaseDL4JTest { +@DisplayName("Sampling Test") +class SamplingTest extends BaseDL4JTest { @Test - public void testSample() throws Exception { + @DisplayName("Test Sample") + void testSample() throws Exception { DataSetIterator iter = new MnistDataSetIterator(10, 10); - //batch size and total + // batch size and total DataSetIterator sampling = new SamplingDataSetIterator(iter.next(), 10, 10); assertEquals(10, sampling.next().numExamples()); } - } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/EvalJsonTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/EvalJsonTest.java index 797bbd8f7..9d235ac92 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/EvalJsonTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/EvalJsonTest.java @@ -17,50 +17,46 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.eval; import org.deeplearning4j.BaseDL4JTest; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.evaluation.curves.Histogram; import org.nd4j.evaluation.curves.PrecisionRecallCurve; import org.nd4j.evaluation.curves.RocCurve; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; import org.nd4j.linalg.factory.Nd4j; - import static junit.framework.TestCase.assertNull; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; - -public class EvalJsonTest extends BaseDL4JTest { +@DisplayName("Eval Json Test") +class EvalJsonTest extends BaseDL4JTest { @Test - public void testSerdeEmpty() { + @DisplayName("Test Serde Empty") + void testSerdeEmpty() { boolean print = false; - - org.nd4j.evaluation.IEvaluation[] arr = new org.nd4j.evaluation.IEvaluation[] {new Evaluation(), new EvaluationBinary(), new ROCBinary(10), - new ROCMultiClass(10), new RegressionEvaluation(3), new RegressionEvaluation(), - new EvaluationCalibration()}; - + org.nd4j.evaluation.IEvaluation[] arr = new org.nd4j.evaluation.IEvaluation[] { new Evaluation(), new EvaluationBinary(), new ROCBinary(10), new ROCMultiClass(10), new RegressionEvaluation(3), new RegressionEvaluation(), new EvaluationCalibration() }; for (org.nd4j.evaluation.IEvaluation e : arr) { String json = e.toJson(); String stats = e.stats(); if (print) { System.out.println(e.getClass() + "\n" + json + "\n\n"); } - IEvaluation fromJson = (IEvaluation) org.nd4j.evaluation.BaseEvaluation.fromJson(json, org.nd4j.evaluation.BaseEvaluation.class); assertEquals(e.toJson(), fromJson.toJson()); } } @Test - public void testSerde() { + @DisplayName("Test Serde") + void testSerde() { boolean print = false; Nd4j.getRandom().setSeed(12345); - Evaluation evaluation = new Evaluation(); EvaluationBinary evaluationBinary = new EvaluationBinary(); ROC roc = new ROC(2); @@ -68,56 +64,43 @@ public class EvalJsonTest extends BaseDL4JTest { ROCMultiClass roc3 = new ROCMultiClass(2); RegressionEvaluation regressionEvaluation = new RegressionEvaluation(); EvaluationCalibration ec = new EvaluationCalibration(); - - - org.nd4j.evaluation.IEvaluation[] arr = new org.nd4j.evaluation.IEvaluation[] {evaluation, evaluationBinary, roc, roc2, roc3, regressionEvaluation, ec}; - + org.nd4j.evaluation.IEvaluation[] arr = new org.nd4j.evaluation.IEvaluation[] { evaluation, evaluationBinary, roc, roc2, roc3, regressionEvaluation, ec }; INDArray evalLabel = Nd4j.create(10, 3); for (int i = 0; i < 10; i++) { evalLabel.putScalar(i, i % 3, 1.0); } INDArray evalProb = Nd4j.rand(10, 3); - evalProb.diviColumnVector(evalProb.sum(true,1)); + evalProb.diviColumnVector(evalProb.sum(true, 1)); evaluation.eval(evalLabel, evalProb); roc3.eval(evalLabel, evalProb); ec.eval(evalLabel, evalProb); - evalLabel = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(10, 3), 0.5)); evalProb = Nd4j.rand(10, 3); evaluationBinary.eval(evalLabel, evalProb); roc2.eval(evalLabel, evalProb); - evalLabel = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(10, 1), 0.5)); evalProb = Nd4j.rand(10, 1); roc.eval(evalLabel, evalProb); - regressionEvaluation.eval(Nd4j.rand(10, 3), Nd4j.rand(10, 3)); - - - for (org.nd4j.evaluation.IEvaluation e : arr) { String json = e.toJson(); if (print) { System.out.println(e.getClass() + "\n" + json + "\n\n"); } - IEvaluation fromJson = (IEvaluation) BaseEvaluation.fromJson(json, org.nd4j.evaluation.BaseEvaluation.class); assertEquals(e.toJson(), fromJson.toJson()); } } @Test - public void testSerdeExactRoc() { + @DisplayName("Test Serde Exact Roc") + void testSerdeExactRoc() { Nd4j.getRandom().setSeed(12345); boolean print = false; - ROC roc = new ROC(0); ROCBinary roc2 = new ROCBinary(0); ROCMultiClass roc3 = new ROCMultiClass(0); - - - org.nd4j.evaluation.IEvaluation[] arr = new org.nd4j.evaluation.IEvaluation[] {roc, roc2, roc3}; - + org.nd4j.evaluation.IEvaluation[] arr = new org.nd4j.evaluation.IEvaluation[] { roc, roc2, roc3 }; INDArray evalLabel = Nd4j.create(100, 3); for (int i = 0; i < 100; i++) { evalLabel.putScalar(i, i % 3, 1.0); @@ -125,15 +108,12 @@ public class EvalJsonTest extends BaseDL4JTest { INDArray evalProb = Nd4j.rand(100, 3); evalProb.diviColumnVector(evalProb.sum(1)); roc3.eval(evalLabel, evalProb); - evalLabel = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(100, 3), 0.5)); evalProb = Nd4j.rand(100, 3); roc2.eval(evalLabel, evalProb); - evalLabel = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(100, 1), 0.5)); evalProb = Nd4j.rand(100, 1); roc.eval(evalLabel, evalProb); - for (org.nd4j.evaluation.IEvaluation e : arr) { System.out.println(e.getClass()); String json = e.toJson(); @@ -143,37 +123,34 @@ public class EvalJsonTest extends BaseDL4JTest { } org.nd4j.evaluation.IEvaluation fromJson = BaseEvaluation.fromJson(json, org.nd4j.evaluation.BaseEvaluation.class); assertEquals(e, fromJson); - if (fromJson instanceof ROC) { - //Shouldn't have probAndLabel, but should have stored AUC and AUPRC + // Shouldn't have probAndLabel, but should have stored AUC and AUPRC assertNull(((ROC) fromJson).getProbAndLabel()); assertTrue(((ROC) fromJson).calculateAUC() > 0.0); assertTrue(((ROC) fromJson).calculateAUCPR() > 0.0); - assertEquals(((ROC) e).getRocCurve(), ((ROC) fromJson).getRocCurve()); assertEquals(((ROC) e).getPrecisionRecallCurve(), ((ROC) fromJson).getPrecisionRecallCurve()); } else if (e instanceof ROCBinary) { org.nd4j.evaluation.classification.ROC[] rocs = ((ROCBinary) fromJson).getUnderlying(); org.nd4j.evaluation.classification.ROC[] origRocs = ((ROCBinary) e).getUnderlying(); - // for(ROC r : rocs ){ + // for(ROC r : rocs ){ for (int i = 0; i < origRocs.length; i++) { org.nd4j.evaluation.classification.ROC r = rocs[i]; org.nd4j.evaluation.classification.ROC origR = origRocs[i]; - //Shouldn't have probAndLabel, but should have stored AUC and AUPRC, AND stored curves + // Shouldn't have probAndLabel, but should have stored AUC and AUPRC, AND stored curves assertNull(r.getProbAndLabel()); assertEquals(origR.calculateAUC(), origR.calculateAUC(), 1e-6); assertEquals(origR.calculateAUCPR(), origR.calculateAUCPR(), 1e-6); assertEquals(origR.getRocCurve(), origR.getRocCurve()); assertEquals(origR.getPrecisionRecallCurve(), origR.getPrecisionRecallCurve()); } - } else if (e instanceof ROCMultiClass) { org.nd4j.evaluation.classification.ROC[] rocs = ((ROCMultiClass) fromJson).getUnderlying(); org.nd4j.evaluation.classification.ROC[] origRocs = ((ROCMultiClass) e).getUnderlying(); for (int i = 0; i < origRocs.length; i++) { org.nd4j.evaluation.classification.ROC r = rocs[i]; org.nd4j.evaluation.classification.ROC origR = origRocs[i]; - //Shouldn't have probAndLabel, but should have stored AUC and AUPRC, AND stored curves + // Shouldn't have probAndLabel, but should have stored AUC and AUPRC, AND stored curves assertNull(r.getProbAndLabel()); assertEquals(origR.calculateAUC(), origR.calculateAUC(), 1e-6); assertEquals(origR.calculateAUCPR(), origR.calculateAUCPR(), 1e-6); @@ -185,32 +162,23 @@ public class EvalJsonTest extends BaseDL4JTest { } @Test - public void testJsonYamlCurves() { + @DisplayName("Test Json Yaml Curves") + void testJsonYamlCurves() { ROC roc = new ROC(0); - - INDArray evalLabel = - Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(100, 1), 0.5)); + INDArray evalLabel = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(100, 1), 0.5)); INDArray evalProb = Nd4j.rand(100, 1); roc.eval(evalLabel, evalProb); - RocCurve c = roc.getRocCurve(); PrecisionRecallCurve prc = roc.getPrecisionRecallCurve(); - String json1 = c.toJson(); String json2 = prc.toJson(); - RocCurve c2 = RocCurve.fromJson(json1); PrecisionRecallCurve prc2 = PrecisionRecallCurve.fromJson(json2); - assertEquals(c, c2); assertEquals(prc, prc2); - - // System.out.println(json1); - - //Also test: histograms - + // System.out.println(json1); + // Also test: histograms EvaluationCalibration ec = new EvaluationCalibration(); - evalLabel = Nd4j.create(10, 3); for (int i = 0; i < 10; i++) { evalLabel.putScalar(i, i % 3, 1.0); @@ -218,67 +186,45 @@ public class EvalJsonTest extends BaseDL4JTest { evalProb = Nd4j.rand(10, 3); evalProb.diviColumnVector(evalProb.sum(1)); ec.eval(evalLabel, evalProb); - - Histogram[] histograms = new Histogram[] {ec.getResidualPlotAllClasses(), ec.getResidualPlot(0), - ec.getResidualPlot(1), ec.getProbabilityHistogramAllClasses(), ec.getProbabilityHistogram(0), - ec.getProbabilityHistogram(1)}; - + Histogram[] histograms = new Histogram[] { ec.getResidualPlotAllClasses(), ec.getResidualPlot(0), ec.getResidualPlot(1), ec.getProbabilityHistogramAllClasses(), ec.getProbabilityHistogram(0), ec.getProbabilityHistogram(1) }; for (Histogram h : histograms) { String json = h.toJson(); String yaml = h.toYaml(); - Histogram h2 = Histogram.fromJson(json); Histogram h3 = Histogram.fromYaml(yaml); - assertEquals(h, h2); assertEquals(h2, h3); } - } @Test - public void testJsonWithCustomThreshold() { - - //Evaluation - binary threshold + @DisplayName("Test Json With Custom Threshold") + void testJsonWithCustomThreshold() { + // Evaluation - binary threshold Evaluation e = new Evaluation(0.25); String json = e.toJson(); String yaml = e.toYaml(); - Evaluation eFromJson = Evaluation.fromJson(json); Evaluation eFromYaml = Evaluation.fromYaml(yaml); - assertEquals(0.25, eFromJson.getBinaryDecisionThreshold(), 1e-6); assertEquals(0.25, eFromYaml.getBinaryDecisionThreshold(), 1e-6); - - - //Evaluation: custom cost array - INDArray costArray = Nd4j.create(new double[] {1.0, 2.0, 3.0}); + // Evaluation: custom cost array + INDArray costArray = Nd4j.create(new double[] { 1.0, 2.0, 3.0 }); Evaluation e2 = new Evaluation(costArray); - json = e2.toJson(); yaml = e2.toYaml(); - eFromJson = Evaluation.fromJson(json); eFromYaml = Evaluation.fromYaml(yaml); - assertEquals(e2.getCostArray(), eFromJson.getCostArray()); assertEquals(e2.getCostArray(), eFromYaml.getCostArray()); - - - - //EvaluationBinary - per-output binary threshold - INDArray threshold = Nd4j.create(new double[] {1.0, 0.5, 0.25}); + // EvaluationBinary - per-output binary threshold + INDArray threshold = Nd4j.create(new double[] { 1.0, 0.5, 0.25 }); EvaluationBinary eb = new EvaluationBinary(threshold); - json = eb.toJson(); yaml = eb.toYaml(); - EvaluationBinary ebFromJson = EvaluationBinary.fromJson(json); EvaluationBinary ebFromYaml = EvaluationBinary.fromYaml(yaml); - assertEquals(threshold, ebFromJson.getDecisionThreshold()); assertEquals(threshold, ebFromYaml.getDecisionThreshold()); - } - } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/EvalTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/EvalTest.java index 886d6645e..cb74ab199 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/EvalTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/EvalTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.eval; import org.datavec.api.records.metadata.RecordMetaData; @@ -45,7 +44,7 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.listeners.EvaluativeListener; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; @@ -58,78 +57,60 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.common.resources.Resources; - import java.util.*; +import static org.junit.jupiter.api.Assertions.*; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.*; - -public class EvalTest extends BaseDL4JTest { +@DisplayName("Eval Test") +class EvalTest extends BaseDL4JTest { @Test - public void testIris() { - + @DisplayName("Test Iris") + void testIris() { // Network config - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - - .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).seed(42) - .updater(new Sgd(1e-6)).list() - .layer(0, new DenseLayer.Builder().nIn(4).nOut(2).activation(Activation.TANH) - .weightInit(WeightInit.XAVIER).build()) - .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3).weightInit(WeightInit.XAVIER) - .activation(Activation.SOFTMAX).build()) - - .build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).seed(42).updater(new Sgd(1e-6)).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(2).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()).layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build()).build(); // Instantiate model MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); model.addListeners(new ScoreIterationListener(1)); - // Train-test split DataSetIterator iter = new IrisDataSetIterator(150, 150); DataSet next = iter.next(); next.shuffle(); SplitTestAndTrain trainTest = next.splitTestAndTrain(5, new Random(42)); - // Train DataSet train = trainTest.getTrain(); train.normalizeZeroMeanZeroUnitVariance(); - // Test DataSet test = trainTest.getTest(); test.normalizeZeroMeanZeroUnitVariance(); INDArray testFeature = test.getFeatures(); INDArray testLabel = test.getLabels(); - // Fitting model model.fit(train); // Get predictions from test feature INDArray testPredictedLabel = model.output(testFeature); - // Eval with class number - org.nd4j.evaluation.classification.Evaluation eval = new org.nd4j.evaluation.classification.Evaluation(3); //// Specify class num here + // // Specify class num here + org.nd4j.evaluation.classification.Evaluation eval = new org.nd4j.evaluation.classification.Evaluation(3); eval.eval(testLabel, testPredictedLabel); double eval1F1 = eval.f1(); double eval1Acc = eval.accuracy(); - // Eval without class number - org.nd4j.evaluation.classification.Evaluation eval2 = new org.nd4j.evaluation.classification.Evaluation(); //// No class num + // // No class num + org.nd4j.evaluation.classification.Evaluation eval2 = new org.nd4j.evaluation.classification.Evaluation(); eval2.eval(testLabel, testPredictedLabel); double eval2F1 = eval2.f1(); double eval2Acc = eval2.accuracy(); - - //Assert the two implementations give same f1 and accuracy (since one batch) + // Assert the two implementations give same f1 and accuracy (since one batch) assertTrue(eval1F1 == eval2F1 && eval1Acc == eval2Acc); - org.nd4j.evaluation.classification.Evaluation evalViaMethod = model.evaluate(new ListDataSetIterator<>(Collections.singletonList(test))); checkEvaluationEquality(eval, evalViaMethod); - -// System.out.println(eval.getConfusionMatrix().toString()); -// System.out.println(eval.getConfusionMatrix().toCSV()); -// System.out.println(eval.getConfusionMatrix().toHTML()); -// System.out.println(eval.confusionToString()); - + // System.out.println(eval.getConfusionMatrix().toString()); + // System.out.println(eval.getConfusionMatrix().toCSV()); + // System.out.println(eval.getConfusionMatrix().toHTML()); + // System.out.println(eval.confusionToString()); eval.getConfusionMatrix().toString(); eval.getConfusionMatrix().toCSV(); eval.getConfusionMatrix().toHTML(); @@ -160,99 +141,79 @@ public class EvalTest extends BaseDL4JTest { } @Test - public void testEvaluationWithMetaData() throws Exception { - + @DisplayName("Test Evaluation With Meta Data") + void testEvaluationWithMetaData() throws Exception { RecordReader csv = new CSVRecordReader(); csv.initialize(new FileSplit(Resources.asFile("iris.txt"))); - int batchSize = 10; int labelIdx = 4; int numClasses = 3; - RecordReaderDataSetIterator rrdsi = new RecordReaderDataSetIterator(csv, batchSize, labelIdx, numClasses); - NormalizerStandardize ns = new NormalizerStandardize(); ns.fit(rrdsi); rrdsi.setPreProcessor(ns); rrdsi.reset(); - Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Sgd(0.1)) - .list() - .layer(0, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nIn(4).nOut(3).build()) - .build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Sgd(0.1)).list().layer(0, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(4).nOut(3).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - for (int i = 0; i < 4; i++) { net.fit(rrdsi); rrdsi.reset(); } - org.nd4j.evaluation.classification.Evaluation e = new org.nd4j.evaluation.classification.Evaluation(); - rrdsi.setCollectMetaData(true); //*** New: Enable collection of metadata (stored in the DataSets) *** - + // *** New: Enable collection of metadata (stored in the DataSets) *** + rrdsi.setCollectMetaData(true); while (rrdsi.hasNext()) { DataSet ds = rrdsi.next(); - List meta = ds.getExampleMetaData(RecordMetaData.class); //*** New - cross dependencies here make types difficult, usid Object internally in DataSet for this*** - + // *** New - cross dependencies here make types difficult, usid Object internally in DataSet for this*** + List meta = ds.getExampleMetaData(RecordMetaData.class); INDArray out = net.output(ds.getFeatures()); - e.eval(ds.getLabels(), out, meta); //*** New - evaluate and also store metadata *** + // *** New - evaluate and also store metadata *** + e.eval(ds.getLabels(), out, meta); } - -// System.out.println(e.stats()); + // System.out.println(e.stats()); e.stats(); - -// System.out.println("\n\n*** Prediction Errors: ***"); - - List errors = e.getPredictionErrors(); //*** New - get list of prediction errors from evaluation *** + // System.out.println("\n\n*** Prediction Errors: ***"); + // *** New - get list of prediction errors from evaluation *** + List errors = e.getPredictionErrors(); List metaForErrors = new ArrayList<>(); for (org.nd4j.evaluation.meta.Prediction p : errors) { metaForErrors.add((RecordMetaData) p.getRecordMetaData()); } - DataSet ds = rrdsi.loadFromMetaData(metaForErrors); //*** New - dynamically load a subset of the data, just for prediction errors *** + // *** New - dynamically load a subset of the data, just for prediction errors *** + DataSet ds = rrdsi.loadFromMetaData(metaForErrors); INDArray output = net.output(ds.getFeatures()); - int count = 0; for (org.nd4j.evaluation.meta.Prediction t : errors) { - String s = t + "\t\tRaw Data: " - + csv.loadFromMetaData((RecordMetaData) t.getRecordMetaData()).getRecord() //*** New - load subset of data from MetaData object (usually batched for efficiency) *** - + "\tNormalized: " + ds.getFeatures().getRow(count) + "\tLabels: " - + ds.getLabels().getRow(count) + "\tNetwork predictions: " + output.getRow(count); -// System.out.println(s); + String s = t + "\t\tRaw Data: " + // *** New - load subset of data from MetaData object (usually batched for efficiency) *** + csv.loadFromMetaData((RecordMetaData) t.getRecordMetaData()).getRecord() + "\tNormalized: " + ds.getFeatures().getRow(count) + "\tLabels: " + ds.getLabels().getRow(count) + "\tNetwork predictions: " + output.getRow(count); + // System.out.println(s); count++; } - int errorCount = errors.size(); double expAcc = 1.0 - errorCount / 150.0; assertEquals(expAcc, e.accuracy(), 1e-5); - org.nd4j.evaluation.classification.ConfusionMatrix confusion = e.getConfusionMatrix(); int[] actualCounts = new int[3]; int[] predictedCounts = new int[3]; for (int i = 0; i < 3; i++) { for (int j = 0; j < 3; j++) { - int entry = confusion.getCount(i, j); //(actual,predicted) + // (actual,predicted) + int entry = confusion.getCount(i, j); List list = e.getPredictions(i, j); assertEquals(entry, list.size()); - actualCounts[i] += entry; predictedCounts[j] += entry; } } - for (int i = 0; i < 3; i++) { List actualClassI = e.getPredictionsByActualClass(i); List predictedClassI = e.getPredictionByPredictedClass(i); assertEquals(actualCounts[i], actualClassI.size()); assertEquals(predictedCounts[i], predictedClassI.size()); } - - - //Finally: test doEvaluation methods + // Finally: test doEvaluation methods rrdsi.reset(); org.nd4j.evaluation.classification.Evaluation e2 = new org.nd4j.evaluation.classification.Evaluation(); net.doEvaluation(rrdsi, e2); @@ -262,7 +223,6 @@ public class EvalTest extends BaseDL4JTest { assertEquals(actualCounts[i], actualClassI.size()); assertEquals(predictedCounts[i], predictedClassI.size()); } - ComputationGraph cg = net.toComputationGraph(); rrdsi.reset(); e2 = new org.nd4j.evaluation.classification.Evaluation(); @@ -273,7 +233,6 @@ public class EvalTest extends BaseDL4JTest { assertEquals(actualCounts[i], actualClassI.size()); assertEquals(predictedCounts[i], predictedClassI.size()); } - } private static void apply(org.nd4j.evaluation.classification.Evaluation e, int nTimes, INDArray predicted, INDArray actual) { @@ -283,138 +242,28 @@ public class EvalTest extends BaseDL4JTest { } @Test - public void testEvalSplitting(){ - //Test for "tbptt-like" functionality - - for(WorkspaceMode ws : WorkspaceMode.values()) { + @DisplayName("Test Eval Splitting") + void testEvalSplitting() { + // Test for "tbptt-like" functionality + for (WorkspaceMode ws : WorkspaceMode.values()) { System.out.println("Starting test for workspace mode: " + ws); - int nIn = 4; int layerSize = 5; int nOut = 6; int tbpttLength = 10; int tsLength = 5 * tbpttLength + tbpttLength / 2; - - MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder() - .seed(12345) - .trainingWorkspaceMode(ws) - .inferenceWorkspaceMode(ws) - .list() - .layer(new LSTM.Builder().nIn(nIn).nOut(layerSize).build()) - .layer(new RnnOutputLayer.Builder().nIn(layerSize).nOut(nOut) - .activation(Activation.SOFTMAX) - .build()) - .build(); - - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder() - .seed(12345) - .trainingWorkspaceMode(ws) - .inferenceWorkspaceMode(ws) - .list() - .layer(new LSTM.Builder().nIn(nIn).nOut(layerSize).build()) - .layer(new RnnOutputLayer.Builder().nIn(layerSize).nOut(nOut) - .activation(Activation.SOFTMAX).build()) - .tBPTTLength(10) - .backpropType(BackpropType.TruncatedBPTT) - .build(); - + MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(12345).trainingWorkspaceMode(ws).inferenceWorkspaceMode(ws).list().layer(new LSTM.Builder().nIn(nIn).nOut(layerSize).build()).layer(new RnnOutputLayer.Builder().nIn(layerSize).nOut(nOut).activation(Activation.SOFTMAX).build()).build(); + MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345).trainingWorkspaceMode(ws).inferenceWorkspaceMode(ws).list().layer(new LSTM.Builder().nIn(nIn).nOut(layerSize).build()).layer(new RnnOutputLayer.Builder().nIn(layerSize).nOut(nOut).activation(Activation.SOFTMAX).build()).tBPTTLength(10).backpropType(BackpropType.TruncatedBPTT).build(); MultiLayerNetwork net1 = new MultiLayerNetwork(conf1); net1.init(); - MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); net2.init(); - net2.setParams(net1.params()); - - for(boolean useMask : new boolean[]{false, true}) { - - INDArray in1 = Nd4j.rand(new int[]{3, nIn, tsLength}); + for (boolean useMask : new boolean[] { false, true }) { + INDArray in1 = Nd4j.rand(new int[] { 3, nIn, tsLength }); INDArray out1 = TestUtils.randomOneHotTimeSeries(3, nOut, tsLength); - - INDArray in2 = Nd4j.rand(new int[]{5, nIn, tsLength}); + INDArray in2 = Nd4j.rand(new int[] { 5, nIn, tsLength }); INDArray out2 = TestUtils.randomOneHotTimeSeries(5, nOut, tsLength); - - INDArray lMask1 = null; - INDArray lMask2 = null; - if(useMask){ - lMask1 = Nd4j.create(3, tsLength); - lMask2 = Nd4j.create(5, tsLength); - Nd4j.getExecutioner().exec(new BernoulliDistribution(lMask1, 0.5)); - Nd4j.getExecutioner().exec(new BernoulliDistribution(lMask2, 0.5)); - } - - List l = Arrays.asList(new DataSet(in1, out1, null, lMask1), new DataSet(in2, out2, null, lMask2)); - DataSetIterator iter = new ExistingDataSetIterator(l); - -// System.out.println("Net 1 eval"); - org.nd4j.evaluation.IEvaluation[] e1 = net1.doEvaluation(iter, new org.nd4j.evaluation.classification.Evaluation(), new org.nd4j.evaluation.classification.ROCMultiClass(), new org.nd4j.evaluation.regression.RegressionEvaluation()); -// System.out.println("Net 2 eval"); - org.nd4j.evaluation.IEvaluation[] e2 = net2.doEvaluation(iter, new org.nd4j.evaluation.classification.Evaluation(), new org.nd4j.evaluation.classification.ROCMultiClass(), new org.nd4j.evaluation.regression.RegressionEvaluation()); - - assertEquals(e1[0], e2[0]); - assertEquals(e1[1], e2[1]); - assertEquals(e1[2], e2[2]); - } - } - } - - @Test - public void testEvalSplittingCompGraph(){ - //Test for "tbptt-like" functionality - - for(WorkspaceMode ws : WorkspaceMode.values()) { - System.out.println("Starting test for workspace mode: " + ws); - - int nIn = 4; - int layerSize = 5; - int nOut = 6; - int tbpttLength = 10; - int tsLength = 5 * tbpttLength + tbpttLength / 2; - - ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder() - .seed(12345) - .trainingWorkspaceMode(ws) - .inferenceWorkspaceMode(ws) - .graphBuilder() - .addInputs("in") - .addLayer("0", new LSTM.Builder().nIn(nIn).nOut(layerSize).build(), "in") - .addLayer("1", new RnnOutputLayer.Builder().nIn(layerSize).nOut(nOut) - .activation(Activation.SOFTMAX) - .build(), "0") - .setOutputs("1") - .build(); - - ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder() - .seed(12345) - .trainingWorkspaceMode(ws) - .inferenceWorkspaceMode(ws) - .graphBuilder() - .addInputs("in") - .addLayer("0", new LSTM.Builder().nIn(nIn).nOut(layerSize).build(), "in") - .addLayer("1", new RnnOutputLayer.Builder().nIn(layerSize).nOut(nOut) - .activation(Activation.SOFTMAX) - .build(), "0") - .setOutputs("1") - .tBPTTLength(10) - .backpropType(BackpropType.TruncatedBPTT) - .build(); - - ComputationGraph net1 = new ComputationGraph(conf1); - net1.init(); - - ComputationGraph net2 = new ComputationGraph(conf2); - net2.init(); - - net2.setParams(net1.params()); - - for (boolean useMask : new boolean[]{false, true}) { - - INDArray in1 = Nd4j.rand(new int[]{3, nIn, tsLength}); - INDArray out1 = TestUtils.randomOneHotTimeSeries(3, nOut, tsLength); - - INDArray in2 = Nd4j.rand(new int[]{5, nIn, tsLength}); - INDArray out2 = TestUtils.randomOneHotTimeSeries(5, nOut, tsLength); - INDArray lMask1 = null; INDArray lMask2 = null; if (useMask) { @@ -423,15 +272,12 @@ public class EvalTest extends BaseDL4JTest { Nd4j.getExecutioner().exec(new BernoulliDistribution(lMask1, 0.5)); Nd4j.getExecutioner().exec(new BernoulliDistribution(lMask2, 0.5)); } - - List l = Arrays.asList(new DataSet(in1, out1), new DataSet(in2, out2)); + List l = Arrays.asList(new DataSet(in1, out1, null, lMask1), new DataSet(in2, out2, null, lMask2)); DataSetIterator iter = new ExistingDataSetIterator(l); - -// System.out.println("Eval net 1"); + // System.out.println("Net 1 eval"); org.nd4j.evaluation.IEvaluation[] e1 = net1.doEvaluation(iter, new org.nd4j.evaluation.classification.Evaluation(), new org.nd4j.evaluation.classification.ROCMultiClass(), new org.nd4j.evaluation.regression.RegressionEvaluation()); -// System.out.println("Eval net 2"); + // System.out.println("Net 2 eval"); org.nd4j.evaluation.IEvaluation[] e2 = net2.doEvaluation(iter, new org.nd4j.evaluation.classification.Evaluation(), new org.nd4j.evaluation.classification.ROCMultiClass(), new org.nd4j.evaluation.regression.RegressionEvaluation()); - assertEquals(e1[0], e2[0]); assertEquals(e1[1], e2[1]); assertEquals(e1[2], e2[2]); @@ -440,192 +286,170 @@ public class EvalTest extends BaseDL4JTest { } @Test - public void testEvalSplitting2(){ + @DisplayName("Test Eval Splitting Comp Graph") + void testEvalSplittingCompGraph() { + // Test for "tbptt-like" functionality + for (WorkspaceMode ws : WorkspaceMode.values()) { + System.out.println("Starting test for workspace mode: " + ws); + int nIn = 4; + int layerSize = 5; + int nOut = 6; + int tbpttLength = 10; + int tsLength = 5 * tbpttLength + tbpttLength / 2; + ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(12345).trainingWorkspaceMode(ws).inferenceWorkspaceMode(ws).graphBuilder().addInputs("in").addLayer("0", new LSTM.Builder().nIn(nIn).nOut(layerSize).build(), "in").addLayer("1", new RnnOutputLayer.Builder().nIn(layerSize).nOut(nOut).activation(Activation.SOFTMAX).build(), "0").setOutputs("1").build(); + ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345).trainingWorkspaceMode(ws).inferenceWorkspaceMode(ws).graphBuilder().addInputs("in").addLayer("0", new LSTM.Builder().nIn(nIn).nOut(layerSize).build(), "in").addLayer("1", new RnnOutputLayer.Builder().nIn(layerSize).nOut(nOut).activation(Activation.SOFTMAX).build(), "0").setOutputs("1").tBPTTLength(10).backpropType(BackpropType.TruncatedBPTT).build(); + ComputationGraph net1 = new ComputationGraph(conf1); + net1.init(); + ComputationGraph net2 = new ComputationGraph(conf2); + net2.init(); + net2.setParams(net1.params()); + for (boolean useMask : new boolean[] { false, true }) { + INDArray in1 = Nd4j.rand(new int[] { 3, nIn, tsLength }); + INDArray out1 = TestUtils.randomOneHotTimeSeries(3, nOut, tsLength); + INDArray in2 = Nd4j.rand(new int[] { 5, nIn, tsLength }); + INDArray out2 = TestUtils.randomOneHotTimeSeries(5, nOut, tsLength); + INDArray lMask1 = null; + INDArray lMask2 = null; + if (useMask) { + lMask1 = Nd4j.create(3, tsLength); + lMask2 = Nd4j.create(5, tsLength); + Nd4j.getExecutioner().exec(new BernoulliDistribution(lMask1, 0.5)); + Nd4j.getExecutioner().exec(new BernoulliDistribution(lMask2, 0.5)); + } + List l = Arrays.asList(new DataSet(in1, out1), new DataSet(in2, out2)); + DataSetIterator iter = new ExistingDataSetIterator(l); + // System.out.println("Eval net 1"); + org.nd4j.evaluation.IEvaluation[] e1 = net1.doEvaluation(iter, new org.nd4j.evaluation.classification.Evaluation(), new org.nd4j.evaluation.classification.ROCMultiClass(), new org.nd4j.evaluation.regression.RegressionEvaluation()); + // System.out.println("Eval net 2"); + org.nd4j.evaluation.IEvaluation[] e2 = net2.doEvaluation(iter, new org.nd4j.evaluation.classification.Evaluation(), new org.nd4j.evaluation.classification.ROCMultiClass(), new org.nd4j.evaluation.regression.RegressionEvaluation()); + assertEquals(e1[0], e2[0]); + assertEquals(e1[1], e2[1]); + assertEquals(e1[2], e2[2]); + } + } + } + + @Test + @DisplayName("Test Eval Splitting 2") + void testEvalSplitting2() { List> seqFeatures = new ArrayList<>(); List step = Arrays.asList(new FloatWritable(0), new FloatWritable(0), new FloatWritable(0)); - for( int i=0; i<30; i++ ){ + for (int i = 0; i < 30; i++) { seqFeatures.add(step); } List> seqLabels = Collections.singletonList(Collections.singletonList(new FloatWritable(0))); - SequenceRecordReader fsr = new CollectionSequenceRecordReader(Collections.singletonList(seqFeatures)); SequenceRecordReader lsr = new CollectionSequenceRecordReader(Collections.singletonList(seqLabels)); - - - DataSetIterator testData = new SequenceRecordReaderDataSetIterator(fsr, lsr, 1, -1, true, - SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END); - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(123) - .list() - .layer(0, new LSTM.Builder().activation(Activation.TANH).nIn(3).nOut(3).build()) - .layer(1, new RnnOutputLayer.Builder().activation(Activation.SIGMOID).lossFunction(LossFunctions.LossFunction.XENT) - .nIn(3).nOut(1).build()) - .backpropType(BackpropType.TruncatedBPTT).tBPTTForwardLength(10).tBPTTBackwardLength(10) - .build(); + DataSetIterator testData = new SequenceRecordReaderDataSetIterator(fsr, lsr, 1, -1, true, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END); + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(123).list().layer(0, new LSTM.Builder().activation(Activation.TANH).nIn(3).nOut(3).build()).layer(1, new RnnOutputLayer.Builder().activation(Activation.SIGMOID).lossFunction(LossFunctions.LossFunction.XENT).nIn(3).nOut(1).build()).backpropType(BackpropType.TruncatedBPTT).tBPTTForwardLength(10).tBPTTBackwardLength(10).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - net.evaluate(testData); } @Test - public void testEvaluativeListenerSimple(){ - //Sanity check: https://github.com/eclipse/deeplearning4j/issues/5351 - + @DisplayName("Test Evaluative Listener Simple") + void testEvaluativeListenerSimple() { + // Sanity check: https://github.com/eclipse/deeplearning4j/issues/5351 // Network config - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - - .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).seed(42) - .updater(new Sgd(1e-6)).list() - .layer(0, new DenseLayer.Builder().nIn(4).nOut(2).activation(Activation.TANH) - .weightInit(WeightInit.XAVIER).build()) - .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3).weightInit(WeightInit.XAVIER) - .activation(Activation.SOFTMAX).build()) - .build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).seed(42).updater(new Sgd(1e-6)).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(2).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()).layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build()).build(); // Instantiate model MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - // Train-test split DataSetIterator iter = new IrisDataSetIterator(30, 150); DataSetIterator iterTest = new IrisDataSetIterator(30, 150); - net.setListeners(new EvaluativeListener(iterTest, 3)); - - for( int i=0; i<3; i++ ){ + for (int i = 0; i < 3; i++) { net.fit(iter); } } @Test - public void testMultiOutputEvalSimple(){ + @DisplayName("Test Multi Output Eval Simple") + void testMultiOutputEvalSimple() { Nd4j.getRandom().setSeed(12345); - - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() - .seed(12345) - .graphBuilder() - .addInputs("in") - .addLayer("out1", new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX).build(), "in") - .addLayer("out2", new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX).build(), "in") - .setOutputs("out1", "out2") - .build(); - + ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).graphBuilder().addInputs("in").addLayer("out1", new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX).build(), "in").addLayer("out2", new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX).build(), "in").setOutputs("out1", "out2").build(); ComputationGraph cg = new ComputationGraph(conf); cg.init(); - List list = new ArrayList<>(); DataSetIterator iter = new IrisDataSetIterator(30, 150); - while(iter.hasNext()){ + while (iter.hasNext()) { DataSet ds = iter.next(); - list.add(new org.nd4j.linalg.dataset.MultiDataSet(new INDArray[]{ds.getFeatures()}, new INDArray[]{ds.getLabels(), ds.getLabels()})); + list.add(new org.nd4j.linalg.dataset.MultiDataSet(new INDArray[] { ds.getFeatures() }, new INDArray[] { ds.getLabels(), ds.getLabels() })); } - org.nd4j.evaluation.classification.Evaluation e = new org.nd4j.evaluation.classification.Evaluation(); org.nd4j.evaluation.regression.RegressionEvaluation e2 = new org.nd4j.evaluation.regression.RegressionEvaluation(); - Map evals = new HashMap<>(); - evals.put(0, new org.nd4j.evaluation.IEvaluation[]{e}); - evals.put(1, new org.nd4j.evaluation.IEvaluation[]{e2}); - + Map evals = new HashMap<>(); + evals.put(0, new org.nd4j.evaluation.IEvaluation[] { e }); + evals.put(1, new org.nd4j.evaluation.IEvaluation[] { e2 }); cg.evaluate(new IteratorMultiDataSetIterator(list.iterator(), 30), evals); - assertEquals(150, e.getNumRowCounter()); assertEquals(150, e2.getExampleCountPerColumn().getInt(0)); } @Test - public void testMultiOutputEvalCG(){ - //Simple sanity check on evaluation - - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() - .graphBuilder() - .addInputs("in") - .layer("0", new EmbeddingSequenceLayer.Builder().nIn(10).nOut(10).build(), "in") - .layer("1", new LSTM.Builder().nIn(10).nOut(10).build(), "0") - .layer("2", new LSTM.Builder().nIn(10).nOut(10).build(), "0") - .layer("out1", new RnnOutputLayer.Builder().nIn(10).nOut(10).activation(Activation.SOFTMAX).build(), "1") - .layer("out2", new RnnOutputLayer.Builder().nIn(10).nOut(20).activation(Activation.SOFTMAX).build(), "2") - .setOutputs("out1", "out2") - .build(); - + @DisplayName("Test Multi Output Eval CG") + void testMultiOutputEvalCG() { + // Simple sanity check on evaluation + ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in").layer("0", new EmbeddingSequenceLayer.Builder().nIn(10).nOut(10).build(), "in").layer("1", new LSTM.Builder().nIn(10).nOut(10).build(), "0").layer("2", new LSTM.Builder().nIn(10).nOut(10).build(), "0").layer("out1", new RnnOutputLayer.Builder().nIn(10).nOut(10).activation(Activation.SOFTMAX).build(), "1").layer("out2", new RnnOutputLayer.Builder().nIn(10).nOut(20).activation(Activation.SOFTMAX).build(), "2").setOutputs("out1", "out2").build(); ComputationGraph cg = new ComputationGraph(conf); cg.init(); - - org.nd4j.linalg.dataset.MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet( - new INDArray[]{Nd4j.create(10, 1, 10)}, - new INDArray[]{Nd4j.create(10, 10, 10), Nd4j.create(10, 20, 10)}); - - Map m = new HashMap<>(); - m.put(0, new org.nd4j.evaluation.IEvaluation[]{new org.nd4j.evaluation.classification.Evaluation()}); - m.put(1, new org.nd4j.evaluation.IEvaluation[]{new org.nd4j.evaluation.classification.Evaluation()}); - + org.nd4j.linalg.dataset.MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(new INDArray[] { Nd4j.create(10, 1, 10) }, new INDArray[] { Nd4j.create(10, 10, 10), Nd4j.create(10, 20, 10) }); + Map m = new HashMap<>(); + m.put(0, new org.nd4j.evaluation.IEvaluation[] { new org.nd4j.evaluation.classification.Evaluation() }); + m.put(1, new org.nd4j.evaluation.IEvaluation[] { new org.nd4j.evaluation.classification.Evaluation() }); cg.evaluate(new SingletonMultiDataSetIterator(mds), m); } @Test - public void testInvalidEvaluation(){ - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - - .list() - .layer(new DenseLayer.Builder().nIn(4).nOut(10).build()) - .layer(new OutputLayer.Builder().nIn(10).nOut(3).lossFunction(LossFunctions.LossFunction.MSE).activation(Activation.RELU).build()) - .build(); - + @DisplayName("Test Invalid Evaluation") + void testInvalidEvaluation() { + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(new DenseLayer.Builder().nIn(4).nOut(10).build()).layer(new OutputLayer.Builder().nIn(10).nOut(3).lossFunction(LossFunctions.LossFunction.MSE).activation(Activation.RELU).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - DataSetIterator iter = new IrisDataSetIterator(150, 150); try { net.evaluate(iter); fail("Expected exception"); - } catch (IllegalStateException e){ + } catch (IllegalStateException e) { assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("Evaluation")); } - try { net.evaluateROC(iter, 0); fail("Expected exception"); - } catch (IllegalStateException e){ + } catch (IllegalStateException e) { assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("ROC")); } - try { net.evaluateROCMultiClass(iter, 0); fail("Expected exception"); - } catch (IllegalStateException e){ + } catch (IllegalStateException e) { assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("ROCMultiClass")); } - ComputationGraph cg = net.toComputationGraph(); try { cg.evaluate(iter); fail("Expected exception"); - } catch (IllegalStateException e){ + } catch (IllegalStateException e) { assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("Evaluation")); } - try { cg.evaluateROC(iter, 0); fail("Expected exception"); - } catch (IllegalStateException e){ + } catch (IllegalStateException e) { assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("ROC")); } - try { cg.evaluateROCMultiClass(iter, 0); fail("Expected exception"); - } catch (IllegalStateException e){ + } catch (IllegalStateException e) { assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("ROCMultiClass")); } - - - //Disable validation, and check same thing: + // Disable validation, and check same thing: net.getLayerWiseConfigurations().setValidateOutputLayerConfig(false); net.evaluate(iter); net.evaluateROCMultiClass(iter, 0); - cg.getConfiguration().setValidateOutputLayerConfig(false); cg.evaluate(iter); cg.evaluateROCMultiClass(iter, 0); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/ROCTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/ROCTest.java index 09586699d..adf9aa54e 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/ROCTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/ROCTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.eval; import org.deeplearning4j.BaseDL4JTest; @@ -28,48 +27,53 @@ import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.Test; -import org.nd4j.evaluation.curves.PrecisionRecallCurve; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; import org.nd4j.evaluation.curves.RocCurve; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.lossfunctions.LossFunctions; -import java.util.*; +import java.util.HashMap; +import java.util.Map; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.assertEquals; -public class ROCTest extends BaseDL4JTest { +@DisplayName("Roc Test") +class ROCTest extends BaseDL4JTest { private static Map expTPR; + private static Map expFPR; static { expTPR = new HashMap<>(); double totalPositives = 5.0; - expTPR.put(0 / 10.0, 5.0 / totalPositives); //All 10 predicted as class 1, of which 5 of 5 are correct + // All 10 predicted as class 1, of which 5 of 5 are correct + expTPR.put(0 / 10.0, 5.0 / totalPositives); expTPR.put(1 / 10.0, 5.0 / totalPositives); expTPR.put(2 / 10.0, 5.0 / totalPositives); expTPR.put(3 / 10.0, 5.0 / totalPositives); expTPR.put(4 / 10.0, 5.0 / totalPositives); expTPR.put(5 / 10.0, 5.0 / totalPositives); - expTPR.put(6 / 10.0, 4.0 / totalPositives); //Threshold: 0.4 -> last 4 predicted; last 5 actual + // Threshold: 0.4 -> last 4 predicted; last 5 actual + expTPR.put(6 / 10.0, 4.0 / totalPositives); expTPR.put(7 / 10.0, 3.0 / totalPositives); expTPR.put(8 / 10.0, 2.0 / totalPositives); expTPR.put(9 / 10.0, 1.0 / totalPositives); expTPR.put(10 / 10.0, 0.0 / totalPositives); - expFPR = new HashMap<>(); double totalNegatives = 5.0; - expFPR.put(0 / 10.0, 5.0 / totalNegatives); //All 10 predicted as class 1, but all 5 true negatives are predicted positive - expFPR.put(1 / 10.0, 4.0 / totalNegatives); //1 true negative is predicted as negative; 4 false positives - expFPR.put(2 / 10.0, 3.0 / totalNegatives); //2 true negatives are predicted as negative; 3 false positives + // All 10 predicted as class 1, but all 5 true negatives are predicted positive + expFPR.put(0 / 10.0, 5.0 / totalNegatives); + // 1 true negative is predicted as negative; 4 false positives + expFPR.put(1 / 10.0, 4.0 / totalNegatives); + // 2 true negatives are predicted as negative; 3 false positives + expFPR.put(2 / 10.0, 3.0 / totalNegatives); expFPR.put(3 / 10.0, 2.0 / totalNegatives); expFPR.put(4 / 10.0, 1.0 / totalNegatives); expFPR.put(5 / 10.0, 0.0 / totalNegatives); @@ -81,56 +85,41 @@ public class ROCTest extends BaseDL4JTest { } @Test - public void RocEvalSanityCheck() { - + @DisplayName("Roc Eval Sanity Check") + void RocEvalSanityCheck() { DataSetIterator iter = new IrisDataSetIterator(150, 150); - Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).seed(12345) - .list() - .layer(0, new DenseLayer.Builder().nIn(4).nOut(4).activation(Activation.TANH).build()).layer(1, - new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX) - .lossFunction(LossFunctions.LossFunction.MCXENT).build()) - .build(); + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).seed(12345).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(4).activation(Activation.TANH).build()).layer(1, new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - NormalizerStandardize ns = new NormalizerStandardize(); DataSet ds = iter.next(); ns.fit(ds); ns.transform(ds); - iter.setPreProcessor(ns); - for (int i = 0; i < 10; i++) { net.fit(ds); } - - for (int steps : new int[] {32, 0}) { //Steps = 0: exact + for (int steps : new int[] { 32, 0 }) { + // Steps = 0: exact System.out.println("steps: " + steps); - iter.reset(); ds = iter.next(); INDArray f = ds.getFeatures(); INDArray l = ds.getLabels(); INDArray out = net.output(f); - // System.out.println(f); - // System.out.println(out); + // System.out.println(f); + // System.out.println(out); ROCMultiClass manual = new ROCMultiClass(steps); manual.eval(l, out); - iter.reset(); ROCMultiClass roc = net.evaluateROCMultiClass(iter, steps); - - for (int i = 0; i < 3; i++) { double rocExp = manual.calculateAUC(i); double rocAct = roc.calculateAUC(i); assertEquals(rocExp, rocAct, 1e-6); - RocCurve rc = roc.getRocCurve(i); RocCurve rm = manual.getRocCurve(i); - assertEquals(rc, rm); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/RegressionEvalTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/RegressionEvalTest.java index e5ac052ab..db5e7d7fa 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/RegressionEvalTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/RegressionEvalTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.eval; import org.deeplearning4j.BaseDL4JTest; @@ -29,59 +28,43 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.lossfunctions.LossFunctions; - import java.util.Collections; - -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.nd4j.linalg.indexing.NDArrayIndex.all; import static org.nd4j.linalg.indexing.NDArrayIndex.interval; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -public class RegressionEvalTest extends BaseDL4JTest { +@DisplayName("Regression Eval Test") +class RegressionEvalTest extends BaseDL4JTest { @Test - public void testRegressionEvalMethods() { - - //Basic sanity check - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.ZERO).list() - .layer(0, new OutputLayer.Builder().activation(Activation.TANH) - .lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(5).build()) - .build(); - + @DisplayName("Test Regression Eval Methods") + void testRegressionEvalMethods() { + // Basic sanity check + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.ZERO).list().layer(0, new OutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(5).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - INDArray f = Nd4j.zeros(4, 10); INDArray l = Nd4j.ones(4, 5); - DataSet ds = new DataSet(f, l); DataSetIterator iter = new ExistingDataSetIterator(Collections.singletonList(ds)); org.nd4j.evaluation.regression.RegressionEvaluation re = net.evaluateRegression(iter); - for (int i = 0; i < 5; i++) { assertEquals(1.0, re.meanSquaredError(i), 1e-6); assertEquals(1.0, re.meanAbsoluteError(i), 1e-6); } - - - ComputationGraphConfiguration graphConf = - new NeuralNetConfiguration.Builder().weightInit(WeightInit.ZERO).graphBuilder() - .addInputs("in").addLayer("0", new OutputLayer.Builder() - .lossFunction(LossFunctions.LossFunction.MSE) - .activation(Activation.TANH).nIn(10).nOut(5).build(), "in") - .setOutputs("0").build(); - + ComputationGraphConfiguration graphConf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.ZERO).graphBuilder().addInputs("in").addLayer("0", new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(10).nOut(5).build(), "in").setOutputs("0").build(); ComputationGraph cg = new ComputationGraph(graphConf); cg.init(); - RegressionEvaluation re2 = cg.evaluateRegression(iter); - for (int i = 0; i < 5; i++) { assertEquals(1.0, re2.meanSquaredError(i), 1e-6); assertEquals(1.0, re2.meanAbsoluteError(i), 1e-6); @@ -89,25 +72,16 @@ public class RegressionEvalTest extends BaseDL4JTest { } @Test - public void testRegressionEvalPerOutputMasking() { - - INDArray l = Nd4j.create(new double[][] {{1, 2, 3}, {10, 20, 30}, {-5, -10, -20}}); - + @DisplayName("Test Regression Eval Per Output Masking") + void testRegressionEvalPerOutputMasking() { + INDArray l = Nd4j.create(new double[][] { { 1, 2, 3 }, { 10, 20, 30 }, { -5, -10, -20 } }); INDArray predictions = Nd4j.zeros(l.shape()); - - INDArray mask = Nd4j.create(new double[][] {{0, 1, 1}, {1, 1, 0}, {0, 1, 0}}); - - + INDArray mask = Nd4j.create(new double[][] { { 0, 1, 1 }, { 1, 1, 0 }, { 0, 1, 0 } }); RegressionEvaluation re = new RegressionEvaluation(); - re.eval(l, predictions, mask); - - double[] mse = new double[] {(10 * 10) / 1.0, (2 * 2 + 20 * 20 + 10 * 10) / 3, (3 * 3) / 1.0}; - - double[] mae = new double[] {10.0, (2 + 20 + 10) / 3.0, 3.0}; - - double[] rmse = new double[] {10.0, Math.sqrt((2 * 2 + 20 * 20 + 10 * 10) / 3.0), 3.0}; - + double[] mse = new double[] { (10 * 10) / 1.0, (2 * 2 + 20 * 20 + 10 * 10) / 3, (3 * 3) / 1.0 }; + double[] mae = new double[] { 10.0, (2 + 20 + 10) / 3.0, 3.0 }; + double[] rmse = new double[] { 10.0, Math.sqrt((2 * 2 + 20 * 20 + 10 * 10) / 3.0), 3.0 }; for (int i = 0; i < 3; i++) { assertEquals(mse[i], re.meanSquaredError(i), 1e-6); assertEquals(mae[i], re.meanAbsoluteError(i), 1e-6); @@ -116,24 +90,19 @@ public class RegressionEvalTest extends BaseDL4JTest { } @Test - public void testRegressionEvalTimeSeriesSplit(){ - - INDArray out1 = Nd4j.rand(new int[]{3, 5, 20}); - INDArray outSub1 = out1.get(all(), all(), interval(0,10)); + @DisplayName("Test Regression Eval Time Series Split") + void testRegressionEvalTimeSeriesSplit() { + INDArray out1 = Nd4j.rand(new int[] { 3, 5, 20 }); + INDArray outSub1 = out1.get(all(), all(), interval(0, 10)); INDArray outSub2 = out1.get(all(), all(), interval(10, 20)); - - INDArray label1 = Nd4j.rand(new int[]{3, 5, 20}); - INDArray labelSub1 = label1.get(all(), all(), interval(0,10)); + INDArray label1 = Nd4j.rand(new int[] { 3, 5, 20 }); + INDArray labelSub1 = label1.get(all(), all(), interval(0, 10)); INDArray labelSub2 = label1.get(all(), all(), interval(10, 20)); - RegressionEvaluation e1 = new RegressionEvaluation(); RegressionEvaluation e2 = new RegressionEvaluation(); - e1.eval(label1, out1); - e2.eval(labelSub1, outSub1); e2.eval(labelSub2, outSub2); - assertEquals(e1, e2); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/AttentionLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/AttentionLayerTest.java index 3911d13bd..2f9fbf18c 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/AttentionLayerTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/AttentionLayerTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.gradientcheck; import org.deeplearning4j.BaseDL4JTest; @@ -32,9 +31,9 @@ import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.Ignore; +import org.junit.jupiter.api.Disabled; import org.junit.Rule; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.rules.ExpectedException; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; @@ -42,13 +41,15 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.lossfunctions.LossFunctions; - import java.util.Random; +import static org.junit.jupiter.api.Assertions.assertTrue; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.assertTrue; +@Disabled +@DisplayName("Attention Layer Test") +class AttentionLayerTest extends BaseDL4JTest { -@Ignore -public class AttentionLayerTest extends BaseDL4JTest { @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -58,19 +59,18 @@ public class AttentionLayerTest extends BaseDL4JTest { } @Test - public void testSelfAttentionLayer() { + @DisplayName("Test Self Attention Layer") + void testSelfAttentionLayer() { int nIn = 3; int nOut = 2; int tsLength = 4; int layerSize = 4; - - for (int mb : new int[]{1, 3}) { - for (boolean inputMask : new boolean[]{false, true}) { - for (boolean projectInput : new boolean[]{false, true}) { - INDArray in = Nd4j.rand(DataType.DOUBLE, new int[]{mb, nIn, tsLength}); + for (int mb : new int[] { 1, 3 }) { + for (boolean inputMask : new boolean[] { false, true }) { + for (boolean projectInput : new boolean[] { false, true }) { + INDArray in = Nd4j.rand(DataType.DOUBLE, new int[] { mb, nIn, tsLength }); INDArray labels = TestUtils.randomOneHot(mb, nOut); String maskType = (inputMask ? "inputMask" : "none"); - INDArray inMask = null; if (inputMask) { inMask = Nd4j.ones(mb, tsLength); @@ -84,54 +84,32 @@ public class AttentionLayerTest extends BaseDL4JTest { } } } - String name = "testSelfAttentionLayer() - mb=" + mb + ", tsLength = " + tsLength + ", maskType=" + maskType + ", projectInput = " + projectInput; System.out.println("Starting test: " + name); - - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .dataType(DataType.DOUBLE) - .activation(Activation.TANH) - .updater(new NoOp()) - .weightInit(WeightInit.XAVIER) - .list() - .layer(new LSTM.Builder().nOut(layerSize).build()) - .layer( projectInput ? - new SelfAttentionLayer.Builder().nOut(4).nHeads(2).projectInput(true).build() - : new SelfAttentionLayer.Builder().nHeads(1).projectInput(false).build() - ) - .layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.MAX).build()) - .layer(new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX) - .lossFunction(LossFunctions.LossFunction.MCXENT).build()) - .setInputType(InputType.recurrent(nIn)) - .build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).activation(Activation.TANH).updater(new NoOp()).weightInit(WeightInit.XAVIER).list().layer(new LSTM.Builder().nOut(layerSize).build()).layer(projectInput ? new SelfAttentionLayer.Builder().nOut(4).nHeads(2).projectInput(true).build() : new SelfAttentionLayer.Builder().nHeads(1).projectInput(false).build()).layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.MAX).build()).layer(new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()).setInputType(InputType.recurrent(nIn)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - - boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in) - .labels(labels).inputMask(inMask).subset(true).maxPerParam(100)); - assertTrue(name, gradOK); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in).labels(labels).inputMask(inMask).subset(true).maxPerParam(100)); + assertTrue(gradOK,name); } } } } @Test - public void testLearnedSelfAttentionLayer() { + @DisplayName("Test Learned Self Attention Layer") + void testLearnedSelfAttentionLayer() { int nIn = 3; int nOut = 2; int tsLength = 4; int layerSize = 4; int numQueries = 3; - - for (boolean inputMask : new boolean[]{false, true}) { - for (int mb : new int[]{3, 1}) { - for (boolean projectInput : new boolean[]{false, true}) { - INDArray in = Nd4j.rand(DataType.DOUBLE, new int[]{mb, nIn, tsLength}); + for (boolean inputMask : new boolean[] { false, true }) { + for (int mb : new int[] { 3, 1 }) { + for (boolean projectInput : new boolean[] { false, true }) { + INDArray in = Nd4j.rand(DataType.DOUBLE, new int[] { mb, nIn, tsLength }); INDArray labels = TestUtils.randomOneHot(mb, nOut); String maskType = (inputMask ? "inputMask" : "none"); - INDArray inMask = null; if (inputMask) { inMask = Nd4j.ones(mb, tsLength); @@ -145,75 +123,36 @@ public class AttentionLayerTest extends BaseDL4JTest { } } } - String name = "testLearnedSelfAttentionLayer() - mb=" + mb + ", tsLength = " + tsLength + ", maskType=" + maskType + ", projectInput = " + projectInput; System.out.println("Starting test: " + name); - - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .dataType(DataType.DOUBLE) - .activation(Activation.TANH) - .updater(new NoOp()) - .weightInit(WeightInit.XAVIER) - .list() - .layer(new LSTM.Builder().nOut(layerSize).build()) - .layer( projectInput ? - new LearnedSelfAttentionLayer.Builder().nOut(4).nHeads(2).nQueries(numQueries).projectInput(true).build() - : new LearnedSelfAttentionLayer.Builder().nHeads(1).nQueries(numQueries).projectInput(false).build() - ) - .layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.MAX).build()) - .layer(new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX) - .lossFunction(LossFunctions.LossFunction.MCXENT).build()) - .setInputType(InputType.recurrent(nIn)) - .build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).activation(Activation.TANH).updater(new NoOp()).weightInit(WeightInit.XAVIER).list().layer(new LSTM.Builder().nOut(layerSize).build()).layer(projectInput ? new LearnedSelfAttentionLayer.Builder().nOut(4).nHeads(2).nQueries(numQueries).projectInput(true).build() : new LearnedSelfAttentionLayer.Builder().nHeads(1).nQueries(numQueries).projectInput(false).build()).layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.MAX).build()).layer(new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()).setInputType(InputType.recurrent(nIn)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - - boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in) - .labels(labels).inputMask(inMask).subset(true).maxPerParam(100)); - assertTrue(name, gradOK); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in).labels(labels).inputMask(inMask).subset(true).maxPerParam(100)); + assertTrue(gradOK,name); } } } } @Test - public void testLearnedSelfAttentionLayer_differentMiniBatchSizes() { + @DisplayName("Test Learned Self Attention Layer _ different Mini Batch Sizes") + void testLearnedSelfAttentionLayer_differentMiniBatchSizes() { int nIn = 3; int nOut = 2; int tsLength = 4; int layerSize = 4; int numQueries = 3; - Random r = new Random(12345); - for (boolean inputMask : new boolean[]{false, true}) { - for (boolean projectInput : new boolean[]{false, true}) { - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .dataType(DataType.DOUBLE) - .activation(Activation.TANH) - .updater(new NoOp()) - .weightInit(WeightInit.XAVIER) - .list() - .layer(new LSTM.Builder().nOut(layerSize).build()) - .layer( projectInput ? - new LearnedSelfAttentionLayer.Builder().nOut(4).nHeads(2).nQueries(numQueries).projectInput(true).build() - : new LearnedSelfAttentionLayer.Builder().nHeads(1).nQueries(numQueries).projectInput(false).build() - ) - .layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.MAX).build()) - .layer(new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX) - .lossFunction(LossFunctions.LossFunction.MCXENT).build()) - .setInputType(InputType.recurrent(nIn)) - .build(); - - MultiLayerNetwork net = new MultiLayerNetwork(conf); - net.init(); - for (int mb : new int[]{3, 1}) { - INDArray in = Nd4j.rand(DataType.DOUBLE, new int[]{mb, nIn, tsLength}); + for (boolean inputMask : new boolean[] { false, true }) { + for (boolean projectInput : new boolean[] { false, true }) { + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).activation(Activation.TANH).updater(new NoOp()).weightInit(WeightInit.XAVIER).list().layer(new LSTM.Builder().nOut(layerSize).build()).layer(projectInput ? new LearnedSelfAttentionLayer.Builder().nOut(4).nHeads(2).nQueries(numQueries).projectInput(true).build() : new LearnedSelfAttentionLayer.Builder().nHeads(1).nQueries(numQueries).projectInput(false).build()).layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.MAX).build()).layer(new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()).setInputType(InputType.recurrent(nIn)).build(); + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + for (int mb : new int[] { 3, 1 }) { + INDArray in = Nd4j.rand(DataType.DOUBLE, new int[] { mb, nIn, tsLength }); INDArray labels = TestUtils.randomOneHot(mb, nOut); String maskType = (inputMask ? "inputMask" : "none"); - INDArray inMask = null; if (inputMask) { inMask = Nd4j.ones(DataType.INT, mb, tsLength); @@ -227,68 +166,47 @@ public class AttentionLayerTest extends BaseDL4JTest { } } } - String name = "testLearnedSelfAttentionLayer() - mb=" + mb + ", tsLength = " + tsLength + ", maskType=" + maskType + ", projectInput = " + projectInput; System.out.println("Starting test: " + name); - - boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in) - .labels(labels).inputMask(inMask).subset(true).maxPerParam(100)); - assertTrue(name, gradOK); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in).labels(labels).inputMask(inMask).subset(true).maxPerParam(100)); + assertTrue(gradOK,name); } } } } @Test - public void testRecurrentAttentionLayer_differingTimeSteps(){ + @DisplayName("Test Recurrent Attention Layer _ differing Time Steps") + void testRecurrentAttentionLayer_differingTimeSteps() { int nIn = 9; int nOut = 5; int layerSize = 8; - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .dataType(DataType.DOUBLE) - .activation(Activation.IDENTITY) - .updater(new NoOp()) - .weightInit(WeightInit.XAVIER) - .list() - .layer(new LSTM.Builder().nOut(layerSize).build()) - .layer(new RecurrentAttentionLayer.Builder().nIn(layerSize).nOut(layerSize).nHeads(1).projectInput(false).hasBias(false).build()) - .layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.AVG).build()) - .layer(new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX) - .lossFunction(LossFunctions.LossFunction.MCXENT).build()) - .setInputType(InputType.recurrent(nIn)) - .build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).activation(Activation.IDENTITY).updater(new NoOp()).weightInit(WeightInit.XAVIER).list().layer(new LSTM.Builder().nOut(layerSize).build()).layer(new RecurrentAttentionLayer.Builder().nIn(layerSize).nOut(layerSize).nHeads(1).projectInput(false).hasBias(false).build()).layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.AVG).build()).layer(new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()).setInputType(InputType.recurrent(nIn)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - - final INDArray initialInput = Nd4j.rand(new int[]{8, nIn, 7}); - final INDArray goodNextInput = Nd4j.rand(new int[]{8, nIn, 7}); - final INDArray badNextInput = Nd4j.rand(new int[]{8, nIn, 12}); - - final INDArray labels = Nd4j.rand(new int[]{8, nOut}); - + final INDArray initialInput = Nd4j.rand(new int[] { 8, nIn, 7 }); + final INDArray goodNextInput = Nd4j.rand(new int[] { 8, nIn, 7 }); + final INDArray badNextInput = Nd4j.rand(new int[] { 8, nIn, 12 }); + final INDArray labels = Nd4j.rand(new int[] { 8, nOut }); net.fit(initialInput, labels); net.fit(goodNextInput, labels); - exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("This layer only supports fixed length mini-batches. Expected 7 time steps but got 12."); net.fit(badNextInput, labels); } @Test - public void testRecurrentAttentionLayer() { + @DisplayName("Test Recurrent Attention Layer") + void testRecurrentAttentionLayer() { int nIn = 4; int nOut = 2; int tsLength = 3; int layerSize = 3; - - for (int mb : new int[]{3, 1}) { - for (boolean inputMask : new boolean[]{true, false}) { - INDArray in = Nd4j.rand(DataType.DOUBLE, new int[]{mb, nIn, tsLength}); + for (int mb : new int[] { 3, 1 }) { + for (boolean inputMask : new boolean[] { true, false }) { + INDArray in = Nd4j.rand(DataType.DOUBLE, new int[] { mb, nIn, tsLength }); INDArray labels = TestUtils.randomOneHot(mb, nOut); String maskType = (inputMask ? "inputMask" : "none"); - INDArray inMask = null; if (inputMask) { inMask = Nd4j.ones(mb, tsLength); @@ -302,51 +220,32 @@ public class AttentionLayerTest extends BaseDL4JTest { } } } - String name = "testRecurrentAttentionLayer() - mb=" + mb + ", tsLength = " + tsLength + ", maskType=" + maskType; System.out.println("Starting test: " + name); - - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .dataType(DataType.DOUBLE) - .activation(Activation.IDENTITY) - .updater(new NoOp()) - .weightInit(WeightInit.XAVIER) - .list() - .layer(new LSTM.Builder().nOut(layerSize).build()) - .layer(new RecurrentAttentionLayer.Builder().nIn(layerSize).nOut(layerSize).nHeads(1).projectInput(false).hasBias(false).build()) - .layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.AVG).build()) - .layer(new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX) - .lossFunction(LossFunctions.LossFunction.MCXENT).build()) - .setInputType(InputType.recurrent(nIn)) - .build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).activation(Activation.IDENTITY).updater(new NoOp()).weightInit(WeightInit.XAVIER).list().layer(new LSTM.Builder().nOut(layerSize).build()).layer(new RecurrentAttentionLayer.Builder().nIn(layerSize).nOut(layerSize).nHeads(1).projectInput(false).hasBias(false).build()).layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.AVG).build()).layer(new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()).setInputType(InputType.recurrent(nIn)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - - //System.out.println("Original"); - boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in) - .labels(labels).inputMask(inMask).subset(true).maxPerParam(100)); - assertTrue(name, gradOK); + // System.out.println("Original"); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in).labels(labels).inputMask(inMask).subset(true).maxPerParam(100)); + assertTrue(gradOK,name); } } } @Test - public void testAttentionVertex() { + @DisplayName("Test Attention Vertex") + void testAttentionVertex() { int nIn = 3; int nOut = 2; int tsLength = 3; int layerSize = 3; - Random r = new Random(12345); - for (boolean inputMask : new boolean[]{false, true}) { - for (int mb : new int[]{3, 1}) { - for (boolean projectInput : new boolean[]{false, true}) { - INDArray in = Nd4j.rand(DataType.DOUBLE, new int[]{mb, nIn, tsLength}); + for (boolean inputMask : new boolean[] { false, true }) { + for (int mb : new int[] { 3, 1 }) { + for (boolean projectInput : new boolean[] { false, true }) { + INDArray in = Nd4j.rand(DataType.DOUBLE, new int[] { mb, nIn, tsLength }); INDArray labels = TestUtils.randomOneHot(mb, nOut); String maskType = (inputMask ? "inputMask" : "none"); - INDArray inMask = null; if (inputMask) { inMask = Nd4j.ones(mb, tsLength); @@ -360,57 +259,32 @@ public class AttentionLayerTest extends BaseDL4JTest { } } } - String name = "testAttentionVertex() - mb=" + mb + ", tsLength = " + tsLength + ", maskType=" + maskType + ", projectInput = " + projectInput; System.out.println("Starting test: " + name); - - - ComputationGraphConfiguration graph = new NeuralNetConfiguration.Builder() - .dataType(DataType.DOUBLE) - .activation(Activation.TANH) - .updater(new NoOp()) - .weightInit(WeightInit.XAVIER) - .graphBuilder() - .addInputs("input") - .addLayer("rnnKeys", new SimpleRnn.Builder().nOut(layerSize).build(), "input") - .addLayer("rnnQueries", new SimpleRnn.Builder().nOut(layerSize).build(), "input") - .addLayer("rnnValues", new SimpleRnn.Builder().nOut(layerSize).build(), "input") - .addVertex("attention", - projectInput ? - new AttentionVertex.Builder().nOut(4).nHeads(2).projectInput(true).nInQueries(layerSize).nInKeys(layerSize).nInValues(layerSize).build() - : new AttentionVertex.Builder().nOut(3).nHeads(1).projectInput(false).nInQueries(layerSize).nInKeys(layerSize).nInValues(layerSize).build(), "rnnQueries", "rnnKeys", "rnnValues") - .addLayer("pooling", new GlobalPoolingLayer.Builder().poolingType(PoolingType.MAX).build(), "attention") - .addLayer("output", new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build(), "pooling") - .setOutputs("output") - .setInputTypes(InputType.recurrent(nIn)) - .build(); - + ComputationGraphConfiguration graph = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).activation(Activation.TANH).updater(new NoOp()).weightInit(WeightInit.XAVIER).graphBuilder().addInputs("input").addLayer("rnnKeys", new SimpleRnn.Builder().nOut(layerSize).build(), "input").addLayer("rnnQueries", new SimpleRnn.Builder().nOut(layerSize).build(), "input").addLayer("rnnValues", new SimpleRnn.Builder().nOut(layerSize).build(), "input").addVertex("attention", projectInput ? new AttentionVertex.Builder().nOut(4).nHeads(2).projectInput(true).nInQueries(layerSize).nInKeys(layerSize).nInValues(layerSize).build() : new AttentionVertex.Builder().nOut(3).nHeads(1).projectInput(false).nInQueries(layerSize).nInKeys(layerSize).nInValues(layerSize).build(), "rnnQueries", "rnnKeys", "rnnValues").addLayer("pooling", new GlobalPoolingLayer.Builder().poolingType(PoolingType.MAX).build(), "attention").addLayer("output", new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build(), "pooling").setOutputs("output").setInputTypes(InputType.recurrent(nIn)).build(); ComputationGraph net = new ComputationGraph(graph); net.init(); - - boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[]{in}) - .labels(new INDArray[]{labels}).inputMask(inMask != null ? new INDArray[]{inMask} : null).subset(true).maxPerParam(100)); - assertTrue(name, gradOK); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[] { in }).labels(new INDArray[] { labels }).inputMask(inMask != null ? new INDArray[] { inMask } : null).subset(true).maxPerParam(100)); + assertTrue(gradOK,name); } } } } @Test - public void testAttentionVertexSameInput() { + @DisplayName("Test Attention Vertex Same Input") + void testAttentionVertexSameInput() { int nIn = 3; int nOut = 2; int tsLength = 4; int layerSize = 4; - Random r = new Random(12345); - for (boolean inputMask : new boolean[]{false, true}) { - for (int mb : new int[]{3, 1}) { - for (boolean projectInput : new boolean[]{false, true}) { - INDArray in = Nd4j.rand(new int[]{mb, nIn, tsLength}); + for (boolean inputMask : new boolean[] { false, true }) { + for (int mb : new int[] { 3, 1 }) { + for (boolean projectInput : new boolean[] { false, true }) { + INDArray in = Nd4j.rand(new int[] { mb, nIn, tsLength }); INDArray labels = TestUtils.randomOneHot(mb, nOut); String maskType = (inputMask ? "inputMask" : "none"); - INDArray inMask = null; if (inputMask) { inMask = Nd4j.ones(mb, tsLength); @@ -424,35 +298,13 @@ public class AttentionLayerTest extends BaseDL4JTest { } } } - String name = "testAttentionVertex() - mb=" + mb + ", tsLength = " + tsLength + ", maskType=" + maskType + ", projectInput = " + projectInput; System.out.println("Starting test: " + name); - - - ComputationGraphConfiguration graph = new NeuralNetConfiguration.Builder() - .dataType(DataType.DOUBLE) - .activation(Activation.TANH) - .updater(new NoOp()) - .weightInit(WeightInit.XAVIER) - .graphBuilder() - .addInputs("input") - .addLayer("rnn", new SimpleRnn.Builder().activation(Activation.TANH).nOut(layerSize).build(), "input") - .addVertex("attention", - projectInput ? - new AttentionVertex.Builder().nOut(4).nHeads(2).projectInput(true).nInQueries(layerSize).nInKeys(layerSize).nInValues(layerSize).build() - : new AttentionVertex.Builder().nOut(4).nHeads(1).projectInput(false).nInQueries(layerSize).nInKeys(layerSize).nInValues(layerSize).build(), "rnn", "rnn", "rnn") - .addLayer("pooling", new GlobalPoolingLayer.Builder().poolingType(PoolingType.MAX).build(), "attention") - .addLayer("output", new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build(), "pooling") - .setOutputs("output") - .setInputTypes(InputType.recurrent(nIn)) - .build(); - + ComputationGraphConfiguration graph = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).activation(Activation.TANH).updater(new NoOp()).weightInit(WeightInit.XAVIER).graphBuilder().addInputs("input").addLayer("rnn", new SimpleRnn.Builder().activation(Activation.TANH).nOut(layerSize).build(), "input").addVertex("attention", projectInput ? new AttentionVertex.Builder().nOut(4).nHeads(2).projectInput(true).nInQueries(layerSize).nInKeys(layerSize).nInValues(layerSize).build() : new AttentionVertex.Builder().nOut(4).nHeads(1).projectInput(false).nInQueries(layerSize).nInKeys(layerSize).nInValues(layerSize).build(), "rnn", "rnn", "rnn").addLayer("pooling", new GlobalPoolingLayer.Builder().poolingType(PoolingType.MAX).build(), "attention").addLayer("output", new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build(), "pooling").setOutputs("output").setInputTypes(InputType.recurrent(nIn)).build(); ComputationGraph net = new ComputationGraph(graph); net.init(); - - boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[]{in}) - .labels(new INDArray[]{labels}).inputMask(inMask != null ? new INDArray[]{inMask} : null)); - assertTrue(name, gradOK); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[] { in }).labels(new INDArray[] { labels }).inputMask(inMask != null ? new INDArray[] { inMask } : null)); + assertTrue(gradOK,name); } } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/BNGradientCheckTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/BNGradientCheckTest.java index 2106ea4be..f728f3f29 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/BNGradientCheckTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/BNGradientCheckTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.gradientcheck; import org.deeplearning4j.BaseDL4JTest; @@ -34,7 +33,7 @@ import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -48,18 +47,18 @@ import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.profiler.OpProfiler; import org.nd4j.linalg.profiler.ProfilerConfig; - import java.util.Arrays; import java.util.HashSet; import java.util.Random; import java.util.Set; - -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertTrue; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** - * */ -public class BNGradientCheckTest extends BaseDL4JTest { +@DisplayName("Bn Gradient Check Test") +class BNGradientCheckTest extends BaseDL4JTest { static { Nd4j.setDataType(DataType.DOUBLE); @@ -71,7 +70,8 @@ public class BNGradientCheckTest extends BaseDL4JTest { } @Test - public void testGradient2dSimple() { + @DisplayName("Test Gradient 2 d Simple") + void testGradient2dSimple() { DataNormalization scaler = new NormalizerMinMaxScaler(); DataSetIterator iter = new IrisDataSetIterator(150, 150); scaler.fit(iter); @@ -79,181 +79,117 @@ public class BNGradientCheckTest extends BaseDL4JTest { DataSet ds = iter.next(); INDArray input = ds.getFeatures(); INDArray labels = ds.getLabels(); - - for (boolean useLogStd : new boolean[]{true, false}) { - - MultiLayerConfiguration.Builder builder = - new NeuralNetConfiguration.Builder().updater(new NoOp()) - .dataType(DataType.DOUBLE) - .seed(12345L) - .dist(new NormalDistribution(0, 1)).list() - .layer(0, new DenseLayer.Builder().nIn(4).nOut(3) - .activation(Activation.IDENTITY).build()) - .layer(1, new BatchNormalization.Builder().useLogStd(useLogStd).nOut(3).build()) - .layer(2, new ActivationLayer.Builder().activation(Activation.TANH).build()) - .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nIn(3).nOut(3).build()); - + for (boolean useLogStd : new boolean[] { true, false }) { + MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().updater(new NoOp()).dataType(DataType.DOUBLE).seed(12345L).dist(new NormalDistribution(0, 1)).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(3).activation(Activation.IDENTITY).build()).layer(1, new BatchNormalization.Builder().useLogStd(useLogStd).nOut(3).build()).layer(2, new ActivationLayer.Builder().activation(Activation.TANH).build()).layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3).build()); MultiLayerNetwork mln = new MultiLayerNetwork(builder.build()); mln.init(); - -// for (int j = 0; j < mln.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); - - //Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc - //i.e., runningMean = decay * runningMean + (1-decay) * batchMean - //However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter" + // for (int j = 0; j < mln.getnLayers(); j++) + // System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); + // Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc + // i.e., runningMean = decay * runningMean + (1-decay) * batchMean + // However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter" Set excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "1_log10stdev")); - boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input) - .labels(labels).excludeParams(excludeParams)); - + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input).labels(labels).excludeParams(excludeParams)); assertTrue(gradOK); TestUtils.testModelSerialization(mln); } } @Test - public void testGradientCnnSimple() { + @DisplayName("Test Gradient Cnn Simple") + void testGradientCnnSimple() { Nd4j.getRandom().setSeed(12345); int minibatch = 10; int depth = 1; int hw = 4; int nOut = 4; - INDArray input = Nd4j.rand(new int[]{minibatch, depth, hw, hw}); + INDArray input = Nd4j.rand(new int[] { minibatch, depth, hw, hw }); INDArray labels = Nd4j.zeros(minibatch, nOut); Random r = new Random(12345); for (int i = 0; i < minibatch; i++) { labels.putScalar(i, r.nextInt(nOut), 1.0); } - - for (boolean useLogStd : new boolean[]{true, false}) { - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder() - .dataType(DataType.DOUBLE) - .updater(new NoOp()).seed(12345L) - .dist(new NormalDistribution(0, 2)).list() - .layer(0, new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).nIn(depth).nOut(2) - .activation(Activation.IDENTITY).build()) - .layer(1, new BatchNormalization.Builder().useLogStd(useLogStd).build()) - .layer(2, new ActivationLayer.Builder().activation(Activation.TANH).build()) - .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nOut(nOut).build()) - .setInputType(InputType.convolutional(hw, hw, depth)); - + for (boolean useLogStd : new boolean[] { true, false }) { + MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).updater(new NoOp()).seed(12345L).dist(new NormalDistribution(0, 2)).list().layer(0, new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).nIn(depth).nOut(2).activation(Activation.IDENTITY).build()).layer(1, new BatchNormalization.Builder().useLogStd(useLogStd).build()).layer(2, new ActivationLayer.Builder().activation(Activation.TANH).build()).layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(nOut).build()).setInputType(InputType.convolutional(hw, hw, depth)); MultiLayerNetwork mln = new MultiLayerNetwork(builder.build()); mln.init(); - -// for (int j = 0; j < mln.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); - - //Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc - //i.e., runningMean = decay * runningMean + (1-decay) * batchMean - //However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter" + // for (int j = 0; j < mln.getnLayers(); j++) + // System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); + // Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc + // i.e., runningMean = decay * runningMean + (1-decay) * batchMean + // However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter" Set excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "1_log10stdev")); - boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input) - .labels(labels).excludeParams(excludeParams)); - + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input).labels(labels).excludeParams(excludeParams)); assertTrue(gradOK); TestUtils.testModelSerialization(mln); } } @Test - public void testGradientBNWithCNNandSubsampling() { - //Parameterized test, testing combinations of: + @DisplayName("Test Gradient BN With CN Nand Subsampling") + void testGradientBNWithCNNandSubsampling() { + // Parameterized test, testing combinations of: // (a) activation function // (b) Whether to test at random initialization, or after some learning (i.e., 'characteristic mode of operation') // (c) Loss function (with specified output activations) // (d) l1 and l2 values - Activation[] activFns = {Activation.SIGMOID, Activation.TANH, Activation.IDENTITY}; - boolean[] characteristic = {true}; //If true: run some backprop steps first - - LossFunctions.LossFunction[] lossFunctions = - {LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD, LossFunctions.LossFunction.MSE}; - Activation[] outputActivations = {Activation.SOFTMAX, Activation.TANH}; //i.e., lossFunctions[i] used with outputActivations[i] here - - double[] l2vals = {0.0, 0.1, 0.1}; - double[] l1vals = {0.0, 0.0, 0.2}; //i.e., use l2vals[j] with l1vals[j] - + Activation[] activFns = { Activation.SIGMOID, Activation.TANH, Activation.IDENTITY }; + // If true: run some backprop steps first + boolean[] characteristic = { true }; + LossFunctions.LossFunction[] lossFunctions = { LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD, LossFunctions.LossFunction.MSE }; + // i.e., lossFunctions[i] used with outputActivations[i] here + Activation[] outputActivations = { Activation.SOFTMAX, Activation.TANH }; + double[] l2vals = { 0.0, 0.1, 0.1 }; + // i.e., use l2vals[j] with l1vals[j] + double[] l1vals = { 0.0, 0.0, 0.2 }; Nd4j.getRandom().setSeed(12345); int minibatch = 4; int depth = 2; int hw = 5; int nOut = 2; - INDArray input = Nd4j.rand(new int[]{minibatch, depth, hw, hw}).muli(5).subi(2.5); + INDArray input = Nd4j.rand(new int[] { minibatch, depth, hw, hw }).muli(5).subi(2.5); INDArray labels = TestUtils.randomOneHot(minibatch, nOut); - DataSet ds = new DataSet(input, labels); Random rng = new Random(12345); - for (boolean useLogStd : new boolean[]{true, false}) { + for (boolean useLogStd : new boolean[] { true, false }) { for (Activation afn : activFns) { for (boolean doLearningFirst : characteristic) { for (int i = 0; i < lossFunctions.length; i++) { for (int j = 0; j < l2vals.length; j++) { - //Skip 2 of every 3 tests: from 24 cases to 8, still with decent coverage + // Skip 2 of every 3 tests: from 24 cases to 8, still with decent coverage if (rng.nextInt(3) != 0) continue; - LossFunctions.LossFunction lf = lossFunctions[i]; Activation outputActivation = outputActivations[i]; - - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(12345) - .dataType(DataType.DOUBLE) - .l2(l2vals[j]) - .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT) - .updater(new NoOp()) - .dist(new UniformDistribution(-2, 2)).seed(12345L).list() - .layer(0, new ConvolutionLayer.Builder(2, 2).stride(1, 1).nOut(3) - .activation(afn).build()) - .layer(1, new BatchNormalization.Builder().useLogStd(useLogStd).build()) - .layer(2, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX) - .kernelSize(2, 2).stride(1, 1).build()) - .layer(3, new BatchNormalization()) - .layer(4, new ActivationLayer.Builder().activation(afn).build()) - .layer(5, new OutputLayer.Builder(lf).activation(outputActivation).nOut(nOut) - .build()) - .setInputType(InputType.convolutional(hw, hw, depth)); - + MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(12345).dataType(DataType.DOUBLE).l2(l2vals[j]).optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).updater(new NoOp()).dist(new UniformDistribution(-2, 2)).seed(12345L).list().layer(0, new ConvolutionLayer.Builder(2, 2).stride(1, 1).nOut(3).activation(afn).build()).layer(1, new BatchNormalization.Builder().useLogStd(useLogStd).build()).layer(2, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(1, 1).build()).layer(3, new BatchNormalization()).layer(4, new ActivationLayer.Builder().activation(afn).build()).layer(5, new OutputLayer.Builder(lf).activation(outputActivation).nOut(nOut).build()).setInputType(InputType.convolutional(hw, hw, depth)); MultiLayerConfiguration conf = builder.build(); - MultiLayerNetwork mln = new MultiLayerNetwork(conf); mln.init(); String name = new Object() { }.getClass().getEnclosingMethod().getName(); - -// System.out.println("Num params: " + mln.numParams()); - + // System.out.println("Num params: " + mln.numParams()); if (doLearningFirst) { - //Run a number of iterations of learning + // Run a number of iterations of learning mln.setInput(ds.getFeatures()); mln.setLabels(ds.getLabels()); mln.computeGradientAndScore(); double scoreBefore = mln.score(); - for (int k = 0; k < 20; k++) - mln.fit(ds); + for (int k = 0; k < 20; k++) mln.fit(ds); mln.computeGradientAndScore(); double scoreAfter = mln.score(); - //Can't test in 'characteristic mode of operation' if not learning - String msg = name - + " - score did not (sufficiently) decrease during learning - activationFn=" - + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation - + ", doLearningFirst= " + doLearningFirst + " (before=" + scoreBefore - + ", scoreAfter=" + scoreAfter + ")"; - assertTrue(msg, scoreAfter < 0.9 * scoreBefore); + // Can't test in 'characteristic mode of operation' if not learning + String msg = name + " - score did not (sufficiently) decrease during learning - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", doLearningFirst= " + doLearningFirst + " (before=" + scoreBefore + ", scoreAfter=" + scoreAfter + ")"; + assertTrue(scoreAfter < 0.9 * scoreBefore,msg); } - - System.out.println(name + " - activationFn=" + afn + ", lossFn=" + lf - + ", outputActivation=" + outputActivation + ", doLearningFirst=" - + doLearningFirst + ", l1=" + l1vals[j] + ", l2=" + l2vals[j]); -// for (int k = 0; k < mln.getnLayers(); k++) -// System.out.println("Layer " + k + " # params: " + mln.getLayer(k).numParams()); - - //Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc - //i.e., runningMean = decay * runningMean + (1-decay) * batchMean - //However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter" + System.out.println(name + " - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", doLearningFirst=" + doLearningFirst + ", l1=" + l1vals[j] + ", l2=" + l2vals[j]); + // for (int k = 0; k < mln.getnLayers(); k++) + // System.out.println("Layer " + k + " # params: " + mln.getLayer(k).numParams()); + // Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc + // i.e., runningMean = decay * runningMean + (1-decay) * batchMean + // However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter" Set excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "3_mean", "3_var", "1_log10stdev", "3_log10stdev")); - boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input) - .labels(labels).excludeParams(excludeParams).subset(true).maxPerParam(25)); //Most params are in output layer, only these should be skipped with this threshold - + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input).labels(labels).excludeParams(excludeParams).subset(true).maxPerParam(// Most params are in output layer, only these should be skipped with this threshold + 25)); assertTrue(gradOK); TestUtils.testModelSerialization(mln); } @@ -263,101 +199,68 @@ public class BNGradientCheckTest extends BaseDL4JTest { } } - @Test - public void testGradientDense() { - //Parameterized test, testing combinations of: + @DisplayName("Test Gradient Dense") + void testGradientDense() { + // Parameterized test, testing combinations of: // (a) activation function // (b) Whether to test at random initialization, or after some learning (i.e., 'characteristic mode of operation') // (c) Loss function (with specified output activations) // (d) l1 and l2 values - Activation[] activFns = {Activation.TANH, Activation.IDENTITY}; - boolean[] characteristic = {true}; //If true: run some backprop steps first - - LossFunctions.LossFunction[] lossFunctions = - {LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD, LossFunctions.LossFunction.MSE}; - Activation[] outputActivations = {Activation.SOFTMAX, Activation.TANH}; //i.e., lossFunctions[i] used with outputActivations[i] here - - double[] l2vals = {0.0, 0.1}; - double[] l1vals = {0.0, 0.2}; //i.e., use l2vals[j] with l1vals[j] - + Activation[] activFns = { Activation.TANH, Activation.IDENTITY }; + // If true: run some backprop steps first + boolean[] characteristic = { true }; + LossFunctions.LossFunction[] lossFunctions = { LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD, LossFunctions.LossFunction.MSE }; + // i.e., lossFunctions[i] used with outputActivations[i] here + Activation[] outputActivations = { Activation.SOFTMAX, Activation.TANH }; + double[] l2vals = { 0.0, 0.1 }; + // i.e., use l2vals[j] with l1vals[j] + double[] l1vals = { 0.0, 0.2 }; Nd4j.getRandom().setSeed(12345); int minibatch = 10; int nIn = 5; int nOut = 3; - INDArray input = Nd4j.rand(new int[]{minibatch, nIn}); + INDArray input = Nd4j.rand(new int[] { minibatch, nIn }); INDArray labels = Nd4j.zeros(minibatch, nOut); Random r = new Random(12345); for (int i = 0; i < minibatch; i++) { labels.putScalar(i, r.nextInt(nOut), 1.0); } - DataSet ds = new DataSet(input, labels); - - for (boolean useLogStd : new boolean[]{true, false}) { + for (boolean useLogStd : new boolean[] { true, false }) { for (Activation afn : activFns) { for (boolean doLearningFirst : characteristic) { for (int i = 0; i < lossFunctions.length; i++) { for (int j = 0; j < l2vals.length; j++) { LossFunctions.LossFunction lf = lossFunctions[i]; Activation outputActivation = outputActivations[i]; - - MultiLayerConfiguration.Builder builder = - new NeuralNetConfiguration.Builder() - .dataType(DataType.DOUBLE) - .l2(l2vals[j]) - .optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT) - .updater(new NoOp()) - .dist(new UniformDistribution(-2, 2)).seed(12345L).list() - .layer(0, new DenseLayer.Builder().nIn(nIn).nOut(4) - .activation(afn).build()) - .layer(1, new BatchNormalization.Builder().useLogStd(useLogStd).build()) - .layer(2, new DenseLayer.Builder().nIn(4).nOut(4).build()) - .layer(3, new BatchNormalization.Builder().useLogStd(useLogStd).build()) - .layer(4, new OutputLayer.Builder(lf) - .activation(outputActivation).nOut(nOut) - .build()); - + MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).l2(l2vals[j]).optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).updater(new NoOp()).dist(new UniformDistribution(-2, 2)).seed(12345L).list().layer(0, new DenseLayer.Builder().nIn(nIn).nOut(4).activation(afn).build()).layer(1, new BatchNormalization.Builder().useLogStd(useLogStd).build()).layer(2, new DenseLayer.Builder().nIn(4).nOut(4).build()).layer(3, new BatchNormalization.Builder().useLogStd(useLogStd).build()).layer(4, new OutputLayer.Builder(lf).activation(outputActivation).nOut(nOut).build()); MultiLayerConfiguration conf = builder.build(); - MultiLayerNetwork mln = new MultiLayerNetwork(conf); mln.init(); - String name = new Object() { }.getClass().getEnclosingMethod().getName(); - if (doLearningFirst) { - //Run a number of iterations of learning + // Run a number of iterations of learning mln.setInput(ds.getFeatures()); mln.setLabels(ds.getLabels()); mln.computeGradientAndScore(); double scoreBefore = mln.score(); - for (int k = 0; k < 10; k++) - mln.fit(ds); + for (int k = 0; k < 10; k++) mln.fit(ds); mln.computeGradientAndScore(); double scoreAfter = mln.score(); - //Can't test in 'characteristic mode of operation' if not learning - String msg = name - + " - score did not (sufficiently) decrease during learning - activationFn=" - + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation - + ", doLearningFirst= " + doLearningFirst + " (before=" + scoreBefore - + ", scoreAfter=" + scoreAfter + ")"; - assertTrue(msg, scoreAfter < 0.8 * scoreBefore); + // Can't test in 'characteristic mode of operation' if not learning + String msg = name + " - score did not (sufficiently) decrease during learning - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", doLearningFirst= " + doLearningFirst + " (before=" + scoreBefore + ", scoreAfter=" + scoreAfter + ")"; + assertTrue(scoreAfter < 0.8 * scoreBefore,msg); } - - System.out.println(name + " - activationFn=" + afn + ", lossFn=" + lf - + ", outputActivation=" + outputActivation + ", doLearningFirst=" - + doLearningFirst + ", l1=" + l1vals[j] + ", l2=" + l2vals[j]); -// for (int k = 0; k < mln.getnLayers(); k++) -// System.out.println("Layer " + k + " # params: " + mln.getLayer(k).numParams()); - - //Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc - //i.e., runningMean = decay * runningMean + (1-decay) * batchMean - //However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter" + System.out.println(name + " - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", doLearningFirst=" + doLearningFirst + ", l1=" + l1vals[j] + ", l2=" + l2vals[j]); + // for (int k = 0; k < mln.getnLayers(); k++) + // System.out.println("Layer " + k + " # params: " + mln.getLayer(k).numParams()); + // Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc + // i.e., runningMean = decay * runningMean + (1-decay) * batchMean + // However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter" Set excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "3_mean", "3_var", "1_log10stdev", "3_log10stdev")); - boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input) - .labels(labels).excludeParams(excludeParams)); - + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input).labels(labels).excludeParams(excludeParams)); assertTrue(gradOK); TestUtils.testModelSerialization(mln); } @@ -368,7 +271,8 @@ public class BNGradientCheckTest extends BaseDL4JTest { } @Test - public void testGradient2dFixedGammaBeta() { + @DisplayName("Test Gradient 2 d Fixed Gamma Beta") + void testGradient2dFixedGammaBeta() { DataNormalization scaler = new NormalizerMinMaxScaler(); DataSetIterator iter = new IrisDataSetIterator(150, 150); scaler.fit(iter); @@ -376,219 +280,142 @@ public class BNGradientCheckTest extends BaseDL4JTest { DataSet ds = iter.next(); INDArray input = ds.getFeatures(); INDArray labels = ds.getLabels(); - - for (boolean useLogStd : new boolean[]{true, false}) { - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().updater(new NoOp()) - .dataType(DataType.DOUBLE) - .seed(12345L) - .dist(new NormalDistribution(0, 1)).list() - .layer(0, new DenseLayer.Builder().nIn(4).nOut(3).activation(Activation.IDENTITY).build()) - .layer(1, new BatchNormalization.Builder().useLogStd(useLogStd).lockGammaBeta(true).gamma(2.0).beta(0.5).nOut(3) - .build()) - .layer(2, new ActivationLayer.Builder().activation(Activation.TANH).build()) - .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nIn(3).nOut(3).build()); - + for (boolean useLogStd : new boolean[] { true, false }) { + MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().updater(new NoOp()).dataType(DataType.DOUBLE).seed(12345L).dist(new NormalDistribution(0, 1)).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(3).activation(Activation.IDENTITY).build()).layer(1, new BatchNormalization.Builder().useLogStd(useLogStd).lockGammaBeta(true).gamma(2.0).beta(0.5).nOut(3).build()).layer(2, new ActivationLayer.Builder().activation(Activation.TANH).build()).layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3).build()); MultiLayerNetwork mln = new MultiLayerNetwork(builder.build()); mln.init(); - -// for (int j = 0; j < mln.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); - - //Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc - //i.e., runningMean = decay * runningMean + (1-decay) * batchMean - //However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter" + // for (int j = 0; j < mln.getnLayers(); j++) + // System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); + // Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc + // i.e., runningMean = decay * runningMean + (1-decay) * batchMean + // However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter" Set excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "1_log10stdev")); - boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input) - .labels(labels).excludeParams(excludeParams)); - + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input).labels(labels).excludeParams(excludeParams)); assertTrue(gradOK); TestUtils.testModelSerialization(mln); } } @Test - public void testGradientCnnFixedGammaBeta() { + @DisplayName("Test Gradient Cnn Fixed Gamma Beta") + void testGradientCnnFixedGammaBeta() { Nd4j.getRandom().setSeed(12345); int minibatch = 10; int depth = 1; int hw = 4; int nOut = 4; - INDArray input = Nd4j.rand(new int[]{minibatch, depth, hw, hw}); + INDArray input = Nd4j.rand(new int[] { minibatch, depth, hw, hw }); INDArray labels = Nd4j.zeros(minibatch, nOut); Random r = new Random(12345); for (int i = 0; i < minibatch; i++) { labels.putScalar(i, r.nextInt(nOut), 1.0); } - - for (boolean useLogStd : new boolean[]{true, false}) { - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().updater(new NoOp()) - .dataType(DataType.DOUBLE) - .seed(12345L) - .dist(new NormalDistribution(0, 2)).list() - .layer(0, new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).nIn(depth).nOut(2) - .activation(Activation.IDENTITY).build()) - .layer(1, new BatchNormalization.Builder().useLogStd(useLogStd).lockGammaBeta(true).gamma(2.0).beta(0.5).build()) - .layer(2, new ActivationLayer.Builder().activation(Activation.TANH).build()) - .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nOut(nOut).build()) - .setInputType(InputType.convolutional(hw, hw, depth)); - + for (boolean useLogStd : new boolean[] { true, false }) { + MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().updater(new NoOp()).dataType(DataType.DOUBLE).seed(12345L).dist(new NormalDistribution(0, 2)).list().layer(0, new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).nIn(depth).nOut(2).activation(Activation.IDENTITY).build()).layer(1, new BatchNormalization.Builder().useLogStd(useLogStd).lockGammaBeta(true).gamma(2.0).beta(0.5).build()).layer(2, new ActivationLayer.Builder().activation(Activation.TANH).build()).layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(nOut).build()).setInputType(InputType.convolutional(hw, hw, depth)); MultiLayerNetwork mln = new MultiLayerNetwork(builder.build()); mln.init(); - -// for (int j = 0; j < mln.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); - - //Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc - //i.e., runningMean = decay * runningMean + (1-decay) * batchMean - //However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter" + // for (int j = 0; j < mln.getnLayers(); j++) + // System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); + // Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc + // i.e., runningMean = decay * runningMean + (1-decay) * batchMean + // However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter" Set excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "1_log10stdev")); - boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input) - .labels(labels).excludeParams(excludeParams)); - + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input).labels(labels).excludeParams(excludeParams)); assertTrue(gradOK); TestUtils.testModelSerialization(mln); } } @Test - public void testBatchNormCompGraphSimple() { - + @DisplayName("Test Batch Norm Comp Graph Simple") + void testBatchNormCompGraphSimple() { int numClasses = 2; int height = 3; int width = 3; int channels = 1; long seed = 123; - int minibatchSize = 3; - - for (boolean useLogStd : new boolean[]{true, false}) { - - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed).updater(new NoOp()) - .dataType(DataType.DOUBLE) - .weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in") - .setInputTypes(InputType.convolutional(height, width, channels)) - .addLayer("bn", new BatchNormalization.Builder().useLogStd(useLogStd).build(), "in") - .addLayer("out", new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nOut(numClasses).build(), "bn") - .setOutputs("out").build(); - + for (boolean useLogStd : new boolean[] { true, false }) { + ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed).updater(new NoOp()).dataType(DataType.DOUBLE).weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in").setInputTypes(InputType.convolutional(height, width, channels)).addLayer("bn", new BatchNormalization.Builder().useLogStd(useLogStd).build(), "in").addLayer("out", new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(numClasses).build(), "bn").setOutputs("out").build(); ComputationGraph net = new ComputationGraph(conf); net.init(); - Random r = new Random(12345); - INDArray input = Nd4j.rand(new int[]{minibatchSize, channels, height, width}); //Order: examples, channels, height, width + // Order: examples, channels, height, width + INDArray input = Nd4j.rand(new int[] { minibatchSize, channels, height, width }); INDArray labels = Nd4j.zeros(minibatchSize, numClasses); for (int i = 0; i < minibatchSize; i++) { - labels.putScalar(new int[]{i, r.nextInt(numClasses)}, 1.0); + labels.putScalar(new int[] { i, r.nextInt(numClasses) }, 1.0); } - - //Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc - //i.e., runningMean = decay * runningMean + (1-decay) * batchMean - //However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter" + // Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc + // i.e., runningMean = decay * runningMean + (1-decay) * batchMean + // However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter" Set excludeParams = new HashSet<>(Arrays.asList("bn_mean", "bn_var")); - boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[]{input}) - .labels(new INDArray[]{labels}).excludeParams(excludeParams)); - + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[] { input }).labels(new INDArray[] { labels }).excludeParams(excludeParams)); assertTrue(gradOK); TestUtils.testModelSerialization(net); } } - @Test - public void testGradientBNWithCNNandSubsamplingCompGraph() { - //Parameterized test, testing combinations of: + @DisplayName("Test Gradient BN With CN Nand Subsampling Comp Graph") + void testGradientBNWithCNNandSubsamplingCompGraph() { + // Parameterized test, testing combinations of: // (a) activation function // (b) Whether to test at random initialization, or after some learning (i.e., 'characteristic mode of operation') // (c) Loss function (with specified output activations) // (d) l1 and l2 values - Activation[] activFns = {Activation.TANH, Activation.IDENTITY}; + Activation[] activFns = { Activation.TANH, Activation.IDENTITY }; boolean doLearningFirst = true; - - LossFunctions.LossFunction[] lossFunctions = {LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD}; - Activation[] outputActivations = {Activation.SOFTMAX}; //i.e., lossFunctions[i] used with outputActivations[i] here - - double[] l2vals = {0.0, 0.1}; - double[] l1vals = {0.0, 0.2}; //i.e., use l2vals[j] with l1vals[j] - + LossFunctions.LossFunction[] lossFunctions = { LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD }; + // i.e., lossFunctions[i] used with outputActivations[i] here + Activation[] outputActivations = { Activation.SOFTMAX }; + double[] l2vals = { 0.0, 0.1 }; + // i.e., use l2vals[j] with l1vals[j] + double[] l1vals = { 0.0, 0.2 }; Nd4j.getRandom().setSeed(12345); int minibatch = 10; int depth = 2; int hw = 5; int nOut = 3; - INDArray input = Nd4j.rand(new int[]{minibatch, depth, hw, hw}); + INDArray input = Nd4j.rand(new int[] { minibatch, depth, hw, hw }); INDArray labels = Nd4j.zeros(minibatch, nOut); Random r = new Random(12345); for (int i = 0; i < minibatch; i++) { labels.putScalar(i, r.nextInt(nOut), 1.0); } - DataSet ds = new DataSet(input, labels); - - for (boolean useLogStd : new boolean[]{true, false}) { + for (boolean useLogStd : new boolean[] { true, false }) { for (Activation afn : activFns) { for (int i = 0; i < lossFunctions.length; i++) { for (int j = 0; j < l2vals.length; j++) { LossFunctions.LossFunction lf = lossFunctions[i]; Activation outputActivation = outputActivations[i]; - - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) - .dataType(DataType.DOUBLE) - .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT) - .updater(new NoOp()) - .dist(new UniformDistribution(-2, 2)).seed(12345L).graphBuilder() - .addInputs("in") - .addLayer("0", new ConvolutionLayer.Builder(2, 2).stride(1, 1).nOut(3) - .activation(afn).build(), "in") - .addLayer("1", new BatchNormalization.Builder().useLogStd(useLogStd).build(), "0") - .addLayer("2", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX) - .kernelSize(2, 2).stride(1, 1).build(), "1") - .addLayer("3", new BatchNormalization.Builder().useLogStd(useLogStd).build(), "2") - .addLayer("4", new ActivationLayer.Builder().activation(afn).build(), "3") - .addLayer("5", new OutputLayer.Builder(lf).activation(outputActivation) - .nOut(nOut).build(), "4") - .setOutputs("5").setInputTypes(InputType.convolutional(hw, hw, depth)) - .build(); - + ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).dataType(DataType.DOUBLE).optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).updater(new NoOp()).dist(new UniformDistribution(-2, 2)).seed(12345L).graphBuilder().addInputs("in").addLayer("0", new ConvolutionLayer.Builder(2, 2).stride(1, 1).nOut(3).activation(afn).build(), "in").addLayer("1", new BatchNormalization.Builder().useLogStd(useLogStd).build(), "0").addLayer("2", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(1, 1).build(), "1").addLayer("3", new BatchNormalization.Builder().useLogStd(useLogStd).build(), "2").addLayer("4", new ActivationLayer.Builder().activation(afn).build(), "3").addLayer("5", new OutputLayer.Builder(lf).activation(outputActivation).nOut(nOut).build(), "4").setOutputs("5").setInputTypes(InputType.convolutional(hw, hw, depth)).build(); ComputationGraph net = new ComputationGraph(conf); net.init(); String name = new Object() { }.getClass().getEnclosingMethod().getName(); - if (doLearningFirst) { - //Run a number of iterations of learning + // Run a number of iterations of learning net.setInput(0, ds.getFeatures()); net.setLabels(ds.getLabels()); net.computeGradientAndScore(); double scoreBefore = net.score(); - for (int k = 0; k < 20; k++) - net.fit(ds); + for (int k = 0; k < 20; k++) net.fit(ds); net.computeGradientAndScore(); double scoreAfter = net.score(); - //Can't test in 'characteristic mode of operation' if not learning - String msg = name - + " - score did not (sufficiently) decrease during learning - activationFn=" - + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation - + ", doLearningFirst= " + doLearningFirst + " (before=" + scoreBefore - + ", scoreAfter=" + scoreAfter + ")"; - assertTrue(msg, scoreAfter < 0.9 * scoreBefore); + // Can't test in 'characteristic mode of operation' if not learning + String msg = name + " - score did not (sufficiently) decrease during learning - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", doLearningFirst= " + doLearningFirst + " (before=" + scoreBefore + ", scoreAfter=" + scoreAfter + ")"; + assertTrue(scoreAfter < 0.9 * scoreBefore,msg); } - - System.out.println(name + " - activationFn=" + afn + ", lossFn=" + lf - + ", outputActivation=" + outputActivation + ", doLearningFirst=" - + doLearningFirst + ", l1=" + l1vals[j] + ", l2=" + l2vals[j]); -// for (int k = 0; k < net.getNumLayers(); k++) -// System.out.println("Layer " + k + " # params: " + net.getLayer(k).numParams()); - - //Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc - //i.e., runningMean = decay * runningMean + (1-decay) * batchMean - //However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter" + System.out.println(name + " - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", doLearningFirst=" + doLearningFirst + ", l1=" + l1vals[j] + ", l2=" + l2vals[j]); + // for (int k = 0; k < net.getNumLayers(); k++) + // System.out.println("Layer " + k + " # params: " + net.getLayer(k).numParams()); + // Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc + // i.e., runningMean = decay * runningMean + (1-decay) * batchMean + // However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter" Set excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "3_mean", "3_var", "1_log10stdev", "3_log10stdev")); - boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[]{input}) - .labels(new INDArray[]{labels}).excludeParams(excludeParams)); - + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[] { input }).labels(new INDArray[] { labels }).excludeParams(excludeParams)); assertTrue(gradOK); TestUtils.testModelSerialization(net); } @@ -596,5 +423,4 @@ public class BNGradientCheckTest extends BaseDL4JTest { } } } - } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN1DGradientCheckTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN1DGradientCheckTest.java index 6151c4099..f85a426d2 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN1DGradientCheckTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN1DGradientCheckTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.gradientcheck; import lombok.extern.slf4j.Slf4j; @@ -35,7 +34,7 @@ import org.deeplearning4j.nn.modelimport.keras.KerasModelImport; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.util.Convolution1DUtils; import org.deeplearning4j.util.ConvolutionUtils; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -44,18 +43,24 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.lossfunctions.LossFunctions; - import java.io.File; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; @Slf4j -public class CNN1DGradientCheckTest extends BaseDL4JTest { +@DisplayName("Cnn 1 D Gradient Check Test") +class CNN1DGradientCheckTest extends BaseDL4JTest { + private static final boolean PRINT_RESULTS = true; + private static final boolean RETURN_ON_FIRST_FAILURE = false; + private static final double DEFAULT_EPS = 1e-6; + private static final double DEFAULT_MAX_REL_ERROR = 1e-3; + private static final double DEFAULT_MIN_ABS_ERROR = 1e-8; static { @@ -68,148 +73,91 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest { } @Test - public void testCnn1DWithLocallyConnected1D() { + @DisplayName("Test Cnn 1 D With Locally Connected 1 D") + void testCnn1DWithLocallyConnected1D() { Nd4j.getRandom().setSeed(1337); - - int[] minibatchSizes = {2, 3}; + int[] minibatchSizes = { 2, 3 }; int length = 7; int convNIn = 2; int convNOut1 = 3; int convNOut2 = 4; int finalNOut = 4; - - int[] kernels = {1}; + int[] kernels = { 1 }; int stride = 1; int padding = 0; - - Activation[] activations = {Activation.SIGMOID}; - + Activation[] activations = { Activation.SIGMOID }; for (Activation afn : activations) { for (int minibatchSize : minibatchSizes) { for (int kernel : kernels) { - INDArray input = Nd4j.rand(new int[]{minibatchSize, convNIn, length}); + INDArray input = Nd4j.rand(new int[] { minibatchSize, convNIn, length }); INDArray labels = Nd4j.zeros(minibatchSize, finalNOut, length); for (int i = 0; i < minibatchSize; i++) { for (int j = 0; j < length; j++) { - labels.putScalar(new int[]{i, i % finalNOut, j}, 1.0); + labels.putScalar(new int[] { i, i % finalNOut, j }, 1.0); } } - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .dataType(DataType.DOUBLE) - .updater(new NoOp()) - .dist(new NormalDistribution(0, 1)).convolutionMode(ConvolutionMode.Same).list() - .layer(new Convolution1DLayer.Builder().activation(afn).kernelSize(kernel) - .stride(stride).padding(padding).nIn(convNIn).nOut(convNOut1) - .rnnDataFormat(RNNFormat.NCW) - .build()) - .layer(new LocallyConnected1D.Builder().activation(afn).kernelSize(kernel) - .stride(stride).padding(padding).nIn(convNOut1).nOut(convNOut2).hasBias(false) - .build()) - .layer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nOut(finalNOut).build()) - .setInputType(InputType.recurrent(convNIn, length)).build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).updater(new NoOp()).dist(new NormalDistribution(0, 1)).convolutionMode(ConvolutionMode.Same).list().layer(new Convolution1DLayer.Builder().activation(afn).kernelSize(kernel).stride(stride).padding(padding).nIn(convNIn).nOut(convNOut1).rnnDataFormat(RNNFormat.NCW).build()).layer(new LocallyConnected1D.Builder().activation(afn).kernelSize(kernel).stride(stride).padding(padding).nIn(convNOut1).nOut(convNOut2).hasBias(false).build()).layer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(finalNOut).build()).setInputType(InputType.recurrent(convNIn, length)).build(); String json = conf.toJson(); MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json); assertEquals(conf, c2); - MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - - String msg = "Minibatch=" + minibatchSize + ", activationFn=" - + afn + ", kernel = " + kernel; - + String msg = "Minibatch=" + minibatchSize + ", activationFn=" + afn + ", kernel = " + kernel; if (PRINT_RESULTS) { System.out.println(msg); -// for (int j = 0; j < net.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); + // for (int j = 0; j < net.getnLayers(); j++) + // System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); } - - boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); - - assertTrue(msg, gradOK); - + boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); + assertTrue(gradOK,msg); TestUtils.testModelSerialization(net); } - } } } - @Test - public void testCnn1DWithCropping1D() { + @DisplayName("Test Cnn 1 D With Cropping 1 D") + void testCnn1DWithCropping1D() { Nd4j.getRandom().setSeed(1337); - - int[] minibatchSizes = {1, 3}; + int[] minibatchSizes = { 1, 3 }; int length = 7; int convNIn = 2; int convNOut1 = 3; int convNOut2 = 4; int finalNOut = 4; - - - int[] kernels = {1, 2, 4}; + int[] kernels = { 1, 2, 4 }; int stride = 1; - int padding = 0; int cropping = 1; int croppedLength = length - 2 * cropping; - - Activation[] activations = {Activation.SIGMOID}; - SubsamplingLayer.PoolingType[] poolingTypes = - new SubsamplingLayer.PoolingType[]{SubsamplingLayer.PoolingType.MAX, - SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM}; - + Activation[] activations = { Activation.SIGMOID }; + SubsamplingLayer.PoolingType[] poolingTypes = new SubsamplingLayer.PoolingType[] { SubsamplingLayer.PoolingType.MAX, SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM }; for (Activation afn : activations) { for (SubsamplingLayer.PoolingType poolingType : poolingTypes) { for (int minibatchSize : minibatchSizes) { for (int kernel : kernels) { - INDArray input = Nd4j.rand(new int[]{minibatchSize, convNIn, length}); + INDArray input = Nd4j.rand(new int[] { minibatchSize, convNIn, length }); INDArray labels = Nd4j.zeros(minibatchSize, finalNOut, croppedLength); for (int i = 0; i < minibatchSize; i++) { for (int j = 0; j < croppedLength; j++) { - labels.putScalar(new int[]{i, i % finalNOut, j}, 1.0); + labels.putScalar(new int[] { i, i % finalNOut, j }, 1.0); } } - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .dataType(DataType.DOUBLE) - .updater(new NoOp()) - .dist(new NormalDistribution(0, 1)).convolutionMode(ConvolutionMode.Same).list() - .layer(new Convolution1DLayer.Builder().activation(afn).kernelSize(kernel) - .stride(stride).padding(padding).nOut(convNOut1) - .build()) - .layer(new Cropping1D.Builder(cropping).build()) - .layer(new Convolution1DLayer.Builder().activation(afn).kernelSize(kernel) - .stride(stride).padding(padding).nOut(convNOut2) - .build()) - .layer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nOut(finalNOut).build()) - .setInputType(InputType.recurrent(convNIn, length,RNNFormat.NCW)).build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).updater(new NoOp()).dist(new NormalDistribution(0, 1)).convolutionMode(ConvolutionMode.Same).list().layer(new Convolution1DLayer.Builder().activation(afn).kernelSize(kernel).stride(stride).padding(padding).nOut(convNOut1).build()).layer(new Cropping1D.Builder(cropping).build()).layer(new Convolution1DLayer.Builder().activation(afn).kernelSize(kernel).stride(stride).padding(padding).nOut(convNOut2).build()).layer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(finalNOut).build()).setInputType(InputType.recurrent(convNIn, length, RNNFormat.NCW)).build(); String json = conf.toJson(); MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json); assertEquals(conf, c2); - MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - - String msg = "PoolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn=" - + afn + ", kernel = " + kernel; - + String msg = "PoolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn=" + afn + ", kernel = " + kernel; if (PRINT_RESULTS) { System.out.println(msg); -// for (int j = 0; j < net.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); + // for (int j = 0; j < net.getnLayers(); j++) + // System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); } - - boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); - - assertTrue(msg, gradOK); + boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); + assertTrue(gradOK,msg); TestUtils.testModelSerialization(net); } @@ -218,82 +166,50 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest { } } - @Test - public void testCnn1DWithZeroPadding1D() { + @DisplayName("Test Cnn 1 D With Zero Padding 1 D") + void testCnn1DWithZeroPadding1D() { Nd4j.getRandom().setSeed(1337); - - int[] minibatchSizes = {1, 3}; + int[] minibatchSizes = { 1, 3 }; int length = 7; int convNIn = 2; int convNOut1 = 3; int convNOut2 = 4; int finalNOut = 4; - - - int[] kernels = {1, 2, 4}; + int[] kernels = { 1, 2, 4 }; int stride = 1; int pnorm = 2; - int padding = 0; int zeroPadding = 2; int paddedLength = length + 2 * zeroPadding; - - Activation[] activations = {Activation.SIGMOID}; - SubsamplingLayer.PoolingType[] poolingTypes = - new SubsamplingLayer.PoolingType[]{SubsamplingLayer.PoolingType.MAX, - SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM}; - + Activation[] activations = { Activation.SIGMOID }; + SubsamplingLayer.PoolingType[] poolingTypes = new SubsamplingLayer.PoolingType[] { SubsamplingLayer.PoolingType.MAX, SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM }; for (Activation afn : activations) { for (SubsamplingLayer.PoolingType poolingType : poolingTypes) { for (int minibatchSize : minibatchSizes) { for (int kernel : kernels) { - INDArray input = Nd4j.rand(new int[]{minibatchSize, convNIn, length}); + INDArray input = Nd4j.rand(new int[] { minibatchSize, convNIn, length }); INDArray labels = Nd4j.zeros(minibatchSize, finalNOut, paddedLength); for (int i = 0; i < minibatchSize; i++) { for (int j = 0; j < paddedLength; j++) { - labels.putScalar(new int[]{i, i % finalNOut, j}, 1.0); + labels.putScalar(new int[] { i, i % finalNOut, j }, 1.0); } } - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .dataType(DataType.DOUBLE) - .updater(new NoOp()) - .dist(new NormalDistribution(0, 1)).convolutionMode(ConvolutionMode.Same).list() - .layer(new Convolution1DLayer.Builder().activation(afn).kernelSize(kernel) - .stride(stride).padding(padding).nOut(convNOut1) - .build()) - .layer(new ZeroPadding1DLayer.Builder(zeroPadding).build()) - .layer(new Convolution1DLayer.Builder().activation(afn).kernelSize(kernel) - .stride(stride).padding(padding).nOut(convNOut2) - .build()) - .layer(new ZeroPadding1DLayer.Builder(0).build()) - .layer(new Subsampling1DLayer.Builder(poolingType).kernelSize(kernel) - .stride(stride).padding(padding).pnorm(pnorm).build()) - .layer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nOut(finalNOut).build()) - .setInputType(InputType.recurrent(convNIn, length,RNNFormat.NCW)).build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).updater(new NoOp()).dist(new NormalDistribution(0, 1)).convolutionMode(ConvolutionMode.Same).list().layer(new Convolution1DLayer.Builder().activation(afn).kernelSize(kernel).stride(stride).padding(padding).nOut(convNOut1).build()).layer(new ZeroPadding1DLayer.Builder(zeroPadding).build()).layer(new Convolution1DLayer.Builder().activation(afn).kernelSize(kernel).stride(stride).padding(padding).nOut(convNOut2).build()).layer(new ZeroPadding1DLayer.Builder(0).build()).layer(new Subsampling1DLayer.Builder(poolingType).kernelSize(kernel).stride(stride).padding(padding).pnorm(pnorm).build()).layer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(finalNOut).build()).setInputType(InputType.recurrent(convNIn, length, RNNFormat.NCW)).build(); String json = conf.toJson(); MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json); assertEquals(conf, c2); - MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - - String msg = "PoolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn=" - + afn + ", kernel = " + kernel; - + String msg = "PoolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn=" + afn + ", kernel = " + kernel; if (PRINT_RESULTS) { System.out.println(msg); -// for (int j = 0; j < net.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); + // for (int j = 0; j < net.getnLayers(); j++) + // System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); } + boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); + assertTrue(gradOK,msg); - boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); - - assertTrue(msg, gradOK); TestUtils.testModelSerialization(net); } } @@ -301,76 +217,48 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest { } } - @Test - public void testCnn1DWithSubsampling1D() { + @DisplayName("Test Cnn 1 D With Subsampling 1 D") + void testCnn1DWithSubsampling1D() { Nd4j.getRandom().setSeed(12345); - - int[] minibatchSizes = {1, 3}; + int[] minibatchSizes = { 1, 3 }; int length = 7; int convNIn = 2; int convNOut1 = 3; int convNOut2 = 4; int finalNOut = 4; - - int[] kernels = {1, 2, 4}; + int[] kernels = { 1, 2, 4 }; int stride = 1; int padding = 0; int pnorm = 2; - - Activation[] activations = {Activation.SIGMOID, Activation.TANH}; - SubsamplingLayer.PoolingType[] poolingTypes = - new SubsamplingLayer.PoolingType[]{SubsamplingLayer.PoolingType.MAX, - SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM}; - + Activation[] activations = { Activation.SIGMOID, Activation.TANH }; + SubsamplingLayer.PoolingType[] poolingTypes = new SubsamplingLayer.PoolingType[] { SubsamplingLayer.PoolingType.MAX, SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM }; for (Activation afn : activations) { for (SubsamplingLayer.PoolingType poolingType : poolingTypes) { for (int minibatchSize : minibatchSizes) { for (int kernel : kernels) { - INDArray input = Nd4j.rand(new int[]{minibatchSize, convNIn, length}); + INDArray input = Nd4j.rand(new int[] { minibatchSize, convNIn, length }); INDArray labels = Nd4j.zeros(minibatchSize, finalNOut, length); for (int i = 0; i < minibatchSize; i++) { for (int j = 0; j < length; j++) { - labels.putScalar(new int[]{i, i % finalNOut, j}, 1.0); + labels.putScalar(new int[] { i, i % finalNOut, j }, 1.0); } } - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .dataType(DataType.DOUBLE) - .updater(new NoOp()) - .dist(new NormalDistribution(0, 1)).convolutionMode(ConvolutionMode.Same).list() - .layer(0, new Convolution1DLayer.Builder().activation(afn).kernelSize(kernel) - .stride(stride).padding(padding).nOut(convNOut1) - .build()) - .layer(1, new Convolution1DLayer.Builder().activation(afn).kernelSize(kernel) - .stride(stride).padding(padding).nOut(convNOut2) - .build()) - .layer(2, new Subsampling1DLayer.Builder(poolingType).kernelSize(kernel) - .stride(stride).padding(padding).pnorm(pnorm).build()) - .layer(3, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nOut(finalNOut).build()) - .setInputType(InputType.recurrent(convNIn, length,RNNFormat.NCW)).build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).updater(new NoOp()).dist(new NormalDistribution(0, 1)).convolutionMode(ConvolutionMode.Same).list().layer(0, new Convolution1DLayer.Builder().activation(afn).kernelSize(kernel).stride(stride).padding(padding).nOut(convNOut1).build()).layer(1, new Convolution1DLayer.Builder().activation(afn).kernelSize(kernel).stride(stride).padding(padding).nOut(convNOut2).build()).layer(2, new Subsampling1DLayer.Builder(poolingType).kernelSize(kernel).stride(stride).padding(padding).pnorm(pnorm).build()).layer(3, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(finalNOut).build()).setInputType(InputType.recurrent(convNIn, length, RNNFormat.NCW)).build(); String json = conf.toJson(); MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json); assertEquals(conf, c2); - MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - - String msg = "PoolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn=" - + afn + ", kernel = " + kernel; - + String msg = "PoolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn=" + afn + ", kernel = " + kernel; if (PRINT_RESULTS) { System.out.println(msg); -// for (int j = 0; j < net.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); + // for (int j = 0; j < net.getnLayers(); j++) + // System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); } + boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); + assertTrue(gradOK,msg); - boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); - - assertTrue(msg, gradOK); TestUtils.testModelSerialization(net); } } @@ -379,66 +267,34 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest { } @Test - public void testCnn1dWithMasking(){ + @DisplayName("Test Cnn 1 d With Masking") + void testCnn1dWithMasking() { int length = 12; int convNIn = 2; int convNOut1 = 3; int convNOut2 = 4; int finalNOut = 3; - int pnorm = 2; - - SubsamplingLayer.PoolingType[] poolingTypes = - new SubsamplingLayer.PoolingType[] {SubsamplingLayer.PoolingType.MAX, SubsamplingLayer.PoolingType.AVG}; - + SubsamplingLayer.PoolingType[] poolingTypes = new SubsamplingLayer.PoolingType[] { SubsamplingLayer.PoolingType.MAX, SubsamplingLayer.PoolingType.AVG }; for (SubsamplingLayer.PoolingType poolingType : poolingTypes) { - for(ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Same, ConvolutionMode.Truncate}) { - for( int stride : new int[]{1, 2}){ + for (ConvolutionMode cm : new ConvolutionMode[] { ConvolutionMode.Same, ConvolutionMode.Truncate }) { + for (int stride : new int[] { 1, 2 }) { String s = cm + ", stride=" + stride + ", pooling=" + poolingType; log.info("Starting test: " + s); Nd4j.getRandom().setSeed(12345); - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .dataType(DataType.DOUBLE) - .updater(new NoOp()) - .activation(Activation.TANH) - .dist(new NormalDistribution(0, 1)).convolutionMode(cm) - .seed(12345) - .list() - .layer(new Convolution1DLayer.Builder().kernelSize(2) - .rnnDataFormat(RNNFormat.NCW) - .stride(stride).nIn(convNIn).nOut(convNOut1) - .build()) - .layer(new Subsampling1DLayer.Builder(poolingType).kernelSize(2) - .stride(stride).pnorm(pnorm).build()) - .layer(new Convolution1DLayer.Builder().kernelSize(2) - .rnnDataFormat(RNNFormat.NCW) - .stride(stride).nIn(convNOut1).nOut(convNOut2) - .build()) - .layer(new GlobalPoolingLayer(PoolingType.AVG)) - .layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nOut(finalNOut).build()) - .setInputType(InputType.recurrent(convNIn, length)).build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).updater(new NoOp()).activation(Activation.TANH).dist(new NormalDistribution(0, 1)).convolutionMode(cm).seed(12345).list().layer(new Convolution1DLayer.Builder().kernelSize(2).rnnDataFormat(RNNFormat.NCW).stride(stride).nIn(convNIn).nOut(convNOut1).build()).layer(new Subsampling1DLayer.Builder(poolingType).kernelSize(2).stride(stride).pnorm(pnorm).build()).layer(new Convolution1DLayer.Builder().kernelSize(2).rnnDataFormat(RNNFormat.NCW).stride(stride).nIn(convNOut1).nOut(convNOut2).build()).layer(new GlobalPoolingLayer(PoolingType.AVG)).layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(finalNOut).build()).setInputType(InputType.recurrent(convNIn, length)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - - INDArray f = Nd4j.rand(new int[]{2, convNIn, length}); + INDArray f = Nd4j.rand(new int[] { 2, convNIn, length }); INDArray fm = Nd4j.create(2, length); fm.get(NDArrayIndex.point(0), NDArrayIndex.all()).assign(1); - fm.get(NDArrayIndex.point(1), NDArrayIndex.interval(0,6)).assign(1); - + fm.get(NDArrayIndex.point(1), NDArrayIndex.interval(0, 6)).assign(1); INDArray label = TestUtils.randomOneHot(2, finalNOut); - - boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(f) - .labels(label).inputMask(fm)); - - assertTrue(s, gradOK); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(f).labels(label).inputMask(fm)); + assertTrue(gradOK,s); TestUtils.testModelSerialization(net); - - //TODO also check that masked step values don't impact forward pass, score or gradients - - DataSet ds = new DataSet(f,label,fm,null); + // TODO also check that masked step values don't impact forward pass, score or gradients + DataSet ds = new DataSet(f, label, fm, null); double scoreBefore = net.score(ds); net.setInput(f); net.setLabels(label); @@ -453,7 +309,6 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest { net.setLayerMaskArrays(fm, null); net.computeGradientAndScore(); INDArray gradAfter = net.getFlattenedGradients().dup(); - assertEquals(scoreBefore, scoreAfter, 1e-6); assertEquals(gradBefore, gradAfter); } @@ -462,18 +317,18 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest { } @Test - public void testCnn1Causal() throws Exception { + @DisplayName("Test Cnn 1 Causal") + void testCnn1Causal() throws Exception { int convNIn = 2; int convNOut1 = 3; int convNOut2 = 4; int finalNOut = 3; - - int[] lengths = {11, 12, 13, 9, 10, 11}; - int[] kernels = {2, 3, 2, 4, 2, 3}; - int[] dilations = {1, 1, 2, 1, 2, 1}; - int[] strides = {1, 2, 1, 2, 1, 1}; - boolean[] masks = {false, true, false, true, false, true}; - boolean[] hasB = {true, false, true, false, true, true}; + int[] lengths = { 11, 12, 13, 9, 10, 11 }; + int[] kernels = { 2, 3, 2, 4, 2, 3 }; + int[] dilations = { 1, 1, 2, 1, 2, 1 }; + int[] strides = { 1, 2, 1, 2, 1, 1 }; + boolean[] masks = { false, true, false, true, false, true }; + boolean[] hasB = { true, false, true, false, true, true }; for (int i = 0; i < lengths.length; i++) { int length = lengths[i]; int k = kernels[i]; @@ -481,36 +336,13 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest { int st = strides[i]; boolean mask = masks[i]; boolean hasBias = hasB[i]; - //TODO has bias + // TODO has bias String s = "k=" + k + ", s=" + st + " d=" + d + ", seqLen=" + length; log.info("Starting test: " + s); Nd4j.getRandom().setSeed(12345); - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .dataType(DataType.DOUBLE) - .updater(new NoOp()) - .activation(Activation.TANH) - .weightInit(new NormalDistribution(0, 1)) - .seed(12345) - .list() - .layer(new Convolution1DLayer.Builder().kernelSize(k) - .dilation(d) - .hasBias(hasBias) - .convolutionMode(ConvolutionMode.Causal) - .stride(st).nOut(convNOut1) - .build()) - .layer(new Convolution1DLayer.Builder().kernelSize(k) - .dilation(d) - .convolutionMode(ConvolutionMode.Causal) - .stride(st).nOut(convNOut2) - .build()) - .layer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nOut(finalNOut).build()) - .setInputType(InputType.recurrent(convNIn, length,RNNFormat.NCW)).build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).updater(new NoOp()).activation(Activation.TANH).weightInit(new NormalDistribution(0, 1)).seed(12345).list().layer(new Convolution1DLayer.Builder().kernelSize(k).dilation(d).hasBias(hasBias).convolutionMode(ConvolutionMode.Causal).stride(st).nOut(convNOut1).build()).layer(new Convolution1DLayer.Builder().kernelSize(k).dilation(d).convolutionMode(ConvolutionMode.Causal).stride(st).nOut(convNOut2).build()).layer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(finalNOut).build()).setInputType(InputType.recurrent(convNIn, length, RNNFormat.NCW)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - INDArray f = Nd4j.rand(DataType.DOUBLE, 2, convNIn, length); INDArray fm = null; if (mask) { @@ -518,16 +350,11 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest { fm.get(NDArrayIndex.point(0), NDArrayIndex.all()).assign(1); fm.get(NDArrayIndex.point(1), NDArrayIndex.interval(0, length - 2)).assign(1); } - long outSize1 = Convolution1DUtils.getOutputSize(length, k, st, 0, ConvolutionMode.Causal, d); long outSize2 = Convolution1DUtils.getOutputSize(outSize1, k, st, 0, ConvolutionMode.Causal, d); - - INDArray label = TestUtils.randomOneHotTimeSeries(2, finalNOut, (int)outSize2); - - boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(f) - .labels(label).inputMask(fm)); - - assertTrue(s, gradOK); + INDArray label = TestUtils.randomOneHotTimeSeries(2, finalNOut, (int) outSize2); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(f).labels(label).inputMask(fm)); + assertTrue(gradOK,s); TestUtils.testModelSerialization(net); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN3DGradientCheckTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN3DGradientCheckTest.java index b3649f97f..122f3ff86 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN3DGradientCheckTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN3DGradientCheckTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.gradientcheck; import lombok.extern.java.Log; @@ -33,7 +32,7 @@ import org.deeplearning4j.nn.conf.layers.convolutional.Cropping3D; import org.deeplearning4j.nn.conf.preprocessor.Cnn3DToFeedForwardPreProcessor; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; @@ -41,18 +40,24 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.lossfunctions.LossFunctions; - import java.util.Arrays; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; @Log -public class CNN3DGradientCheckTest extends BaseDL4JTest { +@DisplayName("Cnn 3 D Gradient Check Test") +class CNN3DGradientCheckTest extends BaseDL4JTest { + private static final boolean PRINT_RESULTS = true; + private static final boolean RETURN_ON_FIRST_FAILURE = false; + private static final double DEFAULT_EPS = 1e-6; + private static final double DEFAULT_MAX_REL_ERROR = 1e-3; + private static final double DEFAULT_MIN_ABS_ERROR = 1e-8; static { @@ -65,30 +70,23 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest { } @Test - public void testCnn3DPlain() { + @DisplayName("Test Cnn 3 D Plain") + void testCnn3DPlain() { Nd4j.getRandom().setSeed(1337); - // Note: we checked this with a variety of parameters, but it takes a lot of time. - int[] depths = {6}; - int[] heights = {6}; - int[] widths = {6}; - - - int[] minibatchSizes = {3}; + int[] depths = { 6 }; + int[] heights = { 6 }; + int[] widths = { 6 }; + int[] minibatchSizes = { 3 }; int convNIn = 2; int convNOut1 = 3; int convNOut2 = 4; int denseNOut = 5; int finalNOut = 42; - - - int[][] kernels = {{2, 2, 2}}; - int[][] strides = {{1, 1, 1}}; - - Activation[] activations = {Activation.SIGMOID}; - - ConvolutionMode[] modes = {ConvolutionMode.Truncate, ConvolutionMode.Same}; - + int[][] kernels = { { 2, 2, 2 } }; + int[][] strides = { { 1, 1, 1 } }; + Activation[] activations = { Activation.SIGMOID }; + ConvolutionMode[] modes = { ConvolutionMode.Truncate, ConvolutionMode.Same }; for (Activation afn : activations) { for (int miniBatchSize : minibatchSizes) { for (int depth : depths) { @@ -98,71 +96,34 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest { for (int[] kernel : kernels) { for (int[] stride : strides) { for (Convolution3D.DataFormat df : Convolution3D.DataFormat.values()) { - - int outDepth = mode == ConvolutionMode.Same ? - depth / stride[0] : (depth - kernel[0]) / stride[0] + 1; - int outHeight = mode == ConvolutionMode.Same ? - height / stride[1] : (height - kernel[1]) / stride[1] + 1; - int outWidth = mode == ConvolutionMode.Same ? - width / stride[2] : (width - kernel[2]) / stride[2] + 1; - + int outDepth = mode == ConvolutionMode.Same ? depth / stride[0] : (depth - kernel[0]) / stride[0] + 1; + int outHeight = mode == ConvolutionMode.Same ? height / stride[1] : (height - kernel[1]) / stride[1] + 1; + int outWidth = mode == ConvolutionMode.Same ? width / stride[2] : (width - kernel[2]) / stride[2] + 1; INDArray input; - if(df == Convolution3D.DataFormat.NDHWC){ - input = Nd4j.rand(new int[]{miniBatchSize, depth, height, width, convNIn}); + if (df == Convolution3D.DataFormat.NDHWC) { + input = Nd4j.rand(new int[] { miniBatchSize, depth, height, width, convNIn }); } else { - input = Nd4j.rand(new int[]{miniBatchSize, convNIn, depth, height, width}); + input = Nd4j.rand(new int[] { miniBatchSize, convNIn, depth, height, width }); } INDArray labels = Nd4j.zeros(miniBatchSize, finalNOut); for (int i = 0; i < miniBatchSize; i++) { - labels.putScalar(new int[]{i, i % finalNOut}, 1.0); + labels.putScalar(new int[] { i, i % finalNOut }, 1.0); } - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .dataType(DataType.DOUBLE) - .updater(new NoOp()).weightInit(WeightInit.LECUN_NORMAL) - .dist(new NormalDistribution(0, 1)) - .list() - .layer(0, new Convolution3D.Builder().activation(afn).kernelSize(kernel) - .stride(stride).nIn(convNIn).nOut(convNOut1).hasBias(false) - .convolutionMode(mode).dataFormat(df) - .build()) - .layer(1, new Convolution3D.Builder().activation(afn).kernelSize(1, 1, 1) - .nIn(convNOut1).nOut(convNOut2).hasBias(false) - .convolutionMode(mode).dataFormat(df) - .build()) - .layer(2, new DenseLayer.Builder().nOut(denseNOut).build()) - .layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nOut(finalNOut).build()) - .inputPreProcessor(2, - new Cnn3DToFeedForwardPreProcessor(outDepth, outHeight, outWidth, - convNOut2, df == Convolution3D.DataFormat.NCDHW)) - .setInputType(InputType.convolutional3D(df, depth, height, width, convNIn)).build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).updater(new NoOp()).weightInit(WeightInit.LECUN_NORMAL).dist(new NormalDistribution(0, 1)).list().layer(0, new Convolution3D.Builder().activation(afn).kernelSize(kernel).stride(stride).nIn(convNIn).nOut(convNOut1).hasBias(false).convolutionMode(mode).dataFormat(df).build()).layer(1, new Convolution3D.Builder().activation(afn).kernelSize(1, 1, 1).nIn(convNOut1).nOut(convNOut2).hasBias(false).convolutionMode(mode).dataFormat(df).build()).layer(2, new DenseLayer.Builder().nOut(denseNOut).build()).layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(finalNOut).build()).inputPreProcessor(2, new Cnn3DToFeedForwardPreProcessor(outDepth, outHeight, outWidth, convNOut2, df == Convolution3D.DataFormat.NCDHW)).setInputType(InputType.convolutional3D(df, depth, height, width, convNIn)).build(); String json = conf.toJson(); MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json); assertEquals(conf, c2); - MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - - String msg = "DataFormat = " + df + ", minibatch size = " + miniBatchSize + ", activationFn=" + afn - + ", kernel = " + Arrays.toString(kernel) + ", stride = " - + Arrays.toString(stride) + ", mode = " + mode.toString() - + ", input depth " + depth + ", input height " + height - + ", input width " + width; - + String msg = "DataFormat = " + df + ", minibatch size = " + miniBatchSize + ", activationFn=" + afn + ", kernel = " + Arrays.toString(kernel) + ", stride = " + Arrays.toString(stride) + ", mode = " + mode.toString() + ", input depth " + depth + ", input height " + height + ", input width " + width; if (PRINT_RESULTS) { log.info(msg); -// for (int j = 0; j < net.getnLayers(); j++) { -// log.info("Layer " + j + " # params: " + net.getLayer(j).numParams()); -// } + // for (int j = 0; j < net.getnLayers(); j++) { + // log.info("Layer " + j + " # params: " + net.getLayer(j).numParams()); + // } } - - boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input) - .labels(labels).subset(true).maxPerParam(128)); - - assertTrue(msg, gradOK); - + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input).labels(labels).subset(true).maxPerParam(128)); + assertTrue(gradOK,msg); TestUtils.testModelSerialization(net); } } @@ -176,186 +137,98 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest { } @Test - public void testCnn3DZeroPadding() { + @DisplayName("Test Cnn 3 D Zero Padding") + void testCnn3DZeroPadding() { Nd4j.getRandom().setSeed(42); - int depth = 4; int height = 4; int width = 4; - - - int[] minibatchSizes = {3}; + int[] minibatchSizes = { 3 }; int convNIn = 2; int convNOut1 = 3; int convNOut2 = 4; int denseNOut = 5; int finalNOut = 42; - - - int[] kernel = {2, 2, 2}; - int[] zeroPadding = {1, 1, 2, 2, 3, 3}; - - Activation[] activations = {Activation.SIGMOID}; - - ConvolutionMode[] modes = {ConvolutionMode.Truncate, ConvolutionMode.Same}; - + int[] kernel = { 2, 2, 2 }; + int[] zeroPadding = { 1, 1, 2, 2, 3, 3 }; + Activation[] activations = { Activation.SIGMOID }; + ConvolutionMode[] modes = { ConvolutionMode.Truncate, ConvolutionMode.Same }; for (Activation afn : activations) { for (int miniBatchSize : minibatchSizes) { for (ConvolutionMode mode : modes) { - - int outDepth = mode == ConvolutionMode.Same ? - depth : (depth - kernel[0]) + 1; - int outHeight = mode == ConvolutionMode.Same ? - height : (height - kernel[1]) + 1; - int outWidth = mode == ConvolutionMode.Same ? - width : (width - kernel[2]) + 1; - + int outDepth = mode == ConvolutionMode.Same ? depth : (depth - kernel[0]) + 1; + int outHeight = mode == ConvolutionMode.Same ? height : (height - kernel[1]) + 1; + int outWidth = mode == ConvolutionMode.Same ? width : (width - kernel[2]) + 1; outDepth += zeroPadding[0] + zeroPadding[1]; outHeight += zeroPadding[2] + zeroPadding[3]; outWidth += zeroPadding[4] + zeroPadding[5]; - - INDArray input = Nd4j.rand(new int[]{miniBatchSize, convNIn, depth, height, width}); + INDArray input = Nd4j.rand(new int[] { miniBatchSize, convNIn, depth, height, width }); INDArray labels = Nd4j.zeros(miniBatchSize, finalNOut); for (int i = 0; i < miniBatchSize; i++) { - labels.putScalar(new int[]{i, i % finalNOut}, 1.0); + labels.putScalar(new int[] { i, i % finalNOut }, 1.0); } - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .dataType(DataType.DOUBLE) - .updater(new NoOp()).weightInit(WeightInit.LECUN_NORMAL) - .dist(new NormalDistribution(0, 1)) - .list() - .layer(0, new Convolution3D.Builder().activation(afn).kernelSize(kernel) - .nIn(convNIn).nOut(convNOut1).hasBias(false) - .convolutionMode(mode).dataFormat(Convolution3D.DataFormat.NCDHW) - .build()) - .layer(1, new Convolution3D.Builder().activation(afn).kernelSize(1, 1, 1) - .nIn(convNOut1).nOut(convNOut2).hasBias(false) - .convolutionMode(mode).dataFormat(Convolution3D.DataFormat.NCDHW) - .build()) - .layer(2, new ZeroPadding3DLayer.Builder(zeroPadding).build()) - .layer(3, new DenseLayer.Builder().nOut(denseNOut).build()) - .layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nOut(finalNOut).build()) - .inputPreProcessor(3, - new Cnn3DToFeedForwardPreProcessor(outDepth, outHeight, outWidth, - convNOut2, true)) - .setInputType(InputType.convolutional3D(depth, height, width, convNIn)).build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).updater(new NoOp()).weightInit(WeightInit.LECUN_NORMAL).dist(new NormalDistribution(0, 1)).list().layer(0, new Convolution3D.Builder().activation(afn).kernelSize(kernel).nIn(convNIn).nOut(convNOut1).hasBias(false).convolutionMode(mode).dataFormat(Convolution3D.DataFormat.NCDHW).build()).layer(1, new Convolution3D.Builder().activation(afn).kernelSize(1, 1, 1).nIn(convNOut1).nOut(convNOut2).hasBias(false).convolutionMode(mode).dataFormat(Convolution3D.DataFormat.NCDHW).build()).layer(2, new ZeroPadding3DLayer.Builder(zeroPadding).build()).layer(3, new DenseLayer.Builder().nOut(denseNOut).build()).layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(finalNOut).build()).inputPreProcessor(3, new Cnn3DToFeedForwardPreProcessor(outDepth, outHeight, outWidth, convNOut2, true)).setInputType(InputType.convolutional3D(depth, height, width, convNIn)).build(); String json = conf.toJson(); MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json); assertEquals(conf, c2); - MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - - String msg = "Minibatch size = " + miniBatchSize + ", activationFn=" + afn - + ", kernel = " + Arrays.toString(kernel) + ", mode = " + mode.toString() - + ", input depth " + depth + ", input height " + height - + ", input width " + width; - + String msg = "Minibatch size = " + miniBatchSize + ", activationFn=" + afn + ", kernel = " + Arrays.toString(kernel) + ", mode = " + mode.toString() + ", input depth " + depth + ", input height " + height + ", input width " + width; if (PRINT_RESULTS) { log.info(msg); -// for (int j = 0; j < net.getnLayers(); j++) { -// log.info("Layer " + j + " # params: " + net.getLayer(j).numParams()); -// } + // for (int j = 0; j < net.getnLayers(); j++) { + // log.info("Layer " + j + " # params: " + net.getLayer(j).numParams()); + // } } - - boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input) - .labels(labels).subset(true).maxPerParam(512)); - - assertTrue(msg, gradOK); - + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input).labels(labels).subset(true).maxPerParam(512)); + assertTrue(gradOK,msg); TestUtils.testModelSerialization(net); } - } } } - @Test - public void testCnn3DPooling() { + @DisplayName("Test Cnn 3 D Pooling") + void testCnn3DPooling() { Nd4j.getRandom().setSeed(42); - int depth = 4; int height = 4; int width = 4; - - - int[] minibatchSizes = {3}; + int[] minibatchSizes = { 3 }; int convNIn = 2; int convNOut = 4; int denseNOut = 5; int finalNOut = 42; - - int[] kernel = {2, 2, 2}; - - Activation[] activations = {Activation.SIGMOID}; - - Subsampling3DLayer.PoolingType[] poolModes = {Subsampling3DLayer.PoolingType.AVG}; - - ConvolutionMode[] modes = {ConvolutionMode.Truncate}; - + int[] kernel = { 2, 2, 2 }; + Activation[] activations = { Activation.SIGMOID }; + Subsampling3DLayer.PoolingType[] poolModes = { Subsampling3DLayer.PoolingType.AVG }; + ConvolutionMode[] modes = { ConvolutionMode.Truncate }; for (Activation afn : activations) { for (int miniBatchSize : minibatchSizes) { for (Subsampling3DLayer.PoolingType pool : poolModes) { for (ConvolutionMode mode : modes) { for (Convolution3D.DataFormat df : Convolution3D.DataFormat.values()) { - int outDepth = depth / kernel[0]; int outHeight = height / kernel[1]; int outWidth = width / kernel[2]; - - INDArray input = Nd4j.rand( - df == Convolution3D.DataFormat.NCDHW ? new int[]{miniBatchSize, convNIn, depth, height, width} - : new int[]{miniBatchSize, depth, height, width, convNIn}); + INDArray input = Nd4j.rand(df == Convolution3D.DataFormat.NCDHW ? new int[] { miniBatchSize, convNIn, depth, height, width } : new int[] { miniBatchSize, depth, height, width, convNIn }); INDArray labels = Nd4j.zeros(miniBatchSize, finalNOut); for (int i = 0; i < miniBatchSize; i++) { - labels.putScalar(new int[]{i, i % finalNOut}, 1.0); + labels.putScalar(new int[] { i, i % finalNOut }, 1.0); } - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .dataType(DataType.DOUBLE) - .updater(new NoOp()) - .weightInit(WeightInit.XAVIER) - .dist(new NormalDistribution(0, 1)) - .list() - .layer(0, new Convolution3D.Builder().activation(afn).kernelSize(1, 1, 1) - .nIn(convNIn).nOut(convNOut).hasBias(false) - .convolutionMode(mode).dataFormat(df) - .build()) - .layer(1, new Subsampling3DLayer.Builder(kernel) - .poolingType(pool).convolutionMode(mode).dataFormat(df).build()) - .layer(2, new DenseLayer.Builder().nOut(denseNOut).build()) - .layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nOut(finalNOut).build()) - .inputPreProcessor(2, - new Cnn3DToFeedForwardPreProcessor(outDepth, outHeight, outWidth,convNOut, df)) - .setInputType(InputType.convolutional3D(df, depth, height, width, convNIn)).build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).updater(new NoOp()).weightInit(WeightInit.XAVIER).dist(new NormalDistribution(0, 1)).list().layer(0, new Convolution3D.Builder().activation(afn).kernelSize(1, 1, 1).nIn(convNIn).nOut(convNOut).hasBias(false).convolutionMode(mode).dataFormat(df).build()).layer(1, new Subsampling3DLayer.Builder(kernel).poolingType(pool).convolutionMode(mode).dataFormat(df).build()).layer(2, new DenseLayer.Builder().nOut(denseNOut).build()).layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(finalNOut).build()).inputPreProcessor(2, new Cnn3DToFeedForwardPreProcessor(outDepth, outHeight, outWidth, convNOut, df)).setInputType(InputType.convolutional3D(df, depth, height, width, convNIn)).build(); String json = conf.toJson(); MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json); assertEquals(conf, c2); - MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - - String msg = "Minibatch size = " + miniBatchSize + ", activationFn=" + afn - + ", kernel = " + Arrays.toString(kernel) + ", mode = " + mode.toString() - + ", input depth " + depth + ", input height " + height - + ", input width " + width + ", dataFormat=" + df; - + String msg = "Minibatch size = " + miniBatchSize + ", activationFn=" + afn + ", kernel = " + Arrays.toString(kernel) + ", mode = " + mode.toString() + ", input depth " + depth + ", input height " + height + ", input width " + width + ", dataFormat=" + df; if (PRINT_RESULTS) { log.info(msg); } - - boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, - DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, - RETURN_ON_FIRST_FAILURE, input, labels); - - assertTrue(msg, gradOK); - + boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); + assertTrue(gradOK,msg); TestUtils.testModelSerialization(net); } } @@ -365,87 +238,47 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest { } @Test - public void testCnn3DUpsampling() { + @DisplayName("Test Cnn 3 D Upsampling") + void testCnn3DUpsampling() { Nd4j.getRandom().setSeed(42); - int depth = 2; int height = 2; int width = 2; - - - int[] minibatchSizes = {3}; + int[] minibatchSizes = { 3 }; int convNIn = 2; int convNOut = 4; int denseNOut = 5; int finalNOut = 42; - - - int[] upsamplingSize = {2, 2, 2}; - - Activation[] activations = {Activation.SIGMOID}; - - - ConvolutionMode[] modes = {ConvolutionMode.Truncate}; - + int[] upsamplingSize = { 2, 2, 2 }; + Activation[] activations = { Activation.SIGMOID }; + ConvolutionMode[] modes = { ConvolutionMode.Truncate }; for (Activation afn : activations) { for (int miniBatchSize : minibatchSizes) { for (ConvolutionMode mode : modes) { - for(Convolution3D.DataFormat df : Convolution3D.DataFormat.values()) { - + for (Convolution3D.DataFormat df : Convolution3D.DataFormat.values()) { int outDepth = depth * upsamplingSize[0]; int outHeight = height * upsamplingSize[1]; int outWidth = width * upsamplingSize[2]; - INDArray input = df == Convolution3D.DataFormat.NCDHW ? Nd4j.rand(miniBatchSize, convNIn, depth, height, width) : Nd4j.rand(miniBatchSize, depth, height, width, convNIn); INDArray labels = Nd4j.zeros(miniBatchSize, finalNOut); for (int i = 0; i < miniBatchSize; i++) { - labels.putScalar(new int[]{i, i % finalNOut}, 1.0); + labels.putScalar(new int[] { i, i % finalNOut }, 1.0); } - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .dataType(DataType.DOUBLE) - .updater(new NoOp()).weightInit(WeightInit.LECUN_NORMAL) - .dist(new NormalDistribution(0, 1)) - .seed(12345) - .list() - .layer(0, new Convolution3D.Builder().activation(afn).kernelSize(1, 1, 1) - .nIn(convNIn).nOut(convNOut).hasBias(false) - .convolutionMode(mode).dataFormat(df) - .build()) - .layer(1, new Upsampling3D.Builder(upsamplingSize[0]).dataFormat(df).build()) - .layer(2, new DenseLayer.Builder().nOut(denseNOut).build()) - .layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nOut(finalNOut).build()) - .inputPreProcessor(2, - new Cnn3DToFeedForwardPreProcessor(outDepth, outHeight, outWidth, - convNOut, true)) - .setInputType(InputType.convolutional3D(df, depth, height, width, convNIn)).build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).updater(new NoOp()).weightInit(WeightInit.LECUN_NORMAL).dist(new NormalDistribution(0, 1)).seed(12345).list().layer(0, new Convolution3D.Builder().activation(afn).kernelSize(1, 1, 1).nIn(convNIn).nOut(convNOut).hasBias(false).convolutionMode(mode).dataFormat(df).build()).layer(1, new Upsampling3D.Builder(upsamplingSize[0]).dataFormat(df).build()).layer(2, new DenseLayer.Builder().nOut(denseNOut).build()).layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(finalNOut).build()).inputPreProcessor(2, new Cnn3DToFeedForwardPreProcessor(outDepth, outHeight, outWidth, convNOut, true)).setInputType(InputType.convolutional3D(df, depth, height, width, convNIn)).build(); String json = conf.toJson(); MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json); assertEquals(conf, c2); - MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - - String msg = "Minibatch size = " + miniBatchSize + ", activationFn=" + afn - + ", kernel = " + Arrays.toString(upsamplingSize) + ", mode = " + mode.toString() - + ", input depth " + depth + ", input height " + height - + ", input width " + width; - + String msg = "Minibatch size = " + miniBatchSize + ", activationFn=" + afn + ", kernel = " + Arrays.toString(upsamplingSize) + ", mode = " + mode.toString() + ", input depth " + depth + ", input height " + height + ", input width " + width; if (PRINT_RESULTS) { log.info(msg); -// for (int j = 0; j < net.getnLayers(); j++) { -// log.info("Layer " + j + " # params: " + net.getLayer(j).numParams()); -// } + // for (int j = 0; j < net.getnLayers(); j++) { + // log.info("Layer " + j + " # params: " + net.getLayer(j).numParams()); + // } } - - boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, - DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, - RETURN_ON_FIRST_FAILURE, input, labels); - - assertTrue(msg, gradOK); - + boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); + assertTrue(gradOK,msg); TestUtils.testModelSerialization(net); } } @@ -454,126 +287,74 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest { } @Test - public void testCnn3DCropping() { + @DisplayName("Test Cnn 3 D Cropping") + void testCnn3DCropping() { Nd4j.getRandom().setSeed(42); - int depth = 6; int height = 6; int width = 6; - - - int[] minibatchSizes = {3}; + int[] minibatchSizes = { 3 }; int convNIn = 2; int convNOut1 = 3; int convNOut2 = 4; int denseNOut = 5; int finalNOut = 8; - - - int[] kernel = {1, 1, 1}; - int[] cropping = {0, 0, 1, 1, 2, 2}; - - Activation[] activations = {Activation.SIGMOID}; - - ConvolutionMode[] modes = {ConvolutionMode.Same}; - + int[] kernel = { 1, 1, 1 }; + int[] cropping = { 0, 0, 1, 1, 2, 2 }; + Activation[] activations = { Activation.SIGMOID }; + ConvolutionMode[] modes = { ConvolutionMode.Same }; for (Activation afn : activations) { for (int miniBatchSize : minibatchSizes) { for (ConvolutionMode mode : modes) { - - int outDepth = mode == ConvolutionMode.Same ? - depth : (depth - kernel[0]) + 1; - int outHeight = mode == ConvolutionMode.Same ? - height : (height - kernel[1]) + 1; - int outWidth = mode == ConvolutionMode.Same ? - width : (width - kernel[2]) + 1; - + int outDepth = mode == ConvolutionMode.Same ? depth : (depth - kernel[0]) + 1; + int outHeight = mode == ConvolutionMode.Same ? height : (height - kernel[1]) + 1; + int outWidth = mode == ConvolutionMode.Same ? width : (width - kernel[2]) + 1; outDepth -= cropping[0] + cropping[1]; outHeight -= cropping[2] + cropping[3]; outWidth -= cropping[4] + cropping[5]; - - INDArray input = Nd4j.rand(new int[]{miniBatchSize, convNIn, depth, height, width}); + INDArray input = Nd4j.rand(new int[] { miniBatchSize, convNIn, depth, height, width }); INDArray labels = Nd4j.zeros(miniBatchSize, finalNOut); for (int i = 0; i < miniBatchSize; i++) { - labels.putScalar(new int[]{i, i % finalNOut}, 1.0); + labels.putScalar(new int[] { i, i % finalNOut }, 1.0); } - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .dataType(DataType.DOUBLE) - .updater(new NoOp()).weightInit(WeightInit.LECUN_NORMAL) - .dist(new NormalDistribution(0, 1)) - .list() - .layer(0, new Convolution3D.Builder().activation(afn).kernelSize(kernel) - .nIn(convNIn).nOut(convNOut1).hasBias(false) - .convolutionMode(mode).dataFormat(Convolution3D.DataFormat.NCDHW) - .build()) - .layer(1, new Convolution3D.Builder().activation(afn).kernelSize(1, 1, 1) - .nIn(convNOut1).nOut(convNOut2).hasBias(false) - .convolutionMode(mode).dataFormat(Convolution3D.DataFormat.NCDHW) - .build()) - .layer(2, new Cropping3D.Builder(cropping).build()) - .layer(3, new DenseLayer.Builder().nOut(denseNOut).build()) - .layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nOut(finalNOut).build()) - .inputPreProcessor(3, - new Cnn3DToFeedForwardPreProcessor(outDepth, outHeight, outWidth, - convNOut2, true)) - .setInputType(InputType.convolutional3D(depth, height, width, convNIn)).build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).updater(new NoOp()).weightInit(WeightInit.LECUN_NORMAL).dist(new NormalDistribution(0, 1)).list().layer(0, new Convolution3D.Builder().activation(afn).kernelSize(kernel).nIn(convNIn).nOut(convNOut1).hasBias(false).convolutionMode(mode).dataFormat(Convolution3D.DataFormat.NCDHW).build()).layer(1, new Convolution3D.Builder().activation(afn).kernelSize(1, 1, 1).nIn(convNOut1).nOut(convNOut2).hasBias(false).convolutionMode(mode).dataFormat(Convolution3D.DataFormat.NCDHW).build()).layer(2, new Cropping3D.Builder(cropping).build()).layer(3, new DenseLayer.Builder().nOut(denseNOut).build()).layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(finalNOut).build()).inputPreProcessor(3, new Cnn3DToFeedForwardPreProcessor(outDepth, outHeight, outWidth, convNOut2, true)).setInputType(InputType.convolutional3D(depth, height, width, convNIn)).build(); String json = conf.toJson(); MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json); assertEquals(conf, c2); - MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - - String msg = "Minibatch size = " + miniBatchSize + ", activationFn=" + afn - + ", kernel = " + Arrays.toString(kernel) + ", mode = " + mode.toString() - + ", input depth " + depth + ", input height " + height - + ", input width " + width; - + String msg = "Minibatch size = " + miniBatchSize + ", activationFn=" + afn + ", kernel = " + Arrays.toString(kernel) + ", mode = " + mode.toString() + ", input depth " + depth + ", input height " + height + ", input width " + width; if (PRINT_RESULTS) { log.info(msg); -// for (int j = 0; j < net.getnLayers(); j++) { -// log.info("Layer " + j + " # params: " + net.getLayer(j).numParams()); -// } + // for (int j = 0; j < net.getnLayers(); j++) { + // log.info("Layer " + j + " # params: " + net.getLayer(j).numParams()); + // } } - - boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, - DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, - RETURN_ON_FIRST_FAILURE, input, labels); - - assertTrue(msg, gradOK); - + boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); + assertTrue(gradOK,msg); TestUtils.testModelSerialization(net); } - } } } @Test - public void testDeconv3d() { + @DisplayName("Test Deconv 3 d") + void testDeconv3d() { Nd4j.getRandom().setSeed(12345); // Note: we checked this with a variety of parameters, but it takes a lot of time. - int[] depths = {8, 8, 9}; - int[] heights = {8, 9, 9}; - int[] widths = {8, 8, 9}; - - - int[][] kernels = {{2, 2, 2}, {3, 3, 3}, {2, 3, 2}}; - int[][] strides = {{1, 1, 1}, {1, 1, 1}, {2, 2, 2}}; - - Activation[] activations = {Activation.SIGMOID, Activation.TANH, Activation.IDENTITY}; - - ConvolutionMode[] modes = {ConvolutionMode.Truncate, ConvolutionMode.Same, ConvolutionMode.Same}; - int[] mbs = {1, 3, 2}; - Convolution3D.DataFormat[] dataFormats = new Convolution3D.DataFormat[]{Convolution3D.DataFormat.NCDHW, Convolution3D.DataFormat.NDHWC, Convolution3D.DataFormat.NCDHW}; - + int[] depths = { 8, 8, 9 }; + int[] heights = { 8, 9, 9 }; + int[] widths = { 8, 8, 9 }; + int[][] kernels = { { 2, 2, 2 }, { 3, 3, 3 }, { 2, 3, 2 } }; + int[][] strides = { { 1, 1, 1 }, { 1, 1, 1 }, { 2, 2, 2 } }; + Activation[] activations = { Activation.SIGMOID, Activation.TANH, Activation.IDENTITY }; + ConvolutionMode[] modes = { ConvolutionMode.Truncate, ConvolutionMode.Same, ConvolutionMode.Same }; + int[] mbs = { 1, 3, 2 }; + Convolution3D.DataFormat[] dataFormats = new Convolution3D.DataFormat[] { Convolution3D.DataFormat.NCDHW, Convolution3D.DataFormat.NDHWC, Convolution3D.DataFormat.NCDHW }; int convNIn = 2; int finalNOut = 2; - int[] deconvOut = {2, 3, 4}; - + int[] deconvOut = { 2, 3, 4 }; for (int i = 0; i < activations.length; i++) { Activation afn = activations[i]; int miniBatchSize = mbs[i]; @@ -585,57 +366,28 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest { int[] stride = strides[i]; Convolution3D.DataFormat df = dataFormats[i]; int dOut = deconvOut[i]; - INDArray input; if (df == Convolution3D.DataFormat.NDHWC) { - input = Nd4j.rand(new int[]{miniBatchSize, depth, height, width, convNIn}); + input = Nd4j.rand(new int[] { miniBatchSize, depth, height, width, convNIn }); } else { - input = Nd4j.rand(new int[]{miniBatchSize, convNIn, depth, height, width}); + input = Nd4j.rand(new int[] { miniBatchSize, convNIn, depth, height, width }); } INDArray labels = Nd4j.zeros(miniBatchSize, finalNOut); for (int j = 0; j < miniBatchSize; j++) { - labels.putScalar(new int[]{j, j % finalNOut}, 1.0); + labels.putScalar(new int[] { j, j % finalNOut }, 1.0); } - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .dataType(DataType.DOUBLE) - .updater(new NoOp()) - .weightInit(new NormalDistribution(0, 0.1)) - .list() - .layer(0, new Convolution3D.Builder().activation(afn).kernelSize(kernel) - .stride(stride).nIn(convNIn).nOut(dOut).hasBias(false) - .convolutionMode(mode).dataFormat(df) - .build()) - .layer(1, new Deconvolution3D.Builder().activation(afn).kernelSize(kernel) - .stride(stride).nOut(dOut).hasBias(false) - .convolutionMode(mode).dataFormat(df) - .build()) - .layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nOut(finalNOut).build()) - .setInputType(InputType.convolutional3D(df, depth, height, width, convNIn)).build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).updater(new NoOp()).weightInit(new NormalDistribution(0, 0.1)).list().layer(0, new Convolution3D.Builder().activation(afn).kernelSize(kernel).stride(stride).nIn(convNIn).nOut(dOut).hasBias(false).convolutionMode(mode).dataFormat(df).build()).layer(1, new Deconvolution3D.Builder().activation(afn).kernelSize(kernel).stride(stride).nOut(dOut).hasBias(false).convolutionMode(mode).dataFormat(df).build()).layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(finalNOut).build()).setInputType(InputType.convolutional3D(df, depth, height, width, convNIn)).build(); String json = conf.toJson(); MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json); assertEquals(conf, c2); - MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - - String msg = "DataFormat = " + df + ", minibatch size = " + miniBatchSize + ", activationFn=" + afn - + ", kernel = " + Arrays.toString(kernel) + ", stride = " - + Arrays.toString(stride) + ", mode = " + mode.toString() - + ", input depth " + depth + ", input height " + height - + ", input width " + width; - + String msg = "DataFormat = " + df + ", minibatch size = " + miniBatchSize + ", activationFn=" + afn + ", kernel = " + Arrays.toString(kernel) + ", stride = " + Arrays.toString(stride) + ", mode = " + mode.toString() + ", input depth " + depth + ", input height " + height + ", input width " + width; if (PRINT_RESULTS) { log.info(msg); } - - boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input) - .labels(labels).subset(true).maxPerParam(64)); - - assertTrue(msg, gradOK); - + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input).labels(labels).subset(true).maxPerParam(64)); + assertTrue(gradOK,msg); TestUtils.testModelSerialization(net); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java index c0f333690..475c45142 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.gradientcheck; import org.deeplearning4j.BaseDL4JTest; @@ -36,8 +35,8 @@ import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.activations.Activation; @@ -47,19 +46,25 @@ import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.lossfunctions.LossFunctions; - import java.util.Arrays; - import static org.deeplearning4j.nn.conf.ConvolutionMode.Same; import static org.deeplearning4j.nn.conf.ConvolutionMode.Truncate; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; @RunWith(Parameterized.class) -public class CNNGradientCheckTest extends BaseDL4JTest { +@DisplayName("Cnn Gradient Check Test") +class CNNGradientCheckTest extends BaseDL4JTest { + private static final boolean PRINT_RESULTS = true; + private static final boolean RETURN_ON_FIRST_FAILURE = false; + private static final double DEFAULT_EPS = 1e-6; + private static final double DEFAULT_MAX_REL_ERROR = 1e-3; + private static final double DEFAULT_MIN_ABS_ERROR = 1e-8; static { @@ -68,12 +73,12 @@ public class CNNGradientCheckTest extends BaseDL4JTest { private CNN2DFormat format; - public CNNGradientCheckTest(CNN2DFormat format){ + public CNNGradientCheckTest(CNN2DFormat format) { this.format = format; } @Parameterized.Parameters(name = "{0}") - public static Object[] params(){ + public static Object[] params() { return CNN2DFormat.values(); } @@ -83,75 +88,55 @@ public class CNNGradientCheckTest extends BaseDL4JTest { } @Test - public void testGradientCNNMLN() { - if(this.format != CNN2DFormat.NCHW) //Only test NCHW due to flat input format... + @DisplayName("Test Gradient CNNMLN") + void testGradientCNNMLN() { + if (// Only test NCHW due to flat input format... + this.format != CNN2DFormat.NCHW) return; - - //Parameterized test, testing combinations of: + // Parameterized test, testing combinations of: // (a) activation function // (b) Whether to test at random initialization, or after some learning (i.e., 'characteristic mode of operation') // (c) Loss function (with specified output activations) - Activation[] activFns = {Activation.SIGMOID, Activation.TANH}; - boolean[] characteristic = {false, true}; //If true: run some backprop steps first - - LossFunctions.LossFunction[] lossFunctions = - {LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD, LossFunctions.LossFunction.MSE}; - Activation[] outputActivations = {Activation.SOFTMAX, Activation.TANH}; //i.e., lossFunctions[i] used with outputActivations[i] here - + Activation[] activFns = { Activation.SIGMOID, Activation.TANH }; + // If true: run some backprop steps first + boolean[] characteristic = { false, true }; + LossFunctions.LossFunction[] lossFunctions = { LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD, LossFunctions.LossFunction.MSE }; + // i.e., lossFunctions[i] used with outputActivations[i] here + Activation[] outputActivations = { Activation.SOFTMAX, Activation.TANH }; DataSet ds = new IrisDataSetIterator(150, 150).next(); ds.normalizeZeroMeanZeroUnitVariance(); INDArray input = ds.getFeatures(); INDArray labels = ds.getLabels(); - for (Activation afn : activFns) { for (boolean doLearningFirst : characteristic) { for (int i = 0; i < lossFunctions.length; i++) { LossFunctions.LossFunction lf = lossFunctions[i]; Activation outputActivation = outputActivations[i]; - - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder() - .dataType(DataType.DOUBLE) - .optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).updater(new NoOp()) - .weightInit(WeightInit.XAVIER).seed(12345L).list() - .layer(0, new ConvolutionLayer.Builder(1, 1).nOut(6).activation(afn).build()) - .layer(1, new OutputLayer.Builder(lf).activation(outputActivation).nOut(3).build()) - .setInputType(InputType.convolutionalFlat(1, 4, 1)); - + MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).updater(new NoOp()).weightInit(WeightInit.XAVIER).seed(12345L).list().layer(0, new ConvolutionLayer.Builder(1, 1).nOut(6).activation(afn).build()).layer(1, new OutputLayer.Builder(lf).activation(outputActivation).nOut(3).build()).setInputType(InputType.convolutionalFlat(1, 4, 1)); MultiLayerConfiguration conf = builder.build(); - MultiLayerNetwork mln = new MultiLayerNetwork(conf); mln.init(); String name = new Object() { }.getClass().getEnclosingMethod().getName(); - if (doLearningFirst) { - //Run a number of iterations of learning + // Run a number of iterations of learning mln.setInput(ds.getFeatures()); mln.setLabels(ds.getLabels()); mln.computeGradientAndScore(); double scoreBefore = mln.score(); - for (int j = 0; j < 10; j++) - mln.fit(ds); + for (int j = 0; j < 10; j++) mln.fit(ds); mln.computeGradientAndScore(); double scoreAfter = mln.score(); - //Can't test in 'characteristic mode of operation' if not learning - String msg = name + " - score did not (sufficiently) decrease during learning - activationFn=" - + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation - + ", doLearningFirst= " + doLearningFirst + " (before=" + scoreBefore - + ", scoreAfter=" + scoreAfter + ")"; - assertTrue(msg, scoreAfter < 0.9 * scoreBefore); + // Can't test in 'characteristic mode of operation' if not learning + String msg = name + " - score did not (sufficiently) decrease during learning - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", doLearningFirst= " + doLearningFirst + " (before=" + scoreBefore + ", scoreAfter=" + scoreAfter + ")"; + assertTrue(scoreAfter < 0.9 * scoreBefore,msg); } - if (PRINT_RESULTS) { - System.out.println(name + " - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" - + outputActivation + ", doLearningFirst=" + doLearningFirst); -// for (int j = 0; j < mln.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); + System.out.println(name + " - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", doLearningFirst=" + doLearningFirst); + // for (int j = 0; j < mln.getnLayers(); j++) + // System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); } - - boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); - + boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); assertTrue(gradOK); TestUtils.testModelSerialization(mln); } @@ -159,364 +144,219 @@ public class CNNGradientCheckTest extends BaseDL4JTest { } } - @Test - public void testGradientCNNL1L2MLN() { - if(this.format != CNN2DFormat.NCHW) //Only test NCHW due to flat input format... + @DisplayName("Test Gradient CNNL 1 L 2 MLN") + void testGradientCNNL1L2MLN() { + if (// Only test NCHW due to flat input format... + this.format != CNN2DFormat.NCHW) return; - - //Parameterized test, testing combinations of: + // Parameterized test, testing combinations of: // (a) activation function // (b) Whether to test at random initialization, or after some learning (i.e., 'characteristic mode of operation') // (c) Loss function (with specified output activations) - DataSet ds = new IrisDataSetIterator(150, 150).next(); ds.normalizeZeroMeanZeroUnitVariance(); INDArray input = ds.getFeatures(); INDArray labels = ds.getLabels(); - - //use l2vals[i] with l1vals[i] - double[] l2vals = {0.4, 0.0, 0.4, 0.4}; - double[] l1vals = {0.0, 0.0, 0.5, 0.0}; - double[] biasL2 = {0.0, 0.0, 0.0, 0.2}; - double[] biasL1 = {0.0, 0.0, 0.6, 0.0}; - Activation[] activFns = {Activation.SIGMOID, Activation.TANH, Activation.ELU, Activation.SOFTPLUS}; - boolean[] characteristic = {false, true, false, true}; //If true: run some backprop steps first - - LossFunctions.LossFunction[] lossFunctions = - {LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD, LossFunctions.LossFunction.MSE, LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD, LossFunctions.LossFunction.MSE}; - Activation[] outputActivations = {Activation.SOFTMAX, Activation.TANH, Activation.SOFTMAX, Activation.IDENTITY}; //i.e., lossFunctions[i] used with outputActivations[i] here - - for( int i=0; i (mb,4,2,2) - .layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nIn(2 * 2 * 4) - .nOut(nOut).build()) - .setInputType(InputType.convolutionalFlat(height, width, inputDepth)) - .build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).updater(new NoOp()).dist(new NormalDistribution(0, 1)).list().layer(new ConvolutionLayer.Builder(kernel).nIn(inputDepth).hasBias(false).nOut(1).build()).layer(new SpaceToDepthLayer.Builder(blocks, SpaceToDepthLayer.DataFormat.NCHW).build()).layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(2 * 2 * 4).nOut(nOut).build()).setInputType(InputType.convolutionalFlat(height, width, inputDepth)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - - String msg = "PoolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn=" - + afn; - + String msg = "PoolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn=" + afn; if (PRINT_RESULTS) { System.out.println(msg); -// for (int j = 0; j < net.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); + // for (int j = 0; j < net.getnLayers(); j++) + // System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); } - boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); - - assertTrue(msg, gradOK); - + boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); + assertTrue(gradOK,msg); TestUtils.testModelSerialization(net); } } } @Test - public void testCnnWithSpaceToBatch() { + @DisplayName("Test Cnn With Space To Batch") + void testCnnWithSpaceToBatch() { Nd4j.getRandom().setSeed(12345); int nOut = 4; - - int[] minibatchSizes = {2, 4}; + int[] minibatchSizes = { 2, 4 }; int width = 5; int height = 5; int inputDepth = 1; - - int[] kernel = {2, 2}; - int[] blocks = {2, 2}; - - String[] activations = {"sigmoid", "tanh"}; - SubsamplingLayer.PoolingType[] poolingTypes = - new SubsamplingLayer.PoolingType[]{SubsamplingLayer.PoolingType.MAX, - SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM}; - + int[] kernel = { 2, 2 }; + int[] blocks = { 2, 2 }; + String[] activations = { "sigmoid", "tanh" }; + SubsamplingLayer.PoolingType[] poolingTypes = new SubsamplingLayer.PoolingType[] { SubsamplingLayer.PoolingType.MAX, SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM }; boolean nchw = format == CNN2DFormat.NCHW; for (String afn : activations) { for (SubsamplingLayer.PoolingType poolingType : poolingTypes) { for (int minibatchSize : minibatchSizes) { - long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth}; + long[] inShape = nchw ? new long[] { minibatchSize, inputDepth, height, width } : new long[] { minibatchSize, height, width, inputDepth }; INDArray input = Nd4j.rand(DataType.DOUBLE, inShape); INDArray labels = Nd4j.zeros(4 * minibatchSize, nOut); for (int i = 0; i < 4 * minibatchSize; i++) { - labels.putScalar(new int[]{i, i % nOut}, 1.0); + labels.putScalar(new int[] { i, i % nOut }, 1.0); } - - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder() - .dataType(DataType.DOUBLE) - .updater(new NoOp()).weightInit(new NormalDistribution(0, 1)) - .list() - .layer(new ConvolutionLayer.Builder(kernel) - .nIn(inputDepth).nOut(3) - .dataFormat(format) - .build()) - .layer(new SpaceToBatchLayer.Builder(blocks) - .dataFormat(format) - .build()) //trivial space to batch - .layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX) - .nOut(nOut).build()) - .setInputType(InputType.convolutional(height, width, inputDepth, format)) - .build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).updater(new NoOp()).weightInit(new NormalDistribution(0, 1)).list().layer(new ConvolutionLayer.Builder(kernel).nIn(inputDepth).nOut(3).dataFormat(format).build()).layer(new SpaceToBatchLayer.Builder(blocks).dataFormat(format).build()).layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(nOut).build()).setInputType(InputType.convolutional(height, width, inputDepth, format)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - - String msg = format + " - poolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn=" - + afn; - + String msg = format + " - poolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn=" + afn; if (PRINT_RESULTS) { System.out.println(msg); -// for (int j = 0; j < net.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); + // for (int j = 0; j < net.getnLayers(); j++) + // System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); } - boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); - - assertTrue(msg, gradOK); - - //Also check compgraph: + boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); + assertTrue(gradOK,msg); + // Also check compgraph: ComputationGraph cg = net.toComputationGraph(); - gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(cg).inputs(new INDArray[]{input}) - .labels(new INDArray[]{labels})); - assertTrue(msg + " - compgraph", gradOK); - + gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(cg).inputs(new INDArray[] { input }).labels(new INDArray[] { labels })); + assertTrue(gradOK,msg + " - compgraph"); TestUtils.testModelSerialization(net); } } } } - @Test - public void testCnnWithUpsampling() { + @DisplayName("Test Cnn With Upsampling") + void testCnnWithUpsampling() { Nd4j.getRandom().setSeed(12345); int nOut = 4; - - int[] minibatchSizes = {1, 3}; + int[] minibatchSizes = { 1, 3 }; int width = 5; int height = 5; int inputDepth = 1; - - int[] kernel = {2, 2}; - int[] stride = {1, 1}; - int[] padding = {0, 0}; + int[] kernel = { 2, 2 }; + int[] stride = { 1, 1 }; + int[] padding = { 0, 0 }; int size = 2; - boolean nchw = format == CNN2DFormat.NCHW; - for (int minibatchSize : minibatchSizes) { - long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth}; + long[] inShape = nchw ? new long[] { minibatchSize, inputDepth, height, width } : new long[] { minibatchSize, height, width, inputDepth }; INDArray input = Nd4j.rand(DataType.DOUBLE, inShape); INDArray labels = TestUtils.randomOneHot(minibatchSize, nOut); - - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder() - .dataType(DataType.DOUBLE) - .updater(new NoOp()) - .dist(new NormalDistribution(0, 1)) - .list().layer(new ConvolutionLayer.Builder(kernel, - stride, padding).nIn(inputDepth) - .dataFormat(format) - .nOut(3).build())//output: (5-2+0)/1+1 = 4 - .layer(new Upsampling2D.Builder().size(size).dataFormat(format).build()) //output: 4*2 =8 -> 8x8x3 - .layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nIn(8 * 8 * 3) - .nOut(4).build()) - .setInputType(InputType.convolutional(height, width, inputDepth, format)) - .build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).updater(new NoOp()).dist(new NormalDistribution(0, 1)).list().layer(new ConvolutionLayer.Builder(kernel, stride, padding).nIn(inputDepth).dataFormat(format).nOut(3).build()).layer(// output: 4*2 =8 -> 8x8x3 + new Upsampling2D.Builder().size(size).dataFormat(format).build()).layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(8 * 8 * 3).nOut(4).build()).setInputType(InputType.convolutional(height, width, inputDepth, format)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - String msg = "Upsampling - minibatch=" + minibatchSize; - if (PRINT_RESULTS) { System.out.println(msg); -// for (int j = 0; j < net.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); + // for (int j = 0; j < net.getnLayers(); j++) + // System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); } - - boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); - - assertTrue(msg, gradOK); - + boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); + assertTrue(gradOK,msg); TestUtils.testModelSerialization(net); } } - @Test - public void testCnnWithSubsampling() { + @DisplayName("Test Cnn With Subsampling") + void testCnnWithSubsampling() { Nd4j.getRandom().setSeed(12345); int nOut = 4; - - int[] minibatchSizes = {1, 3}; + int[] minibatchSizes = { 1, 3 }; int width = 5; int height = 5; int inputDepth = 1; - - int[] kernel = {2, 2}; - int[] stride = {1, 1}; - int[] padding = {0, 0}; + int[] kernel = { 2, 2 }; + int[] stride = { 1, 1 }; + int[] padding = { 0, 0 }; int pnorm = 2; - - Activation[] activations = {Activation.SIGMOID, Activation.TANH}; - SubsamplingLayer.PoolingType[] poolingTypes = - new SubsamplingLayer.PoolingType[]{SubsamplingLayer.PoolingType.MAX, - SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM}; - + Activation[] activations = { Activation.SIGMOID, Activation.TANH }; + SubsamplingLayer.PoolingType[] poolingTypes = new SubsamplingLayer.PoolingType[] { SubsamplingLayer.PoolingType.MAX, SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM }; boolean nchw = format == CNN2DFormat.NCHW; - for (Activation afn : activations) { for (SubsamplingLayer.PoolingType poolingType : poolingTypes) { for (int minibatchSize : minibatchSizes) { - long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth}; + long[] inShape = nchw ? new long[] { minibatchSize, inputDepth, height, width } : new long[] { minibatchSize, height, width, inputDepth }; INDArray input = Nd4j.rand(DataType.DOUBLE, inShape); INDArray labels = Nd4j.zeros(minibatchSize, nOut); for (int i = 0; i < minibatchSize; i++) { - labels.putScalar(new int[]{i, i % nOut}, 1.0); + labels.putScalar(new int[] { i, i % nOut }, 1.0); } - - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder().updater(new NoOp()) - .dataType(DataType.DOUBLE) - .dist(new NormalDistribution(0, 1)) - .list().layer(0, - new ConvolutionLayer.Builder(kernel, - stride, padding).nIn(inputDepth) - .dataFormat(format) - .nOut(3).build())//output: (5-2+0)/1+1 = 4 - .layer(1, new SubsamplingLayer.Builder(poolingType) - .dataFormat(format) - .kernelSize(kernel).stride(stride).padding(padding) - .pnorm(pnorm).build()) //output: (4-2+0)/1+1 =3 -> 3x3x3 - .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nIn(3 * 3 * 3) - .nOut(4).build()) - .setInputType(InputType.convolutional(height, width, inputDepth, format)) - .build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new NoOp()).dataType(DataType.DOUBLE).dist(new NormalDistribution(0, 1)).list().layer(0, new ConvolutionLayer.Builder(kernel, stride, padding).nIn(inputDepth).dataFormat(format).nOut(3).build()).layer(1, new SubsamplingLayer.Builder(poolingType).dataFormat(format).kernelSize(kernel).stride(stride).padding(padding).pnorm(pnorm).build()).layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3 * 3 * 3).nOut(4).build()).setInputType(InputType.convolutional(height, width, inputDepth, format)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - - String msg = format + " - poolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn=" - + afn; - + String msg = format + " - poolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn=" + afn; if (PRINT_RESULTS) { System.out.println(msg); -// for (int j = 0; j < net.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); + // for (int j = 0; j < net.getnLayers(); j++) + // System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); } - - boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); - - assertTrue(msg, gradOK); - + boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); + assertTrue(gradOK,msg); TestUtils.testModelSerialization(net); } } @@ -524,68 +364,37 @@ public class CNNGradientCheckTest extends BaseDL4JTest { } @Test - public void testCnnWithSubsamplingV2() { + @DisplayName("Test Cnn With Subsampling V 2") + void testCnnWithSubsamplingV2() { Nd4j.getRandom().setSeed(12345); int nOut = 4; - - int[] minibatchSizes = {1, 3}; + int[] minibatchSizes = { 1, 3 }; int width = 5; int height = 5; int inputDepth = 1; - - int[] kernel = {2, 2}; - int[] stride = {1, 1}; - int[] padding = {0, 0}; + int[] kernel = { 2, 2 }; + int[] stride = { 1, 1 }; + int[] padding = { 0, 0 }; int pNorm = 3; - - Activation[] activations = {Activation.SIGMOID, Activation.TANH}; - SubsamplingLayer.PoolingType[] poolingTypes = - new SubsamplingLayer.PoolingType[]{SubsamplingLayer.PoolingType.MAX, - SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM}; - + Activation[] activations = { Activation.SIGMOID, Activation.TANH }; + SubsamplingLayer.PoolingType[] poolingTypes = new SubsamplingLayer.PoolingType[] { SubsamplingLayer.PoolingType.MAX, SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM }; boolean nchw = format == CNN2DFormat.NCHW; - for (Activation afn : activations) { for (SubsamplingLayer.PoolingType poolingType : poolingTypes) { for (int minibatchSize : minibatchSizes) { - long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth}; + long[] inShape = nchw ? new long[] { minibatchSize, inputDepth, height, width } : new long[] { minibatchSize, height, width, inputDepth }; INDArray input = Nd4j.rand(DataType.DOUBLE, inShape); INDArray labels = Nd4j.zeros(minibatchSize, nOut); for (int i = 0; i < minibatchSize; i++) { - labels.putScalar(new int[]{i, i % nOut}, 1.0); + labels.putScalar(new int[] { i, i % nOut }, 1.0); } - - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder().updater(new NoOp()) - .dataType(DataType.DOUBLE) - .dist(new NormalDistribution(0, 1)) - .list().layer(0, - new ConvolutionLayer.Builder(kernel, - stride, padding).nIn(inputDepth).dataFormat(format) - .nOut(3).build())//output: (5-2+0)/1+1 = 4 - .layer(1, new SubsamplingLayer.Builder(poolingType).dataFormat(format) - .kernelSize(kernel).stride(stride).padding(padding) - .pnorm(pNorm).build()) //output: (4-2+0)/1+1 =3 -> 3x3x3 - .layer(2, new ConvolutionLayer.Builder(kernel, stride, padding).dataFormat(format) - .nIn(3).nOut(2).build()) //Output: (3-2+0)/1+1 = 2 - .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nIn(2 * 2 * 2) - .nOut(4).build()) - .setInputType(InputType.convolutional(height, width, inputDepth, format)) - .build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new NoOp()).dataType(DataType.DOUBLE).dist(new NormalDistribution(0, 1)).list().layer(0, new ConvolutionLayer.Builder(kernel, stride, padding).nIn(inputDepth).dataFormat(format).nOut(3).build()).layer(1, new SubsamplingLayer.Builder(poolingType).dataFormat(format).kernelSize(kernel).stride(stride).padding(padding).pnorm(pNorm).build()).layer(2, new ConvolutionLayer.Builder(kernel, stride, padding).dataFormat(format).nIn(3).nOut(2).build()).layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(2 * 2 * 2).nOut(4).build()).setInputType(InputType.convolutional(height, width, inputDepth, format)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - - String msg = "PoolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn=" - + afn; + String msg = "PoolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn=" + afn; System.out.println(msg); - - boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); - - assertTrue(msg, gradOK); - + boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); + assertTrue(gradOK,msg); TestUtils.testModelSerialization(net); } } @@ -593,132 +402,68 @@ public class CNNGradientCheckTest extends BaseDL4JTest { } @Test - public void testCnnLocallyConnected2D() { + @DisplayName("Test Cnn Locally Connected 2 D") + void testCnnLocallyConnected2D() { int nOut = 3; int width = 5; int height = 5; - Nd4j.getRandom().setSeed(12345); - - int[] inputDepths = new int[]{1, 2, 4}; - Activation[] activations = {Activation.SIGMOID, Activation.TANH, Activation.SOFTPLUS}; - int[] minibatch = {2, 1, 3}; - + int[] inputDepths = new int[] { 1, 2, 4 }; + Activation[] activations = { Activation.SIGMOID, Activation.TANH, Activation.SOFTPLUS }; + int[] minibatch = { 2, 1, 3 }; boolean nchw = format == CNN2DFormat.NCHW; - - for( int i=0; i trying to predict 1 or -1 - Activation.SIGMOID, //kld -> probab so should be between 0 and 1 - Activation.SOFTMAX, //kld + softmax - Activation.TANH, //l1 - Activation.SOFTMAX, //l1 + softmax - Activation.TANH, //l2 - Activation.SOFTMAX, //l2 + softmax - Activation.IDENTITY, //mae - Activation.SOFTMAX, //mae + softmax - Activation.IDENTITY, //mape - Activation.SOFTMAX, //mape + softmax - Activation.SOFTMAX, //mcxent - Activation.IDENTITY, //mse - Activation.SOFTMAX, //mse + softmax - Activation.SIGMOID, //msle - requires positive labels/activations due to log - Activation.SOFTMAX, //msle + softmax - Activation.SIGMOID, //nll - Activation.SOFTMAX, //nll + softmax - Activation.SIGMOID, //poisson - requires positive predictions due to log... not sure if this is the best option - Activation.TANH, //squared hinge - Activation.SIGMOID, //f-measure (binary, single sigmoid output) - Activation.SOFTMAX //f-measure (binary, 2-label softmax output) - }; - - int[] nOut = new int[] {1, //xent - 3, //xent - 5, //cosine - 3, //hinge - 3, //kld - 3, //kld + softmax - 3, //l1 - 3, //l1 + softmax - 3, //l2 - 3, //l2 + softmax - 3, //mae - 3, //mae + softmax - 3, //mape - 3, //mape + softmax - 3, //mcxent - 3, //mse - 3, //mse + softmax - 3, //msle - 3, //msle + softmax - 3, //nll - 3, //nll + softmax - 3, //poisson - 3, //squared hinge - 1, //f-measure (binary, single sigmoid output) - 2, //f-measure (binary, 2-label softmax output) - }; - + @DisplayName("Test Json Loss Functions") + void testJsonLossFunctions() { + ILossFunction[] lossFunctions = new ILossFunction[] { new LossBinaryXENT(), new LossBinaryXENT(), new LossCosineProximity(), new LossHinge(), new LossKLD(), new LossKLD(), new LossL1(), new LossL1(), new LossL2(), new LossL2(), new LossMAE(), new LossMAE(), new LossMAPE(), new LossMAPE(), new LossMCXENT(), new LossMSE(), new LossMSE(), new LossMSLE(), new LossMSLE(), new LossNegativeLogLikelihood(), new LossNegativeLogLikelihood(), new LossPoisson(), new LossSquaredHinge(), new LossFMeasure(), new LossFMeasure(2.0) }; + Activation[] outputActivationFn = new Activation[] { // xent + Activation.SIGMOID, // xent + Activation.SIGMOID, // cosine + Activation.TANH, // hinge -> trying to predict 1 or -1 + Activation.TANH, // kld -> probab so should be between 0 and 1 + Activation.SIGMOID, // kld + softmax + Activation.SOFTMAX, // l1 + Activation.TANH, // l1 + softmax + Activation.SOFTMAX, // l2 + Activation.TANH, // l2 + softmax + Activation.SOFTMAX, // mae + Activation.IDENTITY, // mae + softmax + Activation.SOFTMAX, // mape + Activation.IDENTITY, // mape + softmax + Activation.SOFTMAX, // mcxent + Activation.SOFTMAX, // mse + Activation.IDENTITY, // mse + softmax + Activation.SOFTMAX, // msle - requires positive labels/activations due to log + Activation.SIGMOID, // msle + softmax + Activation.SOFTMAX, // nll + Activation.SIGMOID, // nll + softmax + Activation.SOFTMAX, // poisson - requires positive predictions due to log... not sure if this is the best option + Activation.SIGMOID, // squared hinge + Activation.TANH, // f-measure (binary, single sigmoid output) + Activation.SIGMOID, // f-measure (binary, 2-label softmax output) + Activation.SOFTMAX }; + int[] nOut = new int[] { // xent + 1, // xent + 3, // cosine + 5, // hinge + 3, // kld + 3, // kld + softmax + 3, // l1 + 3, // l1 + softmax + 3, // l2 + 3, // l2 + softmax + 3, // mae + 3, // mae + softmax + 3, // mape + 3, // mape + softmax + 3, // mcxent + 3, // mse + 3, // mse + softmax + 3, // msle + 3, // msle + softmax + 3, // nll + 3, // nll + softmax + 3, // poisson + 3, // squared hinge + 3, // f-measure (binary, single sigmoid output) + 1, // f-measure (binary, 2-label softmax output) + 2 }; for (int i = 0; i < lossFunctions.length; i++) { - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(Updater.ADAM).list() - .layer(0, new DenseLayer.Builder().nIn(4).nOut(nOut[i]).activation(Activation.TANH).build()) - .layer(1, new LossLayer.Builder().lossFunction(lossFunctions[i]) - .activation(outputActivationFn[i]).build()) - .validateOutputLayerConfig(false).build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(Updater.ADAM).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(nOut[i]).activation(Activation.TANH).build()).layer(1, new LossLayer.Builder().lossFunction(lossFunctions[i]).activation(outputActivationFn[i]).build()).validateOutputLayerConfig(false).build(); String json = conf.toJson(); String yaml = conf.toYaml(); - MultiLayerConfiguration fromJson = MultiLayerConfiguration.fromJson(json); MultiLayerConfiguration fromYaml = MultiLayerConfiguration.fromYaml(yaml); - assertEquals(conf, fromJson); assertEquals(conf, fromYaml); } } - } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/MultiLayerNeuralNetConfigurationTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/MultiLayerNeuralNetConfigurationTest.java index e80c422bf..e08c01440 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/MultiLayerNeuralNetConfigurationTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/MultiLayerNeuralNetConfigurationTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.conf; import lombok.extern.slf4j.Slf4j; @@ -34,41 +33,40 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.lossfunctions.LossFunctions; - import java.io.*; import java.util.Arrays; import java.util.Properties; - -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; +import org.junit.jupiter.api.DisplayName; +import java.nio.file.Path; +import org.junit.jupiter.api.extension.ExtendWith; @Slf4j -public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest { +@DisplayName("Multi Layer Neural Net Configuration Test") +class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest { - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); + @TempDir + public Path testDir; @Test - public void testJson() throws Exception { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list() - .layer(0, new DenseLayer.Builder().dist(new NormalDistribution(1, 1e-1)).build()) - .inputPreProcessor(0, new CnnToFeedForwardPreProcessor()).build(); - + @DisplayName("Test Json") + void testJson() throws Exception { + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(0, new DenseLayer.Builder().dist(new NormalDistribution(1, 1e-1)).build()).inputPreProcessor(0, new CnnToFeedForwardPreProcessor()).build(); String json = conf.toJson(); MultiLayerConfiguration from = MultiLayerConfiguration.fromJson(json); assertEquals(conf.getConf(0), from.getConf(0)); - Properties props = new Properties(); props.put("json", json); String key = props.getProperty("json"); assertEquals(json, key); - File f = testDir.newFile("props"); + File f = testDir.resolve("props").toFile(); f.deleteOnExit(); BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(f)); props.store(bos, ""); @@ -82,36 +80,18 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest { String json2 = props2.getProperty("json"); MultiLayerConfiguration conf3 = MultiLayerConfiguration.fromJson(json2); assertEquals(conf.getConf(0), conf3.getConf(0)); - } @Test - public void testConvnetJson() { + @DisplayName("Test Convnet Json") + void testConvnetJson() { final int numRows = 76; final int numColumns = 76; int nChannels = 3; int outputNum = 6; int seed = 123; - - //setup the network - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed) - .l1(1e-1).l2(2e-4).weightNoise(new DropConnect(0.5)).miniBatch(true) - .optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).list() - .layer(0, new ConvolutionLayer.Builder(5, 5).nOut(5).dropOut(0.5).weightInit(WeightInit.XAVIER) - .activation(Activation.RELU).build()) - .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[] {2, 2}) - .build()) - .layer(2, new ConvolutionLayer.Builder(3, 3).nOut(10).dropOut(0.5).weightInit(WeightInit.XAVIER) - .activation(Activation.RELU).build()) - .layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[] {2, 2}) - .build()) - .layer(4, new DenseLayer.Builder().nOut(100).activation(Activation.RELU).build()) - .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) - .nOut(outputNum).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX) - .build()) - - .setInputType(InputType.convolutional(numRows, numColumns, nChannels)); - + // setup the network + MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed).l1(1e-1).l2(2e-4).weightNoise(new DropConnect(0.5)).miniBatch(true).optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).list().layer(0, new ConvolutionLayer.Builder(5, 5).nOut(5).dropOut(0.5).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build()).layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[] { 2, 2 }).build()).layer(2, new ConvolutionLayer.Builder(3, 3).nOut(10).dropOut(0.5).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build()).layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[] { 2, 2 }).build()).layer(4, new DenseLayer.Builder().nOut(100).activation(Activation.RELU).build()).layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(outputNum).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build()).setInputType(InputType.convolutional(numRows, numColumns, nChannels)); MultiLayerConfiguration conf = builder.build(); String json = conf.toJson(); MultiLayerConfiguration conf2 = MultiLayerConfiguration.fromJson(json); @@ -119,30 +99,15 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest { } @Test - public void testUpsamplingConvnetJson() { + @DisplayName("Test Upsampling Convnet Json") + void testUpsamplingConvnetJson() { final int numRows = 76; final int numColumns = 76; int nChannels = 3; int outputNum = 6; int seed = 123; - - //setup the network - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed) - .l1(1e-1).l2(2e-4).dropOut(0.5).miniBatch(true) - .optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).list() - .layer(new ConvolutionLayer.Builder(5, 5).nOut(5).dropOut(0.5).weightInit(WeightInit.XAVIER) - .activation(Activation.RELU).build()) - .layer(new Upsampling2D.Builder().size(2).build()) - .layer(2, new ConvolutionLayer.Builder(3, 3).nOut(10).dropOut(0.5).weightInit(WeightInit.XAVIER) - .activation(Activation.RELU).build()) - .layer(new Upsampling2D.Builder().size(2).build()) - .layer(4, new DenseLayer.Builder().nOut(100).activation(Activation.RELU).build()) - .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) - .nOut(outputNum).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX) - .build()) - - .setInputType(InputType.convolutional(numRows, numColumns, nChannels)); - + // setup the network + MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed).l1(1e-1).l2(2e-4).dropOut(0.5).miniBatch(true).optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).list().layer(new ConvolutionLayer.Builder(5, 5).nOut(5).dropOut(0.5).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build()).layer(new Upsampling2D.Builder().size(2).build()).layer(2, new ConvolutionLayer.Builder(3, 3).nOut(10).dropOut(0.5).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build()).layer(new Upsampling2D.Builder().size(2).build()).layer(4, new DenseLayer.Builder().nOut(100).activation(Activation.RELU).build()).layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(outputNum).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build()).setInputType(InputType.convolutional(numRows, numColumns, nChannels)); MultiLayerConfiguration conf = builder.build(); String json = conf.toJson(); MultiLayerConfiguration conf2 = MultiLayerConfiguration.fromJson(json); @@ -150,36 +115,26 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest { } @Test - public void testGlobalPoolingJson() { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new NoOp()) - .dist(new NormalDistribution(0, 1.0)).seed(12345L).list() - .layer(0, new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).nOut(5).build()) - .layer(1, new GlobalPoolingLayer.Builder().poolingType(PoolingType.PNORM).pnorm(3).build()) - .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nOut(3).build()) - .setInputType(InputType.convolutional(32, 32, 1)).build(); - + @DisplayName("Test Global Pooling Json") + void testGlobalPoolingJson() { + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new NoOp()).dist(new NormalDistribution(0, 1.0)).seed(12345L).list().layer(0, new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).nOut(5).build()).layer(1, new GlobalPoolingLayer.Builder().poolingType(PoolingType.PNORM).pnorm(3).build()).layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(3).build()).setInputType(InputType.convolutional(32, 32, 1)).build(); String str = conf.toJson(); MultiLayerConfiguration fromJson = conf.fromJson(str); - assertEquals(conf, fromJson); } - @Test - public void testYaml() throws Exception { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list() - .layer(0, new DenseLayer.Builder().dist(new NormalDistribution(1, 1e-1)).build()) - .inputPreProcessor(0, new CnnToFeedForwardPreProcessor()).build(); + @DisplayName("Test Yaml") + void testYaml() throws Exception { + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(0, new DenseLayer.Builder().dist(new NormalDistribution(1, 1e-1)).build()).inputPreProcessor(0, new CnnToFeedForwardPreProcessor()).build(); String json = conf.toYaml(); MultiLayerConfiguration from = MultiLayerConfiguration.fromYaml(json); assertEquals(conf.getConf(0), from.getConf(0)); - Properties props = new Properties(); props.put("json", json); String key = props.getProperty("json"); assertEquals(json, key); - File f = testDir.newFile("props"); + File f = testDir.resolve("props").toFile(); f.deleteOnExit(); BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(f)); props.store(bos, ""); @@ -193,17 +148,13 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest { String yaml = props2.getProperty("json"); MultiLayerConfiguration conf3 = MultiLayerConfiguration.fromYaml(yaml); assertEquals(conf.getConf(0), conf3.getConf(0)); - } @Test - public void testClone() { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(0, new DenseLayer.Builder().build()) - .layer(1, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).build()) - .inputPreProcessor(1, new CnnToFeedForwardPreProcessor()).build(); - + @DisplayName("Test Clone") + void testClone() { + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(0, new DenseLayer.Builder().build()).layer(1, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).build()).inputPreProcessor(1, new CnnToFeedForwardPreProcessor()).build(); MultiLayerConfiguration conf2 = conf.clone(); - assertEquals(conf, conf2); assertNotSame(conf, conf2); assertNotSame(conf.getConfs(), conf2.getConfs()); @@ -217,174 +168,125 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest { } @Test - public void testRandomWeightInit() { + @DisplayName("Test Random Weight Init") + void testRandomWeightInit() { MultiLayerNetwork model1 = new MultiLayerNetwork(getConf()); model1.init(); - Nd4j.getRandom().setSeed(12345L); MultiLayerNetwork model2 = new MultiLayerNetwork(getConf()); model2.init(); - float[] p1 = model1.params().data().asFloat(); float[] p2 = model2.params().data().asFloat(); System.out.println(Arrays.toString(p1)); System.out.println(Arrays.toString(p2)); - - org.junit.Assert.assertArrayEquals(p1, p2, 0.0f); + assertArrayEquals(p1, p2, 0.0f); } @Test - public void testTrainingListener() { + @DisplayName("Test Training Listener") + void testTrainingListener() { MultiLayerNetwork model1 = new MultiLayerNetwork(getConf()); model1.init(); - model1.addListeners( new ScoreIterationListener(1)); - + model1.addListeners(new ScoreIterationListener(1)); MultiLayerNetwork model2 = new MultiLayerNetwork(getConf()); - model2.addListeners( new ScoreIterationListener(1)); + model2.addListeners(new ScoreIterationListener(1)); model2.init(); - Layer[] l1 = model1.getLayers(); - for (int i = 0; i < l1.length; i++) - assertTrue(l1[i].getListeners() != null && l1[i].getListeners().size() == 1); - + for (int i = 0; i < l1.length; i++) assertTrue(l1[i].getListeners() != null && l1[i].getListeners().size() == 1); Layer[] l2 = model2.getLayers(); - for (int i = 0; i < l2.length; i++) - assertTrue(l2[i].getListeners() != null && l2[i].getListeners().size() == 1); + for (int i = 0; i < l2.length; i++) assertTrue(l2[i].getListeners() != null && l2[i].getListeners().size() == 1); } - private static MultiLayerConfiguration getConf() { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345l).list() - .layer(0, new DenseLayer.Builder().nIn(2).nOut(2) - .dist(new NormalDistribution(0, 1)).build()) - .layer(1, new OutputLayer.Builder().nIn(2).nOut(1) - .activation(Activation.TANH) - .dist(new NormalDistribution(0, 1)).lossFunction(LossFunctions.LossFunction.MSE).build()) - .build(); + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345l).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).dist(new NormalDistribution(0, 1)).build()).layer(1, new OutputLayer.Builder().nIn(2).nOut(1).activation(Activation.TANH).dist(new NormalDistribution(0, 1)).lossFunction(LossFunctions.LossFunction.MSE).build()).build(); return conf; } @Test - public void testInvalidConfig() { - + @DisplayName("Test Invalid Config") + void testInvalidConfig() { try { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).list() - .build(); + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).list().build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); fail("No exception thrown for invalid configuration"); } catch (IllegalStateException e) { - //OK - log.error("",e); + // OK + log.error("", e); } catch (Throwable e) { - log.error("",e); + log.error("", e); fail("Unexpected exception thrown for invalid config"); } - try { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).list() - .layer(1, new DenseLayer.Builder().nIn(3).nOut(4).build()) - .layer(2, new OutputLayer.Builder().nIn(4).nOut(5).build()) - .build(); + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).list().layer(1, new DenseLayer.Builder().nIn(3).nOut(4).build()).layer(2, new OutputLayer.Builder().nIn(4).nOut(5).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); fail("No exception thrown for invalid configuration"); } catch (IllegalStateException e) { - //OK + // OK log.info(e.toString()); } catch (Throwable e) { - log.error("",e); + log.error("", e); fail("Unexpected exception thrown for invalid config"); } - try { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).list() - .layer(0, new DenseLayer.Builder().nIn(3).nOut(4).build()) - .layer(2, new OutputLayer.Builder().nIn(4).nOut(5).build()) - .build(); + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).list().layer(0, new DenseLayer.Builder().nIn(3).nOut(4).build()).layer(2, new OutputLayer.Builder().nIn(4).nOut(5).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); fail("No exception thrown for invalid configuration"); } catch (IllegalStateException e) { - //OK + // OK log.info(e.toString()); } catch (Throwable e) { - log.error("",e); + log.error("", e); fail("Unexpected exception thrown for invalid config"); } } @Test - public void testListOverloads() { - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).list() - .layer(0, new DenseLayer.Builder().nIn(3).nOut(4).build()) - .layer(1, new OutputLayer.Builder().nIn(4).nOut(5).activation(Activation.SOFTMAX).build()) - .build(); + @DisplayName("Test List Overloads") + void testListOverloads() { + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).list().layer(0, new DenseLayer.Builder().nIn(3).nOut(4).build()).layer(1, new OutputLayer.Builder().nIn(4).nOut(5).activation(Activation.SOFTMAX).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - DenseLayer dl = (DenseLayer) conf.getConf(0).getLayer(); assertEquals(3, dl.getNIn()); assertEquals(4, dl.getNOut()); OutputLayer ol = (OutputLayer) conf.getConf(1).getLayer(); assertEquals(4, ol.getNIn()); assertEquals(5, ol.getNOut()); - - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345).list() - .layer(0, new DenseLayer.Builder().nIn(3).nOut(4).build()) - .layer(1, new OutputLayer.Builder().nIn(4).nOut(5).activation(Activation.SOFTMAX).build()) - .build(); + MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345).list().layer(0, new DenseLayer.Builder().nIn(3).nOut(4).build()).layer(1, new OutputLayer.Builder().nIn(4).nOut(5).activation(Activation.SOFTMAX).build()).build(); MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); net2.init(); - - MultiLayerConfiguration conf3 = new NeuralNetConfiguration.Builder().seed(12345) - .list(new DenseLayer.Builder().nIn(3).nOut(4).build(), - new OutputLayer.Builder().nIn(4).nOut(5).activation(Activation.SOFTMAX).build()) - .build(); + MultiLayerConfiguration conf3 = new NeuralNetConfiguration.Builder().seed(12345).list(new DenseLayer.Builder().nIn(3).nOut(4).build(), new OutputLayer.Builder().nIn(4).nOut(5).activation(Activation.SOFTMAX).build()).build(); MultiLayerNetwork net3 = new MultiLayerNetwork(conf3); net3.init(); - - assertEquals(conf, conf2); assertEquals(conf, conf3); } - @Test - public void testBiasLr() { - //setup the network - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(new Adam(1e-2)) - .biasUpdater(new Adam(0.5)).list() - .layer(0, new ConvolutionLayer.Builder(5, 5).nOut(5).weightInit(WeightInit.XAVIER) - .activation(Activation.RELU).build()) - .layer(1, new DenseLayer.Builder().nOut(100).activation(Activation.RELU).build()) - .layer(2, new DenseLayer.Builder().nOut(100).activation(Activation.RELU).build()) - .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(10) - .weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutional(28, 28, 1)).build(); - + @DisplayName("Test Bias Lr") + void testBiasLr() { + // setup the network + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(new Adam(1e-2)).biasUpdater(new Adam(0.5)).list().layer(0, new ConvolutionLayer.Builder(5, 5).nOut(5).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build()).layer(1, new DenseLayer.Builder().nOut(100).activation(Activation.RELU).build()).layer(2, new DenseLayer.Builder().nOut(100).activation(Activation.RELU).build()).layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(10).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build()).setInputType(InputType.convolutional(28, 28, 1)).build(); org.deeplearning4j.nn.conf.layers.BaseLayer l0 = (BaseLayer) conf.getConf(0).getLayer(); org.deeplearning4j.nn.conf.layers.BaseLayer l1 = (BaseLayer) conf.getConf(1).getLayer(); org.deeplearning4j.nn.conf.layers.BaseLayer l2 = (BaseLayer) conf.getConf(2).getLayer(); org.deeplearning4j.nn.conf.layers.BaseLayer l3 = (BaseLayer) conf.getConf(3).getLayer(); - - assertEquals(0.5, ((Adam)l0.getUpdaterByParam("b")).getLearningRate(), 1e-6); - assertEquals(1e-2, ((Adam)l0.getUpdaterByParam("W")).getLearningRate(), 1e-6); - - assertEquals(0.5, ((Adam)l1.getUpdaterByParam("b")).getLearningRate(), 1e-6); - assertEquals(1e-2, ((Adam)l1.getUpdaterByParam("W")).getLearningRate(), 1e-6); - - assertEquals(0.5, ((Adam)l2.getUpdaterByParam("b")).getLearningRate(), 1e-6); - assertEquals(1e-2, ((Adam)l2.getUpdaterByParam("W")).getLearningRate(), 1e-6); - - assertEquals(0.5, ((Adam)l3.getUpdaterByParam("b")).getLearningRate(), 1e-6); - assertEquals(1e-2, ((Adam)l3.getUpdaterByParam("W")).getLearningRate(), 1e-6); + assertEquals(0.5, ((Adam) l0.getUpdaterByParam("b")).getLearningRate(), 1e-6); + assertEquals(1e-2, ((Adam) l0.getUpdaterByParam("W")).getLearningRate(), 1e-6); + assertEquals(0.5, ((Adam) l1.getUpdaterByParam("b")).getLearningRate(), 1e-6); + assertEquals(1e-2, ((Adam) l1.getUpdaterByParam("W")).getLearningRate(), 1e-6); + assertEquals(0.5, ((Adam) l2.getUpdaterByParam("b")).getLearningRate(), 1e-6); + assertEquals(1e-2, ((Adam) l2.getUpdaterByParam("W")).getLearningRate(), 1e-6); + assertEquals(0.5, ((Adam) l3.getUpdaterByParam("b")).getLearningRate(), 1e-6); + assertEquals(1e-2, ((Adam) l3.getUpdaterByParam("W")).getLearningRate(), 1e-6); } - @Test - public void testInvalidOutputLayer(){ + @DisplayName("Test Invalid Output Layer") + void testInvalidOutputLayer() { /* Test case (invalid configs) 1. nOut=1 + softmax @@ -393,32 +295,24 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest { 4. xent + relu 5. mcxent + sigmoid */ - - LossFunctions.LossFunction[] lf = new LossFunctions.LossFunction[]{ - LossFunctions.LossFunction.MCXENT, LossFunctions.LossFunction.MCXENT, LossFunctions.LossFunction.XENT, - LossFunctions.LossFunction.XENT, LossFunctions.LossFunction.MCXENT}; - int[] nOut = new int[]{1, 3, 3, 3, 3}; - Activation[] activations = new Activation[]{Activation.SOFTMAX, Activation.TANH, Activation.SOFTMAX, Activation.RELU, Activation.SIGMOID}; - for( int i=0; i r = net.getLayer(0).conf().getLayer().getRegularizationByParam("b"); assertEquals(0, r.size()); - r = net.getLayer(1).conf().getLayer().getRegularizationByParam("beta"); assertTrue(r == null || r.isEmpty()); r = net.getLayer(1).conf().getLayer().getRegularizationByParam("gamma"); @@ -315,14 +268,10 @@ public class NeuralNetConfigurationTest extends BaseDL4JTest { } @Test - public void testLayerPretrainConfig() { + @DisplayName("Test Layer Pretrain Config") + void testLayerPretrainConfig() { boolean pretrain = true; - - VariationalAutoencoder layer = new VariationalAutoencoder.Builder() - .nIn(10).nOut(5).updater(new Sgd(1e-1)) - .lossFunction(LossFunctions.LossFunction.KL_DIVERGENCE).build(); - + VariationalAutoencoder layer = new VariationalAutoencoder.Builder().nIn(10).nOut(5).updater(new Sgd(1e-1)).lossFunction(LossFunctions.LossFunction.KL_DIVERGENCE).build(); NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().seed(42).layer(layer).build(); } - } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/graph/ElementWiseVertexTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/graph/ElementWiseVertexTest.java index cabb6a73a..73c2385d1 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/graph/ElementWiseVertexTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/graph/ElementWiseVertexTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.conf.graph; import org.deeplearning4j.BaseDL4JTest; @@ -30,8 +29,8 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.Assert; -import org.junit.Test; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.impl.ActivationSigmoid; import org.nd4j.linalg.activations.impl.ActivationTanH; @@ -43,194 +42,99 @@ import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; import org.nd4j.linalg.ops.transforms.Transforms; import org.nd4j.common.primitives.Pair; - import java.util.Map; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.assertArrayEquals; +@DisplayName("Element Wise Vertex Test") +class ElementWiseVertexTest extends BaseDL4JTest { -public class ElementWiseVertexTest extends BaseDL4JTest { @Test - public void testElementWiseVertexNumParams() { + @DisplayName("Test Element Wise Vertex Num Params") + void testElementWiseVertexNumParams() { /* * https://github.com/eclipse/deeplearning4j/pull/3514#issuecomment-307754386 * from @agibsonccc: check for the basics: like 0 numParams */ - - ElementWiseVertex.Op ops[] = new ElementWiseVertex.Op[] {ElementWiseVertex.Op.Add, - ElementWiseVertex.Op.Subtract, ElementWiseVertex.Op.Product}; - + ElementWiseVertex.Op[] ops = new ElementWiseVertex.Op[] { ElementWiseVertex.Op.Add, ElementWiseVertex.Op.Subtract, ElementWiseVertex.Op.Product }; for (ElementWiseVertex.Op op : ops) { ElementWiseVertex ewv = new ElementWiseVertex(op); - Assert.assertEquals(0, ewv.numParams(true)); - Assert.assertEquals(0, ewv.numParams(false)); + Assertions.assertEquals(0, ewv.numParams(true)); + Assertions.assertEquals(0, ewv.numParams(false)); } } @Test - public void testElementWiseVertexForwardAdd() { + @DisplayName("Test Element Wise Vertex Forward Add") + void testElementWiseVertexForwardAdd() { int batchsz = 24; int featuresz = 17; - ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().graphBuilder() - .addInputs("input1", "input2", "input3") - .addLayer("denselayer", - new DenseLayer.Builder().nIn(featuresz).nOut(1).activation(Activation.IDENTITY) - .build(), - "input1") - /* denselayer is not actually used, but it seems that you _need_ to have trainable parameters, otherwise, you get - * Invalid shape: Requested INDArray shape [1, 0] contains dimension size values < 1 (all dimensions must be 1 or more) - * at org.nd4j.linalg.factory.Nd4j.checkShapeValues(Nd4j.java:4877) - * at org.nd4j.linalg.factory.Nd4j.create(Nd4j.java:4867) - * at org.nd4j.linalg.factory.Nd4j.create(Nd4j.java:4820) - * at org.nd4j.linalg.factory.Nd4j.create(Nd4j.java:3948) - * at org.deeplearning4j.nn.graph.ComputationGraph.init(ComputationGraph.java:409) - * at org.deeplearning4j.nn.graph.ComputationGraph.init(ComputationGraph.java:341) - */ - .addVertex("elementwiseAdd", new ElementWiseVertex(ElementWiseVertex.Op.Add), "input1", - "input2", "input3") - .addLayer("Add", new ActivationLayer.Builder().activation(Activation.IDENTITY).build(), - "elementwiseAdd") - .setOutputs("Add", "denselayer").build(); - + ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("input1", "input2", "input3").addLayer("denselayer", new DenseLayer.Builder().nIn(featuresz).nOut(1).activation(Activation.IDENTITY).build(), "input1").addVertex("elementwiseAdd", new ElementWiseVertex(ElementWiseVertex.Op.Add), "input1", "input2", "input3").addLayer("Add", new ActivationLayer.Builder().activation(Activation.IDENTITY).build(), "elementwiseAdd").setOutputs("Add", "denselayer").build(); ComputationGraph cg = new ComputationGraph(cgc); cg.init(); - - INDArray input1 = Nd4j.rand(batchsz, featuresz); INDArray input2 = Nd4j.rand(batchsz, featuresz); INDArray input3 = Nd4j.rand(batchsz, featuresz); - INDArray target = input1.dup().addi(input2).addi(input3); - INDArray output = cg.output(input1, input2, input3)[0]; INDArray squared = output.sub(target.castTo(output.dataType())); double rms = squared.mul(squared).sumNumber().doubleValue(); - Assert.assertEquals(0.0, rms, this.epsilon); + Assertions.assertEquals(0.0, rms, this.epsilon); } @Test - public void testElementWiseVertexForwardProduct() { + @DisplayName("Test Element Wise Vertex Forward Product") + void testElementWiseVertexForwardProduct() { int batchsz = 24; int featuresz = 17; - ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().graphBuilder() - .addInputs("input1", "input2", "input3") - .addLayer("denselayer", - new DenseLayer.Builder().nIn(featuresz).nOut(1).activation(Activation.IDENTITY) - .build(), - "input1") - /* denselayer is not actually used, but it seems that you _need_ to have trainable parameters, otherwise, you get - * Invalid shape: Requested INDArray shape [1, 0] contains dimension size values < 1 (all dimensions must be 1 or more) - * at org.nd4j.linalg.factory.Nd4j.checkShapeValues(Nd4j.java:4877) - * at org.nd4j.linalg.factory.Nd4j.create(Nd4j.java:4867) - * at org.nd4j.linalg.factory.Nd4j.create(Nd4j.java:4820) - * at org.nd4j.linalg.factory.Nd4j.create(Nd4j.java:3948) - * at org.deeplearning4j.nn.graph.ComputationGraph.init(ComputationGraph.java:409) - * at org.deeplearning4j.nn.graph.ComputationGraph.init(ComputationGraph.java:341) - */ - .addVertex("elementwiseProduct", new ElementWiseVertex(ElementWiseVertex.Op.Product), "input1", - "input2", "input3") - .addLayer("Product", new ActivationLayer.Builder().activation(Activation.IDENTITY).build(), - "elementwiseProduct") - .setOutputs("Product", "denselayer").build(); - + ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("input1", "input2", "input3").addLayer("denselayer", new DenseLayer.Builder().nIn(featuresz).nOut(1).activation(Activation.IDENTITY).build(), "input1").addVertex("elementwiseProduct", new ElementWiseVertex(ElementWiseVertex.Op.Product), "input1", "input2", "input3").addLayer("Product", new ActivationLayer.Builder().activation(Activation.IDENTITY).build(), "elementwiseProduct").setOutputs("Product", "denselayer").build(); ComputationGraph cg = new ComputationGraph(cgc); cg.init(); - - INDArray input1 = Nd4j.rand(batchsz, featuresz); INDArray input2 = Nd4j.rand(batchsz, featuresz); INDArray input3 = Nd4j.rand(batchsz, featuresz); - INDArray target = input1.dup().muli(input2).muli(input3); - INDArray output = cg.output(input1, input2, input3)[0]; INDArray squared = output.sub(target.castTo(output.dataType())); double rms = squared.mul(squared).sumNumber().doubleValue(); - Assert.assertEquals(0.0, rms, this.epsilon); + Assertions.assertEquals(0.0, rms, this.epsilon); } @Test - public void testElementWiseVertexForwardSubtract() { + @DisplayName("Test Element Wise Vertex Forward Subtract") + void testElementWiseVertexForwardSubtract() { int batchsz = 24; int featuresz = 17; - ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().graphBuilder() - .addInputs("input1", "input2") - .addLayer("denselayer", - new DenseLayer.Builder().nIn(featuresz).nOut(1).activation(Activation.IDENTITY) - .build(), - "input1") - /* denselayer is not actually used, but it seems that you _need_ to have trainable parameters, otherwise, you get - * Invalid shape: Requested INDArray shape [1, 0] contains dimension size values < 1 (all dimensions must be 1 or more) - * at org.nd4j.linalg.factory.Nd4j.checkShapeValues(Nd4j.java:4877) - * at org.nd4j.linalg.factory.Nd4j.create(Nd4j.java:4867) - * at org.nd4j.linalg.factory.Nd4j.create(Nd4j.java:4820) - * at org.nd4j.linalg.factory.Nd4j.create(Nd4j.java:3948) - * at org.deeplearning4j.nn.graph.ComputationGraph.init(ComputationGraph.java:409) - * at org.deeplearning4j.nn.graph.ComputationGraph.init(ComputationGraph.java:341) - */ - .addVertex("elementwiseSubtract", new ElementWiseVertex(ElementWiseVertex.Op.Subtract), - "input1", "input2") - .addLayer("Subtract", new ActivationLayer.Builder().activation(Activation.IDENTITY).build(), - "elementwiseSubtract") - .setOutputs("Subtract", "denselayer").build(); - + ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("input1", "input2").addLayer("denselayer", new DenseLayer.Builder().nIn(featuresz).nOut(1).activation(Activation.IDENTITY).build(), "input1").addVertex("elementwiseSubtract", new ElementWiseVertex(ElementWiseVertex.Op.Subtract), "input1", "input2").addLayer("Subtract", new ActivationLayer.Builder().activation(Activation.IDENTITY).build(), "elementwiseSubtract").setOutputs("Subtract", "denselayer").build(); ComputationGraph cg = new ComputationGraph(cgc); cg.init(); - - INDArray input1 = Nd4j.rand(batchsz, featuresz); INDArray input2 = Nd4j.rand(batchsz, featuresz); - INDArray target = input1.dup().subi(input2); - INDArray output = cg.output(input1, input2)[0]; INDArray squared = output.sub(target); double rms = Math.sqrt(squared.mul(squared).sumNumber().doubleValue()); - Assert.assertEquals(0.0, rms, this.epsilon); + Assertions.assertEquals(0.0, rms, this.epsilon); } @Test - public void testElementWiseVertexFullAdd() { + @DisplayName("Test Element Wise Vertex Full Add") + void testElementWiseVertexFullAdd() { int batchsz = 24; int featuresz = 17; int midsz = 13; int outputsz = 11; - ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER) - .dataType(DataType.DOUBLE) - .biasInit(0.0).updater(new Sgd()) - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder() - .addInputs("input1", "input2", "input3") - .addLayer("dense1", - new DenseLayer.Builder().nIn(featuresz).nOut(midsz) - .activation(new ActivationTanH()).build(), - "input1") - .addLayer("dense2", - new DenseLayer.Builder().nIn(featuresz).nOut(midsz) - .activation(new ActivationTanH()).build(), - "input2") - .addLayer("dense3", - new DenseLayer.Builder().nIn(featuresz).nOut(midsz) - .activation(new ActivationTanH()).build(), - "input3") - .addVertex("elementwiseAdd", new ElementWiseVertex(ElementWiseVertex.Op.Add), "dense1", - "dense2", "dense3") - .addLayer("output", - new OutputLayer.Builder().nIn(midsz).nOut(outputsz) - .activation(new ActivationSigmoid()) - .lossFunction(LossFunction.MSE).build(), - "elementwiseAdd") - .setOutputs("output").build(); - + ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).dataType(DataType.DOUBLE).biasInit(0.0).updater(new Sgd()).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder().addInputs("input1", "input2", "input3").addLayer("dense1", new DenseLayer.Builder().nIn(featuresz).nOut(midsz).activation(new ActivationTanH()).build(), "input1").addLayer("dense2", new DenseLayer.Builder().nIn(featuresz).nOut(midsz).activation(new ActivationTanH()).build(), "input2").addLayer("dense3", new DenseLayer.Builder().nIn(featuresz).nOut(midsz).activation(new ActivationTanH()).build(), "input3").addVertex("elementwiseAdd", new ElementWiseVertex(ElementWiseVertex.Op.Add), "dense1", "dense2", "dense3").addLayer("output", new OutputLayer.Builder().nIn(midsz).nOut(outputsz).activation(new ActivationSigmoid()).lossFunction(LossFunction.MSE).build(), "elementwiseAdd").setOutputs("output").build(); ComputationGraph cg = new ComputationGraph(cgc); cg.init(); - INDArray input1 = Nd4j.rand(new int[] {batchsz, featuresz}, new UniformDistribution(-1, 1)); - INDArray input2 = Nd4j.rand(new int[] {batchsz, featuresz}, new UniformDistribution(-1, 1)); - INDArray input3 = Nd4j.rand(new int[] {batchsz, featuresz}, new UniformDistribution(-1, 1)); - INDArray target = nullsafe(Nd4j.rand(new int[] {batchsz, outputsz}, new UniformDistribution(0, 1))); + INDArray input1 = Nd4j.rand(new int[] { batchsz, featuresz }, new UniformDistribution(-1, 1)); + INDArray input2 = Nd4j.rand(new int[] { batchsz, featuresz }, new UniformDistribution(-1, 1)); + INDArray input3 = Nd4j.rand(new int[] { batchsz, featuresz }, new UniformDistribution(-1, 1)); + INDArray target = nullsafe(Nd4j.rand(new int[] { batchsz, outputsz }, new UniformDistribution(0, 1))); cg.setInputs(input1, input2, input3); cg.setLabels(target); - cg.computeGradientAndScore(); - // Let's figure out what our params are now. Map params = cg.paramTable(); INDArray dense1_W = nullsafe(params.get("dense1_W")); @@ -241,35 +145,22 @@ public class ElementWiseVertexTest extends BaseDL4JTest { INDArray dense3_b = nullsafe(params.get("dense3_b")); INDArray output_W = nullsafe(params.get("output_W")); INDArray output_b = nullsafe(params.get("output_b")); - // Now, let's calculate what we expect the output to be. - INDArray mh = input1.mmul(dense1_W).addi(dense1_b.repmat(batchsz, 1)); INDArray m = (Transforms.tanh(mh)); - INDArray nh = input2.mmul(dense2_W).addi(dense2_b.repmat(batchsz, 1)); INDArray n = (Transforms.tanh(nh)); - INDArray oh = input3.mmul(dense3_W).addi(dense3_b.repmat(batchsz, 1)); INDArray o = (Transforms.tanh(oh)); - INDArray middle = Nd4j.zeros(batchsz, midsz); middle.addi(m).addi(n).addi(o); - - INDArray expect = Nd4j.zeros(batchsz, outputsz); expect.addi(Transforms.sigmoid(middle.mmul(output_W).addi(output_b.repmat(batchsz, 1)))); - - INDArray output = nullsafe(cg.output(input1, input2, input3)[0]); - - Assert.assertEquals(0.0, mse(output, expect), this.epsilon); - + Assertions.assertEquals(0.0, mse(output, expect), this.epsilon); Pair pgd = cg.gradientAndScore(); - double score = pgd.getSecond(); - Assert.assertEquals(score, mse(output, target), this.epsilon); - + Assertions.assertEquals(score, mse(output, target), this.epsilon); Map gradients = pgd.getFirst().gradientForVariable(); /* * So. Let's say we have inputs a, b, c @@ -305,27 +196,23 @@ public class ElementWiseVertexTest extends BaseDL4JTest { * dmh/db1 = Nd4j.ones(1, batchsz) * */ - INDArray y = output; INDArray s = middle; INDArray W4 = output_W; - INDArray dEdy = Nd4j.zeros(target.shape()); - dEdy.addi(y).subi(target).muli(2); // This should be of size batchsz x outputsz - dEdy.divi(target.shape()[1]); // Why? Because the LossFunction divides by the _element size_ of the output. - - INDArray dydyh = y.mul(y.mul(-1).add(1)); // This is of size batchsz x outputsz + // This should be of size batchsz x outputsz + dEdy.addi(y).subi(target).muli(2); + // Why? Because the LossFunction divides by the _element size_ of the output. + dEdy.divi(target.shape()[1]); + // This is of size batchsz x outputsz + INDArray dydyh = y.mul(y.mul(-1).add(1)); INDArray dEdyh = dydyh.mul(dEdy); - INDArray dyhdW4 = s.transpose(); INDArray dEdW4 = nullsafe(dyhdW4.mmul(dEdyh)); - INDArray dyhdb4 = Nd4j.ones(1, batchsz); INDArray dEdb4 = nullsafe(dyhdb4.mmul(dEdyh)); - INDArray dyhds = W4.transpose(); INDArray dEds = dEdyh.mmul(dyhds); - INDArray dsdm = Nd4j.ones(batchsz, midsz); INDArray dEdm = dsdm.mul(dEds); INDArray dmdmh = (m.mul(m)).mul(-1).add(1); @@ -334,7 +221,6 @@ public class ElementWiseVertexTest extends BaseDL4JTest { INDArray dEdW1 = nullsafe(dmhdW1.mmul(dEdmh)); INDArray dmhdb1 = Nd4j.ones(1, batchsz); INDArray dEdb1 = nullsafe(dmhdb1.mmul(dEdmh)); - INDArray dsdn = Nd4j.ones(batchsz, midsz); INDArray dEdn = dsdn.mul(dEds); INDArray dndnh = (n.mul(n)).mul(-1).add(1); @@ -343,7 +229,6 @@ public class ElementWiseVertexTest extends BaseDL4JTest { INDArray dEdW2 = nullsafe(dnhdW2.mmul(dEdnh)); INDArray dnhdb2 = Nd4j.ones(1, batchsz); INDArray dEdb2 = nullsafe(dnhdb2.mmul(dEdnh)); - INDArray dsdo = Nd4j.ones(batchsz, midsz); INDArray dEdo = dsdo.mul(dEds); INDArray dodoh = (o.mul(o)).mul(-1).add(1); @@ -352,61 +237,33 @@ public class ElementWiseVertexTest extends BaseDL4JTest { INDArray dEdW3 = nullsafe(dohdW3.mmul(dEdoh)); INDArray dohdb3 = Nd4j.ones(1, batchsz); INDArray dEdb3 = nullsafe(dohdb3.mmul(dEdoh)); - - - Assert.assertEquals(0, mse(nullsafe(gradients.get("output_W")), dEdW4), this.epsilon); - Assert.assertEquals(0, mse(nullsafe(gradients.get("output_b")), dEdb4), this.epsilon); - Assert.assertEquals(0, mse(nullsafe(gradients.get("dense1_W")), dEdW1), this.epsilon); - Assert.assertEquals(0, mse(nullsafe(gradients.get("dense1_b")), dEdb1), this.epsilon); - Assert.assertEquals(0, mse(nullsafe(gradients.get("dense2_W")), dEdW2), this.epsilon); - Assert.assertEquals(0, mse(nullsafe(gradients.get("dense2_b")), dEdb2), this.epsilon); - Assert.assertEquals(0, mse(nullsafe(gradients.get("dense3_W")), dEdW3), this.epsilon); - Assert.assertEquals(0, mse(nullsafe(gradients.get("dense3_b")), dEdb3), this.epsilon); + Assertions.assertEquals(0, mse(nullsafe(gradients.get("output_W")), dEdW4), this.epsilon); + Assertions.assertEquals(0, mse(nullsafe(gradients.get("output_b")), dEdb4), this.epsilon); + Assertions.assertEquals(0, mse(nullsafe(gradients.get("dense1_W")), dEdW1), this.epsilon); + Assertions.assertEquals(0, mse(nullsafe(gradients.get("dense1_b")), dEdb1), this.epsilon); + Assertions.assertEquals(0, mse(nullsafe(gradients.get("dense2_W")), dEdW2), this.epsilon); + Assertions.assertEquals(0, mse(nullsafe(gradients.get("dense2_b")), dEdb2), this.epsilon); + Assertions.assertEquals(0, mse(nullsafe(gradients.get("dense3_W")), dEdW3), this.epsilon); + Assertions.assertEquals(0, mse(nullsafe(gradients.get("dense3_b")), dEdb3), this.epsilon); } @Test - public void testElementWiseVertexFullProduct() { + @DisplayName("Test Element Wise Vertex Full Product") + void testElementWiseVertexFullProduct() { int batchsz = 24; int featuresz = 17; int midsz = 13; int outputsz = 11; - ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER) - .dataType(DataType.DOUBLE) - .biasInit(0.0).updater(new Sgd()) - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder() - .addInputs("input1", "input2", "input3") - .addLayer("dense1", - new DenseLayer.Builder().nIn(featuresz).nOut(midsz) - .activation(new ActivationTanH()).build(), - "input1") - .addLayer("dense2", - new DenseLayer.Builder().nIn(featuresz).nOut(midsz) - .activation(new ActivationTanH()).build(), - "input2") - .addLayer("dense3", - new DenseLayer.Builder().nIn(featuresz).nOut(midsz) - .activation(new ActivationTanH()).build(), - "input3") - .addVertex("elementwiseProduct", new ElementWiseVertex(ElementWiseVertex.Op.Product), "dense1", - "dense2", "dense3") - .addLayer("output", - new OutputLayer.Builder().nIn(midsz).nOut(outputsz) - .activation(new ActivationSigmoid()) - .lossFunction(LossFunction.MSE).build(), - "elementwiseProduct") - .setOutputs("output").build(); - + ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).dataType(DataType.DOUBLE).biasInit(0.0).updater(new Sgd()).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder().addInputs("input1", "input2", "input3").addLayer("dense1", new DenseLayer.Builder().nIn(featuresz).nOut(midsz).activation(new ActivationTanH()).build(), "input1").addLayer("dense2", new DenseLayer.Builder().nIn(featuresz).nOut(midsz).activation(new ActivationTanH()).build(), "input2").addLayer("dense3", new DenseLayer.Builder().nIn(featuresz).nOut(midsz).activation(new ActivationTanH()).build(), "input3").addVertex("elementwiseProduct", new ElementWiseVertex(ElementWiseVertex.Op.Product), "dense1", "dense2", "dense3").addLayer("output", new OutputLayer.Builder().nIn(midsz).nOut(outputsz).activation(new ActivationSigmoid()).lossFunction(LossFunction.MSE).build(), "elementwiseProduct").setOutputs("output").build(); ComputationGraph cg = new ComputationGraph(cgc); cg.init(); - INDArray input1 = Nd4j.rand(new int[] {batchsz, featuresz}, new UniformDistribution(-1, 1)); - INDArray input2 = Nd4j.rand(new int[] {batchsz, featuresz}, new UniformDistribution(-1, 1)); - INDArray input3 = Nd4j.rand(new int[] {batchsz, featuresz}, new UniformDistribution(-1, 1)); - INDArray target = nullsafe(Nd4j.rand(new int[] {batchsz, outputsz}, new UniformDistribution(0, 1))); + INDArray input1 = Nd4j.rand(new int[] { batchsz, featuresz }, new UniformDistribution(-1, 1)); + INDArray input2 = Nd4j.rand(new int[] { batchsz, featuresz }, new UniformDistribution(-1, 1)); + INDArray input3 = Nd4j.rand(new int[] { batchsz, featuresz }, new UniformDistribution(-1, 1)); + INDArray target = nullsafe(Nd4j.rand(new int[] { batchsz, outputsz }, new UniformDistribution(0, 1))); cg.setInputs(input1, input2, input3); cg.setLabels(target); - cg.computeGradientAndScore(); - // Let's figure out what our params are now. Map params = cg.paramTable(); INDArray dense1_W = nullsafe(params.get("dense1_W")); @@ -417,35 +274,22 @@ public class ElementWiseVertexTest extends BaseDL4JTest { INDArray dense3_b = nullsafe(params.get("dense3_b")); INDArray output_W = nullsafe(params.get("output_W")); INDArray output_b = nullsafe(params.get("output_b")); - // Now, let's calculate what we expect the output to be. - INDArray mh = input1.mmul(dense1_W).addi(dense1_b.repmat(batchsz, 1)); INDArray m = (Transforms.tanh(mh)); - INDArray nh = input2.mmul(dense2_W).addi(dense2_b.repmat(batchsz, 1)); INDArray n = (Transforms.tanh(nh)); - INDArray oh = input3.mmul(dense3_W).addi(dense3_b.repmat(batchsz, 1)); INDArray o = (Transforms.tanh(oh)); - INDArray middle = Nd4j.ones(batchsz, midsz); middle.muli(m).muli(n).muli(o); - - INDArray expect = Nd4j.zeros(batchsz, outputsz); expect.addi(Transforms.sigmoid(middle.mmul(output_W).addi(output_b.repmat(batchsz, 1)))); - - INDArray output = nullsafe(cg.output(input1, input2, input3)[0]); - - Assert.assertEquals(0.0, mse(output, expect), this.epsilon); - + Assertions.assertEquals(0.0, mse(output, expect), this.epsilon); Pair pgd = cg.gradientAndScore(); - double score = pgd.getSecond(); - Assert.assertEquals(score, mse(output, target), this.epsilon); - + Assertions.assertEquals(score, mse(output, target), this.epsilon); Map gradients = pgd.getFirst().gradientForVariable(); /* * So. Let's say we have inputs a, b, c @@ -481,27 +325,23 @@ public class ElementWiseVertexTest extends BaseDL4JTest { * dmh/db1 = Nd4j.ones(1, batchsz) * */ - INDArray y = output; INDArray s = middle; INDArray W4 = output_W; - INDArray dEdy = Nd4j.zeros(target.shape()); - dEdy.addi(y).subi(target).muli(2); // This should be of size batchsz x outputsz - dEdy.divi(target.shape()[1]); // Why? Because the LossFunction divides by the _element size_ of the output. - - INDArray dydyh = y.mul(y.mul(-1).add(1)); // This is of size batchsz x outputsz + // This should be of size batchsz x outputsz + dEdy.addi(y).subi(target).muli(2); + // Why? Because the LossFunction divides by the _element size_ of the output. + dEdy.divi(target.shape()[1]); + // This is of size batchsz x outputsz + INDArray dydyh = y.mul(y.mul(-1).add(1)); INDArray dEdyh = dydyh.mul(dEdy); - INDArray dyhdW4 = s.transpose(); INDArray dEdW4 = nullsafe(dyhdW4.mmul(dEdyh)); - INDArray dyhdb4 = Nd4j.ones(1, batchsz); INDArray dEdb4 = nullsafe(dyhdb4.mmul(dEdyh)); - INDArray dyhds = W4.transpose(); INDArray dEds = dEdyh.mmul(dyhds); - INDArray dsdm = Nd4j.ones(batchsz, midsz).muli(n).muli(o); INDArray dEdm = dsdm.mul(dEds); INDArray dmdmh = (m.mul(m)).mul(-1).add(1); @@ -510,7 +350,6 @@ public class ElementWiseVertexTest extends BaseDL4JTest { INDArray dEdW1 = nullsafe(dmhdW1.mmul(dEdmh)); INDArray dmhdb1 = Nd4j.ones(1, batchsz); INDArray dEdb1 = nullsafe(dmhdb1.mmul(dEdmh)); - INDArray dsdn = Nd4j.ones(batchsz, midsz).muli(m).muli(o); INDArray dEdn = dsdn.mul(dEds); INDArray dndnh = (n.mul(n)).mul(-1).add(1); @@ -519,7 +358,6 @@ public class ElementWiseVertexTest extends BaseDL4JTest { INDArray dEdW2 = nullsafe(dnhdW2.mmul(dEdnh)); INDArray dnhdb2 = Nd4j.ones(1, batchsz); INDArray dEdb2 = nullsafe(dnhdb2.mmul(dEdnh)); - INDArray dsdo = Nd4j.ones(batchsz, midsz).muli(m).muli(n); INDArray dEdo = dsdo.mul(dEds); INDArray dodoh = (o.mul(o)).mul(-1).add(1); @@ -528,55 +366,32 @@ public class ElementWiseVertexTest extends BaseDL4JTest { INDArray dEdW3 = nullsafe(dohdW3.mmul(dEdoh)); INDArray dohdb3 = Nd4j.ones(1, batchsz); INDArray dEdb3 = nullsafe(dohdb3.mmul(dEdoh)); - - Assert.assertEquals(0, mse(nullsafe(gradients.get("output_W")), dEdW4), this.epsilon); - Assert.assertEquals(0, mse(nullsafe(gradients.get("output_b")), dEdb4), this.epsilon); - Assert.assertEquals(0, mse(nullsafe(gradients.get("dense1_W")), dEdW1), this.epsilon); - Assert.assertEquals(0, mse(nullsafe(gradients.get("dense1_b")), dEdb1), this.epsilon); - Assert.assertEquals(0, mse(nullsafe(gradients.get("dense2_W")), dEdW2), this.epsilon); - Assert.assertEquals(0, mse(nullsafe(gradients.get("dense2_b")), dEdb2), this.epsilon); - Assert.assertEquals(0, mse(nullsafe(gradients.get("dense3_W")), dEdW3), this.epsilon); - Assert.assertEquals(0, mse(nullsafe(gradients.get("dense3_b")), dEdb3), this.epsilon); + Assertions.assertEquals(0, mse(nullsafe(gradients.get("output_W")), dEdW4), this.epsilon); + Assertions.assertEquals(0, mse(nullsafe(gradients.get("output_b")), dEdb4), this.epsilon); + Assertions.assertEquals(0, mse(nullsafe(gradients.get("dense1_W")), dEdW1), this.epsilon); + Assertions.assertEquals(0, mse(nullsafe(gradients.get("dense1_b")), dEdb1), this.epsilon); + Assertions.assertEquals(0, mse(nullsafe(gradients.get("dense2_W")), dEdW2), this.epsilon); + Assertions.assertEquals(0, mse(nullsafe(gradients.get("dense2_b")), dEdb2), this.epsilon); + Assertions.assertEquals(0, mse(nullsafe(gradients.get("dense3_W")), dEdW3), this.epsilon); + Assertions.assertEquals(0, mse(nullsafe(gradients.get("dense3_b")), dEdb3), this.epsilon); } @Test - public void testElementWiseVertexFullSubtract() { + @DisplayName("Test Element Wise Vertex Full Subtract") + void testElementWiseVertexFullSubtract() { int batchsz = 24; int featuresz = 17; int midsz = 13; int outputsz = 11; - ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER) - .dataType(DataType.DOUBLE) - .biasInit(0.0).updater(new Sgd()) - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder() - .addInputs("input1", "input2") - .addLayer("dense1", - new DenseLayer.Builder().nIn(featuresz).nOut(midsz) - .activation(new ActivationTanH()).build(), - "input1") - .addLayer("dense2", - new DenseLayer.Builder().nIn(featuresz).nOut(midsz) - .activation(new ActivationTanH()).build(), - "input2") - .addVertex("elementwiseSubtract", new ElementWiseVertex(ElementWiseVertex.Op.Subtract), - "dense1", "dense2") - .addLayer("output", - new OutputLayer.Builder().nIn(midsz).nOut(outputsz) - .activation(new ActivationSigmoid()) - .lossFunction(LossFunction.MSE).build(), - "elementwiseSubtract") - .setOutputs("output").build(); - + ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).dataType(DataType.DOUBLE).biasInit(0.0).updater(new Sgd()).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder().addInputs("input1", "input2").addLayer("dense1", new DenseLayer.Builder().nIn(featuresz).nOut(midsz).activation(new ActivationTanH()).build(), "input1").addLayer("dense2", new DenseLayer.Builder().nIn(featuresz).nOut(midsz).activation(new ActivationTanH()).build(), "input2").addVertex("elementwiseSubtract", new ElementWiseVertex(ElementWiseVertex.Op.Subtract), "dense1", "dense2").addLayer("output", new OutputLayer.Builder().nIn(midsz).nOut(outputsz).activation(new ActivationSigmoid()).lossFunction(LossFunction.MSE).build(), "elementwiseSubtract").setOutputs("output").build(); ComputationGraph cg = new ComputationGraph(cgc); cg.init(); - INDArray input1 = Nd4j.rand(new int[] {batchsz, featuresz}, new UniformDistribution(-1, 1)); - INDArray input2 = Nd4j.rand(new int[] {batchsz, featuresz}, new UniformDistribution(-1, 1)); - INDArray target = nullsafe(Nd4j.rand(new int[] {batchsz, outputsz}, new UniformDistribution(0, 1))); + INDArray input1 = Nd4j.rand(new int[] { batchsz, featuresz }, new UniformDistribution(-1, 1)); + INDArray input2 = Nd4j.rand(new int[] { batchsz, featuresz }, new UniformDistribution(-1, 1)); + INDArray target = nullsafe(Nd4j.rand(new int[] { batchsz, outputsz }, new UniformDistribution(0, 1))); cg.setInputs(input1, input2); cg.setLabels(target); - cg.computeGradientAndScore(); - // Let's figure out what our params are now. Map params = cg.paramTable(); INDArray dense1_W = nullsafe(params.get("dense1_W")); @@ -585,32 +400,20 @@ public class ElementWiseVertexTest extends BaseDL4JTest { INDArray dense2_b = nullsafe(params.get("dense2_b")); INDArray output_W = nullsafe(params.get("output_W")); INDArray output_b = nullsafe(params.get("output_b")); - // Now, let's calculate what we expect the output to be. - INDArray mh = input1.mmul(dense1_W).addi(dense1_b.repmat(batchsz, 1)); INDArray m = (Transforms.tanh(mh)); - INDArray nh = input2.mmul(dense2_W).addi(dense2_b.repmat(batchsz, 1)); INDArray n = (Transforms.tanh(nh)); - INDArray middle = Nd4j.zeros(batchsz, midsz); middle.addi(m).subi(n); - - INDArray expect = Nd4j.zeros(batchsz, outputsz); expect.addi(Transforms.sigmoid(middle.mmul(output_W).addi(output_b.repmat(batchsz, 1)))); - - INDArray output = nullsafe(cg.output(input1, input2)[0]); - - Assert.assertEquals(0.0, mse(output, expect), this.epsilon); - + Assertions.assertEquals(0.0, mse(output, expect), this.epsilon); Pair pgd = cg.gradientAndScore(); - double score = pgd.getSecond(); - Assert.assertEquals(score, mse(output, target), this.epsilon); - + Assertions.assertEquals(score, mse(output, target), this.epsilon); Map gradients = pgd.getFirst().gradientForVariable(); /* * So. Let's say we have inputs a, b, c @@ -644,27 +447,23 @@ public class ElementWiseVertexTest extends BaseDL4JTest { * dmh/db1 = Nd4j.ones(1, batchsz) * */ - INDArray y = output; INDArray s = middle; INDArray W4 = output_W; - INDArray dEdy = Nd4j.zeros(target.shape()); - dEdy.addi(y).subi(target).muli(2); // This should be of size batchsz x outputsz - dEdy.divi(target.shape()[1]); // Why? Because the LossFunction divides by the _element size_ of the output. - - INDArray dydyh = y.mul(y.mul(-1).add(1)); // This is of size batchsz x outputsz + // This should be of size batchsz x outputsz + dEdy.addi(y).subi(target).muli(2); + // Why? Because the LossFunction divides by the _element size_ of the output. + dEdy.divi(target.shape()[1]); + // This is of size batchsz x outputsz + INDArray dydyh = y.mul(y.mul(-1).add(1)); INDArray dEdyh = dydyh.mul(dEdy); - INDArray dyhdW4 = s.transpose(); INDArray dEdW4 = nullsafe(dyhdW4.mmul(dEdyh)); - INDArray dyhdb4 = Nd4j.ones(1, batchsz); INDArray dEdb4 = nullsafe(dyhdb4.mmul(dEdyh)); - INDArray dyhds = W4.transpose(); INDArray dEds = dEdyh.mmul(dyhds); - INDArray dsdm = Nd4j.ones(batchsz, midsz); INDArray dEdm = dsdm.mul(dEds); INDArray dmdmh = (m.mul(m)).mul(-1).add(1); @@ -673,7 +472,6 @@ public class ElementWiseVertexTest extends BaseDL4JTest { INDArray dEdW1 = nullsafe(dmhdW1.mmul(dEdmh)); INDArray dmhdb1 = Nd4j.ones(1, batchsz); INDArray dEdb1 = nullsafe(dmhdb1.mmul(dEdmh)); - INDArray dsdn = Nd4j.ones(batchsz, midsz).muli(-1); INDArray dEdn = dsdn.mul(dEds); INDArray dndnh = (n.mul(n)).mul(-1).add(1); @@ -682,20 +480,16 @@ public class ElementWiseVertexTest extends BaseDL4JTest { INDArray dEdW2 = nullsafe(dnhdW2.mmul(dEdnh)); INDArray dnhdb2 = Nd4j.ones(1, batchsz); INDArray dEdb2 = nullsafe(dnhdb2.mmul(dEdnh)); - - - Assert.assertEquals(0, mse(nullsafe(gradients.get("output_W")), dEdW4), this.epsilon); - Assert.assertEquals(0, mse(nullsafe(gradients.get("output_b")), dEdb4), this.epsilon); - Assert.assertEquals(0, mse(nullsafe(gradients.get("dense1_W")), dEdW1), this.epsilon); - Assert.assertEquals(0, mse(nullsafe(gradients.get("dense1_b")), dEdb1), this.epsilon); - Assert.assertEquals(0, mse(nullsafe(gradients.get("dense2_W")), dEdW2), this.epsilon); - Assert.assertEquals(0, mse(nullsafe(gradients.get("dense2_b")), dEdb2), this.epsilon); + Assertions.assertEquals(0, mse(nullsafe(gradients.get("output_W")), dEdW4), this.epsilon); + Assertions.assertEquals(0, mse(nullsafe(gradients.get("output_b")), dEdb4), this.epsilon); + Assertions.assertEquals(0, mse(nullsafe(gradients.get("dense1_W")), dEdW1), this.epsilon); + Assertions.assertEquals(0, mse(nullsafe(gradients.get("dense1_b")), dEdb1), this.epsilon); + Assertions.assertEquals(0, mse(nullsafe(gradients.get("dense2_W")), dEdW2), this.epsilon); + Assertions.assertEquals(0, mse(nullsafe(gradients.get("dense2_b")), dEdb2), this.epsilon); } - private static double mse(INDArray output, INDArray target) { - double mse_expect = Transforms.pow(output.sub(target), 2.0).sumNumber().doubleValue() - / (output.columns() * output.rows()); + double mse_expect = Transforms.pow(output.sub(target), 2.0).sumNumber().doubleValue() / (output.columns() * output.rows()); return mse_expect; } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/graph/ShiftVertexTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/graph/ShiftVertexTest.java index 7b8e90419..3db72291b 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/graph/ShiftVertexTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/graph/ShiftVertexTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.conf.graph; import org.deeplearning4j.BaseDL4JTest; @@ -30,8 +29,8 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.Assert; -import org.junit.Test; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.activations.impl.ActivationSigmoid; @@ -42,86 +41,70 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; import org.nd4j.common.primitives.Pair; - import java.util.Map; import java.util.TreeMap; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; + +@DisplayName("Shift Vertex Test") +class ShiftVertexTest extends BaseDL4JTest { -public class ShiftVertexTest extends BaseDL4JTest { @Test - public void testShiftVertexNumParamsTrue() { + @DisplayName("Test Shift Vertex Num Params True") + void testShiftVertexNumParamsTrue() { /* * https://github.com/eclipse/deeplearning4j/pull/3514#issuecomment-307754386 * from @agibsonccc: check for the basics: like 0 numParams */ - - ShiftVertex sv = new ShiftVertex(0.7); // The 0.7 doesn't really matter. - Assert.assertEquals(0, sv.numParams(true)); - } - - @Test - public void testShiftVertexNumParamsFalse() { - /* - * https://github.com/eclipse/deeplearning4j/pull/3514#issuecomment-307754386 - * from @agibsonccc: check for the basics: like 0 numParams - */ - - ShiftVertex sv = new ShiftVertex(0.7); // The 0.7 doesn't really matter. - Assert.assertEquals(0, sv.numParams(false)); - } - - @Test - public void testGet() { + // The 0.7 doesn't really matter. ShiftVertex sv = new ShiftVertex(0.7); - Assert.assertEquals(0.7, sv.getShiftFactor(), this.epsilon); + Assertions.assertEquals(0, sv.numParams(true)); } @Test - public void testSimple() { + @DisplayName("Test Shift Vertex Num Params False") + void testShiftVertexNumParamsFalse() { + /* + * https://github.com/eclipse/deeplearning4j/pull/3514#issuecomment-307754386 + * from @agibsonccc: check for the basics: like 0 numParams + */ + // The 0.7 doesn't really matter. + ShiftVertex sv = new ShiftVertex(0.7); + Assertions.assertEquals(0, sv.numParams(false)); + } + + @Test + @DisplayName("Test Get") + void testGet() { + ShiftVertex sv = new ShiftVertex(0.7); + Assertions.assertEquals(0.7, sv.getShiftFactor(), this.epsilon); + } + + @Test + @DisplayName("Test Simple") + void testSimple() { /* * This function _simply_ tests whether ShiftVertex is _in fact_ adding the shift value to it's inputs. */ // Just first n primes / 10. - INDArray input = Nd4j - .create(new double[][] {{0.2, 0.3, 0.5}, {0.7, 1.1, 1.3}, {1.7, 1.9, 2.3}, {2.9, 3.1, 3.7}}); + INDArray input = Nd4j.create(new double[][] { { 0.2, 0.3, 0.5 }, { 0.7, 1.1, 1.3 }, { 1.7, 1.9, 2.3 }, { 2.9, 3.1, 3.7 } }); double sf = 4.1; - ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("input") - .addLayer("denselayer", - new DenseLayer.Builder().nIn(input.columns()).nOut(1) - .activation(Activation.IDENTITY).build(), - "input") - /* denselayer is not actually used, but it seems that you _need_ to have trainable parameters, otherwise, you get - * Invalid shape: Requested INDArray shape [1, 0] contains dimension size values < 1 (all dimensions must be 1 or more) - * at org.nd4j.linalg.factory.Nd4j.checkShapeValues(Nd4j.java:4877) - * at org.nd4j.linalg.factory.Nd4j.create(Nd4j.java:4867) - * at org.nd4j.linalg.factory.Nd4j.create(Nd4j.java:4820) - * at org.nd4j.linalg.factory.Nd4j.create(Nd4j.java:3948) - * at org.deeplearning4j.nn.graph.ComputationGraph.init(ComputationGraph.java:409) - * at org.deeplearning4j.nn.graph.ComputationGraph.init(ComputationGraph.java:341) - */ - .addLayer("identityinputactivation", - new ActivationLayer.Builder().activation(Activation.IDENTITY).build(), "input") - .addVertex("shiftvertex", new ShiftVertex(sf), "identityinputactivation") - .addLayer("identityshiftvertex", - new ActivationLayer.Builder().activation(Activation.IDENTITY).build(), - "shiftvertex") - .setOutputs("identityshiftvertex", "denselayer").build(); - + ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("input").addLayer("denselayer", new DenseLayer.Builder().nIn(input.columns()).nOut(1).activation(Activation.IDENTITY).build(), "input").addLayer("identityinputactivation", new ActivationLayer.Builder().activation(Activation.IDENTITY).build(), "input").addVertex("shiftvertex", new ShiftVertex(sf), "identityinputactivation").addLayer("identityshiftvertex", new ActivationLayer.Builder().activation(Activation.IDENTITY).build(), "shiftvertex").setOutputs("identityshiftvertex", "denselayer").build(); ComputationGraph cg = new ComputationGraph(cgc); cg.init(); - // We can call outputSingle, because we only have a single output layer. It has nothing to do with minibatches. INDArray output = cg.output(true, input)[0]; INDArray target = Nd4j.zeros(input.shape()); target.addi(input); target.addi(sf); - INDArray squared = output.sub(target); double rms = squared.mul(squared).sumNumber().doubleValue(); - Assert.assertEquals(0.0, rms, this.epsilon); + Assertions.assertEquals(0.0, rms, this.epsilon); } @Test - public void testComprehensive() { + @DisplayName("Test Comprehensive") + void testComprehensive() { /* * This function tests ShiftVertex more comprehensively. Specifically, it verifies that the lossfunction works as * expected on a ComputationGraph _with_ a ShiftVertex and it verifies that the derivatives produced by @@ -130,29 +113,12 @@ public class ShiftVertexTest extends BaseDL4JTest { BaseActivationFunction a1 = new ActivationTanH(); BaseActivationFunction a2 = new ActivationSigmoid(); // Just first n primes / 10. - INDArray input = Nd4j - .create(new double[][] {{0.2, 0.3, 0.5}, {0.7, 1.1, 1.3}, {1.7, 1.9, 2.3}, {2.9, 3.1, 3.7}}); + INDArray input = Nd4j.create(new double[][] { { 0.2, 0.3, 0.5 }, { 0.7, 1.1, 1.3 }, { 1.7, 1.9, 2.3 }, { 2.9, 3.1, 3.7 } }); double sf = 4.1; // Actually, given that I'm using a sigmoid on the output, // these should really be between 0 and 1 - INDArray target = Nd4j.create(new double[][] {{0.05, 0.10, 0.15, 0.20, 0.25}, {0.30, 0.35, 0.40, 0.45, 0.50}, - {0.55, 0.60, 0.65, 0.70, 0.75}, {0.80, 0.85, 0.90, 0.95, 0.99}}); - - ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER) - .dataType(DataType.DOUBLE) - .updater(new Sgd(0.01)) - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder() - .addInputs("input") - .addLayer("denselayer", - new DenseLayer.Builder().nIn(input.columns()).nOut(input.columns()) - .activation(a1).build(), - "input") - .addVertex("shiftvertex", new ShiftVertex(sf), "denselayer") - .addLayer("output", - new OutputLayer.Builder().nIn(input.columns()).nOut(target.columns()) - .activation(a2).lossFunction(LossFunction.MSE).build(), - "shiftvertex") - .setOutputs("output").build(); + INDArray target = Nd4j.create(new double[][] { { 0.05, 0.10, 0.15, 0.20, 0.25 }, { 0.30, 0.35, 0.40, 0.45, 0.50 }, { 0.55, 0.60, 0.65, 0.70, 0.75 }, { 0.80, 0.85, 0.90, 0.95, 0.99 } }); + ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).dataType(DataType.DOUBLE).updater(new Sgd(0.01)).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder().addInputs("input").addLayer("denselayer", new DenseLayer.Builder().nIn(input.columns()).nOut(input.columns()).activation(a1).build(), "input").addVertex("shiftvertex", new ShiftVertex(sf), "denselayer").addLayer("output", new OutputLayer.Builder().nIn(input.columns()).nOut(target.columns()).activation(a2).lossFunction(LossFunction.MSE).build(), "shiftvertex").setOutputs("output").build(); ComputationGraph cg = new ComputationGraph(cgc); cg.init(); cg.setInput(0, input); @@ -163,26 +129,23 @@ public class ShiftVertexTest extends BaseDL4JTest { Gradient g = cg.gradient(); Map gradients = g.gradientForVariable(); Map manual_gradients = new TreeMap(); - INDArray W = nullsafe(weights.get("denselayer_W")); INDArray b = nullsafe(weights.get("denselayer_b")); INDArray V = nullsafe(weights.get("output_W")); INDArray c = nullsafe(weights.get("output_b")); - Map manual_weights = new TreeMap(); manual_weights.put("denselayer_W", W); manual_weights.put("denselayer_b", b); manual_weights.put("output_W", V); manual_weights.put("output_b", c); - // First things first, let's calculate the score. long batchsz = input.shape()[0]; INDArray z = input.castTo(W.dataType()).mmul(W).add(b.repmat(batchsz, 1)); - INDArray a = a1.getActivation(z.dup(), true).add(sf); // activation modifies it's input!! + // activation modifies it's input!! + INDArray a = a1.getActivation(z.dup(), true).add(sf); INDArray q = a.mmul(V).add(c.repmat(batchsz, 1)); INDArray o = nullsafe(a2.getActivation(q.dup(), true)); double score_manual = sum_errors(o, target) / (o.columns() * o.rows()); - /* * So. We have * z5 = input1 * W15 + input2 * W25 + input3 * W35 + b5 @@ -197,12 +160,15 @@ public class ShiftVertexTest extends BaseDL4JTest { * dq1/dv11 = a1 dq2/dV12 = a1 dq3/dV13 = a1 ... * dq1/dv21 = a2 dq2... */ - INDArray dEdo = target.like(); //Nd4j.zeros(target.shape()); - dEdo.addi(o.castTo(dEdo.dataType())).subi(target).muli(2); // This should be of size batchsz x outputsz - dEdo.divi(target.shape()[1]); // Why? Because the LossFunction divides by the _element size_ of the output. - + // Nd4j.zeros(target.shape()); + INDArray dEdo = target.like(); + // This should be of size batchsz x outputsz + dEdo.addi(o.castTo(dEdo.dataType())).subi(target).muli(2); + // Why? Because the LossFunction divides by the _element size_ of the output. + dEdo.divi(target.shape()[1]); Pair derivs2 = a2.backprop(q, dEdo); - INDArray dEdq = derivs2.getFirst(); // This should be of size batchsz x outputsz (dE/do * do/dq) this _should_ be o * (1-o) * dE/do for Sigmoid. + // This should be of size batchsz x outputsz (dE/do * do/dq) this _should_ be o * (1-o) * dE/do for Sigmoid. + INDArray dEdq = derivs2.getFirst(); // Should be o = q^3 do/dq = 3 q^2 for Cube. /* INDArray dodq = q.mul(q).mul(3); @@ -213,26 +179,23 @@ public class ShiftVertexTest extends BaseDL4JTest { System.err.println(tbv); System.err.println(dEdq); */ - INDArray dqdc = Nd4j.ones(1, batchsz); - INDArray dEdc = dqdc.mmul(dEdq); // This should be of size 1 x outputsz + // This should be of size 1 x outputsz + INDArray dEdc = dqdc.mmul(dEdq); INDArray dEdV = a.transpose().mmul(dEdq); - INDArray dEda = dEdq.mmul(V.transpose()); // This should be dEdo * dodq * dqda - + // This should be dEdo * dodq * dqda + INDArray dEda = dEdq.mmul(V.transpose()); Pair derivs1 = a1.backprop(z, dEda); INDArray dEdz = derivs1.getFirst(); INDArray dzdb = Nd4j.ones(1, batchsz); INDArray dEdb = dzdb.mmul(dEdz); INDArray dEdW = input.transpose().mmul(dEdz); - manual_gradients.put("output_b", dEdc); manual_gradients.put("output_W", dEdV); manual_gradients.put("denselayer_b", dEdb); manual_gradients.put("denselayer_W", dEdW); - double summse = Math.pow((score_manual - score_dl4j), 2); int denominator = 1; - for (Map.Entry mesi : gradients.entrySet()) { String name = mesi.getKey(); INDArray dl4j_gradient = nullsafe(mesi.getValue()); @@ -241,9 +204,7 @@ public class ShiftVertexTest extends BaseDL4JTest { summse += se; denominator += dl4j_gradient.columns() * dl4j_gradient.rows(); } - - Assert.assertEquals(0.0, summse / denominator, this.epsilon); - + Assertions.assertEquals(0.0, summse / denominator, this.epsilon); } private static double sum_errors(INDArray a, INDArray b) { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerBuilderTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerBuilderTest.java index be3475631..807a3861a 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerBuilderTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerBuilderTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.conf.layers; import org.deeplearning4j.BaseDL4JTest; @@ -25,7 +24,7 @@ import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.dropout.Dropout; import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.activations.impl.ActivationSoftmax; @@ -34,45 +33,62 @@ import org.nd4j.linalg.convolution.Convolution; import org.nd4j.linalg.learning.config.AdaGrad; import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; - import java.io.*; - -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** * @author Jeffrey Tang. */ -public class LayerBuilderTest extends BaseDL4JTest { +@DisplayName("Layer Builder Test") +class LayerBuilderTest extends BaseDL4JTest { + final double DELTA = 1e-15; int numIn = 10; + int numOut = 5; + double drop = 0.3; + IActivation act = new ActivationSoftmax(); + PoolingType poolType = PoolingType.MAX; - int[] kernelSize = new int[] {2, 2}; - int[] stride = new int[] {2, 2}; - int[] padding = new int[] {1, 1}; + + int[] kernelSize = new int[] { 2, 2 }; + + int[] stride = new int[] { 2, 2 }; + + int[] padding = new int[] { 1, 1 }; + int k = 1; + Convolution.Type convType = Convolution.Type.VALID; + LossFunction loss = LossFunction.MCXENT; + WeightInit weight = WeightInit.XAVIER; + double corrupt = 0.4; + double sparsity = 0.3; + double corruptionLevel = 0.5; + double dropOut = 0.1; + IUpdater updater = new AdaGrad(); + GradientNormalization gradNorm = GradientNormalization.ClipL2PerParamType; + double gradNormThreshold = 8; @Test - public void testLayer() throws Exception { - DenseLayer layer = new DenseLayer.Builder().activation(act).weightInit(weight).dropOut(dropOut) - .updater(updater).gradientNormalization(gradNorm) - .gradientNormalizationThreshold(gradNormThreshold).build(); - + @DisplayName("Test Layer") + void testLayer() throws Exception { + DenseLayer layer = new DenseLayer.Builder().activation(act).weightInit(weight).dropOut(dropOut).updater(updater).gradientNormalization(gradNorm).gradientNormalizationThreshold(gradNormThreshold).build(); checkSerialization(layer); - assertEquals(act, layer.getActivationFn()); assertEquals(weight.getWeightInitFunction(), layer.getWeightInitFn()); assertEquals(new Dropout(dropOut), layer.getIDropout()); @@ -82,34 +98,30 @@ public class LayerBuilderTest extends BaseDL4JTest { } @Test - public void testFeedForwardLayer() throws Exception { + @DisplayName("Test Feed Forward Layer") + void testFeedForwardLayer() throws Exception { DenseLayer ff = new DenseLayer.Builder().nIn(numIn).nOut(numOut).build(); - checkSerialization(ff); - assertEquals(numIn, ff.getNIn()); assertEquals(numOut, ff.getNOut()); } @Test - public void testConvolutionLayer() throws Exception { + @DisplayName("Test Convolution Layer") + void testConvolutionLayer() throws Exception { ConvolutionLayer conv = new ConvolutionLayer.Builder(kernelSize, stride, padding).build(); - checkSerialization(conv); - - // assertEquals(convType, conv.getConvolutionType()); + // assertEquals(convType, conv.getConvolutionType()); assertArrayEquals(kernelSize, conv.getKernelSize()); assertArrayEquals(stride, conv.getStride()); assertArrayEquals(padding, conv.getPadding()); } @Test - public void testSubsamplingLayer() throws Exception { - SubsamplingLayer sample = - new SubsamplingLayer.Builder(poolType, stride).kernelSize(kernelSize).padding(padding).build(); - + @DisplayName("Test Subsampling Layer") + void testSubsamplingLayer() throws Exception { + SubsamplingLayer sample = new SubsamplingLayer.Builder(poolType, stride).kernelSize(kernelSize).padding(padding).build(); checkSerialization(sample); - assertArrayEquals(padding, sample.getPadding()); assertArrayEquals(kernelSize, sample.getKernelSize()); assertEquals(poolType, sample.getPoolingType()); @@ -117,36 +129,33 @@ public class LayerBuilderTest extends BaseDL4JTest { } @Test - public void testOutputLayer() throws Exception { + @DisplayName("Test Output Layer") + void testOutputLayer() throws Exception { OutputLayer out = new OutputLayer.Builder(loss).build(); - checkSerialization(out); } @Test - public void testRnnOutputLayer() throws Exception { + @DisplayName("Test Rnn Output Layer") + void testRnnOutputLayer() throws Exception { RnnOutputLayer out = new RnnOutputLayer.Builder(loss).build(); - checkSerialization(out); } @Test - public void testAutoEncoder() throws Exception { + @DisplayName("Test Auto Encoder") + void testAutoEncoder() throws Exception { AutoEncoder enc = new AutoEncoder.Builder().corruptionLevel(corruptionLevel).sparsity(sparsity).build(); - checkSerialization(enc); - assertEquals(corruptionLevel, enc.getCorruptionLevel(), DELTA); assertEquals(sparsity, enc.getSparsity(), DELTA); } @Test - public void testGravesLSTM() throws Exception { - GravesLSTM glstm = new GravesLSTM.Builder().forgetGateBiasInit(1.5).activation(Activation.TANH).nIn(numIn) - .nOut(numOut).build(); - + @DisplayName("Test Graves LSTM") + void testGravesLSTM() throws Exception { + GravesLSTM glstm = new GravesLSTM.Builder().forgetGateBiasInit(1.5).activation(Activation.TANH).nIn(numIn).nOut(numOut).build(); checkSerialization(glstm); - assertEquals(glstm.getForgetGateBiasInit(), 1.5, 0.0); assertEquals(glstm.nIn, numIn); assertEquals(glstm.nOut, numOut); @@ -154,12 +163,10 @@ public class LayerBuilderTest extends BaseDL4JTest { } @Test - public void testGravesBidirectionalLSTM() throws Exception { - final GravesBidirectionalLSTM glstm = new GravesBidirectionalLSTM.Builder().forgetGateBiasInit(1.5) - .activation(Activation.TANH).nIn(numIn).nOut(numOut).build(); - + @DisplayName("Test Graves Bidirectional LSTM") + void testGravesBidirectionalLSTM() throws Exception { + final GravesBidirectionalLSTM glstm = new GravesBidirectionalLSTM.Builder().forgetGateBiasInit(1.5).activation(Activation.TANH).nIn(numIn).nOut(numOut).build(); checkSerialization(glstm); - assertEquals(1.5, glstm.getForgetGateBiasInit(), 0.0); assertEquals(glstm.nIn, numIn); assertEquals(glstm.nOut, numOut); @@ -167,21 +174,19 @@ public class LayerBuilderTest extends BaseDL4JTest { } @Test - public void testEmbeddingLayer() throws Exception { + @DisplayName("Test Embedding Layer") + void testEmbeddingLayer() throws Exception { EmbeddingLayer el = new EmbeddingLayer.Builder().nIn(10).nOut(5).build(); checkSerialization(el); - assertEquals(10, el.getNIn()); assertEquals(5, el.getNOut()); } @Test - public void testBatchNormLayer() throws Exception { - BatchNormalization bN = new BatchNormalization.Builder().nIn(numIn).nOut(numOut).gamma(2).beta(1).decay(0.5) - .lockGammaBeta(true).build(); - + @DisplayName("Test Batch Norm Layer") + void testBatchNormLayer() throws Exception { + BatchNormalization bN = new BatchNormalization.Builder().nIn(numIn).nOut(numOut).gamma(2).beta(1).decay(0.5).lockGammaBeta(true).build(); checkSerialization(bN); - assertEquals(numIn, bN.nIn); assertEquals(numOut, bN.nOut); assertEquals(true, bN.isLockGammaBeta()); @@ -191,42 +196,38 @@ public class LayerBuilderTest extends BaseDL4JTest { } @Test - public void testActivationLayer() throws Exception { + @DisplayName("Test Activation Layer") + void testActivationLayer() throws Exception { ActivationLayer activationLayer = new ActivationLayer.Builder().activation(act).build(); - checkSerialization(activationLayer); - assertEquals(act, activationLayer.activationFn); } private void checkSerialization(Layer layer) throws Exception { NeuralNetConfiguration confExpected = new NeuralNetConfiguration.Builder().layer(layer).build(); NeuralNetConfiguration confActual; - // check Java serialization byte[] data; - try (ByteArrayOutputStream bos = new ByteArrayOutputStream(); ObjectOutput out = new ObjectOutputStream(bos)) { + try (ByteArrayOutputStream bos = new ByteArrayOutputStream(); + ObjectOutput out = new ObjectOutputStream(bos)) { out.writeObject(confExpected); data = bos.toByteArray(); } - try (ByteArrayInputStream bis = new ByteArrayInputStream(data); ObjectInput in = new ObjectInputStream(bis)) { + try (ByteArrayInputStream bis = new ByteArrayInputStream(data); + ObjectInput in = new ObjectInputStream(bis)) { confActual = (NeuralNetConfiguration) in.readObject(); } - assertEquals("unequal Java serialization", confExpected.getLayer(), confActual.getLayer()); - + assertEquals(confExpected.getLayer(), confActual.getLayer(), "unequal Java serialization"); // check JSON String json = confExpected.toJson(); confActual = NeuralNetConfiguration.fromJson(json); - assertEquals("unequal JSON serialization", confExpected.getLayer(), confActual.getLayer()); - + assertEquals(confExpected.getLayer(), confActual.getLayer(), "unequal JSON serialization"); // check YAML String yaml = confExpected.toYaml(); confActual = NeuralNetConfiguration.fromYaml(yaml); - assertEquals("unequal YAML serialization", confExpected.getLayer(), confActual.getLayer()); - + assertEquals(confExpected.getLayer(), confActual.getLayer(), "unequal YAML serialization"); // check the layer's use of callSuper on equals method confActual.getLayer().setIDropout(new Dropout(new java.util.Random().nextDouble())); - assertNotEquals("broken equals method (missing callSuper?)", confExpected.getLayer(), confActual.getLayer()); + assertNotEquals(confExpected.getLayer(), confActual.getLayer(), "broken equals method (missing callSuper?)"); } - } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerConfigTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerConfigTest.java index d9316e37a..71d867701 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerConfigTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerConfigTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.conf.layers; import org.deeplearning4j.BaseDL4JTest; @@ -30,7 +29,7 @@ import org.deeplearning4j.nn.conf.distribution.UniformDistribution; import org.deeplearning4j.nn.conf.dropout.Dropout; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInitDistribution; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.learning.config.AdaDelta; import org.nd4j.linalg.learning.config.Adam; @@ -38,89 +37,170 @@ import org.nd4j.linalg.learning.config.Nesterovs; import org.nd4j.linalg.learning.config.RmsProp; import org.nd4j.linalg.schedule.MapSchedule; import org.nd4j.linalg.schedule.ScheduleType; - import java.util.HashMap; import java.util.Map; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +/* + @Test + public void testLearningRatePolicyExponential() { + double lr = 2; + double lrDecayRate = 5; + int iterations = 1; + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().learningRate(lr) + .updater(Updater.SGD) + .learningRateDecayPolicy(LearningRatePolicy.Exponential).lrPolicyDecayRate(lrDecayRate).list() + .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()) + .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build(); + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); -public class LayerConfigTest extends BaseDL4JTest { + assertEquals(LearningRatePolicy.Exponential, conf.getConf(0).getLearningRatePolicy()); + assertEquals(LearningRatePolicy.Exponential, conf.getConf(1).getLearningRatePolicy()); + assertEquals(lrDecayRate, conf.getConf(0).getLrPolicyDecayRate(), 0.0); + assertEquals(lrDecayRate, conf.getConf(1).getLrPolicyDecayRate(), 0.0); + } @Test - public void testLayerName() { + public void testLearningRatePolicyInverse() { + double lr = 2; + double lrDecayRate = 5; + double power = 3; + int iterations = 1; + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().iterations(iterations).learningRate(lr) + .learningRateDecayPolicy(LearningRatePolicy.Inverse).lrPolicyDecayRate(lrDecayRate) + .lrPolicyPower(power).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()) + .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build(); + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + assertEquals(LearningRatePolicy.Inverse, conf.getConf(0).getLearningRatePolicy()); + assertEquals(LearningRatePolicy.Inverse, conf.getConf(1).getLearningRatePolicy()); + assertEquals(lrDecayRate, conf.getConf(0).getLrPolicyDecayRate(), 0.0); + assertEquals(lrDecayRate, conf.getConf(1).getLrPolicyDecayRate(), 0.0); + assertEquals(power, conf.getConf(0).getLrPolicyPower(), 0.0); + assertEquals(power, conf.getConf(1).getLrPolicyPower(), 0.0); + } + + + @Test + public void testLearningRatePolicySteps() { + double lr = 2; + double lrDecayRate = 5; + double steps = 4; + int iterations = 1; + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().iterations(iterations).learningRate(lr) + .learningRateDecayPolicy(LearningRatePolicy.Step).lrPolicyDecayRate(lrDecayRate) + .lrPolicySteps(steps).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()) + .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build(); + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + + assertEquals(LearningRatePolicy.Step, conf.getConf(0).getLearningRatePolicy()); + assertEquals(LearningRatePolicy.Step, conf.getConf(1).getLearningRatePolicy()); + assertEquals(lrDecayRate, conf.getConf(0).getLrPolicyDecayRate(), 0.0); + assertEquals(lrDecayRate, conf.getConf(1).getLrPolicyDecayRate(), 0.0); + assertEquals(steps, conf.getConf(0).getLrPolicySteps(), 0.0); + assertEquals(steps, conf.getConf(1).getLrPolicySteps(), 0.0); + } + + @Test + public void testLearningRatePolicyPoly() { + double lr = 2; + double lrDecayRate = 5; + double power = 3; + int iterations = 1; + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().iterations(iterations).learningRate(lr) + .learningRateDecayPolicy(LearningRatePolicy.Poly).lrPolicyDecayRate(lrDecayRate) + .lrPolicyPower(power).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()) + .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build(); + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + + assertEquals(LearningRatePolicy.Poly, conf.getConf(0).getLearningRatePolicy()); + assertEquals(LearningRatePolicy.Poly, conf.getConf(1).getLearningRatePolicy()); + assertEquals(lrDecayRate, conf.getConf(0).getLrPolicyDecayRate(), 0.0); + assertEquals(lrDecayRate, conf.getConf(1).getLrPolicyDecayRate(), 0.0); + assertEquals(power, conf.getConf(0).getLrPolicyPower(), 0.0); + assertEquals(power, conf.getConf(1).getLrPolicyPower(), 0.0); + } + + @Test + public void testLearningRatePolicySigmoid() { + double lr = 2; + double lrDecayRate = 5; + double steps = 4; + int iterations = 1; + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().iterations(iterations).learningRate(lr) + .learningRateDecayPolicy(LearningRatePolicy.Sigmoid).lrPolicyDecayRate(lrDecayRate) + .lrPolicySteps(steps).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()) + .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build(); + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + + assertEquals(LearningRatePolicy.Sigmoid, conf.getConf(0).getLearningRatePolicy()); + assertEquals(LearningRatePolicy.Sigmoid, conf.getConf(1).getLearningRatePolicy()); + assertEquals(lrDecayRate, conf.getConf(0).getLrPolicyDecayRate(), 0.0); + assertEquals(lrDecayRate, conf.getConf(1).getLrPolicyDecayRate(), 0.0); + assertEquals(steps, conf.getConf(0).getLrPolicySteps(), 0.0); + assertEquals(steps, conf.getConf(1).getLrPolicySteps(), 0.0); + } + +*/ +@DisplayName("Layer Config Test") +class LayerConfigTest extends BaseDL4JTest { + + @Test + @DisplayName("Test Layer Name") + void testLayerName() { String name1 = "genisys"; String name2 = "bill"; - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list() - .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).name(name1).build()) - .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).name(name2).build()).build(); + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).name(name1).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).name(name2).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - assertEquals(name1, conf.getConf(0).getLayer().getLayerName()); assertEquals(name2, conf.getConf(1).getLayer().getLayerName()); - } @Test - public void testActivationLayerwiseOverride() { - //Without layerwise override: - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().activation(Activation.RELU).list() - .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()) - .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build(); + @DisplayName("Test Activation Layerwise Override") + void testActivationLayerwiseOverride() { + // Without layerwise override: + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().activation(Activation.RELU).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - - assertEquals("relu", ((BaseLayer) conf.getConf(0).getLayer()).getActivationFn().toString()); - assertEquals("relu", ((BaseLayer) conf.getConf(1).getLayer()).getActivationFn().toString()); - - //With - conf = new NeuralNetConfiguration.Builder().activation(Activation.RELU).list() - .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()) - .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).activation(Activation.TANH).build()).build(); - + assertEquals(((BaseLayer) conf.getConf(0).getLayer()).getActivationFn().toString(), "relu"); + assertEquals(((BaseLayer) conf.getConf(1).getLayer()).getActivationFn().toString(), "relu"); + // With + conf = new NeuralNetConfiguration.Builder().activation(Activation.RELU).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).activation(Activation.TANH).build()).build(); net = new MultiLayerNetwork(conf); net.init(); - - assertEquals("relu", ((BaseLayer) conf.getConf(0).getLayer()).getActivationFn().toString()); - assertEquals("tanh", ((BaseLayer) conf.getConf(1).getLayer()).getActivationFn().toString()); + assertEquals(((BaseLayer) conf.getConf(0).getLayer()).getActivationFn().toString(), "relu"); + assertEquals(((BaseLayer) conf.getConf(1).getLayer()).getActivationFn().toString(), "tanh"); } - @Test - public void testWeightBiasInitLayerwiseOverride() { - //Without layerwise override: + @DisplayName("Test Weight Bias Init Layerwise Override") + void testWeightBiasInitLayerwiseOverride() { + // Without layerwise override: final Distribution defaultDistribution = new NormalDistribution(0, 1.0); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .dist(defaultDistribution).biasInit(1).list() - .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()) - .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build(); + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dist(defaultDistribution).biasInit(1).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayer) conf.getConf(0).getLayer()).getWeightInitFn()); assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayer) conf.getConf(1).getLayer()).getWeightInitFn()); - assertEquals(1, ((BaseLayer) conf.getConf(0).getLayer()).getBiasInit(), 0.0); assertEquals(1, ((BaseLayer) conf.getConf(1).getLayer()).getBiasInit(), 0.0); - - //With: + // With: final Distribution overriddenDistribution = new UniformDistribution(0, 1); - conf = new NeuralNetConfiguration.Builder() - .dist(defaultDistribution).biasInit(1).list() - .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, - new DenseLayer.Builder().nIn(2).nOut(2) - .dist(overriddenDistribution).biasInit(0).build()) - .build(); - + conf = new NeuralNetConfiguration.Builder().dist(defaultDistribution).biasInit(1).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).dist(overriddenDistribution).biasInit(0).build()).build(); net = new MultiLayerNetwork(conf); net.init(); - assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayer) conf.getConf(0).getLayer()).getWeightInitFn()); assertEquals(new WeightInitDistribution(overriddenDistribution), ((BaseLayer) conf.getConf(1).getLayer()).getWeightInitFn()); - assertEquals(1, ((BaseLayer) conf.getConf(0).getLayer()).getBiasInit(), 0.0); assertEquals(0, ((BaseLayer) conf.getConf(1).getLayer()).getBiasInit(), 0.0); } @@ -176,101 +256,65 @@ public class LayerConfigTest extends BaseDL4JTest { assertEquals(0.2, ((BaseLayer) conf.getConf(0).getLayer()).getL2(), 0.0); assertEquals(0.8, ((BaseLayer) conf.getConf(1).getLayer()).getL2(), 0.0); }*/ - - - @Test - public void testDropoutLayerwiseOverride() { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dropOut(1.0).list() - .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()) - .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build(); + @DisplayName("Test Dropout Layerwise Override") + void testDropoutLayerwiseOverride() { + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dropOut(1.0).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - assertEquals(new Dropout(1.0), conf.getConf(0).getLayer().getIDropout()); assertEquals(new Dropout(1.0), conf.getConf(1).getLayer().getIDropout()); - - conf = new NeuralNetConfiguration.Builder().dropOut(1.0).list() - .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()) - .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).dropOut(2.0).build()).build(); - + conf = new NeuralNetConfiguration.Builder().dropOut(1.0).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).dropOut(2.0).build()).build(); net = new MultiLayerNetwork(conf); net.init(); - assertEquals(new Dropout(1.0), conf.getConf(0).getLayer().getIDropout()); assertEquals(new Dropout(2.0), conf.getConf(1).getLayer().getIDropout()); } @Test - public void testMomentumLayerwiseOverride() { + @DisplayName("Test Momentum Layerwise Override") + void testMomentumLayerwiseOverride() { Map testMomentumAfter = new HashMap<>(); testMomentumAfter.put(0, 0.1); - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .updater(new Nesterovs(1.0, new MapSchedule(ScheduleType.ITERATION, testMomentumAfter))) - .list() - .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()) - .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build(); + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Nesterovs(1.0, new MapSchedule(ScheduleType.ITERATION, testMomentumAfter))).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - - assertEquals(0.1, ((Nesterovs)((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0); - assertEquals(0.1, ((Nesterovs)((BaseLayer) conf.getConf(1).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0); - + assertEquals(0.1, ((Nesterovs) ((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0, 0), 0.0); + assertEquals(0.1, ((Nesterovs) ((BaseLayer) conf.getConf(1).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0, 0), 0.0); Map testMomentumAfter2 = new HashMap<>(); testMomentumAfter2.put(0, 0.2); - - conf = new NeuralNetConfiguration.Builder().updater(new Nesterovs(1.0, new MapSchedule(ScheduleType.ITERATION, testMomentumAfter) )) - .list() - .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder() - .nIn(2).nOut(2).updater(new Nesterovs(1.0, new MapSchedule(ScheduleType.ITERATION, testMomentumAfter2))).build()) - .build(); - + conf = new NeuralNetConfiguration.Builder().updater(new Nesterovs(1.0, new MapSchedule(ScheduleType.ITERATION, testMomentumAfter))).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).updater(new Nesterovs(1.0, new MapSchedule(ScheduleType.ITERATION, testMomentumAfter2))).build()).build(); net = new MultiLayerNetwork(conf); net.init(); - assertEquals(0.1, ((Nesterovs)((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0); - assertEquals(0.2, ((Nesterovs)((BaseLayer) conf.getConf(1).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0); + assertEquals(0.1, ((Nesterovs) ((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0, 0), 0.0); + assertEquals(0.2, ((Nesterovs) ((BaseLayer) conf.getConf(1).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0, 0), 0.0); } @Test - public void testUpdaterRhoRmsDecayLayerwiseOverride() { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new AdaDelta(0.5, 0.9)).list() - .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()) - .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).updater(new AdaDelta(0.01,0.9)).build()).build(); + @DisplayName("Test Updater Rho Rms Decay Layerwise Override") + void testUpdaterRhoRmsDecayLayerwiseOverride() { + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new AdaDelta(0.5, 0.9)).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).updater(new AdaDelta(0.01, 0.9)).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - assertTrue(((BaseLayer) conf.getConf(0).getLayer()).getIUpdater() instanceof AdaDelta); assertTrue(((BaseLayer) conf.getConf(1).getLayer()).getIUpdater() instanceof AdaDelta); - assertEquals(0.5, ((AdaDelta)((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getRho(), 0.0); - assertEquals(0.01, ((AdaDelta)((BaseLayer) conf.getConf(1).getLayer()).getIUpdater()).getRho(), 0.0); - - conf = new NeuralNetConfiguration.Builder().updater(new RmsProp(1.0, 2.0, RmsProp.DEFAULT_RMSPROP_EPSILON)).list() - .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).updater(new RmsProp(1.0, 1.0, RmsProp.DEFAULT_RMSPROP_EPSILON)).build()) - .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).updater(new AdaDelta(0.5,AdaDelta.DEFAULT_ADADELTA_EPSILON)).build()) - .build(); - + assertEquals(0.5, ((AdaDelta) ((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getRho(), 0.0); + assertEquals(0.01, ((AdaDelta) ((BaseLayer) conf.getConf(1).getLayer()).getIUpdater()).getRho(), 0.0); + conf = new NeuralNetConfiguration.Builder().updater(new RmsProp(1.0, 2.0, RmsProp.DEFAULT_RMSPROP_EPSILON)).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).updater(new RmsProp(1.0, 1.0, RmsProp.DEFAULT_RMSPROP_EPSILON)).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).updater(new AdaDelta(0.5, AdaDelta.DEFAULT_ADADELTA_EPSILON)).build()).build(); net = new MultiLayerNetwork(conf); net.init(); - assertTrue(((BaseLayer) conf.getConf(0).getLayer()).getIUpdater() instanceof RmsProp); assertTrue(((BaseLayer) conf.getConf(1).getLayer()).getIUpdater() instanceof AdaDelta); assertEquals(1.0, ((RmsProp) ((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getRmsDecay(), 0.0); assertEquals(0.5, ((AdaDelta) ((BaseLayer) conf.getConf(1).getLayer()).getIUpdater()).getRho(), 0.0); } - @Test - public void testUpdaterAdamParamsLayerwiseOverride() { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .updater(new Adam(1.0, 0.5, 0.5, 1e-8)) - .list() - .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()) - .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).updater(new Adam(1.0, 0.6, 0.7, 1e-8)).build()) - .build(); + @DisplayName("Test Updater Adam Params Layerwise Override") + void testUpdaterAdamParamsLayerwiseOverride() { + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Adam(1.0, 0.5, 0.5, 1e-8)).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).updater(new Adam(1.0, 0.6, 0.7, 1e-8)).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - assertEquals(0.5, ((Adam) ((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getBeta1(), 0.0); assertEquals(0.6, ((Adam) ((BaseLayer) conf.getConf(1).getLayer()).getIUpdater()).getBeta1(), 0.0); assertEquals(0.5, ((Adam) ((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getBeta2(), 0.0); @@ -278,45 +322,25 @@ public class LayerConfigTest extends BaseDL4JTest { } @Test - public void testGradientNormalizationLayerwiseOverride() { - - //Learning rate without layerwise override: - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) - .gradientNormalizationThreshold(10).list() - .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()) - .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build(); + @DisplayName("Test Gradient Normalization Layerwise Override") + void testGradientNormalizationLayerwiseOverride() { + // Learning rate without layerwise override: + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue).gradientNormalizationThreshold(10).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - - assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, - ((BaseLayer) conf.getConf(0).getLayer()).getGradientNormalization()); - assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, - ((BaseLayer) conf.getConf(1).getLayer()).getGradientNormalization()); + assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, ((BaseLayer) conf.getConf(0).getLayer()).getGradientNormalization()); + assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, ((BaseLayer) conf.getConf(1).getLayer()).getGradientNormalization()); assertEquals(10, ((BaseLayer) conf.getConf(0).getLayer()).getGradientNormalizationThreshold(), 0.0); assertEquals(10, ((BaseLayer) conf.getConf(1).getLayer()).getGradientNormalizationThreshold(), 0.0); - - //With: - conf = new NeuralNetConfiguration.Builder() - .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) - .gradientNormalizationThreshold(10).list() - .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()) - .layer(1, new DenseLayer.Builder().nIn(2).nOut(2) - .gradientNormalization(GradientNormalization.None) - .gradientNormalizationThreshold(2.5).build()) - .build(); - + // With: + conf = new NeuralNetConfiguration.Builder().gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue).gradientNormalizationThreshold(10).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).gradientNormalization(GradientNormalization.None).gradientNormalizationThreshold(2.5).build()).build(); net = new MultiLayerNetwork(conf); net.init(); - - assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, - ((BaseLayer) conf.getConf(0).getLayer()).getGradientNormalization()); + assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, ((BaseLayer) conf.getConf(0).getLayer()).getGradientNormalization()); assertEquals(GradientNormalization.None, ((BaseLayer) conf.getConf(1).getLayer()).getGradientNormalization()); assertEquals(10, ((BaseLayer) conf.getConf(0).getLayer()).getGradientNormalizationThreshold(), 0.0); assertEquals(2.5, ((BaseLayer) conf.getConf(1).getLayer()).getGradientNormalizationThreshold(), 0.0); } - - /* @Test public void testLearningRatePolicyExponential() { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerConfigValidationTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerConfigValidationTest.java index 5ff503c3a..dc0837911 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerConfigValidationTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerConfigValidationTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.conf.layers; import org.deeplearning4j.BaseDL4JTest; @@ -35,8 +34,8 @@ import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInitDistribution; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.Nesterovs; @@ -44,107 +43,89 @@ import org.nd4j.linalg.learning.config.RmsProp; import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.schedule.MapSchedule; import org.nd4j.linalg.schedule.ScheduleType; - import java.util.HashMap; import java.util.Map; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; +import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNull; - -public class LayerConfigValidationTest extends BaseDL4JTest { - +@DisplayName("Layer Config Validation Test") +class LayerConfigValidationTest extends BaseDL4JTest { @Test - public void testDropConnect() { + @DisplayName("Test Drop Connect") + void testDropConnect() { // Warning thrown only since some layers may not have l1 or l2 - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)).weightNoise(new DropConnect(0.5)) - .list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()) - .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build(); + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)).weightNoise(new DropConnect(0.5)).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); } - @Test - public void testL1L2NotSet() { + @DisplayName("Test L 1 L 2 Not Set") + void testL1L2NotSet() { // Warning thrown only since some layers may not have l1 or l2 - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.3)) - .list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()) - .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build(); - MultiLayerNetwork net = new MultiLayerNetwork(conf); - net.init(); - } - - @Test(expected = IllegalStateException.class) - @Ignore //Old assumption: throw exception on l1 but no regularization. Current design: warn, not exception - public void testRegNotSetL1Global() { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.3)).l1(0.5).list() - .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()) - .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build(); - MultiLayerNetwork net = new MultiLayerNetwork(conf); - net.init(); - } - - @Test(expected = IllegalStateException.class) - @Ignore //Old assumption: throw exception on l1 but no regularization. Current design: warn, not exception - public void testRegNotSetL2Local() { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.3)).list() - .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).l2(0.5).build()) - .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build(); + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.3)).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); } @Test - public void testWeightInitDistNotSet() { + @Disabled + @DisplayName("Test Reg Not Set L 1 Global") + void testRegNotSetL1Global() { + assertThrows(IllegalStateException.class, () -> { + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.3)).l1(0.5).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build(); + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + }); + } + + @Test + @Disabled + @DisplayName("Test Reg Not Set L 2 Local") + void testRegNotSetL2Local() { + assertThrows(IllegalStateException.class, () -> { + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.3)).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).l2(0.5).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build(); + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + }); + } + + @Test + @DisplayName("Test Weight Init Dist Not Set") + void testWeightInitDistNotSet() { // Warning thrown only since global dist can be set with a different weight init locally - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder().updater(new Sgd(0.3)).dist(new GaussianDistribution(1e-3, 2)) - .list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()) - .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build(); + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.3)).dist(new GaussianDistribution(1e-3, 2)).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); } @Test - public void testNesterovsNotSetGlobal() { + @DisplayName("Test Nesterovs Not Set Global") + void testNesterovsNotSetGlobal() { // Warnings only thrown Map testMomentumAfter = new HashMap<>(); testMomentumAfter.put(0, 0.1); - - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder().updater(new Nesterovs(1.0, new MapSchedule(ScheduleType.ITERATION, testMomentumAfter))).list() - .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()) - .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build(); + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Nesterovs(1.0, new MapSchedule(ScheduleType.ITERATION, testMomentumAfter))).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); } @Test - public void testCompGraphNullLayer() { - ComputationGraphConfiguration.GraphBuilder gb = new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Sgd(0.01)) - .seed(42).miniBatch(false).l1(0.2).l2(0.2) - /* Graph Builder */ - .updater(Updater.RMSPROP).graphBuilder().addInputs("in") - .addLayer("L" + 1, - new GravesLSTM.Builder().nIn(20).updater(Updater.RMSPROP).nOut(10) - .weightInit(WeightInit.XAVIER) - .dropOut(0.4).l1(0.3).activation(Activation.SIGMOID).build(), - "in") - .addLayer("output", - new RnnOutputLayer.Builder().nIn(20).nOut(10).activation(Activation.SOFTMAX) - .weightInit(WeightInit.RELU_UNIFORM).build(), - "L" + 1) - .setOutputs("output"); + @DisplayName("Test Comp Graph Null Layer") + void testCompGraphNullLayer() { + ComputationGraphConfiguration.GraphBuilder gb = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Sgd(0.01)).seed(42).miniBatch(false).l1(0.2).l2(0.2).updater(Updater.RMSPROP).graphBuilder().addInputs("in").addLayer("L" + 1, new GravesLSTM.Builder().nIn(20).updater(Updater.RMSPROP).nOut(10).weightInit(WeightInit.XAVIER).dropOut(0.4).l1(0.3).activation(Activation.SIGMOID).build(), "in").addLayer("output", new RnnOutputLayer.Builder().nIn(20).nOut(10).activation(Activation.SOFTMAX).weightInit(WeightInit.RELU_UNIFORM).build(), "L" + 1).setOutputs("output"); ComputationGraphConfiguration conf = gb.build(); ComputationGraph cg = new ComputationGraph(conf); cg.init(); } - @Test - public void testPredefinedConfigValues() { + @DisplayName("Test Predefined Config Values") + void testPredefinedConfigValues() { double expectedMomentum = 0.9; double expectedAdamMeanDecay = 0.9; double expectedAdamVarDecay = 0.999; @@ -152,59 +133,38 @@ public class LayerConfigValidationTest extends BaseDL4JTest { Distribution expectedDist = new NormalDistribution(0, 1); double expectedL1 = 0.0; double expectedL2 = 0.0; - // Nesterovs Updater - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Nesterovs(0.9)) - .list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).l2(0.5).build()) - .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).updater(new Nesterovs(0.3, 0.4)).build()).build(); + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Nesterovs(0.9)).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).l2(0.5).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).updater(new Nesterovs(0.3, 0.4)).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - BaseLayer layerConf = (BaseLayer) net.getLayer(0).conf().getLayer(); assertEquals(expectedMomentum, ((Nesterovs) layerConf.getIUpdater()).getMomentum(), 1e-3); assertNull(TestUtils.getL1Reg(layerConf.getRegularization())); assertEquals(0.5, TestUtils.getL2(layerConf), 1e-3); - BaseLayer layerConf1 = (BaseLayer) net.getLayer(1).conf().getLayer(); assertEquals(0.4, ((Nesterovs) layerConf1.getIUpdater()).getMomentum(), 1e-3); - // Adam Updater - conf = new NeuralNetConfiguration.Builder().updater(new Adam(0.3)) - .weightInit(new WeightInitDistribution(expectedDist)).list() - .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).l2(0.5).l1(0.3).build()) - .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build(); + conf = new NeuralNetConfiguration.Builder().updater(new Adam(0.3)).weightInit(new WeightInitDistribution(expectedDist)).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).l2(0.5).l1(0.3).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build(); net = new MultiLayerNetwork(conf); net.init(); - layerConf = (BaseLayer) net.getLayer(0).conf().getLayer(); assertEquals(0.3, TestUtils.getL1(layerConf), 1e-3); assertEquals(0.5, TestUtils.getL2(layerConf), 1e-3); - layerConf1 = (BaseLayer) net.getLayer(1).conf().getLayer(); assertEquals(expectedAdamMeanDecay, ((Adam) layerConf1.getIUpdater()).getBeta1(), 1e-3); assertEquals(expectedAdamVarDecay, ((Adam) layerConf1.getIUpdater()).getBeta2(), 1e-3); assertEquals(new WeightInitDistribution(expectedDist), layerConf1.getWeightInitFn()); assertNull(TestUtils.getL1Reg(layerConf1.getRegularization())); assertNull(TestUtils.getL2Reg(layerConf1.getRegularization())); - - //RMSProp Updater - conf = new NeuralNetConfiguration.Builder().updater(new RmsProp(0.3)).list() - .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()) - .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).updater(new RmsProp(0.3, 0.4, RmsProp.DEFAULT_RMSPROP_EPSILON)).build()).build(); + // RMSProp Updater + conf = new NeuralNetConfiguration.Builder().updater(new RmsProp(0.3)).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).updater(new RmsProp(0.3, 0.4, RmsProp.DEFAULT_RMSPROP_EPSILON)).build()).build(); net = new MultiLayerNetwork(conf); net.init(); - layerConf = (BaseLayer) net.getLayer(0).conf().getLayer(); assertEquals(expectedRmsDecay, ((RmsProp) layerConf.getIUpdater()).getRmsDecay(), 1e-3); assertNull(TestUtils.getL1Reg(layerConf.getRegularization())); assertNull(TestUtils.getL2Reg(layerConf.getRegularization())); - layerConf1 = (BaseLayer) net.getLayer(1).conf().getLayer(); assertEquals(0.4, ((RmsProp) layerConf1.getIUpdater()).getRmsDecay(), 1e-3); - - } - } - - diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/CNNProcessorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/CNNProcessorTest.java index db53d7cf0..a79583eaa 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/CNNProcessorTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/CNNProcessorTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.conf.preprocessor; import org.deeplearning4j.BaseDL4JTest; @@ -28,7 +27,7 @@ import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -36,29 +35,33 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Nesterovs; import org.nd4j.linalg.lossfunctions.LossFunctions; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; - -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** - **/ + */ +@DisplayName("Cnn Processor Test") +class CNNProcessorTest extends BaseDL4JTest { -public class CNNProcessorTest extends BaseDL4JTest { private static int rows = 28; + private static int cols = 28; + private static INDArray in2D = Nd4j.create(DataType.FLOAT, 1, 784); + private static INDArray in3D = Nd4j.create(DataType.FLOAT, 20, 784, 7); + private static INDArray in4D = Nd4j.create(DataType.FLOAT, 20, 1, 28, 28); - @Test - public void testFeedForwardToCnnPreProcessor() { + @DisplayName("Test Feed Forward To Cnn Pre Processor") + void testFeedForwardToCnnPreProcessor() { FeedForwardToCnnPreProcessor convProcessor = new FeedForwardToCnnPreProcessor(rows, cols, 1); - INDArray check2to4 = convProcessor.preProcess(in2D, -1, LayerWorkspaceMgr.noWorkspaces()); int val2to4 = check2to4.shape().length; assertTrue(val2to4 == 4); assertEquals(Nd4j.create(DataType.FLOAT, 1, 1, 28, 28), check2to4); - INDArray check4to4 = convProcessor.preProcess(in4D, -1, LayerWorkspaceMgr.noWorkspaces()); int val4to4 = check4to4.shape().length; assertTrue(val4to4 == 4); @@ -66,42 +69,41 @@ public class CNNProcessorTest extends BaseDL4JTest { } @Test - public void testFeedForwardToCnnPreProcessor2() { - int[] nRows = {1, 5, 20}; - int[] nCols = {1, 5, 20}; - int[] nDepth = {1, 3}; - int[] nMiniBatchSize = {1, 5}; + @DisplayName("Test Feed Forward To Cnn Pre Processor 2") + void testFeedForwardToCnnPreProcessor2() { + int[] nRows = { 1, 5, 20 }; + int[] nCols = { 1, 5, 20 }; + int[] nDepth = { 1, 3 }; + int[] nMiniBatchSize = { 1, 5 }; for (int rows : nRows) { for (int cols : nCols) { for (int d : nDepth) { FeedForwardToCnnPreProcessor convProcessor = new FeedForwardToCnnPreProcessor(rows, cols, d); - for (int miniBatch : nMiniBatchSize) { - long[] ffShape = new long[] {miniBatch, rows * cols * d}; + long[] ffShape = new long[] { miniBatch, rows * cols * d }; INDArray rand = Nd4j.rand(ffShape); INDArray ffInput_c = Nd4j.create(DataType.FLOAT, ffShape, 'c'); INDArray ffInput_f = Nd4j.create(DataType.FLOAT, ffShape, 'f'); ffInput_c.assign(rand); ffInput_f.assign(rand); assertEquals(ffInput_c, ffInput_f); - - //Test forward pass: + // Test forward pass: INDArray convAct_c = convProcessor.preProcess(ffInput_c, -1, LayerWorkspaceMgr.noWorkspaces()); INDArray convAct_f = convProcessor.preProcess(ffInput_f, -1, LayerWorkspaceMgr.noWorkspaces()); - long[] convShape = {miniBatch, d, rows, cols}; + long[] convShape = { miniBatch, d, rows, cols }; assertArrayEquals(convShape, convAct_c.shape()); assertArrayEquals(convShape, convAct_f.shape()); assertEquals(convAct_c, convAct_f); - - //Check values: - //CNN reshaping (for each example) takes a 1d vector and converts it to 3d + // Check values: + // CNN reshaping (for each example) takes a 1d vector and converts it to 3d // (4d total, for minibatch data) - //1d vector is assumed to be rows from channels 0 concatenated, followed by channels 1, etc + // 1d vector is assumed to be rows from channels 0 concatenated, followed by channels 1, etc for (int ex = 0; ex < miniBatch; ex++) { for (int r = 0; r < rows; r++) { for (int c = 0; c < cols; c++) { for (int depth = 0; depth < d; depth++) { - int origPosition = depth * (rows * cols) + r * cols + c; //pos in vector + // pos in vector + int origPosition = depth * (rows * cols) + r * cols + c; double vecValue = ffInput_c.getDouble(ex, origPosition); double convValue = convAct_c.getDouble(ex, depth, r, c); assertEquals(vecValue, convValue, 0.0); @@ -109,9 +111,8 @@ public class CNNProcessorTest extends BaseDL4JTest { } } } - - //Test backward pass: - //Idea is that backward pass should do opposite to forward pass + // Test backward pass: + // Idea is that backward pass should do opposite to forward pass INDArray epsilon4_c = Nd4j.create(DataType.FLOAT, convShape, 'c'); INDArray epsilon4_f = Nd4j.create(DataType.FLOAT, convShape, 'f'); epsilon4_c.assign(convAct_c); @@ -126,12 +127,11 @@ public class CNNProcessorTest extends BaseDL4JTest { } } - @Test - public void testFeedForwardToCnnPreProcessorBackprop() { + @DisplayName("Test Feed Forward To Cnn Pre Processor Backprop") + void testFeedForwardToCnnPreProcessorBackprop() { FeedForwardToCnnPreProcessor convProcessor = new FeedForwardToCnnPreProcessor(rows, cols, 1); convProcessor.preProcess(in2D, -1, LayerWorkspaceMgr.noWorkspaces()); - INDArray check2to2 = convProcessor.backprop(in2D, -1, LayerWorkspaceMgr.noWorkspaces()); int val2to2 = check2to2.shape().length; assertTrue(val2to2 == 2); @@ -139,14 +139,13 @@ public class CNNProcessorTest extends BaseDL4JTest { } @Test - public void testCnnToFeedForwardProcessor() { + @DisplayName("Test Cnn To Feed Forward Processor") + void testCnnToFeedForwardProcessor() { CnnToFeedForwardPreProcessor convProcessor = new CnnToFeedForwardPreProcessor(rows, cols, 1); - INDArray check2to4 = convProcessor.backprop(in2D, -1, LayerWorkspaceMgr.noWorkspaces()); int val2to4 = check2to4.shape().length; assertTrue(val2to4 == 4); assertEquals(Nd4j.create(DataType.FLOAT, 1, 1, 28, 28), check2to4); - INDArray check4to4 = convProcessor.backprop(in4D, -1, LayerWorkspaceMgr.noWorkspaces()); int val4to4 = check4to4.shape().length; assertTrue(val4to4 == 4); @@ -154,15 +153,14 @@ public class CNNProcessorTest extends BaseDL4JTest { } @Test - public void testCnnToFeedForwardPreProcessorBackprop() { + @DisplayName("Test Cnn To Feed Forward Pre Processor Backprop") + void testCnnToFeedForwardPreProcessorBackprop() { CnnToFeedForwardPreProcessor convProcessor = new CnnToFeedForwardPreProcessor(rows, cols, 1); convProcessor.preProcess(in4D, -1, LayerWorkspaceMgr.noWorkspaces()); - INDArray check2to2 = convProcessor.preProcess(in2D, -1, LayerWorkspaceMgr.noWorkspaces()); int val2to2 = check2to2.shape().length; assertTrue(val2to2 == 2); assertEquals(Nd4j.create(DataType.FLOAT, 1, 784), check2to2); - INDArray check4to2 = convProcessor.preProcess(in4D, -1, LayerWorkspaceMgr.noWorkspaces()); int val4to2 = check4to2.shape().length; assertTrue(val4to2 == 2); @@ -170,42 +168,41 @@ public class CNNProcessorTest extends BaseDL4JTest { } @Test - public void testCnnToFeedForwardPreProcessor2() { - int[] nRows = {1, 5, 20}; - int[] nCols = {1, 5, 20}; - int[] nDepth = {1, 3}; - int[] nMiniBatchSize = {1, 5}; + @DisplayName("Test Cnn To Feed Forward Pre Processor 2") + void testCnnToFeedForwardPreProcessor2() { + int[] nRows = { 1, 5, 20 }; + int[] nCols = { 1, 5, 20 }; + int[] nDepth = { 1, 3 }; + int[] nMiniBatchSize = { 1, 5 }; for (int rows : nRows) { for (int cols : nCols) { for (int d : nDepth) { CnnToFeedForwardPreProcessor convProcessor = new CnnToFeedForwardPreProcessor(rows, cols, d); - for (int miniBatch : nMiniBatchSize) { - long[] convActShape = new long[] {miniBatch, d, rows, cols}; + long[] convActShape = new long[] { miniBatch, d, rows, cols }; INDArray rand = Nd4j.rand(convActShape); INDArray convInput_c = Nd4j.create(DataType.FLOAT, convActShape, 'c'); INDArray convInput_f = Nd4j.create(DataType.FLOAT, convActShape, 'f'); convInput_c.assign(rand); convInput_f.assign(rand); assertEquals(convInput_c, convInput_f); - - //Test forward pass: + // Test forward pass: INDArray ffAct_c = convProcessor.preProcess(convInput_c, -1, LayerWorkspaceMgr.noWorkspaces()); INDArray ffAct_f = convProcessor.preProcess(convInput_f, -1, LayerWorkspaceMgr.noWorkspaces()); - long[] ffActShape = {miniBatch, d * rows * cols}; + long[] ffActShape = { miniBatch, d * rows * cols }; assertArrayEquals(ffActShape, ffAct_c.shape()); assertArrayEquals(ffActShape, ffAct_f.shape()); assertEquals(ffAct_c, ffAct_f); - - //Check values: - //CNN reshaping (for each example) takes a 1d vector and converts it to 3d + // Check values: + // CNN reshaping (for each example) takes a 1d vector and converts it to 3d // (4d total, for minibatch data) - //1d vector is assumed to be rows from channels 0 concatenated, followed by channels 1, etc + // 1d vector is assumed to be rows from channels 0 concatenated, followed by channels 1, etc for (int ex = 0; ex < miniBatch; ex++) { for (int r = 0; r < rows; r++) { for (int c = 0; c < cols; c++) { for (int depth = 0; depth < d; depth++) { - int vectorPosition = depth * (rows * cols) + r * cols + c; //pos in vector after reshape + // pos in vector after reshape + int vectorPosition = depth * (rows * cols) + r * cols + c; double vecValue = ffAct_c.getDouble(ex, vectorPosition); double convValue = convInput_c.getDouble(ex, depth, r, c); assertEquals(convValue, vecValue, 0.0); @@ -213,9 +210,8 @@ public class CNNProcessorTest extends BaseDL4JTest { } } } - - //Test backward pass: - //Idea is that backward pass should do opposite to forward pass + // Test backward pass: + // Idea is that backward pass should do opposite to forward pass INDArray epsilon2_c = Nd4j.create(DataType.FLOAT, ffActShape, 'c'); INDArray epsilon2_f = Nd4j.create(DataType.FLOAT, ffActShape, 'f'); epsilon2_c.assign(ffAct_c); @@ -231,79 +227,32 @@ public class CNNProcessorTest extends BaseDL4JTest { } @Test - public void testInvalidInputShape(){ - - NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder() - .seed(123) - .miniBatch(true) - .cacheMode(CacheMode.DEVICE) - .updater(new Nesterovs(0.9)) - .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT); - - int[] kernelArray = new int[]{3,3}; - int[] strideArray = new int[]{1,1}; - int[] zeroPaddingArray = new int[]{0,0}; + @DisplayName("Test Invalid Input Shape") + void testInvalidInputShape() { + NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(123).miniBatch(true).cacheMode(CacheMode.DEVICE).updater(new Nesterovs(0.9)).gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT); + int[] kernelArray = new int[] { 3, 3 }; + int[] strideArray = new int[] { 1, 1 }; + int[] zeroPaddingArray = new int[] { 0, 0 }; int processWidth = 4; - - NeuralNetConfiguration.ListBuilder listBuilder = builder.list(); // Building the DL4J network - - listBuilder = listBuilder.layer(0, new ConvolutionLayer.Builder(kernelArray, strideArray, zeroPaddingArray) - .name("cnn1") - .convolutionMode(ConvolutionMode.Strict) - .nIn(2) // 2 input channels - .nOut(processWidth) - .weightInit(WeightInit.XAVIER_UNIFORM) - .activation(Activation.RELU) - .biasInit(1e-2).build()); - - listBuilder = listBuilder.layer(1, new ConvolutionLayer.Builder(kernelArray, strideArray, zeroPaddingArray) - .name("cnn2") - .convolutionMode(ConvolutionMode.Strict) - .nOut(processWidth) - .weightInit(WeightInit.XAVIER_UNIFORM) - .activation(Activation.RELU) - .biasInit(1e-2) - .build()); - - listBuilder = listBuilder.layer(2, new ConvolutionLayer.Builder(kernelArray, strideArray, zeroPaddingArray) - .name("cnn3") - .convolutionMode(ConvolutionMode.Strict) - .nOut(processWidth) - .weightInit(WeightInit.XAVIER_UNIFORM) - .activation(Activation.RELU).build()); - - listBuilder = listBuilder.layer(3, new ConvolutionLayer.Builder(kernelArray, strideArray, zeroPaddingArray) - .name("cnn4") - .convolutionMode(ConvolutionMode.Strict) - .nOut(processWidth) - .weightInit(WeightInit.XAVIER_UNIFORM) - .activation(Activation.RELU).build()); - - listBuilder = listBuilder - .layer(4, new OutputLayer.Builder(LossFunctions.LossFunction.MSE) - .name("output") - .nOut(1) - .activation(Activation.TANH) - .build()); - - MultiLayerConfiguration conf = listBuilder - - - .setInputType(InputType.convolutional(20, 10, 2)) - .build(); - + // Building the DL4J network + NeuralNetConfiguration.ListBuilder listBuilder = builder.list(); + listBuilder = listBuilder.layer(0, new ConvolutionLayer.Builder(kernelArray, strideArray, zeroPaddingArray).name("cnn1").convolutionMode(ConvolutionMode.Strict).nIn(// 2 input channels + 2).nOut(processWidth).weightInit(WeightInit.XAVIER_UNIFORM).activation(Activation.RELU).biasInit(1e-2).build()); + listBuilder = listBuilder.layer(1, new ConvolutionLayer.Builder(kernelArray, strideArray, zeroPaddingArray).name("cnn2").convolutionMode(ConvolutionMode.Strict).nOut(processWidth).weightInit(WeightInit.XAVIER_UNIFORM).activation(Activation.RELU).biasInit(1e-2).build()); + listBuilder = listBuilder.layer(2, new ConvolutionLayer.Builder(kernelArray, strideArray, zeroPaddingArray).name("cnn3").convolutionMode(ConvolutionMode.Strict).nOut(processWidth).weightInit(WeightInit.XAVIER_UNIFORM).activation(Activation.RELU).build()); + listBuilder = listBuilder.layer(3, new ConvolutionLayer.Builder(kernelArray, strideArray, zeroPaddingArray).name("cnn4").convolutionMode(ConvolutionMode.Strict).nOut(processWidth).weightInit(WeightInit.XAVIER_UNIFORM).activation(Activation.RELU).build()); + listBuilder = listBuilder.layer(4, new OutputLayer.Builder(LossFunctions.LossFunction.MSE).name("output").nOut(1).activation(Activation.TANH).build()); + MultiLayerConfiguration conf = listBuilder.setInputType(InputType.convolutional(20, 10, 2)).build(); // For some reason, this model works MultiLayerNetwork niceModel = new MultiLayerNetwork(conf); niceModel.init(); - - niceModel.output(Nd4j.create(DataType.FLOAT, 1, 2, 20, 10)); //Valid - + // Valid + niceModel.output(Nd4j.create(DataType.FLOAT, 1, 2, 20, 10)); try { niceModel.output(Nd4j.create(DataType.FLOAT, 1, 2, 10, 20)); fail("Expected exception"); - } catch (IllegalStateException e){ - //OK + } catch (IllegalStateException e) { + // OK } } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/CustomPreprocessorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/CustomPreprocessorTest.java index dcd4a2e50..946af34f4 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/CustomPreprocessorTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/CustomPreprocessorTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.conf.preprocessor; import org.deeplearning4j.BaseDL4JTest; @@ -27,44 +26,33 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.preprocessor.custom.MyCustomPreprocessor; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.shade.jackson.databind.ObjectMapper; import org.nd4j.shade.jackson.databind.introspect.AnnotatedClass; import org.nd4j.shade.jackson.databind.jsontype.NamedType; - import java.util.Collection; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; - -public class CustomPreprocessorTest extends BaseDL4JTest { +@DisplayName("Custom Preprocessor Test") +class CustomPreprocessorTest extends BaseDL4JTest { @Test - public void testCustomPreprocessor() { - //Second: let's create a MultiLayerCofiguration with one, and check JSON and YAML config actually works... - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder().list() - .layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build()) - .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(10) - .activation(Activation.SOFTMAX).nOut(10).build()) - .inputPreProcessor(0, new MyCustomPreprocessor()) - .build(); - + @DisplayName("Test Custom Preprocessor") + void testCustomPreprocessor() { + // Second: let's create a MultiLayerCofiguration with one, and check JSON and YAML config actually works... + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build()).layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(10).activation(Activation.SOFTMAX).nOut(10).build()).inputPreProcessor(0, new MyCustomPreprocessor()).build(); String json = conf.toJson(); String yaml = conf.toYaml(); - -// System.out.println(json); - + // System.out.println(json); MultiLayerConfiguration confFromJson = MultiLayerConfiguration.fromJson(json); assertEquals(conf, confFromJson); - MultiLayerConfiguration confFromYaml = MultiLayerConfiguration.fromYaml(yaml); assertEquals(conf, confFromYaml); - assertTrue(confFromJson.getInputPreProcess(0) instanceof MyCustomPreprocessor); - } - } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/ActivationLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/ActivationLayerTest.java index c1e22efed..bf69c638b 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/ActivationLayerTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/ActivationLayerTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.layers; import org.deeplearning4j.BaseDL4JTest; @@ -35,7 +34,7 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.impl.ActivationELU; import org.nd4j.linalg.activations.impl.ActivationRationalTanh; @@ -46,31 +45,27 @@ import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.lossfunctions.LossFunctions; - import java.util.List; - -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** */ - -public class ActivationLayerTest extends BaseDL4JTest { +@DisplayName("Activation Layer Test") +class ActivationLayerTest extends BaseDL4JTest { @Override - public DataType getDataType(){ + public DataType getDataType() { return DataType.FLOAT; } @Test - public void testInputTypes() { - org.deeplearning4j.nn.conf.layers.ActivationLayer l = - new org.deeplearning4j.nn.conf.layers.ActivationLayer.Builder().activation(Activation.RELU) - .build(); - - + @DisplayName("Test Input Types") + void testInputTypes() { + org.deeplearning4j.nn.conf.layers.ActivationLayer l = new org.deeplearning4j.nn.conf.layers.ActivationLayer.Builder().activation(Activation.RELU).build(); InputType in1 = InputType.feedForward(20); InputType in2 = InputType.convolutional(28, 28, 1); - assertEquals(in1, l.getOutputType(0, in1)); assertEquals(in2, l.getOutputType(0, in2)); assertNull(l.getPreProcessorForInputType(in1)); @@ -78,252 +73,132 @@ public class ActivationLayerTest extends BaseDL4JTest { } @Test - public void testDenseActivationLayer() throws Exception { + @DisplayName("Test Dense Activation Layer") + void testDenseActivationLayer() throws Exception { DataSetIterator iter = new MnistDataSetIterator(2, 2); DataSet next = iter.next(); - // Run without separate activation layer - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123) - .list() - .layer(0, new DenseLayer.Builder().nIn(28 * 28 * 1).nOut(10).activation(Activation.RELU) - .weightInit(WeightInit.XAVIER).build()) - .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER) - .activation(Activation.SOFTMAX).nIn(10).nOut(10).build()) - .build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123).list().layer(0, new DenseLayer.Builder().nIn(28 * 28 * 1).nOut(10).activation(Activation.RELU).weightInit(WeightInit.XAVIER).build()).layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nIn(10).nOut(10).build()).build(); MultiLayerNetwork network = new MultiLayerNetwork(conf); network.init(); network.fit(next); - - // Run with separate activation layer - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123) - .list() - .layer(0, new DenseLayer.Builder().nIn(28 * 28 * 1).nOut(10).activation(Activation.IDENTITY) - .weightInit(WeightInit.XAVIER).build()) - .layer(1, new org.deeplearning4j.nn.conf.layers.ActivationLayer.Builder() - .activation(Activation.RELU).build()) - .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nIn(10).nOut(10) - .build()) - .build(); - + MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123).list().layer(0, new DenseLayer.Builder().nIn(28 * 28 * 1).nOut(10).activation(Activation.IDENTITY).weightInit(WeightInit.XAVIER).build()).layer(1, new org.deeplearning4j.nn.conf.layers.ActivationLayer.Builder().activation(Activation.RELU).build()).layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nIn(10).nOut(10).build()).build(); MultiLayerNetwork network2 = new MultiLayerNetwork(conf2); network2.init(); network2.fit(next); - // check parameters assertEquals(network.getLayer(0).getParam("W"), network2.getLayer(0).getParam("W")); assertEquals(network.getLayer(1).getParam("W"), network2.getLayer(2).getParam("W")); assertEquals(network.getLayer(0).getParam("b"), network2.getLayer(0).getParam("b")); assertEquals(network.getLayer(1).getParam("b"), network2.getLayer(2).getParam("b")); - // check activations network.init(); network.setInput(next.getFeatures()); List activations = network.feedForward(true); - network2.init(); network2.setInput(next.getFeatures()); List activations2 = network2.feedForward(true); - assertEquals(activations.get(1).reshape(activations2.get(2).shape()), activations2.get(2)); assertEquals(activations.get(2), activations2.get(3)); - - } @Test - public void testAutoEncoderActivationLayer() throws Exception { - + @DisplayName("Test Auto Encoder Activation Layer") + void testAutoEncoderActivationLayer() throws Exception { int minibatch = 3; int nIn = 5; int layerSize = 5; int nOut = 3; - - INDArray next = Nd4j.rand(new int[] {minibatch, nIn}); + INDArray next = Nd4j.rand(new int[] { minibatch, nIn }); INDArray labels = Nd4j.zeros(minibatch, nOut); for (int i = 0; i < minibatch; i++) { labels.putScalar(i, i % nOut, 1.0); } - // Run without separate activation layer Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123) - .list() - .layer(0, new AutoEncoder.Builder().nIn(nIn).nOut(layerSize).corruptionLevel(0.0) - .activation(Activation.SIGMOID).build()) - .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.RECONSTRUCTION_CROSSENTROPY) - .activation(Activation.SOFTMAX).nIn(layerSize).nOut(nOut) - .build()) - .build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123).list().layer(0, new AutoEncoder.Builder().nIn(nIn).nOut(layerSize).corruptionLevel(0.0).activation(Activation.SIGMOID).build()).layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.RECONSTRUCTION_CROSSENTROPY).activation(Activation.SOFTMAX).nIn(layerSize).nOut(nOut).build()).build(); MultiLayerNetwork network = new MultiLayerNetwork(conf); network.init(); - network.fit(next, labels); //Labels are necessary for this test: layer activation function affect pretraining results, otherwise - - + // Labels are necessary for this test: layer activation function affect pretraining results, otherwise + network.fit(next, labels); // Run with separate activation layer Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123) - .list() - .layer(0, new AutoEncoder.Builder().nIn(nIn).nOut(layerSize).corruptionLevel(0.0) - .activation(Activation.IDENTITY).build()) - .layer(1, new org.deeplearning4j.nn.conf.layers.ActivationLayer.Builder() - .activation(Activation.SIGMOID).build()) - .layer(2, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.RECONSTRUCTION_CROSSENTROPY) - .activation(Activation.SOFTMAX).nIn(layerSize).nOut(nOut) - .build()) - .build(); - + MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123).list().layer(0, new AutoEncoder.Builder().nIn(nIn).nOut(layerSize).corruptionLevel(0.0).activation(Activation.IDENTITY).build()).layer(1, new org.deeplearning4j.nn.conf.layers.ActivationLayer.Builder().activation(Activation.SIGMOID).build()).layer(2, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.RECONSTRUCTION_CROSSENTROPY).activation(Activation.SOFTMAX).nIn(layerSize).nOut(nOut).build()).build(); MultiLayerNetwork network2 = new MultiLayerNetwork(conf2); network2.init(); network2.fit(next, labels); - // check parameters assertEquals(network.getLayer(0).getParam("W"), network2.getLayer(0).getParam("W")); assertEquals(network.getLayer(1).getParam("W"), network2.getLayer(2).getParam("W")); assertEquals(network.getLayer(0).getParam("b"), network2.getLayer(0).getParam("b")); assertEquals(network.getLayer(1).getParam("b"), network2.getLayer(2).getParam("b")); - // check activations network.init(); network.setInput(next); List activations = network.feedForward(true); - network2.init(); network2.setInput(next); List activations2 = network2.feedForward(true); - assertEquals(activations.get(1).reshape(activations2.get(2).shape()), activations2.get(2)); assertEquals(activations.get(2), activations2.get(3)); - - } @Test - public void testCNNActivationLayer() throws Exception { + @DisplayName("Test CNN Activation Layer") + void testCNNActivationLayer() throws Exception { DataSetIterator iter = new MnistDataSetIterator(2, 2); DataSet next = iter.next(); - // Run without separate activation layer - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123) - .list() - .layer(0, new ConvolutionLayer.Builder(4, 4).stride(2, 2).nIn(1).nOut(20) - .activation(Activation.RELU).weightInit(WeightInit.XAVIER).build()) - .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER) - .activation(Activation.SOFTMAX).nOut(10).build()) - .setInputType(InputType.convolutionalFlat(28, 28, 1)).build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123).list().layer(0, new ConvolutionLayer.Builder(4, 4).stride(2, 2).nIn(1).nOut(20).activation(Activation.RELU).weightInit(WeightInit.XAVIER).build()).layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nOut(10).build()).setInputType(InputType.convolutionalFlat(28, 28, 1)).build(); MultiLayerNetwork network = new MultiLayerNetwork(conf); network.init(); network.fit(next); - - // Run with separate activation layer - MultiLayerConfiguration conf2 = - new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .seed(123).list() - .layer(0, new ConvolutionLayer.Builder(4, 4).stride(2, 2).nIn(1).nOut(20) - .activation(Activation.IDENTITY).weightInit(WeightInit.XAVIER) - .build()) - .layer(1, new org.deeplearning4j.nn.conf.layers.ActivationLayer.Builder() - .activation(Activation.RELU).build()) - .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX) - .nOut(10).build()) - - .setInputType(InputType.convolutionalFlat(28, 28, 1)).build(); - + MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123).list().layer(0, new ConvolutionLayer.Builder(4, 4).stride(2, 2).nIn(1).nOut(20).activation(Activation.IDENTITY).weightInit(WeightInit.XAVIER).build()).layer(1, new org.deeplearning4j.nn.conf.layers.ActivationLayer.Builder().activation(Activation.RELU).build()).layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nOut(10).build()).setInputType(InputType.convolutionalFlat(28, 28, 1)).build(); MultiLayerNetwork network2 = new MultiLayerNetwork(conf2); network2.init(); network2.fit(next); - // check parameters assertEquals(network.getLayer(0).getParam("W"), network2.getLayer(0).getParam("W")); assertEquals(network.getLayer(1).getParam("W"), network2.getLayer(2).getParam("W")); assertEquals(network.getLayer(0).getParam("b"), network2.getLayer(0).getParam("b")); - // check activations network.init(); network.setInput(next.getFeatures()); List activations = network.feedForward(true); - network2.init(); network2.setInput(next.getFeatures()); List activations2 = network2.feedForward(true); - assertEquals(activations.get(1).reshape(activations2.get(2).shape()), activations2.get(2)); assertEquals(activations.get(2), activations2.get(3)); } - @Test - public void testActivationInheritance() { - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123) - .weightInit(WeightInit.XAVIER) - .activation(Activation.RATIONALTANH) - .list() - .layer(new DenseLayer.Builder().nIn(10).nOut(10).build()) - .layer(new ActivationLayer()) - .layer(new ActivationLayer.Builder().build()) - .layer(new ActivationLayer.Builder().activation(Activation.ELU).build()) - .layer(new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nIn(10).nOut(10).build()) - .build(); - + @DisplayName("Test Activation Inheritance") + void testActivationInheritance() { + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123).weightInit(WeightInit.XAVIER).activation(Activation.RATIONALTANH).list().layer(new DenseLayer.Builder().nIn(10).nOut(10).build()).layer(new ActivationLayer()).layer(new ActivationLayer.Builder().build()).layer(new ActivationLayer.Builder().activation(Activation.ELU).build()).layer(new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10).nOut(10).build()).build(); MultiLayerNetwork network = new MultiLayerNetwork(conf); network.init(); - - assertNotNull(((ActivationLayer)network.getLayer(1).conf().getLayer()).getActivationFn()); - - assertTrue(((DenseLayer)network.getLayer(0).conf().getLayer()).getActivationFn() instanceof ActivationRationalTanh); - assertTrue(((ActivationLayer)network.getLayer(1).conf().getLayer()).getActivationFn() instanceof ActivationRationalTanh); - assertTrue(((ActivationLayer)network.getLayer(2).conf().getLayer()).getActivationFn() instanceof ActivationRationalTanh); - assertTrue(((ActivationLayer)network.getLayer(3).conf().getLayer()).getActivationFn() instanceof ActivationELU); - assertTrue(((OutputLayer)network.getLayer(4).conf().getLayer()).getActivationFn() instanceof ActivationSoftmax); + assertNotNull(((ActivationLayer) network.getLayer(1).conf().getLayer()).getActivationFn()); + assertTrue(((DenseLayer) network.getLayer(0).conf().getLayer()).getActivationFn() instanceof ActivationRationalTanh); + assertTrue(((ActivationLayer) network.getLayer(1).conf().getLayer()).getActivationFn() instanceof ActivationRationalTanh); + assertTrue(((ActivationLayer) network.getLayer(2).conf().getLayer()).getActivationFn() instanceof ActivationRationalTanh); + assertTrue(((ActivationLayer) network.getLayer(3).conf().getLayer()).getActivationFn() instanceof ActivationELU); + assertTrue(((OutputLayer) network.getLayer(4).conf().getLayer()).getActivationFn() instanceof ActivationSoftmax); } @Test - public void testActivationInheritanceCG() { - - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123) - .weightInit(WeightInit.XAVIER) - .activation(Activation.RATIONALTANH) - .graphBuilder() - .addInputs("in") - .addLayer("0", new DenseLayer.Builder().nIn(10).nOut(10).build(), "in") - .addLayer("1", new ActivationLayer(), "0") - .addLayer("2", new ActivationLayer.Builder().build(), "1") - .addLayer("3", new ActivationLayer.Builder().activation(Activation.ELU).build(), "2") - .addLayer("4", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nIn(10).nOut(10).build(), "3") - .setOutputs("4") - .build(); - + @DisplayName("Test Activation Inheritance CG") + void testActivationInheritanceCG() { + ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123).weightInit(WeightInit.XAVIER).activation(Activation.RATIONALTANH).graphBuilder().addInputs("in").addLayer("0", new DenseLayer.Builder().nIn(10).nOut(10).build(), "in").addLayer("1", new ActivationLayer(), "0").addLayer("2", new ActivationLayer.Builder().build(), "1").addLayer("3", new ActivationLayer.Builder().activation(Activation.ELU).build(), "2").addLayer("4", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10).nOut(10).build(), "3").setOutputs("4").build(); ComputationGraph network = new ComputationGraph(conf); network.init(); - - assertNotNull(((ActivationLayer)network.getLayer("1").conf().getLayer()).getActivationFn()); - - assertTrue(((DenseLayer)network.getLayer("0").conf().getLayer()).getActivationFn() instanceof ActivationRationalTanh); - assertTrue(((ActivationLayer)network.getLayer("1").conf().getLayer()).getActivationFn() instanceof ActivationRationalTanh); - assertTrue(((ActivationLayer)network.getLayer("2").conf().getLayer()).getActivationFn() instanceof ActivationRationalTanh); - assertTrue(((ActivationLayer)network.getLayer("3").conf().getLayer()).getActivationFn() instanceof ActivationELU); - assertTrue(((OutputLayer)network.getLayer("4").conf().getLayer()).getActivationFn() instanceof ActivationSoftmax); + assertNotNull(((ActivationLayer) network.getLayer("1").conf().getLayer()).getActivationFn()); + assertTrue(((DenseLayer) network.getLayer("0").conf().getLayer()).getActivationFn() instanceof ActivationRationalTanh); + assertTrue(((ActivationLayer) network.getLayer("1").conf().getLayer()).getActivationFn() instanceof ActivationRationalTanh); + assertTrue(((ActivationLayer) network.getLayer("2").conf().getLayer()).getActivationFn() instanceof ActivationRationalTanh); + assertTrue(((ActivationLayer) network.getLayer("3").conf().getLayer()).getActivationFn() instanceof ActivationELU); + assertTrue(((OutputLayer) network.getLayer("4").conf().getLayer()).getActivationFn() instanceof ActivationSoftmax); } - } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/AutoEncoderTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/AutoEncoderTest.java index 0d0f22e46..05f40cf77 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/AutoEncoderTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/AutoEncoderTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.layers; import org.deeplearning4j.BaseDL4JTest; @@ -31,49 +30,30 @@ import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.lossfunctions.LossFunctions; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -public class AutoEncoderTest extends BaseDL4JTest { +@DisplayName("Auto Encoder Test") +class AutoEncoderTest extends BaseDL4JTest { @Test - public void sanityCheckIssue5662(){ + @DisplayName("Sanity Check Issue 5662") + void sanityCheckIssue5662() { int mergeSize = 50; int encdecSize = 25; int in1Size = 20; int in2Size = 15; int hiddenSize = 10; - - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() - .weightInit(WeightInit.XAVIER) - .graphBuilder() - .addInputs("in1", "in2") - .addLayer("1", new DenseLayer.Builder().nOut(mergeSize).build(), "in1") - .addLayer("2", new DenseLayer.Builder().nOut(mergeSize).build(), "in2") - .addVertex("merge", new MergeVertex(), "1", "2") - .addLayer("e",new AutoEncoder.Builder().nOut(encdecSize).corruptionLevel(0.2).build(),"merge") - .addLayer("hidden",new AutoEncoder.Builder().nOut(hiddenSize).build(),"e") - .addLayer("decoder",new AutoEncoder.Builder().nOut(encdecSize).corruptionLevel(0.2).build(),"hidden") - .addLayer("L4", new DenseLayer.Builder().nOut(mergeSize).build(), "decoder") - .addLayer("out1", new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nOut(in1Size).build(),"L4") - .addLayer("out2",new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nOut(in2Size).build(),"L4") - .setOutputs("out1","out2") - .setInputTypes(InputType.feedForward(in1Size), InputType.feedForward(in2Size)) - - .build(); - + ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in1", "in2").addLayer("1", new DenseLayer.Builder().nOut(mergeSize).build(), "in1").addLayer("2", new DenseLayer.Builder().nOut(mergeSize).build(), "in2").addVertex("merge", new MergeVertex(), "1", "2").addLayer("e", new AutoEncoder.Builder().nOut(encdecSize).corruptionLevel(0.2).build(), "merge").addLayer("hidden", new AutoEncoder.Builder().nOut(hiddenSize).build(), "e").addLayer("decoder", new AutoEncoder.Builder().nOut(encdecSize).corruptionLevel(0.2).build(), "hidden").addLayer("L4", new DenseLayer.Builder().nOut(mergeSize).build(), "decoder").addLayer("out1", new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nOut(in1Size).build(), "L4").addLayer("out2", new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nOut(in2Size).build(), "L4").setOutputs("out1", "out2").setInputTypes(InputType.feedForward(in1Size), InputType.feedForward(in2Size)).build(); ComputationGraph net = new ComputationGraph(conf); net.init(); - - MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet( - new INDArray[]{Nd4j.create(1, in1Size), Nd4j.create(1, in2Size)}, - new INDArray[]{Nd4j.create(1, in1Size), Nd4j.create(1, in2Size)}); - + MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(new INDArray[] { Nd4j.create(1, in1Size), Nd4j.create(1, in2Size) }, new INDArray[] { Nd4j.create(1, in1Size), Nd4j.create(1, in2Size) }); net.summary(InputType.feedForward(in1Size), InputType.feedForward(in2Size)); net.fit(new SingletonMultiDataSetIterator(mds)); } - } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/BaseLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/BaseLayerTest.java index ea032ecce..9e3bf4df1 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/BaseLayerTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/BaseLayerTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.layers; import lombok.val; @@ -29,46 +28,47 @@ import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; - import java.util.HashMap; import java.util.Map; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotEquals; +@DisplayName("Base Layer Test") +class BaseLayerTest extends BaseDL4JTest { -public class BaseLayerTest extends BaseDL4JTest { + protected INDArray weight = Nd4j.create(new double[] { 0.10, -0.20, -0.15, 0.05 }, new int[] { 2, 2 }); + + protected INDArray bias = Nd4j.create(new double[] { 0.5, 0.5 }, new int[] { 1, 2 }); - protected INDArray weight = Nd4j.create(new double[] {0.10, -0.20, -0.15, 0.05}, new int[] {2, 2}); - protected INDArray bias = Nd4j.create(new double[] {0.5, 0.5}, new int[] {1, 2}); protected Map paramTable; - @Before - public void doBefore() { + @BeforeEach + void doBefore() { paramTable = new HashMap<>(); paramTable.put("W", weight); paramTable.put("b", bias); - } @Test - public void testSetExistingParamsConvolutionSingleLayer() { + @DisplayName("Test Set Existing Params Convolution Single Layer") + void testSetExistingParamsConvolutionSingleLayer() { Layer layer = configureSingleLayer(); assertNotEquals(paramTable, layer.paramTable()); - layer.setParamTable(paramTable); assertEquals(paramTable, layer.paramTable()); } - @Test - public void testSetExistingParamsDenseMultiLayer() { + @DisplayName("Test Set Existing Params Dense Multi Layer") + void testSetExistingParamsDenseMultiLayer() { MultiLayerNetwork net = configureMultiLayer(); - for (Layer layer : net.getLayers()) { assertNotEquals(paramTable, layer.paramTable()); layer.setParamTable(paramTable); @@ -76,31 +76,21 @@ public class BaseLayerTest extends BaseDL4JTest { } } - public Layer configureSingleLayer() { int nIn = 2; int nOut = 2; - - NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder() - .layer(new ConvolutionLayer.Builder().nIn(nIn).nOut(nOut).build()).build(); - + NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().layer(new ConvolutionLayer.Builder().nIn(nIn).nOut(nOut).build()).build(); val numParams = conf.getLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); return conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); } - public MultiLayerNetwork configureMultiLayer() { int nIn = 2; int nOut = 2; - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list() - .layer(0, new DenseLayer.Builder().nIn(nIn).nOut(nOut).build()) - .layer(1, new OutputLayer.Builder().nIn(nIn).nOut(nOut).activation(Activation.SOFTMAX).build()).build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(0, new DenseLayer.Builder().nIn(nIn).nOut(nOut).build()).layer(1, new OutputLayer.Builder().nIn(nIn).nOut(nOut).activation(Activation.SOFTMAX).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); return net; } - } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/CacheModeTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/CacheModeTest.java index f20accbe5..853bf75d0 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/CacheModeTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/CacheModeTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.layers; import org.deeplearning4j.BaseDL4JTest; @@ -28,77 +27,58 @@ import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.assertEquals; - -public class CacheModeTest extends BaseDL4JTest { +@DisplayName("Cache Mode Test") +class CacheModeTest extends BaseDL4JTest { @Test - public void testConvCacheModeSimple(){ - + @DisplayName("Test Conv Cache Mode Simple") + void testConvCacheModeSimple() { MultiLayerConfiguration conf1 = getConf(CacheMode.NONE); MultiLayerConfiguration conf2 = getConf(CacheMode.DEVICE); - MultiLayerNetwork net1 = new MultiLayerNetwork(conf1); net1.init(); MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); net2.init(); - - INDArray in = Nd4j.rand(3, 28*28); + INDArray in = Nd4j.rand(3, 28 * 28); INDArray labels = TestUtils.randomOneHot(3, 10); - INDArray out1 = net1.output(in); INDArray out2 = net2.output(in); assertEquals(out1, out2); - assertEquals(net1.params(), net2.params()); net1.fit(in, labels); net2.fit(in, labels); assertEquals(net1.params(), net2.params()); } - private static MultiLayerConfiguration getConf(CacheMode cacheMode){ - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .activation(Activation.TANH) - .inferenceWorkspaceMode(WorkspaceMode.ENABLED) - .trainingWorkspaceMode(WorkspaceMode.ENABLED) - .seed(12345) - .cacheMode(cacheMode) - .list() - .layer(new ConvolutionLayer.Builder().nOut(3).build()) - .layer(new ConvolutionLayer.Builder().nOut(3).build()) - .layer(new OutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutionalFlat(28, 28, 1)) - .build(); - + private static MultiLayerConfiguration getConf(CacheMode cacheMode) { + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().activation(Activation.TANH).inferenceWorkspaceMode(WorkspaceMode.ENABLED).trainingWorkspaceMode(WorkspaceMode.ENABLED).seed(12345).cacheMode(cacheMode).list().layer(new ConvolutionLayer.Builder().nOut(3).build()).layer(new ConvolutionLayer.Builder().nOut(3).build()).layer(new OutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).build()).setInputType(InputType.convolutionalFlat(28, 28, 1)).build(); return conf; } @Test - public void testLSTMCacheModeSimple(){ - - for(boolean graves : new boolean[]{true, false}) { - + @DisplayName("Test LSTM Cache Mode Simple") + void testLSTMCacheModeSimple() { + for (boolean graves : new boolean[] { true, false }) { MultiLayerConfiguration conf1 = getConfLSTM(CacheMode.NONE, graves); MultiLayerConfiguration conf2 = getConfLSTM(CacheMode.DEVICE, graves); - MultiLayerNetwork net1 = new MultiLayerNetwork(conf1); net1.init(); MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); net2.init(); - - INDArray in = Nd4j.rand(new int[]{3, 3, 10}); + INDArray in = Nd4j.rand(new int[] { 3, 3, 10 }); INDArray labels = TestUtils.randomOneHotTimeSeries(3, 10, 10); - INDArray out1 = net1.output(in); INDArray out2 = net2.output(in); assertEquals(out1, out2); - assertEquals(net1.params(), net2.params()); net1.fit(in, labels); net2.fit(in, labels); @@ -106,68 +86,33 @@ public class CacheModeTest extends BaseDL4JTest { } } - private static MultiLayerConfiguration getConfLSTM(CacheMode cacheMode, boolean graves){ - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .activation(Activation.TANH) - .inferenceWorkspaceMode(WorkspaceMode.ENABLED) - .trainingWorkspaceMode(WorkspaceMode.ENABLED) - .seed(12345) - .cacheMode(cacheMode) - .list() - .layer(graves ? - new GravesLSTM.Builder().nIn(3).nOut(3).build() : - new LSTM.Builder().nIn(3).nOut(3).build()) - .layer(graves ? - new GravesLSTM.Builder().nIn(3).nOut(3).build() : - new LSTM.Builder().nIn(3).nOut(3).build()) - .layer(new RnnOutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).build()) - .build(); - + private static MultiLayerConfiguration getConfLSTM(CacheMode cacheMode, boolean graves) { + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().activation(Activation.TANH).inferenceWorkspaceMode(WorkspaceMode.ENABLED).trainingWorkspaceMode(WorkspaceMode.ENABLED).seed(12345).cacheMode(cacheMode).list().layer(graves ? new GravesLSTM.Builder().nIn(3).nOut(3).build() : new LSTM.Builder().nIn(3).nOut(3).build()).layer(graves ? new GravesLSTM.Builder().nIn(3).nOut(3).build() : new LSTM.Builder().nIn(3).nOut(3).build()).layer(new RnnOutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).build()).build(); return conf; } - @Test - public void testConvCacheModeSimpleCG(){ - + @DisplayName("Test Conv Cache Mode Simple CG") + void testConvCacheModeSimpleCG() { ComputationGraphConfiguration conf1 = getConfCG(CacheMode.NONE); ComputationGraphConfiguration conf2 = getConfCG(CacheMode.DEVICE); - ComputationGraph net1 = new ComputationGraph(conf1); net1.init(); ComputationGraph net2 = new ComputationGraph(conf2); net2.init(); - - INDArray in = Nd4j.rand(3, 28*28); + INDArray in = Nd4j.rand(3, 28 * 28); INDArray labels = TestUtils.randomOneHot(3, 10); - INDArray out1 = net1.outputSingle(in); INDArray out2 = net2.outputSingle(in); assertEquals(out1, out2); - assertEquals(net1.params(), net2.params()); net1.fit(new DataSet(in, labels)); net2.fit(new DataSet(in, labels)); assertEquals(net1.params(), net2.params()); } - private static ComputationGraphConfiguration getConfCG(CacheMode cacheMode){ - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() - .activation(Activation.TANH) - .inferenceWorkspaceMode(WorkspaceMode.ENABLED) - .trainingWorkspaceMode(WorkspaceMode.ENABLED) - .seed(12345) - .cacheMode(cacheMode) - .graphBuilder() - .addInputs("in") - .layer("0", new ConvolutionLayer.Builder().nOut(3).build(), "in") - .layer("1", new ConvolutionLayer.Builder().nOut(3).build(), "0") - .layer("2", new OutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).build(), "1") - .setOutputs("2") - .setInputTypes(InputType.convolutionalFlat(28, 28, 1)) - .build(); - + private static ComputationGraphConfiguration getConfCG(CacheMode cacheMode) { + ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().activation(Activation.TANH).inferenceWorkspaceMode(WorkspaceMode.ENABLED).trainingWorkspaceMode(WorkspaceMode.ENABLED).seed(12345).cacheMode(cacheMode).graphBuilder().addInputs("in").layer("0", new ConvolutionLayer.Builder().nOut(3).build(), "in").layer("1", new ConvolutionLayer.Builder().nOut(3).build(), "0").layer("2", new OutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).build(), "1").setOutputs("2").setInputTypes(InputType.convolutionalFlat(28, 28, 1)).build(); return conf; } - } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/CenterLossOutputLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/CenterLossOutputLayerTest.java index a7c304a83..778bc332d 100755 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/CenterLossOutputLayerTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/CenterLossOutputLayerTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.layers; import org.deeplearning4j.BaseDL4JTest; @@ -34,8 +33,8 @@ import org.deeplearning4j.nn.conf.layers.SubsamplingLayer; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; @@ -44,73 +43,40 @@ import org.nd4j.linalg.learning.config.Nesterovs; import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; - import java.util.Random; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.assertNotEquals; - -public class CenterLossOutputLayerTest extends BaseDL4JTest { +@DisplayName("Center Loss Output Layer Test") +class CenterLossOutputLayerTest extends BaseDL4JTest { private ComputationGraph getGraph(int numLabels, double lambda) { Nd4j.getRandom().setSeed(12345); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .dist(new NormalDistribution(0, 1)).updater(new NoOp()) - .graphBuilder().addInputs("input1") - .addLayer("l1", new DenseLayer.Builder().nIn(4).nOut(5).activation(Activation.RELU).build(), - "input1") - .addLayer("lossLayer", new CenterLossOutputLayer.Builder() - .lossFunction(LossFunctions.LossFunction.MCXENT).nIn(5).nOut(numLabels) - .lambda(lambda).activation(Activation.SOFTMAX).build(), "l1") - .setOutputs("lossLayer").build(); - + ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).dist(new NormalDistribution(0, 1)).updater(new NoOp()).graphBuilder().addInputs("input1").addLayer("l1", new DenseLayer.Builder().nIn(4).nOut(5).activation(Activation.RELU).build(), "input1").addLayer("lossLayer", new CenterLossOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT).nIn(5).nOut(numLabels).lambda(lambda).activation(Activation.SOFTMAX).build(), "l1").setOutputs("lossLayer").build(); ComputationGraph graph = new ComputationGraph(conf); graph.init(); - return graph; } public ComputationGraph getCNNMnistConfig() { - - int nChannels = 1; // Number of input channels - int outputNum = 10; // The number of possible outcomes - - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) // Training iterations as above - .l2(0.0005).weightInit(WeightInit.XAVIER) - .updater(new Nesterovs(0.01, 0.9)) - .graphBuilder().addInputs("input") - .setInputTypes(InputType.convolutionalFlat(28, 28, 1)) - .addLayer("0", new ConvolutionLayer.Builder(5, 5) - //nIn and nOut specify channels. nIn here is the nChannels and nOut is the number of filters to be applied - .nIn(nChannels).stride(1, 1).nOut(20).activation(Activation.IDENTITY).build(), - "input") - .addLayer("1", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2) - .stride(2, 2).build(), "0") - .addLayer("2", new ConvolutionLayer.Builder(5, 5) - //Note that nIn need not be specified in later layers - .stride(1, 1).nOut(50).activation(Activation.IDENTITY).build(), "1") - .addLayer("3", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2) - .stride(2, 2).build(), "2") - .addLayer("4", new DenseLayer.Builder().activation(Activation.RELU).nOut(500).build(), "3") - .addLayer("output", - new org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer.Builder( - LossFunction.MCXENT).nOut(outputNum) - .activation(Activation.SOFTMAX).build(), - "4") - .setOutputs("output").build(); - + // Number of input channels + int nChannels = 1; + // The number of possible outcomes + int outputNum = 10; + ComputationGraphConfiguration conf = // Training iterations as above + new NeuralNetConfiguration.Builder().seed(12345).l2(0.0005).weightInit(WeightInit.XAVIER).updater(new Nesterovs(0.01, 0.9)).graphBuilder().addInputs("input").setInputTypes(InputType.convolutionalFlat(28, 28, 1)).addLayer("0", new ConvolutionLayer.Builder(5, 5).nIn(nChannels).stride(1, 1).nOut(20).activation(Activation.IDENTITY).build(), "input").addLayer("1", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(2, 2).build(), "0").addLayer("2", new ConvolutionLayer.Builder(5, 5).stride(1, 1).nOut(50).activation(Activation.IDENTITY).build(), "1").addLayer("3", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(2, 2).build(), "2").addLayer("4", new DenseLayer.Builder().activation(Activation.RELU).nOut(500).build(), "3").addLayer("output", new org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer.Builder(LossFunction.MCXENT).nOut(outputNum).activation(Activation.SOFTMAX).build(), "4").setOutputs("output").build(); ComputationGraph graph = new ComputationGraph(conf); graph.init(); - return graph; } @Test - public void testLambdaConf() { - double[] lambdas = new double[] {0.1, 0.01}; + @DisplayName("Test Lambda Conf") + void testLambdaConf() { + double[] lambdas = new double[] { 0.1, 0.01 }; double[] results = new double[2]; int numClasses = 2; - INDArray input = Nd4j.rand(150, 4); INDArray labels = Nd4j.zeros(150, numClasses); Random r = new Random(12345); @@ -118,7 +84,6 @@ public class CenterLossOutputLayerTest extends BaseDL4JTest { labels.putScalar(i, r.nextInt(numClasses), 1.0); } ComputationGraph graph; - for (int i = 0; i < lambdas.length; i++) { graph = getGraph(numClasses, lambdas[i]); graph.setInput(0, input); @@ -126,27 +91,23 @@ public class CenterLossOutputLayerTest extends BaseDL4JTest { graph.computeGradientAndScore(); results[i] = graph.score(); } - assertNotEquals(results[0], results[1]); } - - @Test - @Ignore //Should be run manually - public void testMNISTConfig() throws Exception { - int batchSize = 64; // Test batch size + @Disabled + @DisplayName("Test MNIST Config") + void testMNISTConfig() throws Exception { + // Test batch size + int batchSize = 64; DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, 12345); - ComputationGraph net = getCNNMnistConfig(); net.init(); net.setListeners(new ScoreIterationListener(1)); - for (int i = 0; i < 50; i++) { net.fit(mnistTrain.next()); Thread.sleep(1000); } - Thread.sleep(100000); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/DropoutLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/DropoutLayerTest.java index b22f4c869..679b0ac47 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/DropoutLayerTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/DropoutLayerTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.layers; import org.deeplearning4j.BaseDL4JTest; @@ -36,7 +35,7 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -44,30 +43,30 @@ import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.lossfunctions.LossFunctions; - import java.util.HashMap; import java.util.List; import java.util.Map; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNull; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** */ -public class DropoutLayerTest extends BaseDL4JTest { +@DisplayName("Dropout Layer Test") +class DropoutLayerTest extends BaseDL4JTest { @Override - public DataType getDataType(){ + public DataType getDataType() { return DataType.FLOAT; } @Test - public void testInputTypes() { + @DisplayName("Test Input Types") + void testInputTypes() { DropoutLayer config = new DropoutLayer.Builder(0.5).build(); - InputType in1 = InputType.feedForward(20); InputType in2 = InputType.convolutional(28, 28, 1); - assertEquals(in1, config.getOutputType(0, in1)); assertEquals(in2, config.getOutputType(0, in2)); assertNull(config.getPreProcessorForInputType(in1)); @@ -75,58 +74,30 @@ public class DropoutLayerTest extends BaseDL4JTest { } @Test - public void testDropoutLayerWithoutTraining() throws Exception { - MultiLayerConfiguration confIntegrated = new NeuralNetConfiguration.Builder().seed(3648) - .list().layer(0, - new ConvolutionLayer.Builder(1, 1).stride(1, 1).nIn(1).nOut(1).dropOut(0.25) - .activation(Activation.IDENTITY).weightInit(WeightInit.XAVIER) - .build()) - .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX) - .weightInit(WeightInit.XAVIER).dropOut(0.25) - .nOut(4).build()) - .setInputType(InputType.convolutionalFlat(2, 2, 1)).build(); - + @DisplayName("Test Dropout Layer Without Training") + void testDropoutLayerWithoutTraining() throws Exception { + MultiLayerConfiguration confIntegrated = new NeuralNetConfiguration.Builder().seed(3648).list().layer(0, new ConvolutionLayer.Builder(1, 1).stride(1, 1).nIn(1).nOut(1).dropOut(0.25).activation(Activation.IDENTITY).weightInit(WeightInit.XAVIER).build()).layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).weightInit(WeightInit.XAVIER).dropOut(0.25).nOut(4).build()).setInputType(InputType.convolutionalFlat(2, 2, 1)).build(); MultiLayerNetwork netIntegrated = new MultiLayerNetwork(confIntegrated); netIntegrated.init(); netIntegrated.getLayer(0).setParam("W", Nd4j.eye(1)); netIntegrated.getLayer(0).setParam("b", Nd4j.zeros(1, 1)); netIntegrated.getLayer(1).setParam("W", Nd4j.eye(4)); netIntegrated.getLayer(1).setParam("b", Nd4j.zeros(4, 1)); - - MultiLayerConfiguration confSeparate = - new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .seed(3648) - .list().layer(0, - new DropoutLayer.Builder(0.25) - .build()) - .layer(1, new ConvolutionLayer.Builder(1, 1).stride(1, 1).nIn(1).nOut(1) - .activation(Activation.IDENTITY).weightInit(WeightInit.XAVIER) - .build()) - .layer(2, new DropoutLayer.Builder(0.25).build()) - .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX) - .nOut(4).build()) - - .setInputType(InputType.convolutionalFlat(2, 2, 1)).build(); - + MultiLayerConfiguration confSeparate = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(3648).list().layer(0, new DropoutLayer.Builder(0.25).build()).layer(1, new ConvolutionLayer.Builder(1, 1).stride(1, 1).nIn(1).nOut(1).activation(Activation.IDENTITY).weightInit(WeightInit.XAVIER).build()).layer(2, new DropoutLayer.Builder(0.25).build()).layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nOut(4).build()).setInputType(InputType.convolutionalFlat(2, 2, 1)).build(); MultiLayerNetwork netSeparate = new MultiLayerNetwork(confSeparate); netSeparate.init(); netSeparate.getLayer(1).setParam("W", Nd4j.eye(1)); netSeparate.getLayer(1).setParam("b", Nd4j.zeros(1, 1)); netSeparate.getLayer(3).setParam("W", Nd4j.eye(4)); netSeparate.getLayer(3).setParam("b", Nd4j.zeros(4, 1)); - - //Disable input modification for this test: - for(Layer l : netIntegrated.getLayers()){ + // Disable input modification for this test: + for (Layer l : netIntegrated.getLayers()) { l.allowInputModification(false); } - for(Layer l : netSeparate.getLayers()){ + for (Layer l : netSeparate.getLayers()) { l.allowInputModification(false); } - - INDArray in = Nd4j.arange(1, 5).reshape(1,4); + INDArray in = Nd4j.arange(1, 5).reshape(1, 4); Nd4j.getRandom().setSeed(12345); List actTrainIntegrated = netIntegrated.feedForward(in.dup(), true); Nd4j.getRandom().setSeed(12345); @@ -135,15 +106,10 @@ public class DropoutLayerTest extends BaseDL4JTest { List actTestIntegrated = netIntegrated.feedForward(in.dup(), false); Nd4j.getRandom().setSeed(12345); List actTestSeparate = netSeparate.feedForward(in.dup(), false); - - //Check masks: - INDArray maskIntegrated = ((Dropout)netIntegrated.getLayer(0).conf().getLayer().getIDropout()).getMask(); - INDArray maskSeparate = ((Dropout)netSeparate.getLayer(0).conf().getLayer().getIDropout()).getMask(); + // Check masks: + INDArray maskIntegrated = ((Dropout) netIntegrated.getLayer(0).conf().getLayer().getIDropout()).getMask(); + INDArray maskSeparate = ((Dropout) netSeparate.getLayer(0).conf().getLayer().getIDropout()).getMask(); assertEquals(maskIntegrated, maskSeparate); - - - - assertEquals(actTrainIntegrated.get(1), actTrainSeparate.get(2)); assertEquals(actTrainIntegrated.get(2), actTrainSeparate.get(4)); assertEquals(actTestIntegrated.get(1), actTestSeparate.get(2)); @@ -151,68 +117,41 @@ public class DropoutLayerTest extends BaseDL4JTest { } @Test - public void testDropoutLayerWithDenseMnist() throws Exception { + @DisplayName("Test Dropout Layer With Dense Mnist") + void testDropoutLayerWithDenseMnist() throws Exception { DataSetIterator iter = new MnistDataSetIterator(2, 2); DataSet next = iter.next(); - // Run without separate activation layer - MultiLayerConfiguration confIntegrated = new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123) - .list() - .layer(0, new DenseLayer.Builder().nIn(28 * 28 * 1).nOut(10) - .activation(Activation.RELU).weightInit( - WeightInit.XAVIER) - .build()) - .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).dropOut(0.25) - .nIn(10).nOut(10).build()) - .build(); - + MultiLayerConfiguration confIntegrated = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123).list().layer(0, new DenseLayer.Builder().nIn(28 * 28 * 1).nOut(10).activation(Activation.RELU).weightInit(WeightInit.XAVIER).build()).layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).dropOut(0.25).nIn(10).nOut(10).build()).build(); MultiLayerNetwork netIntegrated = new MultiLayerNetwork(confIntegrated); netIntegrated.init(); netIntegrated.fit(next); - // Run with separate activation layer - MultiLayerConfiguration confSeparate = new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123) - .list() - .layer(0, new DenseLayer.Builder().nIn(28 * 28 * 1).nOut(10).activation(Activation.RELU) - .weightInit(WeightInit.XAVIER).build()) - .layer(1, new DropoutLayer.Builder(0.25).build()) - .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nIn(10).nOut(10) - .build()) - .build(); - + MultiLayerConfiguration confSeparate = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123).list().layer(0, new DenseLayer.Builder().nIn(28 * 28 * 1).nOut(10).activation(Activation.RELU).weightInit(WeightInit.XAVIER).build()).layer(1, new DropoutLayer.Builder(0.25).build()).layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nIn(10).nOut(10).build()).build(); MultiLayerNetwork netSeparate = new MultiLayerNetwork(confSeparate); netSeparate.init(); netSeparate.fit(next); - - //Disable input modification for this test: - for(Layer l : netIntegrated.getLayers()){ + // Disable input modification for this test: + for (Layer l : netIntegrated.getLayers()) { l.allowInputModification(false); } - for(Layer l : netSeparate.getLayers()){ + for (Layer l : netSeparate.getLayers()) { l.allowInputModification(false); } - // check parameters assertEquals(netIntegrated.getLayer(0).getParam("W"), netSeparate.getLayer(0).getParam("W")); assertEquals(netIntegrated.getLayer(0).getParam("b"), netSeparate.getLayer(0).getParam("b")); assertEquals(netIntegrated.getLayer(1).getParam("W"), netSeparate.getLayer(2).getParam("W")); assertEquals(netIntegrated.getLayer(1).getParam("b"), netSeparate.getLayer(2).getParam("b")); - // check activations netIntegrated.setInput(next.getFeatures()); netSeparate.setInput(next.getFeatures()); - Nd4j.getRandom().setSeed(12345); List actTrainIntegrated = netIntegrated.feedForward(true); Nd4j.getRandom().setSeed(12345); List actTrainSeparate = netSeparate.feedForward(true); assertEquals(actTrainIntegrated.get(1), actTrainSeparate.get(1)); assertEquals(actTrainIntegrated.get(2), actTrainSeparate.get(3)); - Nd4j.getRandom().setSeed(12345); List actTestIntegrated = netIntegrated.feedForward(false); Nd4j.getRandom().setSeed(12345); @@ -222,77 +161,49 @@ public class DropoutLayerTest extends BaseDL4JTest { } @Test - public void testDropoutLayerWithConvMnist() throws Exception { - Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE); //Set to double datatype - MKL-DNN not used for CPU (otherwise different strides due to Dl4J impl permutes) + @DisplayName("Test Dropout Layer With Conv Mnist") + void testDropoutLayerWithConvMnist() throws Exception { + // Set to double datatype - MKL-DNN not used for CPU (otherwise different strides due to Dl4J impl permutes) + Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE); DataSetIterator iter = new MnistDataSetIterator(2, 2); DataSet next = iter.next(); - // Run without separate activation layer Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration confIntegrated = new NeuralNetConfiguration.Builder().seed(123) - .list().layer(0, - new ConvolutionLayer.Builder(4, 4).stride(2, 2).nIn(1).nOut(20) - .activation(Activation.TANH).weightInit(WeightInit.XAVIER) - .build()) - .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).dropOut(0.5) - .nOut(10).build()) - .setInputType(InputType.convolutionalFlat(28, 28, 1)).build(); - + MultiLayerConfiguration confIntegrated = new NeuralNetConfiguration.Builder().seed(123).list().layer(0, new ConvolutionLayer.Builder(4, 4).stride(2, 2).nIn(1).nOut(20).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()).layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).dropOut(0.5).nOut(10).build()).setInputType(InputType.convolutionalFlat(28, 28, 1)).build(); // Run with separate activation layer Nd4j.getRandom().setSeed(12345); - - //Manually configure preprocessors - //This is necessary, otherwise CnnToFeedForwardPreprocessor will be in different locatinos - //i.e., dropout on 4d activations in latter, and dropout on 2d activations in former + // Manually configure preprocessors + // This is necessary, otherwise CnnToFeedForwardPreprocessor will be in different locatinos + // i.e., dropout on 4d activations in latter, and dropout on 2d activations in former Map preProcessorMap = new HashMap<>(); preProcessorMap.put(1, new CnnToFeedForwardPreProcessor(13, 13, 20)); - - MultiLayerConfiguration confSeparate = new NeuralNetConfiguration.Builder().seed(123).list() - .layer(0, new ConvolutionLayer.Builder(4, 4).stride(2, 2).nIn(1).nOut(20) - .activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()) - .layer(1, new DropoutLayer.Builder(0.5).build()) - .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nOut(10).build()) - .inputPreProcessors(preProcessorMap) - .setInputType(InputType.convolutionalFlat(28, 28, 1)).build(); - - + MultiLayerConfiguration confSeparate = new NeuralNetConfiguration.Builder().seed(123).list().layer(0, new ConvolutionLayer.Builder(4, 4).stride(2, 2).nIn(1).nOut(20).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()).layer(1, new DropoutLayer.Builder(0.5).build()).layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nOut(10).build()).inputPreProcessors(preProcessorMap).setInputType(InputType.convolutionalFlat(28, 28, 1)).build(); Nd4j.getRandom().setSeed(12345); MultiLayerNetwork netIntegrated = new MultiLayerNetwork(confIntegrated); netIntegrated.init(); - Nd4j.getRandom().setSeed(12345); MultiLayerNetwork netSeparate = new MultiLayerNetwork(confSeparate); netSeparate.init(); - assertEquals(netIntegrated.params(), netSeparate.params()); - Nd4j.getRandom().setSeed(12345); netIntegrated.fit(next); - Nd4j.getRandom().setSeed(12345); netSeparate.fit(next); - assertEquals(netIntegrated.params(), netSeparate.params()); - // check parameters assertEquals(netIntegrated.getLayer(0).getParam("W"), netSeparate.getLayer(0).getParam("W")); assertEquals(netIntegrated.getLayer(0).getParam("b"), netSeparate.getLayer(0).getParam("b")); assertEquals(netIntegrated.getLayer(1).getParam("W"), netSeparate.getLayer(2).getParam("W")); assertEquals(netIntegrated.getLayer(1).getParam("b"), netSeparate.getLayer(2).getParam("b")); - // check activations netIntegrated.setInput(next.getFeatures().dup()); netSeparate.setInput(next.getFeatures().dup()); - Nd4j.getRandom().setSeed(12345); List actTrainIntegrated = netIntegrated.feedForward(true); Nd4j.getRandom().setSeed(12345); List actTrainSeparate = netSeparate.feedForward(true); assertEquals(actTrainIntegrated.get(1), actTrainSeparate.get(1)); assertEquals(actTrainIntegrated.get(2), actTrainSeparate.get(3)); - netIntegrated.setInput(next.getFeatures().dup()); netSeparate.setInput(next.getFeatures().dup()); Nd4j.getRandom().setSeed(12345); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerTest.java index 9849810b4..09d467f8d 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.layers; import lombok.extern.slf4j.Slf4j; @@ -31,116 +30,69 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration; import org.deeplearning4j.nn.transferlearning.TransferLearning; import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.lossfunctions.LossFunctions; - import java.util.List; - -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; @Slf4j -public class FrozenLayerTest extends BaseDL4JTest { +@DisplayName("Frozen Layer Test") +class FrozenLayerTest extends BaseDL4JTest { /* A model with a few frozen layers == Model with non frozen layers set with the output of the forward pass of the frozen layers */ @Test - public void testFrozen() { + @DisplayName("Test Frozen") + void testFrozen() { DataSet randomData = new DataSet(Nd4j.rand(10, 4), Nd4j.rand(10, 3)); - - NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)) - .activation(Activation.IDENTITY); - + NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)).activation(Activation.IDENTITY); FineTuneConfiguration finetune = new FineTuneConfiguration.Builder().updater(new Sgd(0.1)).build(); - - MultiLayerNetwork modelToFineTune = new MultiLayerNetwork(overallConf.clone().list() - .layer(0, new DenseLayer.Builder().nIn(4).nOut(3).build()) - .layer(1, new DenseLayer.Builder().nIn(3).nOut(2).build()) - .layer(2, new DenseLayer.Builder().nIn(2).nOut(3).build()) - .layer(3, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3) - .build()) - .build()); - + MultiLayerNetwork modelToFineTune = new MultiLayerNetwork(overallConf.clone().list().layer(0, new DenseLayer.Builder().nIn(4).nOut(3).build()).layer(1, new DenseLayer.Builder().nIn(3).nOut(2).build()).layer(2, new DenseLayer.Builder().nIn(2).nOut(3).build()).layer(3, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3).build()).build()); modelToFineTune.init(); List ff = modelToFineTune.feedForwardToLayer(2, randomData.getFeatures(), false); INDArray asFrozenFeatures = ff.get(2); - - MultiLayerNetwork modelNow = new TransferLearning.Builder(modelToFineTune).fineTuneConfiguration(finetune) - .setFeatureExtractor(1).build(); - - INDArray paramsLastTwoLayers = - Nd4j.hstack(modelToFineTune.getLayer(2).params(), modelToFineTune.getLayer(3).params()); - MultiLayerNetwork notFrozen = new MultiLayerNetwork(overallConf.clone().list() - .layer(0, new DenseLayer.Builder().nIn(2).nOut(3).build()) - .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3) - .build()) - .build(), paramsLastTwoLayers); - - // assertEquals(modelNow.getLayer(2).conf(), notFrozen.getLayer(0).conf()); //Equal, other than names - // assertEquals(modelNow.getLayer(3).conf(), notFrozen.getLayer(1).conf()); //Equal, other than names - - //Check: forward pass + MultiLayerNetwork modelNow = new TransferLearning.Builder(modelToFineTune).fineTuneConfiguration(finetune).setFeatureExtractor(1).build(); + INDArray paramsLastTwoLayers = Nd4j.hstack(modelToFineTune.getLayer(2).params(), modelToFineTune.getLayer(3).params()); + MultiLayerNetwork notFrozen = new MultiLayerNetwork(overallConf.clone().list().layer(0, new DenseLayer.Builder().nIn(2).nOut(3).build()).layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3).build()).build(), paramsLastTwoLayers); + // assertEquals(modelNow.getLayer(2).conf(), notFrozen.getLayer(0).conf()); //Equal, other than names + // assertEquals(modelNow.getLayer(3).conf(), notFrozen.getLayer(1).conf()); //Equal, other than names + // Check: forward pass INDArray outNow = modelNow.output(randomData.getFeatures()); INDArray outNotFrozen = notFrozen.output(asFrozenFeatures); assertEquals(outNow, outNotFrozen); - for (int i = 0; i < 5; i++) { notFrozen.fit(new DataSet(asFrozenFeatures, randomData.getLabels())); modelNow.fit(randomData); } - - INDArray expected = Nd4j.hstack(modelToFineTune.getLayer(0).params(), modelToFineTune.getLayer(1).params(), - notFrozen.params()); + INDArray expected = Nd4j.hstack(modelToFineTune.getLayer(0).params(), modelToFineTune.getLayer(1).params(), notFrozen.params()); INDArray act = modelNow.params(); assertEquals(expected, act); } - @Test - public void cloneMLNFrozen() { - + @DisplayName("Clone MLN Frozen") + void cloneMLNFrozen() { DataSet randomData = new DataSet(Nd4j.rand(10, 4), Nd4j.rand(10, 3)); - - NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)) - .activation(Activation.IDENTITY); - MultiLayerNetwork modelToFineTune = new MultiLayerNetwork(overallConf.list() - .layer(0, new DenseLayer.Builder().nIn(4).nOut(3).build()) - .layer(1, new DenseLayer.Builder().nIn(3).nOut(2).build()) - .layer(2, new DenseLayer.Builder().nIn(2).nOut(3).build()) - .layer(3, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3) - .build()) - .build()); - + NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)).activation(Activation.IDENTITY); + MultiLayerNetwork modelToFineTune = new MultiLayerNetwork(overallConf.list().layer(0, new DenseLayer.Builder().nIn(4).nOut(3).build()).layer(1, new DenseLayer.Builder().nIn(3).nOut(2).build()).layer(2, new DenseLayer.Builder().nIn(2).nOut(3).build()).layer(3, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3).build()).build()); modelToFineTune.init(); INDArray asFrozenFeatures = modelToFineTune.feedForwardToLayer(2, randomData.getFeatures(), false).get(2); MultiLayerNetwork modelNow = new TransferLearning.Builder(modelToFineTune).setFeatureExtractor(1).build(); - MultiLayerNetwork clonedModel = modelNow.clone(); - - //Check json + // Check json assertEquals(modelNow.getLayerWiseConfigurations().toJson(), clonedModel.getLayerWiseConfigurations().toJson()); - - //Check params + // Check params assertEquals(modelNow.params(), clonedModel.params()); - - MultiLayerNetwork notFrozen = new MultiLayerNetwork( - overallConf.list().layer(0, new DenseLayer.Builder().nIn(2).nOut(3).build()) - .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nIn(3).nOut(3) - .build()) - .build(), - Nd4j.hstack(modelToFineTune.getLayer(2).params(), modelToFineTune.getLayer(3).params())); - + MultiLayerNetwork notFrozen = new MultiLayerNetwork(overallConf.list().layer(0, new DenseLayer.Builder().nIn(2).nOut(3).build()).layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3).build()).build(), Nd4j.hstack(modelToFineTune.getLayer(2).params(), modelToFineTune.getLayer(3).params())); int i = 0; while (i < 5) { notFrozen.fit(new DataSet(asFrozenFeatures, randomData.getLabels())); @@ -148,112 +100,49 @@ public class FrozenLayerTest extends BaseDL4JTest { clonedModel.fit(randomData); i++; } - - INDArray expectedParams = Nd4j.hstack(modelToFineTune.getLayer(0).params(), - modelToFineTune.getLayer(1).params(), notFrozen.params()); + INDArray expectedParams = Nd4j.hstack(modelToFineTune.getLayer(0).params(), modelToFineTune.getLayer(1).params(), notFrozen.params()); assertEquals(expectedParams, modelNow.params()); assertEquals(expectedParams, clonedModel.params()); - } - @Test - public void testFrozenCompGraph() { + @DisplayName("Test Frozen Comp Graph") + void testFrozenCompGraph() { DataSet randomData = new DataSet(Nd4j.rand(10, 4), Nd4j.rand(10, 3)); - - NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)) - .activation(Activation.IDENTITY); - - ComputationGraph modelToFineTune = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In") - .addLayer("layer0", new DenseLayer.Builder().nIn(4).nOut(3).build(), "layer0In") - .addLayer("layer1", new DenseLayer.Builder().nIn(3).nOut(2).build(), "layer0") - .addLayer("layer2", new DenseLayer.Builder().nIn(2).nOut(3).build(), "layer1") - .addLayer("layer3", - new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nIn(3).nOut(3) - .build(), - "layer2") - .setOutputs("layer3").build()); - + NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)).activation(Activation.IDENTITY); + ComputationGraph modelToFineTune = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In").addLayer("layer0", new DenseLayer.Builder().nIn(4).nOut(3).build(), "layer0In").addLayer("layer1", new DenseLayer.Builder().nIn(3).nOut(2).build(), "layer0").addLayer("layer2", new DenseLayer.Builder().nIn(2).nOut(3).build(), "layer1").addLayer("layer3", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3).build(), "layer2").setOutputs("layer3").build()); modelToFineTune.init(); INDArray asFrozenFeatures = modelToFineTune.feedForward(randomData.getFeatures(), false).get("layer1"); - - ComputationGraph modelNow = - new TransferLearning.GraphBuilder(modelToFineTune).setFeatureExtractor("layer1").build(); - - ComputationGraph notFrozen = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In") - .addLayer("layer0", new DenseLayer.Builder().nIn(2).nOut(3).build(), "layer0In") - .addLayer("layer1", - new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nIn(3).nOut(3) - .build(), - "layer0") - .setOutputs("layer1").build()); - + ComputationGraph modelNow = new TransferLearning.GraphBuilder(modelToFineTune).setFeatureExtractor("layer1").build(); + ComputationGraph notFrozen = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In").addLayer("layer0", new DenseLayer.Builder().nIn(2).nOut(3).build(), "layer0In").addLayer("layer1", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3).build(), "layer0").setOutputs("layer1").build()); notFrozen.init(); - notFrozen.setParams(Nd4j.hstack(modelToFineTune.getLayer("layer2").params(), - modelToFineTune.getLayer("layer3").params())); - + notFrozen.setParams(Nd4j.hstack(modelToFineTune.getLayer("layer2").params(), modelToFineTune.getLayer("layer3").params())); int i = 0; while (i < 5) { notFrozen.fit(new DataSet(asFrozenFeatures, randomData.getLabels())); modelNow.fit(randomData); i++; } - - assertEquals(Nd4j.hstack(modelToFineTune.getLayer("layer0").params(), - modelToFineTune.getLayer("layer1").params(), notFrozen.params()), modelNow.params()); + assertEquals(Nd4j.hstack(modelToFineTune.getLayer("layer0").params(), modelToFineTune.getLayer("layer1").params(), notFrozen.params()), modelNow.params()); } @Test - public void cloneCompGraphFrozen() { - + @DisplayName("Clone Comp Graph Frozen") + void cloneCompGraphFrozen() { DataSet randomData = new DataSet(Nd4j.rand(10, 4), Nd4j.rand(10, 3)); - - NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)) - .activation(Activation.IDENTITY); - - ComputationGraph modelToFineTune = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In") - .addLayer("layer0", new DenseLayer.Builder().nIn(4).nOut(3).build(), "layer0In") - .addLayer("layer1", new DenseLayer.Builder().nIn(3).nOut(2).build(), "layer0") - .addLayer("layer2", new DenseLayer.Builder().nIn(2).nOut(3).build(), "layer1") - .addLayer("layer3", - new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nIn(3).nOut(3) - .build(), - "layer2") - .setOutputs("layer3").build()); - + NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)).activation(Activation.IDENTITY); + ComputationGraph modelToFineTune = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In").addLayer("layer0", new DenseLayer.Builder().nIn(4).nOut(3).build(), "layer0In").addLayer("layer1", new DenseLayer.Builder().nIn(3).nOut(2).build(), "layer0").addLayer("layer2", new DenseLayer.Builder().nIn(2).nOut(3).build(), "layer1").addLayer("layer3", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3).build(), "layer2").setOutputs("layer3").build()); modelToFineTune.init(); INDArray asFrozenFeatures = modelToFineTune.feedForward(randomData.getFeatures(), false).get("layer1"); - ComputationGraph modelNow = - new TransferLearning.GraphBuilder(modelToFineTune).setFeatureExtractor("layer1").build(); - + ComputationGraph modelNow = new TransferLearning.GraphBuilder(modelToFineTune).setFeatureExtractor("layer1").build(); ComputationGraph clonedModel = modelNow.clone(); - - //Check json + // Check json assertEquals(clonedModel.getConfiguration().toJson(), modelNow.getConfiguration().toJson()); - - //Check params + // Check params assertEquals(modelNow.params(), clonedModel.params()); - - ComputationGraph notFrozen = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In") - .addLayer("layer0", new DenseLayer.Builder().nIn(2).nOut(3).build(), "layer0In") - .addLayer("layer1", - new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nIn(3).nOut(3) - .build(), - "layer0") - .setOutputs("layer1").build()); + ComputationGraph notFrozen = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In").addLayer("layer0", new DenseLayer.Builder().nIn(2).nOut(3).build(), "layer0In").addLayer("layer1", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3).build(), "layer0").setOutputs("layer1").build()); notFrozen.init(); - notFrozen.setParams(Nd4j.hstack(modelToFineTune.getLayer("layer2").params(), - modelToFineTune.getLayer("layer3").params())); - - + notFrozen.setParams(Nd4j.hstack(modelToFineTune.getLayer("layer2").params(), modelToFineTune.getLayer("layer3").params())); int i = 0; while (i < 5) { notFrozen.fit(new DataSet(asFrozenFeatures, randomData.getLabels())); @@ -261,117 +150,54 @@ public class FrozenLayerTest extends BaseDL4JTest { clonedModel.fit(randomData); i++; } - - INDArray expectedParams = Nd4j.hstack(modelToFineTune.getLayer("layer0").params(), - modelToFineTune.getLayer("layer1").params(), notFrozen.params()); + INDArray expectedParams = Nd4j.hstack(modelToFineTune.getLayer("layer0").params(), modelToFineTune.getLayer("layer1").params(), notFrozen.params()); assertEquals(expectedParams, modelNow.params()); assertEquals(expectedParams, clonedModel.params()); } - @Test - public void testFrozenLayerInstantiation() { - //We need to be able to instantitate frozen layers from JSON etc, and have them be the same as if + @DisplayName("Test Frozen Layer Instantiation") + void testFrozenLayerInstantiation() { + // We need to be able to instantitate frozen layers from JSON etc, and have them be the same as if // they were initialized via the builder - MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(12345).list() - .layer(0, new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH) - .weightInit(WeightInit.XAVIER).build()) - .layer(1, new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH) - .weightInit(WeightInit.XAVIER).build()) - .layer(2, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10) - .nOut(10).build()) - .build(); - - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345).list().layer(0, - new org.deeplearning4j.nn.conf.layers.misc.FrozenLayer(new DenseLayer.Builder().nIn(10).nOut(10) - .activation(Activation.TANH).weightInit(WeightInit.XAVIER).build())) - .layer(1, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayer( - new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH) - .weightInit(WeightInit.XAVIER).build())) - .layer(2, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10) - .nOut(10).build()) - .build(); - + MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(12345).list().layer(0, new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()).layer(1, new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()).layer(2, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10).nOut(10).build()).build(); + MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345).list().layer(0, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayer(new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build())).layer(1, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayer(new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build())).layer(2, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10).nOut(10).build()).build(); MultiLayerNetwork net1 = new MultiLayerNetwork(conf1); net1.init(); MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); net2.init(); - assertEquals(net1.params(), net2.params()); - - String json = conf2.toJson(); MultiLayerConfiguration fromJson = MultiLayerConfiguration.fromJson(json); - assertEquals(conf2, fromJson); - MultiLayerNetwork net3 = new MultiLayerNetwork(fromJson); net3.init(); - INDArray input = Nd4j.rand(10, 10); - INDArray out2 = net2.output(input); INDArray out3 = net3.output(input); - assertEquals(out2, out3); } @Test - public void testFrozenLayerInstantiationCompGraph() { - - //We need to be able to instantitate frozen layers from JSON etc, and have them be the same as if + @DisplayName("Test Frozen Layer Instantiation Comp Graph") + void testFrozenLayerInstantiationCompGraph() { + // We need to be able to instantitate frozen layers from JSON etc, and have them be the same as if // they were initialized via the builder - ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(12345).graphBuilder() - .addInputs("in") - .addLayer("0", new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH) - .weightInit(WeightInit.XAVIER).build(), "in") - .addLayer("1", new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH) - .weightInit(WeightInit.XAVIER).build(), "0") - .addLayer("2", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10) - .nOut(10).build(), - "1") - .setOutputs("2").build(); - - ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345).graphBuilder() - .addInputs("in") - .addLayer("0", new org.deeplearning4j.nn.conf.layers.misc.FrozenLayer.Builder() - .layer(new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH) - .weightInit(WeightInit.XAVIER).build()) - .build(), "in") - .addLayer("1", new org.deeplearning4j.nn.conf.layers.misc.FrozenLayer.Builder() - .layer(new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH) - .weightInit(WeightInit.XAVIER).build()) - .build(), "0") - .addLayer("2", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10) - .nOut(10).build(), - "1") - .setOutputs("2").build(); - + ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(12345).graphBuilder().addInputs("in").addLayer("0", new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build(), "in").addLayer("1", new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build(), "0").addLayer("2", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10).nOut(10).build(), "1").setOutputs("2").build(); + ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345).graphBuilder().addInputs("in").addLayer("0", new org.deeplearning4j.nn.conf.layers.misc.FrozenLayer.Builder().layer(new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()).build(), "in").addLayer("1", new org.deeplearning4j.nn.conf.layers.misc.FrozenLayer.Builder().layer(new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()).build(), "0").addLayer("2", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10).nOut(10).build(), "1").setOutputs("2").build(); ComputationGraph net1 = new ComputationGraph(conf1); net1.init(); ComputationGraph net2 = new ComputationGraph(conf2); net2.init(); - assertEquals(net1.params(), net2.params()); - - String json = conf2.toJson(); ComputationGraphConfiguration fromJson = ComputationGraphConfiguration.fromJson(json); - assertEquals(conf2, fromJson); - ComputationGraph net3 = new ComputationGraph(fromJson); net3.init(); - INDArray input = Nd4j.rand(10, 10); - INDArray out2 = net2.outputSingle(input); INDArray out3 = net3.outputSingle(input); - assertEquals(out2, out3); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerWithBackpropTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerWithBackpropTest.java index 40d0aed93..925645781 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerWithBackpropTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerWithBackpropTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.layers; import lombok.extern.slf4j.Slf4j; @@ -34,363 +33,194 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration; import org.deeplearning4j.nn.transferlearning.TransferLearning; import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.lossfunctions.LossFunctions; - import java.util.List; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotEquals; -import static org.junit.Assert.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; @Slf4j -public class FrozenLayerWithBackpropTest extends BaseDL4JTest { +@DisplayName("Frozen Layer With Backprop Test") +class FrozenLayerWithBackpropTest extends BaseDL4JTest { @Test - public void testFrozenWithBackpropLayerInstantiation() { - //We need to be able to instantitate frozen layers from JSON etc, and have them be the same as if + @DisplayName("Test Frozen With Backprop Layer Instantiation") + void testFrozenWithBackpropLayerInstantiation() { + // We need to be able to instantitate frozen layers from JSON etc, and have them be the same as if // they were initialized via the builder - MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(12345).list() - .layer(0, new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH) - .weightInit(WeightInit.XAVIER).build()) - .layer(1, new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH) - .weightInit(WeightInit.XAVIER).build()) - .layer(2, new OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10) - .nOut(10).build()) - .build(); - - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345).list().layer(0, - new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(10).nOut(10) - .activation(Activation.TANH).weightInit(WeightInit.XAVIER).build())) - .layer(1, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop( - new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH) - .weightInit(WeightInit.XAVIER).build())) - .layer(2, new OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10) - .nOut(10).build()) - .build(); - + MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(12345).list().layer(0, new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()).layer(1, new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()).layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10).nOut(10).build()).build(); + MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345).list().layer(0, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build())).layer(1, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build())).layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10).nOut(10).build()).build(); MultiLayerNetwork net1 = new MultiLayerNetwork(conf1); net1.init(); MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); net2.init(); - assertEquals(net1.params(), net2.params()); - - String json = conf2.toJson(); MultiLayerConfiguration fromJson = MultiLayerConfiguration.fromJson(json); - assertEquals(conf2, fromJson); - MultiLayerNetwork net3 = new MultiLayerNetwork(fromJson); net3.init(); - INDArray input = Nd4j.rand(10, 10); - INDArray out2 = net2.output(input); INDArray out3 = net3.output(input); - assertEquals(out2, out3); } @Test - public void testFrozenLayerInstantiationCompGraph() { - - //We need to be able to instantitate frozen layers from JSON etc, and have them be the same as if + @DisplayName("Test Frozen Layer Instantiation Comp Graph") + void testFrozenLayerInstantiationCompGraph() { + // We need to be able to instantitate frozen layers from JSON etc, and have them be the same as if // they were initialized via the builder - ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(12345).graphBuilder() - .addInputs("in") - .addLayer("0", new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH) - .weightInit(WeightInit.XAVIER).build(), "in") - .addLayer("1", new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH) - .weightInit(WeightInit.XAVIER).build(), "0") - .addLayer("2", new OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10) - .nOut(10).build(), - "1") - .setOutputs("2").build(); - - ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345).graphBuilder() - .addInputs("in") - .addLayer("0", new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop( - new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH) - .weightInit(WeightInit.XAVIER).build()), "in") - .addLayer("1", new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop( - new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH) - .weightInit(WeightInit.XAVIER).build()), "0") - .addLayer("2", new OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10) - .nOut(10).build(), - "1") - .setOutputs("2").build(); - + ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(12345).graphBuilder().addInputs("in").addLayer("0", new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build(), "in").addLayer("1", new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build(), "0").addLayer("2", new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10).nOut(10).build(), "1").setOutputs("2").build(); + ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345).graphBuilder().addInputs("in").addLayer("0", new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()), "in").addLayer("1", new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()), "0").addLayer("2", new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10).nOut(10).build(), "1").setOutputs("2").build(); ComputationGraph net1 = new ComputationGraph(conf1); net1.init(); ComputationGraph net2 = new ComputationGraph(conf2); net2.init(); - assertEquals(net1.params(), net2.params()); - - String json = conf2.toJson(); ComputationGraphConfiguration fromJson = ComputationGraphConfiguration.fromJson(json); - assertEquals(conf2, fromJson); - ComputationGraph net3 = new ComputationGraph(fromJson); net3.init(); - INDArray input = Nd4j.rand(10, 10); - INDArray out2 = net2.outputSingle(input); INDArray out3 = net3.outputSingle(input); - assertEquals(out2, out3); } @Test - public void testMultiLayerNetworkFrozenLayerParamsAfterBackprop() { + @DisplayName("Test Multi Layer Network Frozen Layer Params After Backprop") + void testMultiLayerNetworkFrozenLayerParamsAfterBackprop() { Nd4j.getRandom().setSeed(12345); DataSet randomData = new DataSet(Nd4j.rand(100, 4), Nd4j.rand(100, 1)); - - MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder() - .seed(12345) - .weightInit(WeightInit.XAVIER) - .updater(new Sgd(2)) - .list() - .layer(new DenseLayer.Builder().nIn(4).nOut(3).build()) - .layer(new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop( - new DenseLayer.Builder().nIn(3).nOut(4).build())) - .layer(new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop( - new DenseLayer.Builder().nIn(4).nOut(2).build())) - .layer(new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop( - new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(2).nOut(1).build())) - .build(); - + MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(12345).weightInit(WeightInit.XAVIER).updater(new Sgd(2)).list().layer(new DenseLayer.Builder().nIn(4).nOut(3).build()).layer(new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(3).nOut(4).build())).layer(new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(4).nOut(2).build())).layer(new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(2).nOut(1).build())).build(); MultiLayerNetwork network = new MultiLayerNetwork(conf1); network.init(); INDArray unfrozenLayerParams = network.getLayer(0).params().dup(); INDArray frozenLayerParams1 = network.getLayer(1).params().dup(); INDArray frozenLayerParams2 = network.getLayer(2).params().dup(); INDArray frozenOutputLayerParams = network.getLayer(3).params().dup(); - for (int i = 0; i < 100; i++) { network.fit(randomData); } - assertNotEquals(unfrozenLayerParams, network.getLayer(0).params()); assertEquals(frozenLayerParams1, network.getLayer(1).params()); assertEquals(frozenLayerParams2, network.getLayer(2).params()); assertEquals(frozenOutputLayerParams, network.getLayer(3).params()); - } @Test - public void testComputationGraphFrozenLayerParamsAfterBackprop() { + @DisplayName("Test Computation Graph Frozen Layer Params After Backprop") + void testComputationGraphFrozenLayerParamsAfterBackprop() { Nd4j.getRandom().setSeed(12345); - DataSet randomData = new DataSet(Nd4j.rand(100, 4), Nd4j.rand(100, 1)); String frozenBranchName = "B1-"; String unfrozenBranchName = "B2-"; - String initialLayer = "initial"; - String frozenBranchUnfrozenLayer0 = frozenBranchName + "0"; String frozenBranchFrozenLayer1 = frozenBranchName + "1"; String frozenBranchFrozenLayer2 = frozenBranchName + "2"; String frozenBranchOutput = frozenBranchName + "Output"; - - String unfrozenLayer0 = unfrozenBranchName + "0"; String unfrozenLayer1 = unfrozenBranchName + "1"; String unfrozenBranch2 = unfrozenBranchName + "Output"; - - ComputationGraphConfiguration computationGraphConf = new NeuralNetConfiguration.Builder() - .updater(new Sgd(2.0)) - .seed(12345) - .graphBuilder() - .addInputs("input") - .addLayer(initialLayer, new DenseLayer.Builder().nIn(4).nOut(4).build(),"input") - .addLayer(frozenBranchUnfrozenLayer0, new DenseLayer.Builder().nIn(4).nOut(3).build(),initialLayer) - .addLayer(frozenBranchFrozenLayer1, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop( - new DenseLayer.Builder().nIn(3).nOut(4).build()),frozenBranchUnfrozenLayer0) - .addLayer(frozenBranchFrozenLayer2, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop( - new DenseLayer.Builder().nIn(4).nOut(2).build()),frozenBranchFrozenLayer1) - .addLayer(unfrozenLayer0, new DenseLayer.Builder().nIn(4).nOut(4).build(),initialLayer) - .addLayer(unfrozenLayer1, new DenseLayer.Builder().nIn(4).nOut(2).build(),unfrozenLayer0) - .addLayer(unfrozenBranch2, new DenseLayer.Builder().nIn(2).nOut(1).build(),unfrozenLayer1) - .addVertex("merge", new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2) - .addLayer(frozenBranchOutput,new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop( - new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(3).nOut(1).build()),"merge") - .setOutputs(frozenBranchOutput) - .build(); - + ComputationGraphConfiguration computationGraphConf = new NeuralNetConfiguration.Builder().updater(new Sgd(2.0)).seed(12345).graphBuilder().addInputs("input").addLayer(initialLayer, new DenseLayer.Builder().nIn(4).nOut(4).build(), "input").addLayer(frozenBranchUnfrozenLayer0, new DenseLayer.Builder().nIn(4).nOut(3).build(), initialLayer).addLayer(frozenBranchFrozenLayer1, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(3).nOut(4).build()), frozenBranchUnfrozenLayer0).addLayer(frozenBranchFrozenLayer2, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(4).nOut(2).build()), frozenBranchFrozenLayer1).addLayer(unfrozenLayer0, new DenseLayer.Builder().nIn(4).nOut(4).build(), initialLayer).addLayer(unfrozenLayer1, new DenseLayer.Builder().nIn(4).nOut(2).build(), unfrozenLayer0).addLayer(unfrozenBranch2, new DenseLayer.Builder().nIn(2).nOut(1).build(), unfrozenLayer1).addVertex("merge", new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2).addLayer(frozenBranchOutput, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(3).nOut(1).build()), "merge").setOutputs(frozenBranchOutput).build(); ComputationGraph computationGraph = new ComputationGraph(computationGraphConf); computationGraph.init(); INDArray unfrozenLayerParams = computationGraph.getLayer(frozenBranchUnfrozenLayer0).params().dup(); INDArray frozenLayerParams1 = computationGraph.getLayer(frozenBranchFrozenLayer1).params().dup(); INDArray frozenLayerParams2 = computationGraph.getLayer(frozenBranchFrozenLayer2).params().dup(); INDArray frozenOutputLayerParams = computationGraph.getLayer(frozenBranchOutput).params().dup(); - for (int i = 0; i < 100; i++) { computationGraph.fit(randomData); } - assertNotEquals(unfrozenLayerParams, computationGraph.getLayer(frozenBranchUnfrozenLayer0).params()); assertEquals(frozenLayerParams1, computationGraph.getLayer(frozenBranchFrozenLayer1).params()); assertEquals(frozenLayerParams2, computationGraph.getLayer(frozenBranchFrozenLayer2).params()); assertEquals(frozenOutputLayerParams, computationGraph.getLayer(frozenBranchOutput).params()); - } /** * Frozen layer should have same results as a layer with Sgd updater with learning rate set to 0 */ @Test - public void testFrozenLayerVsSgd() { + @DisplayName("Test Frozen Layer Vs Sgd") + void testFrozenLayerVsSgd() { Nd4j.getRandom().setSeed(12345); DataSet randomData = new DataSet(Nd4j.rand(100, 4), Nd4j.rand(100, 1)); - - MultiLayerConfiguration confSgd = new NeuralNetConfiguration.Builder() - .seed(12345) - .weightInit(WeightInit.XAVIER) - .updater(new Sgd(2)) - .list() - .layer(0,new DenseLayer.Builder().nIn(4).nOut(3).build()) - .layer(1,new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(3).nOut(4).build()) - .layer(2,new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(4).nOut(2).build()) - .layer(3,new OutputLayer.Builder(LossFunctions.LossFunction.MSE).updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).activation(Activation.TANH).nIn(2).nOut(1).build()) - .build(); - - MultiLayerConfiguration confFrozen = new NeuralNetConfiguration.Builder() - .seed(12345) - .weightInit(WeightInit.XAVIER) - .updater(new Sgd(2)) - .list() - .layer(0,new DenseLayer.Builder().nIn(4).nOut(3).build()) - .layer(1,new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(3).nOut(4).build())) - .layer(2,new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(4).nOut(2).build())) - .layer(3,new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(2).nOut(1).build())) - .build(); + MultiLayerConfiguration confSgd = new NeuralNetConfiguration.Builder().seed(12345).weightInit(WeightInit.XAVIER).updater(new Sgd(2)).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(3).build()).layer(1, new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(3).nOut(4).build()).layer(2, new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(4).nOut(2).build()).layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MSE).updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).activation(Activation.TANH).nIn(2).nOut(1).build()).build(); + MultiLayerConfiguration confFrozen = new NeuralNetConfiguration.Builder().seed(12345).weightInit(WeightInit.XAVIER).updater(new Sgd(2)).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(3).build()).layer(1, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(3).nOut(4).build())).layer(2, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(4).nOut(2).build())).layer(3, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(2).nOut(1).build())).build(); MultiLayerNetwork frozenNetwork = new MultiLayerNetwork(confFrozen); frozenNetwork.init(); INDArray unfrozenLayerParams = frozenNetwork.getLayer(0).params().dup(); INDArray frozenLayerParams1 = frozenNetwork.getLayer(1).params().dup(); INDArray frozenLayerParams2 = frozenNetwork.getLayer(2).params().dup(); INDArray frozenOutputLayerParams = frozenNetwork.getLayer(3).params().dup(); - MultiLayerNetwork sgdNetwork = new MultiLayerNetwork(confSgd); sgdNetwork.init(); INDArray unfrozenSgdLayerParams = sgdNetwork.getLayer(0).params().dup(); INDArray frozenSgdLayerParams1 = sgdNetwork.getLayer(1).params().dup(); INDArray frozenSgdLayerParams2 = sgdNetwork.getLayer(2).params().dup(); INDArray frozenSgdOutputLayerParams = sgdNetwork.getLayer(3).params().dup(); - for (int i = 0; i < 100; i++) { frozenNetwork.fit(randomData); } for (int i = 0; i < 100; i++) { sgdNetwork.fit(randomData); } - assertEquals(frozenNetwork.getLayer(0).params(), sgdNetwork.getLayer(0).params()); assertEquals(frozenNetwork.getLayer(1).params(), sgdNetwork.getLayer(1).params()); assertEquals(frozenNetwork.getLayer(2).params(), sgdNetwork.getLayer(2).params()); assertEquals(frozenNetwork.getLayer(3).params(), sgdNetwork.getLayer(3).params()); - } @Test - public void testComputationGraphVsSgd() { + @DisplayName("Test Computation Graph Vs Sgd") + void testComputationGraphVsSgd() { Nd4j.getRandom().setSeed(12345); DataSet randomData = new DataSet(Nd4j.rand(100, 4), Nd4j.rand(100, 1)); String frozenBranchName = "B1-"; String unfrozenBranchName = "B2-"; - String initialLayer = "initial"; - String frozenBranchUnfrozenLayer0 = frozenBranchName + "0"; String frozenBranchFrozenLayer1 = frozenBranchName + "1"; String frozenBranchFrozenLayer2 = frozenBranchName + "2"; String frozenBranchOutput = frozenBranchName + "Output"; - - String unfrozenLayer0 = unfrozenBranchName + "0"; String unfrozenLayer1 = unfrozenBranchName + "1"; String unfrozenBranch2 = unfrozenBranchName + "Output"; - - ComputationGraphConfiguration computationGraphConf = new NeuralNetConfiguration.Builder() - .updater(new Sgd(2.0)) - .seed(12345) - .graphBuilder() - .addInputs("input") - .addLayer(initialLayer,new DenseLayer.Builder().nIn(4).nOut(4).build(),"input") - .addLayer(frozenBranchUnfrozenLayer0,new DenseLayer.Builder().nIn(4).nOut(3).build(), initialLayer) - .addLayer(frozenBranchFrozenLayer1,new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop( - new DenseLayer.Builder().nIn(3).nOut(4).build()),frozenBranchUnfrozenLayer0) - .addLayer(frozenBranchFrozenLayer2, - new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop( - new DenseLayer.Builder().nIn(4).nOut(2).build()),frozenBranchFrozenLayer1) - .addLayer(unfrozenLayer0,new DenseLayer.Builder().nIn(4).nOut(4).build(),initialLayer) - .addLayer(unfrozenLayer1,new DenseLayer.Builder().nIn(4).nOut(2).build(),unfrozenLayer0) - .addLayer(unfrozenBranch2,new DenseLayer.Builder().nIn(2).nOut(1).build(),unfrozenLayer1) - .addVertex("merge",new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2) - .addLayer(frozenBranchOutput, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop( - new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(3).nOut(1).build()),"merge") - .setOutputs(frozenBranchOutput) - .build(); - - ComputationGraphConfiguration computationSgdGraphConf = new NeuralNetConfiguration.Builder() - .updater(new Sgd(2.0)) - .seed(12345) - .graphBuilder() - .addInputs("input") - .addLayer(initialLayer, new DenseLayer.Builder().nIn(4).nOut(4).build(),"input") - .addLayer(frozenBranchUnfrozenLayer0,new DenseLayer.Builder().nIn(4).nOut(3).build(),initialLayer) - .addLayer(frozenBranchFrozenLayer1,new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(3).nOut(4).build(),frozenBranchUnfrozenLayer0) - .addLayer(frozenBranchFrozenLayer2,new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(4).nOut(2).build(),frozenBranchFrozenLayer1) - .addLayer(unfrozenLayer0,new DenseLayer.Builder().nIn(4).nOut(4).build(),initialLayer) - .addLayer(unfrozenLayer1,new DenseLayer.Builder().nIn(4).nOut(2).build(),unfrozenLayer0) - .addLayer(unfrozenBranch2,new DenseLayer.Builder().nIn(2).nOut(1).build(),unfrozenLayer1) - .addVertex("merge",new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2) - .addLayer(frozenBranchOutput,new OutputLayer.Builder(LossFunctions.LossFunction.MSE).updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).activation(Activation.TANH).nIn(3).nOut(1).build(),"merge") - .setOutputs(frozenBranchOutput) - .build(); - + ComputationGraphConfiguration computationGraphConf = new NeuralNetConfiguration.Builder().updater(new Sgd(2.0)).seed(12345).graphBuilder().addInputs("input").addLayer(initialLayer, new DenseLayer.Builder().nIn(4).nOut(4).build(), "input").addLayer(frozenBranchUnfrozenLayer0, new DenseLayer.Builder().nIn(4).nOut(3).build(), initialLayer).addLayer(frozenBranchFrozenLayer1, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(3).nOut(4).build()), frozenBranchUnfrozenLayer0).addLayer(frozenBranchFrozenLayer2, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(4).nOut(2).build()), frozenBranchFrozenLayer1).addLayer(unfrozenLayer0, new DenseLayer.Builder().nIn(4).nOut(4).build(), initialLayer).addLayer(unfrozenLayer1, new DenseLayer.Builder().nIn(4).nOut(2).build(), unfrozenLayer0).addLayer(unfrozenBranch2, new DenseLayer.Builder().nIn(2).nOut(1).build(), unfrozenLayer1).addVertex("merge", new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2).addLayer(frozenBranchOutput, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(3).nOut(1).build()), "merge").setOutputs(frozenBranchOutput).build(); + ComputationGraphConfiguration computationSgdGraphConf = new NeuralNetConfiguration.Builder().updater(new Sgd(2.0)).seed(12345).graphBuilder().addInputs("input").addLayer(initialLayer, new DenseLayer.Builder().nIn(4).nOut(4).build(), "input").addLayer(frozenBranchUnfrozenLayer0, new DenseLayer.Builder().nIn(4).nOut(3).build(), initialLayer).addLayer(frozenBranchFrozenLayer1, new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(3).nOut(4).build(), frozenBranchUnfrozenLayer0).addLayer(frozenBranchFrozenLayer2, new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(4).nOut(2).build(), frozenBranchFrozenLayer1).addLayer(unfrozenLayer0, new DenseLayer.Builder().nIn(4).nOut(4).build(), initialLayer).addLayer(unfrozenLayer1, new DenseLayer.Builder().nIn(4).nOut(2).build(), unfrozenLayer0).addLayer(unfrozenBranch2, new DenseLayer.Builder().nIn(2).nOut(1).build(), unfrozenLayer1).addVertex("merge", new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2).addLayer(frozenBranchOutput, new OutputLayer.Builder(LossFunctions.LossFunction.MSE).updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).activation(Activation.TANH).nIn(3).nOut(1).build(), "merge").setOutputs(frozenBranchOutput).build(); ComputationGraph frozenComputationGraph = new ComputationGraph(computationGraphConf); frozenComputationGraph.init(); INDArray unfrozenLayerParams = frozenComputationGraph.getLayer(frozenBranchUnfrozenLayer0).params().dup(); INDArray frozenLayerParams1 = frozenComputationGraph.getLayer(frozenBranchFrozenLayer1).params().dup(); INDArray frozenLayerParams2 = frozenComputationGraph.getLayer(frozenBranchFrozenLayer2).params().dup(); INDArray frozenOutputLayerParams = frozenComputationGraph.getLayer(frozenBranchOutput).params().dup(); - ComputationGraph sgdComputationGraph = new ComputationGraph(computationSgdGraphConf); sgdComputationGraph.init(); INDArray unfrozenSgdLayerParams = sgdComputationGraph.getLayer(frozenBranchUnfrozenLayer0).params().dup(); INDArray frozenSgdLayerParams1 = sgdComputationGraph.getLayer(frozenBranchFrozenLayer1).params().dup(); INDArray frozenSgdLayerParams2 = sgdComputationGraph.getLayer(frozenBranchFrozenLayer2).params().dup(); INDArray frozenSgdOutputLayerParams = sgdComputationGraph.getLayer(frozenBranchOutput).params().dup(); - for (int i = 0; i < 100; i++) { frozenComputationGraph.fit(randomData); } for (int i = 0; i < 100; i++) { sgdComputationGraph.fit(randomData); } - assertEquals(frozenComputationGraph.getLayer(frozenBranchUnfrozenLayer0).params(), sgdComputationGraph.getLayer(frozenBranchUnfrozenLayer0).params()); assertEquals(frozenComputationGraph.getLayer(frozenBranchFrozenLayer1).params(), sgdComputationGraph.getLayer(frozenBranchFrozenLayer1).params()); assertEquals(frozenComputationGraph.getLayer(frozenBranchFrozenLayer2).params(), sgdComputationGraph.getLayer(frozenBranchFrozenLayer2).params()); assertEquals(frozenComputationGraph.getLayer(frozenBranchOutput).params(), sgdComputationGraph.getLayer(frozenBranchOutput).params()); - } - - } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/OutputLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/OutputLayerTest.java index 03e48b169..9827c350e 100755 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/OutputLayerTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/OutputLayerTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.layers; import lombok.extern.slf4j.Slf4j; @@ -36,7 +35,7 @@ import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -46,123 +45,88 @@ import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; - import java.util.Collections; import java.util.Random; - -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; @Slf4j -public class OutputLayerTest extends BaseDL4JTest { +@DisplayName("Output Layer Test") +class OutputLayerTest extends BaseDL4JTest { @Test - public void testSetParams() { - NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT) - .updater(new Sgd(1e-1)) - .layer(new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder().nIn(4).nOut(3) - .weightInit(WeightInit.ZERO).activation(Activation.SOFTMAX) - .lossFunction(LossFunctions.LossFunction.MCXENT).build()) - .build(); - + @DisplayName("Test Set Params") + void testSetParams() { + NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).updater(new Sgd(1e-1)).layer(new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder().nIn(4).nOut(3).weightInit(WeightInit.ZERO).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()).build(); long numParams = conf.getLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - OutputLayer l = (OutputLayer) conf.getLayer().instantiate(conf, - Collections.singletonList(new ScoreIterationListener(1)), 0, params, true, params.dataType()); + OutputLayer l = (OutputLayer) conf.getLayer().instantiate(conf, Collections.singletonList(new ScoreIterationListener(1)), 0, params, true, params.dataType()); params = l.params(); l.setParams(params); assertEquals(params, l.params()); } @Test - public void testOutputLayersRnnForwardPass() { - //Test output layer with RNNs ( - //Expect all outputs etc. to be 2d + @DisplayName("Test Output Layers Rnn Forward Pass") + void testOutputLayersRnnForwardPass() { + // Test output layer with RNNs ( + // Expect all outputs etc. to be 2d int nIn = 2; int nOut = 5; int layerSize = 4; int timeSeriesLength = 6; int miniBatchSize = 3; - Random r = new Random(12345L); INDArray input = Nd4j.zeros(miniBatchSize, nIn, timeSeriesLength); for (int i = 0; i < miniBatchSize; i++) { for (int j = 0; j < nIn; j++) { for (int k = 0; k < timeSeriesLength; k++) { - input.putScalar(new int[] {i, j, k}, r.nextDouble() - 0.5); + input.putScalar(new int[] { i, j, k }, r.nextDouble() - 0.5); } } } - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345L).list() - .layer(0, new GravesLSTM.Builder().nIn(nIn).nOut(layerSize) - .dist(new NormalDistribution(0, 1)).activation(Activation.TANH) - .updater(new NoOp()).build()) - .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nIn(layerSize).nOut(nOut) - .dist(new NormalDistribution(0, 1)) - .updater(new NoOp()).build()) - .inputPreProcessor(1, new RnnToFeedForwardPreProcessor()).build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345L).list().layer(0, new GravesLSTM.Builder().nIn(nIn).nOut(layerSize).dist(new NormalDistribution(0, 1)).activation(Activation.TANH).updater(new NoOp()).build()).layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(layerSize).nOut(nOut).dist(new NormalDistribution(0, 1)).updater(new NoOp()).build()).inputPreProcessor(1, new RnnToFeedForwardPreProcessor()).build(); MultiLayerNetwork mln = new MultiLayerNetwork(conf); mln.init(); - INDArray out2d = mln.feedForward(input).get(2); - assertArrayEquals(out2d.shape(), new long[] {miniBatchSize * timeSeriesLength, nOut}); - + assertArrayEquals(out2d.shape(), new long[] { miniBatchSize * timeSeriesLength, nOut }); INDArray out = mln.output(input); - assertArrayEquals(out.shape(), new long[] {miniBatchSize * timeSeriesLength, nOut}); - + assertArrayEquals(out.shape(), new long[] { miniBatchSize * timeSeriesLength, nOut }); INDArray preout = mln.output(input); - assertArrayEquals(preout.shape(), new long[] {miniBatchSize * timeSeriesLength, nOut}); - - //As above, but for RnnOutputLayer. Expect all activations etc. to be 3d - - MultiLayerConfiguration confRnn = new NeuralNetConfiguration.Builder().seed(12345L).list() - .layer(0, new GravesLSTM.Builder().nIn(nIn).nOut(layerSize) - .dist(new NormalDistribution(0, 1)).activation(Activation.TANH) - .updater(new NoOp()).build()) - .layer(1, new org.deeplearning4j.nn.conf.layers.RnnOutputLayer.Builder(LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nIn(layerSize).nOut(nOut) - .dist(new NormalDistribution(0, 1)) - .updater(new NoOp()).build()) - .build(); - + assertArrayEquals(preout.shape(), new long[] { miniBatchSize * timeSeriesLength, nOut }); + // As above, but for RnnOutputLayer. Expect all activations etc. to be 3d + MultiLayerConfiguration confRnn = new NeuralNetConfiguration.Builder().seed(12345L).list().layer(0, new GravesLSTM.Builder().nIn(nIn).nOut(layerSize).dist(new NormalDistribution(0, 1)).activation(Activation.TANH).updater(new NoOp()).build()).layer(1, new org.deeplearning4j.nn.conf.layers.RnnOutputLayer.Builder(LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(layerSize).nOut(nOut).dist(new NormalDistribution(0, 1)).updater(new NoOp()).build()).build(); MultiLayerNetwork mlnRnn = new MultiLayerNetwork(confRnn); mln.init(); - INDArray out3d = mlnRnn.feedForward(input).get(2); - assertArrayEquals(out3d.shape(), new long[] {miniBatchSize, nOut, timeSeriesLength}); - + assertArrayEquals(out3d.shape(), new long[] { miniBatchSize, nOut, timeSeriesLength }); INDArray outRnn = mlnRnn.output(input); - assertArrayEquals(outRnn.shape(), new long[] {miniBatchSize, nOut, timeSeriesLength}); - + assertArrayEquals(outRnn.shape(), new long[] { miniBatchSize, nOut, timeSeriesLength }); INDArray preoutRnn = mlnRnn.output(input); - assertArrayEquals(preoutRnn.shape(), new long[] {miniBatchSize, nOut, timeSeriesLength}); + assertArrayEquals(preoutRnn.shape(), new long[] { miniBatchSize, nOut, timeSeriesLength }); } @Test - public void testRnnOutputLayerIncEdgeCases() { - //Basic test + test edge cases: timeSeriesLength==1, miniBatchSize==1, both - int[] tsLength = {5, 1, 5, 1}; - int[] miniBatch = {7, 7, 1, 1}; + @DisplayName("Test Rnn Output Layer Inc Edge Cases") + void testRnnOutputLayerIncEdgeCases() { + // Basic test + test edge cases: timeSeriesLength==1, miniBatchSize==1, both + int[] tsLength = { 5, 1, 5, 1 }; + int[] miniBatch = { 7, 7, 1, 1 }; int nIn = 3; int nOut = 6; int layerSize = 4; - FeedForwardToRnnPreProcessor proc = new FeedForwardToRnnPreProcessor(); - for (int t = 0; t < tsLength.length; t++) { Nd4j.getRandom().setSeed(12345); int timeSeriesLength = tsLength[t]; int miniBatchSize = miniBatch[t]; - Random r = new Random(12345L); INDArray input = Nd4j.zeros(miniBatchSize, nIn, timeSeriesLength); for (int i = 0; i < miniBatchSize; i++) { for (int j = 0; j < nIn; j++) { for (int k = 0; k < timeSeriesLength; k++) { - input.putScalar(new int[] {i, j, k}, r.nextDouble() - 0.5); + input.putScalar(new int[] { i, j, k }, r.nextDouble() - 0.5); } } } @@ -170,406 +134,200 @@ public class OutputLayerTest extends BaseDL4JTest { for (int i = 0; i < miniBatchSize; i++) { for (int j = 0; j < timeSeriesLength; j++) { int idx = r.nextInt(nOut); - labels3d.putScalar(new int[] {i, idx, j}, 1.0f); + labels3d.putScalar(new int[] { i, idx, j }, 1.0f); } } INDArray labels2d = proc.backprop(labels3d, miniBatchSize, LayerWorkspaceMgr.noWorkspaces()); - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345L).list() - .layer(0, new GravesLSTM.Builder().nIn(nIn).nOut(layerSize) - .dist(new NormalDistribution(0, 1)) - .activation(Activation.TANH).updater(new NoOp()).build()) - .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nIn(layerSize).nOut(nOut) - .dist(new NormalDistribution(0, 1)) - .updater(new NoOp()).build()) - .inputPreProcessor(1, new RnnToFeedForwardPreProcessor()) - .build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345L).list().layer(0, new GravesLSTM.Builder().nIn(nIn).nOut(layerSize).dist(new NormalDistribution(0, 1)).activation(Activation.TANH).updater(new NoOp()).build()).layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(layerSize).nOut(nOut).dist(new NormalDistribution(0, 1)).updater(new NoOp()).build()).inputPreProcessor(1, new RnnToFeedForwardPreProcessor()).build(); MultiLayerNetwork mln = new MultiLayerNetwork(conf); mln.init(); - INDArray out2d = mln.feedForward(input).get(2); INDArray out3d = proc.preProcess(out2d, miniBatchSize, LayerWorkspaceMgr.noWorkspaces()); - - MultiLayerConfiguration confRnn = new NeuralNetConfiguration.Builder().seed(12345L).list() - .layer(0, new GravesLSTM.Builder().nIn(nIn).nOut(layerSize) - .dist(new NormalDistribution(0, 1)) - .activation(Activation.TANH).updater(new NoOp()).build()) - .layer(1, new org.deeplearning4j.nn.conf.layers.RnnOutputLayer.Builder(LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nIn(layerSize).nOut(nOut) - .dist(new NormalDistribution(0, 1)) - .updater(new NoOp()).build()) - .build(); - + MultiLayerConfiguration confRnn = new NeuralNetConfiguration.Builder().seed(12345L).list().layer(0, new GravesLSTM.Builder().nIn(nIn).nOut(layerSize).dist(new NormalDistribution(0, 1)).activation(Activation.TANH).updater(new NoOp()).build()).layer(1, new org.deeplearning4j.nn.conf.layers.RnnOutputLayer.Builder(LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(layerSize).nOut(nOut).dist(new NormalDistribution(0, 1)).updater(new NoOp()).build()).build(); MultiLayerNetwork mlnRnn = new MultiLayerNetwork(confRnn); mlnRnn.init(); - INDArray outRnn = mlnRnn.feedForward(input).get(2); - mln.setLabels(labels2d); mlnRnn.setLabels(labels3d); - - mln.computeGradientAndScore(); mlnRnn.computeGradientAndScore(); - - //score is average over all examples. - //However: OutputLayer version has miniBatch*timeSeriesLength "examples" (after reshaping) - //RnnOutputLayer has miniBatch examples - //Hence: expect difference in scores by factor of timeSeriesLength + // score is average over all examples. + // However: OutputLayer version has miniBatch*timeSeriesLength "examples" (after reshaping) + // RnnOutputLayer has miniBatch examples + // Hence: expect difference in scores by factor of timeSeriesLength double score = mln.score() * timeSeriesLength; double scoreRNN = mlnRnn.score(); - assertTrue(!Double.isNaN(score)); assertTrue(!Double.isNaN(scoreRNN)); - double relError = Math.abs(score - scoreRNN) / (Math.abs(score) + Math.abs(scoreRNN)); System.out.println(relError); assertTrue(relError < 1e-6); - - //Check labels and inputs for output layer: + // Check labels and inputs for output layer: OutputLayer ol = (OutputLayer) mln.getOutputLayer(); - assertArrayEquals(ol.getInput().shape(), new long[] {miniBatchSize * timeSeriesLength, layerSize}); - assertArrayEquals(ol.getLabels().shape(), new long[] {miniBatchSize * timeSeriesLength, nOut}); - + assertArrayEquals(ol.getInput().shape(), new long[] { miniBatchSize * timeSeriesLength, layerSize }); + assertArrayEquals(ol.getLabels().shape(), new long[] { miniBatchSize * timeSeriesLength, nOut }); RnnOutputLayer rnnol = (RnnOutputLayer) mlnRnn.getOutputLayer(); - //assertArrayEquals(rnnol.getInput().shape(),new int[]{miniBatchSize,layerSize,timeSeriesLength}); - //Input may be set by BaseLayer methods. Thus input may end up as reshaped 2d version instead of original 3d version. - //Not ideal, but everything else works. - assertArrayEquals(rnnol.getLabels().shape(), new long[] {miniBatchSize, nOut, timeSeriesLength}); - - //Check shapes of output for both: - assertArrayEquals(out2d.shape(), new long[] {miniBatchSize * timeSeriesLength, nOut}); - + // assertArrayEquals(rnnol.getInput().shape(),new int[]{miniBatchSize,layerSize,timeSeriesLength}); + // Input may be set by BaseLayer methods. Thus input may end up as reshaped 2d version instead of original 3d version. + // Not ideal, but everything else works. + assertArrayEquals(rnnol.getLabels().shape(), new long[] { miniBatchSize, nOut, timeSeriesLength }); + // Check shapes of output for both: + assertArrayEquals(out2d.shape(), new long[] { miniBatchSize * timeSeriesLength, nOut }); INDArray out = mln.output(input); - assertArrayEquals(out.shape(), new long[] {miniBatchSize * timeSeriesLength, nOut}); - + assertArrayEquals(out.shape(), new long[] { miniBatchSize * timeSeriesLength, nOut }); INDArray preout = mln.output(input); - assertArrayEquals(preout.shape(), new long[] {miniBatchSize * timeSeriesLength, nOut}); - - + assertArrayEquals(preout.shape(), new long[] { miniBatchSize * timeSeriesLength, nOut }); INDArray outFFRnn = mlnRnn.feedForward(input).get(2); - assertArrayEquals(outFFRnn.shape(), new long[] {miniBatchSize, nOut, timeSeriesLength}); - + assertArrayEquals(outFFRnn.shape(), new long[] { miniBatchSize, nOut, timeSeriesLength }); INDArray outRnn2 = mlnRnn.output(input); - assertArrayEquals(outRnn2.shape(), new long[] {miniBatchSize, nOut, timeSeriesLength}); - + assertArrayEquals(outRnn2.shape(), new long[] { miniBatchSize, nOut, timeSeriesLength }); INDArray preoutRnn = mlnRnn.output(input); - assertArrayEquals(preoutRnn.shape(), new long[] {miniBatchSize, nOut, timeSeriesLength}); + assertArrayEquals(preoutRnn.shape(), new long[] { miniBatchSize, nOut, timeSeriesLength }); } } - @Test - public void testCompareRnnOutputRnnLoss(){ + @DisplayName("Test Compare Rnn Output Rnn Loss") + void testCompareRnnOutputRnnLoss() { Nd4j.getRandom().setSeed(12345); - int timeSeriesLength = 4; int nIn = 5; int layerSize = 6; int nOut = 6; int miniBatchSize = 3; - - MultiLayerConfiguration conf1 = - new NeuralNetConfiguration.Builder().seed(12345L) - .updater(new NoOp()) - .list() - .layer(new LSTM.Builder().nIn(nIn).nOut(layerSize).activation(Activation.TANH) - .dist(new NormalDistribution(0, 1.0)) - .updater(new NoOp()).build()) - .layer(new DenseLayer.Builder().nIn(layerSize).nOut(nOut).activation(Activation.IDENTITY).build()) - .layer(new RnnLossLayer.Builder(LossFunction.MCXENT) - .activation(Activation.SOFTMAX) - .build()) - .build(); - + MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(12345L).updater(new NoOp()).list().layer(new LSTM.Builder().nIn(nIn).nOut(layerSize).activation(Activation.TANH).dist(new NormalDistribution(0, 1.0)).updater(new NoOp()).build()).layer(new DenseLayer.Builder().nIn(layerSize).nOut(nOut).activation(Activation.IDENTITY).build()).layer(new RnnLossLayer.Builder(LossFunction.MCXENT).activation(Activation.SOFTMAX).build()).build(); MultiLayerNetwork mln = new MultiLayerNetwork(conf1); mln.init(); - - - MultiLayerConfiguration conf2 = - new NeuralNetConfiguration.Builder().seed(12345L) - .updater(new NoOp()) - .list() - .layer(new LSTM.Builder().nIn(nIn).nOut(layerSize).activation(Activation.TANH) - .dist(new NormalDistribution(0, 1.0)) - .updater(new NoOp()).build()) - .layer(new org.deeplearning4j.nn.conf.layers.RnnOutputLayer.Builder(LossFunction.MCXENT) - .activation(Activation.SOFTMAX) - .nIn(layerSize).nOut(nOut) - .build()) - .build(); - + MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345L).updater(new NoOp()).list().layer(new LSTM.Builder().nIn(nIn).nOut(layerSize).activation(Activation.TANH).dist(new NormalDistribution(0, 1.0)).updater(new NoOp()).build()).layer(new org.deeplearning4j.nn.conf.layers.RnnOutputLayer.Builder(LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(layerSize).nOut(nOut).build()).build(); MultiLayerNetwork mln2 = new MultiLayerNetwork(conf2); mln2.init(); - mln2.setParams(mln.params()); - - INDArray in = Nd4j.rand(new int[]{miniBatchSize, nIn, timeSeriesLength}); - + INDArray in = Nd4j.rand(new int[] { miniBatchSize, nIn, timeSeriesLength }); INDArray out1 = mln.output(in); INDArray out2 = mln.output(in); - assertEquals(out1, out2); - Random r = new Random(12345); INDArray labels = Nd4j.create(miniBatchSize, nOut, timeSeriesLength); - for( int i=0; i= 0 && max <= 1.0); - INDArray sum = out.sum(1); - assertEquals(Nd4j.ones(DataType.FLOAT,2,4,5), sum); + assertEquals(Nd4j.ones(DataType.FLOAT, 2, 4, 5), sum); } @Test - public void testOutputLayerDefaults(){ - - new NeuralNetConfiguration.Builder().list() - .layer(new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder().nIn(10).nOut(10).build()) - .build(); - - new NeuralNetConfiguration.Builder().list() - .layer(new org.deeplearning4j.nn.conf.layers.LossLayer.Builder().build()) - .build(); - - new NeuralNetConfiguration.Builder().list() - .layer(new org.deeplearning4j.nn.conf.layers.CnnLossLayer.Builder().build()) - .build(); - - new NeuralNetConfiguration.Builder().list() - .layer(new org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer.Builder().build()) - .build(); - + @DisplayName("Test Output Layer Defaults") + void testOutputLayerDefaults() { + new NeuralNetConfiguration.Builder().list().layer(new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder().nIn(10).nOut(10).build()).build(); + new NeuralNetConfiguration.Builder().list().layer(new org.deeplearning4j.nn.conf.layers.LossLayer.Builder().build()).build(); + new NeuralNetConfiguration.Builder().list().layer(new org.deeplearning4j.nn.conf.layers.CnnLossLayer.Builder().build()).build(); + new NeuralNetConfiguration.Builder().list().layer(new org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer.Builder().build()).build(); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/RepeatVectorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/RepeatVectorTest.java index 5f4696b89..fddc7c150 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/RepeatVectorTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/RepeatVectorTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.layers; import org.deeplearning4j.BaseDL4JTest; @@ -26,47 +25,41 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.misc.RepeatVector; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.common.primitives.Pair; - import java.util.Arrays; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; - -public class RepeatVectorTest extends BaseDL4JTest { +@DisplayName("Repeat Vector Test") +class RepeatVectorTest extends BaseDL4JTest { private int REPEAT = 4; - private Layer getRepeatVectorLayer() { - NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().seed(123) - .dataType(DataType.DOUBLE) - .layer(new RepeatVector.Builder(REPEAT).build()).build(); - return conf.getLayer().instantiate(conf, null, 0, - null, false, DataType.DOUBLE); + NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().seed(123).dataType(DataType.DOUBLE).layer(new RepeatVector.Builder(REPEAT).build()).build(); + return conf.getLayer().instantiate(conf, null, 0, null, false, DataType.DOUBLE); } @Test - public void testRepeatVector() { - - double[] arr = new double[] {1., 2., 3., 1., 2., 3., 1., 2., 3., 1., 2., 3.}; - INDArray expectedOut = Nd4j.create(arr, new long[] {1, 3, REPEAT}, 'f'); - INDArray input = Nd4j.create(new double[] {1., 2., 3.}, new long[] {1, 3}); + @DisplayName("Test Repeat Vector") + void testRepeatVector() { + double[] arr = new double[] { 1., 2., 3., 1., 2., 3., 1., 2., 3., 1., 2., 3. }; + INDArray expectedOut = Nd4j.create(arr, new long[] { 1, 3, REPEAT }, 'f'); + INDArray input = Nd4j.create(new double[] { 1., 2., 3. }, new long[] { 1, 3 }); Layer layer = getRepeatVectorLayer(); - INDArray output = layer.activate(input, false, LayerWorkspaceMgr.noWorkspaces()); assertTrue(Arrays.equals(expectedOut.shape(), output.shape())); assertEquals(expectedOut, output); - - INDArray epsilon = Nd4j.ones(1,3,4); - + INDArray epsilon = Nd4j.ones(1, 3, 4); Pair out = layer.backpropGradient(epsilon, LayerWorkspaceMgr.noWorkspaces()); INDArray outEpsilon = out.getSecond(); - INDArray expectedEpsilon = Nd4j.create(new double[] {4., 4., 4.}, new long[] {1, 3}); + INDArray expectedEpsilon = Nd4j.create(new double[] { 4., 4., 4. }, new long[] { 1, 3 }); assertEquals(expectedEpsilon, outEpsilon); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/SeedTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/SeedTest.java index 88afce166..c30f867d2 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/SeedTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/SeedTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.layers; import org.deeplearning4j.BaseDL4JTest; @@ -25,45 +24,41 @@ import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.AutoEncoder; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.factory.Nd4j; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; - -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** */ - -public class SeedTest extends BaseDL4JTest { +@DisplayName("Seed Test") +class SeedTest extends BaseDL4JTest { private DataSetIterator irisIter = new IrisDataSetIterator(50, 50); + private DataSet data = irisIter.next(); - @Test - public void testAutoEncoderSeed() { - AutoEncoder layerType = new AutoEncoder.Builder().nIn(4).nOut(3).corruptionLevel(0.0) - .activation(Activation.SIGMOID).build(); - - NeuralNetConfiguration conf = - new NeuralNetConfiguration.Builder().layer(layerType).seed(123).build(); - + @DisplayName("Test Auto Encoder Seed") + void testAutoEncoderSeed() { + AutoEncoder layerType = new AutoEncoder.Builder().nIn(4).nOut(3).corruptionLevel(0.0).activation(Activation.SIGMOID).build(); + NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().layer(layerType).seed(123).build(); long numParams = conf.getLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); Layer layer = conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); layer.setBackpropGradientsViewArray(Nd4j.create(1, numParams)); layer.fit(data.getFeatures(), LayerWorkspaceMgr.noWorkspaces()); - layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces()); double score = layer.score(); INDArray parameters = layer.params(); layer.setParams(parameters); layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces()); - double score2 = layer.score(); assertEquals(parameters, layer.params()); assertEquals(score, score2, 1e-4); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/capsule/CapsNetMNISTTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/capsule/CapsNetMNISTTest.java index 83597dba3..5a5d08b8f 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/capsule/CapsNetMNISTTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/capsule/CapsNetMNISTTest.java @@ -17,11 +17,9 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.layers.capsule; -import static org.junit.Assert.assertTrue; - +import static org.junit.jupiter.api.Assertions.assertTrue; import java.io.IOException; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; @@ -35,64 +33,44 @@ import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.conf.layers.LossLayer; import org.deeplearning4j.nn.conf.layers.PrimaryCapsules; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.linalg.activations.impl.ActivationSoftmax; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -@Ignore("AB - ignored due to excessive runtime. Keep for manual debugging when required") -public class CapsNetMNISTTest extends BaseDL4JTest { +@Disabled("AB - ignored due to excessive runtime. Keep for manual debugging when required") +@DisplayName("Caps Net MNIST Test") +class CapsNetMNISTTest extends BaseDL4JTest { @Override - public DataType getDataType(){ + public DataType getDataType() { return DataType.FLOAT; } @Test - public void testCapsNetOnMNIST(){ - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .seed(123) - .updater(new Adam()) - .list() - .layer(new ConvolutionLayer.Builder() - .nOut(16) - .kernelSize(9, 9) - .stride(3, 3) - .build()) - .layer(new PrimaryCapsules.Builder(8, 8) - .kernelSize(7, 7) - .stride(2, 2) - .build()) - .layer(new CapsuleLayer.Builder(10, 16, 3).build()) - .layer(new CapsuleStrengthLayer.Builder().build()) - .layer(new ActivationLayer.Builder(new ActivationSoftmax()).build()) - .layer(new LossLayer.Builder(new LossNegativeLogLikelihood()).build()) - .setInputType(InputType.convolutionalFlat(28, 28, 1)) - .build(); - + @DisplayName("Test Caps Net On MNIST") + void testCapsNetOnMNIST() { + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(123).updater(new Adam()).list().layer(new ConvolutionLayer.Builder().nOut(16).kernelSize(9, 9).stride(3, 3).build()).layer(new PrimaryCapsules.Builder(8, 8).kernelSize(7, 7).stride(2, 2).build()).layer(new CapsuleLayer.Builder(10, 16, 3).build()).layer(new CapsuleStrengthLayer.Builder().build()).layer(new ActivationLayer.Builder(new ActivationSoftmax()).build()).layer(new LossLayer.Builder(new LossNegativeLogLikelihood()).build()).setInputType(InputType.convolutionalFlat(28, 28, 1)).build(); MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); - int rngSeed = 12345; try { MnistDataSetIterator mnistTrain = new MnistDataSetIterator(64, true, rngSeed); MnistDataSetIterator mnistTest = new MnistDataSetIterator(64, false, rngSeed); - for (int i = 0; i < 2; i++) { model.fit(mnistTrain); } - Evaluation eval = model.evaluate(mnistTest); - - assertTrue("Accuracy not over 95%", eval.accuracy() > 0.95); - assertTrue("Precision not over 95%", eval.precision() > 0.95); - assertTrue("Recall not over 95%", eval.recall() > 0.95); - assertTrue("F1-score not over 95%", eval.f1() > 0.95); - - } catch (IOException e){ + assertTrue(eval.accuracy() > 0.95, "Accuracy not over 95%"); + assertTrue(eval.precision() > 0.95, "Precision not over 95%"); + assertTrue(eval.recall() > 0.95, "Recall not over 95%"); + assertTrue(eval.f1() > 0.95, "F1-score not over 95%"); + } catch (IOException e) { System.out.println("Could not load MNIST."); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/capsule/CapsuleLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/capsule/CapsuleLayerTest.java index f5502170f..9a131f49a 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/capsule/CapsuleLayerTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/capsule/CapsuleLayerTest.java @@ -17,84 +17,71 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.layers.capsule; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; - +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.CapsuleLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -public class CapsuleLayerTest extends BaseDL4JTest { +@DisplayName("Capsule Layer Test") +class CapsuleLayerTest extends BaseDL4JTest { @Override - public DataType getDataType(){ + public DataType getDataType() { return DataType.FLOAT; } @Test - public void testOutputType(){ + @DisplayName("Test Output Type") + void testOutputType() { CapsuleLayer layer = new CapsuleLayer.Builder(10, 16, 5).build(); - InputType in1 = InputType.recurrent(5, 8); - assertEquals(InputType.recurrent(10, 16), layer.getOutputType(0, in1)); } @Test - public void testInputType(){ + @DisplayName("Test Input Type") + void testInputType() { CapsuleLayer layer = new CapsuleLayer.Builder(10, 16, 5).build(); - InputType in1 = InputType.recurrent(5, 8); - layer.setNIn(in1, true); - assertEquals(5, layer.getInputCapsules()); assertEquals(8, layer.getInputCapsuleDimensions()); } @Test - public void testConfig(){ + @DisplayName("Test Config") + void testConfig() { CapsuleLayer layer1 = new CapsuleLayer.Builder(10, 16, 5).build(); - assertEquals(10, layer1.getCapsules()); assertEquals(16, layer1.getCapsuleDimensions()); assertEquals(5, layer1.getRoutings()); assertFalse(layer1.isHasBias()); - CapsuleLayer layer2 = new CapsuleLayer.Builder(10, 16, 5).hasBias(true).build(); - assertTrue(layer2.isHasBias()); - } @Test - public void testLayer(){ - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .seed(123) - .list() - .layer(new CapsuleLayer.Builder(10, 16, 3).build()) - .setInputType(InputType.recurrent(10, 8)) - .build(); - + @DisplayName("Test Layer") + void testLayer() { + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(123).list().layer(new CapsuleLayer.Builder(10, 16, 3).build()).setInputType(InputType.recurrent(10, 8)).build(); MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); - INDArray emptyFeatures = Nd4j.zeros(64, 10, 8); - long[] shape = model.output(emptyFeatures).shape(); - - assertArrayEquals(new long[]{64, 10, 16}, shape); + assertArrayEquals(new long[] { 64, 10, 16 }, shape); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/capsule/CapsuleStrengthLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/capsule/CapsuleStrengthLayerTest.java index 739d32fdb..e9276da71 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/capsule/CapsuleStrengthLayerTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/capsule/CapsuleStrengthLayerTest.java @@ -17,55 +17,47 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.layers.capsule; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; - +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.CapsuleStrengthLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -public class CapsuleStrengthLayerTest extends BaseDL4JTest { +@DisplayName("Capsule Strength Layer Test") +class CapsuleStrengthLayerTest extends BaseDL4JTest { @Override - public DataType getDataType(){ + public DataType getDataType() { return DataType.FLOAT; } @Test - public void testOutputType(){ + @DisplayName("Test Output Type") + void testOutputType() { CapsuleStrengthLayer layer = new CapsuleStrengthLayer.Builder().build(); - InputType in1 = InputType.recurrent(5, 8); - assertEquals(InputType.feedForward(5), layer.getOutputType(0, in1)); } @Test - public void testLayer(){ - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .seed(123) - .list() - .layer(new CapsuleStrengthLayer.Builder().build()) - .setInputType(InputType.recurrent(5, 8)) - .build(); - + @DisplayName("Test Layer") + void testLayer() { + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(123).list().layer(new CapsuleStrengthLayer.Builder().build()).setInputType(InputType.recurrent(5, 8)).build(); MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); - INDArray emptyFeatures = Nd4j.zeros(64, 5, 10); - long[] shape = model.output(emptyFeatures).shape(); - - assertArrayEquals(new long[]{64, 5}, shape); + assertArrayEquals(new long[] { 64, 5 }, shape); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/capsule/PrimaryCapsulesTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/capsule/PrimaryCapsulesTest.java index 8c5262358..0a4e03add 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/capsule/PrimaryCapsulesTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/capsule/PrimaryCapsulesTest.java @@ -17,113 +17,78 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.layers.capsule; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; - +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.PrimaryCapsules; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -public class PrimaryCapsulesTest extends BaseDL4JTest { +@DisplayName("Primary Capsules Test") +class PrimaryCapsulesTest extends BaseDL4JTest { @Override - public DataType getDataType(){ + public DataType getDataType() { return DataType.FLOAT; } @Test - public void testOutputType(){ - PrimaryCapsules layer = new PrimaryCapsules.Builder(8, 8) - .kernelSize(7, 7) - .stride(2, 2) - .build(); - - + @DisplayName("Test Output Type") + void testOutputType() { + PrimaryCapsules layer = new PrimaryCapsules.Builder(8, 8).kernelSize(7, 7).stride(2, 2).build(); InputType in1 = InputType.convolutional(7, 7, 16); assertEquals(InputType.recurrent(8, 8), layer.getOutputType(0, in1)); - } @Test - public void testInputType(){ - PrimaryCapsules layer = new PrimaryCapsules.Builder(8, 8) - .kernelSize(7, 7) - .stride(2, 2) - .build(); + @DisplayName("Test Input Type") + void testInputType() { + PrimaryCapsules layer = new PrimaryCapsules.Builder(8, 8).kernelSize(7, 7).stride(2, 2).build(); InputType in1 = InputType.convolutional(7, 7, 16); - - layer.setNIn(in1, true); - assertEquals(8, layer.getCapsules()); assertEquals(8, layer.getCapsuleDimensions()); } @Test - public void testConfig(){ - PrimaryCapsules layer1 = new PrimaryCapsules.Builder(8, 10) - .kernelSize(5, 5) - .stride(4, 4) - .useLeakyReLU(0.5) - .build(); - + @DisplayName("Test Config") + void testConfig() { + PrimaryCapsules layer1 = new PrimaryCapsules.Builder(8, 10).kernelSize(5, 5).stride(4, 4).useLeakyReLU(0.5).build(); assertEquals(8, layer1.getCapsuleDimensions()); assertEquals(10, layer1.getChannels()); - assertArrayEquals(new int[]{5, 5}, layer1.getKernelSize()); - assertArrayEquals(new int[]{4, 4}, layer1.getStride()); - assertArrayEquals(new int[]{0, 0}, layer1.getPadding()); - assertArrayEquals(new int[]{1, 1}, layer1.getDilation()); + assertArrayEquals(new int[] { 5, 5 }, layer1.getKernelSize()); + assertArrayEquals(new int[] { 4, 4 }, layer1.getStride()); + assertArrayEquals(new int[] { 0, 0 }, layer1.getPadding()); + assertArrayEquals(new int[] { 1, 1 }, layer1.getDilation()); assertTrue(layer1.isUseRelu()); assertEquals(0.5, layer1.getLeak(), 0.001); - - PrimaryCapsules layer2 = new PrimaryCapsules.Builder(8, 10) - .kernelSize(5, 5) - .stride(4, 4) - .build(); + PrimaryCapsules layer2 = new PrimaryCapsules.Builder(8, 10).kernelSize(5, 5).stride(4, 4).build(); assertFalse(layer2.isUseRelu()); - - PrimaryCapsules layer3 = new PrimaryCapsules.Builder(8, 10) - .kernelSize(5, 5) - .stride(4, 4) - .useReLU() - .build(); + PrimaryCapsules layer3 = new PrimaryCapsules.Builder(8, 10).kernelSize(5, 5).stride(4, 4).useReLU().build(); assertTrue(layer3.isUseRelu()); assertEquals(0, layer3.getLeak(), 0.001); - } @Test - public void testLayer(){ - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .seed(123) - .list() - .layer(new PrimaryCapsules.Builder(8, 10) - .kernelSize(5, 5) - .stride(4, 4) - .useLeakyReLU(0.5) - .build()) - .setInputType(InputType.convolutional(20, 20, 20)) - .build(); - + @DisplayName("Test Layer") + void testLayer() { + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(123).list().layer(new PrimaryCapsules.Builder(8, 10).kernelSize(5, 5).stride(4, 4).useLeakyReLU(0.5).build()).setInputType(InputType.convolutional(20, 20, 20)).build(); MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); - INDArray emptyFeatures = Nd4j.zeros(64, 20, 20, 20); - long[] shape = model.output(emptyFeatures).shape(); - - assertArrayEquals(new long[]{64, 160, 8}, shape); + assertArrayEquals(new long[] { 64, 160, 8 }, shape); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Convolution3DTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Convolution3DTest.java index 4615c95a2..31c0e8d5d 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Convolution3DTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Convolution3DTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.layers.convolution; import org.deeplearning4j.BaseDL4JTest; @@ -28,72 +27,67 @@ import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.Convolution3D; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.factory.Nd4j; - import java.util.Arrays; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** * @author Max Pumperla */ -public class Convolution3DTest extends BaseDL4JTest { +@DisplayName("Convolution 3 D Test") +class Convolution3DTest extends BaseDL4JTest { private int nExamples = 1; + private int nChannelsOut = 1; + private int nChannelsIn = 1; + private int inputDepth = 2 * 2; + private int inputWidth = 28 / 2; + private int inputHeight = 28 / 2; - private int[] kernelSize = new int[]{2, 2, 2}; + private int[] kernelSize = new int[] { 2, 2, 2 }; + private int outputDepth = inputDepth - kernelSize[0] + 1; + private int outputHeight = inputHeight - kernelSize[1] + 1; + private int outputWidth = inputWidth - kernelSize[2] + 1; private INDArray epsilon = Nd4j.ones(nExamples, nChannelsOut, outputDepth, outputHeight, outputWidth); - @Test - public void testConvolution3dForwardSameMode() { - + @DisplayName("Test Convolution 3 d Forward Same Mode") + void testConvolution3dForwardSameMode() { INDArray containedInput = getContainedData(); Convolution3DLayer layer = (Convolution3DLayer) getConvolution3DLayer(ConvolutionMode.Same); - assertTrue(layer.convolutionMode == ConvolutionMode.Same); - INDArray containedOutput = layer.activate(containedInput, false, LayerWorkspaceMgr.noWorkspaces()); - assertTrue(Arrays.equals(containedInput.shape(), containedOutput.shape())); - } @Test - public void testConvolution3dForwardValidMode() throws Exception { - + @DisplayName("Test Convolution 3 d Forward Valid Mode") + void testConvolution3dForwardValidMode() throws Exception { Convolution3DLayer layer = (Convolution3DLayer) getConvolution3DLayer(ConvolutionMode.Strict); - assertTrue(layer.convolutionMode == ConvolutionMode.Strict); - INDArray input = getData(); INDArray output = layer.activate(input, false, LayerWorkspaceMgr.noWorkspaces()); - - assertTrue(Arrays.equals(new long[]{nExamples, nChannelsOut, outputDepth, outputWidth, outputHeight}, - output.shape())); + assertTrue(Arrays.equals(new long[] { nExamples, nChannelsOut, outputDepth, outputWidth, outputHeight }, output.shape())); } private Layer getConvolution3DLayer(ConvolutionMode mode) { - NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder() - .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).seed(123) - .layer(new Convolution3D.Builder().kernelSize(kernelSize).nIn(nChannelsIn).nOut(nChannelsOut) - .dataFormat(Convolution3D.DataFormat.NCDHW).convolutionMode(mode).hasBias(false) - .build()) - .build(); + NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).seed(123).layer(new Convolution3D.Builder().kernelSize(kernelSize).nIn(nChannelsIn).nOut(nChannelsOut).dataFormat(Convolution3D.DataFormat.NCDHW).convolutionMode(mode).hasBias(false).build()).build(); long numParams = conf.getLayer().initializer().numParams(conf); INDArray params = Nd4j.ones(1, numParams); return conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); @@ -107,7 +101,6 @@ public class Convolution3DTest extends BaseDL4JTest { } private INDArray getContainedData() { - return Nd4j.create(new double[]{1., 2., 3., 4., 5., 6., 7., 8}, new int[]{1, 1, 2, 2, 2}); + return Nd4j.create(new double[] { 1., 2., 3., 4., 5., 6., 7., 8 }, new int[] { 1, 1, 2, 2, 2 }); } - } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerSetupTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerSetupTest.java index d49028f43..3f30c3ade 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerSetupTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerSetupTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.layers.convolution; import org.datavec.api.records.reader.RecordReader; @@ -37,9 +36,8 @@ import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToCnnPreProcessor; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.params.BatchNormalizationParamInitializer; import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -49,209 +47,122 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.common.io.ClassPathResource; import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.util.FeatureUtil; - import java.io.File; import java.util.ArrayList; import java.util.Arrays; import java.util.List; - -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; +import org.junit.jupiter.api.DisplayName; +import java.nio.file.Path; +import org.junit.jupiter.api.extension.ExtendWith; /** * @author Adam Gibson */ -public class ConvolutionLayerSetupTest extends BaseDL4JTest { +@DisplayName("Convolution Layer Setup Test") +class ConvolutionLayerSetupTest extends BaseDL4JTest { - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); + @TempDir + public Path testDir; @Override - public DataType getDataType(){ + public DataType getDataType() { return DataType.FLOAT; } @Test - public void testConvolutionLayerSetup() { + @DisplayName("Test Convolution Layer Setup") + void testConvolutionLayerSetup() { MultiLayerConfiguration.Builder builder = inComplete(); builder.setInputType(InputType.convolutionalFlat(28, 28, 1)); MultiLayerConfiguration completed = complete().build(); MultiLayerConfiguration test = builder.build(); assertEquals(completed, test); - } - @Test - public void testDenseToOutputLayer() { + @DisplayName("Test Dense To Output Layer") + void testDenseToOutputLayer() { Nd4j.getRandom().setSeed(12345); final int numRows = 76; final int numColumns = 76; int nChannels = 3; int outputNum = 6; int seed = 123; - - //setup the network - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed) - .l1(1e-1).l2(2e-4).dropOut(0.5).miniBatch(true) - .optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).list() - .layer(0, new ConvolutionLayer.Builder(5, 5).nOut(5).dropOut(0.5).weightInit(WeightInit.XAVIER) - .activation(Activation.RELU).build()) - .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[] {2, 2}) - .build()) - .layer(2, new ConvolutionLayer.Builder(3, 3).nOut(10).dropOut(0.5).weightInit(WeightInit.XAVIER) - .activation(Activation.RELU).build()) - .layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[] {2, 2}) - .build()) - .layer(4, new DenseLayer.Builder().nOut(100).activation(Activation.RELU).build()) - .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) - .nOut(outputNum).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX) - .build()) - - .setInputType(InputType.convolutional(numRows, numColumns, nChannels)); - - DataSet d = new DataSet(Nd4j.rand(new int[]{10, nChannels, numRows, numColumns}), - FeatureUtil.toOutcomeMatrix(new int[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, 6)); + // setup the network + MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed).l1(1e-1).l2(2e-4).dropOut(0.5).miniBatch(true).optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).list().layer(0, new ConvolutionLayer.Builder(5, 5).nOut(5).dropOut(0.5).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build()).layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[] { 2, 2 }).build()).layer(2, new ConvolutionLayer.Builder(3, 3).nOut(10).dropOut(0.5).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build()).layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[] { 2, 2 }).build()).layer(4, new DenseLayer.Builder().nOut(100).activation(Activation.RELU).build()).layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(outputNum).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build()).setInputType(InputType.convolutional(numRows, numColumns, nChannels)); + DataSet d = new DataSet(Nd4j.rand(new int[] { 10, nChannels, numRows, numColumns }), FeatureUtil.toOutcomeMatrix(new int[] { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 }, 6)); MultiLayerNetwork network = new MultiLayerNetwork(builder.build()); network.init(); network.fit(d); - } - @Test - public void testMnistLenet() throws Exception { + @DisplayName("Test Mnist Lenet") + void testMnistLenet() throws Exception { MultiLayerConfiguration.Builder incomplete = incompleteMnistLenet(); incomplete.setInputType(InputType.convolutionalFlat(28, 28, 1)); - MultiLayerConfiguration testConf = incomplete.build(); assertEquals(800, ((FeedForwardLayer) testConf.getConf(4).getLayer()).getNIn()); assertEquals(500, ((FeedForwardLayer) testConf.getConf(5).getLayer()).getNIn()); - - //test instantiation + // test instantiation DataSetIterator iter = new MnistDataSetIterator(10, 10); MultiLayerNetwork network = new MultiLayerNetwork(testConf); network.init(); network.fit(iter.next()); } - - @Test - public void testMultiChannel() throws Exception { - INDArray in = Nd4j.rand(new int[] {10, 3, 28, 28}); + @DisplayName("Test Multi Channel") + void testMultiChannel() throws Exception { + INDArray in = Nd4j.rand(new int[] { 10, 3, 28, 28 }); INDArray labels = Nd4j.rand(10, 2); DataSet next = new DataSet(in, labels); - NeuralNetConfiguration.ListBuilder builder = (NeuralNetConfiguration.ListBuilder) incompleteLFW(); builder.setInputType(InputType.convolutional(28, 28, 3)); MultiLayerConfiguration conf = builder.build(); ConvolutionLayer layer2 = (ConvolutionLayer) conf.getConf(2).getLayer(); assertEquals(6, layer2.getNIn()); - MultiLayerNetwork network = new MultiLayerNetwork(conf); network.init(); network.fit(next); } @Test - public void testLRN() throws Exception { + @DisplayName("Test LRN") + void testLRN(@TempDir Path testFolder) throws Exception { List labels = new ArrayList<>(Arrays.asList("Zico", "Ziwang_Xu")); - File dir = testDir.newFolder(); + File dir = testFolder.toFile(); new ClassPathResource("lfwtest/").copyDirectory(dir); String rootDir = dir.getAbsolutePath(); - RecordReader reader = new ImageRecordReader(28, 28, 3); reader.initialize(new FileSplit(new File(rootDir))); DataSetIterator recordReader = new RecordReaderDataSetIterator(reader, 10, 1, labels.size()); labels.remove("lfwtest"); NeuralNetConfiguration.ListBuilder builder = (NeuralNetConfiguration.ListBuilder) incompleteLRN(); builder.setInputType(InputType.convolutional(28, 28, 3)); - MultiLayerConfiguration conf = builder.build(); - ConvolutionLayer layer2 = (ConvolutionLayer) conf.getConf(3).getLayer(); assertEquals(6, layer2.getNIn()); - } - public MultiLayerConfiguration.Builder incompleteLRN() { - MultiLayerConfiguration.Builder builder = - new NeuralNetConfiguration.Builder().seed(3) - .optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).list() - .layer(0, new org.deeplearning4j.nn.conf.layers.ConvolutionLayer.Builder( - new int[] {5, 5}).nOut(6).build()) - .layer(1, new org.deeplearning4j.nn.conf.layers.SubsamplingLayer.Builder( - new int[] {2, 2}).build()) - .layer(2, new LocalResponseNormalization.Builder().build()) - .layer(3, new org.deeplearning4j.nn.conf.layers.ConvolutionLayer.Builder( - new int[] {5, 5}).nOut(6).build()) - .layer(4, new org.deeplearning4j.nn.conf.layers.SubsamplingLayer.Builder( - new int[] {2, 2}).build()) - .layer(5, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(2) - .activation(Activation.SOFTMAX).build()); + MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(3).optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).list().layer(0, new org.deeplearning4j.nn.conf.layers.ConvolutionLayer.Builder(new int[] { 5, 5 }).nOut(6).build()).layer(1, new org.deeplearning4j.nn.conf.layers.SubsamplingLayer.Builder(new int[] { 2, 2 }).build()).layer(2, new LocalResponseNormalization.Builder().build()).layer(3, new org.deeplearning4j.nn.conf.layers.ConvolutionLayer.Builder(new int[] { 5, 5 }).nOut(6).build()).layer(4, new org.deeplearning4j.nn.conf.layers.SubsamplingLayer.Builder(new int[] { 2, 2 }).build()).layer(5, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(2).activation(Activation.SOFTMAX).build()); return builder; } - public MultiLayerConfiguration.Builder incompleteLFW() { - MultiLayerConfiguration.Builder builder = - new NeuralNetConfiguration.Builder().seed(3) - .optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).list() - .layer(0, new org.deeplearning4j.nn.conf.layers.ConvolutionLayer.Builder( - new int[] {5, 5}).nOut(6).build()) - .layer(1, new org.deeplearning4j.nn.conf.layers.SubsamplingLayer.Builder( - new int[] {2, 2}).build()) - .layer(2, new org.deeplearning4j.nn.conf.layers.ConvolutionLayer.Builder( - new int[] {5, 5}).nOut(6).build()) - .layer(3, new org.deeplearning4j.nn.conf.layers.SubsamplingLayer.Builder( - new int[] {2, 2}).build()) - .layer(4, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).activation(Activation.SOFTMAX) - .nOut(2).build()); + MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(3).optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).list().layer(0, new org.deeplearning4j.nn.conf.layers.ConvolutionLayer.Builder(new int[] { 5, 5 }).nOut(6).build()).layer(1, new org.deeplearning4j.nn.conf.layers.SubsamplingLayer.Builder(new int[] { 2, 2 }).build()).layer(2, new org.deeplearning4j.nn.conf.layers.ConvolutionLayer.Builder(new int[] { 5, 5 }).nOut(6).build()).layer(3, new org.deeplearning4j.nn.conf.layers.SubsamplingLayer.Builder(new int[] { 2, 2 }).build()).layer(4, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).activation(Activation.SOFTMAX).nOut(2).build()); return builder; } - - public MultiLayerConfiguration.Builder incompleteMnistLenet() { - MultiLayerConfiguration.Builder builder = - new NeuralNetConfiguration.Builder().seed(3) - .optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).list() - .layer(0, new org.deeplearning4j.nn.conf.layers.ConvolutionLayer.Builder( - new int[] {5, 5}).nIn(1).nOut(20).build()) - .layer(1, new org.deeplearning4j.nn.conf.layers.SubsamplingLayer.Builder( - new int[] {2, 2}, new int[] {2, 2}).build()) - .layer(2, new org.deeplearning4j.nn.conf.layers.ConvolutionLayer.Builder( - new int[] {5, 5}).nIn(20).nOut(50).build()) - .layer(3, new org.deeplearning4j.nn.conf.layers.SubsamplingLayer.Builder( - new int[] {2, 2}, new int[] {2, 2}).build()) - .layer(4, new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nOut(500) - .build()) - .layer(5, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) - .activation(Activation.SOFTMAX).nOut(10) - .build()); + MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(3).optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).list().layer(0, new org.deeplearning4j.nn.conf.layers.ConvolutionLayer.Builder(new int[] { 5, 5 }).nIn(1).nOut(20).build()).layer(1, new org.deeplearning4j.nn.conf.layers.SubsamplingLayer.Builder(new int[] { 2, 2 }, new int[] { 2, 2 }).build()).layer(2, new org.deeplearning4j.nn.conf.layers.ConvolutionLayer.Builder(new int[] { 5, 5 }).nIn(20).nOut(50).build()).layer(3, new org.deeplearning4j.nn.conf.layers.SubsamplingLayer.Builder(new int[] { 2, 2 }, new int[] { 2, 2 }).build()).layer(4, new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nOut(500).build()).layer(5, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).activation(Activation.SOFTMAX).nOut(10).build()); return builder; } public MultiLayerConfiguration mnistLenet() { - MultiLayerConfiguration builder = - new NeuralNetConfiguration.Builder().seed(3) - .optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).list() - .layer(0, new org.deeplearning4j.nn.conf.layers.ConvolutionLayer.Builder( - new int[] {5, 5}).nIn(1).nOut(6).build()) - .layer(1, new org.deeplearning4j.nn.conf.layers.SubsamplingLayer.Builder( - new int[] {5, 5}, new int[] {2, 2}).build()) - .layer(2, new org.deeplearning4j.nn.conf.layers.ConvolutionLayer.Builder( - new int[] {5, 5}).nIn(1).nOut(6).build()) - .layer(3, new org.deeplearning4j.nn.conf.layers.SubsamplingLayer.Builder( - new int[] {5, 5}, new int[] {2, 2}).build()) - .layer(4, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nIn(150) - .nOut(10).build()) - .build(); + MultiLayerConfiguration builder = new NeuralNetConfiguration.Builder().seed(3).optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).list().layer(0, new org.deeplearning4j.nn.conf.layers.ConvolutionLayer.Builder(new int[] { 5, 5 }).nIn(1).nOut(6).build()).layer(1, new org.deeplearning4j.nn.conf.layers.SubsamplingLayer.Builder(new int[] { 5, 5 }, new int[] { 2, 2 }).build()).layer(2, new org.deeplearning4j.nn.conf.layers.ConvolutionLayer.Builder(new int[] { 5, 5 }).nIn(1).nOut(6).build()).layer(3, new org.deeplearning4j.nn.conf.layers.SubsamplingLayer.Builder(new int[] { 5, 5 }, new int[] { 2, 2 }).build()).layer(4, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nIn(150).nOut(10).build()).build(); return builder; } @@ -259,124 +170,75 @@ public class ConvolutionLayerSetupTest extends BaseDL4JTest { int nChannels = 1; int outputNum = 10; int seed = 123; - - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed) - .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).list() - .layer(0, new org.deeplearning4j.nn.conf.layers.ConvolutionLayer.Builder(new int[] {10, 10}, - new int[] {2, 2}).nIn(nChannels).nOut(6).build()) - .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[] {2, 2}) - .build()) - .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) - .nOut(outputNum).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX) - .build()) - ; - + MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed).optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).list().layer(0, new org.deeplearning4j.nn.conf.layers.ConvolutionLayer.Builder(new int[] { 10, 10 }, new int[] { 2, 2 }).nIn(nChannels).nOut(6).build()).layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[] { 2, 2 }).build()).layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(outputNum).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build()); return builder; } - public MultiLayerConfiguration.Builder complete() { final int numRows = 28; final int numColumns = 28; int nChannels = 1; int outputNum = 10; int seed = 123; - - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed) - .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).list() - .layer(0, new org.deeplearning4j.nn.conf.layers.ConvolutionLayer.Builder(new int[] {10, 10}, - new int[] {2, 2}).nIn(nChannels).nOut(6).build()) - .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[] {2, 2}) - .build()) - .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) - .nIn(5 * 5 * 1 * 6) //216 - .nOut(outputNum).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX) - .build()) - .inputPreProcessor(0, new FeedForwardToCnnPreProcessor(numRows, numColumns, nChannels)) - .inputPreProcessor(2, new CnnToFeedForwardPreProcessor(5, 5, 6)); - + MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed).optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).list().layer(0, new org.deeplearning4j.nn.conf.layers.ConvolutionLayer.Builder(new int[] { 10, 10 }, new int[] { 2, 2 }).nIn(nChannels).nOut(6).build()).layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[] { 2, 2 }).build()).layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nIn(// 216 + 5 * 5 * 1 * 6).nOut(outputNum).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build()).inputPreProcessor(0, new FeedForwardToCnnPreProcessor(numRows, numColumns, nChannels)).inputPreProcessor(2, new CnnToFeedForwardPreProcessor(5, 5, 6)); return builder; } - @Test - public void testDeconvolution() { - - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().list() - //out = stride * (in-1) + filter - 2*pad -> 2 * (28-1) + 2 - 0 = 56 -> 56x56x3 - .layer(0, new Deconvolution2D.Builder(2, 2).padding(0, 0).stride(2, 2).nIn(1).nOut(3).build()) - //(56-2+2*1)/2+1 = 29 -> 29x29x3 - .layer(1, new SubsamplingLayer.Builder().kernelSize(2, 2).padding(1, 1).stride(2, 2).build()) - .layer(2, new OutputLayer.Builder().nOut(3).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutional(28, 28, 1)); - + @DisplayName("Test Deconvolution") + void testDeconvolution() { + MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().list().layer(0, new Deconvolution2D.Builder(2, 2).padding(0, 0).stride(2, 2).nIn(1).nOut(3).build()).layer(1, new SubsamplingLayer.Builder().kernelSize(2, 2).padding(1, 1).stride(2, 2).build()).layer(2, new OutputLayer.Builder().nOut(3).activation(Activation.SOFTMAX).build()).setInputType(InputType.convolutional(28, 28, 1)); MultiLayerConfiguration conf = builder.build(); - assertNotNull(conf.getInputPreProcess(2)); assertTrue(conf.getInputPreProcess(2) instanceof CnnToFeedForwardPreProcessor); CnnToFeedForwardPreProcessor proc = (CnnToFeedForwardPreProcessor) conf.getInputPreProcess(2); assertEquals(29, proc.getInputHeight()); assertEquals(29, proc.getInputWidth()); assertEquals(3, proc.getNumChannels()); - assertEquals(29 * 29 * 3, ((FeedForwardLayer) conf.getConf(2).getLayer()).getNIn()); } @Test - public void testSubSamplingWithPadding() { - - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().list() - .layer(0, new ConvolutionLayer.Builder(2, 2).padding(0, 0).stride(2, 2).nIn(1).nOut(3).build()) //(28-2+0)/2+1 = 14 - .layer(1, new SubsamplingLayer.Builder().kernelSize(2, 2).padding(1, 1).stride(2, 2).build()) //(14-2+2)/2+1 = 8 -> 8x8x3 - .layer(2, new OutputLayer.Builder().nOut(3).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutional(28, 28, 1)); - + @DisplayName("Test Sub Sampling With Padding") + void testSubSamplingWithPadding() { + MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().list().layer(0, // (28-2+0)/2+1 = 14 + new ConvolutionLayer.Builder(2, 2).padding(0, 0).stride(2, 2).nIn(1).nOut(3).build()).layer(1, // (14-2+2)/2+1 = 8 -> 8x8x3 + new SubsamplingLayer.Builder().kernelSize(2, 2).padding(1, 1).stride(2, 2).build()).layer(2, new OutputLayer.Builder().nOut(3).activation(Activation.SOFTMAX).build()).setInputType(InputType.convolutional(28, 28, 1)); MultiLayerConfiguration conf = builder.build(); - assertNotNull(conf.getInputPreProcess(2)); assertTrue(conf.getInputPreProcess(2) instanceof CnnToFeedForwardPreProcessor); CnnToFeedForwardPreProcessor proc = (CnnToFeedForwardPreProcessor) conf.getInputPreProcess(2); assertEquals(8, proc.getInputHeight()); assertEquals(8, proc.getInputWidth()); assertEquals(3, proc.getNumChannels()); - assertEquals(8 * 8 * 3, ((FeedForwardLayer) conf.getConf(2).getLayer()).getNIn()); } @Test - public void testUpsampling() { - - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().list() - .layer(new ConvolutionLayer.Builder(2, 2).padding(0, 0).stride(2, 2).nIn(1).nOut(3).build()) //(28-2+0)/2+1 = 14 - .layer(new Upsampling2D.Builder().size(3).build()) // 14 * 3 = 42! - .layer(new OutputLayer.Builder().nOut(3).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutional(28, 28, 1)); - + @DisplayName("Test Upsampling") + void testUpsampling() { + MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().list().layer(// (28-2+0)/2+1 = 14 + new ConvolutionLayer.Builder(2, 2).padding(0, 0).stride(2, 2).nIn(1).nOut(3).build()).layer(// 14 * 3 = 42! + new Upsampling2D.Builder().size(3).build()).layer(new OutputLayer.Builder().nOut(3).activation(Activation.SOFTMAX).build()).setInputType(InputType.convolutional(28, 28, 1)); MultiLayerConfiguration conf = builder.build(); - assertNotNull(conf.getInputPreProcess(2)); assertTrue(conf.getInputPreProcess(2) instanceof CnnToFeedForwardPreProcessor); CnnToFeedForwardPreProcessor proc = (CnnToFeedForwardPreProcessor) conf.getInputPreProcess(2); assertEquals(42, proc.getInputHeight()); assertEquals(42, proc.getInputWidth()); assertEquals(3, proc.getNumChannels()); - assertEquals(42 * 42 * 3, ((FeedForwardLayer) conf.getConf(2).getLayer()).getNIn()); } @Test - public void testSpaceToBatch() { - - int[] blocks = new int[] {2, 2}; - - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().list() - .layer(new ConvolutionLayer.Builder(2, 2).padding(0, 0).stride(2, 2).nIn(1).nOut(3).build()) //(28-2+0)/2+1 = 14 - .layer(new SpaceToBatchLayer.Builder(blocks).build()) // Divide space dimensions by blocks, i.e. 14/2 = 7 - .layer(new OutputLayer.Builder().nOut(3).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutional(28, 28, 1)); - + @DisplayName("Test Space To Batch") + void testSpaceToBatch() { + int[] blocks = new int[] { 2, 2 }; + MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().list().layer(// (28-2+0)/2+1 = 14 + new ConvolutionLayer.Builder(2, 2).padding(0, 0).stride(2, 2).nIn(1).nOut(3).build()).layer(// Divide space dimensions by blocks, i.e. 14/2 = 7 + new SpaceToBatchLayer.Builder(blocks).build()).layer(new OutputLayer.Builder().nOut(3).activation(Activation.SOFTMAX).build()).setInputType(InputType.convolutional(28, 28, 1)); MultiLayerConfiguration conf = builder.build(); - assertNotNull(conf.getInputPreProcess(2)); assertTrue(conf.getInputPreProcess(2) instanceof CnnToFeedForwardPreProcessor); CnnToFeedForwardPreProcessor proc = (CnnToFeedForwardPreProcessor) conf.getInputPreProcess(2); @@ -386,58 +248,32 @@ public class ConvolutionLayerSetupTest extends BaseDL4JTest { } @Test - public void testSpaceToDepth() { - + @DisplayName("Test Space To Depth") + void testSpaceToDepth() { int blocks = 2; - - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().list() - //(28-2+0)/2+1 = 14 -> 14x14x3 out - .layer(new ConvolutionLayer.Builder(2, 2).padding(0, 0).stride(2, 2).nIn(1).nOut(3).build()) - // Divide space dimensions by blocks, i.e. 14/2 = 7 -> 7x7x12 out (3x2x2 depth) - .layer(new SpaceToDepthLayer.Builder(blocks, SpaceToDepthLayer.DataFormat.NCHW).build()) - .layer(new OutputLayer.Builder().nIn(3 * 2 * 2).nOut(3).activation(Activation.SOFTMAX).build()) // nIn of the next layer gets multiplied by 2*2. - .setInputType(InputType.convolutional(28, 28, 1)); - + MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().list().layer(new ConvolutionLayer.Builder(2, 2).padding(0, 0).stride(2, 2).nIn(1).nOut(3).build()).layer(new SpaceToDepthLayer.Builder(blocks, SpaceToDepthLayer.DataFormat.NCHW).build()).layer(// nIn of the next layer gets multiplied by 2*2. + new OutputLayer.Builder().nIn(3 * 2 * 2).nOut(3).activation(Activation.SOFTMAX).build()).setInputType(InputType.convolutional(28, 28, 1)); MultiLayerConfiguration conf = builder.build(); - assertNotNull(conf.getInputPreProcess(2)); assertTrue(conf.getInputPreProcess(2) instanceof CnnToFeedForwardPreProcessor); CnnToFeedForwardPreProcessor proc = (CnnToFeedForwardPreProcessor) conf.getInputPreProcess(2); assertEquals(7, proc.getInputHeight()); assertEquals(7, proc.getInputWidth()); assertEquals(12, proc.getNumChannels()); - } - @Test - public void testCNNDBNMultiLayer() throws Exception { + @DisplayName("Test CNNDBN Multi Layer") + void testCNNDBNMultiLayer() throws Exception { DataSetIterator iter = new MnistDataSetIterator(2, 2); DataSet next = iter.next(); - // Run with separate activation layer - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123) - .weightInit(WeightInit.XAVIER).list() - .layer(0, new ConvolutionLayer.Builder(new int[] {1, 1}, new int[] {1, 1}).nIn(1).nOut(6) - .activation(Activation.IDENTITY).build()) - .layer(1, new BatchNormalization.Builder().build()) - .layer(2, new ActivationLayer.Builder().activation(Activation.RELU).build()) - .layer(3, new DenseLayer.Builder().nIn(28 * 28 * 6).nOut(10).activation(Activation.IDENTITY) - .build()) - .layer(4, new BatchNormalization.Builder().nOut(10).build()) - .layer(5, new ActivationLayer.Builder().activation(Activation.RELU).build()) - .layer(6, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nOut(10).build()) - .setInputType(InputType.convolutionalFlat(28, 28, 1)).build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123).weightInit(WeightInit.XAVIER).list().layer(0, new ConvolutionLayer.Builder(new int[] { 1, 1 }, new int[] { 1, 1 }).nIn(1).nOut(6).activation(Activation.IDENTITY).build()).layer(1, new BatchNormalization.Builder().build()).layer(2, new ActivationLayer.Builder().activation(Activation.RELU).build()).layer(3, new DenseLayer.Builder().nIn(28 * 28 * 6).nOut(10).activation(Activation.IDENTITY).build()).layer(4, new BatchNormalization.Builder().nOut(10).build()).layer(5, new ActivationLayer.Builder().activation(Activation.RELU).build()).layer(6, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(10).build()).setInputType(InputType.convolutionalFlat(28, 28, 1)).build(); MultiLayerNetwork network = new MultiLayerNetwork(conf); network.init(); - network.setInput(next.getFeatures()); INDArray activationsActual = network.output(next.getFeatures()); assertEquals(10, activationsActual.shape()[1], 1e-2); - network.fit(next); INDArray actualGammaParam = network.getLayer(1).getParam(BatchNormalizationParamInitializer.GAMMA); INDArray actualBetaParam = network.getLayer(1).getParam(BatchNormalizationParamInitializer.BETA); @@ -446,52 +282,31 @@ public class ConvolutionLayerSetupTest extends BaseDL4JTest { } @Test - public void testSeparableConv2D() { - - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().list() - .layer( new SeparableConvolution2D.Builder(2, 2) - .depthMultiplier(2) - .padding(0, 0) - .stride(2, 2).nIn(1).nOut(3).build()) //(28-2+0)/2+1 = 14 - .layer( new SubsamplingLayer.Builder().kernelSize(2, 2).padding(1, 1).stride(2, 2).build()) //(14-2+2)/2+1 = 8 -> 8x8x3 - .layer(2, new OutputLayer.Builder().nOut(3).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutional(28, 28, 1)); - + @DisplayName("Test Separable Conv 2 D") + void testSeparableConv2D() { + MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().list().layer(new SeparableConvolution2D.Builder(2, 2).depthMultiplier(2).padding(0, 0).stride(2, 2).nIn(1).nOut(3).build()).layer(// (14-2+2)/2+1 = 8 -> 8x8x3 + new SubsamplingLayer.Builder().kernelSize(2, 2).padding(1, 1).stride(2, 2).build()).layer(2, new OutputLayer.Builder().nOut(3).activation(Activation.SOFTMAX).build()).setInputType(InputType.convolutional(28, 28, 1)); MultiLayerConfiguration conf = builder.build(); - assertNotNull(conf.getInputPreProcess(2)); assertTrue(conf.getInputPreProcess(2) instanceof CnnToFeedForwardPreProcessor); CnnToFeedForwardPreProcessor proc = (CnnToFeedForwardPreProcessor) conf.getInputPreProcess(2); assertEquals(8, proc.getInputHeight()); assertEquals(8, proc.getInputWidth()); assertEquals(3, proc.getNumChannels()); - assertEquals(8 * 8 * 3, ((FeedForwardLayer) conf.getConf(2).getLayer()).getNIn()); } @Test - public void testDeconv2D() { - - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().list() - //out = stride * (in-1) + filter - 2*pad -> 2 * (28-1) + 2 - 0 = 56 -> 56x56x3 - .layer( new Deconvolution2D.Builder(2, 2) - .padding(0, 0) - .stride(2, 2).nIn(1).nOut(3).build()) - //(56-2+2*1)/2+1 = 29 -> 29x29x3 - .layer( new SubsamplingLayer.Builder().kernelSize(2, 2).padding(1, 1).stride(2, 2).build()) - .layer(2, new OutputLayer.Builder().nOut(3).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutional(28, 28, 1)); - + @DisplayName("Test Deconv 2 D") + void testDeconv2D() { + MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().list().layer(new Deconvolution2D.Builder(2, 2).padding(0, 0).stride(2, 2).nIn(1).nOut(3).build()).layer(new SubsamplingLayer.Builder().kernelSize(2, 2).padding(1, 1).stride(2, 2).build()).layer(2, new OutputLayer.Builder().nOut(3).activation(Activation.SOFTMAX).build()).setInputType(InputType.convolutional(28, 28, 1)); MultiLayerConfiguration conf = builder.build(); - assertNotNull(conf.getInputPreProcess(2)); assertTrue(conf.getInputPreProcess(2) instanceof CnnToFeedForwardPreProcessor); CnnToFeedForwardPreProcessor proc = (CnnToFeedForwardPreProcessor) conf.getInputPreProcess(2); assertEquals(29, proc.getInputHeight()); assertEquals(29, proc.getInputWidth()); assertEquals(3, proc.getNumChannels()); - assertEquals(29 * 29 * 3, ((FeedForwardLayer) conf.getConf(2).getLayer()).getNIn()); } - } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerTest.java index 5b93f9fb1..76ee15bf9 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.layers.convolution; import lombok.val; @@ -41,7 +40,7 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInitNormal; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.enums.RnnDataFormat; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.impl.ActivationSoftmax; @@ -58,281 +57,197 @@ import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.Nesterovs; import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.impl.LossMCXENT; - import java.io.File; import java.util.Arrays; import java.util.List; - -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; +import static org.junit.jupiter.api.Assertions.assertThrows; /** * @author Adam Gibson */ -public class ConvolutionLayerTest extends BaseDL4JTest { +@DisplayName("Convolution Layer Test") +class ConvolutionLayerTest extends BaseDL4JTest { @Override - public DataType getDataType(){ + public DataType getDataType() { return DataType.FLOAT; } @Test - public void testTwdFirstLayer() throws Exception { - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(123) - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).l2(2e-4) - .updater(new Nesterovs(0.9)).dropOut(0.5) - .list().layer(0, - new ConvolutionLayer.Builder(8, 8) //16 filters kernel size 8 stride 4 - .stride(4, 4).nOut(16).dropOut(0.5) - .activation(Activation.RELU).weightInit( - WeightInit.XAVIER) - .build()) - .layer(1, new ConvolutionLayer.Builder(4, 4) //32 filters kernel size 4 stride 2 - .stride(2, 2).nOut(32).dropOut(0.5).activation(Activation.RELU) - .weightInit(WeightInit.XAVIER).build()) - .layer(2, new DenseLayer.Builder() //fully connected with 256 rectified units - .nOut(256).activation(Activation.RELU).weightInit(WeightInit.XAVIER) - .dropOut(0.5).build()) - .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.SQUARED_LOSS) //output layer - .nOut(10).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutionalFlat(28, 28, 1)); - + @DisplayName("Test Twd First Layer") + void testTwdFirstLayer() throws Exception { + MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(123).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).l2(2e-4).updater(new Nesterovs(0.9)).dropOut(0.5).list().layer(0, // 16 filters kernel size 8 stride 4 + new ConvolutionLayer.Builder(8, 8).stride(4, 4).nOut(16).dropOut(0.5).activation(Activation.RELU).weightInit(WeightInit.XAVIER).build()).layer(1, // 32 filters kernel size 4 stride 2 + new ConvolutionLayer.Builder(4, 4).stride(2, 2).nOut(32).dropOut(0.5).activation(Activation.RELU).weightInit(WeightInit.XAVIER).build()).layer(2, // fully connected with 256 rectified units + new DenseLayer.Builder().nOut(256).activation(Activation.RELU).weightInit(WeightInit.XAVIER).dropOut(0.5).build()).layer(3, // output layer + new OutputLayer.Builder(LossFunctions.LossFunction.SQUARED_LOSS).nOut(10).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build()).setInputType(InputType.convolutionalFlat(28, 28, 1)); DataSetIterator iter = new MnistDataSetIterator(10, 10); MultiLayerConfiguration conf = builder.build(); MultiLayerNetwork network = new MultiLayerNetwork(conf); network.init(); DataSet ds = iter.next(); - for( int i=0; i<5; i++ ) { + for (int i = 0; i < 5; i++) { network.fit(ds); } } @Test - public void testCNNSubComboWithMixedHW() { + @DisplayName("Test CNN Sub Combo With Mixed HW") + void testCNNSubComboWithMixedHW() { int imageHeight = 20; int imageWidth = 23; int nChannels = 1; int classes = 2; int numSamples = 200; - int kernelHeight = 3; int kernelWidth = 3; - DataSet trainInput; - MultiLayerConfiguration.Builder builder = - new NeuralNetConfiguration.Builder() - .seed(123) - .list() - .layer(0, new ConvolutionLayer.Builder(kernelHeight, kernelWidth).stride(1, 1) - .nOut(2).activation(Activation.RELU) - .weightInit(WeightInit.XAVIER).build()) - .layer(1, new SubsamplingLayer.Builder() - .poolingType(SubsamplingLayer.PoolingType.MAX) - .kernelSize(imageHeight - kernelHeight, 1).stride(1, 1).build()) - .layer(2, new OutputLayer.Builder().nOut(classes).weightInit(WeightInit.XAVIER) - .activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutionalFlat(imageHeight, imageWidth, nChannels)); - + MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(123).list().layer(0, new ConvolutionLayer.Builder(kernelHeight, kernelWidth).stride(1, 1).nOut(2).activation(Activation.RELU).weightInit(WeightInit.XAVIER).build()).layer(1, new SubsamplingLayer.Builder().poolingType(SubsamplingLayer.PoolingType.MAX).kernelSize(imageHeight - kernelHeight, 1).stride(1, 1).build()).layer(2, new OutputLayer.Builder().nOut(classes).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build()).setInputType(InputType.convolutionalFlat(imageHeight, imageWidth, nChannels)); MultiLayerConfiguration conf = builder.build(); MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); - INDArray emptyFeatures = Nd4j.zeros(numSamples, imageWidth * imageHeight * nChannels); INDArray emptyLables = Nd4j.zeros(numSamples, classes); - trainInput = new DataSet(emptyFeatures, emptyLables); model.fit(trainInput); } @Test - public void testCausal1d() { + @DisplayName("Test Causal 1 d") + void testCausal1d() { Nd4j.getEnvironment().setVerbose(true); Nd4j.getEnvironment().setDebug(true); - //See: Fixes: https://github.com/eclipse/deeplearning4j/issues/9060 + // See: Fixes: https://github.com/eclipse/deeplearning4j/issues/9060 double learningRate = 1e-3; long seed = 123; long timeSteps = 72; long vectorLength = 64; long batchSize = 1; - INDArray arr = Nd4j.randn(batchSize,vectorLength,timeSteps); - - MultiLayerConfiguration build = new NeuralNetConfiguration.Builder().seed(seed) - .activation(Activation.RELU) - .weightInit(new WeightInitNormal()) // better init - .updater(new Adam(learningRate)) - .list() - // block 1 - .layer(new Convolution1D.Builder() - .kernelSize(2) - .rnnDataFormat(RNNFormat.NCW) - .stride(1) - .nOut(14) - .convolutionMode(ConvolutionMode.Causal) - .dilation(4) - .build()) - .layer(new RnnLossLayer.Builder().dataFormat(RNNFormat.NCW) - .activation(new ActivationSoftmax()) - .lossFunction(new LossMCXENT()).build()) - .setInputType(InputType.recurrent(vectorLength,timeSteps,RNNFormat.NCW)) - .build(); - + INDArray arr = Nd4j.randn(batchSize, vectorLength, timeSteps); + MultiLayerConfiguration build = new NeuralNetConfiguration.Builder().seed(seed).activation(Activation.RELU).weightInit(// better init + new WeightInitNormal()).updater(new Adam(learningRate)).list().layer(new Convolution1D.Builder().kernelSize(2).rnnDataFormat(RNNFormat.NCW).stride(1).nOut(14).convolutionMode(ConvolutionMode.Causal).dilation(4).build()).layer(new RnnLossLayer.Builder().dataFormat(RNNFormat.NCW).activation(new ActivationSoftmax()).lossFunction(new LossMCXENT()).build()).setInputType(InputType.recurrent(vectorLength, timeSteps, RNNFormat.NCW)).build(); MultiLayerNetwork network = new MultiLayerNetwork(build); network.init(); INDArray output = network.output(arr); - assertArrayEquals(new long[]{1,14,72},output.shape()); + assertArrayEquals(new long[] { 1, 14, 72 }, output.shape()); System.out.println(output); } - @Test(expected = DL4JException.class) - public void testCNNTooLargeKernel() { - int imageHeight = 20; - int imageWidth = 23; - int nChannels = 1; - int classes = 2; - int numSamples = 200; - - int kernelHeight = imageHeight; - int kernelWidth = imageWidth + 1; - - DataSet trainInput; - MultiLayerConfiguration.Builder builder = - new NeuralNetConfiguration.Builder() - .seed(123) - .list() - .layer(0, new ConvolutionLayer.Builder(kernelHeight, kernelWidth) //(img-kernel+2*padding)/stride + 1: must be >= 1. Therefore: with p=0, kernel <= img size - .stride(1, 1).nOut(2).activation(Activation.RELU) - .weightInit(WeightInit.XAVIER).build()) - .layer(1, new OutputLayer.Builder().nOut(classes).weightInit(WeightInit.XAVIER) - .activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutionalFlat(imageHeight, imageWidth, nChannels)) - ; - - MultiLayerConfiguration conf = builder.build(); - MultiLayerNetwork model = new MultiLayerNetwork(conf); - model.init(); - - INDArray emptyFeatures = Nd4j.zeros(numSamples, imageWidth * imageHeight * nChannels); - INDArray emptyLables = Nd4j.zeros(numSamples, classes); - - trainInput = new DataSet(emptyFeatures, emptyLables); - model.fit(trainInput); - } - - @Test(expected = Exception.class) - public void testCNNZeroStride() { - int imageHeight = 20; - int imageWidth = 23; - int nChannels = 1; - int classes = 2; - int numSamples = 200; - - int kernelHeight = imageHeight; - int kernelWidth = imageWidth; - - DataSet trainInput; - MultiLayerConfiguration.Builder builder = - new NeuralNetConfiguration.Builder() - .seed(123) - .list() - .layer(0, new ConvolutionLayer.Builder(kernelHeight, kernelWidth).stride(1, 0) - .nOut(2).activation(Activation.RELU) - .weightInit(WeightInit.XAVIER).build()) - .layer(1, new OutputLayer.Builder().nOut(classes).weightInit(WeightInit.XAVIER) - .activation(Activation.SOFTMAX).build()) - - .setInputType(InputType.convolutional(imageHeight, imageWidth, nChannels)); - - MultiLayerConfiguration conf = builder.build(); - MultiLayerNetwork model = new MultiLayerNetwork(conf); - model.init(); - - INDArray emptyFeatures = Nd4j.zeros(numSamples, imageWidth * imageHeight * nChannels); - INDArray emptyLables = Nd4j.zeros(numSamples, classes); - - trainInput = new DataSet(emptyFeatures, emptyLables); - model.fit(trainInput); + @Test + @DisplayName("Test CNN Too Large Kernel") + void testCNNTooLargeKernel() { + assertThrows(DL4JException.class, () -> { + int imageHeight = 20; + int imageWidth = 23; + int nChannels = 1; + int classes = 2; + int numSamples = 200; + int kernelHeight = imageHeight; + int kernelWidth = imageWidth + 1; + DataSet trainInput; + MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(123).list().layer(0, // (img-kernel+2*padding)/stride + 1: must be >= 1. Therefore: with p=0, kernel <= img size + new ConvolutionLayer.Builder(kernelHeight, kernelWidth).stride(1, 1).nOut(2).activation(Activation.RELU).weightInit(WeightInit.XAVIER).build()).layer(1, new OutputLayer.Builder().nOut(classes).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build()).setInputType(InputType.convolutionalFlat(imageHeight, imageWidth, nChannels)); + MultiLayerConfiguration conf = builder.build(); + MultiLayerNetwork model = new MultiLayerNetwork(conf); + model.init(); + INDArray emptyFeatures = Nd4j.zeros(numSamples, imageWidth * imageHeight * nChannels); + INDArray emptyLables = Nd4j.zeros(numSamples, classes); + trainInput = new DataSet(emptyFeatures, emptyLables); + model.fit(trainInput); + }); } @Test - public void testCNNBiasInit() { + @DisplayName("Test CNN Zero Stride") + void testCNNZeroStride() { + assertThrows(Exception.class, () -> { + int imageHeight = 20; + int imageWidth = 23; + int nChannels = 1; + int classes = 2; + int numSamples = 200; + int kernelHeight = imageHeight; + int kernelWidth = imageWidth; + DataSet trainInput; + MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(123).list().layer(0, new ConvolutionLayer.Builder(kernelHeight, kernelWidth).stride(1, 0).nOut(2).activation(Activation.RELU).weightInit(WeightInit.XAVIER).build()).layer(1, new OutputLayer.Builder().nOut(classes).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build()).setInputType(InputType.convolutional(imageHeight, imageWidth, nChannels)); + MultiLayerConfiguration conf = builder.build(); + MultiLayerNetwork model = new MultiLayerNetwork(conf); + model.init(); + INDArray emptyFeatures = Nd4j.zeros(numSamples, imageWidth * imageHeight * nChannels); + INDArray emptyLables = Nd4j.zeros(numSamples, classes); + trainInput = new DataSet(emptyFeatures, emptyLables); + model.fit(trainInput); + }); + } + + @Test + @DisplayName("Test CNN Bias Init") + void testCNNBiasInit() { ConvolutionLayer cnn = new ConvolutionLayer.Builder().nIn(1).nOut(3).biasInit(1).build(); - NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().layer(cnn).build(); - val numParams = conf.getLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); Layer layer = conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); - assertEquals(1, layer.getParam("b").size(0)); } @Test - public void testCNNInputSetupMNIST() throws Exception { + @DisplayName("Test CNN Input Setup MNIST") + void testCNNInputSetupMNIST() throws Exception { INDArray input = getMnistData(); Layer layer = getMNISTConfig(); layer.activate(input, false, LayerWorkspaceMgr.noWorkspaces()); - assertEquals(input, layer.input()); assertArrayEquals(input.shape(), layer.input().shape()); } @Test - public void testFeatureMapShapeMNIST() throws Exception { + @DisplayName("Test Feature Map Shape MNIST") + void testFeatureMapShapeMNIST() throws Exception { int inputWidth = 28; - int[] stride = new int[] {1, 1}; - int[] padding = new int[] {0, 0}; - int[] kernelSize = new int[] {9, 9}; + int[] stride = new int[] { 1, 1 }; + int[] padding = new int[] { 0, 0 }; + int[] kernelSize = new int[] { 9, 9 }; int nChannelsIn = 1; int depth = 20; int featureMapWidth = (inputWidth + padding[1] * 2 - kernelSize[1]) / stride[1] + 1; - INDArray input = getMnistData(); - Layer layer = getCNNConfig(nChannelsIn, depth, kernelSize, stride, padding); INDArray convActivations = layer.activate(input, false, LayerWorkspaceMgr.noWorkspaces()); - assertEquals(featureMapWidth, convActivations.size(2)); assertEquals(depth, convActivations.size(1)); } @Test - public void testActivateResultsContained() { + @DisplayName("Test Activate Results Contained") + void testActivateResultsContained() { Layer layer = getContainedConfig(); INDArray input = getContainedData(); - INDArray expectedOutput = Nd4j.create(new float[] {0.98201379f, 0.98201379f, 0.98201379f, 0.98201379f, 0.99966465f, - 0.99966465f, 0.99966465f, 0.99966465f, 0.98201379f, 0.98201379f, 0.98201379f, 0.98201379f, 0.99966465f, - 0.99966465f, 0.99966465f, 0.99966465f, 0.98201379f, 0.98201379f, 0.98201379f, 0.98201379f, 0.99966465f, - 0.99966465f, 0.99966465f, 0.99966465f, 0.98201379f, 0.98201379f, 0.98201379f, 0.98201379f, 0.99966465f, - 0.99966465f, 0.99966465f, 0.99966465f}, new int[] {1, 2, 4, 4}); - + INDArray expectedOutput = Nd4j.create(new float[] { 0.98201379f, 0.98201379f, 0.98201379f, 0.98201379f, 0.99966465f, 0.99966465f, 0.99966465f, 0.99966465f, 0.98201379f, 0.98201379f, 0.98201379f, 0.98201379f, 0.99966465f, 0.99966465f, 0.99966465f, 0.99966465f, 0.98201379f, 0.98201379f, 0.98201379f, 0.98201379f, 0.99966465f, 0.99966465f, 0.99966465f, 0.99966465f, 0.98201379f, 0.98201379f, 0.98201379f, 0.98201379f, 0.99966465f, 0.99966465f, 0.99966465f, 0.99966465f }, new int[] { 1, 2, 4, 4 }); INDArray convActivations = layer.activate(input, false, LayerWorkspaceMgr.noWorkspaces()); - assertArrayEquals(expectedOutput.shape(), convActivations.shape()); assertEquals(expectedOutput, convActivations); } - ////////////////////////////////////////////////////////////////////////////////// - + // //////////////////////////////////////////////////////////////////////////////// private static Layer getCNNConfig(int nIn, int nOut, int[] kernelSize, int[] stride, int[] padding) { - - ConvolutionLayer layer = new ConvolutionLayer.Builder(kernelSize, stride, padding).nIn(nIn).nOut(nOut) - .activation(Activation.SIGMOID).build(); - + ConvolutionLayer layer = new ConvolutionLayer.Builder(kernelSize, stride, padding).nIn(nIn).nOut(nOut).activation(Activation.SIGMOID).build(); NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().layer(layer).build(); - val numParams = conf.getLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); return conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); } public Layer getMNISTConfig() { - int[] kernelSize = new int[] {9, 9}; - int[] stride = new int[] {1, 1}; - int[] padding = new int[] {1, 1}; + int[] kernelSize = new int[] { 9, 9 }; + int[] stride = new int[] { 1, 1 }; + int[] padding = new int[] { 1, 1 }; int nChannelsIn = 1; int depth = 20; - return getCNNConfig(nChannelsIn, depth, kernelSize, stride, padding); - } public INDArray getMnistData() throws Exception { @@ -340,7 +255,6 @@ public class ConvolutionLayerTest extends BaseDL4JTest { int inputHeight = 28; int nChannelsIn = 1; int nExamples = 5; - DataSetIterator data = new MnistDataSetIterator(nExamples, nExamples); DataSet mnist = data.next(); nExamples = mnist.numExamples(); @@ -348,131 +262,108 @@ public class ConvolutionLayerTest extends BaseDL4JTest { } public Layer getContainedConfig() { - int[] kernelSize = new int[] {2, 2}; - int[] stride = new int[] {2, 2}; - int[] padding = new int[] {0, 0}; + int[] kernelSize = new int[] { 2, 2 }; + int[] stride = new int[] { 2, 2 }; + int[] padding = new int[] { 0, 0 }; int nChannelsIn = 1; int depth = 2; - - INDArray W = Nd4j.create(new double[] {0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5}, new int[] {2, 1, 2, 2}); - INDArray b = Nd4j.create(new double[] {1, 1}); + INDArray W = Nd4j.create(new double[] { 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5 }, new int[] { 2, 1, 2, 2 }); + INDArray b = Nd4j.create(new double[] { 1, 1 }); Layer layer = getCNNConfig(nChannelsIn, depth, kernelSize, stride, padding); layer.setParam("W", W); layer.setParam("b", b); - return layer; - } public INDArray getContainedData() { - INDArray ret = Nd4j.create(new float[] {1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, - 4, 4, 4, 4, 4, 4, 4, 4, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, - 4, 4, 4, 4, 4, 4, 4, 4}, new int[] {1, 1, 8, 8}); + INDArray ret = Nd4j.create(new float[] { 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4 }, new int[] { 1, 1, 8, 8 }); return ret; } public INDArray getContainedCol() { - return Nd4j.create(new float[] {1, 1, 1, 1, 3, 3, 3, 3, 1, 1, 1, 1, 3, 3, 3, 3, 1, 1, 1, 1, 3, 3, 3, 3, 1, 1, - 1, 1, 3, 3, 3, 3, 2, 2, 2, 2, 4, 4, 4, 4, 2, 2, 2, 2, 4, 4, 4, 4, 2, 2, 2, 2, 4, 4, 4, 4, 2, 2, - 2, 2, 4, 4, 4, 4}, new int[] {1, 1, 2, 2, 4, 4}); + return Nd4j.create(new float[] { 1, 1, 1, 1, 3, 3, 3, 3, 1, 1, 1, 1, 3, 3, 3, 3, 1, 1, 1, 1, 3, 3, 3, 3, 1, 1, 1, 1, 3, 3, 3, 3, 2, 2, 2, 2, 4, 4, 4, 4, 2, 2, 2, 2, 4, 4, 4, 4, 2, 2, 2, 2, 4, 4, 4, 4, 2, 2, 2, 2, 4, 4, 4, 4 }, new int[] { 1, 1, 2, 2, 4, 4 }); } - - - ////////////////////////////////////////////////////////////////////////////////// - - + // //////////////////////////////////////////////////////////////////////////////// @Test - public void testCNNMLNPretrain() throws Exception { + @DisplayName("Test CNNMLN Pretrain") + void testCNNMLNPretrain() throws Exception { // Note CNN does not do pretrain int numSamples = 10; int batchSize = 10; DataSetIterator mnistIter = new MnistDataSetIterator(batchSize, numSamples, true); - MultiLayerNetwork model = getCNNMLNConfig(false, true); model.fit(mnistIter); - mnistIter.reset(); - MultiLayerNetwork model2 = getCNNMLNConfig(false, true); model2.fit(mnistIter); mnistIter.reset(); - DataSet test = mnistIter.next(); - Evaluation eval = new Evaluation(); INDArray output = model.output(test.getFeatures()); eval.eval(test.getLabels(), output); double f1Score = eval.f1(); - Evaluation eval2 = new Evaluation(); INDArray output2 = model2.output(test.getFeatures()); eval2.eval(test.getLabels(), output2); double f1Score2 = eval2.f1(); - assertEquals(f1Score, f1Score2, 1e-4); - - } - @Test - public void testCNNMLNBackprop() throws Exception { + @DisplayName("Test CNNMLN Backprop") + void testCNNMLNBackprop() throws Exception { int numSamples = 10; int batchSize = 10; DataSetIterator mnistIter = new MnistDataSetIterator(batchSize, numSamples, true); - MultiLayerNetwork model = getCNNMLNConfig(true, false); model.fit(mnistIter); - MultiLayerNetwork model2 = getCNNMLNConfig(true, false); model2.fit(mnistIter); - mnistIter.reset(); DataSet test = mnistIter.next(); - Evaluation eval = new Evaluation(); INDArray output = model.output(test.getFeatures()); eval.eval(test.getLabels(), output); double f1Score = eval.f1(); - Evaluation eval2 = new Evaluation(); INDArray output2 = model2.output(test.getFeatures()); eval2.eval(test.getLabels(), output2); double f1Score2 = eval2.f1(); - assertEquals(f1Score, f1Score2, 1e-4); - } @Test - public void testGetSetParams() { - + @DisplayName("Test Get Set Params") + void testGetSetParams() { MultiLayerNetwork net = getCNNMLNConfig(true, false); - INDArray paramsOrig = net.params().dup(); net.setParams(paramsOrig); - INDArray params2 = net.params(); - assertEquals(paramsOrig, params2); } private static final int kH = 2; + private static final int kW = 2; - private static final int[] strides = {1, 1}; - private static final int[] pad = {0, 0}; + + private static final int[] strides = { 1, 1 }; + + private static final int[] pad = { 0, 0 }; private static final int miniBatch = 2; + private static final int inDepth = 2; + private static final int height = 3; + private static final int width = 3; private static final int outW = 2; + private static final int outH = 2; private static INDArray getInput() { - /* ----- Input images ----- example 0: @@ -485,34 +376,27 @@ public class ConvolutionLayerTest extends BaseDL4JTest { 21 22 23 30 31 32 24 25 26] 33 34 35] */ - - INDArray input = Nd4j.create(new int[] {miniBatch, inDepth, height, width}, 'c'); - input.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.all(), - NDArrayIndex.all()}, Nd4j.create(new double[][] {{0, 1, 2}, {3, 4, 5}, {6, 7, 8}})); - input.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(1), NDArrayIndex.all(), - NDArrayIndex.all()}, Nd4j.create(new double[][] {{9, 10, 11}, {12, 13, 14}, {15, 16, 17}})); - input.put(new INDArrayIndex[] {NDArrayIndex.point(1), NDArrayIndex.point(0), NDArrayIndex.all(), - NDArrayIndex.all()}, Nd4j.create(new double[][] {{18, 19, 20}, {21, 22, 23}, {24, 25, 26}})); - input.put(new INDArrayIndex[] {NDArrayIndex.point(1), NDArrayIndex.point(1), NDArrayIndex.all(), - NDArrayIndex.all()}, Nd4j.create(new double[][] {{27, 28, 29}, {30, 31, 32}, {33, 34, 35}})); - + INDArray input = Nd4j.create(new int[] { miniBatch, inDepth, height, width }, 'c'); + input.put(new INDArrayIndex[] { NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.all() }, Nd4j.create(new double[][] { { 0, 1, 2 }, { 3, 4, 5 }, { 6, 7, 8 } })); + input.put(new INDArrayIndex[] { NDArrayIndex.point(0), NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.all() }, Nd4j.create(new double[][] { { 9, 10, 11 }, { 12, 13, 14 }, { 15, 16, 17 } })); + input.put(new INDArrayIndex[] { NDArrayIndex.point(1), NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.all() }, Nd4j.create(new double[][] { { 18, 19, 20 }, { 21, 22, 23 }, { 24, 25, 26 } })); + input.put(new INDArrayIndex[] { NDArrayIndex.point(1), NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.all() }, Nd4j.create(new double[][] { { 27, 28, 29 }, { 30, 31, 32 }, { 33, 34, 35 } })); return input; } @Test - public void testCnnIm2ColReshaping() { - //This test: a bit unusual in that it tests the *assumptions* of the CNN implementation rather than the implementation itself - //Specifically, it tests the row and column orders after reshaping on im2col is reshaped (both forward and backward pass) + @DisplayName("Test Cnn Im 2 Col Reshaping") + void testCnnIm2ColReshaping() { + // This test: a bit unusual in that it tests the *assumptions* of the CNN implementation rather than the implementation itself + // Specifically, it tests the row and column orders after reshaping on im2col is reshaped (both forward and backward pass) INDArray input = getInput(); - - //im2col in the required order: want [outW,outH,miniBatch,depthIn,kH,kW], but need to input [miniBatch,channels,kH,kW,outH,outW] + // im2col in the required order: want [outW,outH,miniBatch,depthIn,kH,kW], but need to input [miniBatch,channels,kH,kW,outH,outW] // given the current im2col implementation - //To get this: create an array of the order we want, permute it to the order required by im2col implementation, and then do im2col on that - //to get old order from required order: permute(2,3,4,5,1,2) - INDArray col = Nd4j.create(new int[] {miniBatch, outH, outW, inDepth, kH, kW}, 'c'); + // To get this: create an array of the order we want, permute it to the order required by im2col implementation, and then do im2col on that + // to get old order from required order: permute(2,3,4,5,1,2) + INDArray col = Nd4j.create(new int[] { miniBatch, outH, outW, inDepth, kH, kW }, 'c'); INDArray col2 = col.permute(0, 3, 4, 5, 1, 2); Convolution.im2col(input, kH, kW, strides[0], strides[1], pad[0], pad[1], false, col2); - /* Expected Output, im2col - example 0 - @@ -535,63 +419,67 @@ public class ConvolutionLayerTest extends BaseDL4JTest { 21 22 22 23 30 31 31 32 24 25 25 26 33 34 34 35 */ - - //Now, after reshaping im2col to 2d, we expect: - //Rows with order (wOut0,hOut0,mb0), (wOut1,hOut0,mb0), (wOut0,hOut1,mb0), (wOut1,hOut1,mb0), (wOut0,hOut0,mb1), ... - //Columns with order (d0,kh0,kw0), (d0,kh0,kw1), (d0,kh1,kw0), (d0,kh1,kw1), (d1,kh0,kw0), ... - - INDArray reshapedCol = Shape.newShapeNoCopy(col, new int[] {miniBatch * outH * outW, inDepth * kH * kW}, false); - + // Now, after reshaping im2col to 2d, we expect: + // Rows with order (wOut0,hOut0,mb0), (wOut1,hOut0,mb0), (wOut0,hOut1,mb0), (wOut1,hOut1,mb0), (wOut0,hOut0,mb1), ... + // Columns with order (d0,kh0,kw0), (d0,kh0,kw1), (d0,kh1,kw0), (d0,kh1,kw1), (d1,kh0,kw0), ... + INDArray reshapedCol = Shape.newShapeNoCopy(col, new int[] { miniBatch * outH * outW, inDepth * kH * kW }, false); INDArray exp2d = Nd4j.create(outW * outH * miniBatch, inDepth * kH * kW); - exp2d.putRow(0, Nd4j.create(new double[] {0, 1, 3, 4, 9, 10, 12, 13})); //wOut0,hOut0,mb0 -> both depths, in order (d0,kh0,kw0), (d0,kh0,kw1), (d0,kh1,kw0), (d0,kh1,kw1), (d1,kh0,kw0), (d1,kh0,kw1), (d1,kh1,kw0), (d1,kh1,kw1) - exp2d.putRow(1, Nd4j.create(new double[] {1, 2, 4, 5, 10, 11, 13, 14})); //wOut1,hOut0,mb0 - exp2d.putRow(2, Nd4j.create(new double[] {3, 4, 6, 7, 12, 13, 15, 16})); //wOut0,hOut1,mb0 - exp2d.putRow(3, Nd4j.create(new double[] {4, 5, 7, 8, 13, 14, 16, 17})); //wOut1,hOut1,mb0 - exp2d.putRow(4, Nd4j.create(new double[] {18, 19, 21, 22, 27, 28, 30, 31})); //wOut0,hOut0,mb1 - exp2d.putRow(5, Nd4j.create(new double[] {19, 20, 22, 23, 28, 29, 31, 32})); //wOut1,hOut0,mb1 - exp2d.putRow(6, Nd4j.create(new double[] {21, 22, 24, 25, 30, 31, 33, 34})); //wOut0,hOut1,mb1 - exp2d.putRow(7, Nd4j.create(new double[] {22, 23, 25, 26, 31, 32, 34, 35})); //wOut1,hOut1,mb1 - + // wOut0,hOut0,mb0 -> both depths, in order (d0,kh0,kw0), (d0,kh0,kw1), (d0,kh1,kw0), (d0,kh1,kw1), (d1,kh0,kw0), (d1,kh0,kw1), (d1,kh1,kw0), (d1,kh1,kw1) + exp2d.putRow(0, Nd4j.create(new double[] { 0, 1, 3, 4, 9, 10, 12, 13 })); + // wOut1,hOut0,mb0 + exp2d.putRow(1, Nd4j.create(new double[] { 1, 2, 4, 5, 10, 11, 13, 14 })); + // wOut0,hOut1,mb0 + exp2d.putRow(2, Nd4j.create(new double[] { 3, 4, 6, 7, 12, 13, 15, 16 })); + // wOut1,hOut1,mb0 + exp2d.putRow(3, Nd4j.create(new double[] { 4, 5, 7, 8, 13, 14, 16, 17 })); + // wOut0,hOut0,mb1 + exp2d.putRow(4, Nd4j.create(new double[] { 18, 19, 21, 22, 27, 28, 30, 31 })); + // wOut1,hOut0,mb1 + exp2d.putRow(5, Nd4j.create(new double[] { 19, 20, 22, 23, 28, 29, 31, 32 })); + // wOut0,hOut1,mb1 + exp2d.putRow(6, Nd4j.create(new double[] { 21, 22, 24, 25, 30, 31, 33, 34 })); + // wOut1,hOut1,mb1 + exp2d.putRow(7, Nd4j.create(new double[] { 22, 23, 25, 26, 31, 32, 34, 35 })); assertEquals(exp2d, reshapedCol); - - //Check the same thing for the backprop im2col (different order) - INDArray colBackprop = Nd4j.create(new int[] {miniBatch, outH, outW, inDepth, kH, kW}, 'c'); + // Check the same thing for the backprop im2col (different order) + INDArray colBackprop = Nd4j.create(new int[] { miniBatch, outH, outW, inDepth, kH, kW }, 'c'); INDArray colBackprop2 = colBackprop.permute(0, 3, 4, 5, 1, 2); - Convolution.im2col(input, kH, kW, strides[0], strides[1], pad[0], pad[1], false, colBackprop2); - - INDArray reshapedColBackprop = Shape.newShapeNoCopy(colBackprop, - new int[] {miniBatch * outH * outW, inDepth * kH * kW}, false); - - //Rows with order (mb0,h0,w0), (mb0,h0,w1), (mb0,h1,w0), (mb0,h1,w1), (mb1,h0,w0), (mb1,h0,w1), (mb1,h1,w0), (mb1,h1,w1) - //Columns with order (d0,kh0,kw0), (d0,kh0,kw1), (d0,kh1,kw0), (d0,kh1,kw1), (d1,kh0,kw0), ... - + INDArray reshapedColBackprop = Shape.newShapeNoCopy(colBackprop, new int[] { miniBatch * outH * outW, inDepth * kH * kW }, false); + // Rows with order (mb0,h0,w0), (mb0,h0,w1), (mb0,h1,w0), (mb0,h1,w1), (mb1,h0,w0), (mb1,h0,w1), (mb1,h1,w0), (mb1,h1,w1) + // Columns with order (d0,kh0,kw0), (d0,kh0,kw1), (d0,kh1,kw0), (d0,kh1,kw1), (d1,kh0,kw0), ... INDArray exp2dv2 = Nd4j.create(outW * outH * miniBatch, inDepth * kH * kW); - exp2dv2.putRow(0, Nd4j.create(new double[] {0, 1, 3, 4, 9, 10, 12, 13})); //wOut0,hOut0,mb0 -> both depths, in order (d0,kh0,kw0), (d0,kh0,kw1), (d0,kh1,kw0), (d0,kh1,kw1), (d1,kh0,kw0), (d1,kh0,kw1), (d1,kh1,kw0), (d1,kh1,kw1) - exp2dv2.putRow(1, Nd4j.create(new double[] {1, 2, 4, 5, 10, 11, 13, 14})); //wOut1,hOut0,mb0 - exp2dv2.putRow(2, Nd4j.create(new double[] {3, 4, 6, 7, 12, 13, 15, 16})); //wOut0,hOut1,mb0 - exp2dv2.putRow(3, Nd4j.create(new double[] {4, 5, 7, 8, 13, 14, 16, 17})); //wOut1,hOut1,mb0 - exp2dv2.putRow(4, Nd4j.create(new double[] {18, 19, 21, 22, 27, 28, 30, 31})); //wOut0,hOut0,mb1 - exp2dv2.putRow(5, Nd4j.create(new double[] {19, 20, 22, 23, 28, 29, 31, 32})); //wOut1,hOut0,mb1 - exp2dv2.putRow(6, Nd4j.create(new double[] {21, 22, 24, 25, 30, 31, 33, 34})); //wOut0,hOut1,mb1 - exp2dv2.putRow(7, Nd4j.create(new double[] {22, 23, 25, 26, 31, 32, 34, 35})); //wOut1,hOut1,mb1 - + // wOut0,hOut0,mb0 -> both depths, in order (d0,kh0,kw0), (d0,kh0,kw1), (d0,kh1,kw0), (d0,kh1,kw1), (d1,kh0,kw0), (d1,kh0,kw1), (d1,kh1,kw0), (d1,kh1,kw1) + exp2dv2.putRow(0, Nd4j.create(new double[] { 0, 1, 3, 4, 9, 10, 12, 13 })); + // wOut1,hOut0,mb0 + exp2dv2.putRow(1, Nd4j.create(new double[] { 1, 2, 4, 5, 10, 11, 13, 14 })); + // wOut0,hOut1,mb0 + exp2dv2.putRow(2, Nd4j.create(new double[] { 3, 4, 6, 7, 12, 13, 15, 16 })); + // wOut1,hOut1,mb0 + exp2dv2.putRow(3, Nd4j.create(new double[] { 4, 5, 7, 8, 13, 14, 16, 17 })); + // wOut0,hOut0,mb1 + exp2dv2.putRow(4, Nd4j.create(new double[] { 18, 19, 21, 22, 27, 28, 30, 31 })); + // wOut1,hOut0,mb1 + exp2dv2.putRow(5, Nd4j.create(new double[] { 19, 20, 22, 23, 28, 29, 31, 32 })); + // wOut0,hOut1,mb1 + exp2dv2.putRow(6, Nd4j.create(new double[] { 21, 22, 24, 25, 30, 31, 33, 34 })); + // wOut1,hOut1,mb1 + exp2dv2.putRow(7, Nd4j.create(new double[] { 22, 23, 25, 26, 31, 32, 34, 35 })); assertEquals(exp2dv2, reshapedColBackprop); } @Test - public void testDeltaReshaping() { - //As per above test: testing assumptions of cnn implementation... - - //Delta: initially shape [miniBatch,dOut,outH,outW] - //permute to [dOut,miniB,outH,outW] - //then reshape to [dOut,miniB*outH*outW] - //Expect columns of delta2d to be like: (mb0,h0,w0), (mb0,h0,w1), (mb1,h0,w2), (mb0,h1,w0), ... (mb1,...), ..., (mb2,...) + @DisplayName("Test Delta Reshaping") + void testDeltaReshaping() { + // As per above test: testing assumptions of cnn implementation... + // Delta: initially shape [miniBatch,dOut,outH,outW] + // permute to [dOut,miniB,outH,outW] + // then reshape to [dOut,miniB*outH*outW] + // Expect columns of delta2d to be like: (mb0,h0,w0), (mb0,h0,w1), (mb1,h0,w2), (mb0,h1,w0), ... (mb1,...), ..., (mb2,...) int miniBatch = 3; int depth = 2; int outW = 3; int outH = 3; - /* ----- Input delta ----- example 0: @@ -608,46 +496,31 @@ public class ConvolutionLayerTest extends BaseDL4JTest { 39 40 41 48 49 50 42 43 44] 51 52 53] */ - - INDArray deltaOrig = Nd4j.create(new int[] {miniBatch, depth, outH, outW}, 'c'); - deltaOrig.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.all(), - NDArrayIndex.all()}, Nd4j.create(new double[][] {{0, 1, 2}, {3, 4, 5}, {6, 7, 8}})); - deltaOrig.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(1), NDArrayIndex.all(), - NDArrayIndex.all()}, Nd4j.create(new double[][] {{9, 10, 11}, {12, 13, 14}, {15, 16, 17}})); - deltaOrig.put(new INDArrayIndex[] {NDArrayIndex.point(1), NDArrayIndex.point(0), NDArrayIndex.all(), - NDArrayIndex.all()}, Nd4j.create(new double[][] {{18, 19, 20}, {21, 22, 23}, {24, 25, 26}})); - deltaOrig.put(new INDArrayIndex[] {NDArrayIndex.point(1), NDArrayIndex.point(1), NDArrayIndex.all(), - NDArrayIndex.all()}, Nd4j.create(new double[][] {{27, 28, 29}, {30, 31, 32}, {33, 34, 35}})); - deltaOrig.put(new INDArrayIndex[] {NDArrayIndex.point(2), NDArrayIndex.point(0), NDArrayIndex.all(), - NDArrayIndex.all()}, Nd4j.create(new double[][] {{36, 37, 38}, {39, 40, 41}, {42, 43, 44}})); - deltaOrig.put(new INDArrayIndex[] {NDArrayIndex.point(2), NDArrayIndex.point(1), NDArrayIndex.all(), - NDArrayIndex.all()}, Nd4j.create(new double[][] {{45, 46, 47}, {48, 49, 50}, {51, 52, 53}})); - - + INDArray deltaOrig = Nd4j.create(new int[] { miniBatch, depth, outH, outW }, 'c'); + deltaOrig.put(new INDArrayIndex[] { NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.all() }, Nd4j.create(new double[][] { { 0, 1, 2 }, { 3, 4, 5 }, { 6, 7, 8 } })); + deltaOrig.put(new INDArrayIndex[] { NDArrayIndex.point(0), NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.all() }, Nd4j.create(new double[][] { { 9, 10, 11 }, { 12, 13, 14 }, { 15, 16, 17 } })); + deltaOrig.put(new INDArrayIndex[] { NDArrayIndex.point(1), NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.all() }, Nd4j.create(new double[][] { { 18, 19, 20 }, { 21, 22, 23 }, { 24, 25, 26 } })); + deltaOrig.put(new INDArrayIndex[] { NDArrayIndex.point(1), NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.all() }, Nd4j.create(new double[][] { { 27, 28, 29 }, { 30, 31, 32 }, { 33, 34, 35 } })); + deltaOrig.put(new INDArrayIndex[] { NDArrayIndex.point(2), NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.all() }, Nd4j.create(new double[][] { { 36, 37, 38 }, { 39, 40, 41 }, { 42, 43, 44 } })); + deltaOrig.put(new INDArrayIndex[] { NDArrayIndex.point(2), NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.all() }, Nd4j.create(new double[][] { { 45, 46, 47 }, { 48, 49, 50 }, { 51, 52, 53 } })); INDArray deltaPermute = deltaOrig.permute(1, 0, 2, 3).dup('c'); - INDArray delta2d = Shape.newShapeNoCopy(deltaPermute, new int[] {depth, miniBatch * outW * outH}, false); - - INDArray exp = Nd4j.create(new double[][] { - {0, 1, 2, 3, 4, 5, 6, 7, 8, 18, 19, 20, 21, 22, 23, 24, 25, 26, 36, 37, 38, 39, 40, 41, 42, 43, - 44}, //depth0 - {9, 10, 11, 12, 13, 14, 15, 16, 17, 27, 28, 29, 30, 31, 32, 33, 34, 35, 45, 46, 47, 48, 49, 50, - 51, 52, 53} //depth1 - }).castTo(delta2d.dataType()); - + INDArray delta2d = Shape.newShapeNoCopy(deltaPermute, new int[] { depth, miniBatch * outW * outH }, false); + INDArray exp = Nd4j.create(new double[][] { { 0, 1, 2, 3, 4, 5, 6, 7, 8, 18, 19, 20, 21, 22, 23, 24, 25, 26, 36, 37, 38, 39, 40, 41, 42, 43, // depth0 + 44 }, { 9, 10, 11, 12, 13, 14, 15, 16, 17, 27, 28, 29, 30, 31, 32, 33, 34, 35, 45, 46, 47, 48, 49, 50, 51, 52, // depth1 + 53 } }).castTo(delta2d.dataType()); assertEquals(exp, delta2d); } @Test - public void testWeightReshaping() { - //Test assumptions of weight reshaping - //Weights: originally c order, shape [outDepth, inDepth, kH, kw] - //permute (3,2,1,0) - + @DisplayName("Test Weight Reshaping") + void testWeightReshaping() { + // Test assumptions of weight reshaping + // Weights: originally c order, shape [outDepth, inDepth, kH, kw] + // permute (3,2,1,0) int depthOut = 2; int depthIn = 3; int kH = 2; int kW = 2; - /* ----- Weights ----- - dOut 0 - @@ -658,177 +531,130 @@ public class ConvolutionLayerTest extends BaseDL4JTest { [12 13 [16 17 [20 21 14 15] 18 19] 22 23] */ - - INDArray weightOrig = Nd4j.create(new int[] {depthOut, depthIn, kH, kW}, 'c'); - weightOrig.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.all(), - NDArrayIndex.all()}, Nd4j.create(new double[][] {{0, 1}, {2, 3}})); - weightOrig.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(1), NDArrayIndex.all(), - NDArrayIndex.all()}, Nd4j.create(new double[][] {{4, 5}, {6, 7}})); - weightOrig.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(2), NDArrayIndex.all(), - NDArrayIndex.all()}, Nd4j.create(new double[][] {{8, 9}, {10, 11}})); - weightOrig.put(new INDArrayIndex[] {NDArrayIndex.point(1), NDArrayIndex.point(0), NDArrayIndex.all(), - NDArrayIndex.all()}, Nd4j.create(new double[][] {{12, 13}, {14, 15}})); - weightOrig.put(new INDArrayIndex[] {NDArrayIndex.point(1), NDArrayIndex.point(1), NDArrayIndex.all(), - NDArrayIndex.all()}, Nd4j.create(new double[][] {{16, 17}, {18, 19}})); - weightOrig.put(new INDArrayIndex[] {NDArrayIndex.point(1), NDArrayIndex.point(2), NDArrayIndex.all(), - NDArrayIndex.all()}, Nd4j.create(new double[][] {{20, 21}, {22, 23}})); - + INDArray weightOrig = Nd4j.create(new int[] { depthOut, depthIn, kH, kW }, 'c'); + weightOrig.put(new INDArrayIndex[] { NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.all() }, Nd4j.create(new double[][] { { 0, 1 }, { 2, 3 } })); + weightOrig.put(new INDArrayIndex[] { NDArrayIndex.point(0), NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.all() }, Nd4j.create(new double[][] { { 4, 5 }, { 6, 7 } })); + weightOrig.put(new INDArrayIndex[] { NDArrayIndex.point(0), NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.all() }, Nd4j.create(new double[][] { { 8, 9 }, { 10, 11 } })); + weightOrig.put(new INDArrayIndex[] { NDArrayIndex.point(1), NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.all() }, Nd4j.create(new double[][] { { 12, 13 }, { 14, 15 } })); + weightOrig.put(new INDArrayIndex[] { NDArrayIndex.point(1), NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.all() }, Nd4j.create(new double[][] { { 16, 17 }, { 18, 19 } })); + weightOrig.put(new INDArrayIndex[] { NDArrayIndex.point(1), NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.all() }, Nd4j.create(new double[][] { { 20, 21 }, { 22, 23 } })); INDArray weightPermute = weightOrig.permute(3, 2, 1, 0); - INDArray w2d = Shape.newShapeNoCopy(weightPermute, new int[] {depthIn * kH * kW, depthOut}, true); - + INDArray w2d = Shape.newShapeNoCopy(weightPermute, new int[] { depthIn * kH * kW, depthOut }, true); assertNotNull(w2d); - - //Expected order of weight rows, after reshaping: (kw0,kh0,din0), (kw1,kh0,din0), (kw0,kh1,din0), (kw1,kh1,din0), (kw0,kh0,din1), ... - INDArray wExp = Nd4j.create(new double[][] {{0, 12}, {1, 13}, {2, 14}, {3, 15}, {4, 16}, {5, 17}, {6, 18}, - {7, 19}, {8, 20}, {9, 21}, {10, 22}, {11, 23}}).castTo(DataType.FLOAT); - + // Expected order of weight rows, after reshaping: (kw0,kh0,din0), (kw1,kh0,din0), (kw0,kh1,din0), (kw1,kh1,din0), (kw0,kh0,din1), ... + INDArray wExp = Nd4j.create(new double[][] { { 0, 12 }, { 1, 13 }, { 2, 14 }, { 3, 15 }, { 4, 16 }, { 5, 17 }, { 6, 18 }, { 7, 19 }, { 8, 20 }, { 9, 21 }, { 10, 22 }, { 11, 23 } }).castTo(DataType.FLOAT); assertEquals(wExp, w2d); } - ////////////////////////////////////////////////////////////////////////////////// - + // //////////////////////////////////////////////////////////////////////////////// private static MultiLayerNetwork getCNNMLNConfig(boolean backprop, boolean pretrain) { int outputNum = 10; int seed = 123; - - MultiLayerConfiguration.Builder conf = - new NeuralNetConfiguration.Builder().seed(seed) - .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).list() - .layer(0, new ConvolutionLayer.Builder(new int[] {10, 10}).nOut(6).build()) - .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, - new int[] {2, 2}).stride(1, 1).build()) - .layer(2, new OutputLayer.Builder( - LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) - .nOut(outputNum).weightInit(WeightInit.XAVIER) - .activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutionalFlat(28, 28, 1)); - + MultiLayerConfiguration.Builder conf = new NeuralNetConfiguration.Builder().seed(seed).optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).list().layer(0, new ConvolutionLayer.Builder(new int[] { 10, 10 }).nOut(6).build()).layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[] { 2, 2 }).stride(1, 1).build()).layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(outputNum).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build()).setInputType(InputType.convolutionalFlat(28, 28, 1)); MultiLayerNetwork model = new MultiLayerNetwork(conf.build()); model.init(); - return model; } - - @Test - public void test1dInputType(){ - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .convolutionMode(ConvolutionMode.Same) - .list() - .layer(new Convolution1DLayer.Builder().nOut(3).kernelSize(2).activation(Activation.TANH).build()) - .layer(new Subsampling1DLayer.Builder().kernelSize(2).stride(2).build()) - .layer(new Upsampling1D.Builder().size(2).build()) - .layer(new RnnOutputLayer.Builder().nOut(7).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.recurrent(10)) - .build(); - + @DisplayName("Test 1 d Input Type") + void test1dInputType() { + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().convolutionMode(ConvolutionMode.Same).list().layer(new Convolution1DLayer.Builder().nOut(3).kernelSize(2).activation(Activation.TANH).build()).layer(new Subsampling1DLayer.Builder().kernelSize(2).stride(2).build()).layer(new Upsampling1D.Builder().size(2).build()).layer(new RnnOutputLayer.Builder().nOut(7).activation(Activation.SOFTMAX).build()).setInputType(InputType.recurrent(10)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - List l = conf.getLayerActivationTypes(InputType.recurrent(10)); assertEquals(InputType.recurrent(3, -1), l.get(0)); assertEquals(InputType.recurrent(3, -1), l.get(1)); assertEquals(InputType.recurrent(3, -1), l.get(2)); assertEquals(InputType.recurrent(7, -1), l.get(3)); - List l2 = conf.getLayerActivationTypes(InputType.recurrent(10, 6)); assertEquals(InputType.recurrent(3, 6), l2.get(0)); assertEquals(InputType.recurrent(3, 3), l2.get(1)); assertEquals(InputType.recurrent(3, 6), l2.get(2)); assertEquals(InputType.recurrent(7, 6), l2.get(3)); - - INDArray in = Nd4j.create(2, 10, 6); INDArray out = net.output(in); - assertArrayEquals(new long[]{2,7,6}, out.shape()); + assertArrayEquals(new long[] { 2, 7, 6 }, out.shape()); } @Test - public void testDeconvBadInput(){ - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .list() - .layer(new Deconvolution2D.Builder().nIn(5).nOut(3).build()) - .build(); + @DisplayName("Test Deconv Bad Input") + void testDeconvBadInput() { + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(new Deconvolution2D.Builder().nIn(5).nOut(3).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - INDArray badInput = Nd4j.create(DataType.FLOAT, 1, 10, 5, 5); try { net.output(badInput); - } catch (DL4JInvalidInputException e){ + } catch (DL4JInvalidInputException e) { String msg = e.getMessage(); - assertTrue(msg,msg.contains("Deconvolution2D") && msg.contains("input") && msg.contains("channels")); + assertTrue( msg.contains("Deconvolution2D") && msg.contains("input") && msg.contains("channels"),msg); } } @Test - public void testConv1dCausalAllowed(){ + @DisplayName("Test Conv 1 d Causal Allowed") + void testConv1dCausalAllowed() { new Convolution1DLayer.Builder().convolutionMode(ConvolutionMode.Causal).kernelSize(2).build(); new Subsampling1DLayer.Builder().convolutionMode(ConvolutionMode.Causal).kernelSize(2).build(); } @Test - public void testConv2dNoCausalAllowed(){ - - try{ + @DisplayName("Test Conv 2 d No Causal Allowed") + void testConv2dNoCausalAllowed() { + try { new ConvolutionLayer.Builder().convolutionMode(ConvolutionMode.Causal).build(); fail("Expected exception"); - } catch (Throwable t){ + } catch (Throwable t) { String m = t.getMessage().toLowerCase(); - assertTrue(m, m.contains("causal") && m.contains("1d")); + assertTrue(m.contains("causal") && m.contains("1d"),m); } - - try{ + try { new Deconvolution2D.Builder().convolutionMode(ConvolutionMode.Causal).build(); fail("Expected exception"); - } catch (Throwable t){ + } catch (Throwable t) { String m = t.getMessage().toLowerCase(); - assertTrue(m, m.contains("causal") && m.contains("1d")); + assertTrue(m.contains("causal") && m.contains("1d"),m); } - - try{ + try { new DepthwiseConvolution2D.Builder().convolutionMode(ConvolutionMode.Causal).build(); fail("Expected exception"); - } catch (Throwable t){ + } catch (Throwable t) { String m = t.getMessage().toLowerCase(); - assertTrue(m, m.contains("causal") && m.contains("1d")); + assertTrue( m.contains("causal") && m.contains("1d"),m); } - - try{ + try { new SeparableConvolution2D.Builder().convolutionMode(ConvolutionMode.Causal).build(); fail("Expected exception"); - } catch (Throwable t){ + } catch (Throwable t) { String m = t.getMessage().toLowerCase(); - assertTrue(m, m.contains("causal") && m.contains("1d")); + assertTrue(m.contains("causal") && m.contains("1d"),m); } - - try{ + try { new SubsamplingLayer.Builder().convolutionMode(ConvolutionMode.Causal).build(); fail("Expected exception"); - } catch (Throwable t){ + } catch (Throwable t) { String m = t.getMessage().toLowerCase(); - assertTrue(m, m.contains("causal") && m.contains("1d")); + assertTrue( m.contains("causal") && m.contains("1d"),m); } } @Test - public void testConv3dNoCausalAllowed(){ - try{ + @DisplayName("Test Conv 3 d No Causal Allowed") + void testConv3dNoCausalAllowed() { + try { new Convolution3D.Builder().convolutionMode(ConvolutionMode.Causal).build(); fail("Expected exception"); - } catch (Throwable t){ + } catch (Throwable t) { String m = t.getMessage().toLowerCase(); - assertTrue(m, m.contains("causal") && m.contains("1d")); + assertTrue(m.contains("causal") && m.contains("1d"),m); } - - try{ + try { new Subsampling3DLayer.Builder().convolutionMode(ConvolutionMode.Causal).build(); fail("Expected exception"); - } catch (Throwable t){ + } catch (Throwable t) { String m = t.getMessage().toLowerCase(); - assertTrue(m, m.contains("causal") && m.contains("1d")); + assertTrue(m.contains("causal") && m.contains("1d"),m); } } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/LocallyConnectedLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/LocallyConnectedLayerTest.java index 37644c322..e3a2886b6 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/LocallyConnectedLayerTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/LocallyConnectedLayerTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.layers.convolution; import org.deeplearning4j.BaseDL4JTest; @@ -35,8 +34,8 @@ import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; @@ -47,150 +46,100 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Nesterovs; import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.lossfunctions.LossFunctions; - import java.util.Arrays; import java.util.Map; - -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** * @author Max Pumperla */ -public class LocallyConnectedLayerTest extends BaseDL4JTest { +@DisplayName("Locally Connected Layer Test") +class LocallyConnectedLayerTest extends BaseDL4JTest { - @Before - public void before() { + @BeforeEach + void before() { DataTypeUtil.setDTypeForContext(DataType.DOUBLE); Nd4j.factory().setDType(DataType.DOUBLE); Nd4j.EPS_THRESHOLD = 1e-4; } @Test - public void test2dForward(){ - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(123) - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).l2(2e-4) - .updater(new Nesterovs(0.9)).dropOut(0.5) - .list() - .layer(new LocallyConnected2D.Builder().kernelSize(8, 8).nIn(3) - .stride(4, 4).nOut(16).dropOut(0.5) - .convolutionMode(ConvolutionMode.Strict) - .setInputSize(28, 28) - .activation(Activation.RELU).weightInit( - WeightInit.XAVIER) - .build()) - .layer(new OutputLayer.Builder(LossFunctions.LossFunction.SQUARED_LOSS) //output layer - .nOut(10).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutionalFlat(28, 28, 3)); - + @DisplayName("Test 2 d Forward") + void test2dForward() { + MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(123).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).l2(2e-4).updater(new Nesterovs(0.9)).dropOut(0.5).list().layer(new LocallyConnected2D.Builder().kernelSize(8, 8).nIn(3).stride(4, 4).nOut(16).dropOut(0.5).convolutionMode(ConvolutionMode.Strict).setInputSize(28, 28).activation(Activation.RELU).weightInit(WeightInit.XAVIER).build()).layer(// output layer + new OutputLayer.Builder(LossFunctions.LossFunction.SQUARED_LOSS).nOut(10).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build()).setInputType(InputType.convolutionalFlat(28, 28, 3)); MultiLayerConfiguration conf = builder.build(); MultiLayerNetwork network = new MultiLayerNetwork(conf); network.init(); - INDArray input = Nd4j.ones(10, 3, 28, 28); INDArray output = network.output(input, false); - - assertArrayEquals(new long[] {10, 10}, output.shape()); + assertArrayEquals(new long[] { 10, 10 }, output.shape()); } @Test - public void test1dForward(){ - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(123) - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).l2(2e-4) - .updater(new Nesterovs(0.9)).dropOut(0.5) - .list() - .layer(new LocallyConnected1D.Builder().kernelSize(4).nIn(3) - .stride(1).nOut(16).dropOut(0.5) - .convolutionMode(ConvolutionMode.Strict) - .setInputSize(28) - .activation(Activation.RELU).weightInit( - WeightInit.XAVIER) - .build()) - .layer(new OutputLayer.Builder(LossFunctions.LossFunction.SQUARED_LOSS) //output layer - .nOut(10).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.recurrent(3, 8)); - + @DisplayName("Test 1 d Forward") + void test1dForward() { + MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(123).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).l2(2e-4).updater(new Nesterovs(0.9)).dropOut(0.5).list().layer(new LocallyConnected1D.Builder().kernelSize(4).nIn(3).stride(1).nOut(16).dropOut(0.5).convolutionMode(ConvolutionMode.Strict).setInputSize(28).activation(Activation.RELU).weightInit(WeightInit.XAVIER).build()).layer(// output layer + new OutputLayer.Builder(LossFunctions.LossFunction.SQUARED_LOSS).nOut(10).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build()).setInputType(InputType.recurrent(3, 8)); MultiLayerConfiguration conf = builder.build(); MultiLayerNetwork network = new MultiLayerNetwork(conf); network.init(); - INDArray input = Nd4j.ones(10, 3, 8); - INDArray output = network.output(input, false);; - for (int i = 0; i < 100; i++) { // TODO: this falls flat for 1000 iterations on my machine + INDArray output = network.output(input, false); + ; + for (int i = 0; i < 100; i++) { + // TODO: this falls flat for 1000 iterations on my machine output = network.output(input, false); } - - assertArrayEquals(new long[] {(8 - 4 + 1) * 10, 10}, output.shape()); + assertArrayEquals(new long[] { (8 - 4 + 1) * 10, 10 }, output.shape()); network.fit(input, output); - } @Test - public void testLocallyConnected(){ - for (DataType globalDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { + @DisplayName("Test Locally Connected") + void testLocallyConnected() { + for (DataType globalDtype : new DataType[] { DataType.DOUBLE, DataType.FLOAT, DataType.HALF }) { Nd4j.setDefaultDataTypes(globalDtype, globalDtype); - for (DataType networkDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { + for (DataType networkDtype : new DataType[] { DataType.DOUBLE, DataType.FLOAT, DataType.HALF }) { assertEquals(globalDtype, Nd4j.dataType()); assertEquals(globalDtype, Nd4j.defaultFloatingPointType()); - for (int test = 0; test < 2; test++) { String msg = "Global dtype: " + globalDtype + ", network dtype: " + networkDtype + ", test=" + test; - - ComputationGraphConfiguration.GraphBuilder b = new NeuralNetConfiguration.Builder() - .dataType(networkDtype) - .seed(123) - .updater(new NoOp()) - .weightInit(WeightInit.XAVIER) - .convolutionMode(ConvolutionMode.Same) - .graphBuilder(); - + ComputationGraphConfiguration.GraphBuilder b = new NeuralNetConfiguration.Builder().dataType(networkDtype).seed(123).updater(new NoOp()).weightInit(WeightInit.XAVIER).convolutionMode(ConvolutionMode.Same).graphBuilder(); INDArray[] in; INDArray label; - switch (test){ + switch(test) { case 0: - b.addInputs("in") - .addLayer("1", new LSTM.Builder().nOut(5).build(), "in") - .addLayer("2", new LocallyConnected1D.Builder().kernelSize(2).nOut(4).build(), "1") - .addLayer("out", new RnnOutputLayer.Builder().nOut(10).build(), "2") - .setOutputs("out") - .setInputTypes(InputType.recurrent(5, 4)); - in = new INDArray[]{Nd4j.rand(networkDtype, 2, 5, 4)}; + b.addInputs("in").addLayer("1", new LSTM.Builder().nOut(5).build(), "in").addLayer("2", new LocallyConnected1D.Builder().kernelSize(2).nOut(4).build(), "1").addLayer("out", new RnnOutputLayer.Builder().nOut(10).build(), "2").setOutputs("out").setInputTypes(InputType.recurrent(5, 4)); + in = new INDArray[] { Nd4j.rand(networkDtype, 2, 5, 4) }; label = TestUtils.randomOneHotTimeSeries(2, 10, 4).castTo(networkDtype); break; case 1: - b.addInputs("in") - .addLayer("1", new ConvolutionLayer.Builder().kernelSize(2,2).nOut(5).convolutionMode(ConvolutionMode.Same).build(), "in") - .addLayer("2", new LocallyConnected2D.Builder().kernelSize(2,2).nOut(5).build(), "1") - .addLayer("out", new OutputLayer.Builder().nOut(10).build(), "2") - .setOutputs("out") -// .setInputTypes(InputType.convolutional(28, 28, 1)); -// in = new INDArray[]{Nd4j.rand(networkDtype, 2, 1, 28, 28)}; - .setInputTypes(InputType.convolutional(8, 8, 1)); - in = new INDArray[]{Nd4j.rand(networkDtype, 2, 1, 8, 8)}; + b.addInputs("in").addLayer("1", new ConvolutionLayer.Builder().kernelSize(2, 2).nOut(5).convolutionMode(ConvolutionMode.Same).build(), "in").addLayer("2", new LocallyConnected2D.Builder().kernelSize(2, 2).nOut(5).build(), "1").addLayer("out", new OutputLayer.Builder().nOut(10).build(), "2").setOutputs("out").setInputTypes(InputType.convolutional(8, 8, 1)); + in = new INDArray[] { Nd4j.rand(networkDtype, 2, 1, 8, 8) }; label = TestUtils.randomOneHot(2, 10).castTo(networkDtype); break; default: throw new RuntimeException(); } - ComputationGraph net = new ComputationGraph(b.build()); net.init(); - INDArray out = net.outputSingle(in); - assertEquals(msg, networkDtype, out.dataType()); + assertEquals(networkDtype, out.dataType(),msg); Map ff = net.feedForward(in, false); for (Map.Entry e : ff.entrySet()) { if (e.getKey().equals("in")) continue; String s = msg + " - layer: " + e.getKey(); - assertEquals(s, networkDtype, e.getValue().dataType()); + assertEquals( networkDtype, e.getValue().dataType(),s); } - net.setInputs(in); net.setLabels(label); net.computeGradientAndScore(); - - net.fit(new MultiDataSet(in, new INDArray[]{label})); + net.fit(new MultiDataSet(in, new INDArray[] { label })); } } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/SpaceToDepthTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/SpaceToDepthTest.java index 296cb66d6..14259e0bb 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/SpaceToDepthTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/SpaceToDepthTest.java @@ -17,79 +17,77 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.layers.convolution; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.conf.layers.SpaceToDepthLayer; - import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; - import java.util.Arrays; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; - -public class SpaceToDepthTest extends BaseDL4JTest { +@DisplayName("Space To Depth Test") +class SpaceToDepthTest extends BaseDL4JTest { private int mb = 1; + private int inDepth = 2; + private int inputWidth = 2; + private int inputHeight = 2; private int blockSize = 2; + private SpaceToDepthLayer.DataFormat dataFormat = SpaceToDepthLayer.DataFormat.NCHW; private int outDepth = inDepth * blockSize * blockSize; + private int outputHeight = inputHeight / blockSize; + private int outputWidth = inputWidth / blockSize; - private INDArray getContainedData() { - return Nd4j.create(new double[] {1., 2., 3., 4., 5., 6., 7., 8.}, - new int[] {mb, inDepth, inputHeight, inputWidth}, 'c'); + return Nd4j.create(new double[] { 1., 2., 3., 4., 5., 6., 7., 8. }, new int[] { mb, inDepth, inputHeight, inputWidth }, 'c'); } private INDArray getContainedOutput() { - return Nd4j.create(new double[] {1., 5., 2., 6., 3., 7., 4., 8.}, - new int[] {mb, outDepth, outputHeight, outputWidth}, 'c'); + return Nd4j.create(new double[] { 1., 5., 2., 6., 3., 7., 4., 8. }, new int[] { mb, outDepth, outputHeight, outputWidth }, 'c'); } private Layer getSpaceToDepthLayer() { - NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder() - .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).seed(123) - .layer(new SpaceToDepthLayer.Builder(blockSize, dataFormat).build()).build(); + NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).seed(123).layer(new SpaceToDepthLayer.Builder(blockSize, dataFormat).build()).build(); return conf.getLayer().instantiate(conf, null, 0, null, true, Nd4j.defaultFloatingPointType()); } @Test - public void testSpaceToDepthForward() throws Exception { + @DisplayName("Test Space To Depth Forward") + void testSpaceToDepthForward() throws Exception { INDArray containedInput = getContainedData(); INDArray containedExpectedOut = getContainedOutput(); Layer std = getSpaceToDepthLayer(); INDArray containedOutput = std.activate(containedInput, false, LayerWorkspaceMgr.noWorkspaces()); - assertTrue(Arrays.equals(containedExpectedOut.shape(), containedOutput.shape())); assertEquals(containedExpectedOut, containedOutput); } @Test - public void testSpaceToDepthBackward() throws Exception { + @DisplayName("Test Space To Depth Backward") + void testSpaceToDepthBackward() throws Exception { INDArray containedInputEpsilon = getContainedOutput(); - INDArray containedExpectedOut = getContainedData(); Layer std = getSpaceToDepthLayer(); - std.setInput(getContainedData(), LayerWorkspaceMgr.noWorkspaces()); INDArray containedOutput = std.backpropGradient(containedInputEpsilon, LayerWorkspaceMgr.noWorkspaces()).getRight(); - assertTrue(Arrays.equals(containedExpectedOut.shape(), containedOutput.shape())); assertEquals(containedExpectedOut, containedOutput); } -} \ No newline at end of file +} diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/SubsamplingLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/SubsamplingLayerTest.java index cde5b25cc..d16aeda08 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/SubsamplingLayerTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/SubsamplingLayerTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.layers.convolution; import org.deeplearning4j.BaseDL4JTest; @@ -34,7 +33,7 @@ import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.params.DefaultParamInitializer; import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -43,137 +42,127 @@ import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.common.primitives.Pair; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; - import java.util.Arrays; - -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; +import static org.junit.jupiter.api.Assertions.assertThrows; /** * @author Adam Gibson */ -public class SubsamplingLayerTest extends BaseDL4JTest { +@DisplayName("Subsampling Layer Test") +class SubsamplingLayerTest extends BaseDL4JTest { private int nExamples = 1; - private int depth = 20; //channels & nOut + + // channels & nOut + private int depth = 20; + private int nChannelsIn = 1; + private int inputWidth = 28; + private int inputHeight = 28; - private int[] kernelSize = new int[] {2, 2}; - private int[] stride = new int[] {2, 2}; + + private int[] kernelSize = new int[] { 2, 2 }; + + private int[] stride = new int[] { 2, 2 }; int featureMapWidth = (inputWidth - kernelSize[0]) / stride[0] + 1; + int featureMapHeight = (inputHeight - kernelSize[1]) / stride[0] + 1; + private INDArray epsilon = Nd4j.ones(nExamples, depth, featureMapHeight, featureMapWidth); @Override - public DataType getDataType(){ + public DataType getDataType() { return DataType.FLOAT; } @Test - public void testSubSampleMaxActivate() throws Exception { - INDArray containedExpectedOut = - Nd4j.create(new double[] {5., 7., 6., 8., 4., 7., 5., 9.}, new long[] {1, 2, 2, 2}).castTo(Nd4j.defaultFloatingPointType()); + @DisplayName("Test Sub Sample Max Activate") + void testSubSampleMaxActivate() throws Exception { + INDArray containedExpectedOut = Nd4j.create(new double[] { 5., 7., 6., 8., 4., 7., 5., 9. }, new long[] { 1, 2, 2, 2 }).castTo(Nd4j.defaultFloatingPointType()); INDArray containedInput = getContainedData(); INDArray input = getData(); Layer layer = getSubsamplingLayer(SubsamplingLayer.PoolingType.MAX); - INDArray containedOutput = layer.activate(containedInput, false, LayerWorkspaceMgr.noWorkspaces()); assertTrue(Arrays.equals(containedExpectedOut.shape(), containedOutput.shape())); assertEquals(containedExpectedOut, containedOutput); - INDArray output = layer.activate(input, false, LayerWorkspaceMgr.noWorkspaces()); - assertTrue(Arrays.equals(new long[] {nExamples, nChannelsIn, featureMapWidth, featureMapHeight}, - output.shape())); - assertEquals(nChannelsIn, output.size(1), 1e-4); // channels retained + assertTrue(Arrays.equals(new long[] { nExamples, nChannelsIn, featureMapWidth, featureMapHeight }, output.shape())); + // channels retained + assertEquals(nChannelsIn, output.size(1), 1e-4); } @Test - public void testSubSampleMeanActivate() throws Exception { - INDArray containedExpectedOut = - Nd4j.create(new double[] {2., 4., 3., 5., 3.5, 6.5, 4.5, 8.5}, new int[] {1, 2, 2, 2}).castTo(Nd4j.defaultFloatingPointType()); + @DisplayName("Test Sub Sample Mean Activate") + void testSubSampleMeanActivate() throws Exception { + INDArray containedExpectedOut = Nd4j.create(new double[] { 2., 4., 3., 5., 3.5, 6.5, 4.5, 8.5 }, new int[] { 1, 2, 2, 2 }).castTo(Nd4j.defaultFloatingPointType()); INDArray containedInput = getContainedData(); INDArray input = getData(); Layer layer = getSubsamplingLayer(SubsamplingLayer.PoolingType.AVG); - INDArray containedOutput = layer.activate(containedInput, false, LayerWorkspaceMgr.noWorkspaces()); assertTrue(Arrays.equals(containedExpectedOut.shape(), containedOutput.shape())); assertEquals(containedExpectedOut, containedOutput); - INDArray output = layer.activate(input, false, LayerWorkspaceMgr.noWorkspaces()); - assertTrue(Arrays.equals(new long[] {nExamples, nChannelsIn, featureMapWidth, featureMapHeight}, - output.shape())); - assertEquals(nChannelsIn, output.size(1), 1e-4); // channels retained + assertTrue(Arrays.equals(new long[] { nExamples, nChannelsIn, featureMapWidth, featureMapHeight }, output.shape())); + // channels retained + assertEquals(nChannelsIn, output.size(1), 1e-4); } - ////////////////////////////////////////////////////////////////////////////////// - + // //////////////////////////////////////////////////////////////////////////////// @Test - public void testSubSampleLayerMaxBackprop() throws Exception { - INDArray expectedContainedEpsilonInput = - Nd4j.create(new double[] {1., 1., 1., 1., 1., 1., 1., 1.}, new int[] {1, 2, 2, 2}).castTo(Nd4j.defaultFloatingPointType()); - - INDArray expectedContainedEpsilonResult = Nd4j.create(new double[] {0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 1., - 0., 0., 1., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0.}, - new int[] {1, 2, 4, 4}).castTo(Nd4j.defaultFloatingPointType()); - + @DisplayName("Test Sub Sample Layer Max Backprop") + void testSubSampleLayerMaxBackprop() throws Exception { + INDArray expectedContainedEpsilonInput = Nd4j.create(new double[] { 1., 1., 1., 1., 1., 1., 1., 1. }, new int[] { 1, 2, 2, 2 }).castTo(Nd4j.defaultFloatingPointType()); + INDArray expectedContainedEpsilonResult = Nd4j.create(new double[] { 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0. }, new int[] { 1, 2, 4, 4 }).castTo(Nd4j.defaultFloatingPointType()); INDArray input = getContainedData(); - Layer layer = getSubsamplingLayer(SubsamplingLayer.PoolingType.MAX); layer.activate(input, false, LayerWorkspaceMgr.noWorkspaces()); - Pair containedOutput = layer.backpropGradient(expectedContainedEpsilonInput, LayerWorkspaceMgr.noWorkspaces()); assertEquals(expectedContainedEpsilonResult, containedOutput.getSecond()); assertEquals(null, containedOutput.getFirst().getGradientFor("W")); assertEquals(expectedContainedEpsilonResult.shape().length, containedOutput.getSecond().shape().length); - INDArray input2 = getData(); layer.activate(input2, false, LayerWorkspaceMgr.noWorkspaces()); long depth = input2.size(1); - epsilon = Nd4j.ones(5, depth, featureMapHeight, featureMapWidth); - Pair out = layer.backpropGradient(epsilon, LayerWorkspaceMgr.noWorkspaces()); assertEquals(input.shape().length, out.getSecond().shape().length); - assertEquals(depth, out.getSecond().size(1)); // channels retained + // channels retained + assertEquals(depth, out.getSecond().size(1)); } @Test - public void testSubSampleLayerAvgBackprop() throws Exception { - INDArray expectedContainedEpsilonInput = - Nd4j.create(new double[] {1., 2., 3., 4., 5., 6., 7., 8.}, new int[] {1, 2, 2, 2}).castTo(Nd4j.defaultFloatingPointType()); - - INDArray expectedContainedEpsilonResult = Nd4j.create(new double[] {0.25, 0.25, 0.5, 0.5, 0.25, 0.25, 0.5, 0.5, - 0.75, 0.75, 1., 1., 0.75, 0.75, 1., 1., 1.25, 1.25, 1.5, 1.5, 1.25, 1.25, 1.5, 1.5, 1.75, 1.75, - 2., 2., 1.75, 1.75, 2., 2.}, new int[] {1, 2, 4, 4}).castTo(Nd4j.defaultFloatingPointType()); + @DisplayName("Test Sub Sample Layer Avg Backprop") + void testSubSampleLayerAvgBackprop() throws Exception { + INDArray expectedContainedEpsilonInput = Nd4j.create(new double[] { 1., 2., 3., 4., 5., 6., 7., 8. }, new int[] { 1, 2, 2, 2 }).castTo(Nd4j.defaultFloatingPointType()); + INDArray expectedContainedEpsilonResult = Nd4j.create(new double[] { 0.25, 0.25, 0.5, 0.5, 0.25, 0.25, 0.5, 0.5, 0.75, 0.75, 1., 1., 0.75, 0.75, 1., 1., 1.25, 1.25, 1.5, 1.5, 1.25, 1.25, 1.5, 1.5, 1.75, 1.75, 2., 2., 1.75, 1.75, 2., 2. }, new int[] { 1, 2, 4, 4 }).castTo(Nd4j.defaultFloatingPointType()); INDArray input = getContainedData(); - Layer layer = getSubsamplingLayer(SubsamplingLayer.PoolingType.AVG); layer.activate(input, false, LayerWorkspaceMgr.noWorkspaces()); - Pair containedOutput = layer.backpropGradient(expectedContainedEpsilonInput, LayerWorkspaceMgr.noWorkspaces()); assertEquals(expectedContainedEpsilonResult, containedOutput.getSecond()); assertEquals(null, containedOutput.getFirst().getGradientFor("W")); assertArrayEquals(expectedContainedEpsilonResult.shape(), containedOutput.getSecond().shape()); - } - - @Test(expected = UnsupportedOperationException.class) - public void testSubSampleLayerSumBackprop() throws Exception { - Layer layer = getSubsamplingLayer(SubsamplingLayer.PoolingType.SUM); - INDArray input = getData(); - layer.setInput(input, LayerWorkspaceMgr.noWorkspaces()); - layer.backpropGradient(epsilon, LayerWorkspaceMgr.noWorkspaces()); + @Test + @DisplayName("Test Sub Sample Layer Sum Backprop") + void testSubSampleLayerSumBackprop() { + assertThrows(UnsupportedOperationException.class, () -> { + Layer layer = getSubsamplingLayer(SubsamplingLayer.PoolingType.SUM); + INDArray input = getData(); + layer.setInput(input, LayerWorkspaceMgr.noWorkspaces()); + layer.backpropGradient(epsilon, LayerWorkspaceMgr.noWorkspaces()); + }); } - ////////////////////////////////////////////////////////////////////////////////// - + // //////////////////////////////////////////////////////////////////////////////// private Layer getSubsamplingLayer(SubsamplingLayer.PoolingType pooling) { - NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder() - .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).seed(123) - .layer(new SubsamplingLayer.Builder(pooling, new int[] {2, 2}).build()).build(); - + NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).seed(123).layer(new SubsamplingLayer.Builder(pooling, new int[] { 2, 2 }).build()).build(); return conf.getLayer().instantiate(conf, null, 0, null, true, Nd4j.defaultFloatingPointType()); } @@ -185,61 +174,40 @@ public class SubsamplingLayerTest extends BaseDL4JTest { } public INDArray getContainedData() { - INDArray ret = Nd4j.create(new double[] {1., 1., 3., 7., 5., 1., 3., 3., 2., 2., 8., 4., 2., 6., 4., 4., 3., 3., - 6., 7., 4., 4., 6., 7., 5., 5., 9., 8., 4., 4., 9., 8.}, new int[] {1, 2, 4, 4}).castTo(Nd4j.defaultFloatingPointType()); + INDArray ret = Nd4j.create(new double[] { 1., 1., 3., 7., 5., 1., 3., 3., 2., 2., 8., 4., 2., 6., 4., 4., 3., 3., 6., 7., 4., 4., 6., 7., 5., 5., 9., 8., 4., 4., 9., 8. }, new int[] { 1, 2, 4, 4 }).castTo(Nd4j.defaultFloatingPointType()); return ret; } private Gradient createPrevGradient() { Gradient gradient = new DefaultGradient(); INDArray pseudoGradients = Nd4j.ones(nExamples, nChannelsIn, inputHeight, inputWidth); - gradient.gradientForVariable().put(DefaultParamInitializer.BIAS_KEY, pseudoGradients); gradient.gradientForVariable().put(DefaultParamInitializer.WEIGHT_KEY, pseudoGradients); return gradient; } - ////////////////////////////////////////////////////////////////////////////////// - - @Test(expected = Exception.class) - public void testSubTooLargeKernel() { - int imageHeight = 20; - int imageWidth = 23; - int nChannels = 1; - int classes = 2; - int numSamples = 200; - - int kernelHeight = 3; - int kernelWidth = 3; - - DataSet trainInput; - MultiLayerConfiguration.Builder builder = - new NeuralNetConfiguration.Builder().seed(123).list() - .layer(0, new org.deeplearning4j.nn.conf.layers.ConvolutionLayer.Builder( - kernelHeight, kernelWidth).stride(1, 1).nOut(2) - .activation(Activation.RELU).weightInit( - WeightInit.XAVIER) - .build()) - .layer(1, new SubsamplingLayer.Builder() - .poolingType(SubsamplingLayer.PoolingType.MAX) - .kernelSize(imageHeight - kernelHeight + 2, 1) //imageHeight-kernelHeight+1 is ok: full height - .stride(1, 1).build()) - .layer(2, new OutputLayer.Builder().nOut(classes).weightInit(WeightInit.XAVIER) - .activation(Activation.SOFTMAX).build()) - - .setInputType(InputType.convolutional(imageHeight, imageWidth, nChannels)); - - MultiLayerConfiguration conf = builder.build(); - MultiLayerNetwork model = new MultiLayerNetwork(conf); - model.init(); - - INDArray emptyFeatures = Nd4j.zeros(numSamples, imageWidth * imageHeight * nChannels); - INDArray emptyLables = Nd4j.zeros(numSamples, classes); - - trainInput = new DataSet(emptyFeatures, emptyLables); - model.fit(trainInput); + // //////////////////////////////////////////////////////////////////////////////// + @Test + @DisplayName("Test Sub Too Large Kernel") + void testSubTooLargeKernel() { + assertThrows(Exception.class, () -> { + int imageHeight = 20; + int imageWidth = 23; + int nChannels = 1; + int classes = 2; + int numSamples = 200; + int kernelHeight = 3; + int kernelWidth = 3; + DataSet trainInput; + MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(123).list().layer(0, new org.deeplearning4j.nn.conf.layers.ConvolutionLayer.Builder(kernelHeight, kernelWidth).stride(1, 1).nOut(2).activation(Activation.RELU).weightInit(WeightInit.XAVIER).build()).layer(1, new SubsamplingLayer.Builder().poolingType(SubsamplingLayer.PoolingType.MAX).kernelSize(imageHeight - kernelHeight + 2, // imageHeight-kernelHeight+1 is ok: full height + 1).stride(1, 1).build()).layer(2, new OutputLayer.Builder().nOut(classes).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build()).setInputType(InputType.convolutional(imageHeight, imageWidth, nChannels)); + MultiLayerConfiguration conf = builder.build(); + MultiLayerNetwork model = new MultiLayerNetwork(conf); + model.init(); + INDArray emptyFeatures = Nd4j.zeros(numSamples, imageWidth * imageHeight * nChannels); + INDArray emptyLables = Nd4j.zeros(numSamples, classes); + trainInput = new DataSet(emptyFeatures, emptyLables); + model.fit(trainInput); + }); } - - - } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Upsampling1DTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Upsampling1DTest.java index 2e307b1db..cea528a36 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Upsampling1DTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Upsampling1DTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.layers.convolution; import lombok.val; @@ -28,91 +27,79 @@ import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.Upsampling1D; import org.deeplearning4j.nn.gradient.Gradient; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.common.primitives.Pair; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; - import java.util.Arrays; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** * @author Max Pumperla */ -public class Upsampling1DTest extends BaseDL4JTest { +@DisplayName("Upsampling 1 D Test") +class Upsampling1DTest extends BaseDL4JTest { private int nExamples = 1; + private int depth = 20; + private int nChannelsIn = 1; + private int inputLength = 28; + private int size = 2; + private int outputLength = inputLength * size; + private INDArray epsilon = Nd4j.ones(nExamples, depth, outputLength); - @Test - public void testUpsampling1D() throws Exception { - - double[] outArray = new double[] {1., 1., 2., 2., 3., 3., 4., 4.}; - INDArray containedExpectedOut = Nd4j.create(outArray, new int[] {1, 1, 8}); + @DisplayName("Test Upsampling 1 D") + void testUpsampling1D() throws Exception { + double[] outArray = new double[] { 1., 1., 2., 2., 3., 3., 4., 4. }; + INDArray containedExpectedOut = Nd4j.create(outArray, new int[] { 1, 1, 8 }); INDArray containedInput = getContainedData(); INDArray input = getData(); - Layer layer = getUpsampling1DLayer(); - + Layer layer = getUpsampling1DLayer(); INDArray containedOutput = layer.activate(containedInput, false, LayerWorkspaceMgr.noWorkspaces()); assertTrue(Arrays.equals(containedExpectedOut.shape(), containedOutput.shape())); assertEquals(containedExpectedOut, containedOutput); - INDArray output = layer.activate(input, false, LayerWorkspaceMgr.noWorkspaces()); - assertTrue(Arrays.equals(new long[] {nExamples, nChannelsIn, outputLength}, - output.shape())); + assertTrue(Arrays.equals(new long[] { nExamples, nChannelsIn, outputLength }, output.shape())); assertEquals(nChannelsIn, output.size(1), 1e-4); } - @Test - public void testUpsampling1DBackprop() throws Exception { - INDArray expectedContainedEpsilonInput = - Nd4j.create(new double[] {1., 3., 2., 6., 7., 2., 5., 5.}, - new int[] {1, 1, 8}); - - INDArray expectedContainedEpsilonResult = Nd4j.create(new double[] {4., 8., 9., 10.}, - new int[] {1, 1, 4}); - + @DisplayName("Test Upsampling 1 D Backprop") + void testUpsampling1DBackprop() throws Exception { + INDArray expectedContainedEpsilonInput = Nd4j.create(new double[] { 1., 3., 2., 6., 7., 2., 5., 5. }, new int[] { 1, 1, 8 }); + INDArray expectedContainedEpsilonResult = Nd4j.create(new double[] { 4., 8., 9., 10. }, new int[] { 1, 1, 4 }); INDArray input = getContainedData(); - Layer layer = getUpsampling1DLayer(); layer.activate(input, false, LayerWorkspaceMgr.noWorkspaces()); - Pair containedOutput = layer.backpropGradient(expectedContainedEpsilonInput, LayerWorkspaceMgr.noWorkspaces()); - assertEquals(expectedContainedEpsilonResult, containedOutput.getSecond()); assertEquals(null, containedOutput.getFirst().getGradientFor("W")); assertEquals(expectedContainedEpsilonResult.shape().length, containedOutput.getSecond().shape().length); - INDArray input2 = getData(); layer.activate(input2, false, LayerWorkspaceMgr.noWorkspaces()); val depth = input2.size(1); - epsilon = Nd4j.ones(5, depth, outputLength); - Pair out = layer.backpropGradient(epsilon, LayerWorkspaceMgr.noWorkspaces()); assertEquals(input.shape().length, out.getSecond().shape().length); assertEquals(depth, out.getSecond().size(1)); } - private Layer getUpsampling1DLayer() { - NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder() - .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).seed(123) - .layer(new Upsampling1D.Builder(size).build()).build(); - return conf.getLayer().instantiate(conf, null, 0, - null, true, Nd4j.defaultFloatingPointType()); + NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).seed(123).layer(new Upsampling1D.Builder(size).build()).build(); + return conf.getLayer().instantiate(conf, null, 0, null, true, Nd4j.defaultFloatingPointType()); } public INDArray getData() throws Exception { @@ -124,10 +111,7 @@ public class Upsampling1DTest extends BaseDL4JTest { } private INDArray getContainedData() { - INDArray ret = Nd4j.create - (new double[] {1., 2., 3., 4.}, - new int[] {1, 1, 4}); + INDArray ret = Nd4j.create(new double[] { 1., 2., 3., 4. }, new int[] { 1, 1, 4 }); return ret; } - } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Upsampling2DTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Upsampling2DTest.java index cc3d38c42..cb424b780 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Upsampling2DTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Upsampling2DTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.layers.convolution; import lombok.val; @@ -28,92 +27,81 @@ import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.Upsampling2D; import org.deeplearning4j.nn.gradient.Gradient; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.common.primitives.Pair; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; - import java.util.Arrays; - -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** * @author Max Pumperla */ -public class Upsampling2DTest extends BaseDL4JTest { +@DisplayName("Upsampling 2 D Test") +class Upsampling2DTest extends BaseDL4JTest { private int nExamples = 1; + private int depth = 20; + private int nChannelsIn = 1; + private int inputWidth = 28; + private int inputHeight = 28; private int size = 2; + private int outputWidth = inputWidth * size; + private int outputHeight = inputHeight * size; private INDArray epsilon = Nd4j.ones(nExamples, depth, outputHeight, outputWidth); - @Test - public void testUpsampling() throws Exception { - - double[] outArray = new double[] {1., 1., 2., 2., 1., 1., 2., 2., 3., 3., 4., 4., 3., 3., 4., 4.}; - INDArray containedExpectedOut = Nd4j.create(outArray, new int[] {1, 1, 4, 4}); + @DisplayName("Test Upsampling") + void testUpsampling() throws Exception { + double[] outArray = new double[] { 1., 1., 2., 2., 1., 1., 2., 2., 3., 3., 4., 4., 3., 3., 4., 4. }; + INDArray containedExpectedOut = Nd4j.create(outArray, new int[] { 1, 1, 4, 4 }); INDArray containedInput = getContainedData(); INDArray input = getData(); Layer layer = getUpsamplingLayer(); - INDArray containedOutput = layer.activate(containedInput, false, LayerWorkspaceMgr.noWorkspaces()); assertTrue(Arrays.equals(containedExpectedOut.shape(), containedOutput.shape())); assertEquals(containedExpectedOut, containedOutput); - INDArray output = layer.activate(input, false, LayerWorkspaceMgr.noWorkspaces()); - assertTrue(Arrays.equals(new long[] {nExamples, nChannelsIn, outputWidth, outputHeight}, - output.shape())); + assertTrue(Arrays.equals(new long[] { nExamples, nChannelsIn, outputWidth, outputHeight }, output.shape())); assertEquals(nChannelsIn, output.size(1), 1e-4); } - @Test - public void testUpsampling2DBackprop() throws Exception { - INDArray expectedContainedEpsilonInput = - Nd4j.create(new double[] {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}, - new int[] {1, 1, 4, 4}); - - INDArray expectedContainedEpsilonResult = Nd4j.create(new double[] {4., 4., 4., 4.}, - new int[] {1, 1, 2, 2}); - + @DisplayName("Test Upsampling 2 D Backprop") + void testUpsampling2DBackprop() throws Exception { + INDArray expectedContainedEpsilonInput = Nd4j.create(new double[] { 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1. }, new int[] { 1, 1, 4, 4 }); + INDArray expectedContainedEpsilonResult = Nd4j.create(new double[] { 4., 4., 4., 4. }, new int[] { 1, 1, 2, 2 }); INDArray input = getContainedData(); - Layer layer = getUpsamplingLayer(); layer.activate(input, false, LayerWorkspaceMgr.noWorkspaces()); - Pair containedOutput = layer.backpropGradient(expectedContainedEpsilonInput, LayerWorkspaceMgr.noWorkspaces()); - assertEquals(expectedContainedEpsilonResult, containedOutput.getSecond()); assertEquals(null, containedOutput.getFirst().getGradientFor("W")); assertEquals(expectedContainedEpsilonResult.shape().length, containedOutput.getSecond().shape().length); - INDArray input2 = getData(); layer.activate(input2, false, LayerWorkspaceMgr.noWorkspaces()); val depth = input2.size(1); - epsilon = Nd4j.ones(5, depth, outputHeight, outputWidth); - Pair out = layer.backpropGradient(epsilon, LayerWorkspaceMgr.noWorkspaces()); assertEquals(input.shape().length, out.getSecond().shape().length); assertEquals(depth, out.getSecond().size(1)); } - private Layer getUpsamplingLayer() { - NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder() - .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).seed(123) - .layer(new Upsampling2D.Builder(size).build()).build(); + NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).seed(123).layer(new Upsampling2D.Builder(size).build()).build(); return conf.getLayer().instantiate(conf, null, 0, null, true, Nd4j.defaultFloatingPointType()); } @@ -125,10 +113,7 @@ public class Upsampling2DTest extends BaseDL4JTest { } private INDArray getContainedData() { - INDArray ret = Nd4j.create - (new double[] {1., 2., 3., 4.}, - new int[] {1, 1, 2, 2}); + INDArray ret = Nd4j.create(new double[] { 1., 2., 3., 4. }, new int[] { 1, 1, 2, 2 }); return ret; } - } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/dense/DenseTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/dense/DenseTest.java index ec9ed319a..c07b50fe8 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/dense/DenseTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/dense/DenseTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.layers.feedforward.dense; import org.deeplearning4j.BaseDL4JTest; @@ -30,7 +29,7 @@ import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; @@ -38,105 +37,83 @@ import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.lossfunctions.LossFunctions; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.assertEquals; - -public class DenseTest extends BaseDL4JTest { +@DisplayName("Dense Test") +class DenseTest extends BaseDL4JTest { private int numSamples = 150; + private int batchSize = 150; + private DataSetIterator iter = new IrisDataSetIterator(batchSize, numSamples); + private DataSet data; @Test - public void testDenseBiasInit() { + @DisplayName("Test Dense Bias Init") + void testDenseBiasInit() { DenseLayer build = new DenseLayer.Builder().nIn(1).nOut(3).biasInit(1).build(); - NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().layer(build).build(); - long numParams = conf.getLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); Layer layer = conf.getLayer().instantiate(conf, null, 0, params, true, Nd4j.defaultFloatingPointType()); - assertEquals(1, layer.getParam("b").size(0)); } @Test - public void testMLPMultiLayerPretrain() { + @DisplayName("Test MLP Multi Layer Pretrain") + void testMLPMultiLayerPretrain() { // Note CNN does not do pretrain MultiLayerNetwork model = getDenseMLNConfig(false, true); model.fit(iter); - MultiLayerNetwork model2 = getDenseMLNConfig(false, true); model2.fit(iter); iter.reset(); - DataSet test = iter.next(); - assertEquals(model.params(), model2.params()); - Evaluation eval = new Evaluation(); INDArray output = model.output(test.getFeatures()); eval.eval(test.getLabels(), output); double f1Score = eval.f1(); - Evaluation eval2 = new Evaluation(); INDArray output2 = model2.output(test.getFeatures()); eval2.eval(test.getLabels(), output2); double f1Score2 = eval2.f1(); - assertEquals(f1Score, f1Score2, 1e-4); - } @Test - public void testMLPMultiLayerBackprop() { + @DisplayName("Test MLP Multi Layer Backprop") + void testMLPMultiLayerBackprop() { MultiLayerNetwork model = getDenseMLNConfig(true, false); model.fit(iter); - MultiLayerNetwork model2 = getDenseMLNConfig(true, false); model2.fit(iter); iter.reset(); - DataSet test = iter.next(); - assertEquals(model.params(), model2.params()); - Evaluation eval = new Evaluation(); INDArray output = model.output(test.getFeatures()); eval.eval(test.getLabels(), output); double f1Score = eval.f1(); - Evaluation eval2 = new Evaluation(); INDArray output2 = model2.output(test.getFeatures()); eval2.eval(test.getLabels(), output2); double f1Score2 = eval2.f1(); - assertEquals(f1Score, f1Score2, 1e-4); - } - - ////////////////////////////////////////////////////////////////////////////////// - + // //////////////////////////////////////////////////////////////////////////////// private static MultiLayerNetwork getDenseMLNConfig(boolean backprop, boolean pretrain) { int numInputs = 4; int outputNum = 3; long seed = 6; - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed) - .updater(new Sgd(1e-3)).l1(0.3).l2(1e-3).list() - .layer(0, new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(numInputs).nOut(3) - .activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()) - .layer(1, new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(3).nOut(2) - .activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()) - .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .weightInit(WeightInit.XAVIER).nIn(2).nOut(outputNum).activation(Activation.SOFTMAX).build()) - .build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed).updater(new Sgd(1e-3)).l1(0.3).l2(1e-3).list().layer(0, new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(numInputs).nOut(3).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()).layer(1, new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(3).nOut(2).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()).layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER).nIn(2).nOut(outputNum).activation(Activation.SOFTMAX).build()).build(); MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); return model; - } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayerTest.java index 8dfae43ca..940aa4e1b 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayerTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayerTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.layers.feedforward.embedding; import lombok.EqualsAndHashCode; @@ -38,7 +37,7 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.params.DefaultParamInitializer; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.embeddings.EmbeddingInitializer; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.impl.ActivationIdentity; import org.nd4j.linalg.api.buffer.DataType; @@ -46,191 +45,136 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.lossfunctions.LossFunctions; - import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.Random; +import static org.junit.jupiter.api.Assertions.*; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.*; - -public class EmbeddingLayerTest extends BaseDL4JTest { +@DisplayName("Embedding Layer Test") +class EmbeddingLayerTest extends BaseDL4JTest { @Test - public void testEmbeddingLayerConfig() { - - for (boolean hasBias : new boolean[]{true, false}) { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().activation(Activation.TANH).list() - .layer(0, new EmbeddingLayer.Builder().hasBias(hasBias).nIn(10).nOut(5).build()) - .layer(1, new OutputLayer.Builder().nIn(5).nOut(4).activation(Activation.SOFTMAX).build()) - .build(); - + @DisplayName("Test Embedding Layer Config") + void testEmbeddingLayerConfig() { + for (boolean hasBias : new boolean[] { true, false }) { + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().activation(Activation.TANH).list().layer(0, new EmbeddingLayer.Builder().hasBias(hasBias).nIn(10).nOut(5).build()).layer(1, new OutputLayer.Builder().nIn(5).nOut(4).activation(Activation.SOFTMAX).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - Layer l0 = net.getLayer(0); - assertEquals(org.deeplearning4j.nn.layers.feedforward.embedding.EmbeddingLayer.class, l0.getClass()); assertEquals(10, ((FeedForwardLayer) l0.conf().getLayer()).getNIn()); assertEquals(5, ((FeedForwardLayer) l0.conf().getLayer()).getNOut()); - INDArray weights = l0.getParam(DefaultParamInitializer.WEIGHT_KEY); INDArray bias = l0.getParam(DefaultParamInitializer.BIAS_KEY); - assertArrayEquals(new long[]{10, 5}, weights.shape()); + assertArrayEquals(new long[] { 10, 5 }, weights.shape()); if (hasBias) { - assertArrayEquals(new long[]{1, 5}, bias.shape()); + assertArrayEquals(new long[] { 1, 5 }, bias.shape()); } } } @Test - public void testEmbeddingSequenceLayerConfig() { - + @DisplayName("Test Embedding Sequence Layer Config") + void testEmbeddingSequenceLayerConfig() { int inputLength = 6; int nIn = 10; int embeddingDim = 5; int nout = 4; - - for (boolean hasBias : new boolean[]{true, false}) { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().activation(Activation.TANH).list() - .layer(new EmbeddingSequenceLayer.Builder().hasBias(hasBias) - .inputLength(inputLength).nIn(nIn).nOut(embeddingDim).build()) - .layer(new RnnOutputLayer.Builder().nIn(embeddingDim).nOut(nout).activation(Activation.SOFTMAX).build()) - .build(); - + for (boolean hasBias : new boolean[] { true, false }) { + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().activation(Activation.TANH).list().layer(new EmbeddingSequenceLayer.Builder().hasBias(hasBias).inputLength(inputLength).nIn(nIn).nOut(embeddingDim).build()).layer(new RnnOutputLayer.Builder().nIn(embeddingDim).nOut(nout).activation(Activation.SOFTMAX).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - Layer l0 = net.getLayer(0); - assertEquals(org.deeplearning4j.nn.layers.feedforward.embedding.EmbeddingSequenceLayer.class, l0.getClass()); assertEquals(10, ((FeedForwardLayer) l0.conf().getLayer()).getNIn()); assertEquals(5, ((FeedForwardLayer) l0.conf().getLayer()).getNOut()); - INDArray weights = l0.getParam(DefaultParamInitializer.WEIGHT_KEY); INDArray bias = l0.getParam(DefaultParamInitializer.BIAS_KEY); - assertArrayEquals(new long[]{10, 5}, weights.shape()); + assertArrayEquals(new long[] { 10, 5 }, weights.shape()); if (hasBias) { - assertArrayEquals(new long[]{1, 5}, bias.shape()); + assertArrayEquals(new long[] { 1, 5 }, bias.shape()); } } } @Test - public void testEmbeddingLongerSequencesForwardPass() { - + @DisplayName("Test Embedding Longer Sequences Forward Pass") + void testEmbeddingLongerSequencesForwardPass() { int nClassesIn = 10; int inputLength = 6; int embeddingDim = 5; int nOut = 4; - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().activation(Activation.TANH).list() - .layer(new EmbeddingSequenceLayer.Builder().inputLength(inputLength) - .hasBias(true).nIn(nClassesIn).nOut(embeddingDim).build()) - .layer(new RnnOutputLayer.Builder().nIn(embeddingDim).nOut(nOut).activation(Activation.SOFTMAX).build()) - .build(); - - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().activation(Activation.TANH).list().layer(new EmbeddingSequenceLayer.Builder().inputLength(inputLength).hasBias(true).nIn(nClassesIn).nOut(embeddingDim).build()).layer(new RnnOutputLayer.Builder().nIn(embeddingDim).nOut(nOut).activation(Activation.SOFTMAX).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - int batchSize = 3; - INDArray inEmbedding = Nd4j.create(batchSize, inputLength); - Random r = new Random(12345); for (int i = 0; i < batchSize; i++) { int classIdx = r.nextInt(nClassesIn); inEmbedding.putScalar(i, classIdx); } - INDArray output = net.output(inEmbedding); - - assertArrayEquals(new long[]{batchSize, nOut, inputLength}, output.shape()); + assertArrayEquals(new long[] { batchSize, nOut, inputLength }, output.shape()); } @Test - public void testEmbeddingSingleSequenceForwardPass() { + @DisplayName("Test Embedding Single Sequence Forward Pass") + void testEmbeddingSingleSequenceForwardPass() { int nClassesIn = 10; int embeddingDim = 5; int nOut = 4; - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().activation(Activation.TANH).list() - .layer(new EmbeddingSequenceLayer.Builder().inputLength(1) - .hasBias(true).nIn(nClassesIn).nOut(embeddingDim).build()) - .layer(new RnnOutputLayer.Builder().nIn(embeddingDim).nOut(nOut).activation(Activation.SOFTMAX).build()) - .build(); - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().activation(Activation.TANH).list() - .layer(0, new DenseLayer.Builder().nIn(nClassesIn).nOut(5).activation(Activation.IDENTITY).build()) - .layer(1, new OutputLayer.Builder().nIn(5).nOut(4).activation(Activation.SOFTMAX).build()) - .inputPreProcessor(0, new RnnToFeedForwardPreProcessor()) - .build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().activation(Activation.TANH).list().layer(new EmbeddingSequenceLayer.Builder().inputLength(1).hasBias(true).nIn(nClassesIn).nOut(embeddingDim).build()).layer(new RnnOutputLayer.Builder().nIn(embeddingDim).nOut(nOut).activation(Activation.SOFTMAX).build()).build(); + MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().activation(Activation.TANH).list().layer(0, new DenseLayer.Builder().nIn(nClassesIn).nOut(5).activation(Activation.IDENTITY).build()).layer(1, new OutputLayer.Builder().nIn(5).nOut(4).activation(Activation.SOFTMAX).build()).inputPreProcessor(0, new RnnToFeedForwardPreProcessor()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); net.init(); net2.init(); - net2.setParams(net.params().dup()); - int batchSize = 3; INDArray inEmbedding = Nd4j.create(batchSize, 1); INDArray inOneHot = Nd4j.create(batchSize, nClassesIn, 1); - Random r = new Random(12345); for (int i = 0; i < batchSize; i++) { int classIdx = r.nextInt(nClassesIn); inEmbedding.putScalar(i, classIdx); - inOneHot.putScalar(new int[]{i, classIdx, 0}, 1.0); + inOneHot.putScalar(new int[] { i, classIdx, 0 }, 1.0); } - List activationsDense = net2.feedForward(inOneHot, false); List activationEmbedding = net.feedForward(inEmbedding, false); - INDArray actD1 = activationsDense.get(1); INDArray actE1 = activationEmbedding.get(1).reshape(batchSize, embeddingDim); assertEquals(actD1, actE1); - - INDArray actD2 = activationsDense.get(2); INDArray actE2 = activationEmbedding.get(2).reshape(batchSize, nOut); assertEquals(actD2, actE2); } @Test - public void testEmbeddingForwardPass() { - //With the same parameters, embedding layer should have same activations as the equivalent one-hot representation + @DisplayName("Test Embedding Forward Pass") + void testEmbeddingForwardPass() { + // With the same parameters, embedding layer should have same activations as the equivalent one-hot representation // input with a DenseLayer - int nClassesIn = 10; - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().activation(Activation.TANH).list() - .layer(0, new EmbeddingLayer.Builder().hasBias(true).nIn(nClassesIn).nOut(5).build()) - .layer(1, new OutputLayer.Builder().nIn(5).nOut(4).activation(Activation.SOFTMAX).build()) - .build(); - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().activation(Activation.TANH).list() - .layer(0, new DenseLayer.Builder().nIn(nClassesIn).nOut(5).activation(Activation.IDENTITY).build()) - .layer(1, new OutputLayer.Builder().nIn(5).nOut(4).activation(Activation.SOFTMAX).build()) - .build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().activation(Activation.TANH).list().layer(0, new EmbeddingLayer.Builder().hasBias(true).nIn(nClassesIn).nOut(5).build()).layer(1, new OutputLayer.Builder().nIn(5).nOut(4).activation(Activation.SOFTMAX).build()).build(); + MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().activation(Activation.TANH).list().layer(0, new DenseLayer.Builder().nIn(nClassesIn).nOut(5).activation(Activation.IDENTITY).build()).layer(1, new OutputLayer.Builder().nIn(5).nOut(4).activation(Activation.SOFTMAX).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); net.init(); net2.init(); - net2.setParams(net.params().dup()); - int batchSize = 3; INDArray inEmbedding = Nd4j.create(batchSize, 1); INDArray inOneHot = Nd4j.create(batchSize, nClassesIn); - Random r = new Random(12345); for (int i = 0; i < batchSize; i++) { int classIdx = r.nextInt(nClassesIn); inEmbedding.putScalar(i, classIdx); - inOneHot.putScalar(new int[]{i, classIdx}, 1.0); + inOneHot.putScalar(new int[] { i, classIdx }, 1.0); } - List activationsEmbedding = net.feedForward(inEmbedding, false); List activationsDense = net2.feedForward(inOneHot, false); for (int i = 1; i < 3; i++) { @@ -241,277 +185,168 @@ public class EmbeddingLayerTest extends BaseDL4JTest { } @Test - public void testEmbeddingBackwardPass() { - //With the same parameters, embedding layer should have same activations as the equivalent one-hot representation + @DisplayName("Test Embedding Backward Pass") + void testEmbeddingBackwardPass() { + // With the same parameters, embedding layer should have same activations as the equivalent one-hot representation // input with a DenseLayer - int nClassesIn = 10; - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().activation(Activation.TANH).list() - .layer(0, new EmbeddingLayer.Builder().hasBias(true).nIn(nClassesIn).nOut(5).build()).layer(1, - new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(5).nOut(4) - .activation(Activation.SOFTMAX).build()) - .build(); - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().activation(Activation.TANH) - .weightInit(WeightInit.XAVIER).list() - .layer(new DenseLayer.Builder().nIn(nClassesIn).nOut(5).activation(Activation.IDENTITY).build()) - .layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(5).nOut(4) - .activation(Activation.SOFTMAX).build()) - .build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().activation(Activation.TANH).list().layer(0, new EmbeddingLayer.Builder().hasBias(true).nIn(nClassesIn).nOut(5).build()).layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(5).nOut(4).activation(Activation.SOFTMAX).build()).build(); + MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().activation(Activation.TANH).weightInit(WeightInit.XAVIER).list().layer(new DenseLayer.Builder().nIn(nClassesIn).nOut(5).activation(Activation.IDENTITY).build()).layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(5).nOut(4).activation(Activation.SOFTMAX).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); net.init(); net2.init(); - net2.setParams(net.params().dup()); - int batchSize = 3; INDArray inEmbedding = Nd4j.create(batchSize, 1); INDArray inOneHot = Nd4j.create(batchSize, nClassesIn); INDArray outLabels = Nd4j.create(batchSize, 4); - Random r = new Random(12345); for (int i = 0; i < batchSize; i++) { int classIdx = r.nextInt(nClassesIn); inEmbedding.putScalar(i, classIdx); - inOneHot.putScalar(new int[]{i, classIdx}, 1.0); - + inOneHot.putScalar(new int[] { i, classIdx }, 1.0); int labelIdx = r.nextInt(4); - outLabels.putScalar(new int[]{i, labelIdx}, 1.0); + outLabels.putScalar(new int[] { i, labelIdx }, 1.0); } - net.setInput(inEmbedding); net2.setInput(inOneHot); net.setLabels(outLabels); net2.setLabels(outLabels); - net.computeGradientAndScore(); net2.computeGradientAndScore(); - assertEquals(net2.score(), net.score(), 1e-6); - Map gradient = net.gradient().gradientForVariable(); Map gradient2 = net2.gradient().gradientForVariable(); assertEquals(gradient.size(), gradient2.size()); - for (String s : gradient.keySet()) { assertEquals(gradient2.get(s), gradient.get(s)); } } - @Test - public void testEmbeddingSequenceBackwardPass() { + @DisplayName("Test Embedding Sequence Backward Pass") + void testEmbeddingSequenceBackwardPass() { int nClassesIn = 10; int embeddingDim = 5; int nOut = 4; int inputLength = 1; - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().activation(Activation.TANH).list() - .layer(new EmbeddingSequenceLayer.Builder().inputLength(inputLength) - .hasBias(true).nIn(nClassesIn).nOut(embeddingDim).build()) - .layer(new RnnOutputLayer.Builder().nIn(embeddingDim).nOut(nOut).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.recurrent(nClassesIn,inputLength,RNNFormat.NCW)) - .build(); - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().activation(Activation.TANH).list() - .layer(new DenseLayer.Builder().nIn(nClassesIn).nOut(embeddingDim).activation(Activation.IDENTITY).build()) - .layer(new RnnOutputLayer.Builder().nIn(embeddingDim).nOut(nOut).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.recurrent(nClassesIn,inputLength,RNNFormat.NCW)) - .build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().activation(Activation.TANH).list().layer(new EmbeddingSequenceLayer.Builder().inputLength(inputLength).hasBias(true).nIn(nClassesIn).nOut(embeddingDim).build()).layer(new RnnOutputLayer.Builder().nIn(embeddingDim).nOut(nOut).activation(Activation.SOFTMAX).build()).setInputType(InputType.recurrent(nClassesIn, inputLength, RNNFormat.NCW)).build(); + MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().activation(Activation.TANH).list().layer(new DenseLayer.Builder().nIn(nClassesIn).nOut(embeddingDim).activation(Activation.IDENTITY).build()).layer(new RnnOutputLayer.Builder().nIn(embeddingDim).nOut(nOut).activation(Activation.SOFTMAX).build()).setInputType(InputType.recurrent(nClassesIn, inputLength, RNNFormat.NCW)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); net.init(); net2.init(); - net2.setParams(net.params().dup()); - int batchSize = 3; INDArray inEmbedding = Nd4j.create(batchSize, 1); INDArray inOneHot = Nd4j.create(batchSize, nClassesIn, 1); INDArray outLabels = Nd4j.create(batchSize, 4, 1); - Random r = new Random(1337); for (int i = 0; i < batchSize; i++) { int classIdx = r.nextInt(nClassesIn); inEmbedding.putScalar(i, classIdx); - inOneHot.putScalar(new int[]{i, classIdx, 0}, 1.0); - + inOneHot.putScalar(new int[] { i, classIdx, 0 }, 1.0); int labelIdx = r.nextInt(4); - outLabels.putScalar(new int[]{i, labelIdx, 0}, 1.0); + outLabels.putScalar(new int[] { i, labelIdx, 0 }, 1.0); } - net.setInput(inEmbedding); net2.setInput(inOneHot); net.setLabels(outLabels); net2.setLabels(outLabels); - net.computeGradientAndScore(); net2.computeGradientAndScore(); - -// System.out.println(net.score() + "\t" + net2.score()); + // System.out.println(net.score() + "\t" + net2.score()); assertEquals(net2.score(), net.score(), 1e-6); - Map gradient = net.gradient().gradientForVariable(); Map gradient2 = net2.gradient().gradientForVariable(); assertEquals(gradient.size(), gradient2.size()); - for (String s : gradient.keySet()) { assertEquals(gradient2.get(s), gradient.get(s)); } } @Test - public void testEmbeddingLayerRNN() { + @DisplayName("Test Embedding Layer RNN") + void testEmbeddingLayerRNN() { int nClassesIn = 10; int batchSize = 3; int timeSeriesLength = 8; - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().activation(Activation.TANH) - .dataType(DataType.DOUBLE) - .list() - .layer(0, new EmbeddingLayer.Builder().hasBias(true).nIn(nClassesIn).nOut(5).build()) - .layer(1, new LSTM.Builder().nIn(5).nOut(7).activation(Activation.SOFTSIGN).build()) - .layer(2, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(7).nOut(4) - .activation(Activation.SOFTMAX).build()) - .inputPreProcessor(0, new RnnToFeedForwardPreProcessor()) - .inputPreProcessor(1, new FeedForwardToRnnPreProcessor()) - .setInputType(InputType.recurrent(nClassesIn,timeSeriesLength, RNNFormat.NCW)) - .build(); - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().activation(Activation.TANH) - .weightInit(WeightInit.XAVIER) - .dataType(DataType.DOUBLE) - .list() - .layer(0, new DenseLayer.Builder().nIn(nClassesIn).nOut(5).activation(Activation.IDENTITY).build()) - .layer(1, new LSTM.Builder().nIn(5).nOut(7).activation(Activation.SOFTSIGN).build()) - .layer(2, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(7).nOut(4) - .activation(Activation.SOFTMAX).build()) - .inputPreProcessor(0, new RnnToFeedForwardPreProcessor()) - .inputPreProcessor(1, new FeedForwardToRnnPreProcessor()) - .setInputType(InputType.recurrent(nClassesIn,timeSeriesLength, RNNFormat.NCW)) - .build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().activation(Activation.TANH).dataType(DataType.DOUBLE).list().layer(0, new EmbeddingLayer.Builder().hasBias(true).nIn(nClassesIn).nOut(5).build()).layer(1, new LSTM.Builder().nIn(5).nOut(7).activation(Activation.SOFTSIGN).build()).layer(2, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(7).nOut(4).activation(Activation.SOFTMAX).build()).inputPreProcessor(0, new RnnToFeedForwardPreProcessor()).inputPreProcessor(1, new FeedForwardToRnnPreProcessor()).setInputType(InputType.recurrent(nClassesIn, timeSeriesLength, RNNFormat.NCW)).build(); + MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().activation(Activation.TANH).weightInit(WeightInit.XAVIER).dataType(DataType.DOUBLE).list().layer(0, new DenseLayer.Builder().nIn(nClassesIn).nOut(5).activation(Activation.IDENTITY).build()).layer(1, new LSTM.Builder().nIn(5).nOut(7).activation(Activation.SOFTSIGN).build()).layer(2, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(7).nOut(4).activation(Activation.SOFTMAX).build()).inputPreProcessor(0, new RnnToFeedForwardPreProcessor()).inputPreProcessor(1, new FeedForwardToRnnPreProcessor()).setInputType(InputType.recurrent(nClassesIn, timeSeriesLength, RNNFormat.NCW)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); net.init(); net2.init(); - net2.setParams(net.params().dup()); - - ; + ; INDArray inEmbedding = Nd4j.create(batchSize, 1, timeSeriesLength); INDArray inOneHot = Nd4j.create(batchSize, nClassesIn, timeSeriesLength); INDArray outLabels = Nd4j.create(batchSize, 4, timeSeriesLength); - Random r = new Random(12345); for (int i = 0; i < batchSize; i++) { for (int j = 0; j < timeSeriesLength; j++) { int classIdx = r.nextInt(nClassesIn); - inEmbedding.putScalar(new int[]{i, 0, j}, classIdx); - inOneHot.putScalar(new int[]{i, classIdx, j}, 1.0); - + inEmbedding.putScalar(new int[] { i, 0, j }, classIdx); + inOneHot.putScalar(new int[] { i, classIdx, j }, 1.0); int labelIdx = r.nextInt(4); - outLabels.putScalar(new int[]{i, labelIdx, j}, 1.0); + outLabels.putScalar(new int[] { i, labelIdx, j }, 1.0); } } - net.setInput(inEmbedding); net2.setInput(inOneHot); net.setLabels(outLabels); net2.setLabels(outLabels); - net.computeGradientAndScore(); net2.computeGradientAndScore(); - -// System.out.println(net.score() + "\t" + net2.score()); + // System.out.println(net.score() + "\t" + net2.score()); assertEquals(net2.score(), net.score(), 1e-5); - Map gradient = net.gradient().gradientForVariable(); Map gradient2 = net2.gradient().gradientForVariable(); assertEquals(gradient.size(), gradient2.size()); - for (String s : gradient.keySet()) { assertEquals(gradient2.get(s), gradient.get(s)); } - } @Test - public void testEmbeddingLayerWithMasking() { - //Idea: have masking on the input with an embedding and dense layers on input - //Ensure that the parameter gradients for the inputs don't depend on the inputs when inputs are masked - - int[] miniBatchSizes = {1, 2, 5}; + @DisplayName("Test Embedding Layer With Masking") + void testEmbeddingLayerWithMasking() { + // Idea: have masking on the input with an embedding and dense layers on input + // Ensure that the parameter gradients for the inputs don't depend on the inputs when inputs are masked + int[] miniBatchSizes = { 1, 2, 5 }; int nIn = 2; Random r = new Random(12345); - int numInputClasses = 10; int timeSeriesLength = 5; - - for (DataType maskDtype : new DataType[]{DataType.FLOAT, DataType.DOUBLE, DataType.INT}) { + for (DataType maskDtype : new DataType[] { DataType.FLOAT, DataType.DOUBLE, DataType.INT }) { for (int nExamples : miniBatchSizes) { Nd4j.getRandom().setSeed(12345); - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .updater(new Sgd(0.1)).seed(12345).list() - .layer(0, new EmbeddingLayer.Builder().hasBias(true).activation(Activation.TANH).nIn(numInputClasses) - .nOut(5).build()) - .layer(1, new DenseLayer.Builder().activation(Activation.TANH).nIn(5).nOut(4).build()) - .layer(2, new LSTM.Builder().activation(Activation.TANH).nIn(4).nOut(3).build()) - .layer(3, new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(3) - .nOut(4).build()) - .inputPreProcessor(0, new RnnToFeedForwardPreProcessor()) - .inputPreProcessor(2, new FeedForwardToRnnPreProcessor()) - .setInputType(InputType.recurrent(numInputClasses,timeSeriesLength, RNNFormat.NCW)) - .build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Sgd(0.1)).seed(12345).list().layer(0, new EmbeddingLayer.Builder().hasBias(true).activation(Activation.TANH).nIn(numInputClasses).nOut(5).build()).layer(1, new DenseLayer.Builder().activation(Activation.TANH).nIn(5).nOut(4).build()).layer(2, new LSTM.Builder().activation(Activation.TANH).nIn(4).nOut(3).build()).layer(3, new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(3).nOut(4).build()).inputPreProcessor(0, new RnnToFeedForwardPreProcessor()).inputPreProcessor(2, new FeedForwardToRnnPreProcessor()).setInputType(InputType.recurrent(numInputClasses, timeSeriesLength, RNNFormat.NCW)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .updater(new Sgd(0.1)).seed(12345).list() - .layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(numInputClasses).nOut(5) - .build()) - .layer(1, new DenseLayer.Builder().activation(Activation.TANH).nIn(5).nOut(4).build()) - .layer(2, new LSTM.Builder().activation(Activation.TANH).nIn(4).nOut(3).build()) - .layer(3, new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(3) - .nOut(4).build()) - .inputPreProcessor(0, new RnnToFeedForwardPreProcessor()) - .inputPreProcessor(2, new FeedForwardToRnnPreProcessor()) - .setInputType(InputType.recurrent(numInputClasses,timeSeriesLength, RNNFormat.NCW)) - .build(); - + MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Sgd(0.1)).seed(12345).list().layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(numInputClasses).nOut(5).build()).layer(1, new DenseLayer.Builder().activation(Activation.TANH).nIn(5).nOut(4).build()).layer(2, new LSTM.Builder().activation(Activation.TANH).nIn(4).nOut(3).build()).layer(3, new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(3).nOut(4).build()).inputPreProcessor(0, new RnnToFeedForwardPreProcessor()).inputPreProcessor(2, new FeedForwardToRnnPreProcessor()).setInputType(InputType.recurrent(numInputClasses, timeSeriesLength, RNNFormat.NCW)).build(); MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); net2.init(); - net2.setParams(net.params().dup()); - INDArray inEmbedding = Nd4j.zeros(nExamples, 1, timeSeriesLength); INDArray inDense = Nd4j.zeros(nExamples, numInputClasses, timeSeriesLength); - INDArray labels = Nd4j.zeros(nExamples, 4, timeSeriesLength); - for (int i = 0; i < nExamples; i++) { for (int j = 0; j < timeSeriesLength; j++) { int inIdx = r.nextInt(numInputClasses); - inEmbedding.putScalar(new int[]{i, 0, j}, inIdx); - inDense.putScalar(new int[]{i, inIdx, j}, 1.0); - + inEmbedding.putScalar(new int[] { i, 0, j }, inIdx); + inDense.putScalar(new int[] { i, inIdx, j }, 1.0); int outIdx = r.nextInt(4); - labels.putScalar(new int[]{i, outIdx, j}, 1.0); + labels.putScalar(new int[] { i, outIdx, j }, 1.0); } } - INDArray inputMask = Nd4j.zeros(maskDtype, nExamples, timeSeriesLength); for (int i = 0; i < nExamples; i++) { for (int j = 0; j < timeSeriesLength; j++) { - inputMask.putScalar(new int[]{i, j}, (r.nextBoolean() ? 1.0 : 0.0)); + inputMask.putScalar(new int[] { i, j }, (r.nextBoolean() ? 1.0 : 0.0)); } } - net.setLayerMaskArrays(inputMask, null); net2.setLayerMaskArrays(inputMask, null); List actEmbedding = net.feedForward(inEmbedding, false); @@ -519,15 +354,12 @@ public class EmbeddingLayerTest extends BaseDL4JTest { for (int i = 1; i < actEmbedding.size(); i++) { assertEquals(actDense.get(i), actEmbedding.get(i)); } - net.setLabels(labels); net2.setLabels(labels); net.computeGradientAndScore(); net2.computeGradientAndScore(); - -// System.out.println(net.score() + "\t" + net2.score()); + // System.out.println(net.score() + "\t" + net2.score()); assertEquals(net2.score(), net.score(), 1e-5); - Map gradients = net.gradient().gradientForVariable(); Map gradients2 = net2.gradient().gradientForVariable(); assertEquals(gradients.keySet(), gradients2.keySet()); @@ -538,151 +370,93 @@ public class EmbeddingLayerTest extends BaseDL4JTest { } } - @Test - public void testW2VInits(){ + @DisplayName("Test W 2 V Inits") + void testW2VInits() { Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); - - for( int i=0; i<2; i++ ) { - - INDArray vectors = Nd4j.linspace(1,15,15, DataType.FLOAT).reshape(5,3); - + for (int i = 0; i < 2; i++) { + INDArray vectors = Nd4j.linspace(1, 15, 15, DataType.FLOAT).reshape(5, 3); EmbeddingLayer el; - if(i == 0){ + if (i == 0) { el = new EmbeddingLayer.Builder().weightInit(vectors).build(); } else { el = new EmbeddingLayer.Builder().weightInit(new WordVectorsMockup()).build(); } - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .seed(12345).list() - .layer(el) - .layer(new DenseLayer.Builder().activation(Activation.TANH).nIn(3).nOut(3).build()) - .layer(new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(3) - .nOut(4).build()) - .build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).list().layer(el).layer(new DenseLayer.Builder().activation(Activation.TANH).nIn(3).nOut(3).build()).layer(new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(3).nOut(4).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - INDArray w = net.getParam("0_W"); assertEquals(vectors, w); - TestUtils.testModelSerialization(net); - - //Test same thing for embedding sequence layer: + // Test same thing for embedding sequence layer: EmbeddingSequenceLayer esl; - if(i == 0){ + if (i == 0) { esl = new EmbeddingSequenceLayer.Builder().weightInit(vectors).build(); } else { esl = new EmbeddingSequenceLayer.Builder().weightInit(new WordVectorsMockup()).build(); } - - conf = new NeuralNetConfiguration.Builder() - .seed(12345).list() - .layer(esl) - .layer(new GlobalPoolingLayer()) - .layer(new DenseLayer.Builder().activation(Activation.TANH).nIn(3).nOut(3).build()) - .layer(new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(3) - .nOut(4).build()) - .build(); - + conf = new NeuralNetConfiguration.Builder().seed(12345).list().layer(esl).layer(new GlobalPoolingLayer()).layer(new DenseLayer.Builder().activation(Activation.TANH).nIn(3).nOut(3).build()).layer(new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(3).nOut(4).build()).build(); net = new MultiLayerNetwork(conf); net.init(); - w = net.getParam("0_W"); assertEquals(vectors, w); - TestUtils.testModelSerialization(net); } } @Test - public void testEmbeddingSequenceLayerWithMasking() { - //Idea: have masking on the input with an embedding and dense layers on input - //Ensure that the parameter gradients for the inputs don't depend on the inputs when inputs are masked - - int[] miniBatchSizes = {1, 3}; + @DisplayName("Test Embedding Sequence Layer With Masking") + void testEmbeddingSequenceLayerWithMasking() { + // Idea: have masking on the input with an embedding and dense layers on input + // Ensure that the parameter gradients for the inputs don't depend on the inputs when inputs are masked + int[] miniBatchSizes = { 1, 3 }; int nIn = 2; Random r = new Random(12345); - int numInputClasses = 10; int timeSeriesLength = 5; - - for (DataType maskDtype : new DataType[]{DataType.FLOAT, DataType.DOUBLE, DataType.INT}) { - for (DataType inLabelDtype : new DataType[]{DataType.FLOAT, DataType.DOUBLE, DataType.INT}) { - for(int inputRank : new int[]{2, 3}) { + for (DataType maskDtype : new DataType[] { DataType.FLOAT, DataType.DOUBLE, DataType.INT }) { + for (DataType inLabelDtype : new DataType[] { DataType.FLOAT, DataType.DOUBLE, DataType.INT }) { + for (int inputRank : new int[] { 2, 3 }) { for (int nExamples : miniBatchSizes) { Nd4j.getRandom().setSeed(12345); - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .updater(new Sgd(0.1)).seed(12345).list() - .layer(0, new EmbeddingSequenceLayer.Builder().hasBias(true).activation(Activation.TANH).nIn(numInputClasses) - .nOut(5).build()) - .layer(1, new DenseLayer.Builder().activation(Activation.TANH).nIn(5).nOut(4).build()) - .layer(2, new LSTM.Builder().activation(Activation.TANH).nIn(4).nOut(3).build()) - .layer(3, new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(3) - .nOut(4).build()) - .setInputType(InputType.recurrent(numInputClasses,timeSeriesLength,RNNFormat.NCW)).build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Sgd(0.1)).seed(12345).list().layer(0, new EmbeddingSequenceLayer.Builder().hasBias(true).activation(Activation.TANH).nIn(numInputClasses).nOut(5).build()).layer(1, new DenseLayer.Builder().activation(Activation.TANH).nIn(5).nOut(4).build()).layer(2, new LSTM.Builder().activation(Activation.TANH).nIn(4).nOut(3).build()).layer(3, new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(3).nOut(4).build()).setInputType(InputType.recurrent(numInputClasses, timeSeriesLength, RNNFormat.NCW)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .updater(new Sgd(0.1)).seed(12345).list() - .layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(numInputClasses).nOut(5) - .build()) - .layer(1, new DenseLayer.Builder().activation(Activation.TANH).nIn(5).nOut(4).build()) - .layer(2, new LSTM.Builder().activation(Activation.TANH).nIn(4).nOut(3).dataFormat(RNNFormat.NCW).build()) - .layer(3, new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(3) - .nOut(4).build()) - .setInputType(InputType.recurrent(numInputClasses,1,RNNFormat.NCW)).build(); - + MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Sgd(0.1)).seed(12345).list().layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(numInputClasses).nOut(5).build()).layer(1, new DenseLayer.Builder().activation(Activation.TANH).nIn(5).nOut(4).build()).layer(2, new LSTM.Builder().activation(Activation.TANH).nIn(4).nOut(3).dataFormat(RNNFormat.NCW).build()).layer(3, new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(3).nOut(4).build()).setInputType(InputType.recurrent(numInputClasses, 1, RNNFormat.NCW)).build(); MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); net2.init(); - net2.setParams(net.params().dup()); - - INDArray inEmbedding = Nd4j.zeros(inLabelDtype, inputRank == 2 ? new long[]{nExamples, timeSeriesLength} : new long[]{nExamples, 1, timeSeriesLength}); + INDArray inEmbedding = Nd4j.zeros(inLabelDtype, inputRank == 2 ? new long[] { nExamples, timeSeriesLength } : new long[] { nExamples, 1, timeSeriesLength }); INDArray inDense = Nd4j.zeros(inLabelDtype, nExamples, numInputClasses, timeSeriesLength); - INDArray labels = Nd4j.zeros(inLabelDtype, nExamples, 4, timeSeriesLength); - for (int i = 0; i < nExamples; i++) { for (int j = 0; j < timeSeriesLength; j++) { int inIdx = r.nextInt(numInputClasses); - inEmbedding.putScalar(inputRank == 2 ? new int[]{i, j} : new int[]{i, 0, j}, inIdx); - inDense.putScalar(new int[]{i, inIdx, j}, 1.0); - + inEmbedding.putScalar(inputRank == 2 ? new int[] { i, j } : new int[] { i, 0, j }, inIdx); + inDense.putScalar(new int[] { i, inIdx, j }, 1.0); int outIdx = r.nextInt(4); - labels.putScalar(new int[]{i, outIdx, j}, 1.0); + labels.putScalar(new int[] { i, outIdx, j }, 1.0); } } - INDArray inputMask = Nd4j.zeros(maskDtype, nExamples, timeSeriesLength); for (int i = 0; i < nExamples; i++) { for (int j = 0; j < timeSeriesLength; j++) { - inputMask.putScalar(new int[]{i, j}, (r.nextBoolean() ? 1.0 : 0.0)); + inputMask.putScalar(new int[] { i, j }, (r.nextBoolean() ? 1.0 : 0.0)); } } - net.setLayerMaskArrays(inputMask, null); net2.setLayerMaskArrays(inputMask, null); List actEmbedding = net.feedForward(inEmbedding, false); List actDense = net2.feedForward(inDense, false); - for (int i = 2; i < actEmbedding.size(); i++) { //Start from layer 2: EmbeddingSequence is 3d, first dense is 2d (before reshape) + for (int i = 2; i < actEmbedding.size(); i++) { + // Start from layer 2: EmbeddingSequence is 3d, first dense is 2d (before reshape) assertEquals(actDense.get(i), actEmbedding.get(i)); } - net.setLabels(labels); net2.setLabels(labels); net.computeGradientAndScore(); net2.computeGradientAndScore(); - assertEquals(net2.score(), net.score(), 1e-5); - Map gradients = net.gradient().gradientForVariable(); Map gradients2 = net2.gradient().gradientForVariable(); assertEquals(gradients.keySet(), gradients2.keySet()); @@ -696,11 +470,12 @@ public class EmbeddingLayerTest extends BaseDL4JTest { } @EqualsAndHashCode + @DisplayName("Word Vectors Mockup") private static class WordVectorsMockup implements EmbeddingInitializer { @Override public void loadWeightsInto(INDArray array) { - INDArray vectors = Nd4j.linspace(1,15,15, DataType.FLOAT).reshape(5,3); + INDArray vectors = Nd4j.linspace(1, 15, 15, DataType.FLOAT).reshape(5, 3); array.assign(vectors); } @@ -721,94 +496,55 @@ public class EmbeddingLayerTest extends BaseDL4JTest { } @Test - public void testEmbeddingDefaultActivation(){ - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .list() - .layer(new EmbeddingLayer.Builder().nIn(10).nOut(10).build()) - .layer(new EmbeddingSequenceLayer.Builder().nIn(10).nOut(10).build()) - .build(); - + @DisplayName("Test Embedding Default Activation") + void testEmbeddingDefaultActivation() { + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(new EmbeddingLayer.Builder().nIn(10).nOut(10).build()).layer(new EmbeddingSequenceLayer.Builder().nIn(10).nOut(10).build()).build(); EmbeddingLayer l = (EmbeddingLayer) conf.getConf(0).getLayer(); assertEquals(new ActivationIdentity(), l.getActivationFn()); - EmbeddingSequenceLayer l2 = (EmbeddingSequenceLayer) conf.getConf(1).getLayer(); assertEquals(new ActivationIdentity(), l2.getActivationFn()); - } - @Test - public void testEmbeddingWeightInit(){ + @DisplayName("Test Embedding Weight Init") + void testEmbeddingWeightInit() { // https://github.com/eclipse/deeplearning4j/issues/8663 - //The embedding layer weight initialization should be independent of the vocabulary size (nIn setting) - - for(WeightInit wi : new WeightInit[]{WeightInit.XAVIER, WeightInit.RELU, WeightInit.XAVIER_UNIFORM, WeightInit.LECUN_NORMAL}) { - - for (boolean seq : new boolean[]{false, true}) { - + // The embedding layer weight initialization should be independent of the vocabulary size (nIn setting) + for (WeightInit wi : new WeightInit[] { WeightInit.XAVIER, WeightInit.RELU, WeightInit.XAVIER_UNIFORM, WeightInit.LECUN_NORMAL }) { + for (boolean seq : new boolean[] { false, true }) { Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .seed(12345) - .list() - .layer(seq ? - new EmbeddingSequenceLayer.Builder().weightInit(wi).nIn(100).nOut(100).build() : - new EmbeddingLayer.Builder().weightInit(wi).nIn(100).nOut(100).build()) - .build(); + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).list().layer(seq ? new EmbeddingSequenceLayer.Builder().weightInit(wi).nIn(100).nOut(100).build() : new EmbeddingLayer.Builder().weightInit(wi).nIn(100).nOut(100).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder() - .seed(12345) - .list() - .layer(seq ? - new EmbeddingSequenceLayer.Builder().weightInit(wi).nIn(100).nOut(100).build() : - new EmbeddingLayer.Builder().weightInit(wi).nIn(100).nOut(100).build()) - .build(); + MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345).list().layer(seq ? new EmbeddingSequenceLayer.Builder().weightInit(wi).nIn(100).nOut(100).build() : new EmbeddingLayer.Builder().weightInit(wi).nIn(100).nOut(100).build()).build(); MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); net2.init(); - Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf3 = new NeuralNetConfiguration.Builder() - .seed(12345) - .list() - .layer(seq ? - new EmbeddingSequenceLayer.Builder().weightInit(wi).nIn(100000).nOut(100).build() : - new EmbeddingLayer.Builder().weightInit(wi).nIn(100000).nOut(100).build()) - .build(); + MultiLayerConfiguration conf3 = new NeuralNetConfiguration.Builder().seed(12345).list().layer(seq ? new EmbeddingSequenceLayer.Builder().weightInit(wi).nIn(100000).nOut(100).build() : new EmbeddingLayer.Builder().weightInit(wi).nIn(100000).nOut(100).build()).build(); MultiLayerNetwork net3 = new MultiLayerNetwork(conf3); net3.init(); - INDArray p1 = net.params(); INDArray p2 = net2.params(); INDArray p3 = net3.params(); boolean eq = p1.equalsWithEps(p2, 1e-4); String str = (seq ? "EmbeddingSequenceLayer" : "EmbeddingLayer") + " - " + wi; - assertTrue(str + " p1/p2 params not equal", eq); - + assertTrue(eq,str + " p1/p2 params not equal"); double m1 = p1.meanNumber().doubleValue(); double s1 = p1.stdNumber().doubleValue(); - double m3 = p3.meanNumber().doubleValue(); double s3 = p3.stdNumber().doubleValue(); - - - - assertEquals(str, m1, m3, 0.1); - assertEquals(str, s1, s3, 0.1); - + assertEquals( m1, m3, 0.1,str); + assertEquals(s1, s3, 0.1,str); double re = relErr(s1, s3); - assertTrue(str + " - " + re, re < 0.05); + assertTrue( re < 0.05,str + " - " + re); } } - } - public static double relErr(double d1, double d2){ - if(d1 == 0.0 && d2 == 0.0) + public static double relErr(double d1, double d2) { + if (d1 == 0.0 && d2 == 0.0) return 0.0; return Math.abs(d1 - d2) / (Math.abs(d1) + Math.abs(d2)); } - } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/normalization/BatchNormalizationTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/normalization/BatchNormalizationTest.java index 9896f05d4..b69ea4241 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/normalization/BatchNormalizationTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/normalization/BatchNormalizationTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.layers.normalization; import lombok.extern.slf4j.Slf4j; @@ -43,8 +42,8 @@ import org.deeplearning4j.nn.transferlearning.TransferLearning; import org.deeplearning4j.nn.updater.MultiLayerUpdater; import org.deeplearning4j.nn.updater.UpdaterBlock; import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.util.DataTypeUtil; @@ -65,32 +64,35 @@ import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.ops.transforms.Transforms; import org.nd4j.common.primitives.Pair; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; - import java.util.ArrayList; import java.util.List; import java.util.Map; - -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** */ @Slf4j -public class BatchNormalizationTest extends BaseDL4JTest { +@DisplayName("Batch Normalization Test") +class BatchNormalizationTest extends BaseDL4JTest { static { - //Force Nd4j initialization, then set data type to double: + // Force Nd4j initialization, then set data type to double: Nd4j.zeros(1); DataTypeUtil.setDTypeForContext(DataType.DOUBLE); } protected INDArray dnnInput = Nd4j.linspace(0, 31, 32, Nd4j.dataType()).reshape(2, 16); + protected INDArray dnnEpsilon = Nd4j.linspace(0, 31, 32, Nd4j.dataType()).reshape(2, 16); protected INDArray cnnInput = Nd4j.linspace(0, 63, 64, Nd4j.dataType()).reshape(2, 2, 4, 4); + protected INDArray cnnEpsilon = Nd4j.linspace(0, 63, 64, Nd4j.dataType()).reshape(2, 2, 4, 4); - @Before - public void doBefore() { + @BeforeEach + void doBefore() { } @Override @@ -99,31 +101,28 @@ public class BatchNormalizationTest extends BaseDL4JTest { } @Test - public void testDnnForwardPass() { + @DisplayName("Test Dnn Forward Pass") + void testDnnForwardPass() { int nOut = 10; Layer l = getLayer(nOut, 0.0, false, -1, -1); - assertEquals(4 * nOut, l.numParams()); //Gamma, beta, global mean, global var - + // Gamma, beta, global mean, global var + assertEquals(4 * nOut, l.numParams()); INDArray randInput = Nd4j.rand(100, nOut); INDArray output = l.activate(randInput, true, LayerWorkspaceMgr.noWorkspaces()); - INDArray mean = output.mean(0); INDArray stdev = output.std(false, 0); - -// System.out.println(Arrays.toString(mean.data().asFloat())); - + // System.out.println(Arrays.toString(mean.data().asFloat())); assertArrayEquals(new float[nOut], mean.data().asFloat(), 1e-6f); assertEquals(Nd4j.ones(nOut), stdev); - - //If we fix gamma/beta: expect different mean and variance... + // If we fix gamma/beta: expect different mean and variance... double gamma = 2.0; double beta = 3.0; l = getLayer(nOut, 0.0, true, gamma, beta); - assertEquals(2 * nOut, l.numParams()); //Should have only global mean/var parameters + // Should have only global mean/var parameters + assertEquals(2 * nOut, l.numParams()); output = l.activate(randInput, true, LayerWorkspaceMgr.noWorkspaces()); mean = output.mean(0); stdev = output.std(false, 0); - assertEquals(Nd4j.valueArrayOf(mean.shape(), beta), mean); assertEquals(Nd4j.valueArrayOf(stdev.shape(), gamma), stdev); } @@ -135,7 +134,6 @@ public class BatchNormalizationTest extends BaseDL4JTest { } BatchNormalization bN = b.build(); NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().layer(bN).build(); - long numParams = conf.getLayer().initializer().numParams(conf); INDArray params = null; if (numParams > 0) { @@ -149,136 +147,108 @@ public class BatchNormalizationTest extends BaseDL4JTest { } @Test - public void testDnnForwardBackward() { + @DisplayName("Test Dnn Forward Backward") + void testDnnForwardBackward() { double eps = 1e-5; int nIn = 4; int minibatch = 2; Nd4j.getRandom().setSeed(12345); - INDArray input = Nd4j.rand('c', new int[]{minibatch, nIn}); - - //TODO: other values for gamma/beta + INDArray input = Nd4j.rand('c', new int[] { minibatch, nIn }); + // TODO: other values for gamma/beta INDArray gamma = Nd4j.ones(1, nIn); INDArray beta = Nd4j.zeros(1, nIn); - Layer l = getLayer(nIn, eps, false, -1, -1); - INDArray mean = input.mean(0); INDArray var = input.var(false, 0); INDArray xHat = input.subRowVector(mean).divRowVector(Transforms.sqrt(var.add(eps), true)); INDArray outExpected = xHat.mulRowVector(gamma).addRowVector(beta); - INDArray out = l.activate(input, true, LayerWorkspaceMgr.noWorkspaces()); - -// System.out.println(Arrays.toString(outExpected.data().asDouble())); -// System.out.println(Arrays.toString(out.data().asDouble())); - + // System.out.println(Arrays.toString(outExpected.data().asDouble())); + // System.out.println(Arrays.toString(out.data().asDouble())); assertEquals(outExpected, out); - - //------------------------------------------------------------- - //Check backprop - INDArray epsilon = Nd4j.rand(minibatch, nIn); //dL/dy - + // ------------------------------------------------------------- + // Check backprop + // dL/dy + INDArray epsilon = Nd4j.rand(minibatch, nIn); INDArray dldgammaExp = epsilon.mul(xHat).sum(true, 0); INDArray dldbetaExp = epsilon.sum(true, 0); - INDArray dldxhat = epsilon.mulRowVector(gamma); - INDArray dldvar = dldxhat.mul(input.subRowVector(mean)).mul(-0.5) - .mulRowVector(Transforms.pow(var.add(eps), -3.0 / 2.0, true)).sum(0); - INDArray dldmu = dldxhat.mulRowVector(Transforms.pow(var.add(eps), -1.0 / 2.0, true)).neg().sum(0) - .add(dldvar.mul(input.subRowVector(mean).mul(-2.0).sum(0).div(minibatch))); - INDArray dldinExp = dldxhat.mulRowVector(Transforms.pow(var.add(eps), -1.0 / 2.0, true)) - .add(input.subRowVector(mean).mul(2.0 / minibatch).mulRowVector(dldvar)) - .addRowVector(dldmu.mul(1.0 / minibatch)); - + INDArray dldvar = dldxhat.mul(input.subRowVector(mean)).mul(-0.5).mulRowVector(Transforms.pow(var.add(eps), -3.0 / 2.0, true)).sum(0); + INDArray dldmu = dldxhat.mulRowVector(Transforms.pow(var.add(eps), -1.0 / 2.0, true)).neg().sum(0).add(dldvar.mul(input.subRowVector(mean).mul(-2.0).sum(0).div(minibatch))); + INDArray dldinExp = dldxhat.mulRowVector(Transforms.pow(var.add(eps), -1.0 / 2.0, true)).add(input.subRowVector(mean).mul(2.0 / minibatch).mulRowVector(dldvar)).addRowVector(dldmu.mul(1.0 / minibatch)); Pair p = l.backpropGradient(epsilon, LayerWorkspaceMgr.noWorkspaces()); - INDArray dldgamma = p.getFirst().getGradientFor("gamma"); INDArray dldbeta = p.getFirst().getGradientFor("beta"); - assertEquals(dldgammaExp, dldgamma); assertEquals(dldbetaExp, dldbeta); - -// System.out.println("EPSILONS"); -// System.out.println(Arrays.toString(dldinExp.data().asDouble())); -// System.out.println(Arrays.toString(p.getSecond().dup().data().asDouble())); + // System.out.println("EPSILONS"); + // System.out.println(Arrays.toString(dldinExp.data().asDouble())); + // System.out.println(Arrays.toString(p.getSecond().dup().data().asDouble())); assertEquals(dldinExp, p.getSecond()); } @Test - public void testCnnForwardPass() { + @DisplayName("Test Cnn Forward Pass") + void testCnnForwardPass() { int nOut = 10; Layer l = getLayer(nOut, 0.0, false, -1, -1); - assertEquals(4 * nOut, l.numParams()); //Gamma, beta, global mean, global var + // Gamma, beta, global mean, global var + assertEquals(4 * nOut, l.numParams()); int hw = 15; - Nd4j.getRandom().setSeed(12345); - INDArray randInput = Nd4j.rand(new int[]{100, nOut, hw, hw}); + INDArray randInput = Nd4j.rand(new int[] { 100, nOut, hw, hw }); INDArray output = l.activate(randInput, true, LayerWorkspaceMgr.noWorkspaces()); - assertEquals(4, output.rank()); - INDArray mean = output.mean(0, 2, 3); INDArray stdev = output.std(false, 0, 2, 3); - assertArrayEquals(new float[nOut], mean.data().asFloat(), 1e-6f); assertArrayEquals(Nd4j.ones(1, nOut).data().asFloat(), stdev.data().asFloat(), 1e-6f); - - //If we fix gamma/beta: expect different mean and variance... + // If we fix gamma/beta: expect different mean and variance... double gamma = 2.0; double beta = 3.0; l = getLayer(nOut, 0.0, true, gamma, beta); - assertEquals(2 * nOut, l.numParams()); //Should have only global mean/var parameters + // Should have only global mean/var parameters + assertEquals(2 * nOut, l.numParams()); output = l.activate(randInput, true, LayerWorkspaceMgr.noWorkspaces()); mean = output.mean(0, 2, 3); stdev = output.std(false, 0, 2, 3); - assertEquals(Nd4j.valueArrayOf(mean.shape(), beta), mean); assertEquals(Nd4j.valueArrayOf(stdev.shape(), gamma), stdev); } @Test - public void test2dVs4d() { - //Idea: 2d and 4d should be the same... + @DisplayName("Test 2 d Vs 4 d") + void test2dVs4d() { + // Idea: 2d and 4d should be the same... Nd4j.getRandom().setSeed(12345); - int m = 2; int h = 3; int w = 3; int nOut = 2; - INDArray in = Nd4j.rand('c', m * h * w, nOut); - INDArray in4 = in.dup(); - in4 = Shape.newShapeNoCopy(in4, new int[]{m, h, w, nOut}, false); + in4 = Shape.newShapeNoCopy(in4, new int[] { m, h, w, nOut }, false); assertNotNull(in4); in4 = in4.permute(0, 3, 1, 2).dup(); INDArray arr = Nd4j.rand(1, m * h * w * nOut).reshape('f', h, w, m, nOut).permute(2, 3, 1, 0); in4 = arr.assign(in4); - Layer l1 = getLayer(nOut); Layer l2 = getLayer(nOut); - INDArray out2d = l1.activate(in.dup(), true, LayerWorkspaceMgr.noWorkspaces()); INDArray out4d = l2.activate(in4.dup(), true, LayerWorkspaceMgr.noWorkspaces()); - INDArray out4dAs2 = out4d.permute(0, 2, 3, 1).dup('c'); - out4dAs2 = Shape.newShapeNoCopy(out4dAs2, new int[]{m * h * w, nOut}, false); - + out4dAs2 = Shape.newShapeNoCopy(out4dAs2, new int[] { m * h * w, nOut }, false); assertEquals(out2d, out4dAs2); - - //Test backprop: + // Test backprop: INDArray epsilons2d = Nd4j.rand('c', m * h * w, nOut); INDArray epsilons4d = epsilons2d.dup(); - epsilons4d = Shape.newShapeNoCopy(epsilons4d, new int[]{m, h, w, nOut}, false); + epsilons4d = Shape.newShapeNoCopy(epsilons4d, new int[] { m, h, w, nOut }, false); assertNotNull(epsilons4d); epsilons4d = epsilons4d.permute(0, 3, 1, 2).dup(); - Pair b2d = l1.backpropGradient(epsilons2d, LayerWorkspaceMgr.noWorkspaces()); Pair b4d = l2.backpropGradient(epsilons4d, LayerWorkspaceMgr.noWorkspaces()); - INDArray e4dAs2d = b4d.getSecond().permute(0, 2, 3, 1).dup('c'); - e4dAs2d = Shape.newShapeNoCopy(e4dAs2d, new int[]{m * h * w, nOut}, false); - + e4dAs2d = Shape.newShapeNoCopy(e4dAs2d, new int[] { m * h * w, nOut }, false); assertEquals(b2d.getSecond(), e4dAs2d); } @@ -287,109 +257,71 @@ public class BatchNormalizationTest extends BaseDL4JTest { } @Test - public void testCnnForwardBackward() { + @DisplayName("Test Cnn Forward Backward") + void testCnnForwardBackward() { double eps = 1e-5; int nIn = 4; int hw = 3; int minibatch = 2; Nd4j.getRandom().setSeed(12345); - INDArray input = Nd4j.rand('c', new int[]{minibatch, nIn, hw, hw}); - - //TODO: other values for gamma/beta + INDArray input = Nd4j.rand('c', new int[] { minibatch, nIn, hw, hw }); + // TODO: other values for gamma/beta INDArray gamma = Nd4j.ones(1, nIn); INDArray beta = Nd4j.zeros(1, nIn); - Layer l = getLayer(nIn, eps, false, -1, -1); - INDArray mean = input.mean(0, 2, 3); INDArray var = input.var(false, 0, 2, 3); INDArray xHat = Nd4j.getExecutioner().exec(new BroadcastSubOp(input, mean, input.dup(), 1)); Nd4j.getExecutioner().exec(new BroadcastDivOp(xHat, Transforms.sqrt(var.add(eps), true), xHat, 1)); - INDArray outExpected = Nd4j.getExecutioner().exec(new BroadcastMulOp(xHat, gamma, xHat.dup(), 1)); Nd4j.getExecutioner().exec(new BroadcastAddOp(outExpected, beta, outExpected, 1)); - INDArray out = l.activate(input, true, LayerWorkspaceMgr.noWorkspaces()); - -// System.out.println(Arrays.toString(outExpected.data().asDouble())); -// System.out.println(Arrays.toString(out.data().asDouble())); - + // System.out.println(Arrays.toString(outExpected.data().asDouble())); + // System.out.println(Arrays.toString(out.data().asDouble())); assertEquals(outExpected, out); - - //------------------------------------------------------------- - //Check backprop - INDArray epsilon = Nd4j.rand('c', new int[]{minibatch, nIn, hw, hw}); //dL/dy - + // ------------------------------------------------------------- + // Check backprop + // dL/dy + INDArray epsilon = Nd4j.rand('c', new int[] { minibatch, nIn, hw, hw }); int effectiveMinibatch = minibatch * hw * hw; - INDArray dldgammaExp = epsilon.mul(xHat).sum(0, 2, 3); dldgammaExp = dldgammaExp.reshape(1, dldgammaExp.length()); INDArray dldbetaExp = epsilon.sum(0, 2, 3); dldbetaExp = dldbetaExp.reshape(1, dldbetaExp.length()); - - INDArray dldxhat = Nd4j.getExecutioner().exec(new BroadcastMulOp(epsilon, gamma, epsilon.dup(), 1)); //epsilon.mulRowVector(gamma); - + // epsilon.mulRowVector(gamma); + INDArray dldxhat = Nd4j.getExecutioner().exec(new BroadcastMulOp(epsilon, gamma, epsilon.dup(), 1)); INDArray inputSubMean = Nd4j.getExecutioner().exec(new BroadcastSubOp(input, mean, input.dup(), 1)); - INDArray dldvar = dldxhat.mul(inputSubMean).mul(-0.5); - dldvar = Nd4j.getExecutioner().exec( - new BroadcastMulOp(dldvar, Transforms.pow(var.add(eps), -3.0 / 2.0, true), dldvar.dup(), 1)); + dldvar = Nd4j.getExecutioner().exec(new BroadcastMulOp(dldvar, Transforms.pow(var.add(eps), -3.0 / 2.0, true), dldvar.dup(), 1)); dldvar = dldvar.sum(0, 2, 3); - - - INDArray dldmu = Nd4j - .getExecutioner().exec(new BroadcastMulOp(dldxhat, - Transforms.pow(var.add(eps), -1.0 / 2.0, true), dldxhat.dup(), 1)) - .neg().sum(0, 2, 3); + INDArray dldmu = Nd4j.getExecutioner().exec(new BroadcastMulOp(dldxhat, Transforms.pow(var.add(eps), -1.0 / 2.0, true), dldxhat.dup(), 1)).neg().sum(0, 2, 3); dldmu = dldmu.add(dldvar.mul(inputSubMean.mul(-2.0).sum(0, 2, 3).div(effectiveMinibatch))); - - INDArray dldinExp = Nd4j.getExecutioner().exec( - new BroadcastMulOp(dldxhat, Transforms.pow(var.add(eps), -1.0 / 2.0, true), dldxhat.dup(), 1)); - dldinExp = dldinExp.add(Nd4j.getExecutioner().exec( - new BroadcastMulOp(inputSubMean.mul(2.0 / effectiveMinibatch), dldvar, inputSubMean.dup(), 1))); - dldinExp = Nd4j.getExecutioner().exec( - new BroadcastAddOp(dldinExp, dldmu.mul(1.0 / effectiveMinibatch), dldinExp.dup(), 1)); - + INDArray dldinExp = Nd4j.getExecutioner().exec(new BroadcastMulOp(dldxhat, Transforms.pow(var.add(eps), -1.0 / 2.0, true), dldxhat.dup(), 1)); + dldinExp = dldinExp.add(Nd4j.getExecutioner().exec(new BroadcastMulOp(inputSubMean.mul(2.0 / effectiveMinibatch), dldvar, inputSubMean.dup(), 1))); + dldinExp = Nd4j.getExecutioner().exec(new BroadcastAddOp(dldinExp, dldmu.mul(1.0 / effectiveMinibatch), dldinExp.dup(), 1)); Pair p = l.backpropGradient(epsilon, LayerWorkspaceMgr.noWorkspaces()); - INDArray dldgamma = p.getFirst().getGradientFor("gamma"); INDArray dldbeta = p.getFirst().getGradientFor("beta"); - assertEquals(dldgammaExp, dldgamma); assertEquals(dldbetaExp, dldbeta); - - // System.out.println("EPSILONS"); - // System.out.println(Arrays.toString(dldinExp.data().asDouble())); - // System.out.println(Arrays.toString(p.getSecond().dup().data().asDouble())); + // System.out.println("EPSILONS"); + // System.out.println(Arrays.toString(dldinExp.data().asDouble())); + // System.out.println(Arrays.toString(p.getSecond().dup().data().asDouble())); assertEquals(dldinExp, p.getSecond()); } @Test - public void testDBNBNMultiLayer() throws Exception { + @DisplayName("Test DBNBN Multi Layer") + void testDBNBNMultiLayer() throws Exception { DataSetIterator iter = new MnistDataSetIterator(2, 2); DataSet next = iter.next(); - // Run with separate activation layer - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123) - .list() - .layer(0, new DenseLayer.Builder().nIn(28 * 28).nOut(10).weightInit(WeightInit.XAVIER) - .activation(Activation.RELU).build()) - .layer(1, new BatchNormalization.Builder().nOut(10).build()).layer(2, - new ActivationLayer.Builder() - .activation(Activation.RELU).build()) - .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nIn(10).nOut(10) - .build()) - .build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123).list().layer(0, new DenseLayer.Builder().nIn(28 * 28).nOut(10).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build()).layer(1, new BatchNormalization.Builder().nOut(10).build()).layer(2, new ActivationLayer.Builder().activation(Activation.RELU).build()).layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nIn(10).nOut(10).build()).build(); MultiLayerNetwork network = new MultiLayerNetwork(conf); network.init(); - network.setInput(next.getFeatures()); INDArray activationsActual = network.output(next.getFeatures()); assertEquals(10, activationsActual.shape()[1], 1e-2); - network.fit(next); INDArray actualGammaParam = network.getLayer(1).getParam(BatchNormalizationParamInitializer.GAMMA); INDArray actualBetaParam = network.getLayer(1).getParam(BatchNormalizationParamInitializer.BETA); @@ -398,115 +330,63 @@ public class BatchNormalizationTest extends BaseDL4JTest { } @Test - public void testCNNBNActivationCombo() throws Exception { + @DisplayName("Test CNNBN Activation Combo") + void testCNNBNActivationCombo() throws Exception { DataSetIterator iter = new MnistDataSetIterator(2, 2); DataSet next = iter.next(); - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123) - .list() - .layer(0, new ConvolutionLayer.Builder().nIn(1).nOut(6).weightInit(WeightInit.XAVIER) - .activation(Activation.IDENTITY).build()) - .layer(1, new BatchNormalization.Builder().build()) - .layer(2, new ActivationLayer.Builder().activation(Activation.RELU).build()) - .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nOut(10).build()) - .setInputType(InputType.convolutionalFlat(28, 28, 1)).build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123).list().layer(0, new ConvolutionLayer.Builder().nIn(1).nOut(6).weightInit(WeightInit.XAVIER).activation(Activation.IDENTITY).build()).layer(1, new BatchNormalization.Builder().build()).layer(2, new ActivationLayer.Builder().activation(Activation.RELU).build()).layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nOut(10).build()).setInputType(InputType.convolutionalFlat(28, 28, 1)).build(); MultiLayerNetwork network = new MultiLayerNetwork(conf); network.init(); network.fit(next); - assertNotEquals(null, network.getLayer(0).getParam("W")); assertNotEquals(null, network.getLayer(0).getParam("b")); } - @Test - public void checkSerialization() throws Exception { - //Serialize the batch norm network (after training), and make sure we get same activations out as before + @DisplayName("Check Serialization") + void checkSerialization() throws Exception { + // Serialize the batch norm network (after training), and make sure we get same activations out as before // i.e., make sure state is properly stored - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .seed(12345) - .list() - .layer(0, new ConvolutionLayer.Builder().nIn(1).nOut(6).weightInit(WeightInit.XAVIER) - .activation(Activation.IDENTITY).build()) - .layer(1, new BatchNormalization.Builder().build()) - .layer(2, new ActivationLayer.Builder().activation(Activation.LEAKYRELU).build()) - .layer(3, new DenseLayer.Builder().nOut(10).activation(Activation.LEAKYRELU).build()) - .layer(4, new BatchNormalization.Builder().build()) - .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nOut(10).build()) - .setInputType(InputType.convolutionalFlat(28, 28, 1)).build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).list().layer(0, new ConvolutionLayer.Builder().nIn(1).nOut(6).weightInit(WeightInit.XAVIER).activation(Activation.IDENTITY).build()).layer(1, new BatchNormalization.Builder().build()).layer(2, new ActivationLayer.Builder().activation(Activation.LEAKYRELU).build()).layer(3, new DenseLayer.Builder().nOut(10).activation(Activation.LEAKYRELU).build()).layer(4, new BatchNormalization.Builder().build()).layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nOut(10).build()).setInputType(InputType.convolutionalFlat(28, 28, 1)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - DataSetIterator iter = new MnistDataSetIterator(16, true, 12345); for (int i = 0; i < 20; i++) { net.fit(iter.next()); } - INDArray in = iter.next().getFeatures(); - INDArray out = net.output(in, false); INDArray out2 = net.output(in, false); - assertEquals(out, out2); - MultiLayerNetwork net2 = TestUtils.testModelSerialization(net); - INDArray outDeser = net2.output(in, false); - assertEquals(out, outDeser); } @Test - public void testGradientAndUpdaters() throws Exception { - //Global mean/variance are part of the parameter vector. Expect 0 gradient, and no-op updater for these - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .updater(Updater.RMSPROP).seed(12345).list() - .layer(0, new ConvolutionLayer.Builder().nIn(1).nOut(6).weightInit(WeightInit.XAVIER) - .activation(Activation.IDENTITY).build()) - .layer(1, new BatchNormalization.Builder().build()) - .layer(2, new ActivationLayer.Builder().activation(Activation.LEAKYRELU).build()) - .layer(3, new DenseLayer.Builder().nOut(10).activation(Activation.LEAKYRELU).build()) - .layer(4, new BatchNormalization.Builder().build()) - .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nOut(10).build()) - .setInputType(InputType.convolutionalFlat(28, 28, 1)).build(); - + @DisplayName("Test Gradient And Updaters") + void testGradientAndUpdaters() throws Exception { + // Global mean/variance are part of the parameter vector. Expect 0 gradient, and no-op updater for these + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.RMSPROP).seed(12345).list().layer(0, new ConvolutionLayer.Builder().nIn(1).nOut(6).weightInit(WeightInit.XAVIER).activation(Activation.IDENTITY).build()).layer(1, new BatchNormalization.Builder().build()).layer(2, new ActivationLayer.Builder().activation(Activation.LEAKYRELU).build()).layer(3, new DenseLayer.Builder().nOut(10).activation(Activation.LEAKYRELU).build()).layer(4, new BatchNormalization.Builder().build()).layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nOut(10).build()).setInputType(InputType.convolutionalFlat(28, 28, 1)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - DataSetIterator iter = new MnistDataSetIterator(16, true, 12345); - DataSet ds = iter.next(); net.setInput(ds.getFeatures()); net.setLabels(ds.getLabels()); - net.computeGradientAndScore(); - Gradient g = net.gradient(); Map map = g.gradientForVariable(); - org.deeplearning4j.nn.api.Updater u = net.getUpdater(); - MultiLayerUpdater mlu = (MultiLayerUpdater) u; List l = mlu.getUpdaterBlocks(); assertNotNull(l); - assertEquals(5, l.size()); //Conv+bn (RMSProp), No-op (bn), RMSProp (dense, bn), no-op (bn), RMSProp (out) - + // Conv+bn (RMSProp), No-op (bn), RMSProp (dense, bn), no-op (bn), RMSProp (out) + assertEquals(5, l.size()); for (UpdaterBlock ub : l) { - List list = ub.getLayersAndVariablesInBlock(); for (UpdaterBlock.ParamState v : list) { - if (BatchNormalizationParamInitializer.GLOBAL_MEAN.equals(v.getParamName()) - || BatchNormalizationParamInitializer.GLOBAL_VAR.equals(v.getParamName()) - || BatchNormalizationParamInitializer.GLOBAL_LOG_STD.equals(v.getParamName())) { + if (BatchNormalizationParamInitializer.GLOBAL_MEAN.equals(v.getParamName()) || BatchNormalizationParamInitializer.GLOBAL_VAR.equals(v.getParamName()) || BatchNormalizationParamInitializer.GLOBAL_LOG_STD.equals(v.getParamName())) { assertTrue(ub.getGradientUpdater() instanceof NoOpUpdater); } else { assertTrue(ub.getGradientUpdater() instanceof RmsPropUpdater); @@ -515,264 +395,171 @@ public class BatchNormalizationTest extends BaseDL4JTest { } } - @Test - public void checkMeanVarianceEstimate() throws Exception { + @DisplayName("Check Mean Variance Estimate") + void checkMeanVarianceEstimate() throws Exception { Nd4j.getRandom().setSeed(12345); - //Check that the internal global mean/variance estimate is approximately correct - - for(boolean useLogStd : new boolean[]{true, false}) { - - //First, Mnist data as 2d input (NOT taking into account convolution property) - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .updater(Updater.RMSPROP).seed(12345) - .list().layer(0, - new BatchNormalization.Builder().nIn(10).nOut(10).eps(1e-5).decay(0.95) - .useLogStd(useLogStd).build()) - .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MSE).weightInit(WeightInit.XAVIER) - .activation(Activation.IDENTITY).nIn(10).nOut(10).build()) - .build(); + // Check that the internal global mean/variance estimate is approximately correct + for (boolean useLogStd : new boolean[] { true, false }) { + // First, Mnist data as 2d input (NOT taking into account convolution property) + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.RMSPROP).seed(12345).list().layer(0, new BatchNormalization.Builder().nIn(10).nOut(10).eps(1e-5).decay(0.95).useLogStd(useLogStd).build()).layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MSE).weightInit(WeightInit.XAVIER).activation(Activation.IDENTITY).nIn(10).nOut(10).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - int minibatch = 32; List list = new ArrayList<>(); for (int i = 0; i < 200; i++) { list.add(new DataSet(Nd4j.rand(minibatch, 10), Nd4j.rand(minibatch, 10))); } - DataSetIterator iter = new ListDataSetIterator(list); - - INDArray expMean = Nd4j.valueArrayOf(new int[]{1, 10}, 0.5); - INDArray expVar = Nd4j.valueArrayOf(new int[]{1, 10}, 1 / 12.0); //Expected variance of U(0,1) distribution: 1/12 * (1-0)^2 = 0.0833 - - + INDArray expMean = Nd4j.valueArrayOf(new int[] { 1, 10 }, 0.5); + // Expected variance of U(0,1) distribution: 1/12 * (1-0)^2 = 0.0833 + INDArray expVar = Nd4j.valueArrayOf(new int[] { 1, 10 }, 1 / 12.0); for (int i = 0; i < 10; i++) { iter.reset(); net.fit(iter); } - INDArray estMean = net.getLayer(0).getParam(BatchNormalizationParamInitializer.GLOBAL_MEAN); INDArray estVar; - if(useLogStd){ + if (useLogStd) { INDArray log10std = net.getLayer(0).getParam(BatchNormalizationParamInitializer.GLOBAL_LOG_STD); estVar = Nd4j.valueArrayOf(log10std.shape(), 10.0).castTo(log10std.dataType()); - Transforms.pow(estVar, log10std, false); // stdev = 10^(log10(stdev)) + // stdev = 10^(log10(stdev)) + Transforms.pow(estVar, log10std, false); estVar.muli(estVar); } else { estVar = net.getLayer(0).getParam(BatchNormalizationParamInitializer.GLOBAL_VAR); } - float[] fMeanExp = expMean.data().asFloat(); float[] fMeanAct = estMean.data().asFloat(); float[] fVarExp = expVar.data().asFloat(); float[] fVarAct = estVar.data().asFloat(); - - // System.out.println("Mean vs. estimated mean:"); - // System.out.println(Arrays.toString(fMeanExp)); - // System.out.println(Arrays.toString(fMeanAct)); - // - // System.out.println("Var vs. estimated var:"); - // System.out.println(Arrays.toString(fVarExp)); - // System.out.println(Arrays.toString(fVarAct)); - + // System.out.println("Mean vs. estimated mean:"); + // System.out.println(Arrays.toString(fMeanExp)); + // System.out.println(Arrays.toString(fMeanAct)); + // + // System.out.println("Var vs. estimated var:"); + // System.out.println(Arrays.toString(fVarExp)); + // System.out.println(Arrays.toString(fVarAct)); assertArrayEquals(fMeanExp, fMeanAct, 0.02f); assertArrayEquals(fVarExp, fVarAct, 0.02f); } } - @Test - public void checkMeanVarianceEstimateCNN() throws Exception { - - for(boolean useLogStd : new boolean[]{true, false}) { + @DisplayName("Check Mean Variance Estimate CNN") + void checkMeanVarianceEstimateCNN() throws Exception { + for (boolean useLogStd : new boolean[] { true, false }) { Nd4j.getRandom().setSeed(12345); - //Check that the internal global mean/variance estimate is approximately correct - - //First, Mnist data as 2d input (NOT taking into account convolution property) - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .updater(Updater.RMSPROP).seed(12345).list() - .layer(0, new BatchNormalization.Builder().nIn(3).nOut(3).eps(1e-5).decay(0.95).useLogStd(useLogStd).build()) - .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MSE).weightInit(WeightInit.XAVIER) - .activation(Activation.IDENTITY).nOut(10).build()) - .setInputType(InputType.convolutional(5, 5, 3)).build(); + // Check that the internal global mean/variance estimate is approximately correct + // First, Mnist data as 2d input (NOT taking into account convolution property) + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.RMSPROP).seed(12345).list().layer(0, new BatchNormalization.Builder().nIn(3).nOut(3).eps(1e-5).decay(0.95).useLogStd(useLogStd).build()).layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MSE).weightInit(WeightInit.XAVIER).activation(Activation.IDENTITY).nOut(10).build()).setInputType(InputType.convolutional(5, 5, 3)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - int minibatch = 32; List list = new ArrayList<>(); for (int i = 0; i < 100; i++) { - list.add(new DataSet(Nd4j.rand(new int[]{minibatch, 3, 5, 5}), Nd4j.rand(minibatch, 10))); + list.add(new DataSet(Nd4j.rand(new int[] { minibatch, 3, 5, 5 }), Nd4j.rand(minibatch, 10))); } - DataSetIterator iter = new ListDataSetIterator(list); - - INDArray expMean = Nd4j.valueArrayOf(new int[]{1, 3}, 0.5); - INDArray expVar = Nd4j.valueArrayOf(new int[]{1, 3}, 1 / 12.0); //Expected variance of U(0,1) distribution: 1/12 * (1-0)^2 = 0.0833 - - + INDArray expMean = Nd4j.valueArrayOf(new int[] { 1, 3 }, 0.5); + // Expected variance of U(0,1) distribution: 1/12 * (1-0)^2 = 0.0833 + INDArray expVar = Nd4j.valueArrayOf(new int[] { 1, 3 }, 1 / 12.0); for (int i = 0; i < 10; i++) { iter.reset(); net.fit(iter); } - INDArray estMean = net.getLayer(0).getParam(BatchNormalizationParamInitializer.GLOBAL_MEAN); INDArray estVar; - if(useLogStd){ + if (useLogStd) { INDArray log10std = net.getLayer(0).getParam(BatchNormalizationParamInitializer.GLOBAL_LOG_STD); estVar = Nd4j.valueArrayOf(log10std.shape(), 10.0).castTo(log10std.dataType()); - Transforms.pow(estVar, log10std, false); // stdev = 10^(log10(stdev)) + // stdev = 10^(log10(stdev)) + Transforms.pow(estVar, log10std, false); estVar.muli(estVar); } else { estVar = net.getLayer(0).getParam(BatchNormalizationParamInitializer.GLOBAL_VAR); } - float[] fMeanExp = expMean.data().asFloat(); float[] fMeanAct = estMean.data().asFloat(); float[] fVarExp = expVar.data().asFloat(); float[] fVarAct = estVar.data().asFloat(); - - // System.out.println("Mean vs. estimated mean:"); - // System.out.println(Arrays.toString(fMeanExp)); - // System.out.println(Arrays.toString(fMeanAct)); - // - // System.out.println("Var vs. estimated var:"); - // System.out.println(Arrays.toString(fVarExp)); - // System.out.println(Arrays.toString(fVarAct)); - + // System.out.println("Mean vs. estimated mean:"); + // System.out.println(Arrays.toString(fMeanExp)); + // System.out.println(Arrays.toString(fMeanAct)); + // + // System.out.println("Var vs. estimated var:"); + // System.out.println(Arrays.toString(fVarExp)); + // System.out.println(Arrays.toString(fVarAct)); assertArrayEquals(fMeanExp, fMeanAct, 0.01f); assertArrayEquals(fVarExp, fVarAct, 0.01f); } } @Test - public void checkMeanVarianceEstimateCNNCompareModes() throws Exception { - + @DisplayName("Check Mean Variance Estimate CNN Compare Modes") + void checkMeanVarianceEstimateCNNCompareModes() throws Exception { Nd4j.getRandom().setSeed(12345); - //Check that the internal global mean/variance estimate is approximately correct - - //First, Mnist data as 2d input (NOT taking into account convolution property) - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .updater(Updater.RMSPROP).seed(12345).list() - .layer(0, new BatchNormalization.Builder().nIn(3).nOut(3).eps(1e-5).decay(0.95).useLogStd(false).build()) - .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MSE).weightInit(WeightInit.XAVIER) - .activation(Activation.IDENTITY).nOut(10).build()) - .setInputType(InputType.convolutional(5, 5, 3)).build(); + // Check that the internal global mean/variance estimate is approximately correct + // First, Mnist data as 2d input (NOT taking into account convolution property) + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.RMSPROP).seed(12345).list().layer(0, new BatchNormalization.Builder().nIn(3).nOut(3).eps(1e-5).decay(0.95).useLogStd(false).build()).layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MSE).weightInit(WeightInit.XAVIER).activation(Activation.IDENTITY).nOut(10).build()).setInputType(InputType.convolutional(5, 5, 3)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .updater(Updater.RMSPROP).seed(12345).list() - .layer(0, new BatchNormalization.Builder().nIn(3).nOut(3).eps(1e-5).decay(0.95).useLogStd(true).build()) - .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MSE).weightInit(WeightInit.XAVIER) - .activation(Activation.IDENTITY).nOut(10).build()) - .setInputType(InputType.convolutional(5, 5, 3)).build(); + MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.RMSPROP).seed(12345).list().layer(0, new BatchNormalization.Builder().nIn(3).nOut(3).eps(1e-5).decay(0.95).useLogStd(true).build()).layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MSE).weightInit(WeightInit.XAVIER).activation(Activation.IDENTITY).nOut(10).build()).setInputType(InputType.convolutional(5, 5, 3)).build(); MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); net2.init(); - int minibatch = 32; for (int i = 0; i < 10; i++) { - DataSet ds = new DataSet(Nd4j.rand(new int[]{minibatch, 3, 5, 5}), Nd4j.rand(minibatch, 10)); + DataSet ds = new DataSet(Nd4j.rand(new int[] { minibatch, 3, 5, 5 }), Nd4j.rand(minibatch, 10)); net.fit(ds); net2.fit(ds); - INDArray globalVar = net.getParam("0_" + BatchNormalizationParamInitializer.GLOBAL_VAR); - INDArray log10std = net2.getParam("0_" + BatchNormalizationParamInitializer.GLOBAL_LOG_STD); INDArray globalVar2 = Nd4j.valueArrayOf(log10std.shape(), 10.0).castTo(log10std.dataType()); - Transforms.pow(globalVar2, log10std, false); // stdev = 10^(log10(stdev)) + // stdev = 10^(log10(stdev)) + Transforms.pow(globalVar2, log10std, false); globalVar2.muli(globalVar2); - assertEquals(globalVar, globalVar2); } } - @Test - public void testBatchNorm() throws Exception { - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .seed(12345) - .updater(new Adam(1e-3)) - .activation(Activation.TANH) - .list() - .layer(new ConvolutionLayer.Builder().nOut(5).kernelSize(2, 2).build()) - .layer(new BatchNormalization()) - .layer(new ConvolutionLayer.Builder().nOut(5).kernelSize(2, 2).build()) - .layer(new OutputLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).nOut(10).build()) - .setInputType(InputType.convolutionalFlat(28, 28, 1)) - .build(); - + @DisplayName("Test Batch Norm") + void testBatchNorm() throws Exception { + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(new Adam(1e-3)).activation(Activation.TANH).list().layer(new ConvolutionLayer.Builder().nOut(5).kernelSize(2, 2).build()).layer(new BatchNormalization()).layer(new ConvolutionLayer.Builder().nOut(5).kernelSize(2, 2).build()).layer(new OutputLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).nOut(10).build()).setInputType(InputType.convolutionalFlat(28, 28, 1)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - DataSetIterator iter = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(32, true, 12345), 10); - net.fit(iter); - - MultiLayerNetwork net2 = new TransferLearning.Builder(net) - .fineTuneConfiguration(FineTuneConfiguration.builder() - .updater(new AdaDelta()) - .build()) - .removeOutputLayer() - .addLayer(new BatchNormalization.Builder().nOut(3380).build()) - .addLayer(new OutputLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).nIn(3380).nOut(10).build()) - .build(); - + MultiLayerNetwork net2 = new TransferLearning.Builder(net).fineTuneConfiguration(FineTuneConfiguration.builder().updater(new AdaDelta()).build()).removeOutputLayer().addLayer(new BatchNormalization.Builder().nOut(3380).build()).addLayer(new OutputLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).nIn(3380).nOut(10).build()).build(); net2.fit(iter); } @Test - public void testBatchNormRecurrentCnn1d() { - //Simple sanity check on CNN1D and RNN layers - - for (boolean rnn : new boolean[]{true, false}) { - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .seed(12345) - .weightInit(WeightInit.XAVIER) - .convolutionMode(ConvolutionMode.Same) - .list() - .layer(rnn ? new LSTM.Builder().nOut(3).build() : - new Convolution1DLayer.Builder().kernelSize(3).stride(1).nOut(3).build()) - .layer(new BatchNormalization()) - .layer(new RnnOutputLayer.Builder().nOut(3).activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build()) - .setInputType(InputType.recurrent(3)) - .build(); - + @DisplayName("Test Batch Norm Recurrent Cnn 1 d") + void testBatchNormRecurrentCnn1d() { + // Simple sanity check on CNN1D and RNN layers + for (boolean rnn : new boolean[] { true, false }) { + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).weightInit(WeightInit.XAVIER).convolutionMode(ConvolutionMode.Same).list().layer(rnn ? new LSTM.Builder().nOut(3).build() : new Convolution1DLayer.Builder().kernelSize(3).stride(1).nOut(3).build()).layer(new BatchNormalization()).layer(new RnnOutputLayer.Builder().nOut(3).activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build()).setInputType(InputType.recurrent(3)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - - INDArray in = Nd4j.rand(new int[]{1, 3, 5}); - INDArray label = Nd4j.rand(new int[]{1, 3, 5}); - + INDArray in = Nd4j.rand(new int[] { 1, 3, 5 }); + INDArray label = Nd4j.rand(new int[] { 1, 3, 5 }); INDArray out = net.output(in); - assertArrayEquals(new long[]{1, 3, 5}, out.shape()); - + assertArrayEquals(new long[] { 1, 3, 5 }, out.shape()); net.fit(in, label); log.info("OK: {}", (rnn ? "rnn" : "cnn1d")); } } @Test - public void testInputValidation() { - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .list() - .layer(new BatchNormalization.Builder().nIn(10).nOut(10).build()) - .build(); - + @DisplayName("Test Input Validation") + void testInputValidation() { + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(new BatchNormalization.Builder().nIn(10).nOut(10).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - INDArray in1 = Nd4j.create(1, 10); INDArray in2 = Nd4j.create(1, 5); - INDArray out1 = net.output(in1); try { INDArray out2 = net.output(in2); @@ -781,4 +568,4 @@ public class BatchNormalizationTest extends BaseDL4JTest { assertTrue(e.getMessage().contains("expected input")); } } -} \ No newline at end of file +} diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/normalization/LocalResponseTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/normalization/LocalResponseTest.java index 41e0da315..7c7accec3 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/normalization/LocalResponseTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/normalization/LocalResponseTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.layers.normalization; import org.deeplearning4j.BaseDL4JTest; @@ -35,8 +34,8 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; @@ -45,92 +44,47 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.common.primitives.Pair; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; - -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** - * */ -public class LocalResponseTest extends BaseDL4JTest { +@DisplayName("Local Response Test") +class LocalResponseTest extends BaseDL4JTest { - private INDArray x = Nd4j.create(new double[] {0.88128096, -0.96666986, -0.61832994, 0.26418415, 0.05694608, - 0.2950289, 0.99222249, 0.24541704, 0.4219842, 0.96430975, 0.19299535, -0.06658337, -0.27603117, - 0.24216647, 0.21834095, 0.03863283, -0.82313406, -0.37236378, -0.77667993, 0.66295379, -0.34406275, - -0.25924176, 0.26652309, -0.58964926, -0.46907067, 0.34666502, 0.81208313, -0.17042427, -0.22470538, - 0.8348338, 0.50494033, 0.45004508, 0.58735144, -0.87217808, -0.74788797, -0.04363599, 0.72276866, - 0.52476895, -0.52383977, 0.1311436, 0.2628099, 0.77274454, 0.86400729, -0.35246921, -0.03399619, - -0.502312, 0.42834607, 0.85534132, 0.90083021, 0.24571614, 0.63058525, -0.82919437, 0.57236177, - -0.0913529, -0.7102778, 0.81631756, -0.89004314, 0.43995622, -0.26112801, -0.76135367, 0.65180862, - -0.54667377, 0.94908774, 0.59298772, 0.36457643, 0.58892179, -0.52951556, 0.31559938, -0.55268252, - 0.8272332, 0.37911707, -0.96299696, -0.40717798, 0.43324658, 0.2589654, -0.15605508, 0.96334064, - -0.31666604, 0.19781154, 0.09908111, 0.64796048, -0.99037546, 0.67919868, 0.43810204}, - new int[] {2, 7, 3, 2}); + private INDArray x = Nd4j.create(new double[] { 0.88128096, -0.96666986, -0.61832994, 0.26418415, 0.05694608, 0.2950289, 0.99222249, 0.24541704, 0.4219842, 0.96430975, 0.19299535, -0.06658337, -0.27603117, 0.24216647, 0.21834095, 0.03863283, -0.82313406, -0.37236378, -0.77667993, 0.66295379, -0.34406275, -0.25924176, 0.26652309, -0.58964926, -0.46907067, 0.34666502, 0.81208313, -0.17042427, -0.22470538, 0.8348338, 0.50494033, 0.45004508, 0.58735144, -0.87217808, -0.74788797, -0.04363599, 0.72276866, 0.52476895, -0.52383977, 0.1311436, 0.2628099, 0.77274454, 0.86400729, -0.35246921, -0.03399619, -0.502312, 0.42834607, 0.85534132, 0.90083021, 0.24571614, 0.63058525, -0.82919437, 0.57236177, -0.0913529, -0.7102778, 0.81631756, -0.89004314, 0.43995622, -0.26112801, -0.76135367, 0.65180862, -0.54667377, 0.94908774, 0.59298772, 0.36457643, 0.58892179, -0.52951556, 0.31559938, -0.55268252, 0.8272332, 0.37911707, -0.96299696, -0.40717798, 0.43324658, 0.2589654, -0.15605508, 0.96334064, -0.31666604, 0.19781154, 0.09908111, 0.64796048, -0.99037546, 0.67919868, 0.43810204 }, new int[] { 2, 7, 3, 2 }); - private INDArray activationsExpected = Nd4j.create(new double[] {0.52397668, -0.57476264, -0.3676528, 0.15707894, - 0.03385943, 0.17542371, 0.58992499, 0.14591768, 0.25090647, 0.57335907, 0.11475233, -0.03958985, - -0.16411273, 0.14398433, 0.12981956, 0.02297027, -0.48942304, -0.22139823, -0.46177959, 0.39418164, - -0.20457059, -0.15413573, 0.15846729, -0.3505919, -0.27889356, 0.20611978, 0.48284137, -0.10133155, - -0.13360347, 0.49636194, 0.30022132, 0.26758799, 0.34922296, -0.51858318, -0.4446843, -0.02594452, - 0.42974478, 0.31202248, -0.31146204, 0.07797609, 0.15626372, 0.4594543, 0.51370209, -0.20957276, - -0.02021335, -0.29866382, 0.25469059, 0.50856382, 0.53558689, 0.14609739, 0.37491882, -0.49301448, - 0.34031925, -0.05431537, -0.42228988, 0.48536259, -0.52917528, 0.26157826, -0.15526266, -0.45265958, - 0.38753596, -0.32503816, 0.56427884, 0.35256693, 0.21676543, 0.35014921, -0.31483513, 0.18764766, - -0.32859638, 0.49183461, 0.22540972, -0.57255536, -0.24210122, 0.25760418, 0.15397197, -0.0927838, - 0.57277, -0.18827969, 0.1176173, 0.05891332, 0.38526815, -0.58884346, 0.40383074, 0.26048511}, - new int[] {2, 7, 3, 2}); + private INDArray activationsExpected = Nd4j.create(new double[] { 0.52397668, -0.57476264, -0.3676528, 0.15707894, 0.03385943, 0.17542371, 0.58992499, 0.14591768, 0.25090647, 0.57335907, 0.11475233, -0.03958985, -0.16411273, 0.14398433, 0.12981956, 0.02297027, -0.48942304, -0.22139823, -0.46177959, 0.39418164, -0.20457059, -0.15413573, 0.15846729, -0.3505919, -0.27889356, 0.20611978, 0.48284137, -0.10133155, -0.13360347, 0.49636194, 0.30022132, 0.26758799, 0.34922296, -0.51858318, -0.4446843, -0.02594452, 0.42974478, 0.31202248, -0.31146204, 0.07797609, 0.15626372, 0.4594543, 0.51370209, -0.20957276, -0.02021335, -0.29866382, 0.25469059, 0.50856382, 0.53558689, 0.14609739, 0.37491882, -0.49301448, 0.34031925, -0.05431537, -0.42228988, 0.48536259, -0.52917528, 0.26157826, -0.15526266, -0.45265958, 0.38753596, -0.32503816, 0.56427884, 0.35256693, 0.21676543, 0.35014921, -0.31483513, 0.18764766, -0.32859638, 0.49183461, 0.22540972, -0.57255536, -0.24210122, 0.25760418, 0.15397197, -0.0927838, 0.57277, -0.18827969, 0.1176173, 0.05891332, 0.38526815, -0.58884346, 0.40383074, 0.26048511 }, new int[] { 2, 7, 3, 2 }); - private INDArray epsilon = Nd4j.create(new double[] {-0.13515499, 0.96470547, -0.62253004, 0.80172491, -0.97510445, - -0.41198033, -0.4790071, 0.07551047, -0.01383764, -0.05797465, 0.21242172, 0.7145375, -0.17809176, - -0.11465316, -0.2066526, 0.21950938, 0.4627091, 0.30275798, 0.61443841, 0.75912178, -0.132248, - -0.82923287, 0.74962652, -0.88993639, 0.04406403, 0.32096064, -0.46400586, 0.1603231, 0.63007826, - 0.10626783, 0.08009516, 0.88297033, 0.11441587, 0.35862735, 0.40441504, -0.60132015, 0.87743825, - 0.09792926, 0.92742652, 0.6182847, -0.9602651, -0.19611064, 0.15762019, 0.00339905, -0.9238292, - 0.02451134, -0.44294646, -0.5450229, 0.87502575, -0.59481794, 0.65259099, -0.77772689, 0.53300053, - 0.11541174, 0.32667685, 0.99437004, -0.04084824, -0.45166185, 0.29513556, 0.53582036, 0.95541358, - -0.75714606, -0.63295805, -0.70315111, -0.6553846, -0.78824568, 0.84295344, -0.38352135, - -0.04541624, 0.17396702, 0.41530582, 0.11870354, 0.85787249, -0.94597596, 0.05792254, 0.04811822, - 0.04847952, -0.82953823, 0.8089835, 0.50185651, -0.88619858, -0.78598201, 0.27489874, 0.63673472}, - new int[] {2, 7, 3, 2}); + private INDArray epsilon = Nd4j.create(new double[] { -0.13515499, 0.96470547, -0.62253004, 0.80172491, -0.97510445, -0.41198033, -0.4790071, 0.07551047, -0.01383764, -0.05797465, 0.21242172, 0.7145375, -0.17809176, -0.11465316, -0.2066526, 0.21950938, 0.4627091, 0.30275798, 0.61443841, 0.75912178, -0.132248, -0.82923287, 0.74962652, -0.88993639, 0.04406403, 0.32096064, -0.46400586, 0.1603231, 0.63007826, 0.10626783, 0.08009516, 0.88297033, 0.11441587, 0.35862735, 0.40441504, -0.60132015, 0.87743825, 0.09792926, 0.92742652, 0.6182847, -0.9602651, -0.19611064, 0.15762019, 0.00339905, -0.9238292, 0.02451134, -0.44294646, -0.5450229, 0.87502575, -0.59481794, 0.65259099, -0.77772689, 0.53300053, 0.11541174, 0.32667685, 0.99437004, -0.04084824, -0.45166185, 0.29513556, 0.53582036, 0.95541358, -0.75714606, -0.63295805, -0.70315111, -0.6553846, -0.78824568, 0.84295344, -0.38352135, -0.04541624, 0.17396702, 0.41530582, 0.11870354, 0.85787249, -0.94597596, 0.05792254, 0.04811822, 0.04847952, -0.82953823, 0.8089835, 0.50185651, -0.88619858, -0.78598201, 0.27489874, 0.63673472 }, new int[] { 2, 7, 3, 2 }); - private INDArray newEpsilonExpected = Nd4j.create(new double[] {-0.08033668, 0.57355404, -0.37014094, 0.47668865, - -0.57978398, -0.24495915, -0.28474802, 0.04490108, -0.00823483, -0.03448687, 0.12630466, 0.42485803, - -0.10589627, -0.06816553, -0.12287001, 0.13051508, 0.27510744, 0.18001786, 0.36528736, 0.45133191, - -0.07863599, -0.49303374, 0.44571424, -0.52912313, 0.02620371, 0.19082049, -0.27585581, 0.09532529, - 0.3746179, 0.06316902, 0.04761803, 0.52497554, 0.06804816, 0.21323238, 0.24044329, -0.35752413, - 0.52168733, 0.05821467, 0.55140609, 0.3676247, -0.57095432, -0.11660115, 0.09367896, 0.00202246, - -0.54928631, 0.01455687, -0.26336867, -0.3240425, 0.52023786, -0.35366109, 0.3879728, -0.46243483, - 0.31692421, 0.06862034, 0.19421607, 0.59124804, -0.0242459, -0.26852599, 0.17547797, 0.31857637, - 0.56804365, -0.45020312, -0.37634474, -0.41804832, -0.38966343, -0.4686695, 0.50119156, -0.22802454, - -0.02698562, 0.10343311, 0.24693431, 0.0706142, 0.5100745, -0.56245267, 0.03443092, 0.02860913, - 0.02883426, -0.49320197, 0.4810102, 0.29840365, -0.5269345, -0.46732581, 0.16344811, 0.37857518}, - new int[] {2, 7, 3, 2}); + private INDArray newEpsilonExpected = Nd4j.create(new double[] { -0.08033668, 0.57355404, -0.37014094, 0.47668865, -0.57978398, -0.24495915, -0.28474802, 0.04490108, -0.00823483, -0.03448687, 0.12630466, 0.42485803, -0.10589627, -0.06816553, -0.12287001, 0.13051508, 0.27510744, 0.18001786, 0.36528736, 0.45133191, -0.07863599, -0.49303374, 0.44571424, -0.52912313, 0.02620371, 0.19082049, -0.27585581, 0.09532529, 0.3746179, 0.06316902, 0.04761803, 0.52497554, 0.06804816, 0.21323238, 0.24044329, -0.35752413, 0.52168733, 0.05821467, 0.55140609, 0.3676247, -0.57095432, -0.11660115, 0.09367896, 0.00202246, -0.54928631, 0.01455687, -0.26336867, -0.3240425, 0.52023786, -0.35366109, 0.3879728, -0.46243483, 0.31692421, 0.06862034, 0.19421607, 0.59124804, -0.0242459, -0.26852599, 0.17547797, 0.31857637, 0.56804365, -0.45020312, -0.37634474, -0.41804832, -0.38966343, -0.4686695, 0.50119156, -0.22802454, -0.02698562, 0.10343311, 0.24693431, 0.0706142, 0.5100745, -0.56245267, 0.03443092, 0.02860913, 0.02883426, -0.49320197, 0.4810102, 0.29840365, -0.5269345, -0.46732581, 0.16344811, 0.37857518 }, new int[] { 2, 7, 3, 2 }); private INDArray activationsActual; + private Layer layer; - @Before - public void doBefore() { - NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder() - .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).seed(123) - .layer(new LocalResponseNormalization.Builder().k(2).n(5).alpha(1e-4).beta(0.75).build()) - .build(); - + @BeforeEach + void doBefore() { + NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).seed(123).layer(new LocalResponseNormalization.Builder().k(2).n(5).alpha(1e-4).beta(0.75).build()).build(); layer = new LocalResponseNormalization().instantiate(conf, null, 0, null, false, Nd4j.defaultFloatingPointType()); activationsActual = layer.activate(x, false, LayerWorkspaceMgr.noWorkspaces()); } @Test - public void testActivate() { + @DisplayName("Test Activate") + void testActivate() { // Precision is off from the expected results because expected results generated in numpy assertEquals(activationsExpected, activationsActual); assertArrayEquals(activationsExpected.shape(), activationsActual.shape()); } @Test - public void testBackpropGradient() { + @DisplayName("Test Backprop Gradient") + void testBackpropGradient() { Pair containedOutput = layer.backpropGradient(epsilon, LayerWorkspaceMgr.noWorkspaces()); - assertEquals(newEpsilonExpected.getDouble(8), containedOutput.getSecond().getDouble(8), 1e-4); assertEquals(newEpsilonExpected.getDouble(20), containedOutput.getSecond().getDouble(20), 1e-4); assertEquals(null, containedOutput.getFirst().getGradientFor("W")); @@ -138,53 +92,35 @@ public class LocalResponseTest extends BaseDL4JTest { } @Test - public void testRegularization() { + @DisplayName("Test Regularization") + void testRegularization() { // Confirm a structure with regularization true will not throw an error - - NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder() - .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).l1(0.2) - .l2(0.1).seed(123) - .layer(new LocalResponseNormalization.Builder().k(2).n(5).alpha(1e-4).beta(0.75).build()) - .build(); + NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).l1(0.2).l2(0.1).seed(123).layer(new LocalResponseNormalization.Builder().k(2).n(5).alpha(1e-4).beta(0.75).build()).build(); } @Test - public void testMultiCNNLayer() throws Exception { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).seed(123).list() - .layer(0, new ConvolutionLayer.Builder().nIn(1).nOut(6).weightInit(WeightInit.XAVIER) - .activation(Activation.RELU).build()) - .layer(1, new LocalResponseNormalization.Builder().build()).layer(2, - new DenseLayer.Builder() - .nOut(2).build()) - .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nIn(2).nOut(10) - .build()) - .setInputType(InputType.convolutionalFlat(28, 28, 1)).build(); - + @DisplayName("Test Multi CNN Layer") + void testMultiCNNLayer() throws Exception { + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).seed(123).list().layer(0, new ConvolutionLayer.Builder().nIn(1).nOut(6).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build()).layer(1, new LocalResponseNormalization.Builder().build()).layer(2, new DenseLayer.Builder().nOut(2).build()).layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nIn(2).nOut(10).build()).setInputType(InputType.convolutionalFlat(28, 28, 1)).build(); MultiLayerNetwork network = new MultiLayerNetwork(conf); network.init(); DataSetIterator iter = new MnistDataSetIterator(2, 2); DataSet next = iter.next(); - network.fit(next); } - @Test - public void testLrnManual() { + @DisplayName("Test Lrn Manual") + void testLrnManual() { int wh = 5; int depth = 6; int minibatch = 3; - int n = 4; double k = 2.0; double alpha = 1e-4; double beta = 0.75; - - INDArray in = Nd4j.rand(new int[] {minibatch, depth, wh, wh}); + INDArray in = Nd4j.rand(new int[] { minibatch, depth, wh, wh }); INDArray outExp = Nd4j.zeros(minibatch, depth, wh, wh); - for (int m = 0; m < minibatch; m++) { for (int x = 0; x < wh; x++) { for (int y = 0; y < wh; y++) { @@ -202,16 +138,10 @@ public class LocalResponseTest extends BaseDL4JTest { } } } - LocalResponseNormalization lrn = new LocalResponseNormalization.Builder().build(); NeuralNetConfiguration nnc = new NeuralNetConfiguration.Builder().layer(lrn).build(); - org.deeplearning4j.nn.layers.normalization.LocalResponseNormalization layer = - (org.deeplearning4j.nn.layers.normalization.LocalResponseNormalization) lrn.instantiate(nnc, - null, 0, null, false, Nd4j.defaultFloatingPointType()); - + org.deeplearning4j.nn.layers.normalization.LocalResponseNormalization layer = (org.deeplearning4j.nn.layers.normalization.LocalResponseNormalization) lrn.instantiate(nnc, null, 0, null, false, Nd4j.defaultFloatingPointType()); INDArray outAct = layer.activate(in, true, LayerWorkspaceMgr.noWorkspaces()); - assertEquals(outExp, outAct); } - } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/ocnn/OCNNOutputLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/ocnn/OCNNOutputLayerTest.java index 4033112ae..4d0f9bc66 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/ocnn/OCNNOutputLayerTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/ocnn/OCNNOutputLayerTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.layers.ocnn; import org.deeplearning4j.BaseDL4JTest; @@ -31,8 +30,8 @@ import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.deeplearning4j.util.ModelSerializer; import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; import org.nd4j.linalg.activations.impl.ActivationIdentity; import org.nd4j.linalg.activations.impl.ActivationReLU; import org.nd4j.linalg.activations.impl.ActivationSigmoid; @@ -48,118 +47,99 @@ import org.nd4j.linalg.learning.config.Nesterovs; import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.schedule.ScheduleType; import org.nd4j.linalg.schedule.StepSchedule; - import java.io.File; import java.util.UUID; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import org.junit.jupiter.api.DisplayName; +import java.nio.file.Path; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; - - -public class OCNNOutputLayerTest extends BaseDL4JTest { +@DisplayName("Ocnn Output Layer Test") +class OCNNOutputLayerTest extends BaseDL4JTest { private static final boolean PRINT_RESULTS = true; + private static final boolean RETURN_ON_FIRST_FAILURE = false; + private static final double DEFAULT_EPS = 1e-6; + private static final double DEFAULT_MAX_REL_ERROR = 1e-3; + private static final double DEFAULT_MIN_ABS_ERROR = 1e-8; - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); + + @TempDir + public Path testDir; + static { Nd4j.setDataType(DataType.DOUBLE); } - @Test - public void testLayer() { + @DisplayName("Test Layer") + void testLayer() { DataSetIterator dataSetIterator = getNormalizedIterator(); boolean doLearningFirst = true; MultiLayerNetwork network = getGradientCheckNetwork(2); - - DataSet ds = dataSetIterator.next(); INDArray arr = ds.getFeatures(); network.setInput(arr); - if (doLearningFirst) { - //Run a number of iterations of learning + // Run a number of iterations of learning network.setInput(arr); network.setListeners(new ScoreIterationListener(1)); network.computeGradientAndScore(); double scoreBefore = network.score(); - for (int j = 0; j < 10; j++) - network.fit(ds); + for (int j = 0; j < 10; j++) network.fit(ds); network.computeGradientAndScore(); double scoreAfter = network.score(); - //Can't test in 'characteristic mode of operation' if not learning - String msg = "testLayer() - score did not (sufficiently) decrease during learning - activationFn=" - + "relu" + ", lossFn=" + "ocnn" + ", " + "sigmoid" - + ", doLearningFirst=" + doLearningFirst + " (before=" + scoreBefore - + ", scoreAfter=" + scoreAfter + ")"; - // assertTrue(msg, scoreAfter < scoreBefore); + // Can't test in 'characteristic mode of operation' if not learning + String msg = "testLayer() - score did not (sufficiently) decrease during learning - activationFn=" + "relu" + ", lossFn=" + "ocnn" + ", " + "sigmoid" + ", doLearningFirst=" + doLearningFirst + " (before=" + scoreBefore + ", scoreAfter=" + scoreAfter + ")"; + // assertTrue(msg, scoreAfter < scoreBefore); } - if (PRINT_RESULTS) { - System.out.println("testLayer() - activationFn=" + "relu" + ", lossFn=" - + "ocnn" + "sigmoid" + ", doLearningFirst=" - + doLearningFirst); - for (int j = 0; j < network.getnLayers(); j++) - System.out.println("Layer " + j + " # params: " + network.getLayer(j).numParams()); + System.out.println("testLayer() - activationFn=" + "relu" + ", lossFn=" + "ocnn" + "sigmoid" + ", doLearningFirst=" + doLearningFirst); + for (int j = 0; j < network.getnLayers(); j++) System.out.println("Layer " + j + " # params: " + network.getLayer(j).numParams()); } - - boolean gradOK = GradientCheckUtil.checkGradients(network, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, ds.getFeatures(), ds.getLabels()); - - String msg = "testLayer() - activationFn=" + "relu" + ", lossFn=" + "ocnn" - + ",=" + "sigmoid" + ", doLearningFirst=" + doLearningFirst; - assertTrue(msg, gradOK); - - - + boolean gradOK = GradientCheckUtil.checkGradients(network, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, ds.getFeatures(), ds.getLabels()); + String msg = "testLayer() - activationFn=" + "relu" + ", lossFn=" + "ocnn" + ",=" + "sigmoid" + ", doLearningFirst=" + doLearningFirst; + assertTrue(gradOK,msg); } - @Test - public void testLabelProbabilities() throws Exception { + @DisplayName("Test Label Probabilities") + void testLabelProbabilities() throws Exception { Nd4j.getRandom().setSeed(42); DataSetIterator dataSetIterator = getNormalizedIterator(); MultiLayerNetwork network = getSingleLayer(); DataSet next = dataSetIterator.next(); - DataSet filtered = next.filterBy(new int[]{0, 1}); + DataSet filtered = next.filterBy(new int[] { 0, 1 }); for (int i = 0; i < 10; i++) { network.setEpochCount(i); network.getLayerWiseConfigurations().setEpochCount(i); network.fit(filtered); } - - DataSet anomalies = next.filterBy(new int[] {2}); + DataSet anomalies = next.filterBy(new int[] { 2 }); INDArray output = network.output(anomalies.getFeatures()); - INDArray normalOutput = network.output(anomalies.getFeatures(),false); - assertEquals(output.lt(0.0).castTo(Nd4j.defaultFloatingPointType()).sumNumber().doubleValue(), - normalOutput.eq(0.0).castTo(Nd4j.defaultFloatingPointType()).sumNumber().doubleValue(),1e-1); - -// System.out.println("Labels " + anomalies.getLabels()); -// System.out.println("Anomaly output " + normalOutput); -// System.out.println(output); - + INDArray normalOutput = network.output(anomalies.getFeatures(), false); + assertEquals(output.lt(0.0).castTo(Nd4j.defaultFloatingPointType()).sumNumber().doubleValue(), normalOutput.eq(0.0).castTo(Nd4j.defaultFloatingPointType()).sumNumber().doubleValue(), 1e-1); + // System.out.println("Labels " + anomalies.getLabels()); + // System.out.println("Anomaly output " + normalOutput); + // System.out.println(output); INDArray normalProbs = network.output(filtered.getFeatures()); - INDArray outputForNormalSamples = network.output(filtered.getFeatures(),false); + INDArray outputForNormalSamples = network.output(filtered.getFeatures(), false); System.out.println("Normal probabilities " + normalProbs); System.out.println("Normal raw output " + outputForNormalSamples); - - File tmpFile = new File(testDir.getRoot(),"tmp-file-" + UUID.randomUUID().toString()); - ModelSerializer.writeModel(network,tmpFile,true); + File tmpFile = new File(testDir.toFile(), "tmp-file-" + UUID.randomUUID().toString()); + ModelSerializer.writeModel(network, tmpFile, true); tmpFile.deleteOnExit(); - MultiLayerNetwork multiLayerNetwork = ModelSerializer.restoreMultiLayerNetwork(tmpFile); - assertEquals(network.params(),multiLayerNetwork.params()); - assertEquals(network.numParams(),multiLayerNetwork.numParams()); - + assertEquals(network.params(), multiLayerNetwork.params()); + assertEquals(network.numParams(), multiLayerNetwork.numParams()); } - public DataSetIterator getNormalizedIterator() { - DataSetIterator dataSetIterator = new IrisDataSetIterator(150,150); + DataSetIterator dataSetIterator = new IrisDataSetIterator(150, 150); NormalizerStandardize normalizerStandardize = new NormalizerStandardize(); normalizerStandardize.fit(dataSetIterator); dataSetIterator.reset(); @@ -169,42 +149,15 @@ public class OCNNOutputLayerTest extends BaseDL4JTest { private MultiLayerNetwork getSingleLayer() { int numHidden = 2; - - MultiLayerConfiguration configuration = new NeuralNetConfiguration.Builder() - .seed(12345) - .weightInit(WeightInit.XAVIER) - .miniBatch(true) - .updater(new Adam(0.1)) -// .updater(Nesterovs.builder() -// .momentum(0.1) -// .learningRateSchedule(new StepSchedule( -// ScheduleType.EPOCH, -// 1e-2, -// 0.1, -// 20)).build()) - .list(new DenseLayer.Builder().activation(new ActivationReLU()) - .nIn(4).nOut(2).build(), - new org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer.Builder() - .nIn(2).activation(new ActivationSigmoid()).initialRValue(0.1) - .nu(0.1) - .hiddenLayerSize(numHidden).build()) - .build(); + MultiLayerConfiguration configuration = new NeuralNetConfiguration.Builder().seed(12345).weightInit(WeightInit.XAVIER).miniBatch(true).updater(new Adam(0.1)).list(new DenseLayer.Builder().activation(new ActivationReLU()).nIn(4).nOut(2).build(), new org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer.Builder().nIn(2).activation(new ActivationSigmoid()).initialRValue(0.1).nu(0.1).hiddenLayerSize(numHidden).build()).build(); MultiLayerNetwork network = new MultiLayerNetwork(configuration); network.init(); network.setListeners(new ScoreIterationListener(1)); return network; } - public MultiLayerNetwork getGradientCheckNetwork(int numHidden) { - MultiLayerConfiguration configuration = new NeuralNetConfiguration.Builder() - .dataType(DataType.DOUBLE) - .seed(42).updater(new NoOp()).miniBatch(false) - .list(new DenseLayer.Builder().activation(new ActivationIdentity()).nIn(4).nOut(4).build(), - new org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer.Builder().nIn(4) - .nu(0.002).activation(new ActivationSigmoid()) - .hiddenLayerSize(numHidden).build()) - .build(); + MultiLayerConfiguration configuration = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).seed(42).updater(new NoOp()).miniBatch(false).list(new DenseLayer.Builder().activation(new ActivationIdentity()).nIn(4).nOut(4).build(), new org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer.Builder().nIn(4).nu(0.002).activation(new ActivationSigmoid()).hiddenLayerSize(numHidden).build()).build(); MultiLayerNetwork network = new MultiLayerNetwork(configuration); network.init(); return network; diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/BidirectionalTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/BidirectionalTest.java index aa16f53ff..d8a95c452 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/BidirectionalTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/BidirectionalTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.layers.recurrent; import lombok.extern.slf4j.Slf4j; @@ -45,7 +44,7 @@ import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.util.ModelSerializer; import org.deeplearning4j.util.TimeSeriesUtils; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.activations.Activation; @@ -60,111 +59,78 @@ import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.common.primitives.Pair; import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; - import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; - import static org.deeplearning4j.nn.conf.RNNFormat.NCW; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; @Slf4j @RunWith(Parameterized.class) -public class BidirectionalTest extends BaseDL4JTest { +@DisplayName("Bidirectional Test") +class BidirectionalTest extends BaseDL4JTest { private RNNFormat rnnDataFormat; - public BidirectionalTest(RNNFormat rnnDataFormat){ + public BidirectionalTest(RNNFormat rnnDataFormat) { this.rnnDataFormat = rnnDataFormat; } + @Parameterized.Parameters - public static Object[] params(){ + public static Object[] params() { return RNNFormat.values(); } + @Test - public void compareImplementations(){ - for(WorkspaceMode wsm : WorkspaceMode.values()) { + @DisplayName("Compare Implementations") + void compareImplementations() { + for (WorkspaceMode wsm : WorkspaceMode.values()) { log.info("*** Starting workspace mode: " + wsm); - - //Bidirectional(GravesLSTM) and GravesBidirectionalLSTM should be equivalent, given equivalent params - //Note that GravesBidirectionalLSTM implements ADD mode only - - MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder() - .activation(Activation.TANH) - .weightInit(WeightInit.XAVIER) - .trainingWorkspaceMode(wsm) - .inferenceWorkspaceMode(wsm) - .updater(new Adam()) - .list() - .layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build())) - .layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build())) - .layer(new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).dataFormat(rnnDataFormat) - .nIn(10).nOut(10).build()) - .build(); - - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder() - .activation(Activation.TANH) - .weightInit(WeightInit.XAVIER) - .trainingWorkspaceMode(wsm) - .inferenceWorkspaceMode(wsm) - .updater(new Adam()) - .list() - .layer(new GravesBidirectionalLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()) - .layer(new GravesBidirectionalLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()) - .layer(new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).dataFormat(rnnDataFormat) - .nIn(10).nOut(10).build()) - .build(); - + // Bidirectional(GravesLSTM) and GravesBidirectionalLSTM should be equivalent, given equivalent params + // Note that GravesBidirectionalLSTM implements ADD mode only + MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder().activation(Activation.TANH).weightInit(WeightInit.XAVIER).trainingWorkspaceMode(wsm).inferenceWorkspaceMode(wsm).updater(new Adam()).list().layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build())).layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build())).layer(new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).dataFormat(rnnDataFormat).nIn(10).nOut(10).build()).build(); + MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().activation(Activation.TANH).weightInit(WeightInit.XAVIER).trainingWorkspaceMode(wsm).inferenceWorkspaceMode(wsm).updater(new Adam()).list().layer(new GravesBidirectionalLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()).layer(new GravesBidirectionalLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()).layer(new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).dataFormat(rnnDataFormat).nIn(10).nOut(10).build()).build(); MultiLayerNetwork net1 = new MultiLayerNetwork(conf1); net1.init(); - MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); net2.init(); - assertEquals(net1.numParams(), net2.numParams()); for (int i = 0; i < 3; i++) { - int n1 = (int)net1.getLayer(i).numParams(); - int n2 = (int)net2.getLayer(i).numParams(); + int n1 = (int) net1.getLayer(i).numParams(); + int n2 = (int) net2.getLayer(i).numParams(); assertEquals(n1, n2); } - - net2.setParams(net1.params()); //Assuming exact same layout here... - + // Assuming exact same layout here... + net2.setParams(net1.params()); INDArray in; - if (rnnDataFormat == NCW){ - in = Nd4j.rand(new int[]{3, 10, 5}); - }else{ - in = Nd4j.rand(new int[]{3, 5, 10}); + if (rnnDataFormat == NCW) { + in = Nd4j.rand(new int[] { 3, 10, 5 }); + } else { + in = Nd4j.rand(new int[] { 3, 5, 10 }); } - INDArray out1 = net1.output(in); INDArray out2 = net2.output(in); - assertEquals(out1, out2); - INDArray labels; - if (rnnDataFormat == NCW){ - labels = Nd4j.rand(new int[]{3, 10, 5}); - }else{ - labels = Nd4j.rand(new int[]{3, 5, 10}); + if (rnnDataFormat == NCW) { + labels = Nd4j.rand(new int[] { 3, 10, 5 }); + } else { + labels = Nd4j.rand(new int[] { 3, 5, 10 }); } net1.setInput(in); net1.setLabels(labels); - net2.setInput(in); net2.setLabels(labels); - net1.computeGradientAndScore(); net2.computeGradientAndScore(); - - //Ensure scores are equal: + // Ensure scores are equal: assertEquals(net1.score(), net2.score(), 1e-6); - - //Ensure gradients are equal: + // Ensure gradients are equal: Gradient g1 = net1.gradient(); Gradient g2 = net2.gradient(); assertEquals(g1.gradient(), g2.gradient()); - - //Ensure updates are equal: + // Ensure updates are equal: MultiLayerUpdater u1 = (MultiLayerUpdater) net1.getUpdater(); MultiLayerUpdater u2 = (MultiLayerUpdater) net2.getUpdater(); assertEquals(u1.getUpdaterStateViewArray(), u2.getUpdaterStateViewArray()); @@ -172,11 +138,9 @@ public class BidirectionalTest extends BaseDL4JTest { u2.update(net2, g2, 0, 0, 3, LayerWorkspaceMgr.noWorkspaces()); assertEquals(g1.gradient(), g2.gradient()); assertEquals(u1.getUpdaterStateViewArray(), u2.getUpdaterStateViewArray()); - - //Ensure params are equal, after fitting + // Ensure params are equal, after fitting net1.fit(in, labels); net2.fit(in, labels); - INDArray p1 = net1.params(); INDArray p2 = net2.params(); assertEquals(p1, p2); @@ -184,86 +148,45 @@ public class BidirectionalTest extends BaseDL4JTest { } @Test - public void compareImplementationsCompGraph(){ -// for(WorkspaceMode wsm : WorkspaceMode.values()) { - for(WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.NONE, WorkspaceMode.ENABLED}) { + @DisplayName("Compare Implementations Comp Graph") + void compareImplementationsCompGraph() { + // for(WorkspaceMode wsm : WorkspaceMode.values()) { + for (WorkspaceMode wsm : new WorkspaceMode[] { WorkspaceMode.NONE, WorkspaceMode.ENABLED }) { log.info("*** Starting workspace mode: " + wsm); - - //Bidirectional(GravesLSTM) and GravesBidirectionalLSTM should be equivalent, given equivalent params - //Note that GravesBidirectionalLSTM implements ADD mode only - - ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder() - .activation(Activation.TANH) - .weightInit(WeightInit.XAVIER) - .updater(new Adam()) - .trainingWorkspaceMode(wsm) - .inferenceWorkspaceMode(wsm) - .graphBuilder() - .addInputs("in") - .layer("0", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build()), "in") - .layer("1", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build()), "0") - .layer("2", new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE) - .nIn(10).nOut(10).build(), "1") - .setOutputs("2") - .build(); - - ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder() - .activation(Activation.TANH) - .weightInit(WeightInit.XAVIER) - .updater(new Adam()) - .trainingWorkspaceMode(wsm) - .inferenceWorkspaceMode(wsm) - .graphBuilder() - .addInputs("in") - .layer("0", new GravesBidirectionalLSTM.Builder().nIn(10).nOut(10).build(), "in") - .layer("1", new GravesBidirectionalLSTM.Builder().nIn(10).nOut(10).build(), "0") - .layer("2", new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE) - .nIn(10).nOut(10).build(), "1") - .setOutputs("2") - .build(); - + // Bidirectional(GravesLSTM) and GravesBidirectionalLSTM should be equivalent, given equivalent params + // Note that GravesBidirectionalLSTM implements ADD mode only + ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder().activation(Activation.TANH).weightInit(WeightInit.XAVIER).updater(new Adam()).trainingWorkspaceMode(wsm).inferenceWorkspaceMode(wsm).graphBuilder().addInputs("in").layer("0", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build()), "in").layer("1", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build()), "0").layer("2", new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).build(), "1").setOutputs("2").build(); + ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder().activation(Activation.TANH).weightInit(WeightInit.XAVIER).updater(new Adam()).trainingWorkspaceMode(wsm).inferenceWorkspaceMode(wsm).graphBuilder().addInputs("in").layer("0", new GravesBidirectionalLSTM.Builder().nIn(10).nOut(10).build(), "in").layer("1", new GravesBidirectionalLSTM.Builder().nIn(10).nOut(10).build(), "0").layer("2", new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).build(), "1").setOutputs("2").build(); ComputationGraph net1 = new ComputationGraph(conf1); net1.init(); - ComputationGraph net2 = new ComputationGraph(conf2); net2.init(); - assertEquals(net1.numParams(), net2.numParams()); for (int i = 0; i < 3; i++) { - int n1 = (int)net1.getLayer(i).numParams(); - int n2 = (int)net2.getLayer(i).numParams(); + int n1 = (int) net1.getLayer(i).numParams(); + int n2 = (int) net2.getLayer(i).numParams(); assertEquals(n1, n2); } - - net2.setParams(net1.params()); //Assuming exact same layout here... - - INDArray in = Nd4j.rand(new int[]{3, 10, 5}); - + // Assuming exact same layout here... + net2.setParams(net1.params()); + INDArray in = Nd4j.rand(new int[] { 3, 10, 5 }); INDArray out1 = net1.outputSingle(in); INDArray out2 = net2.outputSingle(in); - assertEquals(out1, out2); - - INDArray labels = Nd4j.rand(new int[]{3, 10, 5}); - - net1.setInput(0,in); + INDArray labels = Nd4j.rand(new int[] { 3, 10, 5 }); + net1.setInput(0, in); net1.setLabels(labels); - - net2.setInput(0,in); + net2.setInput(0, in); net2.setLabels(labels); - net1.computeGradientAndScore(); net2.computeGradientAndScore(); - - //Ensure scores are equal: + // Ensure scores are equal: assertEquals(net1.score(), net2.score(), 1e-6); - - //Ensure gradients are equal: + // Ensure gradients are equal: Gradient g1 = net1.gradient(); Gradient g2 = net2.gradient(); assertEquals(g1.gradient(), g2.gradient()); - - //Ensure updates are equal: + // Ensure updates are equal: ComputationGraphUpdater u1 = (ComputationGraphUpdater) net1.getUpdater(); ComputationGraphUpdater u2 = (ComputationGraphUpdater) net2.getUpdater(); assertEquals(u1.getUpdaterStateViewArray(), u2.getUpdaterStateViewArray()); @@ -271,203 +194,117 @@ public class BidirectionalTest extends BaseDL4JTest { u2.update(g2, 0, 0, 3, LayerWorkspaceMgr.noWorkspaces()); assertEquals(g1.gradient(), g2.gradient()); assertEquals(u1.getUpdaterStateViewArray(), u2.getUpdaterStateViewArray()); - - //Ensure params are equal, after fitting + // Ensure params are equal, after fitting net1.fit(new DataSet(in, labels)); net2.fit(new DataSet(in, labels)); - INDArray p1 = net1.params(); INDArray p2 = net2.params(); assertEquals(p1, p2); } } - @Test - public void testSerialization() throws Exception { - - for(WorkspaceMode wsm : WorkspaceMode.values()) { + @DisplayName("Test Serialization") + void testSerialization() throws Exception { + for (WorkspaceMode wsm : WorkspaceMode.values()) { log.info("*** Starting workspace mode: " + wsm); - Nd4j.getRandom().setSeed(12345); - - MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder() - .activation(Activation.TANH) - .weightInit(WeightInit.XAVIER) - .trainingWorkspaceMode(wsm) - .inferenceWorkspaceMode(wsm) - .updater(new Adam()) - .list() - .layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build())) - .layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build())) - .layer(new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE) - .nIn(10).nOut(10).dataFormat(rnnDataFormat).build()) - .build(); - + MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder().activation(Activation.TANH).weightInit(WeightInit.XAVIER).trainingWorkspaceMode(wsm).inferenceWorkspaceMode(wsm).updater(new Adam()).list().layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build())).layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build())).layer(new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).dataFormat(rnnDataFormat).build()).build(); MultiLayerNetwork net1 = new MultiLayerNetwork(conf1); net1.init(); - INDArray in; INDArray labels; - - long[] inshape = rnnDataFormat == NCW ? new long[]{3, 10, 5} : new long[]{3, 5, 10}; - + long[] inshape = rnnDataFormat == NCW ? new long[] { 3, 10, 5 } : new long[] { 3, 5, 10 }; in = Nd4j.rand(inshape); labels = Nd4j.rand(inshape); - net1.fit(in, labels); - byte[] bytes; try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) { ModelSerializer.writeModel(net1, baos, true); bytes = baos.toByteArray(); } - - MultiLayerNetwork net2 = ModelSerializer.restoreMultiLayerNetwork(new ByteArrayInputStream(bytes), true); - - in = Nd4j.rand(inshape); labels = Nd4j.rand(inshape); - INDArray out1 = net1.output(in); INDArray out2 = net2.output(in); - assertEquals(out1, out2); - net1.setInput(in); net2.setInput(in); net1.setLabels(labels); net2.setLabels(labels); - net1.computeGradientAndScore(); net2.computeGradientAndScore(); - assertEquals(net1.score(), net2.score(), 1e-6); assertEquals(net1.gradient().gradient(), net2.gradient().gradient()); } } - @Test - public void testSerializationCompGraph() throws Exception { - - for(WorkspaceMode wsm : WorkspaceMode.values()) { + @DisplayName("Test Serialization Comp Graph") + void testSerializationCompGraph() throws Exception { + for (WorkspaceMode wsm : WorkspaceMode.values()) { log.info("*** Starting workspace mode: " + wsm); - Nd4j.getRandom().setSeed(12345); - - ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder() - .activation(Activation.TANH) - .weightInit(WeightInit.XAVIER) - .trainingWorkspaceMode(wsm) - .inferenceWorkspaceMode(wsm) - .updater(new Adam()) - .graphBuilder() - .addInputs("in") - .layer("0", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()), "in") - .layer("1", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()), "0") - .layer("2", new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).dataFormat(rnnDataFormat) - .nIn(10).nOut(10).build(), "1") - .setOutputs("2") - .build(); - + ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder().activation(Activation.TANH).weightInit(WeightInit.XAVIER).trainingWorkspaceMode(wsm).inferenceWorkspaceMode(wsm).updater(new Adam()).graphBuilder().addInputs("in").layer("0", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()), "in").layer("1", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()), "0").layer("2", new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).dataFormat(rnnDataFormat).nIn(10).nOut(10).build(), "1").setOutputs("2").build(); ComputationGraph net1 = new ComputationGraph(conf1); net1.init(); - long[] inshape = (rnnDataFormat == NCW)? new long[]{3, 10, 5}: new long[]{3, 5, 10}; + long[] inshape = (rnnDataFormat == NCW) ? new long[] { 3, 10, 5 } : new long[] { 3, 5, 10 }; INDArray in = Nd4j.rand(inshape); INDArray labels = Nd4j.rand(inshape); - net1.fit(new DataSet(in, labels)); - byte[] bytes; try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) { ModelSerializer.writeModel(net1, baos, true); bytes = baos.toByteArray(); } - - ComputationGraph net2 = ModelSerializer.restoreComputationGraph(new ByteArrayInputStream(bytes), true); - - in = Nd4j.rand(inshape); labels = Nd4j.rand(inshape); - INDArray out1 = net1.outputSingle(in); INDArray out2 = net2.outputSingle(in); - assertEquals(out1, out2); - net1.setInput(0, in); net2.setInput(0, in); net1.setLabels(labels); net2.setLabels(labels); - net1.computeGradientAndScore(); net2.computeGradientAndScore(); - assertEquals(net1.score(), net2.score(), 1e-6); assertEquals(net1.gradient().gradient(), net2.gradient().gradient()); } } @Test - public void testSimpleBidirectional() { - + @DisplayName("Test Simple Bidirectional") + void testSimpleBidirectional() { for (WorkspaceMode wsm : WorkspaceMode.values()) { log.info("*** Starting workspace mode: " + wsm); Nd4j.getRandom().setSeed(12345); - - Bidirectional.Mode[] modes = new Bidirectional.Mode[]{Bidirectional.Mode.CONCAT, Bidirectional.Mode.ADD, - Bidirectional.Mode.AVERAGE, Bidirectional.Mode.MUL}; - - long[] inshape = rnnDataFormat == NCW ? new long[]{3, 10, 6} : new long[]{3, 6, 10}; + Bidirectional.Mode[] modes = new Bidirectional.Mode[] { Bidirectional.Mode.CONCAT, Bidirectional.Mode.ADD, Bidirectional.Mode.AVERAGE, Bidirectional.Mode.MUL }; + long[] inshape = rnnDataFormat == NCW ? new long[] { 3, 10, 6 } : new long[] { 3, 6, 10 }; INDArray in = Nd4j.rand(inshape); - for (Bidirectional.Mode m : modes) { - MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder() - .dataType(DataType.DOUBLE) - .activation(Activation.TANH) - .weightInit(WeightInit.XAVIER) - .trainingWorkspaceMode(wsm) - .inferenceWorkspaceMode(wsm) - .updater(new Adam()) - .list() - .layer(new Bidirectional(m, new SimpleRnn.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build())) - .build(); - + MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).activation(Activation.TANH).weightInit(WeightInit.XAVIER).trainingWorkspaceMode(wsm).inferenceWorkspaceMode(wsm).updater(new Adam()).list().layer(new Bidirectional(m, new SimpleRnn.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build())).build(); MultiLayerNetwork net1 = new MultiLayerNetwork(conf1); net1.init(); - - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder() - .dataType(DataType.DOUBLE) - .activation(Activation.TANH) - .weightInit(WeightInit.XAVIER) - .updater(new Adam()) - .list() - .layer(new SimpleRnn.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()) - .build(); - + MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).activation(Activation.TANH).weightInit(WeightInit.XAVIER).updater(new Adam()).list().layer(new SimpleRnn.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()).build(); MultiLayerNetwork net2 = new MultiLayerNetwork(conf2.clone()); net2.init(); MultiLayerNetwork net3 = new MultiLayerNetwork(conf2.clone()); net3.init(); - net2.setParam("0_W", net1.getParam("0_fW")); net2.setParam("0_RW", net1.getParam("0_fRW")); net2.setParam("0_b", net1.getParam("0_fb")); - net3.setParam("0_W", net1.getParam("0_bW")); net3.setParam("0_RW", net1.getParam("0_bRW")); net3.setParam("0_b", net1.getParam("0_bb")); - INDArray inReverse = TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT, rnnDataFormat); INDArray out1 = net1.output(in); INDArray out2 = net2.output(in); INDArray out3 = TimeSeriesUtils.reverseTimeSeries(net3.output(inReverse), LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT, rnnDataFormat); - INDArray outExp; - switch (m) { + switch(m) { case ADD: outExp = out2.add(out3); break; @@ -478,139 +315,90 @@ public class BidirectionalTest extends BaseDL4JTest { outExp = out2.add(out3).muli(0.5); break; case CONCAT: - outExp = Nd4j.concat((rnnDataFormat == NCW)?1:2, out2, out3); + outExp = Nd4j.concat((rnnDataFormat == NCW) ? 1 : 2, out2, out3); break; default: throw new RuntimeException(); } - - assertEquals(m.toString(), outExp, out1); - - - //Check gradients: + assertEquals(outExp, out1,m.toString()); + // Check gradients: if (m == Bidirectional.Mode.ADD || m == Bidirectional.Mode.CONCAT) { - INDArray eps = Nd4j.rand(inshape); - INDArray eps1; if (m == Bidirectional.Mode.CONCAT) { - eps1 = Nd4j.concat((rnnDataFormat == NCW)?1:2, eps, eps); + eps1 = Nd4j.concat((rnnDataFormat == NCW) ? 1 : 2, eps, eps); } else { eps1 = eps; } - net1.setInput(in); net2.setInput(in); net3.setInput(TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT, rnnDataFormat)); net1.feedForward(true, false); net2.feedForward(true, false); net3.feedForward(true, false); - Pair p1 = net1.backpropGradient(eps1, LayerWorkspaceMgr.noWorkspaces()); Pair p2 = net2.backpropGradient(eps, LayerWorkspaceMgr.noWorkspaces()); Pair p3 = net3.backpropGradient(TimeSeriesUtils.reverseTimeSeries(eps, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT, rnnDataFormat), LayerWorkspaceMgr.noWorkspaces()); Gradient g1 = p1.getFirst(); Gradient g2 = p2.getFirst(); Gradient g3 = p3.getFirst(); - - for (boolean updates : new boolean[]{false, true}) { + for (boolean updates : new boolean[] { false, true }) { if (updates) { net1.getUpdater().update(net1, g1, 0, 0, 3, LayerWorkspaceMgr.noWorkspaces()); net2.getUpdater().update(net2, g2, 0, 0, 3, LayerWorkspaceMgr.noWorkspaces()); net3.getUpdater().update(net3, g3, 0, 0, 3, LayerWorkspaceMgr.noWorkspaces()); } - assertEquals(g2.gradientForVariable().get("0_W"), g1.gradientForVariable().get("0_fW")); assertEquals(g2.gradientForVariable().get("0_RW"), g1.gradientForVariable().get("0_fRW")); assertEquals(g2.gradientForVariable().get("0_b"), g1.gradientForVariable().get("0_fb")); - assertEquals(g3.gradientForVariable().get("0_W"), g1.gradientForVariable().get("0_bW")); assertEquals(g3.gradientForVariable().get("0_RW"), g1.gradientForVariable().get("0_bRW")); assertEquals(g3.gradientForVariable().get("0_b"), g1.gradientForVariable().get("0_bb")); } - } } } } - @Test - public void testSimpleBidirectionalCompGraph() { - + @DisplayName("Test Simple Bidirectional Comp Graph") + void testSimpleBidirectionalCompGraph() { for (WorkspaceMode wsm : WorkspaceMode.values()) { log.info("*** Starting workspace mode: " + wsm); Nd4j.getRandom().setSeed(12345); - - Bidirectional.Mode[] modes = new Bidirectional.Mode[]{Bidirectional.Mode.CONCAT, Bidirectional.Mode.ADD, - Bidirectional.Mode.AVERAGE, Bidirectional.Mode.MUL}; - - - long[] inshape = rnnDataFormat == NCW ? new long[]{3, 10, 6} : new long[]{3, 6, 10}; + Bidirectional.Mode[] modes = new Bidirectional.Mode[] { Bidirectional.Mode.CONCAT, Bidirectional.Mode.ADD, Bidirectional.Mode.AVERAGE, Bidirectional.Mode.MUL }; + long[] inshape = rnnDataFormat == NCW ? new long[] { 3, 10, 6 } : new long[] { 3, 6, 10 }; INDArray in = Nd4j.rand(inshape); - - for (Bidirectional.Mode m : modes) { - ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder() - .dataType(DataType.DOUBLE) - .activation(Activation.TANH) - .weightInit(WeightInit.XAVIER) - .trainingWorkspaceMode(wsm) - .inferenceWorkspaceMode(wsm) - .updater(new Adam()) - .graphBuilder() - .addInputs("in") - .layer("0", new Bidirectional(m, new SimpleRnn.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()), "in") - .setOutputs("0") - .build(); - + ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).activation(Activation.TANH).weightInit(WeightInit.XAVIER).trainingWorkspaceMode(wsm).inferenceWorkspaceMode(wsm).updater(new Adam()).graphBuilder().addInputs("in").layer("0", new Bidirectional(m, new SimpleRnn.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()), "in").setOutputs("0").build(); ComputationGraph net1 = new ComputationGraph(conf1); net1.init(); - - ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder() - .dataType(DataType.DOUBLE) - .activation(Activation.TANH) - .weightInit(WeightInit.XAVIER) - .updater(new Adam()) - .graphBuilder() - .addInputs("in") - .layer("0", new SimpleRnn.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build(), "in") - .setOutputs("0") - .build(); - + ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).activation(Activation.TANH).weightInit(WeightInit.XAVIER).updater(new Adam()).graphBuilder().addInputs("in").layer("0", new SimpleRnn.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build(), "in").setOutputs("0").build(); ComputationGraph net2 = new ComputationGraph(conf2.clone()); net2.init(); ComputationGraph net3 = new ComputationGraph(conf2.clone()); net3.init(); - net2.setParam("0_W", net1.getParam("0_fW")); net2.setParam("0_RW", net1.getParam("0_fRW")); net2.setParam("0_b", net1.getParam("0_fb")); - net3.setParam("0_W", net1.getParam("0_bW")); net3.setParam("0_RW", net1.getParam("0_bRW")); net3.setParam("0_b", net1.getParam("0_bb")); - - INDArray out1 = net1.outputSingle(in); INDArray out2 = net2.outputSingle(in); INDArray out3; INDArray inReverse; - if (rnnDataFormat == RNNFormat.NWC){ + if (rnnDataFormat == RNNFormat.NWC) { inReverse = TimeSeriesUtils.reverseTimeSeries(in.permute(0, 2, 1), LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT).permute(0, 2, 1); out3 = net3.outputSingle(inReverse); out3 = TimeSeriesUtils.reverseTimeSeries(out3.permute(0, 2, 1), LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT).permute(0, 2, 1); - - } - else{ + } else { inReverse = TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT); out3 = net3.outputSingle(inReverse); out3 = TimeSeriesUtils.reverseTimeSeries(out3, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT); - } - INDArray outExp; - switch (m) { + switch(m) { case ADD: outExp = out2.add(out3); break; @@ -623,50 +411,37 @@ public class BidirectionalTest extends BaseDL4JTest { case CONCAT: System.out.println(out2.shapeInfoToString()); System.out.println(out3.shapeInfoToString()); - outExp = Nd4j.concat((rnnDataFormat == NCW)?1:2, out2, out3); + outExp = Nd4j.concat((rnnDataFormat == NCW) ? 1 : 2, out2, out3); break; default: throw new RuntimeException(); } - - assertEquals(m.toString(), outExp, out1); - - - //Check gradients: + assertEquals(outExp, out1,m.toString()); + // Check gradients: if (m == Bidirectional.Mode.ADD || m == Bidirectional.Mode.CONCAT) { - INDArray eps = Nd4j.rand(inshape); - INDArray eps1; if (m == Bidirectional.Mode.CONCAT) { - eps1 = Nd4j.concat((rnnDataFormat == NCW)?1:2, eps, eps); + eps1 = Nd4j.concat((rnnDataFormat == NCW) ? 1 : 2, eps, eps); } else { eps1 = eps; } - - INDArray epsReversed = (rnnDataFormat == NCW)? - TimeSeriesUtils.reverseTimeSeries(eps, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT): - TimeSeriesUtils.reverseTimeSeries(eps.permute(0, 2, 1), LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT) - .permute(0, 2, 1); + INDArray epsReversed = (rnnDataFormat == NCW) ? TimeSeriesUtils.reverseTimeSeries(eps, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT) : TimeSeriesUtils.reverseTimeSeries(eps.permute(0, 2, 1), LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT).permute(0, 2, 1); net1.outputSingle(true, false, in); net2.outputSingle(true, false, in); net3.outputSingle(true, false, inReverse); - Gradient g1 = net1.backpropGradient(eps1); Gradient g2 = net2.backpropGradient(eps); Gradient g3 = net3.backpropGradient(epsReversed); - - for (boolean updates : new boolean[]{false, true}) { + for (boolean updates : new boolean[] { false, true }) { if (updates) { net1.getUpdater().update(g1, 0, 0, 3, LayerWorkspaceMgr.noWorkspaces()); net2.getUpdater().update(g2, 0, 0, 3, LayerWorkspaceMgr.noWorkspaces()); net3.getUpdater().update(g3, 0, 0, 3, LayerWorkspaceMgr.noWorkspaces()); } - assertEquals(g2.gradientForVariable().get("0_W"), g1.gradientForVariable().get("0_fW")); assertEquals(g2.gradientForVariable().get("0_RW"), g1.gradientForVariable().get("0_fRW")); assertEquals(g2.gradientForVariable().get("0_b"), g1.gradientForVariable().get("0_fb")); - assertEquals(g3.gradientForVariable().get("0_W"), g1.gradientForVariable().get("0_bW")); assertEquals(g3.gradientForVariable().get("0_RW"), g1.gradientForVariable().get("0_bRW")); assertEquals(g3.gradientForVariable().get("0_b"), g1.gradientForVariable().get("0_bb")); @@ -676,47 +451,17 @@ public class BidirectionalTest extends BaseDL4JTest { } } - @Test - public void testIssue5472(){ - //https://github.com/deeplearning4j/deeplearning4j/issues/5472 - + @DisplayName("Test Issue 5472") + void testIssue5472() { + // https://github.com/deeplearning4j/deeplearning4j/issues/5472 int in = 2; int out = 2; - ComputationGraphConfiguration.GraphBuilder builder = new NeuralNetConfiguration.Builder() - .updater(new Adam(0.01)) - .activation(Activation.RELU) - .graphBuilder() - .addInputs("IN") - .setInputTypes(InputType.recurrent(in)) - .addLayer("AUTOENCODER", - new VariationalAutoencoder.Builder() - .encoderLayerSizes(64) - .decoderLayerSizes(64) - .nOut(7) - .pzxActivationFunction(Activation.IDENTITY) - .reconstructionDistribution(new BernoulliReconstructionDistribution(Activation.SIGMOID.getActivationFunction())).build(), - "IN") - .addLayer("RNN", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nOut(128).build()), "AUTOENCODER") - .addLayer("OUT", new RnnOutputLayer.Builder() - .nOut(out) - .activation(Activation.SOFTMAX) - .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "RNN") - .setOutputs("OUT") - - ; - + ComputationGraphConfiguration.GraphBuilder builder = new NeuralNetConfiguration.Builder().updater(new Adam(0.01)).activation(Activation.RELU).graphBuilder().addInputs("IN").setInputTypes(InputType.recurrent(in)).addLayer("AUTOENCODER", new VariationalAutoencoder.Builder().encoderLayerSizes(64).decoderLayerSizes(64).nOut(7).pzxActivationFunction(Activation.IDENTITY).reconstructionDistribution(new BernoulliReconstructionDistribution(Activation.SIGMOID.getActivationFunction())).build(), "IN").addLayer("RNN", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nOut(128).build()), "AUTOENCODER").addLayer("OUT", new RnnOutputLayer.Builder().nOut(out).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build(), "RNN").setOutputs("OUT"); ComputationGraph net = new ComputationGraph(builder.build()); net.init(); - - MultiDataSetIterator iterator = new SingletonMultiDataSetIterator(new MultiDataSet(Nd4j.create(10,in,5), Nd4j.create(10,out,5))); - - EarlyStoppingConfiguration.Builder b = new EarlyStoppingConfiguration.Builder<>() - .epochTerminationConditions(new MaxEpochsTerminationCondition(10)) - .scoreCalculator(new DataSetLossCalculator(iterator, true)) - .evaluateEveryNEpochs(1) - .modelSaver(new InMemoryModelSaver<>()); - + MultiDataSetIterator iterator = new SingletonMultiDataSetIterator(new MultiDataSet(Nd4j.create(10, in, 5), Nd4j.create(10, out, 5))); + EarlyStoppingConfiguration.Builder b = new EarlyStoppingConfiguration.Builder<>().epochTerminationConditions(new MaxEpochsTerminationCondition(10)).scoreCalculator(new DataSetLossCalculator(iterator, true)).evaluateEveryNEpochs(1).modelSaver(new InMemoryModelSaver<>()); EarlyStoppingGraphTrainer earlyStoppingGraphTrainer = new EarlyStoppingGraphTrainer(b.build(), net, iterator, null); earlyStoppingGraphTrainer.fit(); } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTMTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTMTest.java index e61623f99..41b91b65a 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTMTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTMTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.layers.recurrent; import junit.framework.TestCase; @@ -35,7 +34,7 @@ import org.deeplearning4j.nn.params.GravesBidirectionalLSTMParamInitializer; import org.deeplearning4j.nn.params.GravesLSTMParamInitializer; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.activations.Activation; @@ -46,197 +45,146 @@ import org.nd4j.linalg.learning.config.AdaGrad; import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.common.primitives.Pair; - -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; @RunWith(Parameterized.class) -public class GravesBidirectionalLSTMTest extends BaseDL4JTest { +@DisplayName("Graves Bidirectional LSTM Test") +class GravesBidirectionalLSTMTest extends BaseDL4JTest { + private double score = 0.0; + private RNNFormat rnnDataFormat; - public GravesBidirectionalLSTMTest(RNNFormat rnnDataFormat){ + public GravesBidirectionalLSTMTest(RNNFormat rnnDataFormat) { this.rnnDataFormat = rnnDataFormat; } + @Parameterized.Parameters - public static Object[] params(){ + public static Object[] params() { return RNNFormat.values(); } + @Test - public void testBidirectionalLSTMGravesForwardBasic() { - //Very basic test of forward prop. of LSTM layer with a time series. - //Essentially make sure it doesn't throw any exceptions, and provides output in the correct shape. + @DisplayName("Test Bidirectional LSTM Graves Forward Basic") + void testBidirectionalLSTMGravesForwardBasic() { + // Very basic test of forward prop. of LSTM layer with a time series. + // Essentially make sure it doesn't throw any exceptions, and provides output in the correct shape. int nIn = 13; int nHiddenUnits = 17; - - final NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder() - .layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().nIn(nIn) - .nOut(nHiddenUnits).dataFormat(rnnDataFormat).activation(Activation.TANH).build()) - .build(); - + final NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().nIn(nIn).nOut(nHiddenUnits).dataFormat(rnnDataFormat).activation(Activation.TANH).build()).build(); val numParams = conf.getLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - final GravesBidirectionalLSTM layer = - (GravesBidirectionalLSTM) conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); - - //Data: has shape [miniBatchSize,nIn,timeSeriesLength]; - //Output/activations has shape [miniBatchsize,nHiddenUnits,timeSeriesLength]; - if (rnnDataFormat == RNNFormat.NCW){ + final GravesBidirectionalLSTM layer = (GravesBidirectionalLSTM) conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); + // Data: has shape [miniBatchSize,nIn,timeSeriesLength]; + // Output/activations has shape [miniBatchsize,nHiddenUnits,timeSeriesLength]; + if (rnnDataFormat == RNNFormat.NCW) { final INDArray dataSingleExampleTimeLength1 = Nd4j.ones(1, nIn, 1); final INDArray activations1 = layer.activate(dataSingleExampleTimeLength1, false, LayerWorkspaceMgr.noWorkspaces()); - assertArrayEquals(activations1.shape(), new long[] {1, nHiddenUnits, 1}); - + assertArrayEquals(activations1.shape(), new long[] { 1, nHiddenUnits, 1 }); final INDArray dataMultiExampleLength1 = Nd4j.ones(10, nIn, 1); final INDArray activations2 = layer.activate(dataMultiExampleLength1, false, LayerWorkspaceMgr.noWorkspaces()); - assertArrayEquals(activations2.shape(), new long[] {10, nHiddenUnits, 1}); - + assertArrayEquals(activations2.shape(), new long[] { 10, nHiddenUnits, 1 }); final INDArray dataSingleExampleLength12 = Nd4j.ones(1, nIn, 12); final INDArray activations3 = layer.activate(dataSingleExampleLength12, false, LayerWorkspaceMgr.noWorkspaces()); - assertArrayEquals(activations3.shape(), new long[] {1, nHiddenUnits, 12}); - + assertArrayEquals(activations3.shape(), new long[] { 1, nHiddenUnits, 12 }); final INDArray dataMultiExampleLength15 = Nd4j.ones(10, nIn, 15); final INDArray activations4 = layer.activate(dataMultiExampleLength15, false, LayerWorkspaceMgr.noWorkspaces()); - assertArrayEquals(activations4.shape(), new long[] {10, nHiddenUnits, 15}); - } - else{ + assertArrayEquals(activations4.shape(), new long[] { 10, nHiddenUnits, 15 }); + } else { final INDArray dataSingleExampleTimeLength1 = Nd4j.ones(1, 1, nIn); final INDArray activations1 = layer.activate(dataSingleExampleTimeLength1, false, LayerWorkspaceMgr.noWorkspaces()); - assertArrayEquals(activations1.shape(), new long[] {1, 1, nHiddenUnits}); - + assertArrayEquals(activations1.shape(), new long[] { 1, 1, nHiddenUnits }); final INDArray dataMultiExampleLength1 = Nd4j.ones(10, 1, nIn); final INDArray activations2 = layer.activate(dataMultiExampleLength1, false, LayerWorkspaceMgr.noWorkspaces()); - assertArrayEquals(activations2.shape(), new long[] {10, 1, nHiddenUnits}); - + assertArrayEquals(activations2.shape(), new long[] { 10, 1, nHiddenUnits }); final INDArray dataSingleExampleLength12 = Nd4j.ones(1, 12, nIn); final INDArray activations3 = layer.activate(dataSingleExampleLength12, false, LayerWorkspaceMgr.noWorkspaces()); - assertArrayEquals(activations3.shape(), new long[] {1, 12, nHiddenUnits}); - + assertArrayEquals(activations3.shape(), new long[] { 1, 12, nHiddenUnits }); final INDArray dataMultiExampleLength15 = Nd4j.ones(10, 15, nIn); final INDArray activations4 = layer.activate(dataMultiExampleLength15, false, LayerWorkspaceMgr.noWorkspaces()); - assertArrayEquals(activations4.shape(), new long[] {10, 15, nHiddenUnits}); + assertArrayEquals(activations4.shape(), new long[] { 10, 15, nHiddenUnits }); } - } @Test - public void testBidirectionalLSTMGravesBackwardBasic() { - //Very basic test of backprop for mini-batch + time series - //Essentially make sure it doesn't throw any exceptions, and provides output in the correct shape. - + @DisplayName("Test Bidirectional LSTM Graves Backward Basic") + void testBidirectionalLSTMGravesBackwardBasic() { + // Very basic test of backprop for mini-batch + time series + // Essentially make sure it doesn't throw any exceptions, and provides output in the correct shape. testGravesBackwardBasicHelper(13, 3, 17, 10, 7); - testGravesBackwardBasicHelper(13, 3, 17, 1, 7); //Edge case: miniBatchSize = 1 - testGravesBackwardBasicHelper(13, 3, 17, 10, 1); //Edge case: timeSeriesLength = 1 - testGravesBackwardBasicHelper(13, 3, 17, 1, 1); //Edge case: both miniBatchSize = 1 and timeSeriesLength = 1 + // Edge case: miniBatchSize = 1 + testGravesBackwardBasicHelper(13, 3, 17, 1, 7); + // Edge case: timeSeriesLength = 1 + testGravesBackwardBasicHelper(13, 3, 17, 10, 1); + // Edge case: both miniBatchSize = 1 and timeSeriesLength = 1 + testGravesBackwardBasicHelper(13, 3, 17, 1, 1); } - private void testGravesBackwardBasicHelper(int nIn, int nOut, int lstmNHiddenUnits, int miniBatchSize, - int timeSeriesLength) { - - INDArray inputData = (rnnDataFormat == RNNFormat.NCW)?Nd4j.ones(miniBatchSize, nIn, timeSeriesLength): - Nd4j.ones(miniBatchSize, timeSeriesLength, nIn); - - NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder() - .layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().nIn(nIn) - .nOut(lstmNHiddenUnits).dataFormat(rnnDataFormat) - .dist(new UniformDistribution(0, 1)).activation(Activation.TANH).build()) - .build(); - + private void testGravesBackwardBasicHelper(int nIn, int nOut, int lstmNHiddenUnits, int miniBatchSize, int timeSeriesLength) { + INDArray inputData = (rnnDataFormat == RNNFormat.NCW) ? Nd4j.ones(miniBatchSize, nIn, timeSeriesLength) : Nd4j.ones(miniBatchSize, timeSeriesLength, nIn); + NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().nIn(nIn).nOut(lstmNHiddenUnits).dataFormat(rnnDataFormat).dist(new UniformDistribution(0, 1)).activation(Activation.TANH).build()).build(); long numParams = conf.getLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - GravesBidirectionalLSTM lstm = - (GravesBidirectionalLSTM) conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); + GravesBidirectionalLSTM lstm = (GravesBidirectionalLSTM) conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); lstm.setBackpropGradientsViewArray(Nd4j.create(1, conf.getLayer().initializer().numParams(conf))); - //Set input, do a forward pass: + // Set input, do a forward pass: lstm.activate(inputData, false, LayerWorkspaceMgr.noWorkspaces()); assertNotNull(lstm.input()); - - INDArray epsilon =(rnnDataFormat == RNNFormat.NCW)? Nd4j.ones(miniBatchSize, lstmNHiddenUnits, timeSeriesLength): - Nd4j.ones(miniBatchSize, timeSeriesLength, lstmNHiddenUnits); - + INDArray epsilon = (rnnDataFormat == RNNFormat.NCW) ? Nd4j.ones(miniBatchSize, lstmNHiddenUnits, timeSeriesLength) : Nd4j.ones(miniBatchSize, timeSeriesLength, lstmNHiddenUnits); Pair out = lstm.backpropGradient(epsilon, LayerWorkspaceMgr.noWorkspaces()); Gradient outGradient = out.getFirst(); INDArray nextEpsilon = out.getSecond(); - INDArray biasGradientF = outGradient.getGradientFor(GravesBidirectionalLSTMParamInitializer.BIAS_KEY_FORWARDS); - INDArray inWeightGradientF = - outGradient.getGradientFor(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS); - INDArray recurrentWeightGradientF = outGradient - .getGradientFor(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS); + INDArray inWeightGradientF = outGradient.getGradientFor(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS); + INDArray recurrentWeightGradientF = outGradient.getGradientFor(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS); assertNotNull(biasGradientF); assertNotNull(inWeightGradientF); assertNotNull(recurrentWeightGradientF); - INDArray biasGradientB = outGradient.getGradientFor(GravesBidirectionalLSTMParamInitializer.BIAS_KEY_BACKWARDS); - INDArray inWeightGradientB = - outGradient.getGradientFor(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_BACKWARDS); - INDArray recurrentWeightGradientB = outGradient - .getGradientFor(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_BACKWARDS); + INDArray inWeightGradientB = outGradient.getGradientFor(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_BACKWARDS); + INDArray recurrentWeightGradientB = outGradient.getGradientFor(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_BACKWARDS); assertNotNull(biasGradientB); assertNotNull(inWeightGradientB); assertNotNull(recurrentWeightGradientB); - - assertArrayEquals(biasGradientF.shape(), new long[] {1, 4 * lstmNHiddenUnits}); - assertArrayEquals(inWeightGradientF.shape(), new long[] {nIn, 4 * lstmNHiddenUnits}); - assertArrayEquals(recurrentWeightGradientF.shape(), new long[] {lstmNHiddenUnits, 4 * lstmNHiddenUnits + 3}); - - assertArrayEquals(biasGradientB.shape(), new long[] {1, 4 * lstmNHiddenUnits}); - assertArrayEquals(inWeightGradientB.shape(), new long[] {nIn, 4 * lstmNHiddenUnits}); - assertArrayEquals(recurrentWeightGradientB.shape(), new long[] {lstmNHiddenUnits, 4 * lstmNHiddenUnits + 3}); - + assertArrayEquals(biasGradientF.shape(), new long[] { 1, 4 * lstmNHiddenUnits }); + assertArrayEquals(inWeightGradientF.shape(), new long[] { nIn, 4 * lstmNHiddenUnits }); + assertArrayEquals(recurrentWeightGradientF.shape(), new long[] { lstmNHiddenUnits, 4 * lstmNHiddenUnits + 3 }); + assertArrayEquals(biasGradientB.shape(), new long[] { 1, 4 * lstmNHiddenUnits }); + assertArrayEquals(inWeightGradientB.shape(), new long[] { nIn, 4 * lstmNHiddenUnits }); + assertArrayEquals(recurrentWeightGradientB.shape(), new long[] { lstmNHiddenUnits, 4 * lstmNHiddenUnits + 3 }); assertNotNull(nextEpsilon); if (rnnDataFormat == RNNFormat.NCW) { - assertArrayEquals(nextEpsilon.shape(), new long[]{miniBatchSize, nIn, timeSeriesLength}); - }else{ - assertArrayEquals(nextEpsilon.shape(), new long[]{miniBatchSize, timeSeriesLength, nIn }); + assertArrayEquals(nextEpsilon.shape(), new long[] { miniBatchSize, nIn, timeSeriesLength }); + } else { + assertArrayEquals(nextEpsilon.shape(), new long[] { miniBatchSize, timeSeriesLength, nIn }); } - - //Check update: + // Check update: for (String s : outGradient.gradientForVariable().keySet()) { lstm.update(outGradient.getGradientFor(s), s); } } @Test - public void testGravesBidirectionalLSTMForwardPassHelper() throws Exception { - //GravesBidirectionalLSTM.activateHelper() has different behaviour (due to optimizations) when forBackprop==true vs false - //But should otherwise provide identical activations + @DisplayName("Test Graves Bidirectional LSTM Forward Pass Helper") + void testGravesBidirectionalLSTMForwardPassHelper() throws Exception { + // GravesBidirectionalLSTM.activateHelper() has different behaviour (due to optimizations) when forBackprop==true vs false + // But should otherwise provide identical activations Nd4j.getRandom().setSeed(12345); - final int nIn = 10; final int layerSize = 15; final int miniBatchSize = 4; final int timeSeriesLength = 7; - - final NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder() - .layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().nIn(nIn) - .nOut(layerSize) - .dist(new UniformDistribution(0, 1)).activation(Activation.TANH).build()) - .build(); - + final NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().nIn(nIn).nOut(layerSize).dist(new UniformDistribution(0, 1)).activation(Activation.TANH).build()).build(); long numParams = conf.getLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - final GravesBidirectionalLSTM lstm = - (GravesBidirectionalLSTM) conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); - final INDArray input = Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength}); + final GravesBidirectionalLSTM lstm = (GravesBidirectionalLSTM) conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); + final INDArray input = Nd4j.rand(new int[] { miniBatchSize, nIn, timeSeriesLength }); lstm.setInput(input, LayerWorkspaceMgr.noWorkspaces()); - - - final INDArray fwdPassFalse = LSTMHelpers.activateHelper(lstm, lstm.conf(), new ActivationSigmoid(), - lstm.input(), - lstm.getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS), - lstm.getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS), - lstm.getParam(GravesBidirectionalLSTMParamInitializer.BIAS_KEY_FORWARDS), false, null, null, - false, true, GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS, null, true, - null, CacheMode.NONE, LayerWorkspaceMgr.noWorkspaces(), true).fwdPassOutput; - - final INDArray[] fwdPassTrue = LSTMHelpers.activateHelper(lstm, lstm.conf(), new ActivationSigmoid(), - lstm.input(), - lstm.getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS), - lstm.getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS), - lstm.getParam(GravesBidirectionalLSTMParamInitializer.BIAS_KEY_FORWARDS), false, null, null, - true, true, GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS, null, true, null, - CacheMode.NONE, LayerWorkspaceMgr.noWorkspaces(), true).fwdPassOutputAsArrays; - - //I have no idea what the heck this does --Ben + final INDArray fwdPassFalse = LSTMHelpers.activateHelper(lstm, lstm.conf(), new ActivationSigmoid(), lstm.input(), lstm.getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS), lstm.getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS), lstm.getParam(GravesBidirectionalLSTMParamInitializer.BIAS_KEY_FORWARDS), false, null, null, false, true, GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS, null, true, null, CacheMode.NONE, LayerWorkspaceMgr.noWorkspaces(), true).fwdPassOutput; + final INDArray[] fwdPassTrue = LSTMHelpers.activateHelper(lstm, lstm.conf(), new ActivationSigmoid(), lstm.input(), lstm.getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS), lstm.getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS), lstm.getParam(GravesBidirectionalLSTMParamInitializer.BIAS_KEY_FORWARDS), false, null, null, true, true, GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS, null, true, null, CacheMode.NONE, LayerWorkspaceMgr.noWorkspaces(), true).fwdPassOutputAsArrays; + // I have no idea what the heck this does --Ben for (int i = 0; i < timeSeriesLength; i++) { final INDArray sliceFalse = fwdPassFalse.tensorAlongDimension(i, 1, 0); final INDArray sliceTrue = fwdPassTrue[i]; @@ -247,315 +195,162 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { static private void reverseColumnsInPlace(final INDArray x) { final long N = x.size(1); final INDArray x2 = x.dup(); - for (int t = 0; t < N; t++) { final long b = N - t - 1; - //clone? + // clone? x.putColumn(t, x2.getColumn(b)); } } @Test - public void testGetSetParmas() { + @DisplayName("Test Get Set Parmas") + void testGetSetParmas() { final int nIn = 2; final int layerSize = 3; final int miniBatchSize = 2; final int timeSeriesLength = 10; - Nd4j.getRandom().setSeed(12345); - - final NeuralNetConfiguration confBidirectional = new NeuralNetConfiguration.Builder() - .layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().nIn(nIn) - .nOut(layerSize).dataFormat(rnnDataFormat) - .dist(new UniformDistribution(-0.1, 0.1)).activation(Activation.TANH).build()) - .build(); - - + final NeuralNetConfiguration confBidirectional = new NeuralNetConfiguration.Builder().layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().nIn(nIn).nOut(layerSize).dataFormat(rnnDataFormat).dist(new UniformDistribution(-0.1, 0.1)).activation(Activation.TANH).build()).build(); long numParams = confBidirectional.getLayer().initializer().numParams(confBidirectional); INDArray params = Nd4j.create(1, numParams); - final GravesBidirectionalLSTM bidirectionalLSTM = (GravesBidirectionalLSTM) confBidirectional.getLayer() - .instantiate(confBidirectional, null, 0, params, true, params.dataType()); - - - final INDArray sig = (rnnDataFormat == RNNFormat.NCW)?Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength}): - Nd4j.rand(new int[] {miniBatchSize, timeSeriesLength, nIn}); - + final GravesBidirectionalLSTM bidirectionalLSTM = (GravesBidirectionalLSTM) confBidirectional.getLayer().instantiate(confBidirectional, null, 0, params, true, params.dataType()); + final INDArray sig = (rnnDataFormat == RNNFormat.NCW) ? Nd4j.rand(new int[] { miniBatchSize, nIn, timeSeriesLength }) : Nd4j.rand(new int[] { miniBatchSize, timeSeriesLength, nIn }); final INDArray act1 = bidirectionalLSTM.activate(sig, false, LayerWorkspaceMgr.noWorkspaces()); - params = bidirectionalLSTM.params(); - bidirectionalLSTM.setParams(params); - final INDArray act2 = bidirectionalLSTM.activate(sig, false, LayerWorkspaceMgr.noWorkspaces()); - assertArrayEquals(act2.data().asDouble(), act1.data().asDouble(), 1e-8); - - } @Test - public void testSimpleForwardsAndBackwardsActivation() { - + @DisplayName("Test Simple Forwards And Backwards Activation") + void testSimpleForwardsAndBackwardsActivation() { final int nIn = 2; final int layerSize = 3; final int miniBatchSize = 1; final int timeSeriesLength = 5; - Nd4j.getRandom().setSeed(12345); - - final NeuralNetConfiguration confBidirectional = - new NeuralNetConfiguration.Builder() - .layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder() - .nIn(nIn).nOut(layerSize).dataFormat(rnnDataFormat) - .dist(new UniformDistribution(-0.1, 0.1)) - .activation(Activation.TANH).updater(new NoOp()).build()) - .build(); - - final NeuralNetConfiguration confForwards = new NeuralNetConfiguration.Builder() - .layer(new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn).nOut(layerSize).dataFormat(rnnDataFormat) - .weightInit(WeightInit.ZERO).activation(Activation.TANH).build()) - .build(); - + final NeuralNetConfiguration confBidirectional = new NeuralNetConfiguration.Builder().layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().nIn(nIn).nOut(layerSize).dataFormat(rnnDataFormat).dist(new UniformDistribution(-0.1, 0.1)).activation(Activation.TANH).updater(new NoOp()).build()).build(); + final NeuralNetConfiguration confForwards = new NeuralNetConfiguration.Builder().layer(new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn).nOut(layerSize).dataFormat(rnnDataFormat).weightInit(WeightInit.ZERO).activation(Activation.TANH).build()).build(); long numParams = confForwards.getLayer().initializer().numParams(confForwards); INDArray params = Nd4j.create(1, numParams); long numParamsBD = confBidirectional.getLayer().initializer().numParams(confBidirectional); INDArray paramsBD = Nd4j.create(1, numParamsBD); - final GravesBidirectionalLSTM bidirectionalLSTM = (GravesBidirectionalLSTM) confBidirectional.getLayer() - .instantiate(confBidirectional, null, 0, paramsBD, true, params.dataType()); - final GravesLSTM forwardsLSTM = - (GravesLSTM) confForwards.getLayer().instantiate(confForwards, null, 0, params, true, params.dataType()); - - bidirectionalLSTM.setBackpropGradientsViewArray( - Nd4j.create(1, confBidirectional.getLayer().initializer().numParams(confBidirectional))); - forwardsLSTM.setBackpropGradientsViewArray( - Nd4j.create(1, confForwards.getLayer().initializer().numParams(confForwards))); - - - final INDArray sig = (rnnDataFormat == RNNFormat.NCW)?Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength}): - Nd4j.rand(new int[] {miniBatchSize, timeSeriesLength, nIn}); + final GravesBidirectionalLSTM bidirectionalLSTM = (GravesBidirectionalLSTM) confBidirectional.getLayer().instantiate(confBidirectional, null, 0, paramsBD, true, params.dataType()); + final GravesLSTM forwardsLSTM = (GravesLSTM) confForwards.getLayer().instantiate(confForwards, null, 0, params, true, params.dataType()); + bidirectionalLSTM.setBackpropGradientsViewArray(Nd4j.create(1, confBidirectional.getLayer().initializer().numParams(confBidirectional))); + forwardsLSTM.setBackpropGradientsViewArray(Nd4j.create(1, confForwards.getLayer().initializer().numParams(confForwards))); + final INDArray sig = (rnnDataFormat == RNNFormat.NCW) ? Nd4j.rand(new int[] { miniBatchSize, nIn, timeSeriesLength }) : Nd4j.rand(new int[] { miniBatchSize, timeSeriesLength, nIn }); final INDArray sigb = sig.dup(); - if (rnnDataFormat == RNNFormat.NCW) { reverseColumnsInPlace(sigb.slice(0)); - } - else{ + } else { reverseColumnsInPlace(sigb.slice(0).permute(1, 0)); } - - final INDArray recurrentWeightsF = bidirectionalLSTM - .getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS); - final INDArray inputWeightsF = - bidirectionalLSTM.getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS); - final INDArray biasWeightsF = - bidirectionalLSTM.getParam(GravesBidirectionalLSTMParamInitializer.BIAS_KEY_FORWARDS); - + final INDArray recurrentWeightsF = bidirectionalLSTM.getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS); + final INDArray inputWeightsF = bidirectionalLSTM.getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS); + final INDArray biasWeightsF = bidirectionalLSTM.getParam(GravesBidirectionalLSTMParamInitializer.BIAS_KEY_FORWARDS); final INDArray recurrentWeightsF2 = forwardsLSTM.getParam(GravesLSTMParamInitializer.RECURRENT_WEIGHT_KEY); final INDArray inputWeightsF2 = forwardsLSTM.getParam(GravesLSTMParamInitializer.INPUT_WEIGHT_KEY); final INDArray biasWeightsF2 = forwardsLSTM.getParam(GravesLSTMParamInitializer.BIAS_KEY); - - //assert that the forwards part of the bidirectional layer is equal to that of the regular LSTM + // assert that the forwards part of the bidirectional layer is equal to that of the regular LSTM assertArrayEquals(recurrentWeightsF2.shape(), recurrentWeightsF.shape()); assertArrayEquals(inputWeightsF2.shape(), inputWeightsF.shape()); assertArrayEquals(biasWeightsF2.shape(), biasWeightsF.shape()); - forwardsLSTM.setParam(GravesLSTMParamInitializer.RECURRENT_WEIGHT_KEY, recurrentWeightsF); forwardsLSTM.setParam(GravesLSTMParamInitializer.INPUT_WEIGHT_KEY, inputWeightsF); forwardsLSTM.setParam(GravesLSTMParamInitializer.BIAS_KEY, biasWeightsF); - - //copy forwards weights to make the forwards activations do the same thing - - final INDArray recurrentWeightsB = bidirectionalLSTM - .getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_BACKWARDS); - final INDArray inputWeightsB = - bidirectionalLSTM.getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_BACKWARDS); - final INDArray biasWeightsB = - bidirectionalLSTM.getParam(GravesBidirectionalLSTMParamInitializer.BIAS_KEY_BACKWARDS); - - //assert that the forwards and backwards are the same shapes + // copy forwards weights to make the forwards activations do the same thing + final INDArray recurrentWeightsB = bidirectionalLSTM.getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_BACKWARDS); + final INDArray inputWeightsB = bidirectionalLSTM.getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_BACKWARDS); + final INDArray biasWeightsB = bidirectionalLSTM.getParam(GravesBidirectionalLSTMParamInitializer.BIAS_KEY_BACKWARDS); + // assert that the forwards and backwards are the same shapes assertArrayEquals(recurrentWeightsF.shape(), recurrentWeightsB.shape()); assertArrayEquals(inputWeightsF.shape(), inputWeightsB.shape()); assertArrayEquals(biasWeightsF.shape(), biasWeightsB.shape()); - - //zero out backwards layer - bidirectionalLSTM.setParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_BACKWARDS, - Nd4j.zeros(recurrentWeightsB.shape())); - bidirectionalLSTM.setParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_BACKWARDS, - Nd4j.zeros(inputWeightsB.shape())); - bidirectionalLSTM.setParam(GravesBidirectionalLSTMParamInitializer.BIAS_KEY_BACKWARDS, - Nd4j.zeros(biasWeightsB.shape())); - - + // zero out backwards layer + bidirectionalLSTM.setParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_BACKWARDS, Nd4j.zeros(recurrentWeightsB.shape())); + bidirectionalLSTM.setParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_BACKWARDS, Nd4j.zeros(inputWeightsB.shape())); + bidirectionalLSTM.setParam(GravesBidirectionalLSTMParamInitializer.BIAS_KEY_BACKWARDS, Nd4j.zeros(biasWeightsB.shape())); forwardsLSTM.setInput(sig, LayerWorkspaceMgr.noWorkspaces()); - - //compare activations + // compare activations final INDArray activation1 = forwardsLSTM.activate(sig, false, LayerWorkspaceMgr.noWorkspaces()).slice(0); final INDArray activation2 = bidirectionalLSTM.activate(sig, false, LayerWorkspaceMgr.noWorkspaces()).slice(0); - assertArrayEquals(activation1.data().asFloat(), activation2.data().asFloat(), 1e-5f); - - final INDArray randSig = (rnnDataFormat == RNNFormat.NCW)?Nd4j.rand(new int[] {1, layerSize, timeSeriesLength}): - Nd4j.rand(new int[] {1, timeSeriesLength, layerSize}); + final INDArray randSig = (rnnDataFormat == RNNFormat.NCW) ? Nd4j.rand(new int[] { 1, layerSize, timeSeriesLength }) : Nd4j.rand(new int[] { 1, timeSeriesLength, layerSize }); INDArray randSigBackwards = randSig.dup(); - if (rnnDataFormat == RNNFormat.NCW){ + if (rnnDataFormat == RNNFormat.NCW) { reverseColumnsInPlace(randSigBackwards.slice(0)); - }else{ + } else { reverseColumnsInPlace(randSigBackwards.slice(0).permute(1, 0)); } - final Pair backprop1 = forwardsLSTM.backpropGradient(randSig, LayerWorkspaceMgr.noWorkspaces()); final Pair backprop2 = bidirectionalLSTM.backpropGradient(randSig, LayerWorkspaceMgr.noWorkspaces()); - - //compare gradients - assertArrayEquals( - backprop1.getFirst().getGradientFor(GravesLSTMParamInitializer.RECURRENT_WEIGHT_KEY).dup() - .data().asFloat(), - backprop2.getFirst() - .getGradientFor(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS) - .dup().data().asFloat(), - 1e-5f); - - assertArrayEquals( - backprop1.getFirst().getGradientFor(GravesLSTMParamInitializer.INPUT_WEIGHT_KEY).dup().data() - .asFloat(), - backprop2.getFirst() - .getGradientFor(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS) - .dup().data().asFloat(), - 1e-5f); - - assertArrayEquals( - backprop1.getFirst().getGradientFor(GravesLSTMParamInitializer.BIAS_KEY).dup().data().asFloat(), - backprop2.getFirst().getGradientFor(GravesBidirectionalLSTMParamInitializer.BIAS_KEY_FORWARDS) - .dup().data().asFloat(), - 1e-5f); - - //copy forwards to backwards - bidirectionalLSTM.setParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_BACKWARDS, - bidirectionalLSTM.getParam( - GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS)); - - bidirectionalLSTM.setParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_BACKWARDS, - bidirectionalLSTM.getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS)); - - bidirectionalLSTM.setParam(GravesBidirectionalLSTMParamInitializer.BIAS_KEY_BACKWARDS, - bidirectionalLSTM.getParam(GravesBidirectionalLSTMParamInitializer.BIAS_KEY_FORWARDS)); - - //zero out forwards layer - bidirectionalLSTM.setParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS, - Nd4j.zeros(recurrentWeightsB.shape())); - bidirectionalLSTM.setParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS, - Nd4j.zeros(inputWeightsB.shape())); - bidirectionalLSTM.setParam(GravesBidirectionalLSTMParamInitializer.BIAS_KEY_FORWARDS, - Nd4j.zeros(biasWeightsB.shape())); - - //run on reversed signal + // compare gradients + assertArrayEquals(backprop1.getFirst().getGradientFor(GravesLSTMParamInitializer.RECURRENT_WEIGHT_KEY).dup().data().asFloat(), backprop2.getFirst().getGradientFor(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS).dup().data().asFloat(), 1e-5f); + assertArrayEquals(backprop1.getFirst().getGradientFor(GravesLSTMParamInitializer.INPUT_WEIGHT_KEY).dup().data().asFloat(), backprop2.getFirst().getGradientFor(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS).dup().data().asFloat(), 1e-5f); + assertArrayEquals(backprop1.getFirst().getGradientFor(GravesLSTMParamInitializer.BIAS_KEY).dup().data().asFloat(), backprop2.getFirst().getGradientFor(GravesBidirectionalLSTMParamInitializer.BIAS_KEY_FORWARDS).dup().data().asFloat(), 1e-5f); + // copy forwards to backwards + bidirectionalLSTM.setParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_BACKWARDS, bidirectionalLSTM.getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS)); + bidirectionalLSTM.setParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_BACKWARDS, bidirectionalLSTM.getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS)); + bidirectionalLSTM.setParam(GravesBidirectionalLSTMParamInitializer.BIAS_KEY_BACKWARDS, bidirectionalLSTM.getParam(GravesBidirectionalLSTMParamInitializer.BIAS_KEY_FORWARDS)); + // zero out forwards layer + bidirectionalLSTM.setParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS, Nd4j.zeros(recurrentWeightsB.shape())); + bidirectionalLSTM.setParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS, Nd4j.zeros(inputWeightsB.shape())); + bidirectionalLSTM.setParam(GravesBidirectionalLSTMParamInitializer.BIAS_KEY_FORWARDS, Nd4j.zeros(biasWeightsB.shape())); + // run on reversed signal final INDArray activation3 = bidirectionalLSTM.activate(sigb, false, LayerWorkspaceMgr.noWorkspaces()).slice(0); - final INDArray activation3Reverse = activation3.dup(); - if (rnnDataFormat == RNNFormat.NCW){ + if (rnnDataFormat == RNNFormat.NCW) { reverseColumnsInPlace(activation3Reverse); - } - else{ + } else { reverseColumnsInPlace(activation3Reverse.permute(1, 0)); } - assertArrayEquals(activation3Reverse.shape(), activation1.shape()); assertEquals(activation3Reverse, activation1); - - - //test backprop now - final INDArray refBackGradientReccurrent = - backprop1.getFirst().getGradientFor(GravesLSTMParamInitializer.RECURRENT_WEIGHT_KEY); - - final INDArray refBackGradientInput = - backprop1.getFirst().getGradientFor(GravesLSTMParamInitializer.INPUT_WEIGHT_KEY); - + // test backprop now + final INDArray refBackGradientReccurrent = backprop1.getFirst().getGradientFor(GravesLSTMParamInitializer.RECURRENT_WEIGHT_KEY); + final INDArray refBackGradientInput = backprop1.getFirst().getGradientFor(GravesLSTMParamInitializer.INPUT_WEIGHT_KEY); final INDArray refBackGradientBias = backprop1.getFirst().getGradientFor(GravesLSTMParamInitializer.BIAS_KEY); - - //reverse weights only with backwards signal should yield same result as forwards weights with forwards signal + // reverse weights only with backwards signal should yield same result as forwards weights with forwards signal final Pair backprop3 = bidirectionalLSTM.backpropGradient(randSigBackwards, LayerWorkspaceMgr.noWorkspaces()); - - final INDArray backGradientRecurrent = backprop3.getFirst() - .getGradientFor(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_BACKWARDS); - final INDArray backGradientInput = backprop3.getFirst() - .getGradientFor(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_BACKWARDS); - final INDArray backGradientBias = - backprop3.getFirst().getGradientFor(GravesBidirectionalLSTMParamInitializer.BIAS_KEY_BACKWARDS); - + final INDArray backGradientRecurrent = backprop3.getFirst().getGradientFor(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_BACKWARDS); + final INDArray backGradientInput = backprop3.getFirst().getGradientFor(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_BACKWARDS); + final INDArray backGradientBias = backprop3.getFirst().getGradientFor(GravesBidirectionalLSTMParamInitializer.BIAS_KEY_BACKWARDS); assertArrayEquals(refBackGradientBias.dup().data().asDouble(), backGradientBias.dup().data().asDouble(), 1e-6); - - assertArrayEquals(refBackGradientInput.dup().data().asDouble(), backGradientInput.dup().data().asDouble(), - 1e-6); - - assertArrayEquals(refBackGradientReccurrent.dup().data().asDouble(), - backGradientRecurrent.dup().data().asDouble(), 1e-6); - + assertArrayEquals(refBackGradientInput.dup().data().asDouble(), backGradientInput.dup().data().asDouble(), 1e-6); + assertArrayEquals(refBackGradientReccurrent.dup().data().asDouble(), backGradientRecurrent.dup().data().asDouble(), 1e-6); final INDArray refEpsilon = backprop1.getSecond().dup(); final INDArray backEpsilon = backprop3.getSecond().dup(); - if (rnnDataFormat == RNNFormat.NCW) { reverseColumnsInPlace(refEpsilon.slice(0)); - } - else{ + } else { reverseColumnsInPlace(refEpsilon.slice(0).permute(1, 0)); } assertArrayEquals(backEpsilon.dup().data().asDouble(), refEpsilon.dup().data().asDouble(), 1e-6); - } @Test - public void testSerialization() { - - final MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .updater(new AdaGrad(0.1)) - .l2(0.001) - .seed(12345).list() - .layer(0, new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder() - .activation(Activation.TANH).nIn(2).nOut(2) - .dist(new UniformDistribution(-0.05, 0.05)).build()) - .layer(1, new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder() - .activation(Activation.TANH).nIn(2).nOut(2) - .dist(new UniformDistribution(-0.05, 0.05)).build()) - .layer(2, new org.deeplearning4j.nn.conf.layers.RnnOutputLayer.Builder() - .activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT) - .nIn(2).nOut(2).build()) - .build(); - - + @DisplayName("Test Serialization") + void testSerialization() { + final MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new AdaGrad(0.1)).l2(0.001).seed(12345).list().layer(0, new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().activation(Activation.TANH).nIn(2).nOut(2).dist(new UniformDistribution(-0.05, 0.05)).build()).layer(1, new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().activation(Activation.TANH).nIn(2).nOut(2).dist(new UniformDistribution(-0.05, 0.05)).build()).layer(2, new org.deeplearning4j.nn.conf.layers.RnnOutputLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).nIn(2).nOut(2).build()).build(); final String json1 = conf1.toJson(); - final MultiLayerConfiguration conf2 = MultiLayerConfiguration.fromJson(json1); - final String json2 = conf1.toJson(); - - - TestCase.assertEquals(json1, json2); + assertEquals(json1, json2); } @Test - public void testGateActivationFnsSanityCheck() { - for (String gateAfn : new String[] {"sigmoid", "hardsigmoid"}) { - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .seed(12345).list() - .layer(0, new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder() - .gateActivationFunction(gateAfn).activation(Activation.TANH).nIn(2).nOut(2).dataFormat(rnnDataFormat) - .build()) - .layer(1, new org.deeplearning4j.nn.conf.layers.RnnOutputLayer.Builder() - .lossFunction(LossFunctions.LossFunction.MSE).nIn(2).nOut(2).dataFormat(rnnDataFormat) - .activation(Activation.TANH).build()) - .build(); - + @DisplayName("Test Gate Activation Fns Sanity Check") + void testGateActivationFnsSanityCheck() { + for (String gateAfn : new String[] { "sigmoid", "hardsigmoid" }) { + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(12345).list().layer(0, new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().gateActivationFunction(gateAfn).activation(Activation.TANH).nIn(2).nOut(2).dataFormat(rnnDataFormat).build()).layer(1, new org.deeplearning4j.nn.conf.layers.RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(2).nOut(2).dataFormat(rnnDataFormat).activation(Activation.TANH).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - - assertEquals(gateAfn, ((org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM) net.getLayer(0).conf() - .getLayer()).getGateActivationFn().toString()); - - INDArray in = Nd4j.rand(new int[] {3, 2, 5}); - INDArray labels = Nd4j.rand(new int[] {3, 2, 5}); - if (rnnDataFormat == RNNFormat.NWC){ + assertEquals(gateAfn, ((org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM) net.getLayer(0).conf().getLayer()).getGateActivationFn().toString()); + INDArray in = Nd4j.rand(new int[] { 3, 2, 5 }); + INDArray labels = Nd4j.rand(new int[] { 3, 2, 5 }); + if (rnnDataFormat == RNNFormat.NWC) { in = in.permute(0, 2, 1); labels = labels.permute(0, 2, 1); } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTMTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTMTest.java index 63f343c3c..1aef56056 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTMTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTMTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.layers.recurrent; import lombok.val; @@ -31,7 +30,7 @@ import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.params.GravesLSTMParamInitializer; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -40,152 +39,118 @@ import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.common.primitives.Pair; - import java.lang.reflect.Field; import java.lang.reflect.Method; import java.util.List; +import static org.junit.jupiter.api.Assertions.*; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.*; - - -public class GravesLSTMTest extends BaseDL4JTest { +@DisplayName("Graves LSTM Test") +class GravesLSTMTest extends BaseDL4JTest { @Test - public void testLSTMGravesForwardBasic() { - //Very basic test of forward prop. of LSTM layer with a time series. - //Essentially make sure it doesn't throw any exceptions, and provides output in the correct shape. - + @DisplayName("Test LSTM Graves Forward Basic") + void testLSTMGravesForwardBasic() { + // Very basic test of forward prop. of LSTM layer with a time series. + // Essentially make sure it doesn't throw any exceptions, and provides output in the correct shape. int nIn = 13; int nHiddenUnits = 17; - - NeuralNetConfiguration conf = - new NeuralNetConfiguration.Builder() - .layer(new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn) - .nOut(nHiddenUnits).activation(Activation.TANH).build()) - .build(); - + NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().layer(new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn).nOut(nHiddenUnits).activation(Activation.TANH).build()).build(); val numParams = conf.getLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); GravesLSTM layer = (GravesLSTM) conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); - - //Data: has shape [miniBatchSize,nIn,timeSeriesLength]; - //Output/activations has shape [miniBatchsize,nHiddenUnits,timeSeriesLength]; - + // Data: has shape [miniBatchSize,nIn,timeSeriesLength]; + // Output/activations has shape [miniBatchsize,nHiddenUnits,timeSeriesLength]; INDArray dataSingleExampleTimeLength1 = Nd4j.ones(1, nIn, 1); INDArray activations1 = layer.activate(dataSingleExampleTimeLength1, false, LayerWorkspaceMgr.noWorkspaces()); - assertArrayEquals(activations1.shape(), new long[] {1, nHiddenUnits, 1}); - + assertArrayEquals(activations1.shape(), new long[] { 1, nHiddenUnits, 1 }); INDArray dataMultiExampleLength1 = Nd4j.ones(10, nIn, 1); INDArray activations2 = layer.activate(dataMultiExampleLength1, false, LayerWorkspaceMgr.noWorkspaces()); - assertArrayEquals(activations2.shape(), new long[] {10, nHiddenUnits, 1}); - + assertArrayEquals(activations2.shape(), new long[] { 10, nHiddenUnits, 1 }); INDArray dataSingleExampleLength12 = Nd4j.ones(1, nIn, 12); INDArray activations3 = layer.activate(dataSingleExampleLength12, false, LayerWorkspaceMgr.noWorkspaces()); - assertArrayEquals(activations3.shape(), new long[] {1, nHiddenUnits, 12}); - + assertArrayEquals(activations3.shape(), new long[] { 1, nHiddenUnits, 12 }); INDArray dataMultiExampleLength15 = Nd4j.ones(10, nIn, 15); INDArray activations4 = layer.activate(dataMultiExampleLength15, false, LayerWorkspaceMgr.noWorkspaces()); - assertArrayEquals(activations4.shape(), new long[] {10, nHiddenUnits, 15}); + assertArrayEquals(activations4.shape(), new long[] { 10, nHiddenUnits, 15 }); } @Test - public void testLSTMGravesBackwardBasic() { - //Very basic test of backprop for mini-batch + time series - //Essentially make sure it doesn't throw any exceptions, and provides output in the correct shape. - + @DisplayName("Test LSTM Graves Backward Basic") + void testLSTMGravesBackwardBasic() { + // Very basic test of backprop for mini-batch + time series + // Essentially make sure it doesn't throw any exceptions, and provides output in the correct shape. testGravesBackwardBasicHelper(13, 3, 17, 10, 7); - testGravesBackwardBasicHelper(13, 3, 17, 1, 7); //Edge case: miniBatchSize = 1 - testGravesBackwardBasicHelper(13, 3, 17, 10, 1); //Edge case: timeSeriesLength = 1 - testGravesBackwardBasicHelper(13, 3, 17, 1, 1); //Edge case: both miniBatchSize = 1 and timeSeriesLength = 1 + // Edge case: miniBatchSize = 1 + testGravesBackwardBasicHelper(13, 3, 17, 1, 7); + // Edge case: timeSeriesLength = 1 + testGravesBackwardBasicHelper(13, 3, 17, 10, 1); + // Edge case: both miniBatchSize = 1 and timeSeriesLength = 1 + testGravesBackwardBasicHelper(13, 3, 17, 1, 1); } - private static void testGravesBackwardBasicHelper(int nIn, int nOut, int lstmNHiddenUnits, int miniBatchSize, - int timeSeriesLength) { - + private static void testGravesBackwardBasicHelper(int nIn, int nOut, int lstmNHiddenUnits, int miniBatchSize, int timeSeriesLength) { INDArray inputData = Nd4j.ones(miniBatchSize, nIn, timeSeriesLength); - - NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder() - .layer(new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn) - .nOut(lstmNHiddenUnits) - .dist(new UniformDistribution(0, 1)).activation(Activation.TANH).build()) - .build(); - + NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().layer(new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn).nOut(lstmNHiddenUnits).dist(new UniformDistribution(0, 1)).activation(Activation.TANH).build()).build(); val numParams = conf.getLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); GravesLSTM lstm = (GravesLSTM) conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); lstm.setBackpropGradientsViewArray(Nd4j.create(1, conf.getLayer().initializer().numParams(conf))); - //Set input, do a forward pass: + // Set input, do a forward pass: lstm.activate(inputData, false, LayerWorkspaceMgr.noWorkspaces()); assertNotNull(lstm.input()); - INDArray epsilon = Nd4j.ones(miniBatchSize, lstmNHiddenUnits, timeSeriesLength); - Pair out = lstm.backpropGradient(epsilon, LayerWorkspaceMgr.noWorkspaces()); Gradient outGradient = out.getFirst(); INDArray nextEpsilon = out.getSecond(); - INDArray biasGradient = outGradient.getGradientFor(GravesLSTMParamInitializer.BIAS_KEY); INDArray inWeightGradient = outGradient.getGradientFor(GravesLSTMParamInitializer.INPUT_WEIGHT_KEY); INDArray recurrentWeightGradient = outGradient.getGradientFor(GravesLSTMParamInitializer.RECURRENT_WEIGHT_KEY); assertNotNull(biasGradient); assertNotNull(inWeightGradient); assertNotNull(recurrentWeightGradient); - - assertArrayEquals(biasGradient.shape(), new long[] {1, 4 * lstmNHiddenUnits}); - assertArrayEquals(inWeightGradient.shape(), new long[] {nIn, 4 * lstmNHiddenUnits}); - assertArrayEquals(recurrentWeightGradient.shape(), new long[] {lstmNHiddenUnits, 4 * lstmNHiddenUnits + 3}); - + assertArrayEquals(biasGradient.shape(), new long[] { 1, 4 * lstmNHiddenUnits }); + assertArrayEquals(inWeightGradient.shape(), new long[] { nIn, 4 * lstmNHiddenUnits }); + assertArrayEquals(recurrentWeightGradient.shape(), new long[] { lstmNHiddenUnits, 4 * lstmNHiddenUnits + 3 }); assertNotNull(nextEpsilon); - assertArrayEquals(nextEpsilon.shape(), new long[] {miniBatchSize, nIn, timeSeriesLength}); - - //Check update: + assertArrayEquals(nextEpsilon.shape(), new long[] { miniBatchSize, nIn, timeSeriesLength }); + // Check update: for (String s : outGradient.gradientForVariable().keySet()) { lstm.update(outGradient.getGradientFor(s), s); } } @Test - public void testGravesLSTMForwardPassHelper() throws Exception { - //GravesLSTM.activateHelper() has different behaviour (due to optimizations) when forBackprop==true vs false - //But should otherwise provide identical activations + @DisplayName("Test Graves LSTM Forward Pass Helper") + void testGravesLSTMForwardPassHelper() throws Exception { + // GravesLSTM.activateHelper() has different behaviour (due to optimizations) when forBackprop==true vs false + // But should otherwise provide identical activations Nd4j.getRandom().setSeed(12345); - int nIn = 10; int layerSize = 15; int miniBatchSize = 4; int timeSeriesLength = 7; - - NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder() - .layer(new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn).nOut(layerSize) - .dist(new UniformDistribution(0, 1)) - .activation(Activation.TANH).build()) - .build(); - + NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().layer(new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn).nOut(layerSize).dist(new UniformDistribution(0, 1)).activation(Activation.TANH).build()).build(); val numParams = conf.getLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); GravesLSTM lstm = (GravesLSTM) conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); - INDArray input = Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength}); + INDArray input = Nd4j.rand(new int[] { miniBatchSize, nIn, timeSeriesLength }); lstm.setInput(input, LayerWorkspaceMgr.noWorkspaces()); - - Method actHelper = GravesLSTM.class.getDeclaredMethod("activateHelper", boolean.class, INDArray.class, - INDArray.class, boolean.class, LayerWorkspaceMgr.class); + Method actHelper = GravesLSTM.class.getDeclaredMethod("activateHelper", boolean.class, INDArray.class, INDArray.class, boolean.class, LayerWorkspaceMgr.class); actHelper.setAccessible(true); - - //Call activateHelper with both forBackprop == true, and forBackprop == false and compare + // Call activateHelper with both forBackprop == true, and forBackprop == false and compare Class innerClass = DL4JClassLoading.loadClassByName("org.deeplearning4j.nn.layers.recurrent.FwdPassReturn"); - - Object oFalse = actHelper.invoke(lstm, false, null, null, false, LayerWorkspaceMgr.noWorkspacesImmutable()); //GravesLSTM.FwdPassReturn object; want fwdPassOutput INDArray - Object oTrue = actHelper.invoke(lstm, false, null, null, true, LayerWorkspaceMgr.noWorkspacesImmutable()); //want fwdPassOutputAsArrays object - + // GravesLSTM.FwdPassReturn object; want fwdPassOutput INDArray + Object oFalse = actHelper.invoke(lstm, false, null, null, false, LayerWorkspaceMgr.noWorkspacesImmutable()); + // want fwdPassOutputAsArrays object + Object oTrue = actHelper.invoke(lstm, false, null, null, true, LayerWorkspaceMgr.noWorkspacesImmutable()); Field fwdPassOutput = innerClass.getDeclaredField("fwdPassOutput"); fwdPassOutput.setAccessible(true); - Field fwdPassOutputAsArrays = innerClass.getDeclaredField("fwdPassOutputAsArrays"); fwdPassOutputAsArrays.setAccessible(true); - INDArray fwdPassFalse = (INDArray) fwdPassOutput.get(oFalse); INDArray[] fwdPassTrue = (INDArray[]) fwdPassOutputAsArrays.get(oTrue); - for (int i = 0; i < timeSeriesLength; i++) { INDArray sliceFalse = fwdPassFalse.tensorAlongDimension(i, 1, 0); INDArray sliceTrue = fwdPassTrue[i]; @@ -194,54 +159,35 @@ public class GravesLSTMTest extends BaseDL4JTest { } @Test - public void testSingleExample() { + @DisplayName("Test Single Example") + void testSingleExample() { Nd4j.getRandom().setSeed(12345); - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .updater(new Sgd(0.1)).seed(12345).list() - .layer(0, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().activation(Activation.TANH) - .nIn(2).nOut(2).build()) - .layer(1, new org.deeplearning4j.nn.conf.layers.RnnOutputLayer.Builder() - .lossFunction(LossFunctions.LossFunction.MSE).nIn(2).nOut(1) - .activation(Activation.TANH).build()) - .build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Sgd(0.1)).seed(12345).list().layer(0, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().activation(Activation.TANH).nIn(2).nOut(2).build()).layer(1, new org.deeplearning4j.nn.conf.layers.RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(2).nOut(1).activation(Activation.TANH).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - - INDArray in1 = Nd4j.rand(new int[] {1, 2, 4}); - INDArray in2 = Nd4j.rand(new int[] {1, 2, 5}); - in2.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, 4)}, in1); - + INDArray in1 = Nd4j.rand(new int[] { 1, 2, 4 }); + INDArray in2 = Nd4j.rand(new int[] { 1, 2, 5 }); + in2.put(new INDArrayIndex[] { NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, 4) }, in1); assertEquals(in1, in2.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, 4))); - - INDArray labels1 = Nd4j.rand(new int[] {1, 1, 4}); + INDArray labels1 = Nd4j.rand(new int[] { 1, 1, 4 }); INDArray labels2 = Nd4j.create(1, 1, 5); - labels2.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, 4)}, labels1); + labels2.put(new INDArrayIndex[] { NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, 4) }, labels1); assertEquals(labels1, labels2.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, 4))); - INDArray out1 = net.output(in1); INDArray out2 = net.output(in2); - -// System.out.println(Arrays.toString(net.output(in1).data().asFloat())); -// System.out.println(Arrays.toString(net.output(in2).data().asFloat())); - + // System.out.println(Arrays.toString(net.output(in1).data().asFloat())); + // System.out.println(Arrays.toString(net.output(in2).data().asFloat())); List activations1 = net.feedForward(in1); List activations2 = net.feedForward(in2); - -// for (int i = 0; i < 3; i++) { -// System.out.println("-----\n" + i); -// System.out.println(Arrays.toString(activations1.get(i).dup().data().asDouble())); -// System.out.println(Arrays.toString(activations2.get(i).dup().data().asDouble())); -// -// System.out.println(activations1.get(i)); -// System.out.println(activations2.get(i)); -// } - - - - //Expect first 4 time steps to be indentical... + // for (int i = 0; i < 3; i++) { + // System.out.println("-----\n" + i); + // System.out.println(Arrays.toString(activations1.get(i).dup().data().asDouble())); + // System.out.println(Arrays.toString(activations2.get(i).dup().data().asDouble())); + // + // System.out.println(activations1.get(i)); + // System.out.println(activations2.get(i)); + // } + // Expect first 4 time steps to be indentical... for (int i = 0; i < 4; i++) { double d1 = out1.getDouble(i); double d2 = out2.getDouble(i); @@ -249,31 +195,16 @@ public class GravesLSTMTest extends BaseDL4JTest { } } - @Test - public void testGateActivationFnsSanityCheck() { - for (String gateAfn : new String[] {"sigmoid", "hardsigmoid"}) { - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .seed(12345).list() - .layer(0, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder() - .gateActivationFunction(gateAfn).activation(Activation.TANH).nIn(2).nOut(2) - .build()) - .layer(1, new org.deeplearning4j.nn.conf.layers.RnnOutputLayer.Builder() - .lossFunction(LossFunctions.LossFunction.MSE).nIn(2).nOut(2) - .activation(Activation.TANH).build()) - .build(); - + @DisplayName("Test Gate Activation Fns Sanity Check") + void testGateActivationFnsSanityCheck() { + for (String gateAfn : new String[] { "sigmoid", "hardsigmoid" }) { + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(12345).list().layer(0, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().gateActivationFunction(gateAfn).activation(Activation.TANH).nIn(2).nOut(2).build()).layer(1, new org.deeplearning4j.nn.conf.layers.RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(2).nOut(2).activation(Activation.TANH).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - - assertEquals(gateAfn, ((org.deeplearning4j.nn.conf.layers.GravesLSTM) net.getLayer(0).conf().getLayer()) - .getGateActivationFn().toString()); - - INDArray in = Nd4j.rand(new int[] {3, 2, 5}); - INDArray labels = Nd4j.rand(new int[] {3, 2, 5}); - + assertEquals(gateAfn, ((org.deeplearning4j.nn.conf.layers.GravesLSTM) net.getLayer(0).conf().getLayer()).getGateActivationFn().toString()); + INDArray in = Nd4j.rand(new int[] { 3, 2, 5 }); + INDArray labels = Nd4j.rand(new int[] { 3, 2, 5 }); net.fit(in, labels); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayerTest.java index cf273a450..dad304dac 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayerTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayerTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.layers.recurrent; import org.deeplearning4j.BaseDL4JTest; @@ -30,95 +29,78 @@ import org.deeplearning4j.nn.conf.layers.LSTM; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.optimize.api.TrainingListener; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; - import java.util.Arrays; import java.util.Collections; - -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; @RunWith(Parameterized.class) -public class MaskZeroLayerTest extends BaseDL4JTest { +@DisplayName("Mask Zero Layer Test") +class MaskZeroLayerTest extends BaseDL4JTest { + private RNNFormat rnnDataFormat; - public MaskZeroLayerTest(RNNFormat rnnDataFormat){ + public MaskZeroLayerTest(RNNFormat rnnDataFormat) { this.rnnDataFormat = rnnDataFormat; } + @Parameterized.Parameters - public static Object[] params(){ + public static Object[] params() { return RNNFormat.values(); } + @Test - public void activate() { - - //GIVEN two examples where some of the timesteps are zero. - INDArray ex1 = Nd4j.create(new double[][]{ - new double[]{0, 3, 5}, - new double[]{0, 0, 2} - }); - INDArray ex2 = Nd4j.create(new double[][]{ - new double[]{0, 0, 2}, - new double[]{0, 0, 2} - }); - + @DisplayName("Activate") + void activate() { + // GIVEN two examples where some of the timesteps are zero. + INDArray ex1 = Nd4j.create(new double[][] { new double[] { 0, 3, 5 }, new double[] { 0, 0, 2 } }); + INDArray ex2 = Nd4j.create(new double[][] { new double[] { 0, 0, 2 }, new double[] { 0, 0, 2 } }); // A LSTM which adds one for every non-zero timestep - org.deeplearning4j.nn.conf.layers.LSTM underlying = new org.deeplearning4j.nn.conf.layers.LSTM.Builder() - .activation(Activation.IDENTITY) - .gateActivationFunction(Activation.IDENTITY) - .nIn(2) - .nOut(1).dataFormat(rnnDataFormat) - .build(); + org.deeplearning4j.nn.conf.layers.LSTM underlying = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().activation(Activation.IDENTITY).gateActivationFunction(Activation.IDENTITY).nIn(2).nOut(1).dataFormat(rnnDataFormat).build(); NeuralNetConfiguration conf = new NeuralNetConfiguration(); conf.setLayer(underlying); - INDArray params = Nd4j.zeros(new int[]{1, 16}); - - //Set the biases to 1. + INDArray params = Nd4j.zeros(new int[] { 1, 16 }); + // Set the biases to 1. for (int i = 12; i < 16; i++) { params.putScalar(i, 1.0); } Layer lstm = underlying.instantiate(conf, Collections.emptyList(), 0, params, false, params.dataType()); double maskingValue = 0.0; - MaskZeroLayer l = new MaskZeroLayer(lstm, maskingValue); - INDArray input = Nd4j.create(Arrays.asList(ex1, ex2), new int[]{2, 2, 3}); - if (rnnDataFormat == RNNFormat.NWC){ + INDArray input = Nd4j.create(Arrays.asList(ex1, ex2), new int[] { 2, 2, 3 }); + if (rnnDataFormat == RNNFormat.NWC) { input = input.permute(0, 2, 1); } - //WHEN + // WHEN INDArray out = l.activate(input, true, LayerWorkspaceMgr.noWorkspaces()); - if (rnnDataFormat == RNNFormat.NWC){ - out = out.permute(0, 2,1); + if (rnnDataFormat == RNNFormat.NWC) { + out = out.permute(0, 2, 1); } - //THEN output should only be incremented for the non-zero timesteps + // THEN output should only be incremented for the non-zero timesteps INDArray firstExampleOutput = out.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.all()); INDArray secondExampleOutput = out.get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.all()); - - assertEquals(0.0, firstExampleOutput.getDouble(0), 1e-6); + assertEquals(0.0, firstExampleOutput.getDouble(0), 1e-6); assertEquals(1.0, firstExampleOutput.getDouble(1), 1e-6); assertEquals(2.0, firstExampleOutput.getDouble(2), 1e-6); - assertEquals(0.0, secondExampleOutput.getDouble(0), 1e-6); - assertEquals(0.0, secondExampleOutput.getDouble(1), 1e-6); + assertEquals(0.0, secondExampleOutput.getDouble(1), 1e-6); assertEquals(1.0, secondExampleOutput.getDouble(2), 1e-6); - } @Test - public void testSerialization(){ - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .list() - .layer(new org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer.Builder() - .setMaskValue(0.0).setUnderlying(new LSTM.Builder().nIn(4).nOut(5).dataFormat(rnnDataFormat).build()).build()) - .build(); + @DisplayName("Test Serialization") + void testSerialization() { + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(new org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer.Builder().setMaskValue(0.0).setUnderlying(new LSTM.Builder().nIn(4).nOut(5).dataFormat(rnnDataFormat).build()).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - TestUtils.testModelSerialization(net); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/misc/LargeNetTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/misc/LargeNetTest.java index a0d294f7d..da01cb60c 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/misc/LargeNetTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/misc/LargeNetTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.misc; import org.deeplearning4j.BaseDL4JTest; @@ -28,83 +27,63 @@ import org.deeplearning4j.nn.conf.layers.EmbeddingLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +@Disabled +@DisplayName("Large Net Test") +class LargeNetTest extends BaseDL4JTest { -@Ignore //Ignored due to very large memory requirements -public class LargeNetTest extends BaseDL4JTest { - - @Ignore + @Disabled @Test - public void testLargeMultiLayerNetwork(){ + @DisplayName("Test Large Multi Layer Network") + void testLargeMultiLayerNetwork() { Nd4j.setDataType(DataType.FLOAT); - - //More than 2.1 billion parameters - //10M classes plus 300 vector size -> 3 billion elements - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .list() - .layer(new EmbeddingLayer.Builder().nIn(10_000_000).nOut(300).build()) - .layer(new OutputLayer.Builder().nIn(300).nOut(10).activation(Activation.SOFTMAX).build()) - .build(); - + // More than 2.1 billion parameters + // 10M classes plus 300 vector size -> 3 billion elements + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(new EmbeddingLayer.Builder().nIn(10_000_000).nOut(300).build()).layer(new OutputLayer.Builder().nIn(300).nOut(10).activation(Activation.SOFTMAX).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - INDArray params = net.params(); long paramsLength = params.length(); long expParamsLength = 10_000_000L * 300 + 300 * 10 + 10; assertEquals(expParamsLength, paramsLength); - - long[] expW = new long[]{10_000_000, 300}; + long[] expW = new long[] { 10_000_000, 300 }; assertArrayEquals(expW, net.getParam("0_W").shape()); - - long[] expW1 = new long[]{300, 10}; + long[] expW1 = new long[] { 300, 10 }; assertArrayEquals(expW1, net.getParam("1_W").shape()); - - long[] expB1 = new long[]{1, 10}; + long[] expB1 = new long[] { 1, 10 }; assertArrayEquals(expB1, net.getParam("1_b").shape()); } - @Ignore + @Disabled @Test - public void testLargeCompGraph(){ + @DisplayName("Test Large Comp Graph") + void testLargeCompGraph() { Nd4j.setDataType(DataType.FLOAT); - - //More than 2.1 billion parameters - //10M classes plus 300 vector size -> 3 billion elements - - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() - .graphBuilder() - .addInputs("in") - .layer("0", new EmbeddingLayer.Builder().nIn(10_000_000).nOut(300).build(), "in") - .layer("1", new OutputLayer.Builder().nIn(300).nOut(10).activation(Activation.SOFTMAX).build(), "0") - .setOutputs("1") - .build(); - + // More than 2.1 billion parameters + // 10M classes plus 300 vector size -> 3 billion elements + ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in").layer("0", new EmbeddingLayer.Builder().nIn(10_000_000).nOut(300).build(), "in").layer("1", new OutputLayer.Builder().nIn(300).nOut(10).activation(Activation.SOFTMAX).build(), "0").setOutputs("1").build(); ComputationGraph net = new ComputationGraph(conf); net.init(); - INDArray params = net.params(); long paramsLength = params.length(); long expParamsLength = 10_000_000L * 300 + 300 * 10 + 10; assertEquals(expParamsLength, paramsLength); - - long[] expW = new long[]{10_000_000, 300}; + long[] expW = new long[] { 10_000_000, 300 }; assertArrayEquals(expW, net.getParam("0_W").shape()); - - long[] expW1 = new long[]{300, 10}; + long[] expW1 = new long[] { 300, 10 }; assertArrayEquals(expW1, net.getParam("1_W").shape()); - - long[] expB1 = new long[]{1, 10}; + long[] expB1 = new long[] { 1, 10 }; assertArrayEquals(expB1, net.getParam("1_b").shape()); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/BackPropMLPTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/BackPropMLPTest.java index a03a19ea7..bd1a1d540 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/BackPropMLPTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/BackPropMLPTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.multilayer; import org.deeplearning4j.BaseDL4JTest; @@ -31,7 +30,7 @@ import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.params.DefaultParamInitializer; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.iter.NdIndexIterator; @@ -45,118 +44,108 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; import org.nd4j.linalg.ops.transforms.Transforms; - import java.util.Arrays; - -import static org.junit.Assert.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.Assert.fail; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -public class BackPropMLPTest extends BaseDL4JTest { +@DisplayName("Back Prop MLP Test") +class BackPropMLPTest extends BaseDL4JTest { @Test - public void testMLPTrivial() { - //Simplest possible case: 1 hidden layer, 1 hidden neuron, batch size of 1. - MultiLayerNetwork network = new MultiLayerNetwork(getIrisMLPSimpleConfig(new int[] {1}, Activation.SIGMOID)); + @DisplayName("Test MLP Trivial") + void testMLPTrivial() { + // Simplest possible case: 1 hidden layer, 1 hidden neuron, batch size of 1. + MultiLayerNetwork network = new MultiLayerNetwork(getIrisMLPSimpleConfig(new int[] { 1 }, Activation.SIGMOID)); network.setListeners(new ScoreIterationListener(1)); network.init(); - DataSetIterator iter = new IrisDataSetIterator(1, 10); - - while (iter.hasNext()) - network.fit(iter.next()); + while (iter.hasNext()) network.fit(iter.next()); } @Test - public void testMLP() { - //Simple mini-batch test with multiple hidden layers - MultiLayerConfiguration conf = getIrisMLPSimpleConfig(new int[] {5, 4, 3}, Activation.SIGMOID); -// System.out.println(conf); + @DisplayName("Test MLP") + void testMLP() { + // Simple mini-batch test with multiple hidden layers + MultiLayerConfiguration conf = getIrisMLPSimpleConfig(new int[] { 5, 4, 3 }, Activation.SIGMOID); + // System.out.println(conf); MultiLayerNetwork network = new MultiLayerNetwork(conf); network.init(); DataSetIterator iter = new IrisDataSetIterator(10, 100); - while (iter.hasNext()) { network.fit(iter.next()); } } @Test - public void testMLP2() { - //Simple mini-batch test with multiple hidden layers - MultiLayerConfiguration conf = getIrisMLPSimpleConfig(new int[] {5, 15, 3}, Activation.TANH); -// System.out.println(conf); + @DisplayName("Test MLP 2") + void testMLP2() { + // Simple mini-batch test with multiple hidden layers + MultiLayerConfiguration conf = getIrisMLPSimpleConfig(new int[] { 5, 15, 3 }, Activation.TANH); + // System.out.println(conf); MultiLayerNetwork network = new MultiLayerNetwork(conf); network.init(); - DataSetIterator iter = new IrisDataSetIterator(12, 120); - while (iter.hasNext()) { network.fit(iter.next()); } } @Test - public void testSingleExampleWeightUpdates() { - //Simplest possible case: 1 hidden layer, 1 hidden neuron, batch size of 1. - //Manually calculate weight updates (entirely outside of DL4J and ND4J) + @DisplayName("Test Single Example Weight Updates") + void testSingleExampleWeightUpdates() { + // Simplest possible case: 1 hidden layer, 1 hidden neuron, batch size of 1. + // Manually calculate weight updates (entirely outside of DL4J and ND4J) // and compare expected and actual weights after backprop - DataSetIterator iris = new IrisDataSetIterator(1, 10); - - MultiLayerNetwork network = new MultiLayerNetwork(getIrisMLPSimpleConfig(new int[] {1}, Activation.SIGMOID)); + MultiLayerNetwork network = new MultiLayerNetwork(getIrisMLPSimpleConfig(new int[] { 1 }, Activation.SIGMOID)); network.init(); - Layer[] layers = network.getLayers(); - final boolean printCalculations = false; - while (iris.hasNext()) { DataSet data = iris.next(); INDArray x = data.getFeatures(); INDArray y = data.getLabels(); float[] xFloat = asFloat(x); float[] yFloat = asFloat(y); - - //Do forward pass: - INDArray l1Weights = layers[0].getParam(DefaultParamInitializer.WEIGHT_KEY).dup(); //Hidden layer - INDArray l2Weights = layers[1].getParam(DefaultParamInitializer.WEIGHT_KEY).dup(); //Output layer + // Do forward pass: + // Hidden layer + INDArray l1Weights = layers[0].getParam(DefaultParamInitializer.WEIGHT_KEY).dup(); + // Output layer + INDArray l2Weights = layers[1].getParam(DefaultParamInitializer.WEIGHT_KEY).dup(); INDArray l1Bias = layers[0].getParam(DefaultParamInitializer.BIAS_KEY).dup(); INDArray l2Bias = layers[1].getParam(DefaultParamInitializer.BIAS_KEY).dup(); float[] l1WeightsFloat = asFloat(l1Weights); float[] l2WeightsFloat = asFloat(l2Weights); float l1BiasFloat = l1Bias.getFloat(0); float[] l2BiasFloatArray = asFloat(l2Bias); - - float hiddenUnitPreSigmoid = dotProduct(l1WeightsFloat, xFloat) + l1BiasFloat; //z=w*x+b - float hiddenUnitPostSigmoid = sigmoid(hiddenUnitPreSigmoid); //a=sigma(z) - + // z=w*x+b + float hiddenUnitPreSigmoid = dotProduct(l1WeightsFloat, xFloat) + l1BiasFloat; + // a=sigma(z) + float hiddenUnitPostSigmoid = sigmoid(hiddenUnitPreSigmoid); float[] outputPreSoftmax = new float[3]; - //Normally a matrix multiplication here, but only one hidden unit in this trivial example + // Normally a matrix multiplication here, but only one hidden unit in this trivial example for (int i = 0; i < 3; i++) { outputPreSoftmax[i] = hiddenUnitPostSigmoid * l2WeightsFloat[i] + l2BiasFloatArray[i]; } float[] outputPostSoftmax = softmax(outputPreSoftmax); - - //Do backward pass: - float[] deltaOut = vectorDifference(outputPostSoftmax, yFloat); //out-labels - //deltaHidden = sigmaPrime(hiddenUnitZ) * sum_k (w_jk * \delta_k); here, only one j + // Do backward pass: + // out-labels + float[] deltaOut = vectorDifference(outputPostSoftmax, yFloat); + // deltaHidden = sigmaPrime(hiddenUnitZ) * sum_k (w_jk * \delta_k); here, only one j float deltaHidden = 0.0f; - for (int i = 0; i < 3; i++) - deltaHidden += l2WeightsFloat[i] * deltaOut[i]; + for (int i = 0; i < 3; i++) deltaHidden += l2WeightsFloat[i] * deltaOut[i]; deltaHidden *= derivOfSigmoid(hiddenUnitPreSigmoid); - - //Calculate weight/bias updates: - //dL/dW = delta * (activation of prev. layer) - //dL/db = delta + // Calculate weight/bias updates: + // dL/dW = delta * (activation of prev. layer) + // dL/db = delta float[] dLdwOut = new float[3]; - for (int i = 0; i < dLdwOut.length; i++) - dLdwOut[i] = deltaOut[i] * hiddenUnitPostSigmoid; + for (int i = 0; i < dLdwOut.length; i++) dLdwOut[i] = deltaOut[i] * hiddenUnitPostSigmoid; float[] dLdwHidden = new float[4]; - for (int i = 0; i < dLdwHidden.length; i++) - dLdwHidden[i] = deltaHidden * xFloat[i]; + for (int i = 0; i < dLdwHidden.length; i++) dLdwHidden[i] = deltaHidden * xFloat[i]; float[] dLdbOut = deltaOut; float dLdbHidden = deltaHidden; - if (printCalculations) { System.out.println("deltaOut = " + Arrays.toString(deltaOut)); System.out.println("deltaHidden = " + deltaHidden); @@ -165,30 +154,21 @@ public class BackPropMLPTest extends BaseDL4JTest { System.out.println("dLdwHidden = " + Arrays.toString(dLdwHidden)); System.out.println("dLdbHidden = " + dLdbHidden); } - - - //Calculate new parameters: - //w_i = w_i - (learningRate)/(batchSize) * sum_j (dL_j/dw_i) - //b_i = b_i - (learningRate)/(batchSize) * sum_j (dL_j/db_i) - //Which for batch size of one (here) is simply: - //w_i = w_i - learningRate * dL/dW - //b_i = b_i - learningRate * dL/db + // Calculate new parameters: + // w_i = w_i - (learningRate)/(batchSize) * sum_j (dL_j/dw_i) + // b_i = b_i - (learningRate)/(batchSize) * sum_j (dL_j/db_i) + // Which for batch size of one (here) is simply: + // w_i = w_i - learningRate * dL/dW + // b_i = b_i - learningRate * dL/db float[] expectedL1WeightsAfter = new float[4]; float[] expectedL2WeightsAfter = new float[3]; float expectedL1BiasAfter = l1BiasFloat - 0.1f * dLdbHidden; float[] expectedL2BiasAfter = new float[3]; - - for (int i = 0; i < 4; i++) - expectedL1WeightsAfter[i] = l1WeightsFloat[i] - 0.1f * dLdwHidden[i]; - for (int i = 0; i < 3; i++) - expectedL2WeightsAfter[i] = l2WeightsFloat[i] - 0.1f * dLdwOut[i]; - for (int i = 0; i < 3; i++) - expectedL2BiasAfter[i] = l2BiasFloatArray[i] - 0.1f * dLdbOut[i]; - - - //Finally, do back-prop on network, and compare parameters vs. expected parameters + for (int i = 0; i < 4; i++) expectedL1WeightsAfter[i] = l1WeightsFloat[i] - 0.1f * dLdwHidden[i]; + for (int i = 0; i < 3; i++) expectedL2WeightsAfter[i] = l2WeightsFloat[i] - 0.1f * dLdwOut[i]; + for (int i = 0; i < 3; i++) expectedL2BiasAfter[i] = l2BiasFloatArray[i] - 0.1f * dLdbOut[i]; + // Finally, do back-prop on network, and compare parameters vs. expected parameters network.fit(data); - /* INDArray l1WeightsAfter = layers[0].getParam(DefaultParamInitializer.WEIGHT_KEY).dup(); //Hidden layer INDArray l2WeightsAfter = layers[1].getParam(DefaultParamInitializer.WEIGHT_KEY).dup(); //Output layer INDArray l1BiasAfter = layers[0].getParam(DefaultParamInitializer.BIAS_KEY).dup(); @@ -216,22 +196,21 @@ public class BackPropMLPTest extends BaseDL4JTest { assertEquals(l1BiasFloatAfter,expectedL1BiasAfter,eps); assertArrayEquals(l2BiasFloatAfter,expectedL2BiasAfter,eps); */ -// System.out.println("\n\n--------------"); + // System.out.println("\n\n--------------"); } } - @Test - public void testMLPGradientCalculation() { - testIrisMiniBatchGradients(1, new int[] {1}, Activation.SIGMOID); - testIrisMiniBatchGradients(1, new int[] {5}, Activation.SIGMOID); - testIrisMiniBatchGradients(12, new int[] {15, 25, 10}, Activation.SIGMOID); - testIrisMiniBatchGradients(50, new int[] {10, 50, 200, 50, 10}, Activation.TANH); - testIrisMiniBatchGradients(150, new int[] {30, 50, 20}, Activation.TANH); + @DisplayName("Test MLP Gradient Calculation") + void testMLPGradientCalculation() { + testIrisMiniBatchGradients(1, new int[] { 1 }, Activation.SIGMOID); + testIrisMiniBatchGradients(1, new int[] { 5 }, Activation.SIGMOID); + testIrisMiniBatchGradients(12, new int[] { 15, 25, 10 }, Activation.SIGMOID); + testIrisMiniBatchGradients(50, new int[] { 10, 50, 200, 50, 10 }, Activation.TANH); + testIrisMiniBatchGradients(150, new int[] { 30, 50, 20 }, Activation.TANH); } - private static void testIrisMiniBatchGradients(int miniBatchSize, int[] hiddenLayerSizes, - Activation activationFunction) { + private static void testIrisMiniBatchGradients(int miniBatchSize, int[] hiddenLayerSizes, Activation activationFunction) { int totalExamples = 10 * miniBatchSize; if (totalExamples > 150) { totalExamples = miniBatchSize * (150 / miniBatchSize); @@ -240,26 +219,21 @@ public class BackPropMLPTest extends BaseDL4JTest { fail(); } DataSetIterator iris = new IrisDataSetIterator(miniBatchSize, totalExamples); - MultiLayerNetwork network = new MultiLayerNetwork(getIrisMLPSimpleConfig(hiddenLayerSizes, Activation.SIGMOID)); network.init(); - Layer[] layers = network.getLayers(); int nLayers = layers.length; - while (iris.hasNext()) { DataSet data = iris.next(); INDArray x = data.getFeatures(); INDArray y = data.getLabels(); - - //Do forward pass: + // Do forward pass: INDArray[] layerWeights = new INDArray[nLayers]; INDArray[] layerBiases = new INDArray[nLayers]; for (int i = 0; i < nLayers; i++) { layerWeights[i] = layers[i].getParam(DefaultParamInitializer.WEIGHT_KEY).dup(); layerBiases[i] = layers[i].getParam(DefaultParamInitializer.BIAS_KEY).dup(); } - INDArray[] layerZs = new INDArray[nLayers]; INDArray[] layerActivations = new INDArray[nLayers]; for (int i = 0; i < nLayers; i++) { @@ -267,40 +241,37 @@ public class BackPropMLPTest extends BaseDL4JTest { layerZs[i] = layerInput.castTo(layerWeights[i].dataType()).mmul(layerWeights[i]).addiRowVector(layerBiases[i]); layerActivations[i] = (i == nLayers - 1 ? doSoftmax(layerZs[i].dup()) : doSigmoid(layerZs[i].dup())); } - - //Do backward pass: + // Do backward pass: INDArray[] deltas = new INDArray[nLayers]; - deltas[nLayers - 1] = layerActivations[nLayers - 1].sub(y.castTo(layerActivations[nLayers-1].dataType())); //Out - labels; shape=[miniBatchSize,nOut]; - assertArrayEquals(deltas[nLayers - 1].shape(), new long[] {miniBatchSize, 3}); + // Out - labels; shape=[miniBatchSize,nOut]; + deltas[nLayers - 1] = layerActivations[nLayers - 1].sub(y.castTo(layerActivations[nLayers - 1].dataType())); + assertArrayEquals(deltas[nLayers - 1].shape(), new long[] { miniBatchSize, 3 }); for (int i = nLayers - 2; i >= 0; i--) { INDArray sigmaPrimeOfZ; sigmaPrimeOfZ = doSigmoidDerivative(layerZs[i]); INDArray epsilon = layerWeights[i + 1].mmul(deltas[i + 1].transpose()).transpose(); deltas[i] = epsilon.mul(sigmaPrimeOfZ); - assertArrayEquals(deltas[i].shape(), new long[] {miniBatchSize, hiddenLayerSizes[i]}); + assertArrayEquals(deltas[i].shape(), new long[] { miniBatchSize, hiddenLayerSizes[i] }); } - INDArray[] dLdw = new INDArray[nLayers]; INDArray[] dLdb = new INDArray[nLayers]; for (int i = 0; i < nLayers; i++) { INDArray prevActivations = (i == 0 ? x : layerActivations[i - 1]); - //Raw gradients, so not yet divided by mini-batch size (division is done in BaseUpdater) - dLdw[i] = deltas[i].transpose().castTo(prevActivations.dataType()).mmul(prevActivations).transpose(); //Shape: [nIn, nOut] - dLdb[i] = deltas[i].sum(true, 0); //Shape: [1,nOut] - + // Raw gradients, so not yet divided by mini-batch size (division is done in BaseUpdater) + // Shape: [nIn, nOut] + dLdw[i] = deltas[i].transpose().castTo(prevActivations.dataType()).mmul(prevActivations).transpose(); + // Shape: [1,nOut] + dLdb[i] = deltas[i].sum(true, 0); int nIn = (i == 0 ? 4 : hiddenLayerSizes[i - 1]); int nOut = (i < nLayers - 1 ? hiddenLayerSizes[i] : 3); - assertArrayEquals(dLdw[i].shape(), new long[] {nIn, nOut}); - assertArrayEquals(dLdb[i].shape(), new long[] {1, nOut}); + assertArrayEquals(dLdw[i].shape(), new long[] { nIn, nOut }); + assertArrayEquals(dLdb[i].shape(), new long[] { 1, nOut }); } - - - //Calculate and get gradient, compare to expected + // Calculate and get gradient, compare to expected network.setInput(x); network.setLabels(y); network.computeGradientAndScore(); Gradient gradient = network.gradientAndScore().getFirst(); - float eps = 1e-4f; for (int i = 0; i < hiddenLayerSizes.length; i++) { String wKey = i + "_" + DefaultParamInitializer.WEIGHT_KEY; @@ -317,29 +288,18 @@ public class BackPropMLPTest extends BaseDL4JTest { } } - - /** Very simple back-prop config set up for Iris. + /** + * Very simple back-prop config set up for Iris. * Learning Rate = 0.1 * No regularization, no Adagrad, no momentum etc. One iteration. */ - private static MultiLayerConfiguration getIrisMLPSimpleConfig(int[] hiddenLayerSizes, - Activation activationFunction) { - NeuralNetConfiguration.ListBuilder lb = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)) - .seed(12345L).list(); - + private static MultiLayerConfiguration getIrisMLPSimpleConfig(int[] hiddenLayerSizes, Activation activationFunction) { + NeuralNetConfiguration.ListBuilder lb = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)).seed(12345L).list(); for (int i = 0; i < hiddenLayerSizes.length; i++) { int nIn = (i == 0 ? 4 : hiddenLayerSizes[i - 1]); - lb.layer(i, new DenseLayer.Builder().nIn(nIn).nOut(hiddenLayerSizes[i]).weightInit(WeightInit.XAVIER) - .activation(activationFunction).build()); + lb.layer(i, new DenseLayer.Builder().nIn(nIn).nOut(hiddenLayerSizes[i]).weightInit(WeightInit.XAVIER).activation(activationFunction).build()); } - - lb.layer(hiddenLayerSizes.length, - new OutputLayer.Builder(LossFunction.MCXENT).nIn(hiddenLayerSizes[hiddenLayerSizes.length - 1]) - .nOut(3).weightInit(WeightInit.XAVIER) - .activation(activationFunction.equals(Activation.IDENTITY) ? Activation.IDENTITY - : Activation.SOFTMAX) - .build()); - + lb.layer(hiddenLayerSizes.length, new OutputLayer.Builder(LossFunction.MCXENT).nIn(hiddenLayerSizes[hiddenLayerSizes.length - 1]).nOut(3).weightInit(WeightInit.XAVIER).activation(activationFunction.equals(Activation.IDENTITY) ? Activation.IDENTITY : Activation.SOFTMAX).build()); return lb.build(); } @@ -357,8 +317,7 @@ public class BackPropMLPTest extends BaseDL4JTest { public static float dotProduct(float[] x, float[] y) { float sum = 0.0f; - for (int i = 0; i < x.length; i++) - sum += x[i] * y[i]; + for (int i = 0; i < x.length; i++) sum += x[i] * y[i]; return sum; } @@ -375,7 +334,7 @@ public class BackPropMLPTest extends BaseDL4JTest { } public static float derivOfSigmoid(float in) { - // float v = (float)( Math.exp(in) / Math.pow(1+Math.exp(in),2.0) ); + // float v = (float)( Math.exp(in) / Math.pow(1+Math.exp(in),2.0) ); float v = in * (1 - in); return v; } @@ -419,5 +378,4 @@ public class BackPropMLPTest extends BaseDL4JTest { public static INDArray doSigmoidDerivative(INDArray input) { return Nd4j.getExecutioner().exec(new SigmoidDerivative(input.dup())); } - } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java index 8a9a7a787..6c3ad1855 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.multilayer; import lombok.Data; @@ -54,6 +53,8 @@ import org.deeplearning4j.optimize.api.BaseTrainingListener; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.deeplearning4j.util.ModelSerializer; import org.junit.*; +import org.junit.Test; +import org.junit.jupiter.api.*; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -75,52 +76,47 @@ import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.common.primitives.Pair; - import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.util.*; +import static org.junit.jupiter.api.Assertions.*; -import static org.junit.Assert.*; +import org.junit.jupiter.api.extension.ExtendWith; +import static org.junit.jupiter.api.Assertions.assertThrows; @Slf4j +@DisplayName("Multi Layer Test") public class MultiLayerTest extends BaseDL4JTest { private static OpExecutioner.ProfilingMode origMode; - @BeforeClass - public static void beforeClass(){ + @BeforeAll + static void beforeClass() { origMode = Nd4j.getExecutioner().getProfilingMode(); } - @Before - public void before(){ + @BeforeEach + void before() { Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.SCOPE_PANIC); } - @AfterClass - public static void afterClass(){ + @AfterAll + static void afterClass() { Nd4j.getExecutioner().setProfilingMode(origMode); } @Override - public DataType getDataType(){ + public DataType getDataType() { return DataType.FLOAT; } @Test - public void testSetParams() { - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder() - .list().layer(0, - new DenseLayer.Builder().nIn(4).nOut(3) - .activation(Activation.TANH).build()) - .layer(1, new DenseLayer.Builder().nIn(3).nOut(2).build()) - .build(); - + @DisplayName("Test Set Params") + void testSetParams() { + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(0, new DenseLayer.Builder().nIn(4).nOut(3).activation(Activation.TANH).build()).layer(1, new DenseLayer.Builder().nIn(3).nOut(2).build()).build(); MultiLayerNetwork network3 = new MultiLayerNetwork(conf); network3.init(); - INDArray params = network3.params(); INDArray weights = network3.getLayer(0).getParam(DefaultParamInitializer.WEIGHT_KEY).dup(); INDArray bias = network3.getLayer(0).getParam(DefaultParamInitializer.BIAS_KEY).dup(); @@ -132,69 +128,42 @@ public class MultiLayerTest extends BaseDL4JTest { } @Test - public void testBatchNorm() { + @DisplayName("Test Batch Norm") + void testBatchNorm() { Nd4j.getRandom().setSeed(123); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).seed(123).list() - .layer(0, new DenseLayer.Builder().nIn(4).nOut(3).weightInit(WeightInit.XAVIER) - .activation(Activation.TANH).build()) - .layer(1, new DenseLayer.Builder().nIn(3).nOut(2).weightInit(WeightInit.XAVIER) - .activation(Activation.TANH).build()) - .layer(2, new BatchNormalization.Builder().nOut(2).build()) - .layer(3, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER) - .activation(Activation.SOFTMAX).nIn(2).nOut(3).build()) - .build(); - - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).seed(123).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(3).weightInit(WeightInit.XAVIER).activation(Activation.TANH).build()).layer(1, new DenseLayer.Builder().nIn(3).nOut(2).weightInit(WeightInit.XAVIER).activation(Activation.TANH).build()).layer(2, new BatchNormalization.Builder().nOut(2).build()).layer(3, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nIn(2).nOut(3).build()).build(); MultiLayerNetwork network = new MultiLayerNetwork(conf); network.init(); network.setListeners(new ScoreIterationListener(1)); - DataSetIterator iter = new IrisDataSetIterator(150, 150); - DataSet next = iter.next(); next.normalizeZeroMeanZeroUnitVariance(); SplitTestAndTrain trainTest = next.splitTestAndTrain(110); network.setLabels(trainTest.getTrain().getLabels()); network.init(); - for( int i=0; i<5; i++ ) { + for (int i = 0; i < 5; i++) { network.fit(trainTest.getTrain()); } - } @Test - public void testBackProp() { + @DisplayName("Test Back Prop") + void testBackProp() { Nd4j.getRandom().setSeed(123); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).seed(123).list() - .layer(0, new DenseLayer.Builder().nIn(4).nOut(3).weightInit(WeightInit.XAVIER) - .activation(Activation.TANH).build()) - .layer(1, new DenseLayer.Builder().nIn(3).nOut(2).weightInit(WeightInit.XAVIER) - .activation(Activation.TANH).build()) - .layer(2, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER) - .activation(Activation.SOFTMAX).nIn(2).nOut(3).build()) - .build(); - - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).seed(123).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(3).weightInit(WeightInit.XAVIER).activation(Activation.TANH).build()).layer(1, new DenseLayer.Builder().nIn(3).nOut(2).weightInit(WeightInit.XAVIER).activation(Activation.TANH).build()).layer(2, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nIn(2).nOut(3).build()).build(); MultiLayerNetwork network = new MultiLayerNetwork(conf); network.init(); network.setListeners(new ScoreIterationListener(1)); - DataSetIterator iter = new IrisDataSetIterator(150, 150); - DataSet next = iter.next(); next.normalizeZeroMeanZeroUnitVariance(); SplitTestAndTrain trainTest = next.splitTestAndTrain(110); network.setInput(trainTest.getTrain().getFeatures()); network.setLabels(trainTest.getTrain().getLabels()); network.init(); - for( int i=0; i<5; i++ ) { + for (int i = 0; i < 5; i++) { network.fit(trainTest.getTrain()); } - DataSet test = trainTest.getTest(); Evaluation eval = new Evaluation(); INDArray output = network.output(test.getFeatures()); @@ -202,30 +171,25 @@ public class MultiLayerTest extends BaseDL4JTest { log.info("Score " + eval.stats()); } - - @Test - public void testGradientWithAsList() { + @DisplayName("Test Gradient With As List") + void testGradientWithAsList() { MultiLayerNetwork net1 = new MultiLayerNetwork(getConf()); MultiLayerNetwork net2 = new MultiLayerNetwork(getConf()); net1.init(); net2.init(); - DataSet x1 = new IrisDataSetIterator(1, 150).next(); DataSet all = new IrisDataSetIterator(150, 150).next(); DataSet x2 = all.asList().get(0); - - //x1 and x2 contain identical data + // x1 and x2 contain identical data assertArrayEquals(asFloat(x1.getFeatures()), asFloat(x2.getFeatures()), 0.0f); assertArrayEquals(asFloat(x1.getLabels()), asFloat(x2.getLabels()), 0.0f); assertEquals(x1, x2); - - //Set inputs/outputs so gradient can be calculated: + // Set inputs/outputs so gradient can be calculated: net1.feedForward(x1.getFeatures()); net2.feedForward(x2.getFeatures()); ((BaseOutputLayer) net1.getLayer(1)).setLabels(x1.getLabels()); ((BaseOutputLayer) net2.getLayer(1)).setLabels(x2.getLabels()); - net1.gradient(); net2.gradient(); } @@ -234,7 +198,8 @@ public class MultiLayerTest extends BaseDL4JTest { * This test intended only to test activateSelectedLayers method, it does not involves fully-working AutoEncoder. */ @Test - public void testSelectedActivations() { + @DisplayName("Test Selected Activations") + void testSelectedActivations() { // Train DeepAutoEncoder on very limited trainset final int numRows = 28; final int numColumns = 28; @@ -242,37 +207,18 @@ public class MultiLayerTest extends BaseDL4JTest { int numSamples = 3; int iterations = 1; int listenerFreq = iterations / 5; - log.info("Load data...."); - float[][] trainingData = new float[numSamples][numColumns * numRows]; Arrays.fill(trainingData[0], 0.95f); Arrays.fill(trainingData[1], 0.5f); Arrays.fill(trainingData[2], 0.05f); - - - log.info("Build model...."); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed) - .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).list() - .layer(0, new DenseLayer.Builder().nIn(numRows * numColumns).nOut(1000).build()) - .layer(1, new DenseLayer.Builder().nIn(1000).nOut(500).build()) - .layer(2, new DenseLayer.Builder().nIn(500).nOut(250).build()) - .layer(3, new DenseLayer.Builder().nIn(250).nOut(100).build()) - .layer(4, new DenseLayer.Builder().nIn(100).nOut(30).build()) //encoding stops - .layer(5, new DenseLayer.Builder().nIn(30).nOut(100).build()) //decoding starts - .layer(6, new DenseLayer.Builder().nIn(100).nOut(250).build()) - .layer(7, new DenseLayer.Builder().nIn(250).nOut(500).build()) - .layer(8, new DenseLayer.Builder().nIn(500).nOut(1000).build()) - .layer(9, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nIn(1000) - .nOut(numRows * numColumns).activation(Activation.SOFTMAX).build()) - .build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed).optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).list().layer(0, new DenseLayer.Builder().nIn(numRows * numColumns).nOut(1000).build()).layer(1, new DenseLayer.Builder().nIn(1000).nOut(500).build()).layer(2, new DenseLayer.Builder().nIn(500).nOut(250).build()).layer(3, new DenseLayer.Builder().nIn(250).nOut(100).build()).layer(4, // encoding stops + new DenseLayer.Builder().nIn(100).nOut(30).build()).layer(5, // decoding starts + new DenseLayer.Builder().nIn(30).nOut(100).build()).layer(6, new DenseLayer.Builder().nIn(100).nOut(250).build()).layer(7, new DenseLayer.Builder().nIn(250).nOut(500).build()).layer(8, new DenseLayer.Builder().nIn(500).nOut(1000).build()).layer(9, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nIn(1000).nOut(numRows * numColumns).activation(Activation.SOFTMAX).build()).build(); MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); - model.addListeners(new ScoreIterationListener(listenerFreq)); - log.info("Train model...."); int cnt = 0; while (cnt < numSamples) { @@ -281,95 +227,47 @@ public class MultiLayerTest extends BaseDL4JTest { cnt++; } // Make two separate selective calls - log.info("Testing full cycle..."); - - List comparableResult = model.feedForward(Nd4j.create(trainingData[0], new long[]{1, trainingData[0].length})); - - INDArray encodeResult = model.activateSelectedLayers(0, 4, Nd4j.create(trainingData[0], new long[]{1, trainingData[0].length})); - + List comparableResult = model.feedForward(Nd4j.create(trainingData[0], new long[] { 1, trainingData[0].length })); + INDArray encodeResult = model.activateSelectedLayers(0, 4, Nd4j.create(trainingData[0], new long[] { 1, trainingData[0].length })); log.info("Compare feedForward results with selectedActivation"); - assertEquals(comparableResult.get(5), encodeResult); - INDArray decodeResults = model.activateSelectedLayers(5, 9, encodeResult); - - log.info("Decode results: " + decodeResults.columns() + " " + decodeResults); log.info("Comparable results: " + comparableResult.get(10).columns() + " " + comparableResult.get(10)); - assertEquals(comparableResult.get(10), decodeResults); } private static MultiLayerConfiguration getConf() { - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder().seed(12345L) - .list().layer(0, - new DenseLayer.Builder().nIn(4).nOut(3) - - .dist(new NormalDistribution(0,1)) - .build()) - .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nIn(3).nOut(3) - - .dist(new NormalDistribution(0, 1)).build()) - .build(); + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345L).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(3).dist(new NormalDistribution(0, 1)).build()).layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3).dist(new NormalDistribution(0, 1)).build()).build(); return conf; } public static float[] asFloat(INDArray arr) { long len = arr.length(); - float[] f = new float[(int) len]; - for (int i = 0; i < len; i++) - f[i] = arr.getFloat(i); + for (int i = 0; i < len; i++) f[i] = arr.getFloat(i); return f; } @Test - public void testFeedForwardToLayer() { - + @DisplayName("Test Feed Forward To Layer") + void testFeedForwardToLayer() { int nIn = 30; int nOut = 25; - - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT) - .updater(new Sgd(1e-3)) - .list().layer( - 0, new DenseLayer.Builder().nIn(nIn).nOut(600) - - .dist(new NormalDistribution(0,1e-5)) - .build()) - .layer(1, new DenseLayer.Builder() - .nIn(600).nOut(250) - .dist(new NormalDistribution(0, 1e-5)) - .build()) - .layer(2, new DenseLayer.Builder() - .nIn(250).nOut(100) - .dist(new NormalDistribution(0, 1e-5)) - .build()) - .layer(3, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT).nIn(100).nOut(25) - .activation(Activation.SOFTMAX) - .weightInit(new NormalDistribution(0, 1e-5)).build()) - .build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).updater(new Sgd(1e-3)).list().layer(0, new DenseLayer.Builder().nIn(nIn).nOut(600).dist(new NormalDistribution(0, 1e-5)).build()).layer(1, new DenseLayer.Builder().nIn(600).nOut(250).dist(new NormalDistribution(0, 1e-5)).build()).layer(2, new DenseLayer.Builder().nIn(250).nOut(100).dist(new NormalDistribution(0, 1e-5)).build()).layer(3, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(100).nOut(25).activation(Activation.SOFTMAX).weightInit(new NormalDistribution(0, 1e-5)).build()).build(); MultiLayerNetwork network = new MultiLayerNetwork(conf); network.init(); - - INDArray input = Nd4j.rand(5, nIn); - List activations = network.feedForward(input); - assertEquals(5, activations.size()); //4 layers + input - + // 4 layers + input + assertEquals(5, activations.size()); List activationsAll = network.feedForwardToLayer(3, input); assertEquals(activations, activationsAll); - for (int i = 3; i >= 0; i--) { List activationsPartial = network.feedForwardToLayer(i, input); - assertEquals(i + 2, activationsPartial.size()); //i+2: for layer 3: input + activations of {0,1,2,3} -> 5 total = 3+2 + // i+2: for layer 3: input + activations of {0,1,2,3} -> 5 total = 3+2 + assertEquals(i + 2, activationsPartial.size()); for (int j = 0; j <= i; j++) { INDArray exp = activationsAll.get(j); INDArray act = activationsPartial.get(j); @@ -378,52 +276,36 @@ public class MultiLayerTest extends BaseDL4JTest { } } - @Test - public void testBackpropGradient() { - //Testing: MultiLayerNetwork.backpropGradient() - //i.e., specifically without an output layer - + @DisplayName("Test Backprop Gradient") + void testBackpropGradient() { + // Testing: MultiLayerNetwork.backpropGradient() + // i.e., specifically without an output layer int nIn = 10; int nOut = 40; int miniBatch = 5; - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .updater(new Sgd(0.1)).list() - .layer(0, new DenseLayer.Builder().nIn(nIn).nOut(20).activation(Activation.RELU) - .weightInit(WeightInit.XAVIER).build()) - .layer(1, new DenseLayer.Builder().nIn(20).nOut(30).activation(Activation.RELU) - .weightInit(WeightInit.XAVIER).build()) - .layer(2, new DenseLayer.Builder().nIn(30).nOut(nOut).activation(Activation.RELU) - .weightInit(WeightInit.XAVIER).build()) - .build(); + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)).list().layer(0, new DenseLayer.Builder().nIn(nIn).nOut(20).activation(Activation.RELU).weightInit(WeightInit.XAVIER).build()).layer(1, new DenseLayer.Builder().nIn(20).nOut(30).activation(Activation.RELU).weightInit(WeightInit.XAVIER).build()).layer(2, new DenseLayer.Builder().nIn(30).nOut(nOut).activation(Activation.RELU).weightInit(WeightInit.XAVIER).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - Nd4j.getRandom().setSeed(12345); INDArray eps = Nd4j.rand(miniBatch, nOut); INDArray input = Nd4j.rand(miniBatch, nIn); - net.setInput(input); - net.feedForward(true, false); //Need to feed forward before backprop - + // Need to feed forward before backprop + net.feedForward(true, false); Pair pair = net.backpropGradient(eps, LayerWorkspaceMgr.noWorkspaces()); INDArray epsOut = pair.getSecond(); assertNotNull(epsOut); - assertArrayEquals(new long[] {miniBatch, nIn}, epsOut.shape()); - + assertArrayEquals(new long[] { miniBatch, nIn }, epsOut.shape()); Gradient g = pair.getFirst(); Map gradMap = g.gradientForVariable(); - assertEquals(6, gradMap.size()); //3 layers, weight + bias gradients for each - - String[] expKeys = {"0_" + DefaultParamInitializer.WEIGHT_KEY, "0_" + DefaultParamInitializer.BIAS_KEY, - "1_" + DefaultParamInitializer.WEIGHT_KEY, "2_" + DefaultParamInitializer.BIAS_KEY, - "2_" + DefaultParamInitializer.WEIGHT_KEY, "2_" + DefaultParamInitializer.BIAS_KEY}; + // 3 layers, weight + bias gradients for each + assertEquals(6, gradMap.size()); + String[] expKeys = { "0_" + DefaultParamInitializer.WEIGHT_KEY, "0_" + DefaultParamInitializer.BIAS_KEY, "1_" + DefaultParamInitializer.WEIGHT_KEY, "2_" + DefaultParamInitializer.BIAS_KEY, "2_" + DefaultParamInitializer.WEIGHT_KEY, "2_" + DefaultParamInitializer.BIAS_KEY }; Set keys = gradMap.keySet(); for (String s : expKeys) { assertTrue(keys.contains(s)); } - /* System.out.println(pair); @@ -442,154 +324,100 @@ public class MultiLayerTest extends BaseDL4JTest { } @Test - public void testLayerNames() { + @DisplayName("Test Layer Names") + void testLayerNames() { int nIn = 10; int nOut = 40; - List layerNameList = new ArrayList<>(); layerNameList.add("dnn1"); layerNameList.add("dnn2"); layerNameList.add("dnn3"); - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .updater(new Sgd(0.1)).list() - .layer(0, new DenseLayer.Builder().name("dnn1").nIn(nIn).nOut(20).activation(Activation.RELU) - .weightInit(WeightInit.XAVIER).build()) - .layer(1, new DenseLayer.Builder().name("dnn2").nIn(20).nOut(30).activation(Activation.RELU) - .weightInit(WeightInit.XAVIER).build()) - .layer(2, new DenseLayer.Builder().name("dnn3").nIn(30).nOut(nOut) - .activation(Activation.SOFTMAX).weightInit(WeightInit.XAVIER).build()) - .build(); + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)).list().layer(0, new DenseLayer.Builder().name("dnn1").nIn(nIn).nOut(20).activation(Activation.RELU).weightInit(WeightInit.XAVIER).build()).layer(1, new DenseLayer.Builder().name("dnn2").nIn(20).nOut(30).activation(Activation.RELU).weightInit(WeightInit.XAVIER).build()).layer(2, new DenseLayer.Builder().name("dnn3").nIn(30).nOut(nOut).activation(Activation.SOFTMAX).weightInit(WeightInit.XAVIER).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - assertEquals(layerNameList.get(0), net.getLayer(0).conf().getLayer().getLayerName()); assertEquals(layerNameList, net.getLayerNames()); BaseLayer b = (BaseLayer) net.getLayer(layerNameList.get(2)).conf().getLayer(); - assertEquals("softmax", b.getActivationFn().toString()); + assertEquals(b.getActivationFn().toString(), "softmax"); } - @Test - public void testScoreExamples() { + @DisplayName("Test Score Examples") + void testScoreExamples() { Nd4j.getRandom().setSeed(12345); int nIn = 5; int nOut = 6; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).l1(0.01) - .l2(0.01).updater(new Sgd(0.1)).activation(Activation.TANH).weightInit(WeightInit.XAVIER).list() - .layer(0, new DenseLayer.Builder().nIn(nIn).nOut(20).build()) - .layer(1, new DenseLayer.Builder().nIn(20).nOut(30).build()).layer(2, new OutputLayer.Builder() - .lossFunction(LossFunctions.LossFunction.MSE).nIn(30).nOut(nOut).build()) - .build(); - - MultiLayerConfiguration confNoReg = new NeuralNetConfiguration.Builder().seed(12345) - .updater(new Sgd(0.1)).activation(Activation.TANH).weightInit(WeightInit.XAVIER).list() - .layer(0, new DenseLayer.Builder().nIn(nIn).nOut(20).build()) - .layer(1, new DenseLayer.Builder().nIn(20).nOut(30).build()).layer(2, new OutputLayer.Builder() - .lossFunction(LossFunctions.LossFunction.MSE).nIn(30).nOut(nOut).build()) - .build(); - - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).l1(0.01).l2(0.01).updater(new Sgd(0.1)).activation(Activation.TANH).weightInit(WeightInit.XAVIER).list().layer(0, new DenseLayer.Builder().nIn(nIn).nOut(20).build()).layer(1, new DenseLayer.Builder().nIn(20).nOut(30).build()).layer(2, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(30).nOut(nOut).build()).build(); + MultiLayerConfiguration confNoReg = new NeuralNetConfiguration.Builder().seed(12345).updater(new Sgd(0.1)).activation(Activation.TANH).weightInit(WeightInit.XAVIER).list().layer(0, new DenseLayer.Builder().nIn(nIn).nOut(20).build()).layer(1, new DenseLayer.Builder().nIn(20).nOut(30).build()).layer(2, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(30).nOut(nOut).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - MultiLayerNetwork netNoReg = new MultiLayerNetwork(confNoReg); netNoReg.init(); netNoReg.setParameters(net.params().dup()); - - //Score single example, and compare to scoreExamples: + // Score single example, and compare to scoreExamples: INDArray input = Nd4j.rand(3, nIn); INDArray output = Nd4j.rand(3, nOut); DataSet ds = new DataSet(input, output); - INDArray scoresWithRegularization = net.scoreExamples(ds, true); INDArray scoresNoRegularization = net.scoreExamples(ds, false); - - assertArrayEquals(new long[] {3, 1}, scoresWithRegularization.shape()); - assertArrayEquals(new long[] {3, 1}, scoresNoRegularization.shape()); - + assertArrayEquals(new long[] { 3, 1 }, scoresWithRegularization.shape()); + assertArrayEquals(new long[] { 3, 1 }, scoresNoRegularization.shape()); for (int i = 0; i < 3; i++) { - DataSet singleEx = new DataSet(input.getRow(i,true), output.getRow(i,true)); + DataSet singleEx = new DataSet(input.getRow(i, true), output.getRow(i, true)); double score = net.score(singleEx); double scoreNoReg = netNoReg.score(singleEx); - double scoreUsingScoreExamples = scoresWithRegularization.getDouble(i); double scoreUsingScoreExamplesNoReg = scoresNoRegularization.getDouble(i); assertEquals(score, scoreUsingScoreExamples, 1e-4); assertEquals(scoreNoReg, scoreUsingScoreExamplesNoReg, 1e-4); - assertTrue(scoreUsingScoreExamples > scoreUsingScoreExamplesNoReg); //Regularization term increases score - - // System.out.println(score + "\t" + scoreUsingScoreExamples + "\t|\t" + scoreNoReg + "\t" + scoreUsingScoreExamplesNoReg); + // Regularization term increases score + assertTrue(scoreUsingScoreExamples > scoreUsingScoreExamplesNoReg); + // System.out.println(score + "\t" + scoreUsingScoreExamples + "\t|\t" + scoreNoReg + "\t" + scoreUsingScoreExamplesNoReg); } } @Test - public void testDataSetScore() { - + @DisplayName("Test Data Set Score") + void testDataSetScore() { Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .weightInit(WeightInit.XAVIER).seed(12345L).list() - .layer(0, new DenseLayer.Builder().nIn(4).nOut(3).activation(Activation.SIGMOID).build()) - .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nIn(3).nOut(3).build()) - .build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).seed(12345L).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(3).activation(Activation.SIGMOID).build()).layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - - INDArray in = Nd4j.create(new double[] {1.0, 2.0, 3.0, 4.0}, new long[]{1, 4}); - INDArray out = Nd4j.create(new double[] {1, 0, 0}, new long[]{1,3}); - + INDArray in = Nd4j.create(new double[] { 1.0, 2.0, 3.0, 4.0 }, new long[] { 1, 4 }); + INDArray out = Nd4j.create(new double[] { 1, 0, 0 }, new long[] { 1, 3 }); double score = net.score(new DataSet(in, out)); } @Test - public void testDataSetScoreCNN() { - + @DisplayName("Test Data Set Score CNN") + void testDataSetScoreCNN() { int miniBatch = 3; int depth = 2; int width = 3; int height = 3; int nOut = 2; - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .seed(12345L).list().layer(0, new ConvolutionLayer.Builder(2, 2).nOut(1).build()) - .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nOut(2).build()) - .setInputType(InputType.convolutionalFlat(height, width, depth)) - .build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345L).list().layer(0, new ConvolutionLayer.Builder(2, 2).nOut(1).build()).layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(2).build()).setInputType(InputType.convolutionalFlat(height, width, depth)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - Nd4j.getRandom().setSeed(12345); Random r = new Random(12345); INDArray input = Nd4j.rand(miniBatch, depth * width * height); INDArray labels = Nd4j.create(miniBatch, nOut); for (int i = 0; i < miniBatch; i++) { - labels.putScalar(new int[] {i, r.nextInt(nOut)}, 1.0); + labels.putScalar(new int[] { i, r.nextInt(nOut) }, 1.0); } - double score = net.score(new DataSet(input, labels)); } @Test - public void testPredict() throws Exception { - + @DisplayName("Test Predict") + void testPredict() throws Exception { Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .weightInit(WeightInit.XAVIER).seed(12345L).list() - .layer(0, new DenseLayer.Builder().nIn(784).nOut(50).activation(Activation.RELU).build()) - .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nIn(50).nOut(10).build()) - .setInputType(InputType.convolutional(28, 28, 1)).build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).seed(12345L).list().layer(0, new DenseLayer.Builder().nIn(784).nOut(50).activation(Activation.RELU).build()).layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(50).nOut(10).build()).setInputType(InputType.convolutional(28, 28, 1)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - DataSetIterator ds = new MnistDataSetIterator(10, 10); net.fit(ds); - DataSetIterator testDs = new MnistDataSetIterator(1, 1); DataSet testData = testDs.next(); testData.setLabelNames(Arrays.asList("0", "1", "2", "3", "4", "5", "6", "7", "8", "9")); @@ -600,138 +428,105 @@ public class MultiLayerTest extends BaseDL4JTest { } @Test - @Ignore - public void testCid() throws Exception { + @Disabled + @DisplayName("Test Cid") + void testCid() throws Exception { System.out.println(EnvironmentUtils.buildCId()); - Environment environment = EnvironmentUtils.buildEnvironment(); environment.setSerialVersionID(EnvironmentUtils.buildCId()); - - Task task = TaskUtils.buildTask(Nd4j.create(new double[] {1, 2, 3, 4, 5, 6}, new long[]{1,6})); - + Task task = TaskUtils.buildTask(Nd4j.create(new double[] { 1, 2, 3, 4, 5, 6 }, new long[] { 1, 6 })); Heartbeat.getInstance().reportEvent(Event.STANDALONE, environment, task); - Thread.sleep(25000); } @Test - public void testOutput() throws Exception { + @DisplayName("Test Output") + void testOutput() throws Exception { Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .weightInit(WeightInit.XAVIER).seed(12345L).list() - .layer(0, new DenseLayer.Builder().nIn(784).nOut(50).activation(Activation.RELU).build()) - .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nIn(50).nOut(10).build()) - .setInputType(InputType.convolutional(28, 28, 1)).build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).seed(12345L).list().layer(0, new DenseLayer.Builder().nIn(784).nOut(50).activation(Activation.RELU).build()).layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(50).nOut(10).build()).setInputType(InputType.convolutional(28, 28, 1)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - DataSetIterator fullData = new MnistDataSetIterator(1, 2); net.fit(fullData); - - fullData.reset(); DataSet expectedSet = fullData.next(2); INDArray expectedOut = net.output(expectedSet.getFeatures(), false); - fullData.reset(); - INDArray actualOut = net.output(fullData); - assertEquals(expectedOut, actualOut); } @Test - public void testGradientUpdate() throws Exception { + @DisplayName("Test Gradient Update") + void testGradientUpdate() throws Exception { DataSetIterator iter = new IrisDataSetIterator(1, 1); - Gradient expectedGradient = new DefaultGradient(); expectedGradient.setGradientFor("0_W", Nd4j.ones(4, 5)); expectedGradient.setGradientFor("0_b", Nd4j.ones(1, 5)); expectedGradient.setGradientFor("1_W", Nd4j.ones(5, 3)); expectedGradient.setGradientFor("1_b", Nd4j.ones(1, 3)); - - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder().updater(new Sgd(1.0)) - .activation(Activation.RELU).weightInit(WeightInit.XAVIER) - .list().layer(0, new DenseLayer.Builder().name("dnn1").nIn(4).nOut(5).build()) - .layer(1, new OutputLayer.Builder().name("output").nIn(5).nOut(3) - .activation(Activation.SOFTMAX).weightInit(WeightInit.XAVIER) - .build()) - .build(); + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Sgd(1.0)).activation(Activation.RELU).weightInit(WeightInit.XAVIER).list().layer(0, new DenseLayer.Builder().name("dnn1").nIn(4).nOut(5).build()).layer(1, new OutputLayer.Builder().name("output").nIn(5).nOut(3).activation(Activation.SOFTMAX).weightInit(WeightInit.XAVIER).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); net.fit(iter.next()); // TODO validate actual layer gradientView - issue getting var out of BaseLayer w/o adding MLN getter that gets confused with local gradient vars Gradient actualGradient = net.gradient; assertNotEquals(expectedGradient.getGradientFor("0_W"), actualGradient.getGradientFor("0_W")); - net.update(expectedGradient); actualGradient = net.gradient; assertEquals(expectedGradient.getGradientFor("0_W"), actualGradient.getGradientFor("0_W")); - // Update params with set net.setParam("0_W", Nd4j.ones(4, 5)); net.setParam("0_b", Nd4j.ones(1, 5)); net.setParam("1_W", Nd4j.ones(5, 3)); net.setParam("1_b", Nd4j.ones(1, 3)); INDArray actualParams = net.params(); - // Confirm params assertEquals(expectedGradient.gradient(), actualParams); - net.update(expectedGradient); actualParams = net.params(); assertEquals(Nd4j.ones(1, 43).addi(1), actualParams); } - - @Test(expected = DL4JException.class) - public void testCnnInvalidData() { - - int miniBatch = 3; - int depth = 2; - int width = 5; - int height = 5; - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list() - .layer(0, new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).padding(0, 0).nIn(2) - .nOut(2).build()) - .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nOut(2).build()) - .setInputType(InputType.convolutional(height, width, depth)) - .build(); - - MultiLayerNetwork net = new MultiLayerNetwork(conf); - net.init(); - - INDArray inputWrongDepth = Nd4j.rand(new int[] {miniBatch, 5, height, width}); //Order: examples, channels, height, width - net.feedForward(inputWrongDepth); - + @Test + @DisplayName("Test Cnn Invalid Data") + void testCnnInvalidData() { + assertThrows(DL4JException.class, () -> { + int miniBatch = 3; + int depth = 2; + int width = 5; + int height = 5; + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(0, new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).padding(0, 0).nIn(2).nOut(2).build()).layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(2).build()).setInputType(InputType.convolutional(height, width, depth)).build(); + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + // Order: examples, channels, height, width + INDArray inputWrongDepth = Nd4j.rand(new int[] { miniBatch, 5, height, width }); + net.feedForward(inputWrongDepth); + }); } @Test - public void testApplyingPreTrainConfigAndParams() { + @DisplayName("Test Applying Pre Train Config And Params") + void testApplyingPreTrainConfigAndParams() { int nIn = 10; int nOut = 10; - // Test pretrain true MultiLayerNetwork aePre = getAeModel(true, nIn, nOut); - int actualNP = (int)aePre.numParams(); + int actualNP = (int) aePre.numParams(); assertEquals(2 * (nIn * nOut + nOut) + nIn, actualNP); INDArray params = aePre.params(); - assertEquals(params.length(), actualNP); // check num params + // check num params + assertEquals(params.length(), actualNP); Map paramTable = aePre.paramTable(); - assertTrue(paramTable.containsKey("0_vb")); // check vb exists for pretrain layer + // check vb exists for pretrain layer + assertTrue(paramTable.containsKey("0_vb")); aePre.setParam("0_vb", Nd4j.ones(10)); params = aePre.getParam("0_vb"); - assertEquals(Nd4j.ones(1,10), params); // check set params for vb - - + // check set params for vb + assertEquals(Nd4j.ones(1, 10), params); // Test pretrain false, expect same for true because its not changed when applying update MultiLayerNetwork aeNoPre = getAeModel(false, nIn, nOut); - actualNP = (int)aeNoPre.numParams(); + actualNP = (int) aeNoPre.numParams(); assertEquals(2 * (nIn * nOut + nOut) + nIn, actualNP); params = aeNoPre.params(); assertEquals(params.length(), actualNP); @@ -740,41 +535,20 @@ public class MultiLayerTest extends BaseDL4JTest { } public MultiLayerNetwork getAeModel(boolean preTrain, int nIn, int nOut) { - MultiLayerConfiguration vae = new NeuralNetConfiguration.Builder() - .seed(42).updater(new NoOp()) - .weightInit(WeightInit.UNIFORM) - .list(new AutoEncoder.Builder() - .activation(Activation.IDENTITY).nOut(nIn).build(), - new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.COSINE_PROXIMITY) - .activation(Activation.IDENTITY).nOut(nOut) - .build()) - .setInputType(InputType.feedForward(nOut)).build(); + MultiLayerConfiguration vae = new NeuralNetConfiguration.Builder().seed(42).updater(new NoOp()).weightInit(WeightInit.UNIFORM).list(new AutoEncoder.Builder().activation(Activation.IDENTITY).nOut(nIn).build(), new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.COSINE_PROXIMITY).activation(Activation.IDENTITY).nOut(nOut).build()).setInputType(InputType.feedForward(nOut)).build(); MultiLayerNetwork network = new MultiLayerNetwork(vae); network.init(); return network; } - @Test - public void testIterationCountAndPersistence() throws IOException { + @DisplayName("Test Iteration Count And Persistence") + void testIterationCountAndPersistence() throws IOException { Nd4j.getRandom().setSeed(123); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123) - .list() - .layer(0, new DenseLayer.Builder().nIn(4).nOut(3).weightInit(WeightInit.XAVIER) - .activation(Activation.TANH).build()) - .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3) - .build()) - .build(); - - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(3).weightInit(WeightInit.XAVIER).activation(Activation.TANH).build()).layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3).build()).build(); MultiLayerNetwork network = new MultiLayerNetwork(conf); network.init(); - DataSetIterator iter = new IrisDataSetIterator(50, 150); - assertEquals(0, network.getLayerWiseConfigurations().getIterationCount()); network.fit(iter); assertEquals(3, network.getLayerWiseConfigurations().getIterationCount()); @@ -784,93 +558,58 @@ public class MultiLayerTest extends BaseDL4JTest { iter.reset(); network.fit(iter.next()); assertEquals(7, network.getLayerWiseConfigurations().getIterationCount()); - ByteArrayOutputStream baos = new ByteArrayOutputStream(); ModelSerializer.writeModel(network, baos, true); byte[] asBytes = baos.toByteArray(); - ByteArrayInputStream bais = new ByteArrayInputStream(asBytes); MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(bais, true); assertEquals(7, net.getLayerWiseConfigurations().getIterationCount()); } - @Test - public void testBiasL1L2() { - - + @DisplayName("Test Bias L 1 L 2") + void testBiasL1L2() { Nd4j.getRandom().setSeed(123); - MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .weightInit(WeightInit.XAVIER).activation(Activation.TANH).seed(123).list() - .layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build()) - .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY).nIn(10).nOut(10) - .build()) - .build(); - - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .l1Bias(0.1).l2Bias(0.2).weightInit(WeightInit.XAVIER).activation(Activation.TANH) - .seed(123).list().layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build()) - .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY).nIn(10).nOut(10) - .build()) - .build(); - + MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).weightInit(WeightInit.XAVIER).activation(Activation.TANH).seed(123).list().layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build()).layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY).nIn(10).nOut(10).build()).build(); + MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).l1Bias(0.1).l2Bias(0.2).weightInit(WeightInit.XAVIER).activation(Activation.TANH).seed(123).list().layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build()).layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY).nIn(10).nOut(10).build()).build(); MultiLayerNetwork net1 = new MultiLayerNetwork(conf1); net1.init(); - MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); net2.init(); - BaseLayer bl0 = (BaseLayer) net2.getLayer(0).conf().getLayer(); assertEquals(0.1, TestUtils.getL1(bl0.getRegularizationBias()), 1e-6); assertEquals(0.2, TestUtils.getL2(bl0.getRegularizationBias()), 1e-6); - INDArray features = Nd4j.rand(10, 10); INDArray labels = Nd4j.rand(10, 10); - net2.setParams(net1.params().dup()); - net1.setInput(features); net1.setLabels(labels); net2.setInput(features); net2.setLabels(labels); - net1.computeGradientAndScore(); net2.computeGradientAndScore(); - double r = net1.calcRegularizationScore(true); assertEquals(0.0, r, 0.0); - r = net2.calcRegularizationScore(true); assertEquals(0.0, r, 0.0); - - double s1 = net1.score(); double s2 = net2.score(); - assertEquals(s1, s2, 1e-6); //Biases initialized to 0 -> should initially have same score - + // Biases initialized to 0 -> should initially have same score + assertEquals(s1, s2, 1e-6); for (int i = 0; i < 10; i++) { net1.fit(features, labels); } - net2.setParams(net1.params().dup()); net1.computeGradientAndScore(); net2.computeGradientAndScore(); - r = net1.calcRegularizationScore(true); assertEquals(0.0, r, 0.0); - r = net2.calcRegularizationScore(true); assertTrue(r > 0.0); - s1 = net1.score(); s2 = net2.score(); - - assertNotEquals(s1, s2, 1e-6); //Scores should differ due to bias l1/l2 - + // Scores should differ due to bias l1/l2 + assertNotEquals(s1, s2, 1e-6); for (int i = 0; i < 2; i++) { assertEquals(0.0, net1.getLayer(i).calcRegularizationScore(true), 0.0); assertTrue(net2.getLayer(i).calcRegularizationScore(true) > 0.0); @@ -881,545 +620,311 @@ public class MultiLayerTest extends BaseDL4JTest { Summary should pick up preprocessors set manually on inputs as well */ @Test - public void testSummary() { + @DisplayName("Test Summary") + void testSummary() { int V_WIDTH = 130; int V_HEIGHT = 130; int V_NFRAMES = 150; - MultiLayerConfiguration confForArchitecture = - new NeuralNetConfiguration.Builder().seed(12345).l2(0.001) //l2 regularization on all layers - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .list() - .layer(0, new ConvolutionLayer.Builder(10, 10).nIn(3) //3 channels: RGB - .nOut(30).stride(4, 4).activation(Activation.RELU).weightInit( - WeightInit.RELU) - .updater(Updater.ADAGRAD).build()) //Output: (130-10+0)/4+1 = 31 -> 31*31*30 - .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX) - .kernelSize(3, 3).stride(2, 2).build()) //(31-3+0)/2+1 = 15 - .layer(2, new ConvolutionLayer.Builder(3, 3).nIn(30).nOut(10).stride(2, 2) - .activation(Activation.RELU).weightInit(WeightInit.RELU) - .updater(Updater.ADAGRAD).build()) //Output: (15-3+0)/2+1 = 7 -> 7*7*10 = 490 - .layer(3, new DenseLayer.Builder().activation(Activation.RELU).nIn(490).nOut(50) - .weightInit(WeightInit.RELU).updater(Updater.ADAGRAD) - .gradientNormalization( - GradientNormalization.ClipElementWiseAbsoluteValue) - .gradientNormalizationThreshold(10).build()) - .layer(4, new GravesLSTM.Builder().activation(Activation.SOFTSIGN).nIn(50) - .nOut(50).weightInit(WeightInit.XAVIER).updater(Updater.ADAGRAD) - .gradientNormalization( - GradientNormalization.ClipElementWiseAbsoluteValue) - .gradientNormalizationThreshold(10) - .build()) - .layer(5, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nIn(50).nOut(4) //4 possible shapes: circle, square, arc, line - .updater(Updater.ADAGRAD).weightInit(WeightInit.XAVIER) - .gradientNormalization( - GradientNormalization.ClipElementWiseAbsoluteValue) - .gradientNormalizationThreshold(10).build()) - .inputPreProcessor(0, new RnnToCnnPreProcessor(V_HEIGHT, V_WIDTH, 3)) - .inputPreProcessor(3, new CnnToFeedForwardPreProcessor(7, 7, 10)) - .inputPreProcessor(4, new FeedForwardToRnnPreProcessor()) - .backpropType(BackpropType.TruncatedBPTT) - .tBPTTForwardLength(V_NFRAMES / 5).tBPTTBackwardLength(V_NFRAMES / 5).build(); + MultiLayerConfiguration confForArchitecture = // l2 regularization on all layers + new NeuralNetConfiguration.Builder().seed(12345).l2(0.001).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list().layer(0, // 3 channels: RGB + new ConvolutionLayer.Builder(10, 10).nIn(3).nOut(30).stride(4, 4).activation(Activation.RELU).weightInit(WeightInit.RELU).updater(Updater.ADAGRAD).build()).layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(3, 3).stride(2, 2).build()).layer(2, new ConvolutionLayer.Builder(3, 3).nIn(30).nOut(10).stride(2, 2).activation(Activation.RELU).weightInit(WeightInit.RELU).updater(Updater.ADAGRAD).build()).layer(3, new DenseLayer.Builder().activation(Activation.RELU).nIn(490).nOut(50).weightInit(WeightInit.RELU).updater(Updater.ADAGRAD).gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue).gradientNormalizationThreshold(10).build()).layer(4, new GravesLSTM.Builder().activation(Activation.SOFTSIGN).nIn(50).nOut(50).weightInit(WeightInit.XAVIER).updater(Updater.ADAGRAD).gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue).gradientNormalizationThreshold(10).build()).layer(5, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(50).nOut(// 4 possible shapes: circle, square, arc, line + 4).updater(Updater.ADAGRAD).weightInit(WeightInit.XAVIER).gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue).gradientNormalizationThreshold(10).build()).inputPreProcessor(0, new RnnToCnnPreProcessor(V_HEIGHT, V_WIDTH, 3)).inputPreProcessor(3, new CnnToFeedForwardPreProcessor(7, 7, 10)).inputPreProcessor(4, new FeedForwardToRnnPreProcessor()).backpropType(BackpropType.TruncatedBPTT).tBPTTForwardLength(V_NFRAMES / 5).tBPTTBackwardLength(V_NFRAMES / 5).build(); MultiLayerNetwork modelExpectedArch = new MultiLayerNetwork(confForArchitecture); modelExpectedArch.init(); MultiLayerNetwork modelMow = new TransferLearning.Builder(modelExpectedArch).setFeatureExtractor(2).build(); -// System.out.println(modelExpectedArch.summary()); -// System.out.println(modelMow.summary()); -// System.out.println(modelMow.summary(InputType.recurrent(V_HEIGHT*V_WIDTH*3))); + // System.out.println(modelExpectedArch.summary()); + // System.out.println(modelMow.summary()); + // System.out.println(modelMow.summary(InputType.recurrent(V_HEIGHT*V_WIDTH*3))); } - @Test(expected = DL4JException.class) - public void testErrorNoOutputLayer() { - - MultiLayerConfiguration c = new NeuralNetConfiguration.Builder().list() - .layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build()).build(); - - MultiLayerNetwork net = new MultiLayerNetwork(c); - net.init(); - - INDArray f = Nd4j.create(1, 10); - INDArray l = Nd4j.create(1, 10); - - net.setInput(f); - net.setLabels(l); - - net.computeGradientAndScore(); - } - - @Test - public void testSetParamTable() { + @DisplayName("Test Error No Output Layer") + void testErrorNoOutputLayer() { + assertThrows(DL4JException.class, () -> { + MultiLayerConfiguration c = new NeuralNetConfiguration.Builder().list().layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build()).build(); + MultiLayerNetwork net = new MultiLayerNetwork(c); + net.init(); + INDArray f = Nd4j.create(1, 10); + INDArray l = Nd4j.create(1, 10); + net.setInput(f); + net.setLabels(l); + net.computeGradientAndScore(); + }); + } + @Test + @DisplayName("Test Set Param Table") + void testSetParamTable() { Nd4j.getRandom().setSeed(123); - MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(123).list() - .layer(0, new DenseLayer.Builder().nIn(4).nOut(3).weightInit(WeightInit.XAVIER) - .activation(Activation.TANH).build()) - .layer(1, new DenseLayer.Builder().nIn(3).nOut(2).weightInit(WeightInit.XAVIER) - .activation(Activation.TANH).build()) - .layer(2, new LSTM.Builder().nIn(2).nOut(2).build()) - .layer(3, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nIn(2).nOut(3) - .build()) - .build(); - - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(987).list() - .layer(0, new DenseLayer.Builder().nIn(4).nOut(3).weightInit(WeightInit.XAVIER) - .activation(Activation.TANH).build()) - .layer(1, new DenseLayer.Builder().nIn(3).nOut(2).weightInit(WeightInit.XAVIER) - .activation(Activation.TANH).build()) - .layer(2, new LSTM.Builder().nIn(2).nOut(2).build()) - .layer(3, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nIn(2).nOut(3) - .build()) - .build(); - + MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(123).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(3).weightInit(WeightInit.XAVIER).activation(Activation.TANH).build()).layer(1, new DenseLayer.Builder().nIn(3).nOut(2).weightInit(WeightInit.XAVIER).activation(Activation.TANH).build()).layer(2, new LSTM.Builder().nIn(2).nOut(2).build()).layer(3, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nIn(2).nOut(3).build()).build(); + MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(987).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(3).weightInit(WeightInit.XAVIER).activation(Activation.TANH).build()).layer(1, new DenseLayer.Builder().nIn(3).nOut(2).weightInit(WeightInit.XAVIER).activation(Activation.TANH).build()).layer(2, new LSTM.Builder().nIn(2).nOut(2).build()).layer(3, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nIn(2).nOut(3).build()).build(); MultiLayerNetwork net1 = new MultiLayerNetwork(conf1); net1.init(); - MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); net2.init(); - assertNotEquals(net1.params(), net2.params()); assertNotEquals(net1.paramTable(), net2.paramTable()); - net1.setParamTable(net2.paramTable()); assertEquals(net1.params(), net2.params()); assertEquals(net1.paramTable(), net2.paramTable()); } - @Test - public void testCompareLayerMethods(){ - //Simple test: compare .layer(int, Layer) and .layer(Layer) are identical - - MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(123).list() - .layer(0, new DenseLayer.Builder().nIn(4).nOut(3).weightInit(WeightInit.XAVIER) - .activation(Activation.TANH).build()) - .layer(1, new DenseLayer.Builder().nIn(3).nOut(2).weightInit(WeightInit.XAVIER) - .activation(Activation.TANH).build()) - .layer(2, new LSTM.Builder().nIn(2).nOut(2).build()) - .layer(3, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nIn(2).nOut(3) - .build()) - .build(); - - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(123).list() - .layer(new DenseLayer.Builder().nIn(4).nOut(3).weightInit(WeightInit.XAVIER) - .activation(Activation.TANH).build()) - .layer(new DenseLayer.Builder().nIn(3).nOut(2).weightInit(WeightInit.XAVIER) - .activation(Activation.TANH).build()) - .layer(new LSTM.Builder().nIn(2).nOut(2).build()) - .layer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nIn(2).nOut(3) - .build()) - .build(); - + @DisplayName("Test Compare Layer Methods") + void testCompareLayerMethods() { + // Simple test: compare .layer(int, Layer) and .layer(Layer) are identical + MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(123).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(3).weightInit(WeightInit.XAVIER).activation(Activation.TANH).build()).layer(1, new DenseLayer.Builder().nIn(3).nOut(2).weightInit(WeightInit.XAVIER).activation(Activation.TANH).build()).layer(2, new LSTM.Builder().nIn(2).nOut(2).build()).layer(3, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nIn(2).nOut(3).build()).build(); + MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(123).list().layer(new DenseLayer.Builder().nIn(4).nOut(3).weightInit(WeightInit.XAVIER).activation(Activation.TANH).build()).layer(new DenseLayer.Builder().nIn(3).nOut(2).weightInit(WeightInit.XAVIER).activation(Activation.TANH).build()).layer(new LSTM.Builder().nIn(2).nOut(2).build()).layer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nIn(2).nOut(3).build()).build(); assertEquals(conf1, conf2); } - @Test - public void testEpochCounter() throws Exception { - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .list() - .layer(new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX).build()) - .build(); - + @DisplayName("Test Epoch Counter") + void testEpochCounter() throws Exception { + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - assertEquals(0, net.getLayerWiseConfigurations().getEpochCount()); - - DataSetIterator iter = new IrisDataSetIterator(150, 150); - - for( int i=0; i<4; i++ ){ + for (int i = 0; i < 4; i++) { assertEquals(i, net.getLayerWiseConfigurations().getEpochCount()); net.fit(iter); - assertEquals(i+1, net.getLayerWiseConfigurations().getEpochCount()); + assertEquals(i + 1, net.getLayerWiseConfigurations().getEpochCount()); } - assertEquals(4, net.getLayerWiseConfigurations().getEpochCount()); - MultiLayerNetwork restored = TestUtils.testModelSerialization(net); assertEquals(4, restored.getLayerWiseConfigurations().getEpochCount()); } @Test - public void testInputClearance() throws Exception { - //Activations should be cleared - if not, it's possible for out of (workspace) scope arrays to be around + @DisplayName("Test Input Clearance") + void testInputClearance() throws Exception { + // Activations should be cleared - if not, it's possible for out of (workspace) scope arrays to be around // which can cause a crash - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .convolutionMode(ConvolutionMode.Same) - .list() - .layer(new ConvolutionLayer.Builder().kernelSize(2,2).stride(1,1).nIn(1).nOut(1).build()) - .layer(new SubsamplingLayer.Builder().kernelSize(2,2).stride(1,1).build()) - .layer(new DenseLayer.Builder().nOut(10).build()) - .layer(new OutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutional(28,28,1)) - .build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().convolutionMode(ConvolutionMode.Same).list().layer(new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).nIn(1).nOut(1).build()).layer(new SubsamplingLayer.Builder().kernelSize(2, 2).stride(1, 1).build()).layer(new DenseLayer.Builder().nOut(10).build()).layer(new OutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).build()).setInputType(InputType.convolutional(28, 28, 1)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - - INDArray content = Nd4j.create(1,1,28,28); - - //Check output: + INDArray content = Nd4j.create(1, 1, 28, 28); + // Check output: net.output(content); - for(org.deeplearning4j.nn.api.Layer l : net.getLayers()){ + for (org.deeplearning4j.nn.api.Layer l : net.getLayers()) { assertNull(l.input()); } - - //Check feedForward: + // Check feedForward: net.feedForward(content, false); - for(org.deeplearning4j.nn.api.Layer l : net.getLayers()){ + for (org.deeplearning4j.nn.api.Layer l : net.getLayers()) { assertNull(l.input()); } } - @Test - public void testExternalErrors() { - //Simple test: same network, but in one case: one less layer (the OutputLayer), where the epsilons are passed in externally + @DisplayName("Test External Errors") + void testExternalErrors() { + // Simple test: same network, but in one case: one less layer (the OutputLayer), where the epsilons are passed in externally // instead. Should get identical results - - for(WorkspaceMode ws : WorkspaceMode.values()) { + for (WorkspaceMode ws : WorkspaceMode.values()) { log.info("Workspace mode: " + ws); - Nd4j.getRandom().setSeed(12345); INDArray inData = Nd4j.rand(3, 10); INDArray outData = Nd4j.rand(3, 10); - Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration standard = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)) - .trainingWorkspaceMode(ws) - .inferenceWorkspaceMode(ws) - .seed(12345).list() - .layer(new DenseLayer.Builder().nIn(10).nOut(10).build()) - .layer(new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(10) - .nOut(10).build()) - .build(); + MultiLayerConfiguration standard = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)).trainingWorkspaceMode(ws).inferenceWorkspaceMode(ws).seed(12345).list().layer(new DenseLayer.Builder().nIn(10).nOut(10).build()).layer(new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).build()).build(); MultiLayerNetwork s = new MultiLayerNetwork(standard); s.init(); - - Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration external = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)) - .trainingWorkspaceMode(ws) - .inferenceWorkspaceMode(ws) - .seed(12345).list() - .layer(new DenseLayer.Builder().nIn(10).nOut(10).build()) - .build(); - + MultiLayerConfiguration external = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)).trainingWorkspaceMode(ws).inferenceWorkspaceMode(ws).seed(12345).list().layer(new DenseLayer.Builder().nIn(10).nOut(10).build()).build(); MultiLayerNetwork e = new MultiLayerNetwork(external); e.init(); - s.setInput(inData); s.setLabels(outData); s.computeGradientAndScore(); Gradient sGrad = s.gradient(); - s.setInput(inData); - s.feedForward(true, false); //FF without clearing inputs as we need them later - + // FF without clearing inputs as we need them later + s.feedForward(true, false); e.setInput(inData); - e.feedForward(true, false); //FF without clearing inputs as we need them later - + // FF without clearing inputs as we need them later + e.feedForward(true, false); org.deeplearning4j.nn.layers.OutputLayer ol = (org.deeplearning4j.nn.layers.OutputLayer) s.getLayer(1); Pair olPairStd = ol.backpropGradient(null, LayerWorkspaceMgr.noWorkspaces()); - INDArray olEpsilon = olPairStd.getSecond().detach(); - e.setInput(inData); e.feedForward(true, false); Pair extErrorGrad = e.backpropGradient(olEpsilon, LayerWorkspaceMgr.noWorkspaces()); - int nParamsDense = 10 * 10 + 10; - assertEquals(sGrad.gradient().get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(0, nParamsDense)), - extErrorGrad.getFirst().gradient()); - + assertEquals(sGrad.gradient().get(NDArrayIndex.interval(0, 0, true), NDArrayIndex.interval(0, nParamsDense)), extErrorGrad.getFirst().gradient()); Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); } } @Test - public void testExternalErrors2(){ + @DisplayName("Test External Errors 2") + void testExternalErrors2() { Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.SCOPE_PANIC); int nIn = 4; int nOut = 3; - - for(WorkspaceMode ws : WorkspaceMode.values()) { -// System.out.println("***** WORKSPACE: " + ws); - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .updater(new Adam(0.01)) - .trainingWorkspaceMode(ws) - .inferenceWorkspaceMode(ws) - .list() - .layer(new DenseLayer.Builder().nIn(nIn).nOut(nOut).activation(Activation.RELU).build()) - .layer(new ActivationLayer.Builder().activation(Activation.IDENTITY).build()) - .inputPreProcessor(0, new RnnToFeedForwardPreProcessor()) - .inputPreProcessor(1, new FeedForwardToRnnPreProcessor()) - .build(); - + for (WorkspaceMode ws : WorkspaceMode.values()) { + // System.out.println("***** WORKSPACE: " + ws); + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Adam(0.01)).trainingWorkspaceMode(ws).inferenceWorkspaceMode(ws).list().layer(new DenseLayer.Builder().nIn(nIn).nOut(nOut).activation(Activation.RELU).build()).layer(new ActivationLayer.Builder().activation(Activation.IDENTITY).build()).inputPreProcessor(0, new RnnToFeedForwardPreProcessor()).inputPreProcessor(1, new FeedForwardToRnnPreProcessor()).build(); MultiLayerNetwork graph = new MultiLayerNetwork(conf); graph.init(); - final int minibatch = 5; final int seqLen = 6; - - INDArray param = Nd4j.create(new double[]{0.54, 0.31, 0.98, -0.30, -0.66, -0.19, -0.29, -0.62, 0.13, -0.32, 0.01, -0.03, 0.00, 0.00, 0.00}).reshape(1, -1); + INDArray param = Nd4j.create(new double[] { 0.54, 0.31, 0.98, -0.30, -0.66, -0.19, -0.29, -0.62, 0.13, -0.32, 0.01, -0.03, 0.00, 0.00, 0.00 }).reshape(1, -1); graph.setParams(param); - - INDArray input = Nd4j.rand(new int[]{minibatch, nIn, seqLen}, 12); + INDArray input = Nd4j.rand(new int[] { minibatch, nIn, seqLen }, 12); INDArray expected = Nd4j.ones(minibatch, nOut, seqLen); - graph.setInput(input); INDArray output = graph.feedForward(false, false).get(2); INDArray error = output.sub(expected); - for (org.deeplearning4j.nn.api.Layer l : graph.getLayers()) { assertNotNull(l.input()); assertFalse(l.input().isAttached()); } - // Compute Gradient - Pair gradient = graph.backpropGradient(error, LayerWorkspaceMgr.noWorkspaces()); + Pair gradient = graph.backpropGradient(error, LayerWorkspaceMgr.noWorkspaces()); graph.getUpdater().update(graph, gradient.getFirst(), 0, 0, minibatch, LayerWorkspaceMgr.noWorkspaces()); - Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); } - Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.DISABLED); } @Test - public void testLayerSize(){ - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - - .list() - .layer(new ConvolutionLayer.Builder().kernelSize(2,2).nOut(6).build()) - .layer(new SubsamplingLayer.Builder().kernelSize(2,2).build()) - .layer(new DenseLayer.Builder().nOut(30).build()) - .layer(new OutputLayer.Builder().nOut(13).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutional(28,28,3)) - .build(); - + @DisplayName("Test Layer Size") + void testLayerSize() { + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(new ConvolutionLayer.Builder().kernelSize(2, 2).nOut(6).build()).layer(new SubsamplingLayer.Builder().kernelSize(2, 2).build()).layer(new DenseLayer.Builder().nOut(30).build()).layer(new OutputLayer.Builder().nOut(13).activation(Activation.SOFTMAX).build()).setInputType(InputType.convolutional(28, 28, 3)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - assertEquals(6, net.layerSize(0)); assertEquals(0, net.layerSize(1)); assertEquals(30, net.layerSize(2)); assertEquals(13, net.layerSize(3)); - assertEquals(3, net.layerInputSize(0)); assertEquals(0, net.layerInputSize(1)); - assertEquals(((FeedForwardLayer)net.getLayer(2).conf().getLayer()).getNIn(), net.layerInputSize(2)); + assertEquals(((FeedForwardLayer) net.getLayer(2).conf().getLayer()).getNIn(), net.layerInputSize(2)); assertEquals(30, net.layerInputSize(3)); } - @Test - public void testZeroParamNet() throws Exception { - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .list() - .layer(new SubsamplingLayer.Builder().kernelSize(2,2).stride(2,2).build()) - .layer(new LossLayer.Builder().activation(Activation.SIGMOID).lossFunction(LossFunctions.LossFunction.MSE).build()) - .setInputType(InputType.convolutionalFlat(28,28,1)) - .build(); - + @DisplayName("Test Zero Param Net") + void testZeroParamNet() throws Exception { + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(new SubsamplingLayer.Builder().kernelSize(2, 2).stride(2, 2).build()).layer(new LossLayer.Builder().activation(Activation.SIGMOID).lossFunction(LossFunctions.LossFunction.MSE).build()).setInputType(InputType.convolutionalFlat(28, 28, 1)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - DataSet ds = new MnistDataSetIterator(16, true, 12345).next(); - INDArray out = net.output(ds.getFeatures()); - INDArray labelTemp = Nd4j.create(out.shape()); ds.setLabels(labelTemp); - net.fit(ds); - MultiLayerNetwork net2 = TestUtils.testModelSerialization(net); INDArray out2 = net2.output(ds.getFeatures()); assertEquals(out, out2); } - @Test - public void testInputActivationGradient(){ + @DisplayName("Test Input Activation Gradient") + void testInputActivationGradient() { Nd4j.setDataType(DataType.DOUBLE); - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .dataType(DataType.DOUBLE) - .seed(12345) - .activation(Activation.TANH) - .list() - .layer(new DenseLayer.Builder().nIn(10).nOut(10).build()) - .layer(new OutputLayer.Builder().nIn(10).nOut(10).lossFunction(LossFunctions.LossFunction.MSE).build()) - .build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).seed(12345).activation(Activation.TANH).list().layer(new DenseLayer.Builder().nIn(10).nOut(10).build()).layer(new OutputLayer.Builder().nIn(10).nOut(10).lossFunction(LossFunctions.LossFunction.MSE).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - INDArray in = Nd4j.rand(1, 10); INDArray label = Nd4j.rand(1, 10); - - Pair p = net.calculateGradients(in, label, null, null); - - //Quick gradient check: + Pair p = net.calculateGradients(in, label, null, null); + // Quick gradient check: double eps = 1e-6; double maxRelError = 1e-5; - for( int i=0; i<10; i++ ){ + for (int i = 0; i < 10; i++) { double orig = in.getDouble(i); in.putScalar(i, orig + eps); double scorePlus = net.score(new DataSet(in, label)); in.putScalar(i, orig - eps); double scoreMinus = net.score(new DataSet(in, label)); in.putScalar(i, orig); - double expGrad = (scorePlus - scoreMinus) / (2.0 * eps); double actGrad = p.getSecond().getDouble(i); - double relError = (Math.abs(expGrad - actGrad)) / (Math.abs(expGrad) + Math.abs(actGrad)); - String str = i + " - " + relError + " - exp=" + expGrad + ", act=" + actGrad; - assertTrue(str, relError < maxRelError); + assertTrue(relError < maxRelError,str); } } - @Test - public void testMultiLayerConfigurationActivationTypes(){ - - NeuralNetConfiguration.ListBuilder builder = new NeuralNetConfiguration.Builder() - .list() - .layer(new LSTM.Builder().nOut(6).build()) - .layer(new LSTM.Builder().nOut(7).build()) - .layer(new GlobalPoolingLayer()) - .layer(new OutputLayer.Builder().nOut(8).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.recurrent(10)); - + @DisplayName("Test Multi Layer Configuration Activation Types") + void testMultiLayerConfigurationActivationTypes() { + NeuralNetConfiguration.ListBuilder builder = new NeuralNetConfiguration.Builder().list().layer(new LSTM.Builder().nOut(6).build()).layer(new LSTM.Builder().nOut(7).build()).layer(new GlobalPoolingLayer()).layer(new OutputLayer.Builder().nOut(8).activation(Activation.SOFTMAX).build()).setInputType(InputType.recurrent(10)); MultiLayerConfiguration conf = builder.build(); - List outBuilder = builder.getLayerActivationTypes(); List outConf = conf.getLayerActivationTypes(InputType.recurrent(10)); - - List exp = Arrays.asList( - InputType.recurrent(6), - InputType.recurrent(7), - InputType.feedForward(7), - InputType.feedForward(8) - ); - - + List exp = Arrays.asList(InputType.recurrent(6), InputType.recurrent(7), InputType.feedForward(7), InputType.feedForward(8)); assertEquals(exp, outBuilder); assertEquals(exp, outConf); } @Test - public void testMultipleEpochsSimple(){ - //Mainly a simple sanity check on the preconditions in the method... + @DisplayName("Test Multiple Epochs Simple") + void testMultipleEpochsSimple() { + // Mainly a simple sanity check on the preconditions in the method... DataSetIterator iter = new IrisDataSetIterator(10, 150); - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .list() - .layer(new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX).build()) - .build(); + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - net.fit(iter, 3); - ComputationGraph g = net.toComputationGraph(); g.fit(iter, 3); } - @Test - public void testPretrainFitMethods(){ - - //The fit methods should *not* do layerwise pretraining: - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - - .list() - .layer(new VariationalAutoencoder.Builder() - .nIn(10).nOut(10).encoderLayerSizes(10).decoderLayerSizes(10).build()) - .layer(new OutputLayer.Builder().nIn(10).nOut(10).activation(Activation.SOFTMAX).build()) - - .build(); - + @DisplayName("Test Pretrain Fit Methods") + void testPretrainFitMethods() { + // The fit methods should *not* do layerwise pretraining: + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(new VariationalAutoencoder.Builder().nIn(10).nOut(10).encoderLayerSizes(10).decoderLayerSizes(10).build()).layer(new OutputLayer.Builder().nIn(10).nOut(10).activation(Activation.SOFTMAX).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - Set> exp = new HashSet<>(); exp.add(MultiLayerNetwork.class); - CheckModelsListener listener = new CheckModelsListener(); net.setListeners(listener); - - INDArray f = Nd4j.create(1,10); - INDArray l = Nd4j.create(1,10); - DataSet ds = new DataSet(f,l); - MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(f,l); - + INDArray f = Nd4j.create(1, 10); + INDArray l = Nd4j.create(1, 10); + DataSet ds = new DataSet(f, l); + MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(f, l); DataSetIterator iter = new ExistingDataSetIterator(Collections.singletonList(ds)); net.fit(iter); assertEquals(exp, listener.getModelClasses()); - net.fit(ds); assertEquals(exp, listener.getModelClasses()); - net.fit(f, l); assertEquals(exp, listener.getModelClasses()); - net.fit(f, l, null, null); assertEquals(exp, listener.getModelClasses()); - net.fit(mds); assertEquals(exp, listener.getModelClasses()); - net.fit(new SingletonMultiDataSetIterator(mds)); assertEquals(exp, listener.getModelClasses()); } @Test - public void testINDArrayConfigCloning(){ - //INDArrays in config should be cloned to avoid threading issues - + @DisplayName("Test IND Array Config Cloning") + void testINDArrayConfigCloning() { + // INDArrays in config should be cloned to avoid threading issues int mb = 3; int b = 4; int c = 3; int depth = b * (5 + c); int w = 6; int h = 6; - - INDArray bbPrior = Nd4j.rand(b, 2).muliRowVector(Nd4j.create(new double[]{w, h}).castTo(Nd4j.defaultFloatingPointType())); - - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .l2(0.01) - .list() - .layer(new ConvolutionLayer.Builder().nIn(depth).nOut(depth).kernelSize(1,1).build()) - .layer(new Yolo2OutputLayer.Builder() - .boundingBoxPriors(bbPrior) - .build()) - .build(); - + INDArray bbPrior = Nd4j.rand(b, 2).muliRowVector(Nd4j.create(new double[] { w, h }).castTo(Nd4j.defaultFloatingPointType())); + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().l2(0.01).list().layer(new ConvolutionLayer.Builder().nIn(depth).nOut(depth).kernelSize(1, 1).build()).layer(new Yolo2OutputLayer.Builder().boundingBoxPriors(bbPrior).build()).build(); MultiLayerConfiguration conf2 = conf.clone(); - - INDArray bb1 = ((Yolo2OutputLayer)conf.getConf(1).getLayer()).getBoundingBoxes(); - INDArray bb2 = ((Yolo2OutputLayer)conf2.getConf(1).getLayer()).getBoundingBoxes(); + INDArray bb1 = ((Yolo2OutputLayer) conf.getConf(1).getLayer()).getBoundingBoxes(); + INDArray bb2 = ((Yolo2OutputLayer) conf2.getConf(1).getLayer()).getBoundingBoxes(); assertFalse(bb1 == bb2); - assertEquals(bb1, bb2); } @Data + @DisplayName("Check Models Listener") public static class CheckModelsListener extends BaseTrainingListener { private Set> modelClasses = new HashSet<>(); @@ -1430,97 +935,79 @@ public class MultiLayerTest extends BaseDL4JTest { } } - @Test - public void testMLNUpdaterBlocks(){ - //Check that setting learning rate results in correct rearrangement of updater state within updater blocks - //https://github.com/deeplearning4j/deeplearning4j/issues/6809#issuecomment-463892644 - + @DisplayName("Test MLN Updater Blocks") + void testMLNUpdaterBlocks() { + // Check that setting learning rate results in correct rearrangement of updater state within updater blocks + // https://github.com/deeplearning4j/deeplearning4j/issues/6809#issuecomment-463892644 double lr = 1e-3; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .seed(12345) - .weightInit(WeightInit.XAVIER) - .updater(new Adam(lr)) - .list() - .layer(new DenseLayer.Builder().nIn(5).nOut(3).build()) - .layer(new DenseLayer.Builder().nIn(3).nOut(2).build()) - .layer(new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(2).nOut(1) - .activation(Activation.SIGMOID).build()) - .build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).weightInit(WeightInit.XAVIER).updater(new Adam(lr)).list().layer(new DenseLayer.Builder().nIn(5).nOut(3).build()).layer(new DenseLayer.Builder().nIn(3).nOut(2).build()).layer(new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(2).nOut(1).activation(Activation.SIGMOID).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - INDArray in = Nd4j.rand(1, 5); - INDArray lbl = Nd4j.rand(1,1); - + INDArray lbl = Nd4j.rand(1, 1); net.fit(new DataSet(in, lbl)); - INDArray viewArray = net.getUpdater().getStateViewArray(); INDArray viewArrayCopy = viewArray.dup(); - //Initially updater view array is set out like: - //[m0w, m0b, m1w, m1b, m2w, m2b][v0w, v0b, v1w, v1b, v2w, v2b] + // Initially updater view array is set out like: + // [m0w, m0b, m1w, m1b, m2w, m2b][v0w, v0b, v1w, v1b, v2w, v2b] long soFar = 0; - INDArray m0w = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+5*3)).assign(0); //m0w - soFar += 5*3; - INDArray m0b = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+3)).assign(1); //m0b + // m0w + INDArray m0w = viewArray.get(NDArrayIndex.interval(0, 0, true), NDArrayIndex.interval(soFar, soFar + 5 * 3)).assign(0); + soFar += 5 * 3; + // m0b + INDArray m0b = viewArray.get(NDArrayIndex.interval(0, 0, true), NDArrayIndex.interval(soFar, soFar + 3)).assign(1); soFar += 3; - INDArray m1w = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+3*2)).assign(2); //m1w - soFar += 3*2; - INDArray m1b = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+2)).assign(3); //m1b + // m1w + INDArray m1w = viewArray.get(NDArrayIndex.interval(0, 0, true), NDArrayIndex.interval(soFar, soFar + 3 * 2)).assign(2); + soFar += 3 * 2; + // m1b + INDArray m1b = viewArray.get(NDArrayIndex.interval(0, 0, true), NDArrayIndex.interval(soFar, soFar + 2)).assign(3); soFar += 2; - INDArray m2w = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+2*1)).assign(4); //m2w - soFar += 2*1; - INDArray m2b = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+1)).assign(5); //m2b + // m2w + INDArray m2w = viewArray.get(NDArrayIndex.interval(0, 0, true), NDArrayIndex.interval(soFar, soFar + 2 * 1)).assign(4); + soFar += 2 * 1; + // m2b + INDArray m2b = viewArray.get(NDArrayIndex.interval(0, 0, true), NDArrayIndex.interval(soFar, soFar + 1)).assign(5); soFar += 1; - - INDArray v0w = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+5*3)).assign(6); //v0w - soFar += 5*3; - INDArray v0b = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+3)).assign(7); //v0b + // v0w + INDArray v0w = viewArray.get(NDArrayIndex.interval(0, 0, true), NDArrayIndex.interval(soFar, soFar + 5 * 3)).assign(6); + soFar += 5 * 3; + // v0b + INDArray v0b = viewArray.get(NDArrayIndex.interval(0, 0, true), NDArrayIndex.interval(soFar, soFar + 3)).assign(7); soFar += 3; - INDArray v1w = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+3*2)).assign(8); //v1w - soFar += 3*2; - INDArray v1b = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+2)).assign(9); //v1b + // v1w + INDArray v1w = viewArray.get(NDArrayIndex.interval(0, 0, true), NDArrayIndex.interval(soFar, soFar + 3 * 2)).assign(8); + soFar += 3 * 2; + // v1b + INDArray v1b = viewArray.get(NDArrayIndex.interval(0, 0, true), NDArrayIndex.interval(soFar, soFar + 2)).assign(9); soFar += 2; - INDArray v2w = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+2*1)).assign(10); //v2w - soFar += 2*1; - INDArray v2b = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+1)).assign(11); //v2b + // v2w + INDArray v2w = viewArray.get(NDArrayIndex.interval(0, 0, true), NDArrayIndex.interval(soFar, soFar + 2 * 1)).assign(10); + soFar += 2 * 1; + // v2b + INDArray v2b = viewArray.get(NDArrayIndex.interval(0, 0, true), NDArrayIndex.interval(soFar, soFar + 1)).assign(11); soFar += 1; - - net.setLearningRate(0, 0.0); - - //Expect new updater state to look like: - //[m0w, m0b][v0w,v0b], [m1w, m1b, m2w, m2b][v1w, v1b, v2w, v2b] - INDArray exp = Nd4j.concat(1, m0w, m0b, v0w, v0b, - m1w, m1b, m2w, m2b, v1w, v1b, v2w, v2b); - + // Expect new updater state to look like: + // [m0w, m0b][v0w,v0b], [m1w, m1b, m2w, m2b][v1w, v1b, v2w, v2b] + INDArray exp = Nd4j.concat(1, m0w, m0b, v0w, v0b, m1w, m1b, m2w, m2b, v1w, v1b, v2w, v2b); INDArray act = net.getUpdater().getStateViewArray(); -// System.out.println(exp); -// System.out.println(act); - + // System.out.println(exp); + // System.out.println(act); assertEquals(exp, act); - - //And set layer 1 LR: + // And set layer 1 LR: net.setLearningRate(1, 0.2); - exp = Nd4j.concat(1, m0w, m0b, v0w, v0b, - m1w, m1b, v1w, v1b, - m2w, m2b, v2w, v2b); + exp = Nd4j.concat(1, m0w, m0b, v0w, v0b, m1w, m1b, v1w, v1b, m2w, m2b, v2w, v2b); assertEquals(exp, net.getUpdater().getStateViewArray()); - - - //Set all back to original LR and check again: + // Set all back to original LR and check again: net.setLearningRate(1, lr); net.setLearningRate(0, lr); - exp = Nd4j.concat(1, m0w, m0b, m1w, m1b, m2w, m2b, v0w, v0b, v1w, v1b, v2w, v2b); assertEquals(exp, net.getUpdater().getStateViewArray()); - - - //Finally, training sanity check (if things are wrong, we get -ve values in adam V, which causes NaNs) + // Finally, training sanity check (if things are wrong, we get -ve values in adam V, which causes NaNs) net.getUpdater().getStateViewArray().assign(viewArrayCopy); net.setLearningRate(0, 0.0); - Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.NAN_PANIC); net.fit(new DataSet(in, lbl)); Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.SCOPE_PANIC); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningCompGraphTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningCompGraphTest.java index 0503214d5..4217b3ed1 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningCompGraphTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningCompGraphTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.transferlearning; import org.deeplearning4j.BaseDL4JTest; @@ -38,7 +37,7 @@ import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInitDistribution; import org.deeplearning4j.nn.weights.WeightInitXavier; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -49,62 +48,34 @@ import org.nd4j.linalg.learning.config.Nesterovs; import org.nd4j.linalg.learning.config.RmsProp; import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.lossfunctions.LossFunctions; - import java.util.HashMap; import java.util.Map; +import static org.junit.jupiter.api.Assertions.*; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.*; - -public class TransferLearningCompGraphTest extends BaseDL4JTest { +@DisplayName("Transfer Learning Comp Graph Test") +class TransferLearningCompGraphTest extends BaseDL4JTest { @Test - public void simpleFineTune() { - + @DisplayName("Simple Fine Tune") + void simpleFineTune() { long rng = 12345L; DataSet randomData = new DataSet(Nd4j.rand(10, 4), Nd4j.rand(10, 3)); - //original conf - ComputationGraphConfiguration confToChange = new NeuralNetConfiguration.Builder().seed(rng) - .optimizationAlgo(OptimizationAlgorithm.LBFGS).updater(new Nesterovs(0.01, 0.99)) - .graphBuilder().addInputs("layer0In").setInputTypes(InputType.feedForward(4)) - .addLayer("layer0", new DenseLayer.Builder().nIn(4).nOut(3).build(), "layer0In") - .addLayer("layer1", - new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nIn(3).nOut(3) - .build(), - "layer0") - .setOutputs("layer1").build(); - - //conf with learning parameters changed - ComputationGraphConfiguration expectedConf = new NeuralNetConfiguration.Builder().seed(rng) - .updater(new RmsProp(0.2)) - .graphBuilder().addInputs("layer0In") - .setInputTypes(InputType.feedForward(4)) - .addLayer("layer0", new DenseLayer.Builder().nIn(4).nOut(3).build(), "layer0In") - .addLayer("layer1", - new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nIn(3).nOut(3) - .build(), - "layer0") - .setOutputs("layer1").build(); + // original conf + ComputationGraphConfiguration confToChange = new NeuralNetConfiguration.Builder().seed(rng).optimizationAlgo(OptimizationAlgorithm.LBFGS).updater(new Nesterovs(0.01, 0.99)).graphBuilder().addInputs("layer0In").setInputTypes(InputType.feedForward(4)).addLayer("layer0", new DenseLayer.Builder().nIn(4).nOut(3).build(), "layer0In").addLayer("layer1", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3).build(), "layer0").setOutputs("layer1").build(); + // conf with learning parameters changed + ComputationGraphConfiguration expectedConf = new NeuralNetConfiguration.Builder().seed(rng).updater(new RmsProp(0.2)).graphBuilder().addInputs("layer0In").setInputTypes(InputType.feedForward(4)).addLayer("layer0", new DenseLayer.Builder().nIn(4).nOut(3).build(), "layer0In").addLayer("layer1", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3).build(), "layer0").setOutputs("layer1").build(); ComputationGraph expectedModel = new ComputationGraph(expectedConf); expectedModel.init(); - ComputationGraph modelToFineTune = new ComputationGraph(expectedConf); modelToFineTune.init(); modelToFineTune.setParams(expectedModel.params()); - //model after applying changes with transfer learning - ComputationGraph modelNow = - new TransferLearning.GraphBuilder(modelToFineTune) - .fineTuneConfiguration(new FineTuneConfiguration.Builder().seed(rng) - .updater(new RmsProp(0.2)).build()) - .build(); - - //Check json + // model after applying changes with transfer learning + ComputationGraph modelNow = new TransferLearning.GraphBuilder(modelToFineTune).fineTuneConfiguration(new FineTuneConfiguration.Builder().seed(rng).updater(new RmsProp(0.2)).build()).build(); + // Check json assertEquals(expectedConf.toJson(), modelNow.getConfiguration().toJson()); - - //Check params after fit + // Check params after fit modelNow.fit(randomData); expectedModel.fit(randomData); assertEquals(modelNow.score(), expectedModel.score(), 1e-8); @@ -112,66 +83,30 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest { } @Test - public void testNoutChanges() { + @DisplayName("Test Nout Changes") + void testNoutChanges() { DataSet randomData = new DataSet(Nd4j.rand(10, 4), Nd4j.rand(10, 2)); - - NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)) - .activation(Activation.IDENTITY); - FineTuneConfiguration fineTuneConfiguration = new FineTuneConfiguration.Builder().updater(new Sgd(0.1)) - .activation(Activation.IDENTITY).build(); - - ComputationGraph modelToFineTune = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In") - .addLayer("layer0", new DenseLayer.Builder().nIn(4).nOut(5).build(), "layer0In") - .addLayer("layer1", new DenseLayer.Builder().nIn(3).nOut(2).build(), "layer0") - .addLayer("layer2", new DenseLayer.Builder().nIn(2).nOut(3).build(), "layer1") - .addLayer("layer3", - new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nIn(3).nOut(3) - .build(), - "layer2") - .setOutputs("layer3").build()); + NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)).activation(Activation.IDENTITY); + FineTuneConfiguration fineTuneConfiguration = new FineTuneConfiguration.Builder().updater(new Sgd(0.1)).activation(Activation.IDENTITY).build(); + ComputationGraph modelToFineTune = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In").addLayer("layer0", new DenseLayer.Builder().nIn(4).nOut(5).build(), "layer0In").addLayer("layer1", new DenseLayer.Builder().nIn(3).nOut(2).build(), "layer0").addLayer("layer2", new DenseLayer.Builder().nIn(2).nOut(3).build(), "layer1").addLayer("layer3", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3).build(), "layer2").setOutputs("layer3").build()); modelToFineTune.init(); - ComputationGraph modelNow = new TransferLearning.GraphBuilder(modelToFineTune) - .fineTuneConfiguration(fineTuneConfiguration).nOutReplace("layer3", 2, WeightInit.XAVIER) - .nOutReplace("layer0", 3, new NormalDistribution(1, 1e-1), WeightInit.XAVIER) - //.setOutputs("layer3") - .build(); - + ComputationGraph modelNow = new TransferLearning.GraphBuilder(modelToFineTune).fineTuneConfiguration(fineTuneConfiguration).nOutReplace("layer3", 2, WeightInit.XAVIER).nOutReplace("layer0", 3, new NormalDistribution(1, 1e-1), WeightInit.XAVIER).build(); BaseLayer bl0 = ((BaseLayer) modelNow.getLayer("layer0").conf().getLayer()); BaseLayer bl1 = ((BaseLayer) modelNow.getLayer("layer1").conf().getLayer()); BaseLayer bl3 = ((BaseLayer) modelNow.getLayer("layer3").conf().getLayer()); assertEquals(bl0.getWeightInitFn(), new WeightInitDistribution(new NormalDistribution(1, 1e-1))); assertEquals(bl1.getWeightInitFn(), new WeightInitXavier()); assertEquals(bl1.getWeightInitFn(), new WeightInitXavier()); - - ComputationGraph modelExpectedArch = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In") - .addLayer("layer0", new DenseLayer.Builder().nIn(4).nOut(3).build(), "layer0In") - .addLayer("layer1", new DenseLayer.Builder().nIn(3).nOut(2).build(), "layer0") - .addLayer("layer2", new DenseLayer.Builder().nIn(2).nOut(3).build(), "layer1") - .addLayer("layer3", - new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nIn(3).nOut(2) - .build(), - "layer2") - .setOutputs("layer3").build()); - + ComputationGraph modelExpectedArch = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In").addLayer("layer0", new DenseLayer.Builder().nIn(4).nOut(3).build(), "layer0In").addLayer("layer1", new DenseLayer.Builder().nIn(3).nOut(2).build(), "layer0").addLayer("layer2", new DenseLayer.Builder().nIn(2).nOut(3).build(), "layer1").addLayer("layer3", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(2).build(), "layer2").setOutputs("layer3").build()); modelExpectedArch.init(); - - //modelNow should have the same architecture as modelExpectedArch + // modelNow should have the same architecture as modelExpectedArch assertArrayEquals(modelExpectedArch.params().shape(), modelNow.params().shape()); - assertArrayEquals(modelExpectedArch.getLayer("layer0").params().shape(), - modelNow.getLayer("layer0").params().shape()); - assertArrayEquals(modelExpectedArch.getLayer("layer1").params().shape(), - modelNow.getLayer("layer1").params().shape()); - assertArrayEquals(modelExpectedArch.getLayer("layer2").params().shape(), - modelNow.getLayer("layer2").params().shape()); - assertArrayEquals(modelExpectedArch.getLayer("layer3").params().shape(), - modelNow.getLayer("layer3").params().shape()); - + assertArrayEquals(modelExpectedArch.getLayer("layer0").params().shape(), modelNow.getLayer("layer0").params().shape()); + assertArrayEquals(modelExpectedArch.getLayer("layer1").params().shape(), modelNow.getLayer("layer1").params().shape()); + assertArrayEquals(modelExpectedArch.getLayer("layer2").params().shape(), modelNow.getLayer("layer2").params().shape()); + assertArrayEquals(modelExpectedArch.getLayer("layer3").params().shape(), modelNow.getLayer("layer3").params().shape()); modelNow.setParams(modelExpectedArch.params()); - //fit should give the same results + // fit should give the same results modelExpectedArch.fit(randomData); modelNow.fit(randomData); assertEquals(modelExpectedArch.score(), modelNow.score(), 1e-8); @@ -179,65 +114,24 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest { } @Test - public void testRemoveAndAdd() { + @DisplayName("Test Remove And Add") + void testRemoveAndAdd() { DataSet randomData = new DataSet(Nd4j.rand(10, 4), Nd4j.rand(10, 3)); - - NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)) - .activation(Activation.IDENTITY); - FineTuneConfiguration fineTuneConfiguration = new FineTuneConfiguration.Builder().updater(new Sgd(0.1)) - .activation(Activation.IDENTITY).build(); - - ComputationGraph modelToFineTune = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In") - .addLayer("layer0", new DenseLayer.Builder().nIn(4).nOut(5).build(), "layer0In") - .addLayer("layer1", new DenseLayer.Builder().nIn(5).nOut(2).build(), "layer0") - .addLayer("layer2", new DenseLayer.Builder().nIn(2).nOut(3).build(), "layer1") - .addLayer("layer3", - new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nIn(3).nOut(3) - .build(), - "layer2") - .setOutputs("layer3").build()); + NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)).activation(Activation.IDENTITY); + FineTuneConfiguration fineTuneConfiguration = new FineTuneConfiguration.Builder().updater(new Sgd(0.1)).activation(Activation.IDENTITY).build(); + ComputationGraph modelToFineTune = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In").addLayer("layer0", new DenseLayer.Builder().nIn(4).nOut(5).build(), "layer0In").addLayer("layer1", new DenseLayer.Builder().nIn(5).nOut(2).build(), "layer0").addLayer("layer2", new DenseLayer.Builder().nIn(2).nOut(3).build(), "layer1").addLayer("layer3", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3).build(), "layer2").setOutputs("layer3").build()); modelToFineTune.init(); - - ComputationGraph modelNow = new TransferLearning.GraphBuilder(modelToFineTune) - .fineTuneConfiguration(fineTuneConfiguration) - .nOutReplace("layer0", 7, WeightInit.XAVIER, WeightInit.XAVIER) - .nOutReplace("layer2", 5, WeightInit.XAVIER).removeVertexKeepConnections("layer3") - .addLayer("layer3", - new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(5).nOut(3) - .activation(Activation.SOFTMAX).build(), - "layer2") - //.setOutputs("layer3") - .build(); - - ComputationGraph modelExpectedArch = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In") - .addLayer("layer0", new DenseLayer.Builder().nIn(4).nOut(7).build(), "layer0In") - .addLayer("layer1", new DenseLayer.Builder().nIn(7).nOut(2).build(), "layer0") - .addLayer("layer2", new DenseLayer.Builder().nIn(2).nOut(5).build(), "layer1") - .addLayer("layer3", - new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nIn(5).nOut(3) - .build(), - "layer2") - .setOutputs("layer3").build()); - + ComputationGraph modelNow = new TransferLearning.GraphBuilder(modelToFineTune).fineTuneConfiguration(fineTuneConfiguration).nOutReplace("layer0", 7, WeightInit.XAVIER, WeightInit.XAVIER).nOutReplace("layer2", 5, WeightInit.XAVIER).removeVertexKeepConnections("layer3").addLayer("layer3", new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(5).nOut(3).activation(Activation.SOFTMAX).build(), "layer2").build(); + ComputationGraph modelExpectedArch = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In").addLayer("layer0", new DenseLayer.Builder().nIn(4).nOut(7).build(), "layer0In").addLayer("layer1", new DenseLayer.Builder().nIn(7).nOut(2).build(), "layer0").addLayer("layer2", new DenseLayer.Builder().nIn(2).nOut(5).build(), "layer1").addLayer("layer3", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(5).nOut(3).build(), "layer2").setOutputs("layer3").build()); modelExpectedArch.init(); - - //modelNow should have the same architecture as modelExpectedArch + // modelNow should have the same architecture as modelExpectedArch assertArrayEquals(modelExpectedArch.params().shape(), modelNow.params().shape()); - assertArrayEquals(modelExpectedArch.getLayer("layer0").params().shape(), - modelNow.getLayer("layer0").params().shape()); - assertArrayEquals(modelExpectedArch.getLayer("layer1").params().shape(), - modelNow.getLayer("layer1").params().shape()); - assertArrayEquals(modelExpectedArch.getLayer("layer2").params().shape(), - modelNow.getLayer("layer2").params().shape()); - assertArrayEquals(modelExpectedArch.getLayer("layer3").params().shape(), - modelNow.getLayer("layer3").params().shape()); - + assertArrayEquals(modelExpectedArch.getLayer("layer0").params().shape(), modelNow.getLayer("layer0").params().shape()); + assertArrayEquals(modelExpectedArch.getLayer("layer1").params().shape(), modelNow.getLayer("layer1").params().shape()); + assertArrayEquals(modelExpectedArch.getLayer("layer2").params().shape(), modelNow.getLayer("layer2").params().shape()); + assertArrayEquals(modelExpectedArch.getLayer("layer3").params().shape(), modelNow.getLayer("layer3").params().shape()); modelNow.setParams(modelExpectedArch.params()); - //fit should give the same results + // fit should give the same results modelExpectedArch.fit(randomData); modelNow.fit(randomData); assertEquals(modelExpectedArch.score(), modelNow.score(), 1e-8); @@ -245,145 +139,20 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest { } @Test - public void testAllWithCNN() { - + @DisplayName("Test All With CNN") + void testAllWithCNN() { DataSet randomData = new DataSet(Nd4j.rand(10, 28 * 28 * 3).reshape(10, 3, 28, 28), Nd4j.rand(10, 10)); - ComputationGraph modelToFineTune = - new ComputationGraph( - new NeuralNetConfiguration.Builder().seed(123) - .weightInit(WeightInit.XAVIER) - .updater(new Nesterovs(0.01, 0.9)).graphBuilder() - .addInputs("layer0In") - .setInputTypes(InputType.convolutionalFlat(28, 28, - 3)) - .addLayer("layer0", - new ConvolutionLayer.Builder(5, 5).nIn(3) - .stride(1, 1).nOut(20) - .activation(Activation.IDENTITY) - .build(), - "layer0In") - .addLayer("layer1", - new SubsamplingLayer.Builder( - SubsamplingLayer.PoolingType.MAX) - .kernelSize(2, 2) - .stride(2, 2) - .build(), - "layer0") - .addLayer("layer2", - new ConvolutionLayer.Builder(5, 5).stride(1, 1) - .nOut(50) - .activation(Activation.IDENTITY) - .build(), - "layer1") - .addLayer("layer3", - new SubsamplingLayer.Builder( - SubsamplingLayer.PoolingType.MAX) - .kernelSize(2, 2) - .stride(2, 2) - .build(), - "layer2") - .addLayer("layer4", - new DenseLayer.Builder() - .activation(Activation.RELU) - .nOut(500).build(), - "layer3") - .addLayer("layer5", - new DenseLayer.Builder() - .activation(Activation.RELU) - .nOut(250).build(), - "layer4") - .addLayer("layer6", - new OutputLayer.Builder( - LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) - .nOut(100) - .activation(Activation.SOFTMAX) - .build(), - "layer5") - .setOutputs("layer6").build()); + ComputationGraph modelToFineTune = new ComputationGraph(new NeuralNetConfiguration.Builder().seed(123).weightInit(WeightInit.XAVIER).updater(new Nesterovs(0.01, 0.9)).graphBuilder().addInputs("layer0In").setInputTypes(InputType.convolutionalFlat(28, 28, 3)).addLayer("layer0", new ConvolutionLayer.Builder(5, 5).nIn(3).stride(1, 1).nOut(20).activation(Activation.IDENTITY).build(), "layer0In").addLayer("layer1", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(2, 2).build(), "layer0").addLayer("layer2", new ConvolutionLayer.Builder(5, 5).stride(1, 1).nOut(50).activation(Activation.IDENTITY).build(), "layer1").addLayer("layer3", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(2, 2).build(), "layer2").addLayer("layer4", new DenseLayer.Builder().activation(Activation.RELU).nOut(500).build(), "layer3").addLayer("layer5", new DenseLayer.Builder().activation(Activation.RELU).nOut(250).build(), "layer4").addLayer("layer6", new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(100).activation(Activation.SOFTMAX).build(), "layer5").setOutputs("layer6").build()); modelToFineTune.init(); - - //this will override the learning configuration set in the model + // this will override the learning configuration set in the model NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().seed(456).updater(new Sgd(0.001)); - FineTuneConfiguration fineTuneConfiguration = new FineTuneConfiguration.Builder().seed(456).updater(new Sgd(0.001)) - .build(); - - ComputationGraph modelNow = - new TransferLearning.GraphBuilder(modelToFineTune).fineTuneConfiguration(fineTuneConfiguration) - .setFeatureExtractor("layer1").nOutReplace("layer4", 600, WeightInit.XAVIER) - .removeVertexAndConnections("layer5").removeVertexAndConnections("layer6") - .setInputs("layer0In").setInputTypes(InputType.convolutionalFlat(28, 28, 3)) - .addLayer("layer5", - new DenseLayer.Builder().activation(Activation.RELU).nIn(600) - .nOut(300).build(), - "layer4") - .addLayer("layer6", - new DenseLayer.Builder().activation(Activation.RELU).nIn(300) - .nOut(150).build(), - "layer5") - .addLayer("layer7", - new DenseLayer.Builder().activation(Activation.RELU).nIn(150) - .nOut(50).build(), - "layer6") - .addLayer("layer8", - new OutputLayer.Builder( - LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) - .activation(Activation.SOFTMAX) - .nIn(50).nOut(10).build(), - "layer7") - .setOutputs("layer8").build(); - - ComputationGraph modelExpectedArch = - new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In") - .setInputTypes(InputType.convolutionalFlat(28,28, 3)) - .addLayer("layer0", - new FrozenLayer(new ConvolutionLayer.Builder(5, 5).nIn(3) - .stride(1, 1).nOut(20) - .activation(Activation.IDENTITY).build()), - "layer0In") - .addLayer("layer1", - new FrozenLayer(new SubsamplingLayer.Builder( - SubsamplingLayer.PoolingType.MAX) - .kernelSize(2, 2).stride(2, 2) - .build()), - "layer0") - .addLayer("layer2", - new ConvolutionLayer.Builder(5, 5).stride(1, 1).nOut(50) - .activation(Activation.IDENTITY).build(), - "layer1") - .addLayer("layer3", - new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX) - .kernelSize(2, 2).stride(2, 2).build(), - "layer2") - .addLayer("layer4", - new DenseLayer.Builder().activation(Activation.RELU).nOut(600) - .build(), - "layer3") - .addLayer("layer5", - new DenseLayer.Builder().activation(Activation.RELU).nOut(300) - .build(), - "layer4") - .addLayer("layer6", - new DenseLayer.Builder().activation(Activation.RELU).nOut(150) - .build(), - "layer5") - .addLayer("layer7", - new DenseLayer.Builder().activation(Activation.RELU).nOut(50) - .build(), - "layer6") - .addLayer("layer8", - new OutputLayer.Builder( - LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) - .nOut(10) - .activation(Activation.SOFTMAX) - .build(), - "layer7") - .setOutputs("layer8").build()); + FineTuneConfiguration fineTuneConfiguration = new FineTuneConfiguration.Builder().seed(456).updater(new Sgd(0.001)).build(); + ComputationGraph modelNow = new TransferLearning.GraphBuilder(modelToFineTune).fineTuneConfiguration(fineTuneConfiguration).setFeatureExtractor("layer1").nOutReplace("layer4", 600, WeightInit.XAVIER).removeVertexAndConnections("layer5").removeVertexAndConnections("layer6").setInputs("layer0In").setInputTypes(InputType.convolutionalFlat(28, 28, 3)).addLayer("layer5", new DenseLayer.Builder().activation(Activation.RELU).nIn(600).nOut(300).build(), "layer4").addLayer("layer6", new DenseLayer.Builder().activation(Activation.RELU).nIn(300).nOut(150).build(), "layer5").addLayer("layer7", new DenseLayer.Builder().activation(Activation.RELU).nIn(150).nOut(50).build(), "layer6").addLayer("layer8", new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).activation(Activation.SOFTMAX).nIn(50).nOut(10).build(), "layer7").setOutputs("layer8").build(); + ComputationGraph modelExpectedArch = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In").setInputTypes(InputType.convolutionalFlat(28, 28, 3)).addLayer("layer0", new FrozenLayer(new ConvolutionLayer.Builder(5, 5).nIn(3).stride(1, 1).nOut(20).activation(Activation.IDENTITY).build()), "layer0In").addLayer("layer1", new FrozenLayer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(2, 2).build()), "layer0").addLayer("layer2", new ConvolutionLayer.Builder(5, 5).stride(1, 1).nOut(50).activation(Activation.IDENTITY).build(), "layer1").addLayer("layer3", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(2, 2).build(), "layer2").addLayer("layer4", new DenseLayer.Builder().activation(Activation.RELU).nOut(600).build(), "layer3").addLayer("layer5", new DenseLayer.Builder().activation(Activation.RELU).nOut(300).build(), "layer4").addLayer("layer6", new DenseLayer.Builder().activation(Activation.RELU).nOut(150).build(), "layer5").addLayer("layer7", new DenseLayer.Builder().activation(Activation.RELU).nOut(50).build(), "layer6").addLayer("layer8", new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(10).activation(Activation.SOFTMAX).build(), "layer7").setOutputs("layer8").build()); modelExpectedArch.init(); modelExpectedArch.getVertex("layer0").setLayerAsFrozen(); modelExpectedArch.getVertex("layer1").setLayerAsFrozen(); - assertEquals(modelExpectedArch.getConfiguration().toJson(), modelNow.getConfiguration().toJson()); - modelNow.setParams(modelExpectedArch.params()); int i = 0; while (i < 5) { @@ -392,277 +161,119 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest { i++; } assertEquals(modelExpectedArch.params(), modelNow.params()); - } - @Test - public void testTransferGlobalPool() { - - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(new Adam(0.1)) - .weightInit(WeightInit.XAVIER) - .graphBuilder().addInputs("in") - .addLayer("blstm1",new GravesBidirectionalLSTM.Builder().nIn(10).nOut(10) - .activation(Activation.TANH).build(), - "in") - .addLayer("pool", new GlobalPoolingLayer.Builder().build(), "blstm1") - .addLayer("dense", new DenseLayer.Builder().nIn(10).nOut(10).build(), "pool") - .addLayer("out", new OutputLayer.Builder().nIn(10).nOut(10).activation(Activation.IDENTITY) - .lossFunction(LossFunctions.LossFunction.MSE).build(), "dense") - .setOutputs("out").build(); - + @DisplayName("Test Transfer Global Pool") + void testTransferGlobalPool() { + ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(new Adam(0.1)).weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in").addLayer("blstm1", new GravesBidirectionalLSTM.Builder().nIn(10).nOut(10).activation(Activation.TANH).build(), "in").addLayer("pool", new GlobalPoolingLayer.Builder().build(), "blstm1").addLayer("dense", new DenseLayer.Builder().nIn(10).nOut(10).build(), "pool").addLayer("out", new OutputLayer.Builder().nIn(10).nOut(10).activation(Activation.IDENTITY).lossFunction(LossFunctions.LossFunction.MSE).build(), "dense").setOutputs("out").build(); ComputationGraph g = new ComputationGraph(conf); g.init(); - - FineTuneConfiguration fineTuneConfiguration = - new FineTuneConfiguration.Builder().seed(12345).updater(new Sgd(0.01)).build(); - - ComputationGraph graph = new TransferLearning.GraphBuilder(g).fineTuneConfiguration(fineTuneConfiguration) - .removeVertexKeepConnections("out").setFeatureExtractor("dense") - .addLayer("out", new OutputLayer.Builder().updater(new Adam(0.1)) - .weightInit(WeightInit.XAVIER) - .activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT) - .nIn(10).nOut(5).build(), "dense") - .build(); - - ComputationGraphConfiguration confExpected = new NeuralNetConfiguration.Builder().seed(12345) - .updater(new Sgd(0.01)) - .weightInit(WeightInit.XAVIER) - .graphBuilder().addInputs("in") - .addLayer("blstm1", - new FrozenLayer(new GravesBidirectionalLSTM.Builder().nIn(10).nOut(10) - .activation(Activation.TANH).build()), - "in") - .addLayer("pool", new FrozenLayer(new GlobalPoolingLayer.Builder().build()), "blstm1") - .addLayer("dense", new FrozenLayer(new DenseLayer.Builder().nIn(10).nOut(10).build()), "pool") - .addLayer("out", new OutputLayer.Builder().nIn(10).nOut(5).activation(Activation.SOFTMAX) - .updater(new Adam(0.1)) - .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "dense") - .setOutputs("out").build(); - + FineTuneConfiguration fineTuneConfiguration = new FineTuneConfiguration.Builder().seed(12345).updater(new Sgd(0.01)).build(); + ComputationGraph graph = new TransferLearning.GraphBuilder(g).fineTuneConfiguration(fineTuneConfiguration).removeVertexKeepConnections("out").setFeatureExtractor("dense").addLayer("out", new OutputLayer.Builder().updater(new Adam(0.1)).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).nIn(10).nOut(5).build(), "dense").build(); + ComputationGraphConfiguration confExpected = new NeuralNetConfiguration.Builder().seed(12345).updater(new Sgd(0.01)).weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in").addLayer("blstm1", new FrozenLayer(new GravesBidirectionalLSTM.Builder().nIn(10).nOut(10).activation(Activation.TANH).build()), "in").addLayer("pool", new FrozenLayer(new GlobalPoolingLayer.Builder().build()), "blstm1").addLayer("dense", new FrozenLayer(new DenseLayer.Builder().nIn(10).nOut(10).build()), "pool").addLayer("out", new OutputLayer.Builder().nIn(10).nOut(5).activation(Activation.SOFTMAX).updater(new Adam(0.1)).lossFunction(LossFunctions.LossFunction.MCXENT).build(), "dense").setOutputs("out").build(); ComputationGraph modelExpected = new ComputationGraph(confExpected); modelExpected.init(); - - -// assertEquals(confExpected, graph.getConfiguration()); + // assertEquals(confExpected, graph.getConfiguration()); assertEquals(confExpected.toJson(), graph.getConfiguration().toJson()); } - @Test - public void testObjectOverrides(){ - //https://github.com/deeplearning4j/deeplearning4j/issues/4368 - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() - .dropOut(0.5) - .weightNoise(new DropConnect(0.5)) - .l2(0.5) - .constrainWeights(new UnitNormConstraint()) - .graphBuilder() - .addInputs("in") - .addLayer("layer", new DenseLayer.Builder().nIn(10).nOut(10).build(), "in") - .setOutputs("layer") - .build(); - + @DisplayName("Test Object Overrides") + void testObjectOverrides() { + // https://github.com/deeplearning4j/deeplearning4j/issues/4368 + ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().dropOut(0.5).weightNoise(new DropConnect(0.5)).l2(0.5).constrainWeights(new UnitNormConstraint()).graphBuilder().addInputs("in").addLayer("layer", new DenseLayer.Builder().nIn(10).nOut(10).build(), "in").setOutputs("layer").build(); ComputationGraph orig = new ComputationGraph(conf); orig.init(); - - FineTuneConfiguration ftc = new FineTuneConfiguration.Builder() - .dropOut(0) - .weightNoise(null) - .constraints(null) - .l2(0.0) - .build(); - - ComputationGraph transfer = new TransferLearning.GraphBuilder(orig) - .fineTuneConfiguration(ftc) - .build(); - + FineTuneConfiguration ftc = new FineTuneConfiguration.Builder().dropOut(0).weightNoise(null).constraints(null).l2(0.0).build(); + ComputationGraph transfer = new TransferLearning.GraphBuilder(orig).fineTuneConfiguration(ftc).build(); DenseLayer l = (DenseLayer) transfer.getLayer(0).conf().getLayer(); - assertNull(l.getIDropout()); assertNull(l.getWeightNoise()); assertNull(l.getConstraints()); assertNull(TestUtils.getL2Reg(l)); } - @Test - public void testTransferLearningSubsequent() { + @DisplayName("Test Transfer Learning Subsequent") + void testTransferLearningSubsequent() { String inputName = "in"; String outputName = "out"; - final String firstConv = "firstConv"; final String secondConv = "secondConv"; - final INDArray input = Nd4j.create(6,6,6,6); - final ComputationGraph graph = new ComputationGraph(new NeuralNetConfiguration.Builder() - .weightInit(new ConstantDistribution(666)) - .graphBuilder() - .addInputs(inputName) - .setOutputs(outputName) - .setInputTypes(InputType.inferInputTypes(input)) - .addLayer(firstConv, new Convolution2D.Builder(3, 3) - .nOut(10) - .build(), inputName) - .addLayer(secondConv, new Convolution2D.Builder(1, 1) - .nOut(3) - .build(), firstConv) - .addLayer(outputName, new OutputLayer.Builder() - .nOut(2) - .lossFunction(LossFunctions.LossFunction.MSE) - .build(), secondConv) - .build()); + final INDArray input = Nd4j.create(6, 6, 6, 6); + final ComputationGraph graph = new ComputationGraph(new NeuralNetConfiguration.Builder().weightInit(new ConstantDistribution(666)).graphBuilder().addInputs(inputName).setOutputs(outputName).setInputTypes(InputType.inferInputTypes(input)).addLayer(firstConv, new Convolution2D.Builder(3, 3).nOut(10).build(), inputName).addLayer(secondConv, new Convolution2D.Builder(1, 1).nOut(3).build(), firstConv).addLayer(outputName, new OutputLayer.Builder().nOut(2).lossFunction(LossFunctions.LossFunction.MSE).build(), secondConv).build()); graph.init(); - - final ComputationGraph newGraph = new TransferLearning - .GraphBuilder(graph) - .nOutReplace(firstConv, 7, new ConstantDistribution(333)) - .nOutReplace(secondConv, 3, new ConstantDistribution(111)) - .removeVertexAndConnections(outputName) - .addLayer(outputName, new OutputLayer.Builder() - .nIn(48).nOut(2) - .lossFunction(LossFunctions.LossFunction.MSE) - .build(), new CnnToFeedForwardPreProcessor(4,4,3), secondConv) - .setOutputs(outputName) - .build(); + final ComputationGraph newGraph = new TransferLearning.GraphBuilder(graph).nOutReplace(firstConv, 7, new ConstantDistribution(333)).nOutReplace(secondConv, 3, new ConstantDistribution(111)).removeVertexAndConnections(outputName).addLayer(outputName, new OutputLayer.Builder().nIn(48).nOut(2).lossFunction(LossFunctions.LossFunction.MSE).build(), new CnnToFeedForwardPreProcessor(4, 4, 3), secondConv).setOutputs(outputName).build(); newGraph.init(); - - assertEquals("Incorrect # inputs", 7, newGraph.layerInputSize(secondConv)); - + assertEquals(7, newGraph.layerInputSize(secondConv), "Incorrect # inputs"); newGraph.outputSingle(input); } - - @Test - public void testChangeNOutNIn() { + @DisplayName("Test Change N Out N In") + void testChangeNOutNIn() { final String inputName = "input"; final String changeNoutName = "changeNout"; final String poolName = "pool"; final String afterPoolName = "afterPool"; final String outputName = "output"; - final INDArray input = Nd4j.create(new long[] {1, 2, 4, 4}); - final ComputationGraph graph = new ComputationGraph(new NeuralNetConfiguration.Builder() - .graphBuilder() - .addInputs(inputName) - .setOutputs(outputName) - .setInputTypes(InputType.inferInputTypes(input)) - .addLayer(changeNoutName, new Convolution2D.Builder(1, 1) - .nOut(10) - .build(), inputName) - .addLayer(poolName, new SubsamplingLayer.Builder(1,1).build(), changeNoutName) - .addLayer(afterPoolName, new Convolution2D.Builder(1, 1) - .nOut(7) - .build(), poolName) - .addLayer(outputName, new OutputLayer.Builder() - .activation(Activation.SOFTMAX) - .nOut(2) - .build(), afterPoolName) - .build()); + final INDArray input = Nd4j.create(new long[] { 1, 2, 4, 4 }); + final ComputationGraph graph = new ComputationGraph(new NeuralNetConfiguration.Builder().graphBuilder().addInputs(inputName).setOutputs(outputName).setInputTypes(InputType.inferInputTypes(input)).addLayer(changeNoutName, new Convolution2D.Builder(1, 1).nOut(10).build(), inputName).addLayer(poolName, new SubsamplingLayer.Builder(1, 1).build(), changeNoutName).addLayer(afterPoolName, new Convolution2D.Builder(1, 1).nOut(7).build(), poolName).addLayer(outputName, new OutputLayer.Builder().activation(Activation.SOFTMAX).nOut(2).build(), afterPoolName).build()); graph.init(); - - final ComputationGraph newGraph = new TransferLearning.GraphBuilder(graph) - .nOutReplace(changeNoutName, 5, WeightInit.XAVIER) - .nInReplace(afterPoolName, 5, WeightInit.XAVIER) - .build(); - + final ComputationGraph newGraph = new TransferLearning.GraphBuilder(graph).nOutReplace(changeNoutName, 5, WeightInit.XAVIER).nInReplace(afterPoolName, 5, WeightInit.XAVIER).build(); newGraph.init(); - - assertEquals("Incorrect number of outputs!", 5 , newGraph.layerSize(changeNoutName)); - assertEquals("Incorrect number of inputs!", 5, newGraph.layerInputSize(afterPoolName)); + assertEquals(5, newGraph.layerSize(changeNoutName), "Incorrect number of outputs!"); + assertEquals(5, newGraph.layerInputSize(afterPoolName), "Incorrect number of inputs!"); newGraph.output(input); } - - - @Test - public void testTransferLearningSameDiffLayersGraph(){ - - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() - - .graphBuilder() - .addInputs("in") - .layer("l0", new LSTM.Builder().nIn(5).nOut(5).build(), "in") - .layer("l1", new RecurrentAttentionLayer.Builder().nHeads(1).headSize(5).nIn(5).nOut(5).build(), "l0") - .layer("out", new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build(), "l1") - .setOutputs("out") - .build(); - + @DisplayName("Test Transfer Learning Same Diff Layers Graph") + void testTransferLearningSameDiffLayersGraph() { + ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in").layer("l0", new LSTM.Builder().nIn(5).nOut(5).build(), "in").layer("l1", new RecurrentAttentionLayer.Builder().nHeads(1).headSize(5).nIn(5).nOut(5).build(), "l0").layer("out", new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build(), "l1").setOutputs("out").build(); ComputationGraph cg = new ComputationGraph(conf); cg.init(); - INDArray arr = Nd4j.rand(DataType.FLOAT, 2, 5, 10); INDArray out = cg.output(arr)[0]; - - - ComputationGraph cg2 = new TransferLearning.GraphBuilder(cg).removeVertexAndConnections("out") - .fineTuneConfiguration(FineTuneConfiguration.builder().updater(new Adam(0.01)).build()) - .removeVertexAndConnections("out") - .addLayer("newOut", new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build(), "l1") - .setOutputs("newOut") - .build(); - + ComputationGraph cg2 = new TransferLearning.GraphBuilder(cg).removeVertexAndConnections("out").fineTuneConfiguration(FineTuneConfiguration.builder().updater(new Adam(0.01)).build()).removeVertexAndConnections("out").addLayer("newOut", new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build(), "l1").setOutputs("newOut").build(); cg2.output(arr); - - Map m = new HashMap<>(cg.paramTable()); + Map m = new HashMap<>(cg.paramTable()); m.put("newOut_W", m.remove("out_W")); m.put("newOut_b", m.remove("out_b")); cg2.setParamTable(m); - - Map p1 = cg.paramTable(); - Map p2 = cg2.paramTable(); - for(String s : p1.keySet()){ + Map p1 = cg.paramTable(); + Map p2 = cg2.paramTable(); + for (String s : p1.keySet()) { INDArray i1 = p1.get(s); INDArray i2 = p2.get(s.replaceAll("out", "newOut")); - assertEquals(s, i1, i2); + assertEquals(i1, i2,s); } - INDArray out2 = cg2.outputSingle(arr); assertEquals(out, out2); } @Test - public void testTransferLearningSameDiffLayersGraphVertex(){ - - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() - - .graphBuilder() - .addInputs("in") - .layer("l0", new LSTM.Builder().nIn(5).nOut(5).build(), "in") - .addVertex("l1", new AttentionVertex.Builder().nHeads(1).headSize(5).nInKeys(5).nInQueries(5).nInValues(5).nOut(5).build(), "l0", "l0", "l0") - .layer("out", new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build(), "l1") - .setOutputs("out") - .build(); - + @DisplayName("Test Transfer Learning Same Diff Layers Graph Vertex") + void testTransferLearningSameDiffLayersGraphVertex() { + ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in").layer("l0", new LSTM.Builder().nIn(5).nOut(5).build(), "in").addVertex("l1", new AttentionVertex.Builder().nHeads(1).headSize(5).nInKeys(5).nInQueries(5).nInValues(5).nOut(5).build(), "l0", "l0", "l0").layer("out", new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build(), "l1").setOutputs("out").build(); ComputationGraph cg = new ComputationGraph(conf); cg.init(); - INDArray arr = Nd4j.rand(DataType.FLOAT, 2, 5, 10); INDArray out = cg.output(arr)[0]; - - - ComputationGraph cg2 = new TransferLearning.GraphBuilder(cg).removeVertexAndConnections("out") - .fineTuneConfiguration(FineTuneConfiguration.builder().updater(new Adam(0.01)).build()) - .removeVertexAndConnections("out") - .addLayer("newOut", new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build(), "l1") - .setOutputs("newOut") - .build(); - + ComputationGraph cg2 = new TransferLearning.GraphBuilder(cg).removeVertexAndConnections("out").fineTuneConfiguration(FineTuneConfiguration.builder().updater(new Adam(0.01)).build()).removeVertexAndConnections("out").addLayer("newOut", new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build(), "l1").setOutputs("newOut").build(); cg2.output(arr); - - Map m = new HashMap<>(cg.paramTable()); + Map m = new HashMap<>(cg.paramTable()); m.put("newOut_W", m.remove("out_W")); m.put("newOut_b", m.remove("out_b")); cg2.setParamTable(m); - - Map p1 = cg.paramTable(); - Map p2 = cg2.paramTable(); - for(String s : p1.keySet()){ + Map p1 = cg.paramTable(); + Map p2 = cg2.paramTable(); + for (String s : p1.keySet()) { INDArray i1 = p1.get(s); INDArray i2 = p2.get(s.replaceAll("out", "newOut")); - assertEquals(s, i1, i2); + assertEquals(i1, i2,s); } - INDArray out2 = cg2.outputSingle(arr); assertEquals(out, out2); } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningHelperTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningHelperTest.java index 8305879e5..e38f8ba4d 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningHelperTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningHelperTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.transferlearning; import lombok.extern.slf4j.Slf4j; @@ -31,7 +30,7 @@ import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; @@ -39,20 +38,19 @@ import org.nd4j.linalg.dataset.MultiDataSet; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.lossfunctions.LossFunctions; - import java.util.List; - -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; @Slf4j -public class TransferLearningHelperTest extends BaseDL4JTest { +@DisplayName("Transfer Learning Helper Test") +class TransferLearningHelperTest extends BaseDL4JTest { @Test - public void tesUnfrozenSubset() { - - NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().seed(124) - .activation(Activation.IDENTITY) - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Sgd(0.1)); + @DisplayName("Tes Unfrozen Subset") + void tesUnfrozenSubset() { + NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().seed(124).activation(Activation.IDENTITY).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Sgd(0.1)); /* (inCentre) (inRight) | | @@ -67,185 +65,80 @@ public class TransferLearningHelperTest extends BaseDL4JTest { (outLeft) (outCentre) (outRight) */ - - ComputationGraphConfiguration conf = overallConf.graphBuilder().addInputs("inCentre", "inRight") - .addLayer("denseCentre0", new DenseLayer.Builder().nIn(10).nOut(9).build(), "inCentre") - .addLayer("denseCentre1", new DenseLayer.Builder().nIn(9).nOut(8).build(), "denseCentre0") - .addLayer("denseCentre2", new DenseLayer.Builder().nIn(8).nOut(7).build(), "denseCentre1") - .addLayer("denseCentre3", new DenseLayer.Builder().nIn(7).nOut(7).build(), "denseCentre2") - .addLayer("outCentre", - new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(7).nOut(4).build(), - "denseCentre3") - .addVertex("subsetLeft", new SubsetVertex(0, 3), "denseCentre1") - .addLayer("denseLeft0", new DenseLayer.Builder().nIn(4).nOut(5).build(), "subsetLeft") - .addLayer("outLeft", - new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(5).nOut(6).build(), - "denseLeft0") - .addLayer("denseRight", new DenseLayer.Builder().nIn(7).nOut(7).build(), "denseCentre2") - .addLayer("denseRight0", new DenseLayer.Builder().nIn(2).nOut(3).build(), "inRight") - .addVertex("mergeRight", new MergeVertex(), "denseRight", "denseRight0") - .addLayer("denseRight1", new DenseLayer.Builder().nIn(10).nOut(5).build(), "mergeRight") - .addLayer("outRight", - new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(5).nOut(5).build(), - "denseRight1") - .setOutputs("outLeft", "outCentre", "outRight").build(); - + ComputationGraphConfiguration conf = overallConf.graphBuilder().addInputs("inCentre", "inRight").addLayer("denseCentre0", new DenseLayer.Builder().nIn(10).nOut(9).build(), "inCentre").addLayer("denseCentre1", new DenseLayer.Builder().nIn(9).nOut(8).build(), "denseCentre0").addLayer("denseCentre2", new DenseLayer.Builder().nIn(8).nOut(7).build(), "denseCentre1").addLayer("denseCentre3", new DenseLayer.Builder().nIn(7).nOut(7).build(), "denseCentre2").addLayer("outCentre", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(7).nOut(4).build(), "denseCentre3").addVertex("subsetLeft", new SubsetVertex(0, 3), "denseCentre1").addLayer("denseLeft0", new DenseLayer.Builder().nIn(4).nOut(5).build(), "subsetLeft").addLayer("outLeft", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(5).nOut(6).build(), "denseLeft0").addLayer("denseRight", new DenseLayer.Builder().nIn(7).nOut(7).build(), "denseCentre2").addLayer("denseRight0", new DenseLayer.Builder().nIn(2).nOut(3).build(), "inRight").addVertex("mergeRight", new MergeVertex(), "denseRight", "denseRight0").addLayer("denseRight1", new DenseLayer.Builder().nIn(10).nOut(5).build(), "mergeRight").addLayer("outRight", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(5).nOut(5).build(), "denseRight1").setOutputs("outLeft", "outCentre", "outRight").build(); ComputationGraph modelToTune = new ComputationGraph(conf); modelToTune.init(); - TransferLearningHelper helper = new TransferLearningHelper(modelToTune, "denseCentre2"); - ComputationGraph modelSubset = helper.unfrozenGraph(); - - ComputationGraphConfiguration expectedConf = - overallConf.graphBuilder().addInputs("denseCentre1", "denseCentre2", "inRight") //inputs are in sorted order - .addLayer("denseCentre3", new DenseLayer.Builder().nIn(7).nOut(7).build(), - "denseCentre2") - .addLayer("outCentre", - new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(7) - .nOut(4).build(), - "denseCentre3") - .addVertex("subsetLeft", new SubsetVertex(0, 3), "denseCentre1") - .addLayer("denseLeft0", new DenseLayer.Builder().nIn(4).nOut(5).build(), - "subsetLeft") - .addLayer("outLeft", - new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(5) - .nOut(6).build(), - "denseLeft0") - .addLayer("denseRight", new DenseLayer.Builder().nIn(7).nOut(7).build(), - "denseCentre2") - .addLayer("denseRight0", new DenseLayer.Builder().nIn(2).nOut(3).build(), - "inRight") - .addVertex("mergeRight", new MergeVertex(), "denseRight", "denseRight0") - .addLayer("denseRight1", new DenseLayer.Builder().nIn(10).nOut(5).build(), - "mergeRight") - .addLayer("outRight", - new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(5) - .nOut(5).build(), - "denseRight1") - .setOutputs("outLeft", "outCentre", "outRight").build(); + ComputationGraphConfiguration expectedConf = // inputs are in sorted order + overallConf.graphBuilder().addInputs("denseCentre1", "denseCentre2", "inRight").addLayer("denseCentre3", new DenseLayer.Builder().nIn(7).nOut(7).build(), "denseCentre2").addLayer("outCentre", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(7).nOut(4).build(), "denseCentre3").addVertex("subsetLeft", new SubsetVertex(0, 3), "denseCentre1").addLayer("denseLeft0", new DenseLayer.Builder().nIn(4).nOut(5).build(), "subsetLeft").addLayer("outLeft", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(5).nOut(6).build(), "denseLeft0").addLayer("denseRight", new DenseLayer.Builder().nIn(7).nOut(7).build(), "denseCentre2").addLayer("denseRight0", new DenseLayer.Builder().nIn(2).nOut(3).build(), "inRight").addVertex("mergeRight", new MergeVertex(), "denseRight", "denseRight0").addLayer("denseRight1", new DenseLayer.Builder().nIn(10).nOut(5).build(), "mergeRight").addLayer("outRight", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(5).nOut(5).build(), "denseRight1").setOutputs("outLeft", "outCentre", "outRight").build(); ComputationGraph expectedModel = new ComputationGraph(expectedConf); expectedModel.init(); assertEquals(expectedConf.toJson(), modelSubset.getConfiguration().toJson()); } @Test - public void testFitUnFrozen() { - - NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.9)).seed(124) - .activation(Activation.IDENTITY) - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT); - - ComputationGraphConfiguration conf = overallConf.graphBuilder().addInputs("inCentre", "inRight") - .addLayer("denseCentre0", new DenseLayer.Builder().nIn(10).nOut(9).build(), "inCentre") - .addLayer("denseCentre1", new DenseLayer.Builder().nIn(9).nOut(8).build(), "denseCentre0") - .addLayer("denseCentre2", new DenseLayer.Builder().nIn(8).nOut(7).build(), "denseCentre1") - .addLayer("denseCentre3", new DenseLayer.Builder().nIn(7).nOut(7).build(), "denseCentre2") - .addLayer("outCentre", - new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(7).nOut(4).build(), - "denseCentre3") - .addVertex("subsetLeft", new SubsetVertex(0, 3), "denseCentre1") - .addLayer("denseLeft0", new DenseLayer.Builder().nIn(4).nOut(5).build(), "subsetLeft") - .addLayer("outLeft", - new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(5).nOut(6).build(), - "denseLeft0") - .addLayer("denseRight", new DenseLayer.Builder().nIn(7).nOut(7).build(), "denseCentre2") - .addLayer("denseRight0", new DenseLayer.Builder().nIn(2).nOut(3).build(), "inRight") - .addVertex("mergeRight", new MergeVertex(), "denseRight", "denseRight0") - .addLayer("denseRight1", new DenseLayer.Builder().nIn(10).nOut(5).build(), "mergeRight") - .addLayer("outRight", - new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(5).nOut(5).build(), - "denseRight1") - .setOutputs("outLeft", "outCentre", "outRight").build(); - + @DisplayName("Test Fit Un Frozen") + void testFitUnFrozen() { + NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.9)).seed(124).activation(Activation.IDENTITY).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT); + ComputationGraphConfiguration conf = overallConf.graphBuilder().addInputs("inCentre", "inRight").addLayer("denseCentre0", new DenseLayer.Builder().nIn(10).nOut(9).build(), "inCentre").addLayer("denseCentre1", new DenseLayer.Builder().nIn(9).nOut(8).build(), "denseCentre0").addLayer("denseCentre2", new DenseLayer.Builder().nIn(8).nOut(7).build(), "denseCentre1").addLayer("denseCentre3", new DenseLayer.Builder().nIn(7).nOut(7).build(), "denseCentre2").addLayer("outCentre", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(7).nOut(4).build(), "denseCentre3").addVertex("subsetLeft", new SubsetVertex(0, 3), "denseCentre1").addLayer("denseLeft0", new DenseLayer.Builder().nIn(4).nOut(5).build(), "subsetLeft").addLayer("outLeft", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(5).nOut(6).build(), "denseLeft0").addLayer("denseRight", new DenseLayer.Builder().nIn(7).nOut(7).build(), "denseCentre2").addLayer("denseRight0", new DenseLayer.Builder().nIn(2).nOut(3).build(), "inRight").addVertex("mergeRight", new MergeVertex(), "denseRight", "denseRight0").addLayer("denseRight1", new DenseLayer.Builder().nIn(10).nOut(5).build(), "mergeRight").addLayer("outRight", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(5).nOut(5).build(), "denseRight1").setOutputs("outLeft", "outCentre", "outRight").build(); ComputationGraph modelToTune = new ComputationGraph(conf); modelToTune.init(); - INDArray inRight = Nd4j.rand(10, 2); INDArray inCentre = Nd4j.rand(10, 10); INDArray outLeft = Nd4j.rand(10, 6); INDArray outRight = Nd4j.rand(10, 5); INDArray outCentre = Nd4j.rand(10, 4); - MultiDataSet origData = new MultiDataSet(new INDArray[] {inCentre, inRight}, - new INDArray[] {outLeft, outCentre, outRight}); + MultiDataSet origData = new MultiDataSet(new INDArray[] { inCentre, inRight }, new INDArray[] { outLeft, outCentre, outRight }); ComputationGraph modelIdentical = modelToTune.clone(); modelIdentical.getVertex("denseCentre0").setLayerAsFrozen(); modelIdentical.getVertex("denseCentre1").setLayerAsFrozen(); modelIdentical.getVertex("denseCentre2").setLayerAsFrozen(); - TransferLearningHelper helper = new TransferLearningHelper(modelToTune, "denseCentre2"); MultiDataSet featurizedDataSet = helper.featurize(origData); - assertEquals(modelIdentical.getLayer("denseRight0").params(), modelToTune.getLayer("denseRight0").params()); modelIdentical.fit(origData); helper.fitFeaturized(featurizedDataSet); - assertEquals(modelIdentical.getLayer("denseCentre0").params(), modelToTune.getLayer("denseCentre0").params()); assertEquals(modelIdentical.getLayer("denseCentre1").params(), modelToTune.getLayer("denseCentre1").params()); assertEquals(modelIdentical.getLayer("denseCentre2").params(), modelToTune.getLayer("denseCentre2").params()); assertEquals(modelIdentical.getLayer("denseCentre3").params(), modelToTune.getLayer("denseCentre3").params()); assertEquals(modelIdentical.getLayer("outCentre").params(), modelToTune.getLayer("outCentre").params()); - assertEquals(modelIdentical.getLayer("denseRight").conf().toJson(), - modelToTune.getLayer("denseRight").conf().toJson()); + assertEquals(modelIdentical.getLayer("denseRight").conf().toJson(), modelToTune.getLayer("denseRight").conf().toJson()); assertEquals(modelIdentical.getLayer("denseRight").params(), modelToTune.getLayer("denseRight").params()); - assertEquals(modelIdentical.getLayer("denseRight0").conf().toJson(), - modelToTune.getLayer("denseRight0").conf().toJson()); - //assertEquals(modelIdentical.getLayer("denseRight0").params(),modelToTune.getLayer("denseRight0").params()); + assertEquals(modelIdentical.getLayer("denseRight0").conf().toJson(), modelToTune.getLayer("denseRight0").conf().toJson()); + // assertEquals(modelIdentical.getLayer("denseRight0").params(),modelToTune.getLayer("denseRight0").params()); assertEquals(modelIdentical.getLayer("denseRight1").params(), modelToTune.getLayer("denseRight1").params()); assertEquals(modelIdentical.getLayer("outRight").params(), modelToTune.getLayer("outRight").params()); assertEquals(modelIdentical.getLayer("denseLeft0").params(), modelToTune.getLayer("denseLeft0").params()); assertEquals(modelIdentical.getLayer("outLeft").params(), modelToTune.getLayer("outLeft").params()); - -// log.info(modelIdentical.summary()); -// log.info(helper.unfrozenGraph().summary()); + // log.info(modelIdentical.summary()); + // log.info(helper.unfrozenGraph().summary()); modelIdentical.summary(); helper.unfrozenGraph().summary(); } @Test - public void testMLN() { + @DisplayName("Test MLN") + void testMLN() { DataSet randomData = new DataSet(Nd4j.rand(10, 4), Nd4j.rand(10, 3)); - - NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)) - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .activation(Activation.IDENTITY); - - MultiLayerNetwork modelToFineTune = new MultiLayerNetwork(overallConf.clone().list() - .layer(0, new DenseLayer.Builder().nIn(4).nOut(3).build()) - .layer(1, new DenseLayer.Builder().nIn(3).nOut(2).build()) - .layer(2, new DenseLayer.Builder().nIn(2).nOut(3).build()) - .layer(3, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3) - .build()) - .build()); - + NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).activation(Activation.IDENTITY); + MultiLayerNetwork modelToFineTune = new MultiLayerNetwork(overallConf.clone().list().layer(0, new DenseLayer.Builder().nIn(4).nOut(3).build()).layer(1, new DenseLayer.Builder().nIn(3).nOut(2).build()).layer(2, new DenseLayer.Builder().nIn(2).nOut(3).build()).layer(3, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3).build()).build()); modelToFineTune.init(); MultiLayerNetwork modelNow = new TransferLearning.Builder(modelToFineTune).setFeatureExtractor(1).build(); List ff = modelToFineTune.feedForwardToLayer(2, randomData.getFeatures(), false); INDArray asFrozenFeatures = ff.get(2); - TransferLearningHelper helper = new TransferLearningHelper(modelToFineTune, 1); - - INDArray paramsLastTwoLayers = - Nd4j.hstack(modelToFineTune.getLayer(2).params(), modelToFineTune.getLayer(3).params()); - MultiLayerNetwork notFrozen = new MultiLayerNetwork(overallConf.clone().list() - .layer(0, new DenseLayer.Builder().nIn(2).nOut(3).build()) - .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3) - .build()) - .build(), paramsLastTwoLayers); - + INDArray paramsLastTwoLayers = Nd4j.hstack(modelToFineTune.getLayer(2).params(), modelToFineTune.getLayer(3).params()); + MultiLayerNetwork notFrozen = new MultiLayerNetwork(overallConf.clone().list().layer(0, new DenseLayer.Builder().nIn(2).nOut(3).build()).layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3).build()).build(), paramsLastTwoLayers); assertEquals(asFrozenFeatures, helper.featurize(randomData).getFeatures()); assertEquals(randomData.getLabels(), helper.featurize(randomData).getLabels()); - for (int i = 0; i < 5; i++) { notFrozen.fit(new DataSet(asFrozenFeatures, randomData.getLabels())); helper.fitFeaturized(helper.featurize(randomData)); modelNow.fit(randomData); } - - INDArray expected = Nd4j.hstack(modelToFineTune.getLayer(0).params(), modelToFineTune.getLayer(1).params(), - notFrozen.params()); + INDArray expected = Nd4j.hstack(modelToFineTune.getLayer(0).params(), modelToFineTune.getLayer(1).params(), notFrozen.params()); INDArray act = modelNow.params(); assertEquals(expected, act); } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningMLNTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningMLNTest.java index 64478feb4..9417abcdd 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningMLNTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningMLNTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.transferlearning; import lombok.extern.slf4j.Slf4j; @@ -43,7 +42,7 @@ import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInitDistribution; import org.deeplearning4j.nn.weights.WeightInitRelu; import org.deeplearning4j.nn.weights.WeightInitXavier; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -52,71 +51,43 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.*; import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.shade.jackson.core.JsonProcessingException; - import java.util.Map; - -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; @Slf4j -public class TransferLearningMLNTest extends BaseDL4JTest { +@DisplayName("Transfer Learning MLN Test") +class TransferLearningMLNTest extends BaseDL4JTest { @Test - public void simpleFineTune() { - + @DisplayName("Simple Fine Tune") + void simpleFineTune() { long rng = 12345L; Nd4j.getRandom().setSeed(rng); DataSet randomData = new DataSet(Nd4j.rand(DataType.FLOAT, 10, 4), TestUtils.randomOneHot(DataType.FLOAT, 10, 3)); - //original conf - NeuralNetConfiguration.Builder confToChange = - new NeuralNetConfiguration.Builder().seed(rng).optimizationAlgo(OptimizationAlgorithm.LBFGS) - .updater(new Nesterovs(0.01, 0.99)); - - MultiLayerNetwork modelToFineTune = new MultiLayerNetwork(confToChange.list() - .layer(0, new DenseLayer.Builder().nIn(4).nOut(3).build()) - .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3) - .build()) - .build()); + // original conf + NeuralNetConfiguration.Builder confToChange = new NeuralNetConfiguration.Builder().seed(rng).optimizationAlgo(OptimizationAlgorithm.LBFGS).updater(new Nesterovs(0.01, 0.99)); + MultiLayerNetwork modelToFineTune = new MultiLayerNetwork(confToChange.list().layer(0, new DenseLayer.Builder().nIn(4).nOut(3).build()).layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3).build()).build()); modelToFineTune.init(); - - //model after applying changes with transfer learning - MultiLayerNetwork modelNow = - new TransferLearning.Builder(modelToFineTune) - .fineTuneConfiguration(new FineTuneConfiguration.Builder().seed(rng) - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .updater(new RmsProp(0.5)) //Intent: override both weight and bias LR, unless bias LR is manually set also - .l2(0.4).build()) - .build(); - + // model after applying changes with transfer learning + MultiLayerNetwork modelNow = new TransferLearning.Builder(modelToFineTune).fineTuneConfiguration(new FineTuneConfiguration.Builder().seed(rng).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(// Intent: override both weight and bias LR, unless bias LR is manually set also + new RmsProp(0.5)).l2(0.4).build()).build(); for (org.deeplearning4j.nn.api.Layer l : modelNow.getLayers()) { BaseLayer bl = ((BaseLayer) l.conf().getLayer()); assertEquals(new RmsProp(0.5), bl.getIUpdater()); } - - - NeuralNetConfiguration.Builder confSet = new NeuralNetConfiguration.Builder().seed(rng) - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .updater(new RmsProp(0.5)).l2(0.4); - - MultiLayerNetwork expectedModel = new MultiLayerNetwork(confSet.list() - .layer(0, new DenseLayer.Builder().nIn(4).nOut(3).build()) - .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3) - .build()) - .build()); + NeuralNetConfiguration.Builder confSet = new NeuralNetConfiguration.Builder().seed(rng).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new RmsProp(0.5)).l2(0.4); + MultiLayerNetwork expectedModel = new MultiLayerNetwork(confSet.list().layer(0, new DenseLayer.Builder().nIn(4).nOut(3).build()).layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3).build()).build()); expectedModel.init(); expectedModel.setParams(modelToFineTune.params().dup()); - assertEquals(expectedModel.params(), modelNow.params()); - - //Check json + // Check json MultiLayerConfiguration expectedConf = expectedModel.getLayerWiseConfigurations(); assertEquals(expectedConf.toJson(), modelNow.getLayerWiseConfigurations().toJson()); - - //Check params after fit + // Check params after fit modelNow.fit(randomData); expectedModel.fit(randomData); - assertEquals(modelNow.score(), expectedModel.score(), 1e-6); INDArray pExp = expectedModel.params(); INDArray pNow = modelNow.params(); @@ -124,115 +95,64 @@ public class TransferLearningMLNTest extends BaseDL4JTest { } @Test - public void testNoutChanges() { + @DisplayName("Test Nout Changes") + void testNoutChanges() { Nd4j.getRandom().setSeed(12345); - DataSet randomData = new DataSet(Nd4j.rand(DataType.FLOAT, 10, 4), TestUtils.randomOneHot(DataType.FLOAT,10, 2)); - + DataSet randomData = new DataSet(Nd4j.rand(DataType.FLOAT, 10, 4), TestUtils.randomOneHot(DataType.FLOAT, 10, 2)); NeuralNetConfiguration.Builder equivalentConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)); - FineTuneConfiguration overallConf = new FineTuneConfiguration.Builder().updater(new Sgd(0.1)) - .build(); - - MultiLayerNetwork modelToFineTune = new MultiLayerNetwork(equivalentConf.list() - .layer(0, new DenseLayer.Builder().nIn(4).nOut(5).build()) - .layer(1, new DenseLayer.Builder().nIn(3).nOut(2).build()) - .layer(2, new DenseLayer.Builder().nIn(2).nOut(3).build()) - .layer(3, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3) - .build()) - .build()); + FineTuneConfiguration overallConf = new FineTuneConfiguration.Builder().updater(new Sgd(0.1)).build(); + MultiLayerNetwork modelToFineTune = new MultiLayerNetwork(equivalentConf.list().layer(0, new DenseLayer.Builder().nIn(4).nOut(5).build()).layer(1, new DenseLayer.Builder().nIn(3).nOut(2).build()).layer(2, new DenseLayer.Builder().nIn(2).nOut(3).build()).layer(3, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3).build()).build()); modelToFineTune.init(); - MultiLayerNetwork modelNow = new TransferLearning.Builder(modelToFineTune).fineTuneConfiguration(overallConf) - .nOutReplace(3, 2, WeightInit.XAVIER, WeightInit.XAVIER) - .nOutReplace(0, 3, WeightInit.XAVIER, new NormalDistribution(1, 1e-1)).build(); - - MultiLayerNetwork modelExpectedArch = new MultiLayerNetwork(equivalentConf.list() - .layer(0, new DenseLayer.Builder().nIn(4).nOut(3).build()) - .layer(1, new DenseLayer.Builder().nIn(3).nOut(2).build()) - .layer(2, new DenseLayer.Builder().nIn(2).nOut(3).build()) - .layer(3, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(2) - .build()) - .build()); + MultiLayerNetwork modelNow = new TransferLearning.Builder(modelToFineTune).fineTuneConfiguration(overallConf).nOutReplace(3, 2, WeightInit.XAVIER, WeightInit.XAVIER).nOutReplace(0, 3, WeightInit.XAVIER, new NormalDistribution(1, 1e-1)).build(); + MultiLayerNetwork modelExpectedArch = new MultiLayerNetwork(equivalentConf.list().layer(0, new DenseLayer.Builder().nIn(4).nOut(3).build()).layer(1, new DenseLayer.Builder().nIn(3).nOut(2).build()).layer(2, new DenseLayer.Builder().nIn(2).nOut(3).build()).layer(3, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(2).build()).build()); modelExpectedArch.init(); - - //Will fail - expected because of dist and weight init changes - //assertEquals(modelExpectedArch.getLayerWiseConfigurations().toJson(), modelNow.getLayerWiseConfigurations().toJson()); - + // Will fail - expected because of dist and weight init changes + // assertEquals(modelExpectedArch.getLayerWiseConfigurations().toJson(), modelNow.getLayerWiseConfigurations().toJson()); BaseLayer bl0 = ((BaseLayer) modelNow.getLayerWiseConfigurations().getConf(0).getLayer()); BaseLayer bl1 = ((BaseLayer) modelNow.getLayerWiseConfigurations().getConf(1).getLayer()); BaseLayer bl3 = ((BaseLayer) modelNow.getLayerWiseConfigurations().getConf(3).getLayer()); assertEquals(bl0.getWeightInitFn().getClass(), WeightInitXavier.class); try { - assertEquals(JsonMappers.getMapper().writeValueAsString(bl1.getWeightInitFn()), - JsonMappers.getMapper().writeValueAsString(new WeightInitDistribution(new NormalDistribution(1, 1e-1)))); + assertEquals(JsonMappers.getMapper().writeValueAsString(bl1.getWeightInitFn()), JsonMappers.getMapper().writeValueAsString(new WeightInitDistribution(new NormalDistribution(1, 1e-1)))); } catch (JsonProcessingException e) { throw new RuntimeException(e); } assertEquals(bl3.getWeightInitFn(), new WeightInitXavier()); - - //modelNow should have the same architecture as modelExpectedArch + // modelNow should have the same architecture as modelExpectedArch assertArrayEquals(modelExpectedArch.params().shape(), modelNow.params().shape()); assertArrayEquals(modelExpectedArch.getLayer(0).params().shape(), modelNow.getLayer(0).params().shape()); assertArrayEquals(modelExpectedArch.getLayer(1).params().shape(), modelNow.getLayer(1).params().shape()); assertArrayEquals(modelExpectedArch.getLayer(2).params().shape(), modelNow.getLayer(2).params().shape()); assertArrayEquals(modelExpectedArch.getLayer(3).params().shape(), modelNow.getLayer(3).params().shape()); - modelNow.setParams(modelExpectedArch.params()); - //fit should give the same results + // fit should give the same results modelExpectedArch.fit(randomData); modelNow.fit(randomData); assertEquals(modelExpectedArch.score(), modelNow.score(), 0.000001); assertEquals(modelExpectedArch.params(), modelNow.params()); } - @Test - public void testRemoveAndAdd() { + @DisplayName("Test Remove And Add") + void testRemoveAndAdd() { Nd4j.getRandom().setSeed(12345); - DataSet randomData = new DataSet(Nd4j.rand(DataType.FLOAT,10, 4), TestUtils.randomOneHot(DataType.FLOAT, 10, 3)); - + DataSet randomData = new DataSet(Nd4j.rand(DataType.FLOAT, 10, 4), TestUtils.randomOneHot(DataType.FLOAT, 10, 3)); NeuralNetConfiguration.Builder equivalentConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)); FineTuneConfiguration overallConf = new FineTuneConfiguration.Builder().updater(new Sgd(0.1)).build(); - - MultiLayerNetwork modelToFineTune = new MultiLayerNetwork(//overallConf.list() - equivalentConf.list().layer(0, new DenseLayer.Builder().nIn(4).nOut(5).build()) - .layer(1, new DenseLayer.Builder().nIn(5).nOut(2).build()) - .layer(2, new DenseLayer.Builder().nIn(2).nOut(3).build()) - .layer(3, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nIn(3).nOut(3) - .build()) - .build()); + MultiLayerNetwork modelToFineTune = new // overallConf.list() + MultiLayerNetwork(equivalentConf.list().layer(0, new DenseLayer.Builder().nIn(4).nOut(5).build()).layer(1, new DenseLayer.Builder().nIn(5).nOut(2).build()).layer(2, new DenseLayer.Builder().nIn(2).nOut(3).build()).layer(3, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3).build()).build()); modelToFineTune.init(); - - MultiLayerNetwork modelNow = - new TransferLearning.Builder(modelToFineTune).fineTuneConfiguration(overallConf) - .nOutReplace(0, 7, WeightInit.XAVIER, WeightInit.XAVIER) - .nOutReplace(2, 5, WeightInit.XAVIER).removeOutputLayer() - .addLayer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(5) - .nOut(3).updater(new Sgd(0.5)).activation(Activation.SOFTMAX) - .build()) - .build(); - - MultiLayerNetwork modelExpectedArch = new MultiLayerNetwork(equivalentConf.list() - .layer(0, new DenseLayer.Builder().nIn(4).nOut(7).build()) - .layer(1, new DenseLayer.Builder().nIn(7).nOut(2).build()) - .layer(2, new DenseLayer.Builder().nIn(2).nOut(5).build()) - .layer(3, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX) - .updater(new Sgd(0.5)).nIn(5).nOut(3).build()) - .build()); + MultiLayerNetwork modelNow = new TransferLearning.Builder(modelToFineTune).fineTuneConfiguration(overallConf).nOutReplace(0, 7, WeightInit.XAVIER, WeightInit.XAVIER).nOutReplace(2, 5, WeightInit.XAVIER).removeOutputLayer().addLayer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(5).nOut(3).updater(new Sgd(0.5)).activation(Activation.SOFTMAX).build()).build(); + MultiLayerNetwork modelExpectedArch = new MultiLayerNetwork(equivalentConf.list().layer(0, new DenseLayer.Builder().nIn(4).nOut(7).build()).layer(1, new DenseLayer.Builder().nIn(7).nOut(2).build()).layer(2, new DenseLayer.Builder().nIn(2).nOut(5).build()).layer(3, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).updater(new Sgd(0.5)).nIn(5).nOut(3).build()).build()); modelExpectedArch.init(); - - //modelNow should have the same architecture as modelExpectedArch + // modelNow should have the same architecture as modelExpectedArch assertArrayEquals(modelExpectedArch.params().shape(), modelNow.params().shape()); assertArrayEquals(modelExpectedArch.getLayer(0).params().shape(), modelNow.getLayer(0).params().shape()); assertArrayEquals(modelExpectedArch.getLayer(1).params().shape(), modelNow.getLayer(1).params().shape()); assertArrayEquals(modelExpectedArch.getLayer(2).params().shape(), modelNow.getLayer(2).params().shape()); assertArrayEquals(modelExpectedArch.getLayer(3).params().shape(), modelNow.getLayer(3).params().shape()); - modelNow.setParams(modelExpectedArch.params()); - //fit should give the same results + // fit should give the same results modelExpectedArch.fit(randomData); modelNow.fit(randomData); double scoreExpected = modelExpectedArch.score(); @@ -242,218 +162,67 @@ public class TransferLearningMLNTest extends BaseDL4JTest { } @Test - public void testRemoveAndProcessing() { - + @DisplayName("Test Remove And Processing") + void testRemoveAndProcessing() { int V_WIDTH = 130; int V_HEIGHT = 130; int V_NFRAMES = 150; - - MultiLayerConfiguration confForArchitecture = - new NeuralNetConfiguration.Builder().seed(12345).l2(0.001) //l2 regularization on all layers - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .updater(new AdaGrad(0.4)).list() - .layer(0, new ConvolutionLayer.Builder(10, 10).nIn(3) //3 channels: RGB - .nOut(30).stride(4, 4).activation(Activation.RELU).weightInit( - WeightInit.RELU).build()) //Output: (130-10+0)/4+1 = 31 -> 31*31*30 - .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX) - .kernelSize(3, 3).stride(2, 2).build()) //(31-3+0)/2+1 = 15 - .layer(2, new ConvolutionLayer.Builder(3, 3).nIn(30).nOut(10).stride(2, 2) - .activation(Activation.RELU).weightInit(WeightInit.RELU) - .build()) //Output: (15-3+0)/2+1 = 7 -> 7*7*10 = 490 - .layer(3, new DenseLayer.Builder().activation(Activation.RELU).nIn(490).nOut(50) - .weightInit(WeightInit.RELU).updater(new AdaGrad(0.5)) - .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) - .gradientNormalizationThreshold(10).build()) - .layer(4, new GravesLSTM.Builder().activation(Activation.SOFTSIGN).nIn(50) - .nOut(50).weightInit(WeightInit.XAVIER).updater(new AdaGrad(0.6)) - .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) - .gradientNormalizationThreshold(10).build()) - .layer(5, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nIn(50).nOut(4) //4 possible shapes: circle, square, arc, line - .weightInit(WeightInit.XAVIER) - .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) - .gradientNormalizationThreshold(10).build()) - .inputPreProcessor(0, new RnnToCnnPreProcessor(V_HEIGHT, V_WIDTH, 3)) - .inputPreProcessor(3, new CnnToFeedForwardPreProcessor(7, 7, 10)) - .inputPreProcessor(4, new FeedForwardToRnnPreProcessor()) - .backpropType(BackpropType.TruncatedBPTT) - .tBPTTForwardLength(V_NFRAMES / 5).tBPTTBackwardLength(V_NFRAMES / 5).build(); + MultiLayerConfiguration confForArchitecture = // l2 regularization on all layers + new NeuralNetConfiguration.Builder().seed(12345).l2(0.001).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new AdaGrad(0.4)).list().layer(0, // 3 channels: RGB + new ConvolutionLayer.Builder(10, 10).nIn(3).nOut(30).stride(4, 4).activation(Activation.RELU).weightInit(WeightInit.RELU).build()).layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(3, 3).stride(2, 2).build()).layer(2, new ConvolutionLayer.Builder(3, 3).nIn(30).nOut(10).stride(2, 2).activation(Activation.RELU).weightInit(WeightInit.RELU).build()).layer(3, new DenseLayer.Builder().activation(Activation.RELU).nIn(490).nOut(50).weightInit(WeightInit.RELU).updater(new AdaGrad(0.5)).gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue).gradientNormalizationThreshold(10).build()).layer(4, new GravesLSTM.Builder().activation(Activation.SOFTSIGN).nIn(50).nOut(50).weightInit(WeightInit.XAVIER).updater(new AdaGrad(0.6)).gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue).gradientNormalizationThreshold(10).build()).layer(5, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(50).nOut(// 4 possible shapes: circle, square, arc, line + 4).weightInit(WeightInit.XAVIER).gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue).gradientNormalizationThreshold(10).build()).inputPreProcessor(0, new RnnToCnnPreProcessor(V_HEIGHT, V_WIDTH, 3)).inputPreProcessor(3, new CnnToFeedForwardPreProcessor(7, 7, 10)).inputPreProcessor(4, new FeedForwardToRnnPreProcessor()).backpropType(BackpropType.TruncatedBPTT).tBPTTForwardLength(V_NFRAMES / 5).tBPTTBackwardLength(V_NFRAMES / 5).build(); MultiLayerNetwork modelExpectedArch = new MultiLayerNetwork(confForArchitecture); modelExpectedArch.init(); - - MultiLayerNetwork modelToTweak = - new MultiLayerNetwork( - new NeuralNetConfiguration.Builder().seed(12345) - .updater(new RmsProp(0.1)) - .list() - .layer(0, new ConvolutionLayer.Builder(10, 10) //Only keep the first layer the same - .nIn(3) //3 channels: RGB - .nOut(30).stride(4, 4) - .activation(Activation.RELU) - .weightInit(WeightInit.RELU) - .updater(new AdaGrad(0.1)).build()) //Output: (130-10+0)/4+1 = 31 -> 31*31*30 - .layer(1, new SubsamplingLayer.Builder( - SubsamplingLayer.PoolingType.MAX) //change kernel size - .kernelSize(5, 5).stride(2, 2) - .build()) //(31-5+0)/2+1 = 14 - .layer(2, new ConvolutionLayer.Builder(6, 6) //change here - .nIn(30).nOut(10).stride(2, 2) - .activation(Activation.RELU) - .weightInit(WeightInit.RELU).build()) //Output: (14-6+0)/2+1 = 5 -> 5*5*10 = 250 - .layer(3, new DenseLayer.Builder() //change here - .activation(Activation.RELU).nIn(250).nOut(50) - .weightInit(WeightInit.RELU) - .gradientNormalization( - GradientNormalization.ClipElementWiseAbsoluteValue) - .gradientNormalizationThreshold(10) - .updater(new RmsProp(0.01)).build()) - .layer(4, new GravesLSTM.Builder() //change here - .activation(Activation.SOFTSIGN).nIn(50) - .nOut(25).weightInit(WeightInit.XAVIER) - .build()) - .layer(5, new RnnOutputLayer.Builder( - LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX) - .nIn(25).nOut(4) - .weightInit(WeightInit.XAVIER) - .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) - .gradientNormalizationThreshold(10) - .build()) - .inputPreProcessor(0,new RnnToCnnPreProcessor(V_HEIGHT, V_WIDTH, 3)) - .inputPreProcessor(3,new CnnToFeedForwardPreProcessor(5, 5, 10)) - .inputPreProcessor(4, new FeedForwardToRnnPreProcessor()) - - .backpropType(BackpropType.TruncatedBPTT) - .tBPTTForwardLength(V_NFRAMES / 5) - .tBPTTBackwardLength(V_NFRAMES / 5).build()); + MultiLayerNetwork modelToTweak = new MultiLayerNetwork(new NeuralNetConfiguration.Builder().seed(12345).updater(new RmsProp(0.1)).list().layer(0, // Only keep the first layer the same + new ConvolutionLayer.Builder(10, 10).nIn(// 3 channels: RGB + 3).nOut(30).stride(4, 4).activation(Activation.RELU).weightInit(WeightInit.RELU).updater(new AdaGrad(0.1)).build()).layer(1, new SubsamplingLayer.Builder(// change kernel size + SubsamplingLayer.PoolingType.MAX).kernelSize(5, 5).stride(2, 2).build()).layer(2, // change here + new ConvolutionLayer.Builder(6, 6).nIn(30).nOut(10).stride(2, 2).activation(Activation.RELU).weightInit(WeightInit.RELU).build()).layer(3, // change here + new DenseLayer.Builder().activation(Activation.RELU).nIn(250).nOut(50).weightInit(WeightInit.RELU).gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue).gradientNormalizationThreshold(10).updater(new RmsProp(0.01)).build()).layer(4, // change here + new GravesLSTM.Builder().activation(Activation.SOFTSIGN).nIn(50).nOut(25).weightInit(WeightInit.XAVIER).build()).layer(5, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(25).nOut(4).weightInit(WeightInit.XAVIER).gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue).gradientNormalizationThreshold(10).build()).inputPreProcessor(0, new RnnToCnnPreProcessor(V_HEIGHT, V_WIDTH, 3)).inputPreProcessor(3, new CnnToFeedForwardPreProcessor(5, 5, 10)).inputPreProcessor(4, new FeedForwardToRnnPreProcessor()).backpropType(BackpropType.TruncatedBPTT).tBPTTForwardLength(V_NFRAMES / 5).tBPTTBackwardLength(V_NFRAMES / 5).build()); modelToTweak.init(); - - MultiLayerNetwork modelNow = new TransferLearning.Builder(modelToTweak) - .fineTuneConfiguration( - new FineTuneConfiguration.Builder().seed(12345).l2(0.001) //l2 regularization on all layers - .updater(new AdaGrad(0.4)) - .weightInit(WeightInit.RELU).build()) - .removeLayersFromOutput(5) - .addLayer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(3, 3) - .stride(2, 2).build()) - .addLayer(new ConvolutionLayer.Builder(3, 3).nIn(30).nOut(10).stride(2, 2) - .activation(Activation.RELU).weightInit(WeightInit.RELU).build()) - .addLayer(new DenseLayer.Builder().activation(Activation.RELU).nIn(490).nOut(50) - .weightInit(WeightInit.RELU).updater(new AdaGrad(0.5)) - .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) - .gradientNormalizationThreshold(10).build()) - .addLayer(new GravesLSTM.Builder().activation(Activation.SOFTSIGN).nIn(50).nOut(50) - .weightInit(WeightInit.XAVIER).updater(new AdaGrad(0.6)) - .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) - .gradientNormalizationThreshold(10).build()) - .addLayer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nIn(50).nOut(4) //4 possible shapes: circle, square, arc, line - .weightInit(WeightInit.XAVIER) - .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) - .gradientNormalizationThreshold(10).build()) - .setInputPreProcessor(3, new CnnToFeedForwardPreProcessor(7, 7, 10)) - .setInputPreProcessor(4, new FeedForwardToRnnPreProcessor()).build(); - - //modelNow should have the same architecture as modelExpectedArch - assertEquals(modelExpectedArch.getLayerWiseConfigurations().getConf(0).toJson(), - modelNow.getLayerWiseConfigurations().getConf(0).toJson()); - //some learning related info the subsampling layer will not be overwritten - //assertTrue(modelExpectedArch.getLayerWiseConfigurations().getConf(1).toJson().equals(modelNow.getLayerWiseConfigurations().getConf(1).toJson())); - assertEquals(modelExpectedArch.getLayerWiseConfigurations().getConf(2).toJson(), - modelNow.getLayerWiseConfigurations().getConf(2).toJson()); - assertEquals(modelExpectedArch.getLayerWiseConfigurations().getConf(3).toJson(), - modelNow.getLayerWiseConfigurations().getConf(3).toJson()); - assertEquals(modelExpectedArch.getLayerWiseConfigurations().getConf(4).toJson(), - modelNow.getLayerWiseConfigurations().getConf(4).toJson()); - assertEquals(modelExpectedArch.getLayerWiseConfigurations().getConf(5).toJson(), - modelNow.getLayerWiseConfigurations().getConf(5).toJson()); - + MultiLayerNetwork modelNow = new TransferLearning.Builder(modelToTweak).fineTuneConfiguration(// l2 regularization on all layers + new FineTuneConfiguration.Builder().seed(12345).l2(0.001).updater(new AdaGrad(0.4)).weightInit(WeightInit.RELU).build()).removeLayersFromOutput(5).addLayer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(3, 3).stride(2, 2).build()).addLayer(new ConvolutionLayer.Builder(3, 3).nIn(30).nOut(10).stride(2, 2).activation(Activation.RELU).weightInit(WeightInit.RELU).build()).addLayer(new DenseLayer.Builder().activation(Activation.RELU).nIn(490).nOut(50).weightInit(WeightInit.RELU).updater(new AdaGrad(0.5)).gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue).gradientNormalizationThreshold(10).build()).addLayer(new GravesLSTM.Builder().activation(Activation.SOFTSIGN).nIn(50).nOut(50).weightInit(WeightInit.XAVIER).updater(new AdaGrad(0.6)).gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue).gradientNormalizationThreshold(10).build()).addLayer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(50).nOut(// 4 possible shapes: circle, square, arc, line + 4).weightInit(WeightInit.XAVIER).gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue).gradientNormalizationThreshold(10).build()).setInputPreProcessor(3, new CnnToFeedForwardPreProcessor(7, 7, 10)).setInputPreProcessor(4, new FeedForwardToRnnPreProcessor()).build(); + // modelNow should have the same architecture as modelExpectedArch + assertEquals(modelExpectedArch.getLayerWiseConfigurations().getConf(0).toJson(), modelNow.getLayerWiseConfigurations().getConf(0).toJson()); + // some learning related info the subsampling layer will not be overwritten + // assertTrue(modelExpectedArch.getLayerWiseConfigurations().getConf(1).toJson().equals(modelNow.getLayerWiseConfigurations().getConf(1).toJson())); + assertEquals(modelExpectedArch.getLayerWiseConfigurations().getConf(2).toJson(), modelNow.getLayerWiseConfigurations().getConf(2).toJson()); + assertEquals(modelExpectedArch.getLayerWiseConfigurations().getConf(3).toJson(), modelNow.getLayerWiseConfigurations().getConf(3).toJson()); + assertEquals(modelExpectedArch.getLayerWiseConfigurations().getConf(4).toJson(), modelNow.getLayerWiseConfigurations().getConf(4).toJson()); + assertEquals(modelExpectedArch.getLayerWiseConfigurations().getConf(5).toJson(), modelNow.getLayerWiseConfigurations().getConf(5).toJson()); assertArrayEquals(modelExpectedArch.params().shape(), modelNow.params().shape()); assertArrayEquals(modelExpectedArch.getLayer(0).params().shape(), modelNow.getLayer(0).params().shape()); - //subsampling has no params - //assertArrayEquals(modelExpectedArch.getLayer(1).params().shape(), modelNow.getLayer(1).params().shape()); + // subsampling has no params + // assertArrayEquals(modelExpectedArch.getLayer(1).params().shape(), modelNow.getLayer(1).params().shape()); assertArrayEquals(modelExpectedArch.getLayer(2).params().shape(), modelNow.getLayer(2).params().shape()); assertArrayEquals(modelExpectedArch.getLayer(3).params().shape(), modelNow.getLayer(3).params().shape()); assertArrayEquals(modelExpectedArch.getLayer(4).params().shape(), modelNow.getLayer(4).params().shape()); assertArrayEquals(modelExpectedArch.getLayer(5).params().shape(), modelNow.getLayer(5).params().shape()); - } @Test - public void testAllWithCNN() { + @DisplayName("Test All With CNN") + void testAllWithCNN() { Nd4j.getRandom().setSeed(12345); - - DataSet randomData = new DataSet(Nd4j.rand(DataType.FLOAT, 10, 28 * 28 * 3).reshape(10, 3, 28, 28), TestUtils.randomOneHot(DataType.FLOAT,10, 10)); - MultiLayerNetwork modelToFineTune = - new MultiLayerNetwork( - new NeuralNetConfiguration.Builder().seed(123) - .weightInit(WeightInit.XAVIER) - .updater(new Nesterovs(0.01, 0.9)) - .list() - .layer(0, new ConvolutionLayer.Builder(5, 5).nIn(3).stride(1, 1) - .nOut(20).activation(Activation.IDENTITY) - .build()) - .layer(1, new SubsamplingLayer.Builder( - SubsamplingLayer.PoolingType.MAX) - .kernelSize(2, 2).stride(2, 2) - .build()) - .layer(2, new ConvolutionLayer.Builder(5, 5).stride(1, 1) - .nOut(50).activation(Activation.IDENTITY) - .build()) - .layer(3, new SubsamplingLayer.Builder( - SubsamplingLayer.PoolingType.MAX) - .kernelSize(2, 2).stride(2, 2) - .build()) - .layer(4, new DenseLayer.Builder().activation(Activation.RELU) - .nOut(500).build()) - .layer(5, new DenseLayer.Builder().activation(Activation.RELU) - .nOut(250).build()) - .layer(6, new OutputLayer.Builder( - LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) - .nOut(100) - .activation(Activation.SOFTMAX) - .build()) - .setInputType(InputType.convolutionalFlat(28, 28, 3)) - .build()); + DataSet randomData = new DataSet(Nd4j.rand(DataType.FLOAT, 10, 28 * 28 * 3).reshape(10, 3, 28, 28), TestUtils.randomOneHot(DataType.FLOAT, 10, 10)); + MultiLayerNetwork modelToFineTune = new MultiLayerNetwork(new NeuralNetConfiguration.Builder().seed(123).weightInit(WeightInit.XAVIER).updater(new Nesterovs(0.01, 0.9)).list().layer(0, new ConvolutionLayer.Builder(5, 5).nIn(3).stride(1, 1).nOut(20).activation(Activation.IDENTITY).build()).layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(2, 2).build()).layer(2, new ConvolutionLayer.Builder(5, 5).stride(1, 1).nOut(50).activation(Activation.IDENTITY).build()).layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(2, 2).build()).layer(4, new DenseLayer.Builder().activation(Activation.RELU).nOut(500).build()).layer(5, new DenseLayer.Builder().activation(Activation.RELU).nOut(250).build()).layer(6, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(100).activation(Activation.SOFTMAX).build()).setInputType(InputType.convolutionalFlat(28, 28, 3)).build()); modelToFineTune.init(); - INDArray asFrozenFeatures = modelToFineTune.feedForwardToLayer(2, randomData.getFeatures(), false).get(2); //10x20x12x12 - - NeuralNetConfiguration.Builder equivalentConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.2)) - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT); - - FineTuneConfiguration overallConf = new FineTuneConfiguration.Builder().updater(new Sgd(0.2)) - .build(); - - MultiLayerNetwork modelNow = new TransferLearning.Builder(modelToFineTune).fineTuneConfiguration(overallConf) - .setFeatureExtractor(1).nOutReplace(4, 600, WeightInit.XAVIER).removeLayersFromOutput(2) - .addLayer(new DenseLayer.Builder().activation(Activation.RELU).nIn(600).nOut(300).build()) - .addLayer(new DenseLayer.Builder().activation(Activation.RELU).nIn(300).nOut(150).build()) - .addLayer(new DenseLayer.Builder().activation(Activation.RELU).nIn(150).nOut(50).build()) - .addLayer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) - .activation(Activation.SOFTMAX).nIn(50).nOut(10).build()) - .build(); - - MultiLayerNetwork notFrozen = new MultiLayerNetwork(equivalentConf.list() - .layer(0, new ConvolutionLayer.Builder(5, 5).stride(1, 1).nOut(50) - .activation(Activation.IDENTITY).build()) - .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2) - .stride(2, 2).build()) - .layer(2, new DenseLayer.Builder().activation(Activation.RELU).nOut(600).build()) - .layer(3, new DenseLayer.Builder().activation(Activation.RELU).nOut(300).build()) - .layer(4, new DenseLayer.Builder().activation(Activation.RELU).nOut(150).build()) - .layer(5, new DenseLayer.Builder().activation(Activation.RELU).nOut(50).build()) - .layer(6, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(10) - .activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutionalFlat(12, 12, 20)).build()); + // 10x20x12x12 + INDArray asFrozenFeatures = modelToFineTune.feedForwardToLayer(2, randomData.getFeatures(), false).get(2); + NeuralNetConfiguration.Builder equivalentConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.2)).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT); + FineTuneConfiguration overallConf = new FineTuneConfiguration.Builder().updater(new Sgd(0.2)).build(); + MultiLayerNetwork modelNow = new TransferLearning.Builder(modelToFineTune).fineTuneConfiguration(overallConf).setFeatureExtractor(1).nOutReplace(4, 600, WeightInit.XAVIER).removeLayersFromOutput(2).addLayer(new DenseLayer.Builder().activation(Activation.RELU).nIn(600).nOut(300).build()).addLayer(new DenseLayer.Builder().activation(Activation.RELU).nIn(300).nOut(150).build()).addLayer(new DenseLayer.Builder().activation(Activation.RELU).nIn(150).nOut(50).build()).addLayer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).activation(Activation.SOFTMAX).nIn(50).nOut(10).build()).build(); + MultiLayerNetwork notFrozen = new MultiLayerNetwork(equivalentConf.list().layer(0, new ConvolutionLayer.Builder(5, 5).stride(1, 1).nOut(50).activation(Activation.IDENTITY).build()).layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(2, 2).build()).layer(2, new DenseLayer.Builder().activation(Activation.RELU).nOut(600).build()).layer(3, new DenseLayer.Builder().activation(Activation.RELU).nOut(300).build()).layer(4, new DenseLayer.Builder().activation(Activation.RELU).nOut(150).build()).layer(5, new DenseLayer.Builder().activation(Activation.RELU).nOut(50).build()).layer(6, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(10).activation(Activation.SOFTMAX).build()).setInputType(InputType.convolutionalFlat(12, 12, 20)).build()); notFrozen.init(); - assertArrayEquals(modelToFineTune.getLayer(0).params().shape(), modelNow.getLayer(0).params().shape()); - //subsampling has no params - //assertArrayEquals(modelExpectedArch.getLayer(1).params().shape(), modelNow.getLayer(1).params().shape()); + // subsampling has no params + // assertArrayEquals(modelExpectedArch.getLayer(1).params().shape(), modelNow.getLayer(1).params().shape()); assertArrayEquals(notFrozen.getLayer(0).params().shape(), modelNow.getLayer(2).params().shape()); modelNow.getLayer(2).setParams(notFrozen.getLayer(0).params()); - //subsampling has no params - //assertArrayEquals(notFrozen.getLayer(1).params().shape(), modelNow.getLayer(3).params().shape()); + // subsampling has no params + // assertArrayEquals(notFrozen.getLayer(1).params().shape(), modelNow.getLayer(3).params().shape()); assertArrayEquals(notFrozen.getLayer(2).params().shape(), modelNow.getLayer(4).params().shape()); modelNow.getLayer(4).setParams(notFrozen.getLayer(2).params()); assertArrayEquals(notFrozen.getLayer(3).params().shape(), modelNow.getLayer(5).params().shape()); @@ -464,129 +233,69 @@ public class TransferLearningMLNTest extends BaseDL4JTest { modelNow.getLayer(7).setParams(notFrozen.getLayer(5).params()); assertArrayEquals(notFrozen.getLayer(6).params().shape(), modelNow.getLayer(8).params().shape()); modelNow.getLayer(8).setParams(notFrozen.getLayer(6).params()); - int i = 0; while (i < 3) { notFrozen.fit(new DataSet(asFrozenFeatures, randomData.getLabels())); modelNow.fit(randomData); i++; } - INDArray expectedParams = Nd4j.hstack(modelToFineTune.getLayer(0).params(), notFrozen.params()); assertEquals(expectedParams, modelNow.params()); } - @Test - public void testFineTuneOverride() { - //Check that fine-tune overrides are selective - i.e., if I only specify a new LR, only the LR should be modified - - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder().updater(new Adam(1e-4)) - .activation(Activation.TANH).weightInit(WeightInit.RELU) - .l1(0.1).l2(0.2).list() - .layer(0, new DenseLayer.Builder().nIn(10).nOut(5).build()).layer(1, - new OutputLayer.Builder().nIn(5).nOut(4) - .activation(Activation.HARDSIGMOID).build()) - .build(); - + @DisplayName("Test Fine Tune Override") + void testFineTuneOverride() { + // Check that fine-tune overrides are selective - i.e., if I only specify a new LR, only the LR should be modified + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Adam(1e-4)).activation(Activation.TANH).weightInit(WeightInit.RELU).l1(0.1).l2(0.2).list().layer(0, new DenseLayer.Builder().nIn(10).nOut(5).build()).layer(1, new OutputLayer.Builder().nIn(5).nOut(4).activation(Activation.HARDSIGMOID).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - - MultiLayerNetwork net2 = new TransferLearning.Builder(net) - .fineTuneConfiguration(new FineTuneConfiguration.Builder().updater(new Adam(2e-2)) - .backpropType(BackpropType.TruncatedBPTT) //Should be set on MLC - .build()) - .build(); - - - //Check original net isn't modified: + MultiLayerNetwork net2 = new TransferLearning.Builder(net).fineTuneConfiguration(new FineTuneConfiguration.Builder().updater(new Adam(2e-2)).backpropType(// Should be set on MLC + BackpropType.TruncatedBPTT).build()).build(); + // Check original net isn't modified: BaseLayer l0 = (BaseLayer) net.getLayer(0).conf().getLayer(); assertEquals(new Adam(1e-4), l0.getIUpdater()); assertEquals(Activation.TANH.getActivationFunction(), l0.getActivationFn()); assertEquals(new WeightInitRelu(), l0.getWeightInitFn()); assertEquals(0.1, TestUtils.getL1(l0), 1e-6); - BaseLayer l1 = (BaseLayer) net.getLayer(1).conf().getLayer(); assertEquals(new Adam(1e-4), l1.getIUpdater()); assertEquals(Activation.HARDSIGMOID.getActivationFunction(), l1.getActivationFn()); assertEquals(new WeightInitRelu(), l1.getWeightInitFn()); assertEquals(0.2, TestUtils.getL2(l1), 1e-6); - assertEquals(BackpropType.Standard, conf.getBackpropType()); - - //Check new net has only the appropriate things modified (i.e., LR) + // Check new net has only the appropriate things modified (i.e., LR) l0 = (BaseLayer) net2.getLayer(0).conf().getLayer(); assertEquals(new Adam(2e-2), l0.getIUpdater()); assertEquals(Activation.TANH.getActivationFunction(), l0.getActivationFn()); assertEquals(new WeightInitRelu(), l0.getWeightInitFn()); assertEquals(0.1, TestUtils.getL1(l0), 1e-6); - l1 = (BaseLayer) net2.getLayer(1).conf().getLayer(); assertEquals(new Adam(2e-2), l1.getIUpdater()); assertEquals(Activation.HARDSIGMOID.getActivationFunction(), l1.getActivationFn()); assertEquals(new WeightInitRelu(), l1.getWeightInitFn()); assertEquals(0.2, TestUtils.getL2(l1), 1e-6); - assertEquals(BackpropType.TruncatedBPTT, net2.getLayerWiseConfigurations().getBackpropType()); } @Test - public void testAllWithCNNNew() { + @DisplayName("Test All With CNN New") + void testAllWithCNNNew() { Nd4j.getRandom().setSeed(12345); - - DataSet randomData = new DataSet(Nd4j.rand(DataType.FLOAT,10, 28 * 28 * 3).reshape(10, 3, 28, 28), TestUtils.randomOneHot(10, 10)); - MultiLayerNetwork modelToFineTune = - new MultiLayerNetwork( - new NeuralNetConfiguration.Builder().seed(123) - .weightInit(WeightInit.XAVIER) - .updater(new Nesterovs(0.01, 0.9)) - .list() - .layer(0, new ConvolutionLayer.Builder(5, 5).nIn(3).stride(1, 1) - .nOut(20).activation(Activation.IDENTITY).build()) - .layer(1, new SubsamplingLayer.Builder(PoolingType.MAX) - .kernelSize(2, 2).stride(2, 2).build()) - .layer(2, new ConvolutionLayer.Builder(5, 5).stride(1, 1) - .nOut(50).activation(Activation.IDENTITY).build()) - .layer(3, new SubsamplingLayer.Builder(PoolingType.MAX) - .kernelSize(2, 2).stride(2, 2).build()) - .layer(4, new DenseLayer.Builder().activation(Activation.RELU).nOut(500).build()) - .layer(5, new DenseLayer.Builder().activation(Activation.RELU).nOut(250).build()) - .layer(6, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) - .nOut(100).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutionalFlat(28, 28, 3)) //See note below - .build()); + DataSet randomData = new DataSet(Nd4j.rand(DataType.FLOAT, 10, 28 * 28 * 3).reshape(10, 3, 28, 28), TestUtils.randomOneHot(10, 10)); + MultiLayerNetwork modelToFineTune = new MultiLayerNetwork(new NeuralNetConfiguration.Builder().seed(123).weightInit(WeightInit.XAVIER).updater(new Nesterovs(0.01, 0.9)).list().layer(0, new ConvolutionLayer.Builder(5, 5).nIn(3).stride(1, 1).nOut(20).activation(Activation.IDENTITY).build()).layer(1, new SubsamplingLayer.Builder(PoolingType.MAX).kernelSize(2, 2).stride(2, 2).build()).layer(2, new ConvolutionLayer.Builder(5, 5).stride(1, 1).nOut(50).activation(Activation.IDENTITY).build()).layer(3, new SubsamplingLayer.Builder(PoolingType.MAX).kernelSize(2, 2).stride(2, 2).build()).layer(4, new DenseLayer.Builder().activation(Activation.RELU).nOut(500).build()).layer(5, new DenseLayer.Builder().activation(Activation.RELU).nOut(250).build()).layer(6, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(100).activation(Activation.SOFTMAX).build()).setInputType(// See note below + InputType.convolutionalFlat(28, 28, 3)).build()); modelToFineTune.init(); - INDArray asFrozenFeatures = modelToFineTune.feedForwardToLayer(2, randomData.getFeatures(), false).get(2); //10x20x12x12 - + // 10x20x12x12 + INDArray asFrozenFeatures = modelToFineTune.feedForwardToLayer(2, randomData.getFeatures(), false).get(2); NeuralNetConfiguration.Builder equivalentConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.2)); FineTuneConfiguration overallConf = new FineTuneConfiguration.Builder().updater(new Sgd(0.2)).build(); - - MultiLayerNetwork modelNow = new TransferLearning.Builder(modelToFineTune).fineTuneConfiguration(overallConf) - .setFeatureExtractor(1).removeLayersFromOutput(5) - .addLayer(new DenseLayer.Builder().activation(Activation.RELU).nIn(12 * 12 * 20).nOut(300) - .build()) - .addLayer(new DenseLayer.Builder().activation(Activation.RELU).nIn(300).nOut(150).build()) - .addLayer(new DenseLayer.Builder().activation(Activation.RELU).nIn(150).nOut(50).build()) - .addLayer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) - .activation(Activation.SOFTMAX).nIn(50).nOut(10).build()) - .setInputPreProcessor(2, new CnnToFeedForwardPreProcessor(12, 12, 20)).build(); - - - MultiLayerNetwork notFrozen = new MultiLayerNetwork(equivalentConf.list() - .layer(0, new DenseLayer.Builder().activation(Activation.RELU).nIn(12 * 12 * 20).nOut(300) - .build()) - .layer(1, new DenseLayer.Builder().activation(Activation.RELU).nIn(300).nOut(150).build()) - .layer(2, new DenseLayer.Builder().activation(Activation.RELU).nIn(150).nOut(50).build()) - .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nIn(50) - .nOut(10).activation(Activation.SOFTMAX).build()) - .inputPreProcessor(0, new CnnToFeedForwardPreProcessor(12, 12, 20)) - .build()); + MultiLayerNetwork modelNow = new TransferLearning.Builder(modelToFineTune).fineTuneConfiguration(overallConf).setFeatureExtractor(1).removeLayersFromOutput(5).addLayer(new DenseLayer.Builder().activation(Activation.RELU).nIn(12 * 12 * 20).nOut(300).build()).addLayer(new DenseLayer.Builder().activation(Activation.RELU).nIn(300).nOut(150).build()).addLayer(new DenseLayer.Builder().activation(Activation.RELU).nIn(150).nOut(50).build()).addLayer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).activation(Activation.SOFTMAX).nIn(50).nOut(10).build()).setInputPreProcessor(2, new CnnToFeedForwardPreProcessor(12, 12, 20)).build(); + MultiLayerNetwork notFrozen = new MultiLayerNetwork(equivalentConf.list().layer(0, new DenseLayer.Builder().activation(Activation.RELU).nIn(12 * 12 * 20).nOut(300).build()).layer(1, new DenseLayer.Builder().activation(Activation.RELU).nIn(300).nOut(150).build()).layer(2, new DenseLayer.Builder().activation(Activation.RELU).nIn(150).nOut(50).build()).layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nIn(50).nOut(10).activation(Activation.SOFTMAX).build()).inputPreProcessor(0, new CnnToFeedForwardPreProcessor(12, 12, 20)).build()); notFrozen.init(); - assertArrayEquals(modelToFineTune.getLayer(0).params().shape(), modelNow.getLayer(0).params().shape()); - //subsampling has no params - //assertArrayEquals(modelExpectedArch.getLayer(1).params().shape(), modelNow.getLayer(1).params().shape()); + // subsampling has no params + // assertArrayEquals(modelExpectedArch.getLayer(1).params().shape(), modelNow.getLayer(1).params().shape()); assertArrayEquals(notFrozen.getLayer(0).params().shape(), modelNow.getLayer(2).params().shape()); modelNow.getLayer(2).setParams(notFrozen.getLayer(0).params()); assertArrayEquals(notFrozen.getLayer(1).params().shape(), modelNow.getLayer(3).params().shape()); @@ -595,154 +304,76 @@ public class TransferLearningMLNTest extends BaseDL4JTest { modelNow.getLayer(4).setParams(notFrozen.getLayer(2).params()); assertArrayEquals(notFrozen.getLayer(3).params().shape(), modelNow.getLayer(5).params().shape()); modelNow.getLayer(5).setParams(notFrozen.getLayer(3).params()); - int i = 0; while (i < 3) { notFrozen.fit(new DataSet(asFrozenFeatures, randomData.getLabels())); modelNow.fit(randomData); i++; } - INDArray expectedParams = Nd4j.hstack(modelToFineTune.getLayer(0).params(), notFrozen.params()); assertEquals(expectedParams, modelNow.params()); } @Test - public void testObjectOverrides(){ - //https://github.com/deeplearning4j/deeplearning4j/issues/4368 - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .dropOut(0.5) - .weightNoise(new DropConnect(0.5)) - .l2(0.5) - .constrainWeights(new UnitNormConstraint()) - .list() - .layer(new DenseLayer.Builder().nIn(10).nOut(10).build()) - .build(); - + @DisplayName("Test Object Overrides") + void testObjectOverrides() { + // https://github.com/deeplearning4j/deeplearning4j/issues/4368 + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dropOut(0.5).weightNoise(new DropConnect(0.5)).l2(0.5).constrainWeights(new UnitNormConstraint()).list().layer(new DenseLayer.Builder().nIn(10).nOut(10).build()).build(); MultiLayerNetwork orig = new MultiLayerNetwork(conf); orig.init(); - - FineTuneConfiguration ftc = new FineTuneConfiguration.Builder() - .dropOut(0) - .weightNoise(null) - .constraints(null) - .l2(0.0) - .build(); - - MultiLayerNetwork transfer = new TransferLearning.Builder(orig) - .fineTuneConfiguration(ftc) - .build(); - + FineTuneConfiguration ftc = new FineTuneConfiguration.Builder().dropOut(0).weightNoise(null).constraints(null).l2(0.0).build(); + MultiLayerNetwork transfer = new TransferLearning.Builder(orig).fineTuneConfiguration(ftc).build(); DenseLayer l = (DenseLayer) transfer.getLayer(0).conf().getLayer(); - assertNull(l.getIDropout()); assertNull(l.getWeightNoise()); assertNull(l.getConstraints()); assertNull(TestUtils.getL2Reg(l)); } - @Test - public void testTransferLearningSubsequent() { - final INDArray input = Nd4j.create(6,6,6,6); - final MultiLayerNetwork net = new MultiLayerNetwork(new NeuralNetConfiguration.Builder() - .weightInit(new ConstantDistribution(666)) - .list() - .setInputType(InputType.inferInputTypes(input)[0]) - .layer(new Convolution2D.Builder(3, 3).nOut(10).build()) - .layer(new Convolution2D.Builder(1, 1).nOut(3).build()) - .layer(new OutputLayer.Builder().nOut(2).lossFunction(LossFunctions.LossFunction.MSE) - .build()).build()); + @DisplayName("Test Transfer Learning Subsequent") + void testTransferLearningSubsequent() { + final INDArray input = Nd4j.create(6, 6, 6, 6); + final MultiLayerNetwork net = new MultiLayerNetwork(new NeuralNetConfiguration.Builder().weightInit(new ConstantDistribution(666)).list().setInputType(InputType.inferInputTypes(input)[0]).layer(new Convolution2D.Builder(3, 3).nOut(10).build()).layer(new Convolution2D.Builder(1, 1).nOut(3).build()).layer(new OutputLayer.Builder().nOut(2).lossFunction(LossFunctions.LossFunction.MSE).build()).build()); net.init(); - - MultiLayerNetwork newGraph = new TransferLearning - .Builder(net) - .fineTuneConfiguration(new FineTuneConfiguration.Builder().build()) - .nOutReplace(0, 7, new ConstantDistribution(333)) - .nOutReplace(1, 3, new ConstantDistribution(111)) - .removeLayersFromOutput(1) - .addLayer(new OutputLayer.Builder() - .nIn(48).nOut(2) - .lossFunction(LossFunctions.LossFunction.MSE) - .build()) - .setInputPreProcessor(2, new CnnToFeedForwardPreProcessor(4,4,3)) - .build(); + MultiLayerNetwork newGraph = new TransferLearning.Builder(net).fineTuneConfiguration(new FineTuneConfiguration.Builder().build()).nOutReplace(0, 7, new ConstantDistribution(333)).nOutReplace(1, 3, new ConstantDistribution(111)).removeLayersFromOutput(1).addLayer(new OutputLayer.Builder().nIn(48).nOut(2).lossFunction(LossFunctions.LossFunction.MSE).build()).setInputPreProcessor(2, new CnnToFeedForwardPreProcessor(4, 4, 3)).build(); newGraph.init(); - - assertEquals("Incorrect # inputs", 7, newGraph.layerInputSize(1)); - + assertEquals(7, newGraph.layerInputSize(1), "Incorrect # inputs"); newGraph.output(input); } @Test - public void testChangeNOutNIn() { - INDArray input = Nd4j.create(new long[] {1, 2, 4, 4}); - MultiLayerNetwork net = new MultiLayerNetwork(new NeuralNetConfiguration.Builder() - .list() - .setInputType(InputType.inferInputTypes(input)[0]) - .layer(new Convolution2D.Builder(1, 1).nOut(10).build()) - .layer(new SubsamplingLayer.Builder(1,1).build()) - .layer(new Convolution2D.Builder(1, 1).nOut(7).build()) - .layer(new OutputLayer.Builder().activation(Activation.SOFTMAX).nOut(2).build()) - .build()); + @DisplayName("Test Change N Out N In") + void testChangeNOutNIn() { + INDArray input = Nd4j.create(new long[] { 1, 2, 4, 4 }); + MultiLayerNetwork net = new MultiLayerNetwork(new NeuralNetConfiguration.Builder().list().setInputType(InputType.inferInputTypes(input)[0]).layer(new Convolution2D.Builder(1, 1).nOut(10).build()).layer(new SubsamplingLayer.Builder(1, 1).build()).layer(new Convolution2D.Builder(1, 1).nOut(7).build()).layer(new OutputLayer.Builder().activation(Activation.SOFTMAX).nOut(2).build()).build()); net.init(); - - final MultiLayerNetwork newNet = new TransferLearning.Builder(net) - .nOutReplace(0, 5, WeightInit.XAVIER) - .nInReplace(2, 5, WeightInit.XAVIER) - .build(); - + final MultiLayerNetwork newNet = new TransferLearning.Builder(net).nOutReplace(0, 5, WeightInit.XAVIER).nInReplace(2, 5, WeightInit.XAVIER).build(); newNet.init(); - - assertEquals("Incorrect number of outputs!", 5 , newNet.layerSize(0)); - assertEquals("Incorrect number of inputs!", 5, newNet.layerInputSize(2)); + assertEquals(5, newNet.layerSize(0), "Incorrect number of outputs!"); + assertEquals(5, newNet.layerInputSize(2), "Incorrect number of inputs!"); newNet.output(input); } - @Test - public void testTransferLearningSameDiffLayers(){ - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .dataType(DataType.DOUBLE) - .activation(Activation.TANH) - .updater(new Adam(0.01)) - .weightInit(WeightInit.XAVIER) - .list() - .layer(new LSTM.Builder().nOut(8).build()) - .layer( new SelfAttentionLayer.Builder().nOut(4).nHeads(2).projectInput(true).build()) - .layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.MAX).build()) - .layer(new OutputLayer.Builder().nOut(2).activation(Activation.SOFTMAX) - .lossFunction(LossFunctions.LossFunction.MCXENT).build()) - .setInputType(InputType.recurrent(4)) - .build(); - + @DisplayName("Test Transfer Learning Same Diff Layers") + void testTransferLearningSameDiffLayers() { + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).activation(Activation.TANH).updater(new Adam(0.01)).weightInit(WeightInit.XAVIER).list().layer(new LSTM.Builder().nOut(8).build()).layer(new SelfAttentionLayer.Builder().nOut(4).nHeads(2).projectInput(true).build()).layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.MAX).build()).layer(new OutputLayer.Builder().nOut(2).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()).setInputType(InputType.recurrent(4)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - INDArray in = Nd4j.rand(DataType.FLOAT, 3, 4, 5); INDArray out = net.output(in); - - MultiLayerNetwork net2 = new TransferLearning.Builder(net) - .fineTuneConfiguration(FineTuneConfiguration.builder().updater(new Adam(0.01)).build()) - .removeLayersFromOutput(1) - .addLayer(new OutputLayer.Builder().nIn(4).nOut(2).activation(Activation.SOFTMAX) - .lossFunction(LossFunctions.LossFunction.MCXENT).build()) - .build(); - + MultiLayerNetwork net2 = new TransferLearning.Builder(net).fineTuneConfiguration(FineTuneConfiguration.builder().updater(new Adam(0.01)).build()).removeLayersFromOutput(1).addLayer(new OutputLayer.Builder().nIn(4).nOut(2).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()).build(); net2.setParam("3_W", net.getParam("3_W")); net2.setParam("3_b", net.getParam("3_b")); - - Map p1 = net.paramTable(); - Map p2 = net2.paramTable(); - for(String s : p1.keySet()){ + Map p1 = net.paramTable(); + Map p2 = net2.paramTable(); + for (String s : p1.keySet()) { INDArray i1 = p1.get(s); INDArray i2 = p2.get(s); - assertEquals(s, i1, i2); + assertEquals(i1, i2,s); } - INDArray out2 = net2.output(in); - assertEquals(out, out2); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/weights/LegacyWeightInitTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/weights/LegacyWeightInitTest.java index 669ca7692..b9e9f3376 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/weights/LegacyWeightInitTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/weights/LegacyWeightInitTest.java @@ -17,50 +17,43 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.weights; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.conf.distribution.*; import org.deeplearning4j.nn.conf.serde.JsonMappers; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.rng.Random; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.RandomFactory; import org.nd4j.shade.jackson.databind.ObjectMapper; - import java.io.IOException; import java.util.Arrays; import java.util.List; +import static org.junit.jupiter.api.Assertions.*; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.*; - - -public class LegacyWeightInitTest extends BaseDL4JTest { +@DisplayName("Legacy Weight Init Test") +class LegacyWeightInitTest extends BaseDL4JTest { private RandomFactory prevFactory; + private final static int SEED = 666; - private final static List distributions = Arrays.asList( - new LogNormalDistribution(12.3, 4.56), - new BinomialDistribution(3, 0.3), - new NormalDistribution(0.666, 0.333), - new UniformDistribution(-1.23, 4.56), - new OrthogonalDistribution(3.45), - new TruncatedNormalDistribution(0.456, 0.123), - new ConstantDistribution(666)); + private final static List distributions = Arrays.asList(new LogNormalDistribution(12.3, 4.56), new BinomialDistribution(3, 0.3), new NormalDistribution(0.666, 0.333), new UniformDistribution(-1.23, 4.56), new OrthogonalDistribution(3.45), new TruncatedNormalDistribution(0.456, 0.123), new ConstantDistribution(666)); - @Before - public void setRandomFactory() { + @BeforeEach + void setRandomFactory() { prevFactory = Nd4j.randomFactory; Nd4j.randomFactory = new FixedSeedRandomFactory(prevFactory); } - @After - public void resetRandomFactory() { + @AfterEach + void resetRandomFactory() { Nd4j.randomFactory = prevFactory; } @@ -68,24 +61,22 @@ public class LegacyWeightInitTest extends BaseDL4JTest { * Test that param init is identical to legacy implementation */ @Test - public void initParams() { - final long[] shape = {5, 5}; // To make identity happy + @DisplayName("Init Params") + void initParams() { + // To make identity happy + final long[] shape = { 5, 5 }; final long fanIn = shape[0]; final long fanOut = shape[1]; - final INDArray inLegacy = Nd4j.create(fanIn * fanOut); final INDArray inTest = inLegacy.dup(); for (WeightInit legacyWi : WeightInit.values()) { if (legacyWi != WeightInit.DISTRIBUTION) { Nd4j.getRandom().setSeed(SEED); final INDArray expected = WeightInitUtil.initWeights(fanIn, fanOut, shape, legacyWi, null, inLegacy); - Nd4j.getRandom().setSeed(SEED); - final INDArray actual = legacyWi.getWeightInitFunction() - .init(fanIn, fanOut, shape, WeightInitUtil.DEFAULT_WEIGHT_INIT_ORDER, inTest); - assertArrayEquals("Incorrect shape for " + legacyWi + "!", shape, actual.shape()); - - assertEquals("Incorrect weight initialization for " + legacyWi + "!", expected, actual); + final INDArray actual = legacyWi.getWeightInitFunction().init(fanIn, fanOut, shape, WeightInitUtil.DEFAULT_WEIGHT_INIT_ORDER, inTest); + assertArrayEquals(shape, actual.shape(),"Incorrect shape for " + legacyWi + "!"); + assertEquals( expected, actual,"Incorrect weight initialization for " + legacyWi + "!"); } } } @@ -94,34 +85,20 @@ public class LegacyWeightInitTest extends BaseDL4JTest { * Test that param init is identical to legacy implementation */ @Test - public void initParamsFromDistribution() { - final long[] shape = {3, 7}; // To make identity happy + @DisplayName("Init Params From Distribution") + void initParamsFromDistribution() { + // To make identity happy + final long[] shape = { 3, 7 }; final long fanIn = shape[0]; final long fanOut = shape[1]; - final INDArray inLegacy = Nd4j.create(fanIn * fanOut); final INDArray inTest = inLegacy.dup(); - for (Distribution dist : distributions) { - Nd4j.getRandom().setSeed(SEED); - final INDArray expected = WeightInitUtil.initWeights( - fanIn, - fanOut, - shape, - WeightInit.DISTRIBUTION, - Distributions.createDistribution(dist), - inLegacy); - - final INDArray actual = new WeightInitDistribution(dist).init( - fanIn, - fanOut, - shape, - WeightInitUtil.DEFAULT_WEIGHT_INIT_ORDER, - inTest); - assertArrayEquals("Incorrect shape for " + dist.getClass().getSimpleName() + "!", shape, actual.shape()); - - assertEquals("Incorrect weight initialization for " + dist.getClass().getSimpleName() + "!", expected, actual); + final INDArray expected = WeightInitUtil.initWeights(fanIn, fanOut, shape, WeightInit.DISTRIBUTION, Distributions.createDistribution(dist), inLegacy); + final INDArray actual = new WeightInitDistribution(dist).init(fanIn, fanOut, shape, WeightInitUtil.DEFAULT_WEIGHT_INIT_ORDER, inTest); + assertArrayEquals(shape, actual.shape(),"Incorrect shape for " + dist.getClass().getSimpleName() + "!"); + assertEquals( expected, actual,"Incorrect weight initialization for " + dist.getClass().getSimpleName() + "!"); } } @@ -129,30 +106,27 @@ public class LegacyWeightInitTest extends BaseDL4JTest { * Test that weight inits can be serialized and de-serialized in JSON format */ @Test - public void serializeDeserializeJson() throws IOException { - final long[] shape = {5, 5}; // To make identity happy + @DisplayName("Serialize Deserialize Json") + void serializeDeserializeJson() throws IOException { + // To make identity happy + final long[] shape = { 5, 5 }; final long fanIn = shape[0]; final long fanOut = shape[1]; - final ObjectMapper mapper = JsonMappers.getMapper(); final INDArray inBefore = Nd4j.create(fanIn * fanOut); final INDArray inAfter = inBefore.dup(); - // Just use to enum to loop over all strategies for (WeightInit legacyWi : WeightInit.values()) { if (legacyWi != WeightInit.DISTRIBUTION) { Nd4j.getRandom().setSeed(SEED); final IWeightInit before = legacyWi.getWeightInitFunction(); final INDArray expected = before.init(fanIn, fanOut, shape, inBefore.ordering(), inBefore); - final String json = mapper.writeValueAsString(before); final IWeightInit after = mapper.readValue(json, IWeightInit.class); - Nd4j.getRandom().setSeed(SEED); final INDArray actual = after.init(fanIn, fanOut, shape, inAfter.ordering(), inAfter); - - assertArrayEquals("Incorrect shape for " + legacyWi + "!", shape, actual.shape()); - assertEquals("Incorrect weight initialization for " + legacyWi + "!", expected, actual); + assertArrayEquals( shape, actual.shape(),"Incorrect shape for " + legacyWi + "!"); + assertEquals(expected, actual,"Incorrect weight initialization for " + legacyWi + "!"); } } } @@ -161,35 +135,25 @@ public class LegacyWeightInitTest extends BaseDL4JTest { * Test that distribution can be serialized and de-serialized in JSON format */ @Test - public void serializeDeserializeDistributionJson() throws IOException { - final long[] shape = {3, 7}; // To make identity happy + @DisplayName("Serialize Deserialize Distribution Json") + void serializeDeserializeDistributionJson() throws IOException { + // To make identity happy + final long[] shape = { 3, 7 }; final long fanIn = shape[0]; final long fanOut = shape[1]; - final ObjectMapper mapper = JsonMappers.getMapper(); final INDArray inBefore = Nd4j.create(fanIn * fanOut); final INDArray inAfter = inBefore.dup(); - for (Distribution dist : distributions) { - Nd4j.getRandom().setSeed(SEED); final IWeightInit before = new WeightInitDistribution(dist); - final INDArray expected = before.init( - fanIn, - fanOut, - shape, - inBefore.ordering(), - inBefore); - + final INDArray expected = before.init(fanIn, fanOut, shape, inBefore.ordering(), inBefore); final String json = mapper.writeValueAsString(before); final IWeightInit after = mapper.readValue(json, IWeightInit.class); - Nd4j.getRandom().setSeed(SEED); final INDArray actual = after.init(fanIn, fanOut, shape, inAfter.ordering(), inAfter); - - assertArrayEquals("Incorrect shape for " + dist.getClass().getSimpleName() + "!", shape, actual.shape()); - - assertEquals("Incorrect weight initialization for " + dist.getClass().getSimpleName() + "!", expected, actual); + assertArrayEquals(shape, actual.shape(),"Incorrect shape for " + dist.getClass().getSimpleName() + "!"); + assertEquals(expected, actual,"Incorrect weight initialization for " + dist.getClass().getSimpleName() + "!"); } } @@ -197,21 +161,22 @@ public class LegacyWeightInitTest extends BaseDL4JTest { * Test equals and hashcode implementation. Redundant as one can trust Lombok on this?? */ @Test - public void equalsAndHashCode() { - WeightInit lastInit = WeightInit.values()[WeightInit.values().length-1]; + @DisplayName("Equals And Hash Code") + void equalsAndHashCode() { + WeightInit lastInit = WeightInit.values()[WeightInit.values().length - 1]; for (WeightInit legacyWi : WeightInit.values()) { - if(legacyWi != WeightInit.DISTRIBUTION) { - assertEquals("Shall be equal!", legacyWi.getWeightInitFunction(), legacyWi.getWeightInitFunction()); - assertNotEquals("Shall not be equal!", lastInit.getWeightInitFunction(), legacyWi.getWeightInitFunction()); + if (legacyWi != WeightInit.DISTRIBUTION) { + assertEquals(legacyWi.getWeightInitFunction(), legacyWi.getWeightInitFunction(), "Shall be equal!"); + assertNotEquals(lastInit.getWeightInitFunction(), legacyWi.getWeightInitFunction(), "Shall not be equal!"); if (legacyWi != WeightInit.NORMAL && legacyWi != WeightInit.LECUN_NORMAL) { lastInit = legacyWi; } } } Distribution lastDist = distributions.get(distributions.size() - 1); - for(Distribution distribution: distributions) { - assertEquals("Shall be equal!", new WeightInitDistribution(distribution), new WeightInitDistribution(distribution.clone())); - assertNotEquals("Shall not be equal!", new WeightInitDistribution(lastDist), new WeightInitDistribution(distribution)); + for (Distribution distribution : distributions) { + assertEquals(new WeightInitDistribution(distribution), new WeightInitDistribution(distribution.clone()), "Shall be equal!"); + assertNotEquals(new WeightInitDistribution(lastDist), new WeightInitDistribution(distribution), "Shall not be equal!"); lastDist = distribution; } } @@ -219,9 +184,10 @@ public class LegacyWeightInitTest extends BaseDL4JTest { /** * Assumes RandomFactory will only call no-args constructor while this test runs */ + @DisplayName("Fixed Seed Random Factory") private static class FixedSeedRandomFactory extends RandomFactory { - private final RandomFactory factory; + private final RandomFactory factory; private FixedSeedRandomFactory(RandomFactory factory) { super(factory.getRandom().getClass()); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/weights/WeightInitIdentityTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/weights/WeightInitIdentityTest.java index be0a1c471..8dd8b9c2f 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/weights/WeightInitIdentityTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/weights/WeightInitIdentityTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.weights; import org.deeplearning4j.BaseDL4JTest; @@ -27,98 +26,63 @@ import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.graph.ComputationGraph; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.impl.ActivationIdentity; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.assertEquals; - -public class WeightInitIdentityTest extends BaseDL4JTest { +@DisplayName("Weight Init Identity Test") +class WeightInitIdentityTest extends BaseDL4JTest { /** * Test identity mapping for 1d convolution */ @Test - @Ignore("Ignore for now. Underlying logic changed. Gradient checker passes so implementatin is valid.") - public void testIdConv1D() { - final INDArray input = Nd4j.randn(DataType.FLOAT, 1,5,7); + @Disabled("Ignore for now. Underlying logic changed. Gradient checker passes so implementatin is valid.") + @DisplayName("Test Id Conv 1 D") + void testIdConv1D() { + final INDArray input = Nd4j.randn(DataType.FLOAT, 1, 5, 7); final String inputName = "input"; final String conv = "conv"; final String output = "output"; - final ComputationGraph graph = new ComputationGraph(new NeuralNetConfiguration.Builder() - .graphBuilder() - .addInputs(inputName) - .setOutputs(output) - .layer(conv, new Convolution1DLayer.Builder(7) - .convolutionMode(ConvolutionMode.Same) - .nOut(input.size(1)) - .weightInit(new WeightInitIdentity()) - .activation(new ActivationIdentity()) - .build(), inputName) - .layer(output, new RnnLossLayer.Builder().activation(new ActivationIdentity()).build(), conv) - .setInputTypes(InputType.recurrent(5,7,RNNFormat.NCW)) - .build()); + final ComputationGraph graph = new ComputationGraph(new NeuralNetConfiguration.Builder().graphBuilder().addInputs(inputName).setOutputs(output).layer(conv, new Convolution1DLayer.Builder(7).convolutionMode(ConvolutionMode.Same).nOut(input.size(1)).weightInit(new WeightInitIdentity()).activation(new ActivationIdentity()).build(), inputName).layer(output, new RnnLossLayer.Builder().activation(new ActivationIdentity()).build(), conv).setInputTypes(InputType.recurrent(5, 7, RNNFormat.NCW)).build()); graph.init(); - INDArray reshape = graph.outputSingle(input).reshape(input.shape()); - assertEquals("Mapping was not identity!", input, reshape); + assertEquals(input, reshape, "Mapping was not identity!"); } /** * Test identity mapping for 2d convolution */ @Test - public void testIdConv2D() { - final INDArray input = Nd4j.randn(DataType.FLOAT,1,5,7,11); + @DisplayName("Test Id Conv 2 D") + void testIdConv2D() { + final INDArray input = Nd4j.randn(DataType.FLOAT, 1, 5, 7, 11); final String inputName = "input"; final String conv = "conv"; final String output = "output"; - final ComputationGraph graph = new ComputationGraph(new NeuralNetConfiguration.Builder() - .graphBuilder() - .setInputTypes(InputType.inferInputType(input)) - .addInputs(inputName) - .setOutputs(output) - .layer(conv, new ConvolutionLayer.Builder(3,5) - .convolutionMode(ConvolutionMode.Same) - .nOut(input.size(1)) - .weightInit(new WeightInitIdentity()) - .activation(new ActivationIdentity()) - .build(), inputName) - .layer(output, new CnnLossLayer.Builder().activation(new ActivationIdentity()).build(), conv) - .build()); + final ComputationGraph graph = new ComputationGraph(new NeuralNetConfiguration.Builder().graphBuilder().setInputTypes(InputType.inferInputType(input)).addInputs(inputName).setOutputs(output).layer(conv, new ConvolutionLayer.Builder(3, 5).convolutionMode(ConvolutionMode.Same).nOut(input.size(1)).weightInit(new WeightInitIdentity()).activation(new ActivationIdentity()).build(), inputName).layer(output, new CnnLossLayer.Builder().activation(new ActivationIdentity()).build(), conv).build()); graph.init(); - - assertEquals("Mapping was not identity!", input, graph.outputSingle(input)); + assertEquals(input, graph.outputSingle(input), "Mapping was not identity!"); } /** * Test identity mapping for 3d convolution */ @Test - public void testIdConv3D() { - final INDArray input = Nd4j.randn(DataType.FLOAT, 1,5,7,11,13); + @DisplayName("Test Id Conv 3 D") + void testIdConv3D() { + final INDArray input = Nd4j.randn(DataType.FLOAT, 1, 5, 7, 11, 13); final String inputName = "input"; final String conv = "conv"; final String output = "output"; - final ComputationGraph graph = new ComputationGraph(new NeuralNetConfiguration.Builder() - .graphBuilder() - .setInputTypes(InputType.inferInputType(input)) - .addInputs(inputName) - .setOutputs(output) - .layer(conv, new Convolution3D.Builder(3,7,5) - .convolutionMode(ConvolutionMode.Same) - .dataFormat(Convolution3D.DataFormat.NCDHW) - .nOut(input.size(1)) - .weightInit(new WeightInitIdentity()) - .activation(new ActivationIdentity()) - .build(), inputName) - .layer(output, new Cnn3DLossLayer.Builder(Convolution3D.DataFormat.NCDHW).activation(new ActivationIdentity()).build(), conv) - .build()); + final ComputationGraph graph = new ComputationGraph(new NeuralNetConfiguration.Builder().graphBuilder().setInputTypes(InputType.inferInputType(input)).addInputs(inputName).setOutputs(output).layer(conv, new Convolution3D.Builder(3, 7, 5).convolutionMode(ConvolutionMode.Same).dataFormat(Convolution3D.DataFormat.NCDHW).nOut(input.size(1)).weightInit(new WeightInitIdentity()).activation(new ActivationIdentity()).build(), inputName).layer(output, new Cnn3DLossLayer.Builder(Convolution3D.DataFormat.NCDHW).activation(new ActivationIdentity()).build(), conv).build()); graph.init(); - - assertEquals("Mapping was not identity!", input, graph.outputSingle(input)); + assertEquals(input, graph.outputSingle(input), "Mapping was not identity!"); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/weights/WeightInitUtilTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/weights/WeightInitUtilTest.java index ef137bc67..47dbfcbe7 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/weights/WeightInitUtilTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/weights/WeightInitUtilTest.java @@ -17,136 +17,129 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.weights; import org.apache.commons.math3.util.FastMath; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.conf.distribution.Distributions; import org.deeplearning4j.nn.conf.distribution.GaussianDistribution; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.rng.distribution.Distribution; import org.nd4j.linalg.factory.Nd4j; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.assertEquals; +@DisplayName("Weight Init Util Test") +class WeightInitUtilTest extends BaseDL4JTest { -public class WeightInitUtilTest extends BaseDL4JTest { protected int fanIn = 3; + protected int fanOut = 2; - protected int[] shape = new int[] {fanIn, fanOut}; + + protected int[] shape = new int[] { fanIn, fanOut }; + protected Distribution dist = Distributions.createDistribution(new GaussianDistribution(0.0, 0.1)); - @Before - public void doBefore() { + @BeforeEach + void doBefore() { Nd4j.getRandom().setSeed(123); } @Test - public void testDistribution() { + @DisplayName("Test Distribution") + void testDistribution() { INDArray params = Nd4j.create(shape, 'f'); - INDArray weightsActual = WeightInitUtil.initWeights(-1, -1, shape, WeightInit.DISTRIBUTION, dist, params); //fan in/out not used - + // fan in/out not used + INDArray weightsActual = WeightInitUtil.initWeights(-1, -1, shape, WeightInit.DISTRIBUTION, dist, params); // expected calculation Nd4j.getRandom().setSeed(123); INDArray weightsExpected = dist.sample(params); - assertEquals(weightsExpected, weightsActual); } @Test - public void testRelu() { + @DisplayName("Test Relu") + void testRelu() { INDArray params = Nd4j.create(shape, 'f'); INDArray weightsActual = WeightInitUtil.initWeights(fanIn, fanOut, shape, WeightInit.RELU, dist, params); - // expected calculation Nd4j.getRandom().setSeed(123); INDArray weightsExpected = Nd4j.randn('f', shape).muli(FastMath.sqrt(2.0 / fanIn)); - assertEquals(weightsExpected, weightsActual); } @Test - public void testSigmoidUniform() { + @DisplayName("Test Sigmoid Uniform") + void testSigmoidUniform() { INDArray params = Nd4j.create(shape, 'f'); - INDArray weightsActual = - WeightInitUtil.initWeights(fanIn, fanOut, shape, WeightInit.SIGMOID_UNIFORM, dist, params); - + INDArray weightsActual = WeightInitUtil.initWeights(fanIn, fanOut, shape, WeightInit.SIGMOID_UNIFORM, dist, params); // expected calculation Nd4j.getRandom().setSeed(123); double min = -4.0 * Math.sqrt(6.0 / (double) (shape[0] + shape[1])); double max = 4.0 * Math.sqrt(6.0 / (double) (shape[0] + shape[1])); INDArray weightsExpected = Nd4j.getDistributions().createUniform(min, max).sample(Nd4j.createUninitialized(shape, 'f')); - assertEquals(weightsExpected, weightsActual); } @Test - public void testUniform() { + @DisplayName("Test Uniform") + void testUniform() { INDArray params = Nd4j.create(shape, 'f'); INDArray weightsActual = WeightInitUtil.initWeights(fanIn, fanOut, shape, WeightInit.UNIFORM, dist, params); - // expected calculation Nd4j.getRandom().setSeed(123); double a = 1.0 / Math.sqrt(fanIn); INDArray weightsExpected = Nd4j.getDistributions().createUniform(-a, a).sample(Nd4j.create(shape, 'f')); - assertEquals(weightsExpected, weightsActual); } @Test - public void testXavier() { + @DisplayName("Test Xavier") + void testXavier() { Nd4j.getRandom().setSeed(123); INDArray params = Nd4j.create(shape, 'f'); INDArray weightsActual = WeightInitUtil.initWeights(fanIn, fanOut, shape, WeightInit.XAVIER, dist, params); - // expected calculation Nd4j.getRandom().setSeed(123); INDArray weightsExpected = Nd4j.randn('f', shape); weightsExpected.muli(FastMath.sqrt(2.0 / (fanIn + fanOut))); - assertEquals(weightsExpected, weightsActual); } @Test - public void testXavierFanIn() { + @DisplayName("Test Xavier Fan In") + void testXavierFanIn() { INDArray params = Nd4j.create(shape, 'f'); - INDArray weightsActual = - WeightInitUtil.initWeights(fanIn, fanOut, shape, WeightInit.XAVIER_FAN_IN, dist, params); - + INDArray weightsActual = WeightInitUtil.initWeights(fanIn, fanOut, shape, WeightInit.XAVIER_FAN_IN, dist, params); // expected calculation Nd4j.getRandom().setSeed(123); INDArray weightsExpected = Nd4j.randn('f', shape); weightsExpected.divi(FastMath.sqrt(fanIn)); - assertEquals(weightsExpected, weightsActual); } @Test - public void testXavierLegacy() { + @DisplayName("Test Xavier Legacy") + void testXavierLegacy() { INDArray params = Nd4j.create(shape, 'f'); - INDArray weightsActual = - WeightInitUtil.initWeights(fanIn, fanOut, shape, WeightInit.XAVIER_LEGACY, dist, params); - + INDArray weightsActual = WeightInitUtil.initWeights(fanIn, fanOut, shape, WeightInit.XAVIER_LEGACY, dist, params); // expected calculation Nd4j.getRandom().setSeed(123); INDArray weightsExpected = Nd4j.randn('f', shape); weightsExpected.muli(FastMath.sqrt(1.0 / (fanIn + fanOut))); - assertEquals(weightsExpected, weightsActual); } @Test - public void testZero() { + @DisplayName("Test Zero") + void testZero() { INDArray params = Nd4j.create(shape, 'f'); INDArray weightsActual = WeightInitUtil.initWeights(fanIn, fanOut, shape, WeightInit.ZERO, dist, params); - // expected calculation INDArray weightsExpected = Nd4j.create(shape, 'f'); - assertEquals(weightsExpected, weightsActual); } - - } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/BackTrackLineSearchTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/BackTrackLineSearchTest.java index fd333f42c..6e2542c07 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/BackTrackLineSearchTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/BackTrackLineSearchTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.optimize.solver; import lombok.val; @@ -36,8 +35,8 @@ import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.deeplearning4j.optimize.solvers.BackTrackLineSearch; import org.deeplearning4j.optimize.stepfunctions.DefaultStepFunction; import org.deeplearning4j.optimize.stepfunctions.NegativeDefaultStepFunction; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; @@ -45,21 +44,24 @@ import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.lossfunctions.LossFunctions; - import java.util.Collections; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** * @author Adam Gibson */ -public class BackTrackLineSearchTest extends BaseDL4JTest { +@DisplayName("Back Track Line Search Test") +class BackTrackLineSearchTest extends BaseDL4JTest { + private DataSetIterator irisIter; + private DataSet irisData; - @Before - public void before() { + @BeforeEach + void before() { if (irisIter == null) { irisIter = new IrisDataSetIterator(5, 5); } @@ -69,59 +71,48 @@ public class BackTrackLineSearchTest extends BaseDL4JTest { } } - - @Test - public void testSingleMinLineSearch() throws Exception { - OutputLayer layer = getIrisLogisticLayerConfig(Activation.SOFTMAX, 100, - LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD); - int nParams = (int)layer.numParams(); + @DisplayName("Test Single Min Line Search") + void testSingleMinLineSearch() throws Exception { + OutputLayer layer = getIrisLogisticLayerConfig(Activation.SOFTMAX, 100, LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD); + int nParams = (int) layer.numParams(); layer.setBackpropGradientsViewArray(Nd4j.create(1, nParams)); layer.setInput(irisData.getFeatures(), LayerWorkspaceMgr.noWorkspaces()); layer.setLabels(irisData.getLabels()); layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces()); - BackTrackLineSearch lineSearch = new BackTrackLineSearch(layer, layer.getOptimizer()); double step = lineSearch.optimize(layer.params(), layer.gradient().gradient(), layer.gradient().gradient(), LayerWorkspaceMgr.noWorkspacesImmutable()); - assertEquals(1.0, step, 1e-3); } @Test - public void testSingleMaxLineSearch() throws Exception { + @DisplayName("Test Single Max Line Search") + void testSingleMaxLineSearch() throws Exception { double score1, score2; - - OutputLayer layer = getIrisLogisticLayerConfig(Activation.SOFTMAX, 100, - LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD); - int nParams = (int)layer.numParams(); + OutputLayer layer = getIrisLogisticLayerConfig(Activation.SOFTMAX, 100, LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD); + int nParams = (int) layer.numParams(); layer.setBackpropGradientsViewArray(Nd4j.create(1, nParams)); layer.setInput(irisData.getFeatures(), LayerWorkspaceMgr.noWorkspaces()); layer.setLabels(irisData.getLabels()); layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces()); score1 = layer.score(); - - BackTrackLineSearch lineSearch = - new BackTrackLineSearch(layer, new NegativeDefaultStepFunction(), layer.getOptimizer()); + BackTrackLineSearch lineSearch = new BackTrackLineSearch(layer, new NegativeDefaultStepFunction(), layer.getOptimizer()); double step = lineSearch.optimize(layer.params(), layer.gradient().gradient(), layer.gradient().gradient(), LayerWorkspaceMgr.noWorkspacesImmutable()); - assertEquals(1.0, step, 1e-3); } - @Test - public void testMultMinLineSearch() throws Exception { + @DisplayName("Test Mult Min Line Search") + void testMultMinLineSearch() throws Exception { double score1, score2; - - OutputLayer layer = getIrisLogisticLayerConfig(Activation.SOFTMAX, 100, - LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD); - int nParams = (int)layer.numParams(); + OutputLayer layer = getIrisLogisticLayerConfig(Activation.SOFTMAX, 100, LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD); + int nParams = (int) layer.numParams(); layer.setBackpropGradientsViewArray(Nd4j.create(1, nParams)); layer.setInput(irisData.getFeatures(), LayerWorkspaceMgr.noWorkspaces()); layer.setLabels(irisData.getLabels()); layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces()); score1 = layer.score(); INDArray origGradient = layer.gradient().gradient().dup(); - NegativeDefaultStepFunction sf = new NegativeDefaultStepFunction(); BackTrackLineSearch lineSearch = new BackTrackLineSearch(layer, sf, layer.getOptimizer()); double step = lineSearch.optimize(layer.params(), layer.gradient().gradient(), layer.gradient().gradient(), LayerWorkspaceMgr.noWorkspacesImmutable()); @@ -129,71 +120,54 @@ public class BackTrackLineSearchTest extends BaseDL4JTest { sf.step(currParams, origGradient, step); layer.setParams(currParams); layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces()); - score2 = layer.score(); - - assertTrue("score1=" + score1 + ", score2=" + score2, score1 > score2); - + assertTrue(score1 > score2,"score1=" + score1 + ", score2=" + score2); } @Test - public void testMultMaxLineSearch() throws Exception { + @DisplayName("Test Mult Max Line Search") + void testMultMaxLineSearch() throws Exception { double score1, score2; - irisData.normalizeZeroMeanZeroUnitVariance(); OutputLayer layer = getIrisLogisticLayerConfig(Activation.SOFTMAX, 100, LossFunctions.LossFunction.MCXENT); - int nParams = (int)layer.numParams(); + int nParams = (int) layer.numParams(); layer.setBackpropGradientsViewArray(Nd4j.create(1, nParams)); layer.setInput(irisData.getFeatures(), LayerWorkspaceMgr.noWorkspaces()); layer.setLabels(irisData.getLabels()); layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces()); score1 = layer.score(); INDArray origGradient = layer.gradient().gradient().dup(); - DefaultStepFunction sf = new DefaultStepFunction(); BackTrackLineSearch lineSearch = new BackTrackLineSearch(layer, sf, layer.getOptimizer()); - double step = lineSearch.optimize(layer.params().dup(), layer.gradient().gradient().dup(), - layer.gradient().gradient().dup(), LayerWorkspaceMgr.noWorkspacesImmutable()); - + double step = lineSearch.optimize(layer.params().dup(), layer.gradient().gradient().dup(), layer.gradient().gradient().dup(), LayerWorkspaceMgr.noWorkspacesImmutable()); INDArray currParams = layer.params(); sf.step(currParams, origGradient, step); layer.setParams(currParams); layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces()); score2 = layer.score(); - - assertTrue("score1 = " + score1 + ", score2 = " + score2, score1 < score2); + assertTrue(score1 < score2,"score1 = " + score1 + ", score2 = " + score2); } - private static OutputLayer getIrisLogisticLayerConfig(Activation activationFunction, int maxIterations, - LossFunctions.LossFunction lossFunction) { - NeuralNetConfiguration conf = - new NeuralNetConfiguration.Builder().seed(12345L).miniBatch(true) - .maxNumLineSearchIterations(maxIterations) - .layer(new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(lossFunction) - .nIn(4).nOut(3).activation(activationFunction) - .weightInit(WeightInit.XAVIER).build()) - .build(); - + private static OutputLayer getIrisLogisticLayerConfig(Activation activationFunction, int maxIterations, LossFunctions.LossFunction lossFunction) { + NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345L).miniBatch(true).maxNumLineSearchIterations(maxIterations).layer(new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(lossFunction).nIn(4).nOut(3).activation(activationFunction).weightInit(WeightInit.XAVIER).build()).build(); val numParams = conf.getLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); return (OutputLayer) conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); } - /////////////////////////////////////////////////////////////////////////// - + // ///////////////////////////////////////////////////////////////////////// @Test - public void testBackTrackLineGradientDescent() { + @DisplayName("Test Back Track Line Gradient Descent") + void testBackTrackLineGradientDescent() { OptimizationAlgorithm optimizer = OptimizationAlgorithm.LINE_GRADIENT_DESCENT; - DataSetIterator irisIter = new IrisDataSetIterator(1, 1); DataSet data = irisIter.next(); - MultiLayerNetwork network = new MultiLayerNetwork(getIrisMultiLayerConfig(Activation.SIGMOID, optimizer)); network.init(); TrainingListener listener = new ScoreIterationListener(10); network.setListeners(Collections.singletonList(listener)); double oldScore = network.score(data); - for( int i=0; i<100; i++ ) { + for (int i = 0; i < 100; i++) { network.fit(data.getFeatures(), data.getLabels()); } double score = network.score(); @@ -201,9 +175,9 @@ public class BackTrackLineSearchTest extends BaseDL4JTest { } @Test - public void testBackTrackLineCG() { + @DisplayName("Test Back Track Line CG") + void testBackTrackLineCG() { OptimizationAlgorithm optimizer = OptimizationAlgorithm.CONJUGATE_GRADIENT; - DataSet data = irisIter.next(); data.normalizeZeroMeanZeroUnitVariance(); MultiLayerNetwork network = new MultiLayerNetwork(getIrisMultiLayerConfig(Activation.RELU, optimizer)); @@ -211,17 +185,16 @@ public class BackTrackLineSearchTest extends BaseDL4JTest { TrainingListener listener = new ScoreIterationListener(10); network.setListeners(Collections.singletonList(listener)); double firstScore = network.score(data); - - for( int i=0; i<5; i++ ) { + for (int i = 0; i < 5; i++) { network.fit(data.getFeatures(), data.getLabels()); } double score = network.score(); assertTrue(score < firstScore); - } @Test - public void testBackTrackLineLBFGS() { + @DisplayName("Test Back Track Line LBFGS") + void testBackTrackLineLBFGS() { OptimizationAlgorithm optimizer = OptimizationAlgorithm.LBFGS; DataSet data = irisIter.next(); data.normalizeZeroMeanZeroUnitVariance(); @@ -230,28 +203,15 @@ public class BackTrackLineSearchTest extends BaseDL4JTest { TrainingListener listener = new ScoreIterationListener(10); network.setListeners(Collections.singletonList(listener)); double oldScore = network.score(data); - - for( int i=0; i<5; i++ ) { + for (int i = 0; i < 5; i++) { network.fit(data.getFeatures(), data.getLabels()); } double score = network.score(); assertTrue(score < oldScore); - } private static MultiLayerConfiguration getIrisMultiLayerConfig(Activation activationFunction, OptimizationAlgorithm optimizer) { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(optimizer) - .updater(new Adam(0.01)).seed(12345L).list() - .layer(0, new DenseLayer.Builder().nIn(4).nOut(100).weightInit(WeightInit.XAVIER) - .activation(activationFunction).build()) - .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT).nIn(100).nOut(3) - .weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX) - .build()) - .build(); - - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(optimizer).updater(new Adam(0.01)).seed(12345L).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(100).weightInit(WeightInit.XAVIER).activation(activationFunction).build()).layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(100).nOut(3).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build()).build(); return conf; } - } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/EncodedGradientsAccumulatorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/EncodedGradientsAccumulatorTest.java index 15379a37d..8cbeaf234 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/EncodedGradientsAccumulatorTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/EncodedGradientsAccumulatorTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.optimize.solver.accumulation; import lombok.extern.slf4j.Slf4j; @@ -26,18 +25,20 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.optimize.solvers.accumulation.EncodedGradientsAccumulator; import org.deeplearning4j.optimize.solvers.accumulation.EncodingHandler; import org.deeplearning4j.optimize.solvers.accumulation.encoding.threshold.FixedThresholdAlgorithm; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.concurrency.AffinityManager; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.util.PrintAffinity; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.nativeblas.OpaqueDataBuffer; - -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; @Slf4j -public class EncodedGradientsAccumulatorTest extends BaseDL4JTest { +@DisplayName("Encoded Gradients Accumulator Test") +class EncodedGradientsAccumulatorTest extends BaseDL4JTest { @Override public long getTimeoutMilliseconds() { @@ -49,29 +50,25 @@ public class EncodedGradientsAccumulatorTest extends BaseDL4JTest { * @throws Exception */ @Test - public void testStore1() throws Exception { + @DisplayName("Test Store 1") + void testStore1() throws Exception { int numParams; int[] workers; - if(isIntegrationTests()){ + if (isIntegrationTests()) { numParams = 100000; - workers = new int[] {2, 4, 8}; + workers = new int[] { 2, 4, 8 }; } else { numParams = 10000; - workers = new int[] {2, 3}; + workers = new int[] { 2, 3 }; } - for (int numWorkers : workers) { - EncodingHandler handler = new EncodingHandler(new FixedThresholdAlgorithm(1e-3),null, null, false); - + EncodingHandler handler = new EncodingHandler(new FixedThresholdAlgorithm(1e-3), null, null, false); val bufferSize = EncodedGradientsAccumulator.getOptimalBufferSize(numParams, numWorkers, 2); log.info("Workers: {}; Buffer size: {} bytes", numWorkers, bufferSize); - EncodedGradientsAccumulator accumulator = - new EncodedGradientsAccumulator(numWorkers, handler, bufferSize, 2, null, false); - + EncodedGradientsAccumulator accumulator = new EncodedGradientsAccumulator(numWorkers, handler, bufferSize, 2, null, false); for (int e = 10; e < numParams / 10; e++) { INDArray encoded = handler.encodeUpdates(0, 0, getGradients(numParams, e, 2e-3)); accumulator.receiveUpdate(encoded); - // just purge updates, like they were consumed for (int i = 0; i < accumulator.getMessages().size(); i++) { accumulator.getMessages().get(i).clear(); @@ -80,45 +77,35 @@ public class EncodedGradientsAccumulatorTest extends BaseDL4JTest { } } - /** * Here we ensure that no matter how dense/sparse our updates are - we're never going above 1/16 of original elements of gradients array * * @throws Exception */ @Test - public void testEncodingLimits1() throws Exception { + @DisplayName("Test Encoding Limits 1") + void testEncodingLimits1() throws Exception { int numParams; - if(isIntegrationTests()){ + if (isIntegrationTests()) { numParams = 100000; } else { numParams = 10000; } - - EncodingHandler handler = new EncodingHandler(new FixedThresholdAlgorithm(1e-3), null, Integer.MAX_VALUE, false); for (int e = 10; e < numParams / 5; e++) { - val gradients = getGradients(numParams, e, 2e-3); val encoded = handler.encodeUpdates(0, 0, gradients); - - assertNotNull("Failed with e == " + e, encoded); - + assertNotNull(encoded,"Failed with e == " + e); int encFormat = encoded.data().getInt(3); - - assertTrue("Failed for E = " + e + "; Format: " + encFormat + "; Length: " + encoded.data().length(), - encoded.data().length() < numParams / 16 + 6); + assertTrue( encoded.data().length() < numParams / 16 + 6,"Failed for E = " + e + "; Format: " + encFormat + "; Length: " + encoded.data().length()); } } - protected INDArray getGradients(int length, int numPositives, double value) { INDArray grad = Nd4j.create(length); - for (int i = 0; i < numPositives; i++) { grad.putScalar(i, value); } - return grad; } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/IndexedTailTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/IndexedTailTest.java index 8e8f81dea..28cd85b35 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/IndexedTailTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/IndexedTailTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.optimize.solver.accumulation; import lombok.extern.slf4j.Slf4j; @@ -25,230 +24,184 @@ import lombok.val; import org.apache.commons.lang3.RandomUtils; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.optimize.solvers.accumulation.IndexedTail; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.factory.Nd4j; - import java.util.ArrayList; import java.util.concurrent.atomic.AtomicInteger; - -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; @Slf4j -public class IndexedTailTest extends BaseDL4JTest { +@DisplayName("Indexed Tail Test") +class IndexedTailTest extends BaseDL4JTest { @Test - public void testDeltas_1() throws Exception { + @DisplayName("Test Deltas _ 1") + void testDeltas_1() throws Exception { val tail = new IndexedTail(2); - assertFalse(tail.hasAnything(11)); assertFalse(tail.hasAnything(22)); - // 3 updates in queue tail.put(Nd4j.create(5, 5)); tail.put(Nd4j.create(5, 5)); tail.put(Nd4j.create(5, 5)); - assertEquals(3, tail.getDelta(11)); assertEquals(3, tail.getDelta(22)); - - tail.drainTo(22, Nd4j.create(5, 5)); - assertEquals(3, tail.getDelta(11)); assertEquals(0, tail.getDelta(22)); - tail.put(Nd4j.create(5, 5)); - assertEquals(4, tail.getDelta(11)); assertEquals(1, tail.getDelta(22)); - tail.drainTo(22, Nd4j.create(5, 5)); tail.drainTo(11, Nd4j.create(5, 5)); - assertEquals(0, tail.getDelta(11)); assertEquals(0, tail.getDelta(22)); - - tail.put(Nd4j.create(5, 5)); tail.put(Nd4j.create(5, 5)); - assertEquals(2, tail.getDelta(11)); assertEquals(2, tail.getDelta(22)); - tail.drainTo(22, Nd4j.create(5, 5)); - assertEquals(2, tail.getDelta(11)); assertEquals(0, tail.getDelta(22)); } - @Test - public void testMaxAppliedIndex_1() { + @DisplayName("Test Max Applied Index _ 1") + void testMaxAppliedIndex_1() { val tail = new IndexedTail(3); - // "registering" 3 consumers assertFalse(tail.hasAnything(11)); assertFalse(tail.hasAnything(22)); assertFalse(tail.hasAnything(33)); - // putting 10 updates in for (int e = 0; e < 10; e++) { tail.put(Nd4j.create(5, 5)); } - assertEquals(10, tail.updatesSize()); - assertEquals(-1, tail.maxAppliedIndexEverywhere()); - // 2 consumers consumed 2 elements, and 1 consumer consumed 3 elements tail.getPositions().get(11L).set(2); tail.getPositions().get(22L).set(2); tail.getPositions().get(33L).set(3); - // all elements including this index are safe to remove, because they were consumed everywhere assertEquals(2, tail.maxAppliedIndexEverywhere()); - // only updates starting from 4 are safe to collapse, because 3 was consumed by one consumer assertEquals(4, tail.firstNotAppliedIndexEverywhere()); - // truncating stuff tail.maintenance(); - assertEquals(8, tail.updatesSize()); } @Test - public void testFirstNotApplied_1() { + @DisplayName("Test First Not Applied _ 1") + void testFirstNotApplied_1() { val tail = new IndexedTail(1); tail.hasAnything(); - assertEquals(-1, tail.firstNotAppliedIndexEverywhere()); - - tail.put(Nd4j.createUninitialized(5,5)); - + tail.put(Nd4j.createUninitialized(5, 5)); assertEquals(0, tail.firstNotAppliedIndexEverywhere()); - - tail.put(Nd4j.createUninitialized(5,5)); - tail.put(Nd4j.createUninitialized(5,5)); - + tail.put(Nd4j.createUninitialized(5, 5)); + tail.put(Nd4j.createUninitialized(5, 5)); assertEquals(0, tail.firstNotAppliedIndexEverywhere()); - assertTrue(tail.drainTo(Nd4j.create(5, 5))); - assertEquals(4, tail.firstNotAppliedIndexEverywhere()); } - @Test - public void testSingleThreaded_1() throws Exception { + @DisplayName("Test Single Threaded _ 1") + void testSingleThreaded_1() throws Exception { val tail = new IndexedTail(1); - for (int e = 0; e < 100; e++) { val orig = Nd4j.create(5, 5).assign(e); tail.put(orig); Nd4j.getExecutioner().commit(); - assertTrue(tail.hasAnything()); - val temp = Nd4j.create(5, 5); val status = tail.drainTo(temp); - assertTrue(status); assertArrayEquals(orig.shape(), temp.shape()); assertEquals(orig, temp); } - assertEquals(0, tail.updatesSize()); } @Test - public void testSingleThreaded_2() throws Exception { + @DisplayName("Test Single Threaded _ 2") + void testSingleThreaded_2() throws Exception { val tail = new IndexedTail(1); - for (int e = 0; e < 100; e++) { int numUpdates = RandomUtils.nextInt(1, 10); int sum = 0; - for (int f = 1; f <= numUpdates; f++) { sum += f; val orig = Nd4j.create(5, 5).assign(f); tail.put(orig); } Nd4j.getExecutioner().commit(); - assertTrue(tail.hasAnything()); - val temp = Nd4j.create(5, 5); val status = tail.drainTo(temp); - assertTrue(status); assertEquals(sum, temp.meanNumber().intValue()); } - assertEquals(0, tail.updatesSize()); } @Test - public void testSingleThreaded_3() throws Exception { - val tail = new IndexedTail(2, true, new long[]{5, 5}); + @DisplayName("Test Single Threaded _ 3") + void testSingleThreaded_3() throws Exception { + val tail = new IndexedTail(2, true, new long[] { 5, 5 }); assertFalse(tail.hasAnything()); assertFalse(tail.hasAnything(11)); - int sum = 0; for (int e = 0; e < 64; e++) { - sum += (e+1); - tail.put(Nd4j.createUninitialized(5,5).assign(e+1)); + sum += (e + 1); + tail.put(Nd4j.createUninitialized(5, 5).assign(e + 1)); Nd4j.getExecutioner().commit(); } - assertTrue(tail.getCollapsedMode().get()); assertEquals(1, tail.updatesSize()); - val array = tail.getUpdates().get(32L); assertNotNull(array); assertEquals(sum, (int) array.getDouble(0)); } - @Test - public void testPseudoMultiThreaded_1() throws Exception { + @DisplayName("Test Pseudo Multi Threaded _ 1") + void testPseudoMultiThreaded_1() throws Exception { val tail = new IndexedTail(2); - for (int e = 0; e < 100; e++) { // putting in one thread val orig = Nd4j.create(5, 5).assign(e); tail.put(orig); Nd4j.getExecutioner().commit(); - for (int t = 0; t < 2; t++) { assertTrue(tail.hasAnything(t)); - val temp = Nd4j.create(5, 5); val status = tail.drainTo(t, temp); - assertTrue(status); assertArrayEquals(orig.shape(), temp.shape()); assertEquals(orig, temp); } } - assertEquals(0, tail.updatesSize()); } - - @Test - @Ignore("AB 2019/05/21 - Failing sometimes on linux-x86_64-cpu - Issue #7657") - public void testMultiThreaded_1() throws Exception { + @Disabled("AB 2019/05/21 - Failing sometimes on linux-x86_64-cpu - Issue #7657") + @DisplayName("Test Multi Threaded _ 1") + void testMultiThreaded_1() throws Exception { val numReaders = 4; final val tail = new IndexedTail(numReaders); - final long[] sums = new long[numReaders]; val readers = new ArrayList(); for (int e = 0; e < numReaders; e++) { final int f = e; val t = new Thread(new Runnable() { + @Override public void run() { sums[f] = 0; @@ -262,48 +215,37 @@ public class IndexedTailTest extends BaseDL4JTest { } } }); - t.setName("reader thread " + f); t.start(); readers.add(t); } - - int sum = 0; for (int e = 0; e < 10000; e++) { - val array = Nd4j.create(5, 5).assign(e+1); + val array = Nd4j.create(5, 5).assign(e + 1); Nd4j.getExecutioner().commit(); - - sum += (e+1); + sum += (e + 1); tail.put(array); } // just wait till everything consumed Thread.sleep(2000); tail.notifyDead(); - - - for (val t:readers) - t.join(); - - - for (int e = 0; e < numReaders; e++) - assertEquals("Failed for reader [" + e + "]",sum, sums[e]); - - + for (val t : readers) t.join(); + for (int e = 0; e < numReaders; e++) assertEquals(sum, sums[e],"Failed for reader [" + e + "]"); assertEquals(0, tail.updatesSize()); } @Test - public void testMultiThreaded_2() throws Exception { + @DisplayName("Test Multi Threaded _ 2") + void testMultiThreaded_2() throws Exception { val numReaders = 4; val numWriters = 4; final val tail = new IndexedTail(numReaders); - final long[] sums = new long[numReaders]; val readers = new ArrayList(); for (int e = 0; e < numReaders; e++) { final int f = e; val t = new Thread(new Runnable() { + @Override public void run() { sums[f] = 0; @@ -317,67 +259,51 @@ public class IndexedTailTest extends BaseDL4JTest { } } }); - t.setName("reader thread " + f); t.start(); readers.add(t); } - val writers = new ArrayList(); for (int e = 0; e < numWriters; e++) { val f = e; val t = new Thread(new Runnable() { + @Override public void run() { int sum = 0; for (int e = 0; e < 1000; e++) { - val array = Nd4j.create(5, 5).assign(e+1); + val array = Nd4j.create(5, 5).assign(e + 1); Nd4j.getExecutioner().commit(); - - sum += (e+1); + sum += (e + 1); tail.put(array); } } }); - t.setName("writer thread " + f); t.start(); writers.add(t); } - - - - for (val t:writers) - t.join(); - + for (val t : writers) t.join(); // just wait till everything consumed Thread.sleep(2000); tail.notifyDead(); - - - - for (val t:readers) - t.join(); - - - for (int e = 0; e < numReaders; e++) - assertEquals("Failed for reader [" + e + "]",500500 * numWriters, sums[e]); - - + for (val t : readers) t.join(); + for (int e = 0; e < numReaders; e++) assertEquals(500500 * numWriters, sums[e],"Failed for reader [" + e + "]"); assertEquals(0, tail.updatesSize()); } @Test - public void testMultiThreaded_3() throws Exception { + @DisplayName("Test Multi Threaded _ 3") + void testMultiThreaded_3() throws Exception { val numReaders = 4; val numWriters = 4; - final val tail = new IndexedTail(numReaders, true, new long[]{5, 5}); - + final val tail = new IndexedTail(numReaders, true, new long[] { 5, 5 }); final long[] sums = new long[numReaders]; val readers = new ArrayList(); for (int e = 0; e < numReaders; e++) { final int f = e; val t = new Thread(new Runnable() { + @Override public void run() { sums[f] = 0; @@ -391,52 +317,37 @@ public class IndexedTailTest extends BaseDL4JTest { } } }); - t.setName("reader thread " + f); t.start(); readers.add(t); } - final AtomicInteger sum = new AtomicInteger(0); val writers = new ArrayList(); for (int e = 0; e < numWriters; e++) { val f = e; val t = new Thread(new Runnable() { + @Override public void run() { for (int i = 0; i < 256; i++) { - - val array = Nd4j.create(5, 5).assign(i+1); + val array = Nd4j.create(5, 5).assign(i + 1); Nd4j.getExecutioner().commit(); - - sum.addAndGet(i+1); + sum.addAndGet(i + 1); tail.put(array); } } }); - t.setName("writer thread " + f); t.start(); writers.add(t); } - - - for (val t:writers) - t.join(); - + for (val t : writers) t.join(); // just wait till everything consumed Thread.sleep(3000); tail.notifyDead(); - - for (val t:readers) - t.join(); - + for (val t : readers) t.join(); log.info("Readers results: {}", sums); - - for (int e = 0; e < numReaders; e++) - assertEquals("Failed for reader [" + e + "]",sum.get(), sums[e]); - - + for (int e = 0; e < numReaders; e++) assertEquals(sum.get(), sums[e],"Failed for reader [" + e + "]"); assertEquals(0, tail.updatesSize()); } -} \ No newline at end of file +} diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/SmartFancyBlockingQueueTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/SmartFancyBlockingQueueTest.java index 5b9dd8c0a..adeb00d93 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/SmartFancyBlockingQueueTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/SmartFancyBlockingQueueTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.optimize.solver.accumulation; import lombok.extern.slf4j.Slf4j; @@ -25,178 +24,168 @@ import lombok.val; import org.apache.commons.lang3.RandomUtils; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.optimize.solvers.accumulation.SmartFancyBlockingQueue; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.common.util.ThreadUtils; import org.nd4j.linalg.factory.Nd4j; - import java.util.ArrayList; import java.util.concurrent.BrokenBarrierException; import java.util.concurrent.CyclicBarrier; +import static org.junit.jupiter.api.Assertions.*; +import org.junit.jupiter.api.DisplayName; +import static java.time.Duration.ofMillis; +import static org.junit.jupiter.api.Assertions.assertTimeout; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.*; +@Slf4j +@Disabled("AB 2019/05/21 - Failing (stuck, causing timeouts) - Issue #7657") +@DisplayName("Smart Fancy Blocking Queue Test") +class SmartFancyBlockingQueueTest extends BaseDL4JTest { -@Slf4j @Ignore("AB 2019/05/21 - Failing (stuck, causing timeouts) - Issue #7657") -public class SmartFancyBlockingQueueTest extends BaseDL4JTest { - @Test(timeout = 120000L) - public void testSFBQ_1() throws Exception { - val queue = new SmartFancyBlockingQueue(8, Nd4j.create(5, 5)); - - val array = Nd4j.create(5, 5); - - for (int e = 0; e < 6; e++) { - queue.put(Nd4j.create(5, 5).assign(e)); - }; - - assertEquals(6, queue.size()); - - for (int e = 6; e < 10; e++) { - queue.put(Nd4j.create(5, 5).assign(e)); - } - - assertEquals(1, queue.size()); - } - - @Test(timeout = 120000L) - public void testSFBQ_2() throws Exception { - final val queue = new SmartFancyBlockingQueue(1285601, Nd4j.create(5, 5)); - final val barrier = new CyclicBarrier(4); - - val threads = new ArrayList(); - for (int e = 0; e< 4; e++) { - val f = e; - val t = new Thread(new Runnable() { - @Override - public void run() { - int cnt = 0; - while (true) { - while (cnt < 1000) { - if (!queue.isEmpty()) { - if (cnt % 50 == 0) - log.info("Thread {}: [{}]", f, cnt); - - val arr = queue.poll(); - - assertNotNull(arr); - val local = arr.unsafeDuplication(true); - - assertEquals(cnt, local.meanNumber().intValue()); - cnt++; - } - - - try { - barrier.await(); - - if (f == 0) - queue.registerConsumers(4); - - barrier.await(); - } catch (InterruptedException e1) { - e1.printStackTrace(); - } catch (BrokenBarrierException e1) { - e1.printStackTrace(); - } - } - break; - } - - - } - }); - t.setName("reader thread " + f); - t.start(); - threads.add(t); - } - - for (int e = 0; e < 1000; e++) { - queue.put(Nd4j.create(5, 5).assign(e)); - Nd4j.getExecutioner().commit(); - } - - - for (val t: threads) - t.join(); - } - - - @Test(timeout = 120000L) - public void testSFBQ_3() throws Exception { - final val queue = new SmartFancyBlockingQueue(1285601, Nd4j.create(5, 5)); - - val threads = new ArrayList(); - for (int e = 0; e< 4; e++) { - val f = e; - val t = new Thread(new Runnable() { - @Override - public void run() { - int cnt = 0; - while (true) { - while (cnt < 1000) { - if (!queue.isEmpty()) { - if (cnt % 50 == 0) - log.info("Thread {}: [{}]", f, cnt); - - val arr = queue.poll(); - - assertNotNull(arr); - val local = arr.unsafeDuplication(true); - cnt++; - } - } - break; - } - } - }); - t.start(); - threads.add(t); - } - - val b = new Thread(new Runnable() { - @Override - public void run() { - while (true) { - queue.registerConsumers(4); - ThreadUtils.uncheckedSleep(30); - } + @Test + @DisplayName("Test SFBQ _ 1") + void testSFBQ_1() { + assertTimeout(ofMillis(120000), () -> { + val queue = new SmartFancyBlockingQueue(8, Nd4j.create(5, 5)); + val array = Nd4j.create(5, 5); + for (int e = 0; e < 6; e++) { + queue.put(Nd4j.create(5, 5).assign(e)); } + ; + assertEquals(6, queue.size()); + for (int e = 6; e < 10; e++) { + queue.put(Nd4j.create(5, 5).assign(e)); + } + assertEquals(1, queue.size()); }); + } - b.setDaemon(true); - b.start(); + @Test + @DisplayName("Test SFBQ _ 2") + void testSFBQ_2() { + assertTimeout(ofMillis(120000), () -> { + final val queue = new SmartFancyBlockingQueue(1285601, Nd4j.create(5, 5)); + final val barrier = new CyclicBarrier(4); + val threads = new ArrayList(); + for (int e = 0; e < 4; e++) { + val f = e; + val t = new Thread(new Runnable() { + + @Override + public void run() { + int cnt = 0; + while (true) { + while (cnt < 1000) { + if (!queue.isEmpty()) { + if (cnt % 50 == 0) + log.info("Thread {}: [{}]", f, cnt); + val arr = queue.poll(); + assertNotNull(arr); + val local = arr.unsafeDuplication(true); + assertEquals(cnt, local.meanNumber().intValue()); + cnt++; + } + try { + barrier.await(); + if (f == 0) + queue.registerConsumers(4); + barrier.await(); + } catch (InterruptedException e1) { + e1.printStackTrace(); + } catch (BrokenBarrierException e1) { + e1.printStackTrace(); + } + } + break; + } + } + }); + t.setName("reader thread " + f); + t.start(); + threads.add(t); + } + for (int e = 0; e < 1000; e++) { + queue.put(Nd4j.create(5, 5).assign(e)); + Nd4j.getExecutioner().commit(); + } + for (val t : threads) t.join(); + }); + } + + @Test + @DisplayName("Test SFBQ _ 3") + void testSFBQ_3() { + assertTimeout(ofMillis(120000), () -> { + final val queue = new SmartFancyBlockingQueue(1285601, Nd4j.create(5, 5)); + val threads = new ArrayList(); + for (int e = 0; e < 4; e++) { + val f = e; + val t = new Thread(new Runnable() { + + @Override + public void run() { + int cnt = 0; + while (true) { + while (cnt < 1000) { + if (!queue.isEmpty()) { + if (cnt % 50 == 0) + log.info("Thread {}: [{}]", f, cnt); + val arr = queue.poll(); + assertNotNull(arr); + val local = arr.unsafeDuplication(true); + cnt++; + } + } + break; + } + } + }); + t.start(); + threads.add(t); + } + val b = new Thread(new Runnable() { - val writers = new ArrayList(); - for (int e = 0; e < 4; e++) { - val t = new Thread(new Runnable() { @Override public void run() { - for (int e = 0; e <250; e++) { - try { - queue.put(Nd4j.createUninitialized(5, 5).assign(e)); - Thread.sleep(30); - } catch (Exception ex) { - throw new RuntimeException(ex); - } + while (true) { + queue.registerConsumers(4); + ThreadUtils.uncheckedSleep(30); } } }); + b.setDaemon(true); + b.start(); + val writers = new ArrayList(); + for (int e = 0; e < 4; e++) { + val t = new Thread(new Runnable() { - writers.add(t); - t.start(); - } - - for (val t: writers) - t.join(); - - for (val t: threads) - t.join(); + @Override + public void run() { + for (int e = 0; e < 250; e++) { + try { + queue.put(Nd4j.createUninitialized(5, 5).assign(e)); + Thread.sleep(30); + } catch (Exception ex) { + throw new RuntimeException(ex); + } + } + } + }); + writers.add(t); + t.start(); + } + for (val t : writers) t.join(); + for (val t : threads) t.join(); + }); } - @Test(timeout = 120000L) - public void testSFBQ_4() throws Exception { - final val queue = new SmartFancyBlockingQueue(16, Nd4j.create(5, 5)); - final val barrier = new CyclicBarrier(4); -/* + @Test + @DisplayName("Test SFBQ _ 4") + void testSFBQ_4() { + assertTimeout(ofMillis(120000), () -> { + final val queue = new SmartFancyBlockingQueue(16, Nd4j.create(5, 5)); + final val barrier = new CyclicBarrier(4); + /* val m = new Thread(new Runnable() { @Override public void run() { @@ -212,145 +201,126 @@ public class SmartFancyBlockingQueueTest extends BaseDL4JTest { m.setDaemon(true); m.start(); */ + val threads = new ArrayList(); + for (int e = 0; e < 4; e++) { + val f = e; + val t = new Thread(new Runnable() { - val threads = new ArrayList(); - for (int e = 0; e < 4; e++) { - val f= e; - val t = new Thread(new Runnable() { - @Override - public void run() { - try { - for (int e = 0; e < 100; e++) { - - log.info("[Thread {}]: fill phase {}", f, e); - val numUpdates = RandomUtils.nextInt(8, 128); - for (int p = 0; p < numUpdates; p++) { - queue.put(Nd4j.createUninitialized(5, 5)); - } - - if (f == 0) - queue.registerConsumers(4); - - barrier.await(); - log.info("[Thread {}]: read phase {}", f, e); - while (!queue.isEmpty()) { - val arr = queue.poll(); - - assertNotNull(arr); - } - - barrier.await(); - - } - } catch (InterruptedException e) { - throw new RuntimeException(e); - } catch (BrokenBarrierException e) { - throw new RuntimeException(e); - } - } - }); - - t.setName("worker thread " + f); - t.start(); - threads.add(t); - } - - for (val t:threads) - t.join(); - } - - - @Test(timeout = 120000L) - public void testSFBQ_5() throws Exception { - final val queue = new SmartFancyBlockingQueue(16, Nd4j.create(5, 5)); - final val barrier = new CyclicBarrier(4); - - // writers are just spamming updates every X ms - val writers = new ArrayList(); - for (int e = 0; e < 4; e++) { - val w = new Thread(new Runnable() { - @Override - public void run() { - while (true) { + @Override + public void run() { try { - val n = RandomUtils.nextInt(8, 64); - for (int i = 1; i < n+1; i++) { - val arr = Nd4j.createUninitialized(5, 5).assign(i); - Nd4j.getExecutioner().commit(); - queue.put(arr); - } - - ThreadUtils.uncheckedSleep(10); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - } - } - }); - - w.setName("writer thread " + e); - w.setDaemon(true); - w.start(); - writers.add(w); - } - - // each reader will read 250 updates. supposedly equal :) - final long[] means = new long[4]; - val readers = new ArrayList(); - for (int e = 0; e < 4; e++) { - final int f = e; - means[f] = 0; - val t = new Thread(new Runnable() { - @Override - public void run() { - try { - int cnt = 0; - int fnt = 0; - while (cnt < 1000) { - - if (!queue.isEmpty()) { + for (int e = 0; e < 100; e++) { + log.info("[Thread {}]: fill phase {}", f, e); + val numUpdates = RandomUtils.nextInt(8, 128); + for (int p = 0; p < numUpdates; p++) { + queue.put(Nd4j.createUninitialized(5, 5)); + } + if (f == 0) + queue.registerConsumers(4); + barrier.await(); + log.info("[Thread {}]: read phase {}", f, e); while (!queue.isEmpty()) { - val m = queue.poll(); - - val arr = m.unsafeDuplication(true); - val mean = arr.meanNumber().longValue(); - assertNotEquals("Failed at cycle: " + cnt,0, mean); - means[f] += mean; - - cnt++; + val arr = queue.poll(); + assertNotNull(arr); } barrier.await(); } - - barrier.await(); - - if (f == 0) { - log.info("Read cycle finished"); - queue.registerConsumers(4); - } - - barrier.await(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } catch (BrokenBarrierException e) { + throw new RuntimeException(e); } - } catch (InterruptedException e) { - throw new RuntimeException(e); - } catch (BrokenBarrierException e) { - throw new RuntimeException(e); } - } - }); - - t.setName("reader thread " + f); - t.start(); - readers.add(t); - } - - - for (val t:readers) - t.join(); - - // all messages should be the same - assertEquals(means[0], means[1]); - assertEquals(means[0], means[2]); - assertEquals(means[0], means[3]); + }); + t.setName("worker thread " + f); + t.start(); + threads.add(t); + } + for (val t : threads) t.join(); + }); } -} \ No newline at end of file + + @Test + @DisplayName("Test SFBQ _ 5") + void testSFBQ_5() { + assertTimeout(ofMillis(120000), () -> { + final val queue = new SmartFancyBlockingQueue(16, Nd4j.create(5, 5)); + final val barrier = new CyclicBarrier(4); + // writers are just spamming updates every X ms + val writers = new ArrayList(); + for (int e = 0; e < 4; e++) { + val w = new Thread(new Runnable() { + + @Override + public void run() { + while (true) { + try { + val n = RandomUtils.nextInt(8, 64); + for (int i = 1; i < n + 1; i++) { + val arr = Nd4j.createUninitialized(5, 5).assign(i); + Nd4j.getExecutioner().commit(); + queue.put(arr); + } + ThreadUtils.uncheckedSleep(10); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + } + }); + w.setName("writer thread " + e); + w.setDaemon(true); + w.start(); + writers.add(w); + } + // each reader will read 250 updates. supposedly equal :) + final long[] means = new long[4]; + val readers = new ArrayList(); + for (int e = 0; e < 4; e++) { + final int f = e; + means[f] = 0; + val t = new Thread(new Runnable() { + + @Override + public void run() { + try { + int cnt = 0; + int fnt = 0; + while (cnt < 1000) { + if (!queue.isEmpty()) { + while (!queue.isEmpty()) { + val m = queue.poll(); + val arr = m.unsafeDuplication(true); + val mean = arr.meanNumber().longValue(); + assertNotEquals(0, mean,"Failed at cycle: " + cnt); + means[f] += mean; + cnt++; + } + barrier.await(); + } + barrier.await(); + if (f == 0) { + log.info("Read cycle finished"); + queue.registerConsumers(4); + } + barrier.await(); + } + } catch (InterruptedException e) { + throw new RuntimeException(e); + } catch (BrokenBarrierException e) { + throw new RuntimeException(e); + } + } + }); + t.setName("reader thread " + f); + t.start(); + readers.add(t); + } + for (val t : readers) t.join(); + // all messages should be the same + assertEquals(means[0], means[1]); + assertEquals(means[0], means[2]); + assertEquals(means[0], means[3]); + }); + } +} diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimizer/listener/ScoreStatTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimizer/listener/ScoreStatTest.java index d7d2f6cce..1edd152c2 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimizer/listener/ScoreStatTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimizer/listener/ScoreStatTest.java @@ -17,104 +17,96 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.optimizer.listener; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.optimize.listeners.CollectScoresIterationListener; -import org.junit.Ignore; -import org.junit.Test; - +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import java.util.List; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; + +@DisplayName("Score Stat Test") +class ScoreStatTest extends BaseDL4JTest { -public class ScoreStatTest extends BaseDL4JTest { @Test - public void testScoreStatSmall() { + @DisplayName("Test Score Stat Small") + void testScoreStatSmall() { CollectScoresIterationListener.ScoreStat statTest = new CollectScoresIterationListener.ScoreStat(); for (int i = 0; i < CollectScoresIterationListener.ScoreStat.BUCKET_LENGTH; ++i) { - double score = (double)i; + double score = (double) i; statTest.addScore(i, score); } - List indexes = statTest.getIndexes(); List scores = statTest.getScores(); - assertTrue(indexes.size() == 1); assertTrue(scores.size() == 1); - assertTrue(indexes.get(0).length == CollectScoresIterationListener.ScoreStat.BUCKET_LENGTH); assertTrue(scores.get(0).length == CollectScoresIterationListener.ScoreStat.BUCKET_LENGTH); - assertEquals(indexes.get(0)[indexes.get(0).length-1], CollectScoresIterationListener.ScoreStat.BUCKET_LENGTH-1); - assertEquals(scores.get(0)[scores.get(0).length-1], CollectScoresIterationListener.ScoreStat.BUCKET_LENGTH-1, 1e-4); + assertEquals(indexes.get(0)[indexes.get(0).length - 1], CollectScoresIterationListener.ScoreStat.BUCKET_LENGTH - 1); + assertEquals(scores.get(0)[scores.get(0).length - 1], CollectScoresIterationListener.ScoreStat.BUCKET_LENGTH - 1, 1e-4); } @Test - public void testScoreStatAverage() { + @DisplayName("Test Score Stat Average") + void testScoreStatAverage() { int dataSize = 1000000; long[] indexes = new long[dataSize]; double[] scores = new double[dataSize]; - for (int i = 0; i < dataSize; ++i) { indexes[i] = i; scores[i] = i; } - CollectScoresIterationListener.ScoreStat statTest = new CollectScoresIterationListener.ScoreStat(); for (int i = 0; i < dataSize; ++i) { statTest.addScore(indexes[i], scores[i]); } - long[] indexesStored = statTest.getIndexes().get(0); double[] scoresStored = statTest.getScores().get(0); - assertArrayEquals(indexes, indexesStored); assertArrayEquals(scores, scoresStored, 1e-4); } @Test - public void testScoresClean() { - int dataSize = 10256; // expected to be placed in 2 buckets of 10k elements size + @DisplayName("Test Scores Clean") + void testScoresClean() { + // expected to be placed in 2 buckets of 10k elements size + int dataSize = 10256; long[] indexes = new long[dataSize]; double[] scores = new double[dataSize]; - for (int i = 0; i < dataSize; ++i) { indexes[i] = i; scores[i] = i; } - CollectScoresIterationListener.ScoreStat statTest = new CollectScoresIterationListener.ScoreStat(); for (int i = 0; i < dataSize; ++i) { statTest.addScore(indexes[i], scores[i]); } - long[] indexesEffective = statTest.getEffectiveIndexes(); double[] scoresEffective = statTest.getEffectiveScores(); - assertArrayEquals(indexes, indexesEffective); assertArrayEquals(scores, scoresEffective, 1e-4); } - @Ignore + @Disabled @Test - public void testScoreStatBig() { + @DisplayName("Test Score Stat Big") + void testScoreStatBig() { CollectScoresIterationListener.ScoreStat statTest = new CollectScoresIterationListener.ScoreStat(); - long bigLength = (long)Integer.MAX_VALUE + 5; + long bigLength = (long) Integer.MAX_VALUE + 5; for (long i = 0; i < bigLength; ++i) { - double score = (double)i; + double score = (double) i; statTest.addScore(i, score); } - List indexes = statTest.getIndexes(); List scores = statTest.getScores(); - assertTrue(indexes.size() == 2); assertTrue(scores.size() == 2); - for (int i = 0; i < 5; ++i) { assertTrue(indexes.get(1)[i] == Integer.MAX_VALUE + i); assertTrue(scores.get(1)[i] == Integer.MAX_VALUE + i); - } } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/parallelism/AsyncIteratorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/parallelism/AsyncIteratorTest.java index 5bfde3fa2..4cc240643 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/parallelism/AsyncIteratorTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/parallelism/AsyncIteratorTest.java @@ -17,26 +17,26 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.parallelism; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.core.parallelism.AsyncIterator; -import org.junit.Test; - +import org.junit.jupiter.api.Test; import java.util.ArrayList; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.assertEquals; - -public class AsyncIteratorTest extends BaseDL4JTest { +@DisplayName("Async Iterator Test") +class AsyncIteratorTest extends BaseDL4JTest { @Test - public void hasNext() throws Exception { + @DisplayName("Has Next") + void hasNext() throws Exception { ArrayList integers = new ArrayList<>(); for (int x = 0; x < 100000; x++) { integers.add(x); } - AsyncIterator iterator = new AsyncIterator<>(integers.iterator(), 512); int cnt = 0; Integer val = null; @@ -45,10 +45,7 @@ public class AsyncIteratorTest extends BaseDL4JTest { assertEquals(cnt, val.intValue()); cnt++; } - System.out.println("Last val: " + val); - assertEquals(integers.size(), cnt); } - } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/parallelism/MultiBooleanTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/parallelism/MultiBooleanTest.java index 9abc3e8a7..54a3a099e 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/parallelism/MultiBooleanTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/parallelism/MultiBooleanTest.java @@ -17,89 +17,73 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.parallelism; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.parallel.MultiBoolean; -import org.junit.Test; +import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; - -public class MultiBooleanTest extends BaseDL4JTest { +@DisplayName("Multi Boolean Test") +class MultiBooleanTest extends BaseDL4JTest { @Test - public void testBoolean1() throws Exception { + @DisplayName("Test Boolean 1") + void testBoolean1() throws Exception { MultiBoolean bool = new MultiBoolean(5); - assertTrue(bool.allFalse()); assertFalse(bool.allTrue()); } - @Test - public void testBoolean2() throws Exception { + @DisplayName("Test Boolean 2") + void testBoolean2() throws Exception { MultiBoolean bool = new MultiBoolean(5); - bool.set(true, 2); - assertFalse(bool.allFalse()); assertFalse(bool.allTrue()); } @Test - public void testBoolean3() throws Exception { + @DisplayName("Test Boolean 3") + void testBoolean3() throws Exception { MultiBoolean bool = new MultiBoolean(5); - bool.set(true, 0); bool.set(true, 1); bool.set(true, 2); - - bool.set(true, 3); - assertFalse(bool.allTrue()); - bool.set(true, 4); - assertFalse(bool.allFalse()); assertTrue(bool.allTrue()); - bool.set(false, 2); - assertFalse(bool.allTrue()); - bool.set(true, 2); - assertTrue(bool.allTrue()); } @Test - public void testBoolean4() throws Exception { + @DisplayName("Test Boolean 4") + void testBoolean4() throws Exception { MultiBoolean bool = new MultiBoolean(5, true); - - assertTrue(bool.get(1)); - bool.set(false, 1); - assertFalse(bool.get(1)); } - @Test - public void testBoolean5() throws Exception { + @DisplayName("Test Boolean 5") + void testBoolean5() throws Exception { MultiBoolean bool = new MultiBoolean(5, true, true); - for (int i = 0; i < 5; i++) { bool.set(false, i); } - for (int i = 0; i < 5; i++) { bool.set(true, i); } - assertTrue(bool.allFalse()); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/parallelism/ParallelExistingMiniBatchDataSetIteratorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/parallelism/ParallelExistingMiniBatchDataSetIteratorTest.java index 918d4aace..aa8c5984f 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/parallelism/ParallelExistingMiniBatchDataSetIteratorTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/parallelism/ParallelExistingMiniBatchDataSetIteratorTest.java @@ -17,77 +17,157 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.parallelism; import lombok.extern.slf4j.Slf4j; import org.junit.Rule; -import org.junit.rules.TemporaryFolder; +import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.io.ClassPathResource; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.callbacks.DataSetDeserializer; import org.deeplearning4j.datasets.iterator.parallel.FileSplitParallelDataSetIterator; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.common.primitives.Pair; - import java.io.File; import java.util.ArrayList; import java.util.List; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import org.junit.jupiter.api.DisplayName; +import java.nio.file.Path; +import static java.time.Duration.ofMillis; +import static org.junit.jupiter.api.Assertions.assertTimeout; +import org.junit.jupiter.api.extension.ExtendWith; @Slf4j -public class ParallelExistingMiniBatchDataSetIteratorTest extends BaseDL4JTest { +/* + @Test + public void testSimpleLoop1() throws Exception { + ParallelExistingMiniBatchDataSetIterator iterator = new ParallelExistingMiniBatchDataSetIterator(rootFolder,"mnist-train-%d.bin", 4); + ExistingMiniBatchDataSetIterator test = new ExistingMiniBatchDataSetIterator(rootFolder,"mnist-train-%d.bin"); + + + List> pairs = new ArrayList<>(); + + int cnt = 0; + long time1 = System.nanoTime(); + while (iterator.hasNext()) { + DataSet ds = iterator.next(); + long time2 = System.nanoTime(); + assertNotNull(ds); + assertEquals(64, ds.numExamples()); + pairs.add(new Pair(time2 - time1, 0L)); + cnt++; + time1 = System.nanoTime(); + } + assertEquals(26, cnt); + + cnt = 0; + time1 = System.nanoTime(); + while (test.hasNext()) { + DataSet ds = test.next(); + long time2 = System.nanoTime(); + assertNotNull(ds); + assertEquals(64, ds.numExamples()); + pairs.get(cnt).setSecond(time2 - time1); + cnt++; + time1 = System.nanoTime(); + } + + assertEquals(26, cnt); + + for (Pair times: pairs) { + log.info("Parallel: {} ns; Simple: {} ns", times.getFirst(), times.getSecond()); + } + } + + @Test + public void testReset1() throws Exception { + ParallelExistingMiniBatchDataSetIterator iterator = new ParallelExistingMiniBatchDataSetIterator(rootFolder,"mnist-train-%d.bin", 8); + + int cnt = 0; + long time1 = System.nanoTime(); + while (iterator.hasNext()) { + DataSet ds = iterator.next(); + long time2 = System.nanoTime(); + assertNotNull(ds); + assertEquals(64, ds.numExamples()); + cnt++; + + if (cnt == 10) + iterator.reset(); + + time1 = System.nanoTime(); + } + assertEquals(36, cnt); + } + + @Test + public void testWithAdsi1() throws Exception { + ParallelExistingMiniBatchDataSetIterator iterator = new ParallelExistingMiniBatchDataSetIterator(rootFolder,"mnist-train-%d.bin", 8); + AsyncDataSetIterator adsi = new AsyncDataSetIterator(iterator, 8, true); + + int cnt = 0; + long time1 = System.nanoTime(); + while (adsi.hasNext()) { + DataSet ds = adsi.next(); + long time2 = System.nanoTime(); + assertNotNull(ds); + assertEquals(64, ds.numExamples()); + cnt++; + + if (cnt == 10) + adsi.reset(); + + time1 = System.nanoTime(); + } + assertEquals(36, cnt); + } + */ +@DisplayName("Parallel Existing Mini Batch Data Set Iterator Test") +class ParallelExistingMiniBatchDataSetIteratorTest extends BaseDL4JTest { + + @TempDir + public Path tempDir; - @Rule - public TemporaryFolder tempDir = new TemporaryFolder(); private static File rootFolder; - @Before - public void setUp() throws Exception { + @BeforeEach + void setUp() throws Exception { if (rootFolder == null) { - rootFolder = tempDir.newFolder(); - for( int i=0; i<26; i++){ + rootFolder = tempDir.toFile(); + for (int i = 0; i < 26; i++) { new ClassPathResource("/datasets/mnist/mnist-train-" + i + ".bin").getTempFileFromArchive(rootFolder); } } } - - @Test(timeout = 30000L) - public void testNewSimpleLoop1() throws Exception { - FileSplitParallelDataSetIterator fspdsi = new FileSplitParallelDataSetIterator(rootFolder, "mnist-train-%d.bin", - new DataSetDeserializer()); - - List> pairs = new ArrayList<>(); - - - long time1 = System.nanoTime(); - int cnt = 0; - while (fspdsi.hasNext()) { - DataSet ds = fspdsi.next(); - long time2 = System.nanoTime(); - pairs.add(new Pair(time2 - time1, 0L)); - assertNotNull(ds); - - // imitating processing here - Thread.sleep(10); - - cnt++; - time1 = System.nanoTime(); - } - - assertEquals(26, cnt); - - for (Pair times : pairs) { - log.info("Parallel: {} ns; Simple: {} ns", times.getFirst(), times.getSecond()); - } + @Test + @DisplayName("Test New Simple Loop 1") + void testNewSimpleLoop1() { + assertTimeout(ofMillis(30000), () -> { + FileSplitParallelDataSetIterator fspdsi = new FileSplitParallelDataSetIterator(rootFolder, "mnist-train-%d.bin", new DataSetDeserializer()); + List> pairs = new ArrayList<>(); + long time1 = System.nanoTime(); + int cnt = 0; + while (fspdsi.hasNext()) { + DataSet ds = fspdsi.next(); + long time2 = System.nanoTime(); + pairs.add(new Pair(time2 - time1, 0L)); + assertNotNull(ds); + // imitating processing here + Thread.sleep(10); + cnt++; + time1 = System.nanoTime(); + } + assertEquals(26, cnt); + for (Pair times : pairs) { + log.info("Parallel: {} ns; Simple: {} ns", times.getFirst(), times.getSecond()); + } + }); } - - /* @Test public void testSimpleLoop1() throws Exception { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/perf/listener/SystemPollingTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/perf/listener/SystemPollingTest.java index 94d344c39..97ef03af9 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/perf/listener/SystemPollingTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/perf/listener/SystemPollingTest.java @@ -17,51 +17,45 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.perf.listener; import org.apache.commons.io.FileUtils; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.core.listener.HardwareMetric; import org.deeplearning4j.core.listener.SystemPolling; -import org.junit.Ignore; +import org.junit.jupiter.api.Disabled; import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; import org.nd4j.linalg.factory.Nd4j; - import java.io.File; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import org.junit.jupiter.api.DisplayName; +import java.nio.file.Path; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +@Disabled("AB 2019/05/24 - Failing on CI - \"Could not initialize class oshi.jna.platform.linux.Libc\" - Issue #7657") +@DisplayName("System Polling Test") +class SystemPollingTest extends BaseDL4JTest { -@Ignore("AB 2019/05/24 - Failing on CI - \"Could not initialize class oshi.jna.platform.linux.Libc\" - Issue #7657") -public class SystemPollingTest extends BaseDL4JTest { - - @Rule - public TemporaryFolder tempDir = new TemporaryFolder(); + @TempDir + public Path tempDir; @Test - public void testPolling() throws Exception { + @DisplayName("Test Polling") + void testPolling() throws Exception { Nd4j.create(1); - File tmpDir = tempDir.newFolder(); - - SystemPolling systemPolling = new SystemPolling.Builder() - .outputDirectory(tmpDir).pollEveryMillis(1000) - .build(); + File tmpDir = tempDir.toFile(); + SystemPolling systemPolling = new SystemPolling.Builder().outputDirectory(tmpDir).pollEveryMillis(1000).build(); systemPolling.run(); - Thread.sleep(8000); - systemPolling.stopPolling(); - File[] files = tmpDir.listFiles(); assertTrue(files != null && files.length > 0); - //System.out.println(Arrays.toString(files)); - + // System.out.println(Arrays.toString(files)); String yaml = FileUtils.readFileToString(files[0]); HardwareMetric fromYaml = HardwareMetric.fromYaml(yaml); System.out.println(fromYaml); } - } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/ui/UiConnectionInfoTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/ui/UiConnectionInfoTest.java index 0e2a71c5c..cf8984bfa 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/ui/UiConnectionInfoTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/ui/UiConnectionInfoTest.java @@ -17,107 +17,97 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.ui; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.core.ui.UiConnectionInfo; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.assertEquals; - -public class UiConnectionInfoTest extends BaseDL4JTest { - - @Before - public void setUp() throws Exception { +@DisplayName("Ui Connection Info Test") +class UiConnectionInfoTest extends BaseDL4JTest { + @BeforeEach + void setUp() throws Exception { } @Test - public void testGetFirstPart1() throws Exception { + @DisplayName("Test Get First Part 1") + void testGetFirstPart1() throws Exception { UiConnectionInfo info = new UiConnectionInfo.Builder().setPort(8080).build(); - - assertEquals("http://localhost:8080", info.getFirstPart()); + assertEquals(info.getFirstPart(), "http://localhost:8080"); } @Test - public void testGetFirstPart2() throws Exception { + @DisplayName("Test Get First Part 2") + void testGetFirstPart2() throws Exception { UiConnectionInfo info = new UiConnectionInfo.Builder().enableHttps(true).setPort(8080).build(); - - assertEquals("https://localhost:8080", info.getFirstPart()); + assertEquals(info.getFirstPart(), "https://localhost:8080"); } @Test - public void testGetFirstPart3() throws Exception { - UiConnectionInfo info = new UiConnectionInfo.Builder().setAddress("192.168.1.1").enableHttps(true).setPort(8082) - .build(); - - assertEquals("https://192.168.1.1:8082", info.getFirstPart()); - } - - - @Test - public void testGetSecondPart1() throws Exception { - UiConnectionInfo info = new UiConnectionInfo.Builder().setAddress("192.168.1.1").enableHttps(true).setPort(8082) - .setPath("www-data").build(); - - assertEquals("/www-data/", info.getSecondPart()); + @DisplayName("Test Get First Part 3") + void testGetFirstPart3() throws Exception { + UiConnectionInfo info = new UiConnectionInfo.Builder().setAddress("192.168.1.1").enableHttps(true).setPort(8082).build(); + assertEquals(info.getFirstPart(), "https://192.168.1.1:8082"); } @Test - public void testGetSecondPart2() throws Exception { - UiConnectionInfo info = new UiConnectionInfo.Builder().setAddress("192.168.1.1").enableHttps(true).setPort(8082) - .setPath("/www-data/tmp/").build(); - - assertEquals("/www-data/tmp/", info.getSecondPart()); + @DisplayName("Test Get Second Part 1") + void testGetSecondPart1() throws Exception { + UiConnectionInfo info = new UiConnectionInfo.Builder().setAddress("192.168.1.1").enableHttps(true).setPort(8082).setPath("www-data").build(); + assertEquals(info.getSecondPart(), "/www-data/"); } @Test - public void testGetSecondPart3() throws Exception { - UiConnectionInfo info = new UiConnectionInfo.Builder().setAddress("192.168.1.1").enableHttps(true).setPort(8082) - .setPath("/www-data/tmp").build(); - - assertEquals("/www-data/tmp/", info.getSecondPart()); + @DisplayName("Test Get Second Part 2") + void testGetSecondPart2() throws Exception { + UiConnectionInfo info = new UiConnectionInfo.Builder().setAddress("192.168.1.1").enableHttps(true).setPort(8082).setPath("/www-data/tmp/").build(); + assertEquals(info.getSecondPart(), "/www-data/tmp/"); } @Test - public void testGetSecondPart4() throws Exception { - UiConnectionInfo info = new UiConnectionInfo.Builder().setAddress("192.168.1.1").enableHttps(true).setPort(8082) - .setPath("/www-data//tmp").build(); - - assertEquals("/www-data/tmp/", info.getSecondPart()); + @DisplayName("Test Get Second Part 3") + void testGetSecondPart3() throws Exception { + UiConnectionInfo info = new UiConnectionInfo.Builder().setAddress("192.168.1.1").enableHttps(true).setPort(8082).setPath("/www-data/tmp").build(); + assertEquals(info.getSecondPart(), "/www-data/tmp/"); } @Test - public void testGetSecondPart5() throws Exception { - UiConnectionInfo info = new UiConnectionInfo.Builder().setAddress("192.168.1.1").enableHttps(true).setPort(8082) - .setPath("/www-data//tmp").build(); - - assertEquals("/www-data/tmp/alpha/", info.getSecondPart("alpha")); + @DisplayName("Test Get Second Part 4") + void testGetSecondPart4() throws Exception { + UiConnectionInfo info = new UiConnectionInfo.Builder().setAddress("192.168.1.1").enableHttps(true).setPort(8082).setPath("/www-data//tmp").build(); + assertEquals(info.getSecondPart(), "/www-data/tmp/"); } @Test - public void testGetSecondPart6() throws Exception { - UiConnectionInfo info = new UiConnectionInfo.Builder().setAddress("192.168.1.1").enableHttps(true).setPort(8082) - .setPath("//www-data//tmp").build(); - - assertEquals("/www-data/tmp/alpha/", info.getSecondPart("/alpha/")); + @DisplayName("Test Get Second Part 5") + void testGetSecondPart5() throws Exception { + UiConnectionInfo info = new UiConnectionInfo.Builder().setAddress("192.168.1.1").enableHttps(true).setPort(8082).setPath("/www-data//tmp").build(); + assertEquals(info.getSecondPart("alpha"), "/www-data/tmp/alpha/"); } @Test - public void testGetSecondPart7() throws Exception { - UiConnectionInfo info = new UiConnectionInfo.Builder().setAddress("192.168.1.1").enableHttps(true).setPort(8082) - .setPath("//www-data//tmp").build(); - - assertEquals("/www-data/tmp/alpha/beta/", info.getSecondPart("/alpha//beta/")); + @DisplayName("Test Get Second Part 6") + void testGetSecondPart6() throws Exception { + UiConnectionInfo info = new UiConnectionInfo.Builder().setAddress("192.168.1.1").enableHttps(true).setPort(8082).setPath("//www-data//tmp").build(); + assertEquals(info.getSecondPart("/alpha/"), "/www-data/tmp/alpha/"); } @Test - public void testGetSecondPart8() throws Exception { - UiConnectionInfo info = new UiConnectionInfo.Builder().setAddress("192.168.1.1").enableHttps(false) - .setPort(8082).setPath("/www-data//tmp").build(); + @DisplayName("Test Get Second Part 7") + void testGetSecondPart7() throws Exception { + UiConnectionInfo info = new UiConnectionInfo.Builder().setAddress("192.168.1.1").enableHttps(true).setPort(8082).setPath("//www-data//tmp").build(); + assertEquals(info.getSecondPart("/alpha//beta/"), "/www-data/tmp/alpha/beta/"); + } - assertEquals("http://192.168.1.1:8082/www-data/tmp/", info.getFullAddress()); + @Test + @DisplayName("Test Get Second Part 8") + void testGetSecondPart8() throws Exception { + UiConnectionInfo info = new UiConnectionInfo.Builder().setAddress("192.168.1.1").enableHttps(false).setPort(8082).setPath("/www-data//tmp").build(); + assertEquals(info.getFullAddress(), "http://192.168.1.1:8082/www-data/tmp/"); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/ArrayUtilTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/ArrayUtilTest.java index f1f36cfae..a35377962 100755 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/ArrayUtilTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/ArrayUtilTest.java @@ -17,55 +17,48 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.util; import org.deeplearning4j.BaseDL4JTest; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.util.ArrayUtil; - import java.util.Arrays; - -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** - * */ -public class ArrayUtilTest extends BaseDL4JTest { +@DisplayName("Array Util Test") +class ArrayUtilTest extends BaseDL4JTest { @Test - public void testRange() { + @DisplayName("Test Range") + void testRange() { int[] range = ArrayUtil.range(0, 2); - int[] test = {0, 1}; + int[] test = { 0, 1 }; assertEquals(true, Arrays.equals(test, range)); - - int[] test2 = {-1, 0}; + int[] test2 = { -1, 0 }; int[] range2 = ArrayUtil.range(-1, 1); assertEquals(true, Arrays.equals(test2, range2)); - } @Test - public void testStrides() { - int[] shape = {5, 4, 3}; - int[] cStyleStride = {12, 3, 1}; - int[] fortranStyleStride = {1, 5, 20}; + @DisplayName("Test Strides") + void testStrides() { + int[] shape = { 5, 4, 3 }; + int[] cStyleStride = { 12, 3, 1 }; + int[] fortranStyleStride = { 1, 5, 20 }; int[] fortranStyleTest = ArrayUtil.calcStridesFortran(shape); int[] cStyleTest = ArrayUtil.calcStrides(shape); assertEquals(true, Arrays.equals(cStyleStride, cStyleTest)); assertEquals(true, Arrays.equals(fortranStyleStride, fortranStyleTest)); - - int[] shape2 = {2, 2}; - int[] cStyleStride2 = {2, 1}; - int[] fortranStyleStride2 = {1, 2}; + int[] shape2 = { 2, 2 }; + int[] cStyleStride2 = { 2, 1 }; + int[] fortranStyleStride2 = { 1, 2 }; int[] cStyleTest2 = ArrayUtil.calcStrides(shape2); int[] fortranStyleTest2 = ArrayUtil.calcStridesFortran(shape2); assertEquals(true, Arrays.equals(cStyleStride2, cStyleTest2)); assertEquals(true, Arrays.equals(fortranStyleStride2, fortranStyleTest2)); - - - } - - } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/CrashReportingUtilTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/CrashReportingUtilTest.java index 8c6752e6b..ebf8510ce 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/CrashReportingUtilTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/CrashReportingUtilTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.util; import org.apache.commons.io.FileUtils; @@ -35,46 +34,48 @@ import org.deeplearning4j.nn.conf.layers.SubsamplingLayer; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; -import org.junit.After; +import org.junit.jupiter.api.AfterEach; import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.lossfunctions.LossFunctions; - import java.io.File; +import static org.junit.jupiter.api.Assertions.*; +import org.junit.jupiter.api.DisplayName; +import java.nio.file.Path; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.*; - -public class CrashReportingUtilTest extends BaseDL4JTest { +@DisplayName("Crash Reporting Util Test") +class CrashReportingUtilTest extends BaseDL4JTest { @Override public long getTimeoutMilliseconds() { return 120000; } - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); + @TempDir + public Path testDir; @Override - public DataType getDataType(){ + public DataType getDataType() { return DataType.FLOAT; } - @After - public void after(){ - //Reset dir + @AfterEach + void after() { + // Reset dir CrashReportingUtil.crashDumpOutputDirectory(null); } @Test - public void test() throws Exception { - File dir = testDir.newFolder(); + @DisplayName("Test") + void test() throws Exception { + File dir = testDir.toFile(); CrashReportingUtil.crashDumpOutputDirectory(dir); - int kernel = 2; int stride = 1; int padding = 0; @@ -82,57 +83,28 @@ public class CrashReportingUtilTest extends BaseDL4JTest { int inputDepth = 1; int height = 28; int width = 28; - - - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder().updater(new NoOp()) - - .dist(new NormalDistribution(0, 1)) - .list().layer(0, - new ConvolutionLayer.Builder() - .kernelSize(kernel, kernel) - .stride(stride, stride) - .padding(padding, padding) - .nIn(inputDepth) - .nOut(3).build()) - .layer(1, new SubsamplingLayer.Builder(poolingType) - .kernelSize(kernel, kernel) - .stride(stride, stride) - .padding(padding, padding) - .build()) - .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX) - .nOut(10).build()) - .setInputType(InputType.convolutionalFlat(height, width, - inputDepth)) - .build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new NoOp()).dist(new NormalDistribution(0, 1)).list().layer(0, new ConvolutionLayer.Builder().kernelSize(kernel, kernel).stride(stride, stride).padding(padding, padding).nIn(inputDepth).nOut(3).build()).layer(1, new SubsamplingLayer.Builder(poolingType).kernelSize(kernel, kernel).stride(stride, stride).padding(padding, padding).build()).layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(10).build()).setInputType(InputType.convolutionalFlat(height, width, inputDepth)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); net.addListeners(new ScoreIterationListener(1)); - - //Test net that hasn't been trained yet + // Test net that hasn't been trained yet Exception e = new Exception(); CrashReportingUtil.writeMemoryCrashDump(net, e); - File[] list = dir.listFiles(); assertNotNull(list); assertEquals(1, list.length); String str = FileUtils.readFileToString(list[0]); -// System.out.println(str); + // System.out.println(str); assertTrue(str.contains("Network Information")); assertTrue(str.contains("Layer Helpers")); assertTrue(str.contains("JavaCPP")); assertTrue(str.contains("ScoreIterationListener")); - - - //Train: + // Train: DataSetIterator iter = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(32, true, 12345), 5); net.fit(iter); - dir = testDir.newFolder(); + dir = testDir.toFile(); CrashReportingUtil.crashDumpOutputDirectory(dir); CrashReportingUtil.writeMemoryCrashDump(net, e); - list = dir.listFiles(); assertNotNull(list); assertEquals(1, list.length); @@ -141,36 +113,26 @@ public class CrashReportingUtilTest extends BaseDL4JTest { assertTrue(str.contains("Layer Helpers")); assertTrue(str.contains("JavaCPP")); assertTrue(str.contains("ScoreIterationListener(1)")); - -// System.out.println("///////////////////////////////////////////////////////////"); -// System.out.println(str); -// System.out.println("///////////////////////////////////////////////////////////"); - - - //Also test manual memory info + // System.out.println("///////////////////////////////////////////////////////////"); + // System.out.println(str); + // System.out.println("///////////////////////////////////////////////////////////"); + // Also test manual memory info String mlnMemoryInfo = net.memoryInfo(32, InputType.convolutionalFlat(28, 28, 1)); -// System.out.println("///////////////////////////////////////////////////////////"); -// System.out.println(mlnMemoryInfo); -// System.out.println("///////////////////////////////////////////////////////////"); - + // System.out.println("///////////////////////////////////////////////////////////"); + // System.out.println(mlnMemoryInfo); + // System.out.println("///////////////////////////////////////////////////////////"); assertTrue(mlnMemoryInfo.contains("Network Information")); assertTrue(mlnMemoryInfo.contains("Layer Helpers")); assertTrue(mlnMemoryInfo.contains("JavaCPP")); assertTrue(mlnMemoryInfo.contains("ScoreIterationListener(1)")); - - - - //////////////////////////////////////// - //Same thing on ComputationGraph: - dir = testDir.newFolder(); + // ////////////////////////////////////// + // Same thing on ComputationGraph: + dir = testDir.toFile(); CrashReportingUtil.crashDumpOutputDirectory(dir); - ComputationGraph cg = net.toComputationGraph(); cg.setListeners(new ScoreIterationListener(1)); - - //Test net that hasn't been trained yet + // Test net that hasn't been trained yet CrashReportingUtil.writeMemoryCrashDump(cg, e); - list = dir.listFiles(); assertNotNull(list); assertEquals(1, list.length); @@ -179,13 +141,11 @@ public class CrashReportingUtilTest extends BaseDL4JTest { assertTrue(str.contains("Layer Helpers")); assertTrue(str.contains("JavaCPP")); assertTrue(str.contains("ScoreIterationListener(1)")); - - //Train: + // Train: cg.fit(iter); - dir = testDir.newFolder(); + dir = testDir.toFile(); CrashReportingUtil.crashDumpOutputDirectory(dir); CrashReportingUtil.writeMemoryCrashDump(cg, e); - list = dir.listFiles(); assertNotNull(list); assertEquals(1, list.length); @@ -194,24 +154,17 @@ public class CrashReportingUtilTest extends BaseDL4JTest { assertTrue(str.contains("Layer Helpers")); assertTrue(str.contains("JavaCPP")); assertTrue(str.contains("ScoreIterationListener(1)")); - -// System.out.println("///////////////////////////////////////////////////////////"); -// System.out.println(str); -// System.out.println("///////////////////////////////////////////////////////////"); - - - //Also test manual memory info + // System.out.println("///////////////////////////////////////////////////////////"); + // System.out.println(str); + // System.out.println("///////////////////////////////////////////////////////////"); + // Also test manual memory info String cgMemoryInfo = cg.memoryInfo(32, InputType.convolutionalFlat(28, 28, 1)); -// System.out.println("///////////////////////////////////////////////////////////"); -// System.out.println(cgMemoryInfo); -// System.out.println("///////////////////////////////////////////////////////////"); - + // System.out.println("///////////////////////////////////////////////////////////"); + // System.out.println(cgMemoryInfo); + // System.out.println("///////////////////////////////////////////////////////////"); assertTrue(cgMemoryInfo.contains("Network Information")); assertTrue(cgMemoryInfo.contains("Layer Helpers")); assertTrue(cgMemoryInfo.contains("JavaCPP")); assertTrue(cgMemoryInfo.contains("ScoreIterationListener(1)")); - } - - } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/ModelGuesserTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/ModelGuesserTest.java index a63a0eb34..fdac1af4e 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/ModelGuesserTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/ModelGuesserTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.util; import org.apache.commons.compress.utils.IOUtils; @@ -31,10 +30,10 @@ import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.Ignore; +import org.junit.jupiter.api.Disabled; import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; import org.junit.rules.Timeout; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.dataset.DataSet; @@ -45,25 +44,27 @@ import org.nd4j.common.io.ClassPathResource; import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.common.resources.Resources; - import java.io.*; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.Assume.assumeNotNull; +import org.junit.jupiter.api.DisplayName; +import java.nio.file.Path; +import org.junit.jupiter.api.extension.ExtendWith; -@Ignore -public class ModelGuesserTest extends BaseDL4JTest { +@Disabled +@DisplayName("Model Guesser Test") +class ModelGuesserTest extends BaseDL4JTest { - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); + @TempDir + public Path testDir; @Rule public Timeout timeout = Timeout.seconds(300); - @Test - public void testModelGuessFile() throws Exception { + @DisplayName("Test Model Guess File") + void testModelGuessFile() throws Exception { File f = Resources.asFile("modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_1_model.h5"); assertTrue(f.exists()); Model guess1 = ModelGuesser.loadModelGuess(f.getAbsolutePath()); @@ -72,76 +73,62 @@ public class ModelGuesserTest extends BaseDL4JTest { assertTrue(f.exists()); Model guess2 = ModelGuesser.loadModelGuess(f.getAbsolutePath()); assumeNotNull(guess2); - } @Test - public void testModelGuessInputStream() throws Exception { + @DisplayName("Test Model Guess Input Stream") + void testModelGuessInputStream() throws Exception { File f = Resources.asFile("modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_1_model.h5"); assertTrue(f.exists()); - try (InputStream inputStream = new FileInputStream(f)) { Model guess1 = ModelGuesser.loadModelGuess(inputStream); assumeNotNull(guess1); } - f = Resources.asFile("modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_1_model.h5"); assertTrue(f.exists()); - try (InputStream inputStream = new FileInputStream(f)) { Model guess1 = ModelGuesser.loadModelGuess(inputStream); assumeNotNull(guess1); } } - - @Test - public void testLoadNormalizersFile() throws Exception { + @DisplayName("Test Load Normalizers File") + void testLoadNormalizersFile() throws Exception { MultiLayerNetwork net = getNetwork(); - - File tempFile = testDir.newFile("testLoadNormalizersFile.bin"); - + File tempFile = testDir.resolve("testLoadNormalizersFile.bin").toFile(); ModelSerializer.writeModel(net, tempFile, true); - NormalizerMinMaxScaler normalizer = new NormalizerMinMaxScaler(0, 1); - normalizer.fit(new DataSet(Nd4j.rand(new int[] {2, 2}), Nd4j.rand(new int[] {2, 2}))); + normalizer.fit(new DataSet(Nd4j.rand(new int[] { 2, 2 }), Nd4j.rand(new int[] { 2, 2 }))); ModelSerializer.addNormalizerToModel(tempFile, normalizer); Model model = ModelGuesser.loadModelGuess(tempFile.getAbsolutePath()); Normalizer normalizer1 = ModelGuesser.loadNormalizer(tempFile.getAbsolutePath()); assertEquals(model, net); assertEquals(normalizer, normalizer1); - } - @Test - public void testNormalizerInPlace() throws Exception { + @DisplayName("Test Normalizer In Place") + void testNormalizerInPlace() throws Exception { MultiLayerNetwork net = getNetwork(); - - File tempFile = testDir.newFile("testNormalizerInPlace.bin"); - + File tempFile = testDir.resolve("testNormalizerInPlace.bin").toFile(); NormalizerMinMaxScaler normalizer = new NormalizerMinMaxScaler(0, 1); - normalizer.fit(new DataSet(Nd4j.rand(new int[] {2, 2}), Nd4j.rand(new int[] {2, 2}))); - ModelSerializer.writeModel(net, tempFile, true,normalizer); - + normalizer.fit(new DataSet(Nd4j.rand(new int[] { 2, 2 }), Nd4j.rand(new int[] { 2, 2 }))); + ModelSerializer.writeModel(net, tempFile, true, normalizer); Model model = ModelGuesser.loadModelGuess(tempFile.getAbsolutePath()); Normalizer normalizer1 = ModelGuesser.loadNormalizer(tempFile.getAbsolutePath()); assertEquals(model, net); assertEquals(normalizer, normalizer1); - } @Test - public void testLoadNormalizersInputStream() throws Exception { + @DisplayName("Test Load Normalizers Input Stream") + void testLoadNormalizersInputStream() throws Exception { MultiLayerNetwork net = getNetwork(); - - File tempFile = testDir.newFile("testLoadNormalizersInputStream.bin"); - + File tempFile = testDir.resolve("testLoadNormalizersInputStream.bin").toFile(); ModelSerializer.writeModel(net, tempFile, true); - NormalizerMinMaxScaler normalizer = new NormalizerMinMaxScaler(0, 1); - normalizer.fit(new DataSet(Nd4j.rand(new int[] {2, 2}), Nd4j.rand(new int[] {2, 2}))); + normalizer.fit(new DataSet(Nd4j.rand(new int[] { 2, 2 }), Nd4j.rand(new int[] { 2, 2 }))); ModelSerializer.addNormalizerToModel(tempFile, normalizer); Model model = ModelGuesser.loadModelGuess(tempFile.getAbsolutePath()); try (InputStream inputStream = new FileInputStream(tempFile)) { @@ -149,33 +136,26 @@ public class ModelGuesserTest extends BaseDL4JTest { assertEquals(model, net); assertEquals(normalizer, normalizer1); } - } - @Test - public void testModelGuesserDl4jModelFile() throws Exception { + @DisplayName("Test Model Guesser Dl 4 j Model File") + void testModelGuesserDl4jModelFile() throws Exception { MultiLayerNetwork net = getNetwork(); - - File tempFile = testDir.newFile("testModelGuesserDl4jModelFile.bin"); - + File tempFile = testDir.resolve("testModelGuesserDl4jModelFile.bin").toFile(); ModelSerializer.writeModel(net, tempFile, true); - MultiLayerNetwork network = (MultiLayerNetwork) ModelGuesser.loadModelGuess(tempFile.getAbsolutePath()); assertEquals(network.getLayerWiseConfigurations().toJson(), net.getLayerWiseConfigurations().toJson()); assertEquals(net.params(), network.params()); assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray()); - } @Test - public void testModelGuesserDl4jModelInputStream() throws Exception { + @DisplayName("Test Model Guesser Dl 4 j Model Input Stream") + void testModelGuesserDl4jModelInputStream() throws Exception { MultiLayerNetwork net = getNetwork(); - - File tempFile = testDir.newFile("testModelGuesserDl4jModelInputStream.bin"); - + File tempFile = testDir.resolve("testModelGuesserDl4jModelInputStream.bin").toFile(); ModelSerializer.writeModel(net, tempFile, true); - try (InputStream inputStream = new FileInputStream(tempFile)) { MultiLayerNetwork network = (MultiLayerNetwork) ModelGuesser.loadModelGuess(inputStream); assumeNotNull(network); @@ -185,65 +165,51 @@ public class ModelGuesserTest extends BaseDL4JTest { } } - @Test - public void testModelGuessConfigFile() throws Exception { - ClassPathResource resource = new ClassPathResource("modelimport/keras/configs/cnn_tf_config.json", - ModelGuesserTest.class.getClassLoader()); + @DisplayName("Test Model Guess Config File") + void testModelGuessConfigFile() throws Exception { + ClassPathResource resource = new ClassPathResource("modelimport/keras/configs/cnn_tf_config.json", ModelGuesserTest.class.getClassLoader()); File f = getTempFile(resource); String configFilename = f.getAbsolutePath(); Object conf = ModelGuesser.loadConfigGuess(configFilename); assertTrue(conf instanceof MultiLayerConfiguration); - ClassPathResource sequenceResource = new ClassPathResource("/keras/simple/mlp_fapi_multiloss_config.json"); File f2 = getTempFile(sequenceResource); Object sequenceConf = ModelGuesser.loadConfigGuess(f2.getAbsolutePath()); assertTrue(sequenceConf instanceof ComputationGraphConfiguration); - - - ClassPathResource resourceDl4j = new ClassPathResource("model.json"); File fDl4j = getTempFile(resourceDl4j); String configFilenameDl4j = fDl4j.getAbsolutePath(); Object confDl4j = ModelGuesser.loadConfigGuess(configFilenameDl4j); assertTrue(confDl4j instanceof ComputationGraphConfiguration); - } @Test - public void testModelGuessConfigInputStream() throws Exception { - ClassPathResource resource = new ClassPathResource("modelimport/keras/configs/cnn_tf_config.json", - ModelGuesserTest.class.getClassLoader()); + @DisplayName("Test Model Guess Config Input Stream") + void testModelGuessConfigInputStream() throws Exception { + ClassPathResource resource = new ClassPathResource("modelimport/keras/configs/cnn_tf_config.json", ModelGuesserTest.class.getClassLoader()); File f = getTempFile(resource); - try (InputStream inputStream = new FileInputStream(f)) { Object conf = ModelGuesser.loadConfigGuess(inputStream); assertTrue(conf instanceof MultiLayerConfiguration); } - ClassPathResource sequenceResource = new ClassPathResource("/keras/simple/mlp_fapi_multiloss_config.json"); File f2 = getTempFile(sequenceResource); - try (InputStream inputStream = new FileInputStream(f2)) { Object sequenceConf = ModelGuesser.loadConfigGuess(inputStream); assertTrue(sequenceConf instanceof ComputationGraphConfiguration); } - - ClassPathResource resourceDl4j = new ClassPathResource("model.json"); File fDl4j = getTempFile(resourceDl4j); - try (InputStream inputStream = new FileInputStream(fDl4j)) { Object confDl4j = ModelGuesser.loadConfigGuess(inputStream); assertTrue(confDl4j instanceof ComputationGraphConfiguration); } - } - private File getTempFile(ClassPathResource classPathResource) throws Exception { InputStream is = classPathResource.getInputStream(); - File f = testDir.newFile(); + File f = testDir.toFile(); BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(f)); IOUtils.copy(is, bos); bos.flush(); @@ -254,18 +220,9 @@ public class ModelGuesserTest extends BaseDL4JTest { private MultiLayerNetwork getNetwork() { int nIn = 5; int nOut = 6; - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).l1(0.01).l2(0.01) - .updater(new Sgd(0.1)).activation(Activation.TANH).weightInit(WeightInit.XAVIER).list() - .layer(0, new DenseLayer.Builder().nIn(nIn).nOut(20).build()) - .layer(1, new DenseLayer.Builder().nIn(20).nOut(30).build()).layer(2, new OutputLayer.Builder() - .lossFunction(LossFunctions.LossFunction.MSE).nIn(30).nOut(nOut).build()) - .build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).l1(0.01).l2(0.01).updater(new Sgd(0.1)).activation(Activation.TANH).weightInit(WeightInit.XAVIER).list().layer(0, new DenseLayer.Builder().nIn(nIn).nOut(20).build()).layer(1, new DenseLayer.Builder().nIn(20).nOut(30).build()).layer(2, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(30).nOut(nOut).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - return net; } - } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/ModelSerializerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/ModelSerializerTest.java index 0c98afcba..e2d128bb8 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/ModelSerializerTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/ModelSerializerTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.util; import lombok.val; @@ -34,8 +33,8 @@ import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; @@ -47,456 +46,308 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.common.primitives.Pair; - import java.io.File; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.InputStream; import java.util.*; +import static org.junit.jupiter.api.Assertions.*; +import org.junit.jupiter.api.DisplayName; +import java.nio.file.Path; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.*; +@DisplayName("Model Serializer Test") +class ModelSerializerTest extends BaseDL4JTest { -public class ModelSerializerTest extends BaseDL4JTest { - - @Rule - public TemporaryFolder tempDir = new TemporaryFolder(); + @TempDir + public Path tempDir; @Test - public void testWriteMLNModel() throws Exception { + @DisplayName("Test Write MLN Model") + void testWriteMLNModel() throws Exception { int nIn = 5; int nOut = 6; - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).l1(0.01) - .l2(0.01).updater(new Sgd(0.1)).activation(Activation.TANH).weightInit(WeightInit.XAVIER).list() - .layer(0, new DenseLayer.Builder().nIn(nIn).nOut(20).build()) - .layer(1, new DenseLayer.Builder().nIn(20).nOut(30).build()).layer(2, new OutputLayer.Builder() - .lossFunction(LossFunctions.LossFunction.MSE).nIn(30).nOut(nOut).build()) - .build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).l1(0.01).l2(0.01).updater(new Sgd(0.1)).activation(Activation.TANH).weightInit(WeightInit.XAVIER).list().layer(0, new DenseLayer.Builder().nIn(nIn).nOut(20).build()).layer(1, new DenseLayer.Builder().nIn(20).nOut(30).build()).layer(2, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(30).nOut(nOut).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - - File tempFile = tempDir.newFile(); - + File tempFile = tempDir.toFile(); ModelSerializer.writeModel(net, tempFile, true); - MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(tempFile); - assertEquals(network.getLayerWiseConfigurations().toJson(), net.getLayerWiseConfigurations().toJson()); assertEquals(net.params(), network.params()); assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray()); } @Test - public void testWriteMlnModelInputStream() throws Exception { + @DisplayName("Test Write Mln Model Input Stream") + void testWriteMlnModelInputStream() throws Exception { int nIn = 5; int nOut = 6; - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).l1(0.01) - .l2(0.01).updater(new Sgd(0.1)).activation(Activation.TANH).weightInit(WeightInit.XAVIER).list() - .layer(0, new DenseLayer.Builder().nIn(nIn).nOut(20).build()) - .layer(1, new DenseLayer.Builder().nIn(20).nOut(30).build()).layer(2, new OutputLayer.Builder() - .lossFunction(LossFunctions.LossFunction.MSE).nIn(30).nOut(nOut).build()) - .build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).l1(0.01).l2(0.01).updater(new Sgd(0.1)).activation(Activation.TANH).weightInit(WeightInit.XAVIER).list().layer(0, new DenseLayer.Builder().nIn(nIn).nOut(20).build()).layer(1, new DenseLayer.Builder().nIn(20).nOut(30).build()).layer(2, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(30).nOut(nOut).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - - File tempFile = tempDir.newFile(); + File tempFile = tempDir.toFile(); FileOutputStream fos = new FileOutputStream(tempFile); - ModelSerializer.writeModel(net, fos, true); - - // checking adding of DataNormalization to the model file - NormalizerMinMaxScaler scaler = new NormalizerMinMaxScaler(); DataSetIterator iter = new IrisDataSetIterator(150, 150); scaler.fit(iter); - ModelSerializer.addNormalizerToModel(tempFile, scaler); - NormalizerMinMaxScaler restoredScaler = ModelSerializer.restoreNormalizerFromFile(tempFile); - assertNotEquals(null, scaler.getMax()); assertEquals(scaler.getMax(), restoredScaler.getMax()); assertEquals(scaler.getMin(), restoredScaler.getMin()); - FileInputStream fis = new FileInputStream(tempFile); - MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(fis); - assertEquals(network.getLayerWiseConfigurations().toJson(), net.getLayerWiseConfigurations().toJson()); assertEquals(net.params(), network.params()); assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray()); } - @Test - public void testWriteCGModel() throws Exception { - ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Sgd(0.1)) - .graphBuilder().addInputs("in") - .addLayer("dense", new DenseLayer.Builder().nIn(4).nOut(2).build(), "in").addLayer("out", - new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3) - .activation(Activation.SOFTMAX).build(), - "dense") - .setOutputs("out").build(); - + @DisplayName("Test Write CG Model") + void testWriteCGModel() throws Exception { + ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Sgd(0.1)).graphBuilder().addInputs("in").addLayer("dense", new DenseLayer.Builder().nIn(4).nOut(2).build(), "in").addLayer("out", new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3).activation(Activation.SOFTMAX).build(), "dense").setOutputs("out").build(); ComputationGraph cg = new ComputationGraph(config); cg.init(); - - File tempFile = tempDir.newFile(); - + File tempFile = tempDir.toFile(); ModelSerializer.writeModel(cg, tempFile, true); - ComputationGraph network = ModelSerializer.restoreComputationGraph(tempFile); - assertEquals(network.getConfiguration().toJson(), cg.getConfiguration().toJson()); assertEquals(cg.params(), network.params()); assertEquals(cg.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray()); } @Test - public void testWriteCGModelInputStream() throws Exception { - ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Sgd(0.1)) - .graphBuilder().addInputs("in") - .addLayer("dense", new DenseLayer.Builder().nIn(4).nOut(2).build(), "in").addLayer("out", - new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3) - .activation(Activation.SOFTMAX).build(), - "dense") - .setOutputs("out").build(); - + @DisplayName("Test Write CG Model Input Stream") + void testWriteCGModelInputStream() throws Exception { + ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Sgd(0.1)).graphBuilder().addInputs("in").addLayer("dense", new DenseLayer.Builder().nIn(4).nOut(2).build(), "in").addLayer("out", new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3).activation(Activation.SOFTMAX).build(), "dense").setOutputs("out").build(); ComputationGraph cg = new ComputationGraph(config); cg.init(); - - File tempFile = tempDir.newFile(); - + File tempFile = tempDir.toFile(); ModelSerializer.writeModel(cg, tempFile, true); FileInputStream fis = new FileInputStream(tempFile); - ComputationGraph network = ModelSerializer.restoreComputationGraph(fis); - assertEquals(network.getConfiguration().toJson(), cg.getConfiguration().toJson()); assertEquals(cg.params(), network.params()); assertEquals(cg.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray()); } private DataSet trivialDataSet() { - INDArray inputs = Nd4j.create(new float[] {1.0f, 2.0f, 3.0f}, new int[]{1,3}); - INDArray labels = Nd4j.create(new float[] {4.0f, 5.0f, 6.0f}, new int[]{1,3}); + INDArray inputs = Nd4j.create(new float[] { 1.0f, 2.0f, 3.0f }, new int[] { 1, 3 }); + INDArray labels = Nd4j.create(new float[] { 4.0f, 5.0f, 6.0f }, new int[] { 1, 3 }); return new DataSet(inputs, labels); } private ComputationGraph simpleComputationGraph() { - ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Sgd(0.1)) - .graphBuilder().addInputs("in") - .addLayer("dense", new DenseLayer.Builder().nIn(4).nOut(2).build(), "in").addLayer("out", - new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3) - .activation(Activation.SOFTMAX).build(), - "dense") - .setOutputs("out").build(); - + ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Sgd(0.1)).graphBuilder().addInputs("in").addLayer("dense", new DenseLayer.Builder().nIn(4).nOut(2).build(), "in").addLayer("out", new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3).activation(Activation.SOFTMAX).build(), "dense").setOutputs("out").build(); return new ComputationGraph(config); } @Test - public void testSaveRestoreNormalizerFromInputStream() throws Exception { + @DisplayName("Test Save Restore Normalizer From Input Stream") + void testSaveRestoreNormalizerFromInputStream() throws Exception { DataSet dataSet = trivialDataSet(); NormalizerStandardize norm = new NormalizerStandardize(); norm.fit(dataSet); - ComputationGraph cg = simpleComputationGraph(); cg.init(); - - File tempFile = tempDir.newFile(); - + File tempFile = tempDir.toFile(); ModelSerializer.writeModel(cg, tempFile, true); - ModelSerializer.addNormalizerToModel(tempFile, norm); FileInputStream fis = new FileInputStream(tempFile); - - NormalizerStandardize restored = ModelSerializer.restoreNormalizerFromInputStream(fis); - assertNotEquals(null, restored); - DataSet dataSet2 = dataSet.copy(); - norm.preProcess(dataSet2); assertNotEquals(dataSet.getFeatures(), dataSet2.getFeatures()); - restored.revert(dataSet2); assertEquals(dataSet.getFeatures(), dataSet2.getFeatures()); } @Test - public void testRestoreUnsavedNormalizerFromInputStream() throws Exception { + @DisplayName("Test Restore Unsaved Normalizer From Input Stream") + void testRestoreUnsavedNormalizerFromInputStream() throws Exception { DataSet dataSet = trivialDataSet(); - NormalizerStandardize norm = new NormalizerStandardize(); norm.fit(dataSet); - ComputationGraph cg = simpleComputationGraph(); cg.init(); - - File tempFile = tempDir.newFile(); + File tempFile = tempDir.toFile(); ModelSerializer.writeModel(cg, tempFile, true); - FileInputStream fis = new FileInputStream(tempFile); - NormalizerStandardize restored = ModelSerializer.restoreNormalizerFromInputStream(fis); - assertEquals(null, restored); } @Test - public void testInvalidLoading1() throws Exception { - ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder() - .graphBuilder().addInputs("in") - .addLayer("dense", new DenseLayer.Builder().nIn(4).nOut(2).build(), "in") - .addLayer("out",new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nIn(2).nOut(3).build(), - "dense") - .setOutputs("out").build(); - + @DisplayName("Test Invalid Loading 1") + void testInvalidLoading1() throws Exception { + ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in").addLayer("dense", new DenseLayer.Builder().nIn(4).nOut(2).build(), "in").addLayer("out", new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(2).nOut(3).build(), "dense").setOutputs("out").build(); ComputationGraph cg = new ComputationGraph(config); cg.init(); - - File tempFile = tempDir.newFile(); - + File tempFile = tempDir.toFile(); ModelSerializer.writeModel(cg, tempFile, true); - try { ModelSerializer.restoreMultiLayerNetwork(tempFile); fail(); - } catch (Exception e){ + } catch (Exception e) { String msg = e.getMessage(); - assertTrue(msg, msg.contains("JSON") && msg.contains("restoreComputationGraph")); + assertTrue(msg.contains("JSON") && msg.contains("restoreComputationGraph"),msg); } } @Test - public void testInvalidLoading2() throws Exception { + @DisplayName("Test Invalid Loading 2") + void testInvalidLoading2() throws Exception { int nIn = 5; int nOut = 6; - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).l1(0.01) - .l2(0.01).updater(new Sgd(0.1)).activation(Activation.TANH).weightInit(WeightInit.XAVIER).list() - .layer(0, new DenseLayer.Builder().nIn(nIn).nOut(20).build()) - .layer(1, new DenseLayer.Builder().nIn(20).nOut(30).build()).layer(2, new OutputLayer.Builder() - .lossFunction(LossFunctions.LossFunction.MSE).nIn(30).nOut(nOut).build()) - .build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).l1(0.01).l2(0.01).updater(new Sgd(0.1)).activation(Activation.TANH).weightInit(WeightInit.XAVIER).list().layer(0, new DenseLayer.Builder().nIn(nIn).nOut(20).build()).layer(1, new DenseLayer.Builder().nIn(20).nOut(30).build()).layer(2, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(30).nOut(nOut).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - - File tempFile = tempDir.newFile("testInvalidLoading2.bin"); - + File tempFile = tempDir.resolve("testInvalidLoading2.bin").toFile(); ModelSerializer.writeModel(net, tempFile, true); - try { ModelSerializer.restoreComputationGraph(tempFile); fail(); - } catch (Exception e){ + } catch (Exception e) { String msg = e.getMessage(); - assertTrue(msg, msg.contains("JSON") && msg.contains("restoreMultiLayerNetwork")); + assertTrue(msg.contains("JSON") && msg.contains("restoreMultiLayerNetwork"),msg); } } @Test - public void testInvalidStreamReuse() throws Exception { + @DisplayName("Test Invalid Stream Reuse") + void testInvalidStreamReuse() throws Exception { int nIn = 5; int nOut = 6; - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).l1(0.01) - .list() - .layer(new OutputLayer.Builder().nIn(nIn).nOut(nOut).activation(Activation.SOFTMAX).build()) - .build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).l1(0.01).list().layer(new OutputLayer.Builder().nIn(nIn).nOut(nOut).activation(Activation.SOFTMAX).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - DataSet dataSet = trivialDataSet(); NormalizerStandardize norm = new NormalizerStandardize(); norm.fit(dataSet); - - File tempFile = tempDir.newFile(); + File tempFile = tempDir.toFile(); ModelSerializer.writeModel(net, tempFile, true); ModelSerializer.addNormalizerToModel(tempFile, norm); - InputStream is = new FileInputStream(tempFile); ModelSerializer.restoreMultiLayerNetwork(is); - - try{ + try { ModelSerializer.restoreNormalizerFromInputStream(is); fail("Expected exception"); - } catch (Exception e){ + } catch (Exception e) { String msg = e.getMessage(); - assertTrue(msg, msg.contains("may have been closed")); + assertTrue(msg.contains("may have been closed"),msg); } - - try{ + try { ModelSerializer.restoreMultiLayerNetwork(is); fail("Expected exception"); - } catch (Exception e){ + } catch (Exception e) { String msg = e.getMessage(); - assertTrue(msg, msg.contains("may have been closed")); + assertTrue(msg.contains("may have been closed"),msg); } - - //Also test reading both model and normalizer from stream (correctly) - Pair pair = ModelSerializer.restoreMultiLayerNetworkAndNormalizer(new FileInputStream(tempFile), true); + // Also test reading both model and normalizer from stream (correctly) + Pair pair = ModelSerializer.restoreMultiLayerNetworkAndNormalizer(new FileInputStream(tempFile), true); assertEquals(net.params(), pair.getFirst().params()); assertNotNull(pair.getSecond()); } - @Test - public void testInvalidStreamReuseCG() throws Exception { + @DisplayName("Test Invalid Stream Reuse CG") + void testInvalidStreamReuseCG() throws Exception { int nIn = 5; int nOut = 6; - - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).l1(0.01) - .graphBuilder() - .addInputs("in") - .layer("0", new OutputLayer.Builder().nIn(nIn).nOut(nOut).activation(Activation.SOFTMAX).build(), "in") - .setOutputs("0") - .build(); - + ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).l1(0.01).graphBuilder().addInputs("in").layer("0", new OutputLayer.Builder().nIn(nIn).nOut(nOut).activation(Activation.SOFTMAX).build(), "in").setOutputs("0").build(); ComputationGraph net = new ComputationGraph(conf); net.init(); - DataSet dataSet = trivialDataSet(); NormalizerStandardize norm = new NormalizerStandardize(); norm.fit(dataSet); - - File tempFile = tempDir.newFile(); + File tempFile = tempDir.toFile(); ModelSerializer.writeModel(net, tempFile, true); ModelSerializer.addNormalizerToModel(tempFile, norm); - InputStream is = new FileInputStream(tempFile); ModelSerializer.restoreComputationGraph(is); - - try{ + try { ModelSerializer.restoreNormalizerFromInputStream(is); fail("Expected exception"); - } catch (Exception e){ + } catch (Exception e) { String msg = e.getMessage(); - assertTrue(msg, msg.contains("may have been closed")); + assertTrue(msg.contains("may have been closed"),msg); } - - try{ + try { ModelSerializer.restoreComputationGraph(is); fail("Expected exception"); - } catch (Exception e){ + } catch (Exception e) { String msg = e.getMessage(); - assertTrue(msg, msg.contains("may have been closed")); + assertTrue(msg.contains("may have been closed"),msg); } - - //Also test reading both model and normalizer from stream (correctly) - Pair pair = ModelSerializer.restoreComputationGraphAndNormalizer(new FileInputStream(tempFile), true); + // Also test reading both model and normalizer from stream (correctly) + Pair pair = ModelSerializer.restoreComputationGraphAndNormalizer(new FileInputStream(tempFile), true); assertEquals(net.params(), pair.getFirst().params()); assertNotNull(pair.getSecond()); } - @Test - public void testJavaSerde_1() throws Exception { + @DisplayName("Test Java Serde _ 1") + void testJavaSerde_1() throws Exception { int nIn = 5; int nOut = 6; - - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).l1(0.01) - .graphBuilder() - .addInputs("in") - .layer("0", new OutputLayer.Builder().nIn(nIn).nOut(nOut).build(), "in") - .setOutputs("0") - .validateOutputLayerConfig(false) - .build(); - + ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).l1(0.01).graphBuilder().addInputs("in").layer("0", new OutputLayer.Builder().nIn(nIn).nOut(nOut).build(), "in").setOutputs("0").validateOutputLayerConfig(false).build(); ComputationGraph net = new ComputationGraph(conf); net.init(); - DataSet dataSet = trivialDataSet(); NormalizerStandardize norm = new NormalizerStandardize(); norm.fit(dataSet); - val b = SerializationUtils.serialize(net); - ComputationGraph restored = SerializationUtils.deserialize(b); - assertEquals(net, restored); } @Test - public void testJavaSerde_2() throws Exception { + @DisplayName("Test Java Serde _ 2") + void testJavaSerde_2() throws Exception { int nIn = 5; int nOut = 6; - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).l1(0.01) - .list() - .layer(0, new OutputLayer.Builder().nIn(nIn).nOut(nOut).activation(Activation.SOFTMAX).build()) - .build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).l1(0.01).list().layer(0, new OutputLayer.Builder().nIn(nIn).nOut(nOut).activation(Activation.SOFTMAX).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - DataSet dataSet = trivialDataSet(); NormalizerStandardize norm = new NormalizerStandardize(); norm.fit(dataSet); - val b = SerializationUtils.serialize(net); - MultiLayerNetwork restored = SerializationUtils.deserialize(b); - assertEquals(net, restored); } @Test - public void testPutGetObject() throws Exception { - + @DisplayName("Test Put Get Object") + void testPutGetObject() throws Exception { int nIn = 5; int nOut = 6; - - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).l1(0.01) - .graphBuilder() - .addInputs("in") - .layer("0", new OutputLayer.Builder().nIn(nIn).nOut(nOut).activation(Activation.SOFTMAX).build(), "in") - .setOutputs("0") - .build(); - + ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).l1(0.01).graphBuilder().addInputs("in").layer("0", new OutputLayer.Builder().nIn(nIn).nOut(nOut).activation(Activation.SOFTMAX).build(), "in").setOutputs("0").build(); ComputationGraph net = new ComputationGraph(conf); net.init(); - - File tempFile = tempDir.newFile(); + File tempFile = tempDir.toFile(); ModelSerializer.writeModel(net, tempFile, true); - - List toWrite = Arrays.asList("zero", "one", "two"); ModelSerializer.addObjectToFile(tempFile, "myLabels", toWrite); List restored = ModelSerializer.getObjectFromFile(tempFile, "myLabels"); assertEquals(toWrite, restored); - - - Map someOtherData = new HashMap<>(); - someOtherData.put("x", new float[]{0,1,2}); - someOtherData.put("y",Nd4j.linspace(1,10,10, Nd4j.dataType())); - + Map someOtherData = new HashMap<>(); + someOtherData.put("x", new float[] { 0, 1, 2 }); + someOtherData.put("y", Nd4j.linspace(1, 10, 10, Nd4j.dataType())); ModelSerializer.addObjectToFile(tempFile, "otherData.bin", someOtherData); - - Map dataRestored = ModelSerializer.getObjectFromFile(tempFile, "otherData.bin"); + Map dataRestored = ModelSerializer.getObjectFromFile(tempFile, "otherData.bin"); assertEquals(someOtherData.keySet(), dataRestored.keySet()); - assertArrayEquals((float[])someOtherData.get("x"), (float[])dataRestored.get("x"), 0f); + assertArrayEquals((float[]) someOtherData.get("x"), (float[]) dataRestored.get("x"), 0f); assertEquals(someOtherData.get("y"), dataRestored.get("y")); - - List entries = ModelSerializer.listObjectsInFile(tempFile); assertEquals(2, entries.size()); System.out.println(entries); assertTrue(entries.contains("myLabels")); assertTrue(entries.contains("otherData.bin")); - ComputationGraph restoredNet = ModelSerializer.restoreComputationGraph(tempFile); assertEquals(net.params(), restoredNet.params()); } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/MovingWindowMatrixTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/MovingWindowMatrixTest.java index 47e05b772..6c0557619 100755 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/MovingWindowMatrixTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/MovingWindowMatrixTest.java @@ -17,23 +17,24 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.util; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.core.util.MovingWindowMatrix; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; - import java.util.List; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.assertEquals; - -public class MovingWindowMatrixTest extends BaseDL4JTest { +@DisplayName("Moving Window Matrix Test") +class MovingWindowMatrixTest extends BaseDL4JTest { @Test - public void testMovingWindow() { + @DisplayName("Test Moving Window") + void testMovingWindow() { INDArray ones = Nd4j.ones(4, 4); org.deeplearning4j.core.util.MovingWindowMatrix m = new org.deeplearning4j.core.util.MovingWindowMatrix(ones, 2, 2); List windows = m.windows(); @@ -41,10 +42,5 @@ public class MovingWindowMatrixTest extends BaseDL4JTest { org.deeplearning4j.core.util.MovingWindowMatrix m2 = new MovingWindowMatrix(ones, 2, 2, true); List windowsRotate = m2.windows(); assertEquals(16, windowsRotate.size()); - - } - - - } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/SerializationUtilsTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/SerializationUtilsTest.java index 2bfd6c536..cabbdf369 100755 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/SerializationUtilsTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/SerializationUtilsTest.java @@ -17,41 +17,38 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.util; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.util.SerializationUtils; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; - import java.io.File; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import java.nio.file.Path; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.assertEquals; +@DisplayName("Serialization Utils Test") +class SerializationUtilsTest extends BaseDL4JTest { -public class SerializationUtilsTest extends BaseDL4JTest { - - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); + @TempDir + public Path testDir; @Test - public void testWriteRead() throws Exception { + @DisplayName("Test Write Read") + void testWriteRead() throws Exception { DataSetIterator iter = new IrisDataSetIterator(150, 150); String irisData = "irisData.dat"; - DataSet freshDataSet = iter.next(150); - - File f = testDir.newFile(irisData); + File f = testDir.resolve(irisData).toFile(); SerializationUtils.saveObject(freshDataSet, f); - DataSet readDataSet = SerializationUtils.readObject(f); - assertEquals(freshDataSet.getFeatures(), readDataSet.getFeatures()); assertEquals(freshDataSet.getLabels(), readDataSet.getLabels()); } - } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/TimeSeriesUtilsTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/TimeSeriesUtilsTest.java index bb652f670..2c8e1dfb7 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/TimeSeriesUtilsTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/TimeSeriesUtilsTest.java @@ -17,27 +17,26 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.util; import org.deeplearning4j.BaseDL4JTest; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.assertEquals; - -public class TimeSeriesUtilsTest extends BaseDL4JTest { +@DisplayName("Time Series Utils Test") +class TimeSeriesUtilsTest extends BaseDL4JTest { @Test - public void testMovingAverage() { + @DisplayName("Test Moving Average") + void testMovingAverage() { INDArray a = Nd4j.arange(0, 20).castTo(DataType.DOUBLE); - INDArray result = Nd4j.create(new double[] {1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f, 8.5f, 9.5f, 10.5f, 11.5f, - 12.5f, 13.5f, 14.5f, 15.5f, 16.5f, 17.5f}); - + INDArray result = Nd4j.create(new double[] { 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f, 8.5f, 9.5f, 10.5f, 11.5f, 12.5f, 13.5f, 14.5f, 15.5f, 16.5f, 17.5f }); INDArray movingAvg = TimeSeriesUtils.movingAverage(a, 4); assertEquals(result, movingAvg); } - } diff --git a/deeplearning4j/deeplearning4j-cuda/pom.xml b/deeplearning4j/deeplearning4j-cuda/pom.xml index e0b0b04fc..3c12fbbc3 100644 --- a/deeplearning4j/deeplearning4j-cuda/pom.xml +++ b/deeplearning4j/deeplearning4j-cuda/pom.xml @@ -76,10 +76,18 @@ org.bytedeco cuda-platform ${cuda.version}-${cudnn.version}-${javacpp-presets.cuda.version} + + + org.junit.jupiter + junit-jupiter-api + ${junit.version} + test - junit - junit + org.junit.jupiter + junit-jupiter-engine + ${junit.version} + test org.deeplearning4j diff --git a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/cuda/gradientcheck/CNNGradientCheckTest.java b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/cuda/gradientcheck/CNNGradientCheckTest.java index cb8311be6..6bedc0389 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/cuda/gradientcheck/CNNGradientCheckTest.java +++ b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/cuda/gradientcheck/CNNGradientCheckTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.cuda.gradientcheck; import lombok.val; @@ -36,8 +35,8 @@ import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -45,21 +44,27 @@ import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.lossfunctions.LossFunctions; - import java.util.Arrays; - import static org.deeplearning4j.nn.conf.ConvolutionMode.Same; import static org.deeplearning4j.nn.conf.ConvolutionMode.Truncate; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** * Created by nyghtowl on 9/1/15. */ -public class CNNGradientCheckTest extends BaseDL4JTest { +@DisplayName("Cnn Gradient Check Test") +class CNNGradientCheckTest extends BaseDL4JTest { + private static final boolean PRINT_RESULTS = true; + private static final boolean RETURN_ON_FIRST_FAILURE = false; + private static final double DEFAULT_EPS = 1e-6; + private static final double DEFAULT_MAX_REL_ERROR = 1e-3; + private static final double DEFAULT_MIN_ABS_ERROR = 1e-8; static { @@ -72,72 +77,50 @@ public class CNNGradientCheckTest extends BaseDL4JTest { } @Test - public void testGradientCNNMLN() { - //Parameterized test, testing combinations of: + @DisplayName("Test Gradient CNNMLN") + void testGradientCNNMLN() { + // Parameterized test, testing combinations of: // (a) activation function // (b) Whether to test at random initialization, or after some learning (i.e., 'characteristic mode of operation') // (c) Loss function (with specified output activations) - Activation[] activFns = {Activation.SIGMOID, Activation.TANH}; - boolean[] characteristic = {false, true}; //If true: run some backprop steps first - - LossFunctions.LossFunction[] lossFunctions = - {LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD, LossFunctions.LossFunction.MSE}; - Activation[] outputActivations = {Activation.SOFTMAX, Activation.TANH}; //i.e., lossFunctions[i] used with outputActivations[i] here - + Activation[] activFns = { Activation.SIGMOID, Activation.TANH }; + // If true: run some backprop steps first + boolean[] characteristic = { false, true }; + LossFunctions.LossFunction[] lossFunctions = { LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD, LossFunctions.LossFunction.MSE }; + // i.e., lossFunctions[i] used with outputActivations[i] here + Activation[] outputActivations = { Activation.SOFTMAX, Activation.TANH }; DataSet ds = new IrisDataSetIterator(150, 150).next(); ds.normalizeZeroMeanZeroUnitVariance(); INDArray input = ds.getFeatures(); INDArray labels = ds.getLabels(); - for (Activation afn : activFns) { for (boolean doLearningFirst : characteristic) { for (int i = 0; i < lossFunctions.length; i++) { LossFunctions.LossFunction lf = lossFunctions[i]; Activation outputActivation = outputActivations[i]; - - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder() - .dataType(DataType.DOUBLE) - .optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).updater(new NoOp()) - .weightInit(WeightInit.XAVIER).seed(12345L).list() - .layer(0, new ConvolutionLayer.Builder(1, 1).nOut(6).activation(afn) - .cudnnAllowFallback(false) - .build()) - .layer(1, new OutputLayer.Builder(lf).activation(outputActivation).nOut(3).build()) - .setInputType(InputType.convolutionalFlat(1, 4, 1)); - + MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).updater(new NoOp()).weightInit(WeightInit.XAVIER).seed(12345L).list().layer(0, new ConvolutionLayer.Builder(1, 1).nOut(6).activation(afn).cudnnAllowFallback(false).build()).layer(1, new OutputLayer.Builder(lf).activation(outputActivation).nOut(3).build()).setInputType(InputType.convolutionalFlat(1, 4, 1)); MultiLayerConfiguration conf = builder.build(); - MultiLayerNetwork mln = new MultiLayerNetwork(conf); mln.init(); String name = new Object() { }.getClass().getEnclosingMethod().getName(); - if (doLearningFirst) { - //Run a number of iterations of learning + // Run a number of iterations of learning mln.setInput(ds.getFeatures()); mln.setLabels(ds.getLabels()); mln.computeGradientAndScore(); double scoreBefore = mln.score(); - for (int j = 0; j < 10; j++) - mln.fit(ds); + for (int j = 0; j < 10; j++) mln.fit(ds); mln.computeGradientAndScore(); double scoreAfter = mln.score(); - //Can't test in 'characteristic mode of operation' if not learning - String msg = name + " - score did not (sufficiently) decrease during learning - activationFn=" - + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation - + ", doLearningFirst= " + doLearningFirst + " (before=" + scoreBefore - + ", scoreAfter=" + scoreAfter + ")"; + // Can't test in 'characteristic mode of operation' if not learning + String msg = name + " - score did not (sufficiently) decrease during learning - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", doLearningFirst= " + doLearningFirst + " (before=" + scoreBefore + ", scoreAfter=" + scoreAfter + ")"; assertTrue(msg, scoreAfter < 0.8 * scoreBefore); } - if (PRINT_RESULTS) { - System.out.println(name + " - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" - + outputActivation + ", doLearningFirst=" + doLearningFirst); + System.out.println(name + " - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", doLearningFirst=" + doLearningFirst); } - - boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); - + boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); assertTrue(gradOK); TestUtils.testModelSerialization(mln); } @@ -145,346 +128,207 @@ public class CNNGradientCheckTest extends BaseDL4JTest { } } - @Test - public void testGradientCNNL1L2MLN() { - //Parameterized test, testing combinations of: + @DisplayName("Test Gradient CNNL 1 L 2 MLN") + void testGradientCNNL1L2MLN() { + // Parameterized test, testing combinations of: // (a) activation function // (b) Whether to test at random initialization, or after some learning (i.e., 'characteristic mode of operation') // (c) Loss function (with specified output activations) - DataSet ds = new IrisDataSetIterator(150, 150).next(); ds.normalizeZeroMeanZeroUnitVariance(); INDArray input = ds.getFeatures(); INDArray labels = ds.getLabels(); - - //use l2vals[i] with l1vals[i] - double[] l2vals = {0.4, 0.0, 0.4, 0.4}; - double[] l1vals = {0.0, 0.0, 0.5, 0.0}; - double[] biasL2 = {0.0, 0.0, 0.0, 0.2}; - double[] biasL1 = {0.0, 0.0, 0.6, 0.0}; - Activation[] activFns = {Activation.SIGMOID, Activation.TANH, Activation.ELU, Activation.SOFTPLUS}; - boolean[] characteristic = {false, true, false, true}; //If true: run some backprop steps first - - LossFunctions.LossFunction[] lossFunctions = - {LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD, LossFunctions.LossFunction.MSE, LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD, LossFunctions.LossFunction.MSE}; - Activation[] outputActivations = {Activation.SOFTMAX, Activation.TANH, Activation.SOFTMAX, Activation.IDENTITY}; //i.e., lossFunctions[i] used with outputActivations[i] here - - for( int i=0; i (mb,4,2,2) - .layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nIn(2 * 2 * 4) - .nOut(nOut).build()) - .setInputType(InputType.convolutionalFlat(height, width, inputDepth)) - .build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).updater(new NoOp()).dist(new NormalDistribution(0, 1)).list().layer(new ConvolutionLayer.Builder(kernel).nIn(inputDepth).hasBias(false).cudnnAllowFallback(false).nOut(1).build()).layer(new SpaceToDepthLayer.Builder(blocks, SpaceToDepthLayer.DataFormat.NCHW).build()).layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(2 * 2 * 4).nOut(nOut).build()).setInputType(InputType.convolutionalFlat(height, width, inputDepth)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - - String msg = "PoolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn=" - + afn; - + String msg = "PoolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn=" + afn; if (PRINT_RESULTS) { System.out.println(msg); -// for (int j = 0; j < net.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); + // for (int j = 0; j < net.getnLayers(); j++) + // System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); } - boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); - + boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); assertTrue(msg, gradOK); - TestUtils.testModelSerialization(net); } } } @Test - public void testCnnWithSpaceToBatch() { + @DisplayName("Test Cnn With Space To Batch") + void testCnnWithSpaceToBatch() { Nd4j.getRandom().setSeed(12345); int nOut = 4; - - int[] minibatchSizes = {2, 4}; + int[] minibatchSizes = { 2, 4 }; int width = 5; int height = 5; int inputDepth = 1; - - int[] kernel = {2, 2}; - int[] blocks = {1, 1}; - - String[] activations = {"sigmoid", "tanh"}; - SubsamplingLayer.PoolingType[] poolingTypes = - new SubsamplingLayer.PoolingType[]{SubsamplingLayer.PoolingType.MAX, - SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM}; - + int[] kernel = { 2, 2 }; + int[] blocks = { 1, 1 }; + String[] activations = { "sigmoid", "tanh" }; + SubsamplingLayer.PoolingType[] poolingTypes = new SubsamplingLayer.PoolingType[] { SubsamplingLayer.PoolingType.MAX, SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM }; for (String afn : activations) { for (SubsamplingLayer.PoolingType poolingType : poolingTypes) { for (int minibatchSize : minibatchSizes) { INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth); INDArray labels = Nd4j.zeros(minibatchSize, nOut); for (int i = 0; i < minibatchSize; i++) { - labels.putScalar(new int[]{i, i % nOut}, 1.0); + labels.putScalar(new int[] { i, i % nOut }, 1.0); } - - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder() - .dataType(DataType.DOUBLE) - .updater(new NoOp()) - .dist(new NormalDistribution(0, 1)) - .list().layer(new ConvolutionLayer.Builder(kernel).nIn(inputDepth) - .cudnnAllowFallback(false) - .nOut(3).build())//output: (5-2+0)/1+1 = 4 - .layer(new SpaceToBatchLayer.Builder(blocks).build()) //trivial space to batch - .layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nIn(4 * 4 * 3) - .nOut(nOut).build()) - .setInputType(InputType.convolutionalFlat(height, width, inputDepth)) - .build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).updater(new NoOp()).dist(new NormalDistribution(0, 1)).list().layer(new ConvolutionLayer.Builder(kernel).nIn(inputDepth).cudnnAllowFallback(false).nOut(3).build()).layer(// trivial space to batch + new SpaceToBatchLayer.Builder(blocks).build()).layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(4 * 4 * 3).nOut(nOut).build()).setInputType(InputType.convolutionalFlat(height, width, inputDepth)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - - String msg = "PoolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn=" - + afn; - + String msg = "PoolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn=" + afn; if (PRINT_RESULTS) { System.out.println(msg); -// for (int j = 0; j < net.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); + // for (int j = 0; j < net.getnLayers(); j++) + // System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); } - boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); - + boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); assertTrue(msg, gradOK); - TestUtils.testModelSerialization(net); } } } } - @Test - public void testCnnWithUpsampling() { + @DisplayName("Test Cnn With Upsampling") + void testCnnWithUpsampling() { Nd4j.getRandom().setSeed(12345); int nOut = 4; - - int[] minibatchSizes = {1, 3}; + int[] minibatchSizes = { 1, 3 }; int width = 5; int height = 5; int inputDepth = 1; - - int[] kernel = {2, 2}; - int[] stride = {1, 1}; - int[] padding = {0, 0}; + int[] kernel = { 2, 2 }; + int[] stride = { 1, 1 }; + int[] padding = { 0, 0 }; int size = 2; - for (int minibatchSize : minibatchSizes) { INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth); INDArray labels = TestUtils.randomOneHot(minibatchSize, nOut); - - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder() - .dataType(DataType.DOUBLE) - .updater(new NoOp()) - .dist(new NormalDistribution(0, 1)) - .list().layer(new ConvolutionLayer.Builder(kernel, - stride, padding).nIn(inputDepth) - .nOut(3).build())//output: (5-2+0)/1+1 = 4 - .layer(new Upsampling2D.Builder().size(size).build()) //output: 4*2 =8 -> 8x8x3 - .layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nIn(8 * 8 * 3) - .nOut(4).build()) - .setInputType(InputType.convolutionalFlat(height, width, - inputDepth)) - .build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).updater(new NoOp()).dist(new NormalDistribution(0, 1)).list().layer(new ConvolutionLayer.Builder(kernel, stride, padding).nIn(inputDepth).nOut(3).build()).layer(// output: 4*2 =8 -> 8x8x3 + new Upsampling2D.Builder().size(size).build()).layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(8 * 8 * 3).nOut(4).build()).setInputType(InputType.convolutionalFlat(height, width, inputDepth)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - String msg = "Upsampling - minibatch=" + minibatchSize; - if (PRINT_RESULTS) { System.out.println(msg); -// for (int j = 0; j < net.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); + // for (int j = 0; j < net.getnLayers(); j++) + // System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); } - - boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); - + boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); assertTrue(msg, gradOK); - TestUtils.testModelSerialization(net); } } - @Test - public void testCnnWithSubsampling() { + @DisplayName("Test Cnn With Subsampling") + void testCnnWithSubsampling() { Nd4j.getRandom().setSeed(12345); int nOut = 4; - - int[] minibatchSizes = {1, 3}; + int[] minibatchSizes = { 1, 3 }; int width = 5; int height = 5; int inputDepth = 1; - - int[] kernel = {2, 2}; - int[] stride = {1, 1}; - int[] padding = {0, 0}; + int[] kernel = { 2, 2 }; + int[] stride = { 1, 1 }; + int[] padding = { 0, 0 }; int pnorm = 2; - - Activation[] activations = {Activation.SIGMOID, Activation.TANH}; - SubsamplingLayer.PoolingType[] poolingTypes = - new SubsamplingLayer.PoolingType[]{SubsamplingLayer.PoolingType.MAX, - SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM}; - + Activation[] activations = { Activation.SIGMOID, Activation.TANH }; + SubsamplingLayer.PoolingType[] poolingTypes = new SubsamplingLayer.PoolingType[] { SubsamplingLayer.PoolingType.MAX, SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM }; for (Activation afn : activations) { for (SubsamplingLayer.PoolingType poolingType : poolingTypes) { for (int minibatchSize : minibatchSizes) { INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth); INDArray labels = Nd4j.zeros(minibatchSize, nOut); for (int i = 0; i < minibatchSize; i++) { - labels.putScalar(new int[]{i, i % nOut}, 1.0); + labels.putScalar(new int[] { i, i % nOut }, 1.0); } - - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder().updater(new NoOp()) - .dataType(DataType.DOUBLE) - .dist(new NormalDistribution(0, 1)) - .list().layer(0, - new ConvolutionLayer.Builder(kernel, - stride, padding).nIn(inputDepth) - .cudnnAllowFallback(false) - .nOut(3).build())//output: (5-2+0)/1+1 = 4 - .layer(1, new SubsamplingLayer.Builder(poolingType) - .cudnnAllowFallback(false) - .kernelSize(kernel).stride(stride).padding(padding) - .pnorm(pnorm).build()) //output: (4-2+0)/1+1 =3 -> 3x3x3 - .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nIn(3 * 3 * 3) - .nOut(4).build()) - .setInputType(InputType.convolutionalFlat(height, width, - inputDepth)) - .build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new NoOp()).dataType(DataType.DOUBLE).dist(new NormalDistribution(0, 1)).list().layer(0, new ConvolutionLayer.Builder(kernel, stride, padding).nIn(inputDepth).cudnnAllowFallback(false).nOut(3).build()).layer(1, new SubsamplingLayer.Builder(poolingType).cudnnAllowFallback(false).kernelSize(kernel).stride(stride).padding(padding).pnorm(pnorm).build()).layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3 * 3 * 3).nOut(4).build()).setInputType(InputType.convolutionalFlat(height, width, inputDepth)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - - String msg = "PoolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn=" - + afn; - + String msg = "PoolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn=" + afn; if (PRINT_RESULTS) { System.out.println(msg); -// for (int j = 0; j < net.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); + // for (int j = 0; j < net.getnLayers(); j++) + // System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); } - - boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); - + boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); assertTrue(msg, gradOK); - TestUtils.testModelSerialization(net); } } @@ -492,69 +336,35 @@ public class CNNGradientCheckTest extends BaseDL4JTest { } @Test - public void testCnnWithSubsamplingV2() { + @DisplayName("Test Cnn With Subsampling V 2") + void testCnnWithSubsamplingV2() { Nd4j.getRandom().setSeed(12345); int nOut = 4; - - int[] minibatchSizes = {1, 3}; + int[] minibatchSizes = { 1, 3 }; int width = 5; int height = 5; int inputDepth = 1; - - int[] kernel = {2, 2}; - int[] stride = {1, 1}; - int[] padding = {0, 0}; + int[] kernel = { 2, 2 }; + int[] stride = { 1, 1 }; + int[] padding = { 0, 0 }; int pNorm = 3; - - Activation[] activations = {Activation.SIGMOID, Activation.TANH}; - SubsamplingLayer.PoolingType[] poolingTypes = - new SubsamplingLayer.PoolingType[]{SubsamplingLayer.PoolingType.MAX, - SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM}; - + Activation[] activations = { Activation.SIGMOID, Activation.TANH }; + SubsamplingLayer.PoolingType[] poolingTypes = new SubsamplingLayer.PoolingType[] { SubsamplingLayer.PoolingType.MAX, SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM }; for (Activation afn : activations) { for (SubsamplingLayer.PoolingType poolingType : poolingTypes) { for (int minibatchSize : minibatchSizes) { INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth); INDArray labels = Nd4j.zeros(minibatchSize, nOut); for (int i = 0; i < minibatchSize; i++) { - labels.putScalar(new int[]{i, i % nOut}, 1.0); + labels.putScalar(new int[] { i, i % nOut }, 1.0); } - - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder().updater(new NoOp()) - .dataType(DataType.DOUBLE) - .dist(new NormalDistribution(0, 1)) - .list().layer(0, - new ConvolutionLayer.Builder(kernel, - stride, padding).nIn(inputDepth) - .cudnnAllowFallback(false) - .nOut(3).build())//output: (5-2+0)/1+1 = 4 - .layer(1, new SubsamplingLayer.Builder(poolingType) - .kernelSize(kernel).stride(stride).padding(padding) - .cudnnAllowFallback(false) - .pnorm(pNorm).build()) //output: (4-2+0)/1+1 =3 -> 3x3x3 - .layer(2, new ConvolutionLayer.Builder(kernel, stride, padding) - .cudnnAllowFallback(false) - .nIn(3).nOut(2).build()) //Output: (3-2+0)/1+1 = 2 - .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nIn(2 * 2 * 2) - .nOut(4).build()) - .setInputType(InputType.convolutionalFlat(height, width, - inputDepth)) - .build(); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new NoOp()).dataType(DataType.DOUBLE).dist(new NormalDistribution(0, 1)).list().layer(0, new ConvolutionLayer.Builder(kernel, stride, padding).nIn(inputDepth).cudnnAllowFallback(false).nOut(3).build()).layer(1, new SubsamplingLayer.Builder(poolingType).kernelSize(kernel).stride(stride).padding(padding).cudnnAllowFallback(false).pnorm(pNorm).build()).layer(2, new ConvolutionLayer.Builder(kernel, stride, padding).cudnnAllowFallback(false).nIn(3).nOut(2).build()).layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(2 * 2 * 2).nOut(4).build()).setInputType(InputType.convolutionalFlat(height, width, inputDepth)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - - String msg = "PoolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn=" - + afn; + String msg = "PoolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn=" + afn; System.out.println(msg); - - boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); - + boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); assertTrue(msg, gradOK); - TestUtils.testModelSerialization(net); } } @@ -562,20 +372,16 @@ public class CNNGradientCheckTest extends BaseDL4JTest { } @Test - public void testCnnMultiLayer() { + @DisplayName("Test Cnn Multi Layer") + void testCnnMultiLayer() { int nOut = 2; - - int[] minibatchSizes = {1, 2, 5}; + int[] minibatchSizes = { 1, 2, 5 }; int width = 5; int height = 5; - int[] inputDepths = {1, 2, 4}; - - Activation[] activations = {Activation.SIGMOID, Activation.TANH}; - SubsamplingLayer.PoolingType[] poolingTypes = new SubsamplingLayer.PoolingType[]{ - SubsamplingLayer.PoolingType.MAX, SubsamplingLayer.PoolingType.AVG}; - + int[] inputDepths = { 1, 2, 4 }; + Activation[] activations = { Activation.SIGMOID, Activation.TANH }; + SubsamplingLayer.PoolingType[] poolingTypes = new SubsamplingLayer.PoolingType[] { SubsamplingLayer.PoolingType.MAX, SubsamplingLayer.PoolingType.AVG }; Nd4j.getRandom().setSeed(12345); - for (int inputDepth : inputDepths) { for (Activation afn : activations) { for (SubsamplingLayer.PoolingType poolingType : poolingTypes) { @@ -583,46 +389,19 @@ public class CNNGradientCheckTest extends BaseDL4JTest { INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth); INDArray labels = Nd4j.zeros(minibatchSize, nOut); for (int i = 0; i < minibatchSize; i++) { - labels.putScalar(new int[]{i, i % nOut}, 1.0); + labels.putScalar(new int[] { i, i % nOut }, 1.0); } - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(new NoOp()) - .dataType(DataType.DOUBLE) - .activation(afn) - .list() - .layer(0, new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1) - .cudnnAllowFallback(false) - .padding(0, 0).nIn(inputDepth).nOut(2).build())//output: (5-2+0)/1+1 = 4 - .layer(1, new ConvolutionLayer.Builder().nIn(2).nOut(2).kernelSize(2, 2) - .cudnnAllowFallback(false) - .stride(1, 1).padding(0, 0).build()) //(4-2+0)/1+1 = 3 - .layer(2, new ConvolutionLayer.Builder().nIn(2).nOut(2).kernelSize(2, 2) - .cudnnAllowFallback(false) - .stride(1, 1).padding(0, 0).build()) //(3-2+0)/1+1 = 2 - .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nIn(2 * 2 * 2).nOut(nOut) - .build()) - .setInputType(InputType.convolutionalFlat(height, width, inputDepth)).build(); - - assertEquals(ConvolutionMode.Truncate, - ((ConvolutionLayer) conf.getConf(0).getLayer()).getConvolutionMode()); - + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(new NoOp()).dataType(DataType.DOUBLE).activation(afn).list().layer(0, new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).cudnnAllowFallback(false).padding(0, 0).nIn(inputDepth).nOut(2).build()).layer(1, new ConvolutionLayer.Builder().nIn(2).nOut(2).kernelSize(2, 2).cudnnAllowFallback(false).stride(1, 1).padding(0, 0).build()).layer(2, new ConvolutionLayer.Builder().nIn(2).nOut(2).kernelSize(2, 2).cudnnAllowFallback(false).stride(1, 1).padding(0, 0).build()).layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(2 * 2 * 2).nOut(nOut).build()).setInputType(InputType.convolutionalFlat(height, width, inputDepth)).build(); + assertEquals(ConvolutionMode.Truncate, ((ConvolutionLayer) conf.getConf(0).getLayer()).getConvolutionMode()); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - -// for (int i = 0; i < 4; i++) { -// System.out.println("nParams, layer " + i + ": " + net.getLayer(i).numParams()); -// } - - String msg = "PoolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn=" - + afn; + // for (int i = 0; i < 4; i++) { + // System.out.println("nParams, layer " + i + ": " + net.getLayer(i).numParams()); + // } + String msg = "PoolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn=" + afn; System.out.println(msg); - - boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); - + boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); assertTrue(msg, gradOK); - TestUtils.testModelSerialization(net); } } @@ -630,126 +409,71 @@ public class CNNGradientCheckTest extends BaseDL4JTest { } } - @Test - public void testCnnSamePaddingMode() { + @DisplayName("Test Cnn Same Padding Mode") + void testCnnSamePaddingMode() { int nOut = 2; - - int[] minibatchSizes = {1, 3, 3, 2, 1, 2}; - int[] heights = new int[]{4, 5, 6, 5, 4, 4}; //Same padding mode: insensitive to exact input size... - int[] kernelSizes = new int[]{2, 3, 2, 3, 2, 3}; - int[] inputDepths = {1, 2, 4, 3, 2, 3}; - + int[] minibatchSizes = { 1, 3, 3, 2, 1, 2 }; + // Same padding mode: insensitive to exact input size... + int[] heights = new int[] { 4, 5, 6, 5, 4, 4 }; + int[] kernelSizes = new int[] { 2, 3, 2, 3, 2, 3 }; + int[] inputDepths = { 1, 2, 4, 3, 2, 3 }; int width = 5; - Nd4j.getRandom().setSeed(12345); - - for( int i=0; i docIds = new ArrayList(); - for (int phase = 1; phase <= 2; ++phase) { - int docIdsIdx = 0; - - if (phase == 2) { - Collections.shuffle(docIds); - } - - final int increment = 32; - - for (int b = 0; b <= 256; b += increment) { - if (256 == b) b--; - for (int g = 0; g <= 256; g += increment) { - if (256 == g) g--; - for (int r = 0; r <= 256; r += increment) { - if (256 == r) r--; - - if (phase == 1) { - docIds.add(docIds.size()+1); - continue; + /** + * Reject deallocator threads over whose cleanup this test has no control. + */ + @Override + public boolean reject(Thread thread) { + final ThreadGroup threadGroup = thread.getThreadGroup(); + final String threadGroupName = (threadGroup == null ? null : threadGroup.getName()); + if (threadGroupName != null && threadGroupName.endsWith(TupleStreamDataSetIteratorTest.class.getSimpleName())) { + final String threadName = thread.getName(); + if (threadName.startsWith(NativeRandomDeallocator.DeallocatorThreadNamePrefix) || threadName.toLowerCase().contains("deallocator") || threadName.equals(BasicWorkspaceManager.WorkspaceDeallocatorThreadName)) { + return true; + } } - - final float luminance = (b*0.0722f + g*0.7152f + r*0.2126f)/(255*3.0f); // https://en.wikipedia.org/wiki/Luma_(video) - - final SolrInputDocument doc = sdoc("id", Integer.toString(docIds.get(docIdsIdx++)), - "channel_b_f", Float.toString(b/255f), - "channel_g_f", Float.toString(g/255f), - "channel_r_f", Float.toString(r/255f), - "luminance_f", Float.toString(luminance)); - - updateRequest.add(doc); - ++numDocs; - - } + return false; } - } } - // make the update request - updateRequest.commit(cluster.getSolrClient(), "mySolrCollection"); - } + private static int numDocs = 0; - private static class CountingIterationListener extends ScoreIterationListener { - - private int numIterationsDone = 0; - - public CountingIterationListener() { - super(1); + @BeforeAll + static void setupCluster() throws Exception { + final int numShards = 2; + final int numReplicas = 2; + final int maxShardsPerNode = 1; + final int nodeCount = (numShards * numReplicas + (maxShardsPerNode - 1)) / maxShardsPerNode; + // create and configure cluster + configureCluster(nodeCount).addConfig("conf", configset("mini")).configure(); + // create an empty collection + CollectionAdminRequest.createCollection("mySolrCollection", "conf", numShards, numReplicas).setMaxShardsPerNode(maxShardsPerNode).process(cluster.getSolrClient()); + // compose an update request + final UpdateRequest updateRequest = new UpdateRequest(); + final List docIds = new ArrayList(); + for (int phase = 1; phase <= 2; ++phase) { + int docIdsIdx = 0; + if (phase == 2) { + Collections.shuffle(docIds); + } + final int increment = 32; + for (int b = 0; b <= 256; b += increment) { + if (256 == b) + b--; + for (int g = 0; g <= 256; g += increment) { + if (256 == g) + g--; + for (int r = 0; r <= 256; r += increment) { + if (256 == r) + r--; + if (phase == 1) { + docIds.add(docIds.size() + 1); + continue; + } + // https://en.wikipedia.org/wiki/Luma_(video) + final float luminance = (b * 0.0722f + g * 0.7152f + r * 0.2126f) / (255 * 3.0f); + final SolrInputDocument doc = sdoc("id", Integer.toString(docIds.get(docIdsIdx++)), "channel_b_f", Float.toString(b / 255f), "channel_g_f", Float.toString(g / 255f), "channel_r_f", Float.toString(r / 255f), "luminance_f", Float.toString(luminance)); + updateRequest.add(doc); + ++numDocs; + } + } + } + } + // make the update request + updateRequest.commit(cluster.getSolrClient(), "mySolrCollection"); } - public int numIterationsDone() { - return numIterationsDone; + @DisplayName("Counting Iteration Listener") + private static class CountingIterationListener extends ScoreIterationListener { + + private int numIterationsDone = 0; + + public CountingIterationListener() { + super(1); + } + + public int numIterationsDone() { + return numIterationsDone; + } + + @Override + public void iterationDone(Model model, int iteration, int epoch) { + super.iterationDone(model, iteration, epoch); + ++numIterationsDone; + } } - @Override - public void iterationDone(Model model, int iteration, int epoch) { - super.iterationDone(model, iteration, epoch); - ++numIterationsDone; + @Test + @DisplayName("Iterate Test") + void iterateTest() throws Exception { + doIterateTest(true); + doIterateTest(false); } - } - - @Test - public void iterateTest() throws Exception { - doIterateTest(true); - doIterateTest(false); - } - - private void doIterateTest(boolean withIdKey) throws Exception { - - try (final TupleStreamDataSetIterator - tsdsi = new TupleStreamDataSetIterator( - 123 /* batch */, - (withIdKey ? "greeting" : null) /* idKey */, - new String[] { "pie" }, - new String[] { "answer" }, - "tuple(greeting=\"hello world\",pie=3.14,answer=42)", - null)) { - - assertTrue(tsdsi.hasNext()); - final DataSet ds = tsdsi.next(); - - assertEquals(1, ds.getFeatures().length()); - assertEquals(3.14f, ds.getFeatures().getFloat(0), 0.0f); - - assertEquals(1, ds.getLabels().length()); - assertEquals(42f, ds.getLabels().getFloat(0), 0.0f); - - assertFalse(tsdsi.hasNext()); + private void doIterateTest(boolean withIdKey) throws Exception { + try (final TupleStreamDataSetIterator tsdsi = new TupleStreamDataSetIterator(123, /* batch */ + (withIdKey ? "greeting" : null), /* idKey */ + new String[] { "pie" }, new String[] { "answer" }, "tuple(greeting=\"hello world\",pie=3.14,answer=42)", null)) { + assertTrue(tsdsi.hasNext()); + final DataSet ds = tsdsi.next(); + assertEquals(1, ds.getFeatures().length()); + assertEquals(3.14f, ds.getFeatures().getFloat(0), 0.0f); + assertEquals(1, ds.getLabels().length()); + assertEquals(42f, ds.getLabels().getFloat(0), 0.0f); + assertFalse(tsdsi.hasNext()); + } } - } - @Test - public void modelFitTest() throws Exception { - - final MultiLayerNetwork model = new MultiLayerNetwork( - new NeuralNetConfiguration.Builder() - .list( - new OutputLayer.Builder(LossFunction.MSE) - .nIn(3) - .nOut(1) - .weightInit(WeightInit.ONES) - .activation(Activation.IDENTITY) - .build() - ) - - - .build() - ); - model.init(); - - int batch = 1; - for (int ii=1; ii<=5; ++ii) { - final CountingIterationListener listener = new CountingIterationListener(); - model.setListeners(listener); - batch *= 2; - - try (final TupleStreamDataSetIterator tsdsi = - new TupleStreamDataSetIterator( - batch, - "id" /* idKey */, - new String[] { "channel_b_f", "channel_g_f", "channel_r_f" }, - new String[] { "luminance_f" }, - "search(mySolrCollection," + - "q=\"id:*\"," + - "fl=\"id,channel_b_f,channel_g_f,channel_r_f,luminance_f\"," + - "sort=\"id asc\"," + - "qt=\"/export\")", - cluster.getZkClient().getZkServerAddress())) { - - model.fit(tsdsi); - } - - assertEquals("numIterationsDone="+listener.numIterationsDone()+" numDocs="+numDocs+" batch="+batch, - (numDocs+(batch-1))/batch, listener.numIterationsDone()); + @Test + @DisplayName("Model Fit Test") + void modelFitTest() throws Exception { + final MultiLayerNetwork model = new MultiLayerNetwork(new NeuralNetConfiguration.Builder().list(new OutputLayer.Builder(LossFunction.MSE).nIn(3).nOut(1).weightInit(WeightInit.ONES).activation(Activation.IDENTITY).build()).build()); + model.init(); + int batch = 1; + for (int ii = 1; ii <= 5; ++ii) { + final CountingIterationListener listener = new CountingIterationListener(); + model.setListeners(listener); + batch *= 2; + try (final TupleStreamDataSetIterator tsdsi = new TupleStreamDataSetIterator(batch, "id", /* idKey */ + new String[] { "channel_b_f", "channel_g_f", "channel_r_f" }, new String[] { "luminance_f" }, "search(mySolrCollection," + "q=\"id:*\"," + "fl=\"id,channel_b_f,channel_g_f,channel_r_f,luminance_f\"," + "sort=\"id asc\"," + "qt=\"/export\")", cluster.getZkClient().getZkServerAddress())) { + model.fit(tsdsi); + } + assertEquals("numIterationsDone=" + listener.numIterationsDone() + " numDocs=" + numDocs + " batch=" + batch, (numDocs + (batch - 1)) / batch, listener.numIterationsDone()); + } } - } - } diff --git a/deeplearning4j/deeplearning4j-graph/pom.xml b/deeplearning4j/deeplearning4j-graph/pom.xml index fe0a366b0..164219a58 100644 --- a/deeplearning4j/deeplearning4j-graph/pom.xml +++ b/deeplearning4j/deeplearning4j-graph/pom.xml @@ -44,10 +44,18 @@ org.threadly threadly ${threadly.version} + + + org.junit.jupiter + junit-jupiter-api + ${junit.version} + test - junit - junit + org.junit.jupiter + junit-jupiter-engine + ${junit.version} + test ch.qos.logback diff --git a/deeplearning4j/deeplearning4j-modelexport-solr/pom.xml b/deeplearning4j/deeplearning4j-modelexport-solr/pom.xml index 97421ce2b..3f430ab04 100644 --- a/deeplearning4j/deeplearning4j-modelexport-solr/pom.xml +++ b/deeplearning4j/deeplearning4j-modelexport-solr/pom.xml @@ -285,8 +285,16 @@ - junit - junit + org.junit.jupiter + junit-jupiter-api + ${junit.version} + test + + + org.junit.jupiter + junit-jupiter-engine + ${junit.version} + test org.apache.solr diff --git a/deeplearning4j/deeplearning4j-modelexport-solr/src/test/java/org/deeplearning4j/nn/modelexport/solr/handler/ModelTupleStreamIntegrationTest.java b/deeplearning4j/deeplearning4j-modelexport-solr/src/test/java/org/deeplearning4j/nn/modelexport/solr/handler/ModelTupleStreamIntegrationTest.java index 3aa714f02..e9d98b205 100644 --- a/deeplearning4j/deeplearning4j-modelexport-solr/src/test/java/org/deeplearning4j/nn/modelexport/solr/handler/ModelTupleStreamIntegrationTest.java +++ b/deeplearning4j/deeplearning4j-modelexport-solr/src/test/java/org/deeplearning4j/nn/modelexport/solr/handler/ModelTupleStreamIntegrationTest.java @@ -17,13 +17,11 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.modelexport.solr.handler; import java.io.File; import java.nio.file.Path; import java.security.SecureRandom; - import com.carrotsearch.randomizedtesting.ThreadFilter; import com.carrotsearch.randomizedtesting.annotations.ThreadLeakFilters; import org.apache.solr.client.solrj.io.Tuple; @@ -40,224 +38,152 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.util.ModelSerializer; -import org.junit.BeforeClass; -import org.junit.Test; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.api.memory.provider.BasicWorkspaceManager; import org.nd4j.rng.deallocator.NativeRandomDeallocator; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -@ThreadLeakFilters(defaultFilters = true, filters = { - ModelTupleStreamIntegrationTest.PrivateDeallocatorThreadsFilter.class -}) -public class ModelTupleStreamIntegrationTest extends SolrCloudTestCase { +@ThreadLeakFilters(defaultFilters = true, filters = { ModelTupleStreamIntegrationTest.PrivateDeallocatorThreadsFilter.class }) +@DisplayName("Model Tuple Stream Integration Test") +class ModelTupleStreamIntegrationTest extends SolrCloudTestCase { - static { - /* + static { + /* This is a hack around the backend-dependent nature of secure random implementations though we can set the secure random algorithm in our pom.xml files (via maven surefire and test.solr.allowed.securerandom) there isn't a mechanism that is completely platform independent. By setting it there (for example, to NativePRNG) that makes it pass on some platforms like Linux but fails on some JVMs on Windows For testing purposes, we don't need strict guarantees around RNG, hence we don't want to enforce the RNG algorithm */ - String algorithm = new SecureRandom().getAlgorithm(); - System.setProperty("test.solr.allowed.securerandom", algorithm); - } + String algorithm = new SecureRandom().getAlgorithm(); + System.setProperty("test.solr.allowed.securerandom", algorithm); + } + @DisplayName("Private Deallocator Threads Filter") + static class PrivateDeallocatorThreadsFilter implements ThreadFilter { - public static class PrivateDeallocatorThreadsFilter implements ThreadFilter { - /** - * Reject deallocator threads over whose cleanup this test has no control. - */ - @Override - public boolean reject(Thread thread) { - final ThreadGroup threadGroup = thread.getThreadGroup(); - final String threadGroupName = (threadGroup == null ? null : threadGroup.getName()); - - if (threadGroupName != null && - threadGroupName.endsWith(ModelTupleStreamIntegrationTest.class.getSimpleName())) { - - final String threadName = thread.getName(); - if (threadName.startsWith(NativeRandomDeallocator.DeallocatorThreadNamePrefix) || - threadName.toLowerCase().contains("deallocator") || - threadName.equals(BasicWorkspaceManager.WorkspaceDeallocatorThreadName)) { - return true; + /** + * Reject deallocator threads over whose cleanup this test has no control. + */ + @Override + public boolean reject(Thread thread) { + final ThreadGroup threadGroup = thread.getThreadGroup(); + final String threadGroupName = (threadGroup == null ? null : threadGroup.getName()); + if (threadGroupName != null && threadGroupName.endsWith(ModelTupleStreamIntegrationTest.class.getSimpleName())) { + final String threadName = thread.getName(); + if (threadName.startsWith(NativeRandomDeallocator.DeallocatorThreadNamePrefix) || threadName.toLowerCase().contains("deallocator") || threadName.equals(BasicWorkspaceManager.WorkspaceDeallocatorThreadName)) { + return true; + } + } + return false; } - } - - return false; - } - } - - final private static String MY_COLLECTION_NAME = "mySolrCollection"; - final private static String MY_SERIALIZED_MODEL_FILENAME = "mySerializedModel"; - - @BeforeClass - public static void setupCluster() throws Exception { - - final Path configsetPath = configset("mini-expressible"); - - // create and serialize model - { - final Model model = buildModel(); - final File serializedModelFile = configsetPath - .resolve(MY_SERIALIZED_MODEL_FILENAME) - .toFile(); - ModelSerializer.writeModel(model, serializedModelFile.getPath(), false); } - final String configName = "conf"; - final int numShards = 2; - final int numReplicas = 2; - final int maxShardsPerNode = 1; - final int nodeCount = (numShards*numReplicas + (maxShardsPerNode-1))/maxShardsPerNode; + final private static String MY_COLLECTION_NAME = "mySolrCollection"; - // create and configure cluster - configureCluster(nodeCount) - .addConfig(configName, configsetPath) - .configure(); + final private static String MY_SERIALIZED_MODEL_FILENAME = "mySerializedModel"; - // create an empty collection - CollectionAdminRequest.createCollection(MY_COLLECTION_NAME, configName, numShards, numReplicas) - .setMaxShardsPerNode(maxShardsPerNode) - .process(cluster.getSolrClient()); - - // compose an update request - final UpdateRequest updateRequest = new UpdateRequest(); - - // add some documents - updateRequest.add( - sdoc("id", "green", - "channel_b_f", "0", - "channel_g_f", "255", - "channel_r_f", "0")); - updateRequest.add( - sdoc("id", "black", - "channel_b_f", "0", - "channel_g_f", "0", - "channel_r_f", "0")); - updateRequest.add( - sdoc("id", "yellow", - "channel_b_f", "0", - "channel_g_f", "255", - "channel_r_f", "255")); - - // make the update request - updateRequest.commit(cluster.getSolrClient(), MY_COLLECTION_NAME); - } - - private static Model buildModel() throws Exception { - - final int numInputs = 3; - final int numOutputs = 2; - - final MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .list( - new OutputLayer.Builder() - .nIn(numInputs) - .nOut(numOutputs) - .activation(Activation.IDENTITY) - .lossFunction(LossFunctions.LossFunction.MSE) - .build() - ) - .build(); - - final MultiLayerNetwork model = new MultiLayerNetwork(conf); - model.init(); - - final float[] floats = new float[]{ +1, +1, +1, -1, -1, -1, 0, 0 }; - // positive weight for first output, negative weight for second output, no biases - assertEquals((numInputs+1)*numOutputs, floats.length); - - final INDArray params = Nd4j.create(floats); - model.setParams(params); - - return model; - } - - private void doTest(String expr, String[] expectedIds, Object[] expectedLefts, Object[] expectedRights) throws Exception { - ModifiableSolrParams paramsLoc = new ModifiableSolrParams(); - paramsLoc.set("expr", expr); - paramsLoc.set("qt", "/stream"); - - String url = cluster.getRandomJetty(random()).getBaseUrl().toString()+"/"+MY_COLLECTION_NAME; - - - TupleStream tupleStream = new SolrStream(url, paramsLoc); - - StreamContext context = new StreamContext(); - tupleStream.setStreamContext(context); - - try { - tupleStream.open(); - - for (int ii=0; ii floatsList(int numFloats) { - final List floatsList = new ArrayList(); - final float[] floats0 = new float[numFloats]; - final float[] floats1 = new float[numFloats]; - for (int ii=0; ii floatsList(int numFloats) { + final List floatsList = new ArrayList(); + final float[] floats0 = new float[numFloats]; + final float[] floats1 = new float[numFloats]; + for (int ii = 0; ii < numFloats; ++ii) { + floats0[ii] = 0f; + floats1[ii] = 1f; } - } + floatsList.add(floats0); + floatsList.add(floats1); + return floatsList; } - assertEquals(50, testsCount); - } - private void doTest(Model originalModel, int numInputs, int numOutputs) throws Exception { + @Test + @DisplayName("Test") + void test() throws Exception { + int testsCount = 0; + for (int numInputs = 1; numInputs <= 5; ++numInputs) { + for (int numOutputs = 1; numOutputs <= 5; ++numOutputs) { + for (Model model : new Model[] { buildMultiLayerNetworkModel(numInputs, numOutputs), buildComputationGraphModel(numInputs, numOutputs) }) { + doTest(model, numInputs, numOutputs); + ++testsCount; + } + } + } + assertEquals(50, testsCount); + } - final Path tempDirPath = Files.createTempDirectory(null); - final File tempDirFile = tempDirPath.toFile(); - tempDirFile.deleteOnExit(); + private void doTest(Model originalModel, int numInputs, int numOutputs) throws Exception { + final Path tempDirPath = Files.createTempDirectory(null); + final File tempDirFile = tempDirPath.toFile(); + tempDirFile.deleteOnExit(); + final SolrResourceLoader solrResourceLoader = new SolrResourceLoader(tempDirPath); + final File tempFile = File.createTempFile("prefix", "suffix", tempDirFile); + tempFile.deleteOnExit(); + final String serializedModelFileName = tempFile.getPath(); + ModelSerializer.writeModel(originalModel, serializedModelFileName, false); + final Model restoredModel = ModelGuesser.loadModelGuess(serializedModelFileName); + final StreamContext streamContext = new StreamContext(); + final SolrClientCache solrClientCache = new SolrClientCache(); + streamContext.setSolrClientCache(solrClientCache); + final String[] inputKeys = new String[numInputs]; + final String inputKeysList = fillArray(inputKeys, "input", ","); + final String[] outputKeys = new String[numOutputs]; + final String outputKeysList = fillArray(outputKeys, "output", ","); + for (final float[] floats : floatsList(numInputs)) { + final String inputValuesList; + { + final StringBuilder sb = new StringBuilder(); + for (int ii = 0; ii < inputKeys.length; ++ii) { + if (0 < ii) + sb.append(','); + sb.append(inputKeys[ii]).append('=').append(floats[ii]); + } + inputValuesList = sb.toString(); + } + final StreamFactory streamFactory = new SolrDefaultStreamFactory().withSolrResourceLoader(solrResourceLoader).withFunctionName("model", ModelTupleStream.class); + final StreamExpression streamExpression = StreamExpressionParser.parse("model(" + "tuple(" + inputValuesList + ")" + ",serializedModelFileName=\"" + serializedModelFileName + "\"" + ",inputKeys=\"" + inputKeysList + "\"" + ",outputKeys=\"" + outputKeysList + "\"" + ")"); + final TupleStream tupleStream = streamFactory.constructStream(streamExpression); + tupleStream.setStreamContext(streamContext); + assertTrue(tupleStream instanceof ModelTupleStream); + final ModelTupleStream modelTupleStream = (ModelTupleStream) tupleStream; + modelTupleStream.open(); + { + final Tuple tuple1 = modelTupleStream.read(); + assertNotNull(tuple1); + assertFalse(tuple1.EOF); + for (int ii = 0; ii < outputKeys.length; ++ii) { + final INDArray inputs = Nd4j.create(new float[][] { floats }); + final double originalScore = NetworkUtils.output((Model) originalModel, inputs).getDouble(ii); + final double restoredScore = NetworkUtils.output((Model) restoredModel, inputs).getDouble(ii); + assertEquals(originalScore, restoredScore, 1e-5,originalModel.getClass().getSimpleName() + " (originalScore-restoredScore)=" + (originalScore - restoredScore)); + final Double outputValue = tuple1.getDouble(outputKeys[ii]); + assertNotNull(outputValue); + final double tupleScore = outputValue.doubleValue(); + assertEquals(originalScore, tupleScore, 1e-5,originalModel.getClass().getSimpleName() + " (originalScore-tupleScore[" + ii + "])=" + (originalScore - tupleScore)); + } + final Tuple tuple2 = modelTupleStream.read(); + assertNotNull(tuple2); + assertTrue(tuple2.EOF); + } + modelTupleStream.close(); + doToExpressionTest(streamExpression, modelTupleStream.toExpression(streamFactory), inputKeys.length); + doToExplanationTest(modelTupleStream.toExplanation(streamFactory)); + } + } - final SolrResourceLoader solrResourceLoader = new SolrResourceLoader(tempDirPath); + private static void doToExpressionTest(StreamExpression streamExpression, StreamExpressionParameter streamExpressionParameter, int inputKeysLength) { + assertTrue(streamExpressionParameter instanceof StreamExpression); + // tuple(input1=1,input2=2) and tuple(input2=2,input1=1) are equivalent + // but StreamExpression equals does not consider them equal. + if (inputKeysLength == 1) { + assertEquals(streamExpression, (StreamExpression) streamExpressionParameter); + } + } - final File tempFile = File.createTempFile("prefix", "suffix", tempDirFile); - tempFile.deleteOnExit(); + private static void doToExplanationTest(Explanation explanation) { + final Map explanationMap = new TreeMap(); + explanation.toMap(explanationMap); + assertTrue(explanation instanceof StreamExplanation); + assertNotNull(explanationMap.remove("children")); + assertNotNull(explanationMap.remove("expression")); + assertNotNull(explanationMap.remove("expressionNodeId")); + assertEquals(ExpressionType.STREAM_DECORATOR, explanationMap.remove("expressionType")); + assertEquals(explanationMap.remove("functionName"), "model"); + assertEquals(ModelTupleStream.class.getName(), explanationMap.remove("implementingClass")); + assertTrue(explanationMap.isEmpty(),explanationMap.toString()); + } - final String serializedModelFileName = tempFile.getPath(); - - ModelSerializer.writeModel(originalModel, serializedModelFileName, false); - - final Model restoredModel = ModelGuesser.loadModelGuess(serializedModelFileName); - - final StreamContext streamContext = new StreamContext(); - final SolrClientCache solrClientCache = new SolrClientCache(); - streamContext.setSolrClientCache(solrClientCache); - - final String[] inputKeys = new String[numInputs]; - final String inputKeysList = fillArray(inputKeys, "input", ","); - - final String[] outputKeys = new String[numOutputs]; - final String outputKeysList = fillArray(outputKeys, "output", ","); - - for (final float[] floats : floatsList(numInputs)) { - - final String inputValuesList; - { + /** + * Fills an existing array using prefix and delimiter, e.g. + * input: arr = [ "", "", "" ] prefix="value" delimiter="," + * output: arr = [ "value1", "value2", "value3" ] + * return: "value1,value2,value3" + */ + private static String fillArray(String[] arr, final String prefix, final String delimiter) { final StringBuilder sb = new StringBuilder(); - for (int ii=0; ii { + String modelPath = "modelimport/keras/examples/foo/bar.h5"; + importEndModelTest(tempDir,modelPath, null, true, true, false, false); + }); } /** * MNIST MLP tests */ @Test - public void importMnistMlpTfKeras1() throws Exception { + @DisplayName("Import Mnist Mlp Tf Keras 1") + void importMnistMlpTfKeras1(@TempDir Path tempDir) throws Exception { String modelPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_1_model.h5"; String inputsOutputPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_1_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); + importEndModelTest(tempDir,modelPath, inputsOutputPath, true, true, false, false); } @Test - public void importMnistMlpThKeras1() throws Exception { + @DisplayName("Import Mnist Mlp Th Keras 1") + void importMnistMlpThKeras1(@TempDir Path tempDir) throws Exception { String modelPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_th_keras_1_model.h5"; String inputsOutputPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_th_keras_1_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, false, true, false, false); + importEndModelTest(tempDir,modelPath, inputsOutputPath, false, true, false, false); } @Test - public void importMnistMlpTfKeras2() throws Exception { + @DisplayName("Import Mnist Mlp Tf Keras 2") + void importMnistMlpTfKeras2(@TempDir Path tempDir) throws Exception { String modelPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_2_model.h5"; String inputsOutputPath = "modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_2_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); + importEndModelTest(tempDir,modelPath, inputsOutputPath, true, true, false, false); } @Test - public void importMnistMlpReshapeTfKeras1() throws Exception { + @DisplayName("Import Mnist Mlp Reshape Tf Keras 1") + void importMnistMlpReshapeTfKeras1(@TempDir Path tempDir) throws Exception { String modelPath = "modelimport/keras/examples/mnist_mlp_reshape/mnist_mlp_reshape_tf_keras_1_model.h5"; String inputsOutputPath = "modelimport/keras/examples/mnist_mlp_reshape/mnist_mlp_reshape_tf_keras_1_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, true, false); + importEndModelTest(tempDir,modelPath, inputsOutputPath, true, true, true, false); } /** * MNIST CNN tests */ @Test - public void importMnistCnnTfKeras1() throws Exception { + @DisplayName("Import Mnist Cnn Tf Keras 1") + void importMnistCnnTfKeras1(@TempDir Path tempDir) throws Exception { String modelPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_1_model.h5"; String inputsOutputPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_1_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, false, false, false); + importEndModelTest(tempDir,modelPath, inputsOutputPath, true, false, false, false); } @Test - public void importMnistCnnThKeras1() throws Exception { + @DisplayName("Import Mnist Cnn Th Keras 1") + void importMnistCnnThKeras1(@TempDir Path tempDir) throws Exception { String modelPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_th_keras_1_model.h5"; String inputsOutputPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_th_keras_1_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, false, true, true, false); + importEndModelTest(tempDir,modelPath, inputsOutputPath, false, true, true, false); } @Test - public void importMnistCnnTfKeras2() throws Exception { + @DisplayName("Import Mnist Cnn Tf Keras 2") + void importMnistCnnTfKeras2(@TempDir Path tempDir) throws Exception { String modelPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_2_model.h5"; String inputsOutputPath = "modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_2_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, true, false); + importEndModelTest(tempDir,modelPath, inputsOutputPath, true, true, true, false); } /** * IMDB Embedding and LSTM test */ @Test - public void importImdbLstmTfKeras1() throws Exception { + @DisplayName("Import Imdb Lstm Tf Keras 1") + void importImdbLstmTfKeras1(@TempDir Path tempDir) throws Exception { String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_model.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, null); + importEndModelTest(tempDir,modelPath, inputsOutputPath, true, true, false, false, true, null, null); } @Test - public void importImdbLstmThKeras1() throws Exception { + @DisplayName("Import Imdb Lstm Th Keras 1") + void importImdbLstmThKeras1(@TempDir Path tempDir) throws Exception { String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_model.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, null); + importEndModelTest(tempDir,modelPath, inputsOutputPath, true, true, false, false, true, null, null); } @Test - public void importImdbLstmTfKeras2() throws Exception { + @DisplayName("Import Imdb Lstm Tf Keras 2") + void importImdbLstmTfKeras2(@TempDir Path tempDir) throws Exception { String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_model.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, null); + importEndModelTest(tempDir,modelPath, inputsOutputPath, true, true, false, false, true, null, null); } @Test - public void importImdbLstmThKeras2() throws Exception { + @DisplayName("Import Imdb Lstm Th Keras 2") + void importImdbLstmThKeras2(@TempDir Path tempDir) throws Exception { String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_model.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, false, true, false, false, true, null, null); + importEndModelTest(tempDir,modelPath, inputsOutputPath, false, true, false, false, true, null, null); } /** @@ -194,99 +218,106 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { */ // TODO: prediction checks fail due to globalpooling for fasttext, very few grads fail as well @Test - public void importImdbFasttextTfKeras1() throws Exception { + @DisplayName("Import Imdb Fasttext Tf Keras 1") + void importImdbFasttextTfKeras1(@TempDir Path tempDir) throws Exception { String modelPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_tf_keras_1_model.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_tf_keras_1_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, false, false, false, false); + importEndModelTest(tempDir,modelPath, inputsOutputPath, false, false, false, false); } @Test - public void importImdbFasttextThKeras1() throws Exception { + @DisplayName("Import Imdb Fasttext Th Keras 1") + void importImdbFasttextThKeras1(@TempDir Path tempDir) throws Exception { String modelPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_th_keras_1_model.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_th_keras_1_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, false, false, false, false); + importEndModelTest(tempDir,modelPath, inputsOutputPath, false, false, false, false); } @Test - public void importImdbFasttextTfKeras2() throws Exception { + @DisplayName("Import Imdb Fasttext Tf Keras 2") + void importImdbFasttextTfKeras2(@TempDir Path tempDir) throws Exception { String modelPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_tf_keras_2_model.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_fasttext/imdb_fasttext_tf_keras_2_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, false, false, false); + importEndModelTest(tempDir,modelPath, inputsOutputPath, true, false, false, false); } /** * Simple LSTM (return sequences = false) into Dense layer test */ @Test - public void importSimpleLstmTfKeras1() throws Exception { + @DisplayName("Import Simple Lstm Tf Keras 1") + void importSimpleLstmTfKeras1(@TempDir Path tempDir) throws Exception { String modelPath = "modelimport/keras/examples/simple_lstm/simple_lstm_tf_keras_1_model.h5"; String inputsOutputPath = "modelimport/keras/examples/simple_lstm/simple_lstm_tf_keras_1_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); + importEndModelTest(tempDir,modelPath, inputsOutputPath, true, true, false, false); } @Test - public void importSimpleLstmThKeras1() throws Exception { + @DisplayName("Import Simple Lstm Th Keras 1") + void importSimpleLstmThKeras1(@TempDir Path tempDir) throws Exception { String modelPath = "modelimport/keras/examples/simple_lstm/simple_lstm_th_keras_1_model.h5"; String inputsOutputPath = "modelimport/keras/examples/simple_lstm/simple_lstm_th_keras_1_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); + importEndModelTest(tempDir,modelPath, inputsOutputPath, true, true, false, false); } @Test - public void importSimpleLstmTfKeras2() throws Exception { + @DisplayName("Import Simple Lstm Tf Keras 2") + void importSimpleLstmTfKeras2(@TempDir Path tempDir) throws Exception { String modelPath = "modelimport/keras/examples/simple_lstm/simple_lstm_tf_keras_2_model.h5"; String inputsOutputPath = "modelimport/keras/examples/simple_lstm/simple_lstm_tf_keras_2_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, false, false, false); + importEndModelTest(tempDir,modelPath, inputsOutputPath, true, false, false, false); } - /** * Simple LSTM (return sequences = true) into flatten into Dense layer test */ @Test - public void importSimpleFlattenLstmTfKeras2() throws Exception { + @DisplayName("Import Simple Flatten Lstm Tf Keras 2") + void importSimpleFlattenLstmTfKeras2(@TempDir Path tempDir) throws Exception { String modelPath = "modelimport/keras/examples/simple_flatten_lstm/simple_flatten_lstm_tf_keras_2_model.h5"; - String inputsOutputPath = "modelimport/keras/examples/simple_flatten_lstm/" + - "simple_flatten_lstm_tf_keras_2_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); + String inputsOutputPath = "modelimport/keras/examples/simple_flatten_lstm/" + "simple_flatten_lstm_tf_keras_2_inputs_and_outputs.h5"; + importEndModelTest(tempDir,modelPath, inputsOutputPath, true, true, false, false); } /** * Simple RNN (return sequences = true) into flatten into Dense layer test */ @Test - public void importSimpleFlattenRnnTfKeras2() throws Exception { + @DisplayName("Import Simple Flatten Rnn Tf Keras 2") + void importSimpleFlattenRnnTfKeras2(@TempDir Path tempDir) throws Exception { String modelPath = "modelimport/keras/examples/simple_flatten_rnn/simple_flatten_rnn_tf_keras_2_model.h5"; - String inputsOutputPath = "modelimport/keras/examples/simple_flatten_rnn/" + - "simple_flatten_rnn_tf_keras_2_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, null); + String inputsOutputPath = "modelimport/keras/examples/simple_flatten_rnn/" + "simple_flatten_rnn_tf_keras_2_inputs_and_outputs.h5"; + importEndModelTest(tempDir,modelPath, inputsOutputPath, true, true, false, false, true, null, null); } /** * Simple RNN (return sequences = false) into Dense layer test */ @Test - public void importSimpleRnnTfKeras2() throws Exception { + @DisplayName("Import Simple Rnn Tf Keras 2") + void importSimpleRnnTfKeras2(@TempDir Path tempDir) throws Exception { String modelPath = "modelimport/keras/examples/simple_rnn/simple_rnn_tf_keras_2_model.h5"; - String inputsOutputPath = "modelimport/keras/examples/simple_rnn/" + - "simple_rnn_tf_keras_2_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); + String inputsOutputPath = "modelimport/keras/examples/simple_rnn/" + "simple_rnn_tf_keras_2_inputs_and_outputs.h5"; + importEndModelTest(tempDir,modelPath, inputsOutputPath, true, true, false, false); } /** * CNN without bias test */ @Test - public void importCnnNoBiasTfKeras2() throws Exception { + @DisplayName("Import Cnn No Bias Tf Keras 2") + void importCnnNoBiasTfKeras2(@TempDir Path tempDir) throws Exception { String modelPath = "modelimport/keras/examples/cnn_no_bias/mnist_cnn_no_bias_tf_keras_2_model.h5"; String inputsOutputPath = "modelimport/keras/examples/cnn_no_bias/mnist_cnn_no_bias_tf_keras_2_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, true, false); + importEndModelTest(tempDir,modelPath, inputsOutputPath, true, true, true, false); } @Test - public void importSparseXent() throws Exception { + @DisplayName("Import Sparse Xent") + void importSparseXent(@TempDir Path tempDir) throws Exception { String modelPath = "modelimport/keras/examples/simple_sparse_xent/simple_sparse_xent_mlp_keras_2_model.h5"; String inputsOutputPath = "modelimport/keras/examples/simple_sparse_xent/simple_sparse_xent_mlp_keras_2_inputs_and_outputs.h5"; - MultiLayerNetwork net = importEndModelTest(modelPath, inputsOutputPath, true, true, true, true); + MultiLayerNetwork net = importEndModelTest(tempDir,modelPath, inputsOutputPath, true, true, true, true); Layer outLayer = net.getOutputLayer(); assertTrue(outLayer instanceof org.deeplearning4j.nn.layers.LossLayer); LossLayer llConf = (LossLayer) outLayer.getConfig(); @@ -297,38 +328,45 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { * GAN import tests */ @Test - public void importDcganMnistDiscriminator() throws Exception { - importSequentialModelH5Test("modelimport/keras/examples/mnist_dcgan/dcgan_discriminator_epoch_50.h5"); + @DisplayName("Import Dcgan Mnist Discriminator") + void importDcganMnistDiscriminator(@TempDir Path tempDir) throws Exception { + importSequentialModelH5Test(tempDir,"modelimport/keras/examples/mnist_dcgan/dcgan_discriminator_epoch_50.h5"); } @Test - @Ignore("Neither keras or tfkeras can load this.") - public void importDcganMnistGenerator() throws Exception { - importSequentialModelH5Test("modelimport/keras/examples/mnist_dcgan/dcgan_generator_epoch_50.h5"); + @Disabled("Neither keras or tfkeras can load this.") + @DisplayName("Import Dcgan Mnist Generator") + void importDcganMnistGenerator(@TempDir Path tempDir) throws Exception { + importSequentialModelH5Test(tempDir,"modelimport/keras/examples/mnist_dcgan/dcgan_generator_epoch_50.h5"); } /** * Auxillary classifier GAN import test */ @Test - public void importAcganDiscriminator() throws Exception { - ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/acgan/acgan_discriminator_1_epochs.h5"); - INDArray input = Nd4j.create(10, 28, 28, 1); //NHWC + @DisplayName("Import Acgan Discriminator") + void importAcganDiscriminator(@TempDir Path tempDir) throws Exception { + ComputationGraph model = importFunctionalModelH5Test(tempDir,"modelimport/keras/examples/acgan/acgan_discriminator_1_epochs.h5"); + // NHWC + INDArray input = Nd4j.create(10, 28, 28, 1); INDArray[] output = model.output(input); } - @Test //AB 2020/04/22 Ignored until Keras model import updated to use NHWC support - public void importAcganGenerator() throws Exception { - ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/acgan/acgan_generator_1_epochs.h5"); - //System.out.println(model.summary()) ; + // AB 2020/04/22 Ignored until Keras model import updated to use NHWC support + @Test + @DisplayName("Import Acgan Generator") + void importAcganGenerator(@TempDir Path tempDir) throws Exception { + ComputationGraph model = importFunctionalModelH5Test(tempDir,"modelimport/keras/examples/acgan/acgan_generator_1_epochs.h5"); + // System.out.println(model.summary()) ; INDArray latent = Nd4j.create(10, 100); INDArray label = Nd4j.create(10, 1); INDArray[] output = model.output(latent, label); } @Test - public void importAcganCombined() throws Exception { - ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/acgan/acgan_combined_1_epochs.h5"); + @DisplayName("Import Acgan Combined") + void importAcganCombined(@TempDir Path tempDir) throws Exception { + ComputationGraph model = importFunctionalModelH5Test(tempDir,"modelimport/keras/examples/acgan/acgan_combined_1_epochs.h5"); // TODO: imports, but incorrectly. Has only one input, should have two. } @@ -336,117 +374,124 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { * Deep convolutional GAN import test */ @Test - public void importDcganDiscriminator() throws Exception { - importSequentialModelH5Test("modelimport/keras/examples/gans/dcgan_discriminator.h5"); + @DisplayName("Import Dcgan Discriminator") + void importDcganDiscriminator(@TempDir Path tempDir) throws Exception { + importSequentialModelH5Test(tempDir,"modelimport/keras/examples/gans/dcgan_discriminator.h5"); } @Test - public void importDcganGenerator() throws Exception { - importSequentialModelH5Test("modelimport/keras/examples/gans/dcgan_generator.h5"); + @DisplayName("Import Dcgan Generator") + void importDcganGenerator(@TempDir Path tempDir) throws Exception { + importSequentialModelH5Test(tempDir,"modelimport/keras/examples/gans/dcgan_generator.h5"); } /** * Wasserstein GAN import test */ @Test - public void importWganDiscriminator() throws Exception { + @DisplayName("Import Wgan Discriminator") + void importWganDiscriminator(@TempDir Path tempDir) throws Exception { for (int i = 0; i < 100; i++) { // run a few times to make sure HDF5 doesn't crash - importSequentialModelH5Test("modelimport/keras/examples/gans/wgan_discriminator.h5"); + importSequentialModelH5Test(tempDir,"modelimport/keras/examples/gans/wgan_discriminator.h5"); } } @Test - public void importWganGenerator() throws Exception { - importSequentialModelH5Test("modelimport/keras/examples/gans/wgan_generator.h5"); + @DisplayName("Import Wgan Generator") + void importWganGenerator(@TempDir Path tempDir) throws Exception { + importSequentialModelH5Test(tempDir,"modelimport/keras/examples/gans/wgan_generator.h5"); } @Test - public void importCnn1d() throws Exception { - importSequentialModelH5Test("modelimport/keras/examples/cnn1d/cnn1d_flatten_tf_keras2.h5"); + @DisplayName("Import Cnn 1 d") + void importCnn1d(@TempDir Path tempDir) throws Exception { + importSequentialModelH5Test(tempDir,"modelimport/keras/examples/cnn1d/cnn1d_flatten_tf_keras2.h5"); } /** * DGA classifier test */ @Test - public void importDgaClassifier() throws Exception { - importSequentialModelH5Test("modelimport/keras/examples/dga_classifier/keras2_dga_classifier_tf_model.h5"); + @DisplayName("Import Dga Classifier") + void importDgaClassifier(@TempDir Path tempDir) throws Exception { + importSequentialModelH5Test(tempDir,"modelimport/keras/examples/dga_classifier/keras2_dga_classifier_tf_model.h5"); } /** * Reshape flat input into 3D to fit into an LSTM model */ @Test - public void importFlatIntoLSTM() throws Exception { - importFunctionalModelH5Test("modelimport/keras/examples/reshape_to_rnn/reshape_model.h5"); + @DisplayName("Import Flat Into LSTM") + void importFlatIntoLSTM(@TempDir Path tempDir) throws Exception { + importFunctionalModelH5Test(tempDir,"modelimport/keras/examples/reshape_to_rnn/reshape_model.h5"); } - /** * Functional LSTM test */ @Test - public void importFunctionalLstmTfKeras2() throws Exception { + @DisplayName("Import Functional Lstm Tf Keras 2") + void importFunctionalLstmTfKeras2(@TempDir Path tempDir) throws Exception { String modelPath = "modelimport/keras/examples/functional_lstm/lstm_functional_tf_keras_2.h5"; - // No training enabled - ComputationGraph graphNoTrain = importFunctionalModelH5Test(modelPath, null, false); + ComputationGraph graphNoTrain = importFunctionalModelH5Test(tempDir,modelPath, null, false); System.out.println(graphNoTrain.summary()); - // Training enabled - ComputationGraph graph = importFunctionalModelH5Test(modelPath, null, true); + ComputationGraph graph = importFunctionalModelH5Test(tempDir,modelPath, null, true); System.out.println(graph.summary()); - // Make predictions int miniBatch = 32; - INDArray input = Nd4j.ones(miniBatch, 10, 4); //NWC format - with nIn=4, seqLength = 10 + // NWC format - with nIn=4, seqLength = 10 + INDArray input = Nd4j.ones(miniBatch, 10, 4); INDArray[] out = graph.output(input); - // Fit model - graph.fit(new INDArray[]{input}, out); + graph.fit(new INDArray[] { input }, out); } /** * U-Net */ @Test - public void importUnetTfKeras2() throws Exception { - importFunctionalModelH5Test( - "modelimport/keras/examples/unet/unet_keras_2_tf.h5", null, true); + @DisplayName("Import Unet Tf Keras 2") + void importUnetTfKeras2(@TempDir Path tempDir) throws Exception { + importFunctionalModelH5Test(tempDir,"modelimport/keras/examples/unet/unet_keras_2_tf.h5", null, true); } /** * ResNet50 */ @Test - public void importResnet50() throws Exception { - importFunctionalModelH5Test("modelimport/keras/examples/resnet/resnet50_weights_tf_dim_ordering_tf_kernels.h5"); + @DisplayName("Import Resnet 50") + void importResnet50(@TempDir Path tempDir) throws Exception { + importFunctionalModelH5Test(tempDir,"modelimport/keras/examples/resnet/resnet50_weights_tf_dim_ordering_tf_kernels.h5"); } /** * DenseNet */ @Test - public void importDenseNet() throws Exception { - importFunctionalModelH5Test("modelimport/keras/examples/densenet/densenet121_tf_keras_2.h5"); + @DisplayName("Import Dense Net") + void importDenseNet(@TempDir Path tempDir) throws Exception { + importFunctionalModelH5Test(tempDir,"modelimport/keras/examples/densenet/densenet121_tf_keras_2.h5"); } /** * SqueezeNet */ @Test - public void importSqueezeNet() throws Exception { - importFunctionalModelH5Test("modelimport/keras/examples/squeezenet/squeezenet.h5"); + @DisplayName("Import Squeeze Net") + void importSqueezeNet(@TempDir Path tempDir) throws Exception { + importFunctionalModelH5Test(tempDir,"modelimport/keras/examples/squeezenet/squeezenet.h5"); } - /** * MobileNet */ @Test - public void importMobileNet() throws Exception { - ComputationGraph graph = importFunctionalModelH5Test("modelimport/keras/examples/mobilenet/alternative.hdf5"); + @DisplayName("Import Mobile Net") + void importMobileNet(@TempDir Path tempDir) throws Exception { + ComputationGraph graph = importFunctionalModelH5Test(tempDir,"modelimport/keras/examples/mobilenet/alternative.hdf5"); INDArray input = Nd4j.ones(10, 299, 299, 3); graph.output(input); } @@ -455,11 +500,12 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { * InceptionV3 Keras 2 no top */ @Test - public void importInceptionKeras2() throws Exception { - int[] inputShape = new int[]{299, 299, 3}; - ComputationGraph graph = importFunctionalModelH5Test( - "modelimport/keras/examples/inception/inception_tf_keras_2.h5", inputShape, false); - INDArray input = Nd4j.ones(10, 299, 299, 3); //TF = channels last = NHWC + @DisplayName("Import Inception Keras 2") + void importInceptionKeras2(@TempDir Path tempDir) throws Exception { + int[] inputShape = new int[] { 299, 299, 3 }; + ComputationGraph graph = importFunctionalModelH5Test(tempDir,"modelimport/keras/examples/inception/inception_tf_keras_2.h5", inputShape, false); + // TF = channels last = NHWC + INDArray input = Nd4j.ones(10, 299, 299, 3); graph.output(input); System.out.println(graph.summary()); } @@ -468,12 +514,13 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { * InceptionV3 */ @Test - //note this is actually keras 1 and its input dimension ordering is channels first + @DisplayName("Import Inception") + // note this is actually keras 1 and its input dimension ordering is channels first // Takes unreasonably long, but works - public void importInception() throws Exception { - ComputationGraph graph = importFunctionalModelH5Test( - "modelimport/keras/examples/inception/inception_v3_complete.h5"); - INDArray input = Nd4j.ones(10, 3,299, 299); //TH = channels first = NCHW + void importInception(@TempDir Path tempDir) throws Exception { + ComputationGraph graph = importFunctionalModelH5Test(tempDir,"modelimport/keras/examples/inception/inception_v3_complete.h5"); + // TH = channels first = NCHW + INDArray input = Nd4j.ones(10, 3, 299, 299); graph.output(input); System.out.println(graph.summary()); } @@ -482,47 +529,41 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { * Inception V4 */ @Test - @Ignore + @Disabled + @DisplayName("Import Inception V 4") // Model and weights have about 170mb, too large for test resources and also too excessive to enable as unit test - public void importInceptionV4() throws Exception { - String modelUrl = DL4JResources.getURLString( - "models/inceptionv4_keras_imagenet_weightsandconfig.h5"); - File kerasFile = testDir.newFile("inceptionv4_keras_imagenet_weightsandconfig.h5"); - + void importInceptionV4(@TempDir Path testDir) throws Exception { + String modelUrl = DL4JResources.getURLString("models/inceptionv4_keras_imagenet_weightsandconfig.h5"); + File kerasFile = testDir.resolve("inceptionv4_keras_imagenet_weightsandconfig.h5").toFile(); if (!kerasFile.exists()) { FileUtils.copyURLToFile(new URL(modelUrl), kerasFile); kerasFile.deleteOnExit(); } - - int[] inputShape = new int[]{299, 299, 3}; - ComputationGraph graph = importFunctionalModelH5Test( - kerasFile.getAbsolutePath(), inputShape, false); - + int[] inputShape = new int[] { 299, 299, 3 }; + ComputationGraph graph = importFunctionalModelH5Test(testDir,kerasFile.getAbsolutePath(), inputShape, false); // System.out.println(graph.summary()); - } /** * Xception */ @Test - public void importXception() throws Exception { - int[] inputShape = new int[]{299, 299, 3}; - ComputationGraph graph = importFunctionalModelH5Test( - "modelimport/keras/examples/xception/xception_tf_keras_2.h5", inputShape, false); + @DisplayName("Import Xception") + void importXception(@TempDir Path tempDir) throws Exception { + int[] inputShape = new int[] { 299, 299, 3 }; + ComputationGraph graph = importFunctionalModelH5Test(tempDir,"modelimport/keras/examples/xception/xception_tf_keras_2.h5", inputShape, false); } /** * Seq2seq model */ @Test - // does not work yet, needs DL4J enhancements - public void importSeq2Seq() throws Exception { - importFunctionalModelH5Test("modelimport/keras/examples/seq2seq/full_model_seq2seq_5549.h5"); - + @DisplayName("Import Seq 2 Seq") + // does not work yet, needs DL4J enhancements + void importSeq2Seq(@TempDir Path tempDir) throws Exception { + importFunctionalModelH5Test(tempDir,"modelimport/keras/examples/seq2seq/full_model_seq2seq_5549.h5"); } - /** * Import all AlphaGo Zero model variants, i.e. * - Dual residual architecture @@ -530,57 +571,64 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { * - Separate (policy and value) residual architecture * - Separate (policy and value) convolutional architecture */ - @Test //AB 20200427 Bad keras model - Keras JSON has input shape [null, 10, 19, 19] (i.e., NCHW) but all layers are set to channels_last - @Ignore("Data and channel layout mismatch. We don't support permuting the weights yet.") - public void importSepConvPolicy() throws Exception { - ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/sep_conv_policy.h5"); + // AB 20200427 Bad keras model - Keras JSON has input shape [null, 10, 19, 19] (i.e., NCHW) but all layers are set to channels_last + @Test + @Disabled("Data and channel layout mismatch. We don't support permuting the weights yet.") + @DisplayName("Import Sep Conv Policy") + void importSepConvPolicy(@TempDir Path tempDir) throws Exception { + ComputationGraph model = importFunctionalModelH5Test(tempDir,"modelimport/keras/examples/agz/sep_conv_policy.h5"); INDArray input = Nd4j.create(32, 19, 19, 10); model.output(input); } - @Test //AB 20200427 Bad keras model - Keras JSON has input shape [null, 10, 19, 19] (i.e., NCHW) but all layers are set to channels_last - @Ignore("Data and channel layout mismatch. We don't support permuting the weights yet.") - public void importSepResPolicy() throws Exception { - ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/sep_res_policy.h5"); + // AB 20200427 Bad keras model - Keras JSON has input shape [null, 10, 19, 19] (i.e., NCHW) but all layers are set to channels_last + @Test + @Disabled("Data and channel layout mismatch. We don't support permuting the weights yet.") + @DisplayName("Import Sep Res Policy") + void importSepResPolicy(@TempDir Path tempDir) throws Exception { + ComputationGraph model = importFunctionalModelH5Test(tempDir,"modelimport/keras/examples/agz/sep_res_policy.h5"); INDArray input = Nd4j.create(32, 19, 19, 10); model.output(input); } - - @Test //AB 20200427 Bad keras model - Keras JSON has input shape [null, 10, 19, 19] (i.e., NCHW) but all layers are set to channels_last - @Ignore("Data and channel layout mismatch. We don't support permuting the weights yet.") - public void importSepConvValue() throws Exception { - ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/sep_conv_value.h5"); + // AB 20200427 Bad keras model - Keras JSON has input shape [null, 10, 19, 19] (i.e., NCHW) but all layers are set to channels_last + @Test + @Disabled("Data and channel layout mismatch. We don't support permuting the weights yet.") + @DisplayName("Import Sep Conv Value") + void importSepConvValue(@TempDir Path tempDir) throws Exception { + ComputationGraph model = importFunctionalModelH5Test(tempDir,"modelimport/keras/examples/agz/sep_conv_value.h5"); INDArray input = Nd4j.create(32, 19, 19, 10); model.output(input); } - @Test() //AB 20200427 Bad keras model - Keras JSON has input shape [null, 10, 19, 19] (i.e., NCHW) but all layers are set to channels_last - @Ignore("Data and channel layout mismatch. We don't support permuting the weights yet.") - public void importSepResValue() throws Exception { + @Test + @Disabled("Data and channel layout mismatch. We don't support permuting the weights yet.") + @DisplayName("Import Sep Res Value") + void importSepResValue(@TempDir Path tempDir) throws Exception { String filePath = "C:\\Users\\agibs\\Documents\\GitHub\\keras1-import-test\\sep_res_value.h5"; - KerasModelBuilder builder = new KerasModel().modelBuilder().modelHdf5Filename(filePath) - .enforceTrainingConfig(false); - + KerasModelBuilder builder = new KerasModel().modelBuilder().modelHdf5Filename(filePath).enforceTrainingConfig(false); KerasModel model = builder.buildModel(); ComputationGraph compGraph = model.getComputationGraph(); - //ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/sep_res_value.h5"); + // ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/sep_res_value.h5"); INDArray input = Nd4j.create(32, 19, 19, 10); compGraph.output(input); } - @Test //AB 20200427 Bad keras model - Keras JSON has input shape [null, 10, 19, 19] (i.e., NCHW) but all layers are set to channels_last - @Ignore("Data and channel layout mismatch. We don't support permuting the weights yet.") - public void importDualRes() throws Exception { - ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/dual_res.h5"); + // AB 20200427 Bad keras model - Keras JSON has input shape [null, 10, 19, 19] (i.e., NCHW) but all layers are set to channels_last + @Test + @Disabled("Data and channel layout mismatch. We don't support permuting the weights yet.") + @DisplayName("Import Dual Res") + void importDualRes(@TempDir Path tempDir) throws Exception { + ComputationGraph model = importFunctionalModelH5Test(tempDir,"modelimport/keras/examples/agz/dual_res.h5"); INDArray input = Nd4j.create(32, 19, 19, 10); model.output(input); } - @Test() //AB 20200427 Bad keras model - Keras JSON has input shape [null, 10, 19, 19] (i.e., NCHW) but all layers are set to channels_last - @Ignore("Data and channel layout mismatch. We don't support permuting the weights yet.") - public void importDualConv() throws Exception { - ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/dual_conv.h5"); + @Test + @Disabled("Data and channel layout mismatch. We don't support permuting the weights yet.") + @DisplayName("Import Dual Conv") + void importDualConv(@TempDir Path tempDir) throws Exception { + ComputationGraph model = importFunctionalModelH5Test(tempDir,"modelimport/keras/examples/agz/dual_conv.h5"); INDArray input = Nd4j.create(32, 19, 19, 10); model.output(input); } @@ -589,74 +637,60 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { * MTCNN */ @Test - public void importMTCNN() throws Exception { - ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/48net_complete.h5"); + @DisplayName("Import MTCNN") + void importMTCNN(@TempDir Path tempDir) throws Exception { + ComputationGraph model = importFunctionalModelH5Test(tempDir,"modelimport/keras/examples/48net_complete.h5"); } - @Test() - @Ignore("Data and channel layout mismatch. We don't support permuting the weights yet.") - public void testNCHWNWHCChangeImportModel() throws Exception { - ComputationGraph computationGraph = importFunctionalModelH5Test("modelimport/keras/weights/simpleconv2d_model.hdf5"); - computationGraph.output(Nd4j.zeros(1,1,28,28)); - - } - - @Test + @Disabled("Data and channel layout mismatch. We don't support permuting the weights yet.") + @DisplayName("Test NCHWNWHC Change Import Model") + void testNCHWNWHCChangeImportModel(@TempDir Path tempDir) throws Exception { + ComputationGraph computationGraph = importFunctionalModelH5Test(tempDir,"modelimport/keras/weights/simpleconv2d_model.hdf5"); + computationGraph.output(Nd4j.zeros(1, 1, 28, 28)); + } + + @Test + @DisplayName("Import MTCNN 2 D") // TODO: fails, since we can't use OldSoftMax on >2D data (here: convolution layer) // TODO: also related to #6339, fix this together - public void importMTCNN2D() throws Exception { - ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/12net.h5", - new int[] {24, 24, 3}, false); - INDArray input = Nd4j.create(10, 24, 24,3); + void importMTCNN2D(@TempDir Path tempDir) throws Exception { + ComputationGraph model = importFunctionalModelH5Test(tempDir,"modelimport/keras/examples/12net.h5", new int[] { 24, 24, 3 }, false); + INDArray input = Nd4j.create(10, 24, 24, 3); model.output(input); -// System.out.println(model.summary()); + // System.out.println(model.summary()); } /** * Masking layers (simple Masking into LSTM) */ @Test - public void testMaskingZeroValue() throws Exception { - MultiLayerNetwork model = importSequentialModelH5Test( - "modelimport/keras/examples/masking/masking_zero_lstm.h5"); + @DisplayName("Test Masking Zero Value") + void testMaskingZeroValue(@TempDir Path tempDir) throws Exception { + MultiLayerNetwork model = importSequentialModelH5Test(tempDir,"modelimport/keras/examples/masking/masking_zero_lstm.h5"); model.summary(); } @Test - public void testMaskingTwoValue() throws Exception { - MultiLayerNetwork model = importSequentialModelH5Test( - "modelimport/keras/examples/masking/masking_two_lstm.h5"); + @DisplayName("Test Masking Two Value") + void testMaskingTwoValue(@TempDir Path tempDir) throws Exception { + MultiLayerNetwork model = importSequentialModelH5Test(tempDir,"modelimport/keras/examples/masking/masking_two_lstm.h5"); model.summary(); } @Test - public void testCausalConv1D() throws Exception { - String[] names = new String[]{ - "causal_conv1d_k2_s1_d1_cl_model.h5", - "causal_conv1d_k2_s1_d2_cl_model.h5", - "causal_conv1d_k2_s2_d1_cl_model.h5", - "causal_conv1d_k2_s3_d1_cl_model.h5", - "causal_conv1d_k3_s1_d1_cl_model.h5", - "causal_conv1d_k3_s1_d2_cl_model.h5", - "causal_conv1d_k3_s2_d1_cl_model.h5", - "causal_conv1d_k3_s3_d1_cl_model.h5", - "causal_conv1d_k4_s1_d1_cl_model.h5", - "causal_conv1d_k4_s1_d2_cl_model.h5", - "causal_conv1d_k4_s2_d1_cl_model.h5", - "causal_conv1d_k4_s3_d1_cl_model.h5" - }; - - for(String name : names) { + @DisplayName("Test Causal Conv 1 D") + void testCausalConv1D(@TempDir Path tempDir) throws Exception { + String[] names = new String[] { "causal_conv1d_k2_s1_d1_cl_model.h5", "causal_conv1d_k2_s1_d2_cl_model.h5", "causal_conv1d_k2_s2_d1_cl_model.h5", "causal_conv1d_k2_s3_d1_cl_model.h5", "causal_conv1d_k3_s1_d1_cl_model.h5", "causal_conv1d_k3_s1_d2_cl_model.h5", "causal_conv1d_k3_s2_d1_cl_model.h5", "causal_conv1d_k3_s3_d1_cl_model.h5", "causal_conv1d_k4_s1_d1_cl_model.h5", "causal_conv1d_k4_s1_d2_cl_model.h5", "causal_conv1d_k4_s2_d1_cl_model.h5", "causal_conv1d_k4_s3_d1_cl_model.h5" }; + for (String name : names) { System.out.println("Starting test: " + name); String modelPath = "modelimport/keras/examples/causal_conv1d/" + name; - String inputsOutputPath = "modelimport/keras/examples/causal_conv1d/" + (name.substring(0,name.length() - "model.h5".length()) + "inputs_and_outputs.h5"); - //TODO: + String inputsOutputPath = "modelimport/keras/examples/causal_conv1d/" + (name.substring(0, name.length() - "model.h5".length()) + "inputs_and_outputs.h5"); + // TODO: /** * Difference in weights. Same elements, but loaded differently. Likely acceptable difference. Need to confirm though. */ - MultiLayerNetwork net = importEndModelTest(modelPath, inputsOutputPath, true, true, - true, true, false, null, null); + MultiLayerNetwork net = importEndModelTest(tempDir,modelPath, inputsOutputPath, true, true, true, true, false, null, null); Layer l = net.getLayer(0); Convolution1DLayer c1d = (Convolution1DLayer) l.getConfig(); assertEquals(ConvolutionMode.Causal, c1d.getConvolutionMode()); @@ -664,106 +698,41 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { } @Test - public void testConv1D() throws Exception { - String[] names = new String[]{ - "conv1d_k2_s1_d1_cf_same_model.h5", - "conv1d_k2_s1_d1_cf_valid_model.h5", - "conv1d_k2_s1_d1_cl_same_model.h5", - "conv1d_k2_s1_d1_cl_valid_model.h5", - "conv1d_k2_s1_d2_cf_same_model.h5", - "conv1d_k2_s1_d2_cf_valid_model.h5", - "conv1d_k2_s1_d2_cl_same_model.h5", - "conv1d_k2_s1_d2_cl_valid_model.h5", - "conv1d_k2_s2_d1_cf_same_model.h5", - "conv1d_k2_s2_d1_cf_valid_model.h5", - "conv1d_k2_s2_d1_cl_same_model.h5", - "conv1d_k2_s2_d1_cl_valid_model.h5", - "conv1d_k2_s3_d1_cf_same_model.h5", - "conv1d_k2_s3_d1_cf_valid_model.h5", - "conv1d_k2_s3_d1_cl_same_model.h5", - "conv1d_k2_s3_d1_cl_valid_model.h5", - "conv1d_k3_s1_d1_cf_same_model.h5", - "conv1d_k3_s1_d1_cf_valid_model.h5", - "conv1d_k3_s1_d1_cl_same_model.h5", - "conv1d_k3_s1_d1_cl_valid_model.h5", - "conv1d_k3_s1_d2_cf_same_model.h5", - "conv1d_k3_s1_d2_cf_valid_model.h5", - "conv1d_k3_s1_d2_cl_same_model.h5", - "conv1d_k3_s1_d2_cl_valid_model.h5", - "conv1d_k3_s2_d1_cf_same_model.h5", - "conv1d_k3_s2_d1_cf_valid_model.h5", - "conv1d_k3_s2_d1_cl_same_model.h5", - "conv1d_k3_s2_d1_cl_valid_model.h5", - "conv1d_k3_s3_d1_cf_same_model.h5", - "conv1d_k3_s3_d1_cf_valid_model.h5", - "conv1d_k3_s3_d1_cl_same_model.h5", - "conv1d_k3_s3_d1_cl_valid_model.h5", - "conv1d_k4_s1_d1_cf_same_model.h5", - "conv1d_k4_s1_d1_cf_valid_model.h5", - "conv1d_k4_s1_d1_cl_same_model.h5", - "conv1d_k4_s1_d1_cl_valid_model.h5", - "conv1d_k4_s1_d2_cf_same_model.h5", - "conv1d_k4_s1_d2_cf_valid_model.h5", - "conv1d_k4_s1_d2_cl_same_model.h5", - "conv1d_k4_s1_d2_cl_valid_model.h5", - "conv1d_k4_s2_d1_cf_same_model.h5", - "conv1d_k4_s2_d1_cf_valid_model.h5", - "conv1d_k4_s2_d1_cl_same_model.h5", - "conv1d_k4_s2_d1_cl_valid_model.h5", - "conv1d_k4_s3_d1_cf_same_model.h5", - "conv1d_k4_s3_d1_cf_valid_model.h5", - "conv1d_k4_s3_d1_cl_same_model.h5", - "conv1d_k4_s3_d1_cl_valid_model.h5", - }; - - for(String name : names) { + @DisplayName("Test Conv 1 D") + void testConv1D(@TempDir Path tempDir) throws Exception { + String[] names = new String[] { "conv1d_k2_s1_d1_cf_same_model.h5", "conv1d_k2_s1_d1_cf_valid_model.h5", "conv1d_k2_s1_d1_cl_same_model.h5", "conv1d_k2_s1_d1_cl_valid_model.h5", "conv1d_k2_s1_d2_cf_same_model.h5", "conv1d_k2_s1_d2_cf_valid_model.h5", "conv1d_k2_s1_d2_cl_same_model.h5", "conv1d_k2_s1_d2_cl_valid_model.h5", "conv1d_k2_s2_d1_cf_same_model.h5", "conv1d_k2_s2_d1_cf_valid_model.h5", "conv1d_k2_s2_d1_cl_same_model.h5", "conv1d_k2_s2_d1_cl_valid_model.h5", "conv1d_k2_s3_d1_cf_same_model.h5", "conv1d_k2_s3_d1_cf_valid_model.h5", "conv1d_k2_s3_d1_cl_same_model.h5", "conv1d_k2_s3_d1_cl_valid_model.h5", "conv1d_k3_s1_d1_cf_same_model.h5", "conv1d_k3_s1_d1_cf_valid_model.h5", "conv1d_k3_s1_d1_cl_same_model.h5", "conv1d_k3_s1_d1_cl_valid_model.h5", "conv1d_k3_s1_d2_cf_same_model.h5", "conv1d_k3_s1_d2_cf_valid_model.h5", "conv1d_k3_s1_d2_cl_same_model.h5", "conv1d_k3_s1_d2_cl_valid_model.h5", "conv1d_k3_s2_d1_cf_same_model.h5", "conv1d_k3_s2_d1_cf_valid_model.h5", "conv1d_k3_s2_d1_cl_same_model.h5", "conv1d_k3_s2_d1_cl_valid_model.h5", "conv1d_k3_s3_d1_cf_same_model.h5", "conv1d_k3_s3_d1_cf_valid_model.h5", "conv1d_k3_s3_d1_cl_same_model.h5", "conv1d_k3_s3_d1_cl_valid_model.h5", "conv1d_k4_s1_d1_cf_same_model.h5", "conv1d_k4_s1_d1_cf_valid_model.h5", "conv1d_k4_s1_d1_cl_same_model.h5", "conv1d_k4_s1_d1_cl_valid_model.h5", "conv1d_k4_s1_d2_cf_same_model.h5", "conv1d_k4_s1_d2_cf_valid_model.h5", "conv1d_k4_s1_d2_cl_same_model.h5", "conv1d_k4_s1_d2_cl_valid_model.h5", "conv1d_k4_s2_d1_cf_same_model.h5", "conv1d_k4_s2_d1_cf_valid_model.h5", "conv1d_k4_s2_d1_cl_same_model.h5", "conv1d_k4_s2_d1_cl_valid_model.h5", "conv1d_k4_s3_d1_cf_same_model.h5", "conv1d_k4_s3_d1_cf_valid_model.h5", "conv1d_k4_s3_d1_cl_same_model.h5", "conv1d_k4_s3_d1_cl_valid_model.h5" }; + for (String name : names) { System.out.println("Starting test: " + name); String modelPath = "modelimport/keras/examples/conv1d/" + name; - String inputsOutputPath = "modelimport/keras/examples/conv1d/" + (name.substring(0,name.length()-"model.h5".length()) + "inputs_and_outputs.h5"); - - importEndModelTest(modelPath, inputsOutputPath, true, true, - true, true, false, null, null); //f, f2); + String inputsOutputPath = "modelimport/keras/examples/conv1d/" + (name.substring(0, name.length() - "model.h5".length()) + "inputs_and_outputs.h5"); + importEndModelTest(tempDir,modelPath, inputsOutputPath, true, true, true, true, false, null, // f, f2); + null); } } - @Test - public void testActivationLayers() throws Exception { - String[] names = new String[]{ - "ELU_0_model.h5", - "LeakyReLU_0_model.h5", - "ReLU_0_model.h5", - "ReLU_1_model.h5", - "ReLU_2_model.h5", - "ReLU_3_model.h5", - "Softmax_0_model.h5", - "ThresholdReLU_0_model.h5", - }; - - for(String name : names ){ + @DisplayName("Test Activation Layers") + void testActivationLayers(@TempDir Path tempDir) throws Exception { + String[] names = new String[] { "ELU_0_model.h5", "LeakyReLU_0_model.h5", "ReLU_0_model.h5", "ReLU_1_model.h5", "ReLU_2_model.h5", "ReLU_3_model.h5", "Softmax_0_model.h5", "ThresholdReLU_0_model.h5" }; + for (String name : names) { System.out.println("Starting test: " + name); String modelPath = "modelimport/keras/examples/activations/" + name; - String inputsOutputPath = "modelimport/keras/examples/activations/" + (name.substring(0,name.length()-"model.h5".length()) + "inputs_and_outputs.h5"); - - importEndModelTest(modelPath, inputsOutputPath, true, true, - true, true, false, null, null); + String inputsOutputPath = "modelimport/keras/examples/activations/" + (name.substring(0, name.length() - "model.h5".length()) + "inputs_and_outputs.h5"); + importEndModelTest(tempDir,modelPath, inputsOutputPath, true, true, true, true, false, null, null); } } - private ComputationGraph importFunctionalModelH5Test(String modelPath) throws Exception { - return importFunctionalModelH5Test(modelPath, null, false); + private ComputationGraph importFunctionalModelH5Test(Path tempDir,String modelPath) throws Exception { + return importFunctionalModelH5Test(tempDir,modelPath, null, false); } - - private ComputationGraph importFunctionalModelH5Test(String modelPath, int[] inputShape, boolean train) - throws Exception { + private ComputationGraph importFunctionalModelH5Test(Path tempDir,String modelPath, int[] inputShape, boolean train) throws Exception { File modelFile; - try(InputStream is = Resources.asStream(modelPath)) { - modelFile = createTempFile(TEMP_MODEL_FILENAME, H5_EXTENSION); + try (InputStream is = Resources.asStream(modelPath)) { + modelFile = createTempFile(tempDir,TEMP_MODEL_FILENAME, H5_EXTENSION); Files.copy(is, modelFile.toPath(), StandardCopyOption.REPLACE_EXISTING); } - KerasModelBuilder builder = new KerasModel().modelBuilder().modelHdf5Filename(modelFile.getAbsolutePath()) - .enforceTrainingConfig(train); + KerasModelBuilder builder = new KerasModel().modelBuilder().modelHdf5Filename(modelFile.getAbsolutePath()).enforceTrainingConfig(train); if (inputShape != null) { builder.inputShape(inputShape); } @@ -771,17 +740,15 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { return model.getComputationGraph(); } - private MultiLayerNetwork importSequentialModelH5Test(String modelPath) throws Exception { - return importSequentialModelH5Test(modelPath, null); + private MultiLayerNetwork importSequentialModelH5Test(Path tempDir,String modelPath) throws Exception { + return importSequentialModelH5Test(tempDir,modelPath, null); } - - private MultiLayerNetwork importSequentialModelH5Test(String modelPath, int[] inputShape) throws Exception { - try(InputStream is = Resources.asStream(modelPath)) { - File modelFile = createTempFile(TEMP_MODEL_FILENAME, H5_EXTENSION); + private MultiLayerNetwork importSequentialModelH5Test(Path tempDir,String modelPath, int[] inputShape) throws Exception { + try (InputStream is = Resources.asStream(modelPath)) { + File modelFile = createTempFile(tempDir,TEMP_MODEL_FILENAME, H5_EXTENSION); Files.copy(is, modelFile.toPath(), StandardCopyOption.REPLACE_EXISTING); - KerasModelBuilder builder = new KerasModel().modelBuilder().modelHdf5Filename(modelFile.getAbsolutePath()) - .enforceTrainingConfig(false); + KerasModelBuilder builder = new KerasModel().modelBuilder().modelHdf5Filename(modelFile.getAbsolutePath()).enforceTrainingConfig(false); if (inputShape != null) { builder.inputShape(inputShape); } @@ -790,35 +757,27 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { } } - public MultiLayerNetwork importEndModelTest(String modelPath, String inputsOutputsPath, boolean tfOrdering, boolean checkPredictions, - boolean checkGradients, boolean enforceTrainingConfig) throws Exception { - return importEndModelTest(modelPath, inputsOutputsPath, tfOrdering, checkPredictions, checkGradients, true, enforceTrainingConfig, null, null); + public MultiLayerNetwork importEndModelTest(Path tempDir,String modelPath, String inputsOutputsPath, boolean tfOrdering, boolean checkPredictions, boolean checkGradients, boolean enforceTrainingConfig) throws Exception { + return importEndModelTest(tempDir,modelPath, inputsOutputsPath, tfOrdering, checkPredictions, checkGradients, true, enforceTrainingConfig, null, null); } - public MultiLayerNetwork importEndModelTest(String modelPath, String inputsOutputsPath, boolean tfOrdering, boolean checkPredictions, - boolean checkGradients, boolean enforceTrainingConfig, boolean checkAuc, Function inputPreProc, - BiFunction expectedPreProc) throws Exception { + public MultiLayerNetwork importEndModelTest(Path tempDir,String modelPath, String inputsOutputsPath, boolean tfOrdering, boolean checkPredictions, boolean checkGradients, boolean enforceTrainingConfig, boolean checkAuc, Function inputPreProc, BiFunction expectedPreProc) throws Exception { MultiLayerNetwork model; - try(InputStream is = Resources.asStream(modelPath)) { - File modelFile = createTempFile(TEMP_MODEL_FILENAME, H5_EXTENSION); + try (InputStream is = Resources.asStream(modelPath)) { + File modelFile = createTempFile(tempDir,TEMP_MODEL_FILENAME, H5_EXTENSION); Files.copy(is, modelFile.toPath(), StandardCopyOption.REPLACE_EXISTING); - KerasSequentialModel kerasModel = new KerasModel().modelBuilder().modelHdf5Filename(modelFile.getAbsolutePath()) - .enforceTrainingConfig(enforceTrainingConfig).buildSequential(); - + KerasSequentialModel kerasModel = new KerasModel().modelBuilder().modelHdf5Filename(modelFile.getAbsolutePath()).enforceTrainingConfig(enforceTrainingConfig).buildSequential(); model = kerasModel.getMultiLayerNetwork(); } - - File outputsFile = createTempFile(TEMP_OUTPUTS_FILENAME, H5_EXTENSION); - try(InputStream is = Resources.asStream(inputsOutputsPath)) { + File outputsFile = createTempFile(tempDir,TEMP_OUTPUTS_FILENAME, H5_EXTENSION); + try (InputStream is = Resources.asStream(inputsOutputsPath)) { Files.copy(is, outputsFile.toPath(), StandardCopyOption.REPLACE_EXISTING); } try (Hdf5Archive outputsArchive = new Hdf5Archive(outputsFile.getAbsolutePath())) { - if (checkPredictions) { INDArray input = getInputs(outputsArchive, tfOrdering)[0]; - if(inputPreProc != null) + if (inputPreProc != null) input = inputPreProc.apply(input); - Map activationsKeras = getActivations(outputsArchive, tfOrdering); for (int i = 0; i < model.getLayers().length; i++) { String layerName = model.getLayerNames().get(i); @@ -828,34 +787,29 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { INDArray exp = activationsKeras.get(layerName); Nd4j.getExecutioner().enableDebugMode(true); Nd4j.getExecutioner().enableVerboseMode(true); - if(expectedPreProc != null) + if (expectedPreProc != null) exp = expectedPreProc.apply(layerName, exp); compareINDArrays(layerName, exp, activationsDl4j, EPS); } } - INDArray predictionsKeras = getPredictions(outputsArchive, tfOrdering)[0]; INDArray predictionsDl4j = model.output(input, false); - if(expectedPreProc != null) + if (expectedPreProc != null) predictionsKeras = expectedPreProc.apply("output", predictionsKeras); compareINDArrays("predictions", predictionsKeras, predictionsDl4j, EPS); INDArray outputs = getOutputs(outputsArchive, true)[0]; - - if(outputs.rank() == 1) { + if (outputs.rank() == 1) { outputs = outputs.reshape(outputs.length(), 1); } val nOut = (int) outputs.size(-1); - - if(checkAuc) + if (checkAuc) compareMulticlassAUC("predictions", outputs, predictionsKeras, predictionsDl4j, nOut, EPS); } - - if (checkGradients && ! SKIP_GRAD_CHECKS) { + if (checkGradients && !SKIP_GRAD_CHECKS) { Random r = new Random(12345); INDArray input = getInputs(outputsArchive, tfOrdering)[0]; INDArray predictionsDl4j = model.output(input, false); - - //Infer one-hot labels... this probably won't work for all + // Infer one-hot labels... this probably won't work for all INDArray testLabels = Nd4j.create(predictionsDl4j.shape()); if (testLabels.rank() == 2) { for (int i = 0; i < testLabels.size(0); i++) { @@ -873,13 +827,11 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { checkGradients(model, input, testLabels); } } - return model; } private static INDArray[] getInputs(Hdf5Archive archive, boolean tensorFlowImageDimOrdering) throws Exception { - List inputNames = (List) KerasModelUtils - .parseJsonString(archive.readAttributeAsJson(GROUP_ATTR_INPUTS)).get(GROUP_ATTR_INPUTS); + List inputNames = (List) KerasModelUtils.parseJsonString(archive.readAttributeAsJson(GROUP_ATTR_INPUTS)).get(GROUP_ATTR_INPUTS); INDArray[] inputs = new INDArray[inputNames.size()]; for (int i = 0; i < inputNames.size(); i++) { inputs[i] = archive.readDataSet(inputNames.get(i), GROUP_ATTR_INPUTS); @@ -887,8 +839,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { return inputs; } - private static Map getActivations(Hdf5Archive archive, boolean tensorFlowImageDimOrdering) - throws Exception { + private static Map getActivations(Hdf5Archive archive, boolean tensorFlowImageDimOrdering) throws Exception { Map activations = new HashMap<>(); for (String layerName : archive.getDataSets(GROUP_ACTIVATIONS)) { INDArray activation = archive.readDataSet(layerName, GROUP_ACTIVATIONS); @@ -897,10 +848,8 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { return activations; } - private static INDArray[] getOutputs(Hdf5Archive archive, boolean tensorFlowImageDimOrdering) throws - Exception { - List outputNames = (List) KerasModelUtils - .parseJsonString(archive.readAttributeAsJson(GROUP_ATTR_OUTPUTS)).get(GROUP_ATTR_OUTPUTS); + private static INDArray[] getOutputs(Hdf5Archive archive, boolean tensorFlowImageDimOrdering) throws Exception { + List outputNames = (List) KerasModelUtils.parseJsonString(archive.readAttributeAsJson(GROUP_ATTR_OUTPUTS)).get(GROUP_ATTR_OUTPUTS); INDArray[] outputs = new INDArray[outputNames.size()]; for (int i = 0; i < outputNames.size(); i++) { outputs[i] = archive.readDataSet(outputNames.get(i), GROUP_ATTR_OUTPUTS); @@ -908,10 +857,8 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { return outputs; } - private static INDArray[] getPredictions(Hdf5Archive archive, boolean tensorFlowImageDimOrdering) - throws Exception { - List outputNames = (List) KerasModelUtils - .parseJsonString(archive.readAttributeAsJson(GROUP_ATTR_OUTPUTS)).get(GROUP_ATTR_OUTPUTS); + private static INDArray[] getPredictions(Hdf5Archive archive, boolean tensorFlowImageDimOrdering) throws Exception { + List outputNames = (List) KerasModelUtils.parseJsonString(archive.readAttributeAsJson(GROUP_ATTR_OUTPUTS)).get(GROUP_ATTR_OUTPUTS); INDArray[] predictions = new INDArray[outputNames.size()]; for (int i = 0; i < outputNames.size(); i++) { predictions[i] = archive.readDataSet(outputNames.get(i), GROUP_PREDICTIONS); @@ -920,7 +867,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { } private static void compareINDArrays(String label, INDArray expected, INDArray actual, double eps) { - if(!expected.equalShapes(actual)){ + if (!expected.equalShapes(actual)) { throw new IllegalStateException("Shapes do not match for \"" + label + "\": got " + Arrays.toString(expected.shape()) + " vs " + Arrays.toString(actual.shape())); } INDArray diff = expected.sub(actual.castTo(expected.dataType())); @@ -930,21 +877,19 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { double threshold = 1e-7; double aAbsMax = Math.max(Math.abs(expected.minNumber().doubleValue()), Math.abs(expected.maxNumber().doubleValue())); double bAbsMax = Math.max(Math.abs(actual.minNumber().doubleValue()), Math.abs(actual.maxNumber().doubleValue())); - // skip too small absolute inputs if (Math.abs(aAbsMax) > threshold && Math.abs(bAbsMax) > threshold) { boolean eq = expected.equalsWithEps(actual.castTo(expected.dataType()), eps); - if(!eq){ + if (!eq) { System.out.println("Expected: " + Arrays.toString(expected.shape()) + ", actual: " + Arrays.toString(actual.shape())); System.out.println("Expected:\n" + expected); System.out.println("Actual: \n" + actual); } - assertTrue("Output differs: " + label, eq); + assertTrue(eq,"Output differs: " + label); } } - private static void compareMulticlassAUC(String label, INDArray target, INDArray a, INDArray b, int nbClasses, - double eps) { + private static void compareMulticlassAUC(String label, INDArray target, INDArray a, INDArray b, int nbClasses, double eps) { ROCMultiClass evalA = new ROCMultiClass(100); evalA.eval(target, a); double avgAucA = evalA.calculateAverageAUC(); @@ -952,7 +897,6 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { evalB.eval(target, b); double avgAucB = evalB.calculateAverageAUC(); assertEquals(avgAucA, avgAucB, EPS); - double[] aucA = new double[nbClasses]; double[] aucB = new double[nbClasses]; if (nbClasses > 1) { @@ -968,43 +912,25 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { double eps = 1e-6; double max_rel_error = 1e-3; double min_abs_error = 1e-8; - MultiLayerNetwork netToTest; if (net.getOutputLayer() instanceof IOutputLayer) { netToTest = net; } else { org.deeplearning4j.nn.conf.layers.Layer l; if (labels.rank() == 2) { - l = new LossLayer.Builder() - .lossFunction(LossFunctions.LossFunction.MSE) - .activation(Activation.IDENTITY) - .build(); + l = new LossLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY).build(); } else { - //Rank 3 - l = new RnnOutputLayer.Builder() - .lossFunction(LossFunctions.LossFunction.MSE) - .activation(Activation.IDENTITY) - .nIn(labels.size(1)) - .nOut(labels.size(1)) - .build(); + // Rank 3 + l = new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY).nIn(labels.size(1)).nOut(labels.size(1)).build(); } - netToTest = new TransferLearning.Builder(net) - .fineTuneConfiguration(new FineTuneConfiguration.Builder() - .updater(new NoOp()) - .dropOut(0.0) - .build()) - .addLayer(l) - .build(); + netToTest = new TransferLearning.Builder(net).fineTuneConfiguration(new FineTuneConfiguration.Builder().updater(new NoOp()).dropOut(0.0).build()).addLayer(l).build(); } - log.info("Num params: " + net.numParams()); - for (Layer l : netToTest.getLayers()) { // Remove any dropout manually - until this is fixed: // https://github.com/eclipse/deeplearning4j/issues/4368 - l.conf().getLayer().setIDropout(null); - - //Also swap out activation functions... this is a bit of a hack, but should make the net gradient checkable... + l.conf().getLayer().setIDropout(null); + // Also swap out activation functions... this is a bit of a hack, but should make the net gradient checkable... if (l.conf().getLayer() instanceof FeedForwardLayer) { FeedForwardLayer ffl = (FeedForwardLayer) l.conf().getLayer(); IActivation activation = ffl.getActivationFn(); @@ -1015,14 +941,15 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { } } } - Nd4j.setDataType(DataType.DOUBLE); - boolean passed = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(netToTest).input(input) - .labels(labels).subset(true).maxPerParam(9)); - assertTrue("Gradient check failed", passed); + boolean passed = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(netToTest).input(input).labels(labels).subset(true).maxPerParam(9)); + assertTrue(passed, "Gradient check failed"); } - private File createTempFile(String prefix, String suffix) throws IOException { - return testDir.newFile(prefix + "-" + System.nanoTime() + suffix); + private File createTempFile(Path testDir,String prefix, String suffix) throws IOException { + File ret = new File(testDir.toFile(),prefix + "-" + System.nanoTime() + suffix); + ret.createNewFile(); + ret.deleteOnExit(); + return ret; } } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasYolo9000PredictTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasYolo9000PredictTest.java index 3144bdb8f..14403a067 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasYolo9000PredictTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasYolo9000PredictTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.modelimport.keras.e2e; import lombok.extern.slf4j.Slf4j; @@ -29,57 +28,40 @@ import org.deeplearning4j.nn.modelimport.keras.KerasModelImport; import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasSpaceToDepth; import org.deeplearning4j.nn.transferlearning.TransferLearning; import org.deeplearning4j.util.ModelSerializer; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler; import org.nd4j.linalg.factory.Nd4j; - import java.io.File; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; @Slf4j -public class KerasYolo9000PredictTest extends BaseDL4JTest { +@DisplayName("Keras Yolo 9000 Predict Test") +class KerasYolo9000PredictTest extends BaseDL4JTest { private static final String DL4J_MODEL_FILE_NAME = "."; + private static ImagePreProcessingScaler IMAGE_PREPROCESSING_SCALER = new ImagePreProcessingScaler(0, 1); @Test - @Ignore("Need to manually download file for ylo.") - public void testYoloPredictionImport() throws Exception { - - + @Disabled("Need to manually download file for ylo.") + @DisplayName("Test Yolo Prediction Import") + void testYoloPredictionImport() throws Exception { int HEIGHT = 416; int WIDTH = 416; INDArray indArray = Nd4j.create(HEIGHT, WIDTH, 3); IMAGE_PREPROCESSING_SCALER.transform(indArray); - KerasLayer.registerCustomLayer("Lambda", KerasSpaceToDepth.class); - String h5_FILENAME = "modelimport/keras/examples/yolo/yolo-voc.h5"; ComputationGraph graph = KerasModelImport.importKerasModelAndWeights(h5_FILENAME, false); - - double[][] priorBoxes = {{1.3221, 1.73145}, {3.19275, 4.00944}, {5.05587, 8.09892}, {9.47112, 4.84053}, {11.2364, 10.0071}}; + double[][] priorBoxes = { { 1.3221, 1.73145 }, { 3.19275, 4.00944 }, { 5.05587, 8.09892 }, { 9.47112, 4.84053 }, { 11.2364, 10.0071 } }; INDArray priors = Nd4j.create(priorBoxes); - - ComputationGraph model = new TransferLearning.GraphBuilder(graph) - .addLayer("outputs", - new org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer.Builder() - .boundingBoxPriors(priors) - .build(), - "conv2d_23") - .setOutputs("outputs") - .build(); - + ComputationGraph model = new TransferLearning.GraphBuilder(graph).addLayer("outputs", new org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer.Builder().boundingBoxPriors(priors).build(), "conv2d_23").setOutputs("outputs").build(); ModelSerializer.writeModel(model, DL4J_MODEL_FILE_NAME, false); - ComputationGraph computationGraph = ModelSerializer.restoreComputationGraph(new File(DL4J_MODEL_FILE_NAME)); - System.out.println(computationGraph.summary(InputType.convolutional(416, 416, 3))); - INDArray results = computationGraph.outputSingle(indArray); - - } - } - diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasYolo9000Test.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasYolo9000Test.java index 34981cbfd..29617de12 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasYolo9000Test.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasYolo9000Test.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.modelimport.keras.e2e; import lombok.extern.slf4j.Slf4j; @@ -26,43 +25,42 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.KerasModel; import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasSpaceToDepth; -import org.junit.Ignore; +import org.junit.jupiter.api.Disabled; import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.resources.Resources; - import java.io.File; import java.io.InputStream; import java.nio.file.Files; import java.nio.file.StandardCopyOption; +import org.junit.jupiter.api.DisplayName; +import java.nio.file.Path; +import org.junit.jupiter.api.extension.ExtendWith; @Slf4j -public class KerasYolo9000Test extends BaseDL4JTest { +@DisplayName("Keras Yolo 9000 Test") +class KerasYolo9000Test extends BaseDL4JTest { private static final String TEMP_MODEL_FILENAME = "tempModel"; + private static final String H5_EXTENSION = ".h5"; - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); + @TempDir + public Path testDir; - @Ignore + @Disabled @Test + @DisplayName("Test Custom Layer Yolo Import") // TODO: yolo and yolo-voc output are too large for github, find smaller equivalents - public void testCustomLayerYoloImport() throws Exception { + void testCustomLayerYoloImport() throws Exception { KerasLayer.registerCustomLayer("Lambda", KerasSpaceToDepth.class); - String modelPath = "modelimport/keras/examples/yolo/yolo.h5"; - - try(InputStream is = Resources.asStream(modelPath)) { - File modelFile = testDir.newFile(TEMP_MODEL_FILENAME + System.currentTimeMillis() + H5_EXTENSION); + try (InputStream is = Resources.asStream(modelPath)) { + File modelFile = testDir.resolve(TEMP_MODEL_FILENAME + System.currentTimeMillis() + H5_EXTENSION).toFile(); Files.copy(is, modelFile.toPath(), StandardCopyOption.REPLACE_EXISTING); - ComputationGraph model = new KerasModel().modelBuilder().modelHdf5Filename(modelFile.getAbsolutePath()) - .enforceTrainingConfig(false).buildModel().getComputationGraph(); - + ComputationGraph model = new KerasModel().modelBuilder().modelHdf5Filename(modelFile.getAbsolutePath()).enforceTrainingConfig(false).buildModel().getComputationGraph(); System.out.println(model.summary()); } - - } } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activation/KerasLeakyReLUTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activation/KerasLeakyReLUTest.java index 5d4e3e97b..ccb2be9df 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activation/KerasLeakyReLUTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activation/KerasLeakyReLUTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.modelimport.keras.layers.advanced.activation; import org.deeplearning4j.nn.conf.layers.ActivationLayer; @@ -26,23 +25,26 @@ import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.layers.advanced.activations.KerasLeakyReLU; -import org.junit.Test; - +import org.junit.jupiter.api.Test; import java.util.HashMap; import java.util.Map; - -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** * @author Max Pumperla */ -public class KerasLeakyReLUTest extends BaseDL4JTest { +@DisplayName("Keras Leaky Re LU Test") +class KerasLeakyReLUTest extends BaseDL4JTest { private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test - public void testLeakyReLULayer() throws Exception { + @DisplayName("Test Leaky Re LU Layer") + void testLeakyReLULayer() throws Exception { Integer keras1 = 1; buildLeakyReLULayer(conf1, keras1); Integer keras2 = 2; @@ -51,7 +53,6 @@ public class KerasLeakyReLUTest extends BaseDL4JTest { private void buildLeakyReLULayer(KerasLayerConfiguration conf, Integer kerasVersion) throws Exception { double alpha = 0.3; - Map layerConfig = new HashMap<>(); layerConfig.put(conf.getLAYER_FIELD_CLASS_NAME(), conf.getLAYER_CLASS_NAME_LEAKY_RELU()); Map config = new HashMap<>(); @@ -61,9 +62,8 @@ public class KerasLeakyReLUTest extends BaseDL4JTest { config.put(conf.getLAYER_FIELD_NAME(), layerName); layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config); layerConfig.put(conf.getLAYER_FIELD_KERAS_VERSION(), kerasVersion); - ActivationLayer layer = new KerasLeakyReLU(layerConfig).getActivationLayer(); - assertEquals("leakyrelu(a=0.3)", layer.getActivationFn().toString()); + assertEquals(layer.getActivationFn().toString(), "leakyrelu(a=0.3)"); assertEquals(layerName, layer.getLayerName()); } } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activation/KerasPReLUTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activation/KerasPReLUTest.java index eb52d30ec..f20465f0e 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activation/KerasPReLUTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activation/KerasPReLUTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.modelimport.keras.layers.advanced.activation; import org.deeplearning4j.nn.conf.inputs.InputType; @@ -29,27 +28,31 @@ import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.layers.advanced.activations.KerasPReLU; import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.WeightInitXavier; -import org.junit.Test; - +import org.junit.jupiter.api.Test; import java.util.HashMap; import java.util.Map; - -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** * @author Max Pumperla */ -public class KerasPReLUTest extends BaseDL4JTest { +@DisplayName("Keras P Re LU Test") +class KerasPReLUTest extends BaseDL4JTest { private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); private final String INIT_KERAS = "glorot_normal"; + private final IWeightInit INIT_DL4J = new WeightInitXavier(); @Test - public void testPReLULayer() throws Exception { + @DisplayName("Test P Re LU Layer") + void testPReLULayer() throws Exception { Integer keras1 = 1; buildPReLULayer(conf1, keras1); Integer keras2 = 2; @@ -57,7 +60,6 @@ public class KerasPReLUTest extends BaseDL4JTest { } private void buildPReLULayer(KerasLayerConfiguration conf, Integer kerasVersion) throws Exception { - Map layerConfig = new HashMap<>(); layerConfig.put(conf.getLAYER_FIELD_CLASS_NAME(), conf.getLAYER_CLASS_NAME_LEAKY_RELU()); Map config = new HashMap<>(); @@ -72,15 +74,11 @@ public class KerasPReLUTest extends BaseDL4JTest { init.put("class_name", conf.getINIT_GLOROT_NORMAL()); config.put("alpha_initializer", init); } - KerasPReLU kerasPReLU = new KerasPReLU(layerConfig); - - kerasPReLU.getOutputType(InputType.convolutional(5,4,3)); - + kerasPReLU.getOutputType(InputType.convolutional(5, 4, 3)); PReLULayer layer = kerasPReLU.getPReLULayer(); - assertArrayEquals(layer.getInputShape(), new long[] {3, 5, 4}); + assertArrayEquals(layer.getInputShape(), new long[] { 3, 5, 4 }); assertEquals(INIT_DL4J, layer.getWeightInitFn()); - assertEquals(layerName, layer.getLayerName()); } } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activation/KerasThresholdedReLUTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activation/KerasThresholdedReLUTest.java index d26f5d746..a0027ffdd 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activation/KerasThresholdedReLUTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activation/KerasThresholdedReLUTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.modelimport.keras.layers.advanced.activation; import org.deeplearning4j.nn.conf.layers.ActivationLayer; @@ -26,23 +25,26 @@ import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.layers.advanced.activations.KerasThresholdedReLU; -import org.junit.Test; - +import org.junit.jupiter.api.Test; import java.util.HashMap; import java.util.Map; - -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** * @author Max Pumperla */ -public class KerasThresholdedReLUTest extends BaseDL4JTest { +@DisplayName("Keras Thresholded Re LU Test") +class KerasThresholdedReLUTest extends BaseDL4JTest { private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test - public void testThresholdedReLULayer() throws Exception { + @DisplayName("Test Thresholded Re LU Layer") + void testThresholdedReLULayer() throws Exception { Integer keras1 = 1; buildThresholdedReLULayer(conf1, keras1); Integer keras2 = 2; @@ -50,9 +52,7 @@ public class KerasThresholdedReLUTest extends BaseDL4JTest { } private void buildThresholdedReLULayer(KerasLayerConfiguration conf, Integer kerasVersion) throws Exception { - double theta = 0.5; - Map layerConfig = new HashMap<>(); layerConfig.put(conf.getLAYER_FIELD_CLASS_NAME(), conf.getLAYER_CLASS_NAME_THRESHOLDED_RELU()); Map config = new HashMap<>(); @@ -62,9 +62,8 @@ public class KerasThresholdedReLUTest extends BaseDL4JTest { config.put(conf.getLAYER_FIELD_NAME(), layerName); layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config); layerConfig.put(conf.getLAYER_FIELD_KERAS_VERSION(), kerasVersion); - ActivationLayer layer = new KerasThresholdedReLU(layerConfig).getActivationLayer(); - assertEquals("thresholdedrelu(theta=0.5)", layer.getActivationFn().toString()); + assertEquals(layer.getActivationFn().toString(), "thresholdedrelu(theta=0.5)"); assertEquals(layerName, layer.getLayerName()); } } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasAtrousConvolution1DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasAtrousConvolution1DTest.java index 95f137b3d..ea5bcfddf 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasAtrousConvolution1DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasAtrousConvolution1DTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.modelimport.keras.layers.convolution; import org.deeplearning4j.nn.conf.ConvolutionMode; @@ -30,44 +29,60 @@ import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasAtrousConvolution1D; import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.WeightInitXavier; -import org.junit.Test; - +import org.junit.jupiter.api.Test; import java.util.HashMap; import java.util.Map; - -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** * @author Max Pumperla */ -public class KerasAtrousConvolution1DTest extends BaseDL4JTest { +@DisplayName("Keras Atrous Convolution 1 D Test") +class KerasAtrousConvolution1DTest extends BaseDL4JTest { private final String ACTIVATION_KERAS = "linear"; + private final String ACTIVATION_DL4J = "identity"; + private final String LAYER_NAME = "atrous_conv_1d"; + private final String INIT_KERAS = "glorot_normal"; + private final IWeightInit INIT_DL4J = new WeightInitXavier(); + private final double L1_REGULARIZATION = 0.01; + private final double L2_REGULARIZATION = 0.02; + private final double DROPOUT_KERAS = 0.3; + private final double DROPOUT_DL4J = 1 - DROPOUT_KERAS; - private final int[] KERNEL_SIZE = new int[]{1, 2}; - private final int[] DILATION = new int[]{2}; - private final int[] STRIDE = new int[]{3, 4}; + + private final int[] KERNEL_SIZE = new int[] { 1, 2 }; + + private final int[] DILATION = new int[] { 2 }; + + private final int[] STRIDE = new int[] { 3, 4 }; + private final int N_OUT = 13; + private final String BORDER_MODE_VALID = "valid"; - private final int[] VALID_PADDING = new int[]{0, 0}; + + private final int[] VALID_PADDING = new int[] { 0, 0 }; private Integer keras1 = 1; + private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); @Test - public void testAtrousConvolution1DLayer() throws Exception { + @DisplayName("Test Atrous Convolution 1 D Layer") + void testAtrousConvolution1DLayer() throws Exception { buildAtrousConvolution1DLayer(conf1, keras1); } - private void buildAtrousConvolution1DLayer(KerasLayerConfiguration conf, Integer kerasVersion) - throws Exception { + private void buildAtrousConvolution1DLayer(KerasLayerConfiguration conf, Integer kerasVersion) throws Exception { Map layerConfig = new HashMap<>(); layerConfig.put(conf.getLAYER_FIELD_CLASS_NAME(), conf.getLAYER_CLASS_NAME_CONVOLUTION_1D()); Map config = new HashMap<>(); @@ -96,7 +111,6 @@ public class KerasAtrousConvolution1DTest extends BaseDL4JTest { config.put(conf.getLAYER_FIELD_NB_FILTER(), N_OUT); config.put(conf.getLAYER_FIELD_BORDER_MODE(), BORDER_MODE_VALID); layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config); - Convolution1DLayer layer = new KerasAtrousConvolution1D(layerConfig).getAtrousConvolution1D(); assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString()); assertEquals(LAYER_NAME, layer.getLayerName()); @@ -115,4 +129,3 @@ public class KerasAtrousConvolution1DTest extends BaseDL4JTest { assertEquals(DILATION, layer.getDilation()); } } - diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasAtrousConvolution2DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasAtrousConvolution2DTest.java index e43769c4a..eec7412ff 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasAtrousConvolution2DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasAtrousConvolution2DTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.modelimport.keras.layers.convolution; import org.deeplearning4j.nn.conf.ConvolutionMode; @@ -30,47 +29,62 @@ import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasAtrousConvolution2D; import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.WeightInitXavier; -import org.junit.Test; - +import org.junit.jupiter.api.Test; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; - -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** * @author Max Pumperla */ -public class KerasAtrousConvolution2DTest extends BaseDL4JTest { +@DisplayName("Keras Atrous Convolution 2 D Test") +class KerasAtrousConvolution2DTest extends BaseDL4JTest { private final String ACTIVATION_KERAS = "linear"; + private final String ACTIVATION_DL4J = "identity"; + private final String LAYER_NAME = "atrous_conv_2d"; + private final String INIT_KERAS = "glorot_normal"; + private final IWeightInit INIT_DL4J = new WeightInitXavier(); + private final double L1_REGULARIZATION = 0.01; + private final double L2_REGULARIZATION = 0.02; + private final double DROPOUT_KERAS = 0.3; + private final double DROPOUT_DL4J = 1 - DROPOUT_KERAS; - private final int[] KERNEL_SIZE = new int[]{1, 2}; - private final int[] DILATION = new int[]{2, 2}; - private final int[] STRIDE = new int[]{3, 4}; + + private final int[] KERNEL_SIZE = new int[] { 1, 2 }; + + private final int[] DILATION = new int[] { 2, 2 }; + + private final int[] STRIDE = new int[] { 3, 4 }; + private final int N_OUT = 13; + private final String BORDER_MODE_VALID = "valid"; - private final int[] VALID_PADDING = new int[]{0, 0}; + + private final int[] VALID_PADDING = new int[] { 0, 0 }; private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); @Test - public void testAtrousConvolution2DLayer() throws Exception { + @DisplayName("Test Atrous Convolution 2 D Layer") + void testAtrousConvolution2DLayer() throws Exception { Integer keras1 = 1; buildAtrousConvolution2DLayer(conf1, keras1); } - private void buildAtrousConvolution2DLayer(KerasLayerConfiguration conf, Integer kerasVersion) - throws Exception { + private void buildAtrousConvolution2DLayer(KerasLayerConfiguration conf, Integer kerasVersion) throws Exception { Map layerConfig = new HashMap<>(); layerConfig.put(conf.getLAYER_FIELD_CLASS_NAME(), conf.getLAYER_CLASS_NAME_CONVOLUTION_2D()); Map config = new HashMap<>(); @@ -92,14 +106,20 @@ public class KerasAtrousConvolution2DTest extends BaseDL4JTest { config.put(conf.getLAYER_FIELD_NB_ROW(), KERNEL_SIZE[0]); config.put(conf.getLAYER_FIELD_NB_COL(), KERNEL_SIZE[1]); } else { - ArrayList kernel = new ArrayList() {{ - for (int i : KERNEL_SIZE) add(i); - }}; + ArrayList kernel = new ArrayList() { + + { + for (int i : KERNEL_SIZE) add(i); + } + }; config.put(conf.getLAYER_FIELD_KERNEL_SIZE(), kernel); } - ArrayList dilation = new ArrayList() {{ - for (int i : DILATION) add(i); - }}; + ArrayList dilation = new ArrayList() { + + { + for (int i : DILATION) add(i); + } + }; config.put(conf.getLAYER_FIELD_DILATION_RATE(), dilation); List subsampleList = new ArrayList<>(); subsampleList.add(STRIDE[0]); @@ -109,8 +129,6 @@ public class KerasAtrousConvolution2DTest extends BaseDL4JTest { config.put(conf.getLAYER_FIELD_BORDER_MODE(), BORDER_MODE_VALID); layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config); layerConfig.put(conf.getLAYER_FIELD_KERAS_VERSION(), kerasVersion); - - ConvolutionLayer layer = new KerasAtrousConvolution2D(layerConfig).getAtrousConvolution2D(); assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString()); assertEquals(LAYER_NAME, layer.getLayerName()); diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution1DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution1DTest.java index f08249c22..5bdb7a013 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution1DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution1DTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.modelimport.keras.layers.convolution; import org.deeplearning4j.nn.conf.ConvolutionMode; @@ -31,49 +30,67 @@ import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasConvolution1D; import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.WeightInitXavier; -import org.junit.Test; - +import org.junit.jupiter.api.Test; import java.util.ArrayList; import java.util.HashMap; import java.util.Map; - -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** * @author Max Pumperla */ -public class KerasConvolution1DTest extends BaseDL4JTest { +@DisplayName("Keras Convolution 1 D Test") +class KerasConvolution1DTest extends BaseDL4JTest { private final String ACTIVATION_KERAS = "linear"; + private final String ACTIVATION_DL4J = "identity"; + private final String LAYER_NAME = "test_layer"; + private final String INIT_KERAS = "glorot_normal"; + private final IWeightInit INIT_DL4J = new WeightInitXavier(); + private final double L1_REGULARIZATION = 0.01; + private final double L2_REGULARIZATION = 0.02; + private final double DROPOUT_KERAS = 0.3; + private final double DROPOUT_DL4J = 1 - DROPOUT_KERAS; - private final int[] KERNEL_SIZE = new int[]{2}; - private final int[] DILATION = new int[]{2}; - private final int[] STRIDE = new int[]{4}; + + private final int[] KERNEL_SIZE = new int[] { 2 }; + + private final int[] DILATION = new int[] { 2 }; + + private final int[] STRIDE = new int[] { 4 }; + private final int N_OUT = 13; + private final String BORDER_MODE_VALID = "valid"; - private final int[] VALID_PADDING = new int[]{0, 0}; + + private final int[] VALID_PADDING = new int[] { 0, 0 }; private Integer keras1 = 1; + private Integer keras2 = 2; + private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test - public void testConvolution1DLayer() throws Exception { + @DisplayName("Test Convolution 1 D Layer") + void testConvolution1DLayer() throws Exception { buildConvolution1DLayer(conf1, keras1, false); buildConvolution1DLayer(conf2, keras2, false); buildConvolution1DLayer(conf2, keras2, true); } - private void buildConvolution1DLayer(KerasLayerConfiguration conf, Integer kerasVersion, boolean withDilation) - throws Exception { + private void buildConvolution1DLayer(KerasLayerConfiguration conf, Integer kerasVersion, boolean withDilation) throws Exception { Map layerConfig = new HashMap<>(); layerConfig.put(conf.getLAYER_FIELD_CLASS_NAME(), conf.getLAYER_CLASS_NAME_CONVOLUTION_1D()); Map config = new HashMap<>(); @@ -88,9 +105,12 @@ public class KerasConvolution1DTest extends BaseDL4JTest { config.put(conf.getLAYER_FIELD_INIT(), init); } if (withDilation) { - ArrayList dilation = new ArrayList() {{ - for (int i : DILATION) add(i); - }}; + ArrayList dilation = new ArrayList() { + + { + for (int i : DILATION) add(i); + } + }; config.put(conf.getLAYER_FIELD_DILATION_RATE(), dilation); } Map W_reg = new HashMap(); @@ -99,18 +119,23 @@ public class KerasConvolution1DTest extends BaseDL4JTest { config.put(conf.getLAYER_FIELD_W_REGULARIZER(), W_reg); config.put(conf.getLAYER_FIELD_DROPOUT(), DROPOUT_KERAS); if (kerasVersion == 2) { - ArrayList kernel = new ArrayList() {{ - for (int i : KERNEL_SIZE) add(i); - }}; + ArrayList kernel = new ArrayList() { + + { + for (int i : KERNEL_SIZE) add(i); + } + }; config.put(conf.getLAYER_FIELD_FILTER_LENGTH(), kernel); } else { config.put(conf.getLAYER_FIELD_FILTER_LENGTH(), KERNEL_SIZE[0]); } - if (kerasVersion == 2) { - ArrayList stride = new ArrayList() {{ - for (int i : STRIDE) add(i); - }}; + ArrayList stride = new ArrayList() { + + { + for (int i : STRIDE) add(i); + } + }; config.put(conf.getLAYER_FIELD_SUBSAMPLE_LENGTH(), stride); } else { config.put(conf.getLAYER_FIELD_SUBSAMPLE_LENGTH(), STRIDE[0]); @@ -118,7 +143,6 @@ public class KerasConvolution1DTest extends BaseDL4JTest { config.put(conf.getLAYER_FIELD_NB_FILTER(), N_OUT); config.put(conf.getLAYER_FIELD_BORDER_MODE(), BORDER_MODE_VALID); layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config); - Convolution1DLayer layer = new KerasConvolution1D(layerConfig).getConvolution1DLayer(); assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString()); assertEquals(LAYER_NAME, layer.getLayerName()); diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution2DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution2DTest.java index 072da9f28..32fef216e 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution2DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution2DTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.modelimport.keras.layers.convolution; import org.deeplearning4j.nn.conf.ConvolutionMode; @@ -31,53 +30,69 @@ import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasConvolution2D; import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.WeightInitXavier; -import org.junit.Test; - +import org.junit.jupiter.api.Test; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; - -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** * @author Max Pumperla */ -public class KerasConvolution2DTest extends BaseDL4JTest { +@DisplayName("Keras Convolution 2 D Test") +class KerasConvolution2DTest extends BaseDL4JTest { private final String ACTIVATION_KERAS = "linear"; + private final String ACTIVATION_DL4J = "identity"; + private final String LAYER_NAME = "test_layer"; + private final String INIT_KERAS = "glorot_normal"; + private final IWeightInit INIT_DL4J = new WeightInitXavier(); + private final double L1_REGULARIZATION = 0.01; + private final double L2_REGULARIZATION = 0.02; + private final double DROPOUT_KERAS = 0.3; + private final double DROPOUT_DL4J = 1 - DROPOUT_KERAS; - private final int[] KERNEL_SIZE = new int[]{1, 2}; - private final int[] DILATION = new int[]{2, 2}; - private final int[] STRIDE = new int[]{3, 4}; + + private final int[] KERNEL_SIZE = new int[] { 1, 2 }; + + private final int[] DILATION = new int[] { 2, 2 }; + + private final int[] STRIDE = new int[] { 3, 4 }; + private final int N_OUT = 13; + private final String BORDER_MODE_VALID = "valid"; - private final int[] VALID_PADDING = new int[]{0, 0}; + + private final int[] VALID_PADDING = new int[] { 0, 0 }; private Integer keras1 = 1; + private Integer keras2 = 2; + private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); - @Test - public void testConvolution2DLayer() throws Exception { + @DisplayName("Test Convolution 2 D Layer") + void testConvolution2DLayer() throws Exception { buildConvolution2DLayer(conf1, keras1, false); buildConvolution2DLayer(conf2, keras2, false); buildConvolution2DLayer(conf2, keras2, true); } - - private void buildConvolution2DLayer(KerasLayerConfiguration conf, Integer kerasVersion, boolean withDilation) - throws Exception { + private void buildConvolution2DLayer(KerasLayerConfiguration conf, Integer kerasVersion, boolean withDilation) throws Exception { Map layerConfig = new HashMap<>(); layerConfig.put(conf.getLAYER_FIELD_CLASS_NAME(), conf.getLAYER_CLASS_NAME_CONVOLUTION_2D()); Map config = new HashMap<>(); @@ -99,15 +114,21 @@ public class KerasConvolution2DTest extends BaseDL4JTest { config.put(conf.getLAYER_FIELD_NB_ROW(), KERNEL_SIZE[0]); config.put(conf.getLAYER_FIELD_NB_COL(), KERNEL_SIZE[1]); } else { - ArrayList kernel = new ArrayList() {{ - for (int i : KERNEL_SIZE) add(i); - }}; + ArrayList kernel = new ArrayList() { + + { + for (int i : KERNEL_SIZE) add(i); + } + }; config.put(conf.getLAYER_FIELD_KERNEL_SIZE(), kernel); } if (withDilation) { - ArrayList dilation = new ArrayList() {{ - for (int i : DILATION) add(i); - }}; + ArrayList dilation = new ArrayList() { + + { + for (int i : DILATION) add(i); + } + }; config.put(conf.getLAYER_FIELD_DILATION_RATE(), dilation); } List subsampleList = new ArrayList<>(); @@ -118,8 +139,6 @@ public class KerasConvolution2DTest extends BaseDL4JTest { config.put(conf.getLAYER_FIELD_BORDER_MODE(), BORDER_MODE_VALID); layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config); layerConfig.put(conf.getLAYER_FIELD_KERAS_VERSION(), kerasVersion); - - ConvolutionLayer layer = new KerasConvolution2D(layerConfig).getConvolution2DLayer(); assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString()); assertEquals(LAYER_NAME, layer.getLayerName()); diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution3DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution3DTest.java index 69b94bdda..e61242e51 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution3DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution3DTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.modelimport.keras.layers.convolution; import org.deeplearning4j.nn.conf.ConvolutionMode; @@ -31,51 +30,66 @@ import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasConvolution3D; import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.WeightInitXavier; -import org.junit.Test; - +import org.junit.jupiter.api.Test; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; - -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** * @author Max Pumperla */ -public class KerasConvolution3DTest extends BaseDL4JTest { +@DisplayName("Keras Convolution 3 D Test") +class KerasConvolution3DTest extends BaseDL4JTest { private final String ACTIVATION_KERAS = "linear"; + private final String ACTIVATION_DL4J = "identity"; + private final String LAYER_NAME = "test_layer"; + private final String INIT_KERAS = "glorot_normal"; + private final IWeightInit INIT_DL4J = new WeightInitXavier(); + private final double L1_REGULARIZATION = 0.01; + private final double L2_REGULARIZATION = 0.02; + private final double DROPOUT_KERAS = 0.3; + private final double DROPOUT_DL4J = 1 - DROPOUT_KERAS; - private final int[] KERNEL_SIZE = new int[]{1, 2, 3}; - private final int[] STRIDE = new int[]{3, 4, 5}; + + private final int[] KERNEL_SIZE = new int[] { 1, 2, 3 }; + + private final int[] STRIDE = new int[] { 3, 4, 5 }; + private final int N_OUT = 13; + private final String BORDER_MODE_VALID = "valid"; - private final int[] VALID_PADDING = new int[]{0, 0, 0}; + + private final int[] VALID_PADDING = new int[] { 0, 0, 0 }; private Integer keras1 = 1; + private Integer keras2 = 2; + private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); - @Test - public void testConvolution3DLayer() throws Exception { + @DisplayName("Test Convolution 3 D Layer") + void testConvolution3DLayer() throws Exception { buildConvolution3DLayer(conf1, keras1); buildConvolution3DLayer(conf2, keras2); } - - private void buildConvolution3DLayer(KerasLayerConfiguration conf, Integer kerasVersion) - throws Exception { + private void buildConvolution3DLayer(KerasLayerConfiguration conf, Integer kerasVersion) throws Exception { Map layerConfig = new HashMap<>(); layerConfig.put(conf.getLAYER_FIELD_CLASS_NAME(), conf.getLAYER_CLASS_NAME_CONVOLUTION_3D()); Map config = new HashMap<>(); @@ -97,14 +111,15 @@ public class KerasConvolution3DTest extends BaseDL4JTest { config.put(conf.getLAYER_FIELD_3D_KERNEL_1(), KERNEL_SIZE[0]); config.put(conf.getLAYER_FIELD_3D_KERNEL_2(), KERNEL_SIZE[1]); config.put(conf.getLAYER_FIELD_3D_KERNEL_3(), KERNEL_SIZE[2]); - } else { - ArrayList kernel = new ArrayList() {{ - for (int i : KERNEL_SIZE) add(i); - }}; + ArrayList kernel = new ArrayList() { + + { + for (int i : KERNEL_SIZE) add(i); + } + }; config.put(conf.getLAYER_FIELD_KERNEL_SIZE(), kernel); } - List subsampleList = new ArrayList<>(); subsampleList.add(STRIDE[0]); subsampleList.add(STRIDE[1]); @@ -114,8 +129,6 @@ public class KerasConvolution3DTest extends BaseDL4JTest { config.put(conf.getLAYER_FIELD_BORDER_MODE(), BORDER_MODE_VALID); layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config); layerConfig.put(conf.getLAYER_FIELD_KERAS_VERSION(), kerasVersion); - - ConvolutionLayer layer = new KerasConvolution3D(layerConfig).getConvolution3DLayer(); assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString()); assertEquals(LAYER_NAME, layer.getLayerName()); @@ -128,6 +141,5 @@ public class KerasConvolution3DTest extends BaseDL4JTest { assertEquals(N_OUT, layer.getNOut()); assertEquals(ConvolutionMode.Truncate, layer.getConvolutionMode()); assertArrayEquals(VALID_PADDING, layer.getPadding()); - } } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasCropping1DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasCropping1DTest.java index b45a7e041..25389fc6b 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasCropping1DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasCropping1DTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.modelimport.keras.layers.convolution; import org.deeplearning4j.nn.conf.layers.convolutional.Cropping1D; @@ -26,36 +25,37 @@ import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasCropping1D; -import org.junit.Test; - +import org.junit.jupiter.api.Test; import java.util.HashMap; import java.util.Map; - -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** * @author Max Pumperla */ -public class KerasCropping1DTest extends BaseDL4JTest { +@DisplayName("Keras Cropping 1 D Test") +class KerasCropping1DTest extends BaseDL4JTest { private final String LAYER_NAME = "cropping_1D_layer"; + private final int CROPPING = 2; private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test - public void testCropping1DLayer() throws Exception { + @DisplayName("Test Cropping 1 D Layer") + void testCropping1DLayer() throws Exception { Integer keras1 = 1; Integer keras2 = 2; buildCroppingSingleDim1DLayer(conf1, keras1); buildCroppingSingleDim1DLayer(conf2, keras2); } - - - private void buildCroppingSingleDim1DLayer(KerasLayerConfiguration conf, Integer kerasVersion) - throws Exception { + private void buildCroppingSingleDim1DLayer(KerasLayerConfiguration conf, Integer kerasVersion) throws Exception { Map layerConfig = new HashMap<>(); layerConfig.put(conf.getLAYER_FIELD_CLASS_NAME(), conf.getLAYER_CLASS_NAME_CROPPING_1D()); Map config = new HashMap<>(); @@ -63,7 +63,6 @@ public class KerasCropping1DTest extends BaseDL4JTest { config.put(conf.getLAYER_FIELD_CROPPING(), CROPPING); layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config); layerConfig.put(conf.getLAYER_FIELD_KERAS_VERSION(), kerasVersion); - Cropping1D layer = new KerasCropping1D(layerConfig).getCropping1DLayer(); assertEquals(LAYER_NAME, layer.getLayerName()); assertEquals(CROPPING, layer.getCropping()[0]); diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasCropping2DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasCropping2DTest.java index e05af2469..1d7a94f11 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasCropping2DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasCropping2DTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.modelimport.keras.layers.convolution; import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D; @@ -26,27 +25,31 @@ import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasCropping2D; -import org.junit.Test; - +import org.junit.jupiter.api.Test; import java.util.ArrayList; import java.util.HashMap; import java.util.Map; - -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** * @author Max Pumperla */ -public class KerasCropping2DTest extends BaseDL4JTest { +@DisplayName("Keras Cropping 2 D Test") +class KerasCropping2DTest extends BaseDL4JTest { private final String LAYER_NAME = "cropping_2D_layer"; - private final int[] CROPPING = new int[]{2, 3}; + + private final int[] CROPPING = new int[] { 2, 3 }; private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test - public void testCropping2DLayer() throws Exception { + @DisplayName("Test Cropping 2 D Layer") + void testCropping2DLayer() throws Exception { Integer keras1 = 1; buildCropping2DLayer(conf1, keras1); Integer keras2 = 2; @@ -55,31 +58,29 @@ public class KerasCropping2DTest extends BaseDL4JTest { buildCroppingSingleDim2DLayer(conf2, keras2); } - - private void buildCropping2DLayer(KerasLayerConfiguration conf, Integer kerasVersion) - throws Exception { + private void buildCropping2DLayer(KerasLayerConfiguration conf, Integer kerasVersion) throws Exception { Map layerConfig = new HashMap<>(); layerConfig.put(conf.getLAYER_FIELD_CLASS_NAME(), conf.getLAYER_CLASS_NAME_CROPPING_2D()); Map config = new HashMap<>(); config.put(conf.getLAYER_FIELD_NAME(), LAYER_NAME); - ArrayList padding = new ArrayList() {{ - for (int i : CROPPING) add(i); - }}; + ArrayList padding = new ArrayList() { + + { + for (int i : CROPPING) add(i); + } + }; config.put(conf.getLAYER_FIELD_CROPPING(), padding); layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config); layerConfig.put(conf.getLAYER_FIELD_KERAS_VERSION(), kerasVersion); - Cropping2D layer = new KerasCropping2D(layerConfig).getCropping2DLayer(); assertEquals(LAYER_NAME, layer.getLayerName()); assertEquals(CROPPING[0], layer.getCropping()[0]); assertEquals(CROPPING[0], layer.getCropping()[1]); assertEquals(CROPPING[1], layer.getCropping()[2]); assertEquals(CROPPING[1], layer.getCropping()[3]); - } - private void buildCroppingSingleDim2DLayer(KerasLayerConfiguration conf, Integer kerasVersion) - throws Exception { + private void buildCroppingSingleDim2DLayer(KerasLayerConfiguration conf, Integer kerasVersion) throws Exception { Map layerConfig = new HashMap<>(); layerConfig.put(conf.getLAYER_FIELD_CLASS_NAME(), conf.getLAYER_CLASS_NAME_CROPPING_2D()); Map config = new HashMap<>(); @@ -87,7 +88,6 @@ public class KerasCropping2DTest extends BaseDL4JTest { config.put(conf.getLAYER_FIELD_CROPPING(), CROPPING[0]); layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config); layerConfig.put(conf.getLAYER_FIELD_KERAS_VERSION(), kerasVersion); - Cropping2D layer = new KerasCropping2D(layerConfig).getCropping2DLayer(); assertEquals(LAYER_NAME, layer.getLayerName()); assertEquals(CROPPING[0], layer.getCropping()[0]); diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasCropping3DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasCropping3DTest.java index fbc3b4f8b..cd91873f2 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasCropping3DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasCropping3DTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.modelimport.keras.layers.convolution; import org.deeplearning4j.nn.conf.layers.convolutional.Cropping3D; @@ -26,27 +25,31 @@ import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasCropping3D; -import org.junit.Test; - +import org.junit.jupiter.api.Test; import java.util.ArrayList; import java.util.HashMap; import java.util.Map; - -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** * @author Max Pumperla */ -public class KerasCropping3DTest extends BaseDL4JTest { +@DisplayName("Keras Cropping 3 D Test") +class KerasCropping3DTest extends BaseDL4JTest { private final String LAYER_NAME = "cropping_3D_layer"; - private final int[] CROPPING = new int[]{2, 3, 5}; + + private final int[] CROPPING = new int[] { 2, 3, 5 }; private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test - public void testCropping3DLayer() throws Exception { + @DisplayName("Test Cropping 3 D Layer") + void testCropping3DLayer() throws Exception { Integer keras1 = 1; buildCropping3DLayer(conf1, keras1); Integer keras2 = 2; @@ -55,20 +58,20 @@ public class KerasCropping3DTest extends BaseDL4JTest { buildCroppingSingleDim3DLayer(conf2, keras2); } - - private void buildCropping3DLayer(KerasLayerConfiguration conf, Integer kerasVersion) - throws Exception { + private void buildCropping3DLayer(KerasLayerConfiguration conf, Integer kerasVersion) throws Exception { Map layerConfig = new HashMap<>(); layerConfig.put(conf.getLAYER_FIELD_CLASS_NAME(), conf.getLAYER_CLASS_NAME_CROPPING_3D()); Map config = new HashMap<>(); config.put(conf.getLAYER_FIELD_NAME(), LAYER_NAME); - ArrayList padding = new ArrayList() {{ - for (int i : CROPPING) add(i); - }}; + ArrayList padding = new ArrayList() { + + { + for (int i : CROPPING) add(i); + } + }; config.put(conf.getLAYER_FIELD_CROPPING(), padding); layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config); layerConfig.put(conf.getLAYER_FIELD_KERAS_VERSION(), kerasVersion); - Cropping3D layer = new KerasCropping3D(layerConfig).getCropping3DLayer(); assertEquals(LAYER_NAME, layer.getLayerName()); assertEquals(CROPPING[0], layer.getCropping()[0]); @@ -77,11 +80,9 @@ public class KerasCropping3DTest extends BaseDL4JTest { assertEquals(CROPPING[1], layer.getCropping()[3]); assertEquals(CROPPING[2], layer.getCropping()[4]); assertEquals(CROPPING[2], layer.getCropping()[5]); - } - private void buildCroppingSingleDim3DLayer(KerasLayerConfiguration conf, Integer kerasVersion) - throws Exception { + private void buildCroppingSingleDim3DLayer(KerasLayerConfiguration conf, Integer kerasVersion) throws Exception { Map layerConfig = new HashMap<>(); layerConfig.put(conf.getLAYER_FIELD_CLASS_NAME(), conf.getLAYER_CLASS_NAME_CROPPING_3D()); Map config = new HashMap<>(); @@ -89,7 +90,6 @@ public class KerasCropping3DTest extends BaseDL4JTest { config.put(conf.getLAYER_FIELD_CROPPING(), CROPPING[0]); layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config); layerConfig.put(conf.getLAYER_FIELD_KERAS_VERSION(), kerasVersion); - Cropping3D layer = new KerasCropping3D(layerConfig).getCropping3DLayer(); assertEquals(LAYER_NAME, layer.getLayerName()); assertEquals(CROPPING[0], layer.getCropping()[0]); diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasDeconvolution2DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasDeconvolution2DTest.java index 87940e400..74ca5f03d 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasDeconvolution2DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasDeconvolution2DTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.modelimport.keras.layers.convolution; import org.deeplearning4j.nn.conf.ConvolutionMode; @@ -31,53 +30,69 @@ import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasDeconvolution2D; import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.WeightInitXavier; -import org.junit.Test; - +import org.junit.jupiter.api.Test; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; - -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** * @author Max Pumperla */ -public class KerasDeconvolution2DTest extends BaseDL4JTest { +@DisplayName("Keras Deconvolution 2 D Test") +class KerasDeconvolution2DTest extends BaseDL4JTest { private final String ACTIVATION_KERAS = "linear"; + private final String ACTIVATION_DL4J = "identity"; + private final String LAYER_NAME = "deconvolution_layer"; + private final String INIT_KERAS = "glorot_normal"; + private final IWeightInit INIT_DL4J = new WeightInitXavier(); + private final double L1_REGULARIZATION = 0.01; + private final double L2_REGULARIZATION = 0.02; + private final double DROPOUT_KERAS = 0.3; + private final double DROPOUT_DL4J = 1 - DROPOUT_KERAS; - private final int[] KERNEL_SIZE = new int[]{1, 2}; - private final int[] DILATION = new int[]{2, 2}; - private final int[] STRIDE = new int[]{3, 4}; + + private final int[] KERNEL_SIZE = new int[] { 1, 2 }; + + private final int[] DILATION = new int[] { 2, 2 }; + + private final int[] STRIDE = new int[] { 3, 4 }; + private final int N_OUT = 13; + private final String BORDER_MODE_VALID = "valid"; - private final int[] VALID_PADDING = new int[]{0, 0}; + + private final int[] VALID_PADDING = new int[] { 0, 0 }; private Integer keras1 = 1; + private Integer keras2 = 2; + private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); - @Test - public void testDeconvolution2DLayer() throws Exception { + @DisplayName("Test Deconvolution 2 D Layer") + void testDeconvolution2DLayer() throws Exception { buildDeconvolution2DLayer(conf1, keras1, false); buildDeconvolution2DLayer(conf2, keras2, false); buildDeconvolution2DLayer(conf2, keras2, true); } - - private void buildDeconvolution2DLayer(KerasLayerConfiguration conf, Integer kerasVersion, boolean withDilation) - throws Exception { + private void buildDeconvolution2DLayer(KerasLayerConfiguration conf, Integer kerasVersion, boolean withDilation) throws Exception { Map layerConfig = new HashMap<>(); layerConfig.put(conf.getLAYER_FIELD_CLASS_NAME(), conf.getLAYER_CLASS_NAME_DECONVOLUTION_2D()); Map config = new HashMap<>(); @@ -99,15 +114,21 @@ public class KerasDeconvolution2DTest extends BaseDL4JTest { config.put(conf.getLAYER_FIELD_NB_ROW(), KERNEL_SIZE[0]); config.put(conf.getLAYER_FIELD_NB_COL(), KERNEL_SIZE[1]); } else { - ArrayList kernel = new ArrayList() {{ - for (int i : KERNEL_SIZE) add(i); - }}; + ArrayList kernel = new ArrayList() { + + { + for (int i : KERNEL_SIZE) add(i); + } + }; config.put(conf.getLAYER_FIELD_KERNEL_SIZE(), kernel); } if (withDilation) { - ArrayList dilation = new ArrayList() {{ - for (int i : DILATION) add(i); - }}; + ArrayList dilation = new ArrayList() { + + { + for (int i : DILATION) add(i); + } + }; config.put(conf.getLAYER_FIELD_DILATION_RATE(), dilation); } List subsampleList = new ArrayList<>(); @@ -118,8 +139,6 @@ public class KerasDeconvolution2DTest extends BaseDL4JTest { config.put(conf.getLAYER_FIELD_BORDER_MODE(), BORDER_MODE_VALID); layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config); layerConfig.put(conf.getLAYER_FIELD_KERAS_VERSION(), kerasVersion); - - Deconvolution2D layer = new KerasDeconvolution2D(layerConfig).getDeconvolution2DLayer(); assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString()); assertEquals(LAYER_NAME, layer.getLayerName()); diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasDepthwiseConvolution2DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasDepthwiseConvolution2DTest.java index 50e8d4ca9..1b6a7c8c4 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasDepthwiseConvolution2DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasDepthwiseConvolution2DTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.modelimport.keras.layers.convolution; import org.deeplearning4j.nn.conf.ConvolutionMode; @@ -32,49 +31,64 @@ import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasConvolu import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasDepthwiseConvolution2D; import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.WeightInitXavier; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.base.Preconditions; - import java.util.*; - -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** * @author Max Pumperla */ -public class KerasDepthwiseConvolution2DTest extends BaseDL4JTest { +@DisplayName("Keras Depthwise Convolution 2 D Test") +class KerasDepthwiseConvolution2DTest extends BaseDL4JTest { private final String ACTIVATION_KERAS = "linear"; + private final String ACTIVATION_DL4J = "identity"; + private final String LAYER_NAME = "test_layer"; + private final String INIT_KERAS = "depthwise_conv_2d"; + private final IWeightInit INIT_DL4J = new WeightInitXavier(); + private final double L1_REGULARIZATION = 0.01; + private final double L2_REGULARIZATION = 0.02; + private final double DROPOUT_KERAS = 0.3; + private final double DROPOUT_DL4J = 1 - DROPOUT_KERAS; - private final int[] KERNEL_SIZE = new int[]{1, 2}; - private final int[] DILATION = new int[]{2, 2}; - private final int[] STRIDE = new int[]{3, 4}; + + private final int[] KERNEL_SIZE = new int[] { 1, 2 }; + + private final int[] DILATION = new int[] { 2, 2 }; + + private final int[] STRIDE = new int[] { 3, 4 }; + private final int DEPTH_MULTIPLIER = 4; + private final int N_IN = 3; + private final String BORDER_MODE_VALID = "valid"; - private final int[] VALID_PADDING = new int[]{0, 0}; + + private final int[] VALID_PADDING = new int[] { 0, 0 }; private Integer keras2 = 2; + private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); - @Test - public void testDepthwiseConvolution2DLayer() throws Exception { + @DisplayName("Test Depthwise Convolution 2 D Layer") + void testDepthwiseConvolution2DLayer() throws Exception { buildDepthwiseConvolution2DLayer(conf2, keras2, false); buildDepthwiseConvolution2DLayer(conf2, keras2, true); } - - private void buildDepthwiseConvolution2DLayer(KerasLayerConfiguration conf, Integer kerasVersion, boolean withDilation) - throws Exception { + private void buildDepthwiseConvolution2DLayer(KerasLayerConfiguration conf, Integer kerasVersion, boolean withDilation) throws Exception { Map layerConfig = new HashMap<>(); layerConfig.put(conf.getLAYER_FIELD_CLASS_NAME(), conf.getLAYER_CLASS_NAME_DEPTHWISE_CONVOLUTION_2D()); Map config = new HashMap<>(); @@ -95,16 +109,20 @@ public class KerasDepthwiseConvolution2DTest extends BaseDL4JTest { config.put(conf.getLAYER_FIELD_DEPTH_WISE_REGULARIZER(), W_reg); config.put(conf.getLAYER_FIELD_DROPOUT(), DROPOUT_KERAS); config.put(conf.getLAYER_FIELD_DEPTH_MULTIPLIER(), DEPTH_MULTIPLIER); + ArrayList kernel = new ArrayList() { - ArrayList kernel = new ArrayList() {{ - for (int i : KERNEL_SIZE) add(i); - }}; + { + for (int i : KERNEL_SIZE) add(i); + } + }; config.put(conf.getLAYER_FIELD_KERNEL_SIZE(), kernel); - if (withDilation) { - ArrayList dilation = new ArrayList() {{ - for (int i : DILATION) add(i); - }}; + ArrayList dilation = new ArrayList() { + + { + for (int i : DILATION) add(i); + } + }; config.put(conf.getLAYER_FIELD_DILATION_RATE(), dilation); } List subsampleList = new ArrayList<>(); @@ -115,16 +133,12 @@ public class KerasDepthwiseConvolution2DTest extends BaseDL4JTest { layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config); layerConfig.put(conf.getLAYER_FIELD_KERAS_VERSION(), kerasVersion); config.put(conf.getLAYER_FIELD_NB_FILTER(), N_IN); - KerasConvolution2D previousLayer = new KerasConvolution2D(layerConfig); Map previousLayers = new HashMap<>(); previousLayers.put("conv", previousLayer); List layerNames = Collections.singletonList("conv"); - - KerasDepthwiseConvolution2D kerasLayer = new KerasDepthwiseConvolution2D( - layerConfig, previousLayers, layerNames, false); + KerasDepthwiseConvolution2D kerasLayer = new KerasDepthwiseConvolution2D(layerConfig, previousLayers, layerNames, false); Preconditions.checkState(kerasLayer.getInboundLayerNames().get(0).equalsIgnoreCase("conv"), "Expected inbound name to be \"conv\" - was \"%s\"", kerasLayer.getInboundLayerNames().get(0)); - DepthwiseConvolution2D layer = kerasLayer.getDepthwiseConvolution2DLayer(); assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString()); assertEquals(LAYER_NAME, layer.getLayerName()); diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasSeparableConvolution2DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasSeparableConvolution2DTest.java index 4b8cc6da5..9d203a3d0 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasSeparableConvolution2DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasSeparableConvolution2DTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.modelimport.keras.layers.convolution; import org.deeplearning4j.nn.conf.ConvolutionMode; @@ -31,54 +30,71 @@ import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasSeparableConvolution2D; import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.WeightInitXavier; -import org.junit.Test; - +import org.junit.jupiter.api.Test; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; - -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** * @author Max Pumperla */ -public class KerasSeparableConvolution2DTest extends BaseDL4JTest { +@DisplayName("Keras Separable Convolution 2 D Test") +class KerasSeparableConvolution2DTest extends BaseDL4JTest { private final String ACTIVATION_KERAS = "linear"; + private final String ACTIVATION_DL4J = "identity"; + private final String LAYER_NAME = "test_layer"; + private final String INIT_KERAS = "glorot_normal"; + private final IWeightInit INIT_DL4J = new WeightInitXavier(); + private final double L1_REGULARIZATION = 0.01; + private final double L2_REGULARIZATION = 0.02; + private final double DROPOUT_KERAS = 0.3; + private final double DROPOUT_DL4J = 1 - DROPOUT_KERAS; - private final int[] KERNEL_SIZE = new int[]{1, 2}; - private final int[] DILATION = new int[]{2, 2}; + + private final int[] KERNEL_SIZE = new int[] { 1, 2 }; + + private final int[] DILATION = new int[] { 2, 2 }; + private final int DEPTH_MULTIPLIER = 4; - private final int[] STRIDE = new int[]{3, 4}; + + private final int[] STRIDE = new int[] { 3, 4 }; + private final int N_OUT = 13; + private final String BORDER_MODE_VALID = "valid"; - private final int[] VALID_PADDING = new int[]{0, 0}; + + private final int[] VALID_PADDING = new int[] { 0, 0 }; private Integer keras1 = 1; + private Integer keras2 = 2; + private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); - @Test - public void testSeparableConvolution2DLayer() throws Exception { + @DisplayName("Test Separable Convolution 2 D Layer") + void testSeparableConvolution2DLayer() throws Exception { buildSeparableConvolution2DLayer(conf1, keras1, false); buildSeparableConvolution2DLayer(conf2, keras2, false); buildSeparableConvolution2DLayer(conf2, keras2, true); } - - private void buildSeparableConvolution2DLayer(KerasLayerConfiguration conf, Integer kerasVersion, boolean withDilation) - throws Exception { + private void buildSeparableConvolution2DLayer(KerasLayerConfiguration conf, Integer kerasVersion, boolean withDilation) throws Exception { Map layerConfig = new HashMap<>(); layerConfig.put(conf.getLAYER_FIELD_CLASS_NAME(), conf.getLAYER_CLASS_NAME_SEPARABLE_CONVOLUTION_2D()); Map config = new HashMap<>(); @@ -87,13 +103,11 @@ public class KerasSeparableConvolution2DTest extends BaseDL4JTest { if (kerasVersion == 1) { config.put(conf.getLAYER_FIELD_DEPTH_WISE_INIT(), INIT_KERAS); config.put(conf.getLAYER_FIELD_POINT_WISE_INIT(), INIT_KERAS); - } else { Map init = new HashMap<>(); init.put("class_name", conf.getINIT_GLOROT_NORMAL()); config.put(conf.getLAYER_FIELD_DEPTH_WISE_INIT(), init); config.put(conf.getLAYER_FIELD_POINT_WISE_INIT(), init); - } Map W_reg = new HashMap<>(); W_reg.put(conf.getREGULARIZATION_TYPE_L1(), L1_REGULARIZATION); @@ -101,20 +115,25 @@ public class KerasSeparableConvolution2DTest extends BaseDL4JTest { config.put(conf.getLAYER_FIELD_DEPTH_WISE_REGULARIZER(), W_reg); config.put(conf.getLAYER_FIELD_DROPOUT(), DROPOUT_KERAS); config.put(conf.getLAYER_FIELD_DEPTH_MULTIPLIER(), DEPTH_MULTIPLIER); - if (kerasVersion == 1) { config.put(conf.getLAYER_FIELD_NB_ROW(), KERNEL_SIZE[0]); config.put(conf.getLAYER_FIELD_NB_COL(), KERNEL_SIZE[1]); } else { - ArrayList kernel = new ArrayList() {{ - for (int i : KERNEL_SIZE) add(i); - }}; + ArrayList kernel = new ArrayList() { + + { + for (int i : KERNEL_SIZE) add(i); + } + }; config.put(conf.getLAYER_FIELD_KERNEL_SIZE(), kernel); } if (withDilation) { - ArrayList dilation = new ArrayList() {{ - for (int i : DILATION) add(i); - }}; + ArrayList dilation = new ArrayList() { + + { + for (int i : DILATION) add(i); + } + }; config.put(conf.getLAYER_FIELD_DILATION_RATE(), dilation); } List subsampleList = new ArrayList<>(); @@ -125,8 +144,6 @@ public class KerasSeparableConvolution2DTest extends BaseDL4JTest { config.put(conf.getLAYER_FIELD_BORDER_MODE(), BORDER_MODE_VALID); layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config); layerConfig.put(conf.getLAYER_FIELD_KERAS_VERSION(), kerasVersion); - - SeparableConvolution2D layer = new KerasSeparableConvolution2D(layerConfig).getSeparableConvolution2DLayer(); assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString()); assertEquals(LAYER_NAME, layer.getLayerName()); diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasUpsampling1DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasUpsampling1DTest.java index 6c2c2b6ea..75b5a2b54 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasUpsampling1DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasUpsampling1DTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.modelimport.keras.layers.convolution; import org.deeplearning4j.nn.conf.layers.Upsampling1D; @@ -26,28 +25,34 @@ import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasUpsampling1D; -import org.junit.Test; - +import org.junit.jupiter.api.Test; import java.util.HashMap; import java.util.Map; - -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** * @author Max Pumperla */ -public class KerasUpsampling1DTest extends BaseDL4JTest { +@DisplayName("Keras Upsampling 1 D Test") +class KerasUpsampling1DTest extends BaseDL4JTest { private final String LAYER_NAME = "upsampling_1D_layer"; + private int size = 4; private Integer keras1 = 1; + private Integer keras2 = 2; + private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test - public void testUpsampling1DLayer() throws Exception { + @DisplayName("Test Upsampling 1 D Layer") + void testUpsampling1DLayer() throws Exception { buildUpsampling1DLayer(conf1, keras1); buildUpsampling1DLayer(conf2, keras2); } @@ -60,10 +65,8 @@ public class KerasUpsampling1DTest extends BaseDL4JTest { config.put(conf.getLAYER_FIELD_NAME(), LAYER_NAME); layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config); layerConfig.put(conf.getLAYER_FIELD_KERAS_VERSION(), kerasVersion); - Upsampling1D layer = new KerasUpsampling1D(layerConfig).getUpsampling1DLayer(); assertEquals(LAYER_NAME, layer.getLayerName()); assertEquals(size, layer.getSize()[0]); } - } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasUpsampling2DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasUpsampling2DTest.java index 35ac1f5f8..908ed449f 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasUpsampling2DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasUpsampling2DTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.modelimport.keras.layers.convolution; import org.deeplearning4j.nn.conf.layers.Upsampling2D; @@ -26,35 +25,40 @@ import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasUpsampling2D; -import org.junit.Test; - +import org.junit.jupiter.api.Test; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; - -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** * @author Max Pumperla */ -public class KerasUpsampling2DTest extends BaseDL4JTest { +@DisplayName("Keras Upsampling 2 D Test") +class KerasUpsampling2DTest extends BaseDL4JTest { private final String LAYER_NAME = "upsampling_2D_layer"; - private int[] size = new int[]{2, 2}; + + private int[] size = new int[] { 2, 2 }; private Integer keras1 = 1; + private Integer keras2 = 2; + private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test - public void testUpsampling2DLayer() throws Exception { + @DisplayName("Test Upsampling 2 D Layer") + void testUpsampling2DLayer() throws Exception { buildUpsampling2DLayer(conf1, keras1); buildUpsampling2DLayer(conf2, keras2); } - private void buildUpsampling2DLayer(KerasLayerConfiguration conf, Integer kerasVersion) throws Exception { Map layerConfig = new HashMap<>(); layerConfig.put(conf.getLAYER_FIELD_CLASS_NAME(), conf.getLAYER_CLASS_NAME_UPSAMPLING_2D()); @@ -66,12 +70,9 @@ public class KerasUpsampling2DTest extends BaseDL4JTest { config.put(conf.getLAYER_FIELD_NAME(), LAYER_NAME); layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config); layerConfig.put(conf.getLAYER_FIELD_KERAS_VERSION(), kerasVersion); - Upsampling2D layer = new KerasUpsampling2D(layerConfig).getUpsampling2DLayer(); assertEquals(LAYER_NAME, layer.getLayerName()); assertEquals(size[0], layer.getSize()[0]); assertEquals(size[1], layer.getSize()[1]); - } - } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasUpsampling3DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasUpsampling3DTest.java index c2304a90d..1e633a929 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasUpsampling3DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasUpsampling3DTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.modelimport.keras.layers.convolution; import org.deeplearning4j.nn.conf.layers.Upsampling3D; @@ -26,35 +25,40 @@ import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasUpsampling3D; -import org.junit.Test; - +import org.junit.jupiter.api.Test; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; - -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** * @author Max Pumperla */ -public class KerasUpsampling3DTest extends BaseDL4JTest { +@DisplayName("Keras Upsampling 3 D Test") +class KerasUpsampling3DTest extends BaseDL4JTest { private final String LAYER_NAME = "upsampling_3D_layer"; - private int[] size = new int[]{2, 2, 2}; + + private int[] size = new int[] { 2, 2, 2 }; private Integer keras1 = 1; + private Integer keras2 = 2; + private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test - public void testUpsampling3DLayer() throws Exception { + @DisplayName("Test Upsampling 3 D Layer") + void testUpsampling3DLayer() throws Exception { buildUpsampling3DLayer(conf1, keras1); buildUpsampling3DLayer(conf2, keras2); } - private void buildUpsampling3DLayer(KerasLayerConfiguration conf, Integer kerasVersion) throws Exception { Map layerConfig = new HashMap<>(); layerConfig.put(conf.getLAYER_FIELD_CLASS_NAME(), conf.getLAYER_CLASS_NAME_UPSAMPLING_3D()); @@ -67,12 +71,10 @@ public class KerasUpsampling3DTest extends BaseDL4JTest { config.put(conf.getLAYER_FIELD_NAME(), LAYER_NAME); layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config); layerConfig.put(conf.getLAYER_FIELD_KERAS_VERSION(), kerasVersion); - Upsampling3D layer = new KerasUpsampling3D(layerConfig).getUpsampling3DLayer(); assertEquals(LAYER_NAME, layer.getLayerName()); assertEquals(size[0], layer.getSize()[0]); assertEquals(size[1], layer.getSize()[1]); assertEquals(size[2], layer.getSize()[2]); } - } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasZeroPadding1DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasZeroPadding1DTest.java index 8fde00deb..1d0607dda 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasZeroPadding1DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasZeroPadding1DTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.modelimport.keras.layers.convolution; import org.deeplearning4j.nn.conf.layers.ZeroPadding1DLayer; @@ -26,30 +25,32 @@ import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasZeroPadding1D; -import org.junit.Test; - +import org.junit.jupiter.api.Test; import java.util.HashMap; import java.util.Map; - -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** * @author Max Pumperla */ -public class KerasZeroPadding1DTest extends BaseDL4JTest { +@DisplayName("Keras Zero Padding 1 D Test") +class KerasZeroPadding1DTest extends BaseDL4JTest { private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test - public void testZeroPadding1DLayer() throws Exception { + @DisplayName("Test Zero Padding 1 D Layer") + void testZeroPadding1DLayer() throws Exception { Integer keras1 = 1; buildZeroPadding1DLayer(conf1, keras1); Integer keras2 = 2; buildZeroPadding1DLayer(conf2, keras2); } - private void buildZeroPadding1DLayer(KerasLayerConfiguration conf, Integer kerasVersion) throws Exception { Map layerConfig = new HashMap<>(); layerConfig.put(conf.getLAYER_FIELD_CLASS_NAME(), conf.getLAYER_CLASS_NAME_ZERO_PADDING_1D()); @@ -60,10 +61,8 @@ public class KerasZeroPadding1DTest extends BaseDL4JTest { config.put(conf.getLAYER_FIELD_ZERO_PADDING(), zeroPadding); layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config); layerConfig.put(conf.getLAYER_FIELD_KERAS_VERSION(), kerasVersion); - ZeroPadding1DLayer layer = new KerasZeroPadding1D(layerConfig).getZeroPadding1DLayer(); assertEquals(layerName, layer.getLayerName()); assertEquals(zeroPadding, layer.getPadding()[0]); } - } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasZeroPadding2DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasZeroPadding2DTest.java index 34fc87778..31d1da354 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasZeroPadding2DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasZeroPadding2DTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.modelimport.keras.layers.convolution; import org.deeplearning4j.nn.conf.layers.ZeroPaddingLayer; @@ -26,27 +25,31 @@ import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasZeroPadding2D; -import org.junit.Test; - +import org.junit.jupiter.api.Test; import java.util.ArrayList; import java.util.HashMap; import java.util.Map; - -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** * @author Max Pumperla */ -public class KerasZeroPadding2DTest extends BaseDL4JTest { +@DisplayName("Keras Zero Padding 2 D Test") +class KerasZeroPadding2DTest extends BaseDL4JTest { private final String LAYER_NAME = "zero_padding_2D_layer"; - private final int[] ZERO_PADDING = new int[]{2, 3}; + + private final int[] ZERO_PADDING = new int[] { 2, 3 }; private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test - public void testZeroPadding2DLayer() throws Exception { + @DisplayName("Test Zero Padding 2 D Layer") + void testZeroPadding2DLayer() throws Exception { Integer keras1 = 1; buildZeroPadding2DLayer(conf1, keras1); Integer keras2 = 2; @@ -55,31 +58,29 @@ public class KerasZeroPadding2DTest extends BaseDL4JTest { buildZeroPaddingSingleDim2DLayer(conf2, keras2); } - - private void buildZeroPadding2DLayer(KerasLayerConfiguration conf, Integer kerasVersion) - throws Exception { + private void buildZeroPadding2DLayer(KerasLayerConfiguration conf, Integer kerasVersion) throws Exception { Map layerConfig = new HashMap<>(); layerConfig.put(conf.getLAYER_FIELD_CLASS_NAME(), conf.getLAYER_CLASS_NAME_ZERO_PADDING_2D()); Map config = new HashMap<>(); config.put(conf.getLAYER_FIELD_NAME(), LAYER_NAME); - ArrayList padding = new ArrayList() {{ - for (int i : ZERO_PADDING) add(i); - }}; + ArrayList padding = new ArrayList() { + + { + for (int i : ZERO_PADDING) add(i); + } + }; config.put(conf.getLAYER_FIELD_ZERO_PADDING(), padding); layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config); layerConfig.put(conf.getLAYER_FIELD_KERAS_VERSION(), kerasVersion); - ZeroPaddingLayer layer = new KerasZeroPadding2D(layerConfig).getZeroPadding2DLayer(); assertEquals(LAYER_NAME, layer.getLayerName()); assertEquals(ZERO_PADDING[0], layer.getPadding()[0]); assertEquals(ZERO_PADDING[0], layer.getPadding()[1]); assertEquals(ZERO_PADDING[1], layer.getPadding()[2]); assertEquals(ZERO_PADDING[1], layer.getPadding()[3]); - } - private void buildZeroPaddingSingleDim2DLayer(KerasLayerConfiguration conf, Integer kerasVersion) - throws Exception { + private void buildZeroPaddingSingleDim2DLayer(KerasLayerConfiguration conf, Integer kerasVersion) throws Exception { Map layerConfig = new HashMap<>(); layerConfig.put(conf.getLAYER_FIELD_CLASS_NAME(), conf.getLAYER_CLASS_NAME_ZERO_PADDING_2D()); Map config = new HashMap<>(); @@ -87,7 +88,6 @@ public class KerasZeroPadding2DTest extends BaseDL4JTest { config.put(conf.getLAYER_FIELD_ZERO_PADDING(), ZERO_PADDING[0]); layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config); layerConfig.put(conf.getLAYER_FIELD_KERAS_VERSION(), kerasVersion); - ZeroPaddingLayer layer = new KerasZeroPadding2D(layerConfig).getZeroPadding2DLayer(); assertEquals(LAYER_NAME, layer.getLayerName()); assertEquals(ZERO_PADDING[0], layer.getPadding()[0]); diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasZeroPadding3DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasZeroPadding3DTest.java index 9a0c61ec9..7a1980c2a 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasZeroPadding3DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasZeroPadding3DTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.modelimport.keras.layers.convolution; import org.deeplearning4j.nn.conf.layers.ZeroPadding3DLayer; @@ -26,27 +25,31 @@ import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasZeroPadding3D; -import org.junit.Test; - +import org.junit.jupiter.api.Test; import java.util.ArrayList; import java.util.HashMap; import java.util.Map; - -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** * @author Max Pumperla */ -public class KerasZeroPadding3DTest extends BaseDL4JTest { +@DisplayName("Keras Zero Padding 3 D Test") +class KerasZeroPadding3DTest extends BaseDL4JTest { private final String LAYER_NAME = "zero_padding_3D_layer"; - private final int[] ZERO_PADDING = new int[]{2, 3, 4}; + + private final int[] ZERO_PADDING = new int[] { 2, 3, 4 }; private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test - public void testZeroPadding3DLayer() throws Exception { + @DisplayName("Test Zero Padding 3 D Layer") + void testZeroPadding3DLayer() throws Exception { Integer keras1 = 1; buildZeroPadding3DLayer(conf1, keras1); Integer keras2 = 2; @@ -55,20 +58,20 @@ public class KerasZeroPadding3DTest extends BaseDL4JTest { buildZeroPaddingSingleDim3DLayer(conf2, keras2); } - - private void buildZeroPadding3DLayer(KerasLayerConfiguration conf, Integer kerasVersion) - throws Exception { + private void buildZeroPadding3DLayer(KerasLayerConfiguration conf, Integer kerasVersion) throws Exception { Map layerConfig = new HashMap<>(); layerConfig.put(conf.getLAYER_FIELD_CLASS_NAME(), conf.getLAYER_CLASS_NAME_ZERO_PADDING_3D()); Map config = new HashMap<>(); config.put(conf.getLAYER_FIELD_NAME(), LAYER_NAME); - ArrayList padding = new ArrayList() {{ - for (int i : ZERO_PADDING) add(i); - }}; + ArrayList padding = new ArrayList() { + + { + for (int i : ZERO_PADDING) add(i); + } + }; config.put(conf.getLAYER_FIELD_ZERO_PADDING(), padding); layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config); layerConfig.put(conf.getLAYER_FIELD_KERAS_VERSION(), kerasVersion); - ZeroPadding3DLayer layer = new KerasZeroPadding3D(layerConfig).getZeroPadding3DLayer(); assertEquals(LAYER_NAME, layer.getLayerName()); assertEquals(ZERO_PADDING[0], layer.getPadding()[0]); @@ -77,11 +80,9 @@ public class KerasZeroPadding3DTest extends BaseDL4JTest { assertEquals(ZERO_PADDING[1], layer.getPadding()[3]); assertEquals(ZERO_PADDING[2], layer.getPadding()[4]); assertEquals(ZERO_PADDING[2], layer.getPadding()[5]); - } - private void buildZeroPaddingSingleDim3DLayer(KerasLayerConfiguration conf, Integer kerasVersion) - throws Exception { + private void buildZeroPaddingSingleDim3DLayer(KerasLayerConfiguration conf, Integer kerasVersion) throws Exception { Map layerConfig = new HashMap<>(); layerConfig.put(conf.getLAYER_FIELD_CLASS_NAME(), conf.getLAYER_CLASS_NAME_ZERO_PADDING_3D()); Map config = new HashMap<>(); @@ -89,7 +90,6 @@ public class KerasZeroPadding3DTest extends BaseDL4JTest { config.put(conf.getLAYER_FIELD_ZERO_PADDING(), ZERO_PADDING[0]); layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config); layerConfig.put(conf.getLAYER_FIELD_KERAS_VERSION(), kerasVersion); - ZeroPadding3DLayer layer = new KerasZeroPadding3D(layerConfig).getZeroPadding3DLayer(); assertEquals(LAYER_NAME, layer.getLayerName()); assertEquals(ZERO_PADDING[0], layer.getPadding()[0]); diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasDenseTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasDenseTest.java index cecb4a087..fe4d2af67 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasDenseTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasDenseTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.modelimport.keras.layers.core; import org.deeplearning4j.nn.conf.dropout.Dropout; @@ -29,41 +28,54 @@ import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.WeightInitXavier; -import org.junit.Test; - +import org.junit.jupiter.api.Test; import java.util.HashMap; import java.util.Map; - -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** * @author Max Pumperla */ -public class KerasDenseTest extends BaseDL4JTest { +@DisplayName("Keras Dense Test") +class KerasDenseTest extends BaseDL4JTest { private Integer keras1 = 1; + private Integer keras2 = 2; + private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); private final String ACTIVATION_KERAS = "linear"; + private final String ACTIVATION_DL4J = "identity"; + private final String LAYER_NAME = "dense"; + private final String INIT_KERAS = "glorot_normal"; + private final IWeightInit INIT_DL4J = new WeightInitXavier(); + private final double L1_REGULARIZATION = 0.01; + private final double L2_REGULARIZATION = 0.02; + private final double DROPOUT_KERAS = 0.3; + private final double DROPOUT_DL4J = 1 - DROPOUT_KERAS; + private final int N_OUT = 13; @Test - public void testDenseLayer() throws Exception { + @DisplayName("Test Dense Layer") + void testDenseLayer() throws Exception { buildDenseLayer(conf1, keras1); buildDenseLayer(conf2, keras2); } - private void buildDenseLayer(KerasLayerConfiguration conf, Integer kerasVersion) throws Exception { Map layerConfig = new HashMap<>(); layerConfig.put(conf.getLAYER_FIELD_CLASS_NAME(), conf.getLAYER_CLASS_NAME_DENSE()); @@ -85,7 +97,6 @@ public class KerasDenseTest extends BaseDL4JTest { config.put(conf.getLAYER_FIELD_OUTPUT_DIM(), N_OUT); layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config); layerConfig.put(conf.getLAYER_FIELD_KERAS_VERSION(), kerasVersion); - DenseLayer layer = new KerasDense(layerConfig, false).getDenseLayer(); assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString()); assertEquals(LAYER_NAME, layer.getLayerName()); diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasDropoutTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasDropoutTest.java index d3a395bf9..d8c9a11ca 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasDropoutTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasDropoutTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.modelimport.keras.layers.core; import org.deeplearning4j.nn.conf.dropout.Dropout; @@ -26,35 +25,40 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; -import org.junit.Test; - +import org.junit.jupiter.api.Test; import java.util.HashMap; import java.util.Map; - -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** * @author Max Pumperla */ -public class KerasDropoutTest extends BaseDL4JTest { +@DisplayName("Keras Dropout Test") +class KerasDropoutTest extends BaseDL4JTest { String LAYER_NAME = "dropout"; + private final double DROPOUT_KERAS = 0.3; + private final double DROPOUT_DL4J = 1 - DROPOUT_KERAS; private Integer keras1 = 1; + private Integer keras2 = 2; + private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); - @Test - public void testDropoutLayer() throws Exception { + @DisplayName("Test Dropout Layer") + void testDropoutLayer() throws Exception { buildDropoutLayer(conf1, keras1); buildDropoutLayer(conf2, keras2); } - private void buildDropoutLayer(KerasLayerConfiguration conf, Integer kerasVersion) throws Exception { Map layerConfig = new HashMap<>(); layerConfig.put(conf.getLAYER_FIELD_CLASS_NAME(), conf.getLAYER_CLASS_NAME_DROPOUT()); @@ -63,11 +67,8 @@ public class KerasDropoutTest extends BaseDL4JTest { config.put(conf.getLAYER_FIELD_DROPOUT(), DROPOUT_KERAS); layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config); layerConfig.put(conf.getLAYER_FIELD_KERAS_VERSION(), kerasVersion); - DropoutLayer layer = new KerasDropout(layerConfig).getDropoutLayer(); assertEquals(LAYER_NAME, layer.getLayerName()); assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout()); } - - } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasMaskingTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasMaskingTest.java index 20b350171..19b087696 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasMaskingTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasMaskingTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.modelimport.keras.layers.core; import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer; @@ -25,33 +24,32 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; -import org.junit.Test; - +import org.junit.jupiter.api.Test; import java.util.HashMap; import java.util.Map; - -import static org.junit.Assert.assertEquals; - +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** * @author Max Pumperla */ -public class KerasMaskingTest extends BaseDL4JTest { - +@DisplayName("Keras Masking Test") +class KerasMaskingTest extends BaseDL4JTest { private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); - @Test - public void testMaskingLayer() throws Exception { + @DisplayName("Test Masking Layer") + void testMaskingLayer() throws Exception { Integer keras1 = 1; buildMaskingLayer(conf1, keras1); Integer keras2 = 2; buildMaskingLayer(conf2, keras2); } - private void buildMaskingLayer(KerasLayerConfiguration conf, Integer kerasVersion) throws Exception { Map layerConfig = new HashMap<>(); layerConfig.put(conf.getLAYER_FIELD_CLASS_NAME(), conf.getLAYER_CLASS_NAME_MASKING()); @@ -62,10 +60,7 @@ public class KerasMaskingTest extends BaseDL4JTest { config.put(conf.getLAYER_FIELD_MASK_VALUE(), MASKING_VALUE); layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config); layerConfig.put(conf.getLAYER_FIELD_KERAS_VERSION(), kerasVersion); - MaskZeroLayer layer = new KerasMasking(layerConfig).getMaskingLayer(); assertEquals(MASKING_VALUE, layer.getMaskingValue(), 0.0); } - - } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasPermuteTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasPermuteTest.java index d3283f511..121858a7b 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasPermuteTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasPermuteTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.modelimport.keras.layers.core; import org.deeplearning4j.nn.conf.inputs.InputType; @@ -28,35 +27,38 @@ import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.preprocessors.PermutePreprocessor; -import org.junit.Test; - +import org.junit.jupiter.api.Test; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; - -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** * @author Max Pumperla */ -public class KerasPermuteTest extends BaseDL4JTest { +@DisplayName("Keras Permute Test") +class KerasPermuteTest extends BaseDL4JTest { private Integer keras1 = 1; + private Integer keras2 = 2; + private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); - @Test - public void testPermuteLayer() throws Exception { + @DisplayName("Test Permute Layer") + void testPermuteLayer() throws Exception { buildPermuteLayer(conf1, keras1); buildPermuteLayer(conf2, keras2); } - private void buildPermuteLayer(KerasLayerConfiguration conf, Integer kerasVersion) throws Exception { - int[] permuteIndices = new int[]{2, 1}; + int[] permuteIndices = new int[] { 2, 1 }; List permuteList = new ArrayList<>(); permuteList.add(permuteIndices[0]); permuteList.add(permuteIndices[1]); @@ -65,9 +67,7 @@ public class KerasPermuteTest extends BaseDL4JTest { assertEquals(preProcessor.getPermutationIndices()[1], permuteIndices[1]); } - private PermutePreprocessor getPermutePreProcessor(KerasLayerConfiguration conf, Integer kerasVersion, - List permuteList) - throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { + private PermutePreprocessor getPermutePreProcessor(KerasLayerConfiguration conf, Integer kerasVersion, List permuteList) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { Map layerConfig = new HashMap<>(); layerConfig.put(conf.getLAYER_FIELD_CLASS_NAME(), conf.getLAYER_CLASS_NAME_RESHAPE()); Map config = new HashMap<>(); @@ -77,6 +77,5 @@ public class KerasPermuteTest extends BaseDL4JTest { layerConfig.put(conf.getLAYER_FIELD_KERAS_VERSION(), kerasVersion); InputType inputType = InputType.InputTypeFeedForward.recurrent(20, 10); return (PermutePreprocessor) new KerasPermute(layerConfig).getInputPreprocessor(inputType); - } } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasRepeatVectorTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasRepeatVectorTest.java index 72d420252..d3e567cb9 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasRepeatVectorTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasRepeatVectorTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.modelimport.keras.layers.core; import org.deeplearning4j.nn.conf.layers.misc.RepeatVector; @@ -25,34 +24,38 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; -import org.junit.Test; - +import org.junit.jupiter.api.Test; import java.util.HashMap; import java.util.Map; - -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** * @author Max Pumperla */ -public class KerasRepeatVectorTest extends BaseDL4JTest { +@DisplayName("Keras Repeat Vector Test") +class KerasRepeatVectorTest extends BaseDL4JTest { String LAYER_NAME = "repeat"; + private int REPEAT = 4; private Integer keras1 = 1; + private Integer keras2 = 2; + private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); - @Test - public void testRepeatVectorLayer() throws Exception { + @DisplayName("Test Repeat Vector Layer") + void testRepeatVectorLayer() throws Exception { buildRepeatVectorLayer(conf1, keras1); buildRepeatVectorLayer(conf2, keras2); } - private void buildRepeatVectorLayer(KerasLayerConfiguration conf, Integer kerasVersion) throws Exception { Map layerConfig = new HashMap<>(); layerConfig.put(conf.getLAYER_FIELD_CLASS_NAME(), conf.getLAYER_CLASS_NAME_REPEAT()); @@ -61,11 +64,8 @@ public class KerasRepeatVectorTest extends BaseDL4JTest { config.put(conf.getLAYER_FIELD_REPEAT_MULTIPLIER(), REPEAT); layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config); layerConfig.put(conf.getLAYER_FIELD_KERAS_VERSION(), kerasVersion); - RepeatVector layer = new KerasRepeatVector(layerConfig).getRepeatVectorLayer(); assertEquals(LAYER_NAME, layer.getLayerName()); assertEquals(layer.getN(), REPEAT); } - - } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasReshapeTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasReshapeTest.java index 1e46c90ae..acaa7adb7 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasReshapeTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasReshapeTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.modelimport.keras.layers.core; import org.deeplearning4j.nn.conf.inputs.InputType; @@ -29,40 +28,45 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurat import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; -import org.junit.Assert; -import org.junit.Test; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; - import java.util.*; - -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** * @author Max Pumperla */ -public class KerasReshapeTest extends BaseDL4JTest { +@DisplayName("Keras Reshape Test") +class KerasReshapeTest extends BaseDL4JTest { private Integer keras1 = 1; + private Integer keras2 = 2; + private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); - @Test - public void testReshapeLayer() throws Exception { + @DisplayName("Test Reshape Layer") + void testReshapeLayer() throws Exception { buildReshapeLayer(conf1, keras1); buildReshapeLayer(conf2, keras2); } @Test - public void testReshapeDynamicMinibatch() throws Exception { + @DisplayName("Test Reshape Dynamic Minibatch") + void testReshapeDynamicMinibatch() throws Exception { testDynamicMinibatches(conf1, keras1); testDynamicMinibatches(conf2, keras2); } private void buildReshapeLayer(KerasLayerConfiguration conf, Integer kerasVersion) throws Exception { - int[] targetShape = new int[]{10, 5}; + int[] targetShape = new int[] { 10, 5 }; List targetShapeList = new ArrayList<>(); targetShapeList.add(targetShape[0]); targetShapeList.add(targetShape[1]); @@ -71,9 +75,7 @@ public class KerasReshapeTest extends BaseDL4JTest { assertEquals(preProcessor.getTargetShape()[1], targetShape[1]); } - private ReshapePreprocessor getReshapePreProcessor(KerasLayerConfiguration conf, Integer kerasVersion, - List targetShapeList) - throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { + private ReshapePreprocessor getReshapePreProcessor(KerasLayerConfiguration conf, Integer kerasVersion, List targetShapeList) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { Map layerConfig = new HashMap<>(); layerConfig.put(conf.getLAYER_FIELD_CLASS_NAME(), conf.getLAYER_CLASS_NAME_RESHAPE()); Map config = new HashMap<>(); @@ -85,7 +87,6 @@ public class KerasReshapeTest extends BaseDL4JTest { layerConfig.put(conf.getLAYER_FIELD_KERAS_VERSION(), kerasVersion); InputType inputType = InputType.InputTypeFeedForward.feedForward(20); return (ReshapePreprocessor) new KerasReshape(layerConfig).getInputPreprocessor(inputType); - } private void testDynamicMinibatches(KerasLayerConfiguration conf, Integer kerasVersion) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { @@ -93,7 +94,7 @@ public class KerasReshapeTest extends BaseDL4JTest { ReshapePreprocessor preprocessor = getReshapePreProcessor(conf, kerasVersion, targetShape); INDArray r1 = preprocessor.preProcess(Nd4j.zeros(10, 20), 10, LayerWorkspaceMgr.noWorkspaces()); INDArray r2 = preprocessor.preProcess(Nd4j.zeros(5, 20), 5, LayerWorkspaceMgr.noWorkspaces()); - Assert.assertArrayEquals(r2.shape(), new long[]{5, 20}); - Assert.assertArrayEquals(r1.shape(), new long[]{10, 20}); + Assertions.assertArrayEquals(r2.shape(), new long[] { 5, 20 }); + Assertions.assertArrayEquals(r1.shape(), new long[] { 10, 20 }); } } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasSpatialDropout2DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasSpatialDropout2DTest.java index ccb785882..88d6e4ace 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasSpatialDropout2DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasSpatialDropout2DTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.modelimport.keras.layers.core; import org.deeplearning4j.nn.conf.dropout.SpatialDropout; @@ -26,35 +25,40 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; -import org.junit.Test; - +import org.junit.jupiter.api.Test; import java.util.HashMap; import java.util.Map; - -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** * @author Max Pumperla */ -public class KerasSpatialDropout2DTest extends BaseDL4JTest { +@DisplayName("Keras Spatial Dropout 2 D Test") +class KerasSpatialDropout2DTest extends BaseDL4JTest { String LAYER_NAME = "spatial_dropout_2d"; + private final double RATE_KERAS = 0.3; + private final double RATE_DL4J = 1 - RATE_KERAS; private Integer keras1 = 1; + private Integer keras2 = 2; + private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); - @Test - public void testSpatialDropoutLayer() throws Exception { + @DisplayName("Test Spatial Dropout Layer") + void testSpatialDropoutLayer() throws Exception { buildSpatialDropoutLayer(conf1, keras1); buildSpatialDropoutLayer(conf2, keras2); } - private void buildSpatialDropoutLayer(KerasLayerConfiguration conf, Integer kerasVersion) throws Exception { Map layerConfig = new HashMap<>(); layerConfig.put(conf.getLAYER_FIELD_CLASS_NAME(), conf.getLAYER_CLASS_NAME_SPATIAL_DROPOUT_2D()); @@ -63,10 +67,8 @@ public class KerasSpatialDropout2DTest extends BaseDL4JTest { config.put(conf.getLAYER_FIELD_RATE(), RATE_KERAS); layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config); layerConfig.put(conf.getLAYER_FIELD_KERAS_VERSION(), kerasVersion); - DropoutLayer layer = new KerasSpatialDropout(layerConfig).getSpatialDropoutLayer(); assertEquals(LAYER_NAME, layer.getLayerName()); assertEquals(new SpatialDropout(RATE_DL4J), layer.getIDropout()); } - } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/embeddings/KerasEmbeddingTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/embeddings/KerasEmbeddingTest.java index 0d1d09dce..eac80f459 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/embeddings/KerasEmbeddingTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/embeddings/KerasEmbeddingTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.modelimport.keras.layers.embeddings; import org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer; @@ -26,30 +25,39 @@ import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.params.DefaultParamInitializer; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; - import java.util.*; - -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** * @author Max Pumperla */ -public class KerasEmbeddingTest extends BaseDL4JTest { +@DisplayName("Keras Embedding Test") +class KerasEmbeddingTest extends BaseDL4JTest { private final String LAYER_NAME = "embedding_sequence_layer"; + private final String INIT_KERAS = "glorot_normal"; - private final int[] INPUT_SHAPE = new int[]{100, 20}; - private static final boolean[] MASK_ZERO = new boolean[]{false, true}; + + private final int[] INPUT_SHAPE = new int[] { 100, 20 }; + + private static final boolean[] MASK_ZERO = new boolean[] { false, true }; + private Integer keras1 = 1; + private Integer keras2 = 2; + private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test - public void testEmbeddingLayer() throws Exception { + @DisplayName("Test Embedding Layer") + void testEmbeddingLayer() throws Exception { for (boolean mz : MASK_ZERO) { buildEmbeddingLayer(conf1, keras1, mz); buildEmbeddingLayer(conf2, keras2, mz); @@ -57,17 +65,17 @@ public class KerasEmbeddingTest extends BaseDL4JTest { } @Test - public void testEmbeddingLayerSetWeightsMaskZero() throws Exception { - //GIVEN keras embedding with mask zero true + @DisplayName("Test Embedding Layer Set Weights Mask Zero") + void testEmbeddingLayerSetWeightsMaskZero() throws Exception { + // GIVEN keras embedding with mask zero true KerasEmbedding embedding = buildEmbeddingLayer(conf1, keras1, true); - //WHEN + // WHEN embedding.setWeights(Collections.singletonMap(conf1.getLAYER_FIELD_EMBEDDING_WEIGHTS(), Nd4j.ones(INPUT_SHAPE))); - //THEN first row is set to zeros + // THEN first row is set to zeros INDArray weights = embedding.getWeights().get(DefaultParamInitializer.WEIGHT_KEY); - assertEquals(embedding.getWeights().get(DefaultParamInitializer.WEIGHT_KEY).columns(),INPUT_SHAPE[1]); + assertEquals(embedding.getWeights().get(DefaultParamInitializer.WEIGHT_KEY).columns(), INPUT_SHAPE[1]); } - private KerasEmbedding buildEmbeddingLayer(KerasLayerConfiguration conf, Integer kerasVersion, boolean maskZero) throws Exception { Map layerConfig = new HashMap<>(); layerConfig.put(conf.getLAYER_FIELD_CLASS_NAME(), conf.getLAYER_CLASS_NAME_EMBEDDING()); @@ -78,7 +86,6 @@ public class KerasEmbeddingTest extends BaseDL4JTest { config.put(conf.getLAYER_FIELD_INPUT_DIM(), inputDim); config.put(conf.getLAYER_FIELD_INPUT_LENGTH(), inputLength); config.put(conf.getLAYER_FIELD_OUTPUT_DIM(), outputDim); - List inputShape = new ArrayList<>(INPUT_SHAPE.length); for (int i : INPUT_SHAPE) { inputShape.add(i); @@ -98,7 +105,6 @@ public class KerasEmbeddingTest extends BaseDL4JTest { KerasEmbedding kerasEmbedding = new KerasEmbedding(layerConfig, false); assertEquals(kerasEmbedding.getNumParams(), 1); assertEquals(kerasEmbedding.isZeroMasking(), maskZero); - EmbeddingSequenceLayer layer = kerasEmbedding.getEmbeddingLayer(); assertEquals(LAYER_NAME, layer.getLayerName()); return kerasEmbedding; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/flatten/KerasFlatten3dTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/flatten/KerasFlatten3dTest.java index 7aa7cd5a4..c355cf28b 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/flatten/KerasFlatten3dTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/flatten/KerasFlatten3dTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.modelimport.keras.layers.flatten; import org.deeplearning4j.nn.conf.InputPreProcessor; @@ -26,23 +25,24 @@ import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.vertex.GraphVertex; import org.deeplearning4j.nn.graph.vertex.impl.PreprocessorVertex; import org.deeplearning4j.nn.modelimport.keras.KerasModelImport; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.io.ClassPathResource; - import java.io.InputStream; +import static org.junit.jupiter.api.Assertions.*; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; -import static org.junit.Assert.*; - -public class KerasFlatten3dTest { - +@DisplayName("Keras Flatten 3 d Test") +class KerasFlatten3dTest { @Test - public void testFlatten3d() throws Exception { + @DisplayName("Test Flatten 3 d") + void testFlatten3d() throws Exception { ClassPathResource classPathResource = new ClassPathResource("modelimport/keras/weights/flatten_3d.hdf5"); - try(InputStream inputStream = classPathResource.getInputStream()) { + try (InputStream inputStream = classPathResource.getInputStream()) { ComputationGraph computationGraph = KerasModelImport.importKerasModelAndWeights(inputStream); assertNotNull(computationGraph); - assertEquals(3,computationGraph.getVertices().length); + assertEquals(3, computationGraph.getVertices().length); GraphVertex[] vertices = computationGraph.getVertices(); assertTrue(vertices[1] instanceof PreprocessorVertex); PreprocessorVertex preprocessorVertex = (PreprocessorVertex) vertices[1]; @@ -50,12 +50,11 @@ public class KerasFlatten3dTest { assertTrue(preProcessor instanceof Cnn3DToFeedForwardPreProcessor); Cnn3DToFeedForwardPreProcessor cnn3DToFeedForwardPreProcessor = (Cnn3DToFeedForwardPreProcessor) preProcessor; assertTrue(cnn3DToFeedForwardPreProcessor.isNCDHW()); - assertEquals(10,cnn3DToFeedForwardPreProcessor.getInputDepth()); - assertEquals(10,cnn3DToFeedForwardPreProcessor.getInputHeight()); - assertEquals(1,cnn3DToFeedForwardPreProcessor.getNumChannels()); - assertEquals(10,cnn3DToFeedForwardPreProcessor.getInputWidth()); + assertEquals(10, cnn3DToFeedForwardPreProcessor.getInputDepth()); + assertEquals(10, cnn3DToFeedForwardPreProcessor.getInputHeight()); + assertEquals(1, cnn3DToFeedForwardPreProcessor.getNumChannels()); + assertEquals(10, cnn3DToFeedForwardPreProcessor.getInputWidth()); System.out.println(cnn3DToFeedForwardPreProcessor); } } - } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/local/KerasLocallyConnected1DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/local/KerasLocallyConnected1DTest.java index defc4a8ad..8dae03fe2 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/local/KerasLocallyConnected1DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/local/KerasLocallyConnected1DTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.modelimport.keras.layers.local; import org.deeplearning4j.nn.conf.ConvolutionMode; @@ -30,49 +29,64 @@ import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.Test; - +import org.junit.jupiter.api.Test; import java.util.ArrayList; import java.util.HashMap; import java.util.Map; - -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** * @author Max Pumperla */ -public class KerasLocallyConnected1DTest extends BaseDL4JTest { +@DisplayName("Keras Locally Connected 1 D Test") +class KerasLocallyConnected1DTest extends BaseDL4JTest { private final String ACTIVATION_KERAS = "linear"; + private final String ACTIVATION_DL4J = "identity"; + private final String LAYER_NAME = "test_layer"; + private final String INIT_KERAS = "glorot_normal"; + private final WeightInit INIT_DL4J = WeightInit.XAVIER; + private final double L1_REGULARIZATION = 0.01; + private final double L2_REGULARIZATION = 0.02; + private final double DROPOUT_KERAS = 0.3; + private final double DROPOUT_DL4J = 1 - DROPOUT_KERAS; + private final int KERNEL_SIZE = 2; + private final int STRIDE = 3; + private final int N_OUT = 13; + private final String BORDER_MODE_VALID = "valid"; + private final int VALID_PADDING = 0; private Integer keras1 = 1; + private Integer keras2 = 2; + private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); - @Test - public void testLocallyConnected2DLayer() throws Exception { + @DisplayName("Test Locally Connected 2 D Layer") + void testLocallyConnected2DLayer() throws Exception { buildLocallyConnected2DLayer(conf1, keras1); buildLocallyConnected2DLayer(conf2, keras2); } - - private void buildLocallyConnected2DLayer(KerasLayerConfiguration conf, Integer kerasVersion) - throws Exception { + private void buildLocallyConnected2DLayer(KerasLayerConfiguration conf, Integer kerasVersion) throws Exception { Map layerConfig = new HashMap<>(); layerConfig.put(conf.getLAYER_FIELD_CLASS_NAME(), conf.getLAYER_CLASS_NAME_LOCALLY_CONNECTED_2D()); Map config = new HashMap<>(); @@ -91,34 +105,34 @@ public class KerasLocallyConnected1DTest extends BaseDL4JTest { config.put(conf.getLAYER_FIELD_W_REGULARIZER(), W_reg); config.put(conf.getLAYER_FIELD_DROPOUT(), DROPOUT_KERAS); if (kerasVersion == 2) { - ArrayList kernel = new ArrayList() {{ - add(KERNEL_SIZE); - }}; + ArrayList kernel = new ArrayList() { + + { + add(KERNEL_SIZE); + } + }; config.put(conf.getLAYER_FIELD_FILTER_LENGTH(), kernel); } else { config.put(conf.getLAYER_FIELD_FILTER_LENGTH(), KERNEL_SIZE); } - if (kerasVersion == 2) { - ArrayList stride = new ArrayList() {{ - add(STRIDE); - }}; + ArrayList stride = new ArrayList() { + + { + add(STRIDE); + } + }; config.put(conf.getLAYER_FIELD_SUBSAMPLE_LENGTH(), stride); } else { config.put(conf.getLAYER_FIELD_SUBSAMPLE_LENGTH(), STRIDE); } - config.put(conf.getLAYER_FIELD_NB_FILTER(), N_OUT); config.put(conf.getLAYER_FIELD_BORDER_MODE(), BORDER_MODE_VALID); layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config); layerConfig.put(conf.getLAYER_FIELD_KERAS_VERSION(), kerasVersion); - - KerasLocallyConnected1D kerasLocal = new KerasLocallyConnected1D(layerConfig); - // once get output type is triggered, inputshape, output shape and input depth are updated - kerasLocal.getOutputType(InputType.recurrent(3, 4)); - + kerasLocal.getOutputType(InputType.recurrent(3, 4)); LocallyConnected1D layer = kerasLocal.getLocallyConnected1DLayer(); assertEquals(ACTIVATION_DL4J, layer.getActivation().toString().toLowerCase()); assertEquals(LAYER_NAME, layer.getLayerName()); @@ -131,9 +145,7 @@ public class KerasLocallyConnected1DTest extends BaseDL4JTest { assertEquals(N_OUT, layer.getNOut()); assertEquals(ConvolutionMode.Truncate, layer.getCm()); assertEquals(VALID_PADDING, layer.getPadding()); - assertEquals(layer.getInputSize(), 4); assertEquals(layer.getNIn(), 3); } } - diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/local/KerasLocallyConnected2DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/local/KerasLocallyConnected2DTest.java index 8e7a49596..b42fa9063 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/local/KerasLocallyConnected2DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/local/KerasLocallyConnected2DTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.modelimport.keras.layers.local; import org.deeplearning4j.nn.conf.ConvolutionMode; @@ -30,52 +29,68 @@ import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.Test; - +import org.junit.jupiter.api.Test; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; - -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** * @author Max Pumperla */ -public class KerasLocallyConnected2DTest extends BaseDL4JTest { +@DisplayName("Keras Locally Connected 2 D Test") +class KerasLocallyConnected2DTest extends BaseDL4JTest { private final String ACTIVATION_KERAS = "linear"; + private final String ACTIVATION_DL4J = "identity"; + private final String LAYER_NAME = "test_layer"; + private final String INIT_KERAS = "glorot_normal"; + private final WeightInit INIT_DL4J = WeightInit.XAVIER; + private final double L1_REGULARIZATION = 0.01; + private final double L2_REGULARIZATION = 0.02; + private final double DROPOUT_KERAS = 0.3; + private final double DROPOUT_DL4J = 1 - DROPOUT_KERAS; - private final int[] KERNEL_SIZE = new int[]{1, 2}; - private final int[] DILATION = new int[]{2, 2}; - private final int[] STRIDE = new int[]{3, 4}; + + private final int[] KERNEL_SIZE = new int[] { 1, 2 }; + + private final int[] DILATION = new int[] { 2, 2 }; + + private final int[] STRIDE = new int[] { 3, 4 }; + private final int N_OUT = 13; + private final String BORDER_MODE_VALID = "valid"; - private final int[] VALID_PADDING = new int[]{0, 0}; + + private final int[] VALID_PADDING = new int[] { 0, 0 }; private Integer keras1 = 1; + private Integer keras2 = 2; + private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); - @Test - public void testLocallyConnected2DLayer() throws Exception { + @DisplayName("Test Locally Connected 2 D Layer") + void testLocallyConnected2DLayer() throws Exception { buildLocallyConnected2DLayer(conf1, keras1); buildLocallyConnected2DLayer(conf2, keras2); } - - private void buildLocallyConnected2DLayer(KerasLayerConfiguration conf, Integer kerasVersion) - throws Exception { + private void buildLocallyConnected2DLayer(KerasLayerConfiguration conf, Integer kerasVersion) throws Exception { Map layerConfig = new HashMap<>(); layerConfig.put(conf.getLAYER_FIELD_CLASS_NAME(), conf.getLAYER_CLASS_NAME_LOCALLY_CONNECTED_2D()); Map config = new HashMap<>(); @@ -97,12 +112,14 @@ public class KerasLocallyConnected2DTest extends BaseDL4JTest { config.put(conf.getLAYER_FIELD_NB_ROW(), KERNEL_SIZE[0]); config.put(conf.getLAYER_FIELD_NB_COL(), KERNEL_SIZE[1]); } else { - ArrayList kernel = new ArrayList() {{ - for (int i : KERNEL_SIZE) add(i); - }}; + ArrayList kernel = new ArrayList() { + + { + for (int i : KERNEL_SIZE) add(i); + } + }; config.put(conf.getLAYER_FIELD_KERNEL_SIZE(), kernel); } - List subsampleList = new ArrayList<>(); subsampleList.add(STRIDE[0]); subsampleList.add(STRIDE[1]); @@ -111,13 +128,9 @@ public class KerasLocallyConnected2DTest extends BaseDL4JTest { config.put(conf.getLAYER_FIELD_BORDER_MODE(), BORDER_MODE_VALID); layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config); layerConfig.put(conf.getLAYER_FIELD_KERAS_VERSION(), kerasVersion); - - KerasLocallyConnected2D kerasLocal = new KerasLocallyConnected2D(layerConfig); - // once get output type is triggered, inputshape, output shape and input depth are updated - kerasLocal.getOutputType(InputType.convolutional(4,4,3)); - + kerasLocal.getOutputType(InputType.convolutional(4, 4, 3)); LocallyConnected2D layer = kerasLocal.getLocallyConnected2DLayer(); assertEquals(ACTIVATION_DL4J, layer.getActivation().toString().toLowerCase()); assertEquals(LAYER_NAME, layer.getLayerName()); @@ -130,9 +143,7 @@ public class KerasLocallyConnected2DTest extends BaseDL4JTest { assertEquals(N_OUT, layer.getNOut()); assertEquals(ConvolutionMode.Truncate, layer.getCm()); assertArrayEquals(VALID_PADDING, layer.getPadding()); - - assertArrayEquals(layer.getInputSize(), new int[] {4, 4}); + assertArrayEquals(layer.getInputSize(), new int[] { 4, 4 }); assertEquals(layer.getNIn(), 3); } } - diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/noise/KerasAlphaDropoutTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/noise/KerasAlphaDropoutTest.java index 14e51a1c6..05b1d1671 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/noise/KerasAlphaDropoutTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/noise/KerasAlphaDropoutTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.modelimport.keras.layers.noise; import org.deeplearning4j.nn.conf.dropout.AlphaDropout; @@ -26,35 +25,40 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; -import org.junit.Test; - +import org.junit.jupiter.api.Test; import java.util.HashMap; import java.util.Map; - -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** * @author Max Pumperla */ -public class KerasAlphaDropoutTest extends BaseDL4JTest { +@DisplayName("Keras Alpha Dropout Test") +class KerasAlphaDropoutTest extends BaseDL4JTest { String LAYER_NAME = "alpha_dropout"; + private final double RATE_KERAS = 0.3; + private final double RATE_DL4J = 1 - RATE_KERAS; private Integer keras1 = 1; + private Integer keras2 = 2; + private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); - @Test - public void testAlphaDropoutLayer() throws Exception { + @DisplayName("Test Alpha Dropout Layer") + void testAlphaDropoutLayer() throws Exception { buildAlphaDropoutLayer(conf1, keras1); buildAlphaDropoutLayer(conf2, keras2); } - private void buildAlphaDropoutLayer(KerasLayerConfiguration conf, Integer kerasVersion) throws Exception { Map layerConfig = new HashMap<>(); layerConfig.put(conf.getLAYER_FIELD_CLASS_NAME(), conf.getLAYER_CLASS_NAME_DROPOUT()); @@ -63,10 +67,8 @@ public class KerasAlphaDropoutTest extends BaseDL4JTest { config.put(conf.getLAYER_FIELD_RATE(), RATE_KERAS); layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config); layerConfig.put(conf.getLAYER_FIELD_KERAS_VERSION(), kerasVersion); - DropoutLayer layer = new KerasAlphaDropout(layerConfig).getAlphaDropoutLayer(); assertEquals(LAYER_NAME, layer.getLayerName()); assertEquals(new AlphaDropout(RATE_DL4J), layer.getIDropout()); } - } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/noise/KerasGaussianDropoutTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/noise/KerasGaussianDropoutTest.java index f55b98c2b..cfde08a52 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/noise/KerasGaussianDropoutTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/noise/KerasGaussianDropoutTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.modelimport.keras.layers.noise; import org.deeplearning4j.nn.conf.dropout.GaussianDropout; @@ -26,35 +25,40 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; -import org.junit.Test; - +import org.junit.jupiter.api.Test; import java.util.HashMap; import java.util.Map; - -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** * @author Max Pumperla */ -public class KerasGaussianDropoutTest extends BaseDL4JTest { +@DisplayName("Keras Gaussian Dropout Test") +class KerasGaussianDropoutTest extends BaseDL4JTest { String LAYER_NAME = "gaussian_dropout"; + private final double RATE_KERAS = 0.3; + private final double RATE_DL4J = 1 - RATE_KERAS; private Integer keras1 = 1; + private Integer keras2 = 2; + private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); - @Test - public void testGaussianDropoutLayer() throws Exception { + @DisplayName("Test Gaussian Dropout Layer") + void testGaussianDropoutLayer() throws Exception { buildGaussianDropoutLayer(conf1, keras1); buildGaussianDropoutLayer(conf2, keras2); } - private void buildGaussianDropoutLayer(KerasLayerConfiguration conf, Integer kerasVersion) throws Exception { Map layerConfig = new HashMap<>(); layerConfig.put(conf.getLAYER_FIELD_CLASS_NAME(), conf.getLAYER_CLASS_NAME_DROPOUT()); @@ -63,10 +67,8 @@ public class KerasGaussianDropoutTest extends BaseDL4JTest { config.put(conf.getLAYER_FIELD_RATE(), RATE_KERAS); layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config); layerConfig.put(conf.getLAYER_FIELD_KERAS_VERSION(), kerasVersion); - DropoutLayer layer = new KerasGaussianDropout(layerConfig).getGaussianDropoutLayer(); assertEquals(LAYER_NAME, layer.getLayerName()); assertEquals(new GaussianDropout(RATE_DL4J), layer.getIDropout()); } - } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/noise/KerasGaussianNoiseTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/noise/KerasGaussianNoiseTest.java index c4d2d642c..50fe47d00 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/noise/KerasGaussianNoiseTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/noise/KerasGaussianNoiseTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.modelimport.keras.layers.noise; import org.deeplearning4j.nn.conf.dropout.GaussianNoise; @@ -26,34 +25,38 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; -import org.junit.Test; - +import org.junit.jupiter.api.Test; import java.util.HashMap; import java.util.Map; - -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** * @author Max Pumperla */ -public class KerasGaussianNoiseTest extends BaseDL4JTest { +@DisplayName("Keras Gaussian Noise Test") +class KerasGaussianNoiseTest extends BaseDL4JTest { String LAYER_NAME = "gaussian_noise"; + private final double STDDEV = 0.3; private Integer keras1 = 1; + private Integer keras2 = 2; + private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); - @Test - public void testGaussianNoiseLayer() throws Exception { + @DisplayName("Test Gaussian Noise Layer") + void testGaussianNoiseLayer() throws Exception { buildGaussianNoiseLayer(conf1, keras1); buildGaussianNoiseLayer(conf2, keras2); } - private void buildGaussianNoiseLayer(KerasLayerConfiguration conf, Integer kerasVersion) throws Exception { Map layerConfig = new HashMap<>(); layerConfig.put(conf.getLAYER_FIELD_CLASS_NAME(), conf.getLAYER_CLASS_NAME_DROPOUT()); @@ -62,10 +65,8 @@ public class KerasGaussianNoiseTest extends BaseDL4JTest { config.put(conf.getLAYER_FIELD_GAUSSIAN_VARIANCE(), STDDEV); layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config); layerConfig.put(conf.getLAYER_FIELD_KERAS_VERSION(), kerasVersion); - DropoutLayer layer = new KerasGaussianNoise(layerConfig).getGaussianNoiseLayer(); assertEquals(LAYER_NAME, layer.getLayerName()); assertEquals(new GaussianNoise(STDDEV), layer.getIDropout()); } - } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/normalization/KerasBatchNormalizationTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/normalization/KerasBatchNormalizationTest.java index d07ac8fe1..c891cc022 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/normalization/KerasBatchNormalizationTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/normalization/KerasBatchNormalizationTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.modelimport.keras.layers.normalization; import org.deeplearning4j.nn.conf.layers.BatchNormalization; @@ -25,41 +24,44 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; - import java.util.HashMap; import java.util.Map; - -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** * @author Max Pumperla */ -public class KerasBatchNormalizationTest extends BaseDL4JTest { +@DisplayName("Keras Batch Normalization Test") +class KerasBatchNormalizationTest extends BaseDL4JTest { + public static final String PARAM_NAME_BETA = "beta"; + private final String LAYER_NAME = "batch_norm_layer"; private Integer keras1 = 1; + private Integer keras2 = 2; + private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); - @Test - public void testBatchnormLayer() throws Exception { + @DisplayName("Test Batchnorm Layer") + void testBatchnormLayer() throws Exception { buildBatchNormalizationLayer(conf1, keras1); buildBatchNormalizationLayer(conf2, keras2); } - private void buildBatchNormalizationLayer(KerasLayerConfiguration conf, Integer kerasVersion) throws Exception { double epsilon = 1E-5; double momentum = 0.99; - KerasBatchNormalization batchNormalization = new KerasBatchNormalization(kerasVersion); - Map layerConfig = new HashMap<>(); layerConfig.put(conf.getLAYER_FIELD_CLASS_NAME(), conf.getLAYER_CLASS_NAME_BATCHNORMALIZATION()); Map config = new HashMap<>(); @@ -72,25 +74,21 @@ public class KerasBatchNormalizationTest extends BaseDL4JTest { config.put(batchNormalization.getLAYER_FIELD_AXIS(), 3); layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config); layerConfig.put(conf.getLAYER_FIELD_KERAS_VERSION(), kerasVersion); - BatchNormalization layer = new KerasBatchNormalization(layerConfig).getBatchNormalizationLayer(); assertEquals(LAYER_NAME, layer.getLayerName()); assertEquals(epsilon, layer.getEps(), 0.0); assertEquals(momentum, layer.getDecay(), 0.0); - } @Test - public void testSetWeights() throws Exception { + @DisplayName("Test Set Weights") + void testSetWeights() throws Exception { Map weights = weightsWithoutGamma(); KerasBatchNormalization batchNormalization = new KerasBatchNormalization(keras2); - batchNormalization.setScale(false); batchNormalization.setWeights(weights); - int size = batchNormalization.getWeights().size(); assertEquals(4, size); - } private Map weightsWithoutGamma() { diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/pooling/KerasPooling1DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/pooling/KerasPooling1DTest.java index c9ce8d8d2..8177eae46 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/pooling/KerasPooling1DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/pooling/KerasPooling1DTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.modelimport.keras.layers.pooling; import org.deeplearning4j.nn.conf.ConvolutionMode; @@ -27,56 +26,70 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; -import org.junit.Test; - +import org.junit.jupiter.api.Test; import java.util.ArrayList; import java.util.HashMap; import java.util.Map; - -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** * @author Max Pumperla */ -public class KerasPooling1DTest extends BaseDL4JTest { +@DisplayName("Keras Pooling 1 D Test") +class KerasPooling1DTest extends BaseDL4JTest { private final String LAYER_NAME = "test_layer"; - private final int[] KERNEL_SIZE = new int[]{2}; - private final int[] STRIDE = new int[]{4}; + + private final int[] KERNEL_SIZE = new int[] { 2 }; + + private final int[] STRIDE = new int[] { 4 }; + private final PoolingType POOLING_TYPE = PoolingType.MAX; + private final String BORDER_MODE_VALID = "valid"; - private final int[] VALID_PADDING = new int[]{0, 0}; + + private final int[] VALID_PADDING = new int[] { 0, 0 }; private Integer keras1 = 1; + private Integer keras2 = 2; + private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test - public void testPooling1DLayer() throws Exception { + @DisplayName("Test Pooling 1 D Layer") + void testPooling1DLayer() throws Exception { buildPooling1DLayer(conf1, keras1); buildPooling1DLayer(conf2, keras2); } - private void buildPooling1DLayer(KerasLayerConfiguration conf, Integer kerasVersion) throws Exception { Map layerConfig = new HashMap<>(); layerConfig.put(conf.getLAYER_FIELD_CLASS_NAME(), conf.getLAYER_CLASS_NAME_MAX_POOLING_1D()); Map config = new HashMap<>(); config.put(conf.getLAYER_FIELD_NAME(), LAYER_NAME); if (kerasVersion == 2) { - ArrayList kernel = new ArrayList() {{ - for (int i : KERNEL_SIZE) add(i); - }}; + ArrayList kernel = new ArrayList() { + + { + for (int i : KERNEL_SIZE) add(i); + } + }; config.put(conf.getLAYER_FIELD_POOL_1D_SIZE(), kernel); } else { config.put(conf.getLAYER_FIELD_POOL_1D_SIZE(), KERNEL_SIZE[0]); } - if (kerasVersion == 2) { - ArrayList stride = new ArrayList() {{ - for (int i : STRIDE) add(i); - }}; + ArrayList stride = new ArrayList() { + + { + for (int i : STRIDE) add(i); + } + }; config.put(conf.getLAYER_FIELD_POOL_1D_STRIDES(), stride); } else { config.put(conf.getLAYER_FIELD_POOL_1D_STRIDES(), STRIDE[0]); @@ -84,7 +97,6 @@ public class KerasPooling1DTest extends BaseDL4JTest { config.put(conf.getLAYER_FIELD_BORDER_MODE(), BORDER_MODE_VALID); layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config); layerConfig.put(conf.getLAYER_FIELD_KERAS_VERSION(), kerasVersion); - Subsampling1DLayer layer = new KerasPooling1D(layerConfig).getSubsampling1DLayer(); assertEquals(LAYER_NAME, layer.getLayerName()); assertEquals(KERNEL_SIZE[0], layer.getKernelSize()[0]); diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/pooling/KerasPooling2DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/pooling/KerasPooling2DTest.java index 6dd8d015f..e1e35af5a 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/pooling/KerasPooling2DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/pooling/KerasPooling2DTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.modelimport.keras.layers.pooling; import org.deeplearning4j.nn.conf.ConvolutionMode; @@ -27,35 +26,45 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; -import org.junit.Test; - +import org.junit.jupiter.api.Test; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; - -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** * @author Max Pumperla */ -public class KerasPooling2DTest extends BaseDL4JTest { +@DisplayName("Keras Pooling 2 D Test") +class KerasPooling2DTest extends BaseDL4JTest { private final String LAYER_NAME = "test_layer"; - private final int[] KERNEL_SIZE = new int[]{1, 2}; - private final int[] STRIDE = new int[]{3, 4}; + + private final int[] KERNEL_SIZE = new int[] { 1, 2 }; + + private final int[] STRIDE = new int[] { 3, 4 }; + private final PoolingType POOLING_TYPE = PoolingType.MAX; + private final String BORDER_MODE_VALID = "valid"; - private final int[] VALID_PADDING = new int[]{0, 0}; + + private final int[] VALID_PADDING = new int[] { 0, 0 }; private Integer keras1 = 1; + private Integer keras2 = 2; + private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test - public void testPooling2DLayer() throws Exception { + @DisplayName("Test Pooling 2 D Layer") + void testPooling2DLayer() throws Exception { buildPooling2DLayer(conf1, keras1); buildPooling2DLayer(conf2, keras2); } @@ -76,7 +85,6 @@ public class KerasPooling2DTest extends BaseDL4JTest { config.put(conf.getLAYER_FIELD_BORDER_MODE(), BORDER_MODE_VALID); layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config); layerConfig.put(conf.getLAYER_FIELD_KERAS_VERSION(), kerasVersion); - SubsamplingLayer layer = new KerasPooling2D(layerConfig).getSubsampling2DLayer(); assertEquals(LAYER_NAME, layer.getLayerName()); assertArrayEquals(KERNEL_SIZE, layer.getKernelSize()); diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/pooling/KerasPooling3DTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/pooling/KerasPooling3DTest.java index f9bb4f667..24041930f 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/pooling/KerasPooling3DTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/pooling/KerasPooling3DTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.modelimport.keras.layers.pooling; import org.deeplearning4j.nn.conf.ConvolutionMode; @@ -27,35 +26,45 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; -import org.junit.Test; - +import org.junit.jupiter.api.Test; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; - -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** * @author Max Pumperla */ -public class KerasPooling3DTest extends BaseDL4JTest { +@DisplayName("Keras Pooling 3 D Test") +class KerasPooling3DTest extends BaseDL4JTest { private final String LAYER_NAME = "pooling_3d"; - private final int[] KERNEL_SIZE = new int[]{2, 2, 2}; - private final int[] STRIDE = new int[]{1, 1, 1}; + + private final int[] KERNEL_SIZE = new int[] { 2, 2, 2 }; + + private final int[] STRIDE = new int[] { 1, 1, 1 }; + private final PoolingType POOLING_TYPE = PoolingType.MAX; + private final String BORDER_MODE_VALID = "valid"; - private final int[] VALID_PADDING = new int[]{0, 0, 0}; + + private final int[] VALID_PADDING = new int[] { 0, 0, 0 }; private Integer keras1 = 1; + private Integer keras2 = 2; + private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test - public void testPooling3DLayer() throws Exception { + @DisplayName("Test Pooling 3 D Layer") + void testPooling3DLayer() throws Exception { buildPooling3DLayer(conf1, keras1); buildPooling3DLayer(conf2, keras2); } @@ -78,7 +87,6 @@ public class KerasPooling3DTest extends BaseDL4JTest { config.put(conf.getLAYER_FIELD_BORDER_MODE(), BORDER_MODE_VALID); layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config); layerConfig.put(conf.getLAYER_FIELD_KERAS_VERSION(), kerasVersion); - Subsampling3DLayer layer = new KerasPooling3D(layerConfig).getSubsampling3DLayer(); assertEquals(LAYER_NAME, layer.getLayerName()); assertArrayEquals(KERNEL_SIZE, layer.getKernelSize()); diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTMTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTMTest.java index e8b541b77..376d84c2e 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTMTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTMTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.modelimport.keras.layers.recurrent; import org.deeplearning4j.nn.conf.dropout.Dropout; @@ -35,41 +34,57 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfig import org.deeplearning4j.nn.modelimport.keras.layers.embeddings.KerasEmbedding; import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.WeightInitXavier; -import org.junit.Assert; -import org.junit.Test; - +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.Map; - -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** * @author Max Pumperla */ -public class KerasLSTMTest extends BaseDL4JTest { +@DisplayName("Keras LSTM Test") +class KerasLSTMTest extends BaseDL4JTest { private final String ACTIVATION_KERAS = "linear"; + private final String ACTIVATION_DL4J = "identity"; + private final String LAYER_NAME = "lstm_layer"; + private final String INIT_KERAS = "glorot_normal"; + private final IWeightInit INIT_DL4J = new WeightInitXavier(); + private final double L1_REGULARIZATION = 0.01; + private final double L2_REGULARIZATION = 0.02; + private final double DROPOUT_KERAS = 0.3; + private final double DROPOUT_DL4J = 1 - DROPOUT_KERAS; + private final int N_OUT = 13; - private Boolean[] returnSequences = new Boolean[]{true, false}; - private Boolean[] maskZero = new Boolean[]{true, false}; + private Boolean[] returnSequences = new Boolean[] { true, false }; + + private Boolean[] maskZero = new Boolean[] { true, false }; + private Integer keras1 = 1; + private Integer keras2 = 2; + private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test - public void testLstmLayer() throws Exception { + @DisplayName("Test Lstm Layer") + void testLstmLayer() throws Exception { for (Boolean rs : returnSequences) { buildLstmLayer(conf1, keras1, rs); buildLstmLayer(conf2, keras2, rs); @@ -85,7 +100,6 @@ public class KerasLSTMTest extends BaseDL4JTest { double lstmForgetBiasDouble = 1.0; String lstmForgetBiasString = "one"; boolean lstmUnroll = true; - Map layerConfig = new HashMap<>(); layerConfig.put(conf.getLAYER_FIELD_CLASS_NAME(), conf.getLAYER_CLASS_NAME_LSTM()); Map config = new HashMap<>(); @@ -95,7 +109,6 @@ public class KerasLSTMTest extends BaseDL4JTest { if (kerasVersion == 1) { config.put(conf.getLAYER_FIELD_INNER_INIT(), INIT_KERAS); config.put(conf.getLAYER_FIELD_INIT(), INIT_KERAS); - } else { Map init = new HashMap<>(); init.put("class_name", conf.getINIT_GLOROT_NORMAL()); @@ -107,7 +120,6 @@ public class KerasLSTMTest extends BaseDL4JTest { W_reg.put(conf.getREGULARIZATION_TYPE_L2(), L2_REGULARIZATION); config.put(conf.getLAYER_FIELD_W_REGULARIZER(), W_reg); config.put(conf.getLAYER_FIELD_RETURN_SEQUENCES(), rs); - config.put(conf.getLAYER_FIELD_DROPOUT_W(), DROPOUT_KERAS); config.put(conf.getLAYER_FIELD_DROPOUT_U(), 0.0); config.put(conf.getLAYER_FIELD_FORGET_BIAS_INIT(), lstmForgetBiasString); @@ -115,7 +127,6 @@ public class KerasLSTMTest extends BaseDL4JTest { config.put(conf.getLAYER_FIELD_UNROLL(), lstmUnroll); layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config); layerConfig.put(conf.getLAYER_FIELD_KERAS_VERSION(), kerasVersion); - LSTM layer; LastTimeStep lts; KerasLSTM kerasLstm = new KerasLSTM(layerConfig); @@ -137,15 +148,12 @@ public class KerasLSTMTest extends BaseDL4JTest { assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout()); assertEquals(lstmForgetBiasDouble, layer.getForgetGateBiasInit(), 0.0); assertEquals(N_OUT, layer.getNOut()); - } - private void buildMaskZeroLstmLayer(KerasLayerConfiguration conf, Integer kerasVersion, Boolean maskZero) - throws Exception { + private void buildMaskZeroLstmLayer(KerasLayerConfiguration conf, Integer kerasVersion, Boolean maskZero) throws Exception { String innerActivation = "hard_sigmoid"; String lstmForgetBiasString = "one"; boolean lstmUnroll = true; - Map layerConfig = new HashMap<>(); layerConfig.put(conf.getLAYER_FIELD_CLASS_NAME(), conf.getLAYER_CLASS_NAME_LSTM()); Map config = new HashMap<>(); @@ -155,7 +163,6 @@ public class KerasLSTMTest extends BaseDL4JTest { if (kerasVersion == 1) { config.put(conf.getLAYER_FIELD_INNER_INIT(), INIT_KERAS); config.put(conf.getLAYER_FIELD_INIT(), INIT_KERAS); - } else { Map init = new HashMap<>(); init.put("class_name", conf.getINIT_GLOROT_NORMAL()); @@ -166,28 +173,22 @@ public class KerasLSTMTest extends BaseDL4JTest { W_reg.put(conf.getREGULARIZATION_TYPE_L1(), L1_REGULARIZATION); W_reg.put(conf.getREGULARIZATION_TYPE_L2(), L2_REGULARIZATION); config.put(conf.getLAYER_FIELD_W_REGULARIZER(), W_reg); - config.put(conf.getLAYER_FIELD_DROPOUT_W(), DROPOUT_KERAS); config.put(conf.getLAYER_FIELD_DROPOUT_U(), 0.0); config.put(conf.getLAYER_FIELD_FORGET_BIAS_INIT(), lstmForgetBiasString); config.put(conf.getLAYER_FIELD_OUTPUT_DIM(), N_OUT); config.put(conf.getLAYER_FIELD_UNROLL(), lstmUnroll); config.put(conf.getLAYER_FIELD_RETURN_SEQUENCES(), true); - layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config); layerConfig.put(conf.getLAYER_FIELD_KERAS_VERSION(), kerasVersion); - layerConfig.put(conf.getLAYER_FIELD_INBOUND_NODES(), - Arrays.asList(Arrays.asList( - Arrays.asList("embedding")))); + layerConfig.put(conf.getLAYER_FIELD_INBOUND_NODES(), Arrays.asList(Arrays.asList(Arrays.asList("embedding")))); KerasEmbedding embedding = getEmbedding(maskZero); Map previousLayers = Collections.singletonMap("embedding", embedding); - KerasLSTM kerasLstm = new KerasLSTM(layerConfig, previousLayers); - Assert.assertEquals(kerasLstm.getLayer() instanceof MaskZeroLayer, maskZero); + Assertions.assertEquals(kerasLstm.getLayer() instanceof MaskZeroLayer, maskZero); } - private KerasEmbedding getEmbedding(boolean maskZero) - throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { + private KerasEmbedding getEmbedding(boolean maskZero) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { KerasEmbedding embedding = new KerasEmbedding(); embedding.setZeroMasking(maskZero); return embedding; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnnTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnnTest.java index 0da6edef8..9a6c24233 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnnTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnnTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.modelimport.keras.layers.recurrent; import org.deeplearning4j.nn.conf.dropout.Dropout; @@ -30,36 +29,50 @@ import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.WeightInitXavier; -import org.junit.Test; - +import org.junit.jupiter.api.Test; import java.util.HashMap; import java.util.Map; - -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** * @author Max Pumperla */ -public class KerasSimpleRnnTest extends BaseDL4JTest { +@DisplayName("Keras Simple Rnn Test") +class KerasSimpleRnnTest extends BaseDL4JTest { private final String ACTIVATION = "sigmoid"; + private final String LAYER_NAME = "simple_rnn_layer"; + private final String INIT_KERAS = "glorot_normal"; + private final IWeightInit INIT_DL4J = new WeightInitXavier(); + private final double L1_REGULARIZATION = 0.01; + private final double L2_REGULARIZATION = 0.02; + private final double DROPOUT_KERAS = 0.3; + private final double DROPOUT_DL4J = 1 - DROPOUT_KERAS; + private final int N_OUT = 13; - private Boolean[] returnSequences = new Boolean[]{true, false}; + private Boolean[] returnSequences = new Boolean[] { true, false }; + private Integer keras1 = 1; + private Integer keras2 = 2; + private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test - public void testSimpleRnnLayer() throws Exception { + @DisplayName("Test Simple Rnn Layer") + void testSimpleRnnLayer() throws Exception { for (Boolean rs : returnSequences) { buildSimpleRnnLayer(conf1, keras1, rs); buildSimpleRnnLayer(conf2, keras2, rs); @@ -67,7 +80,6 @@ public class KerasSimpleRnnTest extends BaseDL4JTest { } private void buildSimpleRnnLayer(KerasLayerConfiguration conf, Integer kerasVersion, Boolean rs) throws Exception { - Map layerConfig = new HashMap<>(); layerConfig.put(conf.getLAYER_FIELD_CLASS_NAME(), conf.getLAYER_CLASS_NAME_LSTM()); Map config = new HashMap<>(); @@ -76,7 +88,6 @@ public class KerasSimpleRnnTest extends BaseDL4JTest { if (kerasVersion == 1) { config.put(conf.getLAYER_FIELD_INNER_INIT(), INIT_KERAS); config.put(conf.getLAYER_FIELD_INIT(), INIT_KERAS); - } else { Map init = new HashMap<>(); init.put("class_name", conf.getINIT_GLOROT_NORMAL()); @@ -88,17 +99,13 @@ public class KerasSimpleRnnTest extends BaseDL4JTest { W_reg.put(conf.getREGULARIZATION_TYPE_L2(), L2_REGULARIZATION); config.put(conf.getLAYER_FIELD_W_REGULARIZER(), W_reg); config.put(conf.getLAYER_FIELD_RETURN_SEQUENCES(), rs); - config.put(conf.getLAYER_FIELD_DROPOUT_W(), DROPOUT_KERAS); config.put(conf.getLAYER_FIELD_DROPOUT_U(), 0.0); config.put(conf.getLAYER_FIELD_OUTPUT_DIM(), N_OUT); config.put(conf.getLAYER_FIELD_UNROLL(), true); layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config); layerConfig.put(conf.getLAYER_FIELD_KERAS_VERSION(), kerasVersion); - - - SimpleRnn layer = rs ? (SimpleRnn) new KerasSimpleRnn(layerConfig).getSimpleRnnLayer() : - (SimpleRnn) ((LastTimeStep) new KerasSimpleRnn(layerConfig).getSimpleRnnLayer()).getUnderlying(); + SimpleRnn layer = rs ? (SimpleRnn) new KerasSimpleRnn(layerConfig).getSimpleRnnLayer() : (SimpleRnn) ((LastTimeStep) new KerasSimpleRnn(layerConfig).getSimpleRnnLayer()).getUnderlying(); assertEquals(ACTIVATION, layer.getActivationFn().toString()); assertEquals(LAYER_NAME, layer.getLayerName()); assertEquals(INIT_DL4J, layer.getWeightInitFn()); diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/wrappers/KerasBidirectionalTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/wrappers/KerasBidirectionalTest.java index ce78746d0..ed0cb7b01 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/wrappers/KerasBidirectionalTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/wrappers/KerasBidirectionalTest.java @@ -17,7 +17,6 @@ * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ - package org.deeplearning4j.nn.modelimport.keras.layers.wrappers; import org.deeplearning4j.nn.conf.layers.LSTM; @@ -27,38 +26,53 @@ import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; - import java.util.HashMap; import java.util.Map; - -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.extension.ExtendWith; /** * @author Max Pumperla */ -public class KerasBidirectionalTest extends BaseDL4JTest { +@DisplayName("Keras Bidirectional Test") +class KerasBidirectionalTest extends BaseDL4JTest { private final String ACTIVATION_KERAS = "linear"; + private final String ACTIVATION_DL4J = "identity"; + private final String LAYER_NAME = "bidirectional_layer"; + private final String INIT_KERAS = "glorot_normal"; + private final WeightInit INIT_DL4J = WeightInit.XAVIER; + private final double L1_REGULARIZATION = 0.01; + private final double L2_REGULARIZATION = 0.02; + private final double DROPOUT_KERAS = 0.3; + private final double DROPOUT_DL4J = 1 - DROPOUT_KERAS; + private final int N_OUT = 13; + private final String mode = "sum"; private Integer keras1 = 1; + private Integer keras2 = 2; + private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test - public void testLstmLayer() throws Exception { + @DisplayName("Test Lstm Layer") + void testLstmLayer() throws Exception { buildLstmLayer(conf1, keras1); buildLstmLayer(conf2, keras2); } @@ -66,17 +80,17 @@ public class KerasBidirectionalTest extends BaseDL4JTest { private void buildLstmLayer(KerasLayerConfiguration conf, Integer kerasVersion) throws Exception { String innerActivation = "hard_sigmoid"; String lstmForgetBiasString = "one"; - Map layerConfig = new HashMap<>(); layerConfig.put(conf.getLAYER_FIELD_CLASS_NAME(), conf.getLAYER_CLASS_NAME_LSTM()); Map lstmConfig = new HashMap<>(); - lstmConfig.put(conf.getLAYER_FIELD_ACTIVATION(), ACTIVATION_KERAS); // keras linear -> dl4j identity - lstmConfig.put(conf.getLAYER_FIELD_INNER_ACTIVATION(), innerActivation); // keras linear -> dl4j identity + // keras linear -> dl4j identity + lstmConfig.put(conf.getLAYER_FIELD_ACTIVATION(), ACTIVATION_KERAS); + // keras linear -> dl4j identity + lstmConfig.put(conf.getLAYER_FIELD_INNER_ACTIVATION(), innerActivation); lstmConfig.put(conf.getLAYER_FIELD_NAME(), LAYER_NAME); if (kerasVersion == 1) { lstmConfig.put(conf.getLAYER_FIELD_INNER_INIT(), INIT_KERAS); lstmConfig.put(conf.getLAYER_FIELD_INIT(), INIT_KERAS); - } else { Map init = new HashMap<>(); init.put("class_name", conf.getINIT_GLOROT_NORMAL()); @@ -88,31 +102,23 @@ public class KerasBidirectionalTest extends BaseDL4JTest { W_reg.put(conf.getREGULARIZATION_TYPE_L2(), L2_REGULARIZATION); lstmConfig.put(conf.getLAYER_FIELD_W_REGULARIZER(), W_reg); lstmConfig.put(conf.getLAYER_FIELD_RETURN_SEQUENCES(), true); - lstmConfig.put(conf.getLAYER_FIELD_DROPOUT_W(), DROPOUT_KERAS); lstmConfig.put(conf.getLAYER_FIELD_DROPOUT_U(), 0.0); lstmConfig.put(conf.getLAYER_FIELD_FORGET_BIAS_INIT(), lstmForgetBiasString); lstmConfig.put(conf.getLAYER_FIELD_OUTPUT_DIM(), N_OUT); lstmConfig.put(conf.getLAYER_FIELD_UNROLL(), true); - Map innerRnnConfig = new HashMap<>(); innerRnnConfig.put("class_name", "LSTM"); innerRnnConfig.put("config", lstmConfig); - Map innerConfig = new HashMap<>(); innerConfig.put("merge_mode", mode); innerConfig.put("layer", innerRnnConfig); innerConfig.put(conf.getLAYER_FIELD_NAME(), LAYER_NAME); - layerConfig.put("config", innerConfig); layerConfig.put(conf.getLAYER_FIELD_KERAS_VERSION(), kerasVersion); - KerasBidirectional kerasBidirectional = new KerasBidirectional(layerConfig); Bidirectional layer = kerasBidirectional.getBidirectionalLayer(); - assertEquals(Bidirectional.Mode.ADD, layer.getMode()); - assertEquals(Activation.HARDSIGMOID.toString().toLowerCase(), - ((LSTM) kerasBidirectional.getUnderlyingRecurrentLayer()).getGateActivationFn().toString()); - + assertEquals(Activation.HARDSIGMOID.toString().toLowerCase(), ((LSTM) kerasBidirectional.getUnderlyingRecurrentLayer()).getGateActivationFn().toString()); } } diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/pom.xml b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/pom.xml index 47d1a6432..0cd3e8071 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/pom.xml +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/pom.xml @@ -57,10 +57,18 @@ org.threadly threadly ${threadly.version} + + + org.junit.jupiter + junit-jupiter-api + ${junit.version} + test - junit - junit + org.junit.jupiter + junit-jupiter-engine + ${junit.version} + test org.mockito diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/reader/impl/TreeModelUtils.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/reader/impl/TreeModelUtils.java deleted file mode 100644 index 8f47cff02..000000000 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/reader/impl/TreeModelUtils.java +++ /dev/null @@ -1,120 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.models.embeddings.reader.impl; - -import lombok.NonNull; -import org.deeplearning4j.clustering.sptree.DataPoint; -import org.deeplearning4j.clustering.vptree.VPTree; -import org.deeplearning4j.models.embeddings.WeightLookupTable; -import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.common.util.SetUtils; - -import java.util.*; - -public class TreeModelUtils extends BasicModelUtils { - protected VPTree vpTree; - - @Override - public void init(@NonNull WeightLookupTable lookupTable) { - super.init(lookupTable); - vpTree = null; - } - - protected synchronized void checkTree() { - // build new tree if it wasn't created before - if (vpTree == null) { - List points = new ArrayList<>(); - for (String word : vocabCache.words()) { - points.add(new DataPoint(vocabCache.indexOf(word), lookupTable.vector(word))); - } - vpTree = new VPTree(points); - } - } - - - /** - * This method returns nearest words for target word, based on tree structure. - * This method is recommended to use if you're going to call for nearest words multiple times. - * VPTree will be built upon firt call to this method - * - * @param label label of element we're looking nearest words to - * @param n number of nearest elements to return - * @return - */ - @Override - public Collection wordsNearest(String label, int n) { - if (!vocabCache.hasToken(label)) - return new ArrayList<>(); - - Collection collection = wordsNearest(Arrays.asList(label), new ArrayList(), n + 1); - if (collection.contains(label)) - collection.remove(label); - - return collection; - } - - @Override - public Collection wordsNearest(Collection positive, Collection negative, int top) { - - // Check every word is in the model - for (String p : SetUtils.union(new HashSet<>(positive), new HashSet<>(negative))) { - if (!vocabCache.containsWord(p)) { - return new ArrayList<>(); - } - } - - INDArray words = Nd4j.create(positive.size() + negative.size(), lookupTable.layerSize()); - int row = 0; - for (String s : positive) { - words.putRow(row++, lookupTable.vector(s)); - } - - for (String s : negative) { - words.putRow(row++, lookupTable.vector(s).mul(-1)); - } - - INDArray mean = words.isMatrix() ? words.mean(0) : words; - - return wordsNearest(mean, top); - } - - @Override - public Collection wordsNearest(INDArray words, int top) { - checkTree(); - words = adjustRank(words); - - List add = new ArrayList<>(); - List distances = new ArrayList<>(); - - // we need n+1 to address original datapoint removal - vpTree.search(words, top, add, distances); - - Collection ret = new ArrayList<>(); - for (DataPoint e : add) { - String word = vocabCache.wordAtIndex(e.getIndex()); - ret.add(word); - } - - return super.wordsNearest(words, top); - } -} diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/embeddings/reader/impl/FlatModelUtilsTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/embeddings/reader/impl/FlatModelUtilsTest.java index f35ea7816..c0b73eddd 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/embeddings/reader/impl/FlatModelUtilsTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/embeddings/reader/impl/FlatModelUtilsTest.java @@ -91,15 +91,7 @@ public class FlatModelUtilsTest extends BaseDL4JTest { assertEquals(arr1, arr2); } - @Test - @Ignore - public void testWordsNearestTree1() throws Exception { - vec.setModelUtils(new TreeModelUtils()); - Collection list = vec.wordsNearest("energy", 10); - log.info("Tree model results:"); - printWords("energy", list, vec); - } private static void printWords(String target, Collection list, WordVectors vec) { System.out.println("Words close to [" + target + "]:"); diff --git a/deeplearning4j/deeplearning4j-nn/pom.xml b/deeplearning4j/deeplearning4j-nn/pom.xml index e3e34d76b..62d092567 100644 --- a/deeplearning4j/deeplearning4j-nn/pom.xml +++ b/deeplearning4j/deeplearning4j-nn/pom.xml @@ -104,10 +104,18 @@ it.unimi.dsi fastutil ${fastutil.version} + + + org.junit.jupiter + junit-jupiter-api + ${junit.version} + test - junit - junit + org.junit.jupiter + junit-jupiter-engine + ${junit.version} + test org.deeplearning4j diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper-parameter-server/pom.xml b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper-parameter-server/pom.xml index 8aa886719..994364216 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper-parameter-server/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper-parameter-server/pom.xml @@ -69,10 +69,18 @@ org.nd4j nd4j-parameter-server-node_2.11 ${nd4j.version} + + + org.junit.jupiter + junit-jupiter-api + ${junit.version} + test - junit - junit + org.junit.jupiter + junit-jupiter-engine + ${junit.version} + test org.scala-lang diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/pom.xml b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/pom.xml index b32a4807d..77e481c6a 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/pom.xml @@ -65,10 +65,18 @@ org.nd4j nd4j-parameter-server-client ${nd4j.version} + + + org.junit.jupiter + junit-jupiter-api + ${junit.version} + test - junit - junit + org.junit.jupiter + junit-jupiter-engine + ${junit.version} + test org.deeplearning4j diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml index 214c7a271..850335cbf 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml @@ -62,8 +62,15 @@ - junit - junit + org.junit.jupiter + junit-jupiter-api + ${junit.version} + test + + + org.junit.jupiter + junit-jupiter-engine + ${junit.version} test diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/pom.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/pom.xml index d63d1e8b4..9e6f92e6b 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/pom.xml @@ -51,8 +51,16 @@ ${project.version} - junit - junit + org.junit.jupiter + junit-jupiter-api + ${junit.version} + test + + + org.junit.jupiter + junit-jupiter-engine + ${junit.version} + test org.datavec diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/pom.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/pom.xml index 75d8579fc..1068bda5c 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/pom.xml @@ -54,8 +54,16 @@ - junit - junit + org.junit.jupiter + junit-jupiter-api + ${junit.version} + test + + + org.junit.jupiter + junit-jupiter-engine + ${junit.version} + test ch.qos.logback diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-components/pom.xml b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-components/pom.xml index a0c944ee9..3a96e8a4a 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-components/pom.xml +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-components/pom.xml @@ -46,8 +46,16 @@ ${freemarker.version} - junit - junit + org.junit.jupiter + junit-jupiter-api + ${junit.version} + test + + + org.junit.jupiter + junit-jupiter-engine + ${junit.version} + test commons-io diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/pom.xml b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/pom.xml index a2bd0595f..137d78fce 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/pom.xml +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/pom.xml @@ -83,8 +83,16 @@ provided - junit - junit + org.junit.jupiter + junit-jupiter-api + ${junit.version} + test + + + org.junit.jupiter + junit-jupiter-engine + ${junit.version} + test org.deeplearning4j diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/pom.xml b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/pom.xml index 4454adda8..53d11e05a 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/pom.xml +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/pom.xml @@ -57,8 +57,13 @@ ${project.version} - junit - junit + org.junit.jupiter + junit-jupiter-api + + + org.junit.jupiter + junit-jupiter-engine + ${junit.version} org.deeplearning4j diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/src/test/java/org/deeplearning4j/ui/ApiTest.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/src/test/java/org/deeplearning4j/ui/ApiTest.java deleted file mode 100644 index 2b26b76ec..000000000 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/src/test/java/org/deeplearning4j/ui/ApiTest.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.ui; - -import org.apache.commons.io.IOUtils; -import org.junit.Ignore; -import org.junit.Test; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.indexing.NDArrayIndex; -import org.nd4j.common.io.ClassPathResource; -import org.nd4j.common.resources.Resources; - -import java.io.File; -import java.util.List; - -/** - * @author Adam Gibson - */ -public class ApiTest { - - -} diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/src/test/java/org/deeplearning4j/ui/ManualTests.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/src/test/java/org/deeplearning4j/ui/ManualTests.java deleted file mode 100644 index b13aecaef..000000000 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/src/test/java/org/deeplearning4j/ui/ManualTests.java +++ /dev/null @@ -1,351 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.ui; - -import lombok.extern.slf4j.Slf4j; -import org.apache.commons.io.IOUtils; -import org.datavec.image.loader.LFWLoader; -import org.deeplearning4j.datasets.iterator.impl.LFWDataSetIterator; -import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; -import org.deeplearning4j.eval.Evaluation; -import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils; -import org.deeplearning4j.models.word2vec.VocabWord; -import org.deeplearning4j.models.word2vec.Word2Vec; -import org.deeplearning4j.nn.conf.GradientNormalization; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.inputs.InputType; -import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; -import org.deeplearning4j.nn.conf.layers.DenseLayer; -import org.deeplearning4j.nn.conf.layers.OutputLayer; -import org.deeplearning4j.nn.conf.layers.SubsamplingLayer; -import org.deeplearning4j.nn.conf.layers.misc.FrozenLayer; -import org.deeplearning4j.nn.conf.weightnoise.DropConnect; -import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.deeplearning4j.nn.weights.WeightInit; -import org.deeplearning4j.optimize.listeners.ScoreIterationListener; -import org.deeplearning4j.text.sentenceiterator.BasicLineIterator; -import org.deeplearning4j.text.sentenceiterator.SentenceIterator; -import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor; -import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory; -import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; -import org.deeplearning4j.ui.api.UIServer; -import org.deeplearning4j.ui.weights.ConvolutionalIterationListener; -import org.junit.Ignore; -import org.junit.Test; -import org.nd4j.common.io.ClassPathResource; -import org.nd4j.linalg.activations.Activation; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.buffer.util.DataTypeUtil; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.dataset.DataSet; -import org.nd4j.linalg.dataset.SplitTestAndTrain; -import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.indexing.NDArrayIndex; -import org.nd4j.linalg.learning.config.AdaGrad; -import org.nd4j.linalg.learning.config.Nesterovs; -import org.nd4j.linalg.lossfunctions.LossFunctions; -import org.nd4j.common.resources.Resources; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import javax.imageio.ImageIO; -import java.awt.image.BufferedImage; -import java.io.File; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Random; -import java.util.UUID; - -import static org.junit.Assert.fail; - -@Ignore -@Slf4j -public class ManualTests { - - - @Test - public void testLaunch() throws Exception { - - // UiServer server = UiServer.getInstance(); - // - // System.out.println("http://localhost:" + server.getPort()+ "/"); - - Thread.sleep(10000000000L); - - new ScoreIterationListener(100); - fail("not implemneted"); - } - - - - - /** - * This test is for manual execution only, since it's here just to get working CNN and visualize it's layers - * - * @throws Exception - */ - @Test - public void testCNNActivationsVisualization() throws Exception { - final int numRows = 40; - final int numColumns = 40; - int nChannels = 3; - int outputNum = LFWLoader.NUM_LABELS; - int numSamples = LFWLoader.NUM_IMAGES; - boolean useSubset = false; - int batchSize = 200;// numSamples/10; - int iterations = 5; - int splitTrainNum = (int) (batchSize * .8); - int seed = 123; - int listenerFreq = iterations / 5; - DataSet lfwNext; - SplitTestAndTrain trainTest; - DataSet trainInput; - List testInput = new ArrayList<>(); - List testLabels = new ArrayList<>(); - - log.info("Load data...."); - DataSetIterator lfw = new LFWDataSetIterator(batchSize, numSamples, new int[] {numRows, numColumns, nChannels}, - outputNum, useSubset, true, 1.0, new Random(seed)); - - log.info("Build model...."); - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed) - .activation(Activation.RELU).weightInit(WeightInit.XAVIER) - .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) - .updater(new AdaGrad(0.01)).weightNoise(new DropConnect(0.5)).list() - .layer(0, new ConvolutionLayer.Builder(4, 4).name("cnn1").nIn(nChannels).stride(1, 1).nOut(20) - .build()) - .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[] {2, 2}) - .name("pool1").build()) - .layer(2, new ConvolutionLayer.Builder(3, 3).name("cnn2").stride(1, 1).nOut(40).build()) - .layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[] {2, 2}) - .name("pool2").build()) - .layer(4, new ConvolutionLayer.Builder(3, 3).name("cnn3").stride(1, 1).nOut(60).build()) - .layer(5, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[] {2, 2}) - .name("pool3").build()) - .layer(6, new ConvolutionLayer.Builder(2, 2).name("cnn3").stride(1, 1).nOut(80).build()) - .layer(7, new DenseLayer.Builder().name("ffn1").nOut(160).dropOut(0.5).build()) - .layer(8, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) - .nOut(outputNum).activation(Activation.SOFTMAX).build()) - - .setInputType(InputType.convolutional(numRows, numColumns, nChannels)); - - MultiLayerNetwork model = new MultiLayerNetwork(builder.build()); - model.init(); - - log.info("Train model...."); - - model.setListeners(new ScoreIterationListener(listenerFreq), new ConvolutionalIterationListener(listenerFreq)); - - while (lfw.hasNext()) { - lfwNext = lfw.next(); - lfwNext.scale(); - trainTest = lfwNext.splitTestAndTrain(splitTrainNum, new Random(seed)); // train set that is the result - trainInput = trainTest.getTrain(); // get feature matrix and labels for training - testInput.add(trainTest.getTest().getFeatures()); - testLabels.add(trainTest.getTest().getLabels()); - model.fit(trainInput); - } - - log.info("Evaluate model...."); - Evaluation eval = new Evaluation(lfw.getLabels()); - for (int i = 0; i < testInput.size(); i++) { - INDArray output = model.output(testInput.get(i)); - eval.eval(testLabels.get(i), output); - } - INDArray output = model.output(testInput.get(0)); - eval.eval(testLabels.get(0), output); - log.info(eval.stats()); - log.info("****************Example finished********************"); - - } - - @Test(timeout = 300000) - public void testWord2VecPlot() throws Exception { - File inputFile = Resources.asFile("big/raw_sentences.txt"); - SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath()); - - TokenizerFactory t = new DefaultTokenizerFactory(); - t.setTokenPreProcessor(new CommonPreprocessor()); - - Word2Vec vec = new Word2Vec.Builder().minWordFrequency(5).iterations(2).batchSize(1000).learningRate(0.025) - .layerSize(100).seed(42).sampling(0).negativeSample(0).windowSize(5) - .modelUtils(new BasicModelUtils()).useAdaGrad(false).iterate(iter).workers(10) - .tokenizerFactory(t).build(); - - vec.fit(); - - // UiConnectionInfo connectionInfo = UiServer.getInstance().getConnectionInfo(); - - // vec.getLookupTable().plotVocab(100, connectionInfo); - - Thread.sleep(10000000000L); - fail("Not implemented"); - } - - @Test - public void testImage() throws Exception { - INDArray array = Nd4j.create(11, 13); - for (int i = 0; i < array.rows(); i++) { - array.putRow(i, Nd4j.create(new double[] {0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1.0f, - 1.2f, 1.3f})); - } - writeImage(array, new File("test.png")); - } - - private void writeImage(INDArray array, File file) { - // BufferedImage image = ImageLoader.toImage(array); - - log.info("Array.rank(): " + array.rank()); - log.info("Size(-1): " + array.size(-1)); - log.info("Size(-2): " + array.size(-2)); - BufferedImage imageToRender = new BufferedImage(array.columns(), array.rows(), BufferedImage.TYPE_BYTE_GRAY); - for (int x = 0; x < array.columns(); x++) { - for (int y = 0; y < array.rows(); y++) { - log.info("x: " + (x) + " y: " + y); - imageToRender.getRaster().setSample(x, y, 0, (int) (255 * array.getRow(y).getDouble(x))); - } - } - - try { - ImageIO.write(imageToRender, "png", file); - } catch (IOException e) { - log.error("",e); - } - - } - - @Test - public void testCNNActivations2() throws Exception { - - int nChannels = 1; - int outputNum = 10; - int batchSize = 64; - int nEpochs = 10; - int seed = 123; - - log.info("Load data...."); - DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, 12345); - DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, 12345); - - log.info("Build model...."); - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed) - .l2(0.0005) - .weightInit(WeightInit.XAVIER) - .updater(new Nesterovs(0.01, 0.9)).list() - .layer(0, new ConvolutionLayer.Builder(5, 5) - //nIn and nOut specify depth. nIn here is the nChannels and nOut is the number of filters to be applied - .nIn(nChannels).stride(1, 1).nOut(20).activation(Activation.IDENTITY).build()) - .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2) - .stride(2, 2).build()) - .layer(2, new ConvolutionLayer.Builder(5, 5) - //Note that nIn needed be specified in later layers - .stride(1, 1).nOut(50).activation(Activation.IDENTITY).build()) - .layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2) - .stride(2, 2).build()) - .layer(4, new DenseLayer.Builder().activation(Activation.RELU).nOut(500).build()) - .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) - .nOut(outputNum).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutional(28, 28, nChannels)); - - MultiLayerConfiguration conf = builder.build(); - MultiLayerNetwork model = new MultiLayerNetwork(conf); - model.init(); - /* - ParallelWrapper wrapper = new ParallelWrapper.Builder(model) - .averagingFrequency(1) - .prefetchBuffer(12) - .workers(2) - .reportScoreAfterAveraging(false) - .useLegacyAveraging(false) - .build(); - */ - - log.info("Train model...."); - model.setListeners(new ConvolutionalIterationListener(1)); - - //((NativeOpExecutioner) Nd4j.getExecutioner()).getLoop().setOmpNumThreads(8); - - long timeX = System.currentTimeMillis(); - // nEpochs = 2; - for (int i = 0; i < nEpochs; i++) { - long time1 = System.currentTimeMillis(); - model.fit(mnistTrain); - //wrapper.fit(mnistTrain); - long time2 = System.currentTimeMillis(); - log.info("*** Completed epoch {}, Time elapsed: {} ***", i, (time2 - time1)); - } - long timeY = System.currentTimeMillis(); - - log.info("Evaluate model...."); - Evaluation eval = new Evaluation(outputNum); - while (mnistTest.hasNext()) { - DataSet ds = mnistTest.next(); - INDArray output = model.output(ds.getFeatures(), false); - eval.eval(ds.getLabels(), output); - } - log.info(eval.stats()); - mnistTest.reset(); - - log.info("****************Example finished********************"); - } - - @Test - public void testCNNActivationsFrozen() throws Exception { - - int nChannels = 1; - int outputNum = 10; - int batchSize = 64; - int nEpochs = 10; - int seed = 123; - - log.info("Load data...."); - DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, 12345); - - log.info("Build model...."); - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed) - .l2(0.0005) - .weightInit(WeightInit.XAVIER) - .updater(new Nesterovs(0.01, 0.9)).list() - .layer(0, new FrozenLayer(new ConvolutionLayer.Builder(5, 5) - //nIn and nOut specify depth. nIn here is the nChannels and nOut is the number of filters to be applied - .nIn(nChannels).stride(1, 1).nOut(20).activation(Activation.IDENTITY).build())) - .layer(1, new FrozenLayer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2) - .stride(2, 2).build())) - .layer(2, new FrozenLayer(new DenseLayer.Builder().activation(Activation.RELU).nOut(500).build())) - .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) - .nOut(outputNum).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutionalFlat(28, 28, nChannels)); - - MultiLayerConfiguration conf = builder.build(); - MultiLayerNetwork model = new MultiLayerNetwork(conf); - model.init(); - - log.info("Train model...."); - model.setListeners(new ConvolutionalIterationListener(1)); - - for (int i = 0; i < nEpochs; i++) { - model.fit(mnistTrain); - } - } -} diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/src/test/java/org/deeplearning4j/ui/weights/HistogramBinTest.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/src/test/java/org/deeplearning4j/ui/weights/HistogramBinTest.java index 1db17c0a2..dc9219629 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/src/test/java/org/deeplearning4j/ui/weights/HistogramBinTest.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/src/test/java/org/deeplearning4j/ui/weights/HistogramBinTest.java @@ -21,21 +21,16 @@ package org.deeplearning4j.ui.weights; import org.deeplearning4j.ui.model.weights.HistogramBin; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import java.math.BigDecimal; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; public class HistogramBinTest { - @Before - public void setUp() throws Exception { - - } @Test public void testGetBins() throws Exception { diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/src/test/java/org/deeplearning4j/ui/weights/TestConvolutionalListener.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/src/test/java/org/deeplearning4j/ui/weights/TestConvolutionalListener.java index d44bb3496..d3c0e04a9 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/src/test/java/org/deeplearning4j/ui/weights/TestConvolutionalListener.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/src/test/java/org/deeplearning4j/ui/weights/TestConvolutionalListener.java @@ -32,8 +32,9 @@ import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; -import org.junit.Ignore; -import org.junit.Test; + +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.learning.config.Nesterovs; @@ -42,7 +43,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions; public class TestConvolutionalListener { @Test - @Ignore //Should be run manually + @Disabled public void testUI() throws Exception { int nChannels = 1; // Number of input channels diff --git a/deeplearning4j/deeplearning4j-zoo/pom.xml b/deeplearning4j/deeplearning4j-zoo/pom.xml index 5d26ea0b2..b93606710 100644 --- a/deeplearning4j/deeplearning4j-zoo/pom.xml +++ b/deeplearning4j/deeplearning4j-zoo/pom.xml @@ -55,8 +55,16 @@ - junit - junit + org.junit.jupiter + junit-jupiter-api + ${junit.version} + test + + + org.junit.jupiter + junit-jupiter-engine + ${junit.version} + test ch.qos.logback diff --git a/deeplearning4j/dl4j-integration-tests/pom.xml b/deeplearning4j/dl4j-integration-tests/pom.xml index a8240c828..461d013a7 100644 --- a/deeplearning4j/dl4j-integration-tests/pom.xml +++ b/deeplearning4j/dl4j-integration-tests/pom.xml @@ -64,9 +64,17 @@ ${project.version} + + org.junit.jupiter + junit-jupiter-api + ${junit.version} + test + - junit - junit + org.junit.jupiter + junit-jupiter-engine + ${junit.version} + test ch.qos.logback diff --git a/deeplearning4j/pom.xml b/deeplearning4j/pom.xml index a9687116e..625bafe6b 100644 --- a/deeplearning4j/pom.xml +++ b/deeplearning4j/pom.xml @@ -92,8 +92,14 @@ ${slf4j.version} - junit - junit + org.junit.jupiter + junit-jupiter-api + ${junit.version} + test + + + org.junit.vintage + junit-vintage-engine ${junit.version} test @@ -102,8 +108,16 @@ - junit - junit + org.junit.jupiter + junit-jupiter-api + ${junit.version} + test + + + org.junit.jupiter + junit-jupiter-engine + ${junit.version} + test org.projectlombok @@ -315,7 +329,7 @@ **/*Test.java **/*TestCase.java - junit:junit + org.junit.jupiter:junit-jupiter-engine org.nd4j.linalg.cpu.nativecpu.CpuBackend @@ -364,9 +378,9 @@ maven-surefire-plugin - org.apache.maven.surefire - surefire-junit47 - 2.19.1 + org.junit + surefire-junit5 + 5.0.0-ALPHA diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda-preset/pom.xml b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda-preset/pom.xml index 2e3c63dad..6c0122349 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda-preset/pom.xml +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda-preset/pom.xml @@ -77,9 +77,17 @@ ${dependency.platform} + + org.junit.jupiter + junit-jupiter-api + ${junit.version} + test + - junit - junit + org.junit.jupiter + junit-jupiter-engine + ${junit.version} + test org.nd4j diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/pom.xml b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/pom.xml index 0b4220f8b..cdb5035aa 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/pom.xml +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/pom.xml @@ -83,9 +83,17 @@ + + org.junit.jupiter + junit-jupiter-api + ${junit.version} + test + - junit - junit + org.junit.jupiter + junit-jupiter-engine + ${junit.version} + test org.nd4j diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native-preset/pom.xml b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native-preset/pom.xml index e75b69649..9e748c353 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native-preset/pom.xml +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native-preset/pom.xml @@ -203,6 +203,7 @@ org.bytedeco javacpp + ${javacpp.version} ${javacpp.platform}-mingw @@ -226,6 +227,7 @@ org.bytedeco javacpp + ${javacpp.version} ${javacpp.platform}-mingw diff --git a/nd4j/nd4j-backends/nd4j-tests/ops-added-new.txt b/nd4j/nd4j-backends/nd4j-tests/ops-added-new.txt deleted file mode 100644 index ea6bc3c53..000000000 --- a/nd4j/nd4j-backends/nd4j-tests/ops-added-new.txt +++ /dev/null @@ -1,704 +0,0 @@ -Placeholder,input_tensor -Const,transpose/perm -Const,Pad/paddings -Const,conv2d/kernel -Const,batch_normalization/gamma -Const,batch_normalization/beta -Const,batch_normalization/moving_mean -Const,batch_normalization/moving_variance -Const,conv2d_1/kernel -Const,conv2d_2/kernel -Const,batch_normalization_1/gamma -Const,batch_normalization_1/beta -Const,batch_normalization_1/moving_mean -Const,batch_normalization_1/moving_variance -Const,conv2d_3/kernel -Const,batch_normalization_2/gamma -Const,batch_normalization_2/beta -Const,batch_normalization_2/moving_mean -Const,batch_normalization_2/moving_variance -Const,conv2d_4/kernel -Const,batch_normalization_3/gamma -Const,batch_normalization_3/beta -Const,batch_normalization_3/moving_mean -Const,batch_normalization_3/moving_variance -Const,conv2d_5/kernel -Const,batch_normalization_4/gamma -Const,batch_normalization_4/beta -Const,batch_normalization_4/moving_mean -Const,batch_normalization_4/moving_variance -Const,conv2d_6/kernel -Const,batch_normalization_5/gamma -Const,batch_normalization_5/beta -Const,batch_normalization_5/moving_mean -Const,batch_normalization_5/moving_variance -Const,conv2d_7/kernel -Const,batch_normalization_6/gamma -Const,batch_normalization_6/beta -Const,batch_normalization_6/moving_mean -Const,batch_normalization_6/moving_variance -Const,conv2d_8/kernel -Const,batch_normalization_7/gamma -Const,batch_normalization_7/beta -Const,batch_normalization_7/moving_mean -Const,batch_normalization_7/moving_variance -Const,conv2d_9/kernel -Const,batch_normalization_8/gamma -Const,batch_normalization_8/beta -Const,batch_normalization_8/moving_mean -Const,batch_normalization_8/moving_variance -Const,conv2d_10/kernel -Const,batch_normalization_9/gamma -Const,batch_normalization_9/beta -Const,batch_normalization_9/moving_mean -Const,batch_normalization_9/moving_variance -Const,Pad_1/paddings -Const,conv2d_11/kernel -Const,conv2d_12/kernel -Const,batch_normalization_10/gamma -Const,batch_normalization_10/beta -Const,batch_normalization_10/moving_mean -Const,batch_normalization_10/moving_variance -Const,Pad_2/paddings -Const,conv2d_13/kernel -Const,batch_normalization_11/gamma -Const,batch_normalization_11/beta -Const,batch_normalization_11/moving_mean -Const,batch_normalization_11/moving_variance -Const,conv2d_14/kernel -Const,batch_normalization_12/gamma -Const,batch_normalization_12/beta -Const,batch_normalization_12/moving_mean -Const,batch_normalization_12/moving_variance -Const,conv2d_15/kernel -Const,batch_normalization_13/gamma -Const,batch_normalization_13/beta -Const,batch_normalization_13/moving_mean -Const,batch_normalization_13/moving_variance -Const,conv2d_16/kernel -Const,batch_normalization_14/gamma -Const,batch_normalization_14/beta -Const,batch_normalization_14/moving_mean -Const,batch_normalization_14/moving_variance -Const,conv2d_17/kernel -Const,batch_normalization_15/gamma -Const,batch_normalization_15/beta -Const,batch_normalization_15/moving_mean -Const,batch_normalization_15/moving_variance -Const,conv2d_18/kernel -Const,batch_normalization_16/gamma -Const,batch_normalization_16/beta -Const,batch_normalization_16/moving_mean -Const,batch_normalization_16/moving_variance -Const,conv2d_19/kernel -Const,batch_normalization_17/gamma -Const,batch_normalization_17/beta -Const,batch_normalization_17/moving_mean -Const,batch_normalization_17/moving_variance -Const,conv2d_20/kernel -Const,batch_normalization_18/gamma -Const,batch_normalization_18/beta -Const,batch_normalization_18/moving_mean -Const,batch_normalization_18/moving_variance -Const,conv2d_21/kernel -Const,batch_normalization_19/gamma -Const,batch_normalization_19/beta -Const,batch_normalization_19/moving_mean -Const,batch_normalization_19/moving_variance -Const,conv2d_22/kernel -Const,batch_normalization_20/gamma -Const,batch_normalization_20/beta -Const,batch_normalization_20/moving_mean -Const,batch_normalization_20/moving_variance -Const,conv2d_23/kernel -Const,batch_normalization_21/gamma -Const,batch_normalization_21/beta -Const,batch_normalization_21/moving_mean -Const,batch_normalization_21/moving_variance -Const,Pad_3/paddings -Const,conv2d_24/kernel -Const,conv2d_25/kernel -Const,batch_normalization_22/gamma -Const,batch_normalization_22/beta -Const,batch_normalization_22/moving_mean -Const,batch_normalization_22/moving_variance -Const,Pad_4/paddings -Const,conv2d_26/kernel -Const,batch_normalization_23/gamma -Const,batch_normalization_23/beta -Const,batch_normalization_23/moving_mean -Const,batch_normalization_23/moving_variance -Const,conv2d_27/kernel -Const,batch_normalization_24/gamma -Const,batch_normalization_24/beta -Const,batch_normalization_24/moving_mean -Const,batch_normalization_24/moving_variance -Const,conv2d_28/kernel -Const,batch_normalization_25/gamma -Const,batch_normalization_25/beta -Const,batch_normalization_25/moving_mean -Const,batch_normalization_25/moving_variance -Const,conv2d_29/kernel -Const,batch_normalization_26/gamma -Const,batch_normalization_26/beta -Const,batch_normalization_26/moving_mean -Const,batch_normalization_26/moving_variance -Const,conv2d_30/kernel -Const,batch_normalization_27/gamma -Const,batch_normalization_27/beta -Const,batch_normalization_27/moving_mean -Const,batch_normalization_27/moving_variance -Const,conv2d_31/kernel -Const,batch_normalization_28/gamma -Const,batch_normalization_28/beta -Const,batch_normalization_28/moving_mean -Const,batch_normalization_28/moving_variance -Const,conv2d_32/kernel -Const,batch_normalization_29/gamma -Const,batch_normalization_29/beta -Const,batch_normalization_29/moving_mean -Const,batch_normalization_29/moving_variance -Const,conv2d_33/kernel -Const,batch_normalization_30/gamma -Const,batch_normalization_30/beta -Const,batch_normalization_30/moving_mean -Const,batch_normalization_30/moving_variance -Const,conv2d_34/kernel -Const,batch_normalization_31/gamma -Const,batch_normalization_31/beta -Const,batch_normalization_31/moving_mean -Const,batch_normalization_31/moving_variance -Const,conv2d_35/kernel -Const,batch_normalization_32/gamma -Const,batch_normalization_32/beta -Const,batch_normalization_32/moving_mean -Const,batch_normalization_32/moving_variance -Const,conv2d_36/kernel -Const,batch_normalization_33/gamma -Const,batch_normalization_33/beta -Const,batch_normalization_33/moving_mean -Const,batch_normalization_33/moving_variance -Const,conv2d_37/kernel -Const,batch_normalization_34/gamma -Const,batch_normalization_34/beta -Const,batch_normalization_34/moving_mean -Const,batch_normalization_34/moving_variance -Const,conv2d_38/kernel -Const,batch_normalization_35/gamma -Const,batch_normalization_35/beta -Const,batch_normalization_35/moving_mean -Const,batch_normalization_35/moving_variance -Const,conv2d_39/kernel -Const,batch_normalization_36/gamma -Const,batch_normalization_36/beta -Const,batch_normalization_36/moving_mean -Const,batch_normalization_36/moving_variance -Const,conv2d_40/kernel -Const,batch_normalization_37/gamma -Const,batch_normalization_37/beta -Const,batch_normalization_37/moving_mean -Const,batch_normalization_37/moving_variance -Const,conv2d_41/kernel -Const,batch_normalization_38/gamma -Const,batch_normalization_38/beta -Const,batch_normalization_38/moving_mean -Const,batch_normalization_38/moving_variance -Const,conv2d_42/kernel -Const,batch_normalization_39/gamma -Const,batch_normalization_39/beta -Const,batch_normalization_39/moving_mean -Const,batch_normalization_39/moving_variance -Const,Pad_5/paddings -Const,conv2d_43/kernel -Const,conv2d_44/kernel -Const,batch_normalization_40/gamma -Const,batch_normalization_40/beta -Const,batch_normalization_40/moving_mean -Const,batch_normalization_40/moving_variance -Const,Pad_6/paddings -Const,conv2d_45/kernel -Const,batch_normalization_41/gamma -Const,batch_normalization_41/beta -Const,batch_normalization_41/moving_mean -Const,batch_normalization_41/moving_variance -Const,conv2d_46/kernel -Const,batch_normalization_42/gamma -Const,batch_normalization_42/beta -Const,batch_normalization_42/moving_mean -Const,batch_normalization_42/moving_variance -Const,conv2d_47/kernel -Const,batch_normalization_43/gamma -Const,batch_normalization_43/beta -Const,batch_normalization_43/moving_mean -Const,batch_normalization_43/moving_variance -Const,conv2d_48/kernel -Const,batch_normalization_44/gamma -Const,batch_normalization_44/beta -Const,batch_normalization_44/moving_mean -Const,batch_normalization_44/moving_variance -Const,conv2d_49/kernel -Const,batch_normalization_45/gamma -Const,batch_normalization_45/beta -Const,batch_normalization_45/moving_mean -Const,batch_normalization_45/moving_variance -Const,conv2d_50/kernel -Const,batch_normalization_46/gamma -Const,batch_normalization_46/beta -Const,batch_normalization_46/moving_mean -Const,batch_normalization_46/moving_variance -Const,conv2d_51/kernel -Const,batch_normalization_47/gamma -Const,batch_normalization_47/beta -Const,batch_normalization_47/moving_mean -Const,batch_normalization_47/moving_variance -Const,conv2d_52/kernel -Const,batch_normalization_48/gamma -Const,batch_normalization_48/beta -Const,batch_normalization_48/moving_mean -Const,batch_normalization_48/moving_variance -Const,Mean/reduction_indices -Const,Reshape/shape -Const,dense/kernel -Const,dense/bias -Const,ArgMax/dimension -Transpose,transpose -Identity,conv2d/kernel/read -Identity,batch_normalization/gamma/read -Identity,batch_normalization/beta/read -Identity,batch_normalization/moving_mean/read -Identity,batch_normalization/moving_variance/read -Identity,conv2d_1/kernel/read -Identity,conv2d_2/kernel/read -Identity,batch_normalization_1/gamma/read -Identity,batch_normalization_1/beta/read -Identity,batch_normalization_1/moving_mean/read -Identity,batch_normalization_1/moving_variance/read -Identity,conv2d_3/kernel/read -Identity,batch_normalization_2/gamma/read -Identity,batch_normalization_2/beta/read -Identity,batch_normalization_2/moving_mean/read -Identity,batch_normalization_2/moving_variance/read -Identity,conv2d_4/kernel/read -Identity,batch_normalization_3/gamma/read -Identity,batch_normalization_3/beta/read -Identity,batch_normalization_3/moving_mean/read -Identity,batch_normalization_3/moving_variance/read -Identity,conv2d_5/kernel/read -Identity,batch_normalization_4/gamma/read -Identity,batch_normalization_4/beta/read -Identity,batch_normalization_4/moving_mean/read -Identity,batch_normalization_4/moving_variance/read -Identity,conv2d_6/kernel/read -Identity,batch_normalization_5/gamma/read -Identity,batch_normalization_5/beta/read -Identity,batch_normalization_5/moving_mean/read -Identity,batch_normalization_5/moving_variance/read -Identity,conv2d_7/kernel/read -Identity,batch_normalization_6/gamma/read -Identity,batch_normalization_6/beta/read -Identity,batch_normalization_6/moving_mean/read -Identity,batch_normalization_6/moving_variance/read -Identity,conv2d_8/kernel/read -Identity,batch_normalization_7/gamma/read -Identity,batch_normalization_7/beta/read -Identity,batch_normalization_7/moving_mean/read -Identity,batch_normalization_7/moving_variance/read -Identity,conv2d_9/kernel/read -Identity,batch_normalization_8/gamma/read -Identity,batch_normalization_8/beta/read -Identity,batch_normalization_8/moving_mean/read -Identity,batch_normalization_8/moving_variance/read -Identity,conv2d_10/kernel/read -Identity,batch_normalization_9/gamma/read -Identity,batch_normalization_9/beta/read -Identity,batch_normalization_9/moving_mean/read -Identity,batch_normalization_9/moving_variance/read -Identity,conv2d_11/kernel/read -Identity,conv2d_12/kernel/read -Identity,batch_normalization_10/gamma/read -Identity,batch_normalization_10/beta/read -Identity,batch_normalization_10/moving_mean/read -Identity,batch_normalization_10/moving_variance/read -Identity,conv2d_13/kernel/read -Identity,batch_normalization_11/gamma/read -Identity,batch_normalization_11/beta/read -Identity,batch_normalization_11/moving_mean/read -Identity,batch_normalization_11/moving_variance/read -Identity,conv2d_14/kernel/read -Identity,batch_normalization_12/gamma/read -Identity,batch_normalization_12/beta/read -Identity,batch_normalization_12/moving_mean/read -Identity,batch_normalization_12/moving_variance/read -Identity,conv2d_15/kernel/read -Identity,batch_normalization_13/gamma/read -Identity,batch_normalization_13/beta/read -Identity,batch_normalization_13/moving_mean/read -Identity,batch_normalization_13/moving_variance/read -Identity,conv2d_16/kernel/read -Identity,batch_normalization_14/gamma/read -Identity,batch_normalization_14/beta/read -Identity,batch_normalization_14/moving_mean/read -Identity,batch_normalization_14/moving_variance/read -Identity,conv2d_17/kernel/read -Identity,batch_normalization_15/gamma/read -Identity,batch_normalization_15/beta/read -Identity,batch_normalization_15/moving_mean/read -Identity,batch_normalization_15/moving_variance/read -Identity,conv2d_18/kernel/read -Identity,batch_normalization_16/gamma/read -Identity,batch_normalization_16/beta/read -Identity,batch_normalization_16/moving_mean/read -Identity,batch_normalization_16/moving_variance/read -Identity,conv2d_19/kernel/read -Identity,batch_normalization_17/gamma/read -Identity,batch_normalization_17/beta/read -Identity,batch_normalization_17/moving_mean/read -Identity,batch_normalization_17/moving_variance/read -Identity,conv2d_20/kernel/read -Identity,batch_normalization_18/gamma/read -Identity,batch_normalization_18/beta/read -Identity,batch_normalization_18/moving_mean/read -Identity,batch_normalization_18/moving_variance/read -Identity,conv2d_21/kernel/read -Identity,batch_normalization_19/gamma/read -Identity,batch_normalization_19/beta/read -Identity,batch_normalization_19/moving_mean/read -Identity,batch_normalization_19/moving_variance/read -Identity,conv2d_22/kernel/read -Identity,batch_normalization_20/gamma/read -Identity,batch_normalization_20/beta/read -Identity,batch_normalization_20/moving_mean/read -Identity,batch_normalization_20/moving_variance/read -Identity,conv2d_23/kernel/read -Identity,batch_normalization_21/gamma/read -Identity,batch_normalization_21/beta/read -Identity,batch_normalization_21/moving_mean/read -Identity,batch_normalization_21/moving_variance/read -Identity,conv2d_24/kernel/read -Identity,conv2d_25/kernel/read -Identity,batch_normalization_22/gamma/read -Identity,batch_normalization_22/beta/read -Identity,batch_normalization_22/moving_mean/read -Identity,batch_normalization_22/moving_variance/read -Identity,conv2d_26/kernel/read -Identity,batch_normalization_23/gamma/read -Identity,batch_normalization_23/beta/read -Identity,batch_normalization_23/moving_mean/read -Identity,batch_normalization_23/moving_variance/read -Identity,conv2d_27/kernel/read -Identity,batch_normalization_24/gamma/read -Identity,batch_normalization_24/beta/read -Identity,batch_normalization_24/moving_mean/read -Identity,batch_normalization_24/moving_variance/read -Identity,conv2d_28/kernel/read -Identity,batch_normalization_25/gamma/read -Identity,batch_normalization_25/beta/read -Identity,batch_normalization_25/moving_mean/read -Identity,batch_normalization_25/moving_variance/read -Identity,conv2d_29/kernel/read -Identity,batch_normalization_26/gamma/read -Identity,batch_normalization_26/beta/read -Identity,batch_normalization_26/moving_mean/read -Identity,batch_normalization_26/moving_variance/read -Identity,conv2d_30/kernel/read -Identity,batch_normalization_27/gamma/read -Identity,batch_normalization_27/beta/read -Identity,batch_normalization_27/moving_mean/read -Identity,batch_normalization_27/moving_variance/read -Identity,conv2d_31/kernel/read -Identity,batch_normalization_28/gamma/read -Identity,batch_normalization_28/beta/read -Identity,batch_normalization_28/moving_mean/read -Identity,batch_normalization_28/moving_variance/read -Identity,conv2d_32/kernel/read -Identity,batch_normalization_29/gamma/read -Identity,batch_normalization_29/beta/read -Identity,batch_normalization_29/moving_mean/read -Identity,batch_normalization_29/moving_variance/read -Identity,conv2d_33/kernel/read -Identity,batch_normalization_30/gamma/read -Identity,batch_normalization_30/beta/read -Identity,batch_normalization_30/moving_mean/read -Identity,batch_normalization_30/moving_variance/read -Identity,conv2d_34/kernel/read -Identity,batch_normalization_31/gamma/read -Identity,batch_normalization_31/beta/read -Identity,batch_normalization_31/moving_mean/read -Identity,batch_normalization_31/moving_variance/read -Identity,conv2d_35/kernel/read -Identity,batch_normalization_32/gamma/read -Identity,batch_normalization_32/beta/read -Identity,batch_normalization_32/moving_mean/read -Identity,batch_normalization_32/moving_variance/read -Identity,conv2d_36/kernel/read -Identity,batch_normalization_33/gamma/read -Identity,batch_normalization_33/beta/read -Identity,batch_normalization_33/moving_mean/read -Identity,batch_normalization_33/moving_variance/read -Identity,conv2d_37/kernel/read -Identity,batch_normalization_34/gamma/read -Identity,batch_normalization_34/beta/read -Identity,batch_normalization_34/moving_mean/read -Identity,batch_normalization_34/moving_variance/read -Identity,conv2d_38/kernel/read -Identity,batch_normalization_35/gamma/read -Identity,batch_normalization_35/beta/read -Identity,batch_normalization_35/moving_mean/read -Identity,batch_normalization_35/moving_variance/read -Identity,conv2d_39/kernel/read -Identity,batch_normalization_36/gamma/read -Identity,batch_normalization_36/beta/read -Identity,batch_normalization_36/moving_mean/read -Identity,batch_normalization_36/moving_variance/read -Identity,conv2d_40/kernel/read -Identity,batch_normalization_37/gamma/read -Identity,batch_normalization_37/beta/read -Identity,batch_normalization_37/moving_mean/read -Identity,batch_normalization_37/moving_variance/read -Identity,conv2d_41/kernel/read -Identity,batch_normalization_38/gamma/read -Identity,batch_normalization_38/beta/read -Identity,batch_normalization_38/moving_mean/read -Identity,batch_normalization_38/moving_variance/read -Identity,conv2d_42/kernel/read -Identity,batch_normalization_39/gamma/read -Identity,batch_normalization_39/beta/read -Identity,batch_normalization_39/moving_mean/read -Identity,batch_normalization_39/moving_variance/read -Identity,conv2d_43/kernel/read -Identity,conv2d_44/kernel/read -Identity,batch_normalization_40/gamma/read -Identity,batch_normalization_40/beta/read -Identity,batch_normalization_40/moving_mean/read -Identity,batch_normalization_40/moving_variance/read -Identity,conv2d_45/kernel/read -Identity,batch_normalization_41/gamma/read -Identity,batch_normalization_41/beta/read -Identity,batch_normalization_41/moving_mean/read -Identity,batch_normalization_41/moving_variance/read -Identity,conv2d_46/kernel/read -Identity,batch_normalization_42/gamma/read -Identity,batch_normalization_42/beta/read -Identity,batch_normalization_42/moving_mean/read -Identity,batch_normalization_42/moving_variance/read -Identity,conv2d_47/kernel/read -Identity,batch_normalization_43/gamma/read -Identity,batch_normalization_43/beta/read -Identity,batch_normalization_43/moving_mean/read -Identity,batch_normalization_43/moving_variance/read -Identity,conv2d_48/kernel/read -Identity,batch_normalization_44/gamma/read -Identity,batch_normalization_44/beta/read -Identity,batch_normalization_44/moving_mean/read -Identity,batch_normalization_44/moving_variance/read -Identity,conv2d_49/kernel/read -Identity,batch_normalization_45/gamma/read -Identity,batch_normalization_45/beta/read -Identity,batch_normalization_45/moving_mean/read -Identity,batch_normalization_45/moving_variance/read -Identity,conv2d_50/kernel/read -Identity,batch_normalization_46/gamma/read -Identity,batch_normalization_46/beta/read -Identity,batch_normalization_46/moving_mean/read -Identity,batch_normalization_46/moving_variance/read -Identity,conv2d_51/kernel/read -Identity,batch_normalization_47/gamma/read -Identity,batch_normalization_47/beta/read -Identity,batch_normalization_47/moving_mean/read -Identity,batch_normalization_47/moving_variance/read -Identity,conv2d_52/kernel/read -Identity,batch_normalization_48/gamma/read -Identity,batch_normalization_48/beta/read -Identity,batch_normalization_48/moving_mean/read -Identity,batch_normalization_48/moving_variance/read -Identity,dense/kernel/read -Identity,dense/bias/read -Pad,Pad -Conv2D,conv2d/Conv2D -Identity,initial_conv -MaxPool,max_pooling2d/MaxPool -Identity,initial_max_pool -FusedBatchNorm,batch_normalization/FusedBatchNorm -Relu,Relu -Conv2D,conv2d_1/Conv2D -Conv2D,conv2d_2/Conv2D -FusedBatchNorm,batch_normalization_1/FusedBatchNorm -Relu,Relu_1 -Conv2D,conv2d_3/Conv2D -FusedBatchNorm,batch_normalization_2/FusedBatchNorm -Relu,Relu_2 -Conv2D,conv2d_4/Conv2D -Add,add -FusedBatchNorm,batch_normalization_3/FusedBatchNorm -Relu,Relu_3 -Conv2D,conv2d_5/Conv2D -FusedBatchNorm,batch_normalization_4/FusedBatchNorm -Relu,Relu_4 -Conv2D,conv2d_6/Conv2D -FusedBatchNorm,batch_normalization_5/FusedBatchNorm -Relu,Relu_5 -Conv2D,conv2d_7/Conv2D -Add,add_1 -FusedBatchNorm,batch_normalization_6/FusedBatchNorm -Relu,Relu_6 -Conv2D,conv2d_8/Conv2D -FusedBatchNorm,batch_normalization_7/FusedBatchNorm -Relu,Relu_7 -Conv2D,conv2d_9/Conv2D -FusedBatchNorm,batch_normalization_8/FusedBatchNorm -Relu,Relu_8 -Conv2D,conv2d_10/Conv2D -Add,add_2 -Identity,block_layer1 -FusedBatchNorm,batch_normalization_9/FusedBatchNorm -Relu,Relu_9 -Pad,Pad_1 -Conv2D,conv2d_12/Conv2D -Conv2D,conv2d_11/Conv2D -FusedBatchNorm,batch_normalization_10/FusedBatchNorm -Relu,Relu_10 -Pad,Pad_2 -Conv2D,conv2d_13/Conv2D -FusedBatchNorm,batch_normalization_11/FusedBatchNorm -Relu,Relu_11 -Conv2D,conv2d_14/Conv2D -Add,add_3 -FusedBatchNorm,batch_normalization_12/FusedBatchNorm -Relu,Relu_12 -Conv2D,conv2d_15/Conv2D -FusedBatchNorm,batch_normalization_13/FusedBatchNorm -Relu,Relu_13 -Conv2D,conv2d_16/Conv2D -FusedBatchNorm,batch_normalization_14/FusedBatchNorm -Relu,Relu_14 -Conv2D,conv2d_17/Conv2D -Add,add_4 -FusedBatchNorm,batch_normalization_15/FusedBatchNorm -Relu,Relu_15 -Conv2D,conv2d_18/Conv2D -FusedBatchNorm,batch_normalization_16/FusedBatchNorm -Relu,Relu_16 -Conv2D,conv2d_19/Conv2D -FusedBatchNorm,batch_normalization_17/FusedBatchNorm -Relu,Relu_17 -Conv2D,conv2d_20/Conv2D -Add,add_5 -FusedBatchNorm,batch_normalization_18/FusedBatchNorm -Relu,Relu_18 -Conv2D,conv2d_21/Conv2D -FusedBatchNorm,batch_normalization_19/FusedBatchNorm -Relu,Relu_19 -Conv2D,conv2d_22/Conv2D -FusedBatchNorm,batch_normalization_20/FusedBatchNorm -Relu,Relu_20 -Conv2D,conv2d_23/Conv2D -Add,add_6 -Identity,block_layer2 -FusedBatchNorm,batch_normalization_21/FusedBatchNorm -Relu,Relu_21 -Pad,Pad_3 -Conv2D,conv2d_25/Conv2D -Conv2D,conv2d_24/Conv2D -FusedBatchNorm,batch_normalization_22/FusedBatchNorm -Relu,Relu_22 -Pad,Pad_4 -Conv2D,conv2d_26/Conv2D -FusedBatchNorm,batch_normalization_23/FusedBatchNorm -Relu,Relu_23 -Conv2D,conv2d_27/Conv2D -Add,add_7 -FusedBatchNorm,batch_normalization_24/FusedBatchNorm -Relu,Relu_24 -Conv2D,conv2d_28/Conv2D -FusedBatchNorm,batch_normalization_25/FusedBatchNorm -Relu,Relu_25 -Conv2D,conv2d_29/Conv2D -FusedBatchNorm,batch_normalization_26/FusedBatchNorm -Relu,Relu_26 -Conv2D,conv2d_30/Conv2D -Add,add_8 -FusedBatchNorm,batch_normalization_27/FusedBatchNorm -Relu,Relu_27 -Conv2D,conv2d_31/Conv2D -FusedBatchNorm,batch_normalization_28/FusedBatchNorm -Relu,Relu_28 -Conv2D,conv2d_32/Conv2D -FusedBatchNorm,batch_normalization_29/FusedBatchNorm -Relu,Relu_29 -Conv2D,conv2d_33/Conv2D -Add,add_9 -FusedBatchNorm,batch_normalization_30/FusedBatchNorm -Relu,Relu_30 -Conv2D,conv2d_34/Conv2D -FusedBatchNorm,batch_normalization_31/FusedBatchNorm -Relu,Relu_31 -Conv2D,conv2d_35/Conv2D -FusedBatchNorm,batch_normalization_32/FusedBatchNorm -Relu,Relu_32 -Conv2D,conv2d_36/Conv2D -Add,add_10 -FusedBatchNorm,batch_normalization_33/FusedBatchNorm -Relu,Relu_33 -Conv2D,conv2d_37/Conv2D -FusedBatchNorm,batch_normalization_34/FusedBatchNorm -Relu,Relu_34 -Conv2D,conv2d_38/Conv2D -FusedBatchNorm,batch_normalization_35/FusedBatchNorm -Relu,Relu_35 -Conv2D,conv2d_39/Conv2D -Add,add_11 -FusedBatchNorm,batch_normalization_36/FusedBatchNorm -Relu,Relu_36 -Conv2D,conv2d_40/Conv2D -FusedBatchNorm,batch_normalization_37/FusedBatchNorm -Relu,Relu_37 -Conv2D,conv2d_41/Conv2D -FusedBatchNorm,batch_normalization_38/FusedBatchNorm -Relu,Relu_38 -Conv2D,conv2d_42/Conv2D -Add,add_12 -Identity,block_layer3 -FusedBatchNorm,batch_normalization_39/FusedBatchNorm -Relu,Relu_39 -Pad,Pad_5 -Conv2D,conv2d_44/Conv2D -Conv2D,conv2d_43/Conv2D -FusedBatchNorm,batch_normalization_40/FusedBatchNorm -Relu,Relu_40 -Pad,Pad_6 -Conv2D,conv2d_45/Conv2D -FusedBatchNorm,batch_normalization_41/FusedBatchNorm -Relu,Relu_41 -Conv2D,conv2d_46/Conv2D -Add,add_13 -FusedBatchNorm,batch_normalization_42/FusedBatchNorm -Relu,Relu_42 -Conv2D,conv2d_47/Conv2D -FusedBatchNorm,batch_normalization_43/FusedBatchNorm -Relu,Relu_43 -Conv2D,conv2d_48/Conv2D -FusedBatchNorm,batch_normalization_44/FusedBatchNorm -Relu,Relu_44 -Conv2D,conv2d_49/Conv2D -Add,add_14 -FusedBatchNorm,batch_normalization_45/FusedBatchNorm -Relu,Relu_45 -Conv2D,conv2d_50/Conv2D -FusedBatchNorm,batch_normalization_46/FusedBatchNorm -Relu,Relu_46 -Conv2D,conv2d_51/Conv2D -FusedBatchNorm,batch_normalization_47/FusedBatchNorm -Relu,Relu_47 -Conv2D,conv2d_52/Conv2D -Add,add_15 -Identity,block_layer4 -FusedBatchNorm,batch_normalization_48/FusedBatchNorm -Relu,Relu_48 -Mean,Mean -Identity,final_reduce_mean -Reshape,Reshape -MatMul,dense/MatMul -BiasAdd,dense/BiasAdd -Identity,final_dense -ArgMax,ArgMax -Softmax,softmax_tensor diff --git a/nd4j/nd4j-backends/nd4j-tests/ops-added-old.txt b/nd4j/nd4j-backends/nd4j-tests/ops-added-old.txt deleted file mode 100644 index 04b25fc95..000000000 --- a/nd4j/nd4j-backends/nd4j-tests/ops-added-old.txt +++ /dev/null @@ -1,3 +0,0 @@ -Const,alpha -Const,Sum/reduction_indices -Sum,Sum diff --git a/nd4j/nd4j-backends/nd4j-tests/ops-imported-new.txt b/nd4j/nd4j-backends/nd4j-tests/ops-imported-new.txt deleted file mode 100644 index dc60391dd..000000000 --- a/nd4j/nd4j-backends/nd4j-tests/ops-imported-new.txt +++ /dev/null @@ -1,441 +0,0 @@ -Transpose,transpose -Identity,conv2d/kernel/read -Identity,batch_normalization/gamma/read -Identity,batch_normalization/beta/read -Identity,batch_normalization/moving_mean/read -Identity,batch_normalization/moving_variance/read -Identity,conv2d_1/kernel/read -Identity,conv2d_2/kernel/read -Identity,batch_normalization_1/gamma/read -Identity,batch_normalization_1/beta/read -Identity,batch_normalization_1/moving_mean/read -Identity,batch_normalization_1/moving_variance/read -Identity,conv2d_3/kernel/read -Identity,batch_normalization_2/gamma/read -Identity,batch_normalization_2/beta/read -Identity,batch_normalization_2/moving_mean/read -Identity,batch_normalization_2/moving_variance/read -Identity,conv2d_4/kernel/read -Identity,batch_normalization_3/gamma/read -Identity,batch_normalization_3/beta/read -Identity,batch_normalization_3/moving_mean/read -Identity,batch_normalization_3/moving_variance/read -Identity,conv2d_5/kernel/read -Identity,batch_normalization_4/gamma/read -Identity,batch_normalization_4/beta/read -Identity,batch_normalization_4/moving_mean/read -Identity,batch_normalization_4/moving_variance/read -Identity,conv2d_6/kernel/read -Identity,batch_normalization_5/gamma/read -Identity,batch_normalization_5/beta/read -Identity,batch_normalization_5/moving_mean/read -Identity,batch_normalization_5/moving_variance/read -Identity,conv2d_7/kernel/read -Identity,batch_normalization_6/gamma/read -Identity,batch_normalization_6/beta/read -Identity,batch_normalization_6/moving_mean/read -Identity,batch_normalization_6/moving_variance/read -Identity,conv2d_8/kernel/read -Identity,batch_normalization_7/gamma/read -Identity,batch_normalization_7/beta/read -Identity,batch_normalization_7/moving_mean/read -Identity,batch_normalization_7/moving_variance/read -Identity,conv2d_9/kernel/read -Identity,batch_normalization_8/gamma/read -Identity,batch_normalization_8/beta/read -Identity,batch_normalization_8/moving_mean/read -Identity,batch_normalization_8/moving_variance/read -Identity,conv2d_10/kernel/read -Identity,batch_normalization_9/gamma/read -Identity,batch_normalization_9/beta/read -Identity,batch_normalization_9/moving_mean/read -Identity,batch_normalization_9/moving_variance/read -Identity,conv2d_11/kernel/read -Identity,conv2d_12/kernel/read -Identity,batch_normalization_10/gamma/read -Identity,batch_normalization_10/beta/read -Identity,batch_normalization_10/moving_mean/read -Identity,batch_normalization_10/moving_variance/read -Identity,conv2d_13/kernel/read -Identity,batch_normalization_11/gamma/read -Identity,batch_normalization_11/beta/read -Identity,batch_normalization_11/moving_mean/read -Identity,batch_normalization_11/moving_variance/read -Identity,conv2d_14/kernel/read -Identity,batch_normalization_12/gamma/read -Identity,batch_normalization_12/beta/read -Identity,batch_normalization_12/moving_mean/read -Identity,batch_normalization_12/moving_variance/read -Identity,conv2d_15/kernel/read -Identity,batch_normalization_13/gamma/read -Identity,batch_normalization_13/beta/read -Identity,batch_normalization_13/moving_mean/read -Identity,batch_normalization_13/moving_variance/read -Identity,conv2d_16/kernel/read -Identity,batch_normalization_14/gamma/read -Identity,batch_normalization_14/beta/read -Identity,batch_normalization_14/moving_mean/read -Identity,batch_normalization_14/moving_variance/read -Identity,conv2d_17/kernel/read -Identity,batch_normalization_15/gamma/read -Identity,batch_normalization_15/beta/read -Identity,batch_normalization_15/moving_mean/read -Identity,batch_normalization_15/moving_variance/read -Identity,conv2d_18/kernel/read -Identity,batch_normalization_16/gamma/read -Identity,batch_normalization_16/beta/read -Identity,batch_normalization_16/moving_mean/read -Identity,batch_normalization_16/moving_variance/read -Identity,conv2d_19/kernel/read -Identity,batch_normalization_17/gamma/read -Identity,batch_normalization_17/beta/read -Identity,batch_normalization_17/moving_mean/read -Identity,batch_normalization_17/moving_variance/read -Identity,conv2d_20/kernel/read -Identity,batch_normalization_18/gamma/read -Identity,batch_normalization_18/beta/read -Identity,batch_normalization_18/moving_mean/read -Identity,batch_normalization_18/moving_variance/read -Identity,conv2d_21/kernel/read -Identity,batch_normalization_19/gamma/read -Identity,batch_normalization_19/beta/read -Identity,batch_normalization_19/moving_mean/read -Identity,batch_normalization_19/moving_variance/read -Identity,conv2d_22/kernel/read -Identity,batch_normalization_20/gamma/read -Identity,batch_normalization_20/beta/read -Identity,batch_normalization_20/moving_mean/read -Identity,batch_normalization_20/moving_variance/read -Identity,conv2d_23/kernel/read -Identity,batch_normalization_21/gamma/read -Identity,batch_normalization_21/beta/read -Identity,batch_normalization_21/moving_mean/read -Identity,batch_normalization_21/moving_variance/read -Identity,conv2d_24/kernel/read -Identity,conv2d_25/kernel/read -Identity,batch_normalization_22/gamma/read -Identity,batch_normalization_22/beta/read -Identity,batch_normalization_22/moving_mean/read -Identity,batch_normalization_22/moving_variance/read -Identity,conv2d_26/kernel/read -Identity,batch_normalization_23/gamma/read -Identity,batch_normalization_23/beta/read -Identity,batch_normalization_23/moving_mean/read -Identity,batch_normalization_23/moving_variance/read -Identity,conv2d_27/kernel/read -Identity,batch_normalization_24/gamma/read -Identity,batch_normalization_24/beta/read -Identity,batch_normalization_24/moving_mean/read -Identity,batch_normalization_24/moving_variance/read -Identity,conv2d_28/kernel/read -Identity,batch_normalization_25/gamma/read -Identity,batch_normalization_25/beta/read -Identity,batch_normalization_25/moving_mean/read -Identity,batch_normalization_25/moving_variance/read -Identity,conv2d_29/kernel/read -Identity,batch_normalization_26/gamma/read -Identity,batch_normalization_26/beta/read -Identity,batch_normalization_26/moving_mean/read -Identity,batch_normalization_26/moving_variance/read -Identity,conv2d_30/kernel/read -Identity,batch_normalization_27/gamma/read -Identity,batch_normalization_27/beta/read -Identity,batch_normalization_27/moving_mean/read -Identity,batch_normalization_27/moving_variance/read -Identity,conv2d_31/kernel/read -Identity,batch_normalization_28/gamma/read -Identity,batch_normalization_28/beta/read -Identity,batch_normalization_28/moving_mean/read -Identity,batch_normalization_28/moving_variance/read -Identity,conv2d_32/kernel/read -Identity,batch_normalization_29/gamma/read -Identity,batch_normalization_29/beta/read -Identity,batch_normalization_29/moving_mean/read -Identity,batch_normalization_29/moving_variance/read -Identity,conv2d_33/kernel/read -Identity,batch_normalization_30/gamma/read -Identity,batch_normalization_30/beta/read -Identity,batch_normalization_30/moving_mean/read -Identity,batch_normalization_30/moving_variance/read -Identity,conv2d_34/kernel/read -Identity,batch_normalization_31/gamma/read -Identity,batch_normalization_31/beta/read -Identity,batch_normalization_31/moving_mean/read -Identity,batch_normalization_31/moving_variance/read -Identity,conv2d_35/kernel/read -Identity,batch_normalization_32/gamma/read -Identity,batch_normalization_32/beta/read -Identity,batch_normalization_32/moving_mean/read -Identity,batch_normalization_32/moving_variance/read -Identity,conv2d_36/kernel/read -Identity,batch_normalization_33/gamma/read -Identity,batch_normalization_33/beta/read -Identity,batch_normalization_33/moving_mean/read -Identity,batch_normalization_33/moving_variance/read -Identity,conv2d_37/kernel/read -Identity,batch_normalization_34/gamma/read -Identity,batch_normalization_34/beta/read -Identity,batch_normalization_34/moving_mean/read -Identity,batch_normalization_34/moving_variance/read -Identity,conv2d_38/kernel/read -Identity,batch_normalization_35/gamma/read -Identity,batch_normalization_35/beta/read -Identity,batch_normalization_35/moving_mean/read -Identity,batch_normalization_35/moving_variance/read -Identity,conv2d_39/kernel/read -Identity,batch_normalization_36/gamma/read -Identity,batch_normalization_36/beta/read -Identity,batch_normalization_36/moving_mean/read -Identity,batch_normalization_36/moving_variance/read -Identity,conv2d_40/kernel/read -Identity,batch_normalization_37/gamma/read -Identity,batch_normalization_37/beta/read -Identity,batch_normalization_37/moving_mean/read -Identity,batch_normalization_37/moving_variance/read -Identity,conv2d_41/kernel/read -Identity,batch_normalization_38/gamma/read -Identity,batch_normalization_38/beta/read -Identity,batch_normalization_38/moving_mean/read -Identity,batch_normalization_38/moving_variance/read -Identity,conv2d_42/kernel/read -Identity,batch_normalization_39/gamma/read -Identity,batch_normalization_39/beta/read -Identity,batch_normalization_39/moving_mean/read -Identity,batch_normalization_39/moving_variance/read -Identity,conv2d_43/kernel/read -Identity,conv2d_44/kernel/read -Identity,batch_normalization_40/gamma/read -Identity,batch_normalization_40/beta/read -Identity,batch_normalization_40/moving_mean/read -Identity,batch_normalization_40/moving_variance/read -Identity,conv2d_45/kernel/read -Identity,batch_normalization_41/gamma/read -Identity,batch_normalization_41/beta/read -Identity,batch_normalization_41/moving_mean/read -Identity,batch_normalization_41/moving_variance/read -Identity,conv2d_46/kernel/read -Identity,batch_normalization_42/gamma/read -Identity,batch_normalization_42/beta/read -Identity,batch_normalization_42/moving_mean/read -Identity,batch_normalization_42/moving_variance/read -Identity,conv2d_47/kernel/read -Identity,batch_normalization_43/gamma/read -Identity,batch_normalization_43/beta/read -Identity,batch_normalization_43/moving_mean/read -Identity,batch_normalization_43/moving_variance/read -Identity,conv2d_48/kernel/read -Identity,batch_normalization_44/gamma/read -Identity,batch_normalization_44/beta/read -Identity,batch_normalization_44/moving_mean/read -Identity,batch_normalization_44/moving_variance/read -Identity,conv2d_49/kernel/read -Identity,batch_normalization_45/gamma/read -Identity,batch_normalization_45/beta/read -Identity,batch_normalization_45/moving_mean/read -Identity,batch_normalization_45/moving_variance/read -Identity,conv2d_50/kernel/read -Identity,batch_normalization_46/gamma/read -Identity,batch_normalization_46/beta/read -Identity,batch_normalization_46/moving_mean/read -Identity,batch_normalization_46/moving_variance/read -Identity,conv2d_51/kernel/read -Identity,batch_normalization_47/gamma/read -Identity,batch_normalization_47/beta/read -Identity,batch_normalization_47/moving_mean/read -Identity,batch_normalization_47/moving_variance/read -Identity,conv2d_52/kernel/read -Identity,batch_normalization_48/gamma/read -Identity,batch_normalization_48/beta/read -Identity,batch_normalization_48/moving_mean/read -Identity,batch_normalization_48/moving_variance/read -Identity,dense/kernel/read -Identity,dense/bias/read -Pad,Pad -Conv2D,conv2d/Conv2D -Identity,initial_conv -MaxPool,max_pooling2d/MaxPool -Identity,initial_max_pool -FusedBatchNorm,batch_normalization/FusedBatchNorm -Relu,Relu -Conv2D,conv2d_1/Conv2D -Conv2D,conv2d_2/Conv2D -FusedBatchNorm,batch_normalization_1/FusedBatchNorm -Relu,Relu_1 -Conv2D,conv2d_3/Conv2D -FusedBatchNorm,batch_normalization_2/FusedBatchNorm -Relu,Relu_2 -Conv2D,conv2d_4/Conv2D -Add,add -FusedBatchNorm,batch_normalization_3/FusedBatchNorm -Relu,Relu_3 -Conv2D,conv2d_5/Conv2D -FusedBatchNorm,batch_normalization_4/FusedBatchNorm -Relu,Relu_4 -Conv2D,conv2d_6/Conv2D -FusedBatchNorm,batch_normalization_5/FusedBatchNorm -Relu,Relu_5 -Conv2D,conv2d_7/Conv2D -Add,add_1 -FusedBatchNorm,batch_normalization_6/FusedBatchNorm -Relu,Relu_6 -Conv2D,conv2d_8/Conv2D -FusedBatchNorm,batch_normalization_7/FusedBatchNorm -Relu,Relu_7 -Conv2D,conv2d_9/Conv2D -FusedBatchNorm,batch_normalization_8/FusedBatchNorm -Relu,Relu_8 -Conv2D,conv2d_10/Conv2D -Add,add_2 -Identity,block_layer1 -FusedBatchNorm,batch_normalization_9/FusedBatchNorm -Relu,Relu_9 -Pad,Pad_1 -Conv2D,conv2d_12/Conv2D -Conv2D,conv2d_11/Conv2D -FusedBatchNorm,batch_normalization_10/FusedBatchNorm -Relu,Relu_10 -Pad,Pad_2 -Conv2D,conv2d_13/Conv2D -FusedBatchNorm,batch_normalization_11/FusedBatchNorm -Relu,Relu_11 -Conv2D,conv2d_14/Conv2D -Add,add_3 -FusedBatchNorm,batch_normalization_12/FusedBatchNorm -Relu,Relu_12 -Conv2D,conv2d_15/Conv2D -FusedBatchNorm,batch_normalization_13/FusedBatchNorm -Relu,Relu_13 -Conv2D,conv2d_16/Conv2D -FusedBatchNorm,batch_normalization_14/FusedBatchNorm -Relu,Relu_14 -Conv2D,conv2d_17/Conv2D -Add,add_4 -FusedBatchNorm,batch_normalization_15/FusedBatchNorm -Relu,Relu_15 -Conv2D,conv2d_18/Conv2D -FusedBatchNorm,batch_normalization_16/FusedBatchNorm -Relu,Relu_16 -Conv2D,conv2d_19/Conv2D -FusedBatchNorm,batch_normalization_17/FusedBatchNorm -Relu,Relu_17 -Conv2D,conv2d_20/Conv2D -Add,add_5 -FusedBatchNorm,batch_normalization_18/FusedBatchNorm -Relu,Relu_18 -Conv2D,conv2d_21/Conv2D -FusedBatchNorm,batch_normalization_19/FusedBatchNorm -Relu,Relu_19 -Conv2D,conv2d_22/Conv2D -FusedBatchNorm,batch_normalization_20/FusedBatchNorm -Relu,Relu_20 -Conv2D,conv2d_23/Conv2D -Add,add_6 -Identity,block_layer2 -FusedBatchNorm,batch_normalization_21/FusedBatchNorm -Relu,Relu_21 -Pad,Pad_3 -Conv2D,conv2d_25/Conv2D -Conv2D,conv2d_24/Conv2D -FusedBatchNorm,batch_normalization_22/FusedBatchNorm -Relu,Relu_22 -Pad,Pad_4 -Conv2D,conv2d_26/Conv2D -FusedBatchNorm,batch_normalization_23/FusedBatchNorm -Relu,Relu_23 -Conv2D,conv2d_27/Conv2D -Add,add_7 -FusedBatchNorm,batch_normalization_24/FusedBatchNorm -Relu,Relu_24 -Conv2D,conv2d_28/Conv2D -FusedBatchNorm,batch_normalization_25/FusedBatchNorm -Relu,Relu_25 -Conv2D,conv2d_29/Conv2D -FusedBatchNorm,batch_normalization_26/FusedBatchNorm -Relu,Relu_26 -Conv2D,conv2d_30/Conv2D -Add,add_8 -FusedBatchNorm,batch_normalization_27/FusedBatchNorm -Relu,Relu_27 -Conv2D,conv2d_31/Conv2D -FusedBatchNorm,batch_normalization_28/FusedBatchNorm -Relu,Relu_28 -Conv2D,conv2d_32/Conv2D -FusedBatchNorm,batch_normalization_29/FusedBatchNorm -Relu,Relu_29 -Conv2D,conv2d_33/Conv2D -Add,add_9 -FusedBatchNorm,batch_normalization_30/FusedBatchNorm -Relu,Relu_30 -Conv2D,conv2d_34/Conv2D -FusedBatchNorm,batch_normalization_31/FusedBatchNorm -Relu,Relu_31 -Conv2D,conv2d_35/Conv2D -FusedBatchNorm,batch_normalization_32/FusedBatchNorm -Relu,Relu_32 -Conv2D,conv2d_36/Conv2D -Add,add_10 -FusedBatchNorm,batch_normalization_33/FusedBatchNorm -Relu,Relu_33 -Conv2D,conv2d_37/Conv2D -FusedBatchNorm,batch_normalization_34/FusedBatchNorm -Relu,Relu_34 -Conv2D,conv2d_38/Conv2D -FusedBatchNorm,batch_normalization_35/FusedBatchNorm -Relu,Relu_35 -Conv2D,conv2d_39/Conv2D -Add,add_11 -FusedBatchNorm,batch_normalization_36/FusedBatchNorm -Relu,Relu_36 -Conv2D,conv2d_40/Conv2D -FusedBatchNorm,batch_normalization_37/FusedBatchNorm -Relu,Relu_37 -Conv2D,conv2d_41/Conv2D -FusedBatchNorm,batch_normalization_38/FusedBatchNorm -Relu,Relu_38 -Conv2D,conv2d_42/Conv2D -Add,add_12 -Identity,block_layer3 -FusedBatchNorm,batch_normalization_39/FusedBatchNorm -Relu,Relu_39 -Pad,Pad_5 -Conv2D,conv2d_44/Conv2D -Conv2D,conv2d_43/Conv2D -FusedBatchNorm,batch_normalization_40/FusedBatchNorm -Relu,Relu_40 -Pad,Pad_6 -Conv2D,conv2d_45/Conv2D -FusedBatchNorm,batch_normalization_41/FusedBatchNorm -Relu,Relu_41 -Conv2D,conv2d_46/Conv2D -Add,add_13 -FusedBatchNorm,batch_normalization_42/FusedBatchNorm -Relu,Relu_42 -Conv2D,conv2d_47/Conv2D -FusedBatchNorm,batch_normalization_43/FusedBatchNorm -Relu,Relu_43 -Conv2D,conv2d_48/Conv2D -FusedBatchNorm,batch_normalization_44/FusedBatchNorm -Relu,Relu_44 -Conv2D,conv2d_49/Conv2D -Add,add_14 -FusedBatchNorm,batch_normalization_45/FusedBatchNorm -Relu,Relu_45 -Conv2D,conv2d_50/Conv2D -FusedBatchNorm,batch_normalization_46/FusedBatchNorm -Relu,Relu_46 -Conv2D,conv2d_51/Conv2D -FusedBatchNorm,batch_normalization_47/FusedBatchNorm -Relu,Relu_47 -Conv2D,conv2d_52/Conv2D -Add,add_15 -Identity,block_layer4 -FusedBatchNorm,batch_normalization_48/FusedBatchNorm -Relu,Relu_48 -Mean,Mean -Identity,final_reduce_mean -Reshape,Reshape -MatMul,dense/MatMul -BiasAdd,dense/BiasAdd -Identity,final_dense -ArgMax,ArgMax -Softmax,softmax_tensor diff --git a/nd4j/nd4j-backends/nd4j-tests/ops-imported-old.txt b/nd4j/nd4j-backends/nd4j-tests/ops-imported-old.txt deleted file mode 100644 index 17b33c1bb..000000000 --- a/nd4j/nd4j-backends/nd4j-tests/ops-imported-old.txt +++ /dev/null @@ -1 +0,0 @@ -Sum,Sum diff --git a/nd4j/nd4j-backends/nd4j-tests/ops-removed-new.txt b/nd4j/nd4j-backends/nd4j-tests/ops-removed-new.txt deleted file mode 100644 index 0b36fa236..000000000 --- a/nd4j/nd4j-backends/nd4j-tests/ops-removed-new.txt +++ /dev/null @@ -1,7 +0,0 @@ -Variable -Variable_1 -Variable/read -Variable_1/read -floordiv/x -floordiv/y -floordiv diff --git a/nd4j/nd4j-backends/nd4j-tests/ops-removed-old.txt b/nd4j/nd4j-backends/nd4j-tests/ops-removed-old.txt deleted file mode 100644 index 870f040eb..000000000 --- a/nd4j/nd4j-backends/nd4j-tests/ops-removed-old.txt +++ /dev/null @@ -1,3 +0,0 @@ -alpha -Sum/reduction_indices -Sum diff --git a/nd4j/nd4j-backends/nd4j-tests/pom.xml b/nd4j/nd4j-backends/nd4j-tests/pom.xml index f19b78df3..60452023f 100644 --- a/nd4j/nd4j-backends/nd4j-tests/pom.xml +++ b/nd4j/nd4j-backends/nd4j-tests/pom.xml @@ -39,7 +39,7 @@ 1.4.30-M1 1.8 true - 4.13 + 5.8.0-M1 5.4.2 UTF-8 1.8 @@ -239,14 +239,10 @@ org.junit.jupiter junit-jupiter-api - ${junit-jupiter.version} - test org.junit.jupiter junit-jupiter-engine - ${junit-jupiter.version} - test @@ -271,10 +267,6 @@ samediff-import-onnx ${project.version} - - junit - junit - org.nd4j nd4j-api diff --git a/nd4j/nd4j-backends/nd4j-tests/variables-added-new.txt b/nd4j/nd4j-backends/nd4j-tests/variables-added-new.txt deleted file mode 100644 index f2634f706..000000000 --- a/nd4j/nd4j-backends/nd4j-tests/variables-added-new.txt +++ /dev/null @@ -1,539 +0,0 @@ -transpose,transpose -conv2d/kernel/read,conv2d/kernel/read -batch_normalization/gamma/read,batch_normalization/gamma/read -batch_normalization/beta/read,batch_normalization/beta/read -batch_normalization/moving_mean/read,batch_normalization/moving_mean/read -batch_normalization/moving_variance/read,batch_normalization/moving_variance/read -conv2d_1/kernel/read,conv2d_1/kernel/read -conv2d_2/kernel/read,conv2d_2/kernel/read -batch_normalization_1/gamma/read,batch_normalization_1/gamma/read -batch_normalization_1/beta/read,batch_normalization_1/beta/read -batch_normalization_1/moving_mean/read,batch_normalization_1/moving_mean/read -batch_normalization_1/moving_variance/read,batch_normalization_1/moving_variance/read -conv2d_3/kernel/read,conv2d_3/kernel/read -batch_normalization_2/gamma/read,batch_normalization_2/gamma/read -batch_normalization_2/beta/read,batch_normalization_2/beta/read -batch_normalization_2/moving_mean/read,batch_normalization_2/moving_mean/read -batch_normalization_2/moving_variance/read,batch_normalization_2/moving_variance/read -conv2d_4/kernel/read,conv2d_4/kernel/read -batch_normalization_3/gamma/read,batch_normalization_3/gamma/read -batch_normalization_3/beta/read,batch_normalization_3/beta/read -batch_normalization_3/moving_mean/read,batch_normalization_3/moving_mean/read -batch_normalization_3/moving_variance/read,batch_normalization_3/moving_variance/read -conv2d_5/kernel/read,conv2d_5/kernel/read -batch_normalization_4/gamma/read,batch_normalization_4/gamma/read -batch_normalization_4/beta/read,batch_normalization_4/beta/read -batch_normalization_4/moving_mean/read,batch_normalization_4/moving_mean/read -batch_normalization_4/moving_variance/read,batch_normalization_4/moving_variance/read -conv2d_6/kernel/read,conv2d_6/kernel/read -batch_normalization_5/gamma/read,batch_normalization_5/gamma/read -batch_normalization_5/beta/read,batch_normalization_5/beta/read -batch_normalization_5/moving_mean/read,batch_normalization_5/moving_mean/read -batch_normalization_5/moving_variance/read,batch_normalization_5/moving_variance/read -conv2d_7/kernel/read,conv2d_7/kernel/read -batch_normalization_6/gamma/read,batch_normalization_6/gamma/read -batch_normalization_6/beta/read,batch_normalization_6/beta/read -batch_normalization_6/moving_mean/read,batch_normalization_6/moving_mean/read -batch_normalization_6/moving_variance/read,batch_normalization_6/moving_variance/read -conv2d_8/kernel/read,conv2d_8/kernel/read -batch_normalization_7/gamma/read,batch_normalization_7/gamma/read -batch_normalization_7/beta/read,batch_normalization_7/beta/read -batch_normalization_7/moving_mean/read,batch_normalization_7/moving_mean/read -batch_normalization_7/moving_variance/read,batch_normalization_7/moving_variance/read -conv2d_9/kernel/read,conv2d_9/kernel/read -batch_normalization_8/gamma/read,batch_normalization_8/gamma/read -batch_normalization_8/beta/read,batch_normalization_8/beta/read -batch_normalization_8/moving_mean/read,batch_normalization_8/moving_mean/read -batch_normalization_8/moving_variance/read,batch_normalization_8/moving_variance/read -conv2d_10/kernel/read,conv2d_10/kernel/read -batch_normalization_9/gamma/read,batch_normalization_9/gamma/read -batch_normalization_9/beta/read,batch_normalization_9/beta/read -batch_normalization_9/moving_mean/read,batch_normalization_9/moving_mean/read -batch_normalization_9/moving_variance/read,batch_normalization_9/moving_variance/read -conv2d_11/kernel/read,conv2d_11/kernel/read -conv2d_12/kernel/read,conv2d_12/kernel/read -batch_normalization_10/gamma/read,batch_normalization_10/gamma/read -batch_normalization_10/beta/read,batch_normalization_10/beta/read -batch_normalization_10/moving_mean/read,batch_normalization_10/moving_mean/read -batch_normalization_10/moving_variance/read,batch_normalization_10/moving_variance/read -conv2d_13/kernel/read,conv2d_13/kernel/read -batch_normalization_11/gamma/read,batch_normalization_11/gamma/read -batch_normalization_11/beta/read,batch_normalization_11/beta/read -batch_normalization_11/moving_mean/read,batch_normalization_11/moving_mean/read -batch_normalization_11/moving_variance/read,batch_normalization_11/moving_variance/read -conv2d_14/kernel/read,conv2d_14/kernel/read -batch_normalization_12/gamma/read,batch_normalization_12/gamma/read -batch_normalization_12/beta/read,batch_normalization_12/beta/read -batch_normalization_12/moving_mean/read,batch_normalization_12/moving_mean/read -batch_normalization_12/moving_variance/read,batch_normalization_12/moving_variance/read -conv2d_15/kernel/read,conv2d_15/kernel/read -batch_normalization_13/gamma/read,batch_normalization_13/gamma/read -batch_normalization_13/beta/read,batch_normalization_13/beta/read -batch_normalization_13/moving_mean/read,batch_normalization_13/moving_mean/read -batch_normalization_13/moving_variance/read,batch_normalization_13/moving_variance/read -conv2d_16/kernel/read,conv2d_16/kernel/read -batch_normalization_14/gamma/read,batch_normalization_14/gamma/read -batch_normalization_14/beta/read,batch_normalization_14/beta/read -batch_normalization_14/moving_mean/read,batch_normalization_14/moving_mean/read -batch_normalization_14/moving_variance/read,batch_normalization_14/moving_variance/read -conv2d_17/kernel/read,conv2d_17/kernel/read -batch_normalization_15/gamma/read,batch_normalization_15/gamma/read -batch_normalization_15/beta/read,batch_normalization_15/beta/read -batch_normalization_15/moving_mean/read,batch_normalization_15/moving_mean/read -batch_normalization_15/moving_variance/read,batch_normalization_15/moving_variance/read -conv2d_18/kernel/read,conv2d_18/kernel/read -batch_normalization_16/gamma/read,batch_normalization_16/gamma/read -batch_normalization_16/beta/read,batch_normalization_16/beta/read -batch_normalization_16/moving_mean/read,batch_normalization_16/moving_mean/read -batch_normalization_16/moving_variance/read,batch_normalization_16/moving_variance/read -conv2d_19/kernel/read,conv2d_19/kernel/read -batch_normalization_17/gamma/read,batch_normalization_17/gamma/read -batch_normalization_17/beta/read,batch_normalization_17/beta/read -batch_normalization_17/moving_mean/read,batch_normalization_17/moving_mean/read -batch_normalization_17/moving_variance/read,batch_normalization_17/moving_variance/read -conv2d_20/kernel/read,conv2d_20/kernel/read -batch_normalization_18/gamma/read,batch_normalization_18/gamma/read -batch_normalization_18/beta/read,batch_normalization_18/beta/read -batch_normalization_18/moving_mean/read,batch_normalization_18/moving_mean/read -batch_normalization_18/moving_variance/read,batch_normalization_18/moving_variance/read -conv2d_21/kernel/read,conv2d_21/kernel/read -batch_normalization_19/gamma/read,batch_normalization_19/gamma/read -batch_normalization_19/beta/read,batch_normalization_19/beta/read -batch_normalization_19/moving_mean/read,batch_normalization_19/moving_mean/read -batch_normalization_19/moving_variance/read,batch_normalization_19/moving_variance/read -conv2d_22/kernel/read,conv2d_22/kernel/read -batch_normalization_20/gamma/read,batch_normalization_20/gamma/read -batch_normalization_20/beta/read,batch_normalization_20/beta/read -batch_normalization_20/moving_mean/read,batch_normalization_20/moving_mean/read -batch_normalization_20/moving_variance/read,batch_normalization_20/moving_variance/read -conv2d_23/kernel/read,conv2d_23/kernel/read -batch_normalization_21/gamma/read,batch_normalization_21/gamma/read -batch_normalization_21/beta/read,batch_normalization_21/beta/read -batch_normalization_21/moving_mean/read,batch_normalization_21/moving_mean/read -batch_normalization_21/moving_variance/read,batch_normalization_21/moving_variance/read -conv2d_24/kernel/read,conv2d_24/kernel/read -conv2d_25/kernel/read,conv2d_25/kernel/read -batch_normalization_22/gamma/read,batch_normalization_22/gamma/read -batch_normalization_22/beta/read,batch_normalization_22/beta/read -batch_normalization_22/moving_mean/read,batch_normalization_22/moving_mean/read -batch_normalization_22/moving_variance/read,batch_normalization_22/moving_variance/read -conv2d_26/kernel/read,conv2d_26/kernel/read -batch_normalization_23/gamma/read,batch_normalization_23/gamma/read -batch_normalization_23/beta/read,batch_normalization_23/beta/read -batch_normalization_23/moving_mean/read,batch_normalization_23/moving_mean/read -batch_normalization_23/moving_variance/read,batch_normalization_23/moving_variance/read -conv2d_27/kernel/read,conv2d_27/kernel/read -batch_normalization_24/gamma/read,batch_normalization_24/gamma/read -batch_normalization_24/beta/read,batch_normalization_24/beta/read -batch_normalization_24/moving_mean/read,batch_normalization_24/moving_mean/read -batch_normalization_24/moving_variance/read,batch_normalization_24/moving_variance/read -conv2d_28/kernel/read,conv2d_28/kernel/read -batch_normalization_25/gamma/read,batch_normalization_25/gamma/read -batch_normalization_25/beta/read,batch_normalization_25/beta/read -batch_normalization_25/moving_mean/read,batch_normalization_25/moving_mean/read -batch_normalization_25/moving_variance/read,batch_normalization_25/moving_variance/read -conv2d_29/kernel/read,conv2d_29/kernel/read -batch_normalization_26/gamma/read,batch_normalization_26/gamma/read -batch_normalization_26/beta/read,batch_normalization_26/beta/read -batch_normalization_26/moving_mean/read,batch_normalization_26/moving_mean/read -batch_normalization_26/moving_variance/read,batch_normalization_26/moving_variance/read -conv2d_30/kernel/read,conv2d_30/kernel/read -batch_normalization_27/gamma/read,batch_normalization_27/gamma/read -batch_normalization_27/beta/read,batch_normalization_27/beta/read -batch_normalization_27/moving_mean/read,batch_normalization_27/moving_mean/read -batch_normalization_27/moving_variance/read,batch_normalization_27/moving_variance/read -conv2d_31/kernel/read,conv2d_31/kernel/read -batch_normalization_28/gamma/read,batch_normalization_28/gamma/read -batch_normalization_28/beta/read,batch_normalization_28/beta/read -batch_normalization_28/moving_mean/read,batch_normalization_28/moving_mean/read -batch_normalization_28/moving_variance/read,batch_normalization_28/moving_variance/read -conv2d_32/kernel/read,conv2d_32/kernel/read -batch_normalization_29/gamma/read,batch_normalization_29/gamma/read -batch_normalization_29/beta/read,batch_normalization_29/beta/read -batch_normalization_29/moving_mean/read,batch_normalization_29/moving_mean/read -batch_normalization_29/moving_variance/read,batch_normalization_29/moving_variance/read -conv2d_33/kernel/read,conv2d_33/kernel/read -batch_normalization_30/gamma/read,batch_normalization_30/gamma/read -batch_normalization_30/beta/read,batch_normalization_30/beta/read -batch_normalization_30/moving_mean/read,batch_normalization_30/moving_mean/read -batch_normalization_30/moving_variance/read,batch_normalization_30/moving_variance/read -conv2d_34/kernel/read,conv2d_34/kernel/read -batch_normalization_31/gamma/read,batch_normalization_31/gamma/read -batch_normalization_31/beta/read,batch_normalization_31/beta/read -batch_normalization_31/moving_mean/read,batch_normalization_31/moving_mean/read -batch_normalization_31/moving_variance/read,batch_normalization_31/moving_variance/read -conv2d_35/kernel/read,conv2d_35/kernel/read -batch_normalization_32/gamma/read,batch_normalization_32/gamma/read -batch_normalization_32/beta/read,batch_normalization_32/beta/read -batch_normalization_32/moving_mean/read,batch_normalization_32/moving_mean/read -batch_normalization_32/moving_variance/read,batch_normalization_32/moving_variance/read -conv2d_36/kernel/read,conv2d_36/kernel/read -batch_normalization_33/gamma/read,batch_normalization_33/gamma/read -batch_normalization_33/beta/read,batch_normalization_33/beta/read -batch_normalization_33/moving_mean/read,batch_normalization_33/moving_mean/read -batch_normalization_33/moving_variance/read,batch_normalization_33/moving_variance/read -conv2d_37/kernel/read,conv2d_37/kernel/read -batch_normalization_34/gamma/read,batch_normalization_34/gamma/read -batch_normalization_34/beta/read,batch_normalization_34/beta/read -batch_normalization_34/moving_mean/read,batch_normalization_34/moving_mean/read -batch_normalization_34/moving_variance/read,batch_normalization_34/moving_variance/read -conv2d_38/kernel/read,conv2d_38/kernel/read -batch_normalization_35/gamma/read,batch_normalization_35/gamma/read -batch_normalization_35/beta/read,batch_normalization_35/beta/read -batch_normalization_35/moving_mean/read,batch_normalization_35/moving_mean/read -batch_normalization_35/moving_variance/read,batch_normalization_35/moving_variance/read -conv2d_39/kernel/read,conv2d_39/kernel/read -batch_normalization_36/gamma/read,batch_normalization_36/gamma/read -batch_normalization_36/beta/read,batch_normalization_36/beta/read -batch_normalization_36/moving_mean/read,batch_normalization_36/moving_mean/read -batch_normalization_36/moving_variance/read,batch_normalization_36/moving_variance/read -conv2d_40/kernel/read,conv2d_40/kernel/read -batch_normalization_37/gamma/read,batch_normalization_37/gamma/read -batch_normalization_37/beta/read,batch_normalization_37/beta/read -batch_normalization_37/moving_mean/read,batch_normalization_37/moving_mean/read -batch_normalization_37/moving_variance/read,batch_normalization_37/moving_variance/read -conv2d_41/kernel/read,conv2d_41/kernel/read -batch_normalization_38/gamma/read,batch_normalization_38/gamma/read -batch_normalization_38/beta/read,batch_normalization_38/beta/read -batch_normalization_38/moving_mean/read,batch_normalization_38/moving_mean/read -batch_normalization_38/moving_variance/read,batch_normalization_38/moving_variance/read -conv2d_42/kernel/read,conv2d_42/kernel/read -batch_normalization_39/gamma/read,batch_normalization_39/gamma/read -batch_normalization_39/beta/read,batch_normalization_39/beta/read -batch_normalization_39/moving_mean/read,batch_normalization_39/moving_mean/read -batch_normalization_39/moving_variance/read,batch_normalization_39/moving_variance/read -conv2d_43/kernel/read,conv2d_43/kernel/read -conv2d_44/kernel/read,conv2d_44/kernel/read -batch_normalization_40/gamma/read,batch_normalization_40/gamma/read -batch_normalization_40/beta/read,batch_normalization_40/beta/read -batch_normalization_40/moving_mean/read,batch_normalization_40/moving_mean/read -batch_normalization_40/moving_variance/read,batch_normalization_40/moving_variance/read -conv2d_45/kernel/read,conv2d_45/kernel/read -batch_normalization_41/gamma/read,batch_normalization_41/gamma/read -batch_normalization_41/beta/read,batch_normalization_41/beta/read -batch_normalization_41/moving_mean/read,batch_normalization_41/moving_mean/read -batch_normalization_41/moving_variance/read,batch_normalization_41/moving_variance/read -conv2d_46/kernel/read,conv2d_46/kernel/read -batch_normalization_42/gamma/read,batch_normalization_42/gamma/read -batch_normalization_42/beta/read,batch_normalization_42/beta/read -batch_normalization_42/moving_mean/read,batch_normalization_42/moving_mean/read -batch_normalization_42/moving_variance/read,batch_normalization_42/moving_variance/read -conv2d_47/kernel/read,conv2d_47/kernel/read -batch_normalization_43/gamma/read,batch_normalization_43/gamma/read -batch_normalization_43/beta/read,batch_normalization_43/beta/read -batch_normalization_43/moving_mean/read,batch_normalization_43/moving_mean/read -batch_normalization_43/moving_variance/read,batch_normalization_43/moving_variance/read -conv2d_48/kernel/read,conv2d_48/kernel/read -batch_normalization_44/gamma/read,batch_normalization_44/gamma/read -batch_normalization_44/beta/read,batch_normalization_44/beta/read -batch_normalization_44/moving_mean/read,batch_normalization_44/moving_mean/read -batch_normalization_44/moving_variance/read,batch_normalization_44/moving_variance/read -conv2d_49/kernel/read,conv2d_49/kernel/read -batch_normalization_45/gamma/read,batch_normalization_45/gamma/read -batch_normalization_45/beta/read,batch_normalization_45/beta/read -batch_normalization_45/moving_mean/read,batch_normalization_45/moving_mean/read -batch_normalization_45/moving_variance/read,batch_normalization_45/moving_variance/read -conv2d_50/kernel/read,conv2d_50/kernel/read -batch_normalization_46/gamma/read,batch_normalization_46/gamma/read -batch_normalization_46/beta/read,batch_normalization_46/beta/read -batch_normalization_46/moving_mean/read,batch_normalization_46/moving_mean/read -batch_normalization_46/moving_variance/read,batch_normalization_46/moving_variance/read -conv2d_51/kernel/read,conv2d_51/kernel/read -batch_normalization_47/gamma/read,batch_normalization_47/gamma/read -batch_normalization_47/beta/read,batch_normalization_47/beta/read -batch_normalization_47/moving_mean/read,batch_normalization_47/moving_mean/read -batch_normalization_47/moving_variance/read,batch_normalization_47/moving_variance/read -conv2d_52/kernel/read,conv2d_52/kernel/read -batch_normalization_48/gamma/read,batch_normalization_48/gamma/read -batch_normalization_48/beta/read,batch_normalization_48/beta/read -batch_normalization_48/moving_mean/read,batch_normalization_48/moving_mean/read -batch_normalization_48/moving_variance/read,batch_normalization_48/moving_variance/read -dense/kernel/read,dense/kernel/read -dense/bias/read,dense/bias/read -Pad,Pad -conv2d/Conv2D,conv2d/Conv2D -initial_conv,initial_conv -max_pooling2d/MaxPool,max_pooling2d/MaxPool -initial_max_pool,initial_max_pool -batch_normalization/FusedBatchNorm,batch_normalization/FusedBatchNorm -batch_normalization/FusedBatchNorm:1,batch_normalization/FusedBatchNorm -batch_normalization/FusedBatchNorm:2,batch_normalization/FusedBatchNorm -Relu,Relu -conv2d_1/Conv2D,conv2d_1/Conv2D -conv2d_2/Conv2D,conv2d_2/Conv2D -batch_normalization_1/FusedBatchNorm,batch_normalization_1/FusedBatchNorm -batch_normalization_1/FusedBatchNorm:1,batch_normalization_1/FusedBatchNorm -batch_normalization_1/FusedBatchNorm:2,batch_normalization_1/FusedBatchNorm -Relu_1,Relu_1 -conv2d_3/Conv2D,conv2d_3/Conv2D -batch_normalization_2/FusedBatchNorm,batch_normalization_2/FusedBatchNorm -batch_normalization_2/FusedBatchNorm:1,batch_normalization_2/FusedBatchNorm -batch_normalization_2/FusedBatchNorm:2,batch_normalization_2/FusedBatchNorm -Relu_2,Relu_2 -conv2d_4/Conv2D,conv2d_4/Conv2D -add,add -batch_normalization_3/FusedBatchNorm,batch_normalization_3/FusedBatchNorm -batch_normalization_3/FusedBatchNorm:1,batch_normalization_3/FusedBatchNorm -batch_normalization_3/FusedBatchNorm:2,batch_normalization_3/FusedBatchNorm -Relu_3,Relu_3 -conv2d_5/Conv2D,conv2d_5/Conv2D -batch_normalization_4/FusedBatchNorm,batch_normalization_4/FusedBatchNorm -batch_normalization_4/FusedBatchNorm:1,batch_normalization_4/FusedBatchNorm -batch_normalization_4/FusedBatchNorm:2,batch_normalization_4/FusedBatchNorm -Relu_4,Relu_4 -conv2d_6/Conv2D,conv2d_6/Conv2D -batch_normalization_5/FusedBatchNorm,batch_normalization_5/FusedBatchNorm -batch_normalization_5/FusedBatchNorm:1,batch_normalization_5/FusedBatchNorm -batch_normalization_5/FusedBatchNorm:2,batch_normalization_5/FusedBatchNorm -Relu_5,Relu_5 -conv2d_7/Conv2D,conv2d_7/Conv2D -add_1,add_1 -batch_normalization_6/FusedBatchNorm,batch_normalization_6/FusedBatchNorm -batch_normalization_6/FusedBatchNorm:1,batch_normalization_6/FusedBatchNorm -batch_normalization_6/FusedBatchNorm:2,batch_normalization_6/FusedBatchNorm -Relu_6,Relu_6 -conv2d_8/Conv2D,conv2d_8/Conv2D -batch_normalization_7/FusedBatchNorm,batch_normalization_7/FusedBatchNorm -batch_normalization_7/FusedBatchNorm:1,batch_normalization_7/FusedBatchNorm -batch_normalization_7/FusedBatchNorm:2,batch_normalization_7/FusedBatchNorm -Relu_7,Relu_7 -conv2d_9/Conv2D,conv2d_9/Conv2D -batch_normalization_8/FusedBatchNorm,batch_normalization_8/FusedBatchNorm -batch_normalization_8/FusedBatchNorm:1,batch_normalization_8/FusedBatchNorm -batch_normalization_8/FusedBatchNorm:2,batch_normalization_8/FusedBatchNorm -Relu_8,Relu_8 -conv2d_10/Conv2D,conv2d_10/Conv2D -add_2,add_2 -block_layer1,block_layer1 -batch_normalization_9/FusedBatchNorm,batch_normalization_9/FusedBatchNorm -batch_normalization_9/FusedBatchNorm:1,batch_normalization_9/FusedBatchNorm -batch_normalization_9/FusedBatchNorm:2,batch_normalization_9/FusedBatchNorm -Relu_9,Relu_9 -Pad_1,Pad_1 -conv2d_12/Conv2D,conv2d_12/Conv2D -conv2d_11/Conv2D,conv2d_11/Conv2D -batch_normalization_10/FusedBatchNorm,batch_normalization_10/FusedBatchNorm -batch_normalization_10/FusedBatchNorm:1,batch_normalization_10/FusedBatchNorm -batch_normalization_10/FusedBatchNorm:2,batch_normalization_10/FusedBatchNorm -Relu_10,Relu_10 -Pad_2,Pad_2 -conv2d_13/Conv2D,conv2d_13/Conv2D -batch_normalization_11/FusedBatchNorm,batch_normalization_11/FusedBatchNorm -batch_normalization_11/FusedBatchNorm:1,batch_normalization_11/FusedBatchNorm -batch_normalization_11/FusedBatchNorm:2,batch_normalization_11/FusedBatchNorm -Relu_11,Relu_11 -conv2d_14/Conv2D,conv2d_14/Conv2D -add_3,add_3 -batch_normalization_12/FusedBatchNorm,batch_normalization_12/FusedBatchNorm -batch_normalization_12/FusedBatchNorm:1,batch_normalization_12/FusedBatchNorm -batch_normalization_12/FusedBatchNorm:2,batch_normalization_12/FusedBatchNorm -Relu_12,Relu_12 -conv2d_15/Conv2D,conv2d_15/Conv2D -batch_normalization_13/FusedBatchNorm,batch_normalization_13/FusedBatchNorm -batch_normalization_13/FusedBatchNorm:1,batch_normalization_13/FusedBatchNorm -batch_normalization_13/FusedBatchNorm:2,batch_normalization_13/FusedBatchNorm -Relu_13,Relu_13 -conv2d_16/Conv2D,conv2d_16/Conv2D -batch_normalization_14/FusedBatchNorm,batch_normalization_14/FusedBatchNorm -batch_normalization_14/FusedBatchNorm:1,batch_normalization_14/FusedBatchNorm -batch_normalization_14/FusedBatchNorm:2,batch_normalization_14/FusedBatchNorm -Relu_14,Relu_14 -conv2d_17/Conv2D,conv2d_17/Conv2D -add_4,add_4 -batch_normalization_15/FusedBatchNorm,batch_normalization_15/FusedBatchNorm -batch_normalization_15/FusedBatchNorm:1,batch_normalization_15/FusedBatchNorm -batch_normalization_15/FusedBatchNorm:2,batch_normalization_15/FusedBatchNorm -Relu_15,Relu_15 -conv2d_18/Conv2D,conv2d_18/Conv2D -batch_normalization_16/FusedBatchNorm,batch_normalization_16/FusedBatchNorm -batch_normalization_16/FusedBatchNorm:1,batch_normalization_16/FusedBatchNorm -batch_normalization_16/FusedBatchNorm:2,batch_normalization_16/FusedBatchNorm -Relu_16,Relu_16 -conv2d_19/Conv2D,conv2d_19/Conv2D -batch_normalization_17/FusedBatchNorm,batch_normalization_17/FusedBatchNorm -batch_normalization_17/FusedBatchNorm:1,batch_normalization_17/FusedBatchNorm -batch_normalization_17/FusedBatchNorm:2,batch_normalization_17/FusedBatchNorm -Relu_17,Relu_17 -conv2d_20/Conv2D,conv2d_20/Conv2D -add_5,add_5 -batch_normalization_18/FusedBatchNorm,batch_normalization_18/FusedBatchNorm -batch_normalization_18/FusedBatchNorm:1,batch_normalization_18/FusedBatchNorm -batch_normalization_18/FusedBatchNorm:2,batch_normalization_18/FusedBatchNorm -Relu_18,Relu_18 -conv2d_21/Conv2D,conv2d_21/Conv2D -batch_normalization_19/FusedBatchNorm,batch_normalization_19/FusedBatchNorm -batch_normalization_19/FusedBatchNorm:1,batch_normalization_19/FusedBatchNorm -batch_normalization_19/FusedBatchNorm:2,batch_normalization_19/FusedBatchNorm -Relu_19,Relu_19 -conv2d_22/Conv2D,conv2d_22/Conv2D -batch_normalization_20/FusedBatchNorm,batch_normalization_20/FusedBatchNorm -batch_normalization_20/FusedBatchNorm:1,batch_normalization_20/FusedBatchNorm -batch_normalization_20/FusedBatchNorm:2,batch_normalization_20/FusedBatchNorm -Relu_20,Relu_20 -conv2d_23/Conv2D,conv2d_23/Conv2D -add_6,add_6 -block_layer2,block_layer2 -batch_normalization_21/FusedBatchNorm,batch_normalization_21/FusedBatchNorm -batch_normalization_21/FusedBatchNorm:1,batch_normalization_21/FusedBatchNorm -batch_normalization_21/FusedBatchNorm:2,batch_normalization_21/FusedBatchNorm -Relu_21,Relu_21 -Pad_3,Pad_3 -conv2d_25/Conv2D,conv2d_25/Conv2D -conv2d_24/Conv2D,conv2d_24/Conv2D -batch_normalization_22/FusedBatchNorm,batch_normalization_22/FusedBatchNorm -batch_normalization_22/FusedBatchNorm:1,batch_normalization_22/FusedBatchNorm -batch_normalization_22/FusedBatchNorm:2,batch_normalization_22/FusedBatchNorm -Relu_22,Relu_22 -Pad_4,Pad_4 -conv2d_26/Conv2D,conv2d_26/Conv2D -batch_normalization_23/FusedBatchNorm,batch_normalization_23/FusedBatchNorm -batch_normalization_23/FusedBatchNorm:1,batch_normalization_23/FusedBatchNorm -batch_normalization_23/FusedBatchNorm:2,batch_normalization_23/FusedBatchNorm -Relu_23,Relu_23 -conv2d_27/Conv2D,conv2d_27/Conv2D -add_7,add_7 -batch_normalization_24/FusedBatchNorm,batch_normalization_24/FusedBatchNorm -batch_normalization_24/FusedBatchNorm:1,batch_normalization_24/FusedBatchNorm -batch_normalization_24/FusedBatchNorm:2,batch_normalization_24/FusedBatchNorm -Relu_24,Relu_24 -conv2d_28/Conv2D,conv2d_28/Conv2D -batch_normalization_25/FusedBatchNorm,batch_normalization_25/FusedBatchNorm -batch_normalization_25/FusedBatchNorm:1,batch_normalization_25/FusedBatchNorm -batch_normalization_25/FusedBatchNorm:2,batch_normalization_25/FusedBatchNorm -Relu_25,Relu_25 -conv2d_29/Conv2D,conv2d_29/Conv2D -batch_normalization_26/FusedBatchNorm,batch_normalization_26/FusedBatchNorm -batch_normalization_26/FusedBatchNorm:1,batch_normalization_26/FusedBatchNorm -batch_normalization_26/FusedBatchNorm:2,batch_normalization_26/FusedBatchNorm -Relu_26,Relu_26 -conv2d_30/Conv2D,conv2d_30/Conv2D -add_8,add_8 -batch_normalization_27/FusedBatchNorm,batch_normalization_27/FusedBatchNorm -batch_normalization_27/FusedBatchNorm:1,batch_normalization_27/FusedBatchNorm -batch_normalization_27/FusedBatchNorm:2,batch_normalization_27/FusedBatchNorm -Relu_27,Relu_27 -conv2d_31/Conv2D,conv2d_31/Conv2D -batch_normalization_28/FusedBatchNorm,batch_normalization_28/FusedBatchNorm -batch_normalization_28/FusedBatchNorm:1,batch_normalization_28/FusedBatchNorm -batch_normalization_28/FusedBatchNorm:2,batch_normalization_28/FusedBatchNorm -Relu_28,Relu_28 -conv2d_32/Conv2D,conv2d_32/Conv2D -batch_normalization_29/FusedBatchNorm,batch_normalization_29/FusedBatchNorm -batch_normalization_29/FusedBatchNorm:1,batch_normalization_29/FusedBatchNorm -batch_normalization_29/FusedBatchNorm:2,batch_normalization_29/FusedBatchNorm -Relu_29,Relu_29 -conv2d_33/Conv2D,conv2d_33/Conv2D -add_9,add_9 -batch_normalization_30/FusedBatchNorm,batch_normalization_30/FusedBatchNorm -batch_normalization_30/FusedBatchNorm:1,batch_normalization_30/FusedBatchNorm -batch_normalization_30/FusedBatchNorm:2,batch_normalization_30/FusedBatchNorm -Relu_30,Relu_30 -conv2d_34/Conv2D,conv2d_34/Conv2D -batch_normalization_31/FusedBatchNorm,batch_normalization_31/FusedBatchNorm -batch_normalization_31/FusedBatchNorm:1,batch_normalization_31/FusedBatchNorm -batch_normalization_31/FusedBatchNorm:2,batch_normalization_31/FusedBatchNorm -Relu_31,Relu_31 -conv2d_35/Conv2D,conv2d_35/Conv2D -batch_normalization_32/FusedBatchNorm,batch_normalization_32/FusedBatchNorm -batch_normalization_32/FusedBatchNorm:1,batch_normalization_32/FusedBatchNorm -batch_normalization_32/FusedBatchNorm:2,batch_normalization_32/FusedBatchNorm -Relu_32,Relu_32 -conv2d_36/Conv2D,conv2d_36/Conv2D -add_10,add_10 -batch_normalization_33/FusedBatchNorm,batch_normalization_33/FusedBatchNorm -batch_normalization_33/FusedBatchNorm:1,batch_normalization_33/FusedBatchNorm -batch_normalization_33/FusedBatchNorm:2,batch_normalization_33/FusedBatchNorm -Relu_33,Relu_33 -conv2d_37/Conv2D,conv2d_37/Conv2D -batch_normalization_34/FusedBatchNorm,batch_normalization_34/FusedBatchNorm -batch_normalization_34/FusedBatchNorm:1,batch_normalization_34/FusedBatchNorm -batch_normalization_34/FusedBatchNorm:2,batch_normalization_34/FusedBatchNorm -Relu_34,Relu_34 -conv2d_38/Conv2D,conv2d_38/Conv2D -batch_normalization_35/FusedBatchNorm,batch_normalization_35/FusedBatchNorm -batch_normalization_35/FusedBatchNorm:1,batch_normalization_35/FusedBatchNorm -batch_normalization_35/FusedBatchNorm:2,batch_normalization_35/FusedBatchNorm -Relu_35,Relu_35 -conv2d_39/Conv2D,conv2d_39/Conv2D -add_11,add_11 -batch_normalization_36/FusedBatchNorm,batch_normalization_36/FusedBatchNorm -batch_normalization_36/FusedBatchNorm:1,batch_normalization_36/FusedBatchNorm -batch_normalization_36/FusedBatchNorm:2,batch_normalization_36/FusedBatchNorm -Relu_36,Relu_36 -conv2d_40/Conv2D,conv2d_40/Conv2D -batch_normalization_37/FusedBatchNorm,batch_normalization_37/FusedBatchNorm -batch_normalization_37/FusedBatchNorm:1,batch_normalization_37/FusedBatchNorm -batch_normalization_37/FusedBatchNorm:2,batch_normalization_37/FusedBatchNorm -Relu_37,Relu_37 -conv2d_41/Conv2D,conv2d_41/Conv2D -batch_normalization_38/FusedBatchNorm,batch_normalization_38/FusedBatchNorm -batch_normalization_38/FusedBatchNorm:1,batch_normalization_38/FusedBatchNorm -batch_normalization_38/FusedBatchNorm:2,batch_normalization_38/FusedBatchNorm -Relu_38,Relu_38 -conv2d_42/Conv2D,conv2d_42/Conv2D -add_12,add_12 -block_layer3,block_layer3 -batch_normalization_39/FusedBatchNorm,batch_normalization_39/FusedBatchNorm -batch_normalization_39/FusedBatchNorm:1,batch_normalization_39/FusedBatchNorm -batch_normalization_39/FusedBatchNorm:2,batch_normalization_39/FusedBatchNorm -Relu_39,Relu_39 -Pad_5,Pad_5 -conv2d_44/Conv2D,conv2d_44/Conv2D -conv2d_43/Conv2D,conv2d_43/Conv2D -batch_normalization_40/FusedBatchNorm,batch_normalization_40/FusedBatchNorm -batch_normalization_40/FusedBatchNorm:1,batch_normalization_40/FusedBatchNorm -batch_normalization_40/FusedBatchNorm:2,batch_normalization_40/FusedBatchNorm -Relu_40,Relu_40 -Pad_6,Pad_6 -conv2d_45/Conv2D,conv2d_45/Conv2D -batch_normalization_41/FusedBatchNorm,batch_normalization_41/FusedBatchNorm -batch_normalization_41/FusedBatchNorm:1,batch_normalization_41/FusedBatchNorm -batch_normalization_41/FusedBatchNorm:2,batch_normalization_41/FusedBatchNorm -Relu_41,Relu_41 -conv2d_46/Conv2D,conv2d_46/Conv2D -add_13,add_13 -batch_normalization_42/FusedBatchNorm,batch_normalization_42/FusedBatchNorm -batch_normalization_42/FusedBatchNorm:1,batch_normalization_42/FusedBatchNorm -batch_normalization_42/FusedBatchNorm:2,batch_normalization_42/FusedBatchNorm -Relu_42,Relu_42 -conv2d_47/Conv2D,conv2d_47/Conv2D -batch_normalization_43/FusedBatchNorm,batch_normalization_43/FusedBatchNorm -batch_normalization_43/FusedBatchNorm:1,batch_normalization_43/FusedBatchNorm -batch_normalization_43/FusedBatchNorm:2,batch_normalization_43/FusedBatchNorm -Relu_43,Relu_43 -conv2d_48/Conv2D,conv2d_48/Conv2D -batch_normalization_44/FusedBatchNorm,batch_normalization_44/FusedBatchNorm -batch_normalization_44/FusedBatchNorm:1,batch_normalization_44/FusedBatchNorm -batch_normalization_44/FusedBatchNorm:2,batch_normalization_44/FusedBatchNorm -Relu_44,Relu_44 -conv2d_49/Conv2D,conv2d_49/Conv2D -add_14,add_14 -batch_normalization_45/FusedBatchNorm,batch_normalization_45/FusedBatchNorm -batch_normalization_45/FusedBatchNorm:1,batch_normalization_45/FusedBatchNorm -batch_normalization_45/FusedBatchNorm:2,batch_normalization_45/FusedBatchNorm -Relu_45,Relu_45 -conv2d_50/Conv2D,conv2d_50/Conv2D -batch_normalization_46/FusedBatchNorm,batch_normalization_46/FusedBatchNorm -batch_normalization_46/FusedBatchNorm:1,batch_normalization_46/FusedBatchNorm -batch_normalization_46/FusedBatchNorm:2,batch_normalization_46/FusedBatchNorm -Relu_46,Relu_46 -conv2d_51/Conv2D,conv2d_51/Conv2D -batch_normalization_47/FusedBatchNorm,batch_normalization_47/FusedBatchNorm -batch_normalization_47/FusedBatchNorm:1,batch_normalization_47/FusedBatchNorm -batch_normalization_47/FusedBatchNorm:2,batch_normalization_47/FusedBatchNorm -Relu_47,Relu_47 -conv2d_52/Conv2D,conv2d_52/Conv2D -add_15,add_15 -block_layer4,block_layer4 -batch_normalization_48/FusedBatchNorm,batch_normalization_48/FusedBatchNorm -batch_normalization_48/FusedBatchNorm:1,batch_normalization_48/FusedBatchNorm -batch_normalization_48/FusedBatchNorm:2,batch_normalization_48/FusedBatchNorm -Relu_48,Relu_48 -Mean,Mean -final_reduce_mean,final_reduce_mean -Reshape,Reshape -dense/MatMul,dense/MatMul -dense/BiasAdd,dense/BiasAdd -final_dense,final_dense -ArgMax,ArgMax -softmax_tensor,softmax_tensor diff --git a/nd4j/nd4j-backends/nd4j-tests/variables-added-old.txt b/nd4j/nd4j-backends/nd4j-tests/variables-added-old.txt deleted file mode 100644 index c273a0be4..000000000 --- a/nd4j/nd4j-backends/nd4j-tests/variables-added-old.txt +++ /dev/null @@ -1 +0,0 @@ -Sum,Sum diff --git a/nd4j/nd4j-common-tests/pom.xml b/nd4j/nd4j-common-tests/pom.xml index 9134e21cc..61cdbb1a3 100644 --- a/nd4j/nd4j-common-tests/pom.xml +++ b/nd4j/nd4j-common-tests/pom.xml @@ -40,9 +40,24 @@ - junit - junit - provided + org.junit.jupiter + junit-jupiter-api + compile + + + org.junit.jupiter + junit-jupiter-engine + compile + + + + org.junit.jupiter + junit-jupiter + + + org.junit.vintage + junit-vintage-engine + compile org.nd4j diff --git a/nd4j/nd4j-common-tests/src/main/java/org/nd4j/common/tests/AbstractAssertTestsClass.java b/nd4j/nd4j-common-tests/src/main/java/org/nd4j/common/tests/AbstractAssertTestsClass.java index 2c531ee61..ff5251175 100644 --- a/nd4j/nd4j-common-tests/src/main/java/org/nd4j/common/tests/AbstractAssertTestsClass.java +++ b/nd4j/nd4j-common-tests/src/main/java/org/nd4j/common/tests/AbstractAssertTestsClass.java @@ -20,7 +20,6 @@ package org.nd4j.common.tests; import lombok.extern.slf4j.Slf4j; -import org.junit.Test; import org.reflections.Reflections; import org.reflections.scanners.MethodAnnotationsScanner; import org.reflections.util.ClasspathHelper; @@ -28,8 +27,8 @@ import org.reflections.util.ConfigurationBuilder; import java.lang.reflect.Method; import java.util.*; +import org.junit.jupiter.api.Test; -import static org.junit.Assert.assertEquals; @Slf4j public abstract class AbstractAssertTestsClass extends BaseND4JTest { @@ -46,7 +45,7 @@ public abstract class AbstractAssertTestsClass extends BaseND4JTest { } @Test - public void checkTestClasses(){ + public void checkTestClasses() { Reflections reflections = new Reflections(new ConfigurationBuilder() .setUrls(ClasspathHelper.forPackage(getPackageName())) .setScanners(new MethodAnnotationsScanner())); diff --git a/nd4j/nd4j-common-tests/src/main/java/org/nd4j/common/tests/BaseND4JTest.java b/nd4j/nd4j-common-tests/src/main/java/org/nd4j/common/tests/BaseND4JTest.java index e105cf706..b7fb96fb5 100644 --- a/nd4j/nd4j-common-tests/src/main/java/org/nd4j/common/tests/BaseND4JTest.java +++ b/nd4j/nd4j-common-tests/src/main/java/org/nd4j/common/tests/BaseND4JTest.java @@ -23,9 +23,9 @@ package org.nd4j.common.tests; import ch.qos.logback.classic.LoggerContext; import lombok.extern.slf4j.Slf4j; import org.bytedeco.javacpp.Pointer; -import org.junit.*; -import org.junit.rules.TestName; -import org.junit.rules.Timeout; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestInfo; import org.nd4j.common.base.Preconditions; import org.nd4j.common.config.ND4JSystemProperties; import org.nd4j.linalg.api.buffer.DataType; @@ -41,15 +41,12 @@ import java.util.List; import java.util.Map; import java.util.Properties; -import static org.junit.Assume.assumeTrue; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + @Slf4j public abstract class BaseND4JTest { - @Rule - public TestName name = new TestName(); - @Rule - public Timeout timeout = Timeout.millis(getTimeoutMilliseconds()); protected long startTime; protected int threadCountBefore; @@ -111,13 +108,13 @@ public abstract class BaseND4JTest { * This can be used to dynamically skip integration tests when the integration test profile is not enabled. * Note that the integration test profile is not enabled by default - "integration-tests" profile */ - public void skipUnlessIntegrationTests(){ - assumeTrue("Skipping integration test - integration profile is not enabled", isIntegrationTests()); + public void skipUnlessIntegrationTests() { + assumeTrue( isIntegrationTests(),"Skipping integration test - integration profile is not enabled"); } - @Before - public void beforeTest(){ - log.info("{}.{}", getClass().getSimpleName(), name.getMethodName()); + @BeforeEach + public void beforeTest(TestInfo testInfo) { + log.info("{}.{}", getClass().getSimpleName(), testInfo.getTestMethod().get().getName()); //Suppress ND4J initialization - don't need this logged for every test... System.setProperty(ND4JSystemProperties.LOG_INITIALIZATION, "false"); System.setProperty(ND4JSystemProperties.ND4J_IGNORE_AVX, "true"); @@ -136,8 +133,8 @@ public abstract class BaseND4JTest { threadCountBefore = ManagementFactory.getThreadMXBean().getThreadCount(); } - @After - public void afterTest(){ + @AfterEach + public void afterTest(TestInfo testInfo) { //Attempt to keep workspaces isolated between tests Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); MemoryWorkspace currWS = Nd4j.getMemoryManager().getCurrentWorkspace(); @@ -170,7 +167,7 @@ public abstract class BaseND4JTest { int threadsAfter = ManagementFactory.getThreadMXBean().getThreadCount(); long duration = System.currentTimeMillis() - startTime; - sb.append(getClass().getSimpleName()).append(".").append(name.getMethodName()) + sb.append(getClass().getSimpleName()).append(".").append( testInfo.getTestMethod().get().getName()) .append(": ").append(duration).append(" ms") .append(", threadCount: (").append(threadCountBefore).append("->").append(threadsAfter).append(")") .append(", jvmTotal=").append(jvmTotal) diff --git a/nd4j/nd4j-common/pom.xml b/nd4j/nd4j-common/pom.xml index e92faac77..4b211dbaa 100644 --- a/nd4j/nd4j-common/pom.xml +++ b/nd4j/nd4j-common/pom.xml @@ -56,8 +56,16 @@ slf4j-api - junit - junit + org.junit.jupiter + junit-jupiter-api + + + org.junit.jupiter + junit-jupiter-engine + + + org.junit.vintage + junit-vintage-engine commons-io diff --git a/nd4j/nd4j-onnxruntime/pom.xml b/nd4j/nd4j-onnxruntime/pom.xml index 013d87616..213348627 100644 --- a/nd4j/nd4j-onnxruntime/pom.xml +++ b/nd4j/nd4j-onnxruntime/pom.xml @@ -66,15 +66,18 @@ - junit - junit + org.junit.jupiter + junit-jupiter-api + + + org.junit.jupiter + junit-jupiter-engine org.nd4j nd4j-native ${project.version} - test diff --git a/nd4j/nd4j-onnxruntime/src/test/java/org/nd4j/onnxruntime/runner/OnnxRuntimeRunnerTests.java b/nd4j/nd4j-onnxruntime/src/test/java/org/nd4j/onnxruntime/runner/OnnxRuntimeRunnerTests.java index 31ee661ba..1cb1859d3 100644 --- a/nd4j/nd4j-onnxruntime/src/test/java/org/nd4j/onnxruntime/runner/OnnxRuntimeRunnerTests.java +++ b/nd4j/nd4j-onnxruntime/src/test/java/org/nd4j/onnxruntime/runner/OnnxRuntimeRunnerTests.java @@ -19,17 +19,17 @@ */ package org.nd4j.onnxruntime.runner; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.nd4j.common.io.ClassPathResource; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import java.io.File; -import java.util.Arrays; import java.util.LinkedHashMap; import java.util.Map; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; + public class OnnxRuntimeRunnerTests { diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/pom.xml b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/pom.xml index bc00bb88f..ab0fa3096 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/pom.xml +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/pom.xml @@ -45,8 +45,12 @@ nd4j-parameter-server-model - junit - junit + org.junit.jupiter + junit-jupiter-api + + + org.junit.jupiter + junit-jupiter-engine org.nd4j diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/pom.xml b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/pom.xml index de219f99b..6b0de214f 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/pom.xml +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/pom.xml @@ -43,8 +43,12 @@ test - junit - junit + org.junit.jupiter + junit-jupiter-api + + + org.junit.jupiter + junit-jupiter-engine commons-io diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-rocksdb-storage/pom.xml b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-rocksdb-storage/pom.xml index a929e89fe..919ea3b91 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-rocksdb-storage/pom.xml +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-rocksdb-storage/pom.xml @@ -46,8 +46,12 @@ nd4j-parameter-server - junit - junit + org.junit.jupiter + junit-jupiter-api + + + org.junit.jupiter + junit-jupiter-engine org.nd4j diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/pom.xml b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/pom.xml index d860a8eb4..d29df2bde 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/pom.xml +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/pom.xml @@ -53,8 +53,12 @@ nd4j-parameter-server - junit - junit + org.junit.jupiter + junit-jupiter-api + + + org.junit.jupiter + junit-jupiter-engine com.typesafe.play diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/pom.xml b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/pom.xml index aa6f52514..d24533025 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/pom.xml +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/pom.xml @@ -50,8 +50,12 @@ test - junit - junit + org.junit.jupiter + junit-jupiter-api + + + org.junit.jupiter + junit-jupiter-engine com.beust diff --git a/nd4j/nd4j-serde/pom.xml b/nd4j/nd4j-serde/pom.xml index c65de9137..853488442 100644 --- a/nd4j/nd4j-serde/pom.xml +++ b/nd4j/nd4j-serde/pom.xml @@ -46,9 +46,16 @@ nd4j-api - junit - junit - test + org.junit.jupiter + junit-jupiter-api + + + org.junit.jupiter + junit-jupiter-engine + + + org.junit.vintage + junit-vintage-engine org.nd4j diff --git a/nd4j/nd4j-tensorflow/pom.xml b/nd4j/nd4j-tensorflow/pom.xml index 288d3e1ad..245a0999e 100644 --- a/nd4j/nd4j-tensorflow/pom.xml +++ b/nd4j/nd4j-tensorflow/pom.xml @@ -65,8 +65,12 @@ ${gson.version} - junit - junit + org.junit.jupiter + junit-jupiter-api + + + org.junit.jupiter + junit-jupiter-engine diff --git a/nd4j/nd4j-tvm/pom.xml b/nd4j/nd4j-tvm/pom.xml index 566032748..6f61a2c15 100644 --- a/nd4j/nd4j-tvm/pom.xml +++ b/nd4j/nd4j-tvm/pom.xml @@ -62,8 +62,12 @@ - junit - junit + org.junit.jupiter + junit-jupiter-api + + + org.junit.jupiter + junit-jupiter-engine diff --git a/nd4j/nd4j-tvm/src/test/java/org/nd4j/tvm/runner/TvmRunnerTests.java b/nd4j/nd4j-tvm/src/test/java/org/nd4j/tvm/runner/TvmRunnerTests.java index 147ccbcaa..567b6f192 100644 --- a/nd4j/nd4j-tvm/src/test/java/org/nd4j/tvm/runner/TvmRunnerTests.java +++ b/nd4j/nd4j-tvm/src/test/java/org/nd4j/tvm/runner/TvmRunnerTests.java @@ -19,32 +19,27 @@ */ package org.nd4j.tvm.runner; -import org.bytedeco.javacpp.*; import org.bytedeco.cpython.*; -import org.bytedeco.numpy.*; -import org.bytedeco.tvm.*; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; -import org.nd4j.common.io.ClassPathResource; + + +import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import java.io.File; -import java.util.Arrays; +import java.nio.file.Path; import java.util.LinkedHashMap; import java.util.Map; import static org.bytedeco.cpython.global.python.*; import static org.bytedeco.numpy.global.numpy.*; -import static org.bytedeco.tvm.global.tvm_runtime.*; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.io.TempDir; + public class TvmRunnerTests { - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); - static void PrepareTestLibs(String libPath) throws Exception { Py_AddPath(org.bytedeco.tvm.presets.tvm.cachePackages()); Py_Initialize(); @@ -81,11 +76,11 @@ public class TvmRunnerTests { } @Test - public void testAdd() throws Exception { + public void testAdd(@TempDir Path tempDir) throws Exception { /* try to use MKL when available */ System.setProperty("org.bytedeco.openblas.load", "mkl"); - File libPath = testDir.newFolder("lib"); + File libPath = tempDir.resolve("lib").toFile(); PrepareTestLibs(libPath.getAbsolutePath().replace(File.separatorChar, '/')); File f = new File(libPath, "test_relay_add.so"); INDArray x = Nd4j.scalar(1.0f).reshape(1,1); diff --git a/nd4j/pom.xml b/nd4j/pom.xml index 613c05f63..4836109b8 100644 --- a/nd4j/pom.xml +++ b/nd4j/pom.xml @@ -73,12 +73,6 @@ slf4j-log4j12 ${slf4j.version} - - junit - junit - ${junit.version} - test - org.nd4j nd4j-native-api diff --git a/nd4j/samediff-import/pom.xml b/nd4j/samediff-import/pom.xml index 1b395213f..931016732 100644 --- a/nd4j/samediff-import/pom.xml +++ b/nd4j/samediff-import/pom.xml @@ -49,8 +49,7 @@ 1.4.30 1.8 true - 4.13 - 5.4.2 + 5.8.0-M1 UTF-8 1.8 1.8 @@ -63,21 +62,17 @@ - junit - junit + org.junit.jupiter + junit-jupiter-api + + + org.junit.vintage + junit-vintage-engine - - org.junit.jupiter - junit-jupiter-api - ${junit-jupiter.version} - test - org.junit.jupiter junit-jupiter-engine - ${junit-jupiter.version} - test diff --git a/pom.xml b/pom.xml index 6691f87ea..bf2503468 100644 --- a/pom.xml +++ b/pom.xml @@ -95,8 +95,8 @@ - 1.7 - 1.7 + 1.8 + 1.8 1.8 1.8 UTF-8 @@ -202,7 +202,7 @@ 1.15.5 ${tensorflow.version}-${javacpp-presets.version} - 0.14.1 + 0.17 1.18 3.5 3.6 @@ -224,7 +224,7 @@ 2 2.0.29 1.7.21 - 4.13 + 5.8.0-M1 0.14.1 1.2.3 2.10.1 @@ -234,7 +234,7 @@ 1.18.16 2.0.0 7.7.1 - 20131018 + 20131018 3.8.0 2.6.1 false @@ -327,6 +327,30 @@ + + org.junit.jupiter + junit-jupiter-api + ${junit.version} + test + + + org.junit.vintage + junit-vintage-engine + ${junit.version} + test + + + org.junit.jupiter + junit-jupiter-engine + ${junit.version} + test + + + org.junit.jupiter + junit-jupiter + ${junit.version} + test + org.jetbrains.kotlin kotlin-stdlib-jdk8 @@ -613,28 +637,6 @@ - - - org.commonjava.maven.plugins - directory-maven-plugin - 0.3.1 - - - native-dir - - directory-of - - initialize - - nd4j.basedir - - org.nd4j - nd4j - - - - - org.apache.maven.plugins maven-source-plugin @@ -783,9 +785,6 @@ true - true - true - true true true true @@ -801,9 +800,6 @@ true - true - true - true true true true @@ -827,10 +823,6 @@ ${dl4j-test-resources.classifier} test - - org.walkmod - junit4git - diff --git a/python4j/pom.xml b/python4j/pom.xml index 36841acb1..c6b9e2165 100644 --- a/python4j/pom.xml +++ b/python4j/pom.xml @@ -59,8 +59,14 @@ test - junit - junit + org.junit.jupiter + junit-jupiter-api + ${junit.version} + test + + + org.junit.vintage + junit-vintage-engine ${junit.version} test diff --git a/rl4j/pom.xml b/rl4j/pom.xml index 46dde6766..8fd079262 100644 --- a/rl4j/pom.xml +++ b/rl4j/pom.xml @@ -58,10 +58,12 @@ - junit - junit - ${junit.version} - test + org.junit.jupiter + junit-jupiter-api + + + org.junit.vintage + junit-vintage-engine org.projectlombok