Upgrade dl4j to junit 5

master
agibsonccc 2021-03-15 13:02:01 +09:00
parent 552c5d0b72
commit fa1a31c877
259 changed files with 10586 additions and 21460 deletions

View File

@ -15,7 +15,7 @@
<commons.dbutils.version>1.7</commons.dbutils.version>
<lombok.version>1.18.8</lombok.version>
<logback.version>1.1.7</logback.version>
<junit.version>4.12</junit.version>
<junit.version>5.8.0-M1</junit.version>
<junit-jupiter.version>5.4.2</junit-jupiter.version>
<java.version>1.8</java.version>
<maven-shade-plugin.version>3.1.1</maven-shade-plugin.version>

View File

@ -17,13 +17,14 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.nd4j.codegen.ir;
public class SerializationTest {
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
public static void main(String...args) {
@DisplayName("Serialization Test")
class SerializationTest {
public static void main(String... args) {
}
}

View File

@ -17,29 +17,23 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.nd4j.codegen.dsl;
import org.apache.commons.lang3.StringUtils;
import org.junit.jupiter.api.Test;
import org.nd4j.codegen.impl.java.DocsGenerator;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
public class DocsGeneratorTest {
@DisplayName("Docs Generator Test")
class DocsGeneratorTest {
@Test
public void testJDtoMDAdapter() {
String original = "{@code %INPUT_TYPE% eye = eye(3,2)\n" +
" eye:\n" +
" [ 1, 0]\n" +
" [ 0, 1]\n" +
" [ 0, 0]}";
String expected = "{ INDArray eye = eye(3,2)\n" +
" eye:\n" +
" [ 1, 0]\n" +
" [ 0, 1]\n" +
" [ 0, 0]}";
@DisplayName("Test J Dto MD Adapter")
void testJDtoMDAdapter() {
String original = "{@code %INPUT_TYPE% eye = eye(3,2)\n" + " eye:\n" + " [ 1, 0]\n" + " [ 0, 1]\n" + " [ 0, 0]}";
String expected = "{ INDArray eye = eye(3,2)\n" + " eye:\n" + " [ 1, 0]\n" + " [ 0, 1]\n" + " [ 0, 0]}";
DocsGenerator.JavaDocToMDAdapter adapter = new DocsGenerator.JavaDocToMDAdapter(original);
String out = adapter.filter("@code", StringUtils.EMPTY).filter("%INPUT_TYPE%", "INDArray").toString();
assertEquals(out, expected);

View File

@ -34,6 +34,14 @@
<artifactId>datavec-api</artifactId>
<dependencies>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-api</artifactId>
</dependency>
<dependency>
<groupId>org.junit.vintage</groupId>
<artifactId>junit-vintage-engine</artifactId>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.records.reader.impl;
import org.apache.commons.io.FileUtils;
@ -27,46 +26,37 @@ import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;
import java.io.File;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals;
@DisplayName("Csv Line Sequence Record Reader Test")
class CSVLineSequenceRecordReaderTest extends BaseND4JTest {
public class CSVLineSequenceRecordReaderTest extends BaseND4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@TempDir
public Path testDir;
@Test
public void test() throws Exception {
File f = testDir.newFolder();
@DisplayName("Test")
void test(@TempDir Path testDir) throws Exception {
File f = testDir.toFile();
File source = new File(f, "temp.csv");
String str = "a,b,c\n1,2,3,4";
FileUtils.writeStringToFile(source, str, StandardCharsets.UTF_8);
SequenceRecordReader rr = new CSVLineSequenceRecordReader();
rr.initialize(new FileSplit(source));
List<List<Writable>> exp0 = Arrays.asList(
Collections.<Writable>singletonList(new Text("a")),
Collections.<Writable>singletonList(new Text("b")),
Collections.<Writable>singletonList(new Text("c")));
List<List<Writable>> exp1 = Arrays.asList(
Collections.<Writable>singletonList(new Text("1")),
Collections.<Writable>singletonList(new Text("2")),
Collections.<Writable>singletonList(new Text("3")),
Collections.<Writable>singletonList(new Text("4")));
for( int i=0; i<3; i++ ) {
List<List<Writable>> exp0 = Arrays.asList(Collections.<Writable>singletonList(new Text("a")), Collections.<Writable>singletonList(new Text("b")), Collections.<Writable>singletonList(new Text("c")));
List<List<Writable>> exp1 = Arrays.asList(Collections.<Writable>singletonList(new Text("1")), Collections.<Writable>singletonList(new Text("2")), Collections.<Writable>singletonList(new Text("3")), Collections.<Writable>singletonList(new Text("4")));
for (int i = 0; i < 3; i++) {
int count = 0;
while (rr.hasNext()) {
List<List<Writable>> next = rr.sequenceRecord();
@ -76,9 +66,7 @@ public class CSVLineSequenceRecordReaderTest extends BaseND4JTest {
assertEquals(exp1, next);
}
}
assertEquals(2, count);
rr.reset();
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.records.reader.impl;
import org.apache.commons.io.FileUtils;
@ -27,32 +26,34 @@ import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;
import java.io.File;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
@DisplayName("Csv Multi Sequence Record Reader Test")
class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@TempDir
public Path testDir;
@Test
public void testConcatMode() throws Exception {
for( int i=0; i<3; i++ ) {
@DisplayName("Test Concat Mode")
void testConcatMode() throws Exception {
for (int i = 0; i < 3; i++) {
String seqSep;
String seqSepRegex;
switch (i){
switch(i) {
case 0:
seqSep = "";
seqSepRegex = "^$";
@ -68,31 +69,23 @@ public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
default:
throw new RuntimeException();
}
String str = "a,b,c\n1,2,3,4\nx,y\n" + seqSep + "\nA,B,C";
File f = testDir.newFile();
File f = testDir.toFile();
FileUtils.writeStringToFile(f, str, StandardCharsets.UTF_8);
SequenceRecordReader seqRR = new CSVMultiSequenceRecordReader(seqSepRegex, CSVMultiSequenceRecordReader.Mode.CONCAT);
seqRR.initialize(new FileSplit(f));
List<List<Writable>> exp0 = new ArrayList<>();
for (String s : "a,b,c,1,2,3,4,x,y".split(",")) {
exp0.add(Collections.<Writable>singletonList(new Text(s)));
}
List<List<Writable>> exp1 = new ArrayList<>();
for (String s : "A,B,C".split(",")) {
exp1.add(Collections.<Writable>singletonList(new Text(s)));
}
assertEquals(exp0, seqRR.sequenceRecord());
assertEquals(exp1, seqRR.sequenceRecord());
assertFalse(seqRR.hasNext());
seqRR.reset();
assertEquals(exp0, seqRR.sequenceRecord());
assertEquals(exp1, seqRR.sequenceRecord());
assertFalse(seqRR.hasNext());
@ -100,13 +93,12 @@ public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
}
@Test
public void testEqualLength() throws Exception {
for( int i=0; i<3; i++ ) {
@DisplayName("Test Equal Length")
void testEqualLength() throws Exception {
for (int i = 0; i < 3; i++) {
String seqSep;
String seqSepRegex;
switch (i) {
switch(i) {
case 0:
seqSep = "";
seqSepRegex = "^$";
@ -122,27 +114,17 @@ public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
default:
throw new RuntimeException();
}
String str = "a,b\n1,2\nx,y\n" + seqSep + "\nA\nB\nC";
File f = testDir.newFile();
File f = testDir.toFile();
FileUtils.writeStringToFile(f, str, StandardCharsets.UTF_8);
SequenceRecordReader seqRR = new CSVMultiSequenceRecordReader(seqSepRegex, CSVMultiSequenceRecordReader.Mode.EQUAL_LENGTH);
seqRR.initialize(new FileSplit(f));
List<List<Writable>> exp0 = Arrays.asList(
Arrays.<Writable>asList(new Text("a"), new Text("1"), new Text("x")),
Arrays.<Writable>asList(new Text("b"), new Text("2"), new Text("y")));
List<List<Writable>> exp0 = Arrays.asList(Arrays.<Writable>asList(new Text("a"), new Text("1"), new Text("x")), Arrays.<Writable>asList(new Text("b"), new Text("2"), new Text("y")));
List<List<Writable>> exp1 = Collections.singletonList(Arrays.<Writable>asList(new Text("A"), new Text("B"), new Text("C")));
assertEquals(exp0, seqRR.sequenceRecord());
assertEquals(exp1, seqRR.sequenceRecord());
assertFalse(seqRR.hasNext());
seqRR.reset();
assertEquals(exp0, seqRR.sequenceRecord());
assertEquals(exp1, seqRR.sequenceRecord());
assertFalse(seqRR.hasNext());
@ -150,13 +132,12 @@ public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
}
@Test
public void testPadding() throws Exception {
for( int i=0; i<3; i++ ) {
@DisplayName("Test Padding")
void testPadding() throws Exception {
for (int i = 0; i < 3; i++) {
String seqSep;
String seqSepRegex;
switch (i) {
switch(i) {
case 0:
seqSep = "";
seqSepRegex = "^$";
@ -172,27 +153,17 @@ public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
default:
throw new RuntimeException();
}
String str = "a,b\n1\nx\n" + seqSep + "\nA\nB\nC";
File f = testDir.newFile();
File f = testDir.toFile();
FileUtils.writeStringToFile(f, str, StandardCharsets.UTF_8);
SequenceRecordReader seqRR = new CSVMultiSequenceRecordReader(seqSepRegex, CSVMultiSequenceRecordReader.Mode.PAD, new Text("PAD"));
seqRR.initialize(new FileSplit(f));
List<List<Writable>> exp0 = Arrays.asList(
Arrays.<Writable>asList(new Text("a"), new Text("1"), new Text("x")),
Arrays.<Writable>asList(new Text("b"), new Text("PAD"), new Text("PAD")));
List<List<Writable>> exp0 = Arrays.asList(Arrays.<Writable>asList(new Text("a"), new Text("1"), new Text("x")), Arrays.<Writable>asList(new Text("b"), new Text("PAD"), new Text("PAD")));
List<List<Writable>> exp1 = Collections.singletonList(Arrays.<Writable>asList(new Text("A"), new Text("B"), new Text("C")));
assertEquals(exp0, seqRR.sequenceRecord());
assertEquals(exp1, seqRR.sequenceRecord());
assertFalse(seqRR.hasNext());
seqRR.reset();
assertEquals(exp0, seqRR.sequenceRecord());
assertEquals(exp1, seqRR.sequenceRecord());
assertFalse(seqRR.hasNext());

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.records.reader.impl;
import org.datavec.api.records.SequenceRecord;
@ -27,61 +26,53 @@ import org.datavec.api.records.reader.impl.csv.CSVNLinesSequenceRecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource;
import java.util.ArrayList;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals;
public class CSVNLinesSequenceRecordReaderTest extends BaseND4JTest {
@DisplayName("Csvn Lines Sequence Record Reader Test")
class CSVNLinesSequenceRecordReaderTest extends BaseND4JTest {
@Test
public void testCSVNLinesSequenceRecordReader() throws Exception {
@DisplayName("Test CSVN Lines Sequence Record Reader")
void testCSVNLinesSequenceRecordReader() throws Exception {
int nLinesPerSequence = 10;
SequenceRecordReader seqRR = new CSVNLinesSequenceRecordReader(nLinesPerSequence);
seqRR.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
CSVRecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
int count = 0;
while (seqRR.hasNext()) {
List<List<Writable>> next = seqRR.sequenceRecord();
List<List<Writable>> expected = new ArrayList<>();
for (int i = 0; i < nLinesPerSequence; i++) {
expected.add(rr.next());
}
assertEquals(10, next.size());
assertEquals(expected, next);
count++;
}
assertEquals(150 / nLinesPerSequence, count);
}
@Test
public void testCSVNlinesSequenceRecordReaderMetaData() throws Exception {
@DisplayName("Test CSV Nlines Sequence Record Reader Meta Data")
void testCSVNlinesSequenceRecordReaderMetaData() throws Exception {
int nLinesPerSequence = 10;
SequenceRecordReader seqRR = new CSVNLinesSequenceRecordReader(nLinesPerSequence);
seqRR.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
CSVRecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
List<List<List<Writable>>> out = new ArrayList<>();
while (seqRR.hasNext()) {
List<List<Writable>> next = seqRR.sequenceRecord();
out.add(next);
}
seqRR.reset();
List<List<List<Writable>>> out2 = new ArrayList<>();
List<SequenceRecord> out3 = new ArrayList<>();
@ -92,11 +83,8 @@ public class CSVNLinesSequenceRecordReaderTest extends BaseND4JTest {
meta.add(seq.getMetaData());
out3.add(seq);
}
assertEquals(out, out2);
List<SequenceRecord> out4 = seqRR.loadSequenceFromMetaData(meta);
assertEquals(out3, out4);
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.records.reader.impl;
import org.apache.commons.io.FileUtils;
@ -34,10 +33,10 @@ import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
@ -47,41 +46,44 @@ import java.util.Arrays;
import java.util.List;
import java.util.NoSuchElementException;
import static org.junit.Assert.*;
import static org.junit.jupiter.api.Assertions.*;
@DisplayName("Csv Record Reader Test")
class CSVRecordReaderTest extends BaseND4JTest {
public class CSVRecordReaderTest extends BaseND4JTest {
@Test
public void testNext() throws Exception {
@DisplayName("Test Next")
void testNext() throws Exception {
CSVRecordReader reader = new CSVRecordReader();
reader.initialize(new StringSplit("1,1,8.0,,,,14.0,,,,15.0,,,,,,,,,,,,1"));
while (reader.hasNext()) {
List<Writable> vals = reader.next();
List<Writable> arr = new ArrayList<>(vals);
assertEquals("Entry count", 23, vals.size());
assertEquals(23, vals.size(), "Entry count");
Text lastEntry = (Text) arr.get(arr.size() - 1);
assertEquals("Last entry garbage", 1, lastEntry.getLength());
assertEquals(1, lastEntry.getLength(), "Last entry garbage");
}
}
@Test
public void testEmptyEntries() throws Exception {
@DisplayName("Test Empty Entries")
void testEmptyEntries() throws Exception {
CSVRecordReader reader = new CSVRecordReader();
reader.initialize(new StringSplit("1,1,8.0,,,,14.0,,,,15.0,,,,,,,,,,,,"));
while (reader.hasNext()) {
List<Writable> vals = reader.next();
assertEquals("Entry count", 23, vals.size());
assertEquals(23, vals.size(), "Entry count");
}
}
@Test
public void testReset() throws Exception {
@DisplayName("Test Reset")
void testReset() throws Exception {
CSVRecordReader rr = new CSVRecordReader(0, ',');
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
int nResets = 5;
for (int i = 0; i < nResets; i++) {
int lineCount = 0;
while (rr.hasNext()) {
List<Writable> line = rr.next();
@ -95,7 +97,8 @@ public class CSVRecordReaderTest extends BaseND4JTest {
}
@Test
public void testResetWithSkipLines() throws Exception {
@DisplayName("Test Reset With Skip Lines")
void testResetWithSkipLines() throws Exception {
CSVRecordReader rr = new CSVRecordReader(10, ',');
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
int lineCount = 0;
@ -114,7 +117,8 @@ public class CSVRecordReaderTest extends BaseND4JTest {
}
@Test
public void testWrite() throws Exception {
@DisplayName("Test Write")
void testWrite() throws Exception {
List<List<Writable>> list = new ArrayList<>();
StringBuilder sb = new StringBuilder();
for (int i = 0; i < 10; i++) {
@ -130,81 +134,72 @@ public class CSVRecordReaderTest extends BaseND4JTest {
}
list.add(temp);
}
String expected = sb.toString();
Path p = Files.createTempFile("csvwritetest", "csv");
p.toFile().deleteOnExit();
FileRecordWriter writer = new CSVRecordWriter();
FileSplit fileSplit = new FileSplit(p.toFile());
writer.initialize(fileSplit,new NumberOfRecordsPartitioner());
writer.initialize(fileSplit, new NumberOfRecordsPartitioner());
for (List<Writable> c : list) {
writer.write(c);
}
writer.close();
//Read file back in; compare
// Read file back in; compare
String fileContents = FileUtils.readFileToString(p.toFile(), FileRecordWriter.DEFAULT_CHARSET.name());
// System.out.println(expected);
// System.out.println("----------");
// System.out.println(fileContents);
// System.out.println(expected);
// System.out.println("----------");
// System.out.println(fileContents);
assertEquals(expected, fileContents);
}
@Test
public void testTabsAsSplit1() throws Exception {
@DisplayName("Test Tabs As Split 1")
void testTabsAsSplit1() throws Exception {
CSVRecordReader reader = new CSVRecordReader(0, '\t');
reader.initialize(new FileSplit(new ClassPathResource("datavec-api/tabbed.txt").getFile()));
while (reader.hasNext()) {
List<Writable> list = new ArrayList<>(reader.next());
assertEquals(2, list.size());
}
}
@Test
public void testPipesAsSplit() throws Exception {
@DisplayName("Test Pipes As Split")
void testPipesAsSplit() throws Exception {
CSVRecordReader reader = new CSVRecordReader(0, '|');
reader.initialize(new FileSplit(new ClassPathResource("datavec-api/issue414.csv").getFile()));
int lineidx = 0;
List<Integer> sixthColumn = Arrays.asList(13, 95, 15, 25);
while (reader.hasNext()) {
List<Writable> list = new ArrayList<>(reader.next());
assertEquals(10, list.size());
assertEquals((long)sixthColumn.get(lineidx), list.get(5).toInt());
assertEquals((long) sixthColumn.get(lineidx), list.get(5).toInt());
lineidx++;
}
}
@Test
public void testWithQuotes() throws Exception {
@DisplayName("Test With Quotes")
void testWithQuotes() throws Exception {
CSVRecordReader reader = new CSVRecordReader(0, ',', '\"');
reader.initialize(new StringSplit("1,0,3,\"Braund, Mr. Owen Harris\",male,\"\"\"\""));
while (reader.hasNext()) {
List<Writable> vals = reader.next();
assertEquals("Entry count", 6, vals.size());
assertEquals("1", vals.get(0).toString());
assertEquals("0", vals.get(1).toString());
assertEquals("3", vals.get(2).toString());
assertEquals("Braund, Mr. Owen Harris", vals.get(3).toString());
assertEquals("male", vals.get(4).toString());
assertEquals("\"", vals.get(5).toString());
assertEquals(6, vals.size(), "Entry count");
assertEquals(vals.get(0).toString(), "1");
assertEquals(vals.get(1).toString(), "0");
assertEquals(vals.get(2).toString(), "3");
assertEquals(vals.get(3).toString(), "Braund, Mr. Owen Harris");
assertEquals(vals.get(4).toString(), "male");
assertEquals(vals.get(5).toString(), "\"");
}
}
@Test
public void testMeta() throws Exception {
@DisplayName("Test Meta")
void testMeta() throws Exception {
CSVRecordReader rr = new CSVRecordReader(0, ',');
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
int lineCount = 0;
List<RecordMetaData> metaList = new ArrayList<>();
List<List<Writable>> writables = new ArrayList<>();
@ -213,30 +208,25 @@ public class CSVRecordReaderTest extends BaseND4JTest {
assertEquals(5, r.getRecord().size());
lineCount++;
RecordMetaData meta = r.getMetaData();
// System.out.println(r.getRecord() + "\t" + meta.getLocation() + "\t" + meta.getURI());
// System.out.println(r.getRecord() + "\t" + meta.getLocation() + "\t" + meta.getURI());
metaList.add(meta);
writables.add(r.getRecord());
}
assertFalse(rr.hasNext());
assertEquals(150, lineCount);
rr.reset();
System.out.println("\n\n\n--------------------------------");
List<Record> contents = rr.loadFromMetaData(metaList);
assertEquals(150, contents.size());
// for(Record r : contents ){
// System.out.println(r);
// }
// for(Record r : contents ){
// System.out.println(r);
// }
List<RecordMetaData> meta2 = new ArrayList<>();
meta2.add(metaList.get(100));
meta2.add(metaList.get(90));
meta2.add(metaList.get(80));
meta2.add(metaList.get(70));
meta2.add(metaList.get(60));
List<Record> contents2 = rr.loadFromMetaData(meta2);
assertEquals(writables.get(100), contents2.get(0).getRecord());
assertEquals(writables.get(90), contents2.get(1).getRecord());
@ -246,50 +236,49 @@ public class CSVRecordReaderTest extends BaseND4JTest {
}
@Test
public void testRegex() throws Exception {
CSVRecordReader reader = new CSVRegexRecordReader(0, ",", null, new String[] {null, "(.+) (.+) (.+)"});
@DisplayName("Test Regex")
void testRegex() throws Exception {
CSVRecordReader reader = new CSVRegexRecordReader(0, ",", null, new String[] { null, "(.+) (.+) (.+)" });
reader.initialize(new StringSplit("normal,1.2.3.4 space separator"));
while (reader.hasNext()) {
List<Writable> vals = reader.next();
assertEquals("Entry count", 4, vals.size());
assertEquals("normal", vals.get(0).toString());
assertEquals("1.2.3.4", vals.get(1).toString());
assertEquals("space", vals.get(2).toString());
assertEquals("separator", vals.get(3).toString());
assertEquals(4, vals.size(), "Entry count");
assertEquals(vals.get(0).toString(), "normal");
assertEquals(vals.get(1).toString(), "1.2.3.4");
assertEquals(vals.get(2).toString(), "space");
assertEquals(vals.get(3).toString(), "separator");
}
}
@Test(expected = NoSuchElementException.class)
public void testCsvSkipAllLines() throws IOException, InterruptedException {
final int numLines = 4;
final List<Writable> lineList = Arrays.asList((Writable) new IntWritable(numLines - 1),
(Writable) new Text("one"), (Writable) new Text("two"), (Writable) new Text("three"));
String header = ",one,two,three";
List<String> lines = new ArrayList<>();
for (int i = 0; i < numLines; i++)
lines.add(Integer.toString(i) + header);
File tempFile = File.createTempFile("csvSkipLines", ".csv");
FileUtils.writeLines(tempFile, lines);
CSVRecordReader rr = new CSVRecordReader(numLines, ',');
rr.initialize(new FileSplit(tempFile));
rr.reset();
assertTrue(!rr.hasNext());
rr.next();
@Test
@DisplayName("Test Csv Skip All Lines")
void testCsvSkipAllLines() {
assertThrows(NoSuchElementException.class, () -> {
final int numLines = 4;
final List<Writable> lineList = Arrays.asList((Writable) new IntWritable(numLines - 1), (Writable) new Text("one"), (Writable) new Text("two"), (Writable) new Text("three"));
String header = ",one,two,three";
List<String> lines = new ArrayList<>();
for (int i = 0; i < numLines; i++) lines.add(Integer.toString(i) + header);
File tempFile = File.createTempFile("csvSkipLines", ".csv");
FileUtils.writeLines(tempFile, lines);
CSVRecordReader rr = new CSVRecordReader(numLines, ',');
rr.initialize(new FileSplit(tempFile));
rr.reset();
assertTrue(!rr.hasNext());
rr.next();
});
}
@Test
public void testCsvSkipAllButOneLine() throws IOException, InterruptedException {
@DisplayName("Test Csv Skip All But One Line")
void testCsvSkipAllButOneLine() throws IOException, InterruptedException {
final int numLines = 4;
final List<Writable> lineList = Arrays.<Writable>asList(new Text(Integer.toString(numLines - 1)),
new Text("one"), new Text("two"), new Text("three"));
final List<Writable> lineList = Arrays.<Writable>asList(new Text(Integer.toString(numLines - 1)), new Text("one"), new Text("two"), new Text("three"));
String header = ",one,two,three";
List<String> lines = new ArrayList<>();
for (int i = 0; i < numLines; i++)
lines.add(Integer.toString(i) + header);
for (int i = 0; i < numLines; i++) lines.add(Integer.toString(i) + header);
File tempFile = File.createTempFile("csvSkipLines", ".csv");
FileUtils.writeLines(tempFile, lines);
CSVRecordReader rr = new CSVRecordReader(numLines - 1, ',');
rr.initialize(new FileSplit(tempFile));
rr.reset();
@ -297,50 +286,45 @@ public class CSVRecordReaderTest extends BaseND4JTest {
assertEquals(rr.next(), lineList);
}
@Test
public void testStreamReset() throws Exception {
@DisplayName("Test Stream Reset")
void testStreamReset() throws Exception {
CSVRecordReader rr = new CSVRecordReader(0, ',');
rr.initialize(new InputStreamInputSplit(new ClassPathResource("datavec-api/iris.dat").getInputStream()));
int count = 0;
while(rr.hasNext()){
while (rr.hasNext()) {
assertNotNull(rr.next());
count++;
}
assertEquals(150, count);
assertFalse(rr.resetSupported());
try{
try {
rr.reset();
fail("Expected exception");
} catch (Exception e){
} catch (Exception e) {
String msg = e.getMessage();
String msg2 = e.getCause().getMessage();
assertTrue(msg, msg.contains("Error during LineRecordReader reset"));
assertTrue(msg2, msg2.contains("Reset not supported from streams"));
// e.printStackTrace();
assertTrue(msg.contains("Error during LineRecordReader reset"),msg);
assertTrue(msg2.contains("Reset not supported from streams"),msg2);
// e.printStackTrace();
}
}
@Test
public void testUsefulExceptionNoInit(){
@DisplayName("Test Useful Exception No Init")
void testUsefulExceptionNoInit() {
CSVRecordReader rr = new CSVRecordReader(0, ',');
try{
try {
rr.hasNext();
fail("Expected exception");
} catch (Exception e){
assertTrue(e.getMessage(), e.getMessage().contains("initialized"));
} catch (Exception e) {
assertTrue( e.getMessage().contains("initialized"),e.getMessage());
}
try{
try {
rr.next();
fail("Expected exception");
} catch (Exception e){
assertTrue(e.getMessage(), e.getMessage().contains("initialized"));
} catch (Exception e) {
assertTrue(e.getMessage().contains("initialized"),e.getMessage());
}
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.records.reader.impl;
import org.datavec.api.records.SequenceRecord;
@ -28,11 +27,10 @@ import org.datavec.api.split.InputSplit;
import org.datavec.api.split.NumberedFileInputSplit;
import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource;
import java.io.File;
import java.io.InputStream;
import java.io.OutputStream;
@ -41,25 +39,27 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals;
@DisplayName("Csv Sequence Record Reader Test")
class CSVSequenceRecordReaderTest extends BaseND4JTest {
public class CSVSequenceRecordReaderTest extends BaseND4JTest {
@Rule
public TemporaryFolder tempDir = new TemporaryFolder();
@TempDir
public Path tempDir;
@Test
public void test() throws Exception {
@DisplayName("Test")
void test() throws Exception {
CSVSequenceRecordReader seqReader = new CSVSequenceRecordReader(1, ",");
seqReader.initialize(new TestInputSplit());
int sequenceCount = 0;
while (seqReader.hasNext()) {
List<List<Writable>> sequence = seqReader.sequenceRecord();
assertEquals(4, sequence.size()); //4 lines, plus 1 header line
// 4 lines, plus 1 header line
assertEquals(4, sequence.size());
Iterator<List<Writable>> timeStepIter = sequence.iterator();
int lineCount = 0;
while (timeStepIter.hasNext()) {
@ -80,19 +80,18 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
}
@Test
public void testReset() throws Exception {
@DisplayName("Test Reset")
void testReset() throws Exception {
CSVSequenceRecordReader seqReader = new CSVSequenceRecordReader(1, ",");
seqReader.initialize(new TestInputSplit());
int nTests = 5;
for (int i = 0; i < nTests; i++) {
seqReader.reset();
int sequenceCount = 0;
while (seqReader.hasNext()) {
List<List<Writable>> sequence = seqReader.sequenceRecord();
assertEquals(4, sequence.size()); //4 lines, plus 1 header line
// 4 lines, plus 1 header line
assertEquals(4, sequence.size());
Iterator<List<Writable>> timeStepIter = sequence.iterator();
int lineCount = 0;
while (timeStepIter.hasNext()) {
@ -107,15 +106,15 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
}
@Test
public void testMetaData() throws Exception {
@DisplayName("Test Meta Data")
void testMetaData() throws Exception {
CSVSequenceRecordReader seqReader = new CSVSequenceRecordReader(1, ",");
seqReader.initialize(new TestInputSplit());
List<List<List<Writable>>> l = new ArrayList<>();
while (seqReader.hasNext()) {
List<List<Writable>> sequence = seqReader.sequenceRecord();
assertEquals(4, sequence.size()); //4 lines, plus 1 header line
// 4 lines, plus 1 header line
assertEquals(4, sequence.size());
Iterator<List<Writable>> timeStepIter = sequence.iterator();
int lineCount = 0;
while (timeStepIter.hasNext()) {
@ -123,10 +122,8 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
lineCount++;
}
assertEquals(4, lineCount);
l.add(sequence);
}
List<SequenceRecord> l2 = new ArrayList<>();
List<RecordMetaData> meta = new ArrayList<>();
seqReader.reset();
@ -136,7 +133,6 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
meta.add(sr.getMetaData());
}
assertEquals(3, l2.size());
List<SequenceRecord> fromMeta = seqReader.loadSequenceFromMetaData(meta);
for (int i = 0; i < 3; i++) {
assertEquals(l.get(i), l2.get(i).getSequenceRecord());
@ -144,8 +140,8 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
}
}
private static class
TestInputSplit implements InputSplit {
@DisplayName("Test Input Split")
private static class TestInputSplit implements InputSplit {
@Override
public boolean canWriteToLocation(URI location) {
@ -164,7 +160,6 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
@Override
public void updateSplitLocations(boolean reset) {
}
@Override
@ -174,7 +169,6 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
@Override
public void bootStrapForWrite() {
}
@Override
@ -222,38 +216,30 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
@Override
public void reset() {
//No op
// No op
}
@Override
public boolean resetSupported() {
return true;
}
}
@Test
public void testCsvSeqAndNumberedFileSplit() throws Exception {
File baseDir = tempDir.newFolder();
//Simple sanity check unit test
@DisplayName("Test Csv Seq And Numbered File Split")
void testCsvSeqAndNumberedFileSplit(@TempDir Path tempDir) throws Exception {
File baseDir = tempDir.toFile();
// Simple sanity check unit test
for (int i = 0; i < 3; i++) {
new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(baseDir);
}
//Load time series from CSV sequence files; compare to SequenceRecordReaderDataSetIterator
// Load time series from CSV sequence files; compare to SequenceRecordReaderDataSetIterator
ClassPathResource resource = new ClassPathResource("csvsequence_0.txt");
String featuresPath = new File(baseDir, "csvsequence_%d.txt").getAbsolutePath();
SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
while(featureReader.hasNext()){
while (featureReader.hasNext()) {
featureReader.nextSequence();
}
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.records.reader.impl;
import org.datavec.api.records.reader.SequenceRecordReader;
@ -25,94 +24,87 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.records.reader.impl.csv.CSVVariableSlidingWindowRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource;
import java.util.LinkedList;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals;
public class CSVVariableSlidingWindowRecordReaderTest extends BaseND4JTest {
@DisplayName("Csv Variable Sliding Window Record Reader Test")
class CSVVariableSlidingWindowRecordReaderTest extends BaseND4JTest {
@Test
public void testCSVVariableSlidingWindowRecordReader() throws Exception {
@DisplayName("Test CSV Variable Sliding Window Record Reader")
void testCSVVariableSlidingWindowRecordReader() throws Exception {
int maxLinesPerSequence = 3;
SequenceRecordReader seqRR = new CSVVariableSlidingWindowRecordReader(maxLinesPerSequence);
seqRR.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
CSVRecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
int count = 0;
while (seqRR.hasNext()) {
List<List<Writable>> next = seqRR.sequenceRecord();
if(count==maxLinesPerSequence-1) {
if (count == maxLinesPerSequence - 1) {
LinkedList<List<Writable>> expected = new LinkedList<>();
for (int i = 0; i < maxLinesPerSequence; i++) {
expected.addFirst(rr.next());
}
assertEquals(expected, next);
}
if(count==maxLinesPerSequence) {
if (count == maxLinesPerSequence) {
assertEquals(maxLinesPerSequence, next.size());
}
if(count==0) { // first seq should be length 1
if (count == 0) {
// first seq should be length 1
assertEquals(1, next.size());
}
if(count>151) { // last seq should be length 1
if (count > 151) {
// last seq should be length 1
assertEquals(1, next.size());
}
count++;
}
assertEquals(152, count);
}
@Test
public void testCSVVariableSlidingWindowRecordReaderStride() throws Exception {
@DisplayName("Test CSV Variable Sliding Window Record Reader Stride")
void testCSVVariableSlidingWindowRecordReaderStride() throws Exception {
int maxLinesPerSequence = 3;
int stride = 2;
SequenceRecordReader seqRR = new CSVVariableSlidingWindowRecordReader(maxLinesPerSequence, stride);
seqRR.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
CSVRecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
int count = 0;
while (seqRR.hasNext()) {
List<List<Writable>> next = seqRR.sequenceRecord();
if(count==maxLinesPerSequence-1) {
if (count == maxLinesPerSequence - 1) {
LinkedList<List<Writable>> expected = new LinkedList<>();
for(int s = 0; s < stride; s++) {
for (int s = 0; s < stride; s++) {
expected = new LinkedList<>();
for (int i = 0; i < maxLinesPerSequence; i++) {
expected.addFirst(rr.next());
}
}
assertEquals(expected, next);
}
if(count==maxLinesPerSequence) {
if (count == maxLinesPerSequence) {
assertEquals(maxLinesPerSequence, next.size());
}
if(count==0) { // first seq should be length 2
if (count == 0) {
// first seq should be length 2
assertEquals(2, next.size());
}
if(count>151) { // last seq should be length 1
if (count > 151) {
// last seq should be length 1
assertEquals(1, next.size());
}
count++;
}
assertEquals(76, count);
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.records.reader.impl;
import org.apache.commons.io.FileUtils;
@ -29,44 +28,38 @@ import org.datavec.api.records.reader.impl.filebatch.FileBatchRecordReader;
import org.datavec.api.records.reader.impl.filebatch.FileBatchSequenceRecordReader;
import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.loader.FileBatch;
import java.io.File;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.*;
public class FileBatchRecordReaderTest extends BaseND4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@DisplayName("File Batch Record Reader Test")
class FileBatchRecordReaderTest extends BaseND4JTest {
@Test
public void testCsv() throws Exception {
//This is an unrealistic use case - one line/record per CSV
File baseDir = testDir.newFolder();
@DisplayName("Test Csv")
void testCsv(@TempDir Path testDir) throws Exception {
// This is an unrealistic use case - one line/record per CSV
File baseDir = testDir.toFile();
List<File> fileList = new ArrayList<>();
for( int i=0; i<10; i++ ){
for (int i = 0; i < 10; i++) {
String s = "file_" + i + "," + i + "," + i;
File f = new File(baseDir, "origFile" + i + ".csv");
FileUtils.writeStringToFile(f, s, StandardCharsets.UTF_8);
fileList.add(f);
}
FileBatch fb = FileBatch.forFiles(fileList);
RecordReader rr = new CSVRecordReader();
FileBatchRecordReader fbrr = new FileBatchRecordReader(rr, fb);
for( int test=0; test<3; test++) {
for (int test = 0; test < 3; test++) {
for (int i = 0; i < 10; i++) {
assertTrue(fbrr.hasNext());
List<Writable> next = fbrr.next();
@ -83,15 +76,15 @@ public class FileBatchRecordReaderTest extends BaseND4JTest {
}
@Test
public void testCsvSequence() throws Exception {
//CSV sequence - 3 lines per file, 10 files
File baseDir = testDir.newFolder();
@DisplayName("Test Csv Sequence")
void testCsvSequence(@TempDir Path testDir) throws Exception {
// CSV sequence - 3 lines per file, 10 files
File baseDir = testDir.toFile();
List<File> fileList = new ArrayList<>();
for( int i=0; i<10; i++ ){
for (int i = 0; i < 10; i++) {
StringBuilder sb = new StringBuilder();
for( int j=0; j<3; j++ ){
if(j > 0)
for (int j = 0; j < 3; j++) {
if (j > 0)
sb.append("\n");
sb.append("file_" + i + "," + i + "," + j);
}
@ -99,19 +92,16 @@ public class FileBatchRecordReaderTest extends BaseND4JTest {
FileUtils.writeStringToFile(f, sb.toString(), StandardCharsets.UTF_8);
fileList.add(f);
}
FileBatch fb = FileBatch.forFiles(fileList);
SequenceRecordReader rr = new CSVSequenceRecordReader();
FileBatchSequenceRecordReader fbrr = new FileBatchSequenceRecordReader(rr, fb);
for( int test=0; test<3; test++) {
for (int test = 0; test < 3; test++) {
for (int i = 0; i < 10; i++) {
assertTrue(fbrr.hasNext());
List<List<Writable>> next = fbrr.sequenceRecord();
assertEquals(3, next.size());
int count = 0;
for(List<Writable> step : next ){
for (List<Writable> step : next) {
String s1 = "file_" + i;
assertEquals(s1, step.get(0).toString());
assertEquals(String.valueOf(i), step.get(1).toString());
@ -123,5 +113,4 @@ public class FileBatchRecordReaderTest extends BaseND4JTest {
fbrr.reset();
}
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.records.reader.impl;
import org.datavec.api.records.Record;
@ -26,28 +25,28 @@ import org.datavec.api.split.CollectionInputSplit;
import org.datavec.api.split.FileSplit;
import org.datavec.api.split.InputSplit;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
public class FileRecordReaderTest extends BaseND4JTest {
@DisplayName("File Record Reader Test")
class FileRecordReaderTest extends BaseND4JTest {
@Test
public void testReset() throws Exception {
@DisplayName("Test Reset")
void testReset() throws Exception {
FileRecordReader rr = new FileRecordReader();
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
int nResets = 5;
for (int i = 0; i < nResets; i++) {
int lineCount = 0;
while (rr.hasNext()) {
List<Writable> line = rr.next();
@ -61,25 +60,20 @@ public class FileRecordReaderTest extends BaseND4JTest {
}
@Test
public void testMeta() throws Exception {
@DisplayName("Test Meta")
void testMeta() throws Exception {
FileRecordReader rr = new FileRecordReader();
URI[] arr = new URI[3];
arr[0] = new ClassPathResource("datavec-api/csvsequence_0.txt").getFile().toURI();
arr[1] = new ClassPathResource("datavec-api/csvsequence_1.txt").getFile().toURI();
arr[2] = new ClassPathResource("datavec-api/csvsequence_2.txt").getFile().toURI();
InputSplit is = new CollectionInputSplit(Arrays.asList(arr));
rr.initialize(is);
List<List<Writable>> out = new ArrayList<>();
while (rr.hasNext()) {
out.add(rr.next());
}
assertEquals(3, out.size());
rr.reset();
List<List<Writable>> out2 = new ArrayList<>();
List<Record> out3 = new ArrayList<>();
@ -90,13 +84,10 @@ public class FileRecordReaderTest extends BaseND4JTest {
out2.add(r.getRecord());
out3.add(r);
meta.add(r.getMetaData());
assertEquals(arr[count++], r.getMetaData().getURI());
}
assertEquals(out, out2);
List<Record> fromMeta = rr.loadFromMetaData(meta);
assertEquals(out3, fromMeta);
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.records.reader.impl;
import org.datavec.api.records.reader.RecordReader;
@ -29,96 +28,80 @@ import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource;
import org.nd4j.shade.jackson.core.JsonFactory;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import java.io.File;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals;
@DisplayName("Jackson Line Record Reader Test")
class JacksonLineRecordReaderTest extends BaseND4JTest {
public class JacksonLineRecordReaderTest extends BaseND4JTest {
@TempDir
public Path testDir;
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
public JacksonLineRecordReaderTest() {
}
public JacksonLineRecordReaderTest() {
}
private static FieldSelection getFieldSelection() {
return new FieldSelection.Builder().addField("value1").
addField("value2").
addField("value3").
addField("value4").
addField("value5").
addField("value6").
addField("value7").
addField("value8").
addField("value9").
addField("value10").build();
return new FieldSelection.Builder().addField("value1").addField("value2").addField("value3").addField("value4").addField("value5").addField("value6").addField("value7").addField("value8").addField("value9").addField("value10").build();
}
@Test
public void testReadJSON() throws Exception {
@DisplayName("Test Read JSON")
void testReadJSON() throws Exception {
RecordReader rr = new JacksonLineRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()));
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/json/json_test_3.txt").getFile()));
testJacksonRecordReader(rr);
}
private static void testJacksonRecordReader(RecordReader rr) {
while (rr.hasNext()) {
List<Writable> json0 = rr.next();
//System.out.println(json0);
assert(json0.size() > 0);
}
}
private static void testJacksonRecordReader(RecordReader rr) {
while (rr.hasNext()) {
List<Writable> json0 = rr.next();
// System.out.println(json0);
assert (json0.size() > 0);
}
}
@Test
public void testJacksonLineSequenceRecordReader() throws Exception {
File dir = testDir.newFolder();
new ClassPathResource("datavec-api/JacksonLineSequenceRecordReaderTest/").copyDirectory(dir);
FieldSelection f = new FieldSelection.Builder().addField("a").addField(new Text("MISSING_B"), "b")
.addField(new Text("MISSING_CX"), "c", "x").build();
JacksonLineSequenceRecordReader rr = new JacksonLineSequenceRecordReader(f, new ObjectMapper(new JsonFactory()));
File[] files = dir.listFiles();
Arrays.sort(files);
URI[] u = new URI[files.length];
for( int i=0; i<files.length; i++ ){
u[i] = files[i].toURI();
}
rr.initialize(new CollectionInputSplit(u));
List<List<Writable>> expSeq0 = new ArrayList<>();
expSeq0.add(Arrays.asList((Writable) new Text("aValue0"), new Text("bValue0"), new Text("cxValue0")));
expSeq0.add(Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1")));
expSeq0.add(Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX")));
List<List<Writable>> expSeq1 = new ArrayList<>();
expSeq1.add(Arrays.asList((Writable) new Text("aValue3"), new Text("bValue3"), new Text("cxValue3")));
int count = 0;
while(rr.hasNext()){
List<List<Writable>> next = rr.sequenceRecord();
if(count++ == 0){
assertEquals(expSeq0, next);
} else {
assertEquals(expSeq1, next);
}
}
assertEquals(2, count);
}
@DisplayName("Test Jackson Line Sequence Record Reader")
void testJacksonLineSequenceRecordReader(@TempDir Path testDir) throws Exception {
File dir = testDir.toFile();
new ClassPathResource("datavec-api/JacksonLineSequenceRecordReaderTest/").copyDirectory(dir);
FieldSelection f = new FieldSelection.Builder().addField("a").addField(new Text("MISSING_B"), "b").addField(new Text("MISSING_CX"), "c", "x").build();
JacksonLineSequenceRecordReader rr = new JacksonLineSequenceRecordReader(f, new ObjectMapper(new JsonFactory()));
File[] files = dir.listFiles();
Arrays.sort(files);
URI[] u = new URI[files.length];
for (int i = 0; i < files.length; i++) {
u[i] = files[i].toURI();
}
rr.initialize(new CollectionInputSplit(u));
List<List<Writable>> expSeq0 = new ArrayList<>();
expSeq0.add(Arrays.asList((Writable) new Text("aValue0"), new Text("bValue0"), new Text("cxValue0")));
expSeq0.add(Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1")));
expSeq0.add(Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX")));
List<List<Writable>> expSeq1 = new ArrayList<>();
expSeq1.add(Arrays.asList((Writable) new Text("aValue3"), new Text("bValue3"), new Text("cxValue3")));
int count = 0;
while (rr.hasNext()) {
List<List<Writable>> next = rr.sequenceRecord();
if (count++ == 0) {
assertEquals(expSeq0, next);
} else {
assertEquals(expSeq1, next);
}
}
assertEquals(2, count);
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.records.reader.impl;
import org.datavec.api.io.labels.PathLabelGenerator;
@ -32,113 +31,94 @@ import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource;
import org.nd4j.shade.jackson.core.JsonFactory;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.dataformat.xml.XmlFactory;
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
import java.io.File;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
@DisplayName("Jackson Record Reader Test")
class JacksonRecordReaderTest extends BaseND4JTest {
public class JacksonRecordReaderTest extends BaseND4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@TempDir
public Path testDir;
@Test
public void testReadingJson() throws Exception {
//Load 3 values from 3 JSON files
//stricture: a:value, b:value, c:x:value, c:y:value
//And we want to load only a:value, b:value and c:x:value
//For first JSON file: all values are present
//For second JSON file: b:value is missing
//For third JSON file: c:x:value is missing
@DisplayName("Test Reading Json")
void testReadingJson(@TempDir Path testDir) throws Exception {
// Load 3 values from 3 JSON files
// stricture: a:value, b:value, c:x:value, c:y:value
// And we want to load only a:value, b:value and c:x:value
// For first JSON file: all values are present
// For second JSON file: b:value is missing
// For third JSON file: c:x:value is missing
ClassPathResource cpr = new ClassPathResource("datavec-api/json/");
File f = testDir.newFolder();
File f = testDir.toFile();
cpr.copyDirectory(f);
String path = new File(f, "json_test_%d.txt").getAbsolutePath();
InputSplit is = new NumberedFileInputSplit(path, 0, 2);
RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()));
rr.initialize(is);
testJacksonRecordReader(rr);
}
@Test
public void testReadingYaml() throws Exception {
//Exact same information as JSON format, but in YAML format
@DisplayName("Test Reading Yaml")
void testReadingYaml(@TempDir Path testDir) throws Exception {
// Exact same information as JSON format, but in YAML format
ClassPathResource cpr = new ClassPathResource("datavec-api/yaml/");
File f = testDir.newFolder();
File f = testDir.toFile();
cpr.copyDirectory(f);
String path = new File(f, "yaml_test_%d.txt").getAbsolutePath();
InputSplit is = new NumberedFileInputSplit(path, 0, 2);
RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new YAMLFactory()));
rr.initialize(is);
testJacksonRecordReader(rr);
}
@Test
public void testReadingXml() throws Exception {
//Exact same information as JSON format, but in XML format
@DisplayName("Test Reading Xml")
void testReadingXml(@TempDir Path testDir) throws Exception {
// Exact same information as JSON format, but in XML format
ClassPathResource cpr = new ClassPathResource("datavec-api/xml/");
File f = testDir.newFolder();
File f = testDir.toFile();
cpr.copyDirectory(f);
String path = new File(f, "xml_test_%d.txt").getAbsolutePath();
InputSplit is = new NumberedFileInputSplit(path, 0, 2);
RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new XmlFactory()));
rr.initialize(is);
testJacksonRecordReader(rr);
}
private static FieldSelection getFieldSelection() {
return new FieldSelection.Builder().addField("a").addField(new Text("MISSING_B"), "b")
.addField(new Text("MISSING_CX"), "c", "x").build();
return new FieldSelection.Builder().addField("a").addField(new Text("MISSING_B"), "b").addField(new Text("MISSING_CX"), "c", "x").build();
}
private static void testJacksonRecordReader(RecordReader rr) {
List<Writable> json0 = rr.next();
List<Writable> exp0 = Arrays.asList((Writable) new Text("aValue0"), new Text("bValue0"), new Text("cxValue0"));
assertEquals(exp0, json0);
List<Writable> json1 = rr.next();
List<Writable> exp1 =
Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1"));
List<Writable> exp1 = Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1"));
assertEquals(exp1, json1);
List<Writable> json2 = rr.next();
List<Writable> exp2 =
Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX"));
List<Writable> exp2 = Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX"));
assertEquals(exp2, json2);
assertFalse(rr.hasNext());
//Test reset
// Test reset
rr.reset();
assertEquals(exp0, rr.next());
assertEquals(exp1, rr.next());
@ -147,72 +127,50 @@ public class JacksonRecordReaderTest extends BaseND4JTest {
}
@Test
public void testAppendingLabels() throws Exception {
@DisplayName("Test Appending Labels")
void testAppendingLabels(@TempDir Path testDir) throws Exception {
ClassPathResource cpr = new ClassPathResource("datavec-api/json/");
File f = testDir.newFolder();
File f = testDir.toFile();
cpr.copyDirectory(f);
String path = new File(f, "json_test_%d.txt").getAbsolutePath();
InputSplit is = new NumberedFileInputSplit(path, 0, 2);
//Insert at the end:
RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1,
new LabelGen());
// Insert at the end:
RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1, new LabelGen());
rr.initialize(is);
List<Writable> exp0 = Arrays.asList((Writable) new Text("aValue0"), new Text("bValue0"), new Text("cxValue0"),
new IntWritable(0));
List<Writable> exp0 = Arrays.asList((Writable) new Text("aValue0"), new Text("bValue0"), new Text("cxValue0"), new IntWritable(0));
assertEquals(exp0, rr.next());
List<Writable> exp1 = Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1"),
new IntWritable(1));
List<Writable> exp1 = Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1"), new IntWritable(1));
assertEquals(exp1, rr.next());
List<Writable> exp2 = Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX"),
new IntWritable(2));
List<Writable> exp2 = Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX"), new IntWritable(2));
assertEquals(exp2, rr.next());
//Insert at position 0:
rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1,
new LabelGen(), 0);
// Insert at position 0:
rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1, new LabelGen(), 0);
rr.initialize(is);
exp0 = Arrays.asList((Writable) new IntWritable(0), new Text("aValue0"), new Text("bValue0"),
new Text("cxValue0"));
exp0 = Arrays.asList((Writable) new IntWritable(0), new Text("aValue0"), new Text("bValue0"), new Text("cxValue0"));
assertEquals(exp0, rr.next());
exp1 = Arrays.asList((Writable) new IntWritable(1), new Text("aValue1"), new Text("MISSING_B"),
new Text("cxValue1"));
exp1 = Arrays.asList((Writable) new IntWritable(1), new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1"));
assertEquals(exp1, rr.next());
exp2 = Arrays.asList((Writable) new IntWritable(2), new Text("aValue2"), new Text("bValue2"),
new Text("MISSING_CX"));
exp2 = Arrays.asList((Writable) new IntWritable(2), new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX"));
assertEquals(exp2, rr.next());
}
@Test
public void testAppendingLabelsMetaData() throws Exception {
@DisplayName("Test Appending Labels Meta Data")
void testAppendingLabelsMetaData(@TempDir Path testDir) throws Exception {
ClassPathResource cpr = new ClassPathResource("datavec-api/json/");
File f = testDir.newFolder();
File f = testDir.toFile();
cpr.copyDirectory(f);
String path = new File(f, "json_test_%d.txt").getAbsolutePath();
InputSplit is = new NumberedFileInputSplit(path, 0, 2);
//Insert at the end:
RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1,
new LabelGen());
// Insert at the end:
RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1, new LabelGen());
rr.initialize(is);
List<List<Writable>> out = new ArrayList<>();
while (rr.hasNext()) {
out.add(rr.next());
}
assertEquals(3, out.size());
rr.reset();
List<List<Writable>> out2 = new ArrayList<>();
List<Record> outRecord = new ArrayList<>();
List<RecordMetaData> meta = new ArrayList<>();
@ -222,14 +180,12 @@ public class JacksonRecordReaderTest extends BaseND4JTest {
outRecord.add(r);
meta.add(r.getMetaData());
}
assertEquals(out, out2);
List<Record> fromMeta = rr.loadFromMetaData(meta);
assertEquals(outRecord, fromMeta);
}
@DisplayName("Label Gen")
private static class LabelGen implements PathLabelGenerator {
@Override
@ -252,5 +208,4 @@ public class JacksonRecordReaderTest extends BaseND4JTest {
return true;
}
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.records.reader.impl;
import org.datavec.api.conf.Configuration;
@ -27,43 +26,30 @@ import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource;
import java.io.IOException;
import java.util.*;
import static org.datavec.api.records.reader.impl.misc.LibSvmRecordReader.*;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.jupiter.api.Assertions.assertThrows;
public class LibSvmRecordReaderTest extends BaseND4JTest {
@DisplayName("Lib Svm Record Reader Test")
class LibSvmRecordReaderTest extends BaseND4JTest {
@Test
public void testBasicRecord() throws IOException, InterruptedException {
@DisplayName("Test Basic Record")
void testBasicRecord() throws IOException, InterruptedException {
Map<Integer, List<Writable>> correct = new HashMap<>();
// 7 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO, ONE,
ZERO, new DoubleWritable(2),
ZERO, new DoubleWritable(3),
ZERO, new DoubleWritable(4),
ZERO, new DoubleWritable(5),
new IntWritable(7)));
correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), new IntWritable(7)));
// 2 qid:42 1:0.1 2:2 6:6.6 8:80
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2),
ZERO, ZERO,
ZERO, new DoubleWritable(6.6),
ZERO, new DoubleWritable(80),
ZERO, ZERO,
new IntWritable(2)));
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, new IntWritable(2)));
// 33
correct.put(2, Arrays.asList(ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
new IntWritable(33)));
correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, new IntWritable(33)));
LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration();
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
@ -80,27 +66,15 @@ public class LibSvmRecordReaderTest extends BaseND4JTest {
}
@Test
public void testNoAppendLabel() throws IOException, InterruptedException {
@DisplayName("Test No Append Label")
void testNoAppendLabel() throws IOException, InterruptedException {
Map<Integer, List<Writable>> correct = new HashMap<>();
// 7 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO, ONE,
ZERO, new DoubleWritable(2),
ZERO, new DoubleWritable(3),
ZERO, new DoubleWritable(4),
ZERO, new DoubleWritable(5)));
correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5)));
// 2 qid:42 1:0.1 2:2 6:6.6 8:80
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2),
ZERO, ZERO,
ZERO, new DoubleWritable(6.6),
ZERO, new DoubleWritable(80),
ZERO, ZERO));
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO));
// 33
correct.put(2, Arrays.asList(ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO));
correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO));
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
@ -117,33 +91,17 @@ public class LibSvmRecordReaderTest extends BaseND4JTest {
}
@Test
public void testNoLabel() throws IOException, InterruptedException {
@DisplayName("Test No Label")
void testNoLabel() throws IOException, InterruptedException {
Map<Integer, List<Writable>> correct = new HashMap<>();
// 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO, ONE,
ZERO, new DoubleWritable(2),
ZERO, new DoubleWritable(3),
ZERO, new DoubleWritable(4),
ZERO, new DoubleWritable(5)));
// qid:42 1:0.1 2:2 6:6.6 8:80
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2),
ZERO, ZERO,
ZERO, new DoubleWritable(6.6),
ZERO, new DoubleWritable(80),
ZERO, ZERO));
// 1:1.0
correct.put(2, Arrays.asList(new DoubleWritable(1.0), ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO));
//
correct.put(3, Arrays.asList(ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO));
// 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5)));
// qid:42 1:0.1 2:2 6:6.6 8:80
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO));
// 1:1.0
correct.put(2, Arrays.asList(new DoubleWritable(1.0), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO));
//
correct.put(3, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO));
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
@ -160,33 +118,15 @@ public class LibSvmRecordReaderTest extends BaseND4JTest {
}
@Test
public void testMultioutputRecord() throws IOException, InterruptedException {
@DisplayName("Test Multioutput Record")
void testMultioutputRecord() throws IOException, InterruptedException {
Map<Integer, List<Writable>> correct = new HashMap<>();
// 7 2.45,9 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO, ONE,
ZERO, new DoubleWritable(2),
ZERO, new DoubleWritable(3),
ZERO, new DoubleWritable(4),
ZERO, new DoubleWritable(5),
new IntWritable(7), new DoubleWritable(2.45),
new IntWritable(9)));
correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), new IntWritable(7), new DoubleWritable(2.45), new IntWritable(9)));
// 2,3,4 qid:42 1:0.1 2:2 6:6.6 8:80
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2),
ZERO, ZERO,
ZERO, new DoubleWritable(6.6),
ZERO, new DoubleWritable(80),
ZERO, ZERO,
new IntWritable(2), new IntWritable(3),
new IntWritable(4)));
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, new IntWritable(2), new IntWritable(3), new IntWritable(4)));
// 33,32.0,31.9
correct.put(2, Arrays.asList(ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
new IntWritable(33), new DoubleWritable(32.0),
new DoubleWritable(31.9)));
correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, new IntWritable(33), new DoubleWritable(32.0), new DoubleWritable(31.9)));
LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration();
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
@ -202,51 +142,20 @@ public class LibSvmRecordReaderTest extends BaseND4JTest {
assertEquals(i, correct.size());
}
@Test
public void testMultilabelRecord() throws IOException, InterruptedException {
@DisplayName("Test Multilabel Record")
void testMultilabelRecord() throws IOException, InterruptedException {
Map<Integer, List<Writable>> correct = new HashMap<>();
// 1,3 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO, ONE,
ZERO, new DoubleWritable(2),
ZERO, new DoubleWritable(3),
ZERO, new DoubleWritable(4),
ZERO, new DoubleWritable(5),
LABEL_ONE, LABEL_ZERO,
LABEL_ONE, LABEL_ZERO));
correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), LABEL_ONE, LABEL_ZERO, LABEL_ONE, LABEL_ZERO));
// 2 qid:42 1:0.1 2:2 6:6.6 8:80
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2),
ZERO, ZERO,
ZERO, new DoubleWritable(6.6),
ZERO, new DoubleWritable(80),
ZERO, ZERO,
LABEL_ZERO, LABEL_ONE,
LABEL_ZERO, LABEL_ZERO));
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, LABEL_ZERO, LABEL_ONE, LABEL_ZERO, LABEL_ZERO));
// 1,2,4
correct.put(2, Arrays.asList(ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
LABEL_ONE, LABEL_ONE,
LABEL_ZERO, LABEL_ONE));
// 1:1.0
correct.put(3, Arrays.asList(new DoubleWritable(1.0), ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
LABEL_ZERO, LABEL_ZERO,
LABEL_ZERO, LABEL_ZERO));
//
correct.put(4, Arrays.asList(ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
LABEL_ZERO, LABEL_ZERO,
LABEL_ZERO, LABEL_ZERO));
correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ONE, LABEL_ONE, LABEL_ZERO, LABEL_ONE));
// 1:1.0
correct.put(3, Arrays.asList(new DoubleWritable(1.0), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
//
correct.put(4, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration();
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
@ -265,63 +174,24 @@ public class LibSvmRecordReaderTest extends BaseND4JTest {
}
@Test
public void testZeroBasedIndexing() throws IOException, InterruptedException {
@DisplayName("Test Zero Based Indexing")
void testZeroBasedIndexing() throws IOException, InterruptedException {
Map<Integer, List<Writable>> correct = new HashMap<>();
// 1,3 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO,
ZERO, ONE,
ZERO, new DoubleWritable(2),
ZERO, new DoubleWritable(3),
ZERO, new DoubleWritable(4),
ZERO, new DoubleWritable(5),
LABEL_ZERO,
LABEL_ONE, LABEL_ZERO,
LABEL_ONE, LABEL_ZERO));
correct.put(0, Arrays.asList(ZERO, ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), LABEL_ZERO, LABEL_ONE, LABEL_ZERO, LABEL_ONE, LABEL_ZERO));
// 2 qid:42 1:0.1 2:2 6:6.6 8:80
correct.put(1, Arrays.asList(ZERO,
new DoubleWritable(0.1), new DoubleWritable(2),
ZERO, ZERO,
ZERO, new DoubleWritable(6.6),
ZERO, new DoubleWritable(80),
ZERO, ZERO,
LABEL_ZERO,
LABEL_ZERO, LABEL_ONE,
LABEL_ZERO, LABEL_ZERO));
correct.put(1, Arrays.asList(ZERO, new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ONE, LABEL_ZERO, LABEL_ZERO));
// 1,2,4
correct.put(2, Arrays.asList(ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
LABEL_ZERO,
LABEL_ONE, LABEL_ONE,
LABEL_ZERO, LABEL_ONE));
// 1:1.0
correct.put(3, Arrays.asList(ZERO,
new DoubleWritable(1.0), ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
LABEL_ZERO,
LABEL_ZERO, LABEL_ZERO,
LABEL_ZERO, LABEL_ZERO));
//
correct.put(4, Arrays.asList(ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
LABEL_ZERO,
LABEL_ZERO, LABEL_ZERO,
LABEL_ZERO, LABEL_ZERO));
correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ONE, LABEL_ONE, LABEL_ZERO, LABEL_ONE));
// 1:1.0
correct.put(3, Arrays.asList(ZERO, new DoubleWritable(1.0), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
//
correct.put(4, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration();
// Zero-based indexing is default
config.setBoolean(SVMLightRecordReader.ZERO_BASED_LABEL_INDEXING, true); // NOT STANDARD!
// NOT STANDARD!
config.setBoolean(SVMLightRecordReader.ZERO_BASED_LABEL_INDEXING, true);
config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true);
config.setInt(LibSvmRecordReader.NUM_FEATURES, 11);
config.setBoolean(LibSvmRecordReader.MULTILABEL, true);
@ -336,87 +206,107 @@ public class LibSvmRecordReaderTest extends BaseND4JTest {
assertEquals(i, correct.size());
}
@Test(expected = NoSuchElementException.class)
public void testNoSuchElementException() throws Exception {
LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration();
config.setInt(LibSvmRecordReader.NUM_FEATURES, 11);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
while (rr.hasNext())
@Test
@DisplayName("Test No Such Element Exception")
void testNoSuchElementException() {
assertThrows(NoSuchElementException.class, () -> {
LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration();
config.setInt(LibSvmRecordReader.NUM_FEATURES, 11);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
while (rr.hasNext()) rr.next();
rr.next();
rr.next();
});
}
@Test(expected = UnsupportedOperationException.class)
public void failedToSetNumFeaturesException() throws Exception {
LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration();
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
while (rr.hasNext())
@Test
@DisplayName("Failed To Set Num Features Exception")
void failedToSetNumFeaturesException() {
assertThrows(UnsupportedOperationException.class, () -> {
LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration();
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
while (rr.hasNext()) rr.next();
});
}
@Test
@DisplayName("Test Inconsistent Num Labels Exception")
void testInconsistentNumLabelsException() {
assertThrows(UnsupportedOperationException.class, () -> {
LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration();
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/inconsistentNumLabels.txt").getFile()));
while (rr.hasNext()) rr.next();
});
}
@Test
@DisplayName("Test Inconsistent Num Multiabels Exception")
void testInconsistentNumMultiabelsException() {
assertThrows(UnsupportedOperationException.class, () -> {
LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration();
config.setBoolean(LibSvmRecordReader.MULTILABEL, false);
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile()));
while (rr.hasNext()) rr.next();
});
}
@Test
@DisplayName("Test Feature Index Exceeds Num Features")
void testFeatureIndexExceedsNumFeatures() {
assertThrows(IndexOutOfBoundsException.class, () -> {
LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration();
config.setInt(LibSvmRecordReader.NUM_FEATURES, 9);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
rr.next();
});
}
@Test(expected = UnsupportedOperationException.class)
public void testInconsistentNumLabelsException() throws Exception {
LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration();
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/inconsistentNumLabels.txt").getFile()));
while (rr.hasNext())
@Test
@DisplayName("Test Label Index Exceeds Num Labels")
void testLabelIndexExceedsNumLabels() {
assertThrows(IndexOutOfBoundsException.class, () -> {
LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration();
config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true);
config.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
config.setInt(LibSvmRecordReader.NUM_LABELS, 6);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
rr.next();
});
}
@Test(expected = UnsupportedOperationException.class)
public void testInconsistentNumMultiabelsException() throws Exception {
LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration();
config.setBoolean(LibSvmRecordReader.MULTILABEL, false);
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile()));
while (rr.hasNext())
@Test
@DisplayName("Test Zero Index Feature Without Using Zero Indexing")
void testZeroIndexFeatureWithoutUsingZeroIndexing() {
assertThrows(IndexOutOfBoundsException.class, () -> {
LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration();
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true);
config.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexFeature.txt").getFile()));
rr.next();
});
}
@Test(expected = IndexOutOfBoundsException.class)
public void testFeatureIndexExceedsNumFeatures() throws Exception {
LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration();
config.setInt(LibSvmRecordReader.NUM_FEATURES, 9);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
rr.next();
}
@Test(expected = IndexOutOfBoundsException.class)
public void testLabelIndexExceedsNumLabels() throws Exception {
LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration();
config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true);
config.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
config.setInt(LibSvmRecordReader.NUM_LABELS, 6);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
rr.next();
}
@Test(expected = IndexOutOfBoundsException.class)
public void testZeroIndexFeatureWithoutUsingZeroIndexing() throws Exception {
LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration();
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true);
config.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexFeature.txt").getFile()));
rr.next();
}
@Test(expected = IndexOutOfBoundsException.class)
public void testZeroIndexLabelWithoutUsingZeroIndexing() throws Exception {
LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration();
config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true);
config.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
config.setBoolean(LibSvmRecordReader.MULTILABEL, true);
config.setInt(LibSvmRecordReader.NUM_LABELS, 2);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexLabel.txt").getFile()));
rr.next();
@Test
@DisplayName("Test Zero Index Label Without Using Zero Indexing")
void testZeroIndexLabelWithoutUsingZeroIndexing() {
assertThrows(IndexOutOfBoundsException.class, () -> {
LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration();
config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true);
config.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
config.setBoolean(LibSvmRecordReader.MULTILABEL, true);
config.setInt(LibSvmRecordReader.NUM_LABELS, 2);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexLabel.txt").getFile()));
rr.next();
});
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.records.reader.impl;
import org.apache.commons.io.FileUtils;
@ -31,10 +30,9 @@ import org.datavec.api.split.InputSplit;
import org.datavec.api.split.InputStreamInputSplit;
import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
@ -45,34 +43,31 @@ import java.util.Arrays;
import java.util.List;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals;
@DisplayName("Line Reader Test")
class LineReaderTest extends BaseND4JTest {
public class LineReaderTest extends BaseND4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test
public void testLineReader() throws Exception {
File tmpdir = testDir.newFolder();
@DisplayName("Test Line Reader")
void testLineReader(@TempDir Path tmpDir) throws Exception {
File tmpdir = tmpDir.toFile();
if (tmpdir.exists())
tmpdir.delete();
tmpdir.mkdir();
File tmp1 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp1.txt"));
File tmp2 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp2.txt"));
File tmp3 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp3.txt"));
FileUtils.writeLines(tmp1, Arrays.asList("1", "2", "3"));
FileUtils.writeLines(tmp2, Arrays.asList("4", "5", "6"));
FileUtils.writeLines(tmp3, Arrays.asList("7", "8", "9"));
InputSplit split = new FileSplit(tmpdir);
RecordReader reader = new LineRecordReader();
reader.initialize(split);
int count = 0;
List<List<Writable>> list = new ArrayList<>();
while (reader.hasNext()) {
@ -81,34 +76,27 @@ public class LineReaderTest extends BaseND4JTest {
list.add(l);
count++;
}
assertEquals(9, count);
}
@Test
public void testLineReaderMetaData() throws Exception {
File tmpdir = testDir.newFolder();
@DisplayName("Test Line Reader Meta Data")
void testLineReaderMetaData(@TempDir Path tmpDir) throws Exception {
File tmpdir = tmpDir.toFile();
File tmp1 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp1.txt"));
File tmp2 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp2.txt"));
File tmp3 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp3.txt"));
FileUtils.writeLines(tmp1, Arrays.asList("1", "2", "3"));
FileUtils.writeLines(tmp2, Arrays.asList("4", "5", "6"));
FileUtils.writeLines(tmp3, Arrays.asList("7", "8", "9"));
InputSplit split = new FileSplit(tmpdir);
RecordReader reader = new LineRecordReader();
reader.initialize(split);
List<List<Writable>> list = new ArrayList<>();
while (reader.hasNext()) {
list.add(reader.next());
}
assertEquals(9, list.size());
List<List<Writable>> out2 = new ArrayList<>();
List<Record> out3 = new ArrayList<>();
List<RecordMetaData> meta = new ArrayList<>();
@ -124,13 +112,10 @@ public class LineReaderTest extends BaseND4JTest {
assertEquals(uri, split.locations()[fileIdx]);
count++;
}
assertEquals(list, out2);
List<Record> fromMeta = reader.loadFromMetaData(meta);
assertEquals(out3, fromMeta);
//try: second line of second and third files only...
// try: second line of second and third files only...
List<RecordMetaData> subsetMeta = new ArrayList<>();
subsetMeta.add(meta.get(4));
subsetMeta.add(meta.get(7));
@ -141,27 +126,22 @@ public class LineReaderTest extends BaseND4JTest {
}
@Test
public void testLineReaderWithInputStreamInputSplit() throws Exception {
File tmpdir = testDir.newFolder();
@DisplayName("Test Line Reader With Input Stream Input Split")
void testLineReaderWithInputStreamInputSplit(@TempDir Path testDir) throws Exception {
File tmpdir = testDir.toFile();
File tmp1 = new File(tmpdir, "tmp1.txt.gz");
OutputStream os = new GZIPOutputStream(new FileOutputStream(tmp1, false));
IOUtils.writeLines(Arrays.asList("1", "2", "3", "4", "5", "6", "7", "8", "9"), null, os);
os.flush();
os.close();
InputSplit split = new InputStreamInputSplit(new GZIPInputStream(new FileInputStream(tmp1)));
RecordReader reader = new LineRecordReader();
reader.initialize(split);
int count = 0;
while (reader.hasNext()) {
assertEquals(1, reader.next().size());
count++;
}
assertEquals(9, count);
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.records.reader.impl;
import org.datavec.api.records.Record;
@ -34,43 +33,40 @@ import org.datavec.api.split.NumberedFileInputSplit;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
@DisplayName("Regex Record Reader Test")
class RegexRecordReaderTest extends BaseND4JTest {
public class RegexRecordReaderTest extends BaseND4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@TempDir
public Path testDir;
@Test
public void testRegexLineRecordReader() throws Exception {
@DisplayName("Test Regex Line Record Reader")
void testRegexLineRecordReader() throws Exception {
String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)";
RecordReader rr = new RegexLineRecordReader(regex, 1);
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/logtestdata/logtestfile0.txt").getFile()));
List<Writable> exp0 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.001"), new Text("1"),
new Text("DEBUG"), new Text("First entry message!"));
List<Writable> exp1 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.002"), new Text("2"),
new Text("INFO"), new Text("Second entry message!"));
List<Writable> exp2 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.003"), new Text("3"),
new Text("WARN"), new Text("Third entry message!"));
List<Writable> exp0 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.001"), new Text("1"), new Text("DEBUG"), new Text("First entry message!"));
List<Writable> exp1 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.002"), new Text("2"), new Text("INFO"), new Text("Second entry message!"));
List<Writable> exp2 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.003"), new Text("3"), new Text("WARN"), new Text("Third entry message!"));
assertEquals(exp0, rr.next());
assertEquals(exp1, rr.next());
assertEquals(exp2, rr.next());
assertFalse(rr.hasNext());
//Test reset:
// Test reset:
rr.reset();
assertEquals(exp0, rr.next());
assertEquals(exp1, rr.next());
@ -79,74 +75,57 @@ public class RegexRecordReaderTest extends BaseND4JTest {
}
@Test
public void testRegexLineRecordReaderMeta() throws Exception {
@DisplayName("Test Regex Line Record Reader Meta")
void testRegexLineRecordReaderMeta() throws Exception {
String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)";
RecordReader rr = new RegexLineRecordReader(regex, 1);
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/logtestdata/logtestfile0.txt").getFile()));
List<List<Writable>> list = new ArrayList<>();
while (rr.hasNext()) {
list.add(rr.next());
}
assertEquals(3, list.size());
List<Record> list2 = new ArrayList<>();
List<List<Writable>> list3 = new ArrayList<>();
List<RecordMetaData> meta = new ArrayList<>();
rr.reset();
int count = 1; //Start by skipping 1 line
// Start by skipping 1 line
int count = 1;
while (rr.hasNext()) {
Record r = rr.nextRecord();
list2.add(r);
list3.add(r.getRecord());
meta.add(r.getMetaData());
assertEquals(count++, ((RecordMetaDataLine) r.getMetaData()).getLineNumber());
}
List<Record> fromMeta = rr.loadFromMetaData(meta);
assertEquals(list, list3);
assertEquals(list2, fromMeta);
}
@Test
public void testRegexSequenceRecordReader() throws Exception {
@DisplayName("Test Regex Sequence Record Reader")
void testRegexSequenceRecordReader(@TempDir Path testDir) throws Exception {
String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)";
ClassPathResource cpr = new ClassPathResource("datavec-api/logtestdata/");
File f = testDir.newFolder();
File f = testDir.toFile();
cpr.copyDirectory(f);
String path = new File(f, "logtestfile%d.txt").getAbsolutePath();
InputSplit is = new NumberedFileInputSplit(path, 0, 1);
SequenceRecordReader rr = new RegexSequenceRecordReader(regex, 1);
rr.initialize(is);
List<List<Writable>> exp0 = new ArrayList<>();
exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.001"), new Text("1"), new Text("DEBUG"),
new Text("First entry message!")));
exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.002"), new Text("2"), new Text("INFO"),
new Text("Second entry message!")));
exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.003"), new Text("3"), new Text("WARN"),
new Text("Third entry message!")));
exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.001"), new Text("1"), new Text("DEBUG"), new Text("First entry message!")));
exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.002"), new Text("2"), new Text("INFO"), new Text("Second entry message!")));
exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.003"), new Text("3"), new Text("WARN"), new Text("Third entry message!")));
List<List<Writable>> exp1 = new ArrayList<>();
exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.011"), new Text("11"), new Text("DEBUG"),
new Text("First entry message!")));
exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.012"), new Text("12"), new Text("INFO"),
new Text("Second entry message!")));
exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.013"), new Text("13"), new Text("WARN"),
new Text("Third entry message!")));
exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.011"), new Text("11"), new Text("DEBUG"), new Text("First entry message!")));
exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.012"), new Text("12"), new Text("INFO"), new Text("Second entry message!")));
exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.013"), new Text("13"), new Text("WARN"), new Text("Third entry message!")));
assertEquals(exp0, rr.sequenceRecord());
assertEquals(exp1, rr.sequenceRecord());
assertFalse(rr.hasNext());
//Test resetting:
// Test resetting:
rr.reset();
assertEquals(exp0, rr.sequenceRecord());
assertEquals(exp1, rr.sequenceRecord());
@ -154,24 +133,20 @@ public class RegexRecordReaderTest extends BaseND4JTest {
}
@Test
public void testRegexSequenceRecordReaderMeta() throws Exception {
@DisplayName("Test Regex Sequence Record Reader Meta")
void testRegexSequenceRecordReaderMeta(@TempDir Path testDir) throws Exception {
String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)";
ClassPathResource cpr = new ClassPathResource("datavec-api/logtestdata/");
File f = testDir.newFolder();
File f = testDir.toFile();
cpr.copyDirectory(f);
String path = new File(f, "logtestfile%d.txt").getAbsolutePath();
InputSplit is = new NumberedFileInputSplit(path, 0, 1);
SequenceRecordReader rr = new RegexSequenceRecordReader(regex, 1);
rr.initialize(is);
List<List<List<Writable>>> out = new ArrayList<>();
while (rr.hasNext()) {
out.add(rr.sequenceRecord());
}
assertEquals(2, out.size());
List<List<List<Writable>>> out2 = new ArrayList<>();
List<SequenceRecord> out3 = new ArrayList<>();
@ -183,11 +158,8 @@ public class RegexRecordReaderTest extends BaseND4JTest {
out3.add(seqr);
meta.add(seqr.getMetaData());
}
List<SequenceRecord> fromMeta = rr.loadSequenceFromMetaData(meta);
assertEquals(out, out2);
assertEquals(out3, fromMeta);
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.records.reader.impl;
import org.datavec.api.conf.Configuration;
@ -27,43 +26,30 @@ import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource;
import java.io.IOException;
import java.util.*;
import static org.datavec.api.records.reader.impl.misc.SVMLightRecordReader.*;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.jupiter.api.Assertions.assertThrows;
public class SVMLightRecordReaderTest extends BaseND4JTest {
@DisplayName("Svm Light Record Reader Test")
class SVMLightRecordReaderTest extends BaseND4JTest {
@Test
public void testBasicRecord() throws IOException, InterruptedException {
@DisplayName("Test Basic Record")
void testBasicRecord() throws IOException, InterruptedException {
Map<Integer, List<Writable>> correct = new HashMap<>();
// 7 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO, ONE,
ZERO, new DoubleWritable(2),
ZERO, new DoubleWritable(3),
ZERO, new DoubleWritable(4),
ZERO, new DoubleWritable(5),
new IntWritable(7)));
correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), new IntWritable(7)));
// 2 qid:42 1:0.1 2:2 6:6.6 8:80
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2),
ZERO, ZERO,
ZERO, new DoubleWritable(6.6),
ZERO, new DoubleWritable(80),
ZERO, ZERO,
new IntWritable(2)));
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, new IntWritable(2)));
// 33
correct.put(2, Arrays.asList(ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
new IntWritable(33)));
correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, new IntWritable(33)));
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
@ -79,27 +65,15 @@ public class SVMLightRecordReaderTest extends BaseND4JTest {
}
@Test
public void testNoAppendLabel() throws IOException, InterruptedException {
@DisplayName("Test No Append Label")
void testNoAppendLabel() throws IOException, InterruptedException {
Map<Integer, List<Writable>> correct = new HashMap<>();
// 7 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO, ONE,
ZERO, new DoubleWritable(2),
ZERO, new DoubleWritable(3),
ZERO, new DoubleWritable(4),
ZERO, new DoubleWritable(5)));
correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5)));
// 2 qid:42 1:0.1 2:2 6:6.6 8:80
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2),
ZERO, ZERO,
ZERO, new DoubleWritable(6.6),
ZERO, new DoubleWritable(80),
ZERO, ZERO));
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO));
// 33
correct.put(2, Arrays.asList(ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO));
correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO));
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
@ -116,33 +90,17 @@ public class SVMLightRecordReaderTest extends BaseND4JTest {
}
@Test
public void testNoLabel() throws IOException, InterruptedException {
@DisplayName("Test No Label")
void testNoLabel() throws IOException, InterruptedException {
Map<Integer, List<Writable>> correct = new HashMap<>();
// 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO, ONE,
ZERO, new DoubleWritable(2),
ZERO, new DoubleWritable(3),
ZERO, new DoubleWritable(4),
ZERO, new DoubleWritable(5)));
// qid:42 1:0.1 2:2 6:6.6 8:80
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2),
ZERO, ZERO,
ZERO, new DoubleWritable(6.6),
ZERO, new DoubleWritable(80),
ZERO, ZERO));
// 1:1.0
correct.put(2, Arrays.asList(new DoubleWritable(1.0), ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO));
//
correct.put(3, Arrays.asList(ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO));
// 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5)));
// qid:42 1:0.1 2:2 6:6.6 8:80
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO));
// 1:1.0
correct.put(2, Arrays.asList(new DoubleWritable(1.0), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO));
//
correct.put(3, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO));
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
@ -159,33 +117,15 @@ public class SVMLightRecordReaderTest extends BaseND4JTest {
}
@Test
public void testMultioutputRecord() throws IOException, InterruptedException {
@DisplayName("Test Multioutput Record")
void testMultioutputRecord() throws IOException, InterruptedException {
Map<Integer, List<Writable>> correct = new HashMap<>();
// 7 2.45,9 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO, ONE,
ZERO, new DoubleWritable(2),
ZERO, new DoubleWritable(3),
ZERO, new DoubleWritable(4),
ZERO, new DoubleWritable(5),
new IntWritable(7), new DoubleWritable(2.45),
new IntWritable(9)));
correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), new IntWritable(7), new DoubleWritable(2.45), new IntWritable(9)));
// 2,3,4 qid:42 1:0.1 2:2 6:6.6 8:80
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2),
ZERO, ZERO,
ZERO, new DoubleWritable(6.6),
ZERO, new DoubleWritable(80),
ZERO, ZERO,
new IntWritable(2), new IntWritable(3),
new IntWritable(4)));
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, new IntWritable(2), new IntWritable(3), new IntWritable(4)));
// 33,32.0,31.9
correct.put(2, Arrays.asList(ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
new IntWritable(33), new DoubleWritable(32.0),
new DoubleWritable(31.9)));
correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, new IntWritable(33), new DoubleWritable(32.0), new DoubleWritable(31.9)));
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
@ -200,51 +140,20 @@ public class SVMLightRecordReaderTest extends BaseND4JTest {
assertEquals(i, correct.size());
}
@Test
public void testMultilabelRecord() throws IOException, InterruptedException {
@DisplayName("Test Multilabel Record")
void testMultilabelRecord() throws IOException, InterruptedException {
Map<Integer, List<Writable>> correct = new HashMap<>();
// 1,3 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO, ONE,
ZERO, new DoubleWritable(2),
ZERO, new DoubleWritable(3),
ZERO, new DoubleWritable(4),
ZERO, new DoubleWritable(5),
LABEL_ONE, LABEL_ZERO,
LABEL_ONE, LABEL_ZERO));
correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), LABEL_ONE, LABEL_ZERO, LABEL_ONE, LABEL_ZERO));
// 2 qid:42 1:0.1 2:2 6:6.6 8:80
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2),
ZERO, ZERO,
ZERO, new DoubleWritable(6.6),
ZERO, new DoubleWritable(80),
ZERO, ZERO,
LABEL_ZERO, LABEL_ONE,
LABEL_ZERO, LABEL_ZERO));
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, LABEL_ZERO, LABEL_ONE, LABEL_ZERO, LABEL_ZERO));
// 1,2,4
correct.put(2, Arrays.asList(ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
LABEL_ONE, LABEL_ONE,
LABEL_ZERO, LABEL_ONE));
// 1:1.0
correct.put(3, Arrays.asList(new DoubleWritable(1.0), ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
LABEL_ZERO, LABEL_ZERO,
LABEL_ZERO, LABEL_ZERO));
//
correct.put(4, Arrays.asList(ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
LABEL_ZERO, LABEL_ZERO,
LABEL_ZERO, LABEL_ZERO));
correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ONE, LABEL_ONE, LABEL_ZERO, LABEL_ONE));
// 1:1.0
correct.put(3, Arrays.asList(new DoubleWritable(1.0), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
//
correct.put(4, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
@ -262,63 +171,24 @@ public class SVMLightRecordReaderTest extends BaseND4JTest {
}
@Test
public void testZeroBasedIndexing() throws IOException, InterruptedException {
@DisplayName("Test Zero Based Indexing")
void testZeroBasedIndexing() throws IOException, InterruptedException {
Map<Integer, List<Writable>> correct = new HashMap<>();
// 1,3 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO,
ZERO, ONE,
ZERO, new DoubleWritable(2),
ZERO, new DoubleWritable(3),
ZERO, new DoubleWritable(4),
ZERO, new DoubleWritable(5),
LABEL_ZERO,
LABEL_ONE, LABEL_ZERO,
LABEL_ONE, LABEL_ZERO));
correct.put(0, Arrays.asList(ZERO, ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), LABEL_ZERO, LABEL_ONE, LABEL_ZERO, LABEL_ONE, LABEL_ZERO));
// 2 qid:42 1:0.1 2:2 6:6.6 8:80
correct.put(1, Arrays.asList(ZERO,
new DoubleWritable(0.1), new DoubleWritable(2),
ZERO, ZERO,
ZERO, new DoubleWritable(6.6),
ZERO, new DoubleWritable(80),
ZERO, ZERO,
LABEL_ZERO,
LABEL_ZERO, LABEL_ONE,
LABEL_ZERO, LABEL_ZERO));
correct.put(1, Arrays.asList(ZERO, new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ONE, LABEL_ZERO, LABEL_ZERO));
// 1,2,4
correct.put(2, Arrays.asList(ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
LABEL_ZERO,
LABEL_ONE, LABEL_ONE,
LABEL_ZERO, LABEL_ONE));
// 1:1.0
correct.put(3, Arrays.asList(ZERO,
new DoubleWritable(1.0), ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
LABEL_ZERO,
LABEL_ZERO, LABEL_ZERO,
LABEL_ZERO, LABEL_ZERO));
//
correct.put(4, Arrays.asList(ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
LABEL_ZERO,
LABEL_ZERO, LABEL_ZERO,
LABEL_ZERO, LABEL_ZERO));
correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ONE, LABEL_ONE, LABEL_ZERO, LABEL_ONE));
// 1:1.0
correct.put(3, Arrays.asList(ZERO, new DoubleWritable(1.0), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
//
correct.put(4, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
// Zero-based indexing is default
config.setBoolean(SVMLightRecordReader.ZERO_BASED_LABEL_INDEXING, true); // NOT STANDARD!
// NOT STANDARD!
config.setBoolean(SVMLightRecordReader.ZERO_BASED_LABEL_INDEXING, true);
config.setInt(SVMLightRecordReader.NUM_FEATURES, 11);
config.setBoolean(SVMLightRecordReader.MULTILABEL, true);
config.setInt(SVMLightRecordReader.NUM_LABELS, 5);
@ -333,20 +203,19 @@ public class SVMLightRecordReaderTest extends BaseND4JTest {
}
@Test
public void testNextRecord() throws IOException, InterruptedException {
@DisplayName("Test Next Record")
void testNextRecord() throws IOException, InterruptedException {
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
config.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
config.setBoolean(SVMLightRecordReader.APPEND_LABEL, false);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
Record record = rr.nextRecord();
List<Writable> recordList = record.getRecord();
assertEquals(new DoubleWritable(1.0), recordList.get(1));
assertEquals(new DoubleWritable(3.0), recordList.get(5));
assertEquals(new DoubleWritable(4.0), recordList.get(7));
record = rr.nextRecord();
recordList = record.getRecord();
assertEquals(new DoubleWritable(0.1), recordList.get(0));
@ -354,82 +223,102 @@ public class SVMLightRecordReaderTest extends BaseND4JTest {
assertEquals(new DoubleWritable(80.0), recordList.get(7));
}
@Test(expected = NoSuchElementException.class)
public void testNoSuchElementException() throws Exception {
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setInt(SVMLightRecordReader.NUM_FEATURES, 11);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
while (rr.hasNext())
@Test
@DisplayName("Test No Such Element Exception")
void testNoSuchElementException() {
assertThrows(NoSuchElementException.class, () -> {
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setInt(SVMLightRecordReader.NUM_FEATURES, 11);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
while (rr.hasNext()) rr.next();
rr.next();
rr.next();
});
}
@Test(expected = UnsupportedOperationException.class)
public void failedToSetNumFeaturesException() throws Exception {
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
while (rr.hasNext())
@Test
@DisplayName("Failed To Set Num Features Exception")
void failedToSetNumFeaturesException() {
assertThrows(UnsupportedOperationException.class, () -> {
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
while (rr.hasNext()) rr.next();
});
}
@Test
@DisplayName("Test Inconsistent Num Labels Exception")
void testInconsistentNumLabelsException() {
assertThrows(UnsupportedOperationException.class, () -> {
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/inconsistentNumLabels.txt").getFile()));
while (rr.hasNext()) rr.next();
});
}
@Test
@DisplayName("Failed To Set Num Multiabels Exception")
void failedToSetNumMultiabelsException() {
assertThrows(UnsupportedOperationException.class, () -> {
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile()));
while (rr.hasNext()) rr.next();
});
}
@Test
@DisplayName("Test Feature Index Exceeds Num Features")
void testFeatureIndexExceedsNumFeatures() {
assertThrows(IndexOutOfBoundsException.class, () -> {
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setInt(SVMLightRecordReader.NUM_FEATURES, 9);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
rr.next();
});
}
@Test(expected = UnsupportedOperationException.class)
public void testInconsistentNumLabelsException() throws Exception {
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/inconsistentNumLabels.txt").getFile()));
while (rr.hasNext())
@Test
@DisplayName("Test Label Index Exceeds Num Labels")
void testLabelIndexExceedsNumLabels() {
assertThrows(IndexOutOfBoundsException.class, () -> {
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
config.setInt(SVMLightRecordReader.NUM_LABELS, 6);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
rr.next();
});
}
@Test(expected = UnsupportedOperationException.class)
public void failedToSetNumMultiabelsException() throws Exception {
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile()));
while (rr.hasNext())
@Test
@DisplayName("Test Zero Index Feature Without Using Zero Indexing")
void testZeroIndexFeatureWithoutUsingZeroIndexing() {
assertThrows(IndexOutOfBoundsException.class, () -> {
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
config.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexFeature.txt").getFile()));
rr.next();
});
}
@Test(expected = IndexOutOfBoundsException.class)
public void testFeatureIndexExceedsNumFeatures() throws Exception {
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setInt(SVMLightRecordReader.NUM_FEATURES, 9);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
rr.next();
}
@Test(expected = IndexOutOfBoundsException.class)
public void testLabelIndexExceedsNumLabels() throws Exception {
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
config.setInt(SVMLightRecordReader.NUM_LABELS, 6);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
rr.next();
}
@Test(expected = IndexOutOfBoundsException.class)
public void testZeroIndexFeatureWithoutUsingZeroIndexing() throws Exception {
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
config.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexFeature.txt").getFile()));
rr.next();
}
@Test(expected = IndexOutOfBoundsException.class)
public void testZeroIndexLabelWithoutUsingZeroIndexing() throws Exception {
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
config.setBoolean(SVMLightRecordReader.MULTILABEL, true);
config.setInt(SVMLightRecordReader.NUM_LABELS, 2);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexLabel.txt").getFile()));
rr.next();
@Test
@DisplayName("Test Zero Index Label Without Using Zero Indexing")
void testZeroIndexLabelWithoutUsingZeroIndexing() {
assertThrows(IndexOutOfBoundsException.class, () -> {
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
config.setBoolean(SVMLightRecordReader.MULTILABEL, true);
config.setInt(SVMLightRecordReader.NUM_LABELS, 2);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexLabel.txt").getFile()));
rr.next();
});
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.records.writer.impl;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
@ -26,44 +25,42 @@ import org.datavec.api.split.FileSplit;
import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Before;
import org.junit.Test;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import java.io.File;
import java.util.ArrayList;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals;
public class CSVRecordWriterTest extends BaseND4JTest {
@Before
public void setUp() throws Exception {
@DisplayName("Csv Record Writer Test")
class CSVRecordWriterTest extends BaseND4JTest {
@BeforeEach
void setUp() throws Exception {
}
@Test
public void testWrite() throws Exception {
@DisplayName("Test Write")
void testWrite() throws Exception {
File tempFile = File.createTempFile("datavec", "writer");
tempFile.deleteOnExit();
FileSplit fileSplit = new FileSplit(tempFile);
CSVRecordWriter writer = new CSVRecordWriter();
writer.initialize(fileSplit,new NumberOfRecordsPartitioner());
writer.initialize(fileSplit, new NumberOfRecordsPartitioner());
List<Writable> collection = new ArrayList<>();
collection.add(new Text("12"));
collection.add(new Text("13"));
collection.add(new Text("14"));
writer.write(collection);
CSVRecordReader reader = new CSVRecordReader(0);
reader.initialize(new FileSplit(tempFile));
int cnt = 0;
while (reader.hasNext()) {
List<Writable> line = new ArrayList<>(reader.next());
assertEquals(3, line.size());
assertEquals(12, line.get(0).toInt());
assertEquals(13, line.get(1).toInt());
assertEquals(14, line.get(2).toInt());

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.records.writer.impl;
import org.apache.commons.io.FileUtils;
@ -30,93 +29,90 @@ import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.io.ClassPathResource;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.Assert.assertEquals;
public class LibSvmRecordWriterTest extends BaseND4JTest {
@DisplayName("Lib Svm Record Writer Test")
class LibSvmRecordWriterTest extends BaseND4JTest {
@Test
public void testBasic() throws Exception {
@DisplayName("Test Basic")
void testBasic() throws Exception {
Configuration configWriter = new Configuration();
Configuration configReader = new Configuration();
configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
configReader.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
File inputFile = new ClassPathResource("datavec-api/svmlight/basic.txt").getFile();
executeTest(configWriter, configReader, inputFile);
}
@Test
public void testNoLabel() throws Exception {
@DisplayName("Test No Label")
void testNoLabel() throws Exception {
Configuration configWriter = new Configuration();
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 9);
Configuration configReader = new Configuration();
configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
configReader.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
File inputFile = new ClassPathResource("datavec-api/svmlight/basic.txt").getFile();
executeTest(configWriter, configReader, inputFile);
}
@Test
public void testMultioutputRecord() throws Exception {
@DisplayName("Test Multioutput Record")
void testMultioutputRecord() throws Exception {
Configuration configWriter = new Configuration();
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 9);
Configuration configReader = new Configuration();
configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
configReader.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
File inputFile = new ClassPathResource("datavec-api/svmlight/multioutput.txt").getFile();
executeTest(configWriter, configReader, inputFile);
}
@Test
public void testMultilabelRecord() throws Exception {
@DisplayName("Test Multilabel Record")
void testMultilabelRecord() throws Exception {
Configuration configWriter = new Configuration();
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 9);
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
Configuration configReader = new Configuration();
configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
configReader.setBoolean(LibSvmRecordReader.MULTILABEL, true);
configReader.setInt(LibSvmRecordReader.NUM_LABELS, 4);
configReader.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
File inputFile = new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile();
executeTest(configWriter, configReader, inputFile);
}
@Test
public void testZeroBasedIndexing() throws Exception {
@DisplayName("Test Zero Based Indexing")
void testZeroBasedIndexing() throws Exception {
Configuration configWriter = new Configuration();
configWriter.setBoolean(LibSvmRecordWriter.ZERO_BASED_INDEXING, true);
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 10);
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
Configuration configReader = new Configuration();
configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 11);
configReader.setBoolean(LibSvmRecordReader.MULTILABEL, true);
configReader.setInt(LibSvmRecordReader.NUM_LABELS, 5);
File inputFile = new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile();
executeTest(configWriter, configReader, inputFile);
}
@ -127,10 +123,9 @@ public class LibSvmRecordWriterTest extends BaseND4JTest {
tempFile.deleteOnExit();
if (tempFile.exists())
tempFile.delete();
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
LibSvmRecordReader rr = new LibSvmRecordReader();
rr.initialize(configReader, new FileSplit(inputFile));
while (rr.hasNext()) {
@ -138,7 +133,6 @@ public class LibSvmRecordWriterTest extends BaseND4JTest {
writer.write(record);
}
}
Pattern p = Pattern.compile(String.format("%s:\\d+ ", LibSvmRecordReader.QID_PREFIX));
List<String> linesOriginal = new ArrayList<>();
for (String line : FileUtils.readLines(inputFile)) {
@ -159,7 +153,8 @@ public class LibSvmRecordWriterTest extends BaseND4JTest {
}
@Test
public void testNDArrayWritables() throws Exception {
@DisplayName("Test ND Array Writables")
void testNDArrayWritables() throws Exception {
INDArray arr2 = Nd4j.zeros(2);
arr2.putScalar(0, 11);
arr2.putScalar(1, 12);
@ -167,35 +162,28 @@ public class LibSvmRecordWriterTest extends BaseND4JTest {
arr3.putScalar(0, 13);
arr3.putScalar(1, 14);
arr3.putScalar(2, 15);
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1),
new NDArrayWritable(arr2),
new IntWritable(2),
new DoubleWritable(3),
new NDArrayWritable(arr3),
new IntWritable(4));
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), new NDArrayWritable(arr3), new IntWritable(4));
File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
tempFile.setWritable(true);
tempFile.deleteOnExit();
if (tempFile.exists())
tempFile.delete();
String lineOriginal = "13.0,14.0,15.0,4 1:1.0 2:11.0 3:12.0 4:2.0 5:3.0";
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
Configuration configWriter = new Configuration();
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 3);
FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record);
}
String lineNew = FileUtils.readFileToString(tempFile).trim();
assertEquals(lineOriginal, lineNew);
}
@Test
public void testNDArrayWritablesMultilabel() throws Exception {
@DisplayName("Test ND Array Writables Multilabel")
void testNDArrayWritablesMultilabel() throws Exception {
INDArray arr2 = Nd4j.zeros(2);
arr2.putScalar(0, 11);
arr2.putScalar(1, 12);
@ -203,36 +191,29 @@ public class LibSvmRecordWriterTest extends BaseND4JTest {
arr3.putScalar(0, 0);
arr3.putScalar(1, 1);
arr3.putScalar(2, 0);
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1),
new NDArrayWritable(arr2),
new IntWritable(2),
new DoubleWritable(3),
new NDArrayWritable(arr3),
new DoubleWritable(1));
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), new NDArrayWritable(arr3), new DoubleWritable(1));
File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
tempFile.setWritable(true);
tempFile.deleteOnExit();
if (tempFile.exists())
tempFile.delete();
String lineOriginal = "2,4 1:1.0 2:11.0 3:12.0 4:2.0 5:3.0";
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
Configuration configWriter = new Configuration();
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 3);
FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record);
}
String lineNew = FileUtils.readFileToString(tempFile).trim();
assertEquals(lineOriginal, lineNew);
}
@Test
public void testNDArrayWritablesZeroIndex() throws Exception {
@DisplayName("Test ND Array Writables Zero Index")
void testNDArrayWritablesZeroIndex() throws Exception {
INDArray arr2 = Nd4j.zeros(2);
arr2.putScalar(0, 11);
arr2.putScalar(1, 12);
@ -240,99 +221,91 @@ public class LibSvmRecordWriterTest extends BaseND4JTest {
arr3.putScalar(0, 0);
arr3.putScalar(1, 1);
arr3.putScalar(2, 0);
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1),
new NDArrayWritable(arr2),
new IntWritable(2),
new DoubleWritable(3),
new NDArrayWritable(arr3),
new DoubleWritable(1));
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), new NDArrayWritable(arr3), new DoubleWritable(1));
File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
tempFile.setWritable(true);
tempFile.deleteOnExit();
if (tempFile.exists())
tempFile.delete();
String lineOriginal = "1,3 0:1.0 1:11.0 2:12.0 3:2.0 4:3.0";
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
Configuration configWriter = new Configuration();
configWriter.setBoolean(LibSvmRecordWriter.ZERO_BASED_INDEXING, true); // NOT STANDARD!
configWriter.setBoolean(LibSvmRecordWriter.ZERO_BASED_LABEL_INDEXING, true); // NOT STANDARD!
// NOT STANDARD!
configWriter.setBoolean(LibSvmRecordWriter.ZERO_BASED_INDEXING, true);
// NOT STANDARD!
configWriter.setBoolean(LibSvmRecordWriter.ZERO_BASED_LABEL_INDEXING, true);
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 3);
FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record);
}
String lineNew = FileUtils.readFileToString(tempFile).trim();
assertEquals(lineOriginal, lineNew);
}
@Test
public void testNonIntegerButValidMultilabel() throws Exception {
List<Writable> record = Arrays.asList((Writable) new IntWritable(3),
new IntWritable(2),
new DoubleWritable(1.0));
@DisplayName("Test Non Integer But Valid Multilabel")
void testNonIntegerButValidMultilabel() throws Exception {
List<Writable> record = Arrays.asList((Writable) new IntWritable(3), new IntWritable(2), new DoubleWritable(1.0));
File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
tempFile.setWritable(true);
tempFile.deleteOnExit();
if (tempFile.exists())
tempFile.delete();
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
Configuration configWriter = new Configuration();
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 1);
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record);
}
}
@Test(expected = NumberFormatException.class)
public void nonIntegerMultilabel() throws Exception {
List<Writable> record = Arrays.asList((Writable) new IntWritable(3),
new IntWritable(2),
new DoubleWritable(1.2));
File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
tempFile.setWritable(true);
tempFile.deleteOnExit();
if (tempFile.exists())
tempFile.delete();
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
Configuration configWriter = new Configuration();
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 1);
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
writer.write(record);
}
@Test
@DisplayName("Non Integer Multilabel")
void nonIntegerMultilabel() {
assertThrows(NumberFormatException.class, () -> {
List<Writable> record = Arrays.asList((Writable) new IntWritable(3), new IntWritable(2), new DoubleWritable(1.2));
File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
tempFile.setWritable(true);
tempFile.deleteOnExit();
if (tempFile.exists())
tempFile.delete();
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
Configuration configWriter = new Configuration();
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 1);
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record);
}
});
}
@Test(expected = NumberFormatException.class)
public void nonBinaryMultilabel() throws Exception {
List<Writable> record = Arrays.asList((Writable) new IntWritable(0),
new IntWritable(1),
new IntWritable(2));
File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
tempFile.setWritable(true);
tempFile.deleteOnExit();
if (tempFile.exists())
tempFile.delete();
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
Configuration configWriter = new Configuration();
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN,0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN,1);
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL,true);
FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
writer.write(record);
}
@Test
@DisplayName("Non Binary Multilabel")
void nonBinaryMultilabel() {
assertThrows(NumberFormatException.class, () -> {
List<Writable> record = Arrays.asList((Writable) new IntWritable(0), new IntWritable(1), new IntWritable(2));
File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
tempFile.setWritable(true);
tempFile.deleteOnExit();
if (tempFile.exists())
tempFile.delete();
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
Configuration configWriter = new Configuration();
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 1);
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record);
}
});
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.records.writer.impl;
import org.apache.commons.io.FileUtils;
@ -27,93 +26,90 @@ import org.datavec.api.records.writer.impl.misc.SVMLightRecordWriter;
import org.datavec.api.split.FileSplit;
import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
import org.datavec.api.writable.*;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.io.ClassPathResource;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.Assert.assertEquals;
public class SVMLightRecordWriterTest extends BaseND4JTest {
@DisplayName("Svm Light Record Writer Test")
class SVMLightRecordWriterTest extends BaseND4JTest {
@Test
public void testBasic() throws Exception {
@DisplayName("Test Basic")
void testBasic() throws Exception {
Configuration configWriter = new Configuration();
Configuration configReader = new Configuration();
configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
configReader.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
File inputFile = new ClassPathResource("datavec-api/svmlight/basic.txt").getFile();
executeTest(configWriter, configReader, inputFile);
}
@Test
public void testNoLabel() throws Exception {
@DisplayName("Test No Label")
void testNoLabel() throws Exception {
Configuration configWriter = new Configuration();
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 9);
Configuration configReader = new Configuration();
configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
configReader.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
File inputFile = new ClassPathResource("datavec-api/svmlight/noLabels.txt").getFile();
executeTest(configWriter, configReader, inputFile);
}
@Test
public void testMultioutputRecord() throws Exception {
@DisplayName("Test Multioutput Record")
void testMultioutputRecord() throws Exception {
Configuration configWriter = new Configuration();
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 9);
Configuration configReader = new Configuration();
configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
configReader.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
File inputFile = new ClassPathResource("datavec-api/svmlight/multioutput.txt").getFile();
executeTest(configWriter, configReader, inputFile);
}
@Test
public void testMultilabelRecord() throws Exception {
@DisplayName("Test Multilabel Record")
void testMultilabelRecord() throws Exception {
Configuration configWriter = new Configuration();
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 9);
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
Configuration configReader = new Configuration();
configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
configReader.setBoolean(SVMLightRecordReader.MULTILABEL, true);
configReader.setInt(SVMLightRecordReader.NUM_LABELS, 4);
configReader.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
File inputFile = new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile();
executeTest(configWriter, configReader, inputFile);
}
@Test
public void testZeroBasedIndexing() throws Exception {
@DisplayName("Test Zero Based Indexing")
void testZeroBasedIndexing() throws Exception {
Configuration configWriter = new Configuration();
configWriter.setBoolean(SVMLightRecordWriter.ZERO_BASED_INDEXING, true);
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 10);
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
Configuration configReader = new Configuration();
configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 11);
configReader.setBoolean(SVMLightRecordReader.MULTILABEL, true);
configReader.setInt(SVMLightRecordReader.NUM_LABELS, 5);
File inputFile = new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile();
executeTest(configWriter, configReader, inputFile);
}
@ -124,10 +120,9 @@ public class SVMLightRecordWriterTest extends BaseND4JTest {
tempFile.deleteOnExit();
if (tempFile.exists())
tempFile.delete();
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
SVMLightRecordReader rr = new SVMLightRecordReader();
rr.initialize(configReader, new FileSplit(inputFile));
while (rr.hasNext()) {
@ -135,7 +130,6 @@ public class SVMLightRecordWriterTest extends BaseND4JTest {
writer.write(record);
}
}
Pattern p = Pattern.compile(String.format("%s:\\d+ ", SVMLightRecordReader.QID_PREFIX));
List<String> linesOriginal = new ArrayList<>();
for (String line : FileUtils.readLines(inputFile)) {
@ -156,7 +150,8 @@ public class SVMLightRecordWriterTest extends BaseND4JTest {
}
@Test
public void testNDArrayWritables() throws Exception {
@DisplayName("Test ND Array Writables")
void testNDArrayWritables() throws Exception {
INDArray arr2 = Nd4j.zeros(2);
arr2.putScalar(0, 11);
arr2.putScalar(1, 12);
@ -164,35 +159,28 @@ public class SVMLightRecordWriterTest extends BaseND4JTest {
arr3.putScalar(0, 13);
arr3.putScalar(1, 14);
arr3.putScalar(2, 15);
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1),
new NDArrayWritable(arr2),
new IntWritable(2),
new DoubleWritable(3),
new NDArrayWritable(arr3),
new IntWritable(4));
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), new NDArrayWritable(arr3), new IntWritable(4));
File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
tempFile.setWritable(true);
tempFile.deleteOnExit();
if (tempFile.exists())
tempFile.delete();
String lineOriginal = "13.0,14.0,15.0,4 1:1.0 2:11.0 3:12.0 4:2.0 5:3.0";
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
Configuration configWriter = new Configuration();
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 3);
FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record);
}
String lineNew = FileUtils.readFileToString(tempFile).trim();
assertEquals(lineOriginal, lineNew);
}
@Test
public void testNDArrayWritablesMultilabel() throws Exception {
@DisplayName("Test ND Array Writables Multilabel")
void testNDArrayWritablesMultilabel() throws Exception {
INDArray arr2 = Nd4j.zeros(2);
arr2.putScalar(0, 11);
arr2.putScalar(1, 12);
@ -200,36 +188,29 @@ public class SVMLightRecordWriterTest extends BaseND4JTest {
arr3.putScalar(0, 0);
arr3.putScalar(1, 1);
arr3.putScalar(2, 0);
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1),
new NDArrayWritable(arr2),
new IntWritable(2),
new DoubleWritable(3),
new NDArrayWritable(arr3),
new DoubleWritable(1));
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), new NDArrayWritable(arr3), new DoubleWritable(1));
File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
tempFile.setWritable(true);
tempFile.deleteOnExit();
if (tempFile.exists())
tempFile.delete();
String lineOriginal = "2,4 1:1.0 2:11.0 3:12.0 4:2.0 5:3.0";
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
Configuration configWriter = new Configuration();
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 3);
FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record);
}
String lineNew = FileUtils.readFileToString(tempFile).trim();
assertEquals(lineOriginal, lineNew);
}
@Test
public void testNDArrayWritablesZeroIndex() throws Exception {
@DisplayName("Test ND Array Writables Zero Index")
void testNDArrayWritablesZeroIndex() throws Exception {
INDArray arr2 = Nd4j.zeros(2);
arr2.putScalar(0, 11);
arr2.putScalar(1, 12);
@ -237,99 +218,91 @@ public class SVMLightRecordWriterTest extends BaseND4JTest {
arr3.putScalar(0, 0);
arr3.putScalar(1, 1);
arr3.putScalar(2, 0);
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1),
new NDArrayWritable(arr2),
new IntWritable(2),
new DoubleWritable(3),
new NDArrayWritable(arr3),
new DoubleWritable(1));
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), new NDArrayWritable(arr3), new DoubleWritable(1));
File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
tempFile.setWritable(true);
tempFile.deleteOnExit();
if (tempFile.exists())
tempFile.delete();
String lineOriginal = "1,3 0:1.0 1:11.0 2:12.0 3:2.0 4:3.0";
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
Configuration configWriter = new Configuration();
configWriter.setBoolean(SVMLightRecordWriter.ZERO_BASED_INDEXING, true); // NOT STANDARD!
configWriter.setBoolean(SVMLightRecordWriter.ZERO_BASED_LABEL_INDEXING, true); // NOT STANDARD!
// NOT STANDARD!
configWriter.setBoolean(SVMLightRecordWriter.ZERO_BASED_INDEXING, true);
// NOT STANDARD!
configWriter.setBoolean(SVMLightRecordWriter.ZERO_BASED_LABEL_INDEXING, true);
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 3);
FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record);
}
String lineNew = FileUtils.readFileToString(tempFile).trim();
assertEquals(lineOriginal, lineNew);
}
@Test
public void testNonIntegerButValidMultilabel() throws Exception {
List<Writable> record = Arrays.asList((Writable) new IntWritable(3),
new IntWritable(2),
new DoubleWritable(1.0));
@DisplayName("Test Non Integer But Valid Multilabel")
void testNonIntegerButValidMultilabel() throws Exception {
List<Writable> record = Arrays.asList((Writable) new IntWritable(3), new IntWritable(2), new DoubleWritable(1.0));
File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
tempFile.setWritable(true);
tempFile.deleteOnExit();
if (tempFile.exists())
tempFile.delete();
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
Configuration configWriter = new Configuration();
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 1);
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record);
}
}
@Test(expected = NumberFormatException.class)
public void nonIntegerMultilabel() throws Exception {
List<Writable> record = Arrays.asList((Writable) new IntWritable(3),
new IntWritable(2),
new DoubleWritable(1.2));
File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
tempFile.setWritable(true);
tempFile.deleteOnExit();
if (tempFile.exists())
tempFile.delete();
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
Configuration configWriter = new Configuration();
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 1);
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
writer.write(record);
}
@Test
@DisplayName("Non Integer Multilabel")
void nonIntegerMultilabel() {
assertThrows(NumberFormatException.class, () -> {
List<Writable> record = Arrays.asList((Writable) new IntWritable(3), new IntWritable(2), new DoubleWritable(1.2));
File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
tempFile.setWritable(true);
tempFile.deleteOnExit();
if (tempFile.exists())
tempFile.delete();
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
Configuration configWriter = new Configuration();
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 1);
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record);
}
});
}
@Test(expected = NumberFormatException.class)
public void nonBinaryMultilabel() throws Exception {
List<Writable> record = Arrays.asList((Writable) new IntWritable(0),
new IntWritable(1),
new IntWritable(2));
File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
tempFile.setWritable(true);
tempFile.deleteOnExit();
if (tempFile.exists())
tempFile.delete();
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
Configuration configWriter = new Configuration();
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 1);
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
writer.write(record);
}
@Test
@DisplayName("Non Binary Multilabel")
void nonBinaryMultilabel() {
assertThrows(NumberFormatException.class, () -> {
List<Writable> record = Arrays.asList((Writable) new IntWritable(0), new IntWritable(1), new IntWritable(2));
File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
tempFile.setWritable(true);
tempFile.deleteOnExit();
if (tempFile.exists())
tempFile.delete();
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
Configuration configWriter = new Configuration();
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 1);
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record);
}
});
}
}

View File

@ -17,44 +17,43 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.split;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.Collection;
import static java.util.Arrays.asList;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
/**
* @author Ede Meijer
*/
public class TransformSplitTest extends BaseND4JTest {
@Test
public void testTransform() throws URISyntaxException {
Collection<URI> inputFiles = asList(new URI("file:///foo/bar/../0.csv"), new URI("file:///foo/1.csv"));
@DisplayName("Transform Split Test")
class TransformSplitTest extends BaseND4JTest {
@Test
@DisplayName("Test Transform")
void testTransform() throws URISyntaxException {
Collection<URI> inputFiles = asList(new URI("file:///foo/bar/../0.csv"), new URI("file:///foo/1.csv"));
InputSplit SUT = new TransformSplit(new CollectionInputSplit(inputFiles), new TransformSplit.URITransform() {
@Override
public URI apply(URI uri) throws URISyntaxException {
return uri.normalize();
}
});
assertArrayEquals(new URI[] {new URI("file:///foo/0.csv"), new URI("file:///foo/1.csv")}, SUT.locations());
assertArrayEquals(new URI[] { new URI("file:///foo/0.csv"), new URI("file:///foo/1.csv") }, SUT.locations());
}
@Test
public void testSearchReplace() throws URISyntaxException {
@DisplayName("Test Search Replace")
void testSearchReplace() throws URISyntaxException {
Collection<URI> inputFiles = asList(new URI("file:///foo/1-in.csv"), new URI("file:///foo/2-in.csv"));
InputSplit SUT = TransformSplit.ofSearchReplace(new CollectionInputSplit(inputFiles), "-in.csv", "-out.csv");
assertArrayEquals(new URI[] {new URI("file:///foo/1-out.csv"), new URI("file:///foo/2-out.csv")},
SUT.locations());
assertArrayEquals(new URI[] { new URI("file:///foo/1-out.csv"), new URI("file:///foo/2-out.csv") }, SUT.locations());
}
}

View File

@ -17,32 +17,25 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.transform.ops;
import com.tngtech.archunit.core.importer.ImportOption;
import com.tngtech.archunit.junit.AnalyzeClasses;
import com.tngtech.archunit.junit.ArchTest;
import com.tngtech.archunit.junit.ArchUnitRunner;
import com.tngtech.archunit.lang.ArchRule;
import com.tngtech.archunit.lang.extension.ArchUnitExtension;
import com.tngtech.archunit.lang.extension.ArchUnitExtensions;
import org.junit.runner.RunWith;
import org.nd4j.common.tests.BaseND4JTest;
import java.io.Serializable;
import static com.tngtech.archunit.lang.syntax.ArchRuleDefinition.classes;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@RunWith(ArchUnitRunner.class)
@AnalyzeClasses(packages = "org.datavec.api.transform.ops", importOptions = {ImportOption.DoNotIncludeTests.class})
public class AggregableMultiOpArchTest extends BaseND4JTest {
@AnalyzeClasses(packages = "org.datavec.api.transform.ops", importOptions = { ImportOption.DoNotIncludeTests.class })
@DisplayName("Aggregable Multi Op Arch Test")
class AggregableMultiOpArchTest extends BaseND4JTest {
@ArchTest
public static final ArchRule ALL_AGGREGATE_OPS_MUST_BE_SERIALIZABLE = classes()
.that().resideInAPackage("org.datavec.api.transform.ops")
.and().doNotHaveSimpleName("AggregatorImpls")
.and().doNotHaveSimpleName("IAggregableReduceOp")
.and().doNotHaveSimpleName("StringAggregatorImpls")
.and().doNotHaveFullyQualifiedName("org.datavec.api.transform.ops.StringAggregatorImpls$1")
.should().implement(Serializable.class)
.because("All aggregate ops must be serializable.");
}
public static final ArchRule ALL_AGGREGATE_OPS_MUST_BE_SERIALIZABLE = classes().that().resideInAPackage("org.datavec.api.transform.ops").and().doNotHaveSimpleName("AggregatorImpls").and().doNotHaveSimpleName("IAggregableReduceOp").and().doNotHaveSimpleName("StringAggregatorImpls").and().doNotHaveFullyQualifiedName("org.datavec.api.transform.ops.StringAggregatorImpls$1").should().implement(Serializable.class).because("All aggregate ops must be serializable.");
}

View File

@ -17,52 +17,46 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.transform.ops;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import java.util.*;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertTrue;
public class AggregableMultiOpTest extends BaseND4JTest {
@DisplayName("Aggregable Multi Op Test")
class AggregableMultiOpTest extends BaseND4JTest {
private List<Integer> intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
@Test
public void testMulti() throws Exception {
@DisplayName("Test Multi")
void testMulti() throws Exception {
AggregatorImpls.AggregableFirst<Integer> af = new AggregatorImpls.AggregableFirst<>();
AggregatorImpls.AggregableSum<Integer> as = new AggregatorImpls.AggregableSum<>();
AggregableMultiOp<Integer> multi = new AggregableMultiOp<>(Arrays.asList(af, as));
assertTrue(multi.getOperations().size() == 2);
for (int i = 0; i < intList.size(); i++) {
multi.accept(intList.get(i));
}
// mutablility
assertTrue(as.get().toDouble() == 45D);
assertTrue(af.get().toInt() == 1);
List<Writable> res = multi.get();
assertTrue(res.get(1).toDouble() == 45D);
assertTrue(res.get(0).toInt() == 1);
AggregatorImpls.AggregableFirst<Integer> rf = new AggregatorImpls.AggregableFirst<>();
AggregatorImpls.AggregableSum<Integer> rs = new AggregatorImpls.AggregableSum<>();
AggregableMultiOp<Integer> reverse = new AggregableMultiOp<>(Arrays.asList(rf, rs));
for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1));
}
List<Writable> revRes = reverse.get();
assertTrue(revRes.get(1).toDouble() == 45D);
assertTrue(revRes.get(0).toInt() == 9);
multi.combine(reverse);
List<Writable> combinedRes = multi.get();
assertTrue(combinedRes.get(1).toDouble() == 90D);

View File

@ -17,41 +17,39 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.transform.ops;
import org.junit.Rule;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.junit.rules.ExpectedException;
import org.nd4j.common.tests.BaseND4JTest;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
public class AggregatorImplsTest extends BaseND4JTest {
@DisplayName("Aggregator Impls Test")
class AggregatorImplsTest extends BaseND4JTest {
private List<Integer> intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
private List<String> stringList = new ArrayList<>(Arrays.asList("arakoa", "abracadabra", "blast", "acceptance"));
@Test
public void aggregableFirstTest() {
@DisplayName("Aggregable First Test")
void aggregableFirstTest() {
AggregatorImpls.AggregableFirst<Integer> first = new AggregatorImpls.AggregableFirst<>();
for (int i = 0; i < intList.size(); i++) {
first.accept(intList.get(i));
}
assertEquals(1, first.get().toInt());
AggregatorImpls.AggregableFirst<String> firstS = new AggregatorImpls.AggregableFirst<>();
for (int i = 0; i < stringList.size(); i++) {
firstS.accept(stringList.get(i));
}
assertTrue(firstS.get().toString().equals("arakoa"));
AggregatorImpls.AggregableFirst<Integer> reverse = new AggregatorImpls.AggregableFirst<>();
for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1));
@ -60,22 +58,19 @@ public class AggregatorImplsTest extends BaseND4JTest {
assertEquals(1, first.get().toInt());
}
@Test
public void aggregableLastTest() {
@DisplayName("Aggregable Last Test")
void aggregableLastTest() {
AggregatorImpls.AggregableLast<Integer> last = new AggregatorImpls.AggregableLast<>();
for (int i = 0; i < intList.size(); i++) {
last.accept(intList.get(i));
}
assertEquals(9, last.get().toInt());
AggregatorImpls.AggregableLast<String> lastS = new AggregatorImpls.AggregableLast<>();
for (int i = 0; i < stringList.size(); i++) {
lastS.accept(stringList.get(i));
}
assertTrue(lastS.get().toString().equals("acceptance"));
AggregatorImpls.AggregableLast<Integer> reverse = new AggregatorImpls.AggregableLast<>();
for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1));
@ -85,20 +80,18 @@ public class AggregatorImplsTest extends BaseND4JTest {
}
@Test
public void aggregableCountTest() {
@DisplayName("Aggregable Count Test")
void aggregableCountTest() {
AggregatorImpls.AggregableCount<Integer> cnt = new AggregatorImpls.AggregableCount<>();
for (int i = 0; i < intList.size(); i++) {
cnt.accept(intList.get(i));
}
assertEquals(9, cnt.get().toInt());
AggregatorImpls.AggregableCount<String> lastS = new AggregatorImpls.AggregableCount<>();
for (int i = 0; i < stringList.size(); i++) {
lastS.accept(stringList.get(i));
}
assertEquals(4, lastS.get().toInt());
AggregatorImpls.AggregableCount<Integer> reverse = new AggregatorImpls.AggregableCount<>();
for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1));
@ -108,14 +101,13 @@ public class AggregatorImplsTest extends BaseND4JTest {
}
@Test
public void aggregableMaxTest() {
@DisplayName("Aggregable Max Test")
void aggregableMaxTest() {
AggregatorImpls.AggregableMax<Integer> mx = new AggregatorImpls.AggregableMax<>();
for (int i = 0; i < intList.size(); i++) {
mx.accept(intList.get(i));
}
assertEquals(9, mx.get().toInt());
AggregatorImpls.AggregableMax<Integer> reverse = new AggregatorImpls.AggregableMax<>();
for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1));
@ -124,16 +116,14 @@ public class AggregatorImplsTest extends BaseND4JTest {
assertEquals(9, mx.get().toInt());
}
@Test
public void aggregableRangeTest() {
@DisplayName("Aggregable Range Test")
void aggregableRangeTest() {
AggregatorImpls.AggregableRange<Integer> mx = new AggregatorImpls.AggregableRange<>();
for (int i = 0; i < intList.size(); i++) {
mx.accept(intList.get(i));
}
assertEquals(8, mx.get().toInt());
AggregatorImpls.AggregableRange<Integer> reverse = new AggregatorImpls.AggregableRange<>();
for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1) + 9);
@ -143,14 +133,13 @@ public class AggregatorImplsTest extends BaseND4JTest {
}
@Test
public void aggregableMinTest() {
@DisplayName("Aggregable Min Test")
void aggregableMinTest() {
AggregatorImpls.AggregableMin<Integer> mn = new AggregatorImpls.AggregableMin<>();
for (int i = 0; i < intList.size(); i++) {
mn.accept(intList.get(i));
}
assertEquals(1, mn.get().toInt());
AggregatorImpls.AggregableMin<Integer> reverse = new AggregatorImpls.AggregableMin<>();
for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1));
@ -160,14 +149,13 @@ public class AggregatorImplsTest extends BaseND4JTest {
}
@Test
public void aggregableSumTest() {
@DisplayName("Aggregable Sum Test")
void aggregableSumTest() {
AggregatorImpls.AggregableSum<Integer> sm = new AggregatorImpls.AggregableSum<>();
for (int i = 0; i < intList.size(); i++) {
sm.accept(intList.get(i));
}
assertEquals(45, sm.get().toInt());
AggregatorImpls.AggregableSum<Integer> reverse = new AggregatorImpls.AggregableSum<>();
for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1));
@ -176,17 +164,15 @@ public class AggregatorImplsTest extends BaseND4JTest {
assertEquals(90, sm.get().toInt());
}
@Test
public void aggregableMeanTest() {
@DisplayName("Aggregable Mean Test")
void aggregableMeanTest() {
AggregatorImpls.AggregableMean<Integer> mn = new AggregatorImpls.AggregableMean<>();
for (int i = 0; i < intList.size(); i++) {
mn.accept(intList.get(i));
}
assertEquals(9l, (long) mn.getCount());
assertEquals(5D, mn.get().toDouble(), 0.001);
AggregatorImpls.AggregableMean<Integer> reverse = new AggregatorImpls.AggregableMean<>();
for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1));
@ -197,80 +183,73 @@ public class AggregatorImplsTest extends BaseND4JTest {
}
@Test
public void aggregableStdDevTest() {
@DisplayName("Aggregable Std Dev Test")
void aggregableStdDevTest() {
AggregatorImpls.AggregableStdDev<Integer> sd = new AggregatorImpls.AggregableStdDev<>();
for (int i = 0; i < intList.size(); i++) {
sd.accept(intList.get(i));
}
assertTrue(Math.abs(sd.get().toDouble() - 2.7386) < 0.0001);
AggregatorImpls.AggregableStdDev<Integer> reverse = new AggregatorImpls.AggregableStdDev<>();
for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1));
}
sd.combine(reverse);
assertTrue("" + sd.get().toDouble(), Math.abs(sd.get().toDouble() - 1.8787) < 0.0001);
assertTrue(Math.abs(sd.get().toDouble() - 1.8787) < 0.0001,"" + sd.get().toDouble());
}
@Test
public void aggregableVariance() {
@DisplayName("Aggregable Variance")
void aggregableVariance() {
AggregatorImpls.AggregableVariance<Integer> sd = new AggregatorImpls.AggregableVariance<>();
for (int i = 0; i < intList.size(); i++) {
sd.accept(intList.get(i));
}
assertTrue(Math.abs(sd.get().toDouble() - 60D / 8) < 0.0001);
AggregatorImpls.AggregableVariance<Integer> reverse = new AggregatorImpls.AggregableVariance<>();
for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1));
}
sd.combine(reverse);
assertTrue("" + sd.get().toDouble(), Math.abs(sd.get().toDouble() - 3.5294) < 0.0001);
assertTrue(Math.abs(sd.get().toDouble() - 3.5294) < 0.0001,"" + sd.get().toDouble());
}
@Test
public void aggregableUncorrectedStdDevTest() {
@DisplayName("Aggregable Uncorrected Std Dev Test")
void aggregableUncorrectedStdDevTest() {
AggregatorImpls.AggregableUncorrectedStdDev<Integer> sd = new AggregatorImpls.AggregableUncorrectedStdDev<>();
for (int i = 0; i < intList.size(); i++) {
sd.accept(intList.get(i));
}
assertTrue(Math.abs(sd.get().toDouble() - 2.582) < 0.0001);
AggregatorImpls.AggregableUncorrectedStdDev<Integer> reverse =
new AggregatorImpls.AggregableUncorrectedStdDev<>();
AggregatorImpls.AggregableUncorrectedStdDev<Integer> reverse = new AggregatorImpls.AggregableUncorrectedStdDev<>();
for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1));
}
sd.combine(reverse);
assertTrue("" + sd.get().toDouble(), Math.abs(sd.get().toDouble() - 1.8257) < 0.0001);
assertTrue(Math.abs(sd.get().toDouble() - 1.8257) < 0.0001,"" + sd.get().toDouble());
}
@Test
public void aggregablePopulationVariance() {
@DisplayName("Aggregable Population Variance")
void aggregablePopulationVariance() {
AggregatorImpls.AggregablePopulationVariance<Integer> sd = new AggregatorImpls.AggregablePopulationVariance<>();
for (int i = 0; i < intList.size(); i++) {
sd.accept(intList.get(i));
}
assertTrue(Math.abs(sd.get().toDouble() - 60D / 9) < 0.0001);
AggregatorImpls.AggregablePopulationVariance<Integer> reverse =
new AggregatorImpls.AggregablePopulationVariance<>();
AggregatorImpls.AggregablePopulationVariance<Integer> reverse = new AggregatorImpls.AggregablePopulationVariance<>();
for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1));
}
sd.combine(reverse);
assertTrue("" + sd.get().toDouble(), Math.abs(sd.get().toDouble() - 30D / 9) < 0.0001);
assertTrue(Math.abs(sd.get().toDouble() - 30D / 9) < 0.0001,"" + sd.get().toDouble());
}
@Test
public void aggregableCountUniqueTest() {
@DisplayName("Aggregable Count Unique Test")
void aggregableCountUniqueTest() {
// at this low range, it's linear counting
AggregatorImpls.AggregableCountUnique<Integer> cu = new AggregatorImpls.AggregableCountUnique<>();
for (int i = 0; i < intList.size(); i++) {
cu.accept(intList.get(i));
@ -278,7 +257,6 @@ public class AggregatorImplsTest extends BaseND4JTest {
assertEquals(9, cu.get().toInt());
cu.accept(1);
assertEquals(9, cu.get().toInt());
AggregatorImpls.AggregableCountUnique<Integer> reverse = new AggregatorImpls.AggregableCountUnique<>();
for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1));
@ -290,16 +268,14 @@ public class AggregatorImplsTest extends BaseND4JTest {
@Rule
public final ExpectedException exception = ExpectedException.none();
@Test
public void incompatibleAggregatorTest() {
@DisplayName("Incompatible Aggregator Test")
void incompatibleAggregatorTest() {
AggregatorImpls.AggregableSum<Integer> sm = new AggregatorImpls.AggregableSum<>();
for (int i = 0; i < intList.size(); i++) {
sm.accept(intList.get(i));
}
assertEquals(45, sm.get().toInt());
AggregatorImpls.AggregableMean<Integer> reverse = new AggregatorImpls.AggregableMean<>();
for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1));
@ -308,5 +284,4 @@ public class AggregatorImplsTest extends BaseND4JTest {
sm.combine(reverse);
assertEquals(45, sm.get().toInt());
}
}

View File

@ -17,77 +17,65 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.transform.ops;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertTrue;
public class DispatchOpTest extends BaseND4JTest {
@DisplayName("Dispatch Op Test")
class DispatchOpTest extends BaseND4JTest {
private List<Integer> intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
private List<String> stringList = new ArrayList<>(Arrays.asList("arakoa", "abracadabra", "blast", "acceptance"));
@Test
public void testDispatchSimple() {
@DisplayName("Test Dispatch Simple")
void testDispatchSimple() {
AggregatorImpls.AggregableFirst<Integer> af = new AggregatorImpls.AggregableFirst<>();
AggregatorImpls.AggregableSum<Integer> as = new AggregatorImpls.AggregableSum<>();
AggregableMultiOp<Integer> multiaf =
new AggregableMultiOp<>(Collections.<IAggregableReduceOp<Integer, Writable>>singletonList(af));
AggregableMultiOp<Integer> multias =
new AggregableMultiOp<>(Collections.<IAggregableReduceOp<Integer, Writable>>singletonList(as));
DispatchOp<Integer, Writable> parallel =
new DispatchOp<>(Arrays.<IAggregableReduceOp<Integer, List<Writable>>>asList(multiaf, multias));
AggregableMultiOp<Integer> multiaf = new AggregableMultiOp<>(Collections.<IAggregableReduceOp<Integer, Writable>>singletonList(af));
AggregableMultiOp<Integer> multias = new AggregableMultiOp<>(Collections.<IAggregableReduceOp<Integer, Writable>>singletonList(as));
DispatchOp<Integer, Writable> parallel = new DispatchOp<>(Arrays.<IAggregableReduceOp<Integer, List<Writable>>>asList(multiaf, multias));
assertTrue(multiaf.getOperations().size() == 1);
assertTrue(multias.getOperations().size() == 1);
assertTrue(parallel.getOperations().size() == 2);
for (int i = 0; i < intList.size(); i++) {
parallel.accept(Arrays.asList(intList.get(i), intList.get(i)));
}
List<Writable> res = parallel.get();
assertTrue(res.get(1).toDouble() == 45D);
assertTrue(res.get(0).toInt() == 1);
}
@Test
public void testDispatchFlatMap() {
@DisplayName("Test Dispatch Flat Map")
void testDispatchFlatMap() {
AggregatorImpls.AggregableFirst<Integer> af = new AggregatorImpls.AggregableFirst<>();
AggregatorImpls.AggregableSum<Integer> as = new AggregatorImpls.AggregableSum<>();
AggregableMultiOp<Integer> multi = new AggregableMultiOp<>(Arrays.asList(af, as));
AggregatorImpls.AggregableLast<Integer> al = new AggregatorImpls.AggregableLast<>();
AggregatorImpls.AggregableMax<Integer> amax = new AggregatorImpls.AggregableMax<>();
AggregableMultiOp<Integer> otherMulti = new AggregableMultiOp<>(Arrays.asList(al, amax));
DispatchOp<Integer, Writable> parallel = new DispatchOp<>(
Arrays.<IAggregableReduceOp<Integer, List<Writable>>>asList(multi, otherMulti));
DispatchOp<Integer, Writable> parallel = new DispatchOp<>(Arrays.<IAggregableReduceOp<Integer, List<Writable>>>asList(multi, otherMulti));
assertTrue(multi.getOperations().size() == 2);
assertTrue(otherMulti.getOperations().size() == 2);
assertTrue(parallel.getOperations().size() == 2);
for (int i = 0; i < intList.size(); i++) {
parallel.accept(Arrays.asList(intList.get(i), intList.get(i)));
}
List<Writable> res = parallel.get();
assertTrue(res.get(1).toDouble() == 45D);
assertTrue(res.get(0).toInt() == 1);
assertTrue(res.get(3).toDouble() == 9);
assertTrue(res.get(2).toInt() == 9);
}
}

View File

@ -17,29 +17,29 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.transform.transform.parse;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals;
@DisplayName("Parse Double Transform Test")
class ParseDoubleTransformTest extends BaseND4JTest {
public class ParseDoubleTransformTest extends BaseND4JTest {
@Test
public void testDoubleTransform() {
@DisplayName("Test Double Transform")
void testDoubleTransform() {
List<Writable> record = new ArrayList<>();
record.add(new Text("0.0"));
List<Writable> transformed = Arrays.<Writable>asList(new DoubleWritable(0.0));
assertEquals(transformed, new ParseDoubleTransform().map(record));
}
}

View File

@ -17,30 +17,31 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.util;
import org.junit.Before;
import org.junit.Test;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import java.io.BufferedReader;
import java.io.File;
import java.io.InputStream;
import java.io.InputStreamReader;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.core.AnyOf.anyOf;
import static org.hamcrest.core.IsEqual.equalTo;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
public class ClassPathResourceTest extends BaseND4JTest {
@DisplayName("Class Path Resource Test")
class ClassPathResourceTest extends BaseND4JTest {
private boolean isWindows = false; //File sizes are reported slightly different on Linux vs. Windows
// File sizes are reported slightly different on Linux vs. Windows
private boolean isWindows = false;
@Before
public void setUp() throws Exception {
@BeforeEach
void setUp() throws Exception {
String osname = System.getProperty("os.name");
if (osname != null && osname.toLowerCase().contains("win")) {
isWindows = true;
@ -48,9 +49,9 @@ public class ClassPathResourceTest extends BaseND4JTest {
}
@Test
public void testGetFile1() throws Exception {
@DisplayName("Test Get File 1")
void testGetFile1() throws Exception {
File intFile = new ClassPathResource("datavec-api/iris.dat").getFile();
assertTrue(intFile.exists());
if (isWindows) {
assertThat(intFile.length(), anyOf(equalTo(2700L), equalTo(2850L)));
@ -60,9 +61,9 @@ public class ClassPathResourceTest extends BaseND4JTest {
}
@Test
public void testGetFileSlash1() throws Exception {
@DisplayName("Test Get File Slash 1")
void testGetFileSlash1() throws Exception {
File intFile = new ClassPathResource("datavec-api/iris.dat").getFile();
assertTrue(intFile.exists());
if (isWindows) {
assertThat(intFile.length(), anyOf(equalTo(2700L), equalTo(2850L)));
@ -72,11 +73,10 @@ public class ClassPathResourceTest extends BaseND4JTest {
}
@Test
public void testGetFileWithSpace1() throws Exception {
@DisplayName("Test Get File With Space 1")
void testGetFileWithSpace1() throws Exception {
File intFile = new ClassPathResource("datavec-api/csvsequence test.txt").getFile();
assertTrue(intFile.exists());
if (isWindows) {
assertThat(intFile.length(), anyOf(equalTo(60L), equalTo(64L)));
} else {
@ -85,16 +85,15 @@ public class ClassPathResourceTest extends BaseND4JTest {
}
@Test
public void testInputStream() throws Exception {
@DisplayName("Test Input Stream")
void testInputStream() throws Exception {
ClassPathResource resource = new ClassPathResource("datavec-api/csvsequence_1.txt");
File intFile = resource.getFile();
if (isWindows) {
assertThat(intFile.length(), anyOf(equalTo(60L), equalTo(64L)));
} else {
assertEquals(60, intFile.length());
}
InputStream stream = resource.getInputStream();
BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
String line = "";
@ -102,21 +101,19 @@ public class ClassPathResourceTest extends BaseND4JTest {
while ((line = reader.readLine()) != null) {
cnt++;
}
assertEquals(5, cnt);
}
@Test
public void testInputStreamSlash() throws Exception {
@DisplayName("Test Input Stream Slash")
void testInputStreamSlash() throws Exception {
ClassPathResource resource = new ClassPathResource("datavec-api/csvsequence_1.txt");
File intFile = resource.getFile();
if (isWindows) {
assertThat(intFile.length(), anyOf(equalTo(60L), equalTo(64L)));
} else {
assertEquals(60, intFile.length());
}
InputStream stream = resource.getInputStream();
BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
String line = "";
@ -124,7 +121,6 @@ public class ClassPathResourceTest extends BaseND4JTest {
while ((line = reader.readLine()) != null) {
cnt++;
}
assertEquals(5, cnt);
}
}

View File

@ -17,44 +17,41 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.util;
import org.datavec.api.timeseries.util.TimeSeriesWritableUtils;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.ArrayList;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertArrayEquals;
public class TimeSeriesUtilsTest extends BaseND4JTest {
@DisplayName("Time Series Utils Test")
class TimeSeriesUtilsTest extends BaseND4JTest {
@Test
public void testTimeSeriesCreation() {
@DisplayName("Test Time Series Creation")
void testTimeSeriesCreation() {
List<List<List<Writable>>> test = new ArrayList<>();
List<List<Writable>> timeStep = new ArrayList<>();
for(int i = 0; i < 5; i++) {
for (int i = 0; i < 5; i++) {
timeStep.add(getRecord(5));
}
test.add(timeStep);
INDArray arr = TimeSeriesWritableUtils.convertWritablesSequence(test).getFirst();
assertArrayEquals(new long[]{1,5,5},arr.shape());
}
assertArrayEquals(new long[] { 1, 5, 5 }, arr.shape());
}
private List<Writable> getRecord(int length) {
private List<Writable> getRecord(int length) {
List<Writable> ret = new ArrayList<>();
for(int i = 0; i < length; i++) {
for (int i = 0; i < length; i++) {
ret.add(new DoubleWritable(1.0));
}
return ret;
}
}
}

View File

@ -17,52 +17,50 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.writable;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.shade.guava.collect.Lists;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.util.ndarray.RecordConverter;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import java.util.Arrays;
import java.util.List;
import java.util.TimeZone;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals;
@DisplayName("Record Converter Test")
class RecordConverterTest extends BaseND4JTest {
public class RecordConverterTest extends BaseND4JTest {
@Test
public void toRecords_PassInClassificationDataSet_ExpectNDArrayAndIntWritables() {
INDArray feature1 = Nd4j.create(new double[]{4, -5.7, 10, -0.1}, new long[]{1, 4}, DataType.FLOAT);
INDArray feature2 = Nd4j.create(new double[]{11, .7, -1.3, 4}, new long[]{1, 4}, DataType.FLOAT);
INDArray label1 = Nd4j.create(new double[]{0, 0, 1, 0}, new long[]{1, 4}, DataType.FLOAT);
INDArray label2 = Nd4j.create(new double[]{0, 1, 0, 0}, new long[]{1, 4}, DataType.FLOAT);
DataSet dataSet = new DataSet(Nd4j.vstack(Lists.newArrayList(feature1, feature2)),
Nd4j.vstack(Lists.newArrayList(label1, label2)));
@DisplayName("To Records _ Pass In Classification Data Set _ Expect ND Array And Int Writables")
void toRecords_PassInClassificationDataSet_ExpectNDArrayAndIntWritables() {
INDArray feature1 = Nd4j.create(new double[] { 4, -5.7, 10, -0.1 }, new long[] { 1, 4 }, DataType.FLOAT);
INDArray feature2 = Nd4j.create(new double[] { 11, .7, -1.3, 4 }, new long[] { 1, 4 }, DataType.FLOAT);
INDArray label1 = Nd4j.create(new double[] { 0, 0, 1, 0 }, new long[] { 1, 4 }, DataType.FLOAT);
INDArray label2 = Nd4j.create(new double[] { 0, 1, 0, 0 }, new long[] { 1, 4 }, DataType.FLOAT);
DataSet dataSet = new DataSet(Nd4j.vstack(Lists.newArrayList(feature1, feature2)), Nd4j.vstack(Lists.newArrayList(label1, label2)));
List<List<Writable>> writableList = RecordConverter.toRecords(dataSet);
assertEquals(2, writableList.size());
testClassificationWritables(feature1, 2, writableList.get(0));
testClassificationWritables(feature2, 1, writableList.get(1));
}
@Test
public void toRecords_PassInRegressionDataSet_ExpectNDArrayAndDoubleWritables() {
INDArray feature = Nd4j.create(new double[]{4, -5.7, 10, -0.1}, new long[]{1, 4}, DataType.FLOAT);
INDArray label = Nd4j.create(new double[]{.5, 2, 3, .5}, new long[]{1, 4}, DataType.FLOAT);
@DisplayName("To Records _ Pass In Regression Data Set _ Expect ND Array And Double Writables")
void toRecords_PassInRegressionDataSet_ExpectNDArrayAndDoubleWritables() {
INDArray feature = Nd4j.create(new double[] { 4, -5.7, 10, -0.1 }, new long[] { 1, 4 }, DataType.FLOAT);
INDArray label = Nd4j.create(new double[] { .5, 2, 3, .5 }, new long[] { 1, 4 }, DataType.FLOAT);
DataSet dataSet = new DataSet(feature, label);
List<List<Writable>> writableList = RecordConverter.toRecords(dataSet);
List<Writable> results = writableList.get(0);
NDArrayWritable ndArrayWritable = (NDArrayWritable) results.get(0);
assertEquals(1, writableList.size());
assertEquals(5, results.size());
assertEquals(feature, ndArrayWritable.get());
@ -72,62 +70,39 @@ public class RecordConverterTest extends BaseND4JTest {
}
}
private void testClassificationWritables(INDArray expectedFeatureVector, int expectLabelIndex,
List<Writable> writables) {
private void testClassificationWritables(INDArray expectedFeatureVector, int expectLabelIndex, List<Writable> writables) {
NDArrayWritable ndArrayWritable = (NDArrayWritable) writables.get(0);
IntWritable intWritable = (IntWritable) writables.get(1);
assertEquals(2, writables.size());
assertEquals(expectedFeatureVector, ndArrayWritable.get());
assertEquals(expectLabelIndex, intWritable.get());
}
@Test
public void testNDArrayWritableConcat() {
List<Writable> l = Arrays.<Writable>asList(new DoubleWritable(1),
new NDArrayWritable(Nd4j.create(new double[]{2, 3, 4}, new long[]{1, 3}, DataType.FLOAT)), new DoubleWritable(5),
new NDArrayWritable(Nd4j.create(new double[]{6, 7, 8}, new long[]{1, 3}, DataType.FLOAT)), new IntWritable(9),
new IntWritable(1));
INDArray exp = Nd4j.create(new double[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 1}, new long[]{1, 10}, DataType.FLOAT);
@DisplayName("Test ND Array Writable Concat")
void testNDArrayWritableConcat() {
List<Writable> l = Arrays.<Writable>asList(new DoubleWritable(1), new NDArrayWritable(Nd4j.create(new double[] { 2, 3, 4 }, new long[] { 1, 3 }, DataType.FLOAT)), new DoubleWritable(5), new NDArrayWritable(Nd4j.create(new double[] { 6, 7, 8 }, new long[] { 1, 3 }, DataType.FLOAT)), new IntWritable(9), new IntWritable(1));
INDArray exp = Nd4j.create(new double[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 1 }, new long[] { 1, 10 }, DataType.FLOAT);
INDArray act = RecordConverter.toArray(DataType.FLOAT, l);
assertEquals(exp, act);
}
@Test
public void testNDArrayWritableConcatToMatrix(){
List<Writable> l1 = Arrays.<Writable>asList(new DoubleWritable(1), new NDArrayWritable(Nd4j.create(new double[]{2, 3, 4}, new long[]{1,3}, DataType.FLOAT)), new DoubleWritable(5));
List<Writable> l2 = Arrays.<Writable>asList(new DoubleWritable(6), new NDArrayWritable(Nd4j.create(new double[]{7, 8, 9}, new long[]{1,3}, DataType.FLOAT)), new DoubleWritable(10));
INDArray exp = Nd4j.create(new double[][]{
{1,2,3,4,5},
{6,7,8,9,10}}).castTo(DataType.FLOAT);
INDArray act = RecordConverter.toMatrix(DataType.FLOAT, Arrays.asList(l1,l2));
@DisplayName("Test ND Array Writable Concat To Matrix")
void testNDArrayWritableConcatToMatrix() {
List<Writable> l1 = Arrays.<Writable>asList(new DoubleWritable(1), new NDArrayWritable(Nd4j.create(new double[] { 2, 3, 4 }, new long[] { 1, 3 }, DataType.FLOAT)), new DoubleWritable(5));
List<Writable> l2 = Arrays.<Writable>asList(new DoubleWritable(6), new NDArrayWritable(Nd4j.create(new double[] { 7, 8, 9 }, new long[] { 1, 3 }, DataType.FLOAT)), new DoubleWritable(10));
INDArray exp = Nd4j.create(new double[][] { { 1, 2, 3, 4, 5 }, { 6, 7, 8, 9, 10 } }).castTo(DataType.FLOAT);
INDArray act = RecordConverter.toMatrix(DataType.FLOAT, Arrays.asList(l1, l2));
assertEquals(exp, act);
}
@Test
public void testToRecordWithListOfObject(){
final List<Object> list = Arrays.asList((Object)3, 7.0f, "Foo", "Bar", 1.0, 3f, 3L, 7, 0L);
final Schema schema = new Schema.Builder()
.addColumnInteger("a")
.addColumnFloat("b")
.addColumnString("c")
.addColumnCategorical("d", "Bar", "Baz")
.addColumnDouble("e")
.addColumnFloat("f")
.addColumnLong("g")
.addColumnInteger("h")
.addColumnTime("i", TimeZone.getDefault())
.build();
@DisplayName("Test To Record With List Of Object")
void testToRecordWithListOfObject() {
final List<Object> list = Arrays.asList((Object) 3, 7.0f, "Foo", "Bar", 1.0, 3f, 3L, 7, 0L);
final Schema schema = new Schema.Builder().addColumnInteger("a").addColumnFloat("b").addColumnString("c").addColumnCategorical("d", "Bar", "Baz").addColumnDouble("e").addColumnFloat("f").addColumnLong("g").addColumnInteger("h").addColumnTime("i", TimeZone.getDefault()).build();
final List<Writable> record = RecordConverter.toRecord(schema, list);
assertEquals(record.get(0).toInt(), 3);
assertEquals(record.get(1).toFloat(), 7f, 1e-6);
assertEquals(record.get(2).toString(), "Foo");
@ -137,7 +112,5 @@ public class RecordConverterTest extends BaseND4JTest {
assertEquals(record.get(6).toLong(), 3L);
assertEquals(record.get(7).toInt(), 7);
assertEquals(record.get(8).toLong(), 0);
}
}

View File

@ -17,38 +17,38 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.writable;
import org.datavec.api.writable.batch.NDArrayRecordBatch;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.junit.jupiter.api.DisplayName;
import static org.junit.Assert.*;
import static org.junit.jupiter.api.Assertions.*;
public class WritableTest extends BaseND4JTest {
@DisplayName("Writable Test")
class WritableTest extends BaseND4JTest {
@Test
public void testWritableEqualityReflexive() {
@DisplayName("Test Writable Equality Reflexive")
void testWritableEqualityReflexive() {
assertEquals(new IntWritable(1), new IntWritable(1));
assertEquals(new LongWritable(1), new LongWritable(1));
assertEquals(new DoubleWritable(1), new DoubleWritable(1));
assertEquals(new FloatWritable(1), new FloatWritable(1));
assertEquals(new Text("Hello"), new Text("Hello"));
assertEquals(new BytesWritable("Hello".getBytes()),new BytesWritable("Hello".getBytes()));
INDArray ndArray = Nd4j.rand(new int[]{1, 100});
assertEquals(new BytesWritable("Hello".getBytes()), new BytesWritable("Hello".getBytes()));
INDArray ndArray = Nd4j.rand(new int[] { 1, 100 });
assertEquals(new NDArrayWritable(ndArray), new NDArrayWritable(ndArray));
assertEquals(new NullWritable(), new NullWritable());
assertEquals(new BooleanWritable(true), new BooleanWritable(true));
@ -56,9 +56,9 @@ public class WritableTest extends BaseND4JTest {
assertEquals(new ByteWritable(b), new ByteWritable(b));
}
@Test
public void testBytesWritableIndexing() {
@DisplayName("Test Bytes Writable Indexing")
void testBytesWritableIndexing() {
byte[] doubleWrite = new byte[16];
ByteBuffer wrapped = ByteBuffer.wrap(doubleWrite);
Buffer buffer = (Buffer) wrapped;
@ -66,53 +66,51 @@ public class WritableTest extends BaseND4JTest {
wrapped.putDouble(2.0);
buffer.rewind();
BytesWritable byteWritable = new BytesWritable(doubleWrite);
assertEquals(2,byteWritable.getDouble(1),1e-1);
DataBuffer dataBuffer = Nd4j.createBuffer(new double[] {1,2});
assertEquals(2, byteWritable.getDouble(1), 1e-1);
DataBuffer dataBuffer = Nd4j.createBuffer(new double[] { 1, 2 });
double[] d1 = dataBuffer.asDouble();
double[] d2 = byteWritable.asNd4jBuffer(DataType.DOUBLE,8).asDouble();
double[] d2 = byteWritable.asNd4jBuffer(DataType.DOUBLE, 8).asDouble();
assertArrayEquals(d1, d2, 0.0);
}
@Test
public void testByteWritable() {
@DisplayName("Test Byte Writable")
void testByteWritable() {
byte b = 0xfffffffe;
assertEquals(new IntWritable(-2), new ByteWritable(b));
assertEquals(new LongWritable(-2), new ByteWritable(b));
assertEquals(new ByteWritable(b), new IntWritable(-2));
assertEquals(new ByteWritable(b), new LongWritable(-2));
// those would cast to the same Int
byte minus126 = 0xffffff82;
assertNotEquals(new ByteWritable(minus126), new IntWritable(130));
}
@Test
public void testIntLongWritable() {
@DisplayName("Test Int Long Writable")
void testIntLongWritable() {
assertEquals(new IntWritable(1), new LongWritable(1l));
assertEquals(new LongWritable(2l), new IntWritable(2));
long l = 1L << 34;
// those would cast to the same Int
assertNotEquals(new LongWritable(l), new IntWritable(4));
}
@Test
public void testDoubleFloatWritable() {
@DisplayName("Test Double Float Writable")
void testDoubleFloatWritable() {
assertEquals(new DoubleWritable(1d), new FloatWritable(1f));
assertEquals(new FloatWritable(2f), new DoubleWritable(2d));
// we defer to Java equality for Floats
assertNotEquals(new DoubleWritable(1.1d), new FloatWritable(1.1f));
// same idea as above
assertNotEquals(new DoubleWritable(1.1d), new FloatWritable((float)1.1d));
assertNotEquals(new DoubleWritable((double)Float.MAX_VALUE + 1), new FloatWritable(Float.POSITIVE_INFINITY));
assertNotEquals(new DoubleWritable(1.1d), new FloatWritable((float) 1.1d));
assertNotEquals(new DoubleWritable((double) Float.MAX_VALUE + 1), new FloatWritable(Float.POSITIVE_INFINITY));
}
@Test
public void testFuzzies() {
@DisplayName("Test Fuzzies")
void testFuzzies() {
assertTrue(new DoubleWritable(1.1d).fuzzyEquals(new FloatWritable(1.1f), 1e-6d));
assertTrue(new FloatWritable(1.1f).fuzzyEquals(new DoubleWritable(1.1d), 1e-6d));
byte b = 0xfffffffe;
@ -122,62 +120,57 @@ public class WritableTest extends BaseND4JTest {
assertTrue(new LongWritable(1).fuzzyEquals(new DoubleWritable(1.05f), 1e-1d));
}
@Test
public void testNDArrayRecordBatch(){
@DisplayName("Test ND Array Record Batch")
void testNDArrayRecordBatch() {
Nd4j.getRandom().setSeed(12345);
List<List<INDArray>> orig = new ArrayList<>(); //Outer list over writables/columns, inner list over examples
for( int i=0; i<3; i++ ){
// Outer list over writables/columns, inner list over examples
List<List<INDArray>> orig = new ArrayList<>();
for (int i = 0; i < 3; i++) {
orig.add(new ArrayList<INDArray>());
}
for( int i=0; i<5; i++ ){
orig.get(0).add(Nd4j.rand(1,10));
orig.get(1).add(Nd4j.rand(new int[]{1,5,6}));
orig.get(2).add(Nd4j.rand(new int[]{1,3,4,5}));
for (int i = 0; i < 5; i++) {
orig.get(0).add(Nd4j.rand(1, 10));
orig.get(1).add(Nd4j.rand(new int[] { 1, 5, 6 }));
orig.get(2).add(Nd4j.rand(new int[] { 1, 3, 4, 5 }));
}
List<List<INDArray>> origByExample = new ArrayList<>(); //Outer list over examples, inner list over writables
for( int i=0; i<5; i++ ){
// Outer list over examples, inner list over writables
List<List<INDArray>> origByExample = new ArrayList<>();
for (int i = 0; i < 5; i++) {
origByExample.add(Arrays.asList(orig.get(0).get(i), orig.get(1).get(i), orig.get(2).get(i)));
}
List<INDArray> batched = new ArrayList<>();
for(List<INDArray> l : orig){
for (List<INDArray> l : orig) {
batched.add(Nd4j.concat(0, l.toArray(new INDArray[5])));
}
NDArrayRecordBatch batch = new NDArrayRecordBatch(batched);
assertEquals(5, batch.size());
for( int i=0; i<5; i++ ){
for (int i = 0; i < 5; i++) {
List<Writable> act = batch.get(i);
List<INDArray> unboxed = new ArrayList<>();
for(Writable w : act){
unboxed.add(((NDArrayWritable)w).get());
for (Writable w : act) {
unboxed.add(((NDArrayWritable) w).get());
}
List<INDArray> exp = origByExample.get(i);
assertEquals(exp.size(), unboxed.size());
for( int j=0; j<exp.size(); j++ ){
for (int j = 0; j < exp.size(); j++) {
assertEquals(exp.get(j), unboxed.get(j));
}
}
Iterator<List<Writable>> iter = batch.iterator();
int count = 0;
while(iter.hasNext()){
while (iter.hasNext()) {
List<Writable> next = iter.next();
List<INDArray> unboxed = new ArrayList<>();
for(Writable w : next){
unboxed.add(((NDArrayWritable)w).get());
for (Writable w : next) {
unboxed.add(((NDArrayWritable) w).get());
}
List<INDArray> exp = origByExample.get(count++);
assertEquals(exp.size(), unboxed.size());
for( int j=0; j<exp.size(); j++ ){
for (int j = 0; j < exp.size(); j++) {
assertEquals(exp.get(j), unboxed.get(j));
}
}
assertEquals(5, count);
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.arrow;
import lombok.extern.slf4j.Slf4j;
@ -43,460 +42,398 @@ import org.datavec.api.writable.*;
import org.datavec.arrow.recordreader.ArrowRecordReader;
import org.datavec.arrow.recordreader.ArrowWritableRecordBatch;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.primitives.Pair;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.*;
import static java.nio.channels.Channels.newChannel;
import static junit.framework.TestCase.assertTrue;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
@Slf4j
public class ArrowConverterTest extends BaseND4JTest {
@DisplayName("Arrow Converter Test")
class ArrowConverterTest extends BaseND4JTest {
private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE);
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@TempDir
public Path testDir;
@Test
public void testToArrayFromINDArray() {
@DisplayName("Test To Array From IND Array")
void testToArrayFromINDArray() {
Schema.Builder schemaBuilder = new Schema.Builder();
schemaBuilder.addColumnNDArray("outputArray",new long[]{1,4});
schemaBuilder.addColumnNDArray("outputArray", new long[] { 1, 4 });
Schema schema = schemaBuilder.build();
int numRows = 4;
List<List<Writable>> ret = new ArrayList<>(numRows);
for(int i = 0; i < numRows; i++) {
ret.add(Arrays.<Writable>asList(new NDArrayWritable(Nd4j.linspace(1,4,4).reshape(1, 4))));
for (int i = 0; i < numRows; i++) {
ret.add(Arrays.<Writable>asList(new NDArrayWritable(Nd4j.linspace(1, 4, 4).reshape(1, 4))));
}
List<FieldVector> fieldVectors = ArrowConverter.toArrowColumns(bufferAllocator, schema, ret);
ArrowWritableRecordBatch arrowWritableRecordBatch = new ArrowWritableRecordBatch(fieldVectors,schema);
ArrowWritableRecordBatch arrowWritableRecordBatch = new ArrowWritableRecordBatch(fieldVectors, schema);
INDArray array = ArrowConverter.toArray(arrowWritableRecordBatch);
assertArrayEquals(new long[]{4,4},array.shape());
INDArray assertion = Nd4j.repeat(Nd4j.linspace(1,4,4),4).reshape(4,4);
assertEquals(assertion,array);
assertArrayEquals(new long[] { 4, 4 }, array.shape());
INDArray assertion = Nd4j.repeat(Nd4j.linspace(1, 4, 4), 4).reshape(4, 4);
assertEquals(assertion, array);
}
@Test
public void testArrowColumnINDArray() {
@DisplayName("Test Arrow Column IND Array")
void testArrowColumnINDArray() {
Schema.Builder schema = new Schema.Builder();
List<String> single = new ArrayList<>();
int numCols = 2;
INDArray arr = Nd4j.linspace(1,4,4);
for(int i = 0; i < numCols; i++) {
schema.addColumnNDArray(String.valueOf(i),new long[]{1,4});
INDArray arr = Nd4j.linspace(1, 4, 4);
for (int i = 0; i < numCols; i++) {
schema.addColumnNDArray(String.valueOf(i), new long[] { 1, 4 });
single.add(String.valueOf(i));
}
Schema buildSchema = schema.build();
List<List<Writable>> list = new ArrayList<>();
List<Writable> firstRow = new ArrayList<>();
for(int i = 0 ; i < numCols; i++) {
for (int i = 0; i < numCols; i++) {
firstRow.add(new NDArrayWritable(arr));
}
list.add(firstRow);
List<FieldVector> fieldVectors = ArrowConverter.toArrowColumns(bufferAllocator, buildSchema, list);
assertEquals(numCols,fieldVectors.size());
assertEquals(1,fieldVectors.get(0).getValueCount());
assertEquals(numCols, fieldVectors.size());
assertEquals(1, fieldVectors.get(0).getValueCount());
assertFalse(fieldVectors.get(0).isNull(0));
ArrowWritableRecordBatch arrowWritableRecordBatch = ArrowConverter.toArrowWritables(fieldVectors, buildSchema);
assertEquals(1,arrowWritableRecordBatch.size());
assertEquals(1, arrowWritableRecordBatch.size());
Writable writable = arrowWritableRecordBatch.get(0).get(0);
assertTrue(writable instanceof NDArrayWritable);
NDArrayWritable ndArrayWritable = (NDArrayWritable) writable;
assertEquals(arr,ndArrayWritable.get());
assertEquals(arr, ndArrayWritable.get());
Writable writable1 = ArrowConverter.fromEntry(0, fieldVectors.get(0), ColumnType.NDArray);
NDArrayWritable ndArrayWritablewritable1 = (NDArrayWritable) writable1;
System.out.println(ndArrayWritablewritable1.get());
}
@Test
public void testArrowColumnString() {
@DisplayName("Test Arrow Column String")
void testArrowColumnString() {
Schema.Builder schema = new Schema.Builder();
List<String> single = new ArrayList<>();
for(int i = 0; i < 2; i++) {
for (int i = 0; i < 2; i++) {
schema.addColumnInteger(String.valueOf(i));
single.add(String.valueOf(i));
}
List<FieldVector> fieldVectors = ArrowConverter.toArrowColumnsStringSingle(bufferAllocator, schema.build(), single);
List<List<Writable>> records = ArrowConverter.toArrowWritables(fieldVectors, schema.build());
List<List<Writable>> assertion = new ArrayList<>();
assertion.add(Arrays.<Writable>asList(new IntWritable(0),new IntWritable(1)));
assertEquals(assertion,records);
assertion.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(1)));
assertEquals(assertion, records);
List<List<String>> batch = new ArrayList<>();
for(int i = 0; i < 2; i++) {
batch.add(Arrays.asList(String.valueOf(i),String.valueOf(i)));
for (int i = 0; i < 2; i++) {
batch.add(Arrays.asList(String.valueOf(i), String.valueOf(i)));
}
List<FieldVector> fieldVectorsBatch = ArrowConverter.toArrowColumnsString(bufferAllocator, schema.build(), batch);
List<List<Writable>> batchRecords = ArrowConverter.toArrowWritables(fieldVectorsBatch, schema.build());
List<List<Writable>> assertionBatch = new ArrayList<>();
assertionBatch.add(Arrays.<Writable>asList(new IntWritable(0),new IntWritable(0)));
assertionBatch.add(Arrays.<Writable>asList(new IntWritable(1),new IntWritable(1)));
assertEquals(assertionBatch,batchRecords);
assertionBatch.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(0)));
assertionBatch.add(Arrays.<Writable>asList(new IntWritable(1), new IntWritable(1)));
assertEquals(assertionBatch, batchRecords);
}
@Test
public void testArrowBatchSetTime() {
@DisplayName("Test Arrow Batch Set Time")
void testArrowBatchSetTime() {
Schema.Builder schema = new Schema.Builder();
List<String> single = new ArrayList<>();
for(int i = 0; i < 2; i++) {
schema.addColumnTime(String.valueOf(i),TimeZone.getDefault());
for (int i = 0; i < 2; i++) {
schema.addColumnTime(String.valueOf(i), TimeZone.getDefault());
single.add(String.valueOf(i));
}
List<List<Writable>> input = Arrays.asList(
Arrays.<Writable>asList(new LongWritable(0),new LongWritable(1)),
Arrays.<Writable>asList(new LongWritable(2),new LongWritable(3))
);
List<FieldVector> fieldVector = ArrowConverter.toArrowColumns(bufferAllocator,schema.build(),input);
ArrowWritableRecordBatch writableRecordBatch = new ArrowWritableRecordBatch(fieldVector,schema.build());
List<List<Writable>> input = Arrays.asList(Arrays.<Writable>asList(new LongWritable(0), new LongWritable(1)), Arrays.<Writable>asList(new LongWritable(2), new LongWritable(3)));
List<FieldVector> fieldVector = ArrowConverter.toArrowColumns(bufferAllocator, schema.build(), input);
ArrowWritableRecordBatch writableRecordBatch = new ArrowWritableRecordBatch(fieldVector, schema.build());
List<Writable> assertion = Arrays.<Writable>asList(new LongWritable(4), new LongWritable(5));
writableRecordBatch.set(1, Arrays.<Writable>asList(new LongWritable(4),new LongWritable(5)));
writableRecordBatch.set(1, Arrays.<Writable>asList(new LongWritable(4), new LongWritable(5)));
List<Writable> recordTest = writableRecordBatch.get(1);
assertEquals(assertion,recordTest);
assertEquals(assertion, recordTest);
}
@Test
public void testArrowBatchSet() {
@DisplayName("Test Arrow Batch Set")
void testArrowBatchSet() {
Schema.Builder schema = new Schema.Builder();
List<String> single = new ArrayList<>();
for(int i = 0; i < 2; i++) {
for (int i = 0; i < 2; i++) {
schema.addColumnInteger(String.valueOf(i));
single.add(String.valueOf(i));
}
List<List<Writable>> input = Arrays.asList(
Arrays.<Writable>asList(new IntWritable(0),new IntWritable(1)),
Arrays.<Writable>asList(new IntWritable(2),new IntWritable(3))
);
List<FieldVector> fieldVector = ArrowConverter.toArrowColumns(bufferAllocator,schema.build(),input);
ArrowWritableRecordBatch writableRecordBatch = new ArrowWritableRecordBatch(fieldVector,schema.build());
List<List<Writable>> input = Arrays.asList(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(1)), Arrays.<Writable>asList(new IntWritable(2), new IntWritable(3)));
List<FieldVector> fieldVector = ArrowConverter.toArrowColumns(bufferAllocator, schema.build(), input);
ArrowWritableRecordBatch writableRecordBatch = new ArrowWritableRecordBatch(fieldVector, schema.build());
List<Writable> assertion = Arrays.<Writable>asList(new IntWritable(4), new IntWritable(5));
writableRecordBatch.set(1, Arrays.<Writable>asList(new IntWritable(4),new IntWritable(5)));
writableRecordBatch.set(1, Arrays.<Writable>asList(new IntWritable(4), new IntWritable(5)));
List<Writable> recordTest = writableRecordBatch.get(1);
assertEquals(assertion,recordTest);
assertEquals(assertion, recordTest);
}
@Test
public void testArrowColumnsStringTimeSeries() {
@DisplayName("Test Arrow Columns String Time Series")
void testArrowColumnsStringTimeSeries() {
Schema.Builder schema = new Schema.Builder();
List<List<List<String>>> entries = new ArrayList<>();
for(int i = 0; i < 3; i++) {
for (int i = 0; i < 3; i++) {
schema.addColumnInteger(String.valueOf(i));
}
for(int i = 0; i < 5; i++) {
for (int i = 0; i < 5; i++) {
List<List<String>> arr = Arrays.asList(Arrays.asList(String.valueOf(i), String.valueOf(i), String.valueOf(i)));
entries.add(arr);
}
List<FieldVector> fieldVectors = ArrowConverter.toArrowColumnsStringTimeSeries(bufferAllocator, schema.build(), entries);
assertEquals(3,fieldVectors.size());
assertEquals(5,fieldVectors.get(0).getValueCount());
assertEquals(3, fieldVectors.size());
assertEquals(5, fieldVectors.get(0).getValueCount());
INDArray exp = Nd4j.create(5, 3);
for( int i = 0; i < 5; i++) {
for (int i = 0; i < 5; i++) {
exp.getRow(i).assign(i);
}
//Convert to ArrowWritableRecordBatch - note we can't do this in general with time series...
// Convert to ArrowWritableRecordBatch - note we can't do this in general with time series...
ArrowWritableRecordBatch wri = ArrowConverter.toArrowWritables(fieldVectors, schema.build());
INDArray arr = ArrowConverter.toArray(wri);
assertArrayEquals(new long[] {5,3}, arr.shape());
assertArrayEquals(new long[] { 5, 3 }, arr.shape());
assertEquals(exp, arr);
}
@Test
public void testConvertVector() {
@DisplayName("Test Convert Vector")
void testConvertVector() {
Schema.Builder schema = new Schema.Builder();
List<List<List<String>>> entries = new ArrayList<>();
for(int i = 0; i < 3; i++) {
for (int i = 0; i < 3; i++) {
schema.addColumnInteger(String.valueOf(i));
}
for(int i = 0; i < 5; i++) {
for (int i = 0; i < 5; i++) {
List<List<String>> arr = Arrays.asList(Arrays.asList(String.valueOf(i), String.valueOf(i), String.valueOf(i)));
entries.add(arr);
}
List<FieldVector> fieldVectors = ArrowConverter.toArrowColumnsStringTimeSeries(bufferAllocator, schema.build(), entries);
INDArray arr = ArrowConverter.convertArrowVector(fieldVectors.get(0),schema.build().getType(0));
assertEquals(5,arr.length());
INDArray arr = ArrowConverter.convertArrowVector(fieldVectors.get(0), schema.build().getType(0));
assertEquals(5, arr.length());
}
@Test
public void testCreateNDArray() throws Exception {
@DisplayName("Test Create ND Array")
void testCreateNDArray() throws Exception {
val recordsToWrite = recordToWrite();
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(),recordsToWrite.getFirst(),byteArrayOutputStream);
File f = testDir.newFolder();
ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(), recordsToWrite.getFirst(), byteArrayOutputStream);
File f = testDir.toFile();
File tmpFile = new File(f, "tmp-arrow-file-" + UUID.randomUUID().toString() + ".arrorw");
FileOutputStream outputStream = new FileOutputStream(tmpFile);
tmpFile.deleteOnExit();
ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(),recordsToWrite.getFirst(),outputStream);
ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(), recordsToWrite.getFirst(), outputStream);
outputStream.flush();
outputStream.close();
Pair<Schema, ArrowWritableRecordBatch> schemaArrowWritableRecordBatchPair = ArrowConverter.readFromFile(tmpFile);
assertEquals(recordsToWrite.getFirst(),schemaArrowWritableRecordBatchPair.getFirst());
assertEquals(recordsToWrite.getRight(),schemaArrowWritableRecordBatchPair.getRight().toArrayList());
assertEquals(recordsToWrite.getFirst(), schemaArrowWritableRecordBatchPair.getFirst());
assertEquals(recordsToWrite.getRight(), schemaArrowWritableRecordBatchPair.getRight().toArrayList());
byte[] arr = byteArrayOutputStream.toByteArray();
val read = ArrowConverter.readFromBytes(arr);
assertEquals(recordsToWrite,read);
//send file
File tmp = tmpDataFile(recordsToWrite);
assertEquals(recordsToWrite, read);
// send file
File tmp = tmpDataFile(recordsToWrite);
ArrowRecordReader recordReader = new ArrowRecordReader();
recordReader.initialize(new FileSplit(tmp));
recordReader.next();
ArrowWritableRecordBatch currentBatch = recordReader.getCurrentBatch();
INDArray arr2 = ArrowConverter.toArray(currentBatch);
assertEquals(2,arr2.rows());
assertEquals(2,arr2.columns());
}
@Test
public void testConvertToArrowVectors() {
INDArray matrix = Nd4j.linspace(1,4,4).reshape(2,2);
val vectors = ArrowConverter.convertToArrowVector(matrix,Arrays.asList("test","test2"), ColumnType.Double,bufferAllocator);
assertEquals(matrix.rows(),vectors.size());
INDArray vector = Nd4j.linspace(1,4,4);
val vectors2 = ArrowConverter.convertToArrowVector(vector,Arrays.asList("test"), ColumnType.Double,bufferAllocator);
assertEquals(1,vectors2.size());
assertEquals(matrix.length(),vectors2.get(0).getValueCount());
assertEquals(2, arr2.rows());
assertEquals(2, arr2.columns());
}
@Test
public void testSchemaConversionBasic() {
@DisplayName("Test Convert To Arrow Vectors")
void testConvertToArrowVectors() {
INDArray matrix = Nd4j.linspace(1, 4, 4).reshape(2, 2);
val vectors = ArrowConverter.convertToArrowVector(matrix, Arrays.asList("test", "test2"), ColumnType.Double, bufferAllocator);
assertEquals(matrix.rows(), vectors.size());
INDArray vector = Nd4j.linspace(1, 4, 4);
val vectors2 = ArrowConverter.convertToArrowVector(vector, Arrays.asList("test"), ColumnType.Double, bufferAllocator);
assertEquals(1, vectors2.size());
assertEquals(matrix.length(), vectors2.get(0).getValueCount());
}
@Test
@DisplayName("Test Schema Conversion Basic")
void testSchemaConversionBasic() {
Schema.Builder schemaBuilder = new Schema.Builder();
for(int i = 0; i < 2; i++) {
for (int i = 0; i < 2; i++) {
schemaBuilder.addColumnDouble("test-" + i);
schemaBuilder.addColumnInteger("testi-" + i);
schemaBuilder.addColumnLong("testl-" + i);
schemaBuilder.addColumnFloat("testf-" + i);
}
Schema schema = schemaBuilder.build();
val schema2 = ArrowConverter.toArrowSchema(schema);
assertEquals(8,schema2.getFields().size());
assertEquals(8, schema2.getFields().size());
val convertedSchema = ArrowConverter.toDatavecSchema(schema2);
assertEquals(schema,convertedSchema);
assertEquals(schema, convertedSchema);
}
@Test
public void testReadSchemaAndRecordsFromByteArray() throws Exception {
@DisplayName("Test Read Schema And Records From Byte Array")
void testReadSchemaAndRecordsFromByteArray() throws Exception {
BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
int valueCount = 3;
List<Field> fields = new ArrayList<>();
fields.add(ArrowConverter.field("field1",new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)));
fields.add(ArrowConverter.field("field1", new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)));
fields.add(ArrowConverter.intField("field2"));
List<FieldVector> fieldVectors = new ArrayList<>();
fieldVectors.add(ArrowConverter.vectorFor(allocator,"field1",new float[] {1,2,3}));
fieldVectors.add(ArrowConverter.vectorFor(allocator,"field2",new int[] {1,2,3}));
fieldVectors.add(ArrowConverter.vectorFor(allocator, "field1", new float[] { 1, 2, 3 }));
fieldVectors.add(ArrowConverter.vectorFor(allocator, "field2", new int[] { 1, 2, 3 }));
org.apache.arrow.vector.types.pojo.Schema schema = new org.apache.arrow.vector.types.pojo.Schema(fields);
VectorSchemaRoot schemaRoot1 = new VectorSchemaRoot(schema, fieldVectors, valueCount);
VectorUnloader vectorUnloader = new VectorUnloader(schemaRoot1);
vectorUnloader.getRecordBatch();
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
try(ArrowFileWriter arrowFileWriter = new ArrowFileWriter(schemaRoot1,null,newChannel(byteArrayOutputStream))) {
try (ArrowFileWriter arrowFileWriter = new ArrowFileWriter(schemaRoot1, null, newChannel(byteArrayOutputStream))) {
arrowFileWriter.writeBatch();
} catch (IOException e) {
log.error("",e);
log.error("", e);
}
byte[] arr = byteArrayOutputStream.toByteArray();
val arr2 = ArrowConverter.readFromBytes(arr);
assertEquals(2,arr2.getFirst().numColumns());
assertEquals(3,arr2.getRight().size());
val arrowCols = ArrowConverter.toArrowColumns(allocator,arr2.getFirst(),arr2.getRight());
assertEquals(2,arrowCols.size());
assertEquals(valueCount,arrowCols.get(0).getValueCount());
assertEquals(2, arr2.getFirst().numColumns());
assertEquals(3, arr2.getRight().size());
val arrowCols = ArrowConverter.toArrowColumns(allocator, arr2.getFirst(), arr2.getRight());
assertEquals(2, arrowCols.size());
assertEquals(valueCount, arrowCols.get(0).getValueCount());
}
@Test
public void testVectorForEdgeCases() {
@DisplayName("Test Vector For Edge Cases")
void testVectorForEdgeCases() {
BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
val vector = ArrowConverter.vectorFor(allocator,"field1",new float[]{Float.MIN_VALUE,Float.MAX_VALUE});
assertEquals(Float.MIN_VALUE,vector.get(0),1e-2);
assertEquals(Float.MAX_VALUE,vector.get(1),1e-2);
val vectorInt = ArrowConverter.vectorFor(allocator,"field1",new int[]{Integer.MIN_VALUE,Integer.MAX_VALUE});
assertEquals(Integer.MIN_VALUE,vectorInt.get(0),1e-2);
assertEquals(Integer.MAX_VALUE,vectorInt.get(1),1e-2);
val vector = ArrowConverter.vectorFor(allocator, "field1", new float[] { Float.MIN_VALUE, Float.MAX_VALUE });
assertEquals(Float.MIN_VALUE, vector.get(0), 1e-2);
assertEquals(Float.MAX_VALUE, vector.get(1), 1e-2);
val vectorInt = ArrowConverter.vectorFor(allocator, "field1", new int[] { Integer.MIN_VALUE, Integer.MAX_VALUE });
assertEquals(Integer.MIN_VALUE, vectorInt.get(0), 1e-2);
assertEquals(Integer.MAX_VALUE, vectorInt.get(1), 1e-2);
}
@Test
public void testVectorFor() {
@DisplayName("Test Vector For")
void testVectorFor() {
BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
val vector = ArrowConverter.vectorFor(allocator,"field1",new float[]{1,2,3});
assertEquals(3,vector.getValueCount());
assertEquals(1,vector.get(0),1e-2);
assertEquals(2,vector.get(1),1e-2);
assertEquals(3,vector.get(2),1e-2);
val vectorLong = ArrowConverter.vectorFor(allocator,"field1",new long[]{1,2,3});
assertEquals(3,vectorLong.getValueCount());
assertEquals(1,vectorLong.get(0),1e-2);
assertEquals(2,vectorLong.get(1),1e-2);
assertEquals(3,vectorLong.get(2),1e-2);
val vectorInt = ArrowConverter.vectorFor(allocator,"field1",new int[]{1,2,3});
assertEquals(3,vectorInt.getValueCount());
assertEquals(1,vectorInt.get(0),1e-2);
assertEquals(2,vectorInt.get(1),1e-2);
assertEquals(3,vectorInt.get(2),1e-2);
val vectorDouble = ArrowConverter.vectorFor(allocator,"field1",new double[]{1,2,3});
assertEquals(3,vectorDouble.getValueCount());
assertEquals(1,vectorDouble.get(0),1e-2);
assertEquals(2,vectorDouble.get(1),1e-2);
assertEquals(3,vectorDouble.get(2),1e-2);
val vectorBool = ArrowConverter.vectorFor(allocator,"field1",new boolean[]{true,true,false});
assertEquals(3,vectorBool.getValueCount());
assertEquals(1,vectorBool.get(0),1e-2);
assertEquals(1,vectorBool.get(1),1e-2);
assertEquals(0,vectorBool.get(2),1e-2);
val vector = ArrowConverter.vectorFor(allocator, "field1", new float[] { 1, 2, 3 });
assertEquals(3, vector.getValueCount());
assertEquals(1, vector.get(0), 1e-2);
assertEquals(2, vector.get(1), 1e-2);
assertEquals(3, vector.get(2), 1e-2);
val vectorLong = ArrowConverter.vectorFor(allocator, "field1", new long[] { 1, 2, 3 });
assertEquals(3, vectorLong.getValueCount());
assertEquals(1, vectorLong.get(0), 1e-2);
assertEquals(2, vectorLong.get(1), 1e-2);
assertEquals(3, vectorLong.get(2), 1e-2);
val vectorInt = ArrowConverter.vectorFor(allocator, "field1", new int[] { 1, 2, 3 });
assertEquals(3, vectorInt.getValueCount());
assertEquals(1, vectorInt.get(0), 1e-2);
assertEquals(2, vectorInt.get(1), 1e-2);
assertEquals(3, vectorInt.get(2), 1e-2);
val vectorDouble = ArrowConverter.vectorFor(allocator, "field1", new double[] { 1, 2, 3 });
assertEquals(3, vectorDouble.getValueCount());
assertEquals(1, vectorDouble.get(0), 1e-2);
assertEquals(2, vectorDouble.get(1), 1e-2);
assertEquals(3, vectorDouble.get(2), 1e-2);
val vectorBool = ArrowConverter.vectorFor(allocator, "field1", new boolean[] { true, true, false });
assertEquals(3, vectorBool.getValueCount());
assertEquals(1, vectorBool.get(0), 1e-2);
assertEquals(1, vectorBool.get(1), 1e-2);
assertEquals(0, vectorBool.get(2), 1e-2);
}
@Test
public void testRecordReaderAndWriteFile() throws Exception {
@DisplayName("Test Record Reader And Write File")
void testRecordReaderAndWriteFile() throws Exception {
val recordsToWrite = recordToWrite();
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(),recordsToWrite.getFirst(),byteArrayOutputStream);
ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(), recordsToWrite.getFirst(), byteArrayOutputStream);
byte[] arr = byteArrayOutputStream.toByteArray();
val read = ArrowConverter.readFromBytes(arr);
assertEquals(recordsToWrite,read);
//send file
File tmp = tmpDataFile(recordsToWrite);
assertEquals(recordsToWrite, read);
// send file
File tmp = tmpDataFile(recordsToWrite);
RecordReader recordReader = new ArrowRecordReader();
recordReader.initialize(new FileSplit(tmp));
List<Writable> record = recordReader.next();
assertEquals(2,record.size());
assertEquals(2, record.size());
}
@Test
public void testRecordReaderMetaDataList() throws Exception {
@DisplayName("Test Record Reader Meta Data List")
void testRecordReaderMetaDataList() throws Exception {
val recordsToWrite = recordToWrite();
//send file
File tmp = tmpDataFile(recordsToWrite);
// send file
File tmp = tmpDataFile(recordsToWrite);
RecordReader recordReader = new ArrowRecordReader();
RecordMetaDataIndex recordMetaDataIndex = new RecordMetaDataIndex(0,tmp.toURI(),ArrowRecordReader.class);
RecordMetaDataIndex recordMetaDataIndex = new RecordMetaDataIndex(0, tmp.toURI(), ArrowRecordReader.class);
recordReader.loadFromMetaData(Arrays.<RecordMetaData>asList(recordMetaDataIndex));
Record record = recordReader.nextRecord();
assertEquals(2,record.getRecord().size());
assertEquals(2, record.getRecord().size());
}
@Test
public void testDates() {
@DisplayName("Test Dates")
void testDates() {
Date now = new Date();
BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE);
TimeStampMilliVector timeStampMilliVector = ArrowConverter.vectorFor(bufferAllocator, "col1", new Date[]{now});
assertEquals(now.getTime(),timeStampMilliVector.get(0));
TimeStampMilliVector timeStampMilliVector = ArrowConverter.vectorFor(bufferAllocator, "col1", new Date[] { now });
assertEquals(now.getTime(), timeStampMilliVector.get(0));
}
@Test
public void testRecordReaderMetaData() throws Exception {
@DisplayName("Test Record Reader Meta Data")
void testRecordReaderMetaData() throws Exception {
val recordsToWrite = recordToWrite();
//send file
File tmp = tmpDataFile(recordsToWrite);
// send file
File tmp = tmpDataFile(recordsToWrite);
RecordReader recordReader = new ArrowRecordReader();
RecordMetaDataIndex recordMetaDataIndex = new RecordMetaDataIndex(0,tmp.toURI(),ArrowRecordReader.class);
RecordMetaDataIndex recordMetaDataIndex = new RecordMetaDataIndex(0, tmp.toURI(), ArrowRecordReader.class);
recordReader.loadFromMetaData(recordMetaDataIndex);
Record record = recordReader.nextRecord();
assertEquals(2,record.getRecord().size());
assertEquals(2, record.getRecord().size());
}
private File tmpDataFile(Pair<Schema,List<List<Writable>>> recordsToWrite) throws IOException {
File f = testDir.newFolder();
//send file
File tmp = new File(f,"tmp-file-" + UUID.randomUUID().toString());
private File tmpDataFile(Pair<Schema, List<List<Writable>>> recordsToWrite) throws IOException {
File f = testDir.toFile();
// send file
File tmp = new File(f, "tmp-file-" + UUID.randomUUID().toString());
tmp.mkdirs();
File tmpFile = new File(tmp,"data.arrow");
File tmpFile = new File(tmp, "data.arrow");
tmpFile.deleteOnExit();
FileOutputStream bufferedOutputStream = new FileOutputStream(tmpFile);
ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(),recordsToWrite.getFirst(),bufferedOutputStream);
ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(), recordsToWrite.getFirst(), bufferedOutputStream);
bufferedOutputStream.flush();
bufferedOutputStream.close();
return tmp;
}
private Pair<Schema,List<List<Writable>>> recordToWrite() {
private Pair<Schema, List<List<Writable>>> recordToWrite() {
List<List<Writable>> records = new ArrayList<>();
records.add(Arrays.<Writable>asList(new DoubleWritable(0.0),new DoubleWritable(0.0)));
records.add(Arrays.<Writable>asList(new DoubleWritable(0.0),new DoubleWritable(0.0)));
records.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new DoubleWritable(0.0)));
records.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new DoubleWritable(0.0)));
Schema.Builder schemaBuilder = new Schema.Builder();
for(int i = 0; i < 2; i++) {
for (int i = 0; i < 2; i++) {
schemaBuilder.addColumnFloat("col-" + i);
}
return Pair.of(schemaBuilder.build(),records);
return Pair.of(schemaBuilder.build(), records);
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.arrow;
import lombok.val;
@ -34,132 +33,98 @@ import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable;
import org.datavec.arrow.recordreader.ArrowRecordReader;
import org.datavec.arrow.recordreader.ArrowRecordWriter;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.primitives.Triple;
import java.io.File;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals;
public class RecordMapperTest extends BaseND4JTest {
@DisplayName("Record Mapper Test")
class RecordMapperTest extends BaseND4JTest {
@Test
public void testMultiWrite() throws Exception {
@DisplayName("Test Multi Write")
void testMultiWrite() throws Exception {
val recordsPair = records();
Path p = Files.createTempFile("arrowwritetest", ".arrow");
FileUtils.write(p.toFile(),recordsPair.getFirst());
FileUtils.write(p.toFile(), recordsPair.getFirst());
p.toFile().deleteOnExit();
int numReaders = 2;
RecordReader[] readers = new RecordReader[numReaders];
InputSplit[] splits = new InputSplit[numReaders];
for(int i = 0; i < readers.length; i++) {
for (int i = 0; i < readers.length; i++) {
FileSplit split = new FileSplit(p.toFile());
ArrowRecordReader arrowRecordReader = new ArrowRecordReader();
readers[i] = arrowRecordReader;
splits[i] = split;
}
ArrowRecordWriter arrowRecordWriter = new ArrowRecordWriter(recordsPair.getMiddle());
FileSplit split = new FileSplit(p.toFile());
arrowRecordWriter.initialize(split,new NumberOfRecordsPartitioner());
arrowRecordWriter.initialize(split, new NumberOfRecordsPartitioner());
arrowRecordWriter.writeBatch(recordsPair.getRight());
CSVRecordWriter csvRecordWriter = new CSVRecordWriter();
Path p2 = Files.createTempFile("arrowwritetest", ".csv");
FileUtils.write(p2.toFile(),recordsPair.getFirst());
FileUtils.write(p2.toFile(), recordsPair.getFirst());
p.toFile().deleteOnExit();
FileSplit outputCsv = new FileSplit(p2.toFile());
RecordMapper mapper = RecordMapper.builder().batchSize(10).inputUrl(split)
.outputUrl(outputCsv)
.partitioner(new NumberOfRecordsPartitioner()).readersToConcat(readers)
.splitPerReader(splits)
.recordWriter(csvRecordWriter)
.build();
RecordMapper mapper = RecordMapper.builder().batchSize(10).inputUrl(split).outputUrl(outputCsv).partitioner(new NumberOfRecordsPartitioner()).readersToConcat(readers).splitPerReader(splits).recordWriter(csvRecordWriter).build();
mapper.copy();
}
@Test
public void testCopyFromArrowToCsv() throws Exception {
@DisplayName("Test Copy From Arrow To Csv")
void testCopyFromArrowToCsv() throws Exception {
val recordsPair = records();
Path p = Files.createTempFile("arrowwritetest", ".arrow");
FileUtils.write(p.toFile(),recordsPair.getFirst());
FileUtils.write(p.toFile(), recordsPair.getFirst());
p.toFile().deleteOnExit();
ArrowRecordWriter arrowRecordWriter = new ArrowRecordWriter(recordsPair.getMiddle());
FileSplit split = new FileSplit(p.toFile());
arrowRecordWriter.initialize(split,new NumberOfRecordsPartitioner());
arrowRecordWriter.initialize(split, new NumberOfRecordsPartitioner());
arrowRecordWriter.writeBatch(recordsPair.getRight());
ArrowRecordReader arrowRecordReader = new ArrowRecordReader();
arrowRecordReader.initialize(split);
CSVRecordWriter csvRecordWriter = new CSVRecordWriter();
Path p2 = Files.createTempFile("arrowwritetest", ".csv");
FileUtils.write(p2.toFile(),recordsPair.getFirst());
FileUtils.write(p2.toFile(), recordsPair.getFirst());
p.toFile().deleteOnExit();
FileSplit outputCsv = new FileSplit(p2.toFile());
RecordMapper mapper = RecordMapper.builder().batchSize(10).inputUrl(split)
.outputUrl(outputCsv)
.partitioner(new NumberOfRecordsPartitioner())
.recordReader(arrowRecordReader).recordWriter(csvRecordWriter)
.build();
RecordMapper mapper = RecordMapper.builder().batchSize(10).inputUrl(split).outputUrl(outputCsv).partitioner(new NumberOfRecordsPartitioner()).recordReader(arrowRecordReader).recordWriter(csvRecordWriter).build();
mapper.copy();
CSVRecordReader recordReader = new CSVRecordReader();
recordReader.initialize(outputCsv);
List<List<Writable>> loadedCSvRecords = recordReader.next(10);
assertEquals(10,loadedCSvRecords.size());
assertEquals(10, loadedCSvRecords.size());
}
@Test
public void testCopyFromCsvToArrow() throws Exception {
@DisplayName("Test Copy From Csv To Arrow")
void testCopyFromCsvToArrow() throws Exception {
val recordsPair = records();
Path p = Files.createTempFile("csvwritetest", ".csv");
FileUtils.write(p.toFile(),recordsPair.getFirst());
FileUtils.write(p.toFile(), recordsPair.getFirst());
p.toFile().deleteOnExit();
CSVRecordReader recordReader = new CSVRecordReader();
FileSplit fileSplit = new FileSplit(p.toFile());
ArrowRecordWriter arrowRecordWriter = new ArrowRecordWriter(recordsPair.getMiddle());
File outputFile = Files.createTempFile("outputarrow","arrow").toFile();
File outputFile = Files.createTempFile("outputarrow", "arrow").toFile();
FileSplit outputFileSplit = new FileSplit(outputFile);
RecordMapper mapper = RecordMapper.builder().batchSize(10).inputUrl(fileSplit)
.outputUrl(outputFileSplit).partitioner(new NumberOfRecordsPartitioner())
.recordReader(recordReader).recordWriter(arrowRecordWriter)
.build();
RecordMapper mapper = RecordMapper.builder().batchSize(10).inputUrl(fileSplit).outputUrl(outputFileSplit).partitioner(new NumberOfRecordsPartitioner()).recordReader(recordReader).recordWriter(arrowRecordWriter).build();
mapper.copy();
ArrowRecordReader arrowRecordReader = new ArrowRecordReader();
arrowRecordReader.initialize(outputFileSplit);
List<List<Writable>> next = arrowRecordReader.next(10);
System.out.println(next);
assertEquals(10,next.size());
assertEquals(10, next.size());
}
private Triple<String,Schema,List<List<Writable>>> records() {
private Triple<String, Schema, List<List<Writable>>> records() {
List<List<Writable>> list = new ArrayList<>();
StringBuilder sb = new StringBuilder();
int numColumns = 3;
@ -176,15 +141,10 @@ public class RecordMapperTest extends BaseND4JTest {
}
list.add(temp);
}
Schema.Builder schemaBuilder = new Schema.Builder();
for(int i = 0; i < numColumns; i++) {
for (int i = 0; i < numColumns; i++) {
schemaBuilder.addColumnInteger(String.valueOf(i));
}
return Triple.of(sb.toString(),schemaBuilder.build(),list);
return Triple.of(sb.toString(), schemaBuilder.build(), list);
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.image;
import org.apache.commons.io.FileUtils;
@ -25,33 +24,32 @@ import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit;
import org.datavec.image.recordreader.ImageRecordReader;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.io.ClassPathResource;
import java.io.File;
import java.util.Arrays;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
@DisplayName("Label Generator Test")
class LabelGeneratorTest {
public class LabelGeneratorTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@TempDir
public Path testDir;
@Test
public void testParentPathLabelGenerator() throws Exception {
//https://github.com/deeplearning4j/DataVec/issues/273
@DisplayName("Test Parent Path Label Generator")
void testParentPathLabelGenerator(@TempDir Path testDir) throws Exception {
File orig = new ClassPathResource("datavec-data-image/testimages/class0/0.jpg").getFile();
for(String dirPrefix : new String[]{"m.", "m"}) {
File f = testDir.newFolder();
for (String dirPrefix : new String[] { "m.", "m" }) {
File f = testDir.toFile();
int numDirs = 3;
int filesPerDir = 4;
for (int i = 0; i < numDirs; i++) {
File currentLabelDir = new File(f, dirPrefix + i);
currentLabelDir.mkdirs();
@ -61,14 +59,11 @@ public class LabelGeneratorTest {
assertTrue(f3.exists());
}
}
ImageRecordReader rr = new ImageRecordReader(28, 28, 1, new ParentPathLabelGenerator());
rr.initialize(new FileSplit(f));
List<String> labelsAct = rr.getLabels();
List<String> labelsExp = Arrays.asList(dirPrefix + "0", dirPrefix + "1", dirPrefix + "2");
assertEquals(labelsExp, labelsAct);
int expCount = numDirs * filesPerDir;
int actCount = 0;
while (rr.hasNext()) {

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.image.recordreader;
import org.apache.commons.io.FileUtils;
@ -29,60 +28,55 @@ import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.datavec.image.loader.NativeImageLoader;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.loader.FileBatch;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.common.io.ClassPathResource;
import java.io.File;
import java.util.*;
import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.*;
@DisplayName("File Batch Record Reader Test")
class FileBatchRecordReaderTest {
public class FileBatchRecordReaderTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@TempDir
public Path testDir;
@Test
public void testCsv() throws Exception {
File extractedSourceDir = testDir.newFolder();
@DisplayName("Test Csv")
void testCsv(@TempDir Path testDir,@TempDir Path baseDirPath) throws Exception {
File extractedSourceDir = testDir.toFile();
new ClassPathResource("datavec-data-image/testimages").copyDirectory(extractedSourceDir);
File baseDir = testDir.newFolder();
File baseDir = baseDirPath.toFile();
List<File> c = new ArrayList<>(FileUtils.listFiles(extractedSourceDir, null, true));
assertEquals(6, c.size());
Collections.sort(c, new Comparator<File>() {
@Override
public int compare(File o1, File o2) {
return o1.getPath().compareTo(o2.getPath());
}
});
FileBatch fb = FileBatch.forFiles(c);
File saveFile = new File(baseDir, "saved.zip");
fb.writeAsZip(saveFile);
fb = FileBatch.readFromZip(saveFile);
PathLabelGenerator labelMaker = new ParentPathLabelGenerator();
ImageRecordReader rr = new ImageRecordReader(32, 32, 1, labelMaker);
rr.setLabels(Arrays.asList("class0", "class1"));
FileBatchRecordReader fbrr = new FileBatchRecordReader(rr, fb);
NativeImageLoader il = new NativeImageLoader(32, 32, 1);
for( int test=0; test<3; test++) {
for (int test = 0; test < 3; test++) {
for (int i = 0; i < 6; i++) {
assertTrue(fbrr.hasNext());
List<Writable> next = fbrr.next();
assertEquals(2, next.size());
INDArray exp;
switch (i){
switch(i) {
case 0:
exp = il.asMatrix(new File(extractedSourceDir, "class0/0.jpg"));
break;
@ -105,8 +99,7 @@ public class FileBatchRecordReaderTest {
throw new RuntimeException();
}
Writable expLabel = (i < 3 ? new IntWritable(0) : new IntWritable(1));
assertEquals(((NDArrayWritable)next.get(0)).get(), exp);
assertEquals(((NDArrayWritable) next.get(0)).get(), exp);
assertEquals(expLabel, next.get(1));
}
assertFalse(fbrr.hasNext());
@ -114,5 +107,4 @@ public class FileBatchRecordReaderTest {
fbrr.reset();
}
}
}

View File

@ -17,106 +17,70 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.image.transform;
import org.datavec.image.data.ImageWritable;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import java.io.IOException;
import java.util.List;
import java.util.Random;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
@DisplayName("Json Yaml Test")
class JsonYamlTest {
public class JsonYamlTest {
@Test
public void testJsonYamlImageTransformProcess() throws IOException {
@DisplayName("Test Json Yaml Image Transform Process")
void testJsonYamlImageTransformProcess() throws IOException {
int seed = 12345;
Random random = new Random(seed);
//from org.bytedeco.javacpp.opencv_imgproc
// from org.bytedeco.javacpp.opencv_imgproc
int COLOR_BGR2Luv = 50;
int CV_BGR2GRAY = 6;
ImageTransformProcess itp = new ImageTransformProcess.Builder().colorConversionTransform(COLOR_BGR2Luv)
.cropImageTransform(10).equalizeHistTransform(CV_BGR2GRAY).flipImageTransform(0)
.resizeImageTransform(300, 300).rotateImageTransform(30).scaleImageTransform(3)
.warpImageTransform((float) 0.5)
// Note : since randomCropTransform use random value
// the results from each case(json, yaml, ImageTransformProcess)
// can be different
// don't use the below line
// if you uncomment it, you will get fail from below assertions
// .randomCropTransform(seed, 50, 50)
// Note : you will get "java.lang.NoClassDefFoundError: Could not initialize class org.bytedeco.javacpp.avutil"
// it needs to add the below dependency
// <dependency>
// <groupId>org.bytedeco</groupId>
// <artifactId>ffmpeg-platform</artifactId>
// </dependency>
// FFmpeg has license issues, be careful to use it
//.filterImageTransform("noise=alls=20:allf=t+u,format=rgba", 100, 100, 4)
.build();
ImageTransformProcess itp = new ImageTransformProcess.Builder().colorConversionTransform(COLOR_BGR2Luv).cropImageTransform(10).equalizeHistTransform(CV_BGR2GRAY).flipImageTransform(0).resizeImageTransform(300, 300).rotateImageTransform(30).scaleImageTransform(3).warpImageTransform((float) 0.5).build();
String asJson = itp.toJson();
String asYaml = itp.toYaml();
// System.out.println(asJson);
// System.out.println("\n\n\n");
// System.out.println(asYaml);
// System.out.println(asJson);
// System.out.println("\n\n\n");
// System.out.println(asYaml);
ImageWritable img = TestImageTransform.makeRandomImage(0, 0, 3);
ImageWritable imgJson = new ImageWritable(img.getFrame().clone());
ImageWritable imgYaml = new ImageWritable(img.getFrame().clone());
ImageWritable imgAll = new ImageWritable(img.getFrame().clone());
ImageTransformProcess itpFromJson = ImageTransformProcess.fromJson(asJson);
ImageTransformProcess itpFromYaml = ImageTransformProcess.fromYaml(asYaml);
List<ImageTransform> transformList = itp.getTransformList();
List<ImageTransform> transformListJson = itpFromJson.getTransformList();
List<ImageTransform> transformListYaml = itpFromYaml.getTransformList();
for (int i = 0; i < transformList.size(); i++) {
ImageTransform it = transformList.get(i);
ImageTransform itJson = transformListJson.get(i);
ImageTransform itYaml = transformListYaml.get(i);
System.out.println(i + "\t" + it);
img = it.transform(img);
imgJson = itJson.transform(imgJson);
imgYaml = itYaml.transform(imgYaml);
if (it instanceof RandomCropTransform) {
assertTrue(img.getFrame().imageHeight == imgJson.getFrame().imageHeight);
assertTrue(img.getFrame().imageWidth == imgJson.getFrame().imageWidth);
assertTrue(img.getFrame().imageHeight == imgYaml.getFrame().imageHeight);
assertTrue(img.getFrame().imageWidth == imgYaml.getFrame().imageWidth);
} else if (it instanceof FilterImageTransform) {
assertEquals(img.getFrame().imageHeight, imgJson.getFrame().imageHeight);
assertEquals(img.getFrame().imageWidth, imgJson.getFrame().imageWidth);
assertEquals(img.getFrame().imageChannels, imgJson.getFrame().imageChannels);
assertEquals(img.getFrame().imageHeight, imgYaml.getFrame().imageHeight);
assertEquals(img.getFrame().imageWidth, imgYaml.getFrame().imageWidth);
assertEquals(img.getFrame().imageChannels, imgYaml.getFrame().imageChannels);
} else {
assertEquals(img, imgJson);
assertEquals(img, imgYaml);
}
}
imgAll = itp.execute(imgAll);
assertEquals(imgAll, img);
}
}

View File

@ -17,56 +17,50 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.image.transform;
import org.bytedeco.javacv.Frame;
import org.datavec.image.data.ImageWritable;
import org.junit.Before;
import org.junit.Test;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals;
public class ResizeImageTransformTest {
@Before
public void setUp() throws Exception {
@DisplayName("Resize Image Transform Test")
class ResizeImageTransformTest {
@BeforeEach
void setUp() throws Exception {
}
@Test
public void testResizeUpscale1() throws Exception {
@DisplayName("Test Resize Upscale 1")
void testResizeUpscale1() throws Exception {
ImageWritable srcImg = TestImageTransform.makeRandomImage(32, 32, 3);
ResizeImageTransform transform = new ResizeImageTransform(200, 200);
ImageWritable dstImg = transform.transform(srcImg);
Frame f = dstImg.getFrame();
assertEquals(f.imageWidth, 200);
assertEquals(f.imageHeight, 200);
float[] coordinates = {100, 200};
float[] coordinates = { 100, 200 };
float[] transformed = transform.query(coordinates);
assertEquals(200f * 100 / 32, transformed[0], 0);
assertEquals(200f * 200 / 32, transformed[1], 0);
}
@Test
public void testResizeDownscale() throws Exception {
@DisplayName("Test Resize Downscale")
void testResizeDownscale() throws Exception {
ImageWritable srcImg = TestImageTransform.makeRandomImage(571, 443, 3);
ResizeImageTransform transform = new ResizeImageTransform(200, 200);
ImageWritable dstImg = transform.transform(srcImg);
Frame f = dstImg.getFrame();
assertEquals(f.imageWidth, 200);
assertEquals(f.imageHeight, 200);
float[] coordinates = {300, 400};
float[] coordinates = { 300, 400 };
float[] transformed = transform.query(coordinates);
assertEquals(200f * 300 / 443, transformed[0], 0);
assertEquals(200f * 400 / 571, transformed[1], 0);
}
}

View File

@ -17,37 +17,34 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.poi.excel;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.io.ClassPathResource;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
public class ExcelRecordReaderTest {
@DisplayName("Excel Record Reader Test")
class ExcelRecordReaderTest {
@Test
public void testSimple() throws Exception {
@DisplayName("Test Simple")
void testSimple() throws Exception {
RecordReader excel = new ExcelRecordReader();
excel.initialize(new FileSplit(new ClassPathResource("datavec-excel/testsheet.xlsx").getFile()));
assertTrue(excel.hasNext());
List<Writable> next = excel.next();
assertEquals(3,next.size());
assertEquals(3, next.size());
RecordReader headerReader = new ExcelRecordReader(1);
headerReader.initialize(new FileSplit(new ClassPathResource("datavec-excel/testsheetheader.xlsx").getFile()));
assertTrue(excel.hasNext());
List<Writable> next2 = excel.next();
assertEquals(3,next2.size());
assertEquals(3, next2.size());
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.poi.excel;
import lombok.val;
@ -27,43 +26,44 @@ import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.primitives.Triple;
import java.io.File;
import java.util.ArrayList;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals;
@DisplayName("Excel Record Writer Test")
class ExcelRecordWriterTest {
public class ExcelRecordWriterTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@TempDir
public Path testDir;
@Test
public void testWriter() throws Exception {
@DisplayName("Test Writer")
void testWriter() throws Exception {
ExcelRecordWriter excelRecordWriter = new ExcelRecordWriter();
val records = records();
File tmpDir = testDir.newFolder();
File outputFile = new File(tmpDir,"testexcel.xlsx");
File tmpDir = testDir.toFile();
File outputFile = new File(tmpDir, "testexcel.xlsx");
outputFile.deleteOnExit();
FileSplit fileSplit = new FileSplit(outputFile);
excelRecordWriter.initialize(fileSplit,new NumberOfRecordsPartitioner());
excelRecordWriter.initialize(fileSplit, new NumberOfRecordsPartitioner());
excelRecordWriter.writeBatch(records.getRight());
excelRecordWriter.close();
File parentFile = outputFile.getParentFile();
assertEquals(1,parentFile.list().length);
assertEquals(1, parentFile.list().length);
ExcelRecordReader excelRecordReader = new ExcelRecordReader();
excelRecordReader.initialize(fileSplit);
List<List<Writable>> next = excelRecordReader.next(10);
assertEquals(10,next.size());
assertEquals(10, next.size());
}
private Triple<String,Schema,List<List<Writable>>> records() {
private Triple<String, Schema, List<List<Writable>>> records() {
List<List<Writable>> list = new ArrayList<>();
StringBuilder sb = new StringBuilder();
int numColumns = 3;
@ -80,13 +80,10 @@ public class ExcelRecordWriterTest {
}
list.add(temp);
}
Schema.Builder schemaBuilder = new Schema.Builder();
for(int i = 0; i < numColumns; i++) {
for (int i = 0; i < numColumns; i++) {
schemaBuilder.addColumnInteger(String.valueOf(i));
}
return Triple.of(sb.toString(),schemaBuilder.build(),list);
return Triple.of(sb.toString(), schemaBuilder.build(), list);
}
}

View File

@ -17,14 +17,12 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.records.reader.impl;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import java.io.File;
import java.net.URI;
import java.sql.Connection;
@ -49,53 +47,57 @@ import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.LongWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.After;
import org.junit.Before;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.jupiter.api.Assertions.assertThrows;
public class JDBCRecordReaderTest {
@DisplayName("Jdbc Record Reader Test")
class JDBCRecordReaderTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@TempDir
public Path testDir;
Connection conn;
EmbeddedDataSource dataSource;
private final String dbName = "datavecTests";
private final String driverClassName = "org.apache.derby.jdbc.EmbeddedDriver";
@Before
public void setUp() throws Exception {
File f = testDir.newFolder();
@BeforeEach
void setUp() throws Exception {
File f = testDir.toFile();
System.setProperty("derby.system.home", f.getAbsolutePath());
dataSource = new EmbeddedDataSource();
dataSource.setDatabaseName(dbName);
dataSource.setCreateDatabase("create");
conn = dataSource.getConnection();
TestDb.dropTables(conn);
TestDb.buildCoffeeTable(conn);
}
@After
public void tearDown() throws Exception {
@AfterEach
void tearDown() throws Exception {
DbUtils.closeQuietly(conn);
}
@Test
public void testSimpleIter() throws Exception {
@DisplayName("Test Simple Iter")
void testSimpleIter() throws Exception {
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
List<List<Writable>> records = new ArrayList<>();
while (reader.hasNext()) {
List<Writable> values = reader.next();
records.add(values);
}
assertFalse(records.isEmpty());
List<Writable> first = records.get(0);
assertEquals(new Text("Bolivian Dark"), first.get(0));
assertEquals(new Text("14-001"), first.get(1));
@ -104,39 +106,43 @@ public class JDBCRecordReaderTest {
}
@Test
public void testSimpleWithListener() throws Exception {
@DisplayName("Test Simple With Listener")
void testSimpleWithListener() throws Exception {
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
RecordListener recordListener = new LogRecordListener();
reader.setListeners(recordListener);
reader.next();
assertTrue(recordListener.invoked());
}
}
@Test
public void testReset() throws Exception {
@DisplayName("Test Reset")
void testReset() throws Exception {
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
List<List<Writable>> records = new ArrayList<>();
records.add(reader.next());
reader.reset();
records.add(reader.next());
assertEquals(2, records.size());
assertEquals(new Text("Bolivian Dark"), records.get(0).get(0));
assertEquals(new Text("Bolivian Dark"), records.get(1).get(0));
}
}
@Test(expected = IllegalStateException.class)
public void testLackingDataSourceShouldFail() throws Exception {
try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM Coffee")) {
reader.initialize(null);
}
@Test
@DisplayName("Test Lacking Data Source Should Fail")
void testLackingDataSourceShouldFail() {
assertThrows(IllegalStateException.class, () -> {
try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM Coffee")) {
reader.initialize(null);
}
});
}
@Test
public void testConfigurationDataSourceInitialization() throws Exception {
@DisplayName("Test Configuration Data Source Initialization")
void testConfigurationDataSourceInitialization() throws Exception {
try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM Coffee")) {
Configuration conf = new Configuration();
conf.set(JDBCRecordReader.JDBC_URL, "jdbc:derby:" + dbName + ";create=true");
@ -146,28 +152,33 @@ public class JDBCRecordReaderTest {
}
}
@Test(expected = IllegalArgumentException.class)
public void testInitConfigurationMissingParametersShouldFail() throws Exception {
try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM Coffee")) {
Configuration conf = new Configuration();
conf.set(JDBCRecordReader.JDBC_URL, "should fail anyway");
reader.initialize(conf, null);
}
}
@Test(expected = UnsupportedOperationException.class)
public void testRecordDataInputStreamShouldFail() throws Exception {
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
reader.record(null, null);
}
@Test
@DisplayName("Test Init Configuration Missing Parameters Should Fail")
void testInitConfigurationMissingParametersShouldFail() {
assertThrows(IllegalArgumentException.class, () -> {
try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM Coffee")) {
Configuration conf = new Configuration();
conf.set(JDBCRecordReader.JDBC_URL, "should fail anyway");
reader.initialize(conf, null);
}
});
}
@Test
public void testLoadFromMetaData() throws Exception {
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
RecordMetaDataJdbc rmd = new RecordMetaDataJdbc(new URI(conn.getMetaData().getURL()),
"SELECT * FROM Coffee WHERE ProdNum = ?", Collections.singletonList("14-001"), reader.getClass());
@DisplayName("Test Record Data Input Stream Should Fail")
void testRecordDataInputStreamShouldFail() {
assertThrows(UnsupportedOperationException.class, () -> {
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
reader.record(null, null);
}
});
}
@Test
@DisplayName("Test Load From Meta Data")
void testLoadFromMetaData() throws Exception {
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
RecordMetaDataJdbc rmd = new RecordMetaDataJdbc(new URI(conn.getMetaData().getURL()), "SELECT * FROM Coffee WHERE ProdNum = ?", Collections.singletonList("14-001"), reader.getClass());
Record res = reader.loadFromMetaData(rmd);
assertNotNull(res);
assertEquals(new Text("Bolivian Dark"), res.getRecord().get(0));
@ -177,7 +188,8 @@ public class JDBCRecordReaderTest {
}
@Test
public void testNextRecord() throws Exception {
@DisplayName("Test Next Record")
void testNextRecord() throws Exception {
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
Record r = reader.nextRecord();
List<Writable> fields = r.getRecord();
@ -193,7 +205,8 @@ public class JDBCRecordReaderTest {
}
@Test
public void testNextRecordAndRecover() throws Exception {
@DisplayName("Test Next Record And Recover")
void testNextRecordAndRecover() throws Exception {
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
Record r = reader.nextRecord();
List<Writable> fields = r.getRecord();
@ -208,69 +221,91 @@ public class JDBCRecordReaderTest {
}
// Resetting the record reader when initialized as forward only should fail
@Test(expected = RuntimeException.class)
public void testResetForwardOnlyShouldFail() throws Exception {
try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM Coffee", dataSource)) {
Configuration conf = new Configuration();
conf.setInt(JDBCRecordReader.JDBC_RESULTSET_TYPE, ResultSet.TYPE_FORWARD_ONLY);
reader.initialize(conf, null);
reader.next();
reader.reset();
}
@Test
@DisplayName("Test Reset Forward Only Should Fail")
void testResetForwardOnlyShouldFail() {
assertThrows(RuntimeException.class, () -> {
try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM Coffee", dataSource)) {
Configuration conf = new Configuration();
conf.setInt(JDBCRecordReader.JDBC_RESULTSET_TYPE, ResultSet.TYPE_FORWARD_ONLY);
reader.initialize(conf, null);
reader.next();
reader.reset();
}
});
}
@Test
public void testReadAllTypes() throws Exception {
@DisplayName("Test Read All Types")
void testReadAllTypes() throws Exception {
TestDb.buildAllTypesTable(conn);
try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM AllTypes", dataSource)) {
reader.initialize(null);
List<Writable> item = reader.next();
assertEquals(item.size(), 15);
assertEquals(BooleanWritable.class, item.get(0).getClass()); // boolean to boolean
assertEquals(Text.class, item.get(1).getClass()); // date to text
assertEquals(Text.class, item.get(2).getClass()); // time to text
assertEquals(Text.class, item.get(3).getClass()); // timestamp to text
assertEquals(Text.class, item.get(4).getClass()); // char to text
assertEquals(Text.class, item.get(5).getClass()); // long varchar to text
assertEquals(Text.class, item.get(6).getClass()); // varchar to text
assertEquals(DoubleWritable.class,
item.get(7).getClass()); // float to double (derby's float is an alias of double by default)
assertEquals(FloatWritable.class, item.get(8).getClass()); // real to float
assertEquals(DoubleWritable.class, item.get(9).getClass()); // decimal to double
assertEquals(DoubleWritable.class, item.get(10).getClass()); // numeric to double
assertEquals(DoubleWritable.class, item.get(11).getClass()); // double to double
assertEquals(IntWritable.class, item.get(12).getClass()); // integer to integer
assertEquals(IntWritable.class, item.get(13).getClass()); // small int to integer
assertEquals(LongWritable.class, item.get(14).getClass()); // bigint to long
// boolean to boolean
assertEquals(BooleanWritable.class, item.get(0).getClass());
// date to text
assertEquals(Text.class, item.get(1).getClass());
// time to text
assertEquals(Text.class, item.get(2).getClass());
// timestamp to text
assertEquals(Text.class, item.get(3).getClass());
// char to text
assertEquals(Text.class, item.get(4).getClass());
// long varchar to text
assertEquals(Text.class, item.get(5).getClass());
// varchar to text
assertEquals(Text.class, item.get(6).getClass());
assertEquals(DoubleWritable.class, // float to double (derby's float is an alias of double by default)
item.get(7).getClass());
// real to float
assertEquals(FloatWritable.class, item.get(8).getClass());
// decimal to double
assertEquals(DoubleWritable.class, item.get(9).getClass());
// numeric to double
assertEquals(DoubleWritable.class, item.get(10).getClass());
// double to double
assertEquals(DoubleWritable.class, item.get(11).getClass());
// integer to integer
assertEquals(IntWritable.class, item.get(12).getClass());
// small int to integer
assertEquals(IntWritable.class, item.get(13).getClass());
// bigint to long
assertEquals(LongWritable.class, item.get(14).getClass());
}
}
@Test(expected = RuntimeException.class)
public void testNextNoMoreShouldFail() throws Exception {
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
while (reader.hasNext()) {
@Test
@DisplayName("Test Next No More Should Fail")
void testNextNoMoreShouldFail() {
assertThrows(RuntimeException.class, () -> {
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
while (reader.hasNext()) {
reader.next();
}
reader.next();
}
reader.next();
}
});
}
@Test(expected = IllegalArgumentException.class)
public void testInvalidMetadataShouldFail() throws Exception {
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
RecordMetaDataLine md = new RecordMetaDataLine(1, new URI("file://test"), JDBCRecordReader.class);
reader.loadFromMetaData(md);
}
@Test
@DisplayName("Test Invalid Metadata Should Fail")
void testInvalidMetadataShouldFail() {
assertThrows(IllegalArgumentException.class, () -> {
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
RecordMetaDataLine md = new RecordMetaDataLine(1, new URI("file://test"), JDBCRecordReader.class);
reader.loadFromMetaData(md);
}
});
}
private JDBCRecordReader getInitializedReader(String query) throws Exception {
int[] indices = {1}; // ProdNum column
JDBCRecordReader reader = new JDBCRecordReader(query, dataSource, "SELECT * FROM Coffee WHERE ProdNum = ?",
indices);
// ProdNum column
int[] indices = { 1 };
JDBCRecordReader reader = new JDBCRecordReader(query, dataSource, "SELECT * FROM Coffee WHERE ProdNum = ?", indices);
reader.setTrimStrings(true);
reader.initialize(null);
return reader;
}
}
}

View File

@ -17,10 +17,8 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.local.transforms.transform;
import org.datavec.api.transform.MathFunction;
import org.datavec.api.transform.MathOp;
import org.datavec.api.transform.ReduceOp;
@ -32,107 +30,86 @@ import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.schema.SequenceSchema;
import org.datavec.api.writable.*;
import org.datavec.python.PythonTransform;
import org.datavec.local.transforms.LocalTransformExecutor;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import java.util.*;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static java.time.Duration.ofMillis;
import static org.junit.jupiter.api.Assertions.assertTimeout;
import static org.junit.Assert.assertEquals;
public class ExecutionTest {
@DisplayName("Execution Test")
class ExecutionTest {
@Test
public void testExecutionNdarray() {
Schema schema = new Schema.Builder()
.addColumnNDArray("first",new long[]{1,32577})
.addColumnNDArray("second",new long[]{1,32577}).build();
TransformProcess transformProcess = new TransformProcess.Builder(schema)
.ndArrayMathFunctionTransform("first", MathFunction.SIN)
.ndArrayMathFunctionTransform("second",MathFunction.COS)
.build();
@DisplayName("Test Execution Ndarray")
void testExecutionNdarray() {
Schema schema = new Schema.Builder().addColumnNDArray("first", new long[] { 1, 32577 }).addColumnNDArray("second", new long[] { 1, 32577 }).build();
TransformProcess transformProcess = new TransformProcess.Builder(schema).ndArrayMathFunctionTransform("first", MathFunction.SIN).ndArrayMathFunctionTransform("second", MathFunction.COS).build();
List<List<Writable>> functions = new ArrayList<>();
List<Writable> firstRow = new ArrayList<>();
INDArray firstArr = Nd4j.linspace(1,4,4);
INDArray secondArr = Nd4j.linspace(1,4,4);
INDArray firstArr = Nd4j.linspace(1, 4, 4);
INDArray secondArr = Nd4j.linspace(1, 4, 4);
firstRow.add(new NDArrayWritable(firstArr));
firstRow.add(new NDArrayWritable(secondArr));
functions.add(firstRow);
List<List<Writable>> execute = LocalTransformExecutor.execute(functions, transformProcess);
INDArray firstResult = ((NDArrayWritable) execute.get(0).get(0)).get();
INDArray secondResult = ((NDArrayWritable) execute.get(0).get(1)).get();
INDArray expected = Transforms.sin(firstArr);
INDArray secondExpected = Transforms.cos(secondArr);
assertEquals(expected,firstResult);
assertEquals(secondExpected,secondResult);
assertEquals(expected, firstResult);
assertEquals(secondExpected, secondResult);
}
@Test
public void testExecutionSimple() {
Schema schema = new Schema.Builder().addColumnInteger("col0")
.addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").
addColumnFloat("col3").build();
TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1")
.doubleMathOp("col2", MathOp.Add, 10.0).floatMathOp("col3", MathOp.Add, 5f).build();
@DisplayName("Test Execution Simple")
void testExecutionSimple() {
Schema schema = new Schema.Builder().addColumnInteger("col0").addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").addColumnFloat("col3").build();
TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1").doubleMathOp("col2", MathOp.Add, 10.0).floatMathOp("col3", MathOp.Add, 5f).build();
List<List<Writable>> inputData = new ArrayList<>();
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1), new FloatWritable(0.3f)));
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1), new FloatWritable(1.7f)));
inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1), new FloatWritable(3.6f)));
List<List<Writable>> rdd = (inputData);
List<List<Writable>> out = new ArrayList<>(LocalTransformExecutor.execute(rdd, tp));
Collections.sort(out, new Comparator<List<Writable>>() {
@Override
public int compare(List<Writable> o1, List<Writable> o2) {
return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
}
});
List<List<Writable>> expected = new ArrayList<>();
expected.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1), new FloatWritable(5.3f)));
expected.add(Arrays.<Writable>asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1), new FloatWritable(6.7f)));
expected.add(Arrays.<Writable>asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1), new FloatWritable(8.6f)));
assertEquals(expected, out);
}
@Test
public void testFilter() {
Schema filterSchema = new Schema.Builder()
.addColumnDouble("col1").addColumnDouble("col2")
.addColumnDouble("col3").build();
@DisplayName("Test Filter")
void testFilter() {
Schema filterSchema = new Schema.Builder().addColumnDouble("col1").addColumnDouble("col2").addColumnDouble("col3").build();
List<List<Writable>> inputData = new ArrayList<>();
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new DoubleWritable(1), new DoubleWritable(0.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new DoubleWritable(3), new DoubleWritable(1.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(2), new DoubleWritable(3), new DoubleWritable(2.1)));
TransformProcess transformProcess = new TransformProcess.Builder(filterSchema)
.filter(new DoubleColumnCondition("col1",ConditionOp.LessThan,1)).build();
TransformProcess transformProcess = new TransformProcess.Builder(filterSchema).filter(new DoubleColumnCondition("col1", ConditionOp.LessThan, 1)).build();
List<List<Writable>> execute = LocalTransformExecutor.execute(inputData, transformProcess);
assertEquals(2,execute.size());
assertEquals(2, execute.size());
}
@Test
public void testExecutionSequence() {
Schema schema = new SequenceSchema.Builder().addColumnInteger("col0")
.addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").build();
TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1")
.doubleMathOp("col2", MathOp.Add, 10.0).build();
@DisplayName("Test Execution Sequence")
void testExecutionSequence() {
Schema schema = new SequenceSchema.Builder().addColumnInteger("col0").addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").build();
TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1").doubleMathOp("col2", MathOp.Add, 10.0).build();
List<List<List<Writable>>> inputSequences = new ArrayList<>();
List<List<Writable>> seq1 = new ArrayList<>();
seq1.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
@ -141,21 +118,17 @@ public class ExecutionTest {
List<List<Writable>> seq2 = new ArrayList<>();
seq2.add(Arrays.<Writable>asList(new IntWritable(3), new Text("state0"), new DoubleWritable(3.1)));
seq2.add(Arrays.<Writable>asList(new IntWritable(4), new Text("state1"), new DoubleWritable(4.1)));
inputSequences.add(seq1);
inputSequences.add(seq2);
List<List<List<Writable>>> rdd = (inputSequences);
List<List<List<Writable>>> rdd = (inputSequences);
List<List<List<Writable>>> out = LocalTransformExecutor.executeSequenceToSequence(rdd, tp);
Collections.sort(out, new Comparator<List<List<Writable>>>() {
@Override
public int compare(List<List<Writable>> o1, List<List<Writable>> o2) {
return -Integer.compare(o1.size(), o2.size());
}
});
List<List<List<Writable>>> expectedSequence = new ArrayList<>();
List<List<Writable>> seq1e = new ArrayList<>();
seq1e.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1)));
@ -164,121 +137,66 @@ public class ExecutionTest {
List<List<Writable>> seq2e = new ArrayList<>();
seq2e.add(Arrays.<Writable>asList(new IntWritable(3), new IntWritable(0), new DoubleWritable(13.1)));
seq2e.add(Arrays.<Writable>asList(new IntWritable(4), new IntWritable(1), new DoubleWritable(14.1)));
expectedSequence.add(seq1e);
expectedSequence.add(seq2e);
assertEquals(expectedSequence, out);
}
@Test
public void testReductionGlobal() {
List<List<Writable>> in = Arrays.asList(
Arrays.<Writable>asList(new Text("first"), new DoubleWritable(3.0)),
Arrays.<Writable>asList(new Text("second"), new DoubleWritable(5.0))
);
@DisplayName("Test Reduction Global")
void testReductionGlobal() {
List<List<Writable>> in = Arrays.asList(Arrays.<Writable>asList(new Text("first"), new DoubleWritable(3.0)), Arrays.<Writable>asList(new Text("second"), new DoubleWritable(5.0)));
List<List<Writable>> inData = in;
Schema s = new Schema.Builder()
.addColumnString("textCol")
.addColumnDouble("doubleCol")
.build();
TransformProcess tp = new TransformProcess.Builder(s)
.reduce(new Reducer.Builder(ReduceOp.TakeFirst)
.takeFirstColumns("textCol")
.meanColumns("doubleCol").build())
.build();
Schema s = new Schema.Builder().addColumnString("textCol").addColumnDouble("doubleCol").build();
TransformProcess tp = new TransformProcess.Builder(s).reduce(new Reducer.Builder(ReduceOp.TakeFirst).takeFirstColumns("textCol").meanColumns("doubleCol").build()).build();
List<List<Writable>> outRdd = LocalTransformExecutor.execute(inData, tp);
List<List<Writable>> out = outRdd;
List<List<Writable>> expOut = Collections.singletonList(Arrays.<Writable>asList(new Text("first"), new DoubleWritable(4.0)));
assertEquals(expOut, out);
}
@Test
public void testReductionByKey(){
List<List<Writable>> in = Arrays.asList(
Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(3.0)),
Arrays.<Writable>asList(new IntWritable(0), new Text("second"), new DoubleWritable(5.0)),
Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(30.0)),
Arrays.<Writable>asList(new IntWritable(1), new Text("s"), new DoubleWritable(50.0))
);
@DisplayName("Test Reduction By Key")
void testReductionByKey() {
List<List<Writable>> in = Arrays.asList(Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(3.0)), Arrays.<Writable>asList(new IntWritable(0), new Text("second"), new DoubleWritable(5.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(30.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("s"), new DoubleWritable(50.0)));
List<List<Writable>> inData = in;
Schema s = new Schema.Builder()
.addColumnInteger("intCol")
.addColumnString("textCol")
.addColumnDouble("doubleCol")
.build();
TransformProcess tp = new TransformProcess.Builder(s)
.reduce(new Reducer.Builder(ReduceOp.TakeFirst)
.keyColumns("intCol")
.takeFirstColumns("textCol")
.meanColumns("doubleCol").build())
.build();
Schema s = new Schema.Builder().addColumnInteger("intCol").addColumnString("textCol").addColumnDouble("doubleCol").build();
TransformProcess tp = new TransformProcess.Builder(s).reduce(new Reducer.Builder(ReduceOp.TakeFirst).keyColumns("intCol").takeFirstColumns("textCol").meanColumns("doubleCol").build()).build();
List<List<Writable>> outRdd = LocalTransformExecutor.execute(inData, tp);
List<List<Writable>> out = outRdd;
List<List<Writable>> expOut = Arrays.asList(
Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(4.0)),
Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(40.0)));
List<List<Writable>> expOut = Arrays.asList(Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(4.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(40.0)));
out = new ArrayList<>(out);
Collections.sort(
out, new Comparator<List<Writable>>() {
@Override
public int compare(List<Writable> o1, List<Writable> o2) {
return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
}
}
);
Collections.sort(out, new Comparator<List<Writable>>() {
@Override
public int compare(List<Writable> o1, List<Writable> o2) {
return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
}
});
assertEquals(expOut, out);
}
@Test(timeout = 60000L)
@Ignore("AB 2019/05/21 - Fine locally, timeouts on CI - Issue #7657 and #7771")
public void testPythonExecutionNdarray()throws Exception{
Schema schema = new Schema.Builder()
.addColumnNDArray("first",new long[]{1,32577})
.addColumnNDArray("second",new long[]{1,32577}).build();
TransformProcess transformProcess = new TransformProcess.Builder(schema)
.transform(
PythonTransform.builder().code(
"first = np.sin(first)\nsecond = np.cos(second)")
.outputSchema(schema).build())
.build();
List<List<Writable>> functions = new ArrayList<>();
List<Writable> firstRow = new ArrayList<>();
INDArray firstArr = Nd4j.linspace(1,4,4);
INDArray secondArr = Nd4j.linspace(1,4,4);
firstRow.add(new NDArrayWritable(firstArr));
firstRow.add(new NDArrayWritable(secondArr));
functions.add(firstRow);
List<List<Writable>> execute = LocalTransformExecutor.execute(functions, transformProcess);
INDArray firstResult = ((NDArrayWritable) execute.get(0).get(0)).get();
INDArray secondResult = ((NDArrayWritable) execute.get(0).get(1)).get();
INDArray expected = Transforms.sin(firstArr);
INDArray secondExpected = Transforms.cos(secondArr);
assertEquals(expected,firstResult);
assertEquals(secondExpected,secondResult);
@Test
@Disabled("AB 2019/05/21 - Fine locally, timeouts on CI - Issue #7657 and #7771")
@DisplayName("Test Python Execution Ndarray")
void testPythonExecutionNdarray() {
assertTimeout(ofMillis(60000), () -> {
Schema schema = new Schema.Builder().addColumnNDArray("first", new long[] { 1, 32577 }).addColumnNDArray("second", new long[] { 1, 32577 }).build();
TransformProcess transformProcess = new TransformProcess.Builder(schema).transform(PythonTransform.builder().code("first = np.sin(first)\nsecond = np.cos(second)").outputSchema(schema).build()).build();
List<List<Writable>> functions = new ArrayList<>();
List<Writable> firstRow = new ArrayList<>();
INDArray firstArr = Nd4j.linspace(1, 4, 4);
INDArray secondArr = Nd4j.linspace(1, 4, 4);
firstRow.add(new NDArrayWritable(firstArr));
firstRow.add(new NDArrayWritable(secondArr));
functions.add(firstRow);
List<List<Writable>> execute = LocalTransformExecutor.execute(functions, transformProcess);
INDArray firstResult = ((NDArrayWritable) execute.get(0).get(0)).get();
INDArray secondResult = ((NDArrayWritable) execute.get(0).get(1)).get();
INDArray expected = Transforms.sin(firstArr);
INDArray secondExpected = Transforms.cos(secondArr);
assertEquals(expected, firstResult);
assertEquals(secondExpected, secondResult);
});
}
}

View File

@ -17,36 +17,38 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.spark;
import lombok.extern.slf4j.Slf4j;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.junit.After;
import org.junit.Before;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import java.io.Serializable;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@Slf4j
@DisplayName("Base Spark Test")
public abstract class BaseSparkTest implements Serializable {
protected static JavaSparkContext sc;
@Before
public void before() {
@BeforeEach
void before() {
sc = getContext();
}
@After
public synchronized void after() {
@AfterEach
synchronized void after() {
sc.close();
//Wait until it's stopped, to avoid race conditions during tests
// Wait until it's stopped, to avoid race conditions during tests
for (int i = 0; i < 100; i++) {
if (!sc.sc().stopped().get()) {
try {
Thread.sleep(100L);
} catch (InterruptedException e) {
log.error("",e);
log.error("", e);
}
} else {
break;
@ -55,29 +57,21 @@ public abstract class BaseSparkTest implements Serializable {
if (!sc.sc().stopped().get()) {
throw new RuntimeException("Spark context is not stopped after 10s");
}
sc = null;
}
public synchronized JavaSparkContext getContext() {
if (sc != null)
return sc;
SparkConf sparkConf = new SparkConf().setMaster("local[*]").set("spark.driver.host", "localhost")
.set("spark.driverEnv.SPARK_LOCAL_IP", "127.0.0.1")
.set("spark.executorEnv.SPARK_LOCAL_IP", "127.0.0.1").setAppName("sparktest");
SparkConf sparkConf = new SparkConf().setMaster("local[*]").set("spark.driver.host", "localhost").set("spark.driverEnv.SPARK_LOCAL_IP", "127.0.0.1").set("spark.executorEnv.SPARK_LOCAL_IP", "127.0.0.1").setAppName("sparktest");
if (useKryo()) {
sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer");
}
sc = new JavaSparkContext(sparkConf);
return sc;
}
public boolean useKryo(){
public boolean useKryo() {
return false;
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.spark.transform;
import org.apache.spark.api.java.JavaRDD;
@ -35,59 +34,51 @@ import org.datavec.api.writable.Writable;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.spark.BaseSparkTest;
import org.datavec.python.PythonTransform;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.util.*;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static java.time.Duration.ofMillis;
import static org.junit.jupiter.api.Assertions.assertTimeout;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
public class ExecutionTest extends BaseSparkTest {
@DisplayName("Execution Test")
class ExecutionTest extends BaseSparkTest {
@Test
public void testExecutionSimple() {
Schema schema = new Schema.Builder().addColumnInteger("col0")
.addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").build();
TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1")
.doubleMathOp("col2", MathOp.Add, 10.0).build();
@DisplayName("Test Execution Simple")
void testExecutionSimple() {
Schema schema = new Schema.Builder().addColumnInteger("col0").addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").build();
TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1").doubleMathOp("col2", MathOp.Add, 10.0).build();
List<List<Writable>> inputData = new ArrayList<>();
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1)));
JavaRDD<List<Writable>> rdd = sc.parallelize(inputData);
List<List<Writable>> out = new ArrayList<>(SparkTransformExecutor.execute(rdd, tp).collect());
Collections.sort(out, new Comparator<List<Writable>>() {
@Override
public int compare(List<Writable> o1, List<Writable> o2) {
return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
}
});
List<List<Writable>> expected = new ArrayList<>();
expected.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1)));
expected.add(Arrays.<Writable>asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1)));
expected.add(Arrays.<Writable>asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1)));
assertEquals(expected, out);
}
@Test
public void testExecutionSequence() {
Schema schema = new SequenceSchema.Builder().addColumnInteger("col0")
.addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").build();
TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1")
.doubleMathOp("col2", MathOp.Add, 10.0).build();
@DisplayName("Test Execution Sequence")
void testExecutionSequence() {
Schema schema = new SequenceSchema.Builder().addColumnInteger("col0").addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").build();
TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1").doubleMathOp("col2", MathOp.Add, 10.0).build();
List<List<List<Writable>>> inputSequences = new ArrayList<>();
List<List<Writable>> seq1 = new ArrayList<>();
seq1.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
@ -96,22 +87,17 @@ public class ExecutionTest extends BaseSparkTest {
List<List<Writable>> seq2 = new ArrayList<>();
seq2.add(Arrays.<Writable>asList(new IntWritable(3), new Text("state0"), new DoubleWritable(3.1)));
seq2.add(Arrays.<Writable>asList(new IntWritable(4), new Text("state1"), new DoubleWritable(4.1)));
inputSequences.add(seq1);
inputSequences.add(seq2);
JavaRDD<List<List<Writable>>> rdd = sc.parallelize(inputSequences);
List<List<List<Writable>>> out =
new ArrayList<>(SparkTransformExecutor.executeSequenceToSequence(rdd, tp).collect());
List<List<List<Writable>>> out = new ArrayList<>(SparkTransformExecutor.executeSequenceToSequence(rdd, tp).collect());
Collections.sort(out, new Comparator<List<List<Writable>>>() {
@Override
public int compare(List<List<Writable>> o1, List<List<Writable>> o2) {
return -Integer.compare(o1.size(), o2.size());
}
});
List<List<List<Writable>>> expectedSequence = new ArrayList<>();
List<List<Writable>> seq1e = new ArrayList<>();
seq1e.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1)));
@ -120,99 +106,49 @@ public class ExecutionTest extends BaseSparkTest {
List<List<Writable>> seq2e = new ArrayList<>();
seq2e.add(Arrays.<Writable>asList(new IntWritable(3), new IntWritable(0), new DoubleWritable(13.1)));
seq2e.add(Arrays.<Writable>asList(new IntWritable(4), new IntWritable(1), new DoubleWritable(14.1)));
expectedSequence.add(seq1e);
expectedSequence.add(seq2e);
assertEquals(expectedSequence, out);
}
@Test
public void testReductionGlobal() {
List<List<Writable>> in = Arrays.asList(
Arrays.<Writable>asList(new Text("first"), new DoubleWritable(3.0)),
Arrays.<Writable>asList(new Text("second"), new DoubleWritable(5.0))
);
@DisplayName("Test Reduction Global")
void testReductionGlobal() {
List<List<Writable>> in = Arrays.asList(Arrays.<Writable>asList(new Text("first"), new DoubleWritable(3.0)), Arrays.<Writable>asList(new Text("second"), new DoubleWritable(5.0)));
JavaRDD<List<Writable>> inData = sc.parallelize(in);
Schema s = new Schema.Builder()
.addColumnString("textCol")
.addColumnDouble("doubleCol")
.build();
TransformProcess tp = new TransformProcess.Builder(s)
.reduce(new Reducer.Builder(ReduceOp.TakeFirst)
.takeFirstColumns("textCol")
.meanColumns("doubleCol").build())
.build();
Schema s = new Schema.Builder().addColumnString("textCol").addColumnDouble("doubleCol").build();
TransformProcess tp = new TransformProcess.Builder(s).reduce(new Reducer.Builder(ReduceOp.TakeFirst).takeFirstColumns("textCol").meanColumns("doubleCol").build()).build();
JavaRDD<List<Writable>> outRdd = SparkTransformExecutor.execute(inData, tp);
List<List<Writable>> out = outRdd.collect();
List<List<Writable>> expOut = Collections.singletonList(Arrays.<Writable>asList(new Text("first"), new DoubleWritable(4.0)));
assertEquals(expOut, out);
}
@Test
public void testReductionByKey(){
List<List<Writable>> in = Arrays.asList(
Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(3.0)),
Arrays.<Writable>asList(new IntWritable(0), new Text("second"), new DoubleWritable(5.0)),
Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(30.0)),
Arrays.<Writable>asList(new IntWritable(1), new Text("s"), new DoubleWritable(50.0))
);
@DisplayName("Test Reduction By Key")
void testReductionByKey() {
List<List<Writable>> in = Arrays.asList(Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(3.0)), Arrays.<Writable>asList(new IntWritable(0), new Text("second"), new DoubleWritable(5.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(30.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("s"), new DoubleWritable(50.0)));
JavaRDD<List<Writable>> inData = sc.parallelize(in);
Schema s = new Schema.Builder()
.addColumnInteger("intCol")
.addColumnString("textCol")
.addColumnDouble("doubleCol")
.build();
TransformProcess tp = new TransformProcess.Builder(s)
.reduce(new Reducer.Builder(ReduceOp.TakeFirst)
.keyColumns("intCol")
.takeFirstColumns("textCol")
.meanColumns("doubleCol").build())
.build();
Schema s = new Schema.Builder().addColumnInteger("intCol").addColumnString("textCol").addColumnDouble("doubleCol").build();
TransformProcess tp = new TransformProcess.Builder(s).reduce(new Reducer.Builder(ReduceOp.TakeFirst).keyColumns("intCol").takeFirstColumns("textCol").meanColumns("doubleCol").build()).build();
JavaRDD<List<Writable>> outRdd = SparkTransformExecutor.execute(inData, tp);
List<List<Writable>> out = outRdd.collect();
List<List<Writable>> expOut = Arrays.asList(
Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(4.0)),
Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(40.0)));
List<List<Writable>> expOut = Arrays.asList(Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(4.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(40.0)));
out = new ArrayList<>(out);
Collections.sort(
out, new Comparator<List<Writable>>() {
@Override
public int compare(List<Writable> o1, List<Writable> o2) {
return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
}
}
);
Collections.sort(out, new Comparator<List<Writable>>() {
@Override
public int compare(List<Writable> o1, List<Writable> o2) {
return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
}
});
assertEquals(expOut, out);
}
@Test
public void testUniqueMultiCol(){
Schema schema = new Schema.Builder()
.addColumnInteger("col0")
.addColumnCategorical("col1", "state0", "state1", "state2")
.addColumnDouble("col2").build();
@DisplayName("Test Unique Multi Col")
void testUniqueMultiCol() {
Schema schema = new Schema.Builder().addColumnInteger("col0").addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").build();
List<List<Writable>> inputData = new ArrayList<>();
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1)));
@ -223,149 +159,103 @@ public class ExecutionTest extends BaseSparkTest {
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1)));
JavaRDD<List<Writable>> rdd = sc.parallelize(inputData);
Map<String,List<Writable>> l = AnalyzeSpark.getUnique(Arrays.asList("col0", "col1"), schema, rdd);
Map<String, List<Writable>> l = AnalyzeSpark.getUnique(Arrays.asList("col0", "col1"), schema, rdd);
assertEquals(2, l.size());
List<Writable> c0 = l.get("col0");
assertEquals(3, c0.size());
assertTrue(c0.contains(new IntWritable(0)) && c0.contains(new IntWritable(1)) && c0.contains(new IntWritable(2)));
List<Writable> c1 = l.get("col1");
assertEquals(3, c1.size());
assertTrue(c1.contains(new Text("state0")) && c1.contains(new Text("state1")) && c1.contains(new Text("state2")));
}
@Test(timeout = 60000L)
@Ignore("AB 2019/05/21 - Fine locally, timeouts on CI - Issue #7657 and #7771")
public void testPythonExecution() throws Exception {
Schema schema = new Schema.Builder().addColumnInteger("col0")
.addColumnString("col1").addColumnDouble("col2").build();
@Test
@Disabled("AB 2019/05/21 - Fine locally, timeouts on CI - Issue #7657 and #7771")
@DisplayName("Test Python Execution")
void testPythonExecution() {
assertTimeout(ofMillis(60000), () -> {
Schema schema = new Schema.Builder().addColumnInteger("col0").addColumnString("col1").addColumnDouble("col2").build();
Schema finalSchema = new Schema.Builder().addColumnInteger("col0").addColumnInteger("col1").addColumnDouble("col2").build();
String pythonCode = "col1 = ['state0', 'state1', 'state2'].index(col1)\ncol2 += 10.0";
TransformProcess tp = new TransformProcess.Builder(schema).transform(PythonTransform.builder().code("first = np.sin(first)\nsecond = np.cos(second)").outputSchema(finalSchema).build()).build();
List<List<Writable>> inputData = new ArrayList<>();
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1)));
JavaRDD<List<Writable>> rdd = sc.parallelize(inputData);
List<List<Writable>> out = new ArrayList<>(SparkTransformExecutor.execute(rdd, tp).collect());
Collections.sort(out, new Comparator<List<Writable>>() {
Schema finalSchema = new Schema.Builder().addColumnInteger("col0")
.addColumnInteger("col1").addColumnDouble("col2").build();
String pythonCode = "col1 = ['state0', 'state1', 'state2'].index(col1)\ncol2 += 10.0";
TransformProcess tp = new TransformProcess.Builder(schema).transform(
PythonTransform.builder().code(
"first = np.sin(first)\nsecond = np.cos(second)")
.outputSchema(finalSchema).build()
).build();
List<List<Writable>> inputData = new ArrayList<>();
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1)));
JavaRDD<List<Writable>> rdd = sc.parallelize(inputData);
List<List<Writable>> out = new ArrayList<>(SparkTransformExecutor.execute(rdd, tp).collect());
Collections.sort(out, new Comparator<List<Writable>>() {
@Override
public int compare(List<Writable> o1, List<Writable> o2) {
return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
}
@Override
public int compare(List<Writable> o1, List<Writable> o2) {
return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
}
});
List<List<Writable>> expected = new ArrayList<>();
expected.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1)));
expected.add(Arrays.<Writable>asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1)));
expected.add(Arrays.<Writable>asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1)));
assertEquals(expected, out);
});
List<List<Writable>> expected = new ArrayList<>();
expected.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1)));
expected.add(Arrays.<Writable>asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1)));
expected.add(Arrays.<Writable>asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1)));
assertEquals(expected, out);
}
@Test(timeout = 60000L)
@Ignore("AB 2019/05/21 - Fine locally, timeouts on CI - Issue #7657 and #7771")
public void testPythonExecutionWithNDArrays() throws Exception {
long[] shape = new long[]{3, 2};
Schema schema = new Schema.Builder().addColumnInteger("id").addColumnNDArray("col1", shape)
.addColumnNDArray("col2", shape).build();
Schema finalSchema = new Schema.Builder().addColumnInteger("id").addColumnNDArray("col1", shape)
.addColumnNDArray("col2", shape).addColumnNDArray("col3", shape).build();
String pythonCode = "col3 = col1 + col2";
TransformProcess tp = new TransformProcess.Builder(schema).transform(
PythonTransform.builder().code(
"first = np.sin(first)\nsecond = np.cos(second)")
.outputSchema(schema).build()
).build();
INDArray zeros = Nd4j.zeros(shape);
INDArray ones = Nd4j.ones(shape);
INDArray twos = ones.add(ones);
List<List<Writable>> inputData = new ArrayList<>();
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new NDArrayWritable(zeros), new NDArrayWritable(zeros)));
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new NDArrayWritable(zeros), new NDArrayWritable(ones)));
inputData.add(Arrays.<Writable>asList(new IntWritable(2), new NDArrayWritable(ones), new NDArrayWritable(ones)));
JavaRDD<List<Writable>> rdd = sc.parallelize(inputData);
List<List<Writable>> out = new ArrayList<>(SparkTransformExecutor.execute(rdd, tp).collect());
Collections.sort(out, new Comparator<List<Writable>>() {
@Override
public int compare(List<Writable> o1, List<Writable> o2) {
return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
}
});
List<List<Writable>> expected = new ArrayList<>();
expected.add(Arrays.<Writable>asList(new IntWritable(0), new NDArrayWritable(zeros), new NDArrayWritable(zeros), new NDArrayWritable(zeros)));
expected.add(Arrays.<Writable>asList(new IntWritable(1), new NDArrayWritable(zeros), new NDArrayWritable(ones), new NDArrayWritable(ones)));
expected.add(Arrays.<Writable>asList(new IntWritable(2), new NDArrayWritable(ones), new NDArrayWritable(ones), new NDArrayWritable(twos)));
}
@Test
public void testFirstDigitTransformBenfordsLaw(){
Schema s = new Schema.Builder()
.addColumnString("data")
.addColumnDouble("double")
.addColumnString("stringNumber")
.build();
@Disabled("AB 2019/05/21 - Fine locally, timeouts on CI - Issue #7657 and #7771")
@DisplayName("Test Python Execution With ND Arrays")
void testPythonExecutionWithNDArrays() {
assertTimeout(ofMillis(60000), () -> {
long[] shape = new long[] { 3, 2 };
Schema schema = new Schema.Builder().addColumnInteger("id").addColumnNDArray("col1", shape).addColumnNDArray("col2", shape).build();
Schema finalSchema = new Schema.Builder().addColumnInteger("id").addColumnNDArray("col1", shape).addColumnNDArray("col2", shape).addColumnNDArray("col3", shape).build();
String pythonCode = "col3 = col1 + col2";
TransformProcess tp = new TransformProcess.Builder(schema).transform(PythonTransform.builder().code("first = np.sin(first)\nsecond = np.cos(second)").outputSchema(schema).build()).build();
INDArray zeros = Nd4j.zeros(shape);
INDArray ones = Nd4j.ones(shape);
INDArray twos = ones.add(ones);
List<List<Writable>> inputData = new ArrayList<>();
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new NDArrayWritable(zeros), new NDArrayWritable(zeros)));
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new NDArrayWritable(zeros), new NDArrayWritable(ones)));
inputData.add(Arrays.<Writable>asList(new IntWritable(2), new NDArrayWritable(ones), new NDArrayWritable(ones)));
JavaRDD<List<Writable>> rdd = sc.parallelize(inputData);
List<List<Writable>> out = new ArrayList<>(SparkTransformExecutor.execute(rdd, tp).collect());
Collections.sort(out, new Comparator<List<Writable>>() {
List<List<Writable>> in = Arrays.asList(
Arrays.<Writable>asList(new Text("a"), new DoubleWritable(3.14159), new Text("8e-4")),
Arrays.<Writable>asList(new Text("a2"), new DoubleWritable(3.14159), new Text("7e-4")),
Arrays.<Writable>asList(new Text("b"), new DoubleWritable(2.71828), new Text("7e2")),
Arrays.<Writable>asList(new Text("c"), new DoubleWritable(1.61803), new Text("6e8")),
Arrays.<Writable>asList(new Text("c"), new DoubleWritable(1.61803), new Text("2.0")),
Arrays.<Writable>asList(new Text("c"), new DoubleWritable(1.61803), new Text("2.1")),
Arrays.<Writable>asList(new Text("c"), new DoubleWritable(1.61803), new Text("2.2")),
Arrays.<Writable>asList(new Text("c"), new DoubleWritable(-2), new Text("non numerical")));
//Test Benfords law use case:
TransformProcess tp = new TransformProcess.Builder(s)
.firstDigitTransform("double", "fdDouble", FirstDigitTransform.Mode.EXCEPTION_ON_INVALID)
.firstDigitTransform("stringNumber", "stringNumber", FirstDigitTransform.Mode.INCLUDE_OTHER_CATEGORY)
.removeAllColumnsExceptFor("stringNumber")
.categoricalToOneHot("stringNumber")
.reduce(new Reducer.Builder(ReduceOp.Sum).build())
.build();
JavaRDD<List<Writable>> rdd = sc.parallelize(in);
List<List<Writable>> out = SparkTransformExecutor.execute(rdd, tp).collect();
assertEquals(1, out.size());
List<Writable> l = out.get(0);
List<Writable> exp = Arrays.<Writable>asList(
new IntWritable(0), //0
new IntWritable(0), //1
new IntWritable(3), //2
new IntWritable(0), //3
new IntWritable(0), //4
new IntWritable(0), //5
new IntWritable(1), //6
new IntWritable(2), //7
new IntWritable(1), //8
new IntWritable(0), //9
new IntWritable(1)); //Other
assertEquals(exp, l);
@Override
public int compare(List<Writable> o1, List<Writable> o2) {
return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
}
});
List<List<Writable>> expected = new ArrayList<>();
expected.add(Arrays.<Writable>asList(new IntWritable(0), new NDArrayWritable(zeros), new NDArrayWritable(zeros), new NDArrayWritable(zeros)));
expected.add(Arrays.<Writable>asList(new IntWritable(1), new NDArrayWritable(zeros), new NDArrayWritable(ones), new NDArrayWritable(ones)));
expected.add(Arrays.<Writable>asList(new IntWritable(2), new NDArrayWritable(ones), new NDArrayWritable(ones), new NDArrayWritable(twos)));
});
}
@Test
@DisplayName("Test First Digit Transform Benfords Law")
void testFirstDigitTransformBenfordsLaw() {
Schema s = new Schema.Builder().addColumnString("data").addColumnDouble("double").addColumnString("stringNumber").build();
List<List<Writable>> in = Arrays.asList(Arrays.<Writable>asList(new Text("a"), new DoubleWritable(3.14159), new Text("8e-4")), Arrays.<Writable>asList(new Text("a2"), new DoubleWritable(3.14159), new Text("7e-4")), Arrays.<Writable>asList(new Text("b"), new DoubleWritable(2.71828), new Text("7e2")), Arrays.<Writable>asList(new Text("c"), new DoubleWritable(1.61803), new Text("6e8")), Arrays.<Writable>asList(new Text("c"), new DoubleWritable(1.61803), new Text("2.0")), Arrays.<Writable>asList(new Text("c"), new DoubleWritable(1.61803), new Text("2.1")), Arrays.<Writable>asList(new Text("c"), new DoubleWritable(1.61803), new Text("2.2")), Arrays.<Writable>asList(new Text("c"), new DoubleWritable(-2), new Text("non numerical")));
// Test Benfords law use case:
TransformProcess tp = new TransformProcess.Builder(s).firstDigitTransform("double", "fdDouble", FirstDigitTransform.Mode.EXCEPTION_ON_INVALID).firstDigitTransform("stringNumber", "stringNumber", FirstDigitTransform.Mode.INCLUDE_OTHER_CATEGORY).removeAllColumnsExceptFor("stringNumber").categoricalToOneHot("stringNumber").reduce(new Reducer.Builder(ReduceOp.Sum).build()).build();
JavaRDD<List<Writable>> rdd = sc.parallelize(in);
List<List<Writable>> out = SparkTransformExecutor.execute(rdd, tp).collect();
assertEquals(1, out.size());
List<Writable> l = out.get(0);
List<Writable> exp = Arrays.<Writable>asList(// 0
new IntWritable(0), // 1
new IntWritable(0), // 2
new IntWritable(3), // 3
new IntWritable(0), // 4
new IntWritable(0), // 5
new IntWritable(0), // 6
new IntWritable(1), // 7
new IntWritable(2), // 8
new IntWritable(1), // 9
new IntWritable(0), // Other
new IntWritable(1));
assertEquals(exp, l);
}
}

View File

@ -89,14 +89,22 @@
<dependencies>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>${junit.version}</version>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-api</artifactId>
</dependency>
<dependency>
<groupId>org.junit.vintage</groupId>
<artifactId>junit-vintage-engine</artifactId>
</dependency>
<dependency>
<groupId>com.tngtech.archunit</groupId>
<artifactId>archunit-junit5-engine</artifactId>
<version>${archunit.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.tngtech.archunit</groupId>
<artifactId>archunit-junit4</artifactId>
<artifactId>archunit-junit5-api</artifactId>
<version>${archunit.version}</version>
<scope>test</scope>
</dependency>

View File

@ -34,10 +34,18 @@
<dependencies>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-api</artifactId>
<version>${junit.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-engine</artifactId>
<version>${junit.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-api</artifactId>

View File

@ -17,17 +17,13 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j;
import ch.qos.logback.classic.LoggerContext;
import lombok.extern.slf4j.Slf4j;
import org.bytedeco.javacpp.Pointer;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.rules.TestName;
import org.junit.rules.Timeout;
import org.junit.jupiter.api.*;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.config.ND4JSystemProperties;
import org.nd4j.linalg.api.buffer.DataType;
@ -37,23 +33,22 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.profiler.ProfilerConfig;
import org.slf4j.ILoggerFactory;
import org.slf4j.LoggerFactory;
import java.lang.management.ManagementFactory;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import static org.junit.jupiter.api.Assumptions.assumeTrue;
import static org.junit.Assume.assumeTrue;
import org.junit.jupiter.api.extension.ExtendWith;
@Slf4j
@DisplayName("Base DL 4 J Test")
public abstract class BaseDL4JTest {
@Rule
public TestName name = new TestName();
@Rule
public Timeout timeout = Timeout.millis(getTimeoutMilliseconds());
protected long startTime;
protected int threadCountBefore;
private final int DEFAULT_THREADS = Runtime.getRuntime().availableProcessors();
@ -63,32 +58,32 @@ public abstract class BaseDL4JTest {
* {@link org.nd4j.linalg.factory.Environment#setMaxMasterThreads(int)}
* @return Number of threads to use for C++ op execution
*/
public int numThreads(){
public int numThreads() {
return DEFAULT_THREADS;
}
/**
* Override this method to set the default timeout for methods in the test class
*/
public long getTimeoutMilliseconds(){
public long getTimeoutMilliseconds() {
return 90_000;
}
/**
* Override this to set the profiling mode for the tests defined in the child class
*/
public OpExecutioner.ProfilingMode getProfilingMode(){
public OpExecutioner.ProfilingMode getProfilingMode() {
return OpExecutioner.ProfilingMode.SCOPE_PANIC;
}
/**
* Override this to set the datatype of the tests defined in the child class
*/
public DataType getDataType(){
public DataType getDataType() {
return DataType.DOUBLE;
}
public DataType getDefaultFPDataType(){
public DataType getDefaultFPDataType() {
return getDataType();
}
@ -97,8 +92,8 @@ public abstract class BaseDL4JTest {
/**
* @return True if integration tests maven profile is enabled, false otherwise.
*/
public static boolean isIntegrationTests(){
if(integrationTest == null){
public static boolean isIntegrationTests() {
if (integrationTest == null) {
String prop = System.getenv("DL4J_INTEGRATION_TESTS");
integrationTest = Boolean.parseBoolean(prop);
}
@ -110,14 +105,15 @@ public abstract class BaseDL4JTest {
* This can be used to dynamically skip integration tests when the integration test profile is not enabled.
* Note that the integration test profile is not enabled by default - "integration-tests" profile
*/
public static void skipUnlessIntegrationTests(){
assumeTrue("Skipping integration test - integration profile is not enabled", isIntegrationTests());
public static void skipUnlessIntegrationTests() {
assumeTrue(isIntegrationTests(), "Skipping integration test - integration profile is not enabled");
}
@Before
public void beforeTest(){
log.info("{}.{}", getClass().getSimpleName(), name.getMethodName());
//Suppress ND4J initialization - don't need this logged for every test...
@BeforeEach
@Timeout(90000L)
void beforeTest(TestInfo testInfo) {
log.info("{}.{}", getClass().getSimpleName(), testInfo.getTestMethod().get().getName());
// Suppress ND4J initialization - don't need this logged for every test...
System.setProperty(ND4JSystemProperties.LOG_INITIALIZATION, "false");
System.setProperty(ND4JSystemProperties.ND4J_IGNORE_AVX, "true");
Nd4j.getExecutioner().setProfilingMode(getProfilingMode());
@ -128,83 +124,71 @@ public abstract class BaseDL4JTest {
Nd4j.getExecutioner().enableVerboseMode(false);
int numThreads = numThreads();
Preconditions.checkState(numThreads > 0, "Number of threads must be > 0");
if(numThreads != Nd4j.getEnvironment().maxMasterThreads()) {
if (numThreads != Nd4j.getEnvironment().maxMasterThreads()) {
Nd4j.getEnvironment().setMaxMasterThreads(numThreads);
}
startTime = System.currentTimeMillis();
threadCountBefore = ManagementFactory.getThreadMXBean().getThreadCount();
}
@After
public void afterTest(){
//Attempt to keep workspaces isolated between tests
@AfterEach
void afterTest(TestInfo testInfo) {
// Attempt to keep workspaces isolated between tests
Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread();
MemoryWorkspace currWS = Nd4j.getMemoryManager().getCurrentWorkspace();
Nd4j.getMemoryManager().setCurrentWorkspace(null);
if(currWS != null){
//Not really safe to continue testing under this situation... other tests will likely fail with obscure
if (currWS != null) {
// Not really safe to continue testing under this situation... other tests will likely fail with obscure
// errors that are hard to track back to this
log.error("Open workspace leaked from test! Exiting - {}, isOpen = {} - {}", currWS.getId(), currWS.isScopeActive(), currWS);
System.out.println("Open workspace leaked from test! Exiting - " + currWS.getId() + ", isOpen = " + currWS.isScopeActive() + " - " + currWS);
System.out.flush();
//Try to flush logs also:
try{ Thread.sleep(1000); } catch (InterruptedException e){ }
ILoggerFactory lf = LoggerFactory.getILoggerFactory();
if( lf instanceof LoggerContext){
((LoggerContext)lf).stop();
// Try to flush logs also:
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
}
ILoggerFactory lf = LoggerFactory.getILoggerFactory();
if (lf instanceof LoggerContext) {
((LoggerContext) lf).stop();
}
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
}
try{ Thread.sleep(1000); } catch (InterruptedException e){ }
System.exit(1);
}
StringBuilder sb = new StringBuilder();
long maxPhys = Pointer.maxPhysicalBytes();
long maxBytes = Pointer.maxBytes();
long currPhys = Pointer.physicalBytes();
long currBytes = Pointer.totalBytes();
long jvmTotal = Runtime.getRuntime().totalMemory();
long jvmMax = Runtime.getRuntime().maxMemory();
int threadsAfter = ManagementFactory.getThreadMXBean().getThreadCount();
long duration = System.currentTimeMillis() - startTime;
sb.append(getClass().getSimpleName()).append(".").append(name.getMethodName())
.append(": ").append(duration).append(" ms")
.append(", threadCount: (").append(threadCountBefore).append("->").append(threadsAfter).append(")")
.append(", jvmTotal=").append(jvmTotal)
.append(", jvmMax=").append(jvmMax)
.append(", totalBytes=").append(currBytes).append(", maxBytes=").append(maxBytes)
.append(", currPhys=").append(currPhys).append(", maxPhys=").append(maxPhys);
sb.append(getClass().getSimpleName()).append(".").append(testInfo.getTestMethod().get().getName()).append(": ").append(duration).append(" ms").append(", threadCount: (").append(threadCountBefore).append("->").append(threadsAfter).append(")").append(", jvmTotal=").append(jvmTotal).append(", jvmMax=").append(jvmMax).append(", totalBytes=").append(currBytes).append(", maxBytes=").append(maxBytes).append(", currPhys=").append(currPhys).append(", maxPhys=").append(maxPhys);
List<MemoryWorkspace> ws = Nd4j.getWorkspaceManager().getAllWorkspacesForCurrentThread();
if(ws != null && ws.size() > 0){
if (ws != null && ws.size() > 0) {
long currSize = 0;
for(MemoryWorkspace w : ws){
for (MemoryWorkspace w : ws) {
currSize += w.getCurrentSize();
}
if(currSize > 0){
sb.append(", threadWSSize=").append(currSize)
.append(" (").append(ws.size()).append(" WSs)");
if (currSize > 0) {
sb.append(", threadWSSize=").append(currSize).append(" (").append(ws.size()).append(" WSs)");
}
}
Properties p = Nd4j.getExecutioner().getEnvironmentInformation();
Object o = p.get("cuda.devicesInformation");
if(o instanceof List){
List<Map<String,Object>> l = (List<Map<String, Object>>) o;
if(l.size() > 0) {
sb.append(" [").append(l.size())
.append(" GPUs: ");
if (o instanceof List) {
List<Map<String, Object>> l = (List<Map<String, Object>>) o;
if (l.size() > 0) {
sb.append(" [").append(l.size()).append(" GPUs: ");
for (int i = 0; i < l.size(); i++) {
Map<String,Object> m = l.get(i);
if(i > 0)
Map<String, Object> m = l.get(i);
if (i > 0)
sb.append(",");
sb.append("(").append(m.get("cuda.freeMemory")).append(" free, ")
.append(m.get("cuda.totalMemory")).append(" total)");
sb.append("(").append(m.get("cuda.freeMemory")).append(" free, ").append(m.get("cuda.totalMemory")).append(" total)");
}
sb.append("]");
}

View File

@ -41,8 +41,15 @@
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-api</artifactId>
<version>${junit.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-engine</artifactId>
<version>${junit.version}</version>
<scope>test</scope>
</dependency>
</dependencies>

View File

@ -17,70 +17,56 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.common.config;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import org.deeplearning4j.common.config.dummies.TestAbstract;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@DisplayName("Dl 4 J Class Loading Test")
class DL4JClassLoadingTest {
public class DL4JClassLoadingTest {
private static final String PACKAGE_PREFIX = "org.deeplearning4j.common.config.dummies.";
@Test
public void testCreateNewInstance_constructorWithoutArguments() {
@DisplayName("Test Create New Instance _ constructor Without Arguments")
void testCreateNewInstance_constructorWithoutArguments() {
/* Given */
String className = PACKAGE_PREFIX + "TestDummy";
/* When */
Object instance = DL4JClassLoading.createNewInstance(className);
/* Then */
assertNotNull(instance);
assertEquals(className, instance.getClass().getName());
}
@Test
public void testCreateNewInstance_constructorWithArgument_implicitArgumentTypes() {
@DisplayName("Test Create New Instance _ constructor With Argument _ implicit Argument Types")
void testCreateNewInstance_constructorWithArgument_implicitArgumentTypes() {
/* Given */
String className = PACKAGE_PREFIX + "TestColor";
/* When */
TestAbstract instance = DL4JClassLoading.createNewInstance(className, TestAbstract.class, "white");
/* Then */
assertNotNull(instance);
assertEquals(className, instance.getClass().getName());
}
@Test
public void testCreateNewInstance_constructorWithArgument_explicitArgumentTypes() {
@DisplayName("Test Create New Instance _ constructor With Argument _ explicit Argument Types")
void testCreateNewInstance_constructorWithArgument_explicitArgumentTypes() {
/* Given */
String colorClassName = PACKAGE_PREFIX + "TestColor";
String rectangleClassName = PACKAGE_PREFIX + "TestRectangle";
/* When */
TestAbstract color = DL4JClassLoading.createNewInstance(
colorClassName,
Object.class,
new Class<?>[]{ int.class, int.class, int.class },
45, 175, 200);
TestAbstract rectangle = DL4JClassLoading.createNewInstance(
rectangleClassName,
Object.class,
new Class<?>[]{ int.class, int.class, TestAbstract.class },
10, 15, color);
TestAbstract color = DL4JClassLoading.createNewInstance(colorClassName, Object.class, new Class<?>[] { int.class, int.class, int.class }, 45, 175, 200);
TestAbstract rectangle = DL4JClassLoading.createNewInstance(rectangleClassName, Object.class, new Class<?>[] { int.class, int.class, TestAbstract.class }, 10, 15, color);
/* Then */
assertNotNull(color);
assertEquals(colorClassName, color.getClass().getName());
assertNotNull(rectangle);
assertEquals(rectangleClassName, rectangle.getClass().getName());
}

View File

@ -49,11 +49,6 @@
</dependencyManagement>
<dependencies>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-tsne</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-datasets</artifactId>
@ -99,8 +94,12 @@
<version>${commons-compress.version}</version>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-api</artifactId>
</dependency>
<dependency>
<groupId>org.junit.vintage</groupId>
<artifactId>junit-vintage-engine</artifactId>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>

View File

@ -17,15 +17,16 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.datasets;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.base.MnistFetcher;
import org.deeplearning4j.common.resources.DL4JResources;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.junit.*;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.junit.rules.Timeout;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
@ -33,69 +34,67 @@ import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.conditions.Conditions;
import java.io.File;
import java.nio.file.Path;
import java.util.HashSet;
import java.util.Set;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
@DisplayName("Mnist Fetcher Test")
class MnistFetcherTest extends BaseDL4JTest {
public class MnistFetcherTest extends BaseDL4JTest {
@ClassRule
public static TemporaryFolder testDir = new TemporaryFolder();
@Rule
public Timeout timeout = Timeout.seconds(300);
@BeforeClass
public static void setup() throws Exception {
DL4JResources.setBaseDirectory(testDir.newFolder());
@BeforeAll
static void setup(@TempDir Path tempPath) throws Exception {
DL4JResources.setBaseDirectory(tempPath.toFile());
}
@AfterClass
public static void after() {
@AfterAll
static void after() {
DL4JResources.resetBaseDirectoryLocation();
}
@Test
public void testMnist() throws Exception {
@DisplayName("Test Mnist")
void testMnist() throws Exception {
DataSetIterator iter = new MnistDataSetIterator(32, 60000, false, true, false, -1);
int count = 0;
while(iter.hasNext()){
while (iter.hasNext()) {
DataSet ds = iter.next();
INDArray arr = ds.getFeatures().sum(1);
int countMatch = Nd4j.getExecutioner().execAndReturn(new MatchCondition(arr, Conditions.equals(0))).z().getInt(0);
assertEquals(0, countMatch);
count++;
}
assertEquals(60000/32, count);
assertEquals(60000 / 32, count);
count = 0;
iter = new MnistDataSetIterator(32, false, 12345);
while(iter.hasNext()){
while (iter.hasNext()) {
DataSet ds = iter.next();
INDArray arr = ds.getFeatures().sum(1);
int countMatch = Nd4j.getExecutioner().execAndReturn(new MatchCondition(arr, Conditions.equals(0))).z().getInt(0);
assertEquals(0, countMatch);
count++;
}
assertEquals((int)Math.ceil(10000/32.0), count);
assertEquals((int) Math.ceil(10000 / 32.0), count);
}
@Test
public void testMnistDataFetcher() throws Exception {
@DisplayName("Test Mnist Data Fetcher")
void testMnistDataFetcher() throws Exception {
MnistFetcher mnistFetcher = new MnistFetcher();
File mnistDir = mnistFetcher.downloadAndUntar();
assertTrue(mnistDir.isDirectory());
}
// @Test
// @Test
public void testMnistSubset() throws Exception {
final int numExamples = 100;
MnistDataSetIterator iter1 = new MnistDataSetIterator(10, numExamples, false, true, true, 123);
int examples1 = 0;
int itCount1 = 0;
@ -105,7 +104,6 @@ public class MnistFetcherTest extends BaseDL4JTest {
}
assertEquals(10, itCount1);
assertEquals(100, examples1);
MnistDataSetIterator iter2 = new MnistDataSetIterator(10, numExamples, false, true, true, 123);
int examples2 = 0;
int itCount2 = 0;
@ -116,7 +114,6 @@ public class MnistFetcherTest extends BaseDL4JTest {
assertFalse(iter2.hasNext());
assertEquals(10, itCount2);
assertEquals(100, examples2);
MnistDataSetIterator iter3 = new MnistDataSetIterator(19, numExamples, false, true, true, 123);
int examples3 = 0;
int itCount3 = 0;
@ -125,51 +122,45 @@ public class MnistFetcherTest extends BaseDL4JTest {
examples3 += iter3.next().numExamples();
}
assertEquals(100, examples3);
assertEquals((int)Math.ceil(100/19.0), itCount3);
assertEquals((int) Math.ceil(100 / 19.0), itCount3);
MnistDataSetIterator iter4 = new MnistDataSetIterator(32, true, 12345);
int count4 = 0;
while(iter4.hasNext()){
while (iter4.hasNext()) {
count4 += iter4.next().numExamples();
}
assertEquals(60000, count4);
}
@Test
public void testSubsetRepeatability() throws Exception {
@DisplayName("Test Subset Repeatability")
void testSubsetRepeatability() throws Exception {
DataSetIterator it = new MnistDataSetIterator(1, 1, false, false, true, 0);
DataSet d1 = it.next();
for( int i=0; i<10; i++ ) {
for (int i = 0; i < 10; i++) {
it.reset();
DataSet d2 = it.next();
assertEquals(d1.get(0).getFeatures(), d2.get(0).getFeatures());
}
//Check larger number:
// Check larger number:
it = new MnistDataSetIterator(8, 32, false, false, true, 12345);
Set<String> featureLabelSet = new HashSet<>();
while(it.hasNext()){
while (it.hasNext()) {
DataSet ds = it.next();
INDArray f = ds.getFeatures();
INDArray l = ds.getLabels();
for( int i=0; i<f.size(0); i++ ){
for (int i = 0; i < f.size(0); i++) {
featureLabelSet.add(f.getRow(i).toString() + "\t" + l.getRow(i).toString());
}
}
assertEquals(32, featureLabelSet.size());
for( int i=0; i<3; i++ ){
for (int i = 0; i < 3; i++) {
it.reset();
Set<String> flSet2 = new HashSet<>();
while(it.hasNext()){
while (it.hasNext()) {
DataSet ds = it.next();
INDArray f = ds.getFeatures();
INDArray l = ds.getLabels();
for( int j=0; j<f.size(0); j++ ){
for (int j = 0; j < f.size(0); j++) {
flSet2.add(f.getRow(j).toString() + "\t" + l.getRow(j).toString());
}
}

View File

@ -17,10 +17,8 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.datasets.datavec;
import org.junit.rules.Timeout;
import org.nd4j.shade.guava.io.Files;
import org.apache.commons.io.FileUtils;
@ -47,8 +45,8 @@ import org.datavec.image.recordreader.ImageRecordReader;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
@ -58,42 +56,40 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.common.io.ClassPathResource;
import org.nd4j.common.resources.Resources;
import java.io.*;
import java.net.URI;
import java.util.*;
import static org.junit.Assert.*;
import static org.junit.jupiter.api.Assertions.*;
import static org.nd4j.linalg.indexing.NDArrayIndex.all;
import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
import static org.nd4j.linalg.indexing.NDArrayIndex.point;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest {
@DisplayName("Record Reader Multi Data Set Iterator Test")
class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest {
@Rule
public TemporaryFolder temporaryFolder = new TemporaryFolder();
@TempDir
public Path temporaryFolder;
@Rule
public Timeout timeout = Timeout.seconds(300);
@Test
public void testsBasic() throws Exception {
//Load details from CSV files; single input/output -> compare to RecordReaderDataSetIterator
@DisplayName("Tests Basic")
void testsBasic() throws Exception {
// Load details from CSV files; single input/output -> compare to RecordReaderDataSetIterator
RecordReader rr = new CSVRecordReader(0, ',');
rr.initialize(new FileSplit(Resources.asFile("iris.txt")));
RecordReaderDataSetIterator rrdsi = new RecordReaderDataSetIterator(rr, 10, 4, 3);
RecordReader rr2 = new CSVRecordReader(0, ',');
rr2.initialize(new FileSplit(Resources.asFile("iris.txt")));
MultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10).addReader("reader", rr2)
.addInput("reader", 0, 3).addOutputOneHot("reader", 4, 3).build();
MultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10).addReader("reader", rr2).addInput("reader", 0, 3).addOutputOneHot("reader", 4, 3).build();
while (rrdsi.hasNext()) {
DataSet ds = rrdsi.next();
INDArray fds = ds.getFeatures();
INDArray lds = ds.getLabels();
MultiDataSet mds = rrmdsi.next();
assertEquals(1, mds.getFeatures().length);
assertEquals(1, mds.getLabels().length);
@ -101,49 +97,36 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest {
assertNull(mds.getLabelsMaskArrays());
INDArray fmds = mds.getFeatures(0);
INDArray lmds = mds.getLabels(0);
assertNotNull(fmds);
assertNotNull(lmds);
assertEquals(fds, fmds);
assertEquals(lds, lmds);
}
assertFalse(rrmdsi.hasNext());
//need to manually extract
File rootDir = temporaryFolder.newFolder();
// need to manually extract
File rootDir = temporaryFolder.toFile();
for (int i = 0; i < 3; i++) {
new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(rootDir);
new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive(rootDir);
new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive(rootDir);
}
//Load time series from CSV sequence files; compare to SequenceRecordReaderDataSetIterator
// Load time series from CSV sequence files; compare to SequenceRecordReaderDataSetIterator
String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt");
String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabels_%d.txt");
SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
SequenceRecordReaderDataSetIterator iter =
new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false);
SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false);
SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ",");
featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
MultiDataSetIterator srrmdsi = new RecordReaderMultiDataSetIterator.Builder(1)
.addSequenceReader("in", featureReader2).addSequenceReader("out", labelReader2).addInput("in")
.addOutputOneHot("out", 0, 4).build();
MultiDataSetIterator srrmdsi = new RecordReaderMultiDataSetIterator.Builder(1).addSequenceReader("in", featureReader2).addSequenceReader("out", labelReader2).addInput("in").addOutputOneHot("out", 0, 4).build();
while (iter.hasNext()) {
DataSet ds = iter.next();
INDArray fds = ds.getFeatures();
INDArray lds = ds.getLabels();
MultiDataSet mds = srrmdsi.next();
assertEquals(1, mds.getFeatures().length);
assertEquals(1, mds.getLabels().length);
@ -151,10 +134,8 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest {
assertNull(mds.getLabelsMaskArrays());
INDArray fmds = mds.getFeatures(0);
INDArray lmds = mds.getLabels(0);
assertNotNull(fmds);
assertNotNull(lmds);
assertEquals(fds, fmds);
assertEquals(lds, lmds);
}
@ -162,16 +143,13 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest {
}
@Test
public void testsBasicMeta() throws Exception {
//As per testBasic - but also loading metadata
@DisplayName("Tests Basic Meta")
void testsBasicMeta() throws Exception {
// As per testBasic - but also loading metadata
RecordReader rr2 = new CSVRecordReader(0, ',');
rr2.initialize(new FileSplit(Resources.asFile("iris.txt")));
RecordReaderMultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10)
.addReader("reader", rr2).addInput("reader", 0, 3).addOutputOneHot("reader", 4, 3).build();
RecordReaderMultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10).addReader("reader", rr2).addInput("reader", 0, 3).addOutputOneHot("reader", 4, 3).build();
rrmdsi.setCollectMetaData(true);
int count = 0;
while (rrmdsi.hasNext()) {
MultiDataSet mds = rrmdsi.next();
@ -183,27 +161,22 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest {
}
@Test
public void testSplittingCSV() throws Exception {
//Here's the idea: take Iris, and split it up into 2 inputs and 2 output arrays
//Inputs: columns 0 and 1-2
//Outputs: columns 3, and 4->OneHot
//need to manually extract
@DisplayName("Test Splitting CSV")
void testSplittingCSV() throws Exception {
// Here's the idea: take Iris, and split it up into 2 inputs and 2 output arrays
// Inputs: columns 0 and 1-2
// Outputs: columns 3, and 4->OneHot
// need to manually extract
RecordReader rr = new CSVRecordReader(0, ',');
rr.initialize(new FileSplit(Resources.asFile("iris.txt")));
RecordReaderDataSetIterator rrdsi = new RecordReaderDataSetIterator(rr, 10, 4, 3);
RecordReader rr2 = new CSVRecordReader(0, ',');
rr2.initialize(new FileSplit(Resources.asFile("iris.txt")));
MultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10).addReader("reader", rr2)
.addInput("reader", 0, 0).addInput("reader", 1, 2).addOutput("reader", 3, 3)
.addOutputOneHot("reader", 4, 3).build();
MultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10).addReader("reader", rr2).addInput("reader", 0, 0).addInput("reader", 1, 2).addOutput("reader", 3, 3).addOutputOneHot("reader", 4, 3).build();
while (rrdsi.hasNext()) {
DataSet ds = rrdsi.next();
INDArray fds = ds.getFeatures();
INDArray lds = ds.getLabels();
MultiDataSet mds = rrmdsi.next();
assertEquals(2, mds.getFeatures().length);
assertEquals(2, mds.getLabels().length);
@ -211,20 +184,15 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest {
assertNull(mds.getLabelsMaskArrays());
INDArray[] fmds = mds.getFeatures();
INDArray[] lmds = mds.getLabels();
assertNotNull(fmds);
assertNotNull(lmds);
for (int i = 0; i < fmds.length; i++)
assertNotNull(fmds[i]);
for (int i = 0; i < lmds.length; i++)
assertNotNull(lmds[i]);
//Get the subsets of the original iris data
INDArray expIn1 = fds.get(all(), interval(0,0,true));
for (int i = 0; i < fmds.length; i++) assertNotNull(fmds[i]);
for (int i = 0; i < lmds.length; i++) assertNotNull(lmds[i]);
// Get the subsets of the original iris data
INDArray expIn1 = fds.get(all(), interval(0, 0, true));
INDArray expIn2 = fds.get(all(), interval(1, 2, true));
INDArray expOut1 = fds.get(all(), interval(3,3,true));
INDArray expOut1 = fds.get(all(), interval(3, 3, true));
INDArray expOut2 = lds;
assertEquals(expIn1, fmds[0]);
assertEquals(expIn2, fmds[1]);
assertEquals(expOut1, lmds[0]);
@ -234,18 +202,15 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest {
}
@Test
public void testSplittingCSVMeta() throws Exception {
//Here's the idea: take Iris, and split it up into 2 inputs and 2 output arrays
//Inputs: columns 0 and 1-2
//Outputs: columns 3, and 4->OneHot
@DisplayName("Test Splitting CSV Meta")
void testSplittingCSVMeta() throws Exception {
// Here's the idea: take Iris, and split it up into 2 inputs and 2 output arrays
// Inputs: columns 0 and 1-2
// Outputs: columns 3, and 4->OneHot
RecordReader rr2 = new CSVRecordReader(0, ',');
rr2.initialize(new FileSplit(Resources.asFile("iris.txt")));
RecordReaderMultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10)
.addReader("reader", rr2).addInput("reader", 0, 0).addInput("reader", 1, 2)
.addOutput("reader", 3, 3).addOutputOneHot("reader", 4, 3).build();
RecordReaderMultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10).addReader("reader", rr2).addInput("reader", 0, 0).addInput("reader", 1, 2).addOutput("reader", 3, 3).addOutputOneHot("reader", 4, 3).build();
rrmdsi.setCollectMetaData(true);
int count = 0;
while (rrmdsi.hasNext()) {
MultiDataSet mds = rrmdsi.next();
@ -257,42 +222,33 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest {
}
@Test
public void testSplittingCSVSequence() throws Exception {
//Idea: take CSV sequences, and split "csvsequence_i.txt" into two separate inputs; keep "csvSequencelables_i.txt"
@DisplayName("Test Splitting CSV Sequence")
void testSplittingCSVSequence() throws Exception {
// Idea: take CSV sequences, and split "csvsequence_i.txt" into two separate inputs; keep "csvSequencelables_i.txt"
// as standard one-hot output
//need to manually extract
File rootDir = temporaryFolder.newFolder();
// need to manually extract
File rootDir = temporaryFolder.toFile();
for (int i = 0; i < 3; i++) {
new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(rootDir);
new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive(rootDir);
new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive(rootDir);
}
String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt");
String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabels_%d.txt");
SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
SequenceRecordReaderDataSetIterator iter =
new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false);
SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false);
SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ",");
featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
MultiDataSetIterator srrmdsi = new RecordReaderMultiDataSetIterator.Builder(1)
.addSequenceReader("seq1", featureReader2).addSequenceReader("seq2", labelReader2)
.addInput("seq1", 0, 1).addInput("seq1", 2, 2).addOutputOneHot("seq2", 0, 4).build();
MultiDataSetIterator srrmdsi = new RecordReaderMultiDataSetIterator.Builder(1).addSequenceReader("seq1", featureReader2).addSequenceReader("seq2", labelReader2).addInput("seq1", 0, 1).addInput("seq1", 2, 2).addOutputOneHot("seq2", 0, 4).build();
while (iter.hasNext()) {
DataSet ds = iter.next();
INDArray fds = ds.getFeatures();
INDArray lds = ds.getLabels();
MultiDataSet mds = srrmdsi.next();
assertEquals(2, mds.getFeatures().length);
assertEquals(1, mds.getLabels().length);
@ -300,17 +256,12 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest {
assertNull(mds.getLabelsMaskArrays());
INDArray[] fmds = mds.getFeatures();
INDArray[] lmds = mds.getLabels();
assertNotNull(fmds);
assertNotNull(lmds);
for (int i = 0; i < fmds.length; i++)
assertNotNull(fmds[i]);
for (int i = 0; i < lmds.length; i++)
assertNotNull(lmds[i]);
for (int i = 0; i < fmds.length; i++) assertNotNull(fmds[i]);
for (int i = 0; i < lmds.length; i++) assertNotNull(lmds[i]);
INDArray expIn1 = fds.get(all(), NDArrayIndex.interval(0, 1, true), all());
INDArray expIn2 = fds.get(all(), NDArrayIndex.interval(2, 2, true), all());
assertEquals(expIn1, fmds[0]);
assertEquals(expIn2, fmds[1]);
assertEquals(lds, lmds[0]);
@ -319,36 +270,29 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest {
}
@Test
public void testSplittingCSVSequenceMeta() throws Exception {
//Idea: take CSV sequences, and split "csvsequence_i.txt" into two separate inputs; keep "csvSequencelables_i.txt"
@DisplayName("Test Splitting CSV Sequence Meta")
void testSplittingCSVSequenceMeta() throws Exception {
// Idea: take CSV sequences, and split "csvsequence_i.txt" into two separate inputs; keep "csvSequencelables_i.txt"
// as standard one-hot output
//need to manually extract
File rootDir = temporaryFolder.newFolder();
// need to manually extract
File rootDir = temporaryFolder.toFile();
for (int i = 0; i < 3; i++) {
new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(rootDir);
new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive(rootDir);
new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive(rootDir);
}
String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt");
String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabels_%d.txt");
SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ",");
featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
RecordReaderMultiDataSetIterator srrmdsi = new RecordReaderMultiDataSetIterator.Builder(1)
.addSequenceReader("seq1", featureReader2).addSequenceReader("seq2", labelReader2)
.addInput("seq1", 0, 1).addInput("seq1", 2, 2).addOutputOneHot("seq2", 0, 4).build();
RecordReaderMultiDataSetIterator srrmdsi = new RecordReaderMultiDataSetIterator.Builder(1).addSequenceReader("seq1", featureReader2).addSequenceReader("seq2", labelReader2).addInput("seq1", 0, 1).addInput("seq1", 2, 2).addOutputOneHot("seq2", 0, 4).build();
srrmdsi.setCollectMetaData(true);
int count = 0;
while (srrmdsi.hasNext()) {
MultiDataSet mds = srrmdsi.next();
@ -359,34 +303,27 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest {
assertEquals(3, count);
}
@Test
public void testInputValidation() {
//Test: no readers
@DisplayName("Test Input Validation")
void testInputValidation() {
// Test: no readers
try {
MultiDataSetIterator r = new RecordReaderMultiDataSetIterator.Builder(1).addInput("something")
.addOutput("something").build();
MultiDataSetIterator r = new RecordReaderMultiDataSetIterator.Builder(1).addInput("something").addOutput("something").build();
fail("Should have thrown exception");
} catch (Exception e) {
}
//Test: reference to reader that doesn't exist
// Test: reference to reader that doesn't exist
try {
RecordReader rr = new CSVRecordReader(0, ',');
rr.initialize(new FileSplit(Resources.asFile("iris.txt")));
MultiDataSetIterator r = new RecordReaderMultiDataSetIterator.Builder(1).addReader("iris", rr)
.addInput("thisDoesntExist", 0, 3).addOutputOneHot("iris", 4, 3).build();
MultiDataSetIterator r = new RecordReaderMultiDataSetIterator.Builder(1).addReader("iris", rr).addInput("thisDoesntExist", 0, 3).addOutputOneHot("iris", 4, 3).build();
fail("Should have thrown exception");
} catch (Exception e) {
}
//Test: no inputs or outputs
// Test: no inputs or outputs
try {
RecordReader rr = new CSVRecordReader(0, ',');
rr.initialize(new FileSplit(Resources.asFile("iris.txt")));
MultiDataSetIterator r = new RecordReaderMultiDataSetIterator.Builder(1).addReader("iris", rr).build();
fail("Should have thrown exception");
} catch (Exception e) {
@ -394,81 +331,55 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest {
}
@Test
public void testVariableLengthTS() throws Exception {
//need to manually extract
File rootDir = temporaryFolder.newFolder();
@DisplayName("Test Variable Length TS")
void testVariableLengthTS() throws Exception {
// need to manually extract
File rootDir = temporaryFolder.toFile();
for (int i = 0; i < 3; i++) {
new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(rootDir);
new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive(rootDir);
new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive(rootDir);
}
String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt");
String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabelsShort_%d.txt");
//Set up SequenceRecordReaderDataSetIterators for comparison
// Set up SequenceRecordReaderDataSetIterators for comparison
SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ",");
featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
SequenceRecordReaderDataSetIterator iterAlignStart = new SequenceRecordReaderDataSetIterator(featureReader,
labelReader, 1, 4, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_START);
SequenceRecordReaderDataSetIterator iterAlignEnd = new SequenceRecordReaderDataSetIterator(featureReader2,
labelReader2, 1, 4, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
//Set up
SequenceRecordReaderDataSetIterator iterAlignStart = new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_START);
SequenceRecordReaderDataSetIterator iterAlignEnd = new SequenceRecordReaderDataSetIterator(featureReader2, labelReader2, 1, 4, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
// Set up
SequenceRecordReader featureReader3 = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader3 = new CSVSequenceRecordReader(1, ",");
featureReader3.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader3.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
SequenceRecordReader featureReader4 = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader4 = new CSVSequenceRecordReader(1, ",");
featureReader4.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader4.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
RecordReaderMultiDataSetIterator rrmdsiStart = new RecordReaderMultiDataSetIterator.Builder(1)
.addSequenceReader("in", featureReader3).addSequenceReader("out", labelReader3).addInput("in")
.addOutputOneHot("out", 0, 4)
.sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_START).build();
RecordReaderMultiDataSetIterator rrmdsiEnd = new RecordReaderMultiDataSetIterator.Builder(1)
.addSequenceReader("in", featureReader4).addSequenceReader("out", labelReader4).addInput("in")
.addOutputOneHot("out", 0, 4)
.sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_END).build();
RecordReaderMultiDataSetIterator rrmdsiStart = new RecordReaderMultiDataSetIterator.Builder(1).addSequenceReader("in", featureReader3).addSequenceReader("out", labelReader3).addInput("in").addOutputOneHot("out", 0, 4).sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_START).build();
RecordReaderMultiDataSetIterator rrmdsiEnd = new RecordReaderMultiDataSetIterator.Builder(1).addSequenceReader("in", featureReader4).addSequenceReader("out", labelReader4).addInput("in").addOutputOneHot("out", 0, 4).sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_END).build();
while (iterAlignStart.hasNext()) {
DataSet dsStart = iterAlignStart.next();
DataSet dsEnd = iterAlignEnd.next();
MultiDataSet mdsStart = rrmdsiStart.next();
MultiDataSet mdsEnd = rrmdsiEnd.next();
assertEquals(1, mdsStart.getFeatures().length);
assertEquals(1, mdsStart.getLabels().length);
//assertEquals(1, mdsStart.getFeaturesMaskArrays().length); //Features data is always longer -> don't need mask arrays for it
// assertEquals(1, mdsStart.getFeaturesMaskArrays().length); //Features data is always longer -> don't need mask arrays for it
assertEquals(1, mdsStart.getLabelsMaskArrays().length);
assertEquals(1, mdsEnd.getFeatures().length);
assertEquals(1, mdsEnd.getLabels().length);
//assertEquals(1, mdsEnd.getFeaturesMaskArrays().length);
// assertEquals(1, mdsEnd.getFeaturesMaskArrays().length);
assertEquals(1, mdsEnd.getLabelsMaskArrays().length);
assertEquals(dsStart.getFeatures(), mdsStart.getFeatures(0));
assertEquals(dsStart.getLabels(), mdsStart.getLabels(0));
assertEquals(dsStart.getLabelsMaskArray(), mdsStart.getLabelsMaskArray(0));
assertEquals(dsEnd.getFeatures(), mdsEnd.getFeatures(0));
assertEquals(dsEnd.getLabels(), mdsEnd.getLabels(0));
assertEquals(dsEnd.getLabelsMaskArray(), mdsEnd.getLabelsMaskArray(0));
@ -477,57 +388,40 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest {
assertFalse(rrmdsiEnd.hasNext());
}
@Test
public void testVariableLengthTSMeta() throws Exception {
//need to manually extract
File rootDir = temporaryFolder.newFolder();
@DisplayName("Test Variable Length TS Meta")
void testVariableLengthTSMeta() throws Exception {
// need to manually extract
File rootDir = temporaryFolder.toFile();
for (int i = 0; i < 3; i++) {
new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(rootDir);
new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive(rootDir);
new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive(rootDir);
}
//Set up SequenceRecordReaderDataSetIterators for comparison
// Set up SequenceRecordReaderDataSetIterators for comparison
String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt");
String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabelsShort_%d.txt");
//Set up
// Set up
SequenceRecordReader featureReader3 = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader3 = new CSVSequenceRecordReader(1, ",");
featureReader3.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader3.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
SequenceRecordReader featureReader4 = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader4 = new CSVSequenceRecordReader(1, ",");
featureReader4.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader4.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
RecordReaderMultiDataSetIterator rrmdsiStart = new RecordReaderMultiDataSetIterator.Builder(1)
.addSequenceReader("in", featureReader3).addSequenceReader("out", labelReader3).addInput("in")
.addOutputOneHot("out", 0, 4)
.sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_START).build();
RecordReaderMultiDataSetIterator rrmdsiEnd = new RecordReaderMultiDataSetIterator.Builder(1)
.addSequenceReader("in", featureReader4).addSequenceReader("out", labelReader4).addInput("in")
.addOutputOneHot("out", 0, 4)
.sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_END).build();
RecordReaderMultiDataSetIterator rrmdsiStart = new RecordReaderMultiDataSetIterator.Builder(1).addSequenceReader("in", featureReader3).addSequenceReader("out", labelReader3).addInput("in").addOutputOneHot("out", 0, 4).sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_START).build();
RecordReaderMultiDataSetIterator rrmdsiEnd = new RecordReaderMultiDataSetIterator.Builder(1).addSequenceReader("in", featureReader4).addSequenceReader("out", labelReader4).addInput("in").addOutputOneHot("out", 0, 4).sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_END).build();
rrmdsiStart.setCollectMetaData(true);
rrmdsiEnd.setCollectMetaData(true);
int count = 0;
while (rrmdsiStart.hasNext()) {
MultiDataSet mdsStart = rrmdsiStart.next();
MultiDataSet mdsEnd = rrmdsiEnd.next();
MultiDataSet mdsStartFromMeta =
rrmdsiStart.loadFromMetaData(mdsStart.getExampleMetaData(RecordMetaData.class));
MultiDataSet mdsStartFromMeta = rrmdsiStart.loadFromMetaData(mdsStart.getExampleMetaData(RecordMetaData.class));
MultiDataSet mdsEndFromMeta = rrmdsiEnd.loadFromMetaData(mdsEnd.getExampleMetaData(RecordMetaData.class));
assertEquals(mdsStart, mdsStartFromMeta);
assertEquals(mdsEnd, mdsEndFromMeta);
count++;
}
assertFalse(rrmdsiStart.hasNext());
@ -536,53 +430,37 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest {
}
@Test
public void testImagesRRDMSI() throws Exception {
File parentDir = temporaryFolder.newFolder();
@DisplayName("Test Images RRDMSI")
void testImagesRRDMSI() throws Exception {
File parentDir = temporaryFolder.toFile();
parentDir.deleteOnExit();
String str1 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Zico/");
String str2 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Ziwang_Xu/");
File f1 = new File(str1);
File f2 = new File(str2);
f1.mkdirs();
f2.mkdirs();
TestUtils.writeStreamToFile(new File(FilenameUtils.concat(f1.getPath(), "Zico_0001.jpg")),
new ClassPathResource("lfwtest/Zico/Zico_0001.jpg").getInputStream());
TestUtils.writeStreamToFile(new File(FilenameUtils.concat(f2.getPath(), "Ziwang_Xu_0001.jpg")),
new ClassPathResource("lfwtest/Ziwang_Xu/Ziwang_Xu_0001.jpg").getInputStream());
TestUtils.writeStreamToFile(new File(FilenameUtils.concat(f1.getPath(), "Zico_0001.jpg")), new ClassPathResource("lfwtest/Zico/Zico_0001.jpg").getInputStream());
TestUtils.writeStreamToFile(new File(FilenameUtils.concat(f2.getPath(), "Ziwang_Xu_0001.jpg")), new ClassPathResource("lfwtest/Ziwang_Xu/Ziwang_Xu_0001.jpg").getInputStream());
int outputNum = 2;
Random r = new Random(12345);
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
ImageRecordReader rr1 = new ImageRecordReader(10, 10, 1, labelMaker);
ImageRecordReader rr1s = new ImageRecordReader(5, 5, 1, labelMaker);
rr1.initialize(new FileSplit(parentDir));
rr1s.initialize(new FileSplit(parentDir));
MultiDataSetIterator trainDataIterator = new RecordReaderMultiDataSetIterator.Builder(1).addReader("rr1", rr1)
.addReader("rr1s", rr1s).addInput("rr1", 0, 0).addInput("rr1s", 0, 0)
.addOutputOneHot("rr1s", 1, outputNum).build();
//Now, do the same thing with ImageRecordReader, and check we get the same results:
MultiDataSetIterator trainDataIterator = new RecordReaderMultiDataSetIterator.Builder(1).addReader("rr1", rr1).addReader("rr1s", rr1s).addInput("rr1", 0, 0).addInput("rr1s", 0, 0).addOutputOneHot("rr1s", 1, outputNum).build();
// Now, do the same thing with ImageRecordReader, and check we get the same results:
ImageRecordReader rr1_b = new ImageRecordReader(10, 10, 1, labelMaker);
ImageRecordReader rr1s_b = new ImageRecordReader(5, 5, 1, labelMaker);
rr1_b.initialize(new FileSplit(parentDir));
rr1s_b.initialize(new FileSplit(parentDir));
DataSetIterator dsi1 = new RecordReaderDataSetIterator(rr1_b, 1, 1, 2);
DataSetIterator dsi2 = new RecordReaderDataSetIterator(rr1s_b, 1, 1, 2);
for (int i = 0; i < 2; i++) {
MultiDataSet mds = trainDataIterator.next();
DataSet d1 = dsi1.next();
DataSet d2 = dsi2.next();
assertEquals(d1.getFeatures(), mds.getFeatures(0));
assertEquals(d2.getFeatures(), mds.getFeatures(1));
assertEquals(d1.getLabels(), mds.getLabels(0));
@ -590,261 +468,180 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest {
}
@Test
public void testImagesRRDMSI_Batched() throws Exception {
File parentDir = temporaryFolder.newFolder();
@DisplayName("Test Images RRDMSI _ Batched")
void testImagesRRDMSI_Batched() throws Exception {
File parentDir = temporaryFolder.toFile();
parentDir.deleteOnExit();
String str1 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Zico/");
String str2 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Ziwang_Xu/");
File f1 = new File(str1);
File f2 = new File(str2);
f1.mkdirs();
f2.mkdirs();
TestUtils.writeStreamToFile(new File(FilenameUtils.concat(f1.getPath(), "Zico_0001.jpg")),
new ClassPathResource("lfwtest/Zico/Zico_0001.jpg").getInputStream());
TestUtils.writeStreamToFile(new File(FilenameUtils.concat(f2.getPath(), "Ziwang_Xu_0001.jpg")),
new ClassPathResource("lfwtest/Ziwang_Xu/Ziwang_Xu_0001.jpg").getInputStream());
TestUtils.writeStreamToFile(new File(FilenameUtils.concat(f1.getPath(), "Zico_0001.jpg")), new ClassPathResource("lfwtest/Zico/Zico_0001.jpg").getInputStream());
TestUtils.writeStreamToFile(new File(FilenameUtils.concat(f2.getPath(), "Ziwang_Xu_0001.jpg")), new ClassPathResource("lfwtest/Ziwang_Xu/Ziwang_Xu_0001.jpg").getInputStream());
int outputNum = 2;
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
ImageRecordReader rr1 = new ImageRecordReader(10, 10, 1, labelMaker);
ImageRecordReader rr1s = new ImageRecordReader(5, 5, 1, labelMaker);
URI[] uris = new FileSplit(parentDir).locations();
rr1.initialize(new CollectionInputSplit(uris));
rr1s.initialize(new CollectionInputSplit(uris));
MultiDataSetIterator trainDataIterator = new RecordReaderMultiDataSetIterator.Builder(2).addReader("rr1", rr1)
.addReader("rr1s", rr1s).addInput("rr1", 0, 0).addInput("rr1s", 0, 0)
.addOutputOneHot("rr1s", 1, outputNum).build();
//Now, do the same thing with ImageRecordReader, and check we get the same results:
MultiDataSetIterator trainDataIterator = new RecordReaderMultiDataSetIterator.Builder(2).addReader("rr1", rr1).addReader("rr1s", rr1s).addInput("rr1", 0, 0).addInput("rr1s", 0, 0).addOutputOneHot("rr1s", 1, outputNum).build();
// Now, do the same thing with ImageRecordReader, and check we get the same results:
ImageRecordReader rr1_b = new ImageRecordReader(10, 10, 1, labelMaker);
ImageRecordReader rr1s_b = new ImageRecordReader(5, 5, 1, labelMaker);
rr1_b.initialize(new FileSplit(parentDir));
rr1s_b.initialize(new FileSplit(parentDir));
DataSetIterator dsi1 = new RecordReaderDataSetIterator(rr1_b, 2, 1, 2);
DataSetIterator dsi2 = new RecordReaderDataSetIterator(rr1s_b, 2, 1, 2);
MultiDataSet mds = trainDataIterator.next();
DataSet d1 = dsi1.next();
DataSet d2 = dsi2.next();
assertEquals(d1.getFeatures(), mds.getFeatures(0));
assertEquals(d2.getFeatures(), mds.getFeatures(1));
assertEquals(d1.getLabels(), mds.getLabels(0));
//Check label assignment:
// Check label assignment:
File currentFile = rr1_b.getCurrentFile();
INDArray expLabels;
if(currentFile.getAbsolutePath().contains("Zico")){
expLabels = Nd4j.create(new double[][] {{0, 1}, {1, 0}});
if (currentFile.getAbsolutePath().contains("Zico")) {
expLabels = Nd4j.create(new double[][] { { 0, 1 }, { 1, 0 } });
} else {
expLabels = Nd4j.create(new double[][] {{1, 0}, {0, 1}});
expLabels = Nd4j.create(new double[][] { { 1, 0 }, { 0, 1 } });
}
assertEquals(expLabels, d1.getLabels());
assertEquals(expLabels, d2.getLabels());
}
@Test
public void testTimeSeriesRandomOffset() {
//2 in, 2 out, 3 total sequences of length [1,3,5]
List<List<Writable>> seq1 =
Arrays.asList(Arrays.<Writable>asList(new DoubleWritable(1.0), new DoubleWritable(2.0)));
List<List<Writable>> seq2 =
Arrays.asList(Arrays.<Writable>asList(new DoubleWritable(10.0), new DoubleWritable(11.0)),
Arrays.<Writable>asList(new DoubleWritable(20.0), new DoubleWritable(21.0)),
Arrays.<Writable>asList(new DoubleWritable(30.0), new DoubleWritable(31.0)));
List<List<Writable>> seq3 =
Arrays.asList(Arrays.<Writable>asList(new DoubleWritable(100.0), new DoubleWritable(101.0)),
Arrays.<Writable>asList(new DoubleWritable(200.0), new DoubleWritable(201.0)),
Arrays.<Writable>asList(new DoubleWritable(300.0), new DoubleWritable(301.0)),
Arrays.<Writable>asList(new DoubleWritable(400.0), new DoubleWritable(401.0)),
Arrays.<Writable>asList(new DoubleWritable(500.0), new DoubleWritable(501.0)));
@DisplayName("Test Time Series Random Offset")
void testTimeSeriesRandomOffset() {
// 2 in, 2 out, 3 total sequences of length [1,3,5]
List<List<Writable>> seq1 = Arrays.asList(Arrays.<Writable>asList(new DoubleWritable(1.0), new DoubleWritable(2.0)));
List<List<Writable>> seq2 = Arrays.asList(Arrays.<Writable>asList(new DoubleWritable(10.0), new DoubleWritable(11.0)), Arrays.<Writable>asList(new DoubleWritable(20.0), new DoubleWritable(21.0)), Arrays.<Writable>asList(new DoubleWritable(30.0), new DoubleWritable(31.0)));
List<List<Writable>> seq3 = Arrays.asList(Arrays.<Writable>asList(new DoubleWritable(100.0), new DoubleWritable(101.0)), Arrays.<Writable>asList(new DoubleWritable(200.0), new DoubleWritable(201.0)), Arrays.<Writable>asList(new DoubleWritable(300.0), new DoubleWritable(301.0)), Arrays.<Writable>asList(new DoubleWritable(400.0), new DoubleWritable(401.0)), Arrays.<Writable>asList(new DoubleWritable(500.0), new DoubleWritable(501.0)));
Collection<List<List<Writable>>> seqs = Arrays.asList(seq1, seq2, seq3);
SequenceRecordReader rr = new CollectionSequenceRecordReader(seqs);
RecordReaderMultiDataSetIterator rrmdsi =
new RecordReaderMultiDataSetIterator.Builder(3).addSequenceReader("rr", rr).addInput("rr", 0, 0)
.addOutput("rr", 1, 1).timeSeriesRandomOffset(true, 1234L).build();
Random r = new Random(1234); //Provides seed for each minibatch
RecordReaderMultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(3).addSequenceReader("rr", rr).addInput("rr", 0, 0).addOutput("rr", 1, 1).timeSeriesRandomOffset(true, 1234L).build();
// Provides seed for each minibatch
Random r = new Random(1234);
long seed = r.nextLong();
Random r2 = new Random(seed); //Use same RNG seed in new RNG for each minibatch
int expOffsetSeq1 = r2.nextInt(5 - 1 + 1); //0 to 4 inclusive
// Use same RNG seed in new RNG for each minibatch
Random r2 = new Random(seed);
// 0 to 4 inclusive
int expOffsetSeq1 = r2.nextInt(5 - 1 + 1);
int expOffsetSeq2 = r2.nextInt(5 - 3 + 1);
int expOffsetSeq3 = 0; //Longest TS, always 0
//With current seed: 3, 1, 0
// System.out.println(expOffsetSeq1 + "\t" + expOffsetSeq2 + "\t" + expOffsetSeq3);
// Longest TS, always 0
int expOffsetSeq3 = 0;
// With current seed: 3, 1, 0
// System.out.println(expOffsetSeq1 + "\t" + expOffsetSeq2 + "\t" + expOffsetSeq3);
MultiDataSet mds = rrmdsi.next();
INDArray expMask = Nd4j.create(new double[][] {{0, 0, 0, 1, 0}, {0, 1, 1, 1, 0}, {1, 1, 1, 1, 1}});
INDArray expMask = Nd4j.create(new double[][] { { 0, 0, 0, 1, 0 }, { 0, 1, 1, 1, 0 }, { 1, 1, 1, 1, 1 } });
assertEquals(expMask, mds.getFeaturesMaskArray(0));
assertEquals(expMask, mds.getLabelsMaskArray(0));
INDArray f = mds.getFeatures(0);
INDArray l = mds.getLabels(0);
INDArray expF1 = Nd4j.create(new double[] {1.0}, new int[]{1,1});
INDArray expL1 = Nd4j.create(new double[] {2.0}, new int[]{1,1});
INDArray expF2 = Nd4j.create(new double[] {10, 20, 30}, new int[]{1,3});
INDArray expL2 = Nd4j.create(new double[] {11, 21, 31}, new int[]{1,3});
INDArray expF3 = Nd4j.create(new double[] {100, 200, 300, 400, 500}, new int[]{1,5});
INDArray expL3 = Nd4j.create(new double[] {101, 201, 301, 401, 501}, new int[]{1,5});
assertEquals(expF1, f.get(point(0), all(),
NDArrayIndex.interval(expOffsetSeq1, expOffsetSeq1 + 1)));
assertEquals(expL1, l.get(point(0), all(),
NDArrayIndex.interval(expOffsetSeq1, expOffsetSeq1 + 1)));
assertEquals(expF2, f.get(point(1), all(),
NDArrayIndex.interval(expOffsetSeq2, expOffsetSeq2 + 3)));
assertEquals(expL2, l.get(point(1), all(),
NDArrayIndex.interval(expOffsetSeq2, expOffsetSeq2 + 3)));
assertEquals(expF3, f.get(point(2), all(),
NDArrayIndex.interval(expOffsetSeq3, expOffsetSeq3 + 5)));
assertEquals(expL3, l.get(point(2), all(),
NDArrayIndex.interval(expOffsetSeq3, expOffsetSeq3 + 5)));
INDArray expF1 = Nd4j.create(new double[] { 1.0 }, new int[] { 1, 1 });
INDArray expL1 = Nd4j.create(new double[] { 2.0 }, new int[] { 1, 1 });
INDArray expF2 = Nd4j.create(new double[] { 10, 20, 30 }, new int[] { 1, 3 });
INDArray expL2 = Nd4j.create(new double[] { 11, 21, 31 }, new int[] { 1, 3 });
INDArray expF3 = Nd4j.create(new double[] { 100, 200, 300, 400, 500 }, new int[] { 1, 5 });
INDArray expL3 = Nd4j.create(new double[] { 101, 201, 301, 401, 501 }, new int[] { 1, 5 });
assertEquals(expF1, f.get(point(0), all(), NDArrayIndex.interval(expOffsetSeq1, expOffsetSeq1 + 1)));
assertEquals(expL1, l.get(point(0), all(), NDArrayIndex.interval(expOffsetSeq1, expOffsetSeq1 + 1)));
assertEquals(expF2, f.get(point(1), all(), NDArrayIndex.interval(expOffsetSeq2, expOffsetSeq2 + 3)));
assertEquals(expL2, l.get(point(1), all(), NDArrayIndex.interval(expOffsetSeq2, expOffsetSeq2 + 3)));
assertEquals(expF3, f.get(point(2), all(), NDArrayIndex.interval(expOffsetSeq3, expOffsetSeq3 + 5)));
assertEquals(expL3, l.get(point(2), all(), NDArrayIndex.interval(expOffsetSeq3, expOffsetSeq3 + 5)));
}
@Test
public void testSeqRRDSIMasking(){
//This also tests RecordReaderMultiDataSetIterator, by virtue of
@DisplayName("Test Seq RRDSI Masking")
void testSeqRRDSIMasking() {
// This also tests RecordReaderMultiDataSetIterator, by virtue of
List<List<List<Writable>>> features = new ArrayList<>();
List<List<List<Writable>>> labels = new ArrayList<>();
features.add(Arrays.asList(l(new DoubleWritable(1)), l(new DoubleWritable(2)), l(new DoubleWritable(3))));
features.add(Arrays.asList(l(new DoubleWritable(4)), l(new DoubleWritable(5))));
labels.add(Arrays.asList(l(new IntWritable(0))));
labels.add(Arrays.asList(l(new IntWritable(1))));
CollectionSequenceRecordReader fR = new CollectionSequenceRecordReader(features);
CollectionSequenceRecordReader lR = new CollectionSequenceRecordReader(labels);
SequenceRecordReaderDataSetIterator seqRRDSI = new SequenceRecordReaderDataSetIterator(
fR, lR, 2, 2, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
SequenceRecordReaderDataSetIterator seqRRDSI = new SequenceRecordReaderDataSetIterator(fR, lR, 2, 2, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
DataSet ds = seqRRDSI.next();
INDArray fMask = Nd4j.create(new double[][]{
{1,1,1},
{1,1,0}});
INDArray lMask = Nd4j.create(new double[][]{
{0,0,1},
{0,1,0}});
INDArray fMask = Nd4j.create(new double[][] { { 1, 1, 1 }, { 1, 1, 0 } });
INDArray lMask = Nd4j.create(new double[][] { { 0, 0, 1 }, { 0, 1, 0 } });
assertEquals(fMask, ds.getFeaturesMaskArray());
assertEquals(lMask, ds.getLabelsMaskArray());
INDArray f = Nd4j.create(new double[][]{
{1,2,3},
{4,5,0}});
INDArray l = Nd4j.create(2,2,3);
l.putScalar(0,0,2, 1.0);
l.putScalar(1,1,1, 1.0);
INDArray f = Nd4j.create(new double[][] { { 1, 2, 3 }, { 4, 5, 0 } });
INDArray l = Nd4j.create(2, 2, 3);
l.putScalar(0, 0, 2, 1.0);
l.putScalar(1, 1, 1, 1.0);
assertEquals(f, ds.getFeatures().get(all(), point(0), all()));
assertEquals(l, ds.getLabels());
}
private static List<Writable> l(Writable... in){
private static List<Writable> l(Writable... in) {
return Arrays.asList(in);
}
@Test
public void testExcludeStringColCSV() throws Exception {
File csvFile = temporaryFolder.newFile();
@DisplayName("Test Exclude String Col CSV")
void testExcludeStringColCSV() throws Exception {
File csvFile = temporaryFolder.toFile();
StringBuilder sb = new StringBuilder();
for(int i=1; i<=10; i++ ){
if(i > 1){
for (int i = 1; i <= 10; i++) {
if (i > 1) {
sb.append("\n");
}
sb.append("skip_").append(i).append(",").append(i).append(",").append(i + 0.5);
}
FileUtils.writeStringToFile(csvFile, sb.toString());
RecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(csvFile));
RecordReaderMultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10)
.addReader("rr", rr)
.addInput("rr", 1, 1)
.addOutput("rr", 2, 2)
.build();
INDArray expFeatures = Nd4j.linspace(1,10,10).reshape(1,10).transpose();
INDArray expLabels = Nd4j.linspace(1,10,10).addi(0.5).reshape(1,10).transpose();
RecordReaderMultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10).addReader("rr", rr).addInput("rr", 1, 1).addOutput("rr", 2, 2).build();
INDArray expFeatures = Nd4j.linspace(1, 10, 10).reshape(1, 10).transpose();
INDArray expLabels = Nd4j.linspace(1, 10, 10).addi(0.5).reshape(1, 10).transpose();
MultiDataSet mds = rrmdsi.next();
assertFalse(rrmdsi.hasNext());
assertEquals(expFeatures, mds.getFeatures(0).castTo(expFeatures.dataType()));
assertEquals(expLabels, mds.getLabels(0).castTo(expLabels.dataType()));
}
private static final int nX = 32;
private static final int nY = 32;
private static final int nZ = 28;
@Test
public void testRRMDSI5D() {
@DisplayName("Test RRMDSI 5 D")
void testRRMDSI5D() {
int batchSize = 5;
CustomRecordReader recordReader = new CustomRecordReader();
DataSetIterator dataIter = new RecordReaderDataSetIterator(recordReader, batchSize,
1, /* Index of label in records */
2 /* number of different labels */);
DataSetIterator dataIter = new RecordReaderDataSetIterator(recordReader, batchSize, 1, /* Index of label in records */
2);
int count = 0;
while(dataIter.hasNext()){
while (dataIter.hasNext()) {
DataSet ds = dataIter.next();
int offset = 5*count;
for( int i=0; i<5; i++ ){
INDArray act = ds.getFeatures().get(interval(i,i,true), all(), all(), all(), all());
INDArray exp = Nd4j.valueArrayOf(new int[]{1, 1, nZ, nX, nY}, i + offset );
int offset = 5 * count;
for (int i = 0; i < 5; i++) {
INDArray act = ds.getFeatures().get(interval(i, i, true), all(), all(), all(), all());
INDArray exp = Nd4j.valueArrayOf(new int[] { 1, 1, nZ, nX, nY }, i + offset);
assertEquals(exp, act);
}
count++;
}
assertEquals(2, count);
}
@DisplayName("Custom Record Reader")
static class CustomRecordReader extends BaseRecordReader {
int n = 0;
CustomRecordReader() { }
CustomRecordReader() {
}
@Override
public boolean batchesSupported() {
@ -858,8 +655,8 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest {
@Override
public List<Writable> next() {
INDArray nd = Nd4j.create(new float[nZ*nY*nX], new int[] {1, 1, nZ, nY, nX }, 'c').assign(n);
final List<Writable>res = RecordConverter.toRecord(nd);
INDArray nd = Nd4j.create(new float[nZ * nY * nX], new int[] { 1, 1, nZ, nY, nX }, 'c').assign(n);
final List<Writable> res = RecordConverter.toRecord(nd);
res.add(new IntWritable(0));
n++;
return res;
@ -867,14 +664,16 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest {
@Override
public boolean hasNext() {
return n<10;
return n < 10;
}
final static ArrayList<String> labels = new ArrayList<>(2);
static {
labels.add("lbl0");
labels.add("lbl1");
}
@Override
public List<String> getLabels() {
return labels;
@ -928,6 +727,7 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest {
public void initialize(InputSplit split) {
n = 0;
}
@Override
public void initialize(Configuration conf, InputSplit split) {
n = 0;

View File

@ -17,38 +17,39 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.datasets.fetchers;
import org.deeplearning4j.BaseDL4JTest;
import org.junit.Rule;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.junit.rules.Timeout;
import java.io.File;
import static org.junit.Assert.assertTrue;
import static org.junit.Assume.assumeTrue;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assumptions.assumeTrue;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
/**
* @author saudet
*/
public class SvhnDataFetcherTest extends BaseDL4JTest {
@DisplayName("Svhn Data Fetcher Test")
class SvhnDataFetcherTest extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 480_000_000L; //Shouldn't take this long but slow download or drive access on CI machines may need extra time.
// Shouldn't take this long but slow download or drive access on CI machines may need extra time.
return 480_000_000L;
}
@Test
public void testSvhnDataFetcher() throws Exception {
assumeTrue(isIntegrationTests()); //Ignore unless integration tests - CI can get caught up on slow disk access
@DisplayName("Test Svhn Data Fetcher")
void testSvhnDataFetcher() throws Exception {
// Ignore unless integration tests - CI can get caught up on slow disk access
assumeTrue(isIntegrationTests());
SvhnDataFetcher fetch = new SvhnDataFetcher();
File path = fetch.getDataSetPath(DataSetType.TRAIN);
File path2 = fetch.getDataSetPath(DataSetType.TEST);
File path3 = fetch.getDataSetPath(DataSetType.VALIDATION);
assertTrue(path.isDirectory());
assertTrue(path2.isDirectory());
assertTrue(path3.isDirectory());

View File

@ -17,52 +17,50 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.datasets.iterator;
import org.apache.commons.lang3.RandomUtils;
import org.deeplearning4j.BaseDL4JTest;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.common.primitives.Pair;
import java.util.Iterator;
import java.util.concurrent.atomic.AtomicInteger;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
@DisplayName("Abstract Data Set Iterator Test")
class AbstractDataSetIteratorTest extends BaseDL4JTest {
public class AbstractDataSetIteratorTest extends BaseDL4JTest {
@Test
public void next() throws Exception {
@DisplayName("Next")
void next() throws Exception {
int numFeatures = 128;
int batchSize = 10;
int numRows = 1000;
AtomicInteger cnt = new AtomicInteger(0);
FloatsDataSetIterator iterator = new FloatsDataSetIterator(floatIterable(numRows, numFeatures), batchSize);
assertTrue(iterator.hasNext());
while (iterator.hasNext()) {
DataSet dataSet = iterator.next();
INDArray features = dataSet.getFeatures();
assertEquals(batchSize, features.rows());
assertEquals(numFeatures, features.columns());
cnt.incrementAndGet();
}
assertEquals(numRows / batchSize, cnt.get());
}
protected static Iterable<Pair<float[], float[]>> floatIterable(final int totalRows, final int numColumns) {
return new Iterable<Pair<float[], float[]>>() {
@Override
public Iterator<Pair<float[], float[]>> iterator() {
return new Iterator<Pair<float[], float[]>>() {
private AtomicInteger cnt = new AtomicInteger(0);
@Override
@ -72,8 +70,8 @@ public class AbstractDataSetIteratorTest extends BaseDL4JTest {
@Override
public Pair<float[], float[]> next() {
float features[] = new float[numColumns];
float labels[] = new float[numColumns];
float[] features = new float[numColumns];
float[] labels = new float[numColumns];
for (int i = 0; i < numColumns; i++) {
features[i] = (float) i;
labels[i] = RandomUtils.nextFloat(0, 5);

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.datasets.iterator;
import lombok.extern.slf4j.Slf4j;
@ -25,117 +24,118 @@ import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.callbacks.InterleavedDataSetCallback;
import org.deeplearning4j.datasets.iterator.tools.VariableTimeseriesGenerator;
import org.deeplearning4j.nn.util.TestDataSetConsumer;
import org.junit.Before;
import org.junit.Test;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.atomic.AtomicLong;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.jupiter.api.Assertions.assertThrows;
@Slf4j
public class AsyncDataSetIteratorTest extends BaseDL4JTest {
@DisplayName("Async Data Set Iterator Test")
class AsyncDataSetIteratorTest extends BaseDL4JTest {
private ExistingDataSetIterator backIterator;
private static final int TEST_SIZE = 100;
private static final int ITERATIONS = 10;
// time spent in consumer thread, milliseconds
private static final long EXECUTION_TIME = 5;
private static final long EXECUTION_SMALL = 1;
@Before
public void setUp() throws Exception {
@BeforeEach
void setUp() throws Exception {
List<DataSet> iterable = new ArrayList<>();
for (int i = 0; i < TEST_SIZE; i++) {
iterable.add(new DataSet(Nd4j.create(new float[100]), Nd4j.create(new float[10])));
}
backIterator = new ExistingDataSetIterator(iterable);
}
@Test
public void hasNext1() throws Exception {
@DisplayName("Has Next 1")
void hasNext1() throws Exception {
for (int iter = 0; iter < ITERATIONS; iter++) {
for (int prefetchSize = 2; prefetchSize <= 8; prefetchSize++) {
AsyncDataSetIterator iterator = new AsyncDataSetIterator(backIterator, prefetchSize);
int cnt = 0;
while (iterator.hasNext()) {
DataSet ds = iterator.next();
assertNotEquals(null, ds);
cnt++;
}
assertEquals("Failed on iteration: " + iter + ", prefetchSize: " + prefetchSize, TEST_SIZE, cnt);
assertEquals( TEST_SIZE, cnt,"Failed on iteration: " + iter + ", prefetchSize: " + prefetchSize);
iterator.shutdown();
}
}
}
@Test
public void hasNextWithResetAndLoad() throws Exception {
@DisplayName("Has Next With Reset And Load")
void hasNextWithResetAndLoad() throws Exception {
int[] prefetchSizes;
if(isIntegrationTests()){
prefetchSizes = new int[]{2, 3, 4, 5, 6, 7, 8};
if (isIntegrationTests()) {
prefetchSizes = new int[] { 2, 3, 4, 5, 6, 7, 8 };
} else {
prefetchSizes = new int[]{2, 3, 8};
prefetchSizes = new int[] { 2, 3, 8 };
}
for (int iter = 0; iter < ITERATIONS; iter++) {
for(int prefetchSize : prefetchSizes){
for (int prefetchSize : prefetchSizes) {
AsyncDataSetIterator iterator = new AsyncDataSetIterator(backIterator, prefetchSize);
TestDataSetConsumer consumer = new TestDataSetConsumer(EXECUTION_SMALL);
int cnt = 0;
while (iterator.hasNext()) {
DataSet ds = iterator.next();
consumer.consumeOnce(ds, false);
cnt++;
if (cnt == TEST_SIZE / 2)
iterator.reset();
}
assertEquals(TEST_SIZE + (TEST_SIZE / 2), cnt);
iterator.shutdown();
}
}
}
@Test
public void testWithLoad() {
@DisplayName("Test With Load")
void testWithLoad() {
for (int iter = 0; iter < ITERATIONS; iter++) {
AsyncDataSetIterator iterator = new AsyncDataSetIterator(backIterator, 8);
TestDataSetConsumer consumer = new TestDataSetConsumer(iterator, EXECUTION_TIME);
consumer.consumeWhileHasNext(true);
assertEquals(TEST_SIZE, consumer.getCount());
iterator.shutdown();
}
}
@Test(expected = ArrayIndexOutOfBoundsException.class)
public void testWithException() {
ExistingDataSetIterator crashingIterator = new ExistingDataSetIterator(new IterableWithException(100));
AsyncDataSetIterator iterator = new AsyncDataSetIterator(crashingIterator, 8);
TestDataSetConsumer consumer = new TestDataSetConsumer(iterator, EXECUTION_SMALL);
consumer.consumeWhileHasNext(true);
iterator.shutdown();
@Test
@DisplayName("Test With Exception")
void testWithException() {
assertThrows(ArrayIndexOutOfBoundsException.class, () -> {
ExistingDataSetIterator crashingIterator = new ExistingDataSetIterator(new IterableWithException(100));
AsyncDataSetIterator iterator = new AsyncDataSetIterator(crashingIterator, 8);
TestDataSetConsumer consumer = new TestDataSetConsumer(iterator, EXECUTION_SMALL);
consumer.consumeWhileHasNext(true);
iterator.shutdown();
});
}
@DisplayName("Iterable With Exception")
private class IterableWithException implements Iterable<DataSet> {
private final AtomicLong counter = new AtomicLong(0);
private final int crashIteration;
public IterableWithException(int iteration) {
@ -146,6 +146,7 @@ public class AsyncDataSetIteratorTest extends BaseDL4JTest {
public Iterator<DataSet> iterator() {
counter.set(0);
return new Iterator<DataSet>() {
@Override
public boolean hasNext() {
return true;
@ -155,82 +156,59 @@ public class AsyncDataSetIteratorTest extends BaseDL4JTest {
public DataSet next() {
if (counter.incrementAndGet() >= crashIteration)
throw new ArrayIndexOutOfBoundsException("Thrown as expected");
return new DataSet(Nd4j.create(10), Nd4j.create(10));
}
@Override
public void remove() {
}
};
}
}
@Test
public void testVariableTimeSeries1() throws Exception {
@DisplayName("Test Variable Time Series 1")
void testVariableTimeSeries1() throws Exception {
int numBatches = isIntegrationTests() ? 1000 : 100;
int batchSize = isIntegrationTests() ? 32 : 8;
int timeStepsMin = 10;
int timeStepsMax = isIntegrationTests() ? 500 : 100;
int valuesPerTimestep = isIntegrationTests() ? 128 : 16;
AsyncDataSetIterator adsi = new AsyncDataSetIterator(
new VariableTimeseriesGenerator(1192, numBatches, batchSize, valuesPerTimestep, timeStepsMin, timeStepsMax, 10), 2, true);
AsyncDataSetIterator adsi = new AsyncDataSetIterator(new VariableTimeseriesGenerator(1192, numBatches, batchSize, valuesPerTimestep, timeStepsMin, timeStepsMax, 10), 2, true);
for (int e = 0; e < 10; e++) {
int cnt = 0;
while (adsi.hasNext()) {
DataSet ds = adsi.next();
//log.info("Features ptr: {}", AtomicAllocator.getInstance().getPointer(mds.getFeatures()[0].data()).address());
assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt,
ds.getFeatures().meanNumber().doubleValue(), 1e-10);
assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt + 0.25,
ds.getLabels().meanNumber().doubleValue(), 1e-10);
assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt + 0.5,
ds.getFeaturesMaskArray().meanNumber().doubleValue(), 1e-10);
assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt + 0.75,
ds.getLabelsMaskArray().meanNumber().doubleValue(), 1e-10);
// log.info("Features ptr: {}", AtomicAllocator.getInstance().getPointer(mds.getFeatures()[0].data()).address());
assertEquals( (double) cnt, ds.getFeatures().meanNumber().doubleValue(), 1e-10,"Failed on epoch " + e + "; iteration: " + cnt + ";");
assertEquals( (double) cnt + 0.25, ds.getLabels().meanNumber().doubleValue(), 1e-10,"Failed on epoch " + e + "; iteration: " + cnt + ";");
assertEquals( (double) cnt + 0.5, ds.getFeaturesMaskArray().meanNumber().doubleValue(), 1e-10,"Failed on epoch " + e + "; iteration: " + cnt + ";");
assertEquals( (double) cnt + 0.75, ds.getLabelsMaskArray().meanNumber().doubleValue(), 1e-10,"Failed on epoch " + e + "; iteration: " + cnt + ";");
cnt++;
}
adsi.reset();
// log.info("Epoch {} finished...", e);
// log.info("Epoch {} finished...", e);
}
}
@Test
public void testVariableTimeSeries2() throws Exception {
AsyncDataSetIterator adsi =
new AsyncDataSetIterator(new VariableTimeseriesGenerator(1192, 100, 32, 128, 100, 100, 100), 2,
true, new InterleavedDataSetCallback(2 * 2));
@DisplayName("Test Variable Time Series 2")
void testVariableTimeSeries2() throws Exception {
AsyncDataSetIterator adsi = new AsyncDataSetIterator(new VariableTimeseriesGenerator(1192, 100, 32, 128, 100, 100, 100), 2, true, new InterleavedDataSetCallback(2 * 2));
for (int e = 0; e < 5; e++) {
int cnt = 0;
while (adsi.hasNext()) {
DataSet ds = adsi.next();
ds.detach();
//log.info("Features ptr: {}", AtomicAllocator.getInstance().getPointer(mds.getFeatures()[0].data()).address());
assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt,
ds.getFeatures().meanNumber().doubleValue(), 1e-10);
assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt + 0.25,
ds.getLabels().meanNumber().doubleValue(), 1e-10);
assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt + 0.5,
ds.getFeaturesMaskArray().meanNumber().doubleValue(), 1e-10);
assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt + 0.75,
ds.getLabelsMaskArray().meanNumber().doubleValue(), 1e-10);
// log.info("Features ptr: {}", AtomicAllocator.getInstance().getPointer(mds.getFeatures()[0].data()).address());
assertEquals((double) cnt, ds.getFeatures().meanNumber().doubleValue(), 1e-10,"Failed on epoch " + e + "; iteration: " + cnt + ";");
assertEquals((double) cnt + 0.25, ds.getLabels().meanNumber().doubleValue(), 1e-10,"Failed on epoch " + e + "; iteration: " + cnt + ";");
assertEquals( (double) cnt + 0.5, ds.getFeaturesMaskArray().meanNumber().doubleValue(), 1e-10,"Failed on epoch " + e + "; iteration: " + cnt + ";");
assertEquals((double) cnt + 0.75, ds.getLabelsMaskArray().meanNumber().doubleValue(), 1e-10,"Failed on epoch " + e + "; iteration: " + cnt + ";");
cnt++;
}
adsi.reset();
// log.info("Epoch {} finished...", e);
// log.info("Epoch {} finished...", e);
}
}
}

View File

@ -17,98 +17,19 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.datasets.iterator;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.tools.VariableMultiTimeseriesGenerator;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@Slf4j
public class AsyncMultiDataSetIteratorTest extends BaseDL4JTest {
/**
* THIS TEST SHOULD BE ALWAYS RUN WITH DOUBLE PRECISION, WITHOUT ANY EXCLUSIONS
*
* @throws Exception
*/
@Test
public void testVariableTimeSeries1() throws Exception {
int numBatches = isIntegrationTests() ? 1000 : 100;
int batchSize = isIntegrationTests() ? 32 : 8;
int timeStepsMin = 10;
int timeStepsMax = isIntegrationTests() ? 500 : 100;
int valuesPerTimestep = isIntegrationTests() ? 128 : 16;
val iterator = new VariableMultiTimeseriesGenerator(1192, numBatches, batchSize, valuesPerTimestep, timeStepsMin, timeStepsMax, 10);
iterator.reset();
iterator.hasNext();
val amdsi = new AsyncMultiDataSetIterator(iterator, 2, true);
for (int e = 0; e < 10; e++) {
int cnt = 0;
while (amdsi.hasNext()) {
MultiDataSet mds = amdsi.next();
//log.info("Features ptr: {}", AtomicAllocator.getInstance().getPointer(mds.getFeatures()[0].data()).address());
assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt,
mds.getFeatures()[0].meanNumber().doubleValue(), 1e-10);
assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt + 0.25,
mds.getLabels()[0].meanNumber().doubleValue(), 1e-10);
assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt + 0.5,
mds.getFeaturesMaskArrays()[0].meanNumber().doubleValue(), 1e-10);
assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt + 0.75,
mds.getLabelsMaskArrays()[0].meanNumber().doubleValue(), 1e-10);
cnt++;
}
amdsi.reset();
log.info("Epoch {} finished...", e);
}
}
@Test
public void testVariableTimeSeries2() throws Exception {
int numBatches = isIntegrationTests() ? 1000 : 100;
int batchSize = isIntegrationTests() ? 32 : 8;
int timeStepsMin = 10;
int timeStepsMax = isIntegrationTests() ? 500 : 100;
int valuesPerTimestep = isIntegrationTests() ? 128 : 16;
val iterator = new VariableMultiTimeseriesGenerator(1192, numBatches, batchSize, valuesPerTimestep, timeStepsMin, timeStepsMax, 10);
for (int e = 0; e < 10; e++) {
iterator.reset();
iterator.hasNext();
val amdsi = new AsyncMultiDataSetIterator(iterator, 2, true);
int cnt = 0;
while (amdsi.hasNext()) {
MultiDataSet mds = amdsi.next();
//log.info("Features ptr: {}", AtomicAllocator.getInstance().getPointer(mds.getFeatures()[0].data()).address());
assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt,
mds.getFeatures()[0].meanNumber().doubleValue(), 1e-10);
assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt + 0.25,
mds.getLabels()[0].meanNumber().doubleValue(), 1e-10);
assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt + 0.5,
mds.getFeaturesMaskArrays()[0].meanNumber().doubleValue(), 1e-10);
assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt + 0.75,
mds.getLabelsMaskArrays()[0].meanNumber().doubleValue(), 1e-10);
cnt++;
}
}
}
/*
@Test
public void testResetBug() throws Exception {
@ -134,6 +55,120 @@ public class AsyncMultiDataSetIteratorTest extends BaseDL4JTest {
trainData.reset();
SequenceRecordReader testFeatures = new CSVSequenceRecordReader();
testFeatures.initialize(new NumberedFileInputSplit("/home/raver119/develop/dl4j-examples/src/main/resources/uci/test/features" + "/%d.csv", 0, 149));
RecordReader testLabels = new CSVRecordReader();
testLabels.initialize(new NumberedFileInputSplit("/home/raver119/develop/dl4j-examples/src/main/resources/uci/test/labels" + "/%d.csv", 0, 149));
MultiDataSetIterator testData = new RecordReaderMultiDataSetIterator.Builder(miniBatchSize)
.addSequenceReader("features", testFeatures)
.addReader("labels", testLabels)
.addInput("features")
.addOutputOneHot("labels", 0, numLabelClasses)
.build();
System.out.println("-------------- HASH 1----------------");
testData.reset();
while(testData.hasNext()){
System.out.println(Arrays.hashCode(testData.next().getFeatures(0).data().asFloat()));
}
System.out.println("-------------- HASH 2 ----------------");
testData.reset();
testData.hasNext(); //***** Remove this (or move to after async creation), and we get expected results *****
val adsi = new AsyncMultiDataSetIterator(testData, 4, true); //OR remove this (keeping hasNext) and we get expected results
//val adsi = new AsyncShieldMultiDataSetIterator(testData);
while(adsi.hasNext()){
System.out.println(Arrays.hashCode(adsi.next().getFeatures(0).data().asFloat()));
}
}
*/
@DisplayName("Async Multi Data Set Iterator Test")
class AsyncMultiDataSetIteratorTest extends BaseDL4JTest {
/**
* THIS TEST SHOULD BE ALWAYS RUN WITH DOUBLE PRECISION, WITHOUT ANY EXCLUSIONS
*
* @throws Exception
*/
@Test
@DisplayName("Test Variable Time Series 1")
void testVariableTimeSeries1() throws Exception {
int numBatches = isIntegrationTests() ? 1000 : 100;
int batchSize = isIntegrationTests() ? 32 : 8;
int timeStepsMin = 10;
int timeStepsMax = isIntegrationTests() ? 500 : 100;
int valuesPerTimestep = isIntegrationTests() ? 128 : 16;
val iterator = new VariableMultiTimeseriesGenerator(1192, numBatches, batchSize, valuesPerTimestep, timeStepsMin, timeStepsMax, 10);
iterator.reset();
iterator.hasNext();
val amdsi = new AsyncMultiDataSetIterator(iterator, 2, true);
for (int e = 0; e < 10; e++) {
int cnt = 0;
while (amdsi.hasNext()) {
MultiDataSet mds = amdsi.next();
// log.info("Features ptr: {}", AtomicAllocator.getInstance().getPointer(mds.getFeatures()[0].data()).address());
assertEquals( (double) cnt, mds.getFeatures()[0].meanNumber().doubleValue(), 1e-10,"Failed on epoch " + e + "; iteration: " + cnt + ";");
assertEquals( (double) cnt + 0.25, mds.getLabels()[0].meanNumber().doubleValue(), 1e-10,"Failed on epoch " + e + "; iteration: " + cnt + ";");
assertEquals((double) cnt + 0.5, mds.getFeaturesMaskArrays()[0].meanNumber().doubleValue(), 1e-10,"Failed on epoch " + e + "; iteration: " + cnt + ";");
assertEquals((double) cnt + 0.75, mds.getLabelsMaskArrays()[0].meanNumber().doubleValue(), 1e-10,"Failed on epoch " + e + "; iteration: " + cnt + ";");
cnt++;
}
amdsi.reset();
log.info("Epoch {} finished...", e);
}
}
@Test
@DisplayName("Test Variable Time Series 2")
void testVariableTimeSeries2() throws Exception {
int numBatches = isIntegrationTests() ? 1000 : 100;
int batchSize = isIntegrationTests() ? 32 : 8;
int timeStepsMin = 10;
int timeStepsMax = isIntegrationTests() ? 500 : 100;
int valuesPerTimestep = isIntegrationTests() ? 128 : 16;
val iterator = new VariableMultiTimeseriesGenerator(1192, numBatches, batchSize, valuesPerTimestep, timeStepsMin, timeStepsMax, 10);
for (int e = 0; e < 10; e++) {
iterator.reset();
iterator.hasNext();
val amdsi = new AsyncMultiDataSetIterator(iterator, 2, true);
int cnt = 0;
while (amdsi.hasNext()) {
MultiDataSet mds = amdsi.next();
// log.info("Features ptr: {}", AtomicAllocator.getInstance().getPointer(mds.getFeatures()[0].data()).address());
assertEquals( (double) cnt, mds.getFeatures()[0].meanNumber().doubleValue(), 1e-10,"Failed on epoch " + e + "; iteration: " + cnt + ";");
assertEquals((double) cnt + 0.25, mds.getLabels()[0].meanNumber().doubleValue(), 1e-10,"Failed on epoch " + e + "; iteration: " + cnt + ";");
assertEquals( (double) cnt + 0.5, mds.getFeaturesMaskArrays()[0].meanNumber().doubleValue(), 1e-10,"Failed on epoch " + e + "; iteration: " + cnt + ";");
assertEquals( (double) cnt + 0.75, mds.getLabelsMaskArrays()[0].meanNumber().doubleValue(), 1e-10,"Failed on epoch " + e + "; iteration: " + cnt + ";");
cnt++;
}
}
}
/*
@Test
public void testResetBug() throws Exception {
// /home/raver119/develop/dl4j-examples/src/main/resources/uci/train/features
SequenceRecordReader trainFeatures = new CSVSequenceRecordReader();
trainFeatures.initialize(new NumberedFileInputSplit("/home/raver119/develop/dl4j-examples/src/main/resources/uci/train/features" + "/%d.csv", 0, 449));
RecordReader trainLabels = new CSVRecordReader();
trainLabels.initialize(new NumberedFileInputSplit("/home/raver119/develop/dl4j-examples/src/main/resources/uci/train/labels" + "/%d.csv", 0, 449));
int miniBatchSize = 10;
int numLabelClasses = 6;
MultiDataSetIterator trainData = new RecordReaderMultiDataSetIterator.Builder(miniBatchSize)
.addSequenceReader("features", trainFeatures)
.addReader("labels", trainLabels)
.addInput("features")
.addOutputOneHot("labels", 0, numLabelClasses)
.build();
//Normalize the training data
MultiDataNormalization normalizer = new MultiNormalizerStandardize();
normalizer.fit(trainData); //Collect training data statistics
trainData.reset();
SequenceRecordReader testFeatures = new CSVSequenceRecordReader();
testFeatures.initialize(new NumberedFileInputSplit("/home/raver119/develop/dl4j-examples/src/main/resources/uci/test/features" + "/%d.csv", 0, 149));
RecordReader testLabels = new CSVRecordReader();

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.datasets.iterator;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
@ -41,8 +40,8 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.CollectScoresIterationListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
@ -50,26 +49,28 @@ import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.io.ClassPathResource;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.*;
public class DataSetIteratorTest extends BaseDL4JTest {
@DisplayName("Data Set Iterator Test")
class DataSetIteratorTest extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 360000; //Should run quickly; increased to large timeout due to occasonal slow CI downloads
// Should run quickly; increased to large timeout due to occasonal slow CI downloads
return 360000;
}
@Test
public void testBatchSizeOfOneIris() throws Exception {
//Test for (a) iterators returning correct number of examples, and
//(b) Labels are a proper one-hot vector (i.e., sum is 1.0)
//Iris:
@DisplayName("Test Batch Size Of One Iris")
void testBatchSizeOfOneIris() throws Exception {
// Test for (a) iterators returning correct number of examples, and
// (b) Labels are a proper one-hot vector (i.e., sum is 1.0)
// Iris:
DataSetIterator iris = new IrisDataSetIterator(1, 5);
int irisC = 0;
while (iris.hasNext()) {
@ -81,9 +82,9 @@ public class DataSetIteratorTest extends BaseDL4JTest {
}
@Test
public void testBatchSizeOfOneMnist() throws Exception {
//MNIST:
@DisplayName("Test Batch Size Of One Mnist")
void testBatchSizeOfOneMnist() throws Exception {
// MNIST:
DataSetIterator mnist = new MnistDataSetIterator(1, 5);
int mnistC = 0;
while (mnist.hasNext()) {
@ -95,25 +96,21 @@ public class DataSetIteratorTest extends BaseDL4JTest {
}
@Test
public void testMnist() throws Exception {
@DisplayName("Test Mnist")
void testMnist() throws Exception {
ClassPathResource cpr = new ClassPathResource("mnist_first_200.txt");
CSVRecordReader rr = new CSVRecordReader(0, ',');
rr.initialize(new FileSplit(cpr.getTempFileFromArchive()));
RecordReaderDataSetIterator dsi = new RecordReaderDataSetIterator(rr, 10, 0, 10);
MnistDataSetIterator iter = new MnistDataSetIterator(10, 200, false, true, false, 0);
while (dsi.hasNext()) {
DataSet dsExp = dsi.next();
DataSet dsAct = iter.next();
INDArray fExp = dsExp.getFeatures();
fExp.divi(255);
INDArray lExp = dsExp.getLabels();
INDArray fAct = dsAct.getFeatures();
INDArray lAct = dsAct.getLabels();
assertEquals(fExp, fAct.castTo(fExp.dataType()));
assertEquals(lExp, lAct.castTo(lExp.dataType()));
}
@ -121,12 +118,13 @@ public class DataSetIteratorTest extends BaseDL4JTest {
}
@Test
public void testLfwIterator() throws Exception {
@DisplayName("Test Lfw Iterator")
void testLfwIterator() throws Exception {
int numExamples = 1;
int row = 28;
int col = 28;
int channels = 1;
LFWDataSetIterator iter = new LFWDataSetIterator(numExamples, new int[] {row, col, channels}, true);
LFWDataSetIterator iter = new LFWDataSetIterator(numExamples, new int[] { row, col, channels }, true);
assertTrue(iter.hasNext());
DataSet data = iter.next();
assertEquals(numExamples, data.getLabels().size(0));
@ -134,7 +132,8 @@ public class DataSetIteratorTest extends BaseDL4JTest {
}
@Test
public void testTinyImageNetIterator() throws Exception {
@DisplayName("Test Tiny Image Net Iterator")
void testTinyImageNetIterator() throws Exception {
int numClasses = 200;
int row = 64;
int col = 64;
@ -143,24 +142,26 @@ public class DataSetIteratorTest extends BaseDL4JTest {
assertTrue(iter.hasNext());
DataSet data = iter.next();
assertEquals(numClasses, data.getLabels().size(1));
assertArrayEquals(new long[]{1, channels, row, col}, data.getFeatures().shape());
assertArrayEquals(new long[] { 1, channels, row, col }, data.getFeatures().shape());
}
@Test
public void testTinyImageNetIterator2() throws Exception {
@DisplayName("Test Tiny Image Net Iterator 2")
void testTinyImageNetIterator2() throws Exception {
int numClasses = 200;
int row = 224;
int col = 224;
int channels = 3;
TinyImageNetDataSetIterator iter = new TinyImageNetDataSetIterator(1, new int[]{row, col}, DataSetType.TEST);
TinyImageNetDataSetIterator iter = new TinyImageNetDataSetIterator(1, new int[] { row, col }, DataSetType.TEST);
assertTrue(iter.hasNext());
DataSet data = iter.next();
assertEquals(numClasses, data.getLabels().size(1));
assertArrayEquals(new long[]{1, channels, row, col}, data.getFeatures().shape());
assertArrayEquals(new long[] { 1, channels, row, col }, data.getFeatures().shape());
}
@Test
public void testLfwModel() throws Exception {
@DisplayName("Test Lfw Model")
void testLfwModel() throws Exception {
final int numRows = 28;
final int numColumns = 28;
int numChannels = 3;
@ -169,39 +170,22 @@ public class DataSetIteratorTest extends BaseDL4JTest {
int batchSize = 2;
int seed = 123;
int listenerFreq = 1;
LFWDataSetIterator lfw = new LFWDataSetIterator(batchSize, numSamples,
new int[] {numRows, numColumns, numChannels}, outputNum, false, true, 1.0, new Random(seed));
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed)
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list()
.layer(0, new ConvolutionLayer.Builder(5, 5).nIn(numChannels).nOut(6)
.weightInit(WeightInit.XAVIER).activation(Activation.RELU).build())
.layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[] {2, 2})
.stride(1, 1).build())
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nOut(outputNum).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX)
.build())
.setInputType(InputType.convolutionalFlat(numRows, numColumns, numChannels))
;
LFWDataSetIterator lfw = new LFWDataSetIterator(batchSize, numSamples, new int[] { numRows, numColumns, numChannels }, outputNum, false, true, 1.0, new Random(seed));
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed).gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list().layer(0, new ConvolutionLayer.Builder(5, 5).nIn(numChannels).nOut(6).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build()).layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[] { 2, 2 }).stride(1, 1).build()).layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(outputNum).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build()).setInputType(InputType.convolutionalFlat(numRows, numColumns, numChannels));
MultiLayerNetwork model = new MultiLayerNetwork(builder.build());
model.init();
model.setListeners(new ScoreIterationListener(listenerFreq));
model.fit(lfw.next());
DataSet dataTest = lfw.next();
INDArray output = model.output(dataTest.getFeatures());
Evaluation eval = new Evaluation(outputNum);
eval.eval(dataTest.getLabels(), output);
// System.out.println(eval.stats());
// System.out.println(eval.stats());
}
@Test
public void testCifar10Iterator() throws Exception {
@DisplayName("Test Cifar 10 Iterator")
void testCifar10Iterator() throws Exception {
int numExamples = 1;
int row = 32;
int col = 32;
@ -213,12 +197,13 @@ public class DataSetIteratorTest extends BaseDL4JTest {
assertEquals(channels * row * col, data.getFeatures().ravel().length());
}
@Test @Ignore //Ignored for now - CIFAR iterator needs work - https://github.com/eclipse/deeplearning4j/issues/4673
public void testCifarModel() throws Exception {
// Ignored for now - CIFAR iterator needs work - https://github.com/eclipse/deeplearning4j/issues/4673
@Test
@Disabled
@DisplayName("Test Cifar Model")
void testCifarModel() throws Exception {
// Streaming
runCifar(false);
// Preprocess
runCifar(true);
}
@ -231,32 +216,14 @@ public class DataSetIteratorTest extends BaseDL4JTest {
int batchSize = 5;
int seed = 123;
int listenerFreq = 1;
Cifar10DataSetIterator cifar = new Cifar10DataSetIterator(batchSize);
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed)
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list()
.layer(0, new ConvolutionLayer.Builder(5, 5).nIn(channels).nOut(6).weightInit(WeightInit.XAVIER)
.activation(Activation.RELU).build())
.layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[] {2, 2})
.build())
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nOut(outputNum).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX)
.build())
.setInputType(InputType.convolutionalFlat(height, width, channels));
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed).gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list().layer(0, new ConvolutionLayer.Builder(5, 5).nIn(channels).nOut(6).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build()).layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[] { 2, 2 }).build()).layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(outputNum).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build()).setInputType(InputType.convolutionalFlat(height, width, channels));
MultiLayerNetwork model = new MultiLayerNetwork(builder.build());
model.init();
//model.setListeners(Arrays.asList((TrainingListener) new ScoreIterationListener(listenerFreq)));
// model.setListeners(Arrays.asList((TrainingListener) new ScoreIterationListener(listenerFreq)));
CollectScoresIterationListener listener = new CollectScoresIterationListener(listenerFreq);
model.setListeners(listener);
model.fit(cifar);
cifar = new Cifar10DataSetIterator(batchSize);
Evaluation eval = new Evaluation(cifar.getLabels());
while (cifar.hasNext()) {
@ -264,37 +231,31 @@ public class DataSetIteratorTest extends BaseDL4JTest {
INDArray output = model.output(testDS.getFeatures());
eval.eval(testDS.getLabels(), output);
}
// System.out.println(eval.stats(true));
// System.out.println(eval.stats(true));
listener.exportScores(System.out);
}
@Test
public void testIteratorDataSetIteratorCombining() {
//Test combining of a bunch of small (size 1) data sets together
@DisplayName("Test Iterator Data Set Iterator Combining")
void testIteratorDataSetIteratorCombining() {
// Test combining of a bunch of small (size 1) data sets together
int batchSize = 3;
int numBatches = 4;
int featureSize = 5;
int labelSize = 6;
Nd4j.getRandom().setSeed(12345);
List<DataSet> orig = new ArrayList<>();
for (int i = 0; i < batchSize * numBatches; i++) {
INDArray features = Nd4j.rand(1, featureSize);
INDArray labels = Nd4j.rand(1, labelSize);
orig.add(new DataSet(features, labels));
}
DataSetIterator iter = new IteratorDataSetIterator(orig.iterator(), batchSize);
int count = 0;
while (iter.hasNext()) {
DataSet ds = iter.next();
assertArrayEquals(new long[] {batchSize, featureSize}, ds.getFeatures().shape());
assertArrayEquals(new long[] {batchSize, labelSize}, ds.getLabels().shape());
assertArrayEquals(new long[] { batchSize, featureSize }, ds.getFeatures().shape());
assertArrayEquals(new long[] { batchSize, labelSize }, ds.getLabels().shape());
List<INDArray> fList = new ArrayList<>();
List<INDArray> lList = new ArrayList<>();
for (int i = 0; i < batchSize; i++) {
@ -302,66 +263,44 @@ public class DataSetIteratorTest extends BaseDL4JTest {
fList.add(dsOrig.getFeatures());
lList.add(dsOrig.getLabels());
}
INDArray fExp = Nd4j.vstack(fList);
INDArray lExp = Nd4j.vstack(lList);
assertEquals(fExp, ds.getFeatures());
assertEquals(lExp, ds.getLabels());
count++;
}
assertEquals(count, numBatches);
}
@Test
public void testIteratorDataSetIteratorSplitting() {
//Test splitting large data sets into smaller ones
@DisplayName("Test Iterator Data Set Iterator Splitting")
void testIteratorDataSetIteratorSplitting() {
// Test splitting large data sets into smaller ones
int origBatchSize = 4;
int origNumDSs = 3;
int batchSize = 3;
int numBatches = 4;
int featureSize = 5;
int labelSize = 6;
Nd4j.getRandom().setSeed(12345);
List<DataSet> orig = new ArrayList<>();
for (int i = 0; i < origNumDSs; i++) {
INDArray features = Nd4j.rand(origBatchSize, featureSize);
INDArray labels = Nd4j.rand(origBatchSize, labelSize);
orig.add(new DataSet(features, labels));
}
List<DataSet> expected = new ArrayList<>();
expected.add(new DataSet(orig.get(0).getFeatures().getRows(0, 1, 2),
orig.get(0).getLabels().getRows(0, 1, 2)));
expected.add(new DataSet(
Nd4j.vstack(orig.get(0).getFeatures().getRows(3),
orig.get(1).getFeatures().getRows(0, 1)),
Nd4j.vstack(orig.get(0).getLabels().getRows(3), orig.get(1).getLabels().getRows(0, 1))));
expected.add(new DataSet(
Nd4j.vstack(orig.get(1).getFeatures().getRows(2, 3),
orig.get(2).getFeatures().getRows(0)),
Nd4j.vstack(orig.get(1).getLabels().getRows(2, 3), orig.get(2).getLabels().getRows(0))));
expected.add(new DataSet(orig.get(2).getFeatures().getRows(1, 2, 3),
orig.get(2).getLabels().getRows(1, 2, 3)));
expected.add(new DataSet(orig.get(0).getFeatures().getRows(0, 1, 2), orig.get(0).getLabels().getRows(0, 1, 2)));
expected.add(new DataSet(Nd4j.vstack(orig.get(0).getFeatures().getRows(3), orig.get(1).getFeatures().getRows(0, 1)), Nd4j.vstack(orig.get(0).getLabels().getRows(3), orig.get(1).getLabels().getRows(0, 1))));
expected.add(new DataSet(Nd4j.vstack(orig.get(1).getFeatures().getRows(2, 3), orig.get(2).getFeatures().getRows(0)), Nd4j.vstack(orig.get(1).getLabels().getRows(2, 3), orig.get(2).getLabels().getRows(0))));
expected.add(new DataSet(orig.get(2).getFeatures().getRows(1, 2, 3), orig.get(2).getLabels().getRows(1, 2, 3)));
DataSetIterator iter = new IteratorDataSetIterator(orig.iterator(), batchSize);
int count = 0;
while (iter.hasNext()) {
DataSet ds = iter.next();
assertEquals(expected.get(count), ds);
count++;
}
assertEquals(count, numBatches);
}
}

View File

@ -17,13 +17,12 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.datasets.iterator;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.junit.Rule;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.junit.rules.ExpectedException;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
@ -32,23 +31,27 @@ import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import static org.junit.Assert.*;
import static org.junit.jupiter.api.Assertions.*;
public class EarlyTerminationDataSetIteratorTest extends BaseDL4JTest {
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@DisplayName("Early Termination Data Set Iterator Test")
class EarlyTerminationDataSetIteratorTest extends BaseDL4JTest {
int minibatchSize = 10;
int numExamples = 105;
@Rule
public final ExpectedException exception = ExpectedException.none();
@Test
public void testNextAndReset() throws Exception {
@DisplayName("Test Next And Reset")
void testNextAndReset() throws Exception {
int terminateAfter = 2;
DataSetIterator iter = new MnistDataSetIterator(minibatchSize, numExamples);
EarlyTerminationDataSetIterator earlyEndIter = new EarlyTerminationDataSetIterator(iter, terminateAfter);
assertTrue(earlyEndIter.hasNext());
int batchesSeen = 0;
List<DataSet> seenData = new ArrayList<>();
@ -59,8 +62,7 @@ public class EarlyTerminationDataSetIteratorTest extends BaseDL4JTest {
batchesSeen++;
}
assertEquals(batchesSeen, terminateAfter);
//check data is repeated after reset
// check data is repeated after reset
earlyEndIter.reset();
batchesSeen = 0;
while (earlyEndIter.hasNext()) {
@ -72,27 +74,23 @@ public class EarlyTerminationDataSetIteratorTest extends BaseDL4JTest {
}
@Test
public void testNextNum() throws IOException {
@DisplayName("Test Next Num")
void testNextNum() throws IOException {
int terminateAfter = 1;
DataSetIterator iter = new MnistDataSetIterator(minibatchSize, numExamples);
EarlyTerminationDataSetIterator earlyEndIter = new EarlyTerminationDataSetIterator(iter, terminateAfter);
earlyEndIter.next(10);
assertEquals(false, earlyEndIter.hasNext());
earlyEndIter.reset();
assertEquals(true, earlyEndIter.hasNext());
}
@Test
public void testCallstoNextNotAllowed() throws IOException {
@DisplayName("Test Callsto Next Not Allowed")
void testCallstoNextNotAllowed() throws IOException {
int terminateAfter = 1;
DataSetIterator iter = new MnistDataSetIterator(minibatchSize, numExamples);
EarlyTerminationDataSetIterator earlyEndIter = new EarlyTerminationDataSetIterator(iter, terminateAfter);
earlyEndIter.next(10);
iter.reset();
exception.expect(RuntimeException.class);

View File

@ -17,40 +17,39 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.datasets.iterator;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.junit.Rule;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.junit.rules.ExpectedException;
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
public class EarlyTerminationMultiDataSetIteratorTest extends BaseDL4JTest {
@DisplayName("Early Termination Multi Data Set Iterator Test")
class EarlyTerminationMultiDataSetIteratorTest extends BaseDL4JTest {
int minibatchSize = 5;
int numExamples = 105;
@Rule
public final ExpectedException exception = ExpectedException.none();
@Test
public void testNextAndReset() throws Exception {
@DisplayName("Test Next And Reset")
void testNextAndReset() throws Exception {
int terminateAfter = 2;
MultiDataSetIterator iter =
new MultiDataSetIteratorAdapter(new MnistDataSetIterator(minibatchSize, numExamples));
MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(new MnistDataSetIterator(minibatchSize, numExamples));
int count = 0;
List<MultiDataSet> seenMDS = new ArrayList<>();
while (count < terminateAfter) {
@ -58,10 +57,7 @@ public class EarlyTerminationMultiDataSetIteratorTest extends BaseDL4JTest {
count++;
}
iter.reset();
EarlyTerminationMultiDataSetIterator earlyEndIter =
new EarlyTerminationMultiDataSetIterator(iter, terminateAfter);
EarlyTerminationMultiDataSetIterator earlyEndIter = new EarlyTerminationMultiDataSetIterator(iter, terminateAfter);
assertTrue(earlyEndIter.hasNext());
count = 0;
while (earlyEndIter.hasNext()) {
@ -71,8 +67,7 @@ public class EarlyTerminationMultiDataSetIteratorTest extends BaseDL4JTest {
count++;
}
assertEquals(count, terminateAfter);
//check data is repeated
// check data is repeated
earlyEndIter.reset();
count = 0;
while (earlyEndIter.hasNext()) {
@ -84,34 +79,26 @@ public class EarlyTerminationMultiDataSetIteratorTest extends BaseDL4JTest {
}
@Test
public void testNextNum() throws IOException {
@DisplayName("Test Next Num")
void testNextNum() throws IOException {
int terminateAfter = 1;
MultiDataSetIterator iter =
new MultiDataSetIteratorAdapter(new MnistDataSetIterator(minibatchSize, numExamples));
EarlyTerminationMultiDataSetIterator earlyEndIter =
new EarlyTerminationMultiDataSetIterator(iter, terminateAfter);
MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(new MnistDataSetIterator(minibatchSize, numExamples));
EarlyTerminationMultiDataSetIterator earlyEndIter = new EarlyTerminationMultiDataSetIterator(iter, terminateAfter);
earlyEndIter.next(10);
assertEquals(false, earlyEndIter.hasNext());
earlyEndIter.reset();
assertEquals(true, earlyEndIter.hasNext());
}
@Test
public void testCallstoNextNotAllowed() throws IOException {
@DisplayName("Test Callsto Next Not Allowed")
void testCallstoNextNotAllowed() throws IOException {
int terminateAfter = 1;
MultiDataSetIterator iter =
new MultiDataSetIteratorAdapter(new MnistDataSetIterator(minibatchSize, numExamples));
EarlyTerminationMultiDataSetIterator earlyEndIter =
new EarlyTerminationMultiDataSetIterator(iter, terminateAfter);
MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(new MnistDataSetIterator(minibatchSize, numExamples));
EarlyTerminationMultiDataSetIterator earlyEndIter = new EarlyTerminationMultiDataSetIterator(iter, terminateAfter);
earlyEndIter.next(10);
iter.reset();
exception.expect(RuntimeException.class);
earlyEndIter.next(10);
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.datasets.iterator;
import lombok.extern.slf4j.Slf4j;
@ -25,90 +24,75 @@ import lombok.val;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.parallel.JointParallelDataSetIterator;
import org.deeplearning4j.datasets.iterator.tools.SimpleVariableGenerator;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.enums.InequalityHandling;
import org.nd4j.linalg.factory.Nd4j;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@Slf4j
public class JointParallelDataSetIteratorTest extends BaseDL4JTest {
@DisplayName("Joint Parallel Data Set Iterator Test")
class JointParallelDataSetIteratorTest extends BaseDL4JTest {
/**
* Simple test, checking datasets alignment. They all should have the same data for the same cycle
*
*
* @throws Exception
*/
@Test
public void testJointIterator1() throws Exception {
@DisplayName("Test Joint Iterator 1")
void testJointIterator1() throws Exception {
DataSetIterator iteratorA = new SimpleVariableGenerator(119, 100, 32, 100, 10);
DataSetIterator iteratorB = new SimpleVariableGenerator(119, 100, 32, 100, 10);
JointParallelDataSetIterator jpdsi = new JointParallelDataSetIterator.Builder(InequalityHandling.STOP_EVERYONE)
.addSourceIterator(iteratorA).addSourceIterator(iteratorB).build();
JointParallelDataSetIterator jpdsi = new JointParallelDataSetIterator.Builder(InequalityHandling.STOP_EVERYONE).addSourceIterator(iteratorA).addSourceIterator(iteratorB).build();
int cnt = 0;
int example = 0;
while (jpdsi.hasNext()) {
DataSet ds = jpdsi.next();
assertNotNull("Failed on iteration " + cnt, ds);
// ds.detach();
//ds.migrate();
assertEquals("Failed on iteration " + cnt, (double) example, ds.getFeatures().meanNumber().doubleValue(), 0.001);
assertEquals("Failed on iteration " + cnt, (double) example + 0.5, ds.getLabels().meanNumber().doubleValue(), 0.001);
assertNotNull(ds,"Failed on iteration " + cnt);
// ds.detach();
// ds.migrate();
assertEquals( (double) example, ds.getFeatures().meanNumber().doubleValue(), 0.001,"Failed on iteration " + cnt);
assertEquals( (double) example + 0.5, ds.getLabels().meanNumber().doubleValue(), 0.001,"Failed on iteration " + cnt);
cnt++;
if (cnt % 2 == 0)
example++;
}
assertEquals(100, example);
assertEquals(200, cnt);
}
/**
* This test checks for pass_null scenario, so in total we should have 300 real datasets + 100 nulls
* @throws Exception
*/
@Test
public void testJointIterator2() throws Exception {
@DisplayName("Test Joint Iterator 2")
void testJointIterator2() throws Exception {
DataSetIterator iteratorA = new SimpleVariableGenerator(119, 200, 32, 100, 10);
DataSetIterator iteratorB = new SimpleVariableGenerator(119, 100, 32, 100, 10);
JointParallelDataSetIterator jpdsi = new JointParallelDataSetIterator.Builder(InequalityHandling.PASS_NULL)
.addSourceIterator(iteratorA).addSourceIterator(iteratorB).build();
JointParallelDataSetIterator jpdsi = new JointParallelDataSetIterator.Builder(InequalityHandling.PASS_NULL).addSourceIterator(iteratorA).addSourceIterator(iteratorB).build();
int cnt = 0;
int example = 0;
int nulls = 0;
while (jpdsi.hasNext()) {
DataSet ds = jpdsi.next();
if (cnt < 200)
assertNotNull("Failed on iteration " + cnt, ds);
assertNotNull(ds,"Failed on iteration " + cnt);
if (ds == null)
nulls++;
if (cnt % 2 == 2) {
assertEquals("Failed on iteration " + cnt, (double) example,
ds.getFeatures().meanNumber().doubleValue(), 0.001);
assertEquals("Failed on iteration " + cnt, (double) example + 0.5,
ds.getLabels().meanNumber().doubleValue(), 0.001);
assertEquals((double) example, ds.getFeatures().meanNumber().doubleValue(), 0.001,"Failed on iteration " + cnt);
assertEquals((double) example + 0.5, ds.getLabels().meanNumber().doubleValue(), 0.001,"Failed on iteration " + cnt);
}
cnt++;
if (cnt % 2 == 0)
example++;
}
assertEquals(100, nulls);
assertEquals(200, example);
assertEquals(400, cnt);
@ -120,25 +104,18 @@ public class JointParallelDataSetIteratorTest extends BaseDL4JTest {
* @throws Exception
*/
@Test
public void testJointIterator3() throws Exception {
@DisplayName("Test Joint Iterator 3")
void testJointIterator3() throws Exception {
DataSetIterator iteratorA = new SimpleVariableGenerator(119, 200, 32, 100, 10);
DataSetIterator iteratorB = new SimpleVariableGenerator(119, 100, 32, 100, 10);
JointParallelDataSetIterator jpdsi = new JointParallelDataSetIterator.Builder(InequalityHandling.RELOCATE)
.addSourceIterator(iteratorA).addSourceIterator(iteratorB).build();
JointParallelDataSetIterator jpdsi = new JointParallelDataSetIterator.Builder(InequalityHandling.RELOCATE).addSourceIterator(iteratorA).addSourceIterator(iteratorB).build();
int cnt = 0;
int example = 0;
while (jpdsi.hasNext()) {
DataSet ds = jpdsi.next();
assertNotNull("Failed on iteration " + cnt, ds);
assertEquals("Failed on iteration " + cnt, (double) example, ds.getFeatures().meanNumber().doubleValue(),
0.001);
assertEquals("Failed on iteration " + cnt, (double) example + 0.5,
ds.getLabels().meanNumber().doubleValue(), 0.001);
assertNotNull(ds,"Failed on iteration " + cnt);
assertEquals((double) example, ds.getFeatures().meanNumber().doubleValue(), 0.001,"Failed on iteration " + cnt);
assertEquals( (double) example + 0.5, ds.getLabels().meanNumber().doubleValue(), 0.001,"Failed on iteration " + cnt);
cnt++;
if (cnt < 200) {
if (cnt % 2 == 0)
@ -146,8 +123,6 @@ public class JointParallelDataSetIteratorTest extends BaseDL4JTest {
} else
example++;
}
assertEquals(300, cnt);
assertEquals(200, example);
}
@ -158,52 +133,38 @@ public class JointParallelDataSetIteratorTest extends BaseDL4JTest {
* @throws Exception
*/
@Test
public void testJointIterator4() throws Exception {
@DisplayName("Test Joint Iterator 4")
void testJointIterator4() throws Exception {
DataSetIterator iteratorA = new SimpleVariableGenerator(119, 200, 32, 100, 10);
DataSetIterator iteratorB = new SimpleVariableGenerator(119, 100, 32, 100, 10);
JointParallelDataSetIterator jpdsi = new JointParallelDataSetIterator.Builder(InequalityHandling.RESET)
.addSourceIterator(iteratorA).addSourceIterator(iteratorB).build();
JointParallelDataSetIterator jpdsi = new JointParallelDataSetIterator.Builder(InequalityHandling.RESET).addSourceIterator(iteratorA).addSourceIterator(iteratorB).build();
int cnt = 0;
int cnt_sec = 0;
int example_sec = 0;
int example = 0;
while (jpdsi.hasNext()) {
DataSet ds = jpdsi.next();
assertNotNull("Failed on iteration " + cnt, ds);
assertNotNull(ds,"Failed on iteration " + cnt);
if (cnt % 2 == 0) {
assertEquals("Failed on iteration " + cnt, (double) example,
ds.getFeatures().meanNumber().doubleValue(), 0.001);
assertEquals("Failed on iteration " + cnt, (double) example + 0.5,
ds.getLabels().meanNumber().doubleValue(), 0.001);
assertEquals( (double) example, ds.getFeatures().meanNumber().doubleValue(), 0.001,"Failed on iteration " + cnt);
assertEquals((double) example + 0.5, ds.getLabels().meanNumber().doubleValue(), 0.001,"Failed on iteration " + cnt);
} else {
if (cnt <= 200) {
assertEquals("Failed on iteration " + cnt, (double) example,
ds.getFeatures().meanNumber().doubleValue(), 0.001);
assertEquals("Failed on iteration " + cnt, (double) example + 0.5,
ds.getLabels().meanNumber().doubleValue(), 0.001);
assertEquals((double) example, ds.getFeatures().meanNumber().doubleValue(), 0.001,"Failed on iteration " + cnt);
assertEquals( (double) example + 0.5, ds.getLabels().meanNumber().doubleValue(), 0.001,"Failed on iteration " + cnt);
} else {
assertEquals("Failed on iteration " + cnt + ", second iteration " + cnt_sec, (double) example_sec,
ds.getFeatures().meanNumber().doubleValue(), 0.001);
assertEquals("Failed on iteration " + cnt + ", second iteration " + cnt_sec,
(double) example_sec + 0.5, ds.getLabels().meanNumber().doubleValue(), 0.001);
assertEquals((double) example_sec, ds.getFeatures().meanNumber().doubleValue(), 0.001,"Failed on iteration " + cnt + ", second iteration " + cnt_sec);
assertEquals((double) example_sec + 0.5, ds.getLabels().meanNumber().doubleValue(), 0.001,"Failed on iteration " + cnt + ", second iteration " + cnt_sec);
}
}
cnt++;
if (cnt % 2 == 0)
example++;
if (cnt > 201 && cnt % 2 == 1) {
cnt_sec++;
example_sec++;
}
}
assertEquals(400, cnt);
assertEquals(200, example);
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.datasets.iterator;
import org.datavec.api.records.reader.RecordReader;
@ -27,34 +26,33 @@ import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.util.TestDataSetConsumer;
import org.junit.Rule;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.junit.rules.Timeout;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.io.ClassPathResource;
import org.nd4j.common.resources.Resources;
import java.util.Iterator;
import java.util.concurrent.atomic.AtomicLong;
import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.*;
public class MultipleEpochsIteratorTest extends BaseDL4JTest {
@DisplayName("Multiple Epochs Iterator Test")
class MultipleEpochsIteratorTest extends BaseDL4JTest {
@Rule
public Timeout timeout = Timeout.seconds(300);
@Test
public void testNextAndReset() throws Exception {
@DisplayName("Test Next And Reset")
void testNextAndReset() throws Exception {
int epochs = 3;
RecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(Resources.asFile("iris.txt")));
DataSetIterator iter = new RecordReaderDataSetIterator(rr, 150);
MultipleEpochsIterator multiIter = new MultipleEpochsIterator(epochs, iter);
assertTrue(multiIter.hasNext());
while (multiIter.hasNext()) {
DataSet path = multiIter.next();
@ -64,18 +62,15 @@ public class MultipleEpochsIteratorTest extends BaseDL4JTest {
}
@Test
public void testLoadFullDataSet() throws Exception {
@DisplayName("Test Load Full Data Set")
void testLoadFullDataSet() throws Exception {
int epochs = 3;
RecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(Resources.asFile("iris.txt")));
DataSetIterator iter = new RecordReaderDataSetIterator(rr, 150);
DataSet ds = iter.next(50);
assertEquals(50, ds.getFeatures().size(0));
MultipleEpochsIterator multiIter = new MultipleEpochsIterator(epochs, ds);
assertTrue(multiIter.hasNext());
int count = 0;
while (multiIter.hasNext()) {
@ -89,28 +84,26 @@ public class MultipleEpochsIteratorTest extends BaseDL4JTest {
}
@Test
public void testLoadBatchDataSet() throws Exception {
@DisplayName("Test Load Batch Data Set")
void testLoadBatchDataSet() throws Exception {
int epochs = 2;
RecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile()));
DataSetIterator iter = new RecordReaderDataSetIterator(rr, 150, 4, 3);
DataSet ds = iter.next(20);
assertEquals(20, ds.getFeatures().size(0));
MultipleEpochsIterator multiIter = new MultipleEpochsIterator(epochs, ds);
while (multiIter.hasNext()) {
DataSet path = multiIter.next(10);
assertNotNull(path);
assertEquals(10, path.numExamples(), 0.0);
}
assertEquals(epochs, multiIter.epochs);
}
@Test
public void testMEDIWithLoad1() throws Exception {
@DisplayName("Test MEDI With Load 1")
void testMEDIWithLoad1() throws Exception {
ExistingDataSetIterator iter = new ExistingDataSetIterator(new IterableWithoutException(100));
MultipleEpochsIterator iterator = new MultipleEpochsIterator(10, iter, 24);
TestDataSetConsumer consumer = new TestDataSetConsumer(iterator, 1);
@ -119,38 +112,39 @@ public class MultipleEpochsIteratorTest extends BaseDL4JTest {
}
@Test
public void testMEDIWithLoad2() throws Exception {
@DisplayName("Test MEDI With Load 2")
void testMEDIWithLoad2() throws Exception {
ExistingDataSetIterator iter = new ExistingDataSetIterator(new IterableWithoutException(100));
MultipleEpochsIterator iterator = new MultipleEpochsIterator(10, iter, 24);
TestDataSetConsumer consumer = new TestDataSetConsumer(iterator, 2);
long num1 = 0;
for (; num1 < 150; num1++) {
consumer.consumeOnce(iterator.next(), true);
}
iterator.reset();
long num2 = consumer.consumeWhileHasNext(true);
assertEquals((10 * 100) + 150, num1 + num2);
}
@Test
public void testMEDIWithLoad3() throws Exception {
@DisplayName("Test MEDI With Load 3")
void testMEDIWithLoad3() throws Exception {
ExistingDataSetIterator iter = new ExistingDataSetIterator(new IterableWithoutException(10000));
MultipleEpochsIterator iterator = new MultipleEpochsIterator(iter, 24, 136);
TestDataSetConsumer consumer = new TestDataSetConsumer(iterator, 2);
long num1 = 0;
while (iterator.hasNext()) {
consumer.consumeOnce(iterator.next(), true);
num1++;
}
assertEquals(136, num1);
}
@DisplayName("Iterable Without Exception")
private class IterableWithoutException implements Iterable<DataSet> {
private final AtomicLong counter = new AtomicLong(0);
private final int datasets;
public IterableWithoutException(int datasets) {
@ -161,6 +155,7 @@ public class MultipleEpochsIteratorTest extends BaseDL4JTest {
public Iterator<DataSet> iterator() {
counter.set(0);
return new Iterator<DataSet>() {
@Override
public boolean hasNext() {
return counter.get() < datasets;
@ -174,7 +169,6 @@ public class MultipleEpochsIteratorTest extends BaseDL4JTest {
@Override
public void remove() {
}
};
}

View File

@ -17,36 +17,34 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.datasets.iterator;
import org.deeplearning4j.BaseDL4JTest;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
public class RandomDataSetIteratorTest extends BaseDL4JTest {
@DisplayName("Random Data Set Iterator Test")
class RandomDataSetIteratorTest extends BaseDL4JTest {
@Test
public void testDSI(){
DataSetIterator iter = new RandomDataSetIterator(5, new long[]{3,4}, new long[]{3,5}, RandomDataSetIterator.Values.RANDOM_UNIFORM,
RandomDataSetIterator.Values.ONE_HOT);
@DisplayName("Test DSI")
void testDSI() {
DataSetIterator iter = new RandomDataSetIterator(5, new long[] { 3, 4 }, new long[] { 3, 5 }, RandomDataSetIterator.Values.RANDOM_UNIFORM, RandomDataSetIterator.Values.ONE_HOT);
int count = 0;
while(iter.hasNext()){
while (iter.hasNext()) {
count++;
DataSet ds = iter.next();
assertArrayEquals(new long[]{3,4}, ds.getFeatures().shape());
assertArrayEquals(new long[]{3,5}, ds.getLabels().shape());
assertArrayEquals(new long[] { 3, 4 }, ds.getFeatures().shape());
assertArrayEquals(new long[] { 3, 5 }, ds.getLabels().shape());
assertTrue(ds.getFeatures().minNumber().doubleValue() >= 0.0 && ds.getFeatures().maxNumber().doubleValue() <= 1.0);
assertEquals(Nd4j.ones(3), ds.getLabels().sum(1));
}
@ -54,31 +52,23 @@ public class RandomDataSetIteratorTest extends BaseDL4JTest {
}
@Test
public void testMDSI(){
@DisplayName("Test MDSI")
void testMDSI() {
Nd4j.getRandom().setSeed(12345);
MultiDataSetIterator iter = new RandomMultiDataSetIterator.Builder(5)
.addFeatures(new long[]{3,4}, RandomMultiDataSetIterator.Values.INTEGER_0_100)
.addFeatures(new long[]{3,5}, RandomMultiDataSetIterator.Values.BINARY)
.addLabels(new long[]{3,6}, RandomMultiDataSetIterator.Values.ZEROS)
.build();
MultiDataSetIterator iter = new RandomMultiDataSetIterator.Builder(5).addFeatures(new long[] { 3, 4 }, RandomMultiDataSetIterator.Values.INTEGER_0_100).addFeatures(new long[] { 3, 5 }, RandomMultiDataSetIterator.Values.BINARY).addLabels(new long[] { 3, 6 }, RandomMultiDataSetIterator.Values.ZEROS).build();
int count = 0;
while(iter.hasNext()){
while (iter.hasNext()) {
count++;
MultiDataSet mds = iter.next();
assertEquals(2, mds.numFeatureArrays());
assertEquals(1, mds.numLabelsArrays());
assertArrayEquals(new long[]{3,4}, mds.getFeatures(0).shape());
assertArrayEquals(new long[]{3,5}, mds.getFeatures(1).shape());
assertArrayEquals(new long[]{3,6}, mds.getLabels(0).shape());
assertTrue(mds.getFeatures(0).minNumber().doubleValue() >= 0 && mds.getFeatures(0).maxNumber().doubleValue() <= 100.0
&& mds.getFeatures(0).maxNumber().doubleValue() > 2.0);
assertArrayEquals(new long[] { 3, 4 }, mds.getFeatures(0).shape());
assertArrayEquals(new long[] { 3, 5 }, mds.getFeatures(1).shape());
assertArrayEquals(new long[] { 3, 6 }, mds.getLabels(0).shape());
assertTrue(mds.getFeatures(0).minNumber().doubleValue() >= 0 && mds.getFeatures(0).maxNumber().doubleValue() <= 100.0 && mds.getFeatures(0).maxNumber().doubleValue() > 2.0);
assertTrue(mds.getFeatures(1).minNumber().doubleValue() == 0.0 && mds.getFeatures(1).maxNumber().doubleValue() == 1.0);
assertEquals(0.0, mds.getLabels(0).sumNumber().doubleValue(), 0.0);
}
assertEquals(5, count);
}
}

View File

@ -17,27 +17,28 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.datasets.iterator;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
/**
* @author Adam Gibson
*/
public class SamplingTest extends BaseDL4JTest {
@DisplayName("Sampling Test")
class SamplingTest extends BaseDL4JTest {
@Test
public void testSample() throws Exception {
@DisplayName("Test Sample")
void testSample() throws Exception {
DataSetIterator iter = new MnistDataSetIterator(10, 10);
//batch size and total
// batch size and total
DataSetIterator sampling = new SamplingDataSetIterator(iter.next(), 10, 10);
assertEquals(10, sampling.next().numExamples());
}
}

View File

@ -17,50 +17,46 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.eval;
import org.deeplearning4j.BaseDL4JTest;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.evaluation.curves.Histogram;
import org.nd4j.evaluation.curves.PrecisionRecallCurve;
import org.nd4j.evaluation.curves.RocCurve;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
import org.nd4j.linalg.factory.Nd4j;
import static junit.framework.TestCase.assertNull;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
public class EvalJsonTest extends BaseDL4JTest {
@DisplayName("Eval Json Test")
class EvalJsonTest extends BaseDL4JTest {
@Test
public void testSerdeEmpty() {
@DisplayName("Test Serde Empty")
void testSerdeEmpty() {
boolean print = false;
org.nd4j.evaluation.IEvaluation[] arr = new org.nd4j.evaluation.IEvaluation[] {new Evaluation(), new EvaluationBinary(), new ROCBinary(10),
new ROCMultiClass(10), new RegressionEvaluation(3), new RegressionEvaluation(),
new EvaluationCalibration()};
org.nd4j.evaluation.IEvaluation[] arr = new org.nd4j.evaluation.IEvaluation[] { new Evaluation(), new EvaluationBinary(), new ROCBinary(10), new ROCMultiClass(10), new RegressionEvaluation(3), new RegressionEvaluation(), new EvaluationCalibration() };
for (org.nd4j.evaluation.IEvaluation e : arr) {
String json = e.toJson();
String stats = e.stats();
if (print) {
System.out.println(e.getClass() + "\n" + json + "\n\n");
}
IEvaluation fromJson = (IEvaluation) org.nd4j.evaluation.BaseEvaluation.fromJson(json, org.nd4j.evaluation.BaseEvaluation.class);
assertEquals(e.toJson(), fromJson.toJson());
}
}
@Test
public void testSerde() {
@DisplayName("Test Serde")
void testSerde() {
boolean print = false;
Nd4j.getRandom().setSeed(12345);
Evaluation evaluation = new Evaluation();
EvaluationBinary evaluationBinary = new EvaluationBinary();
ROC roc = new ROC(2);
@ -68,56 +64,43 @@ public class EvalJsonTest extends BaseDL4JTest {
ROCMultiClass roc3 = new ROCMultiClass(2);
RegressionEvaluation regressionEvaluation = new RegressionEvaluation();
EvaluationCalibration ec = new EvaluationCalibration();
org.nd4j.evaluation.IEvaluation[] arr = new org.nd4j.evaluation.IEvaluation[] {evaluation, evaluationBinary, roc, roc2, roc3, regressionEvaluation, ec};
org.nd4j.evaluation.IEvaluation[] arr = new org.nd4j.evaluation.IEvaluation[] { evaluation, evaluationBinary, roc, roc2, roc3, regressionEvaluation, ec };
INDArray evalLabel = Nd4j.create(10, 3);
for (int i = 0; i < 10; i++) {
evalLabel.putScalar(i, i % 3, 1.0);
}
INDArray evalProb = Nd4j.rand(10, 3);
evalProb.diviColumnVector(evalProb.sum(true,1));
evalProb.diviColumnVector(evalProb.sum(true, 1));
evaluation.eval(evalLabel, evalProb);
roc3.eval(evalLabel, evalProb);
ec.eval(evalLabel, evalProb);
evalLabel = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(10, 3), 0.5));
evalProb = Nd4j.rand(10, 3);
evaluationBinary.eval(evalLabel, evalProb);
roc2.eval(evalLabel, evalProb);
evalLabel = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(10, 1), 0.5));
evalProb = Nd4j.rand(10, 1);
roc.eval(evalLabel, evalProb);
regressionEvaluation.eval(Nd4j.rand(10, 3), Nd4j.rand(10, 3));
for (org.nd4j.evaluation.IEvaluation e : arr) {
String json = e.toJson();
if (print) {
System.out.println(e.getClass() + "\n" + json + "\n\n");
}
IEvaluation fromJson = (IEvaluation) BaseEvaluation.fromJson(json, org.nd4j.evaluation.BaseEvaluation.class);
assertEquals(e.toJson(), fromJson.toJson());
}
}
@Test
public void testSerdeExactRoc() {
@DisplayName("Test Serde Exact Roc")
void testSerdeExactRoc() {
Nd4j.getRandom().setSeed(12345);
boolean print = false;
ROC roc = new ROC(0);
ROCBinary roc2 = new ROCBinary(0);
ROCMultiClass roc3 = new ROCMultiClass(0);
org.nd4j.evaluation.IEvaluation[] arr = new org.nd4j.evaluation.IEvaluation[] {roc, roc2, roc3};
org.nd4j.evaluation.IEvaluation[] arr = new org.nd4j.evaluation.IEvaluation[] { roc, roc2, roc3 };
INDArray evalLabel = Nd4j.create(100, 3);
for (int i = 0; i < 100; i++) {
evalLabel.putScalar(i, i % 3, 1.0);
@ -125,15 +108,12 @@ public class EvalJsonTest extends BaseDL4JTest {
INDArray evalProb = Nd4j.rand(100, 3);
evalProb.diviColumnVector(evalProb.sum(1));
roc3.eval(evalLabel, evalProb);
evalLabel = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(100, 3), 0.5));
evalProb = Nd4j.rand(100, 3);
roc2.eval(evalLabel, evalProb);
evalLabel = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(100, 1), 0.5));
evalProb = Nd4j.rand(100, 1);
roc.eval(evalLabel, evalProb);
for (org.nd4j.evaluation.IEvaluation e : arr) {
System.out.println(e.getClass());
String json = e.toJson();
@ -143,37 +123,34 @@ public class EvalJsonTest extends BaseDL4JTest {
}
org.nd4j.evaluation.IEvaluation fromJson = BaseEvaluation.fromJson(json, org.nd4j.evaluation.BaseEvaluation.class);
assertEquals(e, fromJson);
if (fromJson instanceof ROC) {
//Shouldn't have probAndLabel, but should have stored AUC and AUPRC
// Shouldn't have probAndLabel, but should have stored AUC and AUPRC
assertNull(((ROC) fromJson).getProbAndLabel());
assertTrue(((ROC) fromJson).calculateAUC() > 0.0);
assertTrue(((ROC) fromJson).calculateAUCPR() > 0.0);
assertEquals(((ROC) e).getRocCurve(), ((ROC) fromJson).getRocCurve());
assertEquals(((ROC) e).getPrecisionRecallCurve(), ((ROC) fromJson).getPrecisionRecallCurve());
} else if (e instanceof ROCBinary) {
org.nd4j.evaluation.classification.ROC[] rocs = ((ROCBinary) fromJson).getUnderlying();
org.nd4j.evaluation.classification.ROC[] origRocs = ((ROCBinary) e).getUnderlying();
// for(ROC r : rocs ){
// for(ROC r : rocs ){
for (int i = 0; i < origRocs.length; i++) {
org.nd4j.evaluation.classification.ROC r = rocs[i];
org.nd4j.evaluation.classification.ROC origR = origRocs[i];
//Shouldn't have probAndLabel, but should have stored AUC and AUPRC, AND stored curves
// Shouldn't have probAndLabel, but should have stored AUC and AUPRC, AND stored curves
assertNull(r.getProbAndLabel());
assertEquals(origR.calculateAUC(), origR.calculateAUC(), 1e-6);
assertEquals(origR.calculateAUCPR(), origR.calculateAUCPR(), 1e-6);
assertEquals(origR.getRocCurve(), origR.getRocCurve());
assertEquals(origR.getPrecisionRecallCurve(), origR.getPrecisionRecallCurve());
}
} else if (e instanceof ROCMultiClass) {
org.nd4j.evaluation.classification.ROC[] rocs = ((ROCMultiClass) fromJson).getUnderlying();
org.nd4j.evaluation.classification.ROC[] origRocs = ((ROCMultiClass) e).getUnderlying();
for (int i = 0; i < origRocs.length; i++) {
org.nd4j.evaluation.classification.ROC r = rocs[i];
org.nd4j.evaluation.classification.ROC origR = origRocs[i];
//Shouldn't have probAndLabel, but should have stored AUC and AUPRC, AND stored curves
// Shouldn't have probAndLabel, but should have stored AUC and AUPRC, AND stored curves
assertNull(r.getProbAndLabel());
assertEquals(origR.calculateAUC(), origR.calculateAUC(), 1e-6);
assertEquals(origR.calculateAUCPR(), origR.calculateAUCPR(), 1e-6);
@ -185,32 +162,23 @@ public class EvalJsonTest extends BaseDL4JTest {
}
@Test
public void testJsonYamlCurves() {
@DisplayName("Test Json Yaml Curves")
void testJsonYamlCurves() {
ROC roc = new ROC(0);
INDArray evalLabel =
Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(100, 1), 0.5));
INDArray evalLabel = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(100, 1), 0.5));
INDArray evalProb = Nd4j.rand(100, 1);
roc.eval(evalLabel, evalProb);
RocCurve c = roc.getRocCurve();
PrecisionRecallCurve prc = roc.getPrecisionRecallCurve();
String json1 = c.toJson();
String json2 = prc.toJson();
RocCurve c2 = RocCurve.fromJson(json1);
PrecisionRecallCurve prc2 = PrecisionRecallCurve.fromJson(json2);
assertEquals(c, c2);
assertEquals(prc, prc2);
// System.out.println(json1);
//Also test: histograms
// System.out.println(json1);
// Also test: histograms
EvaluationCalibration ec = new EvaluationCalibration();
evalLabel = Nd4j.create(10, 3);
for (int i = 0; i < 10; i++) {
evalLabel.putScalar(i, i % 3, 1.0);
@ -218,67 +186,45 @@ public class EvalJsonTest extends BaseDL4JTest {
evalProb = Nd4j.rand(10, 3);
evalProb.diviColumnVector(evalProb.sum(1));
ec.eval(evalLabel, evalProb);
Histogram[] histograms = new Histogram[] {ec.getResidualPlotAllClasses(), ec.getResidualPlot(0),
ec.getResidualPlot(1), ec.getProbabilityHistogramAllClasses(), ec.getProbabilityHistogram(0),
ec.getProbabilityHistogram(1)};
Histogram[] histograms = new Histogram[] { ec.getResidualPlotAllClasses(), ec.getResidualPlot(0), ec.getResidualPlot(1), ec.getProbabilityHistogramAllClasses(), ec.getProbabilityHistogram(0), ec.getProbabilityHistogram(1) };
for (Histogram h : histograms) {
String json = h.toJson();
String yaml = h.toYaml();
Histogram h2 = Histogram.fromJson(json);
Histogram h3 = Histogram.fromYaml(yaml);
assertEquals(h, h2);
assertEquals(h2, h3);
}
}
@Test
public void testJsonWithCustomThreshold() {
//Evaluation - binary threshold
@DisplayName("Test Json With Custom Threshold")
void testJsonWithCustomThreshold() {
// Evaluation - binary threshold
Evaluation e = new Evaluation(0.25);
String json = e.toJson();
String yaml = e.toYaml();
Evaluation eFromJson = Evaluation.fromJson(json);
Evaluation eFromYaml = Evaluation.fromYaml(yaml);
assertEquals(0.25, eFromJson.getBinaryDecisionThreshold(), 1e-6);
assertEquals(0.25, eFromYaml.getBinaryDecisionThreshold(), 1e-6);
//Evaluation: custom cost array
INDArray costArray = Nd4j.create(new double[] {1.0, 2.0, 3.0});
// Evaluation: custom cost array
INDArray costArray = Nd4j.create(new double[] { 1.0, 2.0, 3.0 });
Evaluation e2 = new Evaluation(costArray);
json = e2.toJson();
yaml = e2.toYaml();
eFromJson = Evaluation.fromJson(json);
eFromYaml = Evaluation.fromYaml(yaml);
assertEquals(e2.getCostArray(), eFromJson.getCostArray());
assertEquals(e2.getCostArray(), eFromYaml.getCostArray());
//EvaluationBinary - per-output binary threshold
INDArray threshold = Nd4j.create(new double[] {1.0, 0.5, 0.25});
// EvaluationBinary - per-output binary threshold
INDArray threshold = Nd4j.create(new double[] { 1.0, 0.5, 0.25 });
EvaluationBinary eb = new EvaluationBinary(threshold);
json = eb.toJson();
yaml = eb.toYaml();
EvaluationBinary ebFromJson = EvaluationBinary.fromJson(json);
EvaluationBinary ebFromYaml = EvaluationBinary.fromYaml(yaml);
assertEquals(threshold, ebFromJson.getDecisionThreshold());
assertEquals(threshold, ebFromYaml.getDecisionThreshold());
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.eval;
import org.datavec.api.records.metadata.RecordMetaData;
@ -45,7 +44,7 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.EvaluativeListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
@ -58,78 +57,60 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.common.resources.Resources;
import java.util.*;
import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.*;
public class EvalTest extends BaseDL4JTest {
@DisplayName("Eval Test")
class EvalTest extends BaseDL4JTest {
@Test
public void testIris() {
@DisplayName("Test Iris")
void testIris() {
// Network config
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).seed(42)
.updater(new Sgd(1e-6)).list()
.layer(0, new DenseLayer.Builder().nIn(4).nOut(2).activation(Activation.TANH)
.weightInit(WeightInit.XAVIER).build())
.layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3).weightInit(WeightInit.XAVIER)
.activation(Activation.SOFTMAX).build())
.build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).seed(42).updater(new Sgd(1e-6)).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(2).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()).layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build()).build();
// Instantiate model
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
model.addListeners(new ScoreIterationListener(1));
// Train-test split
DataSetIterator iter = new IrisDataSetIterator(150, 150);
DataSet next = iter.next();
next.shuffle();
SplitTestAndTrain trainTest = next.splitTestAndTrain(5, new Random(42));
// Train
DataSet train = trainTest.getTrain();
train.normalizeZeroMeanZeroUnitVariance();
// Test
DataSet test = trainTest.getTest();
test.normalizeZeroMeanZeroUnitVariance();
INDArray testFeature = test.getFeatures();
INDArray testLabel = test.getLabels();
// Fitting model
model.fit(train);
// Get predictions from test feature
INDArray testPredictedLabel = model.output(testFeature);
// Eval with class number
org.nd4j.evaluation.classification.Evaluation eval = new org.nd4j.evaluation.classification.Evaluation(3); //// Specify class num here
// // Specify class num here
org.nd4j.evaluation.classification.Evaluation eval = new org.nd4j.evaluation.classification.Evaluation(3);
eval.eval(testLabel, testPredictedLabel);
double eval1F1 = eval.f1();
double eval1Acc = eval.accuracy();
// Eval without class number
org.nd4j.evaluation.classification.Evaluation eval2 = new org.nd4j.evaluation.classification.Evaluation(); //// No class num
// // No class num
org.nd4j.evaluation.classification.Evaluation eval2 = new org.nd4j.evaluation.classification.Evaluation();
eval2.eval(testLabel, testPredictedLabel);
double eval2F1 = eval2.f1();
double eval2Acc = eval2.accuracy();
//Assert the two implementations give same f1 and accuracy (since one batch)
// Assert the two implementations give same f1 and accuracy (since one batch)
assertTrue(eval1F1 == eval2F1 && eval1Acc == eval2Acc);
org.nd4j.evaluation.classification.Evaluation evalViaMethod = model.evaluate(new ListDataSetIterator<>(Collections.singletonList(test)));
checkEvaluationEquality(eval, evalViaMethod);
// System.out.println(eval.getConfusionMatrix().toString());
// System.out.println(eval.getConfusionMatrix().toCSV());
// System.out.println(eval.getConfusionMatrix().toHTML());
// System.out.println(eval.confusionToString());
// System.out.println(eval.getConfusionMatrix().toString());
// System.out.println(eval.getConfusionMatrix().toCSV());
// System.out.println(eval.getConfusionMatrix().toHTML());
// System.out.println(eval.confusionToString());
eval.getConfusionMatrix().toString();
eval.getConfusionMatrix().toCSV();
eval.getConfusionMatrix().toHTML();
@ -160,99 +141,79 @@ public class EvalTest extends BaseDL4JTest {
}
@Test
public void testEvaluationWithMetaData() throws Exception {
@DisplayName("Test Evaluation With Meta Data")
void testEvaluationWithMetaData() throws Exception {
RecordReader csv = new CSVRecordReader();
csv.initialize(new FileSplit(Resources.asFile("iris.txt")));
int batchSize = 10;
int labelIdx = 4;
int numClasses = 3;
RecordReaderDataSetIterator rrdsi = new RecordReaderDataSetIterator(csv, batchSize, labelIdx, numClasses);
NormalizerStandardize ns = new NormalizerStandardize();
ns.fit(rrdsi);
rrdsi.setPreProcessor(ns);
rrdsi.reset();
Nd4j.getRandom().setSeed(12345);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Sgd(0.1))
.list()
.layer(0, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(4).nOut(3).build())
.build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Sgd(0.1)).list().layer(0, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(4).nOut(3).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
for (int i = 0; i < 4; i++) {
net.fit(rrdsi);
rrdsi.reset();
}
org.nd4j.evaluation.classification.Evaluation e = new org.nd4j.evaluation.classification.Evaluation();
rrdsi.setCollectMetaData(true); //*** New: Enable collection of metadata (stored in the DataSets) ***
// *** New: Enable collection of metadata (stored in the DataSets) ***
rrdsi.setCollectMetaData(true);
while (rrdsi.hasNext()) {
DataSet ds = rrdsi.next();
List<RecordMetaData> meta = ds.getExampleMetaData(RecordMetaData.class); //*** New - cross dependencies here make types difficult, usid Object internally in DataSet for this***
// *** New - cross dependencies here make types difficult, usid Object internally in DataSet for this***
List<RecordMetaData> meta = ds.getExampleMetaData(RecordMetaData.class);
INDArray out = net.output(ds.getFeatures());
e.eval(ds.getLabels(), out, meta); //*** New - evaluate and also store metadata ***
// *** New - evaluate and also store metadata ***
e.eval(ds.getLabels(), out, meta);
}
// System.out.println(e.stats());
// System.out.println(e.stats());
e.stats();
// System.out.println("\n\n*** Prediction Errors: ***");
List<org.nd4j.evaluation.meta.Prediction> errors = e.getPredictionErrors(); //*** New - get list of prediction errors from evaluation ***
// System.out.println("\n\n*** Prediction Errors: ***");
// *** New - get list of prediction errors from evaluation ***
List<org.nd4j.evaluation.meta.Prediction> errors = e.getPredictionErrors();
List<RecordMetaData> metaForErrors = new ArrayList<>();
for (org.nd4j.evaluation.meta.Prediction p : errors) {
metaForErrors.add((RecordMetaData) p.getRecordMetaData());
}
DataSet ds = rrdsi.loadFromMetaData(metaForErrors); //*** New - dynamically load a subset of the data, just for prediction errors ***
// *** New - dynamically load a subset of the data, just for prediction errors ***
DataSet ds = rrdsi.loadFromMetaData(metaForErrors);
INDArray output = net.output(ds.getFeatures());
int count = 0;
for (org.nd4j.evaluation.meta.Prediction t : errors) {
String s = t + "\t\tRaw Data: "
+ csv.loadFromMetaData((RecordMetaData) t.getRecordMetaData()).getRecord() //*** New - load subset of data from MetaData object (usually batched for efficiency) ***
+ "\tNormalized: " + ds.getFeatures().getRow(count) + "\tLabels: "
+ ds.getLabels().getRow(count) + "\tNetwork predictions: " + output.getRow(count);
// System.out.println(s);
String s = t + "\t\tRaw Data: " + // *** New - load subset of data from MetaData object (usually batched for efficiency) ***
csv.loadFromMetaData((RecordMetaData) t.getRecordMetaData()).getRecord() + "\tNormalized: " + ds.getFeatures().getRow(count) + "\tLabels: " + ds.getLabels().getRow(count) + "\tNetwork predictions: " + output.getRow(count);
// System.out.println(s);
count++;
}
int errorCount = errors.size();
double expAcc = 1.0 - errorCount / 150.0;
assertEquals(expAcc, e.accuracy(), 1e-5);
org.nd4j.evaluation.classification.ConfusionMatrix<Integer> confusion = e.getConfusionMatrix();
int[] actualCounts = new int[3];
int[] predictedCounts = new int[3];
for (int i = 0; i < 3; i++) {
for (int j = 0; j < 3; j++) {
int entry = confusion.getCount(i, j); //(actual,predicted)
// (actual,predicted)
int entry = confusion.getCount(i, j);
List<org.nd4j.evaluation.meta.Prediction> list = e.getPredictions(i, j);
assertEquals(entry, list.size());
actualCounts[i] += entry;
predictedCounts[j] += entry;
}
}
for (int i = 0; i < 3; i++) {
List<org.nd4j.evaluation.meta.Prediction> actualClassI = e.getPredictionsByActualClass(i);
List<org.nd4j.evaluation.meta.Prediction> predictedClassI = e.getPredictionByPredictedClass(i);
assertEquals(actualCounts[i], actualClassI.size());
assertEquals(predictedCounts[i], predictedClassI.size());
}
//Finally: test doEvaluation methods
// Finally: test doEvaluation methods
rrdsi.reset();
org.nd4j.evaluation.classification.Evaluation e2 = new org.nd4j.evaluation.classification.Evaluation();
net.doEvaluation(rrdsi, e2);
@ -262,7 +223,6 @@ public class EvalTest extends BaseDL4JTest {
assertEquals(actualCounts[i], actualClassI.size());
assertEquals(predictedCounts[i], predictedClassI.size());
}
ComputationGraph cg = net.toComputationGraph();
rrdsi.reset();
e2 = new org.nd4j.evaluation.classification.Evaluation();
@ -273,7 +233,6 @@ public class EvalTest extends BaseDL4JTest {
assertEquals(actualCounts[i], actualClassI.size());
assertEquals(predictedCounts[i], predictedClassI.size());
}
}
private static void apply(org.nd4j.evaluation.classification.Evaluation e, int nTimes, INDArray predicted, INDArray actual) {
@ -283,138 +242,28 @@ public class EvalTest extends BaseDL4JTest {
}
@Test
public void testEvalSplitting(){
//Test for "tbptt-like" functionality
for(WorkspaceMode ws : WorkspaceMode.values()) {
@DisplayName("Test Eval Splitting")
void testEvalSplitting() {
// Test for "tbptt-like" functionality
for (WorkspaceMode ws : WorkspaceMode.values()) {
System.out.println("Starting test for workspace mode: " + ws);
int nIn = 4;
int layerSize = 5;
int nOut = 6;
int tbpttLength = 10;
int tsLength = 5 * tbpttLength + tbpttLength / 2;
MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder()
.seed(12345)
.trainingWorkspaceMode(ws)
.inferenceWorkspaceMode(ws)
.list()
.layer(new LSTM.Builder().nIn(nIn).nOut(layerSize).build())
.layer(new RnnOutputLayer.Builder().nIn(layerSize).nOut(nOut)
.activation(Activation.SOFTMAX)
.build())
.build();
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder()
.seed(12345)
.trainingWorkspaceMode(ws)
.inferenceWorkspaceMode(ws)
.list()
.layer(new LSTM.Builder().nIn(nIn).nOut(layerSize).build())
.layer(new RnnOutputLayer.Builder().nIn(layerSize).nOut(nOut)
.activation(Activation.SOFTMAX).build())
.tBPTTLength(10)
.backpropType(BackpropType.TruncatedBPTT)
.build();
MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(12345).trainingWorkspaceMode(ws).inferenceWorkspaceMode(ws).list().layer(new LSTM.Builder().nIn(nIn).nOut(layerSize).build()).layer(new RnnOutputLayer.Builder().nIn(layerSize).nOut(nOut).activation(Activation.SOFTMAX).build()).build();
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345).trainingWorkspaceMode(ws).inferenceWorkspaceMode(ws).list().layer(new LSTM.Builder().nIn(nIn).nOut(layerSize).build()).layer(new RnnOutputLayer.Builder().nIn(layerSize).nOut(nOut).activation(Activation.SOFTMAX).build()).tBPTTLength(10).backpropType(BackpropType.TruncatedBPTT).build();
MultiLayerNetwork net1 = new MultiLayerNetwork(conf1);
net1.init();
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
net2.init();
net2.setParams(net1.params());
for(boolean useMask : new boolean[]{false, true}) {
INDArray in1 = Nd4j.rand(new int[]{3, nIn, tsLength});
for (boolean useMask : new boolean[] { false, true }) {
INDArray in1 = Nd4j.rand(new int[] { 3, nIn, tsLength });
INDArray out1 = TestUtils.randomOneHotTimeSeries(3, nOut, tsLength);
INDArray in2 = Nd4j.rand(new int[]{5, nIn, tsLength});
INDArray in2 = Nd4j.rand(new int[] { 5, nIn, tsLength });
INDArray out2 = TestUtils.randomOneHotTimeSeries(5, nOut, tsLength);
INDArray lMask1 = null;
INDArray lMask2 = null;
if(useMask){
lMask1 = Nd4j.create(3, tsLength);
lMask2 = Nd4j.create(5, tsLength);
Nd4j.getExecutioner().exec(new BernoulliDistribution(lMask1, 0.5));
Nd4j.getExecutioner().exec(new BernoulliDistribution(lMask2, 0.5));
}
List<DataSet> l = Arrays.asList(new DataSet(in1, out1, null, lMask1), new DataSet(in2, out2, null, lMask2));
DataSetIterator iter = new ExistingDataSetIterator(l);
// System.out.println("Net 1 eval");
org.nd4j.evaluation.IEvaluation[] e1 = net1.doEvaluation(iter, new org.nd4j.evaluation.classification.Evaluation(), new org.nd4j.evaluation.classification.ROCMultiClass(), new org.nd4j.evaluation.regression.RegressionEvaluation());
// System.out.println("Net 2 eval");
org.nd4j.evaluation.IEvaluation[] e2 = net2.doEvaluation(iter, new org.nd4j.evaluation.classification.Evaluation(), new org.nd4j.evaluation.classification.ROCMultiClass(), new org.nd4j.evaluation.regression.RegressionEvaluation());
assertEquals(e1[0], e2[0]);
assertEquals(e1[1], e2[1]);
assertEquals(e1[2], e2[2]);
}
}
}
@Test
public void testEvalSplittingCompGraph(){
//Test for "tbptt-like" functionality
for(WorkspaceMode ws : WorkspaceMode.values()) {
System.out.println("Starting test for workspace mode: " + ws);
int nIn = 4;
int layerSize = 5;
int nOut = 6;
int tbpttLength = 10;
int tsLength = 5 * tbpttLength + tbpttLength / 2;
ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder()
.seed(12345)
.trainingWorkspaceMode(ws)
.inferenceWorkspaceMode(ws)
.graphBuilder()
.addInputs("in")
.addLayer("0", new LSTM.Builder().nIn(nIn).nOut(layerSize).build(), "in")
.addLayer("1", new RnnOutputLayer.Builder().nIn(layerSize).nOut(nOut)
.activation(Activation.SOFTMAX)
.build(), "0")
.setOutputs("1")
.build();
ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder()
.seed(12345)
.trainingWorkspaceMode(ws)
.inferenceWorkspaceMode(ws)
.graphBuilder()
.addInputs("in")
.addLayer("0", new LSTM.Builder().nIn(nIn).nOut(layerSize).build(), "in")
.addLayer("1", new RnnOutputLayer.Builder().nIn(layerSize).nOut(nOut)
.activation(Activation.SOFTMAX)
.build(), "0")
.setOutputs("1")
.tBPTTLength(10)
.backpropType(BackpropType.TruncatedBPTT)
.build();
ComputationGraph net1 = new ComputationGraph(conf1);
net1.init();
ComputationGraph net2 = new ComputationGraph(conf2);
net2.init();
net2.setParams(net1.params());
for (boolean useMask : new boolean[]{false, true}) {
INDArray in1 = Nd4j.rand(new int[]{3, nIn, tsLength});
INDArray out1 = TestUtils.randomOneHotTimeSeries(3, nOut, tsLength);
INDArray in2 = Nd4j.rand(new int[]{5, nIn, tsLength});
INDArray out2 = TestUtils.randomOneHotTimeSeries(5, nOut, tsLength);
INDArray lMask1 = null;
INDArray lMask2 = null;
if (useMask) {
@ -423,15 +272,12 @@ public class EvalTest extends BaseDL4JTest {
Nd4j.getExecutioner().exec(new BernoulliDistribution(lMask1, 0.5));
Nd4j.getExecutioner().exec(new BernoulliDistribution(lMask2, 0.5));
}
List<DataSet> l = Arrays.asList(new DataSet(in1, out1), new DataSet(in2, out2));
List<DataSet> l = Arrays.asList(new DataSet(in1, out1, null, lMask1), new DataSet(in2, out2, null, lMask2));
DataSetIterator iter = new ExistingDataSetIterator(l);
// System.out.println("Eval net 1");
// System.out.println("Net 1 eval");
org.nd4j.evaluation.IEvaluation[] e1 = net1.doEvaluation(iter, new org.nd4j.evaluation.classification.Evaluation(), new org.nd4j.evaluation.classification.ROCMultiClass(), new org.nd4j.evaluation.regression.RegressionEvaluation());
// System.out.println("Eval net 2");
// System.out.println("Net 2 eval");
org.nd4j.evaluation.IEvaluation[] e2 = net2.doEvaluation(iter, new org.nd4j.evaluation.classification.Evaluation(), new org.nd4j.evaluation.classification.ROCMultiClass(), new org.nd4j.evaluation.regression.RegressionEvaluation());
assertEquals(e1[0], e2[0]);
assertEquals(e1[1], e2[1]);
assertEquals(e1[2], e2[2]);
@ -440,192 +286,170 @@ public class EvalTest extends BaseDL4JTest {
}
@Test
public void testEvalSplitting2(){
@DisplayName("Test Eval Splitting Comp Graph")
void testEvalSplittingCompGraph() {
// Test for "tbptt-like" functionality
for (WorkspaceMode ws : WorkspaceMode.values()) {
System.out.println("Starting test for workspace mode: " + ws);
int nIn = 4;
int layerSize = 5;
int nOut = 6;
int tbpttLength = 10;
int tsLength = 5 * tbpttLength + tbpttLength / 2;
ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(12345).trainingWorkspaceMode(ws).inferenceWorkspaceMode(ws).graphBuilder().addInputs("in").addLayer("0", new LSTM.Builder().nIn(nIn).nOut(layerSize).build(), "in").addLayer("1", new RnnOutputLayer.Builder().nIn(layerSize).nOut(nOut).activation(Activation.SOFTMAX).build(), "0").setOutputs("1").build();
ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345).trainingWorkspaceMode(ws).inferenceWorkspaceMode(ws).graphBuilder().addInputs("in").addLayer("0", new LSTM.Builder().nIn(nIn).nOut(layerSize).build(), "in").addLayer("1", new RnnOutputLayer.Builder().nIn(layerSize).nOut(nOut).activation(Activation.SOFTMAX).build(), "0").setOutputs("1").tBPTTLength(10).backpropType(BackpropType.TruncatedBPTT).build();
ComputationGraph net1 = new ComputationGraph(conf1);
net1.init();
ComputationGraph net2 = new ComputationGraph(conf2);
net2.init();
net2.setParams(net1.params());
for (boolean useMask : new boolean[] { false, true }) {
INDArray in1 = Nd4j.rand(new int[] { 3, nIn, tsLength });
INDArray out1 = TestUtils.randomOneHotTimeSeries(3, nOut, tsLength);
INDArray in2 = Nd4j.rand(new int[] { 5, nIn, tsLength });
INDArray out2 = TestUtils.randomOneHotTimeSeries(5, nOut, tsLength);
INDArray lMask1 = null;
INDArray lMask2 = null;
if (useMask) {
lMask1 = Nd4j.create(3, tsLength);
lMask2 = Nd4j.create(5, tsLength);
Nd4j.getExecutioner().exec(new BernoulliDistribution(lMask1, 0.5));
Nd4j.getExecutioner().exec(new BernoulliDistribution(lMask2, 0.5));
}
List<DataSet> l = Arrays.asList(new DataSet(in1, out1), new DataSet(in2, out2));
DataSetIterator iter = new ExistingDataSetIterator(l);
// System.out.println("Eval net 1");
org.nd4j.evaluation.IEvaluation[] e1 = net1.doEvaluation(iter, new org.nd4j.evaluation.classification.Evaluation(), new org.nd4j.evaluation.classification.ROCMultiClass(), new org.nd4j.evaluation.regression.RegressionEvaluation());
// System.out.println("Eval net 2");
org.nd4j.evaluation.IEvaluation[] e2 = net2.doEvaluation(iter, new org.nd4j.evaluation.classification.Evaluation(), new org.nd4j.evaluation.classification.ROCMultiClass(), new org.nd4j.evaluation.regression.RegressionEvaluation());
assertEquals(e1[0], e2[0]);
assertEquals(e1[1], e2[1]);
assertEquals(e1[2], e2[2]);
}
}
}
@Test
@DisplayName("Test Eval Splitting 2")
void testEvalSplitting2() {
List<List<Writable>> seqFeatures = new ArrayList<>();
List<Writable> step = Arrays.<Writable>asList(new FloatWritable(0), new FloatWritable(0), new FloatWritable(0));
for( int i=0; i<30; i++ ){
for (int i = 0; i < 30; i++) {
seqFeatures.add(step);
}
List<List<Writable>> seqLabels = Collections.singletonList(Collections.<Writable>singletonList(new FloatWritable(0)));
SequenceRecordReader fsr = new CollectionSequenceRecordReader(Collections.singletonList(seqFeatures));
SequenceRecordReader lsr = new CollectionSequenceRecordReader(Collections.singletonList(seqLabels));
DataSetIterator testData = new SequenceRecordReaderDataSetIterator(fsr, lsr, 1, -1, true,
SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(123)
.list()
.layer(0, new LSTM.Builder().activation(Activation.TANH).nIn(3).nOut(3).build())
.layer(1, new RnnOutputLayer.Builder().activation(Activation.SIGMOID).lossFunction(LossFunctions.LossFunction.XENT)
.nIn(3).nOut(1).build())
.backpropType(BackpropType.TruncatedBPTT).tBPTTForwardLength(10).tBPTTBackwardLength(10)
.build();
DataSetIterator testData = new SequenceRecordReaderDataSetIterator(fsr, lsr, 1, -1, true, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(123).list().layer(0, new LSTM.Builder().activation(Activation.TANH).nIn(3).nOut(3).build()).layer(1, new RnnOutputLayer.Builder().activation(Activation.SIGMOID).lossFunction(LossFunctions.LossFunction.XENT).nIn(3).nOut(1).build()).backpropType(BackpropType.TruncatedBPTT).tBPTTForwardLength(10).tBPTTBackwardLength(10).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
net.evaluate(testData);
}
@Test
public void testEvaluativeListenerSimple(){
//Sanity check: https://github.com/eclipse/deeplearning4j/issues/5351
@DisplayName("Test Evaluative Listener Simple")
void testEvaluativeListenerSimple() {
// Sanity check: https://github.com/eclipse/deeplearning4j/issues/5351
// Network config
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).seed(42)
.updater(new Sgd(1e-6)).list()
.layer(0, new DenseLayer.Builder().nIn(4).nOut(2).activation(Activation.TANH)
.weightInit(WeightInit.XAVIER).build())
.layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3).weightInit(WeightInit.XAVIER)
.activation(Activation.SOFTMAX).build())
.build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).seed(42).updater(new Sgd(1e-6)).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(2).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()).layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build()).build();
// Instantiate model
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
// Train-test split
DataSetIterator iter = new IrisDataSetIterator(30, 150);
DataSetIterator iterTest = new IrisDataSetIterator(30, 150);
net.setListeners(new EvaluativeListener(iterTest, 3));
for( int i=0; i<3; i++ ){
for (int i = 0; i < 3; i++) {
net.fit(iter);
}
}
@Test
public void testMultiOutputEvalSimple(){
@DisplayName("Test Multi Output Eval Simple")
void testMultiOutputEvalSimple() {
Nd4j.getRandom().setSeed(12345);
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(12345)
.graphBuilder()
.addInputs("in")
.addLayer("out1", new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX).build(), "in")
.addLayer("out2", new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX).build(), "in")
.setOutputs("out1", "out2")
.build();
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).graphBuilder().addInputs("in").addLayer("out1", new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX).build(), "in").addLayer("out2", new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX).build(), "in").setOutputs("out1", "out2").build();
ComputationGraph cg = new ComputationGraph(conf);
cg.init();
List<MultiDataSet> list = new ArrayList<>();
DataSetIterator iter = new IrisDataSetIterator(30, 150);
while(iter.hasNext()){
while (iter.hasNext()) {
DataSet ds = iter.next();
list.add(new org.nd4j.linalg.dataset.MultiDataSet(new INDArray[]{ds.getFeatures()}, new INDArray[]{ds.getLabels(), ds.getLabels()}));
list.add(new org.nd4j.linalg.dataset.MultiDataSet(new INDArray[] { ds.getFeatures() }, new INDArray[] { ds.getLabels(), ds.getLabels() }));
}
org.nd4j.evaluation.classification.Evaluation e = new org.nd4j.evaluation.classification.Evaluation();
org.nd4j.evaluation.regression.RegressionEvaluation e2 = new org.nd4j.evaluation.regression.RegressionEvaluation();
Map<Integer,org.nd4j.evaluation.IEvaluation[]> evals = new HashMap<>();
evals.put(0, new org.nd4j.evaluation.IEvaluation[]{e});
evals.put(1, new org.nd4j.evaluation.IEvaluation[]{e2});
Map<Integer, org.nd4j.evaluation.IEvaluation[]> evals = new HashMap<>();
evals.put(0, new org.nd4j.evaluation.IEvaluation[] { e });
evals.put(1, new org.nd4j.evaluation.IEvaluation[] { e2 });
cg.evaluate(new IteratorMultiDataSetIterator(list.iterator(), 30), evals);
assertEquals(150, e.getNumRowCounter());
assertEquals(150, e2.getExampleCountPerColumn().getInt(0));
}
@Test
public void testMultiOutputEvalCG(){
//Simple sanity check on evaluation
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
.graphBuilder()
.addInputs("in")
.layer("0", new EmbeddingSequenceLayer.Builder().nIn(10).nOut(10).build(), "in")
.layer("1", new LSTM.Builder().nIn(10).nOut(10).build(), "0")
.layer("2", new LSTM.Builder().nIn(10).nOut(10).build(), "0")
.layer("out1", new RnnOutputLayer.Builder().nIn(10).nOut(10).activation(Activation.SOFTMAX).build(), "1")
.layer("out2", new RnnOutputLayer.Builder().nIn(10).nOut(20).activation(Activation.SOFTMAX).build(), "2")
.setOutputs("out1", "out2")
.build();
@DisplayName("Test Multi Output Eval CG")
void testMultiOutputEvalCG() {
// Simple sanity check on evaluation
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in").layer("0", new EmbeddingSequenceLayer.Builder().nIn(10).nOut(10).build(), "in").layer("1", new LSTM.Builder().nIn(10).nOut(10).build(), "0").layer("2", new LSTM.Builder().nIn(10).nOut(10).build(), "0").layer("out1", new RnnOutputLayer.Builder().nIn(10).nOut(10).activation(Activation.SOFTMAX).build(), "1").layer("out2", new RnnOutputLayer.Builder().nIn(10).nOut(20).activation(Activation.SOFTMAX).build(), "2").setOutputs("out1", "out2").build();
ComputationGraph cg = new ComputationGraph(conf);
cg.init();
org.nd4j.linalg.dataset.MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(
new INDArray[]{Nd4j.create(10, 1, 10)},
new INDArray[]{Nd4j.create(10, 10, 10), Nd4j.create(10, 20, 10)});
Map<Integer,org.nd4j.evaluation.IEvaluation[]> m = new HashMap<>();
m.put(0, new org.nd4j.evaluation.IEvaluation[]{new org.nd4j.evaluation.classification.Evaluation()});
m.put(1, new org.nd4j.evaluation.IEvaluation[]{new org.nd4j.evaluation.classification.Evaluation()});
org.nd4j.linalg.dataset.MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(new INDArray[] { Nd4j.create(10, 1, 10) }, new INDArray[] { Nd4j.create(10, 10, 10), Nd4j.create(10, 20, 10) });
Map<Integer, org.nd4j.evaluation.IEvaluation[]> m = new HashMap<>();
m.put(0, new org.nd4j.evaluation.IEvaluation[] { new org.nd4j.evaluation.classification.Evaluation() });
m.put(1, new org.nd4j.evaluation.IEvaluation[] { new org.nd4j.evaluation.classification.Evaluation() });
cg.evaluate(new SingletonMultiDataSetIterator(mds), m);
}
@Test
public void testInvalidEvaluation(){
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.list()
.layer(new DenseLayer.Builder().nIn(4).nOut(10).build())
.layer(new OutputLayer.Builder().nIn(10).nOut(3).lossFunction(LossFunctions.LossFunction.MSE).activation(Activation.RELU).build())
.build();
@DisplayName("Test Invalid Evaluation")
void testInvalidEvaluation() {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(new DenseLayer.Builder().nIn(4).nOut(10).build()).layer(new OutputLayer.Builder().nIn(10).nOut(3).lossFunction(LossFunctions.LossFunction.MSE).activation(Activation.RELU).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
DataSetIterator iter = new IrisDataSetIterator(150, 150);
try {
net.evaluate(iter);
fail("Expected exception");
} catch (IllegalStateException e){
} catch (IllegalStateException e) {
assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("Evaluation"));
}
try {
net.evaluateROC(iter, 0);
fail("Expected exception");
} catch (IllegalStateException e){
} catch (IllegalStateException e) {
assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("ROC"));
}
try {
net.evaluateROCMultiClass(iter, 0);
fail("Expected exception");
} catch (IllegalStateException e){
} catch (IllegalStateException e) {
assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("ROCMultiClass"));
}
ComputationGraph cg = net.toComputationGraph();
try {
cg.evaluate(iter);
fail("Expected exception");
} catch (IllegalStateException e){
} catch (IllegalStateException e) {
assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("Evaluation"));
}
try {
cg.evaluateROC(iter, 0);
fail("Expected exception");
} catch (IllegalStateException e){
} catch (IllegalStateException e) {
assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("ROC"));
}
try {
cg.evaluateROCMultiClass(iter, 0);
fail("Expected exception");
} catch (IllegalStateException e){
} catch (IllegalStateException e) {
assertTrue(e.getMessage().contains("Classifier") && e.getMessage().contains("ROCMultiClass"));
}
//Disable validation, and check same thing:
// Disable validation, and check same thing:
net.getLayerWiseConfigurations().setValidateOutputLayerConfig(false);
net.evaluate(iter);
net.evaluateROCMultiClass(iter, 0);
cg.getConfiguration().setValidateOutputLayerConfig(false);
cg.evaluate(iter);
cg.evaluateROCMultiClass(iter, 0);

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.eval;
import org.deeplearning4j.BaseDL4JTest;
@ -28,48 +27,53 @@ import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.Test;
import org.nd4j.evaluation.curves.PrecisionRecallCurve;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import org.nd4j.evaluation.curves.RocCurve;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.*;
import java.util.HashMap;
import java.util.Map;
import static org.junit.Assert.*;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class ROCTest extends BaseDL4JTest {
@DisplayName("Roc Test")
class ROCTest extends BaseDL4JTest {
private static Map<Double, Double> expTPR;
private static Map<Double, Double> expFPR;
static {
expTPR = new HashMap<>();
double totalPositives = 5.0;
expTPR.put(0 / 10.0, 5.0 / totalPositives); //All 10 predicted as class 1, of which 5 of 5 are correct
// All 10 predicted as class 1, of which 5 of 5 are correct
expTPR.put(0 / 10.0, 5.0 / totalPositives);
expTPR.put(1 / 10.0, 5.0 / totalPositives);
expTPR.put(2 / 10.0, 5.0 / totalPositives);
expTPR.put(3 / 10.0, 5.0 / totalPositives);
expTPR.put(4 / 10.0, 5.0 / totalPositives);
expTPR.put(5 / 10.0, 5.0 / totalPositives);
expTPR.put(6 / 10.0, 4.0 / totalPositives); //Threshold: 0.4 -> last 4 predicted; last 5 actual
// Threshold: 0.4 -> last 4 predicted; last 5 actual
expTPR.put(6 / 10.0, 4.0 / totalPositives);
expTPR.put(7 / 10.0, 3.0 / totalPositives);
expTPR.put(8 / 10.0, 2.0 / totalPositives);
expTPR.put(9 / 10.0, 1.0 / totalPositives);
expTPR.put(10 / 10.0, 0.0 / totalPositives);
expFPR = new HashMap<>();
double totalNegatives = 5.0;
expFPR.put(0 / 10.0, 5.0 / totalNegatives); //All 10 predicted as class 1, but all 5 true negatives are predicted positive
expFPR.put(1 / 10.0, 4.0 / totalNegatives); //1 true negative is predicted as negative; 4 false positives
expFPR.put(2 / 10.0, 3.0 / totalNegatives); //2 true negatives are predicted as negative; 3 false positives
// All 10 predicted as class 1, but all 5 true negatives are predicted positive
expFPR.put(0 / 10.0, 5.0 / totalNegatives);
// 1 true negative is predicted as negative; 4 false positives
expFPR.put(1 / 10.0, 4.0 / totalNegatives);
// 2 true negatives are predicted as negative; 3 false positives
expFPR.put(2 / 10.0, 3.0 / totalNegatives);
expFPR.put(3 / 10.0, 2.0 / totalNegatives);
expFPR.put(4 / 10.0, 1.0 / totalNegatives);
expFPR.put(5 / 10.0, 0.0 / totalNegatives);
@ -81,56 +85,41 @@ public class ROCTest extends BaseDL4JTest {
}
@Test
public void RocEvalSanityCheck() {
@DisplayName("Roc Eval Sanity Check")
void RocEvalSanityCheck() {
DataSetIterator iter = new IrisDataSetIterator(150, 150);
Nd4j.getRandom().setSeed(12345);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).seed(12345)
.list()
.layer(0, new DenseLayer.Builder().nIn(4).nOut(4).activation(Activation.TANH).build()).layer(1,
new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).seed(12345).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(4).activation(Activation.TANH).build()).layer(1, new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
NormalizerStandardize ns = new NormalizerStandardize();
DataSet ds = iter.next();
ns.fit(ds);
ns.transform(ds);
iter.setPreProcessor(ns);
for (int i = 0; i < 10; i++) {
net.fit(ds);
}
for (int steps : new int[] {32, 0}) { //Steps = 0: exact
for (int steps : new int[] { 32, 0 }) {
// Steps = 0: exact
System.out.println("steps: " + steps);
iter.reset();
ds = iter.next();
INDArray f = ds.getFeatures();
INDArray l = ds.getLabels();
INDArray out = net.output(f);
// System.out.println(f);
// System.out.println(out);
// System.out.println(f);
// System.out.println(out);
ROCMultiClass manual = new ROCMultiClass(steps);
manual.eval(l, out);
iter.reset();
ROCMultiClass roc = net.evaluateROCMultiClass(iter, steps);
for (int i = 0; i < 3; i++) {
double rocExp = manual.calculateAUC(i);
double rocAct = roc.calculateAUC(i);
assertEquals(rocExp, rocAct, 1e-6);
RocCurve rc = roc.getRocCurve(i);
RocCurve rm = manual.getRocCurve(i);
assertEquals(rc, rm);
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.eval;
import org.deeplearning4j.BaseDL4JTest;
@ -29,59 +28,43 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.Collections;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.nd4j.linalg.indexing.NDArrayIndex.all;
import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
public class RegressionEvalTest extends BaseDL4JTest {
@DisplayName("Regression Eval Test")
class RegressionEvalTest extends BaseDL4JTest {
@Test
public void testRegressionEvalMethods() {
//Basic sanity check
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.ZERO).list()
.layer(0, new OutputLayer.Builder().activation(Activation.TANH)
.lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(5).build())
.build();
@DisplayName("Test Regression Eval Methods")
void testRegressionEvalMethods() {
// Basic sanity check
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.ZERO).list().layer(0, new OutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(5).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
INDArray f = Nd4j.zeros(4, 10);
INDArray l = Nd4j.ones(4, 5);
DataSet ds = new DataSet(f, l);
DataSetIterator iter = new ExistingDataSetIterator(Collections.singletonList(ds));
org.nd4j.evaluation.regression.RegressionEvaluation re = net.evaluateRegression(iter);
for (int i = 0; i < 5; i++) {
assertEquals(1.0, re.meanSquaredError(i), 1e-6);
assertEquals(1.0, re.meanAbsoluteError(i), 1e-6);
}
ComputationGraphConfiguration graphConf =
new NeuralNetConfiguration.Builder().weightInit(WeightInit.ZERO).graphBuilder()
.addInputs("in").addLayer("0", new OutputLayer.Builder()
.lossFunction(LossFunctions.LossFunction.MSE)
.activation(Activation.TANH).nIn(10).nOut(5).build(), "in")
.setOutputs("0").build();
ComputationGraphConfiguration graphConf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.ZERO).graphBuilder().addInputs("in").addLayer("0", new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(10).nOut(5).build(), "in").setOutputs("0").build();
ComputationGraph cg = new ComputationGraph(graphConf);
cg.init();
RegressionEvaluation re2 = cg.evaluateRegression(iter);
for (int i = 0; i < 5; i++) {
assertEquals(1.0, re2.meanSquaredError(i), 1e-6);
assertEquals(1.0, re2.meanAbsoluteError(i), 1e-6);
@ -89,25 +72,16 @@ public class RegressionEvalTest extends BaseDL4JTest {
}
@Test
public void testRegressionEvalPerOutputMasking() {
INDArray l = Nd4j.create(new double[][] {{1, 2, 3}, {10, 20, 30}, {-5, -10, -20}});
@DisplayName("Test Regression Eval Per Output Masking")
void testRegressionEvalPerOutputMasking() {
INDArray l = Nd4j.create(new double[][] { { 1, 2, 3 }, { 10, 20, 30 }, { -5, -10, -20 } });
INDArray predictions = Nd4j.zeros(l.shape());
INDArray mask = Nd4j.create(new double[][] {{0, 1, 1}, {1, 1, 0}, {0, 1, 0}});
INDArray mask = Nd4j.create(new double[][] { { 0, 1, 1 }, { 1, 1, 0 }, { 0, 1, 0 } });
RegressionEvaluation re = new RegressionEvaluation();
re.eval(l, predictions, mask);
double[] mse = new double[] {(10 * 10) / 1.0, (2 * 2 + 20 * 20 + 10 * 10) / 3, (3 * 3) / 1.0};
double[] mae = new double[] {10.0, (2 + 20 + 10) / 3.0, 3.0};
double[] rmse = new double[] {10.0, Math.sqrt((2 * 2 + 20 * 20 + 10 * 10) / 3.0), 3.0};
double[] mse = new double[] { (10 * 10) / 1.0, (2 * 2 + 20 * 20 + 10 * 10) / 3, (3 * 3) / 1.0 };
double[] mae = new double[] { 10.0, (2 + 20 + 10) / 3.0, 3.0 };
double[] rmse = new double[] { 10.0, Math.sqrt((2 * 2 + 20 * 20 + 10 * 10) / 3.0), 3.0 };
for (int i = 0; i < 3; i++) {
assertEquals(mse[i], re.meanSquaredError(i), 1e-6);
assertEquals(mae[i], re.meanAbsoluteError(i), 1e-6);
@ -116,24 +90,19 @@ public class RegressionEvalTest extends BaseDL4JTest {
}
@Test
public void testRegressionEvalTimeSeriesSplit(){
INDArray out1 = Nd4j.rand(new int[]{3, 5, 20});
INDArray outSub1 = out1.get(all(), all(), interval(0,10));
@DisplayName("Test Regression Eval Time Series Split")
void testRegressionEvalTimeSeriesSplit() {
INDArray out1 = Nd4j.rand(new int[] { 3, 5, 20 });
INDArray outSub1 = out1.get(all(), all(), interval(0, 10));
INDArray outSub2 = out1.get(all(), all(), interval(10, 20));
INDArray label1 = Nd4j.rand(new int[]{3, 5, 20});
INDArray labelSub1 = label1.get(all(), all(), interval(0,10));
INDArray label1 = Nd4j.rand(new int[] { 3, 5, 20 });
INDArray labelSub1 = label1.get(all(), all(), interval(0, 10));
INDArray labelSub2 = label1.get(all(), all(), interval(10, 20));
RegressionEvaluation e1 = new RegressionEvaluation();
RegressionEvaluation e2 = new RegressionEvaluation();
e1.eval(label1, out1);
e2.eval(labelSub1, outSub1);
e2.eval(labelSub2, outSub2);
assertEquals(e1, e2);
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.gradientcheck;
import org.deeplearning4j.BaseDL4JTest;
@ -32,9 +31,9 @@ import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.Ignore;
import org.junit.jupiter.api.Disabled;
import org.junit.Rule;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.junit.rules.ExpectedException;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
@ -42,13 +41,15 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.Random;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertTrue;
@Disabled
@DisplayName("Attention Layer Test")
class AttentionLayerTest extends BaseDL4JTest {
@Ignore
public class AttentionLayerTest extends BaseDL4JTest {
@Rule
public ExpectedException exceptionRule = ExpectedException.none();
@ -58,19 +59,18 @@ public class AttentionLayerTest extends BaseDL4JTest {
}
@Test
public void testSelfAttentionLayer() {
@DisplayName("Test Self Attention Layer")
void testSelfAttentionLayer() {
int nIn = 3;
int nOut = 2;
int tsLength = 4;
int layerSize = 4;
for (int mb : new int[]{1, 3}) {
for (boolean inputMask : new boolean[]{false, true}) {
for (boolean projectInput : new boolean[]{false, true}) {
INDArray in = Nd4j.rand(DataType.DOUBLE, new int[]{mb, nIn, tsLength});
for (int mb : new int[] { 1, 3 }) {
for (boolean inputMask : new boolean[] { false, true }) {
for (boolean projectInput : new boolean[] { false, true }) {
INDArray in = Nd4j.rand(DataType.DOUBLE, new int[] { mb, nIn, tsLength });
INDArray labels = TestUtils.randomOneHot(mb, nOut);
String maskType = (inputMask ? "inputMask" : "none");
INDArray inMask = null;
if (inputMask) {
inMask = Nd4j.ones(mb, tsLength);
@ -84,54 +84,32 @@ public class AttentionLayerTest extends BaseDL4JTest {
}
}
}
String name = "testSelfAttentionLayer() - mb=" + mb + ", tsLength = " + tsLength + ", maskType=" + maskType + ", projectInput = " + projectInput;
System.out.println("Starting test: " + name);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.activation(Activation.TANH)
.updater(new NoOp())
.weightInit(WeightInit.XAVIER)
.list()
.layer(new LSTM.Builder().nOut(layerSize).build())
.layer( projectInput ?
new SelfAttentionLayer.Builder().nOut(4).nHeads(2).projectInput(true).build()
: new SelfAttentionLayer.Builder().nHeads(1).projectInput(false).build()
)
.layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.MAX).build())
.layer(new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.setInputType(InputType.recurrent(nIn))
.build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).activation(Activation.TANH).updater(new NoOp()).weightInit(WeightInit.XAVIER).list().layer(new LSTM.Builder().nOut(layerSize).build()).layer(projectInput ? new SelfAttentionLayer.Builder().nOut(4).nHeads(2).projectInput(true).build() : new SelfAttentionLayer.Builder().nHeads(1).projectInput(false).build()).layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.MAX).build()).layer(new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()).setInputType(InputType.recurrent(nIn)).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in)
.labels(labels).inputMask(inMask).subset(true).maxPerParam(100));
assertTrue(name, gradOK);
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in).labels(labels).inputMask(inMask).subset(true).maxPerParam(100));
assertTrue(gradOK,name);
}
}
}
}
@Test
public void testLearnedSelfAttentionLayer() {
@DisplayName("Test Learned Self Attention Layer")
void testLearnedSelfAttentionLayer() {
int nIn = 3;
int nOut = 2;
int tsLength = 4;
int layerSize = 4;
int numQueries = 3;
for (boolean inputMask : new boolean[]{false, true}) {
for (int mb : new int[]{3, 1}) {
for (boolean projectInput : new boolean[]{false, true}) {
INDArray in = Nd4j.rand(DataType.DOUBLE, new int[]{mb, nIn, tsLength});
for (boolean inputMask : new boolean[] { false, true }) {
for (int mb : new int[] { 3, 1 }) {
for (boolean projectInput : new boolean[] { false, true }) {
INDArray in = Nd4j.rand(DataType.DOUBLE, new int[] { mb, nIn, tsLength });
INDArray labels = TestUtils.randomOneHot(mb, nOut);
String maskType = (inputMask ? "inputMask" : "none");
INDArray inMask = null;
if (inputMask) {
inMask = Nd4j.ones(mb, tsLength);
@ -145,75 +123,36 @@ public class AttentionLayerTest extends BaseDL4JTest {
}
}
}
String name = "testLearnedSelfAttentionLayer() - mb=" + mb + ", tsLength = " + tsLength + ", maskType=" + maskType + ", projectInput = " + projectInput;
System.out.println("Starting test: " + name);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.activation(Activation.TANH)
.updater(new NoOp())
.weightInit(WeightInit.XAVIER)
.list()
.layer(new LSTM.Builder().nOut(layerSize).build())
.layer( projectInput ?
new LearnedSelfAttentionLayer.Builder().nOut(4).nHeads(2).nQueries(numQueries).projectInput(true).build()
: new LearnedSelfAttentionLayer.Builder().nHeads(1).nQueries(numQueries).projectInput(false).build()
)
.layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.MAX).build())
.layer(new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.setInputType(InputType.recurrent(nIn))
.build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).activation(Activation.TANH).updater(new NoOp()).weightInit(WeightInit.XAVIER).list().layer(new LSTM.Builder().nOut(layerSize).build()).layer(projectInput ? new LearnedSelfAttentionLayer.Builder().nOut(4).nHeads(2).nQueries(numQueries).projectInput(true).build() : new LearnedSelfAttentionLayer.Builder().nHeads(1).nQueries(numQueries).projectInput(false).build()).layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.MAX).build()).layer(new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()).setInputType(InputType.recurrent(nIn)).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in)
.labels(labels).inputMask(inMask).subset(true).maxPerParam(100));
assertTrue(name, gradOK);
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in).labels(labels).inputMask(inMask).subset(true).maxPerParam(100));
assertTrue(gradOK,name);
}
}
}
}
@Test
public void testLearnedSelfAttentionLayer_differentMiniBatchSizes() {
@DisplayName("Test Learned Self Attention Layer _ different Mini Batch Sizes")
void testLearnedSelfAttentionLayer_differentMiniBatchSizes() {
int nIn = 3;
int nOut = 2;
int tsLength = 4;
int layerSize = 4;
int numQueries = 3;
Random r = new Random(12345);
for (boolean inputMask : new boolean[]{false, true}) {
for (boolean projectInput : new boolean[]{false, true}) {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.activation(Activation.TANH)
.updater(new NoOp())
.weightInit(WeightInit.XAVIER)
.list()
.layer(new LSTM.Builder().nOut(layerSize).build())
.layer( projectInput ?
new LearnedSelfAttentionLayer.Builder().nOut(4).nHeads(2).nQueries(numQueries).projectInput(true).build()
: new LearnedSelfAttentionLayer.Builder().nHeads(1).nQueries(numQueries).projectInput(false).build()
)
.layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.MAX).build())
.layer(new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.setInputType(InputType.recurrent(nIn))
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
for (int mb : new int[]{3, 1}) {
INDArray in = Nd4j.rand(DataType.DOUBLE, new int[]{mb, nIn, tsLength});
for (boolean inputMask : new boolean[] { false, true }) {
for (boolean projectInput : new boolean[] { false, true }) {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).activation(Activation.TANH).updater(new NoOp()).weightInit(WeightInit.XAVIER).list().layer(new LSTM.Builder().nOut(layerSize).build()).layer(projectInput ? new LearnedSelfAttentionLayer.Builder().nOut(4).nHeads(2).nQueries(numQueries).projectInput(true).build() : new LearnedSelfAttentionLayer.Builder().nHeads(1).nQueries(numQueries).projectInput(false).build()).layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.MAX).build()).layer(new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()).setInputType(InputType.recurrent(nIn)).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
for (int mb : new int[] { 3, 1 }) {
INDArray in = Nd4j.rand(DataType.DOUBLE, new int[] { mb, nIn, tsLength });
INDArray labels = TestUtils.randomOneHot(mb, nOut);
String maskType = (inputMask ? "inputMask" : "none");
INDArray inMask = null;
if (inputMask) {
inMask = Nd4j.ones(DataType.INT, mb, tsLength);
@ -227,68 +166,47 @@ public class AttentionLayerTest extends BaseDL4JTest {
}
}
}
String name = "testLearnedSelfAttentionLayer() - mb=" + mb + ", tsLength = " + tsLength + ", maskType=" + maskType + ", projectInput = " + projectInput;
System.out.println("Starting test: " + name);
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in)
.labels(labels).inputMask(inMask).subset(true).maxPerParam(100));
assertTrue(name, gradOK);
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in).labels(labels).inputMask(inMask).subset(true).maxPerParam(100));
assertTrue(gradOK,name);
}
}
}
}
@Test
public void testRecurrentAttentionLayer_differingTimeSteps(){
@DisplayName("Test Recurrent Attention Layer _ differing Time Steps")
void testRecurrentAttentionLayer_differingTimeSteps() {
int nIn = 9;
int nOut = 5;
int layerSize = 8;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.activation(Activation.IDENTITY)
.updater(new NoOp())
.weightInit(WeightInit.XAVIER)
.list()
.layer(new LSTM.Builder().nOut(layerSize).build())
.layer(new RecurrentAttentionLayer.Builder().nIn(layerSize).nOut(layerSize).nHeads(1).projectInput(false).hasBias(false).build())
.layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.AVG).build())
.layer(new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.setInputType(InputType.recurrent(nIn))
.build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).activation(Activation.IDENTITY).updater(new NoOp()).weightInit(WeightInit.XAVIER).list().layer(new LSTM.Builder().nOut(layerSize).build()).layer(new RecurrentAttentionLayer.Builder().nIn(layerSize).nOut(layerSize).nHeads(1).projectInput(false).hasBias(false).build()).layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.AVG).build()).layer(new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()).setInputType(InputType.recurrent(nIn)).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
final INDArray initialInput = Nd4j.rand(new int[]{8, nIn, 7});
final INDArray goodNextInput = Nd4j.rand(new int[]{8, nIn, 7});
final INDArray badNextInput = Nd4j.rand(new int[]{8, nIn, 12});
final INDArray labels = Nd4j.rand(new int[]{8, nOut});
final INDArray initialInput = Nd4j.rand(new int[] { 8, nIn, 7 });
final INDArray goodNextInput = Nd4j.rand(new int[] { 8, nIn, 7 });
final INDArray badNextInput = Nd4j.rand(new int[] { 8, nIn, 12 });
final INDArray labels = Nd4j.rand(new int[] { 8, nOut });
net.fit(initialInput, labels);
net.fit(goodNextInput, labels);
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("This layer only supports fixed length mini-batches. Expected 7 time steps but got 12.");
net.fit(badNextInput, labels);
}
@Test
public void testRecurrentAttentionLayer() {
@DisplayName("Test Recurrent Attention Layer")
void testRecurrentAttentionLayer() {
int nIn = 4;
int nOut = 2;
int tsLength = 3;
int layerSize = 3;
for (int mb : new int[]{3, 1}) {
for (boolean inputMask : new boolean[]{true, false}) {
INDArray in = Nd4j.rand(DataType.DOUBLE, new int[]{mb, nIn, tsLength});
for (int mb : new int[] { 3, 1 }) {
for (boolean inputMask : new boolean[] { true, false }) {
INDArray in = Nd4j.rand(DataType.DOUBLE, new int[] { mb, nIn, tsLength });
INDArray labels = TestUtils.randomOneHot(mb, nOut);
String maskType = (inputMask ? "inputMask" : "none");
INDArray inMask = null;
if (inputMask) {
inMask = Nd4j.ones(mb, tsLength);
@ -302,51 +220,32 @@ public class AttentionLayerTest extends BaseDL4JTest {
}
}
}
String name = "testRecurrentAttentionLayer() - mb=" + mb + ", tsLength = " + tsLength + ", maskType=" + maskType;
System.out.println("Starting test: " + name);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.activation(Activation.IDENTITY)
.updater(new NoOp())
.weightInit(WeightInit.XAVIER)
.list()
.layer(new LSTM.Builder().nOut(layerSize).build())
.layer(new RecurrentAttentionLayer.Builder().nIn(layerSize).nOut(layerSize).nHeads(1).projectInput(false).hasBias(false).build())
.layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.AVG).build())
.layer(new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.setInputType(InputType.recurrent(nIn))
.build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).activation(Activation.IDENTITY).updater(new NoOp()).weightInit(WeightInit.XAVIER).list().layer(new LSTM.Builder().nOut(layerSize).build()).layer(new RecurrentAttentionLayer.Builder().nIn(layerSize).nOut(layerSize).nHeads(1).projectInput(false).hasBias(false).build()).layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.AVG).build()).layer(new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()).setInputType(InputType.recurrent(nIn)).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
//System.out.println("Original");
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in)
.labels(labels).inputMask(inMask).subset(true).maxPerParam(100));
assertTrue(name, gradOK);
// System.out.println("Original");
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in).labels(labels).inputMask(inMask).subset(true).maxPerParam(100));
assertTrue(gradOK,name);
}
}
}
@Test
public void testAttentionVertex() {
@DisplayName("Test Attention Vertex")
void testAttentionVertex() {
int nIn = 3;
int nOut = 2;
int tsLength = 3;
int layerSize = 3;
Random r = new Random(12345);
for (boolean inputMask : new boolean[]{false, true}) {
for (int mb : new int[]{3, 1}) {
for (boolean projectInput : new boolean[]{false, true}) {
INDArray in = Nd4j.rand(DataType.DOUBLE, new int[]{mb, nIn, tsLength});
for (boolean inputMask : new boolean[] { false, true }) {
for (int mb : new int[] { 3, 1 }) {
for (boolean projectInput : new boolean[] { false, true }) {
INDArray in = Nd4j.rand(DataType.DOUBLE, new int[] { mb, nIn, tsLength });
INDArray labels = TestUtils.randomOneHot(mb, nOut);
String maskType = (inputMask ? "inputMask" : "none");
INDArray inMask = null;
if (inputMask) {
inMask = Nd4j.ones(mb, tsLength);
@ -360,57 +259,32 @@ public class AttentionLayerTest extends BaseDL4JTest {
}
}
}
String name = "testAttentionVertex() - mb=" + mb + ", tsLength = " + tsLength + ", maskType=" + maskType + ", projectInput = " + projectInput;
System.out.println("Starting test: " + name);
ComputationGraphConfiguration graph = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.activation(Activation.TANH)
.updater(new NoOp())
.weightInit(WeightInit.XAVIER)
.graphBuilder()
.addInputs("input")
.addLayer("rnnKeys", new SimpleRnn.Builder().nOut(layerSize).build(), "input")
.addLayer("rnnQueries", new SimpleRnn.Builder().nOut(layerSize).build(), "input")
.addLayer("rnnValues", new SimpleRnn.Builder().nOut(layerSize).build(), "input")
.addVertex("attention",
projectInput ?
new AttentionVertex.Builder().nOut(4).nHeads(2).projectInput(true).nInQueries(layerSize).nInKeys(layerSize).nInValues(layerSize).build()
: new AttentionVertex.Builder().nOut(3).nHeads(1).projectInput(false).nInQueries(layerSize).nInKeys(layerSize).nInValues(layerSize).build(), "rnnQueries", "rnnKeys", "rnnValues")
.addLayer("pooling", new GlobalPoolingLayer.Builder().poolingType(PoolingType.MAX).build(), "attention")
.addLayer("output", new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build(), "pooling")
.setOutputs("output")
.setInputTypes(InputType.recurrent(nIn))
.build();
ComputationGraphConfiguration graph = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).activation(Activation.TANH).updater(new NoOp()).weightInit(WeightInit.XAVIER).graphBuilder().addInputs("input").addLayer("rnnKeys", new SimpleRnn.Builder().nOut(layerSize).build(), "input").addLayer("rnnQueries", new SimpleRnn.Builder().nOut(layerSize).build(), "input").addLayer("rnnValues", new SimpleRnn.Builder().nOut(layerSize).build(), "input").addVertex("attention", projectInput ? new AttentionVertex.Builder().nOut(4).nHeads(2).projectInput(true).nInQueries(layerSize).nInKeys(layerSize).nInValues(layerSize).build() : new AttentionVertex.Builder().nOut(3).nHeads(1).projectInput(false).nInQueries(layerSize).nInKeys(layerSize).nInValues(layerSize).build(), "rnnQueries", "rnnKeys", "rnnValues").addLayer("pooling", new GlobalPoolingLayer.Builder().poolingType(PoolingType.MAX).build(), "attention").addLayer("output", new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build(), "pooling").setOutputs("output").setInputTypes(InputType.recurrent(nIn)).build();
ComputationGraph net = new ComputationGraph(graph);
net.init();
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[]{in})
.labels(new INDArray[]{labels}).inputMask(inMask != null ? new INDArray[]{inMask} : null).subset(true).maxPerParam(100));
assertTrue(name, gradOK);
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[] { in }).labels(new INDArray[] { labels }).inputMask(inMask != null ? new INDArray[] { inMask } : null).subset(true).maxPerParam(100));
assertTrue(gradOK,name);
}
}
}
}
@Test
public void testAttentionVertexSameInput() {
@DisplayName("Test Attention Vertex Same Input")
void testAttentionVertexSameInput() {
int nIn = 3;
int nOut = 2;
int tsLength = 4;
int layerSize = 4;
Random r = new Random(12345);
for (boolean inputMask : new boolean[]{false, true}) {
for (int mb : new int[]{3, 1}) {
for (boolean projectInput : new boolean[]{false, true}) {
INDArray in = Nd4j.rand(new int[]{mb, nIn, tsLength});
for (boolean inputMask : new boolean[] { false, true }) {
for (int mb : new int[] { 3, 1 }) {
for (boolean projectInput : new boolean[] { false, true }) {
INDArray in = Nd4j.rand(new int[] { mb, nIn, tsLength });
INDArray labels = TestUtils.randomOneHot(mb, nOut);
String maskType = (inputMask ? "inputMask" : "none");
INDArray inMask = null;
if (inputMask) {
inMask = Nd4j.ones(mb, tsLength);
@ -424,35 +298,13 @@ public class AttentionLayerTest extends BaseDL4JTest {
}
}
}
String name = "testAttentionVertex() - mb=" + mb + ", tsLength = " + tsLength + ", maskType=" + maskType + ", projectInput = " + projectInput;
System.out.println("Starting test: " + name);
ComputationGraphConfiguration graph = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.activation(Activation.TANH)
.updater(new NoOp())
.weightInit(WeightInit.XAVIER)
.graphBuilder()
.addInputs("input")
.addLayer("rnn", new SimpleRnn.Builder().activation(Activation.TANH).nOut(layerSize).build(), "input")
.addVertex("attention",
projectInput ?
new AttentionVertex.Builder().nOut(4).nHeads(2).projectInput(true).nInQueries(layerSize).nInKeys(layerSize).nInValues(layerSize).build()
: new AttentionVertex.Builder().nOut(4).nHeads(1).projectInput(false).nInQueries(layerSize).nInKeys(layerSize).nInValues(layerSize).build(), "rnn", "rnn", "rnn")
.addLayer("pooling", new GlobalPoolingLayer.Builder().poolingType(PoolingType.MAX).build(), "attention")
.addLayer("output", new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build(), "pooling")
.setOutputs("output")
.setInputTypes(InputType.recurrent(nIn))
.build();
ComputationGraphConfiguration graph = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).activation(Activation.TANH).updater(new NoOp()).weightInit(WeightInit.XAVIER).graphBuilder().addInputs("input").addLayer("rnn", new SimpleRnn.Builder().activation(Activation.TANH).nOut(layerSize).build(), "input").addVertex("attention", projectInput ? new AttentionVertex.Builder().nOut(4).nHeads(2).projectInput(true).nInQueries(layerSize).nInKeys(layerSize).nInValues(layerSize).build() : new AttentionVertex.Builder().nOut(4).nHeads(1).projectInput(false).nInQueries(layerSize).nInKeys(layerSize).nInValues(layerSize).build(), "rnn", "rnn", "rnn").addLayer("pooling", new GlobalPoolingLayer.Builder().poolingType(PoolingType.MAX).build(), "attention").addLayer("output", new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build(), "pooling").setOutputs("output").setInputTypes(InputType.recurrent(nIn)).build();
ComputationGraph net = new ComputationGraph(graph);
net.init();
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[]{in})
.labels(new INDArray[]{labels}).inputMask(inMask != null ? new INDArray[]{inMask} : null));
assertTrue(name, gradOK);
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[] { in }).labels(new INDArray[] { labels }).inputMask(inMask != null ? new INDArray[] { inMask } : null));
assertTrue(gradOK,name);
}
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.gradientcheck;
import org.deeplearning4j.BaseDL4JTest;
@ -34,7 +33,7 @@ import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -48,18 +47,18 @@ import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.profiler.OpProfiler;
import org.nd4j.linalg.profiler.ProfilerConfig;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Random;
import java.util.Set;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
/**
*
*/
public class BNGradientCheckTest extends BaseDL4JTest {
@DisplayName("Bn Gradient Check Test")
class BNGradientCheckTest extends BaseDL4JTest {
static {
Nd4j.setDataType(DataType.DOUBLE);
@ -71,7 +70,8 @@ public class BNGradientCheckTest extends BaseDL4JTest {
}
@Test
public void testGradient2dSimple() {
@DisplayName("Test Gradient 2 d Simple")
void testGradient2dSimple() {
DataNormalization scaler = new NormalizerMinMaxScaler();
DataSetIterator iter = new IrisDataSetIterator(150, 150);
scaler.fit(iter);
@ -79,181 +79,117 @@ public class BNGradientCheckTest extends BaseDL4JTest {
DataSet ds = iter.next();
INDArray input = ds.getFeatures();
INDArray labels = ds.getLabels();
for (boolean useLogStd : new boolean[]{true, false}) {
MultiLayerConfiguration.Builder builder =
new NeuralNetConfiguration.Builder().updater(new NoOp())
.dataType(DataType.DOUBLE)
.seed(12345L)
.dist(new NormalDistribution(0, 1)).list()
.layer(0, new DenseLayer.Builder().nIn(4).nOut(3)
.activation(Activation.IDENTITY).build())
.layer(1, new BatchNormalization.Builder().useLogStd(useLogStd).nOut(3).build())
.layer(2, new ActivationLayer.Builder().activation(Activation.TANH).build())
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(3).nOut(3).build());
for (boolean useLogStd : new boolean[] { true, false }) {
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().updater(new NoOp()).dataType(DataType.DOUBLE).seed(12345L).dist(new NormalDistribution(0, 1)).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(3).activation(Activation.IDENTITY).build()).layer(1, new BatchNormalization.Builder().useLogStd(useLogStd).nOut(3).build()).layer(2, new ActivationLayer.Builder().activation(Activation.TANH).build()).layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3).build());
MultiLayerNetwork mln = new MultiLayerNetwork(builder.build());
mln.init();
// for (int j = 0; j < mln.getnLayers(); j++)
// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
//Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc
//i.e., runningMean = decay * runningMean + (1-decay) * batchMean
//However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter"
// for (int j = 0; j < mln.getnLayers(); j++)
// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
// Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc
// i.e., runningMean = decay * runningMean + (1-decay) * batchMean
// However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter"
Set<String> excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "1_log10stdev"));
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input)
.labels(labels).excludeParams(excludeParams));
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input).labels(labels).excludeParams(excludeParams));
assertTrue(gradOK);
TestUtils.testModelSerialization(mln);
}
}
@Test
public void testGradientCnnSimple() {
@DisplayName("Test Gradient Cnn Simple")
void testGradientCnnSimple() {
Nd4j.getRandom().setSeed(12345);
int minibatch = 10;
int depth = 1;
int hw = 4;
int nOut = 4;
INDArray input = Nd4j.rand(new int[]{minibatch, depth, hw, hw});
INDArray input = Nd4j.rand(new int[] { minibatch, depth, hw, hw });
INDArray labels = Nd4j.zeros(minibatch, nOut);
Random r = new Random(12345);
for (int i = 0; i < minibatch; i++) {
labels.putScalar(i, r.nextInt(nOut), 1.0);
}
for (boolean useLogStd : new boolean[]{true, false}) {
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.updater(new NoOp()).seed(12345L)
.dist(new NormalDistribution(0, 2)).list()
.layer(0, new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).nIn(depth).nOut(2)
.activation(Activation.IDENTITY).build())
.layer(1, new BatchNormalization.Builder().useLogStd(useLogStd).build())
.layer(2, new ActivationLayer.Builder().activation(Activation.TANH).build())
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(nOut).build())
.setInputType(InputType.convolutional(hw, hw, depth));
for (boolean useLogStd : new boolean[] { true, false }) {
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).updater(new NoOp()).seed(12345L).dist(new NormalDistribution(0, 2)).list().layer(0, new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).nIn(depth).nOut(2).activation(Activation.IDENTITY).build()).layer(1, new BatchNormalization.Builder().useLogStd(useLogStd).build()).layer(2, new ActivationLayer.Builder().activation(Activation.TANH).build()).layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(nOut).build()).setInputType(InputType.convolutional(hw, hw, depth));
MultiLayerNetwork mln = new MultiLayerNetwork(builder.build());
mln.init();
// for (int j = 0; j < mln.getnLayers(); j++)
// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
//Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc
//i.e., runningMean = decay * runningMean + (1-decay) * batchMean
//However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter"
// for (int j = 0; j < mln.getnLayers(); j++)
// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
// Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc
// i.e., runningMean = decay * runningMean + (1-decay) * batchMean
// However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter"
Set<String> excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "1_log10stdev"));
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input)
.labels(labels).excludeParams(excludeParams));
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input).labels(labels).excludeParams(excludeParams));
assertTrue(gradOK);
TestUtils.testModelSerialization(mln);
}
}
@Test
public void testGradientBNWithCNNandSubsampling() {
//Parameterized test, testing combinations of:
@DisplayName("Test Gradient BN With CN Nand Subsampling")
void testGradientBNWithCNNandSubsampling() {
// Parameterized test, testing combinations of:
// (a) activation function
// (b) Whether to test at random initialization, or after some learning (i.e., 'characteristic mode of operation')
// (c) Loss function (with specified output activations)
// (d) l1 and l2 values
Activation[] activFns = {Activation.SIGMOID, Activation.TANH, Activation.IDENTITY};
boolean[] characteristic = {true}; //If true: run some backprop steps first
LossFunctions.LossFunction[] lossFunctions =
{LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD, LossFunctions.LossFunction.MSE};
Activation[] outputActivations = {Activation.SOFTMAX, Activation.TANH}; //i.e., lossFunctions[i] used with outputActivations[i] here
double[] l2vals = {0.0, 0.1, 0.1};
double[] l1vals = {0.0, 0.0, 0.2}; //i.e., use l2vals[j] with l1vals[j]
Activation[] activFns = { Activation.SIGMOID, Activation.TANH, Activation.IDENTITY };
// If true: run some backprop steps first
boolean[] characteristic = { true };
LossFunctions.LossFunction[] lossFunctions = { LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD, LossFunctions.LossFunction.MSE };
// i.e., lossFunctions[i] used with outputActivations[i] here
Activation[] outputActivations = { Activation.SOFTMAX, Activation.TANH };
double[] l2vals = { 0.0, 0.1, 0.1 };
// i.e., use l2vals[j] with l1vals[j]
double[] l1vals = { 0.0, 0.0, 0.2 };
Nd4j.getRandom().setSeed(12345);
int minibatch = 4;
int depth = 2;
int hw = 5;
int nOut = 2;
INDArray input = Nd4j.rand(new int[]{minibatch, depth, hw, hw}).muli(5).subi(2.5);
INDArray input = Nd4j.rand(new int[] { minibatch, depth, hw, hw }).muli(5).subi(2.5);
INDArray labels = TestUtils.randomOneHot(minibatch, nOut);
DataSet ds = new DataSet(input, labels);
Random rng = new Random(12345);
for (boolean useLogStd : new boolean[]{true, false}) {
for (boolean useLogStd : new boolean[] { true, false }) {
for (Activation afn : activFns) {
for (boolean doLearningFirst : characteristic) {
for (int i = 0; i < lossFunctions.length; i++) {
for (int j = 0; j < l2vals.length; j++) {
//Skip 2 of every 3 tests: from 24 cases to 8, still with decent coverage
// Skip 2 of every 3 tests: from 24 cases to 8, still with decent coverage
if (rng.nextInt(3) != 0)
continue;
LossFunctions.LossFunction lf = lossFunctions[i];
Activation outputActivation = outputActivations[i];
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(12345)
.dataType(DataType.DOUBLE)
.l2(l2vals[j])
.optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT)
.updater(new NoOp())
.dist(new UniformDistribution(-2, 2)).seed(12345L).list()
.layer(0, new ConvolutionLayer.Builder(2, 2).stride(1, 1).nOut(3)
.activation(afn).build())
.layer(1, new BatchNormalization.Builder().useLogStd(useLogStd).build())
.layer(2, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2, 2).stride(1, 1).build())
.layer(3, new BatchNormalization())
.layer(4, new ActivationLayer.Builder().activation(afn).build())
.layer(5, new OutputLayer.Builder(lf).activation(outputActivation).nOut(nOut)
.build())
.setInputType(InputType.convolutional(hw, hw, depth));
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(12345).dataType(DataType.DOUBLE).l2(l2vals[j]).optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).updater(new NoOp()).dist(new UniformDistribution(-2, 2)).seed(12345L).list().layer(0, new ConvolutionLayer.Builder(2, 2).stride(1, 1).nOut(3).activation(afn).build()).layer(1, new BatchNormalization.Builder().useLogStd(useLogStd).build()).layer(2, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(1, 1).build()).layer(3, new BatchNormalization()).layer(4, new ActivationLayer.Builder().activation(afn).build()).layer(5, new OutputLayer.Builder(lf).activation(outputActivation).nOut(nOut).build()).setInputType(InputType.convolutional(hw, hw, depth));
MultiLayerConfiguration conf = builder.build();
MultiLayerNetwork mln = new MultiLayerNetwork(conf);
mln.init();
String name = new Object() {
}.getClass().getEnclosingMethod().getName();
// System.out.println("Num params: " + mln.numParams());
// System.out.println("Num params: " + mln.numParams());
if (doLearningFirst) {
//Run a number of iterations of learning
// Run a number of iterations of learning
mln.setInput(ds.getFeatures());
mln.setLabels(ds.getLabels());
mln.computeGradientAndScore();
double scoreBefore = mln.score();
for (int k = 0; k < 20; k++)
mln.fit(ds);
for (int k = 0; k < 20; k++) mln.fit(ds);
mln.computeGradientAndScore();
double scoreAfter = mln.score();
//Can't test in 'characteristic mode of operation' if not learning
String msg = name
+ " - score did not (sufficiently) decrease during learning - activationFn="
+ afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation
+ ", doLearningFirst= " + doLearningFirst + " (before=" + scoreBefore
+ ", scoreAfter=" + scoreAfter + ")";
assertTrue(msg, scoreAfter < 0.9 * scoreBefore);
// Can't test in 'characteristic mode of operation' if not learning
String msg = name + " - score did not (sufficiently) decrease during learning - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", doLearningFirst= " + doLearningFirst + " (before=" + scoreBefore + ", scoreAfter=" + scoreAfter + ")";
assertTrue(scoreAfter < 0.9 * scoreBefore,msg);
}
System.out.println(name + " - activationFn=" + afn + ", lossFn=" + lf
+ ", outputActivation=" + outputActivation + ", doLearningFirst="
+ doLearningFirst + ", l1=" + l1vals[j] + ", l2=" + l2vals[j]);
// for (int k = 0; k < mln.getnLayers(); k++)
// System.out.println("Layer " + k + " # params: " + mln.getLayer(k).numParams());
//Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc
//i.e., runningMean = decay * runningMean + (1-decay) * batchMean
//However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter"
System.out.println(name + " - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", doLearningFirst=" + doLearningFirst + ", l1=" + l1vals[j] + ", l2=" + l2vals[j]);
// for (int k = 0; k < mln.getnLayers(); k++)
// System.out.println("Layer " + k + " # params: " + mln.getLayer(k).numParams());
// Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc
// i.e., runningMean = decay * runningMean + (1-decay) * batchMean
// However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter"
Set<String> excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "3_mean", "3_var", "1_log10stdev", "3_log10stdev"));
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input)
.labels(labels).excludeParams(excludeParams).subset(true).maxPerParam(25)); //Most params are in output layer, only these should be skipped with this threshold
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input).labels(labels).excludeParams(excludeParams).subset(true).maxPerParam(// Most params are in output layer, only these should be skipped with this threshold
25));
assertTrue(gradOK);
TestUtils.testModelSerialization(mln);
}
@ -263,101 +199,68 @@ public class BNGradientCheckTest extends BaseDL4JTest {
}
}
@Test
public void testGradientDense() {
//Parameterized test, testing combinations of:
@DisplayName("Test Gradient Dense")
void testGradientDense() {
// Parameterized test, testing combinations of:
// (a) activation function
// (b) Whether to test at random initialization, or after some learning (i.e., 'characteristic mode of operation')
// (c) Loss function (with specified output activations)
// (d) l1 and l2 values
Activation[] activFns = {Activation.TANH, Activation.IDENTITY};
boolean[] characteristic = {true}; //If true: run some backprop steps first
LossFunctions.LossFunction[] lossFunctions =
{LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD, LossFunctions.LossFunction.MSE};
Activation[] outputActivations = {Activation.SOFTMAX, Activation.TANH}; //i.e., lossFunctions[i] used with outputActivations[i] here
double[] l2vals = {0.0, 0.1};
double[] l1vals = {0.0, 0.2}; //i.e., use l2vals[j] with l1vals[j]
Activation[] activFns = { Activation.TANH, Activation.IDENTITY };
// If true: run some backprop steps first
boolean[] characteristic = { true };
LossFunctions.LossFunction[] lossFunctions = { LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD, LossFunctions.LossFunction.MSE };
// i.e., lossFunctions[i] used with outputActivations[i] here
Activation[] outputActivations = { Activation.SOFTMAX, Activation.TANH };
double[] l2vals = { 0.0, 0.1 };
// i.e., use l2vals[j] with l1vals[j]
double[] l1vals = { 0.0, 0.2 };
Nd4j.getRandom().setSeed(12345);
int minibatch = 10;
int nIn = 5;
int nOut = 3;
INDArray input = Nd4j.rand(new int[]{minibatch, nIn});
INDArray input = Nd4j.rand(new int[] { minibatch, nIn });
INDArray labels = Nd4j.zeros(minibatch, nOut);
Random r = new Random(12345);
for (int i = 0; i < minibatch; i++) {
labels.putScalar(i, r.nextInt(nOut), 1.0);
}
DataSet ds = new DataSet(input, labels);
for (boolean useLogStd : new boolean[]{true, false}) {
for (boolean useLogStd : new boolean[] { true, false }) {
for (Activation afn : activFns) {
for (boolean doLearningFirst : characteristic) {
for (int i = 0; i < lossFunctions.length; i++) {
for (int j = 0; j < l2vals.length; j++) {
LossFunctions.LossFunction lf = lossFunctions[i];
Activation outputActivation = outputActivations[i];
MultiLayerConfiguration.Builder builder =
new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.l2(l2vals[j])
.optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT)
.updater(new NoOp())
.dist(new UniformDistribution(-2, 2)).seed(12345L).list()
.layer(0, new DenseLayer.Builder().nIn(nIn).nOut(4)
.activation(afn).build())
.layer(1, new BatchNormalization.Builder().useLogStd(useLogStd).build())
.layer(2, new DenseLayer.Builder().nIn(4).nOut(4).build())
.layer(3, new BatchNormalization.Builder().useLogStd(useLogStd).build())
.layer(4, new OutputLayer.Builder(lf)
.activation(outputActivation).nOut(nOut)
.build());
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).l2(l2vals[j]).optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).updater(new NoOp()).dist(new UniformDistribution(-2, 2)).seed(12345L).list().layer(0, new DenseLayer.Builder().nIn(nIn).nOut(4).activation(afn).build()).layer(1, new BatchNormalization.Builder().useLogStd(useLogStd).build()).layer(2, new DenseLayer.Builder().nIn(4).nOut(4).build()).layer(3, new BatchNormalization.Builder().useLogStd(useLogStd).build()).layer(4, new OutputLayer.Builder(lf).activation(outputActivation).nOut(nOut).build());
MultiLayerConfiguration conf = builder.build();
MultiLayerNetwork mln = new MultiLayerNetwork(conf);
mln.init();
String name = new Object() {
}.getClass().getEnclosingMethod().getName();
if (doLearningFirst) {
//Run a number of iterations of learning
// Run a number of iterations of learning
mln.setInput(ds.getFeatures());
mln.setLabels(ds.getLabels());
mln.computeGradientAndScore();
double scoreBefore = mln.score();
for (int k = 0; k < 10; k++)
mln.fit(ds);
for (int k = 0; k < 10; k++) mln.fit(ds);
mln.computeGradientAndScore();
double scoreAfter = mln.score();
//Can't test in 'characteristic mode of operation' if not learning
String msg = name
+ " - score did not (sufficiently) decrease during learning - activationFn="
+ afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation
+ ", doLearningFirst= " + doLearningFirst + " (before=" + scoreBefore
+ ", scoreAfter=" + scoreAfter + ")";
assertTrue(msg, scoreAfter < 0.8 * scoreBefore);
// Can't test in 'characteristic mode of operation' if not learning
String msg = name + " - score did not (sufficiently) decrease during learning - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", doLearningFirst= " + doLearningFirst + " (before=" + scoreBefore + ", scoreAfter=" + scoreAfter + ")";
assertTrue(scoreAfter < 0.8 * scoreBefore,msg);
}
System.out.println(name + " - activationFn=" + afn + ", lossFn=" + lf
+ ", outputActivation=" + outputActivation + ", doLearningFirst="
+ doLearningFirst + ", l1=" + l1vals[j] + ", l2=" + l2vals[j]);
// for (int k = 0; k < mln.getnLayers(); k++)
// System.out.println("Layer " + k + " # params: " + mln.getLayer(k).numParams());
//Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc
//i.e., runningMean = decay * runningMean + (1-decay) * batchMean
//However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter"
System.out.println(name + " - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", doLearningFirst=" + doLearningFirst + ", l1=" + l1vals[j] + ", l2=" + l2vals[j]);
// for (int k = 0; k < mln.getnLayers(); k++)
// System.out.println("Layer " + k + " # params: " + mln.getLayer(k).numParams());
// Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc
// i.e., runningMean = decay * runningMean + (1-decay) * batchMean
// However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter"
Set<String> excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "3_mean", "3_var", "1_log10stdev", "3_log10stdev"));
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input)
.labels(labels).excludeParams(excludeParams));
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input).labels(labels).excludeParams(excludeParams));
assertTrue(gradOK);
TestUtils.testModelSerialization(mln);
}
@ -368,7 +271,8 @@ public class BNGradientCheckTest extends BaseDL4JTest {
}
@Test
public void testGradient2dFixedGammaBeta() {
@DisplayName("Test Gradient 2 d Fixed Gamma Beta")
void testGradient2dFixedGammaBeta() {
DataNormalization scaler = new NormalizerMinMaxScaler();
DataSetIterator iter = new IrisDataSetIterator(150, 150);
scaler.fit(iter);
@ -376,219 +280,142 @@ public class BNGradientCheckTest extends BaseDL4JTest {
DataSet ds = iter.next();
INDArray input = ds.getFeatures();
INDArray labels = ds.getLabels();
for (boolean useLogStd : new boolean[]{true, false}) {
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().updater(new NoOp())
.dataType(DataType.DOUBLE)
.seed(12345L)
.dist(new NormalDistribution(0, 1)).list()
.layer(0, new DenseLayer.Builder().nIn(4).nOut(3).activation(Activation.IDENTITY).build())
.layer(1, new BatchNormalization.Builder().useLogStd(useLogStd).lockGammaBeta(true).gamma(2.0).beta(0.5).nOut(3)
.build())
.layer(2, new ActivationLayer.Builder().activation(Activation.TANH).build())
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(3).nOut(3).build());
for (boolean useLogStd : new boolean[] { true, false }) {
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().updater(new NoOp()).dataType(DataType.DOUBLE).seed(12345L).dist(new NormalDistribution(0, 1)).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(3).activation(Activation.IDENTITY).build()).layer(1, new BatchNormalization.Builder().useLogStd(useLogStd).lockGammaBeta(true).gamma(2.0).beta(0.5).nOut(3).build()).layer(2, new ActivationLayer.Builder().activation(Activation.TANH).build()).layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3).build());
MultiLayerNetwork mln = new MultiLayerNetwork(builder.build());
mln.init();
// for (int j = 0; j < mln.getnLayers(); j++)
// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
//Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc
//i.e., runningMean = decay * runningMean + (1-decay) * batchMean
//However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter"
// for (int j = 0; j < mln.getnLayers(); j++)
// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
// Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc
// i.e., runningMean = decay * runningMean + (1-decay) * batchMean
// However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter"
Set<String> excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "1_log10stdev"));
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input)
.labels(labels).excludeParams(excludeParams));
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input).labels(labels).excludeParams(excludeParams));
assertTrue(gradOK);
TestUtils.testModelSerialization(mln);
}
}
@Test
public void testGradientCnnFixedGammaBeta() {
@DisplayName("Test Gradient Cnn Fixed Gamma Beta")
void testGradientCnnFixedGammaBeta() {
Nd4j.getRandom().setSeed(12345);
int minibatch = 10;
int depth = 1;
int hw = 4;
int nOut = 4;
INDArray input = Nd4j.rand(new int[]{minibatch, depth, hw, hw});
INDArray input = Nd4j.rand(new int[] { minibatch, depth, hw, hw });
INDArray labels = Nd4j.zeros(minibatch, nOut);
Random r = new Random(12345);
for (int i = 0; i < minibatch; i++) {
labels.putScalar(i, r.nextInt(nOut), 1.0);
}
for (boolean useLogStd : new boolean[]{true, false}) {
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().updater(new NoOp())
.dataType(DataType.DOUBLE)
.seed(12345L)
.dist(new NormalDistribution(0, 2)).list()
.layer(0, new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).nIn(depth).nOut(2)
.activation(Activation.IDENTITY).build())
.layer(1, new BatchNormalization.Builder().useLogStd(useLogStd).lockGammaBeta(true).gamma(2.0).beta(0.5).build())
.layer(2, new ActivationLayer.Builder().activation(Activation.TANH).build())
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(nOut).build())
.setInputType(InputType.convolutional(hw, hw, depth));
for (boolean useLogStd : new boolean[] { true, false }) {
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().updater(new NoOp()).dataType(DataType.DOUBLE).seed(12345L).dist(new NormalDistribution(0, 2)).list().layer(0, new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).nIn(depth).nOut(2).activation(Activation.IDENTITY).build()).layer(1, new BatchNormalization.Builder().useLogStd(useLogStd).lockGammaBeta(true).gamma(2.0).beta(0.5).build()).layer(2, new ActivationLayer.Builder().activation(Activation.TANH).build()).layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(nOut).build()).setInputType(InputType.convolutional(hw, hw, depth));
MultiLayerNetwork mln = new MultiLayerNetwork(builder.build());
mln.init();
// for (int j = 0; j < mln.getnLayers(); j++)
// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
//Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc
//i.e., runningMean = decay * runningMean + (1-decay) * batchMean
//However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter"
// for (int j = 0; j < mln.getnLayers(); j++)
// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
// Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc
// i.e., runningMean = decay * runningMean + (1-decay) * batchMean
// However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter"
Set<String> excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "1_log10stdev"));
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input)
.labels(labels).excludeParams(excludeParams));
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input).labels(labels).excludeParams(excludeParams));
assertTrue(gradOK);
TestUtils.testModelSerialization(mln);
}
}
@Test
public void testBatchNormCompGraphSimple() {
@DisplayName("Test Batch Norm Comp Graph Simple")
void testBatchNormCompGraphSimple() {
int numClasses = 2;
int height = 3;
int width = 3;
int channels = 1;
long seed = 123;
int minibatchSize = 3;
for (boolean useLogStd : new boolean[]{true, false}) {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed).updater(new NoOp())
.dataType(DataType.DOUBLE)
.weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in")
.setInputTypes(InputType.convolutional(height, width, channels))
.addLayer("bn", new BatchNormalization.Builder().useLogStd(useLogStd).build(), "in")
.addLayer("out", new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(numClasses).build(), "bn")
.setOutputs("out").build();
for (boolean useLogStd : new boolean[] { true, false }) {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed).updater(new NoOp()).dataType(DataType.DOUBLE).weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in").setInputTypes(InputType.convolutional(height, width, channels)).addLayer("bn", new BatchNormalization.Builder().useLogStd(useLogStd).build(), "in").addLayer("out", new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(numClasses).build(), "bn").setOutputs("out").build();
ComputationGraph net = new ComputationGraph(conf);
net.init();
Random r = new Random(12345);
INDArray input = Nd4j.rand(new int[]{minibatchSize, channels, height, width}); //Order: examples, channels, height, width
// Order: examples, channels, height, width
INDArray input = Nd4j.rand(new int[] { minibatchSize, channels, height, width });
INDArray labels = Nd4j.zeros(minibatchSize, numClasses);
for (int i = 0; i < minibatchSize; i++) {
labels.putScalar(new int[]{i, r.nextInt(numClasses)}, 1.0);
labels.putScalar(new int[] { i, r.nextInt(numClasses) }, 1.0);
}
//Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc
//i.e., runningMean = decay * runningMean + (1-decay) * batchMean
//However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter"
// Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc
// i.e., runningMean = decay * runningMean + (1-decay) * batchMean
// However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter"
Set<String> excludeParams = new HashSet<>(Arrays.asList("bn_mean", "bn_var"));
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[]{input})
.labels(new INDArray[]{labels}).excludeParams(excludeParams));
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[] { input }).labels(new INDArray[] { labels }).excludeParams(excludeParams));
assertTrue(gradOK);
TestUtils.testModelSerialization(net);
}
}
@Test
public void testGradientBNWithCNNandSubsamplingCompGraph() {
//Parameterized test, testing combinations of:
@DisplayName("Test Gradient BN With CN Nand Subsampling Comp Graph")
void testGradientBNWithCNNandSubsamplingCompGraph() {
// Parameterized test, testing combinations of:
// (a) activation function
// (b) Whether to test at random initialization, or after some learning (i.e., 'characteristic mode of operation')
// (c) Loss function (with specified output activations)
// (d) l1 and l2 values
Activation[] activFns = {Activation.TANH, Activation.IDENTITY};
Activation[] activFns = { Activation.TANH, Activation.IDENTITY };
boolean doLearningFirst = true;
LossFunctions.LossFunction[] lossFunctions = {LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD};
Activation[] outputActivations = {Activation.SOFTMAX}; //i.e., lossFunctions[i] used with outputActivations[i] here
double[] l2vals = {0.0, 0.1};
double[] l1vals = {0.0, 0.2}; //i.e., use l2vals[j] with l1vals[j]
LossFunctions.LossFunction[] lossFunctions = { LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD };
// i.e., lossFunctions[i] used with outputActivations[i] here
Activation[] outputActivations = { Activation.SOFTMAX };
double[] l2vals = { 0.0, 0.1 };
// i.e., use l2vals[j] with l1vals[j]
double[] l1vals = { 0.0, 0.2 };
Nd4j.getRandom().setSeed(12345);
int minibatch = 10;
int depth = 2;
int hw = 5;
int nOut = 3;
INDArray input = Nd4j.rand(new int[]{minibatch, depth, hw, hw});
INDArray input = Nd4j.rand(new int[] { minibatch, depth, hw, hw });
INDArray labels = Nd4j.zeros(minibatch, nOut);
Random r = new Random(12345);
for (int i = 0; i < minibatch; i++) {
labels.putScalar(i, r.nextInt(nOut), 1.0);
}
DataSet ds = new DataSet(input, labels);
for (boolean useLogStd : new boolean[]{true, false}) {
for (boolean useLogStd : new boolean[] { true, false }) {
for (Activation afn : activFns) {
for (int i = 0; i < lossFunctions.length; i++) {
for (int j = 0; j < l2vals.length; j++) {
LossFunctions.LossFunction lf = lossFunctions[i];
Activation outputActivation = outputActivations[i];
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
.dataType(DataType.DOUBLE)
.optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT)
.updater(new NoOp())
.dist(new UniformDistribution(-2, 2)).seed(12345L).graphBuilder()
.addInputs("in")
.addLayer("0", new ConvolutionLayer.Builder(2, 2).stride(1, 1).nOut(3)
.activation(afn).build(), "in")
.addLayer("1", new BatchNormalization.Builder().useLogStd(useLogStd).build(), "0")
.addLayer("2", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2, 2).stride(1, 1).build(), "1")
.addLayer("3", new BatchNormalization.Builder().useLogStd(useLogStd).build(), "2")
.addLayer("4", new ActivationLayer.Builder().activation(afn).build(), "3")
.addLayer("5", new OutputLayer.Builder(lf).activation(outputActivation)
.nOut(nOut).build(), "4")
.setOutputs("5").setInputTypes(InputType.convolutional(hw, hw, depth))
.build();
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).dataType(DataType.DOUBLE).optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).updater(new NoOp()).dist(new UniformDistribution(-2, 2)).seed(12345L).graphBuilder().addInputs("in").addLayer("0", new ConvolutionLayer.Builder(2, 2).stride(1, 1).nOut(3).activation(afn).build(), "in").addLayer("1", new BatchNormalization.Builder().useLogStd(useLogStd).build(), "0").addLayer("2", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(1, 1).build(), "1").addLayer("3", new BatchNormalization.Builder().useLogStd(useLogStd).build(), "2").addLayer("4", new ActivationLayer.Builder().activation(afn).build(), "3").addLayer("5", new OutputLayer.Builder(lf).activation(outputActivation).nOut(nOut).build(), "4").setOutputs("5").setInputTypes(InputType.convolutional(hw, hw, depth)).build();
ComputationGraph net = new ComputationGraph(conf);
net.init();
String name = new Object() {
}.getClass().getEnclosingMethod().getName();
if (doLearningFirst) {
//Run a number of iterations of learning
// Run a number of iterations of learning
net.setInput(0, ds.getFeatures());
net.setLabels(ds.getLabels());
net.computeGradientAndScore();
double scoreBefore = net.score();
for (int k = 0; k < 20; k++)
net.fit(ds);
for (int k = 0; k < 20; k++) net.fit(ds);
net.computeGradientAndScore();
double scoreAfter = net.score();
//Can't test in 'characteristic mode of operation' if not learning
String msg = name
+ " - score did not (sufficiently) decrease during learning - activationFn="
+ afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation
+ ", doLearningFirst= " + doLearningFirst + " (before=" + scoreBefore
+ ", scoreAfter=" + scoreAfter + ")";
assertTrue(msg, scoreAfter < 0.9 * scoreBefore);
// Can't test in 'characteristic mode of operation' if not learning
String msg = name + " - score did not (sufficiently) decrease during learning - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", doLearningFirst= " + doLearningFirst + " (before=" + scoreBefore + ", scoreAfter=" + scoreAfter + ")";
assertTrue(scoreAfter < 0.9 * scoreBefore,msg);
}
System.out.println(name + " - activationFn=" + afn + ", lossFn=" + lf
+ ", outputActivation=" + outputActivation + ", doLearningFirst="
+ doLearningFirst + ", l1=" + l1vals[j] + ", l2=" + l2vals[j]);
// for (int k = 0; k < net.getNumLayers(); k++)
// System.out.println("Layer " + k + " # params: " + net.getLayer(k).numParams());
//Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc
//i.e., runningMean = decay * runningMean + (1-decay) * batchMean
//However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter"
System.out.println(name + " - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", doLearningFirst=" + doLearningFirst + ", l1=" + l1vals[j] + ", l2=" + l2vals[j]);
// for (int k = 0; k < net.getNumLayers(); k++)
// System.out.println("Layer " + k + " # params: " + net.getLayer(k).numParams());
// Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc
// i.e., runningMean = decay * runningMean + (1-decay) * batchMean
// However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter"
Set<String> excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "3_mean", "3_var", "1_log10stdev", "3_log10stdev"));
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[]{input})
.labels(new INDArray[]{labels}).excludeParams(excludeParams));
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[] { input }).labels(new INDArray[] { labels }).excludeParams(excludeParams));
assertTrue(gradOK);
TestUtils.testModelSerialization(net);
}
@ -596,5 +423,4 @@ public class BNGradientCheckTest extends BaseDL4JTest {
}
}
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.gradientcheck;
import lombok.extern.slf4j.Slf4j;
@ -35,7 +34,7 @@ import org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.util.Convolution1DUtils;
import org.deeplearning4j.util.ConvolutionUtils;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -44,18 +43,24 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.io.File;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@Slf4j
public class CNN1DGradientCheckTest extends BaseDL4JTest {
@DisplayName("Cnn 1 D Gradient Check Test")
class CNN1DGradientCheckTest extends BaseDL4JTest {
private static final boolean PRINT_RESULTS = true;
private static final boolean RETURN_ON_FIRST_FAILURE = false;
private static final double DEFAULT_EPS = 1e-6;
private static final double DEFAULT_MAX_REL_ERROR = 1e-3;
private static final double DEFAULT_MIN_ABS_ERROR = 1e-8;
static {
@ -68,148 +73,91 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
}
@Test
public void testCnn1DWithLocallyConnected1D() {
@DisplayName("Test Cnn 1 D With Locally Connected 1 D")
void testCnn1DWithLocallyConnected1D() {
Nd4j.getRandom().setSeed(1337);
int[] minibatchSizes = {2, 3};
int[] minibatchSizes = { 2, 3 };
int length = 7;
int convNIn = 2;
int convNOut1 = 3;
int convNOut2 = 4;
int finalNOut = 4;
int[] kernels = {1};
int[] kernels = { 1 };
int stride = 1;
int padding = 0;
Activation[] activations = {Activation.SIGMOID};
Activation[] activations = { Activation.SIGMOID };
for (Activation afn : activations) {
for (int minibatchSize : minibatchSizes) {
for (int kernel : kernels) {
INDArray input = Nd4j.rand(new int[]{minibatchSize, convNIn, length});
INDArray input = Nd4j.rand(new int[] { minibatchSize, convNIn, length });
INDArray labels = Nd4j.zeros(minibatchSize, finalNOut, length);
for (int i = 0; i < minibatchSize; i++) {
for (int j = 0; j < length; j++) {
labels.putScalar(new int[]{i, i % finalNOut, j}, 1.0);
labels.putScalar(new int[] { i, i % finalNOut, j }, 1.0);
}
}
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.updater(new NoOp())
.dist(new NormalDistribution(0, 1)).convolutionMode(ConvolutionMode.Same).list()
.layer(new Convolution1DLayer.Builder().activation(afn).kernelSize(kernel)
.stride(stride).padding(padding).nIn(convNIn).nOut(convNOut1)
.rnnDataFormat(RNNFormat.NCW)
.build())
.layer(new LocallyConnected1D.Builder().activation(afn).kernelSize(kernel)
.stride(stride).padding(padding).nIn(convNOut1).nOut(convNOut2).hasBias(false)
.build())
.layer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(finalNOut).build())
.setInputType(InputType.recurrent(convNIn, length)).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).updater(new NoOp()).dist(new NormalDistribution(0, 1)).convolutionMode(ConvolutionMode.Same).list().layer(new Convolution1DLayer.Builder().activation(afn).kernelSize(kernel).stride(stride).padding(padding).nIn(convNIn).nOut(convNOut1).rnnDataFormat(RNNFormat.NCW).build()).layer(new LocallyConnected1D.Builder().activation(afn).kernelSize(kernel).stride(stride).padding(padding).nIn(convNOut1).nOut(convNOut2).hasBias(false).build()).layer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(finalNOut).build()).setInputType(InputType.recurrent(convNIn, length)).build();
String json = conf.toJson();
MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json);
assertEquals(conf, c2);
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
String msg = "Minibatch=" + minibatchSize + ", activationFn="
+ afn + ", kernel = " + kernel;
String msg = "Minibatch=" + minibatchSize + ", activationFn=" + afn + ", kernel = " + kernel;
if (PRINT_RESULTS) {
System.out.println(msg);
// for (int j = 0; j < net.getnLayers(); j++)
// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
// for (int j = 0; j < net.getnLayers(); j++)
// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
}
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
assertTrue(msg, gradOK);
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
assertTrue(gradOK,msg);
TestUtils.testModelSerialization(net);
}
}
}
}
@Test
public void testCnn1DWithCropping1D() {
@DisplayName("Test Cnn 1 D With Cropping 1 D")
void testCnn1DWithCropping1D() {
Nd4j.getRandom().setSeed(1337);
int[] minibatchSizes = {1, 3};
int[] minibatchSizes = { 1, 3 };
int length = 7;
int convNIn = 2;
int convNOut1 = 3;
int convNOut2 = 4;
int finalNOut = 4;
int[] kernels = {1, 2, 4};
int[] kernels = { 1, 2, 4 };
int stride = 1;
int padding = 0;
int cropping = 1;
int croppedLength = length - 2 * cropping;
Activation[] activations = {Activation.SIGMOID};
SubsamplingLayer.PoolingType[] poolingTypes =
new SubsamplingLayer.PoolingType[]{SubsamplingLayer.PoolingType.MAX,
SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM};
Activation[] activations = { Activation.SIGMOID };
SubsamplingLayer.PoolingType[] poolingTypes = new SubsamplingLayer.PoolingType[] { SubsamplingLayer.PoolingType.MAX, SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM };
for (Activation afn : activations) {
for (SubsamplingLayer.PoolingType poolingType : poolingTypes) {
for (int minibatchSize : minibatchSizes) {
for (int kernel : kernels) {
INDArray input = Nd4j.rand(new int[]{minibatchSize, convNIn, length});
INDArray input = Nd4j.rand(new int[] { minibatchSize, convNIn, length });
INDArray labels = Nd4j.zeros(minibatchSize, finalNOut, croppedLength);
for (int i = 0; i < minibatchSize; i++) {
for (int j = 0; j < croppedLength; j++) {
labels.putScalar(new int[]{i, i % finalNOut, j}, 1.0);
labels.putScalar(new int[] { i, i % finalNOut, j }, 1.0);
}
}
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.updater(new NoOp())
.dist(new NormalDistribution(0, 1)).convolutionMode(ConvolutionMode.Same).list()
.layer(new Convolution1DLayer.Builder().activation(afn).kernelSize(kernel)
.stride(stride).padding(padding).nOut(convNOut1)
.build())
.layer(new Cropping1D.Builder(cropping).build())
.layer(new Convolution1DLayer.Builder().activation(afn).kernelSize(kernel)
.stride(stride).padding(padding).nOut(convNOut2)
.build())
.layer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(finalNOut).build())
.setInputType(InputType.recurrent(convNIn, length,RNNFormat.NCW)).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).updater(new NoOp()).dist(new NormalDistribution(0, 1)).convolutionMode(ConvolutionMode.Same).list().layer(new Convolution1DLayer.Builder().activation(afn).kernelSize(kernel).stride(stride).padding(padding).nOut(convNOut1).build()).layer(new Cropping1D.Builder(cropping).build()).layer(new Convolution1DLayer.Builder().activation(afn).kernelSize(kernel).stride(stride).padding(padding).nOut(convNOut2).build()).layer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(finalNOut).build()).setInputType(InputType.recurrent(convNIn, length, RNNFormat.NCW)).build();
String json = conf.toJson();
MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json);
assertEquals(conf, c2);
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
String msg = "PoolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn="
+ afn + ", kernel = " + kernel;
String msg = "PoolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn=" + afn + ", kernel = " + kernel;
if (PRINT_RESULTS) {
System.out.println(msg);
// for (int j = 0; j < net.getnLayers(); j++)
// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
// for (int j = 0; j < net.getnLayers(); j++)
// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
}
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
assertTrue(msg, gradOK);
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
assertTrue(gradOK,msg);
TestUtils.testModelSerialization(net);
}
@ -218,82 +166,50 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
}
}
@Test
public void testCnn1DWithZeroPadding1D() {
@DisplayName("Test Cnn 1 D With Zero Padding 1 D")
void testCnn1DWithZeroPadding1D() {
Nd4j.getRandom().setSeed(1337);
int[] minibatchSizes = {1, 3};
int[] minibatchSizes = { 1, 3 };
int length = 7;
int convNIn = 2;
int convNOut1 = 3;
int convNOut2 = 4;
int finalNOut = 4;
int[] kernels = {1, 2, 4};
int[] kernels = { 1, 2, 4 };
int stride = 1;
int pnorm = 2;
int padding = 0;
int zeroPadding = 2;
int paddedLength = length + 2 * zeroPadding;
Activation[] activations = {Activation.SIGMOID};
SubsamplingLayer.PoolingType[] poolingTypes =
new SubsamplingLayer.PoolingType[]{SubsamplingLayer.PoolingType.MAX,
SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM};
Activation[] activations = { Activation.SIGMOID };
SubsamplingLayer.PoolingType[] poolingTypes = new SubsamplingLayer.PoolingType[] { SubsamplingLayer.PoolingType.MAX, SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM };
for (Activation afn : activations) {
for (SubsamplingLayer.PoolingType poolingType : poolingTypes) {
for (int minibatchSize : minibatchSizes) {
for (int kernel : kernels) {
INDArray input = Nd4j.rand(new int[]{minibatchSize, convNIn, length});
INDArray input = Nd4j.rand(new int[] { minibatchSize, convNIn, length });
INDArray labels = Nd4j.zeros(minibatchSize, finalNOut, paddedLength);
for (int i = 0; i < minibatchSize; i++) {
for (int j = 0; j < paddedLength; j++) {
labels.putScalar(new int[]{i, i % finalNOut, j}, 1.0);
labels.putScalar(new int[] { i, i % finalNOut, j }, 1.0);
}
}
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.updater(new NoOp())
.dist(new NormalDistribution(0, 1)).convolutionMode(ConvolutionMode.Same).list()
.layer(new Convolution1DLayer.Builder().activation(afn).kernelSize(kernel)
.stride(stride).padding(padding).nOut(convNOut1)
.build())
.layer(new ZeroPadding1DLayer.Builder(zeroPadding).build())
.layer(new Convolution1DLayer.Builder().activation(afn).kernelSize(kernel)
.stride(stride).padding(padding).nOut(convNOut2)
.build())
.layer(new ZeroPadding1DLayer.Builder(0).build())
.layer(new Subsampling1DLayer.Builder(poolingType).kernelSize(kernel)
.stride(stride).padding(padding).pnorm(pnorm).build())
.layer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(finalNOut).build())
.setInputType(InputType.recurrent(convNIn, length,RNNFormat.NCW)).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).updater(new NoOp()).dist(new NormalDistribution(0, 1)).convolutionMode(ConvolutionMode.Same).list().layer(new Convolution1DLayer.Builder().activation(afn).kernelSize(kernel).stride(stride).padding(padding).nOut(convNOut1).build()).layer(new ZeroPadding1DLayer.Builder(zeroPadding).build()).layer(new Convolution1DLayer.Builder().activation(afn).kernelSize(kernel).stride(stride).padding(padding).nOut(convNOut2).build()).layer(new ZeroPadding1DLayer.Builder(0).build()).layer(new Subsampling1DLayer.Builder(poolingType).kernelSize(kernel).stride(stride).padding(padding).pnorm(pnorm).build()).layer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(finalNOut).build()).setInputType(InputType.recurrent(convNIn, length, RNNFormat.NCW)).build();
String json = conf.toJson();
MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json);
assertEquals(conf, c2);
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
String msg = "PoolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn="
+ afn + ", kernel = " + kernel;
String msg = "PoolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn=" + afn + ", kernel = " + kernel;
if (PRINT_RESULTS) {
System.out.println(msg);
// for (int j = 0; j < net.getnLayers(); j++)
// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
// for (int j = 0; j < net.getnLayers(); j++)
// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
}
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
assertTrue(gradOK,msg);
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(net);
}
}
@ -301,76 +217,48 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
}
}
@Test
public void testCnn1DWithSubsampling1D() {
@DisplayName("Test Cnn 1 D With Subsampling 1 D")
void testCnn1DWithSubsampling1D() {
Nd4j.getRandom().setSeed(12345);
int[] minibatchSizes = {1, 3};
int[] minibatchSizes = { 1, 3 };
int length = 7;
int convNIn = 2;
int convNOut1 = 3;
int convNOut2 = 4;
int finalNOut = 4;
int[] kernels = {1, 2, 4};
int[] kernels = { 1, 2, 4 };
int stride = 1;
int padding = 0;
int pnorm = 2;
Activation[] activations = {Activation.SIGMOID, Activation.TANH};
SubsamplingLayer.PoolingType[] poolingTypes =
new SubsamplingLayer.PoolingType[]{SubsamplingLayer.PoolingType.MAX,
SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM};
Activation[] activations = { Activation.SIGMOID, Activation.TANH };
SubsamplingLayer.PoolingType[] poolingTypes = new SubsamplingLayer.PoolingType[] { SubsamplingLayer.PoolingType.MAX, SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM };
for (Activation afn : activations) {
for (SubsamplingLayer.PoolingType poolingType : poolingTypes) {
for (int minibatchSize : minibatchSizes) {
for (int kernel : kernels) {
INDArray input = Nd4j.rand(new int[]{minibatchSize, convNIn, length});
INDArray input = Nd4j.rand(new int[] { minibatchSize, convNIn, length });
INDArray labels = Nd4j.zeros(minibatchSize, finalNOut, length);
for (int i = 0; i < minibatchSize; i++) {
for (int j = 0; j < length; j++) {
labels.putScalar(new int[]{i, i % finalNOut, j}, 1.0);
labels.putScalar(new int[] { i, i % finalNOut, j }, 1.0);
}
}
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.updater(new NoOp())
.dist(new NormalDistribution(0, 1)).convolutionMode(ConvolutionMode.Same).list()
.layer(0, new Convolution1DLayer.Builder().activation(afn).kernelSize(kernel)
.stride(stride).padding(padding).nOut(convNOut1)
.build())
.layer(1, new Convolution1DLayer.Builder().activation(afn).kernelSize(kernel)
.stride(stride).padding(padding).nOut(convNOut2)
.build())
.layer(2, new Subsampling1DLayer.Builder(poolingType).kernelSize(kernel)
.stride(stride).padding(padding).pnorm(pnorm).build())
.layer(3, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(finalNOut).build())
.setInputType(InputType.recurrent(convNIn, length,RNNFormat.NCW)).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).updater(new NoOp()).dist(new NormalDistribution(0, 1)).convolutionMode(ConvolutionMode.Same).list().layer(0, new Convolution1DLayer.Builder().activation(afn).kernelSize(kernel).stride(stride).padding(padding).nOut(convNOut1).build()).layer(1, new Convolution1DLayer.Builder().activation(afn).kernelSize(kernel).stride(stride).padding(padding).nOut(convNOut2).build()).layer(2, new Subsampling1DLayer.Builder(poolingType).kernelSize(kernel).stride(stride).padding(padding).pnorm(pnorm).build()).layer(3, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(finalNOut).build()).setInputType(InputType.recurrent(convNIn, length, RNNFormat.NCW)).build();
String json = conf.toJson();
MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json);
assertEquals(conf, c2);
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
String msg = "PoolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn="
+ afn + ", kernel = " + kernel;
String msg = "PoolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn=" + afn + ", kernel = " + kernel;
if (PRINT_RESULTS) {
System.out.println(msg);
// for (int j = 0; j < net.getnLayers(); j++)
// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
// for (int j = 0; j < net.getnLayers(); j++)
// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
}
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
assertTrue(gradOK,msg);
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(net);
}
}
@ -379,66 +267,34 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
}
@Test
public void testCnn1dWithMasking(){
@DisplayName("Test Cnn 1 d With Masking")
void testCnn1dWithMasking() {
int length = 12;
int convNIn = 2;
int convNOut1 = 3;
int convNOut2 = 4;
int finalNOut = 3;
int pnorm = 2;
SubsamplingLayer.PoolingType[] poolingTypes =
new SubsamplingLayer.PoolingType[] {SubsamplingLayer.PoolingType.MAX, SubsamplingLayer.PoolingType.AVG};
SubsamplingLayer.PoolingType[] poolingTypes = new SubsamplingLayer.PoolingType[] { SubsamplingLayer.PoolingType.MAX, SubsamplingLayer.PoolingType.AVG };
for (SubsamplingLayer.PoolingType poolingType : poolingTypes) {
for(ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Same, ConvolutionMode.Truncate}) {
for( int stride : new int[]{1, 2}){
for (ConvolutionMode cm : new ConvolutionMode[] { ConvolutionMode.Same, ConvolutionMode.Truncate }) {
for (int stride : new int[] { 1, 2 }) {
String s = cm + ", stride=" + stride + ", pooling=" + poolingType;
log.info("Starting test: " + s);
Nd4j.getRandom().setSeed(12345);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.updater(new NoOp())
.activation(Activation.TANH)
.dist(new NormalDistribution(0, 1)).convolutionMode(cm)
.seed(12345)
.list()
.layer(new Convolution1DLayer.Builder().kernelSize(2)
.rnnDataFormat(RNNFormat.NCW)
.stride(stride).nIn(convNIn).nOut(convNOut1)
.build())
.layer(new Subsampling1DLayer.Builder(poolingType).kernelSize(2)
.stride(stride).pnorm(pnorm).build())
.layer(new Convolution1DLayer.Builder().kernelSize(2)
.rnnDataFormat(RNNFormat.NCW)
.stride(stride).nIn(convNOut1).nOut(convNOut2)
.build())
.layer(new GlobalPoolingLayer(PoolingType.AVG))
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(finalNOut).build())
.setInputType(InputType.recurrent(convNIn, length)).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).updater(new NoOp()).activation(Activation.TANH).dist(new NormalDistribution(0, 1)).convolutionMode(cm).seed(12345).list().layer(new Convolution1DLayer.Builder().kernelSize(2).rnnDataFormat(RNNFormat.NCW).stride(stride).nIn(convNIn).nOut(convNOut1).build()).layer(new Subsampling1DLayer.Builder(poolingType).kernelSize(2).stride(stride).pnorm(pnorm).build()).layer(new Convolution1DLayer.Builder().kernelSize(2).rnnDataFormat(RNNFormat.NCW).stride(stride).nIn(convNOut1).nOut(convNOut2).build()).layer(new GlobalPoolingLayer(PoolingType.AVG)).layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(finalNOut).build()).setInputType(InputType.recurrent(convNIn, length)).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
INDArray f = Nd4j.rand(new int[]{2, convNIn, length});
INDArray f = Nd4j.rand(new int[] { 2, convNIn, length });
INDArray fm = Nd4j.create(2, length);
fm.get(NDArrayIndex.point(0), NDArrayIndex.all()).assign(1);
fm.get(NDArrayIndex.point(1), NDArrayIndex.interval(0,6)).assign(1);
fm.get(NDArrayIndex.point(1), NDArrayIndex.interval(0, 6)).assign(1);
INDArray label = TestUtils.randomOneHot(2, finalNOut);
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(f)
.labels(label).inputMask(fm));
assertTrue(s, gradOK);
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(f).labels(label).inputMask(fm));
assertTrue(gradOK,s);
TestUtils.testModelSerialization(net);
//TODO also check that masked step values don't impact forward pass, score or gradients
DataSet ds = new DataSet(f,label,fm,null);
// TODO also check that masked step values don't impact forward pass, score or gradients
DataSet ds = new DataSet(f, label, fm, null);
double scoreBefore = net.score(ds);
net.setInput(f);
net.setLabels(label);
@ -453,7 +309,6 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
net.setLayerMaskArrays(fm, null);
net.computeGradientAndScore();
INDArray gradAfter = net.getFlattenedGradients().dup();
assertEquals(scoreBefore, scoreAfter, 1e-6);
assertEquals(gradBefore, gradAfter);
}
@ -462,18 +317,18 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
}
@Test
public void testCnn1Causal() throws Exception {
@DisplayName("Test Cnn 1 Causal")
void testCnn1Causal() throws Exception {
int convNIn = 2;
int convNOut1 = 3;
int convNOut2 = 4;
int finalNOut = 3;
int[] lengths = {11, 12, 13, 9, 10, 11};
int[] kernels = {2, 3, 2, 4, 2, 3};
int[] dilations = {1, 1, 2, 1, 2, 1};
int[] strides = {1, 2, 1, 2, 1, 1};
boolean[] masks = {false, true, false, true, false, true};
boolean[] hasB = {true, false, true, false, true, true};
int[] lengths = { 11, 12, 13, 9, 10, 11 };
int[] kernels = { 2, 3, 2, 4, 2, 3 };
int[] dilations = { 1, 1, 2, 1, 2, 1 };
int[] strides = { 1, 2, 1, 2, 1, 1 };
boolean[] masks = { false, true, false, true, false, true };
boolean[] hasB = { true, false, true, false, true, true };
for (int i = 0; i < lengths.length; i++) {
int length = lengths[i];
int k = kernels[i];
@ -481,36 +336,13 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
int st = strides[i];
boolean mask = masks[i];
boolean hasBias = hasB[i];
//TODO has bias
// TODO has bias
String s = "k=" + k + ", s=" + st + " d=" + d + ", seqLen=" + length;
log.info("Starting test: " + s);
Nd4j.getRandom().setSeed(12345);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.updater(new NoOp())
.activation(Activation.TANH)
.weightInit(new NormalDistribution(0, 1))
.seed(12345)
.list()
.layer(new Convolution1DLayer.Builder().kernelSize(k)
.dilation(d)
.hasBias(hasBias)
.convolutionMode(ConvolutionMode.Causal)
.stride(st).nOut(convNOut1)
.build())
.layer(new Convolution1DLayer.Builder().kernelSize(k)
.dilation(d)
.convolutionMode(ConvolutionMode.Causal)
.stride(st).nOut(convNOut2)
.build())
.layer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(finalNOut).build())
.setInputType(InputType.recurrent(convNIn, length,RNNFormat.NCW)).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).updater(new NoOp()).activation(Activation.TANH).weightInit(new NormalDistribution(0, 1)).seed(12345).list().layer(new Convolution1DLayer.Builder().kernelSize(k).dilation(d).hasBias(hasBias).convolutionMode(ConvolutionMode.Causal).stride(st).nOut(convNOut1).build()).layer(new Convolution1DLayer.Builder().kernelSize(k).dilation(d).convolutionMode(ConvolutionMode.Causal).stride(st).nOut(convNOut2).build()).layer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(finalNOut).build()).setInputType(InputType.recurrent(convNIn, length, RNNFormat.NCW)).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
INDArray f = Nd4j.rand(DataType.DOUBLE, 2, convNIn, length);
INDArray fm = null;
if (mask) {
@ -518,16 +350,11 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
fm.get(NDArrayIndex.point(0), NDArrayIndex.all()).assign(1);
fm.get(NDArrayIndex.point(1), NDArrayIndex.interval(0, length - 2)).assign(1);
}
long outSize1 = Convolution1DUtils.getOutputSize(length, k, st, 0, ConvolutionMode.Causal, d);
long outSize2 = Convolution1DUtils.getOutputSize(outSize1, k, st, 0, ConvolutionMode.Causal, d);
INDArray label = TestUtils.randomOneHotTimeSeries(2, finalNOut, (int)outSize2);
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(f)
.labels(label).inputMask(fm));
assertTrue(s, gradOK);
INDArray label = TestUtils.randomOneHotTimeSeries(2, finalNOut, (int) outSize2);
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(f).labels(label).inputMask(fm));
assertTrue(gradOK,s);
TestUtils.testModelSerialization(net);
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.gradientcheck;
import lombok.extern.java.Log;
@ -33,7 +32,7 @@ import org.deeplearning4j.nn.conf.layers.convolutional.Cropping3D;
import org.deeplearning4j.nn.conf.preprocessor.Cnn3DToFeedForwardPreProcessor;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
@ -41,18 +40,24 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.Arrays;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@Log
public class CNN3DGradientCheckTest extends BaseDL4JTest {
@DisplayName("Cnn 3 D Gradient Check Test")
class CNN3DGradientCheckTest extends BaseDL4JTest {
private static final boolean PRINT_RESULTS = true;
private static final boolean RETURN_ON_FIRST_FAILURE = false;
private static final double DEFAULT_EPS = 1e-6;
private static final double DEFAULT_MAX_REL_ERROR = 1e-3;
private static final double DEFAULT_MIN_ABS_ERROR = 1e-8;
static {
@ -65,30 +70,23 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest {
}
@Test
public void testCnn3DPlain() {
@DisplayName("Test Cnn 3 D Plain")
void testCnn3DPlain() {
Nd4j.getRandom().setSeed(1337);
// Note: we checked this with a variety of parameters, but it takes a lot of time.
int[] depths = {6};
int[] heights = {6};
int[] widths = {6};
int[] minibatchSizes = {3};
int[] depths = { 6 };
int[] heights = { 6 };
int[] widths = { 6 };
int[] minibatchSizes = { 3 };
int convNIn = 2;
int convNOut1 = 3;
int convNOut2 = 4;
int denseNOut = 5;
int finalNOut = 42;
int[][] kernels = {{2, 2, 2}};
int[][] strides = {{1, 1, 1}};
Activation[] activations = {Activation.SIGMOID};
ConvolutionMode[] modes = {ConvolutionMode.Truncate, ConvolutionMode.Same};
int[][] kernels = { { 2, 2, 2 } };
int[][] strides = { { 1, 1, 1 } };
Activation[] activations = { Activation.SIGMOID };
ConvolutionMode[] modes = { ConvolutionMode.Truncate, ConvolutionMode.Same };
for (Activation afn : activations) {
for (int miniBatchSize : minibatchSizes) {
for (int depth : depths) {
@ -98,71 +96,34 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest {
for (int[] kernel : kernels) {
for (int[] stride : strides) {
for (Convolution3D.DataFormat df : Convolution3D.DataFormat.values()) {
int outDepth = mode == ConvolutionMode.Same ?
depth / stride[0] : (depth - kernel[0]) / stride[0] + 1;
int outHeight = mode == ConvolutionMode.Same ?
height / stride[1] : (height - kernel[1]) / stride[1] + 1;
int outWidth = mode == ConvolutionMode.Same ?
width / stride[2] : (width - kernel[2]) / stride[2] + 1;
int outDepth = mode == ConvolutionMode.Same ? depth / stride[0] : (depth - kernel[0]) / stride[0] + 1;
int outHeight = mode == ConvolutionMode.Same ? height / stride[1] : (height - kernel[1]) / stride[1] + 1;
int outWidth = mode == ConvolutionMode.Same ? width / stride[2] : (width - kernel[2]) / stride[2] + 1;
INDArray input;
if(df == Convolution3D.DataFormat.NDHWC){
input = Nd4j.rand(new int[]{miniBatchSize, depth, height, width, convNIn});
if (df == Convolution3D.DataFormat.NDHWC) {
input = Nd4j.rand(new int[] { miniBatchSize, depth, height, width, convNIn });
} else {
input = Nd4j.rand(new int[]{miniBatchSize, convNIn, depth, height, width});
input = Nd4j.rand(new int[] { miniBatchSize, convNIn, depth, height, width });
}
INDArray labels = Nd4j.zeros(miniBatchSize, finalNOut);
for (int i = 0; i < miniBatchSize; i++) {
labels.putScalar(new int[]{i, i % finalNOut}, 1.0);
labels.putScalar(new int[] { i, i % finalNOut }, 1.0);
}
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.updater(new NoOp()).weightInit(WeightInit.LECUN_NORMAL)
.dist(new NormalDistribution(0, 1))
.list()
.layer(0, new Convolution3D.Builder().activation(afn).kernelSize(kernel)
.stride(stride).nIn(convNIn).nOut(convNOut1).hasBias(false)
.convolutionMode(mode).dataFormat(df)
.build())
.layer(1, new Convolution3D.Builder().activation(afn).kernelSize(1, 1, 1)
.nIn(convNOut1).nOut(convNOut2).hasBias(false)
.convolutionMode(mode).dataFormat(df)
.build())
.layer(2, new DenseLayer.Builder().nOut(denseNOut).build())
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(finalNOut).build())
.inputPreProcessor(2,
new Cnn3DToFeedForwardPreProcessor(outDepth, outHeight, outWidth,
convNOut2, df == Convolution3D.DataFormat.NCDHW))
.setInputType(InputType.convolutional3D(df, depth, height, width, convNIn)).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).updater(new NoOp()).weightInit(WeightInit.LECUN_NORMAL).dist(new NormalDistribution(0, 1)).list().layer(0, new Convolution3D.Builder().activation(afn).kernelSize(kernel).stride(stride).nIn(convNIn).nOut(convNOut1).hasBias(false).convolutionMode(mode).dataFormat(df).build()).layer(1, new Convolution3D.Builder().activation(afn).kernelSize(1, 1, 1).nIn(convNOut1).nOut(convNOut2).hasBias(false).convolutionMode(mode).dataFormat(df).build()).layer(2, new DenseLayer.Builder().nOut(denseNOut).build()).layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(finalNOut).build()).inputPreProcessor(2, new Cnn3DToFeedForwardPreProcessor(outDepth, outHeight, outWidth, convNOut2, df == Convolution3D.DataFormat.NCDHW)).setInputType(InputType.convolutional3D(df, depth, height, width, convNIn)).build();
String json = conf.toJson();
MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json);
assertEquals(conf, c2);
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
String msg = "DataFormat = " + df + ", minibatch size = " + miniBatchSize + ", activationFn=" + afn
+ ", kernel = " + Arrays.toString(kernel) + ", stride = "
+ Arrays.toString(stride) + ", mode = " + mode.toString()
+ ", input depth " + depth + ", input height " + height
+ ", input width " + width;
String msg = "DataFormat = " + df + ", minibatch size = " + miniBatchSize + ", activationFn=" + afn + ", kernel = " + Arrays.toString(kernel) + ", stride = " + Arrays.toString(stride) + ", mode = " + mode.toString() + ", input depth " + depth + ", input height " + height + ", input width " + width;
if (PRINT_RESULTS) {
log.info(msg);
// for (int j = 0; j < net.getnLayers(); j++) {
// log.info("Layer " + j + " # params: " + net.getLayer(j).numParams());
// }
// for (int j = 0; j < net.getnLayers(); j++) {
// log.info("Layer " + j + " # params: " + net.getLayer(j).numParams());
// }
}
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input)
.labels(labels).subset(true).maxPerParam(128));
assertTrue(msg, gradOK);
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input).labels(labels).subset(true).maxPerParam(128));
assertTrue(gradOK,msg);
TestUtils.testModelSerialization(net);
}
}
@ -176,186 +137,98 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest {
}
@Test
public void testCnn3DZeroPadding() {
@DisplayName("Test Cnn 3 D Zero Padding")
void testCnn3DZeroPadding() {
Nd4j.getRandom().setSeed(42);
int depth = 4;
int height = 4;
int width = 4;
int[] minibatchSizes = {3};
int[] minibatchSizes = { 3 };
int convNIn = 2;
int convNOut1 = 3;
int convNOut2 = 4;
int denseNOut = 5;
int finalNOut = 42;
int[] kernel = {2, 2, 2};
int[] zeroPadding = {1, 1, 2, 2, 3, 3};
Activation[] activations = {Activation.SIGMOID};
ConvolutionMode[] modes = {ConvolutionMode.Truncate, ConvolutionMode.Same};
int[] kernel = { 2, 2, 2 };
int[] zeroPadding = { 1, 1, 2, 2, 3, 3 };
Activation[] activations = { Activation.SIGMOID };
ConvolutionMode[] modes = { ConvolutionMode.Truncate, ConvolutionMode.Same };
for (Activation afn : activations) {
for (int miniBatchSize : minibatchSizes) {
for (ConvolutionMode mode : modes) {
int outDepth = mode == ConvolutionMode.Same ?
depth : (depth - kernel[0]) + 1;
int outHeight = mode == ConvolutionMode.Same ?
height : (height - kernel[1]) + 1;
int outWidth = mode == ConvolutionMode.Same ?
width : (width - kernel[2]) + 1;
int outDepth = mode == ConvolutionMode.Same ? depth : (depth - kernel[0]) + 1;
int outHeight = mode == ConvolutionMode.Same ? height : (height - kernel[1]) + 1;
int outWidth = mode == ConvolutionMode.Same ? width : (width - kernel[2]) + 1;
outDepth += zeroPadding[0] + zeroPadding[1];
outHeight += zeroPadding[2] + zeroPadding[3];
outWidth += zeroPadding[4] + zeroPadding[5];
INDArray input = Nd4j.rand(new int[]{miniBatchSize, convNIn, depth, height, width});
INDArray input = Nd4j.rand(new int[] { miniBatchSize, convNIn, depth, height, width });
INDArray labels = Nd4j.zeros(miniBatchSize, finalNOut);
for (int i = 0; i < miniBatchSize; i++) {
labels.putScalar(new int[]{i, i % finalNOut}, 1.0);
labels.putScalar(new int[] { i, i % finalNOut }, 1.0);
}
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.updater(new NoOp()).weightInit(WeightInit.LECUN_NORMAL)
.dist(new NormalDistribution(0, 1))
.list()
.layer(0, new Convolution3D.Builder().activation(afn).kernelSize(kernel)
.nIn(convNIn).nOut(convNOut1).hasBias(false)
.convolutionMode(mode).dataFormat(Convolution3D.DataFormat.NCDHW)
.build())
.layer(1, new Convolution3D.Builder().activation(afn).kernelSize(1, 1, 1)
.nIn(convNOut1).nOut(convNOut2).hasBias(false)
.convolutionMode(mode).dataFormat(Convolution3D.DataFormat.NCDHW)
.build())
.layer(2, new ZeroPadding3DLayer.Builder(zeroPadding).build())
.layer(3, new DenseLayer.Builder().nOut(denseNOut).build())
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(finalNOut).build())
.inputPreProcessor(3,
new Cnn3DToFeedForwardPreProcessor(outDepth, outHeight, outWidth,
convNOut2, true))
.setInputType(InputType.convolutional3D(depth, height, width, convNIn)).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).updater(new NoOp()).weightInit(WeightInit.LECUN_NORMAL).dist(new NormalDistribution(0, 1)).list().layer(0, new Convolution3D.Builder().activation(afn).kernelSize(kernel).nIn(convNIn).nOut(convNOut1).hasBias(false).convolutionMode(mode).dataFormat(Convolution3D.DataFormat.NCDHW).build()).layer(1, new Convolution3D.Builder().activation(afn).kernelSize(1, 1, 1).nIn(convNOut1).nOut(convNOut2).hasBias(false).convolutionMode(mode).dataFormat(Convolution3D.DataFormat.NCDHW).build()).layer(2, new ZeroPadding3DLayer.Builder(zeroPadding).build()).layer(3, new DenseLayer.Builder().nOut(denseNOut).build()).layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(finalNOut).build()).inputPreProcessor(3, new Cnn3DToFeedForwardPreProcessor(outDepth, outHeight, outWidth, convNOut2, true)).setInputType(InputType.convolutional3D(depth, height, width, convNIn)).build();
String json = conf.toJson();
MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json);
assertEquals(conf, c2);
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
String msg = "Minibatch size = " + miniBatchSize + ", activationFn=" + afn
+ ", kernel = " + Arrays.toString(kernel) + ", mode = " + mode.toString()
+ ", input depth " + depth + ", input height " + height
+ ", input width " + width;
String msg = "Minibatch size = " + miniBatchSize + ", activationFn=" + afn + ", kernel = " + Arrays.toString(kernel) + ", mode = " + mode.toString() + ", input depth " + depth + ", input height " + height + ", input width " + width;
if (PRINT_RESULTS) {
log.info(msg);
// for (int j = 0; j < net.getnLayers(); j++) {
// log.info("Layer " + j + " # params: " + net.getLayer(j).numParams());
// }
// for (int j = 0; j < net.getnLayers(); j++) {
// log.info("Layer " + j + " # params: " + net.getLayer(j).numParams());
// }
}
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input)
.labels(labels).subset(true).maxPerParam(512));
assertTrue(msg, gradOK);
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input).labels(labels).subset(true).maxPerParam(512));
assertTrue(gradOK,msg);
TestUtils.testModelSerialization(net);
}
}
}
}
@Test
public void testCnn3DPooling() {
@DisplayName("Test Cnn 3 D Pooling")
void testCnn3DPooling() {
Nd4j.getRandom().setSeed(42);
int depth = 4;
int height = 4;
int width = 4;
int[] minibatchSizes = {3};
int[] minibatchSizes = { 3 };
int convNIn = 2;
int convNOut = 4;
int denseNOut = 5;
int finalNOut = 42;
int[] kernel = {2, 2, 2};
Activation[] activations = {Activation.SIGMOID};
Subsampling3DLayer.PoolingType[] poolModes = {Subsampling3DLayer.PoolingType.AVG};
ConvolutionMode[] modes = {ConvolutionMode.Truncate};
int[] kernel = { 2, 2, 2 };
Activation[] activations = { Activation.SIGMOID };
Subsampling3DLayer.PoolingType[] poolModes = { Subsampling3DLayer.PoolingType.AVG };
ConvolutionMode[] modes = { ConvolutionMode.Truncate };
for (Activation afn : activations) {
for (int miniBatchSize : minibatchSizes) {
for (Subsampling3DLayer.PoolingType pool : poolModes) {
for (ConvolutionMode mode : modes) {
for (Convolution3D.DataFormat df : Convolution3D.DataFormat.values()) {
int outDepth = depth / kernel[0];
int outHeight = height / kernel[1];
int outWidth = width / kernel[2];
INDArray input = Nd4j.rand(
df == Convolution3D.DataFormat.NCDHW ? new int[]{miniBatchSize, convNIn, depth, height, width}
: new int[]{miniBatchSize, depth, height, width, convNIn});
INDArray input = Nd4j.rand(df == Convolution3D.DataFormat.NCDHW ? new int[] { miniBatchSize, convNIn, depth, height, width } : new int[] { miniBatchSize, depth, height, width, convNIn });
INDArray labels = Nd4j.zeros(miniBatchSize, finalNOut);
for (int i = 0; i < miniBatchSize; i++) {
labels.putScalar(new int[]{i, i % finalNOut}, 1.0);
labels.putScalar(new int[] { i, i % finalNOut }, 1.0);
}
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.updater(new NoOp())
.weightInit(WeightInit.XAVIER)
.dist(new NormalDistribution(0, 1))
.list()
.layer(0, new Convolution3D.Builder().activation(afn).kernelSize(1, 1, 1)
.nIn(convNIn).nOut(convNOut).hasBias(false)
.convolutionMode(mode).dataFormat(df)
.build())
.layer(1, new Subsampling3DLayer.Builder(kernel)
.poolingType(pool).convolutionMode(mode).dataFormat(df).build())
.layer(2, new DenseLayer.Builder().nOut(denseNOut).build())
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(finalNOut).build())
.inputPreProcessor(2,
new Cnn3DToFeedForwardPreProcessor(outDepth, outHeight, outWidth,convNOut, df))
.setInputType(InputType.convolutional3D(df, depth, height, width, convNIn)).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).updater(new NoOp()).weightInit(WeightInit.XAVIER).dist(new NormalDistribution(0, 1)).list().layer(0, new Convolution3D.Builder().activation(afn).kernelSize(1, 1, 1).nIn(convNIn).nOut(convNOut).hasBias(false).convolutionMode(mode).dataFormat(df).build()).layer(1, new Subsampling3DLayer.Builder(kernel).poolingType(pool).convolutionMode(mode).dataFormat(df).build()).layer(2, new DenseLayer.Builder().nOut(denseNOut).build()).layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(finalNOut).build()).inputPreProcessor(2, new Cnn3DToFeedForwardPreProcessor(outDepth, outHeight, outWidth, convNOut, df)).setInputType(InputType.convolutional3D(df, depth, height, width, convNIn)).build();
String json = conf.toJson();
MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json);
assertEquals(conf, c2);
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
String msg = "Minibatch size = " + miniBatchSize + ", activationFn=" + afn
+ ", kernel = " + Arrays.toString(kernel) + ", mode = " + mode.toString()
+ ", input depth " + depth + ", input height " + height
+ ", input width " + width + ", dataFormat=" + df;
String msg = "Minibatch size = " + miniBatchSize + ", activationFn=" + afn + ", kernel = " + Arrays.toString(kernel) + ", mode = " + mode.toString() + ", input depth " + depth + ", input height " + height + ", input width " + width + ", dataFormat=" + df;
if (PRINT_RESULTS) {
log.info(msg);
}
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS,
DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS,
RETURN_ON_FIRST_FAILURE, input, labels);
assertTrue(msg, gradOK);
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
assertTrue(gradOK,msg);
TestUtils.testModelSerialization(net);
}
}
@ -365,87 +238,47 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest {
}
@Test
public void testCnn3DUpsampling() {
@DisplayName("Test Cnn 3 D Upsampling")
void testCnn3DUpsampling() {
Nd4j.getRandom().setSeed(42);
int depth = 2;
int height = 2;
int width = 2;
int[] minibatchSizes = {3};
int[] minibatchSizes = { 3 };
int convNIn = 2;
int convNOut = 4;
int denseNOut = 5;
int finalNOut = 42;
int[] upsamplingSize = {2, 2, 2};
Activation[] activations = {Activation.SIGMOID};
ConvolutionMode[] modes = {ConvolutionMode.Truncate};
int[] upsamplingSize = { 2, 2, 2 };
Activation[] activations = { Activation.SIGMOID };
ConvolutionMode[] modes = { ConvolutionMode.Truncate };
for (Activation afn : activations) {
for (int miniBatchSize : minibatchSizes) {
for (ConvolutionMode mode : modes) {
for(Convolution3D.DataFormat df : Convolution3D.DataFormat.values()) {
for (Convolution3D.DataFormat df : Convolution3D.DataFormat.values()) {
int outDepth = depth * upsamplingSize[0];
int outHeight = height * upsamplingSize[1];
int outWidth = width * upsamplingSize[2];
INDArray input = df == Convolution3D.DataFormat.NCDHW ? Nd4j.rand(miniBatchSize, convNIn, depth, height, width) : Nd4j.rand(miniBatchSize, depth, height, width, convNIn);
INDArray labels = Nd4j.zeros(miniBatchSize, finalNOut);
for (int i = 0; i < miniBatchSize; i++) {
labels.putScalar(new int[]{i, i % finalNOut}, 1.0);
labels.putScalar(new int[] { i, i % finalNOut }, 1.0);
}
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.updater(new NoOp()).weightInit(WeightInit.LECUN_NORMAL)
.dist(new NormalDistribution(0, 1))
.seed(12345)
.list()
.layer(0, new Convolution3D.Builder().activation(afn).kernelSize(1, 1, 1)
.nIn(convNIn).nOut(convNOut).hasBias(false)
.convolutionMode(mode).dataFormat(df)
.build())
.layer(1, new Upsampling3D.Builder(upsamplingSize[0]).dataFormat(df).build())
.layer(2, new DenseLayer.Builder().nOut(denseNOut).build())
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(finalNOut).build())
.inputPreProcessor(2,
new Cnn3DToFeedForwardPreProcessor(outDepth, outHeight, outWidth,
convNOut, true))
.setInputType(InputType.convolutional3D(df, depth, height, width, convNIn)).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).updater(new NoOp()).weightInit(WeightInit.LECUN_NORMAL).dist(new NormalDistribution(0, 1)).seed(12345).list().layer(0, new Convolution3D.Builder().activation(afn).kernelSize(1, 1, 1).nIn(convNIn).nOut(convNOut).hasBias(false).convolutionMode(mode).dataFormat(df).build()).layer(1, new Upsampling3D.Builder(upsamplingSize[0]).dataFormat(df).build()).layer(2, new DenseLayer.Builder().nOut(denseNOut).build()).layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(finalNOut).build()).inputPreProcessor(2, new Cnn3DToFeedForwardPreProcessor(outDepth, outHeight, outWidth, convNOut, true)).setInputType(InputType.convolutional3D(df, depth, height, width, convNIn)).build();
String json = conf.toJson();
MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json);
assertEquals(conf, c2);
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
String msg = "Minibatch size = " + miniBatchSize + ", activationFn=" + afn
+ ", kernel = " + Arrays.toString(upsamplingSize) + ", mode = " + mode.toString()
+ ", input depth " + depth + ", input height " + height
+ ", input width " + width;
String msg = "Minibatch size = " + miniBatchSize + ", activationFn=" + afn + ", kernel = " + Arrays.toString(upsamplingSize) + ", mode = " + mode.toString() + ", input depth " + depth + ", input height " + height + ", input width " + width;
if (PRINT_RESULTS) {
log.info(msg);
// for (int j = 0; j < net.getnLayers(); j++) {
// log.info("Layer " + j + " # params: " + net.getLayer(j).numParams());
// }
// for (int j = 0; j < net.getnLayers(); j++) {
// log.info("Layer " + j + " # params: " + net.getLayer(j).numParams());
// }
}
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS,
DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS,
RETURN_ON_FIRST_FAILURE, input, labels);
assertTrue(msg, gradOK);
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
assertTrue(gradOK,msg);
TestUtils.testModelSerialization(net);
}
}
@ -454,126 +287,74 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest {
}
@Test
public void testCnn3DCropping() {
@DisplayName("Test Cnn 3 D Cropping")
void testCnn3DCropping() {
Nd4j.getRandom().setSeed(42);
int depth = 6;
int height = 6;
int width = 6;
int[] minibatchSizes = {3};
int[] minibatchSizes = { 3 };
int convNIn = 2;
int convNOut1 = 3;
int convNOut2 = 4;
int denseNOut = 5;
int finalNOut = 8;
int[] kernel = {1, 1, 1};
int[] cropping = {0, 0, 1, 1, 2, 2};
Activation[] activations = {Activation.SIGMOID};
ConvolutionMode[] modes = {ConvolutionMode.Same};
int[] kernel = { 1, 1, 1 };
int[] cropping = { 0, 0, 1, 1, 2, 2 };
Activation[] activations = { Activation.SIGMOID };
ConvolutionMode[] modes = { ConvolutionMode.Same };
for (Activation afn : activations) {
for (int miniBatchSize : minibatchSizes) {
for (ConvolutionMode mode : modes) {
int outDepth = mode == ConvolutionMode.Same ?
depth : (depth - kernel[0]) + 1;
int outHeight = mode == ConvolutionMode.Same ?
height : (height - kernel[1]) + 1;
int outWidth = mode == ConvolutionMode.Same ?
width : (width - kernel[2]) + 1;
int outDepth = mode == ConvolutionMode.Same ? depth : (depth - kernel[0]) + 1;
int outHeight = mode == ConvolutionMode.Same ? height : (height - kernel[1]) + 1;
int outWidth = mode == ConvolutionMode.Same ? width : (width - kernel[2]) + 1;
outDepth -= cropping[0] + cropping[1];
outHeight -= cropping[2] + cropping[3];
outWidth -= cropping[4] + cropping[5];
INDArray input = Nd4j.rand(new int[]{miniBatchSize, convNIn, depth, height, width});
INDArray input = Nd4j.rand(new int[] { miniBatchSize, convNIn, depth, height, width });
INDArray labels = Nd4j.zeros(miniBatchSize, finalNOut);
for (int i = 0; i < miniBatchSize; i++) {
labels.putScalar(new int[]{i, i % finalNOut}, 1.0);
labels.putScalar(new int[] { i, i % finalNOut }, 1.0);
}
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.updater(new NoOp()).weightInit(WeightInit.LECUN_NORMAL)
.dist(new NormalDistribution(0, 1))
.list()
.layer(0, new Convolution3D.Builder().activation(afn).kernelSize(kernel)
.nIn(convNIn).nOut(convNOut1).hasBias(false)
.convolutionMode(mode).dataFormat(Convolution3D.DataFormat.NCDHW)
.build())
.layer(1, new Convolution3D.Builder().activation(afn).kernelSize(1, 1, 1)
.nIn(convNOut1).nOut(convNOut2).hasBias(false)
.convolutionMode(mode).dataFormat(Convolution3D.DataFormat.NCDHW)
.build())
.layer(2, new Cropping3D.Builder(cropping).build())
.layer(3, new DenseLayer.Builder().nOut(denseNOut).build())
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(finalNOut).build())
.inputPreProcessor(3,
new Cnn3DToFeedForwardPreProcessor(outDepth, outHeight, outWidth,
convNOut2, true))
.setInputType(InputType.convolutional3D(depth, height, width, convNIn)).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).updater(new NoOp()).weightInit(WeightInit.LECUN_NORMAL).dist(new NormalDistribution(0, 1)).list().layer(0, new Convolution3D.Builder().activation(afn).kernelSize(kernel).nIn(convNIn).nOut(convNOut1).hasBias(false).convolutionMode(mode).dataFormat(Convolution3D.DataFormat.NCDHW).build()).layer(1, new Convolution3D.Builder().activation(afn).kernelSize(1, 1, 1).nIn(convNOut1).nOut(convNOut2).hasBias(false).convolutionMode(mode).dataFormat(Convolution3D.DataFormat.NCDHW).build()).layer(2, new Cropping3D.Builder(cropping).build()).layer(3, new DenseLayer.Builder().nOut(denseNOut).build()).layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(finalNOut).build()).inputPreProcessor(3, new Cnn3DToFeedForwardPreProcessor(outDepth, outHeight, outWidth, convNOut2, true)).setInputType(InputType.convolutional3D(depth, height, width, convNIn)).build();
String json = conf.toJson();
MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json);
assertEquals(conf, c2);
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
String msg = "Minibatch size = " + miniBatchSize + ", activationFn=" + afn
+ ", kernel = " + Arrays.toString(kernel) + ", mode = " + mode.toString()
+ ", input depth " + depth + ", input height " + height
+ ", input width " + width;
String msg = "Minibatch size = " + miniBatchSize + ", activationFn=" + afn + ", kernel = " + Arrays.toString(kernel) + ", mode = " + mode.toString() + ", input depth " + depth + ", input height " + height + ", input width " + width;
if (PRINT_RESULTS) {
log.info(msg);
// for (int j = 0; j < net.getnLayers(); j++) {
// log.info("Layer " + j + " # params: " + net.getLayer(j).numParams());
// }
// for (int j = 0; j < net.getnLayers(); j++) {
// log.info("Layer " + j + " # params: " + net.getLayer(j).numParams());
// }
}
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS,
DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS,
RETURN_ON_FIRST_FAILURE, input, labels);
assertTrue(msg, gradOK);
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
assertTrue(gradOK,msg);
TestUtils.testModelSerialization(net);
}
}
}
}
@Test
public void testDeconv3d() {
@DisplayName("Test Deconv 3 d")
void testDeconv3d() {
Nd4j.getRandom().setSeed(12345);
// Note: we checked this with a variety of parameters, but it takes a lot of time.
int[] depths = {8, 8, 9};
int[] heights = {8, 9, 9};
int[] widths = {8, 8, 9};
int[][] kernels = {{2, 2, 2}, {3, 3, 3}, {2, 3, 2}};
int[][] strides = {{1, 1, 1}, {1, 1, 1}, {2, 2, 2}};
Activation[] activations = {Activation.SIGMOID, Activation.TANH, Activation.IDENTITY};
ConvolutionMode[] modes = {ConvolutionMode.Truncate, ConvolutionMode.Same, ConvolutionMode.Same};
int[] mbs = {1, 3, 2};
Convolution3D.DataFormat[] dataFormats = new Convolution3D.DataFormat[]{Convolution3D.DataFormat.NCDHW, Convolution3D.DataFormat.NDHWC, Convolution3D.DataFormat.NCDHW};
int[] depths = { 8, 8, 9 };
int[] heights = { 8, 9, 9 };
int[] widths = { 8, 8, 9 };
int[][] kernels = { { 2, 2, 2 }, { 3, 3, 3 }, { 2, 3, 2 } };
int[][] strides = { { 1, 1, 1 }, { 1, 1, 1 }, { 2, 2, 2 } };
Activation[] activations = { Activation.SIGMOID, Activation.TANH, Activation.IDENTITY };
ConvolutionMode[] modes = { ConvolutionMode.Truncate, ConvolutionMode.Same, ConvolutionMode.Same };
int[] mbs = { 1, 3, 2 };
Convolution3D.DataFormat[] dataFormats = new Convolution3D.DataFormat[] { Convolution3D.DataFormat.NCDHW, Convolution3D.DataFormat.NDHWC, Convolution3D.DataFormat.NCDHW };
int convNIn = 2;
int finalNOut = 2;
int[] deconvOut = {2, 3, 4};
int[] deconvOut = { 2, 3, 4 };
for (int i = 0; i < activations.length; i++) {
Activation afn = activations[i];
int miniBatchSize = mbs[i];
@ -585,57 +366,28 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest {
int[] stride = strides[i];
Convolution3D.DataFormat df = dataFormats[i];
int dOut = deconvOut[i];
INDArray input;
if (df == Convolution3D.DataFormat.NDHWC) {
input = Nd4j.rand(new int[]{miniBatchSize, depth, height, width, convNIn});
input = Nd4j.rand(new int[] { miniBatchSize, depth, height, width, convNIn });
} else {
input = Nd4j.rand(new int[]{miniBatchSize, convNIn, depth, height, width});
input = Nd4j.rand(new int[] { miniBatchSize, convNIn, depth, height, width });
}
INDArray labels = Nd4j.zeros(miniBatchSize, finalNOut);
for (int j = 0; j < miniBatchSize; j++) {
labels.putScalar(new int[]{j, j % finalNOut}, 1.0);
labels.putScalar(new int[] { j, j % finalNOut }, 1.0);
}
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.updater(new NoOp())
.weightInit(new NormalDistribution(0, 0.1))
.list()
.layer(0, new Convolution3D.Builder().activation(afn).kernelSize(kernel)
.stride(stride).nIn(convNIn).nOut(dOut).hasBias(false)
.convolutionMode(mode).dataFormat(df)
.build())
.layer(1, new Deconvolution3D.Builder().activation(afn).kernelSize(kernel)
.stride(stride).nOut(dOut).hasBias(false)
.convolutionMode(mode).dataFormat(df)
.build())
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(finalNOut).build())
.setInputType(InputType.convolutional3D(df, depth, height, width, convNIn)).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).updater(new NoOp()).weightInit(new NormalDistribution(0, 0.1)).list().layer(0, new Convolution3D.Builder().activation(afn).kernelSize(kernel).stride(stride).nIn(convNIn).nOut(dOut).hasBias(false).convolutionMode(mode).dataFormat(df).build()).layer(1, new Deconvolution3D.Builder().activation(afn).kernelSize(kernel).stride(stride).nOut(dOut).hasBias(false).convolutionMode(mode).dataFormat(df).build()).layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(finalNOut).build()).setInputType(InputType.convolutional3D(df, depth, height, width, convNIn)).build();
String json = conf.toJson();
MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json);
assertEquals(conf, c2);
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
String msg = "DataFormat = " + df + ", minibatch size = " + miniBatchSize + ", activationFn=" + afn
+ ", kernel = " + Arrays.toString(kernel) + ", stride = "
+ Arrays.toString(stride) + ", mode = " + mode.toString()
+ ", input depth " + depth + ", input height " + height
+ ", input width " + width;
String msg = "DataFormat = " + df + ", minibatch size = " + miniBatchSize + ", activationFn=" + afn + ", kernel = " + Arrays.toString(kernel) + ", stride = " + Arrays.toString(stride) + ", mode = " + mode.toString() + ", input depth " + depth + ", input height " + height + ", input width " + width;
if (PRINT_RESULTS) {
log.info(msg);
}
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input)
.labels(labels).subset(true).maxPerParam(64));
assertTrue(msg, gradOK);
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input).labels(labels).subset(true).maxPerParam(64));
assertTrue(gradOK,msg);
TestUtils.testModelSerialization(net);
}
}

View File

@ -17,11 +17,9 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.gradientcheck;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
@ -35,19 +33,21 @@ import org.deeplearning4j.nn.conf.layers.LossLayer;
import org.deeplearning4j.nn.conf.layers.PrimaryCapsules;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInitDistribution;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.impl.ActivationSoftmax;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood;
import java.util.Random;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@Ignore
public class CapsnetGradientCheckTest extends BaseDL4JTest {
@Disabled
@DisplayName("Capsnet Gradient Check Test")
class CapsnetGradientCheckTest extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
@ -55,71 +55,39 @@ public class CapsnetGradientCheckTest extends BaseDL4JTest {
}
@Test
public void testCapsNet() {
int[] minibatchSizes = {8, 16};
@DisplayName("Test Caps Net")
void testCapsNet() {
int[] minibatchSizes = { 8, 16 };
int width = 6;
int height = 6;
int inputDepth = 4;
int[] primaryCapsDims = {2, 4};
int[] primaryCapsChannels = {8};
int[] capsules = {5};
int[] capsuleDims = {4, 8};
int[] routings = {1};
int[] primaryCapsDims = { 2, 4 };
int[] primaryCapsChannels = { 8 };
int[] capsules = { 5 };
int[] capsuleDims = { 4, 8 };
int[] routings = { 1 };
Nd4j.getRandom().setSeed(12345);
for (int routing : routings) {
for (int primaryCapsDim : primaryCapsDims) {
for (int primarpCapsChannel : primaryCapsChannels) {
for (int capsule : capsules) {
for (int capsuleDim : capsuleDims) {
for (int minibatchSize : minibatchSizes) {
INDArray input = Nd4j.rand(minibatchSize, inputDepth * height * width).mul(10)
.reshape(-1, inputDepth, height, width);
INDArray input = Nd4j.rand(minibatchSize, inputDepth * height * width).mul(10).reshape(-1, inputDepth, height, width);
INDArray labels = Nd4j.zeros(minibatchSize, capsule);
for (int i = 0; i < minibatchSize; i++) {
labels.putScalar(new int[]{i, i % capsule}, 1.0);
labels.putScalar(new int[] { i, i % capsule }, 1.0);
}
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.seed(123)
.updater(new NoOp())
.weightInit(new WeightInitDistribution(new UniformDistribution(-6, 6)))
.list()
.layer(new PrimaryCapsules.Builder(primaryCapsDim, primarpCapsChannel)
.kernelSize(3, 3)
.stride(2, 2)
.build())
.layer(new CapsuleLayer.Builder(capsule, capsuleDim, routing).build())
.layer(new CapsuleStrengthLayer.Builder().build())
.layer(new ActivationLayer.Builder(new ActivationSoftmax()).build())
.layer(new LossLayer.Builder(new LossNegativeLogLikelihood()).build())
.setInputType(InputType.convolutional(height, width, inputDepth))
.build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dataType(DataType.DOUBLE).seed(123).updater(new NoOp()).weightInit(new WeightInitDistribution(new UniformDistribution(-6, 6))).list().layer(new PrimaryCapsules.Builder(primaryCapsDim, primarpCapsChannel).kernelSize(3, 3).stride(2, 2).build()).layer(new CapsuleLayer.Builder(capsule, capsuleDim, routing).build()).layer(new CapsuleStrengthLayer.Builder().build()).layer(new ActivationLayer.Builder(new ActivationSoftmax()).build()).layer(new LossLayer.Builder(new LossNegativeLogLikelihood()).build()).setInputType(InputType.convolutional(height, width, inputDepth)).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
for (int i = 0; i < 4; i++) {
System.out.println("nParams, layer " + i + ": " + net.getLayer(i).numParams());
}
String msg = "minibatch=" + minibatchSize +
", PrimaryCaps: " + primarpCapsChannel +
" channels, " + primaryCapsDim + " dimensions, Capsules: " + capsule +
" capsules with " + capsuleDim + " dimensions and " + routing + " routings";
String msg = "minibatch=" + minibatchSize + ", PrimaryCaps: " + primarpCapsChannel + " channels, " + primaryCapsDim + " dimensions, Capsules: " + capsule + " capsules with " + capsuleDim + " dimensions and " + routing + " routings";
System.out.println(msg);
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input)
.labels(labels).subset(true).maxPerParam(100));
assertTrue(msg, gradOK);
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input).labels(labels).subset(true).maxPerParam(100));
assertTrue(gradOK,msg);
TestUtils.testModelSerialization(net);
}
}

View File

@ -17,34 +17,34 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.nn.adapters;
import lombok.val;
import org.deeplearning4j.BaseDL4JTest;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.factory.Nd4j;
import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.*;
@DisplayName("Argmax Adapter Test")
class ArgmaxAdapterTest extends BaseDL4JTest {
public class ArgmaxAdapterTest extends BaseDL4JTest {
@Test
public void testSoftmax_2D_1() {
val in = new double[][] {{1, 3, 2}, { 4, 5, 6}};
@DisplayName("Test Softmax _ 2 D _ 1")
void testSoftmax_2D_1() {
val in = new double[][] { { 1, 3, 2 }, { 4, 5, 6 } };
val adapter = new ArgmaxAdapter();
val result = adapter.apply(Nd4j.create(in));
assertArrayEquals(new int[]{1, 2}, result);
assertArrayEquals(new int[] { 1, 2 }, result);
}
@Test
public void testSoftmax_1D_1() {
val in = new double[] {1, 3, 2};
@DisplayName("Test Softmax _ 1 D _ 1")
void testSoftmax_1D_1() {
val in = new double[] { 1, 3, 2 };
val adapter = new ArgmaxAdapter();
val result = adapter.apply(Nd4j.create(in));
assertArrayEquals(new int[]{1}, result);
assertArrayEquals(new int[] { 1 }, result);
}
}
}

View File

@ -17,35 +17,37 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.nn.adapters;
import lombok.val;
import org.deeplearning4j.BaseDL4JTest;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.util.ArrayUtil;
import static org.junit.jupiter.api.Assertions.*;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.Assert.*;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@DisplayName("Regression 2 d Adapter Test")
class Regression2dAdapterTest extends BaseDL4JTest {
public class Regression2dAdapterTest extends BaseDL4JTest {
@Test
public void testRegressionAdapter_2D_1() throws Exception {
val in = new double[][] {{1, 2, 3}, { 4, 5, 6}};
@DisplayName("Test Regression Adapter _ 2 D _ 1")
void testRegressionAdapter_2D_1() throws Exception {
val in = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } };
val adapter = new Regression2dAdapter();
val result = adapter.apply(Nd4j.create(in));
assertArrayEquals(ArrayUtil.flatten(in), ArrayUtil.flatten(result), 1e-5);
}
@Test
public void testRegressionAdapter_2D_2() throws Exception {
val in = new double[]{1, 2, 3};
@DisplayName("Test Regression Adapter _ 2 D _ 2")
void testRegressionAdapter_2D_2() throws Exception {
val in = new double[] { 1, 2, 3 };
val adapter = new Regression2dAdapter();
val result = adapter.apply(Nd4j.create(in));
assertArrayEquals(in, ArrayUtil.flatten(result), 1e-5);
}
}
}

View File

@ -17,10 +17,8 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.nn.conf;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.EqualsAndHashCode;
@ -43,296 +41,158 @@ import org.deeplearning4j.nn.conf.misc.TestGraphVertex;
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToCnnPreProcessor;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import static org.junit.Assert.*;
import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@Slf4j
public class ComputationGraphConfigurationTest extends BaseDL4JTest {
@DisplayName("Computation Graph Configuration Test")
class ComputationGraphConfigurationTest extends BaseDL4JTest {
@Test
public void testJSONBasic() {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.dist(new NormalDistribution(0, 1)).updater(new NoOp())
.graphBuilder().addInputs("input")
.appendLayer("firstLayer",
new DenseLayer.Builder().nIn(4).nOut(5).activation(Activation.TANH).build())
.addLayer("outputLayer",
new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(5).nOut(3).build(),
"firstLayer")
.setOutputs("outputLayer").build();
@DisplayName("Test JSON Basic")
void testJSONBasic() {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).dist(new NormalDistribution(0, 1)).updater(new NoOp()).graphBuilder().addInputs("input").appendLayer("firstLayer", new DenseLayer.Builder().nIn(4).nOut(5).activation(Activation.TANH).build()).addLayer("outputLayer", new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(5).nOut(3).build(), "firstLayer").setOutputs("outputLayer").build();
String json = conf.toJson();
ComputationGraphConfiguration conf2 = ComputationGraphConfiguration.fromJson(json);
assertEquals(json, conf2.toJson());
assertEquals(conf, conf2);
}
@Test
public void testJSONBasic2() {
ComputationGraphConfiguration conf =
new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.graphBuilder().addInputs("input")
.addLayer("cnn1",
new ConvolutionLayer.Builder(2, 2).stride(2, 2).nIn(1).nOut(5)
.build(),
"input")
.addLayer("cnn2",
new ConvolutionLayer.Builder(2, 2).stride(2, 2).nIn(1).nOut(5)
.build(),
"input")
.addLayer("max1",
new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2, 2).build(),
"cnn1", "cnn2")
.addLayer("dnn1", new DenseLayer.Builder().nOut(7).build(), "max1")
.addLayer("max2", new SubsamplingLayer.Builder().build(), "max1")
.addLayer("output", new OutputLayer.Builder().nIn(7).nOut(10).activation(Activation.SOFTMAX).build(), "dnn1",
"max2")
.setOutputs("output")
.inputPreProcessor("cnn1", new FeedForwardToCnnPreProcessor(32, 32, 3))
.inputPreProcessor("cnn2", new FeedForwardToCnnPreProcessor(32, 32, 3))
.inputPreProcessor("dnn1", new CnnToFeedForwardPreProcessor(8, 8, 5))
.build();
@DisplayName("Test JSON Basic 2")
void testJSONBasic2() {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder().addInputs("input").addLayer("cnn1", new ConvolutionLayer.Builder(2, 2).stride(2, 2).nIn(1).nOut(5).build(), "input").addLayer("cnn2", new ConvolutionLayer.Builder(2, 2).stride(2, 2).nIn(1).nOut(5).build(), "input").addLayer("max1", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).build(), "cnn1", "cnn2").addLayer("dnn1", new DenseLayer.Builder().nOut(7).build(), "max1").addLayer("max2", new SubsamplingLayer.Builder().build(), "max1").addLayer("output", new OutputLayer.Builder().nIn(7).nOut(10).activation(Activation.SOFTMAX).build(), "dnn1", "max2").setOutputs("output").inputPreProcessor("cnn1", new FeedForwardToCnnPreProcessor(32, 32, 3)).inputPreProcessor("cnn2", new FeedForwardToCnnPreProcessor(32, 32, 3)).inputPreProcessor("dnn1", new CnnToFeedForwardPreProcessor(8, 8, 5)).build();
String json = conf.toJson();
ComputationGraphConfiguration conf2 = ComputationGraphConfiguration.fromJson(json);
assertEquals(json, conf2.toJson());
assertEquals(conf, conf2);
}
@Test
public void testJSONWithGraphNodes() {
ComputationGraphConfiguration conf =
new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.graphBuilder().addInputs("input1", "input2")
.addLayer("cnn1",
new ConvolutionLayer.Builder(2, 2).stride(2, 2).nIn(1).nOut(5)
.build(),
"input1")
.addLayer("cnn2",
new ConvolutionLayer.Builder(2, 2).stride(2, 2).nIn(1).nOut(5)
.build(),
"input2")
.addVertex("merge1", new MergeVertex(), "cnn1", "cnn2")
.addVertex("subset1", new SubsetVertex(0, 1), "merge1")
.addLayer("dense1", new DenseLayer.Builder().nIn(20).nOut(5).build(), "subset1")
.addLayer("dense2", new DenseLayer.Builder().nIn(20).nOut(5).build(), "subset1")
.addVertex("add", new ElementWiseVertex(ElementWiseVertex.Op.Add), "dense1",
"dense2")
.addLayer("out", new OutputLayer.Builder().nIn(1).nOut(1).activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build(), "add")
.setOutputs("out").build();
@DisplayName("Test JSON With Graph Nodes")
void testJSONWithGraphNodes() {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder().addInputs("input1", "input2").addLayer("cnn1", new ConvolutionLayer.Builder(2, 2).stride(2, 2).nIn(1).nOut(5).build(), "input1").addLayer("cnn2", new ConvolutionLayer.Builder(2, 2).stride(2, 2).nIn(1).nOut(5).build(), "input2").addVertex("merge1", new MergeVertex(), "cnn1", "cnn2").addVertex("subset1", new SubsetVertex(0, 1), "merge1").addLayer("dense1", new DenseLayer.Builder().nIn(20).nOut(5).build(), "subset1").addLayer("dense2", new DenseLayer.Builder().nIn(20).nOut(5).build(), "subset1").addVertex("add", new ElementWiseVertex(ElementWiseVertex.Op.Add), "dense1", "dense2").addLayer("out", new OutputLayer.Builder().nIn(1).nOut(1).activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build(), "add").setOutputs("out").build();
String json = conf.toJson();
// System.out.println(json);
// System.out.println(json);
ComputationGraphConfiguration conf2 = ComputationGraphConfiguration.fromJson(json);
assertEquals(json, conf2.toJson());
assertEquals(conf, conf2);
}
@Test
public void testInvalidConfigurations() {
//Test no inputs for a layer:
@DisplayName("Test Invalid Configurations")
void testInvalidConfigurations() {
// Test no inputs for a layer:
try {
new NeuralNetConfiguration.Builder().graphBuilder().addInputs("input1")
.addLayer("dense1", new DenseLayer.Builder().nIn(2).nOut(2).build(), "input1")
.addLayer("out", new OutputLayer.Builder().nIn(2).nOut(2).build()).setOutputs("out")
.build();
new NeuralNetConfiguration.Builder().graphBuilder().addInputs("input1").addLayer("dense1", new DenseLayer.Builder().nIn(2).nOut(2).build(), "input1").addLayer("out", new OutputLayer.Builder().nIn(2).nOut(2).build()).setOutputs("out").build();
fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) {
//OK - exception is good
// OK - exception is good
log.info(e.toString());
}
// Use appendLayer on first layer
try {
new NeuralNetConfiguration.Builder().graphBuilder()
.appendLayer("dense1", new DenseLayer.Builder().nIn(2).nOut(2).build())
.addLayer("out", new OutputLayer.Builder().nIn(2).nOut(2).build()).setOutputs("out")
.build();
new NeuralNetConfiguration.Builder().graphBuilder().appendLayer("dense1", new DenseLayer.Builder().nIn(2).nOut(2).build()).addLayer("out", new OutputLayer.Builder().nIn(2).nOut(2).build()).setOutputs("out").build();
fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) {
//OK - exception is good
// OK - exception is good
log.info(e.toString());
}
//Test no network inputs
// Test no network inputs
try {
new NeuralNetConfiguration.Builder().graphBuilder()
.addLayer("dense1", new DenseLayer.Builder().nIn(2).nOut(2).build(), "input1")
.addLayer("out", new OutputLayer.Builder().nIn(2).nOut(2).build(), "dense1")
.setOutputs("out").build();
new NeuralNetConfiguration.Builder().graphBuilder().addLayer("dense1", new DenseLayer.Builder().nIn(2).nOut(2).build(), "input1").addLayer("out", new OutputLayer.Builder().nIn(2).nOut(2).build(), "dense1").setOutputs("out").build();
fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) {
//OK - exception is good
// OK - exception is good
log.info(e.toString());
}
//Test no network outputs
// Test no network outputs
try {
new NeuralNetConfiguration.Builder().graphBuilder().addInputs("input1")
.addLayer("dense1", new DenseLayer.Builder().nIn(2).nOut(2).build(), "input1")
.addLayer("out", new OutputLayer.Builder().nIn(2).nOut(2).build(), "dense1").build();
new NeuralNetConfiguration.Builder().graphBuilder().addInputs("input1").addLayer("dense1", new DenseLayer.Builder().nIn(2).nOut(2).build(), "input1").addLayer("out", new OutputLayer.Builder().nIn(2).nOut(2).build(), "dense1").build();
fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) {
//OK - exception is good
// OK - exception is good
log.info(e.toString());
}
//Test: invalid input
// Test: invalid input
try {
new NeuralNetConfiguration.Builder().graphBuilder().addInputs("input1")
.addLayer("dense1", new DenseLayer.Builder().nIn(2).nOut(2).build(), "input1")
.addLayer("out", new OutputLayer.Builder().nIn(2).nOut(2).build(), "thisDoesntExist")
.setOutputs("out").build();
new NeuralNetConfiguration.Builder().graphBuilder().addInputs("input1").addLayer("dense1", new DenseLayer.Builder().nIn(2).nOut(2).build(), "input1").addLayer("out", new OutputLayer.Builder().nIn(2).nOut(2).build(), "thisDoesntExist").setOutputs("out").build();
fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) {
//OK - exception is good
// OK - exception is good
log.info(e.toString());
}
//Test: graph with cycles
// Test: graph with cycles
try {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("input1")
.addLayer("dense1", new DenseLayer.Builder().nIn(2).nOut(2).build(), "input1", "dense3")
.addLayer("dense2", new DenseLayer.Builder().nIn(2).nOut(2).build(), "dense1")
.addLayer("dense3", new DenseLayer.Builder().nIn(2).nOut(2).build(), "dense2")
.addLayer("out", new OutputLayer.Builder().nIn(2).nOut(2).lossFunction(LossFunctions.LossFunction.MSE).build(), "dense1")
.setOutputs("out").build();
//Cycle detection happens in ComputationGraph.init()
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("input1").addLayer("dense1", new DenseLayer.Builder().nIn(2).nOut(2).build(), "input1", "dense3").addLayer("dense2", new DenseLayer.Builder().nIn(2).nOut(2).build(), "dense1").addLayer("dense3", new DenseLayer.Builder().nIn(2).nOut(2).build(), "dense2").addLayer("out", new OutputLayer.Builder().nIn(2).nOut(2).lossFunction(LossFunctions.LossFunction.MSE).build(), "dense1").setOutputs("out").build();
// Cycle detection happens in ComputationGraph.init()
ComputationGraph graph = new ComputationGraph(conf);
graph.init();
fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) {
//OK - exception is good
// OK - exception is good
log.info(e.toString());
}
//Test: input != inputType count mismatch
// Test: input != inputType count mismatch
try {
new NeuralNetConfiguration.Builder().graphBuilder().addInputs("input1", "input2")
.setInputTypes(new InputType.InputTypeRecurrent(10, 12))
.addLayer("cnn1",
new ConvolutionLayer.Builder(2, 2).stride(2, 2).nIn(1).nOut(5)
.build(),
"input1")
.addLayer("cnn2",
new ConvolutionLayer.Builder(2, 2).stride(2, 2).nIn(1).nOut(5)
.build(),
"input2")
.addVertex("merge1", new MergeVertex(), "cnn1", "cnn2")
.addVertex("subset1", new SubsetVertex(0, 1), "merge1")
.addLayer("dense1", new DenseLayer.Builder().nIn(20).nOut(5).build(), "subset1")
.addLayer("dense2", new DenseLayer.Builder().nIn(20).nOut(5).build(), "subset1")
.addVertex("add", new ElementWiseVertex(ElementWiseVertex.Op.Add), "dense1",
"dense2")
.addLayer("out", new OutputLayer.Builder().nIn(1).nOut(1).activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build(), "add")
.setOutputs("out").build();
new NeuralNetConfiguration.Builder().graphBuilder().addInputs("input1", "input2").setInputTypes(new InputType.InputTypeRecurrent(10, 12)).addLayer("cnn1", new ConvolutionLayer.Builder(2, 2).stride(2, 2).nIn(1).nOut(5).build(), "input1").addLayer("cnn2", new ConvolutionLayer.Builder(2, 2).stride(2, 2).nIn(1).nOut(5).build(), "input2").addVertex("merge1", new MergeVertex(), "cnn1", "cnn2").addVertex("subset1", new SubsetVertex(0, 1), "merge1").addLayer("dense1", new DenseLayer.Builder().nIn(20).nOut(5).build(), "subset1").addLayer("dense2", new DenseLayer.Builder().nIn(20).nOut(5).build(), "subset1").addVertex("add", new ElementWiseVertex(ElementWiseVertex.Op.Add), "dense1", "dense2").addLayer("out", new OutputLayer.Builder().nIn(1).nOut(1).activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build(), "add").setOutputs("out").build();
fail("No exception thrown for invalid configuration");
} catch (IllegalArgumentException e) {
//OK - exception is good
// OK - exception is good
log.info(e.toString());
}
}
@Test
public void testConfigurationWithRuntimeJSONSubtypes() {
//Idea: suppose someone wants to use a ComputationGraph with a custom GraphVertex
@DisplayName("Test Configuration With Runtime JSON Subtypes")
void testConfigurationWithRuntimeJSONSubtypes() {
// Idea: suppose someone wants to use a ComputationGraph with a custom GraphVertex
// (i.e., one not built into DL4J). Check that this works for JSON serialization
// using runtime/reflection subtype mechanism in ComputationGraphConfiguration.fromJson()
//Check a standard GraphVertex implementation, plus a static inner graph vertex
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in")
.addVertex("test", new TestGraphVertex(3, 7), "in")
.addVertex("test2", new StaticInnerGraphVertex(4, 5), "in").setOutputs("test", "test2").build();
// Check a standard GraphVertex implementation, plus a static inner graph vertex
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in").addVertex("test", new TestGraphVertex(3, 7), "in").addVertex("test2", new StaticInnerGraphVertex(4, 5), "in").setOutputs("test", "test2").build();
String json = conf.toJson();
// System.out.println(json);
// System.out.println(json);
ComputationGraphConfiguration conf2 = ComputationGraphConfiguration.fromJson(json);
assertEquals(conf, conf2);
assertEquals(json, conf2.toJson());
TestGraphVertex tgv = (TestGraphVertex) conf2.getVertices().get("test");
assertEquals(3, tgv.getFirstVal());
assertEquals(7, tgv.getSecondVal());
StaticInnerGraphVertex sigv = (StaticInnerGraphVertex) conf.getVertices().get("test2");
assertEquals(4, sigv.getFirstVal());
assertEquals(5, sigv.getSecondVal());
}
@Test
public void testOutputOrderDoesntChangeWhenCloning() {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in")
.addLayer("out1", new OutputLayer.Builder().nIn(1).nOut(1).build(), "in")
.addLayer("out2", new OutputLayer.Builder().nIn(1).nOut(1).build(), "in")
.addLayer("out3", new OutputLayer.Builder().nIn(1).nOut(1).build(), "in")
.validateOutputLayerConfig(false)
.setOutputs("out1", "out2", "out3").build();
@DisplayName("Test Output Order Doesnt Change When Cloning")
void testOutputOrderDoesntChangeWhenCloning() {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in").addLayer("out1", new OutputLayer.Builder().nIn(1).nOut(1).build(), "in").addLayer("out2", new OutputLayer.Builder().nIn(1).nOut(1).build(), "in").addLayer("out3", new OutputLayer.Builder().nIn(1).nOut(1).build(), "in").validateOutputLayerConfig(false).setOutputs("out1", "out2", "out3").build();
ComputationGraphConfiguration cloned = conf.clone();
String json = conf.toJson();
String jsonCloned = cloned.toJson();
assertEquals(json, jsonCloned);
}
@Test
public void testAllowDisconnectedLayers() {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in")
.addLayer("bidirectional",
new Bidirectional(new LSTM.Builder().activation(Activation.TANH).nOut(10).build()),
"in")
.addLayer("out", new RnnOutputLayer.Builder().nOut(6)
.lossFunction(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX)
.build(), "bidirectional")
.addLayer("disconnected_layer",
new Bidirectional(new LSTM.Builder().activation(Activation.TANH).nOut(10).build()),
"in")
.setOutputs("out")
.setInputTypes(new InputType.InputTypeRecurrent(10, 12))
.allowDisconnected(true)
.build();
@DisplayName("Test Allow Disconnected Layers")
void testAllowDisconnectedLayers() {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in").addLayer("bidirectional", new Bidirectional(new LSTM.Builder().activation(Activation.TANH).nOut(10).build()), "in").addLayer("out", new RnnOutputLayer.Builder().nOut(6).lossFunction(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).build(), "bidirectional").addLayer("disconnected_layer", new Bidirectional(new LSTM.Builder().activation(Activation.TANH).nOut(10).build()), "in").setOutputs("out").setInputTypes(new InputType.InputTypeRecurrent(10, 12)).allowDisconnected(true).build();
ComputationGraph graph = new ComputationGraph(conf);
graph.init();
}
@Test
public void testBidirectionalGraphSummary() {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in")
.addLayer("bidirectional",
new Bidirectional(new LSTM.Builder().activation(Activation.TANH).nOut(10).build()),
"in")
.addLayer("out", new RnnOutputLayer.Builder().nOut(6)
.lossFunction(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX)
.build(), "bidirectional")
.setOutputs("out")
.setInputTypes(new InputType.InputTypeRecurrent(10, 12))
.build();
@DisplayName("Test Bidirectional Graph Summary")
void testBidirectionalGraphSummary() {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in").addLayer("bidirectional", new Bidirectional(new LSTM.Builder().activation(Activation.TANH).nOut(10).build()), "in").addLayer("out", new RnnOutputLayer.Builder().nOut(6).lossFunction(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).build(), "bidirectional").setOutputs("out").setInputTypes(new InputType.InputTypeRecurrent(10, 12)).build();
ComputationGraph graph = new ComputationGraph(conf);
graph.init();
graph.summary();
@ -342,9 +202,11 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
@NoArgsConstructor
@Data
@EqualsAndHashCode(callSuper = false)
public static class StaticInnerGraphVertex extends GraphVertex {
@DisplayName("Static Inner Graph Vertex")
static class StaticInnerGraphVertex extends GraphVertex {
private int firstVal;
private int secondVal;
@Override
@ -368,8 +230,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
}
@Override
public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph graph, String name, int idx,
INDArray paramsView, boolean initializeParams, DataType networkDatatype) {
public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph graph, String name, int idx, INDArray paramsView, boolean initializeParams, DataType networkDatatype) {
throw new UnsupportedOperationException("Not supported");
}
@ -384,9 +245,9 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
}
}
@Test
public void testInvalidOutputLayer(){
@DisplayName("Test Invalid Output Layer")
void testInvalidOutputLayer() {
/*
Test case (invalid configs)
1. nOut=1 + softmax
@ -395,35 +256,24 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
4. xent + relu
5. mcxent + sigmoid
*/
LossFunctions.LossFunction[] lf = new LossFunctions.LossFunction[]{
LossFunctions.LossFunction.MCXENT, LossFunctions.LossFunction.MCXENT, LossFunctions.LossFunction.XENT,
LossFunctions.LossFunction.XENT, LossFunctions.LossFunction.MCXENT};
int[] nOut = new int[]{1, 3, 3, 3, 3};
Activation[] activations = new Activation[]{Activation.SOFTMAX, Activation.TANH, Activation.SOFTMAX, Activation.RELU, Activation.SIGMOID};
for( int i=0; i<lf.length; i++ ){
for(boolean lossLayer : new boolean[]{false, true}) {
for (boolean validate : new boolean[]{true, false}) {
LossFunctions.LossFunction[] lf = new LossFunctions.LossFunction[] { LossFunctions.LossFunction.MCXENT, LossFunctions.LossFunction.MCXENT, LossFunctions.LossFunction.XENT, LossFunctions.LossFunction.XENT, LossFunctions.LossFunction.MCXENT };
int[] nOut = new int[] { 1, 3, 3, 3, 3 };
Activation[] activations = new Activation[] { Activation.SOFTMAX, Activation.TANH, Activation.SOFTMAX, Activation.RELU, Activation.SIGMOID };
for (int i = 0; i < lf.length; i++) {
for (boolean lossLayer : new boolean[] { false, true }) {
for (boolean validate : new boolean[] { true, false }) {
String s = "nOut=" + nOut[i] + ",lossFn=" + lf[i] + ",lossLayer=" + lossLayer + ",validate=" + validate;
if(nOut[i] == 1 && lossLayer)
continue; //nOuts are not availabel in loss layer, can't expect it to detect this case
if (nOut[i] == 1 && lossLayer)
// nOuts are not availabel in loss layer, can't expect it to detect this case
continue;
try {
new NeuralNetConfiguration.Builder()
.graphBuilder()
.addInputs("in")
.layer("0", new DenseLayer.Builder().nIn(10).nOut(10).build(), "in")
.layer("1",
!lossLayer ? new OutputLayer.Builder().nIn(10).nOut(nOut[i]).activation(activations[i]).lossFunction(lf[i]).build()
: new LossLayer.Builder().activation(activations[i]).lossFunction(lf[i]).build(), "0")
.setOutputs("1")
.validateOutputLayerConfig(validate)
.build();
new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in").layer("0", new DenseLayer.Builder().nIn(10).nOut(10).build(), "in").layer("1", !lossLayer ? new OutputLayer.Builder().nIn(10).nOut(nOut[i]).activation(activations[i]).lossFunction(lf[i]).build() : new LossLayer.Builder().activation(activations[i]).lossFunction(lf[i]).build(), "0").setOutputs("1").validateOutputLayerConfig(validate).build();
if (validate) {
fail("Expected exception: " + s);
}
} catch (DL4JInvalidConfigException e) {
if (validate) {
assertTrue(s, e.getMessage().toLowerCase().contains("invalid output"));
assertTrue(e.getMessage().toLowerCase().contains("invalid output"),s);
} else {
fail("Validation should not be enabled");
}

View File

@ -17,102 +17,86 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.nn.conf;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.LossLayer;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.impl.*;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals;
public class JsonTest extends BaseDL4JTest {
@DisplayName("Json Test")
class JsonTest extends BaseDL4JTest {
@Test
public void testJsonLossFunctions() {
ILossFunction[] lossFunctions = new ILossFunction[] {new LossBinaryXENT(), new LossBinaryXENT(),
new LossCosineProximity(), new LossHinge(), new LossKLD(), new LossKLD(), new LossL1(),
new LossL1(), new LossL2(), new LossL2(), new LossMAE(), new LossMAE(), new LossMAPE(),
new LossMAPE(), new LossMCXENT(), new LossMSE(), new LossMSE(), new LossMSLE(), new LossMSLE(),
new LossNegativeLogLikelihood(), new LossNegativeLogLikelihood(), new LossPoisson(),
new LossSquaredHinge(), new LossFMeasure(), new LossFMeasure(2.0)};
Activation[] outputActivationFn = new Activation[] {Activation.SIGMOID, //xent
Activation.SIGMOID, //xent
Activation.TANH, //cosine
Activation.TANH, //hinge -> trying to predict 1 or -1
Activation.SIGMOID, //kld -> probab so should be between 0 and 1
Activation.SOFTMAX, //kld + softmax
Activation.TANH, //l1
Activation.SOFTMAX, //l1 + softmax
Activation.TANH, //l2
Activation.SOFTMAX, //l2 + softmax
Activation.IDENTITY, //mae
Activation.SOFTMAX, //mae + softmax
Activation.IDENTITY, //mape
Activation.SOFTMAX, //mape + softmax
Activation.SOFTMAX, //mcxent
Activation.IDENTITY, //mse
Activation.SOFTMAX, //mse + softmax
Activation.SIGMOID, //msle - requires positive labels/activations due to log
Activation.SOFTMAX, //msle + softmax
Activation.SIGMOID, //nll
Activation.SOFTMAX, //nll + softmax
Activation.SIGMOID, //poisson - requires positive predictions due to log... not sure if this is the best option
Activation.TANH, //squared hinge
Activation.SIGMOID, //f-measure (binary, single sigmoid output)
Activation.SOFTMAX //f-measure (binary, 2-label softmax output)
};
int[] nOut = new int[] {1, //xent
3, //xent
5, //cosine
3, //hinge
3, //kld
3, //kld + softmax
3, //l1
3, //l1 + softmax
3, //l2
3, //l2 + softmax
3, //mae
3, //mae + softmax
3, //mape
3, //mape + softmax
3, //mcxent
3, //mse
3, //mse + softmax
3, //msle
3, //msle + softmax
3, //nll
3, //nll + softmax
3, //poisson
3, //squared hinge
1, //f-measure (binary, single sigmoid output)
2, //f-measure (binary, 2-label softmax output)
};
@DisplayName("Test Json Loss Functions")
void testJsonLossFunctions() {
ILossFunction[] lossFunctions = new ILossFunction[] { new LossBinaryXENT(), new LossBinaryXENT(), new LossCosineProximity(), new LossHinge(), new LossKLD(), new LossKLD(), new LossL1(), new LossL1(), new LossL2(), new LossL2(), new LossMAE(), new LossMAE(), new LossMAPE(), new LossMAPE(), new LossMCXENT(), new LossMSE(), new LossMSE(), new LossMSLE(), new LossMSLE(), new LossNegativeLogLikelihood(), new LossNegativeLogLikelihood(), new LossPoisson(), new LossSquaredHinge(), new LossFMeasure(), new LossFMeasure(2.0) };
Activation[] outputActivationFn = new Activation[] { // xent
Activation.SIGMOID, // xent
Activation.SIGMOID, // cosine
Activation.TANH, // hinge -> trying to predict 1 or -1
Activation.TANH, // kld -> probab so should be between 0 and 1
Activation.SIGMOID, // kld + softmax
Activation.SOFTMAX, // l1
Activation.TANH, // l1 + softmax
Activation.SOFTMAX, // l2
Activation.TANH, // l2 + softmax
Activation.SOFTMAX, // mae
Activation.IDENTITY, // mae + softmax
Activation.SOFTMAX, // mape
Activation.IDENTITY, // mape + softmax
Activation.SOFTMAX, // mcxent
Activation.SOFTMAX, // mse
Activation.IDENTITY, // mse + softmax
Activation.SOFTMAX, // msle - requires positive labels/activations due to log
Activation.SIGMOID, // msle + softmax
Activation.SOFTMAX, // nll
Activation.SIGMOID, // nll + softmax
Activation.SOFTMAX, // poisson - requires positive predictions due to log... not sure if this is the best option
Activation.SIGMOID, // squared hinge
Activation.TANH, // f-measure (binary, single sigmoid output)
Activation.SIGMOID, // f-measure (binary, 2-label softmax output)
Activation.SOFTMAX };
int[] nOut = new int[] { // xent
1, // xent
3, // cosine
5, // hinge
3, // kld
3, // kld + softmax
3, // l1
3, // l1 + softmax
3, // l2
3, // l2 + softmax
3, // mae
3, // mae + softmax
3, // mape
3, // mape + softmax
3, // mcxent
3, // mse
3, // mse + softmax
3, // msle
3, // msle + softmax
3, // nll
3, // nll + softmax
3, // poisson
3, // squared hinge
3, // f-measure (binary, single sigmoid output)
1, // f-measure (binary, 2-label softmax output)
2 };
for (int i = 0; i < lossFunctions.length; i++) {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(Updater.ADAM).list()
.layer(0, new DenseLayer.Builder().nIn(4).nOut(nOut[i]).activation(Activation.TANH).build())
.layer(1, new LossLayer.Builder().lossFunction(lossFunctions[i])
.activation(outputActivationFn[i]).build())
.validateOutputLayerConfig(false).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(Updater.ADAM).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(nOut[i]).activation(Activation.TANH).build()).layer(1, new LossLayer.Builder().lossFunction(lossFunctions[i]).activation(outputActivationFn[i]).build()).validateOutputLayerConfig(false).build();
String json = conf.toJson();
String yaml = conf.toYaml();
MultiLayerConfiguration fromJson = MultiLayerConfiguration.fromJson(json);
MultiLayerConfiguration fromYaml = MultiLayerConfiguration.fromYaml(yaml);
assertEquals(conf, fromJson);
assertEquals(conf, fromYaml);
}
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.nn.conf;
import lombok.extern.slf4j.Slf4j;
@ -34,41 +33,40 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.io.*;
import java.util.Arrays;
import java.util.Properties;
import static org.junit.Assert.*;
import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
@Slf4j
public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
@DisplayName("Multi Layer Neural Net Configuration Test")
class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@TempDir
public Path testDir;
@Test
public void testJson() throws Exception {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list()
.layer(0, new DenseLayer.Builder().dist(new NormalDistribution(1, 1e-1)).build())
.inputPreProcessor(0, new CnnToFeedForwardPreProcessor()).build();
@DisplayName("Test Json")
void testJson() throws Exception {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(0, new DenseLayer.Builder().dist(new NormalDistribution(1, 1e-1)).build()).inputPreProcessor(0, new CnnToFeedForwardPreProcessor()).build();
String json = conf.toJson();
MultiLayerConfiguration from = MultiLayerConfiguration.fromJson(json);
assertEquals(conf.getConf(0), from.getConf(0));
Properties props = new Properties();
props.put("json", json);
String key = props.getProperty("json");
assertEquals(json, key);
File f = testDir.newFile("props");
File f = testDir.resolve("props").toFile();
f.deleteOnExit();
BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(f));
props.store(bos, "");
@ -82,36 +80,18 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
String json2 = props2.getProperty("json");
MultiLayerConfiguration conf3 = MultiLayerConfiguration.fromJson(json2);
assertEquals(conf.getConf(0), conf3.getConf(0));
}
@Test
public void testConvnetJson() {
@DisplayName("Test Convnet Json")
void testConvnetJson() {
final int numRows = 76;
final int numColumns = 76;
int nChannels = 3;
int outputNum = 6;
int seed = 123;
//setup the network
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed)
.l1(1e-1).l2(2e-4).weightNoise(new DropConnect(0.5)).miniBatch(true)
.optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).list()
.layer(0, new ConvolutionLayer.Builder(5, 5).nOut(5).dropOut(0.5).weightInit(WeightInit.XAVIER)
.activation(Activation.RELU).build())
.layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[] {2, 2})
.build())
.layer(2, new ConvolutionLayer.Builder(3, 3).nOut(10).dropOut(0.5).weightInit(WeightInit.XAVIER)
.activation(Activation.RELU).build())
.layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[] {2, 2})
.build())
.layer(4, new DenseLayer.Builder().nOut(100).activation(Activation.RELU).build())
.layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nOut(outputNum).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX)
.build())
.setInputType(InputType.convolutional(numRows, numColumns, nChannels));
// setup the network
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed).l1(1e-1).l2(2e-4).weightNoise(new DropConnect(0.5)).miniBatch(true).optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).list().layer(0, new ConvolutionLayer.Builder(5, 5).nOut(5).dropOut(0.5).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build()).layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[] { 2, 2 }).build()).layer(2, new ConvolutionLayer.Builder(3, 3).nOut(10).dropOut(0.5).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build()).layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[] { 2, 2 }).build()).layer(4, new DenseLayer.Builder().nOut(100).activation(Activation.RELU).build()).layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(outputNum).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build()).setInputType(InputType.convolutional(numRows, numColumns, nChannels));
MultiLayerConfiguration conf = builder.build();
String json = conf.toJson();
MultiLayerConfiguration conf2 = MultiLayerConfiguration.fromJson(json);
@ -119,30 +99,15 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
}
@Test
public void testUpsamplingConvnetJson() {
@DisplayName("Test Upsampling Convnet Json")
void testUpsamplingConvnetJson() {
final int numRows = 76;
final int numColumns = 76;
int nChannels = 3;
int outputNum = 6;
int seed = 123;
//setup the network
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed)
.l1(1e-1).l2(2e-4).dropOut(0.5).miniBatch(true)
.optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).list()
.layer(new ConvolutionLayer.Builder(5, 5).nOut(5).dropOut(0.5).weightInit(WeightInit.XAVIER)
.activation(Activation.RELU).build())
.layer(new Upsampling2D.Builder().size(2).build())
.layer(2, new ConvolutionLayer.Builder(3, 3).nOut(10).dropOut(0.5).weightInit(WeightInit.XAVIER)
.activation(Activation.RELU).build())
.layer(new Upsampling2D.Builder().size(2).build())
.layer(4, new DenseLayer.Builder().nOut(100).activation(Activation.RELU).build())
.layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nOut(outputNum).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX)
.build())
.setInputType(InputType.convolutional(numRows, numColumns, nChannels));
// setup the network
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed).l1(1e-1).l2(2e-4).dropOut(0.5).miniBatch(true).optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).list().layer(new ConvolutionLayer.Builder(5, 5).nOut(5).dropOut(0.5).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build()).layer(new Upsampling2D.Builder().size(2).build()).layer(2, new ConvolutionLayer.Builder(3, 3).nOut(10).dropOut(0.5).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build()).layer(new Upsampling2D.Builder().size(2).build()).layer(4, new DenseLayer.Builder().nOut(100).activation(Activation.RELU).build()).layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(outputNum).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build()).setInputType(InputType.convolutional(numRows, numColumns, nChannels));
MultiLayerConfiguration conf = builder.build();
String json = conf.toJson();
MultiLayerConfiguration conf2 = MultiLayerConfiguration.fromJson(json);
@ -150,36 +115,26 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
}
@Test
public void testGlobalPoolingJson() {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new NoOp())
.dist(new NormalDistribution(0, 1.0)).seed(12345L).list()
.layer(0, new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).nOut(5).build())
.layer(1, new GlobalPoolingLayer.Builder().poolingType(PoolingType.PNORM).pnorm(3).build())
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(3).build())
.setInputType(InputType.convolutional(32, 32, 1)).build();
@DisplayName("Test Global Pooling Json")
void testGlobalPoolingJson() {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new NoOp()).dist(new NormalDistribution(0, 1.0)).seed(12345L).list().layer(0, new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).nOut(5).build()).layer(1, new GlobalPoolingLayer.Builder().poolingType(PoolingType.PNORM).pnorm(3).build()).layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(3).build()).setInputType(InputType.convolutional(32, 32, 1)).build();
String str = conf.toJson();
MultiLayerConfiguration fromJson = conf.fromJson(str);
assertEquals(conf, fromJson);
}
@Test
public void testYaml() throws Exception {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list()
.layer(0, new DenseLayer.Builder().dist(new NormalDistribution(1, 1e-1)).build())
.inputPreProcessor(0, new CnnToFeedForwardPreProcessor()).build();
@DisplayName("Test Yaml")
void testYaml() throws Exception {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(0, new DenseLayer.Builder().dist(new NormalDistribution(1, 1e-1)).build()).inputPreProcessor(0, new CnnToFeedForwardPreProcessor()).build();
String json = conf.toYaml();
MultiLayerConfiguration from = MultiLayerConfiguration.fromYaml(json);
assertEquals(conf.getConf(0), from.getConf(0));
Properties props = new Properties();
props.put("json", json);
String key = props.getProperty("json");
assertEquals(json, key);
File f = testDir.newFile("props");
File f = testDir.resolve("props").toFile();
f.deleteOnExit();
BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(f));
props.store(bos, "");
@ -193,17 +148,13 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
String yaml = props2.getProperty("json");
MultiLayerConfiguration conf3 = MultiLayerConfiguration.fromYaml(yaml);
assertEquals(conf.getConf(0), conf3.getConf(0));
}
@Test
public void testClone() {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(0, new DenseLayer.Builder().build())
.layer(1, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).build())
.inputPreProcessor(1, new CnnToFeedForwardPreProcessor()).build();
@DisplayName("Test Clone")
void testClone() {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(0, new DenseLayer.Builder().build()).layer(1, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).build()).inputPreProcessor(1, new CnnToFeedForwardPreProcessor()).build();
MultiLayerConfiguration conf2 = conf.clone();
assertEquals(conf, conf2);
assertNotSame(conf, conf2);
assertNotSame(conf.getConfs(), conf2.getConfs());
@ -217,174 +168,125 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
}
@Test
public void testRandomWeightInit() {
@DisplayName("Test Random Weight Init")
void testRandomWeightInit() {
MultiLayerNetwork model1 = new MultiLayerNetwork(getConf());
model1.init();
Nd4j.getRandom().setSeed(12345L);
MultiLayerNetwork model2 = new MultiLayerNetwork(getConf());
model2.init();
float[] p1 = model1.params().data().asFloat();
float[] p2 = model2.params().data().asFloat();
System.out.println(Arrays.toString(p1));
System.out.println(Arrays.toString(p2));
org.junit.Assert.assertArrayEquals(p1, p2, 0.0f);
assertArrayEquals(p1, p2, 0.0f);
}
@Test
public void testTrainingListener() {
@DisplayName("Test Training Listener")
void testTrainingListener() {
MultiLayerNetwork model1 = new MultiLayerNetwork(getConf());
model1.init();
model1.addListeners( new ScoreIterationListener(1));
model1.addListeners(new ScoreIterationListener(1));
MultiLayerNetwork model2 = new MultiLayerNetwork(getConf());
model2.addListeners( new ScoreIterationListener(1));
model2.addListeners(new ScoreIterationListener(1));
model2.init();
Layer[] l1 = model1.getLayers();
for (int i = 0; i < l1.length; i++)
assertTrue(l1[i].getListeners() != null && l1[i].getListeners().size() == 1);
for (int i = 0; i < l1.length; i++) assertTrue(l1[i].getListeners() != null && l1[i].getListeners().size() == 1);
Layer[] l2 = model2.getLayers();
for (int i = 0; i < l2.length; i++)
assertTrue(l2[i].getListeners() != null && l2[i].getListeners().size() == 1);
for (int i = 0; i < l2.length; i++) assertTrue(l2[i].getListeners() != null && l2[i].getListeners().size() == 1);
}
private static MultiLayerConfiguration getConf() {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345l).list()
.layer(0, new DenseLayer.Builder().nIn(2).nOut(2)
.dist(new NormalDistribution(0, 1)).build())
.layer(1, new OutputLayer.Builder().nIn(2).nOut(1)
.activation(Activation.TANH)
.dist(new NormalDistribution(0, 1)).lossFunction(LossFunctions.LossFunction.MSE).build())
.build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345l).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).dist(new NormalDistribution(0, 1)).build()).layer(1, new OutputLayer.Builder().nIn(2).nOut(1).activation(Activation.TANH).dist(new NormalDistribution(0, 1)).lossFunction(LossFunctions.LossFunction.MSE).build()).build();
return conf;
}
@Test
public void testInvalidConfig() {
@DisplayName("Test Invalid Config")
void testInvalidConfig() {
try {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).list()
.build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).list().build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) {
//OK
log.error("",e);
// OK
log.error("", e);
} catch (Throwable e) {
log.error("",e);
log.error("", e);
fail("Unexpected exception thrown for invalid config");
}
try {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).list()
.layer(1, new DenseLayer.Builder().nIn(3).nOut(4).build())
.layer(2, new OutputLayer.Builder().nIn(4).nOut(5).build())
.build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).list().layer(1, new DenseLayer.Builder().nIn(3).nOut(4).build()).layer(2, new OutputLayer.Builder().nIn(4).nOut(5).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) {
//OK
// OK
log.info(e.toString());
} catch (Throwable e) {
log.error("",e);
log.error("", e);
fail("Unexpected exception thrown for invalid config");
}
try {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).list()
.layer(0, new DenseLayer.Builder().nIn(3).nOut(4).build())
.layer(2, new OutputLayer.Builder().nIn(4).nOut(5).build())
.build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).list().layer(0, new DenseLayer.Builder().nIn(3).nOut(4).build()).layer(2, new OutputLayer.Builder().nIn(4).nOut(5).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) {
//OK
// OK
log.info(e.toString());
} catch (Throwable e) {
log.error("",e);
log.error("", e);
fail("Unexpected exception thrown for invalid config");
}
}
@Test
public void testListOverloads() {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).list()
.layer(0, new DenseLayer.Builder().nIn(3).nOut(4).build())
.layer(1, new OutputLayer.Builder().nIn(4).nOut(5).activation(Activation.SOFTMAX).build())
.build();
@DisplayName("Test List Overloads")
void testListOverloads() {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).list().layer(0, new DenseLayer.Builder().nIn(3).nOut(4).build()).layer(1, new OutputLayer.Builder().nIn(4).nOut(5).activation(Activation.SOFTMAX).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
DenseLayer dl = (DenseLayer) conf.getConf(0).getLayer();
assertEquals(3, dl.getNIn());
assertEquals(4, dl.getNOut());
OutputLayer ol = (OutputLayer) conf.getConf(1).getLayer();
assertEquals(4, ol.getNIn());
assertEquals(5, ol.getNOut());
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345).list()
.layer(0, new DenseLayer.Builder().nIn(3).nOut(4).build())
.layer(1, new OutputLayer.Builder().nIn(4).nOut(5).activation(Activation.SOFTMAX).build())
.build();
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345).list().layer(0, new DenseLayer.Builder().nIn(3).nOut(4).build()).layer(1, new OutputLayer.Builder().nIn(4).nOut(5).activation(Activation.SOFTMAX).build()).build();
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
net2.init();
MultiLayerConfiguration conf3 = new NeuralNetConfiguration.Builder().seed(12345)
.list(new DenseLayer.Builder().nIn(3).nOut(4).build(),
new OutputLayer.Builder().nIn(4).nOut(5).activation(Activation.SOFTMAX).build())
.build();
MultiLayerConfiguration conf3 = new NeuralNetConfiguration.Builder().seed(12345).list(new DenseLayer.Builder().nIn(3).nOut(4).build(), new OutputLayer.Builder().nIn(4).nOut(5).activation(Activation.SOFTMAX).build()).build();
MultiLayerNetwork net3 = new MultiLayerNetwork(conf3);
net3.init();
assertEquals(conf, conf2);
assertEquals(conf, conf3);
}
@Test
public void testBiasLr() {
//setup the network
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(new Adam(1e-2))
.biasUpdater(new Adam(0.5)).list()
.layer(0, new ConvolutionLayer.Builder(5, 5).nOut(5).weightInit(WeightInit.XAVIER)
.activation(Activation.RELU).build())
.layer(1, new DenseLayer.Builder().nOut(100).activation(Activation.RELU).build())
.layer(2, new DenseLayer.Builder().nOut(100).activation(Activation.RELU).build())
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(10)
.weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build())
.setInputType(InputType.convolutional(28, 28, 1)).build();
@DisplayName("Test Bias Lr")
void testBiasLr() {
// setup the network
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(new Adam(1e-2)).biasUpdater(new Adam(0.5)).list().layer(0, new ConvolutionLayer.Builder(5, 5).nOut(5).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build()).layer(1, new DenseLayer.Builder().nOut(100).activation(Activation.RELU).build()).layer(2, new DenseLayer.Builder().nOut(100).activation(Activation.RELU).build()).layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(10).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build()).setInputType(InputType.convolutional(28, 28, 1)).build();
org.deeplearning4j.nn.conf.layers.BaseLayer l0 = (BaseLayer) conf.getConf(0).getLayer();
org.deeplearning4j.nn.conf.layers.BaseLayer l1 = (BaseLayer) conf.getConf(1).getLayer();
org.deeplearning4j.nn.conf.layers.BaseLayer l2 = (BaseLayer) conf.getConf(2).getLayer();
org.deeplearning4j.nn.conf.layers.BaseLayer l3 = (BaseLayer) conf.getConf(3).getLayer();
assertEquals(0.5, ((Adam)l0.getUpdaterByParam("b")).getLearningRate(), 1e-6);
assertEquals(1e-2, ((Adam)l0.getUpdaterByParam("W")).getLearningRate(), 1e-6);
assertEquals(0.5, ((Adam)l1.getUpdaterByParam("b")).getLearningRate(), 1e-6);
assertEquals(1e-2, ((Adam)l1.getUpdaterByParam("W")).getLearningRate(), 1e-6);
assertEquals(0.5, ((Adam)l2.getUpdaterByParam("b")).getLearningRate(), 1e-6);
assertEquals(1e-2, ((Adam)l2.getUpdaterByParam("W")).getLearningRate(), 1e-6);
assertEquals(0.5, ((Adam)l3.getUpdaterByParam("b")).getLearningRate(), 1e-6);
assertEquals(1e-2, ((Adam)l3.getUpdaterByParam("W")).getLearningRate(), 1e-6);
assertEquals(0.5, ((Adam) l0.getUpdaterByParam("b")).getLearningRate(), 1e-6);
assertEquals(1e-2, ((Adam) l0.getUpdaterByParam("W")).getLearningRate(), 1e-6);
assertEquals(0.5, ((Adam) l1.getUpdaterByParam("b")).getLearningRate(), 1e-6);
assertEquals(1e-2, ((Adam) l1.getUpdaterByParam("W")).getLearningRate(), 1e-6);
assertEquals(0.5, ((Adam) l2.getUpdaterByParam("b")).getLearningRate(), 1e-6);
assertEquals(1e-2, ((Adam) l2.getUpdaterByParam("W")).getLearningRate(), 1e-6);
assertEquals(0.5, ((Adam) l3.getUpdaterByParam("b")).getLearningRate(), 1e-6);
assertEquals(1e-2, ((Adam) l3.getUpdaterByParam("W")).getLearningRate(), 1e-6);
}
@Test
public void testInvalidOutputLayer(){
@DisplayName("Test Invalid Output Layer")
void testInvalidOutputLayer() {
/*
Test case (invalid configs)
1. nOut=1 + softmax
@ -393,32 +295,24 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
4. xent + relu
5. mcxent + sigmoid
*/
LossFunctions.LossFunction[] lf = new LossFunctions.LossFunction[]{
LossFunctions.LossFunction.MCXENT, LossFunctions.LossFunction.MCXENT, LossFunctions.LossFunction.XENT,
LossFunctions.LossFunction.XENT, LossFunctions.LossFunction.MCXENT};
int[] nOut = new int[]{1, 3, 3, 3, 3};
Activation[] activations = new Activation[]{Activation.SOFTMAX, Activation.TANH, Activation.SOFTMAX, Activation.RELU, Activation.SIGMOID};
for( int i=0; i<lf.length; i++ ){
for(boolean lossLayer : new boolean[]{false, true}) {
for (boolean validate : new boolean[]{true, false}) {
LossFunctions.LossFunction[] lf = new LossFunctions.LossFunction[] { LossFunctions.LossFunction.MCXENT, LossFunctions.LossFunction.MCXENT, LossFunctions.LossFunction.XENT, LossFunctions.LossFunction.XENT, LossFunctions.LossFunction.MCXENT };
int[] nOut = new int[] { 1, 3, 3, 3, 3 };
Activation[] activations = new Activation[] { Activation.SOFTMAX, Activation.TANH, Activation.SOFTMAX, Activation.RELU, Activation.SIGMOID };
for (int i = 0; i < lf.length; i++) {
for (boolean lossLayer : new boolean[] { false, true }) {
for (boolean validate : new boolean[] { true, false }) {
String s = "nOut=" + nOut[i] + ",lossFn=" + lf[i] + ",lossLayer=" + lossLayer + ",validate=" + validate;
if(nOut[i] == 1 && lossLayer)
continue; //nOuts are not availabel in loss layer, can't expect it to detect this case
if (nOut[i] == 1 && lossLayer)
// nOuts are not availabel in loss layer, can't expect it to detect this case
continue;
try {
new NeuralNetConfiguration.Builder()
.list()
.layer(new DenseLayer.Builder().nIn(10).nOut(10).build())
.layer(!lossLayer ? new OutputLayer.Builder().nIn(10).nOut(nOut[i]).activation(activations[i]).lossFunction(lf[i]).build()
: new LossLayer.Builder().activation(activations[i]).lossFunction(lf[i]).build())
.validateOutputLayerConfig(validate)
.build();
new NeuralNetConfiguration.Builder().list().layer(new DenseLayer.Builder().nIn(10).nOut(10).build()).layer(!lossLayer ? new OutputLayer.Builder().nIn(10).nOut(nOut[i]).activation(activations[i]).lossFunction(lf[i]).build() : new LossLayer.Builder().activation(activations[i]).lossFunction(lf[i]).build()).validateOutputLayerConfig(validate).build();
if (validate) {
fail("Expected exception: " + s);
}
} catch (DL4JInvalidConfigException e) {
if (validate) {
assertTrue(s, e.getMessage().toLowerCase().contains("invalid output"));
assertTrue(e.getMessage().toLowerCase().contains("invalid output"),s);
} else {
fail("Validation should not be enabled");
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.nn.conf;
import org.deeplearning4j.BaseDL4JTest;
@ -29,61 +28,70 @@ import org.deeplearning4j.nn.conf.layers.SubsamplingLayer.PoolingType;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.convolution.Convolution;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
/**
* @author Jeffrey Tang.
*/
public class MultiNeuralNetConfLayerBuilderTest extends BaseDL4JTest {
@DisplayName("Multi Neural Net Conf Layer Builder Test")
class MultiNeuralNetConfLayerBuilderTest extends BaseDL4JTest {
int numIn = 10;
int numOut = 5;
double drop = 0.3;
Activation act = Activation.SOFTMAX;
PoolingType poolType = PoolingType.MAX;
int[] filterSize = new int[] {2, 2};
int[] filterSize = new int[] { 2, 2 };
int filterDepth = 6;
int[] stride = new int[] {2, 2};
int[] stride = new int[] { 2, 2 };
int k = 1;
Convolution.Type convType = Convolution.Type.FULL;
LossFunction loss = LossFunction.MCXENT;
WeightInit weight = WeightInit.XAVIER;
double corrupt = 0.4;
double sparsity = 0.3;
@Test
public void testNeuralNetConfigAPI() {
@DisplayName("Test Neural Net Config API")
void testNeuralNetConfigAPI() {
LossFunction newLoss = LossFunction.SQUARED_LOSS;
int newNumIn = numIn + 1;
int newNumOut = numOut + 1;
WeightInit newWeight = WeightInit.UNIFORM;
double newDrop = 0.5;
int[] newFS = new int[] {3, 3};
int[] newFS = new int[] { 3, 3 };
int newFD = 7;
int[] newStride = new int[] {3, 3};
int[] newStride = new int[] { 3, 3 };
Convolution.Type newConvType = Convolution.Type.SAME;
PoolingType newPoolType = PoolingType.AVG;
double newCorrupt = 0.5;
double newSparsity = 0.5;
MultiLayerConfiguration multiConf1 =
new NeuralNetConfiguration.Builder().list()
.layer(0, new DenseLayer.Builder().nIn(newNumIn).nOut(newNumOut).activation(act)
.build())
.layer(1, new DenseLayer.Builder().nIn(newNumIn + 1).nOut(newNumOut + 1)
.activation(act).build())
.build();
MultiLayerConfiguration multiConf1 = new NeuralNetConfiguration.Builder().list().layer(0, new DenseLayer.Builder().nIn(newNumIn).nOut(newNumOut).activation(act).build()).layer(1, new DenseLayer.Builder().nIn(newNumIn + 1).nOut(newNumOut + 1).activation(act).build()).build();
NeuralNetConfiguration firstLayer = multiConf1.getConf(0);
NeuralNetConfiguration secondLayer = multiConf1.getConf(1);
assertFalse(firstLayer.equals(secondLayer));
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.nn.conf;
import org.deeplearning4j.BaseDL4JTest;
@ -37,7 +36,7 @@ import org.deeplearning4j.nn.weights.*;
import org.deeplearning4j.optimize.api.ConvexOptimizer;
import org.deeplearning4j.optimize.solvers.StochasticGradientDescent;
import org.deeplearning4j.optimize.stepfunctions.NegativeDefaultStepFunction;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU;
@ -46,65 +45,61 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.learning.regularization.Regularization;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotSame;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertTrue;
public class NeuralNetConfigurationTest extends BaseDL4JTest {
@DisplayName("Neural Net Configuration Test")
class NeuralNetConfigurationTest extends BaseDL4JTest {
final DataSet trainingSet = createData();
public DataSet createData() {
int numFeatures = 40;
INDArray input = Nd4j.create(2, numFeatures); // have to be at least two or else output layer gradient is a scalar and cause exception
// have to be at least two or else output layer gradient is a scalar and cause exception
INDArray input = Nd4j.create(2, numFeatures);
INDArray labels = Nd4j.create(2, 2);
INDArray row0 = Nd4j.create(1, numFeatures);
row0.assign(0.1);
input.putRow(0, row0);
labels.put(0, 1, 1); // set the 4th column
// set the 4th column
labels.put(0, 1, 1);
INDArray row1 = Nd4j.create(1, numFeatures);
row1.assign(0.2);
input.putRow(1, row1);
labels.put(1, 0, 1); // set the 2nd column
// set the 2nd column
labels.put(1, 0, 1);
return new DataSet(input, labels);
}
@Test
public void testJson() {
@DisplayName("Test Json")
void testJson() {
NeuralNetConfiguration conf = getConfig(1, 1, new WeightInitXavier(), true);
String json = conf.toJson();
NeuralNetConfiguration read = NeuralNetConfiguration.fromJson(json);
assertEquals(conf, read);
}
@Test
public void testYaml() {
@DisplayName("Test Yaml")
void testYaml() {
NeuralNetConfiguration conf = getConfig(1, 1, new WeightInitXavier(), true);
String json = conf.toYaml();
NeuralNetConfiguration read = NeuralNetConfiguration.fromYaml(json);
assertEquals(conf, read);
}
@Test
public void testClone() {
@DisplayName("Test Clone")
void testClone() {
NeuralNetConfiguration conf = getConfig(1, 1, new WeightInitUniform(), true);
BaseLayer bl = (BaseLayer) conf.getLayer();
conf.setStepFunction(new DefaultStepFunction());
NeuralNetConfiguration conf2 = conf.clone();
assertEquals(conf, conf2);
assertNotSame(conf, conf2);
assertNotSame(conf.getLayer(), conf2.getLayer());
@ -112,97 +107,74 @@ public class NeuralNetConfigurationTest extends BaseDL4JTest {
}
@Test
public void testRNG() {
DenseLayer layer = new DenseLayer.Builder().nIn(trainingSet.numInputs()).nOut(trainingSet.numOutcomes())
.weightInit(WeightInit.UNIFORM).activation(Activation.TANH).build();
NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().seed(123)
.optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).layer(layer).build();
@DisplayName("Test RNG")
void testRNG() {
DenseLayer layer = new DenseLayer.Builder().nIn(trainingSet.numInputs()).nOut(trainingSet.numOutcomes()).weightInit(WeightInit.UNIFORM).activation(Activation.TANH).build();
NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().seed(123).optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).layer(layer).build();
long numParams = conf.getLayer().initializer().numParams(conf);
INDArray params = Nd4j.create(1, numParams);
Layer model = conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType());
INDArray modelWeights = model.getParam(DefaultParamInitializer.WEIGHT_KEY);
DenseLayer layer2 = new DenseLayer.Builder().nIn(trainingSet.numInputs()).nOut(trainingSet.numOutcomes())
.weightInit(WeightInit.UNIFORM).activation(Activation.TANH).build();
NeuralNetConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(123)
.optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).layer(layer2).build();
DenseLayer layer2 = new DenseLayer.Builder().nIn(trainingSet.numInputs()).nOut(trainingSet.numOutcomes()).weightInit(WeightInit.UNIFORM).activation(Activation.TANH).build();
NeuralNetConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(123).optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).layer(layer2).build();
long numParams2 = conf2.getLayer().initializer().numParams(conf);
INDArray params2 = Nd4j.create(1, numParams);
Layer model2 = conf2.getLayer().instantiate(conf2, null, 0, params2, true, params.dataType());
INDArray modelWeights2 = model2.getParam(DefaultParamInitializer.WEIGHT_KEY);
assertEquals(modelWeights, modelWeights2);
}
@Test
public void testSetSeedSize() {
@DisplayName("Test Set Seed Size")
void testSetSeedSize() {
Nd4j.getRandom().setSeed(123);
Layer model = getLayer(trainingSet.numInputs(), trainingSet.numOutcomes(), new WeightInitXavier(), true);
INDArray modelWeights = model.getParam(DefaultParamInitializer.WEIGHT_KEY);
Nd4j.getRandom().setSeed(123);
Layer model2 = getLayer(trainingSet.numInputs(), trainingSet.numOutcomes(), new WeightInitXavier(), true);
INDArray modelWeights2 = model2.getParam(DefaultParamInitializer.WEIGHT_KEY);
assertEquals(modelWeights, modelWeights2);
}
@Test
public void testSetSeedNormalized() {
Nd4j.getRandom().setSeed(123);
Layer model = getLayer(trainingSet.numInputs(), trainingSet.numOutcomes(), new WeightInitXavier(), true);
INDArray modelWeights = model.getParam(DefaultParamInitializer.WEIGHT_KEY);
Nd4j.getRandom().setSeed(123);
Layer model2 = getLayer(trainingSet.numInputs(), trainingSet.numOutcomes(), new WeightInitXavier(), true);
INDArray modelWeights2 = model2.getParam(DefaultParamInitializer.WEIGHT_KEY);
assertEquals(modelWeights, modelWeights2);
}
@Test
public void testSetSeedXavier() {
@DisplayName("Test Set Seed Normalized")
void testSetSeedNormalized() {
Nd4j.getRandom().setSeed(123);
Layer model = getLayer(trainingSet.numInputs(), trainingSet.numOutcomes(), new WeightInitXavier(), true);
INDArray modelWeights = model.getParam(DefaultParamInitializer.WEIGHT_KEY);
Nd4j.getRandom().setSeed(123);
Layer model2 = getLayer(trainingSet.numInputs(), trainingSet.numOutcomes(), new WeightInitXavier(), true);
INDArray modelWeights2 = model2.getParam(DefaultParamInitializer.WEIGHT_KEY);
assertEquals(modelWeights, modelWeights2);
}
@Test
@DisplayName("Test Set Seed Xavier")
void testSetSeedXavier() {
Nd4j.getRandom().setSeed(123);
Layer model = getLayer(trainingSet.numInputs(), trainingSet.numOutcomes(), new WeightInitUniform(), true);
INDArray modelWeights = model.getParam(DefaultParamInitializer.WEIGHT_KEY);
Nd4j.getRandom().setSeed(123);
Layer model2 = getLayer(trainingSet.numInputs(), trainingSet.numOutcomes(), new WeightInitUniform(), true);
INDArray modelWeights2 = model2.getParam(DefaultParamInitializer.WEIGHT_KEY);
assertEquals(modelWeights, modelWeights2);
}
@Test
public void testSetSeedDistribution() {
@DisplayName("Test Set Seed Distribution")
void testSetSeedDistribution() {
Nd4j.getRandom().setSeed(123);
Layer model = getLayer(trainingSet.numInputs(), trainingSet.numOutcomes(),
new WeightInitDistribution(new NormalDistribution(1, 1)), true);
Layer model = getLayer(trainingSet.numInputs(), trainingSet.numOutcomes(), new WeightInitDistribution(new NormalDistribution(1, 1)), true);
INDArray modelWeights = model.getParam(DefaultParamInitializer.WEIGHT_KEY);
Nd4j.getRandom().setSeed(123);
Layer model2 = getLayer(trainingSet.numInputs(), trainingSet.numOutcomes(),
new WeightInitDistribution(new NormalDistribution(1, 1)), true);
Layer model2 = getLayer(trainingSet.numInputs(), trainingSet.numOutcomes(), new WeightInitDistribution(new NormalDistribution(1, 1)), true);
INDArray modelWeights2 = model2.getParam(DefaultParamInitializer.WEIGHT_KEY);
assertEquals(modelWeights, modelWeights2);
}
private static NeuralNetConfiguration getConfig(int nIn, int nOut, IWeightInit weightInit, boolean pretrain) {
DenseLayer layer = new DenseLayer.Builder().nIn(nIn).nOut(nOut).weightInit(weightInit)
.activation(Activation.TANH).build();
NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).layer(layer)
.build();
DenseLayer layer = new DenseLayer.Builder().nIn(nIn).nOut(nOut).weightInit(weightInit).activation(Activation.TANH).build();
NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).layer(layer).build();
return conf;
}
@ -213,94 +185,75 @@ public class NeuralNetConfigurationTest extends BaseDL4JTest {
return conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType());
}
@Test
public void testLearningRateByParam() {
@DisplayName("Test Learning Rate By Param")
void testLearningRateByParam() {
double lr = 0.01;
double biasLr = 0.02;
int[] nIns = {4, 3, 3};
int[] nOuts = {3, 3, 3};
int[] nIns = { 4, 3, 3 };
int[] nOuts = { 3, 3, 3 };
int oldScore = 1;
int newScore = 1;
int iteration = 3;
INDArray gradientW = Nd4j.ones(nIns[0], nOuts[0]);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.3)).list()
.layer(0, new DenseLayer.Builder().nIn(nIns[0]).nOut(nOuts[0])
.updater(new Sgd(lr)).biasUpdater(new Sgd(biasLr)).build())
.layer(1, new BatchNormalization.Builder().nIn(nIns[1]).nOut(nOuts[1]).updater(new Sgd(0.7)).build())
.layer(2, new OutputLayer.Builder().nIn(nIns[2]).nOut(nOuts[2]).lossFunction(LossFunctions.LossFunction.MSE).build())
.build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.3)).list().layer(0, new DenseLayer.Builder().nIn(nIns[0]).nOut(nOuts[0]).updater(new Sgd(lr)).biasUpdater(new Sgd(biasLr)).build()).layer(1, new BatchNormalization.Builder().nIn(nIns[1]).nOut(nOuts[1]).updater(new Sgd(0.7)).build()).layer(2, new OutputLayer.Builder().nIn(nIns[2]).nOut(nOuts[2]).lossFunction(LossFunctions.LossFunction.MSE).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
ConvexOptimizer opt = new StochasticGradientDescent(net.getDefaultConfiguration(),
new NegativeDefaultStepFunction(), null, net);
assertEquals(lr, ((Sgd)net.getLayer(0).conf().getLayer().getUpdaterByParam("W")).getLearningRate(), 1e-4);
assertEquals(biasLr, ((Sgd)net.getLayer(0).conf().getLayer().getUpdaterByParam("b")).getLearningRate(), 1e-4);
assertEquals(0.7, ((Sgd)net.getLayer(1).conf().getLayer().getUpdaterByParam("gamma")).getLearningRate(), 1e-4);
assertEquals(0.3, ((Sgd)net.getLayer(2).conf().getLayer().getUpdaterByParam("W")).getLearningRate(), 1e-4); //From global LR
assertEquals(0.3, ((Sgd)net.getLayer(2).conf().getLayer().getUpdaterByParam("W")).getLearningRate(), 1e-4); //From global LR
ConvexOptimizer opt = new StochasticGradientDescent(net.getDefaultConfiguration(), new NegativeDefaultStepFunction(), null, net);
assertEquals(lr, ((Sgd) net.getLayer(0).conf().getLayer().getUpdaterByParam("W")).getLearningRate(), 1e-4);
assertEquals(biasLr, ((Sgd) net.getLayer(0).conf().getLayer().getUpdaterByParam("b")).getLearningRate(), 1e-4);
assertEquals(0.7, ((Sgd) net.getLayer(1).conf().getLayer().getUpdaterByParam("gamma")).getLearningRate(), 1e-4);
// From global LR
assertEquals(0.3, ((Sgd) net.getLayer(2).conf().getLayer().getUpdaterByParam("W")).getLearningRate(), 1e-4);
// From global LR
assertEquals(0.3, ((Sgd) net.getLayer(2).conf().getLayer().getUpdaterByParam("W")).getLearningRate(), 1e-4);
}
@Test
public void testLeakyreluAlpha() {
//FIXME: Make more generic to use neuralnetconfs
@DisplayName("Test Leakyrelu Alpha")
void testLeakyreluAlpha() {
// FIXME: Make more generic to use neuralnetconfs
int sizeX = 4;
int scaleX = 10;
System.out.println("Here is a leaky vector..");
INDArray leakyVector = Nd4j.linspace(-1, 1, sizeX, Nd4j.dataType());
leakyVector = leakyVector.mul(scaleX);
System.out.println(leakyVector);
double myAlpha = 0.5;
System.out.println("======================");
System.out.println("Exec and Return: Leaky Relu transformation with alpha = 0.5 ..");
System.out.println("======================");
INDArray outDef = Nd4j.getExecutioner().exec(new LeakyReLU(leakyVector.dup(), myAlpha));
System.out.println(outDef);
String confActivation = "leakyrelu";
Object[] confExtra = {myAlpha};
Object[] confExtra = { myAlpha };
INDArray outMine = Nd4j.getExecutioner().exec(new LeakyReLU(leakyVector.dup(), myAlpha));
System.out.println("======================");
System.out.println("Exec and Return: Leaky Relu transformation with a value via getOpFactory");
System.out.println("======================");
System.out.println(outMine);
//Test equality for ndarray elementwise
//assertArrayEquals(..)
// Test equality for ndarray elementwise
// assertArrayEquals(..)
}
@Test
public void testL1L2ByParam() {
@DisplayName("Test L 1 L 2 By Param")
void testL1L2ByParam() {
double l1 = 0.01;
double l2 = 0.07;
int[] nIns = {4, 3, 3};
int[] nOuts = {3, 3, 3};
int[] nIns = { 4, 3, 3 };
int[] nOuts = { 3, 3, 3 };
int oldScore = 1;
int newScore = 1;
int iteration = 3;
INDArray gradientW = Nd4j.ones(nIns[0], nOuts[0]);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().l1(l1)
.l2(l2).list()
.layer(0, new DenseLayer.Builder().nIn(nIns[0]).nOut(nOuts[0]).build())
.layer(1, new BatchNormalization.Builder().nIn(nIns[1]).nOut(nOuts[1]).l2(0.5).build())
.layer(2, new OutputLayer.Builder().nIn(nIns[2]).nOut(nOuts[2]).lossFunction(LossFunctions.LossFunction.MSE).build())
.build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().l1(l1).l2(l2).list().layer(0, new DenseLayer.Builder().nIn(nIns[0]).nOut(nOuts[0]).build()).layer(1, new BatchNormalization.Builder().nIn(nIns[1]).nOut(nOuts[1]).l2(0.5).build()).layer(2, new OutputLayer.Builder().nIn(nIns[2]).nOut(nOuts[2]).lossFunction(LossFunctions.LossFunction.MSE).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
ConvexOptimizer opt = new StochasticGradientDescent(net.getDefaultConfiguration(),
new NegativeDefaultStepFunction(), null, net);
ConvexOptimizer opt = new StochasticGradientDescent(net.getDefaultConfiguration(), new NegativeDefaultStepFunction(), null, net);
assertEquals(l1, TestUtils.getL1(net.getLayer(0).conf().getLayer().getRegularizationByParam("W")), 1e-4);
List<Regularization> r = net.getLayer(0).conf().getLayer().getRegularizationByParam("b");
assertEquals(0, r.size());
r = net.getLayer(1).conf().getLayer().getRegularizationByParam("beta");
assertTrue(r == null || r.isEmpty());
r = net.getLayer(1).conf().getLayer().getRegularizationByParam("gamma");
@ -315,14 +268,10 @@ public class NeuralNetConfigurationTest extends BaseDL4JTest {
}
@Test
public void testLayerPretrainConfig() {
@DisplayName("Test Layer Pretrain Config")
void testLayerPretrainConfig() {
boolean pretrain = true;
VariationalAutoencoder layer = new VariationalAutoencoder.Builder()
.nIn(10).nOut(5).updater(new Sgd(1e-1))
.lossFunction(LossFunctions.LossFunction.KL_DIVERGENCE).build();
VariationalAutoencoder layer = new VariationalAutoencoder.Builder().nIn(10).nOut(5).updater(new Sgd(1e-1)).lossFunction(LossFunctions.LossFunction.KL_DIVERGENCE).build();
NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().seed(42).layer(layer).build();
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.nn.conf.graph;
import org.deeplearning4j.BaseDL4JTest;
@ -30,8 +29,8 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.Assert;
import org.junit.Test;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
import org.nd4j.linalg.activations.impl.ActivationTanH;
@ -43,194 +42,99 @@ import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.common.primitives.Pair;
import java.util.Map;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertArrayEquals;
@DisplayName("Element Wise Vertex Test")
class ElementWiseVertexTest extends BaseDL4JTest {
public class ElementWiseVertexTest extends BaseDL4JTest {
@Test
public void testElementWiseVertexNumParams() {
@DisplayName("Test Element Wise Vertex Num Params")
void testElementWiseVertexNumParams() {
/*
* https://github.com/eclipse/deeplearning4j/pull/3514#issuecomment-307754386
* from @agibsonccc: check for the basics: like 0 numParams
*/
ElementWiseVertex.Op ops[] = new ElementWiseVertex.Op[] {ElementWiseVertex.Op.Add,
ElementWiseVertex.Op.Subtract, ElementWiseVertex.Op.Product};
ElementWiseVertex.Op[] ops = new ElementWiseVertex.Op[] { ElementWiseVertex.Op.Add, ElementWiseVertex.Op.Subtract, ElementWiseVertex.Op.Product };
for (ElementWiseVertex.Op op : ops) {
ElementWiseVertex ewv = new ElementWiseVertex(op);
Assert.assertEquals(0, ewv.numParams(true));
Assert.assertEquals(0, ewv.numParams(false));
Assertions.assertEquals(0, ewv.numParams(true));
Assertions.assertEquals(0, ewv.numParams(false));
}
}
@Test
public void testElementWiseVertexForwardAdd() {
@DisplayName("Test Element Wise Vertex Forward Add")
void testElementWiseVertexForwardAdd() {
int batchsz = 24;
int featuresz = 17;
ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().graphBuilder()
.addInputs("input1", "input2", "input3")
.addLayer("denselayer",
new DenseLayer.Builder().nIn(featuresz).nOut(1).activation(Activation.IDENTITY)
.build(),
"input1")
/* denselayer is not actually used, but it seems that you _need_ to have trainable parameters, otherwise, you get
* Invalid shape: Requested INDArray shape [1, 0] contains dimension size values < 1 (all dimensions must be 1 or more)
* at org.nd4j.linalg.factory.Nd4j.checkShapeValues(Nd4j.java:4877)
* at org.nd4j.linalg.factory.Nd4j.create(Nd4j.java:4867)
* at org.nd4j.linalg.factory.Nd4j.create(Nd4j.java:4820)
* at org.nd4j.linalg.factory.Nd4j.create(Nd4j.java:3948)
* at org.deeplearning4j.nn.graph.ComputationGraph.init(ComputationGraph.java:409)
* at org.deeplearning4j.nn.graph.ComputationGraph.init(ComputationGraph.java:341)
*/
.addVertex("elementwiseAdd", new ElementWiseVertex(ElementWiseVertex.Op.Add), "input1",
"input2", "input3")
.addLayer("Add", new ActivationLayer.Builder().activation(Activation.IDENTITY).build(),
"elementwiseAdd")
.setOutputs("Add", "denselayer").build();
ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("input1", "input2", "input3").addLayer("denselayer", new DenseLayer.Builder().nIn(featuresz).nOut(1).activation(Activation.IDENTITY).build(), "input1").addVertex("elementwiseAdd", new ElementWiseVertex(ElementWiseVertex.Op.Add), "input1", "input2", "input3").addLayer("Add", new ActivationLayer.Builder().activation(Activation.IDENTITY).build(), "elementwiseAdd").setOutputs("Add", "denselayer").build();
ComputationGraph cg = new ComputationGraph(cgc);
cg.init();
INDArray input1 = Nd4j.rand(batchsz, featuresz);
INDArray input2 = Nd4j.rand(batchsz, featuresz);
INDArray input3 = Nd4j.rand(batchsz, featuresz);
INDArray target = input1.dup().addi(input2).addi(input3);
INDArray output = cg.output(input1, input2, input3)[0];
INDArray squared = output.sub(target.castTo(output.dataType()));
double rms = squared.mul(squared).sumNumber().doubleValue();
Assert.assertEquals(0.0, rms, this.epsilon);
Assertions.assertEquals(0.0, rms, this.epsilon);
}
@Test
public void testElementWiseVertexForwardProduct() {
@DisplayName("Test Element Wise Vertex Forward Product")
void testElementWiseVertexForwardProduct() {
int batchsz = 24;
int featuresz = 17;
ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().graphBuilder()
.addInputs("input1", "input2", "input3")
.addLayer("denselayer",
new DenseLayer.Builder().nIn(featuresz).nOut(1).activation(Activation.IDENTITY)
.build(),
"input1")
/* denselayer is not actually used, but it seems that you _need_ to have trainable parameters, otherwise, you get
* Invalid shape: Requested INDArray shape [1, 0] contains dimension size values < 1 (all dimensions must be 1 or more)
* at org.nd4j.linalg.factory.Nd4j.checkShapeValues(Nd4j.java:4877)
* at org.nd4j.linalg.factory.Nd4j.create(Nd4j.java:4867)
* at org.nd4j.linalg.factory.Nd4j.create(Nd4j.java:4820)
* at org.nd4j.linalg.factory.Nd4j.create(Nd4j.java:3948)
* at org.deeplearning4j.nn.graph.ComputationGraph.init(ComputationGraph.java:409)
* at org.deeplearning4j.nn.graph.ComputationGraph.init(ComputationGraph.java:341)
*/
.addVertex("elementwiseProduct", new ElementWiseVertex(ElementWiseVertex.Op.Product), "input1",
"input2", "input3")
.addLayer("Product", new ActivationLayer.Builder().activation(Activation.IDENTITY).build(),
"elementwiseProduct")
.setOutputs("Product", "denselayer").build();
ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("input1", "input2", "input3").addLayer("denselayer", new DenseLayer.Builder().nIn(featuresz).nOut(1).activation(Activation.IDENTITY).build(), "input1").addVertex("elementwiseProduct", new ElementWiseVertex(ElementWiseVertex.Op.Product), "input1", "input2", "input3").addLayer("Product", new ActivationLayer.Builder().activation(Activation.IDENTITY).build(), "elementwiseProduct").setOutputs("Product", "denselayer").build();
ComputationGraph cg = new ComputationGraph(cgc);
cg.init();
INDArray input1 = Nd4j.rand(batchsz, featuresz);
INDArray input2 = Nd4j.rand(batchsz, featuresz);
INDArray input3 = Nd4j.rand(batchsz, featuresz);
INDArray target = input1.dup().muli(input2).muli(input3);
INDArray output = cg.output(input1, input2, input3)[0];
INDArray squared = output.sub(target.castTo(output.dataType()));
double rms = squared.mul(squared).sumNumber().doubleValue();
Assert.assertEquals(0.0, rms, this.epsilon);
Assertions.assertEquals(0.0, rms, this.epsilon);
}
@Test
public void testElementWiseVertexForwardSubtract() {
@DisplayName("Test Element Wise Vertex Forward Subtract")
void testElementWiseVertexForwardSubtract() {
int batchsz = 24;
int featuresz = 17;
ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().graphBuilder()
.addInputs("input1", "input2")
.addLayer("denselayer",
new DenseLayer.Builder().nIn(featuresz).nOut(1).activation(Activation.IDENTITY)
.build(),
"input1")
/* denselayer is not actually used, but it seems that you _need_ to have trainable parameters, otherwise, you get
* Invalid shape: Requested INDArray shape [1, 0] contains dimension size values < 1 (all dimensions must be 1 or more)
* at org.nd4j.linalg.factory.Nd4j.checkShapeValues(Nd4j.java:4877)
* at org.nd4j.linalg.factory.Nd4j.create(Nd4j.java:4867)
* at org.nd4j.linalg.factory.Nd4j.create(Nd4j.java:4820)
* at org.nd4j.linalg.factory.Nd4j.create(Nd4j.java:3948)
* at org.deeplearning4j.nn.graph.ComputationGraph.init(ComputationGraph.java:409)
* at org.deeplearning4j.nn.graph.ComputationGraph.init(ComputationGraph.java:341)
*/
.addVertex("elementwiseSubtract", new ElementWiseVertex(ElementWiseVertex.Op.Subtract),
"input1", "input2")
.addLayer("Subtract", new ActivationLayer.Builder().activation(Activation.IDENTITY).build(),
"elementwiseSubtract")
.setOutputs("Subtract", "denselayer").build();
ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("input1", "input2").addLayer("denselayer", new DenseLayer.Builder().nIn(featuresz).nOut(1).activation(Activation.IDENTITY).build(), "input1").addVertex("elementwiseSubtract", new ElementWiseVertex(ElementWiseVertex.Op.Subtract), "input1", "input2").addLayer("Subtract", new ActivationLayer.Builder().activation(Activation.IDENTITY).build(), "elementwiseSubtract").setOutputs("Subtract", "denselayer").build();
ComputationGraph cg = new ComputationGraph(cgc);
cg.init();
INDArray input1 = Nd4j.rand(batchsz, featuresz);
INDArray input2 = Nd4j.rand(batchsz, featuresz);
INDArray target = input1.dup().subi(input2);
INDArray output = cg.output(input1, input2)[0];
INDArray squared = output.sub(target);
double rms = Math.sqrt(squared.mul(squared).sumNumber().doubleValue());
Assert.assertEquals(0.0, rms, this.epsilon);
Assertions.assertEquals(0.0, rms, this.epsilon);
}
@Test
public void testElementWiseVertexFullAdd() {
@DisplayName("Test Element Wise Vertex Full Add")
void testElementWiseVertexFullAdd() {
int batchsz = 24;
int featuresz = 17;
int midsz = 13;
int outputsz = 11;
ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER)
.dataType(DataType.DOUBLE)
.biasInit(0.0).updater(new Sgd())
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder()
.addInputs("input1", "input2", "input3")
.addLayer("dense1",
new DenseLayer.Builder().nIn(featuresz).nOut(midsz)
.activation(new ActivationTanH()).build(),
"input1")
.addLayer("dense2",
new DenseLayer.Builder().nIn(featuresz).nOut(midsz)
.activation(new ActivationTanH()).build(),
"input2")
.addLayer("dense3",
new DenseLayer.Builder().nIn(featuresz).nOut(midsz)
.activation(new ActivationTanH()).build(),
"input3")
.addVertex("elementwiseAdd", new ElementWiseVertex(ElementWiseVertex.Op.Add), "dense1",
"dense2", "dense3")
.addLayer("output",
new OutputLayer.Builder().nIn(midsz).nOut(outputsz)
.activation(new ActivationSigmoid())
.lossFunction(LossFunction.MSE).build(),
"elementwiseAdd")
.setOutputs("output").build();
ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).dataType(DataType.DOUBLE).biasInit(0.0).updater(new Sgd()).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder().addInputs("input1", "input2", "input3").addLayer("dense1", new DenseLayer.Builder().nIn(featuresz).nOut(midsz).activation(new ActivationTanH()).build(), "input1").addLayer("dense2", new DenseLayer.Builder().nIn(featuresz).nOut(midsz).activation(new ActivationTanH()).build(), "input2").addLayer("dense3", new DenseLayer.Builder().nIn(featuresz).nOut(midsz).activation(new ActivationTanH()).build(), "input3").addVertex("elementwiseAdd", new ElementWiseVertex(ElementWiseVertex.Op.Add), "dense1", "dense2", "dense3").addLayer("output", new OutputLayer.Builder().nIn(midsz).nOut(outputsz).activation(new ActivationSigmoid()).lossFunction(LossFunction.MSE).build(), "elementwiseAdd").setOutputs("output").build();
ComputationGraph cg = new ComputationGraph(cgc);
cg.init();
INDArray input1 = Nd4j.rand(new int[] {batchsz, featuresz}, new UniformDistribution(-1, 1));
INDArray input2 = Nd4j.rand(new int[] {batchsz, featuresz}, new UniformDistribution(-1, 1));
INDArray input3 = Nd4j.rand(new int[] {batchsz, featuresz}, new UniformDistribution(-1, 1));
INDArray target = nullsafe(Nd4j.rand(new int[] {batchsz, outputsz}, new UniformDistribution(0, 1)));
INDArray input1 = Nd4j.rand(new int[] { batchsz, featuresz }, new UniformDistribution(-1, 1));
INDArray input2 = Nd4j.rand(new int[] { batchsz, featuresz }, new UniformDistribution(-1, 1));
INDArray input3 = Nd4j.rand(new int[] { batchsz, featuresz }, new UniformDistribution(-1, 1));
INDArray target = nullsafe(Nd4j.rand(new int[] { batchsz, outputsz }, new UniformDistribution(0, 1)));
cg.setInputs(input1, input2, input3);
cg.setLabels(target);
cg.computeGradientAndScore();
// Let's figure out what our params are now.
Map<String, INDArray> params = cg.paramTable();
INDArray dense1_W = nullsafe(params.get("dense1_W"));
@ -241,35 +145,22 @@ public class ElementWiseVertexTest extends BaseDL4JTest {
INDArray dense3_b = nullsafe(params.get("dense3_b"));
INDArray output_W = nullsafe(params.get("output_W"));
INDArray output_b = nullsafe(params.get("output_b"));
// Now, let's calculate what we expect the output to be.
INDArray mh = input1.mmul(dense1_W).addi(dense1_b.repmat(batchsz, 1));
INDArray m = (Transforms.tanh(mh));
INDArray nh = input2.mmul(dense2_W).addi(dense2_b.repmat(batchsz, 1));
INDArray n = (Transforms.tanh(nh));
INDArray oh = input3.mmul(dense3_W).addi(dense3_b.repmat(batchsz, 1));
INDArray o = (Transforms.tanh(oh));
INDArray middle = Nd4j.zeros(batchsz, midsz);
middle.addi(m).addi(n).addi(o);
INDArray expect = Nd4j.zeros(batchsz, outputsz);
expect.addi(Transforms.sigmoid(middle.mmul(output_W).addi(output_b.repmat(batchsz, 1))));
INDArray output = nullsafe(cg.output(input1, input2, input3)[0]);
Assert.assertEquals(0.0, mse(output, expect), this.epsilon);
Assertions.assertEquals(0.0, mse(output, expect), this.epsilon);
Pair<Gradient, Double> pgd = cg.gradientAndScore();
double score = pgd.getSecond();
Assert.assertEquals(score, mse(output, target), this.epsilon);
Assertions.assertEquals(score, mse(output, target), this.epsilon);
Map<String, INDArray> gradients = pgd.getFirst().gradientForVariable();
/*
* So. Let's say we have inputs a, b, c
@ -305,27 +196,23 @@ public class ElementWiseVertexTest extends BaseDL4JTest {
* dmh/db1 = Nd4j.ones(1, batchsz)
*
*/
INDArray y = output;
INDArray s = middle;
INDArray W4 = output_W;
INDArray dEdy = Nd4j.zeros(target.shape());
dEdy.addi(y).subi(target).muli(2); // This should be of size batchsz x outputsz
dEdy.divi(target.shape()[1]); // Why? Because the LossFunction divides by the _element size_ of the output.
INDArray dydyh = y.mul(y.mul(-1).add(1)); // This is of size batchsz x outputsz
// This should be of size batchsz x outputsz
dEdy.addi(y).subi(target).muli(2);
// Why? Because the LossFunction divides by the _element size_ of the output.
dEdy.divi(target.shape()[1]);
// This is of size batchsz x outputsz
INDArray dydyh = y.mul(y.mul(-1).add(1));
INDArray dEdyh = dydyh.mul(dEdy);
INDArray dyhdW4 = s.transpose();
INDArray dEdW4 = nullsafe(dyhdW4.mmul(dEdyh));
INDArray dyhdb4 = Nd4j.ones(1, batchsz);
INDArray dEdb4 = nullsafe(dyhdb4.mmul(dEdyh));
INDArray dyhds = W4.transpose();
INDArray dEds = dEdyh.mmul(dyhds);
INDArray dsdm = Nd4j.ones(batchsz, midsz);
INDArray dEdm = dsdm.mul(dEds);
INDArray dmdmh = (m.mul(m)).mul(-1).add(1);
@ -334,7 +221,6 @@ public class ElementWiseVertexTest extends BaseDL4JTest {
INDArray dEdW1 = nullsafe(dmhdW1.mmul(dEdmh));
INDArray dmhdb1 = Nd4j.ones(1, batchsz);
INDArray dEdb1 = nullsafe(dmhdb1.mmul(dEdmh));
INDArray dsdn = Nd4j.ones(batchsz, midsz);
INDArray dEdn = dsdn.mul(dEds);
INDArray dndnh = (n.mul(n)).mul(-1).add(1);
@ -343,7 +229,6 @@ public class ElementWiseVertexTest extends BaseDL4JTest {
INDArray dEdW2 = nullsafe(dnhdW2.mmul(dEdnh));
INDArray dnhdb2 = Nd4j.ones(1, batchsz);
INDArray dEdb2 = nullsafe(dnhdb2.mmul(dEdnh));
INDArray dsdo = Nd4j.ones(batchsz, midsz);
INDArray dEdo = dsdo.mul(dEds);
INDArray dodoh = (o.mul(o)).mul(-1).add(1);
@ -352,61 +237,33 @@ public class ElementWiseVertexTest extends BaseDL4JTest {
INDArray dEdW3 = nullsafe(dohdW3.mmul(dEdoh));
INDArray dohdb3 = Nd4j.ones(1, batchsz);
INDArray dEdb3 = nullsafe(dohdb3.mmul(dEdoh));
Assert.assertEquals(0, mse(nullsafe(gradients.get("output_W")), dEdW4), this.epsilon);
Assert.assertEquals(0, mse(nullsafe(gradients.get("output_b")), dEdb4), this.epsilon);
Assert.assertEquals(0, mse(nullsafe(gradients.get("dense1_W")), dEdW1), this.epsilon);
Assert.assertEquals(0, mse(nullsafe(gradients.get("dense1_b")), dEdb1), this.epsilon);
Assert.assertEquals(0, mse(nullsafe(gradients.get("dense2_W")), dEdW2), this.epsilon);
Assert.assertEquals(0, mse(nullsafe(gradients.get("dense2_b")), dEdb2), this.epsilon);
Assert.assertEquals(0, mse(nullsafe(gradients.get("dense3_W")), dEdW3), this.epsilon);
Assert.assertEquals(0, mse(nullsafe(gradients.get("dense3_b")), dEdb3), this.epsilon);
Assertions.assertEquals(0, mse(nullsafe(gradients.get("output_W")), dEdW4), this.epsilon);
Assertions.assertEquals(0, mse(nullsafe(gradients.get("output_b")), dEdb4), this.epsilon);
Assertions.assertEquals(0, mse(nullsafe(gradients.get("dense1_W")), dEdW1), this.epsilon);
Assertions.assertEquals(0, mse(nullsafe(gradients.get("dense1_b")), dEdb1), this.epsilon);
Assertions.assertEquals(0, mse(nullsafe(gradients.get("dense2_W")), dEdW2), this.epsilon);
Assertions.assertEquals(0, mse(nullsafe(gradients.get("dense2_b")), dEdb2), this.epsilon);
Assertions.assertEquals(0, mse(nullsafe(gradients.get("dense3_W")), dEdW3), this.epsilon);
Assertions.assertEquals(0, mse(nullsafe(gradients.get("dense3_b")), dEdb3), this.epsilon);
}
@Test
public void testElementWiseVertexFullProduct() {
@DisplayName("Test Element Wise Vertex Full Product")
void testElementWiseVertexFullProduct() {
int batchsz = 24;
int featuresz = 17;
int midsz = 13;
int outputsz = 11;
ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER)
.dataType(DataType.DOUBLE)
.biasInit(0.0).updater(new Sgd())
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder()
.addInputs("input1", "input2", "input3")
.addLayer("dense1",
new DenseLayer.Builder().nIn(featuresz).nOut(midsz)
.activation(new ActivationTanH()).build(),
"input1")
.addLayer("dense2",
new DenseLayer.Builder().nIn(featuresz).nOut(midsz)
.activation(new ActivationTanH()).build(),
"input2")
.addLayer("dense3",
new DenseLayer.Builder().nIn(featuresz).nOut(midsz)
.activation(new ActivationTanH()).build(),
"input3")
.addVertex("elementwiseProduct", new ElementWiseVertex(ElementWiseVertex.Op.Product), "dense1",
"dense2", "dense3")
.addLayer("output",
new OutputLayer.Builder().nIn(midsz).nOut(outputsz)
.activation(new ActivationSigmoid())
.lossFunction(LossFunction.MSE).build(),
"elementwiseProduct")
.setOutputs("output").build();
ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).dataType(DataType.DOUBLE).biasInit(0.0).updater(new Sgd()).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder().addInputs("input1", "input2", "input3").addLayer("dense1", new DenseLayer.Builder().nIn(featuresz).nOut(midsz).activation(new ActivationTanH()).build(), "input1").addLayer("dense2", new DenseLayer.Builder().nIn(featuresz).nOut(midsz).activation(new ActivationTanH()).build(), "input2").addLayer("dense3", new DenseLayer.Builder().nIn(featuresz).nOut(midsz).activation(new ActivationTanH()).build(), "input3").addVertex("elementwiseProduct", new ElementWiseVertex(ElementWiseVertex.Op.Product), "dense1", "dense2", "dense3").addLayer("output", new OutputLayer.Builder().nIn(midsz).nOut(outputsz).activation(new ActivationSigmoid()).lossFunction(LossFunction.MSE).build(), "elementwiseProduct").setOutputs("output").build();
ComputationGraph cg = new ComputationGraph(cgc);
cg.init();
INDArray input1 = Nd4j.rand(new int[] {batchsz, featuresz}, new UniformDistribution(-1, 1));
INDArray input2 = Nd4j.rand(new int[] {batchsz, featuresz}, new UniformDistribution(-1, 1));
INDArray input3 = Nd4j.rand(new int[] {batchsz, featuresz}, new UniformDistribution(-1, 1));
INDArray target = nullsafe(Nd4j.rand(new int[] {batchsz, outputsz}, new UniformDistribution(0, 1)));
INDArray input1 = Nd4j.rand(new int[] { batchsz, featuresz }, new UniformDistribution(-1, 1));
INDArray input2 = Nd4j.rand(new int[] { batchsz, featuresz }, new UniformDistribution(-1, 1));
INDArray input3 = Nd4j.rand(new int[] { batchsz, featuresz }, new UniformDistribution(-1, 1));
INDArray target = nullsafe(Nd4j.rand(new int[] { batchsz, outputsz }, new UniformDistribution(0, 1)));
cg.setInputs(input1, input2, input3);
cg.setLabels(target);
cg.computeGradientAndScore();
// Let's figure out what our params are now.
Map<String, INDArray> params = cg.paramTable();
INDArray dense1_W = nullsafe(params.get("dense1_W"));
@ -417,35 +274,22 @@ public class ElementWiseVertexTest extends BaseDL4JTest {
INDArray dense3_b = nullsafe(params.get("dense3_b"));
INDArray output_W = nullsafe(params.get("output_W"));
INDArray output_b = nullsafe(params.get("output_b"));
// Now, let's calculate what we expect the output to be.
INDArray mh = input1.mmul(dense1_W).addi(dense1_b.repmat(batchsz, 1));
INDArray m = (Transforms.tanh(mh));
INDArray nh = input2.mmul(dense2_W).addi(dense2_b.repmat(batchsz, 1));
INDArray n = (Transforms.tanh(nh));
INDArray oh = input3.mmul(dense3_W).addi(dense3_b.repmat(batchsz, 1));
INDArray o = (Transforms.tanh(oh));
INDArray middle = Nd4j.ones(batchsz, midsz);
middle.muli(m).muli(n).muli(o);
INDArray expect = Nd4j.zeros(batchsz, outputsz);
expect.addi(Transforms.sigmoid(middle.mmul(output_W).addi(output_b.repmat(batchsz, 1))));
INDArray output = nullsafe(cg.output(input1, input2, input3)[0]);
Assert.assertEquals(0.0, mse(output, expect), this.epsilon);
Assertions.assertEquals(0.0, mse(output, expect), this.epsilon);
Pair<Gradient, Double> pgd = cg.gradientAndScore();
double score = pgd.getSecond();
Assert.assertEquals(score, mse(output, target), this.epsilon);
Assertions.assertEquals(score, mse(output, target), this.epsilon);
Map<String, INDArray> gradients = pgd.getFirst().gradientForVariable();
/*
* So. Let's say we have inputs a, b, c
@ -481,27 +325,23 @@ public class ElementWiseVertexTest extends BaseDL4JTest {
* dmh/db1 = Nd4j.ones(1, batchsz)
*
*/
INDArray y = output;
INDArray s = middle;
INDArray W4 = output_W;
INDArray dEdy = Nd4j.zeros(target.shape());
dEdy.addi(y).subi(target).muli(2); // This should be of size batchsz x outputsz
dEdy.divi(target.shape()[1]); // Why? Because the LossFunction divides by the _element size_ of the output.
INDArray dydyh = y.mul(y.mul(-1).add(1)); // This is of size batchsz x outputsz
// This should be of size batchsz x outputsz
dEdy.addi(y).subi(target).muli(2);
// Why? Because the LossFunction divides by the _element size_ of the output.
dEdy.divi(target.shape()[1]);
// This is of size batchsz x outputsz
INDArray dydyh = y.mul(y.mul(-1).add(1));
INDArray dEdyh = dydyh.mul(dEdy);
INDArray dyhdW4 = s.transpose();
INDArray dEdW4 = nullsafe(dyhdW4.mmul(dEdyh));
INDArray dyhdb4 = Nd4j.ones(1, batchsz);
INDArray dEdb4 = nullsafe(dyhdb4.mmul(dEdyh));
INDArray dyhds = W4.transpose();
INDArray dEds = dEdyh.mmul(dyhds);
INDArray dsdm = Nd4j.ones(batchsz, midsz).muli(n).muli(o);
INDArray dEdm = dsdm.mul(dEds);
INDArray dmdmh = (m.mul(m)).mul(-1).add(1);
@ -510,7 +350,6 @@ public class ElementWiseVertexTest extends BaseDL4JTest {
INDArray dEdW1 = nullsafe(dmhdW1.mmul(dEdmh));
INDArray dmhdb1 = Nd4j.ones(1, batchsz);
INDArray dEdb1 = nullsafe(dmhdb1.mmul(dEdmh));
INDArray dsdn = Nd4j.ones(batchsz, midsz).muli(m).muli(o);
INDArray dEdn = dsdn.mul(dEds);
INDArray dndnh = (n.mul(n)).mul(-1).add(1);
@ -519,7 +358,6 @@ public class ElementWiseVertexTest extends BaseDL4JTest {
INDArray dEdW2 = nullsafe(dnhdW2.mmul(dEdnh));
INDArray dnhdb2 = Nd4j.ones(1, batchsz);
INDArray dEdb2 = nullsafe(dnhdb2.mmul(dEdnh));
INDArray dsdo = Nd4j.ones(batchsz, midsz).muli(m).muli(n);
INDArray dEdo = dsdo.mul(dEds);
INDArray dodoh = (o.mul(o)).mul(-1).add(1);
@ -528,55 +366,32 @@ public class ElementWiseVertexTest extends BaseDL4JTest {
INDArray dEdW3 = nullsafe(dohdW3.mmul(dEdoh));
INDArray dohdb3 = Nd4j.ones(1, batchsz);
INDArray dEdb3 = nullsafe(dohdb3.mmul(dEdoh));
Assert.assertEquals(0, mse(nullsafe(gradients.get("output_W")), dEdW4), this.epsilon);
Assert.assertEquals(0, mse(nullsafe(gradients.get("output_b")), dEdb4), this.epsilon);
Assert.assertEquals(0, mse(nullsafe(gradients.get("dense1_W")), dEdW1), this.epsilon);
Assert.assertEquals(0, mse(nullsafe(gradients.get("dense1_b")), dEdb1), this.epsilon);
Assert.assertEquals(0, mse(nullsafe(gradients.get("dense2_W")), dEdW2), this.epsilon);
Assert.assertEquals(0, mse(nullsafe(gradients.get("dense2_b")), dEdb2), this.epsilon);
Assert.assertEquals(0, mse(nullsafe(gradients.get("dense3_W")), dEdW3), this.epsilon);
Assert.assertEquals(0, mse(nullsafe(gradients.get("dense3_b")), dEdb3), this.epsilon);
Assertions.assertEquals(0, mse(nullsafe(gradients.get("output_W")), dEdW4), this.epsilon);
Assertions.assertEquals(0, mse(nullsafe(gradients.get("output_b")), dEdb4), this.epsilon);
Assertions.assertEquals(0, mse(nullsafe(gradients.get("dense1_W")), dEdW1), this.epsilon);
Assertions.assertEquals(0, mse(nullsafe(gradients.get("dense1_b")), dEdb1), this.epsilon);
Assertions.assertEquals(0, mse(nullsafe(gradients.get("dense2_W")), dEdW2), this.epsilon);
Assertions.assertEquals(0, mse(nullsafe(gradients.get("dense2_b")), dEdb2), this.epsilon);
Assertions.assertEquals(0, mse(nullsafe(gradients.get("dense3_W")), dEdW3), this.epsilon);
Assertions.assertEquals(0, mse(nullsafe(gradients.get("dense3_b")), dEdb3), this.epsilon);
}
@Test
public void testElementWiseVertexFullSubtract() {
@DisplayName("Test Element Wise Vertex Full Subtract")
void testElementWiseVertexFullSubtract() {
int batchsz = 24;
int featuresz = 17;
int midsz = 13;
int outputsz = 11;
ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER)
.dataType(DataType.DOUBLE)
.biasInit(0.0).updater(new Sgd())
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder()
.addInputs("input1", "input2")
.addLayer("dense1",
new DenseLayer.Builder().nIn(featuresz).nOut(midsz)
.activation(new ActivationTanH()).build(),
"input1")
.addLayer("dense2",
new DenseLayer.Builder().nIn(featuresz).nOut(midsz)
.activation(new ActivationTanH()).build(),
"input2")
.addVertex("elementwiseSubtract", new ElementWiseVertex(ElementWiseVertex.Op.Subtract),
"dense1", "dense2")
.addLayer("output",
new OutputLayer.Builder().nIn(midsz).nOut(outputsz)
.activation(new ActivationSigmoid())
.lossFunction(LossFunction.MSE).build(),
"elementwiseSubtract")
.setOutputs("output").build();
ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).dataType(DataType.DOUBLE).biasInit(0.0).updater(new Sgd()).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder().addInputs("input1", "input2").addLayer("dense1", new DenseLayer.Builder().nIn(featuresz).nOut(midsz).activation(new ActivationTanH()).build(), "input1").addLayer("dense2", new DenseLayer.Builder().nIn(featuresz).nOut(midsz).activation(new ActivationTanH()).build(), "input2").addVertex("elementwiseSubtract", new ElementWiseVertex(ElementWiseVertex.Op.Subtract), "dense1", "dense2").addLayer("output", new OutputLayer.Builder().nIn(midsz).nOut(outputsz).activation(new ActivationSigmoid()).lossFunction(LossFunction.MSE).build(), "elementwiseSubtract").setOutputs("output").build();
ComputationGraph cg = new ComputationGraph(cgc);
cg.init();
INDArray input1 = Nd4j.rand(new int[] {batchsz, featuresz}, new UniformDistribution(-1, 1));
INDArray input2 = Nd4j.rand(new int[] {batchsz, featuresz}, new UniformDistribution(-1, 1));
INDArray target = nullsafe(Nd4j.rand(new int[] {batchsz, outputsz}, new UniformDistribution(0, 1)));
INDArray input1 = Nd4j.rand(new int[] { batchsz, featuresz }, new UniformDistribution(-1, 1));
INDArray input2 = Nd4j.rand(new int[] { batchsz, featuresz }, new UniformDistribution(-1, 1));
INDArray target = nullsafe(Nd4j.rand(new int[] { batchsz, outputsz }, new UniformDistribution(0, 1)));
cg.setInputs(input1, input2);
cg.setLabels(target);
cg.computeGradientAndScore();
// Let's figure out what our params are now.
Map<String, INDArray> params = cg.paramTable();
INDArray dense1_W = nullsafe(params.get("dense1_W"));
@ -585,32 +400,20 @@ public class ElementWiseVertexTest extends BaseDL4JTest {
INDArray dense2_b = nullsafe(params.get("dense2_b"));
INDArray output_W = nullsafe(params.get("output_W"));
INDArray output_b = nullsafe(params.get("output_b"));
// Now, let's calculate what we expect the output to be.
INDArray mh = input1.mmul(dense1_W).addi(dense1_b.repmat(batchsz, 1));
INDArray m = (Transforms.tanh(mh));
INDArray nh = input2.mmul(dense2_W).addi(dense2_b.repmat(batchsz, 1));
INDArray n = (Transforms.tanh(nh));
INDArray middle = Nd4j.zeros(batchsz, midsz);
middle.addi(m).subi(n);
INDArray expect = Nd4j.zeros(batchsz, outputsz);
expect.addi(Transforms.sigmoid(middle.mmul(output_W).addi(output_b.repmat(batchsz, 1))));
INDArray output = nullsafe(cg.output(input1, input2)[0]);
Assert.assertEquals(0.0, mse(output, expect), this.epsilon);
Assertions.assertEquals(0.0, mse(output, expect), this.epsilon);
Pair<Gradient, Double> pgd = cg.gradientAndScore();
double score = pgd.getSecond();
Assert.assertEquals(score, mse(output, target), this.epsilon);
Assertions.assertEquals(score, mse(output, target), this.epsilon);
Map<String, INDArray> gradients = pgd.getFirst().gradientForVariable();
/*
* So. Let's say we have inputs a, b, c
@ -644,27 +447,23 @@ public class ElementWiseVertexTest extends BaseDL4JTest {
* dmh/db1 = Nd4j.ones(1, batchsz)
*
*/
INDArray y = output;
INDArray s = middle;
INDArray W4 = output_W;
INDArray dEdy = Nd4j.zeros(target.shape());
dEdy.addi(y).subi(target).muli(2); // This should be of size batchsz x outputsz
dEdy.divi(target.shape()[1]); // Why? Because the LossFunction divides by the _element size_ of the output.
INDArray dydyh = y.mul(y.mul(-1).add(1)); // This is of size batchsz x outputsz
// This should be of size batchsz x outputsz
dEdy.addi(y).subi(target).muli(2);
// Why? Because the LossFunction divides by the _element size_ of the output.
dEdy.divi(target.shape()[1]);
// This is of size batchsz x outputsz
INDArray dydyh = y.mul(y.mul(-1).add(1));
INDArray dEdyh = dydyh.mul(dEdy);
INDArray dyhdW4 = s.transpose();
INDArray dEdW4 = nullsafe(dyhdW4.mmul(dEdyh));
INDArray dyhdb4 = Nd4j.ones(1, batchsz);
INDArray dEdb4 = nullsafe(dyhdb4.mmul(dEdyh));
INDArray dyhds = W4.transpose();
INDArray dEds = dEdyh.mmul(dyhds);
INDArray dsdm = Nd4j.ones(batchsz, midsz);
INDArray dEdm = dsdm.mul(dEds);
INDArray dmdmh = (m.mul(m)).mul(-1).add(1);
@ -673,7 +472,6 @@ public class ElementWiseVertexTest extends BaseDL4JTest {
INDArray dEdW1 = nullsafe(dmhdW1.mmul(dEdmh));
INDArray dmhdb1 = Nd4j.ones(1, batchsz);
INDArray dEdb1 = nullsafe(dmhdb1.mmul(dEdmh));
INDArray dsdn = Nd4j.ones(batchsz, midsz).muli(-1);
INDArray dEdn = dsdn.mul(dEds);
INDArray dndnh = (n.mul(n)).mul(-1).add(1);
@ -682,20 +480,16 @@ public class ElementWiseVertexTest extends BaseDL4JTest {
INDArray dEdW2 = nullsafe(dnhdW2.mmul(dEdnh));
INDArray dnhdb2 = Nd4j.ones(1, batchsz);
INDArray dEdb2 = nullsafe(dnhdb2.mmul(dEdnh));
Assert.assertEquals(0, mse(nullsafe(gradients.get("output_W")), dEdW4), this.epsilon);
Assert.assertEquals(0, mse(nullsafe(gradients.get("output_b")), dEdb4), this.epsilon);
Assert.assertEquals(0, mse(nullsafe(gradients.get("dense1_W")), dEdW1), this.epsilon);
Assert.assertEquals(0, mse(nullsafe(gradients.get("dense1_b")), dEdb1), this.epsilon);
Assert.assertEquals(0, mse(nullsafe(gradients.get("dense2_W")), dEdW2), this.epsilon);
Assert.assertEquals(0, mse(nullsafe(gradients.get("dense2_b")), dEdb2), this.epsilon);
Assertions.assertEquals(0, mse(nullsafe(gradients.get("output_W")), dEdW4), this.epsilon);
Assertions.assertEquals(0, mse(nullsafe(gradients.get("output_b")), dEdb4), this.epsilon);
Assertions.assertEquals(0, mse(nullsafe(gradients.get("dense1_W")), dEdW1), this.epsilon);
Assertions.assertEquals(0, mse(nullsafe(gradients.get("dense1_b")), dEdb1), this.epsilon);
Assertions.assertEquals(0, mse(nullsafe(gradients.get("dense2_W")), dEdW2), this.epsilon);
Assertions.assertEquals(0, mse(nullsafe(gradients.get("dense2_b")), dEdb2), this.epsilon);
}
private static double mse(INDArray output, INDArray target) {
double mse_expect = Transforms.pow(output.sub(target), 2.0).sumNumber().doubleValue()
/ (output.columns() * output.rows());
double mse_expect = Transforms.pow(output.sub(target), 2.0).sumNumber().doubleValue() / (output.columns() * output.rows());
return mse_expect;
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.nn.conf.graph;
import org.deeplearning4j.BaseDL4JTest;
@ -30,8 +29,8 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.Assert;
import org.junit.Test;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
@ -42,86 +41,70 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import org.nd4j.common.primitives.Pair;
import java.util.Map;
import java.util.TreeMap;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@DisplayName("Shift Vertex Test")
class ShiftVertexTest extends BaseDL4JTest {
public class ShiftVertexTest extends BaseDL4JTest {
@Test
public void testShiftVertexNumParamsTrue() {
@DisplayName("Test Shift Vertex Num Params True")
void testShiftVertexNumParamsTrue() {
/*
* https://github.com/eclipse/deeplearning4j/pull/3514#issuecomment-307754386
* from @agibsonccc: check for the basics: like 0 numParams
*/
ShiftVertex sv = new ShiftVertex(0.7); // The 0.7 doesn't really matter.
Assert.assertEquals(0, sv.numParams(true));
}
@Test
public void testShiftVertexNumParamsFalse() {
/*
* https://github.com/eclipse/deeplearning4j/pull/3514#issuecomment-307754386
* from @agibsonccc: check for the basics: like 0 numParams
*/
ShiftVertex sv = new ShiftVertex(0.7); // The 0.7 doesn't really matter.
Assert.assertEquals(0, sv.numParams(false));
}
@Test
public void testGet() {
// The 0.7 doesn't really matter.
ShiftVertex sv = new ShiftVertex(0.7);
Assert.assertEquals(0.7, sv.getShiftFactor(), this.epsilon);
Assertions.assertEquals(0, sv.numParams(true));
}
@Test
public void testSimple() {
@DisplayName("Test Shift Vertex Num Params False")
void testShiftVertexNumParamsFalse() {
/*
* https://github.com/eclipse/deeplearning4j/pull/3514#issuecomment-307754386
* from @agibsonccc: check for the basics: like 0 numParams
*/
// The 0.7 doesn't really matter.
ShiftVertex sv = new ShiftVertex(0.7);
Assertions.assertEquals(0, sv.numParams(false));
}
@Test
@DisplayName("Test Get")
void testGet() {
ShiftVertex sv = new ShiftVertex(0.7);
Assertions.assertEquals(0.7, sv.getShiftFactor(), this.epsilon);
}
@Test
@DisplayName("Test Simple")
void testSimple() {
/*
* This function _simply_ tests whether ShiftVertex is _in fact_ adding the shift value to it's inputs.
*/
// Just first n primes / 10.
INDArray input = Nd4j
.create(new double[][] {{0.2, 0.3, 0.5}, {0.7, 1.1, 1.3}, {1.7, 1.9, 2.3}, {2.9, 3.1, 3.7}});
INDArray input = Nd4j.create(new double[][] { { 0.2, 0.3, 0.5 }, { 0.7, 1.1, 1.3 }, { 1.7, 1.9, 2.3 }, { 2.9, 3.1, 3.7 } });
double sf = 4.1;
ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("input")
.addLayer("denselayer",
new DenseLayer.Builder().nIn(input.columns()).nOut(1)
.activation(Activation.IDENTITY).build(),
"input")
/* denselayer is not actually used, but it seems that you _need_ to have trainable parameters, otherwise, you get
* Invalid shape: Requested INDArray shape [1, 0] contains dimension size values < 1 (all dimensions must be 1 or more)
* at org.nd4j.linalg.factory.Nd4j.checkShapeValues(Nd4j.java:4877)
* at org.nd4j.linalg.factory.Nd4j.create(Nd4j.java:4867)
* at org.nd4j.linalg.factory.Nd4j.create(Nd4j.java:4820)
* at org.nd4j.linalg.factory.Nd4j.create(Nd4j.java:3948)
* at org.deeplearning4j.nn.graph.ComputationGraph.init(ComputationGraph.java:409)
* at org.deeplearning4j.nn.graph.ComputationGraph.init(ComputationGraph.java:341)
*/
.addLayer("identityinputactivation",
new ActivationLayer.Builder().activation(Activation.IDENTITY).build(), "input")
.addVertex("shiftvertex", new ShiftVertex(sf), "identityinputactivation")
.addLayer("identityshiftvertex",
new ActivationLayer.Builder().activation(Activation.IDENTITY).build(),
"shiftvertex")
.setOutputs("identityshiftvertex", "denselayer").build();
ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("input").addLayer("denselayer", new DenseLayer.Builder().nIn(input.columns()).nOut(1).activation(Activation.IDENTITY).build(), "input").addLayer("identityinputactivation", new ActivationLayer.Builder().activation(Activation.IDENTITY).build(), "input").addVertex("shiftvertex", new ShiftVertex(sf), "identityinputactivation").addLayer("identityshiftvertex", new ActivationLayer.Builder().activation(Activation.IDENTITY).build(), "shiftvertex").setOutputs("identityshiftvertex", "denselayer").build();
ComputationGraph cg = new ComputationGraph(cgc);
cg.init();
// We can call outputSingle, because we only have a single output layer. It has nothing to do with minibatches.
INDArray output = cg.output(true, input)[0];
INDArray target = Nd4j.zeros(input.shape());
target.addi(input);
target.addi(sf);
INDArray squared = output.sub(target);
double rms = squared.mul(squared).sumNumber().doubleValue();
Assert.assertEquals(0.0, rms, this.epsilon);
Assertions.assertEquals(0.0, rms, this.epsilon);
}
@Test
public void testComprehensive() {
@DisplayName("Test Comprehensive")
void testComprehensive() {
/*
* This function tests ShiftVertex more comprehensively. Specifically, it verifies that the lossfunction works as
* expected on a ComputationGraph _with_ a ShiftVertex and it verifies that the derivatives produced by
@ -130,29 +113,12 @@ public class ShiftVertexTest extends BaseDL4JTest {
BaseActivationFunction a1 = new ActivationTanH();
BaseActivationFunction a2 = new ActivationSigmoid();
// Just first n primes / 10.
INDArray input = Nd4j
.create(new double[][] {{0.2, 0.3, 0.5}, {0.7, 1.1, 1.3}, {1.7, 1.9, 2.3}, {2.9, 3.1, 3.7}});
INDArray input = Nd4j.create(new double[][] { { 0.2, 0.3, 0.5 }, { 0.7, 1.1, 1.3 }, { 1.7, 1.9, 2.3 }, { 2.9, 3.1, 3.7 } });
double sf = 4.1;
// Actually, given that I'm using a sigmoid on the output,
// these should really be between 0 and 1
INDArray target = Nd4j.create(new double[][] {{0.05, 0.10, 0.15, 0.20, 0.25}, {0.30, 0.35, 0.40, 0.45, 0.50},
{0.55, 0.60, 0.65, 0.70, 0.75}, {0.80, 0.85, 0.90, 0.95, 0.99}});
ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER)
.dataType(DataType.DOUBLE)
.updater(new Sgd(0.01))
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder()
.addInputs("input")
.addLayer("denselayer",
new DenseLayer.Builder().nIn(input.columns()).nOut(input.columns())
.activation(a1).build(),
"input")
.addVertex("shiftvertex", new ShiftVertex(sf), "denselayer")
.addLayer("output",
new OutputLayer.Builder().nIn(input.columns()).nOut(target.columns())
.activation(a2).lossFunction(LossFunction.MSE).build(),
"shiftvertex")
.setOutputs("output").build();
INDArray target = Nd4j.create(new double[][] { { 0.05, 0.10, 0.15, 0.20, 0.25 }, { 0.30, 0.35, 0.40, 0.45, 0.50 }, { 0.55, 0.60, 0.65, 0.70, 0.75 }, { 0.80, 0.85, 0.90, 0.95, 0.99 } });
ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).dataType(DataType.DOUBLE).updater(new Sgd(0.01)).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder().addInputs("input").addLayer("denselayer", new DenseLayer.Builder().nIn(input.columns()).nOut(input.columns()).activation(a1).build(), "input").addVertex("shiftvertex", new ShiftVertex(sf), "denselayer").addLayer("output", new OutputLayer.Builder().nIn(input.columns()).nOut(target.columns()).activation(a2).lossFunction(LossFunction.MSE).build(), "shiftvertex").setOutputs("output").build();
ComputationGraph cg = new ComputationGraph(cgc);
cg.init();
cg.setInput(0, input);
@ -163,26 +129,23 @@ public class ShiftVertexTest extends BaseDL4JTest {
Gradient g = cg.gradient();
Map<String, INDArray> gradients = g.gradientForVariable();
Map<String, INDArray> manual_gradients = new TreeMap<String, INDArray>();
INDArray W = nullsafe(weights.get("denselayer_W"));
INDArray b = nullsafe(weights.get("denselayer_b"));
INDArray V = nullsafe(weights.get("output_W"));
INDArray c = nullsafe(weights.get("output_b"));
Map<String, INDArray> manual_weights = new TreeMap<String, INDArray>();
manual_weights.put("denselayer_W", W);
manual_weights.put("denselayer_b", b);
manual_weights.put("output_W", V);
manual_weights.put("output_b", c);
// First things first, let's calculate the score.
long batchsz = input.shape()[0];
INDArray z = input.castTo(W.dataType()).mmul(W).add(b.repmat(batchsz, 1));
INDArray a = a1.getActivation(z.dup(), true).add(sf); // activation modifies it's input!!
// activation modifies it's input!!
INDArray a = a1.getActivation(z.dup(), true).add(sf);
INDArray q = a.mmul(V).add(c.repmat(batchsz, 1));
INDArray o = nullsafe(a2.getActivation(q.dup(), true));
double score_manual = sum_errors(o, target) / (o.columns() * o.rows());
/*
* So. We have
* z5 = input1 * W15 + input2 * W25 + input3 * W35 + b5
@ -197,12 +160,15 @@ public class ShiftVertexTest extends BaseDL4JTest {
* dq1/dv11 = a1 dq2/dV12 = a1 dq3/dV13 = a1 ...
* dq1/dv21 = a2 dq2...
*/
INDArray dEdo = target.like(); //Nd4j.zeros(target.shape());
dEdo.addi(o.castTo(dEdo.dataType())).subi(target).muli(2); // This should be of size batchsz x outputsz
dEdo.divi(target.shape()[1]); // Why? Because the LossFunction divides by the _element size_ of the output.
// Nd4j.zeros(target.shape());
INDArray dEdo = target.like();
// This should be of size batchsz x outputsz
dEdo.addi(o.castTo(dEdo.dataType())).subi(target).muli(2);
// Why? Because the LossFunction divides by the _element size_ of the output.
dEdo.divi(target.shape()[1]);
Pair<INDArray, INDArray> derivs2 = a2.backprop(q, dEdo);
INDArray dEdq = derivs2.getFirst(); // This should be of size batchsz x outputsz (dE/do * do/dq) this _should_ be o * (1-o) * dE/do for Sigmoid.
// This should be of size batchsz x outputsz (dE/do * do/dq) this _should_ be o * (1-o) * dE/do for Sigmoid.
INDArray dEdq = derivs2.getFirst();
// Should be o = q^3 do/dq = 3 q^2 for Cube.
/*
INDArray dodq = q.mul(q).mul(3);
@ -213,26 +179,23 @@ public class ShiftVertexTest extends BaseDL4JTest {
System.err.println(tbv);
System.err.println(dEdq);
*/
INDArray dqdc = Nd4j.ones(1, batchsz);
INDArray dEdc = dqdc.mmul(dEdq); // This should be of size 1 x outputsz
// This should be of size 1 x outputsz
INDArray dEdc = dqdc.mmul(dEdq);
INDArray dEdV = a.transpose().mmul(dEdq);
INDArray dEda = dEdq.mmul(V.transpose()); // This should be dEdo * dodq * dqda
// This should be dEdo * dodq * dqda
INDArray dEda = dEdq.mmul(V.transpose());
Pair<INDArray, INDArray> derivs1 = a1.backprop(z, dEda);
INDArray dEdz = derivs1.getFirst();
INDArray dzdb = Nd4j.ones(1, batchsz);
INDArray dEdb = dzdb.mmul(dEdz);
INDArray dEdW = input.transpose().mmul(dEdz);
manual_gradients.put("output_b", dEdc);
manual_gradients.put("output_W", dEdV);
manual_gradients.put("denselayer_b", dEdb);
manual_gradients.put("denselayer_W", dEdW);
double summse = Math.pow((score_manual - score_dl4j), 2);
int denominator = 1;
for (Map.Entry<String, INDArray> mesi : gradients.entrySet()) {
String name = mesi.getKey();
INDArray dl4j_gradient = nullsafe(mesi.getValue());
@ -241,9 +204,7 @@ public class ShiftVertexTest extends BaseDL4JTest {
summse += se;
denominator += dl4j_gradient.columns() * dl4j_gradient.rows();
}
Assert.assertEquals(0.0, summse / denominator, this.epsilon);
Assertions.assertEquals(0.0, summse / denominator, this.epsilon);
}
private static double sum_errors(INDArray a, INDArray b) {

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.nn.conf.layers;
import org.deeplearning4j.BaseDL4JTest;
@ -25,7 +24,7 @@ import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.dropout.Dropout;
import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationSoftmax;
@ -34,45 +33,62 @@ import org.nd4j.linalg.convolution.Convolution;
import org.nd4j.linalg.learning.config.AdaGrad;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import java.io.*;
import static org.junit.Assert.*;
import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
/**
* @author Jeffrey Tang.
*/
public class LayerBuilderTest extends BaseDL4JTest {
@DisplayName("Layer Builder Test")
class LayerBuilderTest extends BaseDL4JTest {
final double DELTA = 1e-15;
int numIn = 10;
int numOut = 5;
double drop = 0.3;
IActivation act = new ActivationSoftmax();
PoolingType poolType = PoolingType.MAX;
int[] kernelSize = new int[] {2, 2};
int[] stride = new int[] {2, 2};
int[] padding = new int[] {1, 1};
int[] kernelSize = new int[] { 2, 2 };
int[] stride = new int[] { 2, 2 };
int[] padding = new int[] { 1, 1 };
int k = 1;
Convolution.Type convType = Convolution.Type.VALID;
LossFunction loss = LossFunction.MCXENT;
WeightInit weight = WeightInit.XAVIER;
double corrupt = 0.4;
double sparsity = 0.3;
double corruptionLevel = 0.5;
double dropOut = 0.1;
IUpdater updater = new AdaGrad();
GradientNormalization gradNorm = GradientNormalization.ClipL2PerParamType;
double gradNormThreshold = 8;
@Test
public void testLayer() throws Exception {
DenseLayer layer = new DenseLayer.Builder().activation(act).weightInit(weight).dropOut(dropOut)
.updater(updater).gradientNormalization(gradNorm)
.gradientNormalizationThreshold(gradNormThreshold).build();
@DisplayName("Test Layer")
void testLayer() throws Exception {
DenseLayer layer = new DenseLayer.Builder().activation(act).weightInit(weight).dropOut(dropOut).updater(updater).gradientNormalization(gradNorm).gradientNormalizationThreshold(gradNormThreshold).build();
checkSerialization(layer);
assertEquals(act, layer.getActivationFn());
assertEquals(weight.getWeightInitFunction(), layer.getWeightInitFn());
assertEquals(new Dropout(dropOut), layer.getIDropout());
@ -82,34 +98,30 @@ public class LayerBuilderTest extends BaseDL4JTest {
}
@Test
public void testFeedForwardLayer() throws Exception {
@DisplayName("Test Feed Forward Layer")
void testFeedForwardLayer() throws Exception {
DenseLayer ff = new DenseLayer.Builder().nIn(numIn).nOut(numOut).build();
checkSerialization(ff);
assertEquals(numIn, ff.getNIn());
assertEquals(numOut, ff.getNOut());
}
@Test
public void testConvolutionLayer() throws Exception {
@DisplayName("Test Convolution Layer")
void testConvolutionLayer() throws Exception {
ConvolutionLayer conv = new ConvolutionLayer.Builder(kernelSize, stride, padding).build();
checkSerialization(conv);
// assertEquals(convType, conv.getConvolutionType());
// assertEquals(convType, conv.getConvolutionType());
assertArrayEquals(kernelSize, conv.getKernelSize());
assertArrayEquals(stride, conv.getStride());
assertArrayEquals(padding, conv.getPadding());
}
@Test
public void testSubsamplingLayer() throws Exception {
SubsamplingLayer sample =
new SubsamplingLayer.Builder(poolType, stride).kernelSize(kernelSize).padding(padding).build();
@DisplayName("Test Subsampling Layer")
void testSubsamplingLayer() throws Exception {
SubsamplingLayer sample = new SubsamplingLayer.Builder(poolType, stride).kernelSize(kernelSize).padding(padding).build();
checkSerialization(sample);
assertArrayEquals(padding, sample.getPadding());
assertArrayEquals(kernelSize, sample.getKernelSize());
assertEquals(poolType, sample.getPoolingType());
@ -117,36 +129,33 @@ public class LayerBuilderTest extends BaseDL4JTest {
}
@Test
public void testOutputLayer() throws Exception {
@DisplayName("Test Output Layer")
void testOutputLayer() throws Exception {
OutputLayer out = new OutputLayer.Builder(loss).build();
checkSerialization(out);
}
@Test
public void testRnnOutputLayer() throws Exception {
@DisplayName("Test Rnn Output Layer")
void testRnnOutputLayer() throws Exception {
RnnOutputLayer out = new RnnOutputLayer.Builder(loss).build();
checkSerialization(out);
}
@Test
public void testAutoEncoder() throws Exception {
@DisplayName("Test Auto Encoder")
void testAutoEncoder() throws Exception {
AutoEncoder enc = new AutoEncoder.Builder().corruptionLevel(corruptionLevel).sparsity(sparsity).build();
checkSerialization(enc);
assertEquals(corruptionLevel, enc.getCorruptionLevel(), DELTA);
assertEquals(sparsity, enc.getSparsity(), DELTA);
}
@Test
public void testGravesLSTM() throws Exception {
GravesLSTM glstm = new GravesLSTM.Builder().forgetGateBiasInit(1.5).activation(Activation.TANH).nIn(numIn)
.nOut(numOut).build();
@DisplayName("Test Graves LSTM")
void testGravesLSTM() throws Exception {
GravesLSTM glstm = new GravesLSTM.Builder().forgetGateBiasInit(1.5).activation(Activation.TANH).nIn(numIn).nOut(numOut).build();
checkSerialization(glstm);
assertEquals(glstm.getForgetGateBiasInit(), 1.5, 0.0);
assertEquals(glstm.nIn, numIn);
assertEquals(glstm.nOut, numOut);
@ -154,12 +163,10 @@ public class LayerBuilderTest extends BaseDL4JTest {
}
@Test
public void testGravesBidirectionalLSTM() throws Exception {
final GravesBidirectionalLSTM glstm = new GravesBidirectionalLSTM.Builder().forgetGateBiasInit(1.5)
.activation(Activation.TANH).nIn(numIn).nOut(numOut).build();
@DisplayName("Test Graves Bidirectional LSTM")
void testGravesBidirectionalLSTM() throws Exception {
final GravesBidirectionalLSTM glstm = new GravesBidirectionalLSTM.Builder().forgetGateBiasInit(1.5).activation(Activation.TANH).nIn(numIn).nOut(numOut).build();
checkSerialization(glstm);
assertEquals(1.5, glstm.getForgetGateBiasInit(), 0.0);
assertEquals(glstm.nIn, numIn);
assertEquals(glstm.nOut, numOut);
@ -167,21 +174,19 @@ public class LayerBuilderTest extends BaseDL4JTest {
}
@Test
public void testEmbeddingLayer() throws Exception {
@DisplayName("Test Embedding Layer")
void testEmbeddingLayer() throws Exception {
EmbeddingLayer el = new EmbeddingLayer.Builder().nIn(10).nOut(5).build();
checkSerialization(el);
assertEquals(10, el.getNIn());
assertEquals(5, el.getNOut());
}
@Test
public void testBatchNormLayer() throws Exception {
BatchNormalization bN = new BatchNormalization.Builder().nIn(numIn).nOut(numOut).gamma(2).beta(1).decay(0.5)
.lockGammaBeta(true).build();
@DisplayName("Test Batch Norm Layer")
void testBatchNormLayer() throws Exception {
BatchNormalization bN = new BatchNormalization.Builder().nIn(numIn).nOut(numOut).gamma(2).beta(1).decay(0.5).lockGammaBeta(true).build();
checkSerialization(bN);
assertEquals(numIn, bN.nIn);
assertEquals(numOut, bN.nOut);
assertEquals(true, bN.isLockGammaBeta());
@ -191,42 +196,38 @@ public class LayerBuilderTest extends BaseDL4JTest {
}
@Test
public void testActivationLayer() throws Exception {
@DisplayName("Test Activation Layer")
void testActivationLayer() throws Exception {
ActivationLayer activationLayer = new ActivationLayer.Builder().activation(act).build();
checkSerialization(activationLayer);
assertEquals(act, activationLayer.activationFn);
}
private void checkSerialization(Layer layer) throws Exception {
NeuralNetConfiguration confExpected = new NeuralNetConfiguration.Builder().layer(layer).build();
NeuralNetConfiguration confActual;
// check Java serialization
byte[] data;
try (ByteArrayOutputStream bos = new ByteArrayOutputStream(); ObjectOutput out = new ObjectOutputStream(bos)) {
try (ByteArrayOutputStream bos = new ByteArrayOutputStream();
ObjectOutput out = new ObjectOutputStream(bos)) {
out.writeObject(confExpected);
data = bos.toByteArray();
}
try (ByteArrayInputStream bis = new ByteArrayInputStream(data); ObjectInput in = new ObjectInputStream(bis)) {
try (ByteArrayInputStream bis = new ByteArrayInputStream(data);
ObjectInput in = new ObjectInputStream(bis)) {
confActual = (NeuralNetConfiguration) in.readObject();
}
assertEquals("unequal Java serialization", confExpected.getLayer(), confActual.getLayer());
assertEquals(confExpected.getLayer(), confActual.getLayer(), "unequal Java serialization");
// check JSON
String json = confExpected.toJson();
confActual = NeuralNetConfiguration.fromJson(json);
assertEquals("unequal JSON serialization", confExpected.getLayer(), confActual.getLayer());
assertEquals(confExpected.getLayer(), confActual.getLayer(), "unequal JSON serialization");
// check YAML
String yaml = confExpected.toYaml();
confActual = NeuralNetConfiguration.fromYaml(yaml);
assertEquals("unequal YAML serialization", confExpected.getLayer(), confActual.getLayer());
assertEquals(confExpected.getLayer(), confActual.getLayer(), "unequal YAML serialization");
// check the layer's use of callSuper on equals method
confActual.getLayer().setIDropout(new Dropout(new java.util.Random().nextDouble()));
assertNotEquals("broken equals method (missing callSuper?)", confExpected.getLayer(), confActual.getLayer());
assertNotEquals(confExpected.getLayer(), confActual.getLayer(), "broken equals method (missing callSuper?)");
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.nn.conf.layers;
import org.deeplearning4j.BaseDL4JTest;
@ -30,7 +29,7 @@ import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
import org.deeplearning4j.nn.conf.dropout.Dropout;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInitDistribution;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.learning.config.AdaDelta;
import org.nd4j.linalg.learning.config.Adam;
@ -38,89 +37,170 @@ import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.learning.config.RmsProp;
import org.nd4j.linalg.schedule.MapSchedule;
import org.nd4j.linalg.schedule.ScheduleType;
import java.util.HashMap;
import java.util.Map;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
/*
@Test
public void testLearningRatePolicyExponential() {
double lr = 2;
double lrDecayRate = 5;
int iterations = 1;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().learningRate(lr)
.updater(Updater.SGD)
.learningRateDecayPolicy(LearningRatePolicy.Exponential).lrPolicyDecayRate(lrDecayRate).list()
.layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build())
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
public class LayerConfigTest extends BaseDL4JTest {
assertEquals(LearningRatePolicy.Exponential, conf.getConf(0).getLearningRatePolicy());
assertEquals(LearningRatePolicy.Exponential, conf.getConf(1).getLearningRatePolicy());
assertEquals(lrDecayRate, conf.getConf(0).getLrPolicyDecayRate(), 0.0);
assertEquals(lrDecayRate, conf.getConf(1).getLrPolicyDecayRate(), 0.0);
}
@Test
public void testLayerName() {
public void testLearningRatePolicyInverse() {
double lr = 2;
double lrDecayRate = 5;
double power = 3;
int iterations = 1;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().iterations(iterations).learningRate(lr)
.learningRateDecayPolicy(LearningRatePolicy.Inverse).lrPolicyDecayRate(lrDecayRate)
.lrPolicyPower(power).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build())
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
assertEquals(LearningRatePolicy.Inverse, conf.getConf(0).getLearningRatePolicy());
assertEquals(LearningRatePolicy.Inverse, conf.getConf(1).getLearningRatePolicy());
assertEquals(lrDecayRate, conf.getConf(0).getLrPolicyDecayRate(), 0.0);
assertEquals(lrDecayRate, conf.getConf(1).getLrPolicyDecayRate(), 0.0);
assertEquals(power, conf.getConf(0).getLrPolicyPower(), 0.0);
assertEquals(power, conf.getConf(1).getLrPolicyPower(), 0.0);
}
@Test
public void testLearningRatePolicySteps() {
double lr = 2;
double lrDecayRate = 5;
double steps = 4;
int iterations = 1;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().iterations(iterations).learningRate(lr)
.learningRateDecayPolicy(LearningRatePolicy.Step).lrPolicyDecayRate(lrDecayRate)
.lrPolicySteps(steps).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build())
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
assertEquals(LearningRatePolicy.Step, conf.getConf(0).getLearningRatePolicy());
assertEquals(LearningRatePolicy.Step, conf.getConf(1).getLearningRatePolicy());
assertEquals(lrDecayRate, conf.getConf(0).getLrPolicyDecayRate(), 0.0);
assertEquals(lrDecayRate, conf.getConf(1).getLrPolicyDecayRate(), 0.0);
assertEquals(steps, conf.getConf(0).getLrPolicySteps(), 0.0);
assertEquals(steps, conf.getConf(1).getLrPolicySteps(), 0.0);
}
@Test
public void testLearningRatePolicyPoly() {
double lr = 2;
double lrDecayRate = 5;
double power = 3;
int iterations = 1;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().iterations(iterations).learningRate(lr)
.learningRateDecayPolicy(LearningRatePolicy.Poly).lrPolicyDecayRate(lrDecayRate)
.lrPolicyPower(power).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build())
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
assertEquals(LearningRatePolicy.Poly, conf.getConf(0).getLearningRatePolicy());
assertEquals(LearningRatePolicy.Poly, conf.getConf(1).getLearningRatePolicy());
assertEquals(lrDecayRate, conf.getConf(0).getLrPolicyDecayRate(), 0.0);
assertEquals(lrDecayRate, conf.getConf(1).getLrPolicyDecayRate(), 0.0);
assertEquals(power, conf.getConf(0).getLrPolicyPower(), 0.0);
assertEquals(power, conf.getConf(1).getLrPolicyPower(), 0.0);
}
@Test
public void testLearningRatePolicySigmoid() {
double lr = 2;
double lrDecayRate = 5;
double steps = 4;
int iterations = 1;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().iterations(iterations).learningRate(lr)
.learningRateDecayPolicy(LearningRatePolicy.Sigmoid).lrPolicyDecayRate(lrDecayRate)
.lrPolicySteps(steps).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build())
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
assertEquals(LearningRatePolicy.Sigmoid, conf.getConf(0).getLearningRatePolicy());
assertEquals(LearningRatePolicy.Sigmoid, conf.getConf(1).getLearningRatePolicy());
assertEquals(lrDecayRate, conf.getConf(0).getLrPolicyDecayRate(), 0.0);
assertEquals(lrDecayRate, conf.getConf(1).getLrPolicyDecayRate(), 0.0);
assertEquals(steps, conf.getConf(0).getLrPolicySteps(), 0.0);
assertEquals(steps, conf.getConf(1).getLrPolicySteps(), 0.0);
}
*/
@DisplayName("Layer Config Test")
class LayerConfigTest extends BaseDL4JTest {
@Test
@DisplayName("Test Layer Name")
void testLayerName() {
String name1 = "genisys";
String name2 = "bill";
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list()
.layer(0, new DenseLayer.Builder().nIn(2).nOut(2).name(name1).build())
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).name(name2).build()).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).name(name1).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).name(name2).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
assertEquals(name1, conf.getConf(0).getLayer().getLayerName());
assertEquals(name2, conf.getConf(1).getLayer().getLayerName());
}
@Test
public void testActivationLayerwiseOverride() {
//Without layerwise override:
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().activation(Activation.RELU).list()
.layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build())
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
@DisplayName("Test Activation Layerwise Override")
void testActivationLayerwiseOverride() {
// Without layerwise override:
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().activation(Activation.RELU).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
assertEquals("relu", ((BaseLayer) conf.getConf(0).getLayer()).getActivationFn().toString());
assertEquals("relu", ((BaseLayer) conf.getConf(1).getLayer()).getActivationFn().toString());
//With
conf = new NeuralNetConfiguration.Builder().activation(Activation.RELU).list()
.layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build())
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).activation(Activation.TANH).build()).build();
assertEquals(((BaseLayer) conf.getConf(0).getLayer()).getActivationFn().toString(), "relu");
assertEquals(((BaseLayer) conf.getConf(1).getLayer()).getActivationFn().toString(), "relu");
// With
conf = new NeuralNetConfiguration.Builder().activation(Activation.RELU).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).activation(Activation.TANH).build()).build();
net = new MultiLayerNetwork(conf);
net.init();
assertEquals("relu", ((BaseLayer) conf.getConf(0).getLayer()).getActivationFn().toString());
assertEquals("tanh", ((BaseLayer) conf.getConf(1).getLayer()).getActivationFn().toString());
assertEquals(((BaseLayer) conf.getConf(0).getLayer()).getActivationFn().toString(), "relu");
assertEquals(((BaseLayer) conf.getConf(1).getLayer()).getActivationFn().toString(), "tanh");
}
@Test
public void testWeightBiasInitLayerwiseOverride() {
//Without layerwise override:
@DisplayName("Test Weight Bias Init Layerwise Override")
void testWeightBiasInitLayerwiseOverride() {
// Without layerwise override:
final Distribution defaultDistribution = new NormalDistribution(0, 1.0);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dist(defaultDistribution).biasInit(1).list()
.layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build())
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dist(defaultDistribution).biasInit(1).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayer) conf.getConf(0).getLayer()).getWeightInitFn());
assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayer) conf.getConf(1).getLayer()).getWeightInitFn());
assertEquals(1, ((BaseLayer) conf.getConf(0).getLayer()).getBiasInit(), 0.0);
assertEquals(1, ((BaseLayer) conf.getConf(1).getLayer()).getBiasInit(), 0.0);
//With:
// With:
final Distribution overriddenDistribution = new UniformDistribution(0, 1);
conf = new NeuralNetConfiguration.Builder()
.dist(defaultDistribution).biasInit(1).list()
.layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1,
new DenseLayer.Builder().nIn(2).nOut(2)
.dist(overriddenDistribution).biasInit(0).build())
.build();
conf = new NeuralNetConfiguration.Builder().dist(defaultDistribution).biasInit(1).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).dist(overriddenDistribution).biasInit(0).build()).build();
net = new MultiLayerNetwork(conf);
net.init();
assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayer) conf.getConf(0).getLayer()).getWeightInitFn());
assertEquals(new WeightInitDistribution(overriddenDistribution), ((BaseLayer) conf.getConf(1).getLayer()).getWeightInitFn());
assertEquals(1, ((BaseLayer) conf.getConf(0).getLayer()).getBiasInit(), 0.0);
assertEquals(0, ((BaseLayer) conf.getConf(1).getLayer()).getBiasInit(), 0.0);
}
@ -176,101 +256,65 @@ public class LayerConfigTest extends BaseDL4JTest {
assertEquals(0.2, ((BaseLayer) conf.getConf(0).getLayer()).getL2(), 0.0);
assertEquals(0.8, ((BaseLayer) conf.getConf(1).getLayer()).getL2(), 0.0);
}*/
@Test
public void testDropoutLayerwiseOverride() {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dropOut(1.0).list()
.layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build())
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
@DisplayName("Test Dropout Layerwise Override")
void testDropoutLayerwiseOverride() {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dropOut(1.0).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
assertEquals(new Dropout(1.0), conf.getConf(0).getLayer().getIDropout());
assertEquals(new Dropout(1.0), conf.getConf(1).getLayer().getIDropout());
conf = new NeuralNetConfiguration.Builder().dropOut(1.0).list()
.layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build())
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).dropOut(2.0).build()).build();
conf = new NeuralNetConfiguration.Builder().dropOut(1.0).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).dropOut(2.0).build()).build();
net = new MultiLayerNetwork(conf);
net.init();
assertEquals(new Dropout(1.0), conf.getConf(0).getLayer().getIDropout());
assertEquals(new Dropout(2.0), conf.getConf(1).getLayer().getIDropout());
}
@Test
public void testMomentumLayerwiseOverride() {
@DisplayName("Test Momentum Layerwise Override")
void testMomentumLayerwiseOverride() {
Map<Integer, Double> testMomentumAfter = new HashMap<>();
testMomentumAfter.put(0, 0.1);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.updater(new Nesterovs(1.0, new MapSchedule(ScheduleType.ITERATION, testMomentumAfter)))
.list()
.layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build())
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Nesterovs(1.0, new MapSchedule(ScheduleType.ITERATION, testMomentumAfter))).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
assertEquals(0.1, ((Nesterovs)((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0);
assertEquals(0.1, ((Nesterovs)((BaseLayer) conf.getConf(1).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0);
assertEquals(0.1, ((Nesterovs) ((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0, 0), 0.0);
assertEquals(0.1, ((Nesterovs) ((BaseLayer) conf.getConf(1).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0, 0), 0.0);
Map<Integer, Double> testMomentumAfter2 = new HashMap<>();
testMomentumAfter2.put(0, 0.2);
conf = new NeuralNetConfiguration.Builder().updater(new Nesterovs(1.0, new MapSchedule(ScheduleType.ITERATION, testMomentumAfter) ))
.list()
.layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder()
.nIn(2).nOut(2).updater(new Nesterovs(1.0, new MapSchedule(ScheduleType.ITERATION, testMomentumAfter2))).build())
.build();
conf = new NeuralNetConfiguration.Builder().updater(new Nesterovs(1.0, new MapSchedule(ScheduleType.ITERATION, testMomentumAfter))).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).updater(new Nesterovs(1.0, new MapSchedule(ScheduleType.ITERATION, testMomentumAfter2))).build()).build();
net = new MultiLayerNetwork(conf);
net.init();
assertEquals(0.1, ((Nesterovs)((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0);
assertEquals(0.2, ((Nesterovs)((BaseLayer) conf.getConf(1).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0);
assertEquals(0.1, ((Nesterovs) ((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0, 0), 0.0);
assertEquals(0.2, ((Nesterovs) ((BaseLayer) conf.getConf(1).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0, 0), 0.0);
}
@Test
public void testUpdaterRhoRmsDecayLayerwiseOverride() {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new AdaDelta(0.5, 0.9)).list()
.layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build())
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).updater(new AdaDelta(0.01,0.9)).build()).build();
@DisplayName("Test Updater Rho Rms Decay Layerwise Override")
void testUpdaterRhoRmsDecayLayerwiseOverride() {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new AdaDelta(0.5, 0.9)).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).updater(new AdaDelta(0.01, 0.9)).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
assertTrue(((BaseLayer) conf.getConf(0).getLayer()).getIUpdater() instanceof AdaDelta);
assertTrue(((BaseLayer) conf.getConf(1).getLayer()).getIUpdater() instanceof AdaDelta);
assertEquals(0.5, ((AdaDelta)((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getRho(), 0.0);
assertEquals(0.01, ((AdaDelta)((BaseLayer) conf.getConf(1).getLayer()).getIUpdater()).getRho(), 0.0);
conf = new NeuralNetConfiguration.Builder().updater(new RmsProp(1.0, 2.0, RmsProp.DEFAULT_RMSPROP_EPSILON)).list()
.layer(0, new DenseLayer.Builder().nIn(2).nOut(2).updater(new RmsProp(1.0, 1.0, RmsProp.DEFAULT_RMSPROP_EPSILON)).build())
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).updater(new AdaDelta(0.5,AdaDelta.DEFAULT_ADADELTA_EPSILON)).build())
.build();
assertEquals(0.5, ((AdaDelta) ((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getRho(), 0.0);
assertEquals(0.01, ((AdaDelta) ((BaseLayer) conf.getConf(1).getLayer()).getIUpdater()).getRho(), 0.0);
conf = new NeuralNetConfiguration.Builder().updater(new RmsProp(1.0, 2.0, RmsProp.DEFAULT_RMSPROP_EPSILON)).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).updater(new RmsProp(1.0, 1.0, RmsProp.DEFAULT_RMSPROP_EPSILON)).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).updater(new AdaDelta(0.5, AdaDelta.DEFAULT_ADADELTA_EPSILON)).build()).build();
net = new MultiLayerNetwork(conf);
net.init();
assertTrue(((BaseLayer) conf.getConf(0).getLayer()).getIUpdater() instanceof RmsProp);
assertTrue(((BaseLayer) conf.getConf(1).getLayer()).getIUpdater() instanceof AdaDelta);
assertEquals(1.0, ((RmsProp) ((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getRmsDecay(), 0.0);
assertEquals(0.5, ((AdaDelta) ((BaseLayer) conf.getConf(1).getLayer()).getIUpdater()).getRho(), 0.0);
}
@Test
public void testUpdaterAdamParamsLayerwiseOverride() {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.updater(new Adam(1.0, 0.5, 0.5, 1e-8))
.list()
.layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build())
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).updater(new Adam(1.0, 0.6, 0.7, 1e-8)).build())
.build();
@DisplayName("Test Updater Adam Params Layerwise Override")
void testUpdaterAdamParamsLayerwiseOverride() {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Adam(1.0, 0.5, 0.5, 1e-8)).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).updater(new Adam(1.0, 0.6, 0.7, 1e-8)).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
assertEquals(0.5, ((Adam) ((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getBeta1(), 0.0);
assertEquals(0.6, ((Adam) ((BaseLayer) conf.getConf(1).getLayer()).getIUpdater()).getBeta1(), 0.0);
assertEquals(0.5, ((Adam) ((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getBeta2(), 0.0);
@ -278,45 +322,25 @@ public class LayerConfigTest extends BaseDL4JTest {
}
@Test
public void testGradientNormalizationLayerwiseOverride() {
//Learning rate without layerwise override:
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
.gradientNormalizationThreshold(10).list()
.layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build())
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
@DisplayName("Test Gradient Normalization Layerwise Override")
void testGradientNormalizationLayerwiseOverride() {
// Learning rate without layerwise override:
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue).gradientNormalizationThreshold(10).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue,
((BaseLayer) conf.getConf(0).getLayer()).getGradientNormalization());
assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue,
((BaseLayer) conf.getConf(1).getLayer()).getGradientNormalization());
assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, ((BaseLayer) conf.getConf(0).getLayer()).getGradientNormalization());
assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, ((BaseLayer) conf.getConf(1).getLayer()).getGradientNormalization());
assertEquals(10, ((BaseLayer) conf.getConf(0).getLayer()).getGradientNormalizationThreshold(), 0.0);
assertEquals(10, ((BaseLayer) conf.getConf(1).getLayer()).getGradientNormalizationThreshold(), 0.0);
//With:
conf = new NeuralNetConfiguration.Builder()
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
.gradientNormalizationThreshold(10).list()
.layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build())
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2)
.gradientNormalization(GradientNormalization.None)
.gradientNormalizationThreshold(2.5).build())
.build();
// With:
conf = new NeuralNetConfiguration.Builder().gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue).gradientNormalizationThreshold(10).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).gradientNormalization(GradientNormalization.None).gradientNormalizationThreshold(2.5).build()).build();
net = new MultiLayerNetwork(conf);
net.init();
assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue,
((BaseLayer) conf.getConf(0).getLayer()).getGradientNormalization());
assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, ((BaseLayer) conf.getConf(0).getLayer()).getGradientNormalization());
assertEquals(GradientNormalization.None, ((BaseLayer) conf.getConf(1).getLayer()).getGradientNormalization());
assertEquals(10, ((BaseLayer) conf.getConf(0).getLayer()).getGradientNormalizationThreshold(), 0.0);
assertEquals(2.5, ((BaseLayer) conf.getConf(1).getLayer()).getGradientNormalizationThreshold(), 0.0);
}
/*
@Test
public void testLearningRatePolicyExponential() {

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.nn.conf.layers;
import org.deeplearning4j.BaseDL4JTest;
@ -35,8 +34,8 @@ import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.WeightInitDistribution;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.Nesterovs;
@ -44,107 +43,89 @@ import org.nd4j.linalg.learning.config.RmsProp;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.schedule.MapSchedule;
import org.nd4j.linalg.schedule.ScheduleType;
import java.util.HashMap;
import java.util.Map;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
public class LayerConfigValidationTest extends BaseDL4JTest {
@DisplayName("Layer Config Validation Test")
class LayerConfigValidationTest extends BaseDL4JTest {
@Test
public void testDropConnect() {
@DisplayName("Test Drop Connect")
void testDropConnect() {
// Warning thrown only since some layers may not have l1 or l2
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)).weightNoise(new DropConnect(0.5))
.list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build())
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)).weightNoise(new DropConnect(0.5)).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
}
@Test
public void testL1L2NotSet() {
@DisplayName("Test L 1 L 2 Not Set")
void testL1L2NotSet() {
// Warning thrown only since some layers may not have l1 or l2
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.3))
.list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build())
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
}
@Test(expected = IllegalStateException.class)
@Ignore //Old assumption: throw exception on l1 but no regularization. Current design: warn, not exception
public void testRegNotSetL1Global() {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.3)).l1(0.5).list()
.layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build())
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
}
@Test(expected = IllegalStateException.class)
@Ignore //Old assumption: throw exception on l1 but no regularization. Current design: warn, not exception
public void testRegNotSetL2Local() {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.3)).list()
.layer(0, new DenseLayer.Builder().nIn(2).nOut(2).l2(0.5).build())
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.3)).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
}
@Test
public void testWeightInitDistNotSet() {
@Disabled
@DisplayName("Test Reg Not Set L 1 Global")
void testRegNotSetL1Global() {
assertThrows(IllegalStateException.class, () -> {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.3)).l1(0.5).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
});
}
@Test
@Disabled
@DisplayName("Test Reg Not Set L 2 Local")
void testRegNotSetL2Local() {
assertThrows(IllegalStateException.class, () -> {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.3)).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).l2(0.5).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
});
}
@Test
@DisplayName("Test Weight Init Dist Not Set")
void testWeightInitDistNotSet() {
// Warning thrown only since global dist can be set with a different weight init locally
MultiLayerConfiguration conf =
new NeuralNetConfiguration.Builder().updater(new Sgd(0.3)).dist(new GaussianDistribution(1e-3, 2))
.list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build())
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.3)).dist(new GaussianDistribution(1e-3, 2)).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
}
@Test
public void testNesterovsNotSetGlobal() {
@DisplayName("Test Nesterovs Not Set Global")
void testNesterovsNotSetGlobal() {
// Warnings only thrown
Map<Integer, Double> testMomentumAfter = new HashMap<>();
testMomentumAfter.put(0, 0.1);
MultiLayerConfiguration conf =
new NeuralNetConfiguration.Builder().updater(new Nesterovs(1.0, new MapSchedule(ScheduleType.ITERATION, testMomentumAfter))).list()
.layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build())
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Nesterovs(1.0, new MapSchedule(ScheduleType.ITERATION, testMomentumAfter))).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
}
@Test
public void testCompGraphNullLayer() {
ComputationGraphConfiguration.GraphBuilder gb = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Sgd(0.01))
.seed(42).miniBatch(false).l1(0.2).l2(0.2)
/* Graph Builder */
.updater(Updater.RMSPROP).graphBuilder().addInputs("in")
.addLayer("L" + 1,
new GravesLSTM.Builder().nIn(20).updater(Updater.RMSPROP).nOut(10)
.weightInit(WeightInit.XAVIER)
.dropOut(0.4).l1(0.3).activation(Activation.SIGMOID).build(),
"in")
.addLayer("output",
new RnnOutputLayer.Builder().nIn(20).nOut(10).activation(Activation.SOFTMAX)
.weightInit(WeightInit.RELU_UNIFORM).build(),
"L" + 1)
.setOutputs("output");
@DisplayName("Test Comp Graph Null Layer")
void testCompGraphNullLayer() {
ComputationGraphConfiguration.GraphBuilder gb = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Sgd(0.01)).seed(42).miniBatch(false).l1(0.2).l2(0.2).updater(Updater.RMSPROP).graphBuilder().addInputs("in").addLayer("L" + 1, new GravesLSTM.Builder().nIn(20).updater(Updater.RMSPROP).nOut(10).weightInit(WeightInit.XAVIER).dropOut(0.4).l1(0.3).activation(Activation.SIGMOID).build(), "in").addLayer("output", new RnnOutputLayer.Builder().nIn(20).nOut(10).activation(Activation.SOFTMAX).weightInit(WeightInit.RELU_UNIFORM).build(), "L" + 1).setOutputs("output");
ComputationGraphConfiguration conf = gb.build();
ComputationGraph cg = new ComputationGraph(conf);
cg.init();
}
@Test
public void testPredefinedConfigValues() {
@DisplayName("Test Predefined Config Values")
void testPredefinedConfigValues() {
double expectedMomentum = 0.9;
double expectedAdamMeanDecay = 0.9;
double expectedAdamVarDecay = 0.999;
@ -152,59 +133,38 @@ public class LayerConfigValidationTest extends BaseDL4JTest {
Distribution expectedDist = new NormalDistribution(0, 1);
double expectedL1 = 0.0;
double expectedL2 = 0.0;
// Nesterovs Updater
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Nesterovs(0.9))
.list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).l2(0.5).build())
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).updater(new Nesterovs(0.3, 0.4)).build()).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Nesterovs(0.9)).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).l2(0.5).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).updater(new Nesterovs(0.3, 0.4)).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
BaseLayer layerConf = (BaseLayer) net.getLayer(0).conf().getLayer();
assertEquals(expectedMomentum, ((Nesterovs) layerConf.getIUpdater()).getMomentum(), 1e-3);
assertNull(TestUtils.getL1Reg(layerConf.getRegularization()));
assertEquals(0.5, TestUtils.getL2(layerConf), 1e-3);
BaseLayer layerConf1 = (BaseLayer) net.getLayer(1).conf().getLayer();
assertEquals(0.4, ((Nesterovs) layerConf1.getIUpdater()).getMomentum(), 1e-3);
// Adam Updater
conf = new NeuralNetConfiguration.Builder().updater(new Adam(0.3))
.weightInit(new WeightInitDistribution(expectedDist)).list()
.layer(0, new DenseLayer.Builder().nIn(2).nOut(2).l2(0.5).l1(0.3).build())
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
conf = new NeuralNetConfiguration.Builder().updater(new Adam(0.3)).weightInit(new WeightInitDistribution(expectedDist)).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).l2(0.5).l1(0.3).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
net = new MultiLayerNetwork(conf);
net.init();
layerConf = (BaseLayer) net.getLayer(0).conf().getLayer();
assertEquals(0.3, TestUtils.getL1(layerConf), 1e-3);
assertEquals(0.5, TestUtils.getL2(layerConf), 1e-3);
layerConf1 = (BaseLayer) net.getLayer(1).conf().getLayer();
assertEquals(expectedAdamMeanDecay, ((Adam) layerConf1.getIUpdater()).getBeta1(), 1e-3);
assertEquals(expectedAdamVarDecay, ((Adam) layerConf1.getIUpdater()).getBeta2(), 1e-3);
assertEquals(new WeightInitDistribution(expectedDist), layerConf1.getWeightInitFn());
assertNull(TestUtils.getL1Reg(layerConf1.getRegularization()));
assertNull(TestUtils.getL2Reg(layerConf1.getRegularization()));
//RMSProp Updater
conf = new NeuralNetConfiguration.Builder().updater(new RmsProp(0.3)).list()
.layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build())
.layer(1, new DenseLayer.Builder().nIn(2).nOut(2).updater(new RmsProp(0.3, 0.4, RmsProp.DEFAULT_RMSPROP_EPSILON)).build()).build();
// RMSProp Updater
conf = new NeuralNetConfiguration.Builder().updater(new RmsProp(0.3)).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).updater(new RmsProp(0.3, 0.4, RmsProp.DEFAULT_RMSPROP_EPSILON)).build()).build();
net = new MultiLayerNetwork(conf);
net.init();
layerConf = (BaseLayer) net.getLayer(0).conf().getLayer();
assertEquals(expectedRmsDecay, ((RmsProp) layerConf.getIUpdater()).getRmsDecay(), 1e-3);
assertNull(TestUtils.getL1Reg(layerConf.getRegularization()));
assertNull(TestUtils.getL2Reg(layerConf.getRegularization()));
layerConf1 = (BaseLayer) net.getLayer(1).conf().getLayer();
assertEquals(0.4, ((RmsProp) layerConf1.getIUpdater()).getRmsDecay(), 1e-3);
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.nn.conf.preprocessor;
import org.deeplearning4j.BaseDL4JTest;
@ -28,7 +27,7 @@ import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -36,29 +35,33 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import static org.junit.Assert.*;
import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
/**
**/
*/
@DisplayName("Cnn Processor Test")
class CNNProcessorTest extends BaseDL4JTest {
public class CNNProcessorTest extends BaseDL4JTest {
private static int rows = 28;
private static int cols = 28;
private static INDArray in2D = Nd4j.create(DataType.FLOAT, 1, 784);
private static INDArray in3D = Nd4j.create(DataType.FLOAT, 20, 784, 7);
private static INDArray in4D = Nd4j.create(DataType.FLOAT, 20, 1, 28, 28);
@Test
public void testFeedForwardToCnnPreProcessor() {
@DisplayName("Test Feed Forward To Cnn Pre Processor")
void testFeedForwardToCnnPreProcessor() {
FeedForwardToCnnPreProcessor convProcessor = new FeedForwardToCnnPreProcessor(rows, cols, 1);
INDArray check2to4 = convProcessor.preProcess(in2D, -1, LayerWorkspaceMgr.noWorkspaces());
int val2to4 = check2to4.shape().length;
assertTrue(val2to4 == 4);
assertEquals(Nd4j.create(DataType.FLOAT, 1, 1, 28, 28), check2to4);
INDArray check4to4 = convProcessor.preProcess(in4D, -1, LayerWorkspaceMgr.noWorkspaces());
int val4to4 = check4to4.shape().length;
assertTrue(val4to4 == 4);
@ -66,42 +69,41 @@ public class CNNProcessorTest extends BaseDL4JTest {
}
@Test
public void testFeedForwardToCnnPreProcessor2() {
int[] nRows = {1, 5, 20};
int[] nCols = {1, 5, 20};
int[] nDepth = {1, 3};
int[] nMiniBatchSize = {1, 5};
@DisplayName("Test Feed Forward To Cnn Pre Processor 2")
void testFeedForwardToCnnPreProcessor2() {
int[] nRows = { 1, 5, 20 };
int[] nCols = { 1, 5, 20 };
int[] nDepth = { 1, 3 };
int[] nMiniBatchSize = { 1, 5 };
for (int rows : nRows) {
for (int cols : nCols) {
for (int d : nDepth) {
FeedForwardToCnnPreProcessor convProcessor = new FeedForwardToCnnPreProcessor(rows, cols, d);
for (int miniBatch : nMiniBatchSize) {
long[] ffShape = new long[] {miniBatch, rows * cols * d};
long[] ffShape = new long[] { miniBatch, rows * cols * d };
INDArray rand = Nd4j.rand(ffShape);
INDArray ffInput_c = Nd4j.create(DataType.FLOAT, ffShape, 'c');
INDArray ffInput_f = Nd4j.create(DataType.FLOAT, ffShape, 'f');
ffInput_c.assign(rand);
ffInput_f.assign(rand);
assertEquals(ffInput_c, ffInput_f);
//Test forward pass:
// Test forward pass:
INDArray convAct_c = convProcessor.preProcess(ffInput_c, -1, LayerWorkspaceMgr.noWorkspaces());
INDArray convAct_f = convProcessor.preProcess(ffInput_f, -1, LayerWorkspaceMgr.noWorkspaces());
long[] convShape = {miniBatch, d, rows, cols};
long[] convShape = { miniBatch, d, rows, cols };
assertArrayEquals(convShape, convAct_c.shape());
assertArrayEquals(convShape, convAct_f.shape());
assertEquals(convAct_c, convAct_f);
//Check values:
//CNN reshaping (for each example) takes a 1d vector and converts it to 3d
// Check values:
// CNN reshaping (for each example) takes a 1d vector and converts it to 3d
// (4d total, for minibatch data)
//1d vector is assumed to be rows from channels 0 concatenated, followed by channels 1, etc
// 1d vector is assumed to be rows from channels 0 concatenated, followed by channels 1, etc
for (int ex = 0; ex < miniBatch; ex++) {
for (int r = 0; r < rows; r++) {
for (int c = 0; c < cols; c++) {
for (int depth = 0; depth < d; depth++) {
int origPosition = depth * (rows * cols) + r * cols + c; //pos in vector
// pos in vector
int origPosition = depth * (rows * cols) + r * cols + c;
double vecValue = ffInput_c.getDouble(ex, origPosition);
double convValue = convAct_c.getDouble(ex, depth, r, c);
assertEquals(vecValue, convValue, 0.0);
@ -109,9 +111,8 @@ public class CNNProcessorTest extends BaseDL4JTest {
}
}
}
//Test backward pass:
//Idea is that backward pass should do opposite to forward pass
// Test backward pass:
// Idea is that backward pass should do opposite to forward pass
INDArray epsilon4_c = Nd4j.create(DataType.FLOAT, convShape, 'c');
INDArray epsilon4_f = Nd4j.create(DataType.FLOAT, convShape, 'f');
epsilon4_c.assign(convAct_c);
@ -126,12 +127,11 @@ public class CNNProcessorTest extends BaseDL4JTest {
}
}
@Test
public void testFeedForwardToCnnPreProcessorBackprop() {
@DisplayName("Test Feed Forward To Cnn Pre Processor Backprop")
void testFeedForwardToCnnPreProcessorBackprop() {
FeedForwardToCnnPreProcessor convProcessor = new FeedForwardToCnnPreProcessor(rows, cols, 1);
convProcessor.preProcess(in2D, -1, LayerWorkspaceMgr.noWorkspaces());
INDArray check2to2 = convProcessor.backprop(in2D, -1, LayerWorkspaceMgr.noWorkspaces());
int val2to2 = check2to2.shape().length;
assertTrue(val2to2 == 2);
@ -139,14 +139,13 @@ public class CNNProcessorTest extends BaseDL4JTest {
}
@Test
public void testCnnToFeedForwardProcessor() {
@DisplayName("Test Cnn To Feed Forward Processor")
void testCnnToFeedForwardProcessor() {
CnnToFeedForwardPreProcessor convProcessor = new CnnToFeedForwardPreProcessor(rows, cols, 1);
INDArray check2to4 = convProcessor.backprop(in2D, -1, LayerWorkspaceMgr.noWorkspaces());
int val2to4 = check2to4.shape().length;
assertTrue(val2to4 == 4);
assertEquals(Nd4j.create(DataType.FLOAT, 1, 1, 28, 28), check2to4);
INDArray check4to4 = convProcessor.backprop(in4D, -1, LayerWorkspaceMgr.noWorkspaces());
int val4to4 = check4to4.shape().length;
assertTrue(val4to4 == 4);
@ -154,15 +153,14 @@ public class CNNProcessorTest extends BaseDL4JTest {
}
@Test
public void testCnnToFeedForwardPreProcessorBackprop() {
@DisplayName("Test Cnn To Feed Forward Pre Processor Backprop")
void testCnnToFeedForwardPreProcessorBackprop() {
CnnToFeedForwardPreProcessor convProcessor = new CnnToFeedForwardPreProcessor(rows, cols, 1);
convProcessor.preProcess(in4D, -1, LayerWorkspaceMgr.noWorkspaces());
INDArray check2to2 = convProcessor.preProcess(in2D, -1, LayerWorkspaceMgr.noWorkspaces());
int val2to2 = check2to2.shape().length;
assertTrue(val2to2 == 2);
assertEquals(Nd4j.create(DataType.FLOAT, 1, 784), check2to2);
INDArray check4to2 = convProcessor.preProcess(in4D, -1, LayerWorkspaceMgr.noWorkspaces());
int val4to2 = check4to2.shape().length;
assertTrue(val4to2 == 2);
@ -170,42 +168,41 @@ public class CNNProcessorTest extends BaseDL4JTest {
}
@Test
public void testCnnToFeedForwardPreProcessor2() {
int[] nRows = {1, 5, 20};
int[] nCols = {1, 5, 20};
int[] nDepth = {1, 3};
int[] nMiniBatchSize = {1, 5};
@DisplayName("Test Cnn To Feed Forward Pre Processor 2")
void testCnnToFeedForwardPreProcessor2() {
int[] nRows = { 1, 5, 20 };
int[] nCols = { 1, 5, 20 };
int[] nDepth = { 1, 3 };
int[] nMiniBatchSize = { 1, 5 };
for (int rows : nRows) {
for (int cols : nCols) {
for (int d : nDepth) {
CnnToFeedForwardPreProcessor convProcessor = new CnnToFeedForwardPreProcessor(rows, cols, d);
for (int miniBatch : nMiniBatchSize) {
long[] convActShape = new long[] {miniBatch, d, rows, cols};
long[] convActShape = new long[] { miniBatch, d, rows, cols };
INDArray rand = Nd4j.rand(convActShape);
INDArray convInput_c = Nd4j.create(DataType.FLOAT, convActShape, 'c');
INDArray convInput_f = Nd4j.create(DataType.FLOAT, convActShape, 'f');
convInput_c.assign(rand);
convInput_f.assign(rand);
assertEquals(convInput_c, convInput_f);
//Test forward pass:
// Test forward pass:
INDArray ffAct_c = convProcessor.preProcess(convInput_c, -1, LayerWorkspaceMgr.noWorkspaces());
INDArray ffAct_f = convProcessor.preProcess(convInput_f, -1, LayerWorkspaceMgr.noWorkspaces());
long[] ffActShape = {miniBatch, d * rows * cols};
long[] ffActShape = { miniBatch, d * rows * cols };
assertArrayEquals(ffActShape, ffAct_c.shape());
assertArrayEquals(ffActShape, ffAct_f.shape());
assertEquals(ffAct_c, ffAct_f);
//Check values:
//CNN reshaping (for each example) takes a 1d vector and converts it to 3d
// Check values:
// CNN reshaping (for each example) takes a 1d vector and converts it to 3d
// (4d total, for minibatch data)
//1d vector is assumed to be rows from channels 0 concatenated, followed by channels 1, etc
// 1d vector is assumed to be rows from channels 0 concatenated, followed by channels 1, etc
for (int ex = 0; ex < miniBatch; ex++) {
for (int r = 0; r < rows; r++) {
for (int c = 0; c < cols; c++) {
for (int depth = 0; depth < d; depth++) {
int vectorPosition = depth * (rows * cols) + r * cols + c; //pos in vector after reshape
// pos in vector after reshape
int vectorPosition = depth * (rows * cols) + r * cols + c;
double vecValue = ffAct_c.getDouble(ex, vectorPosition);
double convValue = convInput_c.getDouble(ex, depth, r, c);
assertEquals(convValue, vecValue, 0.0);
@ -213,9 +210,8 @@ public class CNNProcessorTest extends BaseDL4JTest {
}
}
}
//Test backward pass:
//Idea is that backward pass should do opposite to forward pass
// Test backward pass:
// Idea is that backward pass should do opposite to forward pass
INDArray epsilon2_c = Nd4j.create(DataType.FLOAT, ffActShape, 'c');
INDArray epsilon2_f = Nd4j.create(DataType.FLOAT, ffActShape, 'f');
epsilon2_c.assign(ffAct_c);
@ -231,79 +227,32 @@ public class CNNProcessorTest extends BaseDL4JTest {
}
@Test
public void testInvalidInputShape(){
NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder()
.seed(123)
.miniBatch(true)
.cacheMode(CacheMode.DEVICE)
.updater(new Nesterovs(0.9))
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT);
int[] kernelArray = new int[]{3,3};
int[] strideArray = new int[]{1,1};
int[] zeroPaddingArray = new int[]{0,0};
@DisplayName("Test Invalid Input Shape")
void testInvalidInputShape() {
NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(123).miniBatch(true).cacheMode(CacheMode.DEVICE).updater(new Nesterovs(0.9)).gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT);
int[] kernelArray = new int[] { 3, 3 };
int[] strideArray = new int[] { 1, 1 };
int[] zeroPaddingArray = new int[] { 0, 0 };
int processWidth = 4;
NeuralNetConfiguration.ListBuilder listBuilder = builder.list(); // Building the DL4J network
listBuilder = listBuilder.layer(0, new ConvolutionLayer.Builder(kernelArray, strideArray, zeroPaddingArray)
.name("cnn1")
.convolutionMode(ConvolutionMode.Strict)
.nIn(2) // 2 input channels
.nOut(processWidth)
.weightInit(WeightInit.XAVIER_UNIFORM)
.activation(Activation.RELU)
.biasInit(1e-2).build());
listBuilder = listBuilder.layer(1, new ConvolutionLayer.Builder(kernelArray, strideArray, zeroPaddingArray)
.name("cnn2")
.convolutionMode(ConvolutionMode.Strict)
.nOut(processWidth)
.weightInit(WeightInit.XAVIER_UNIFORM)
.activation(Activation.RELU)
.biasInit(1e-2)
.build());
listBuilder = listBuilder.layer(2, new ConvolutionLayer.Builder(kernelArray, strideArray, zeroPaddingArray)
.name("cnn3")
.convolutionMode(ConvolutionMode.Strict)
.nOut(processWidth)
.weightInit(WeightInit.XAVIER_UNIFORM)
.activation(Activation.RELU).build());
listBuilder = listBuilder.layer(3, new ConvolutionLayer.Builder(kernelArray, strideArray, zeroPaddingArray)
.name("cnn4")
.convolutionMode(ConvolutionMode.Strict)
.nOut(processWidth)
.weightInit(WeightInit.XAVIER_UNIFORM)
.activation(Activation.RELU).build());
listBuilder = listBuilder
.layer(4, new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
.name("output")
.nOut(1)
.activation(Activation.TANH)
.build());
MultiLayerConfiguration conf = listBuilder
.setInputType(InputType.convolutional(20, 10, 2))
.build();
// Building the DL4J network
NeuralNetConfiguration.ListBuilder listBuilder = builder.list();
listBuilder = listBuilder.layer(0, new ConvolutionLayer.Builder(kernelArray, strideArray, zeroPaddingArray).name("cnn1").convolutionMode(ConvolutionMode.Strict).nIn(// 2 input channels
2).nOut(processWidth).weightInit(WeightInit.XAVIER_UNIFORM).activation(Activation.RELU).biasInit(1e-2).build());
listBuilder = listBuilder.layer(1, new ConvolutionLayer.Builder(kernelArray, strideArray, zeroPaddingArray).name("cnn2").convolutionMode(ConvolutionMode.Strict).nOut(processWidth).weightInit(WeightInit.XAVIER_UNIFORM).activation(Activation.RELU).biasInit(1e-2).build());
listBuilder = listBuilder.layer(2, new ConvolutionLayer.Builder(kernelArray, strideArray, zeroPaddingArray).name("cnn3").convolutionMode(ConvolutionMode.Strict).nOut(processWidth).weightInit(WeightInit.XAVIER_UNIFORM).activation(Activation.RELU).build());
listBuilder = listBuilder.layer(3, new ConvolutionLayer.Builder(kernelArray, strideArray, zeroPaddingArray).name("cnn4").convolutionMode(ConvolutionMode.Strict).nOut(processWidth).weightInit(WeightInit.XAVIER_UNIFORM).activation(Activation.RELU).build());
listBuilder = listBuilder.layer(4, new OutputLayer.Builder(LossFunctions.LossFunction.MSE).name("output").nOut(1).activation(Activation.TANH).build());
MultiLayerConfiguration conf = listBuilder.setInputType(InputType.convolutional(20, 10, 2)).build();
// For some reason, this model works
MultiLayerNetwork niceModel = new MultiLayerNetwork(conf);
niceModel.init();
niceModel.output(Nd4j.create(DataType.FLOAT, 1, 2, 20, 10)); //Valid
// Valid
niceModel.output(Nd4j.create(DataType.FLOAT, 1, 2, 20, 10));
try {
niceModel.output(Nd4j.create(DataType.FLOAT, 1, 2, 10, 20));
fail("Expected exception");
} catch (IllegalStateException e){
//OK
} catch (IllegalStateException e) {
// OK
}
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.nn.conf.preprocessor;
import org.deeplearning4j.BaseDL4JTest;
@ -27,44 +26,33 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.preprocessor.custom.MyCustomPreprocessor;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.databind.introspect.AnnotatedClass;
import org.nd4j.shade.jackson.databind.jsontype.NamedType;
import java.util.Collection;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
public class CustomPreprocessorTest extends BaseDL4JTest {
@DisplayName("Custom Preprocessor Test")
class CustomPreprocessorTest extends BaseDL4JTest {
@Test
public void testCustomPreprocessor() {
//Second: let's create a MultiLayerCofiguration with one, and check JSON and YAML config actually works...
MultiLayerConfiguration conf =
new NeuralNetConfiguration.Builder().list()
.layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build())
.layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(10)
.activation(Activation.SOFTMAX).nOut(10).build())
.inputPreProcessor(0, new MyCustomPreprocessor())
.build();
@DisplayName("Test Custom Preprocessor")
void testCustomPreprocessor() {
// Second: let's create a MultiLayerCofiguration with one, and check JSON and YAML config actually works...
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build()).layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(10).activation(Activation.SOFTMAX).nOut(10).build()).inputPreProcessor(0, new MyCustomPreprocessor()).build();
String json = conf.toJson();
String yaml = conf.toYaml();
// System.out.println(json);
// System.out.println(json);
MultiLayerConfiguration confFromJson = MultiLayerConfiguration.fromJson(json);
assertEquals(conf, confFromJson);
MultiLayerConfiguration confFromYaml = MultiLayerConfiguration.fromYaml(yaml);
assertEquals(conf, confFromYaml);
assertTrue(confFromJson.getInputPreProcess(0) instanceof MyCustomPreprocessor);
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.nn.layers;
import org.deeplearning4j.BaseDL4JTest;
@ -35,7 +34,7 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.impl.ActivationELU;
import org.nd4j.linalg.activations.impl.ActivationRationalTanh;
@ -46,31 +45,27 @@ import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.List;
import static org.junit.Assert.*;
import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
/**
*/
public class ActivationLayerTest extends BaseDL4JTest {
@DisplayName("Activation Layer Test")
class ActivationLayerTest extends BaseDL4JTest {
@Override
public DataType getDataType(){
public DataType getDataType() {
return DataType.FLOAT;
}
@Test
public void testInputTypes() {
org.deeplearning4j.nn.conf.layers.ActivationLayer l =
new org.deeplearning4j.nn.conf.layers.ActivationLayer.Builder().activation(Activation.RELU)
.build();
@DisplayName("Test Input Types")
void testInputTypes() {
org.deeplearning4j.nn.conf.layers.ActivationLayer l = new org.deeplearning4j.nn.conf.layers.ActivationLayer.Builder().activation(Activation.RELU).build();
InputType in1 = InputType.feedForward(20);
InputType in2 = InputType.convolutional(28, 28, 1);
assertEquals(in1, l.getOutputType(0, in1));
assertEquals(in2, l.getOutputType(0, in2));
assertNull(l.getPreProcessorForInputType(in1));
@ -78,252 +73,132 @@ public class ActivationLayerTest extends BaseDL4JTest {
}
@Test
public void testDenseActivationLayer() throws Exception {
@DisplayName("Test Dense Activation Layer")
void testDenseActivationLayer() throws Exception {
DataSetIterator iter = new MnistDataSetIterator(2, 2);
DataSet next = iter.next();
// Run without separate activation layer
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123)
.list()
.layer(0, new DenseLayer.Builder().nIn(28 * 28 * 1).nOut(10).activation(Activation.RELU)
.weightInit(WeightInit.XAVIER).build())
.layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER)
.activation(Activation.SOFTMAX).nIn(10).nOut(10).build())
.build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123).list().layer(0, new DenseLayer.Builder().nIn(28 * 28 * 1).nOut(10).activation(Activation.RELU).weightInit(WeightInit.XAVIER).build()).layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nIn(10).nOut(10).build()).build();
MultiLayerNetwork network = new MultiLayerNetwork(conf);
network.init();
network.fit(next);
// Run with separate activation layer
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123)
.list()
.layer(0, new DenseLayer.Builder().nIn(28 * 28 * 1).nOut(10).activation(Activation.IDENTITY)
.weightInit(WeightInit.XAVIER).build())
.layer(1, new org.deeplearning4j.nn.conf.layers.ActivationLayer.Builder()
.activation(Activation.RELU).build())
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nIn(10).nOut(10)
.build())
.build();
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123).list().layer(0, new DenseLayer.Builder().nIn(28 * 28 * 1).nOut(10).activation(Activation.IDENTITY).weightInit(WeightInit.XAVIER).build()).layer(1, new org.deeplearning4j.nn.conf.layers.ActivationLayer.Builder().activation(Activation.RELU).build()).layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nIn(10).nOut(10).build()).build();
MultiLayerNetwork network2 = new MultiLayerNetwork(conf2);
network2.init();
network2.fit(next);
// check parameters
assertEquals(network.getLayer(0).getParam("W"), network2.getLayer(0).getParam("W"));
assertEquals(network.getLayer(1).getParam("W"), network2.getLayer(2).getParam("W"));
assertEquals(network.getLayer(0).getParam("b"), network2.getLayer(0).getParam("b"));
assertEquals(network.getLayer(1).getParam("b"), network2.getLayer(2).getParam("b"));
// check activations
network.init();
network.setInput(next.getFeatures());
List<INDArray> activations = network.feedForward(true);
network2.init();
network2.setInput(next.getFeatures());
List<INDArray> activations2 = network2.feedForward(true);
assertEquals(activations.get(1).reshape(activations2.get(2).shape()), activations2.get(2));
assertEquals(activations.get(2), activations2.get(3));
}
@Test
public void testAutoEncoderActivationLayer() throws Exception {
@DisplayName("Test Auto Encoder Activation Layer")
void testAutoEncoderActivationLayer() throws Exception {
int minibatch = 3;
int nIn = 5;
int layerSize = 5;
int nOut = 3;
INDArray next = Nd4j.rand(new int[] {minibatch, nIn});
INDArray next = Nd4j.rand(new int[] { minibatch, nIn });
INDArray labels = Nd4j.zeros(minibatch, nOut);
for (int i = 0; i < minibatch; i++) {
labels.putScalar(i, i % nOut, 1.0);
}
// Run without separate activation layer
Nd4j.getRandom().setSeed(12345);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123)
.list()
.layer(0, new AutoEncoder.Builder().nIn(nIn).nOut(layerSize).corruptionLevel(0.0)
.activation(Activation.SIGMOID).build())
.layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
LossFunctions.LossFunction.RECONSTRUCTION_CROSSENTROPY)
.activation(Activation.SOFTMAX).nIn(layerSize).nOut(nOut)
.build())
.build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123).list().layer(0, new AutoEncoder.Builder().nIn(nIn).nOut(layerSize).corruptionLevel(0.0).activation(Activation.SIGMOID).build()).layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.RECONSTRUCTION_CROSSENTROPY).activation(Activation.SOFTMAX).nIn(layerSize).nOut(nOut).build()).build();
MultiLayerNetwork network = new MultiLayerNetwork(conf);
network.init();
network.fit(next, labels); //Labels are necessary for this test: layer activation function affect pretraining results, otherwise
// Labels are necessary for this test: layer activation function affect pretraining results, otherwise
network.fit(next, labels);
// Run with separate activation layer
Nd4j.getRandom().setSeed(12345);
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123)
.list()
.layer(0, new AutoEncoder.Builder().nIn(nIn).nOut(layerSize).corruptionLevel(0.0)
.activation(Activation.IDENTITY).build())
.layer(1, new org.deeplearning4j.nn.conf.layers.ActivationLayer.Builder()
.activation(Activation.SIGMOID).build())
.layer(2, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
LossFunctions.LossFunction.RECONSTRUCTION_CROSSENTROPY)
.activation(Activation.SOFTMAX).nIn(layerSize).nOut(nOut)
.build())
.build();
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123).list().layer(0, new AutoEncoder.Builder().nIn(nIn).nOut(layerSize).corruptionLevel(0.0).activation(Activation.IDENTITY).build()).layer(1, new org.deeplearning4j.nn.conf.layers.ActivationLayer.Builder().activation(Activation.SIGMOID).build()).layer(2, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.RECONSTRUCTION_CROSSENTROPY).activation(Activation.SOFTMAX).nIn(layerSize).nOut(nOut).build()).build();
MultiLayerNetwork network2 = new MultiLayerNetwork(conf2);
network2.init();
network2.fit(next, labels);
// check parameters
assertEquals(network.getLayer(0).getParam("W"), network2.getLayer(0).getParam("W"));
assertEquals(network.getLayer(1).getParam("W"), network2.getLayer(2).getParam("W"));
assertEquals(network.getLayer(0).getParam("b"), network2.getLayer(0).getParam("b"));
assertEquals(network.getLayer(1).getParam("b"), network2.getLayer(2).getParam("b"));
// check activations
network.init();
network.setInput(next);
List<INDArray> activations = network.feedForward(true);
network2.init();
network2.setInput(next);
List<INDArray> activations2 = network2.feedForward(true);
assertEquals(activations.get(1).reshape(activations2.get(2).shape()), activations2.get(2));
assertEquals(activations.get(2), activations2.get(3));
}
@Test
public void testCNNActivationLayer() throws Exception {
@DisplayName("Test CNN Activation Layer")
void testCNNActivationLayer() throws Exception {
DataSetIterator iter = new MnistDataSetIterator(2, 2);
DataSet next = iter.next();
// Run without separate activation layer
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123)
.list()
.layer(0, new ConvolutionLayer.Builder(4, 4).stride(2, 2).nIn(1).nOut(20)
.activation(Activation.RELU).weightInit(WeightInit.XAVIER).build())
.layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER)
.activation(Activation.SOFTMAX).nOut(10).build())
.setInputType(InputType.convolutionalFlat(28, 28, 1)).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123).list().layer(0, new ConvolutionLayer.Builder(4, 4).stride(2, 2).nIn(1).nOut(20).activation(Activation.RELU).weightInit(WeightInit.XAVIER).build()).layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nOut(10).build()).setInputType(InputType.convolutionalFlat(28, 28, 1)).build();
MultiLayerNetwork network = new MultiLayerNetwork(conf);
network.init();
network.fit(next);
// Run with separate activation layer
MultiLayerConfiguration conf2 =
new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.seed(123).list()
.layer(0, new ConvolutionLayer.Builder(4, 4).stride(2, 2).nIn(1).nOut(20)
.activation(Activation.IDENTITY).weightInit(WeightInit.XAVIER)
.build())
.layer(1, new org.deeplearning4j.nn.conf.layers.ActivationLayer.Builder()
.activation(Activation.RELU).build())
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX)
.nOut(10).build())
.setInputType(InputType.convolutionalFlat(28, 28, 1)).build();
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123).list().layer(0, new ConvolutionLayer.Builder(4, 4).stride(2, 2).nIn(1).nOut(20).activation(Activation.IDENTITY).weightInit(WeightInit.XAVIER).build()).layer(1, new org.deeplearning4j.nn.conf.layers.ActivationLayer.Builder().activation(Activation.RELU).build()).layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nOut(10).build()).setInputType(InputType.convolutionalFlat(28, 28, 1)).build();
MultiLayerNetwork network2 = new MultiLayerNetwork(conf2);
network2.init();
network2.fit(next);
// check parameters
assertEquals(network.getLayer(0).getParam("W"), network2.getLayer(0).getParam("W"));
assertEquals(network.getLayer(1).getParam("W"), network2.getLayer(2).getParam("W"));
assertEquals(network.getLayer(0).getParam("b"), network2.getLayer(0).getParam("b"));
// check activations
network.init();
network.setInput(next.getFeatures());
List<INDArray> activations = network.feedForward(true);
network2.init();
network2.setInput(next.getFeatures());
List<INDArray> activations2 = network2.feedForward(true);
assertEquals(activations.get(1).reshape(activations2.get(2).shape()), activations2.get(2));
assertEquals(activations.get(2), activations2.get(3));
}
@Test
public void testActivationInheritance() {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123)
.weightInit(WeightInit.XAVIER)
.activation(Activation.RATIONALTANH)
.list()
.layer(new DenseLayer.Builder().nIn(10).nOut(10).build())
.layer(new ActivationLayer())
.layer(new ActivationLayer.Builder().build())
.layer(new ActivationLayer.Builder().activation(Activation.ELU).build())
.layer(new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(10).nOut(10).build())
.build();
@DisplayName("Test Activation Inheritance")
void testActivationInheritance() {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123).weightInit(WeightInit.XAVIER).activation(Activation.RATIONALTANH).list().layer(new DenseLayer.Builder().nIn(10).nOut(10).build()).layer(new ActivationLayer()).layer(new ActivationLayer.Builder().build()).layer(new ActivationLayer.Builder().activation(Activation.ELU).build()).layer(new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10).nOut(10).build()).build();
MultiLayerNetwork network = new MultiLayerNetwork(conf);
network.init();
assertNotNull(((ActivationLayer)network.getLayer(1).conf().getLayer()).getActivationFn());
assertTrue(((DenseLayer)network.getLayer(0).conf().getLayer()).getActivationFn() instanceof ActivationRationalTanh);
assertTrue(((ActivationLayer)network.getLayer(1).conf().getLayer()).getActivationFn() instanceof ActivationRationalTanh);
assertTrue(((ActivationLayer)network.getLayer(2).conf().getLayer()).getActivationFn() instanceof ActivationRationalTanh);
assertTrue(((ActivationLayer)network.getLayer(3).conf().getLayer()).getActivationFn() instanceof ActivationELU);
assertTrue(((OutputLayer)network.getLayer(4).conf().getLayer()).getActivationFn() instanceof ActivationSoftmax);
assertNotNull(((ActivationLayer) network.getLayer(1).conf().getLayer()).getActivationFn());
assertTrue(((DenseLayer) network.getLayer(0).conf().getLayer()).getActivationFn() instanceof ActivationRationalTanh);
assertTrue(((ActivationLayer) network.getLayer(1).conf().getLayer()).getActivationFn() instanceof ActivationRationalTanh);
assertTrue(((ActivationLayer) network.getLayer(2).conf().getLayer()).getActivationFn() instanceof ActivationRationalTanh);
assertTrue(((ActivationLayer) network.getLayer(3).conf().getLayer()).getActivationFn() instanceof ActivationELU);
assertTrue(((OutputLayer) network.getLayer(4).conf().getLayer()).getActivationFn() instanceof ActivationSoftmax);
}
@Test
public void testActivationInheritanceCG() {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123)
.weightInit(WeightInit.XAVIER)
.activation(Activation.RATIONALTANH)
.graphBuilder()
.addInputs("in")
.addLayer("0", new DenseLayer.Builder().nIn(10).nOut(10).build(), "in")
.addLayer("1", new ActivationLayer(), "0")
.addLayer("2", new ActivationLayer.Builder().build(), "1")
.addLayer("3", new ActivationLayer.Builder().activation(Activation.ELU).build(), "2")
.addLayer("4", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(10).nOut(10).build(), "3")
.setOutputs("4")
.build();
@DisplayName("Test Activation Inheritance CG")
void testActivationInheritanceCG() {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123).weightInit(WeightInit.XAVIER).activation(Activation.RATIONALTANH).graphBuilder().addInputs("in").addLayer("0", new DenseLayer.Builder().nIn(10).nOut(10).build(), "in").addLayer("1", new ActivationLayer(), "0").addLayer("2", new ActivationLayer.Builder().build(), "1").addLayer("3", new ActivationLayer.Builder().activation(Activation.ELU).build(), "2").addLayer("4", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10).nOut(10).build(), "3").setOutputs("4").build();
ComputationGraph network = new ComputationGraph(conf);
network.init();
assertNotNull(((ActivationLayer)network.getLayer("1").conf().getLayer()).getActivationFn());
assertTrue(((DenseLayer)network.getLayer("0").conf().getLayer()).getActivationFn() instanceof ActivationRationalTanh);
assertTrue(((ActivationLayer)network.getLayer("1").conf().getLayer()).getActivationFn() instanceof ActivationRationalTanh);
assertTrue(((ActivationLayer)network.getLayer("2").conf().getLayer()).getActivationFn() instanceof ActivationRationalTanh);
assertTrue(((ActivationLayer)network.getLayer("3").conf().getLayer()).getActivationFn() instanceof ActivationELU);
assertTrue(((OutputLayer)network.getLayer("4").conf().getLayer()).getActivationFn() instanceof ActivationSoftmax);
assertNotNull(((ActivationLayer) network.getLayer("1").conf().getLayer()).getActivationFn());
assertTrue(((DenseLayer) network.getLayer("0").conf().getLayer()).getActivationFn() instanceof ActivationRationalTanh);
assertTrue(((ActivationLayer) network.getLayer("1").conf().getLayer()).getActivationFn() instanceof ActivationRationalTanh);
assertTrue(((ActivationLayer) network.getLayer("2").conf().getLayer()).getActivationFn() instanceof ActivationRationalTanh);
assertTrue(((ActivationLayer) network.getLayer("3").conf().getLayer()).getActivationFn() instanceof ActivationELU);
assertTrue(((OutputLayer) network.getLayer("4").conf().getLayer()).getActivationFn() instanceof ActivationSoftmax);
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.nn.layers;
import org.deeplearning4j.BaseDL4JTest;
@ -31,49 +30,30 @@ import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
public class AutoEncoderTest extends BaseDL4JTest {
@DisplayName("Auto Encoder Test")
class AutoEncoderTest extends BaseDL4JTest {
@Test
public void sanityCheckIssue5662(){
@DisplayName("Sanity Check Issue 5662")
void sanityCheckIssue5662() {
int mergeSize = 50;
int encdecSize = 25;
int in1Size = 20;
int in2Size = 15;
int hiddenSize = 10;
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
.weightInit(WeightInit.XAVIER)
.graphBuilder()
.addInputs("in1", "in2")
.addLayer("1", new DenseLayer.Builder().nOut(mergeSize).build(), "in1")
.addLayer("2", new DenseLayer.Builder().nOut(mergeSize).build(), "in2")
.addVertex("merge", new MergeVertex(), "1", "2")
.addLayer("e",new AutoEncoder.Builder().nOut(encdecSize).corruptionLevel(0.2).build(),"merge")
.addLayer("hidden",new AutoEncoder.Builder().nOut(hiddenSize).build(),"e")
.addLayer("decoder",new AutoEncoder.Builder().nOut(encdecSize).corruptionLevel(0.2).build(),"hidden")
.addLayer("L4", new DenseLayer.Builder().nOut(mergeSize).build(), "decoder")
.addLayer("out1", new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nOut(in1Size).build(),"L4")
.addLayer("out2",new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nOut(in2Size).build(),"L4")
.setOutputs("out1","out2")
.setInputTypes(InputType.feedForward(in1Size), InputType.feedForward(in2Size))
.build();
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in1", "in2").addLayer("1", new DenseLayer.Builder().nOut(mergeSize).build(), "in1").addLayer("2", new DenseLayer.Builder().nOut(mergeSize).build(), "in2").addVertex("merge", new MergeVertex(), "1", "2").addLayer("e", new AutoEncoder.Builder().nOut(encdecSize).corruptionLevel(0.2).build(), "merge").addLayer("hidden", new AutoEncoder.Builder().nOut(hiddenSize).build(), "e").addLayer("decoder", new AutoEncoder.Builder().nOut(encdecSize).corruptionLevel(0.2).build(), "hidden").addLayer("L4", new DenseLayer.Builder().nOut(mergeSize).build(), "decoder").addLayer("out1", new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nOut(in1Size).build(), "L4").addLayer("out2", new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nOut(in2Size).build(), "L4").setOutputs("out1", "out2").setInputTypes(InputType.feedForward(in1Size), InputType.feedForward(in2Size)).build();
ComputationGraph net = new ComputationGraph(conf);
net.init();
MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(
new INDArray[]{Nd4j.create(1, in1Size), Nd4j.create(1, in2Size)},
new INDArray[]{Nd4j.create(1, in1Size), Nd4j.create(1, in2Size)});
MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(new INDArray[] { Nd4j.create(1, in1Size), Nd4j.create(1, in2Size) }, new INDArray[] { Nd4j.create(1, in1Size), Nd4j.create(1, in2Size) });
net.summary(InputType.feedForward(in1Size), InputType.feedForward(in2Size));
net.fit(new SingletonMultiDataSetIterator(mds));
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.nn.layers;
import lombok.val;
@ -29,46 +28,47 @@ import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Before;
import org.junit.Test;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.util.HashMap;
import java.util.Map;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
@DisplayName("Base Layer Test")
class BaseLayerTest extends BaseDL4JTest {
public class BaseLayerTest extends BaseDL4JTest {
protected INDArray weight = Nd4j.create(new double[] { 0.10, -0.20, -0.15, 0.05 }, new int[] { 2, 2 });
protected INDArray bias = Nd4j.create(new double[] { 0.5, 0.5 }, new int[] { 1, 2 });
protected INDArray weight = Nd4j.create(new double[] {0.10, -0.20, -0.15, 0.05}, new int[] {2, 2});
protected INDArray bias = Nd4j.create(new double[] {0.5, 0.5}, new int[] {1, 2});
protected Map<String, INDArray> paramTable;
@Before
public void doBefore() {
@BeforeEach
void doBefore() {
paramTable = new HashMap<>();
paramTable.put("W", weight);
paramTable.put("b", bias);
}
@Test
public void testSetExistingParamsConvolutionSingleLayer() {
@DisplayName("Test Set Existing Params Convolution Single Layer")
void testSetExistingParamsConvolutionSingleLayer() {
Layer layer = configureSingleLayer();
assertNotEquals(paramTable, layer.paramTable());
layer.setParamTable(paramTable);
assertEquals(paramTable, layer.paramTable());
}
@Test
public void testSetExistingParamsDenseMultiLayer() {
@DisplayName("Test Set Existing Params Dense Multi Layer")
void testSetExistingParamsDenseMultiLayer() {
MultiLayerNetwork net = configureMultiLayer();
for (Layer layer : net.getLayers()) {
assertNotEquals(paramTable, layer.paramTable());
layer.setParamTable(paramTable);
@ -76,31 +76,21 @@ public class BaseLayerTest extends BaseDL4JTest {
}
}
public Layer configureSingleLayer() {
int nIn = 2;
int nOut = 2;
NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder()
.layer(new ConvolutionLayer.Builder().nIn(nIn).nOut(nOut).build()).build();
NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().layer(new ConvolutionLayer.Builder().nIn(nIn).nOut(nOut).build()).build();
val numParams = conf.getLayer().initializer().numParams(conf);
INDArray params = Nd4j.create(1, numParams);
return conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType());
}
public MultiLayerNetwork configureMultiLayer() {
int nIn = 2;
int nOut = 2;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list()
.layer(0, new DenseLayer.Builder().nIn(nIn).nOut(nOut).build())
.layer(1, new OutputLayer.Builder().nIn(nIn).nOut(nOut).activation(Activation.SOFTMAX).build()).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(0, new DenseLayer.Builder().nIn(nIn).nOut(nOut).build()).layer(1, new OutputLayer.Builder().nIn(nIn).nOut(nOut).activation(Activation.SOFTMAX).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
return net;
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.nn.layers;
import org.deeplearning4j.BaseDL4JTest;
@ -28,77 +27,58 @@ import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals;
public class CacheModeTest extends BaseDL4JTest {
@DisplayName("Cache Mode Test")
class CacheModeTest extends BaseDL4JTest {
@Test
public void testConvCacheModeSimple(){
@DisplayName("Test Conv Cache Mode Simple")
void testConvCacheModeSimple() {
MultiLayerConfiguration conf1 = getConf(CacheMode.NONE);
MultiLayerConfiguration conf2 = getConf(CacheMode.DEVICE);
MultiLayerNetwork net1 = new MultiLayerNetwork(conf1);
net1.init();
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
net2.init();
INDArray in = Nd4j.rand(3, 28*28);
INDArray in = Nd4j.rand(3, 28 * 28);
INDArray labels = TestUtils.randomOneHot(3, 10);
INDArray out1 = net1.output(in);
INDArray out2 = net2.output(in);
assertEquals(out1, out2);
assertEquals(net1.params(), net2.params());
net1.fit(in, labels);
net2.fit(in, labels);
assertEquals(net1.params(), net2.params());
}
private static MultiLayerConfiguration getConf(CacheMode cacheMode){
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.activation(Activation.TANH)
.inferenceWorkspaceMode(WorkspaceMode.ENABLED)
.trainingWorkspaceMode(WorkspaceMode.ENABLED)
.seed(12345)
.cacheMode(cacheMode)
.list()
.layer(new ConvolutionLayer.Builder().nOut(3).build())
.layer(new ConvolutionLayer.Builder().nOut(3).build())
.layer(new OutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).build())
.setInputType(InputType.convolutionalFlat(28, 28, 1))
.build();
private static MultiLayerConfiguration getConf(CacheMode cacheMode) {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().activation(Activation.TANH).inferenceWorkspaceMode(WorkspaceMode.ENABLED).trainingWorkspaceMode(WorkspaceMode.ENABLED).seed(12345).cacheMode(cacheMode).list().layer(new ConvolutionLayer.Builder().nOut(3).build()).layer(new ConvolutionLayer.Builder().nOut(3).build()).layer(new OutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).build()).setInputType(InputType.convolutionalFlat(28, 28, 1)).build();
return conf;
}
@Test
public void testLSTMCacheModeSimple(){
for(boolean graves : new boolean[]{true, false}) {
@DisplayName("Test LSTM Cache Mode Simple")
void testLSTMCacheModeSimple() {
for (boolean graves : new boolean[] { true, false }) {
MultiLayerConfiguration conf1 = getConfLSTM(CacheMode.NONE, graves);
MultiLayerConfiguration conf2 = getConfLSTM(CacheMode.DEVICE, graves);
MultiLayerNetwork net1 = new MultiLayerNetwork(conf1);
net1.init();
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
net2.init();
INDArray in = Nd4j.rand(new int[]{3, 3, 10});
INDArray in = Nd4j.rand(new int[] { 3, 3, 10 });
INDArray labels = TestUtils.randomOneHotTimeSeries(3, 10, 10);
INDArray out1 = net1.output(in);
INDArray out2 = net2.output(in);
assertEquals(out1, out2);
assertEquals(net1.params(), net2.params());
net1.fit(in, labels);
net2.fit(in, labels);
@ -106,68 +86,33 @@ public class CacheModeTest extends BaseDL4JTest {
}
}
private static MultiLayerConfiguration getConfLSTM(CacheMode cacheMode, boolean graves){
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.activation(Activation.TANH)
.inferenceWorkspaceMode(WorkspaceMode.ENABLED)
.trainingWorkspaceMode(WorkspaceMode.ENABLED)
.seed(12345)
.cacheMode(cacheMode)
.list()
.layer(graves ?
new GravesLSTM.Builder().nIn(3).nOut(3).build() :
new LSTM.Builder().nIn(3).nOut(3).build())
.layer(graves ?
new GravesLSTM.Builder().nIn(3).nOut(3).build() :
new LSTM.Builder().nIn(3).nOut(3).build())
.layer(new RnnOutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).build())
.build();
private static MultiLayerConfiguration getConfLSTM(CacheMode cacheMode, boolean graves) {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().activation(Activation.TANH).inferenceWorkspaceMode(WorkspaceMode.ENABLED).trainingWorkspaceMode(WorkspaceMode.ENABLED).seed(12345).cacheMode(cacheMode).list().layer(graves ? new GravesLSTM.Builder().nIn(3).nOut(3).build() : new LSTM.Builder().nIn(3).nOut(3).build()).layer(graves ? new GravesLSTM.Builder().nIn(3).nOut(3).build() : new LSTM.Builder().nIn(3).nOut(3).build()).layer(new RnnOutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).build()).build();
return conf;
}
@Test
public void testConvCacheModeSimpleCG(){
@DisplayName("Test Conv Cache Mode Simple CG")
void testConvCacheModeSimpleCG() {
ComputationGraphConfiguration conf1 = getConfCG(CacheMode.NONE);
ComputationGraphConfiguration conf2 = getConfCG(CacheMode.DEVICE);
ComputationGraph net1 = new ComputationGraph(conf1);
net1.init();
ComputationGraph net2 = new ComputationGraph(conf2);
net2.init();
INDArray in = Nd4j.rand(3, 28*28);
INDArray in = Nd4j.rand(3, 28 * 28);
INDArray labels = TestUtils.randomOneHot(3, 10);
INDArray out1 = net1.outputSingle(in);
INDArray out2 = net2.outputSingle(in);
assertEquals(out1, out2);
assertEquals(net1.params(), net2.params());
net1.fit(new DataSet(in, labels));
net2.fit(new DataSet(in, labels));
assertEquals(net1.params(), net2.params());
}
private static ComputationGraphConfiguration getConfCG(CacheMode cacheMode){
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
.activation(Activation.TANH)
.inferenceWorkspaceMode(WorkspaceMode.ENABLED)
.trainingWorkspaceMode(WorkspaceMode.ENABLED)
.seed(12345)
.cacheMode(cacheMode)
.graphBuilder()
.addInputs("in")
.layer("0", new ConvolutionLayer.Builder().nOut(3).build(), "in")
.layer("1", new ConvolutionLayer.Builder().nOut(3).build(), "0")
.layer("2", new OutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).build(), "1")
.setOutputs("2")
.setInputTypes(InputType.convolutionalFlat(28, 28, 1))
.build();
private static ComputationGraphConfiguration getConfCG(CacheMode cacheMode) {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().activation(Activation.TANH).inferenceWorkspaceMode(WorkspaceMode.ENABLED).trainingWorkspaceMode(WorkspaceMode.ENABLED).seed(12345).cacheMode(cacheMode).graphBuilder().addInputs("in").layer("0", new ConvolutionLayer.Builder().nOut(3).build(), "in").layer("1", new ConvolutionLayer.Builder().nOut(3).build(), "0").layer("2", new OutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).build(), "1").setOutputs("2").setInputTypes(InputType.convolutionalFlat(28, 28, 1)).build();
return conf;
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.nn.layers;
import org.deeplearning4j.BaseDL4JTest;
@ -34,8 +33,8 @@ import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
@ -44,73 +43,40 @@ import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import java.util.Random;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertNotEquals;
public class CenterLossOutputLayerTest extends BaseDL4JTest {
@DisplayName("Center Loss Output Layer Test")
class CenterLossOutputLayerTest extends BaseDL4JTest {
private ComputationGraph getGraph(int numLabels, double lambda) {
Nd4j.getRandom().setSeed(12345);
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.dist(new NormalDistribution(0, 1)).updater(new NoOp())
.graphBuilder().addInputs("input1")
.addLayer("l1", new DenseLayer.Builder().nIn(4).nOut(5).activation(Activation.RELU).build(),
"input1")
.addLayer("lossLayer", new CenterLossOutputLayer.Builder()
.lossFunction(LossFunctions.LossFunction.MCXENT).nIn(5).nOut(numLabels)
.lambda(lambda).activation(Activation.SOFTMAX).build(), "l1")
.setOutputs("lossLayer").build();
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).dist(new NormalDistribution(0, 1)).updater(new NoOp()).graphBuilder().addInputs("input1").addLayer("l1", new DenseLayer.Builder().nIn(4).nOut(5).activation(Activation.RELU).build(), "input1").addLayer("lossLayer", new CenterLossOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT).nIn(5).nOut(numLabels).lambda(lambda).activation(Activation.SOFTMAX).build(), "l1").setOutputs("lossLayer").build();
ComputationGraph graph = new ComputationGraph(conf);
graph.init();
return graph;
}
public ComputationGraph getCNNMnistConfig() {
int nChannels = 1; // Number of input channels
int outputNum = 10; // The number of possible outcomes
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) // Training iterations as above
.l2(0.0005).weightInit(WeightInit.XAVIER)
.updater(new Nesterovs(0.01, 0.9))
.graphBuilder().addInputs("input")
.setInputTypes(InputType.convolutionalFlat(28, 28, 1))
.addLayer("0", new ConvolutionLayer.Builder(5, 5)
//nIn and nOut specify channels. nIn here is the nChannels and nOut is the number of filters to be applied
.nIn(nChannels).stride(1, 1).nOut(20).activation(Activation.IDENTITY).build(),
"input")
.addLayer("1", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2)
.stride(2, 2).build(), "0")
.addLayer("2", new ConvolutionLayer.Builder(5, 5)
//Note that nIn need not be specified in later layers
.stride(1, 1).nOut(50).activation(Activation.IDENTITY).build(), "1")
.addLayer("3", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2)
.stride(2, 2).build(), "2")
.addLayer("4", new DenseLayer.Builder().activation(Activation.RELU).nOut(500).build(), "3")
.addLayer("output",
new org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer.Builder(
LossFunction.MCXENT).nOut(outputNum)
.activation(Activation.SOFTMAX).build(),
"4")
.setOutputs("output").build();
// Number of input channels
int nChannels = 1;
// The number of possible outcomes
int outputNum = 10;
ComputationGraphConfiguration conf = // Training iterations as above
new NeuralNetConfiguration.Builder().seed(12345).l2(0.0005).weightInit(WeightInit.XAVIER).updater(new Nesterovs(0.01, 0.9)).graphBuilder().addInputs("input").setInputTypes(InputType.convolutionalFlat(28, 28, 1)).addLayer("0", new ConvolutionLayer.Builder(5, 5).nIn(nChannels).stride(1, 1).nOut(20).activation(Activation.IDENTITY).build(), "input").addLayer("1", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(2, 2).build(), "0").addLayer("2", new ConvolutionLayer.Builder(5, 5).stride(1, 1).nOut(50).activation(Activation.IDENTITY).build(), "1").addLayer("3", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(2, 2).build(), "2").addLayer("4", new DenseLayer.Builder().activation(Activation.RELU).nOut(500).build(), "3").addLayer("output", new org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer.Builder(LossFunction.MCXENT).nOut(outputNum).activation(Activation.SOFTMAX).build(), "4").setOutputs("output").build();
ComputationGraph graph = new ComputationGraph(conf);
graph.init();
return graph;
}
@Test
public void testLambdaConf() {
double[] lambdas = new double[] {0.1, 0.01};
@DisplayName("Test Lambda Conf")
void testLambdaConf() {
double[] lambdas = new double[] { 0.1, 0.01 };
double[] results = new double[2];
int numClasses = 2;
INDArray input = Nd4j.rand(150, 4);
INDArray labels = Nd4j.zeros(150, numClasses);
Random r = new Random(12345);
@ -118,7 +84,6 @@ public class CenterLossOutputLayerTest extends BaseDL4JTest {
labels.putScalar(i, r.nextInt(numClasses), 1.0);
}
ComputationGraph graph;
for (int i = 0; i < lambdas.length; i++) {
graph = getGraph(numClasses, lambdas[i]);
graph.setInput(0, input);
@ -126,27 +91,23 @@ public class CenterLossOutputLayerTest extends BaseDL4JTest {
graph.computeGradientAndScore();
results[i] = graph.score();
}
assertNotEquals(results[0], results[1]);
}
@Test
@Ignore //Should be run manually
public void testMNISTConfig() throws Exception {
int batchSize = 64; // Test batch size
@Disabled
@DisplayName("Test MNIST Config")
void testMNISTConfig() throws Exception {
// Test batch size
int batchSize = 64;
DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, 12345);
ComputationGraph net = getCNNMnistConfig();
net.init();
net.setListeners(new ScoreIterationListener(1));
for (int i = 0; i < 50; i++) {
net.fit(mnistTrain.next());
Thread.sleep(1000);
}
Thread.sleep(100000);
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.nn.layers;
import org.deeplearning4j.BaseDL4JTest;
@ -36,7 +35,7 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -44,30 +43,30 @@ import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
/**
*/
public class DropoutLayerTest extends BaseDL4JTest {
@DisplayName("Dropout Layer Test")
class DropoutLayerTest extends BaseDL4JTest {
@Override
public DataType getDataType(){
public DataType getDataType() {
return DataType.FLOAT;
}
@Test
public void testInputTypes() {
@DisplayName("Test Input Types")
void testInputTypes() {
DropoutLayer config = new DropoutLayer.Builder(0.5).build();
InputType in1 = InputType.feedForward(20);
InputType in2 = InputType.convolutional(28, 28, 1);
assertEquals(in1, config.getOutputType(0, in1));
assertEquals(in2, config.getOutputType(0, in2));
assertNull(config.getPreProcessorForInputType(in1));
@ -75,58 +74,30 @@ public class DropoutLayerTest extends BaseDL4JTest {
}
@Test
public void testDropoutLayerWithoutTraining() throws Exception {
MultiLayerConfiguration confIntegrated = new NeuralNetConfiguration.Builder().seed(3648)
.list().layer(0,
new ConvolutionLayer.Builder(1, 1).stride(1, 1).nIn(1).nOut(1).dropOut(0.25)
.activation(Activation.IDENTITY).weightInit(WeightInit.XAVIER)
.build())
.layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX)
.weightInit(WeightInit.XAVIER).dropOut(0.25)
.nOut(4).build())
.setInputType(InputType.convolutionalFlat(2, 2, 1)).build();
@DisplayName("Test Dropout Layer Without Training")
void testDropoutLayerWithoutTraining() throws Exception {
MultiLayerConfiguration confIntegrated = new NeuralNetConfiguration.Builder().seed(3648).list().layer(0, new ConvolutionLayer.Builder(1, 1).stride(1, 1).nIn(1).nOut(1).dropOut(0.25).activation(Activation.IDENTITY).weightInit(WeightInit.XAVIER).build()).layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).weightInit(WeightInit.XAVIER).dropOut(0.25).nOut(4).build()).setInputType(InputType.convolutionalFlat(2, 2, 1)).build();
MultiLayerNetwork netIntegrated = new MultiLayerNetwork(confIntegrated);
netIntegrated.init();
netIntegrated.getLayer(0).setParam("W", Nd4j.eye(1));
netIntegrated.getLayer(0).setParam("b", Nd4j.zeros(1, 1));
netIntegrated.getLayer(1).setParam("W", Nd4j.eye(4));
netIntegrated.getLayer(1).setParam("b", Nd4j.zeros(4, 1));
MultiLayerConfiguration confSeparate =
new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.seed(3648)
.list().layer(0,
new DropoutLayer.Builder(0.25)
.build())
.layer(1, new ConvolutionLayer.Builder(1, 1).stride(1, 1).nIn(1).nOut(1)
.activation(Activation.IDENTITY).weightInit(WeightInit.XAVIER)
.build())
.layer(2, new DropoutLayer.Builder(0.25).build())
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX)
.nOut(4).build())
.setInputType(InputType.convolutionalFlat(2, 2, 1)).build();
MultiLayerConfiguration confSeparate = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(3648).list().layer(0, new DropoutLayer.Builder(0.25).build()).layer(1, new ConvolutionLayer.Builder(1, 1).stride(1, 1).nIn(1).nOut(1).activation(Activation.IDENTITY).weightInit(WeightInit.XAVIER).build()).layer(2, new DropoutLayer.Builder(0.25).build()).layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nOut(4).build()).setInputType(InputType.convolutionalFlat(2, 2, 1)).build();
MultiLayerNetwork netSeparate = new MultiLayerNetwork(confSeparate);
netSeparate.init();
netSeparate.getLayer(1).setParam("W", Nd4j.eye(1));
netSeparate.getLayer(1).setParam("b", Nd4j.zeros(1, 1));
netSeparate.getLayer(3).setParam("W", Nd4j.eye(4));
netSeparate.getLayer(3).setParam("b", Nd4j.zeros(4, 1));
//Disable input modification for this test:
for(Layer l : netIntegrated.getLayers()){
// Disable input modification for this test:
for (Layer l : netIntegrated.getLayers()) {
l.allowInputModification(false);
}
for(Layer l : netSeparate.getLayers()){
for (Layer l : netSeparate.getLayers()) {
l.allowInputModification(false);
}
INDArray in = Nd4j.arange(1, 5).reshape(1,4);
INDArray in = Nd4j.arange(1, 5).reshape(1, 4);
Nd4j.getRandom().setSeed(12345);
List<INDArray> actTrainIntegrated = netIntegrated.feedForward(in.dup(), true);
Nd4j.getRandom().setSeed(12345);
@ -135,15 +106,10 @@ public class DropoutLayerTest extends BaseDL4JTest {
List<INDArray> actTestIntegrated = netIntegrated.feedForward(in.dup(), false);
Nd4j.getRandom().setSeed(12345);
List<INDArray> actTestSeparate = netSeparate.feedForward(in.dup(), false);
//Check masks:
INDArray maskIntegrated = ((Dropout)netIntegrated.getLayer(0).conf().getLayer().getIDropout()).getMask();
INDArray maskSeparate = ((Dropout)netSeparate.getLayer(0).conf().getLayer().getIDropout()).getMask();
// Check masks:
INDArray maskIntegrated = ((Dropout) netIntegrated.getLayer(0).conf().getLayer().getIDropout()).getMask();
INDArray maskSeparate = ((Dropout) netSeparate.getLayer(0).conf().getLayer().getIDropout()).getMask();
assertEquals(maskIntegrated, maskSeparate);
assertEquals(actTrainIntegrated.get(1), actTrainSeparate.get(2));
assertEquals(actTrainIntegrated.get(2), actTrainSeparate.get(4));
assertEquals(actTestIntegrated.get(1), actTestSeparate.get(2));
@ -151,68 +117,41 @@ public class DropoutLayerTest extends BaseDL4JTest {
}
@Test
public void testDropoutLayerWithDenseMnist() throws Exception {
@DisplayName("Test Dropout Layer With Dense Mnist")
void testDropoutLayerWithDenseMnist() throws Exception {
DataSetIterator iter = new MnistDataSetIterator(2, 2);
DataSet next = iter.next();
// Run without separate activation layer
MultiLayerConfiguration confIntegrated = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123)
.list()
.layer(0, new DenseLayer.Builder().nIn(28 * 28 * 1).nOut(10)
.activation(Activation.RELU).weightInit(
WeightInit.XAVIER)
.build())
.layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).dropOut(0.25)
.nIn(10).nOut(10).build())
.build();
MultiLayerConfiguration confIntegrated = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123).list().layer(0, new DenseLayer.Builder().nIn(28 * 28 * 1).nOut(10).activation(Activation.RELU).weightInit(WeightInit.XAVIER).build()).layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).dropOut(0.25).nIn(10).nOut(10).build()).build();
MultiLayerNetwork netIntegrated = new MultiLayerNetwork(confIntegrated);
netIntegrated.init();
netIntegrated.fit(next);
// Run with separate activation layer
MultiLayerConfiguration confSeparate = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123)
.list()
.layer(0, new DenseLayer.Builder().nIn(28 * 28 * 1).nOut(10).activation(Activation.RELU)
.weightInit(WeightInit.XAVIER).build())
.layer(1, new DropoutLayer.Builder(0.25).build())
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nIn(10).nOut(10)
.build())
.build();
MultiLayerConfiguration confSeparate = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123).list().layer(0, new DenseLayer.Builder().nIn(28 * 28 * 1).nOut(10).activation(Activation.RELU).weightInit(WeightInit.XAVIER).build()).layer(1, new DropoutLayer.Builder(0.25).build()).layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nIn(10).nOut(10).build()).build();
MultiLayerNetwork netSeparate = new MultiLayerNetwork(confSeparate);
netSeparate.init();
netSeparate.fit(next);
//Disable input modification for this test:
for(Layer l : netIntegrated.getLayers()){
// Disable input modification for this test:
for (Layer l : netIntegrated.getLayers()) {
l.allowInputModification(false);
}
for(Layer l : netSeparate.getLayers()){
for (Layer l : netSeparate.getLayers()) {
l.allowInputModification(false);
}
// check parameters
assertEquals(netIntegrated.getLayer(0).getParam("W"), netSeparate.getLayer(0).getParam("W"));
assertEquals(netIntegrated.getLayer(0).getParam("b"), netSeparate.getLayer(0).getParam("b"));
assertEquals(netIntegrated.getLayer(1).getParam("W"), netSeparate.getLayer(2).getParam("W"));
assertEquals(netIntegrated.getLayer(1).getParam("b"), netSeparate.getLayer(2).getParam("b"));
// check activations
netIntegrated.setInput(next.getFeatures());
netSeparate.setInput(next.getFeatures());
Nd4j.getRandom().setSeed(12345);
List<INDArray> actTrainIntegrated = netIntegrated.feedForward(true);
Nd4j.getRandom().setSeed(12345);
List<INDArray> actTrainSeparate = netSeparate.feedForward(true);
assertEquals(actTrainIntegrated.get(1), actTrainSeparate.get(1));
assertEquals(actTrainIntegrated.get(2), actTrainSeparate.get(3));
Nd4j.getRandom().setSeed(12345);
List<INDArray> actTestIntegrated = netIntegrated.feedForward(false);
Nd4j.getRandom().setSeed(12345);
@ -222,77 +161,49 @@ public class DropoutLayerTest extends BaseDL4JTest {
}
@Test
public void testDropoutLayerWithConvMnist() throws Exception {
Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE); //Set to double datatype - MKL-DNN not used for CPU (otherwise different strides due to Dl4J impl permutes)
@DisplayName("Test Dropout Layer With Conv Mnist")
void testDropoutLayerWithConvMnist() throws Exception {
// Set to double datatype - MKL-DNN not used for CPU (otherwise different strides due to Dl4J impl permutes)
Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE);
DataSetIterator iter = new MnistDataSetIterator(2, 2);
DataSet next = iter.next();
// Run without separate activation layer
Nd4j.getRandom().setSeed(12345);
MultiLayerConfiguration confIntegrated = new NeuralNetConfiguration.Builder().seed(123)
.list().layer(0,
new ConvolutionLayer.Builder(4, 4).stride(2, 2).nIn(1).nOut(20)
.activation(Activation.TANH).weightInit(WeightInit.XAVIER)
.build())
.layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).dropOut(0.5)
.nOut(10).build())
.setInputType(InputType.convolutionalFlat(28, 28, 1)).build();
MultiLayerConfiguration confIntegrated = new NeuralNetConfiguration.Builder().seed(123).list().layer(0, new ConvolutionLayer.Builder(4, 4).stride(2, 2).nIn(1).nOut(20).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()).layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).dropOut(0.5).nOut(10).build()).setInputType(InputType.convolutionalFlat(28, 28, 1)).build();
// Run with separate activation layer
Nd4j.getRandom().setSeed(12345);
//Manually configure preprocessors
//This is necessary, otherwise CnnToFeedForwardPreprocessor will be in different locatinos
//i.e., dropout on 4d activations in latter, and dropout on 2d activations in former
// Manually configure preprocessors
// This is necessary, otherwise CnnToFeedForwardPreprocessor will be in different locatinos
// i.e., dropout on 4d activations in latter, and dropout on 2d activations in former
Map<Integer, InputPreProcessor> preProcessorMap = new HashMap<>();
preProcessorMap.put(1, new CnnToFeedForwardPreProcessor(13, 13, 20));
MultiLayerConfiguration confSeparate = new NeuralNetConfiguration.Builder().seed(123).list()
.layer(0, new ConvolutionLayer.Builder(4, 4).stride(2, 2).nIn(1).nOut(20)
.activation(Activation.TANH).weightInit(WeightInit.XAVIER).build())
.layer(1, new DropoutLayer.Builder(0.5).build())
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nOut(10).build())
.inputPreProcessors(preProcessorMap)
.setInputType(InputType.convolutionalFlat(28, 28, 1)).build();
MultiLayerConfiguration confSeparate = new NeuralNetConfiguration.Builder().seed(123).list().layer(0, new ConvolutionLayer.Builder(4, 4).stride(2, 2).nIn(1).nOut(20).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()).layer(1, new DropoutLayer.Builder(0.5).build()).layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nOut(10).build()).inputPreProcessors(preProcessorMap).setInputType(InputType.convolutionalFlat(28, 28, 1)).build();
Nd4j.getRandom().setSeed(12345);
MultiLayerNetwork netIntegrated = new MultiLayerNetwork(confIntegrated);
netIntegrated.init();
Nd4j.getRandom().setSeed(12345);
MultiLayerNetwork netSeparate = new MultiLayerNetwork(confSeparate);
netSeparate.init();
assertEquals(netIntegrated.params(), netSeparate.params());
Nd4j.getRandom().setSeed(12345);
netIntegrated.fit(next);
Nd4j.getRandom().setSeed(12345);
netSeparate.fit(next);
assertEquals(netIntegrated.params(), netSeparate.params());
// check parameters
assertEquals(netIntegrated.getLayer(0).getParam("W"), netSeparate.getLayer(0).getParam("W"));
assertEquals(netIntegrated.getLayer(0).getParam("b"), netSeparate.getLayer(0).getParam("b"));
assertEquals(netIntegrated.getLayer(1).getParam("W"), netSeparate.getLayer(2).getParam("W"));
assertEquals(netIntegrated.getLayer(1).getParam("b"), netSeparate.getLayer(2).getParam("b"));
// check activations
netIntegrated.setInput(next.getFeatures().dup());
netSeparate.setInput(next.getFeatures().dup());
Nd4j.getRandom().setSeed(12345);
List<INDArray> actTrainIntegrated = netIntegrated.feedForward(true);
Nd4j.getRandom().setSeed(12345);
List<INDArray> actTrainSeparate = netSeparate.feedForward(true);
assertEquals(actTrainIntegrated.get(1), actTrainSeparate.get(1));
assertEquals(actTrainIntegrated.get(2), actTrainSeparate.get(3));
netIntegrated.setInput(next.getFeatures().dup());
netSeparate.setInput(next.getFeatures().dup());
Nd4j.getRandom().setSeed(12345);

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.nn.layers;
import lombok.extern.slf4j.Slf4j;
@ -31,116 +30,69 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration;
import org.deeplearning4j.nn.transferlearning.TransferLearning;
import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@Slf4j
public class FrozenLayerTest extends BaseDL4JTest {
@DisplayName("Frozen Layer Test")
class FrozenLayerTest extends BaseDL4JTest {
/*
A model with a few frozen layers ==
Model with non frozen layers set with the output of the forward pass of the frozen layers
*/
@Test
public void testFrozen() {
@DisplayName("Test Frozen")
void testFrozen() {
DataSet randomData = new DataSet(Nd4j.rand(10, 4), Nd4j.rand(10, 3));
NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1))
.activation(Activation.IDENTITY);
NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)).activation(Activation.IDENTITY);
FineTuneConfiguration finetune = new FineTuneConfiguration.Builder().updater(new Sgd(0.1)).build();
MultiLayerNetwork modelToFineTune = new MultiLayerNetwork(overallConf.clone().list()
.layer(0, new DenseLayer.Builder().nIn(4).nOut(3).build())
.layer(1, new DenseLayer.Builder().nIn(3).nOut(2).build())
.layer(2, new DenseLayer.Builder().nIn(2).nOut(3).build())
.layer(3, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3)
.build())
.build());
MultiLayerNetwork modelToFineTune = new MultiLayerNetwork(overallConf.clone().list().layer(0, new DenseLayer.Builder().nIn(4).nOut(3).build()).layer(1, new DenseLayer.Builder().nIn(3).nOut(2).build()).layer(2, new DenseLayer.Builder().nIn(2).nOut(3).build()).layer(3, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3).build()).build());
modelToFineTune.init();
List<INDArray> ff = modelToFineTune.feedForwardToLayer(2, randomData.getFeatures(), false);
INDArray asFrozenFeatures = ff.get(2);
MultiLayerNetwork modelNow = new TransferLearning.Builder(modelToFineTune).fineTuneConfiguration(finetune)
.setFeatureExtractor(1).build();
INDArray paramsLastTwoLayers =
Nd4j.hstack(modelToFineTune.getLayer(2).params(), modelToFineTune.getLayer(3).params());
MultiLayerNetwork notFrozen = new MultiLayerNetwork(overallConf.clone().list()
.layer(0, new DenseLayer.Builder().nIn(2).nOut(3).build())
.layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3)
.build())
.build(), paramsLastTwoLayers);
// assertEquals(modelNow.getLayer(2).conf(), notFrozen.getLayer(0).conf()); //Equal, other than names
// assertEquals(modelNow.getLayer(3).conf(), notFrozen.getLayer(1).conf()); //Equal, other than names
//Check: forward pass
MultiLayerNetwork modelNow = new TransferLearning.Builder(modelToFineTune).fineTuneConfiguration(finetune).setFeatureExtractor(1).build();
INDArray paramsLastTwoLayers = Nd4j.hstack(modelToFineTune.getLayer(2).params(), modelToFineTune.getLayer(3).params());
MultiLayerNetwork notFrozen = new MultiLayerNetwork(overallConf.clone().list().layer(0, new DenseLayer.Builder().nIn(2).nOut(3).build()).layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3).build()).build(), paramsLastTwoLayers);
// assertEquals(modelNow.getLayer(2).conf(), notFrozen.getLayer(0).conf()); //Equal, other than names
// assertEquals(modelNow.getLayer(3).conf(), notFrozen.getLayer(1).conf()); //Equal, other than names
// Check: forward pass
INDArray outNow = modelNow.output(randomData.getFeatures());
INDArray outNotFrozen = notFrozen.output(asFrozenFeatures);
assertEquals(outNow, outNotFrozen);
for (int i = 0; i < 5; i++) {
notFrozen.fit(new DataSet(asFrozenFeatures, randomData.getLabels()));
modelNow.fit(randomData);
}
INDArray expected = Nd4j.hstack(modelToFineTune.getLayer(0).params(), modelToFineTune.getLayer(1).params(),
notFrozen.params());
INDArray expected = Nd4j.hstack(modelToFineTune.getLayer(0).params(), modelToFineTune.getLayer(1).params(), notFrozen.params());
INDArray act = modelNow.params();
assertEquals(expected, act);
}
@Test
public void cloneMLNFrozen() {
@DisplayName("Clone MLN Frozen")
void cloneMLNFrozen() {
DataSet randomData = new DataSet(Nd4j.rand(10, 4), Nd4j.rand(10, 3));
NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1))
.activation(Activation.IDENTITY);
MultiLayerNetwork modelToFineTune = new MultiLayerNetwork(overallConf.list()
.layer(0, new DenseLayer.Builder().nIn(4).nOut(3).build())
.layer(1, new DenseLayer.Builder().nIn(3).nOut(2).build())
.layer(2, new DenseLayer.Builder().nIn(2).nOut(3).build())
.layer(3, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3)
.build())
.build());
NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)).activation(Activation.IDENTITY);
MultiLayerNetwork modelToFineTune = new MultiLayerNetwork(overallConf.list().layer(0, new DenseLayer.Builder().nIn(4).nOut(3).build()).layer(1, new DenseLayer.Builder().nIn(3).nOut(2).build()).layer(2, new DenseLayer.Builder().nIn(2).nOut(3).build()).layer(3, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3).build()).build());
modelToFineTune.init();
INDArray asFrozenFeatures = modelToFineTune.feedForwardToLayer(2, randomData.getFeatures(), false).get(2);
MultiLayerNetwork modelNow = new TransferLearning.Builder(modelToFineTune).setFeatureExtractor(1).build();
MultiLayerNetwork clonedModel = modelNow.clone();
//Check json
// Check json
assertEquals(modelNow.getLayerWiseConfigurations().toJson(), clonedModel.getLayerWiseConfigurations().toJson());
//Check params
// Check params
assertEquals(modelNow.params(), clonedModel.params());
MultiLayerNetwork notFrozen = new MultiLayerNetwork(
overallConf.list().layer(0, new DenseLayer.Builder().nIn(2).nOut(3).build())
.layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(3).nOut(3)
.build())
.build(),
Nd4j.hstack(modelToFineTune.getLayer(2).params(), modelToFineTune.getLayer(3).params()));
MultiLayerNetwork notFrozen = new MultiLayerNetwork(overallConf.list().layer(0, new DenseLayer.Builder().nIn(2).nOut(3).build()).layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3).build()).build(), Nd4j.hstack(modelToFineTune.getLayer(2).params(), modelToFineTune.getLayer(3).params()));
int i = 0;
while (i < 5) {
notFrozen.fit(new DataSet(asFrozenFeatures, randomData.getLabels()));
@ -148,112 +100,49 @@ public class FrozenLayerTest extends BaseDL4JTest {
clonedModel.fit(randomData);
i++;
}
INDArray expectedParams = Nd4j.hstack(modelToFineTune.getLayer(0).params(),
modelToFineTune.getLayer(1).params(), notFrozen.params());
INDArray expectedParams = Nd4j.hstack(modelToFineTune.getLayer(0).params(), modelToFineTune.getLayer(1).params(), notFrozen.params());
assertEquals(expectedParams, modelNow.params());
assertEquals(expectedParams, clonedModel.params());
}
@Test
public void testFrozenCompGraph() {
@DisplayName("Test Frozen Comp Graph")
void testFrozenCompGraph() {
DataSet randomData = new DataSet(Nd4j.rand(10, 4), Nd4j.rand(10, 3));
NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1))
.activation(Activation.IDENTITY);
ComputationGraph modelToFineTune = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In")
.addLayer("layer0", new DenseLayer.Builder().nIn(4).nOut(3).build(), "layer0In")
.addLayer("layer1", new DenseLayer.Builder().nIn(3).nOut(2).build(), "layer0")
.addLayer("layer2", new DenseLayer.Builder().nIn(2).nOut(3).build(), "layer1")
.addLayer("layer3",
new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(3).nOut(3)
.build(),
"layer2")
.setOutputs("layer3").build());
NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)).activation(Activation.IDENTITY);
ComputationGraph modelToFineTune = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In").addLayer("layer0", new DenseLayer.Builder().nIn(4).nOut(3).build(), "layer0In").addLayer("layer1", new DenseLayer.Builder().nIn(3).nOut(2).build(), "layer0").addLayer("layer2", new DenseLayer.Builder().nIn(2).nOut(3).build(), "layer1").addLayer("layer3", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3).build(), "layer2").setOutputs("layer3").build());
modelToFineTune.init();
INDArray asFrozenFeatures = modelToFineTune.feedForward(randomData.getFeatures(), false).get("layer1");
ComputationGraph modelNow =
new TransferLearning.GraphBuilder(modelToFineTune).setFeatureExtractor("layer1").build();
ComputationGraph notFrozen = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In")
.addLayer("layer0", new DenseLayer.Builder().nIn(2).nOut(3).build(), "layer0In")
.addLayer("layer1",
new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(3).nOut(3)
.build(),
"layer0")
.setOutputs("layer1").build());
ComputationGraph modelNow = new TransferLearning.GraphBuilder(modelToFineTune).setFeatureExtractor("layer1").build();
ComputationGraph notFrozen = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In").addLayer("layer0", new DenseLayer.Builder().nIn(2).nOut(3).build(), "layer0In").addLayer("layer1", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3).build(), "layer0").setOutputs("layer1").build());
notFrozen.init();
notFrozen.setParams(Nd4j.hstack(modelToFineTune.getLayer("layer2").params(),
modelToFineTune.getLayer("layer3").params()));
notFrozen.setParams(Nd4j.hstack(modelToFineTune.getLayer("layer2").params(), modelToFineTune.getLayer("layer3").params()));
int i = 0;
while (i < 5) {
notFrozen.fit(new DataSet(asFrozenFeatures, randomData.getLabels()));
modelNow.fit(randomData);
i++;
}
assertEquals(Nd4j.hstack(modelToFineTune.getLayer("layer0").params(),
modelToFineTune.getLayer("layer1").params(), notFrozen.params()), modelNow.params());
assertEquals(Nd4j.hstack(modelToFineTune.getLayer("layer0").params(), modelToFineTune.getLayer("layer1").params(), notFrozen.params()), modelNow.params());
}
@Test
public void cloneCompGraphFrozen() {
@DisplayName("Clone Comp Graph Frozen")
void cloneCompGraphFrozen() {
DataSet randomData = new DataSet(Nd4j.rand(10, 4), Nd4j.rand(10, 3));
NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1))
.activation(Activation.IDENTITY);
ComputationGraph modelToFineTune = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In")
.addLayer("layer0", new DenseLayer.Builder().nIn(4).nOut(3).build(), "layer0In")
.addLayer("layer1", new DenseLayer.Builder().nIn(3).nOut(2).build(), "layer0")
.addLayer("layer2", new DenseLayer.Builder().nIn(2).nOut(3).build(), "layer1")
.addLayer("layer3",
new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(3).nOut(3)
.build(),
"layer2")
.setOutputs("layer3").build());
NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)).activation(Activation.IDENTITY);
ComputationGraph modelToFineTune = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In").addLayer("layer0", new DenseLayer.Builder().nIn(4).nOut(3).build(), "layer0In").addLayer("layer1", new DenseLayer.Builder().nIn(3).nOut(2).build(), "layer0").addLayer("layer2", new DenseLayer.Builder().nIn(2).nOut(3).build(), "layer1").addLayer("layer3", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3).build(), "layer2").setOutputs("layer3").build());
modelToFineTune.init();
INDArray asFrozenFeatures = modelToFineTune.feedForward(randomData.getFeatures(), false).get("layer1");
ComputationGraph modelNow =
new TransferLearning.GraphBuilder(modelToFineTune).setFeatureExtractor("layer1").build();
ComputationGraph modelNow = new TransferLearning.GraphBuilder(modelToFineTune).setFeatureExtractor("layer1").build();
ComputationGraph clonedModel = modelNow.clone();
//Check json
// Check json
assertEquals(clonedModel.getConfiguration().toJson(), modelNow.getConfiguration().toJson());
//Check params
// Check params
assertEquals(modelNow.params(), clonedModel.params());
ComputationGraph notFrozen = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In")
.addLayer("layer0", new DenseLayer.Builder().nIn(2).nOut(3).build(), "layer0In")
.addLayer("layer1",
new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(3).nOut(3)
.build(),
"layer0")
.setOutputs("layer1").build());
ComputationGraph notFrozen = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In").addLayer("layer0", new DenseLayer.Builder().nIn(2).nOut(3).build(), "layer0In").addLayer("layer1", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3).build(), "layer0").setOutputs("layer1").build());
notFrozen.init();
notFrozen.setParams(Nd4j.hstack(modelToFineTune.getLayer("layer2").params(),
modelToFineTune.getLayer("layer3").params()));
notFrozen.setParams(Nd4j.hstack(modelToFineTune.getLayer("layer2").params(), modelToFineTune.getLayer("layer3").params()));
int i = 0;
while (i < 5) {
notFrozen.fit(new DataSet(asFrozenFeatures, randomData.getLabels()));
@ -261,117 +150,54 @@ public class FrozenLayerTest extends BaseDL4JTest {
clonedModel.fit(randomData);
i++;
}
INDArray expectedParams = Nd4j.hstack(modelToFineTune.getLayer("layer0").params(),
modelToFineTune.getLayer("layer1").params(), notFrozen.params());
INDArray expectedParams = Nd4j.hstack(modelToFineTune.getLayer("layer0").params(), modelToFineTune.getLayer("layer1").params(), notFrozen.params());
assertEquals(expectedParams, modelNow.params());
assertEquals(expectedParams, clonedModel.params());
}
@Test
public void testFrozenLayerInstantiation() {
//We need to be able to instantitate frozen layers from JSON etc, and have them be the same as if
@DisplayName("Test Frozen Layer Instantiation")
void testFrozenLayerInstantiation() {
// We need to be able to instantitate frozen layers from JSON etc, and have them be the same as if
// they were initialized via the builder
MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(12345).list()
.layer(0, new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH)
.weightInit(WeightInit.XAVIER).build())
.layer(1, new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH)
.weightInit(WeightInit.XAVIER).build())
.layer(2, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10)
.nOut(10).build())
.build();
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345).list().layer(0,
new org.deeplearning4j.nn.conf.layers.misc.FrozenLayer(new DenseLayer.Builder().nIn(10).nOut(10)
.activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()))
.layer(1, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayer(
new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH)
.weightInit(WeightInit.XAVIER).build()))
.layer(2, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10)
.nOut(10).build())
.build();
MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(12345).list().layer(0, new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()).layer(1, new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()).layer(2, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10).nOut(10).build()).build();
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345).list().layer(0, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayer(new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build())).layer(1, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayer(new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build())).layer(2, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10).nOut(10).build()).build();
MultiLayerNetwork net1 = new MultiLayerNetwork(conf1);
net1.init();
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
net2.init();
assertEquals(net1.params(), net2.params());
String json = conf2.toJson();
MultiLayerConfiguration fromJson = MultiLayerConfiguration.fromJson(json);
assertEquals(conf2, fromJson);
MultiLayerNetwork net3 = new MultiLayerNetwork(fromJson);
net3.init();
INDArray input = Nd4j.rand(10, 10);
INDArray out2 = net2.output(input);
INDArray out3 = net3.output(input);
assertEquals(out2, out3);
}
@Test
public void testFrozenLayerInstantiationCompGraph() {
//We need to be able to instantitate frozen layers from JSON etc, and have them be the same as if
@DisplayName("Test Frozen Layer Instantiation Comp Graph")
void testFrozenLayerInstantiationCompGraph() {
// We need to be able to instantitate frozen layers from JSON etc, and have them be the same as if
// they were initialized via the builder
ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(12345).graphBuilder()
.addInputs("in")
.addLayer("0", new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH)
.weightInit(WeightInit.XAVIER).build(), "in")
.addLayer("1", new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH)
.weightInit(WeightInit.XAVIER).build(), "0")
.addLayer("2", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10)
.nOut(10).build(),
"1")
.setOutputs("2").build();
ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345).graphBuilder()
.addInputs("in")
.addLayer("0", new org.deeplearning4j.nn.conf.layers.misc.FrozenLayer.Builder()
.layer(new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH)
.weightInit(WeightInit.XAVIER).build())
.build(), "in")
.addLayer("1", new org.deeplearning4j.nn.conf.layers.misc.FrozenLayer.Builder()
.layer(new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH)
.weightInit(WeightInit.XAVIER).build())
.build(), "0")
.addLayer("2", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10)
.nOut(10).build(),
"1")
.setOutputs("2").build();
ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(12345).graphBuilder().addInputs("in").addLayer("0", new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build(), "in").addLayer("1", new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build(), "0").addLayer("2", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10).nOut(10).build(), "1").setOutputs("2").build();
ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345).graphBuilder().addInputs("in").addLayer("0", new org.deeplearning4j.nn.conf.layers.misc.FrozenLayer.Builder().layer(new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()).build(), "in").addLayer("1", new org.deeplearning4j.nn.conf.layers.misc.FrozenLayer.Builder().layer(new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()).build(), "0").addLayer("2", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10).nOut(10).build(), "1").setOutputs("2").build();
ComputationGraph net1 = new ComputationGraph(conf1);
net1.init();
ComputationGraph net2 = new ComputationGraph(conf2);
net2.init();
assertEquals(net1.params(), net2.params());
String json = conf2.toJson();
ComputationGraphConfiguration fromJson = ComputationGraphConfiguration.fromJson(json);
assertEquals(conf2, fromJson);
ComputationGraph net3 = new ComputationGraph(fromJson);
net3.init();
INDArray input = Nd4j.rand(10, 10);
INDArray out2 = net2.outputSingle(input);
INDArray out3 = net3.outputSingle(input);
assertEquals(out2, out3);
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.nn.layers;
import lombok.extern.slf4j.Slf4j;
@ -34,363 +33,194 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration;
import org.deeplearning4j.nn.transferlearning.TransferLearning;
import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@Slf4j
public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
@DisplayName("Frozen Layer With Backprop Test")
class FrozenLayerWithBackpropTest extends BaseDL4JTest {
@Test
public void testFrozenWithBackpropLayerInstantiation() {
//We need to be able to instantitate frozen layers from JSON etc, and have them be the same as if
@DisplayName("Test Frozen With Backprop Layer Instantiation")
void testFrozenWithBackpropLayerInstantiation() {
// We need to be able to instantitate frozen layers from JSON etc, and have them be the same as if
// they were initialized via the builder
MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(12345).list()
.layer(0, new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH)
.weightInit(WeightInit.XAVIER).build())
.layer(1, new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH)
.weightInit(WeightInit.XAVIER).build())
.layer(2, new OutputLayer.Builder(
LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10)
.nOut(10).build())
.build();
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345).list().layer(0,
new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(10).nOut(10)
.activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()))
.layer(1, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH)
.weightInit(WeightInit.XAVIER).build()))
.layer(2, new OutputLayer.Builder(
LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10)
.nOut(10).build())
.build();
MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(12345).list().layer(0, new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()).layer(1, new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()).layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10).nOut(10).build()).build();
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345).list().layer(0, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build())).layer(1, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build())).layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10).nOut(10).build()).build();
MultiLayerNetwork net1 = new MultiLayerNetwork(conf1);
net1.init();
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
net2.init();
assertEquals(net1.params(), net2.params());
String json = conf2.toJson();
MultiLayerConfiguration fromJson = MultiLayerConfiguration.fromJson(json);
assertEquals(conf2, fromJson);
MultiLayerNetwork net3 = new MultiLayerNetwork(fromJson);
net3.init();
INDArray input = Nd4j.rand(10, 10);
INDArray out2 = net2.output(input);
INDArray out3 = net3.output(input);
assertEquals(out2, out3);
}
@Test
public void testFrozenLayerInstantiationCompGraph() {
//We need to be able to instantitate frozen layers from JSON etc, and have them be the same as if
@DisplayName("Test Frozen Layer Instantiation Comp Graph")
void testFrozenLayerInstantiationCompGraph() {
// We need to be able to instantitate frozen layers from JSON etc, and have them be the same as if
// they were initialized via the builder
ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(12345).graphBuilder()
.addInputs("in")
.addLayer("0", new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH)
.weightInit(WeightInit.XAVIER).build(), "in")
.addLayer("1", new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH)
.weightInit(WeightInit.XAVIER).build(), "0")
.addLayer("2", new OutputLayer.Builder(
LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10)
.nOut(10).build(),
"1")
.setOutputs("2").build();
ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345).graphBuilder()
.addInputs("in")
.addLayer("0", new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH)
.weightInit(WeightInit.XAVIER).build()), "in")
.addLayer("1", new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH)
.weightInit(WeightInit.XAVIER).build()), "0")
.addLayer("2", new OutputLayer.Builder(
LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10)
.nOut(10).build(),
"1")
.setOutputs("2").build();
ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(12345).graphBuilder().addInputs("in").addLayer("0", new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build(), "in").addLayer("1", new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build(), "0").addLayer("2", new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10).nOut(10).build(), "1").setOutputs("2").build();
ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345).graphBuilder().addInputs("in").addLayer("0", new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()), "in").addLayer("1", new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()), "0").addLayer("2", new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10).nOut(10).build(), "1").setOutputs("2").build();
ComputationGraph net1 = new ComputationGraph(conf1);
net1.init();
ComputationGraph net2 = new ComputationGraph(conf2);
net2.init();
assertEquals(net1.params(), net2.params());
String json = conf2.toJson();
ComputationGraphConfiguration fromJson = ComputationGraphConfiguration.fromJson(json);
assertEquals(conf2, fromJson);
ComputationGraph net3 = new ComputationGraph(fromJson);
net3.init();
INDArray input = Nd4j.rand(10, 10);
INDArray out2 = net2.outputSingle(input);
INDArray out3 = net3.outputSingle(input);
assertEquals(out2, out3);
}
@Test
public void testMultiLayerNetworkFrozenLayerParamsAfterBackprop() {
@DisplayName("Test Multi Layer Network Frozen Layer Params After Backprop")
void testMultiLayerNetworkFrozenLayerParamsAfterBackprop() {
Nd4j.getRandom().setSeed(12345);
DataSet randomData = new DataSet(Nd4j.rand(100, 4), Nd4j.rand(100, 1));
MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder()
.seed(12345)
.weightInit(WeightInit.XAVIER)
.updater(new Sgd(2))
.list()
.layer(new DenseLayer.Builder().nIn(4).nOut(3).build())
.layer(new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
new DenseLayer.Builder().nIn(3).nOut(4).build()))
.layer(new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
new DenseLayer.Builder().nIn(4).nOut(2).build()))
.layer(new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(2).nOut(1).build()))
.build();
MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(12345).weightInit(WeightInit.XAVIER).updater(new Sgd(2)).list().layer(new DenseLayer.Builder().nIn(4).nOut(3).build()).layer(new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(3).nOut(4).build())).layer(new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(4).nOut(2).build())).layer(new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(2).nOut(1).build())).build();
MultiLayerNetwork network = new MultiLayerNetwork(conf1);
network.init();
INDArray unfrozenLayerParams = network.getLayer(0).params().dup();
INDArray frozenLayerParams1 = network.getLayer(1).params().dup();
INDArray frozenLayerParams2 = network.getLayer(2).params().dup();
INDArray frozenOutputLayerParams = network.getLayer(3).params().dup();
for (int i = 0; i < 100; i++) {
network.fit(randomData);
}
assertNotEquals(unfrozenLayerParams, network.getLayer(0).params());
assertEquals(frozenLayerParams1, network.getLayer(1).params());
assertEquals(frozenLayerParams2, network.getLayer(2).params());
assertEquals(frozenOutputLayerParams, network.getLayer(3).params());
}
@Test
public void testComputationGraphFrozenLayerParamsAfterBackprop() {
@DisplayName("Test Computation Graph Frozen Layer Params After Backprop")
void testComputationGraphFrozenLayerParamsAfterBackprop() {
Nd4j.getRandom().setSeed(12345);
DataSet randomData = new DataSet(Nd4j.rand(100, 4), Nd4j.rand(100, 1));
String frozenBranchName = "B1-";
String unfrozenBranchName = "B2-";
String initialLayer = "initial";
String frozenBranchUnfrozenLayer0 = frozenBranchName + "0";
String frozenBranchFrozenLayer1 = frozenBranchName + "1";
String frozenBranchFrozenLayer2 = frozenBranchName + "2";
String frozenBranchOutput = frozenBranchName + "Output";
String unfrozenLayer0 = unfrozenBranchName + "0";
String unfrozenLayer1 = unfrozenBranchName + "1";
String unfrozenBranch2 = unfrozenBranchName + "Output";
ComputationGraphConfiguration computationGraphConf = new NeuralNetConfiguration.Builder()
.updater(new Sgd(2.0))
.seed(12345)
.graphBuilder()
.addInputs("input")
.addLayer(initialLayer, new DenseLayer.Builder().nIn(4).nOut(4).build(),"input")
.addLayer(frozenBranchUnfrozenLayer0, new DenseLayer.Builder().nIn(4).nOut(3).build(),initialLayer)
.addLayer(frozenBranchFrozenLayer1, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
new DenseLayer.Builder().nIn(3).nOut(4).build()),frozenBranchUnfrozenLayer0)
.addLayer(frozenBranchFrozenLayer2, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
new DenseLayer.Builder().nIn(4).nOut(2).build()),frozenBranchFrozenLayer1)
.addLayer(unfrozenLayer0, new DenseLayer.Builder().nIn(4).nOut(4).build(),initialLayer)
.addLayer(unfrozenLayer1, new DenseLayer.Builder().nIn(4).nOut(2).build(),unfrozenLayer0)
.addLayer(unfrozenBranch2, new DenseLayer.Builder().nIn(2).nOut(1).build(),unfrozenLayer1)
.addVertex("merge", new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2)
.addLayer(frozenBranchOutput,new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(3).nOut(1).build()),"merge")
.setOutputs(frozenBranchOutput)
.build();
ComputationGraphConfiguration computationGraphConf = new NeuralNetConfiguration.Builder().updater(new Sgd(2.0)).seed(12345).graphBuilder().addInputs("input").addLayer(initialLayer, new DenseLayer.Builder().nIn(4).nOut(4).build(), "input").addLayer(frozenBranchUnfrozenLayer0, new DenseLayer.Builder().nIn(4).nOut(3).build(), initialLayer).addLayer(frozenBranchFrozenLayer1, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(3).nOut(4).build()), frozenBranchUnfrozenLayer0).addLayer(frozenBranchFrozenLayer2, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(4).nOut(2).build()), frozenBranchFrozenLayer1).addLayer(unfrozenLayer0, new DenseLayer.Builder().nIn(4).nOut(4).build(), initialLayer).addLayer(unfrozenLayer1, new DenseLayer.Builder().nIn(4).nOut(2).build(), unfrozenLayer0).addLayer(unfrozenBranch2, new DenseLayer.Builder().nIn(2).nOut(1).build(), unfrozenLayer1).addVertex("merge", new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2).addLayer(frozenBranchOutput, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(3).nOut(1).build()), "merge").setOutputs(frozenBranchOutput).build();
ComputationGraph computationGraph = new ComputationGraph(computationGraphConf);
computationGraph.init();
INDArray unfrozenLayerParams = computationGraph.getLayer(frozenBranchUnfrozenLayer0).params().dup();
INDArray frozenLayerParams1 = computationGraph.getLayer(frozenBranchFrozenLayer1).params().dup();
INDArray frozenLayerParams2 = computationGraph.getLayer(frozenBranchFrozenLayer2).params().dup();
INDArray frozenOutputLayerParams = computationGraph.getLayer(frozenBranchOutput).params().dup();
for (int i = 0; i < 100; i++) {
computationGraph.fit(randomData);
}
assertNotEquals(unfrozenLayerParams, computationGraph.getLayer(frozenBranchUnfrozenLayer0).params());
assertEquals(frozenLayerParams1, computationGraph.getLayer(frozenBranchFrozenLayer1).params());
assertEquals(frozenLayerParams2, computationGraph.getLayer(frozenBranchFrozenLayer2).params());
assertEquals(frozenOutputLayerParams, computationGraph.getLayer(frozenBranchOutput).params());
}
/**
* Frozen layer should have same results as a layer with Sgd updater with learning rate set to 0
*/
@Test
public void testFrozenLayerVsSgd() {
@DisplayName("Test Frozen Layer Vs Sgd")
void testFrozenLayerVsSgd() {
Nd4j.getRandom().setSeed(12345);
DataSet randomData = new DataSet(Nd4j.rand(100, 4), Nd4j.rand(100, 1));
MultiLayerConfiguration confSgd = new NeuralNetConfiguration.Builder()
.seed(12345)
.weightInit(WeightInit.XAVIER)
.updater(new Sgd(2))
.list()
.layer(0,new DenseLayer.Builder().nIn(4).nOut(3).build())
.layer(1,new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(3).nOut(4).build())
.layer(2,new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(4).nOut(2).build())
.layer(3,new OutputLayer.Builder(LossFunctions.LossFunction.MSE).updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).activation(Activation.TANH).nIn(2).nOut(1).build())
.build();
MultiLayerConfiguration confFrozen = new NeuralNetConfiguration.Builder()
.seed(12345)
.weightInit(WeightInit.XAVIER)
.updater(new Sgd(2))
.list()
.layer(0,new DenseLayer.Builder().nIn(4).nOut(3).build())
.layer(1,new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(3).nOut(4).build()))
.layer(2,new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(4).nOut(2).build()))
.layer(3,new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(2).nOut(1).build()))
.build();
MultiLayerConfiguration confSgd = new NeuralNetConfiguration.Builder().seed(12345).weightInit(WeightInit.XAVIER).updater(new Sgd(2)).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(3).build()).layer(1, new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(3).nOut(4).build()).layer(2, new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(4).nOut(2).build()).layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MSE).updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).activation(Activation.TANH).nIn(2).nOut(1).build()).build();
MultiLayerConfiguration confFrozen = new NeuralNetConfiguration.Builder().seed(12345).weightInit(WeightInit.XAVIER).updater(new Sgd(2)).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(3).build()).layer(1, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(3).nOut(4).build())).layer(2, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(4).nOut(2).build())).layer(3, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(2).nOut(1).build())).build();
MultiLayerNetwork frozenNetwork = new MultiLayerNetwork(confFrozen);
frozenNetwork.init();
INDArray unfrozenLayerParams = frozenNetwork.getLayer(0).params().dup();
INDArray frozenLayerParams1 = frozenNetwork.getLayer(1).params().dup();
INDArray frozenLayerParams2 = frozenNetwork.getLayer(2).params().dup();
INDArray frozenOutputLayerParams = frozenNetwork.getLayer(3).params().dup();
MultiLayerNetwork sgdNetwork = new MultiLayerNetwork(confSgd);
sgdNetwork.init();
INDArray unfrozenSgdLayerParams = sgdNetwork.getLayer(0).params().dup();
INDArray frozenSgdLayerParams1 = sgdNetwork.getLayer(1).params().dup();
INDArray frozenSgdLayerParams2 = sgdNetwork.getLayer(2).params().dup();
INDArray frozenSgdOutputLayerParams = sgdNetwork.getLayer(3).params().dup();
for (int i = 0; i < 100; i++) {
frozenNetwork.fit(randomData);
}
for (int i = 0; i < 100; i++) {
sgdNetwork.fit(randomData);
}
assertEquals(frozenNetwork.getLayer(0).params(), sgdNetwork.getLayer(0).params());
assertEquals(frozenNetwork.getLayer(1).params(), sgdNetwork.getLayer(1).params());
assertEquals(frozenNetwork.getLayer(2).params(), sgdNetwork.getLayer(2).params());
assertEquals(frozenNetwork.getLayer(3).params(), sgdNetwork.getLayer(3).params());
}
@Test
public void testComputationGraphVsSgd() {
@DisplayName("Test Computation Graph Vs Sgd")
void testComputationGraphVsSgd() {
Nd4j.getRandom().setSeed(12345);
DataSet randomData = new DataSet(Nd4j.rand(100, 4), Nd4j.rand(100, 1));
String frozenBranchName = "B1-";
String unfrozenBranchName = "B2-";
String initialLayer = "initial";
String frozenBranchUnfrozenLayer0 = frozenBranchName + "0";
String frozenBranchFrozenLayer1 = frozenBranchName + "1";
String frozenBranchFrozenLayer2 = frozenBranchName + "2";
String frozenBranchOutput = frozenBranchName + "Output";
String unfrozenLayer0 = unfrozenBranchName + "0";
String unfrozenLayer1 = unfrozenBranchName + "1";
String unfrozenBranch2 = unfrozenBranchName + "Output";
ComputationGraphConfiguration computationGraphConf = new NeuralNetConfiguration.Builder()
.updater(new Sgd(2.0))
.seed(12345)
.graphBuilder()
.addInputs("input")
.addLayer(initialLayer,new DenseLayer.Builder().nIn(4).nOut(4).build(),"input")
.addLayer(frozenBranchUnfrozenLayer0,new DenseLayer.Builder().nIn(4).nOut(3).build(), initialLayer)
.addLayer(frozenBranchFrozenLayer1,new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
new DenseLayer.Builder().nIn(3).nOut(4).build()),frozenBranchUnfrozenLayer0)
.addLayer(frozenBranchFrozenLayer2,
new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
new DenseLayer.Builder().nIn(4).nOut(2).build()),frozenBranchFrozenLayer1)
.addLayer(unfrozenLayer0,new DenseLayer.Builder().nIn(4).nOut(4).build(),initialLayer)
.addLayer(unfrozenLayer1,new DenseLayer.Builder().nIn(4).nOut(2).build(),unfrozenLayer0)
.addLayer(unfrozenBranch2,new DenseLayer.Builder().nIn(2).nOut(1).build(),unfrozenLayer1)
.addVertex("merge",new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2)
.addLayer(frozenBranchOutput, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(3).nOut(1).build()),"merge")
.setOutputs(frozenBranchOutput)
.build();
ComputationGraphConfiguration computationSgdGraphConf = new NeuralNetConfiguration.Builder()
.updater(new Sgd(2.0))
.seed(12345)
.graphBuilder()
.addInputs("input")
.addLayer(initialLayer, new DenseLayer.Builder().nIn(4).nOut(4).build(),"input")
.addLayer(frozenBranchUnfrozenLayer0,new DenseLayer.Builder().nIn(4).nOut(3).build(),initialLayer)
.addLayer(frozenBranchFrozenLayer1,new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(3).nOut(4).build(),frozenBranchUnfrozenLayer0)
.addLayer(frozenBranchFrozenLayer2,new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(4).nOut(2).build(),frozenBranchFrozenLayer1)
.addLayer(unfrozenLayer0,new DenseLayer.Builder().nIn(4).nOut(4).build(),initialLayer)
.addLayer(unfrozenLayer1,new DenseLayer.Builder().nIn(4).nOut(2).build(),unfrozenLayer0)
.addLayer(unfrozenBranch2,new DenseLayer.Builder().nIn(2).nOut(1).build(),unfrozenLayer1)
.addVertex("merge",new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2)
.addLayer(frozenBranchOutput,new OutputLayer.Builder(LossFunctions.LossFunction.MSE).updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).activation(Activation.TANH).nIn(3).nOut(1).build(),"merge")
.setOutputs(frozenBranchOutput)
.build();
ComputationGraphConfiguration computationGraphConf = new NeuralNetConfiguration.Builder().updater(new Sgd(2.0)).seed(12345).graphBuilder().addInputs("input").addLayer(initialLayer, new DenseLayer.Builder().nIn(4).nOut(4).build(), "input").addLayer(frozenBranchUnfrozenLayer0, new DenseLayer.Builder().nIn(4).nOut(3).build(), initialLayer).addLayer(frozenBranchFrozenLayer1, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(3).nOut(4).build()), frozenBranchUnfrozenLayer0).addLayer(frozenBranchFrozenLayer2, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(4).nOut(2).build()), frozenBranchFrozenLayer1).addLayer(unfrozenLayer0, new DenseLayer.Builder().nIn(4).nOut(4).build(), initialLayer).addLayer(unfrozenLayer1, new DenseLayer.Builder().nIn(4).nOut(2).build(), unfrozenLayer0).addLayer(unfrozenBranch2, new DenseLayer.Builder().nIn(2).nOut(1).build(), unfrozenLayer1).addVertex("merge", new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2).addLayer(frozenBranchOutput, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(3).nOut(1).build()), "merge").setOutputs(frozenBranchOutput).build();
ComputationGraphConfiguration computationSgdGraphConf = new NeuralNetConfiguration.Builder().updater(new Sgd(2.0)).seed(12345).graphBuilder().addInputs("input").addLayer(initialLayer, new DenseLayer.Builder().nIn(4).nOut(4).build(), "input").addLayer(frozenBranchUnfrozenLayer0, new DenseLayer.Builder().nIn(4).nOut(3).build(), initialLayer).addLayer(frozenBranchFrozenLayer1, new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(3).nOut(4).build(), frozenBranchUnfrozenLayer0).addLayer(frozenBranchFrozenLayer2, new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(4).nOut(2).build(), frozenBranchFrozenLayer1).addLayer(unfrozenLayer0, new DenseLayer.Builder().nIn(4).nOut(4).build(), initialLayer).addLayer(unfrozenLayer1, new DenseLayer.Builder().nIn(4).nOut(2).build(), unfrozenLayer0).addLayer(unfrozenBranch2, new DenseLayer.Builder().nIn(2).nOut(1).build(), unfrozenLayer1).addVertex("merge", new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2).addLayer(frozenBranchOutput, new OutputLayer.Builder(LossFunctions.LossFunction.MSE).updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).activation(Activation.TANH).nIn(3).nOut(1).build(), "merge").setOutputs(frozenBranchOutput).build();
ComputationGraph frozenComputationGraph = new ComputationGraph(computationGraphConf);
frozenComputationGraph.init();
INDArray unfrozenLayerParams = frozenComputationGraph.getLayer(frozenBranchUnfrozenLayer0).params().dup();
INDArray frozenLayerParams1 = frozenComputationGraph.getLayer(frozenBranchFrozenLayer1).params().dup();
INDArray frozenLayerParams2 = frozenComputationGraph.getLayer(frozenBranchFrozenLayer2).params().dup();
INDArray frozenOutputLayerParams = frozenComputationGraph.getLayer(frozenBranchOutput).params().dup();
ComputationGraph sgdComputationGraph = new ComputationGraph(computationSgdGraphConf);
sgdComputationGraph.init();
INDArray unfrozenSgdLayerParams = sgdComputationGraph.getLayer(frozenBranchUnfrozenLayer0).params().dup();
INDArray frozenSgdLayerParams1 = sgdComputationGraph.getLayer(frozenBranchFrozenLayer1).params().dup();
INDArray frozenSgdLayerParams2 = sgdComputationGraph.getLayer(frozenBranchFrozenLayer2).params().dup();
INDArray frozenSgdOutputLayerParams = sgdComputationGraph.getLayer(frozenBranchOutput).params().dup();
for (int i = 0; i < 100; i++) {
frozenComputationGraph.fit(randomData);
}
for (int i = 0; i < 100; i++) {
sgdComputationGraph.fit(randomData);
}
assertEquals(frozenComputationGraph.getLayer(frozenBranchUnfrozenLayer0).params(), sgdComputationGraph.getLayer(frozenBranchUnfrozenLayer0).params());
assertEquals(frozenComputationGraph.getLayer(frozenBranchFrozenLayer1).params(), sgdComputationGraph.getLayer(frozenBranchFrozenLayer1).params());
assertEquals(frozenComputationGraph.getLayer(frozenBranchFrozenLayer2).params(), sgdComputationGraph.getLayer(frozenBranchFrozenLayer2).params());
assertEquals(frozenComputationGraph.getLayer(frozenBranchOutput).params(), sgdComputationGraph.getLayer(frozenBranchOutput).params());
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.nn.layers;
import lombok.extern.slf4j.Slf4j;
@ -36,7 +35,7 @@ import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -46,123 +45,88 @@ import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import java.util.Collections;
import java.util.Random;
import static org.junit.Assert.*;
import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@Slf4j
public class OutputLayerTest extends BaseDL4JTest {
@DisplayName("Output Layer Test")
class OutputLayerTest extends BaseDL4JTest {
@Test
public void testSetParams() {
NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT)
.updater(new Sgd(1e-1))
.layer(new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder().nIn(4).nOut(3)
.weightInit(WeightInit.ZERO).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.build();
@DisplayName("Test Set Params")
void testSetParams() {
NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).updater(new Sgd(1e-1)).layer(new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder().nIn(4).nOut(3).weightInit(WeightInit.ZERO).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()).build();
long numParams = conf.getLayer().initializer().numParams(conf);
INDArray params = Nd4j.create(1, numParams);
OutputLayer l = (OutputLayer) conf.getLayer().instantiate(conf,
Collections.<TrainingListener>singletonList(new ScoreIterationListener(1)), 0, params, true, params.dataType());
OutputLayer l = (OutputLayer) conf.getLayer().instantiate(conf, Collections.<TrainingListener>singletonList(new ScoreIterationListener(1)), 0, params, true, params.dataType());
params = l.params();
l.setParams(params);
assertEquals(params, l.params());
}
@Test
public void testOutputLayersRnnForwardPass() {
//Test output layer with RNNs (
//Expect all outputs etc. to be 2d
@DisplayName("Test Output Layers Rnn Forward Pass")
void testOutputLayersRnnForwardPass() {
// Test output layer with RNNs (
// Expect all outputs etc. to be 2d
int nIn = 2;
int nOut = 5;
int layerSize = 4;
int timeSeriesLength = 6;
int miniBatchSize = 3;
Random r = new Random(12345L);
INDArray input = Nd4j.zeros(miniBatchSize, nIn, timeSeriesLength);
for (int i = 0; i < miniBatchSize; i++) {
for (int j = 0; j < nIn; j++) {
for (int k = 0; k < timeSeriesLength; k++) {
input.putScalar(new int[] {i, j, k}, r.nextDouble() - 0.5);
input.putScalar(new int[] { i, j, k }, r.nextDouble() - 0.5);
}
}
}
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345L).list()
.layer(0, new GravesLSTM.Builder().nIn(nIn).nOut(layerSize)
.dist(new NormalDistribution(0, 1)).activation(Activation.TANH)
.updater(new NoOp()).build())
.layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(layerSize).nOut(nOut)
.dist(new NormalDistribution(0, 1))
.updater(new NoOp()).build())
.inputPreProcessor(1, new RnnToFeedForwardPreProcessor()).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345L).list().layer(0, new GravesLSTM.Builder().nIn(nIn).nOut(layerSize).dist(new NormalDistribution(0, 1)).activation(Activation.TANH).updater(new NoOp()).build()).layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(layerSize).nOut(nOut).dist(new NormalDistribution(0, 1)).updater(new NoOp()).build()).inputPreProcessor(1, new RnnToFeedForwardPreProcessor()).build();
MultiLayerNetwork mln = new MultiLayerNetwork(conf);
mln.init();
INDArray out2d = mln.feedForward(input).get(2);
assertArrayEquals(out2d.shape(), new long[] {miniBatchSize * timeSeriesLength, nOut});
assertArrayEquals(out2d.shape(), new long[] { miniBatchSize * timeSeriesLength, nOut });
INDArray out = mln.output(input);
assertArrayEquals(out.shape(), new long[] {miniBatchSize * timeSeriesLength, nOut});
assertArrayEquals(out.shape(), new long[] { miniBatchSize * timeSeriesLength, nOut });
INDArray preout = mln.output(input);
assertArrayEquals(preout.shape(), new long[] {miniBatchSize * timeSeriesLength, nOut});
//As above, but for RnnOutputLayer. Expect all activations etc. to be 3d
MultiLayerConfiguration confRnn = new NeuralNetConfiguration.Builder().seed(12345L).list()
.layer(0, new GravesLSTM.Builder().nIn(nIn).nOut(layerSize)
.dist(new NormalDistribution(0, 1)).activation(Activation.TANH)
.updater(new NoOp()).build())
.layer(1, new org.deeplearning4j.nn.conf.layers.RnnOutputLayer.Builder(LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(layerSize).nOut(nOut)
.dist(new NormalDistribution(0, 1))
.updater(new NoOp()).build())
.build();
assertArrayEquals(preout.shape(), new long[] { miniBatchSize * timeSeriesLength, nOut });
// As above, but for RnnOutputLayer. Expect all activations etc. to be 3d
MultiLayerConfiguration confRnn = new NeuralNetConfiguration.Builder().seed(12345L).list().layer(0, new GravesLSTM.Builder().nIn(nIn).nOut(layerSize).dist(new NormalDistribution(0, 1)).activation(Activation.TANH).updater(new NoOp()).build()).layer(1, new org.deeplearning4j.nn.conf.layers.RnnOutputLayer.Builder(LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(layerSize).nOut(nOut).dist(new NormalDistribution(0, 1)).updater(new NoOp()).build()).build();
MultiLayerNetwork mlnRnn = new MultiLayerNetwork(confRnn);
mln.init();
INDArray out3d = mlnRnn.feedForward(input).get(2);
assertArrayEquals(out3d.shape(), new long[] {miniBatchSize, nOut, timeSeriesLength});
assertArrayEquals(out3d.shape(), new long[] { miniBatchSize, nOut, timeSeriesLength });
INDArray outRnn = mlnRnn.output(input);
assertArrayEquals(outRnn.shape(), new long[] {miniBatchSize, nOut, timeSeriesLength});
assertArrayEquals(outRnn.shape(), new long[] { miniBatchSize, nOut, timeSeriesLength });
INDArray preoutRnn = mlnRnn.output(input);
assertArrayEquals(preoutRnn.shape(), new long[] {miniBatchSize, nOut, timeSeriesLength});
assertArrayEquals(preoutRnn.shape(), new long[] { miniBatchSize, nOut, timeSeriesLength });
}
@Test
public void testRnnOutputLayerIncEdgeCases() {
//Basic test + test edge cases: timeSeriesLength==1, miniBatchSize==1, both
int[] tsLength = {5, 1, 5, 1};
int[] miniBatch = {7, 7, 1, 1};
@DisplayName("Test Rnn Output Layer Inc Edge Cases")
void testRnnOutputLayerIncEdgeCases() {
// Basic test + test edge cases: timeSeriesLength==1, miniBatchSize==1, both
int[] tsLength = { 5, 1, 5, 1 };
int[] miniBatch = { 7, 7, 1, 1 };
int nIn = 3;
int nOut = 6;
int layerSize = 4;
FeedForwardToRnnPreProcessor proc = new FeedForwardToRnnPreProcessor();
for (int t = 0; t < tsLength.length; t++) {
Nd4j.getRandom().setSeed(12345);
int timeSeriesLength = tsLength[t];
int miniBatchSize = miniBatch[t];
Random r = new Random(12345L);
INDArray input = Nd4j.zeros(miniBatchSize, nIn, timeSeriesLength);
for (int i = 0; i < miniBatchSize; i++) {
for (int j = 0; j < nIn; j++) {
for (int k = 0; k < timeSeriesLength; k++) {
input.putScalar(new int[] {i, j, k}, r.nextDouble() - 0.5);
input.putScalar(new int[] { i, j, k }, r.nextDouble() - 0.5);
}
}
}
@ -170,406 +134,200 @@ public class OutputLayerTest extends BaseDL4JTest {
for (int i = 0; i < miniBatchSize; i++) {
for (int j = 0; j < timeSeriesLength; j++) {
int idx = r.nextInt(nOut);
labels3d.putScalar(new int[] {i, idx, j}, 1.0f);
labels3d.putScalar(new int[] { i, idx, j }, 1.0f);
}
}
INDArray labels2d = proc.backprop(labels3d, miniBatchSize, LayerWorkspaceMgr.noWorkspaces());
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345L).list()
.layer(0, new GravesLSTM.Builder().nIn(nIn).nOut(layerSize)
.dist(new NormalDistribution(0, 1))
.activation(Activation.TANH).updater(new NoOp()).build())
.layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(layerSize).nOut(nOut)
.dist(new NormalDistribution(0, 1))
.updater(new NoOp()).build())
.inputPreProcessor(1, new RnnToFeedForwardPreProcessor())
.build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345L).list().layer(0, new GravesLSTM.Builder().nIn(nIn).nOut(layerSize).dist(new NormalDistribution(0, 1)).activation(Activation.TANH).updater(new NoOp()).build()).layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(layerSize).nOut(nOut).dist(new NormalDistribution(0, 1)).updater(new NoOp()).build()).inputPreProcessor(1, new RnnToFeedForwardPreProcessor()).build();
MultiLayerNetwork mln = new MultiLayerNetwork(conf);
mln.init();
INDArray out2d = mln.feedForward(input).get(2);
INDArray out3d = proc.preProcess(out2d, miniBatchSize, LayerWorkspaceMgr.noWorkspaces());
MultiLayerConfiguration confRnn = new NeuralNetConfiguration.Builder().seed(12345L).list()
.layer(0, new GravesLSTM.Builder().nIn(nIn).nOut(layerSize)
.dist(new NormalDistribution(0, 1))
.activation(Activation.TANH).updater(new NoOp()).build())
.layer(1, new org.deeplearning4j.nn.conf.layers.RnnOutputLayer.Builder(LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(layerSize).nOut(nOut)
.dist(new NormalDistribution(0, 1))
.updater(new NoOp()).build())
.build();
MultiLayerConfiguration confRnn = new NeuralNetConfiguration.Builder().seed(12345L).list().layer(0, new GravesLSTM.Builder().nIn(nIn).nOut(layerSize).dist(new NormalDistribution(0, 1)).activation(Activation.TANH).updater(new NoOp()).build()).layer(1, new org.deeplearning4j.nn.conf.layers.RnnOutputLayer.Builder(LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(layerSize).nOut(nOut).dist(new NormalDistribution(0, 1)).updater(new NoOp()).build()).build();
MultiLayerNetwork mlnRnn = new MultiLayerNetwork(confRnn);
mlnRnn.init();
INDArray outRnn = mlnRnn.feedForward(input).get(2);
mln.setLabels(labels2d);
mlnRnn.setLabels(labels3d);
mln.computeGradientAndScore();
mlnRnn.computeGradientAndScore();
//score is average over all examples.
//However: OutputLayer version has miniBatch*timeSeriesLength "examples" (after reshaping)
//RnnOutputLayer has miniBatch examples
//Hence: expect difference in scores by factor of timeSeriesLength
// score is average over all examples.
// However: OutputLayer version has miniBatch*timeSeriesLength "examples" (after reshaping)
// RnnOutputLayer has miniBatch examples
// Hence: expect difference in scores by factor of timeSeriesLength
double score = mln.score() * timeSeriesLength;
double scoreRNN = mlnRnn.score();
assertTrue(!Double.isNaN(score));
assertTrue(!Double.isNaN(scoreRNN));
double relError = Math.abs(score - scoreRNN) / (Math.abs(score) + Math.abs(scoreRNN));
System.out.println(relError);
assertTrue(relError < 1e-6);
//Check labels and inputs for output layer:
// Check labels and inputs for output layer:
OutputLayer ol = (OutputLayer) mln.getOutputLayer();
assertArrayEquals(ol.getInput().shape(), new long[] {miniBatchSize * timeSeriesLength, layerSize});
assertArrayEquals(ol.getLabels().shape(), new long[] {miniBatchSize * timeSeriesLength, nOut});
assertArrayEquals(ol.getInput().shape(), new long[] { miniBatchSize * timeSeriesLength, layerSize });
assertArrayEquals(ol.getLabels().shape(), new long[] { miniBatchSize * timeSeriesLength, nOut });
RnnOutputLayer rnnol = (RnnOutputLayer) mlnRnn.getOutputLayer();
//assertArrayEquals(rnnol.getInput().shape(),new int[]{miniBatchSize,layerSize,timeSeriesLength});
//Input may be set by BaseLayer methods. Thus input may end up as reshaped 2d version instead of original 3d version.
//Not ideal, but everything else works.
assertArrayEquals(rnnol.getLabels().shape(), new long[] {miniBatchSize, nOut, timeSeriesLength});
//Check shapes of output for both:
assertArrayEquals(out2d.shape(), new long[] {miniBatchSize * timeSeriesLength, nOut});
// assertArrayEquals(rnnol.getInput().shape(),new int[]{miniBatchSize,layerSize,timeSeriesLength});
// Input may be set by BaseLayer methods. Thus input may end up as reshaped 2d version instead of original 3d version.
// Not ideal, but everything else works.
assertArrayEquals(rnnol.getLabels().shape(), new long[] { miniBatchSize, nOut, timeSeriesLength });
// Check shapes of output for both:
assertArrayEquals(out2d.shape(), new long[] { miniBatchSize * timeSeriesLength, nOut });
INDArray out = mln.output(input);
assertArrayEquals(out.shape(), new long[] {miniBatchSize * timeSeriesLength, nOut});
assertArrayEquals(out.shape(), new long[] { miniBatchSize * timeSeriesLength, nOut });
INDArray preout = mln.output(input);
assertArrayEquals(preout.shape(), new long[] {miniBatchSize * timeSeriesLength, nOut});
assertArrayEquals(preout.shape(), new long[] { miniBatchSize * timeSeriesLength, nOut });
INDArray outFFRnn = mlnRnn.feedForward(input).get(2);
assertArrayEquals(outFFRnn.shape(), new long[] {miniBatchSize, nOut, timeSeriesLength});
assertArrayEquals(outFFRnn.shape(), new long[] { miniBatchSize, nOut, timeSeriesLength });
INDArray outRnn2 = mlnRnn.output(input);
assertArrayEquals(outRnn2.shape(), new long[] {miniBatchSize, nOut, timeSeriesLength});
assertArrayEquals(outRnn2.shape(), new long[] { miniBatchSize, nOut, timeSeriesLength });
INDArray preoutRnn = mlnRnn.output(input);
assertArrayEquals(preoutRnn.shape(), new long[] {miniBatchSize, nOut, timeSeriesLength});
assertArrayEquals(preoutRnn.shape(), new long[] { miniBatchSize, nOut, timeSeriesLength });
}
}
@Test
public void testCompareRnnOutputRnnLoss(){
@DisplayName("Test Compare Rnn Output Rnn Loss")
void testCompareRnnOutputRnnLoss() {
Nd4j.getRandom().setSeed(12345);
int timeSeriesLength = 4;
int nIn = 5;
int layerSize = 6;
int nOut = 6;
int miniBatchSize = 3;
MultiLayerConfiguration conf1 =
new NeuralNetConfiguration.Builder().seed(12345L)
.updater(new NoOp())
.list()
.layer(new LSTM.Builder().nIn(nIn).nOut(layerSize).activation(Activation.TANH)
.dist(new NormalDistribution(0, 1.0))
.updater(new NoOp()).build())
.layer(new DenseLayer.Builder().nIn(layerSize).nOut(nOut).activation(Activation.IDENTITY).build())
.layer(new RnnLossLayer.Builder(LossFunction.MCXENT)
.activation(Activation.SOFTMAX)
.build())
.build();
MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(12345L).updater(new NoOp()).list().layer(new LSTM.Builder().nIn(nIn).nOut(layerSize).activation(Activation.TANH).dist(new NormalDistribution(0, 1.0)).updater(new NoOp()).build()).layer(new DenseLayer.Builder().nIn(layerSize).nOut(nOut).activation(Activation.IDENTITY).build()).layer(new RnnLossLayer.Builder(LossFunction.MCXENT).activation(Activation.SOFTMAX).build()).build();
MultiLayerNetwork mln = new MultiLayerNetwork(conf1);
mln.init();
MultiLayerConfiguration conf2 =
new NeuralNetConfiguration.Builder().seed(12345L)
.updater(new NoOp())
.list()
.layer(new LSTM.Builder().nIn(nIn).nOut(layerSize).activation(Activation.TANH)
.dist(new NormalDistribution(0, 1.0))
.updater(new NoOp()).build())
.layer(new org.deeplearning4j.nn.conf.layers.RnnOutputLayer.Builder(LossFunction.MCXENT)
.activation(Activation.SOFTMAX)
.nIn(layerSize).nOut(nOut)
.build())
.build();
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345L).updater(new NoOp()).list().layer(new LSTM.Builder().nIn(nIn).nOut(layerSize).activation(Activation.TANH).dist(new NormalDistribution(0, 1.0)).updater(new NoOp()).build()).layer(new org.deeplearning4j.nn.conf.layers.RnnOutputLayer.Builder(LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(layerSize).nOut(nOut).build()).build();
MultiLayerNetwork mln2 = new MultiLayerNetwork(conf2);
mln2.init();
mln2.setParams(mln.params());
INDArray in = Nd4j.rand(new int[]{miniBatchSize, nIn, timeSeriesLength});
INDArray in = Nd4j.rand(new int[] { miniBatchSize, nIn, timeSeriesLength });
INDArray out1 = mln.output(in);
INDArray out2 = mln.output(in);
assertEquals(out1, out2);
Random r = new Random(12345);
INDArray labels = Nd4j.create(miniBatchSize, nOut, timeSeriesLength);
for( int i=0; i<miniBatchSize; i++ ){
for( int j=0; j<timeSeriesLength; j++ ){
for (int i = 0; i < miniBatchSize; i++) {
for (int j = 0; j < timeSeriesLength; j++) {
labels.putScalar(i, r.nextInt(nOut), j, 1.0);
}
}
mln.setInput(in);
mln.setLabels(labels);
mln2.setInput(in);
mln2.setLabels(labels);
mln.computeGradientAndScore();
mln2.computeGradientAndScore();
assertEquals(mln.gradient().gradient(), mln2.gradient().gradient());
assertEquals(mln.score(), mln2.score(), 1e-6);
TestUtils.testModelSerialization(mln);
}
@Test
public void testCnnLossLayer(){
for(WorkspaceMode ws : WorkspaceMode.values()) {
@DisplayName("Test Cnn Loss Layer")
void testCnnLossLayer() {
for (WorkspaceMode ws : WorkspaceMode.values()) {
log.info("*** Testing workspace: " + ws);
for (Activation a : new Activation[]{Activation.TANH, Activation.SELU}) {
//Check that (A+identity) is equal to (identity+A), for activation A
//i.e., should get same output and weight gradients for both
MultiLayerConfiguration conf1 =
new NeuralNetConfiguration.Builder().seed(12345L)
.updater(new NoOp())
.convolutionMode(ConvolutionMode.Same)
.inferenceWorkspaceMode(ws)
.trainingWorkspaceMode(ws)
.list()
.layer(new ConvolutionLayer.Builder().nIn(3).nOut(4).activation(Activation.IDENTITY)
.kernelSize(2, 2).stride(1, 1)
.dist(new NormalDistribution(0, 1.0))
.updater(new NoOp()).build())
.layer(new CnnLossLayer.Builder(LossFunction.MSE)
.activation(a)
.build())
.build();
MultiLayerConfiguration conf2 =
new NeuralNetConfiguration.Builder().seed(12345L)
.updater(new NoOp())
.convolutionMode(ConvolutionMode.Same)
.inferenceWorkspaceMode(ws)
.trainingWorkspaceMode(ws)
.list()
.layer(new ConvolutionLayer.Builder().nIn(3).nOut(4).activation(a)
.kernelSize(2, 2).stride(1, 1)
.dist(new NormalDistribution(0, 1.0))
.updater(new NoOp()).build())
.layer(new CnnLossLayer.Builder(LossFunction.MSE)
.activation(Activation.IDENTITY)
.build())
.build();
for (Activation a : new Activation[] { Activation.TANH, Activation.SELU }) {
// Check that (A+identity) is equal to (identity+A), for activation A
// i.e., should get same output and weight gradients for both
MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(12345L).updater(new NoOp()).convolutionMode(ConvolutionMode.Same).inferenceWorkspaceMode(ws).trainingWorkspaceMode(ws).list().layer(new ConvolutionLayer.Builder().nIn(3).nOut(4).activation(Activation.IDENTITY).kernelSize(2, 2).stride(1, 1).dist(new NormalDistribution(0, 1.0)).updater(new NoOp()).build()).layer(new CnnLossLayer.Builder(LossFunction.MSE).activation(a).build()).build();
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345L).updater(new NoOp()).convolutionMode(ConvolutionMode.Same).inferenceWorkspaceMode(ws).trainingWorkspaceMode(ws).list().layer(new ConvolutionLayer.Builder().nIn(3).nOut(4).activation(a).kernelSize(2, 2).stride(1, 1).dist(new NormalDistribution(0, 1.0)).updater(new NoOp()).build()).layer(new CnnLossLayer.Builder(LossFunction.MSE).activation(Activation.IDENTITY).build()).build();
MultiLayerNetwork mln = new MultiLayerNetwork(conf1);
mln.init();
MultiLayerNetwork mln2 = new MultiLayerNetwork(conf2);
mln2.init();
mln2.setParams(mln.params());
INDArray in = Nd4j.rand(new int[]{3, 3, 5, 5});
INDArray in = Nd4j.rand(new int[] { 3, 3, 5, 5 });
INDArray out1 = mln.output(in);
INDArray out2 = mln2.output(in);
assertEquals(out1, out2);
INDArray labels = Nd4j.rand(out1.shape());
mln.setInput(in);
mln.setLabels(labels);
mln2.setInput(in);
mln2.setLabels(labels);
mln.computeGradientAndScore();
mln2.computeGradientAndScore();
assertEquals(mln.score(), mln2.score(), 1e-6);
assertEquals(mln.gradient().gradient(), mln2.gradient().gradient());
//Also check computeScoreForExamples
INDArray in2a = Nd4j.rand(new int[]{1, 3, 5, 5});
INDArray labels2a = Nd4j.rand(new int[]{1, 4, 5, 5});
// Also check computeScoreForExamples
INDArray in2a = Nd4j.rand(new int[] { 1, 3, 5, 5 });
INDArray labels2a = Nd4j.rand(new int[] { 1, 4, 5, 5 });
INDArray in2 = Nd4j.concat(0, in2a, in2a);
INDArray labels2 = Nd4j.concat(0, labels2a, labels2a);
INDArray s = mln.scoreExamples(new DataSet(in2, labels2), false);
assertArrayEquals(new long[]{2, 1}, s.shape());
assertArrayEquals(new long[] { 2, 1 }, s.shape());
assertEquals(s.getDouble(0), s.getDouble(1), 1e-6);
TestUtils.testModelSerialization(mln);
}
}
}
@Test
public void testCnnLossLayerCompGraph(){
for(WorkspaceMode ws : WorkspaceMode.values()) {
@DisplayName("Test Cnn Loss Layer Comp Graph")
void testCnnLossLayerCompGraph() {
for (WorkspaceMode ws : WorkspaceMode.values()) {
log.info("*** Testing workspace: " + ws);
for (Activation a : new Activation[]{Activation.TANH, Activation.SELU}) {
//Check that (A+identity) is equal to (identity+A), for activation A
//i.e., should get same output and weight gradients for both
ComputationGraphConfiguration conf1 =
new NeuralNetConfiguration.Builder().seed(12345L)
.updater(new NoOp())
.convolutionMode(ConvolutionMode.Same)
.inferenceWorkspaceMode(ws)
.trainingWorkspaceMode(ws)
.graphBuilder()
.addInputs("in")
.addLayer("0", new ConvolutionLayer.Builder().nIn(3).nOut(4).activation(Activation.IDENTITY)
.kernelSize(2, 2).stride(1, 1)
.dist(new NormalDistribution(0, 1.0))
.updater(new NoOp()).build(), "in")
.addLayer("1", new CnnLossLayer.Builder(LossFunction.MSE)
.activation(a)
.build(), "0")
.setOutputs("1")
.build();
ComputationGraphConfiguration conf2 =
new NeuralNetConfiguration.Builder().seed(12345L)
.updater(new NoOp())
.convolutionMode(ConvolutionMode.Same)
.inferenceWorkspaceMode(ws)
.trainingWorkspaceMode(ws)
.graphBuilder()
.addInputs("in")
.addLayer("0", new ConvolutionLayer.Builder().nIn(3).nOut(4).activation(a)
.kernelSize(2, 2).stride(1, 1)
.dist(new NormalDistribution(0, 1.0))
.updater(new NoOp()).build(), "in")
.addLayer("1", new CnnLossLayer.Builder(LossFunction.MSE)
.activation(Activation.IDENTITY)
.build(), "0")
.setOutputs("1")
.build();
for (Activation a : new Activation[] { Activation.TANH, Activation.SELU }) {
// Check that (A+identity) is equal to (identity+A), for activation A
// i.e., should get same output and weight gradients for both
ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(12345L).updater(new NoOp()).convolutionMode(ConvolutionMode.Same).inferenceWorkspaceMode(ws).trainingWorkspaceMode(ws).graphBuilder().addInputs("in").addLayer("0", new ConvolutionLayer.Builder().nIn(3).nOut(4).activation(Activation.IDENTITY).kernelSize(2, 2).stride(1, 1).dist(new NormalDistribution(0, 1.0)).updater(new NoOp()).build(), "in").addLayer("1", new CnnLossLayer.Builder(LossFunction.MSE).activation(a).build(), "0").setOutputs("1").build();
ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345L).updater(new NoOp()).convolutionMode(ConvolutionMode.Same).inferenceWorkspaceMode(ws).trainingWorkspaceMode(ws).graphBuilder().addInputs("in").addLayer("0", new ConvolutionLayer.Builder().nIn(3).nOut(4).activation(a).kernelSize(2, 2).stride(1, 1).dist(new NormalDistribution(0, 1.0)).updater(new NoOp()).build(), "in").addLayer("1", new CnnLossLayer.Builder(LossFunction.MSE).activation(Activation.IDENTITY).build(), "0").setOutputs("1").build();
ComputationGraph graph = new ComputationGraph(conf1);
graph.init();
ComputationGraph graph2 = new ComputationGraph(conf2);
graph2.init();
graph2.setParams(graph.params());
INDArray in = Nd4j.rand(new int[]{3, 3, 5, 5});
INDArray in = Nd4j.rand(new int[] { 3, 3, 5, 5 });
INDArray out1 = graph.outputSingle(in);
INDArray out2 = graph2.outputSingle(in);
assertEquals(out1, out2);
INDArray labels = Nd4j.rand(out1.shape());
graph.setInput(0,in);
graph.setInput(0, in);
graph.setLabels(labels);
graph2.setInput(0,in);
graph2.setInput(0, in);
graph2.setLabels(labels);
graph.computeGradientAndScore();
graph2.computeGradientAndScore();
assertEquals(graph.score(), graph2.score(), 1e-6);
assertEquals(graph.gradient().gradient(), graph2.gradient().gradient());
//Also check computeScoreForExamples
INDArray in2a = Nd4j.rand(new int[]{1, 3, 5, 5});
INDArray labels2a = Nd4j.rand(new int[]{1, 4, 5, 5});
// Also check computeScoreForExamples
INDArray in2a = Nd4j.rand(new int[] { 1, 3, 5, 5 });
INDArray labels2a = Nd4j.rand(new int[] { 1, 4, 5, 5 });
INDArray in2 = Nd4j.concat(0, in2a, in2a);
INDArray labels2 = Nd4j.concat(0, labels2a, labels2a);
INDArray s = graph.scoreExamples(new DataSet(in2, labels2), false);
assertArrayEquals(new long[]{2, 1}, s.shape());
assertArrayEquals(new long[] { 2, 1 }, s.shape());
assertEquals(s.getDouble(0), s.getDouble(1), 1e-6);
TestUtils.testModelSerialization(graph);
}
}
}
@Test
public void testCnnOutputLayerSoftmax(){
//Check that softmax is applied channels-wise
MultiLayerConfiguration conf =
new NeuralNetConfiguration.Builder().seed(12345L)
.updater(new NoOp())
.convolutionMode(ConvolutionMode.Same)
.list()
.layer(new ConvolutionLayer.Builder().nIn(3).nOut(4).activation(Activation.IDENTITY)
.dist(new NormalDistribution(0, 1.0))
.updater(new NoOp()).build())
.layer(new CnnLossLayer.Builder(LossFunction.MSE)
.activation(Activation.SOFTMAX)
.build())
.build();
@DisplayName("Test Cnn Output Layer Softmax")
void testCnnOutputLayerSoftmax() {
// Check that softmax is applied channels-wise
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345L).updater(new NoOp()).convolutionMode(ConvolutionMode.Same).list().layer(new ConvolutionLayer.Builder().nIn(3).nOut(4).activation(Activation.IDENTITY).dist(new NormalDistribution(0, 1.0)).updater(new NoOp()).build()).layer(new CnnLossLayer.Builder(LossFunction.MSE).activation(Activation.SOFTMAX).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
INDArray in = Nd4j.rand(new int[]{2,3,4,5});
INDArray in = Nd4j.rand(new int[] { 2, 3, 4, 5 });
INDArray out = net.output(in);
double min = out.minNumber().doubleValue();
double max = out.maxNumber().doubleValue();
assertTrue(min >= 0 && max <= 1.0);
INDArray sum = out.sum(1);
assertEquals(Nd4j.ones(DataType.FLOAT,2,4,5), sum);
assertEquals(Nd4j.ones(DataType.FLOAT, 2, 4, 5), sum);
}
@Test
public void testOutputLayerDefaults(){
new NeuralNetConfiguration.Builder().list()
.layer(new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder().nIn(10).nOut(10).build())
.build();
new NeuralNetConfiguration.Builder().list()
.layer(new org.deeplearning4j.nn.conf.layers.LossLayer.Builder().build())
.build();
new NeuralNetConfiguration.Builder().list()
.layer(new org.deeplearning4j.nn.conf.layers.CnnLossLayer.Builder().build())
.build();
new NeuralNetConfiguration.Builder().list()
.layer(new org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer.Builder().build())
.build();
@DisplayName("Test Output Layer Defaults")
void testOutputLayerDefaults() {
new NeuralNetConfiguration.Builder().list().layer(new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder().nIn(10).nOut(10).build()).build();
new NeuralNetConfiguration.Builder().list().layer(new org.deeplearning4j.nn.conf.layers.LossLayer.Builder().build()).build();
new NeuralNetConfiguration.Builder().list().layer(new org.deeplearning4j.nn.conf.layers.CnnLossLayer.Builder().build()).build();
new NeuralNetConfiguration.Builder().list().layer(new org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer.Builder().build()).build();
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.nn.layers;
import org.deeplearning4j.BaseDL4JTest;
@ -26,47 +25,41 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.misc.RepeatVector;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.primitives.Pair;
import java.util.Arrays;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
public class RepeatVectorTest extends BaseDL4JTest {
@DisplayName("Repeat Vector Test")
class RepeatVectorTest extends BaseDL4JTest {
private int REPEAT = 4;
private Layer getRepeatVectorLayer() {
NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().seed(123)
.dataType(DataType.DOUBLE)
.layer(new RepeatVector.Builder(REPEAT).build()).build();
return conf.getLayer().instantiate(conf, null, 0,
null, false, DataType.DOUBLE);
NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().seed(123).dataType(DataType.DOUBLE).layer(new RepeatVector.Builder(REPEAT).build()).build();
return conf.getLayer().instantiate(conf, null, 0, null, false, DataType.DOUBLE);
}
@Test
public void testRepeatVector() {
double[] arr = new double[] {1., 2., 3., 1., 2., 3., 1., 2., 3., 1., 2., 3.};
INDArray expectedOut = Nd4j.create(arr, new long[] {1, 3, REPEAT}, 'f');
INDArray input = Nd4j.create(new double[] {1., 2., 3.}, new long[] {1, 3});
@DisplayName("Test Repeat Vector")
void testRepeatVector() {
double[] arr = new double[] { 1., 2., 3., 1., 2., 3., 1., 2., 3., 1., 2., 3. };
INDArray expectedOut = Nd4j.create(arr, new long[] { 1, 3, REPEAT }, 'f');
INDArray input = Nd4j.create(new double[] { 1., 2., 3. }, new long[] { 1, 3 });
Layer layer = getRepeatVectorLayer();
INDArray output = layer.activate(input, false, LayerWorkspaceMgr.noWorkspaces());
assertTrue(Arrays.equals(expectedOut.shape(), output.shape()));
assertEquals(expectedOut, output);
INDArray epsilon = Nd4j.ones(1,3,4);
INDArray epsilon = Nd4j.ones(1, 3, 4);
Pair<Gradient, INDArray> out = layer.backpropGradient(epsilon, LayerWorkspaceMgr.noWorkspaces());
INDArray outEpsilon = out.getSecond();
INDArray expectedEpsilon = Nd4j.create(new double[] {4., 4., 4.}, new long[] {1, 3});
INDArray expectedEpsilon = Nd4j.create(new double[] { 4., 4., 4. }, new long[] { 1, 3 });
assertEquals(expectedEpsilon, outEpsilon);
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.nn.layers;
import org.deeplearning4j.BaseDL4JTest;
@ -25,45 +24,41 @@ import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.AutoEncoder;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
/**
*/
public class SeedTest extends BaseDL4JTest {
@DisplayName("Seed Test")
class SeedTest extends BaseDL4JTest {
private DataSetIterator irisIter = new IrisDataSetIterator(50, 50);
private DataSet data = irisIter.next();
@Test
public void testAutoEncoderSeed() {
AutoEncoder layerType = new AutoEncoder.Builder().nIn(4).nOut(3).corruptionLevel(0.0)
.activation(Activation.SIGMOID).build();
NeuralNetConfiguration conf =
new NeuralNetConfiguration.Builder().layer(layerType).seed(123).build();
@DisplayName("Test Auto Encoder Seed")
void testAutoEncoderSeed() {
AutoEncoder layerType = new AutoEncoder.Builder().nIn(4).nOut(3).corruptionLevel(0.0).activation(Activation.SIGMOID).build();
NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().layer(layerType).seed(123).build();
long numParams = conf.getLayer().initializer().numParams(conf);
INDArray params = Nd4j.create(1, numParams);
Layer layer = conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType());
layer.setBackpropGradientsViewArray(Nd4j.create(1, numParams));
layer.fit(data.getFeatures(), LayerWorkspaceMgr.noWorkspaces());
layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
double score = layer.score();
INDArray parameters = layer.params();
layer.setParams(parameters);
layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces());
double score2 = layer.score();
assertEquals(parameters, layer.params());
assertEquals(score, score2, 1e-4);

View File

@ -17,11 +17,9 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.nn.layers.capsule;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertTrue;
import java.io.IOException;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
@ -35,64 +33,44 @@ import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.LossLayer;
import org.deeplearning4j.nn.conf.layers.PrimaryCapsules;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.impl.ActivationSoftmax;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@Ignore("AB - ignored due to excessive runtime. Keep for manual debugging when required")
public class CapsNetMNISTTest extends BaseDL4JTest {
@Disabled("AB - ignored due to excessive runtime. Keep for manual debugging when required")
@DisplayName("Caps Net MNIST Test")
class CapsNetMNISTTest extends BaseDL4JTest {
@Override
public DataType getDataType(){
public DataType getDataType() {
return DataType.FLOAT;
}
@Test
public void testCapsNetOnMNIST(){
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(123)
.updater(new Adam())
.list()
.layer(new ConvolutionLayer.Builder()
.nOut(16)
.kernelSize(9, 9)
.stride(3, 3)
.build())
.layer(new PrimaryCapsules.Builder(8, 8)
.kernelSize(7, 7)
.stride(2, 2)
.build())
.layer(new CapsuleLayer.Builder(10, 16, 3).build())
.layer(new CapsuleStrengthLayer.Builder().build())
.layer(new ActivationLayer.Builder(new ActivationSoftmax()).build())
.layer(new LossLayer.Builder(new LossNegativeLogLikelihood()).build())
.setInputType(InputType.convolutionalFlat(28, 28, 1))
.build();
@DisplayName("Test Caps Net On MNIST")
void testCapsNetOnMNIST() {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(123).updater(new Adam()).list().layer(new ConvolutionLayer.Builder().nOut(16).kernelSize(9, 9).stride(3, 3).build()).layer(new PrimaryCapsules.Builder(8, 8).kernelSize(7, 7).stride(2, 2).build()).layer(new CapsuleLayer.Builder(10, 16, 3).build()).layer(new CapsuleStrengthLayer.Builder().build()).layer(new ActivationLayer.Builder(new ActivationSoftmax()).build()).layer(new LossLayer.Builder(new LossNegativeLogLikelihood()).build()).setInputType(InputType.convolutionalFlat(28, 28, 1)).build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
int rngSeed = 12345;
try {
MnistDataSetIterator mnistTrain = new MnistDataSetIterator(64, true, rngSeed);
MnistDataSetIterator mnistTest = new MnistDataSetIterator(64, false, rngSeed);
for (int i = 0; i < 2; i++) {
model.fit(mnistTrain);
}
Evaluation eval = model.evaluate(mnistTest);
assertTrue("Accuracy not over 95%", eval.accuracy() > 0.95);
assertTrue("Precision not over 95%", eval.precision() > 0.95);
assertTrue("Recall not over 95%", eval.recall() > 0.95);
assertTrue("F1-score not over 95%", eval.f1() > 0.95);
} catch (IOException e){
assertTrue(eval.accuracy() > 0.95, "Accuracy not over 95%");
assertTrue(eval.precision() > 0.95, "Precision not over 95%");
assertTrue(eval.recall() > 0.95, "Recall not over 95%");
assertTrue(eval.f1() > 0.95, "F1-score not over 95%");
} catch (IOException e) {
System.out.println("Could not load MNIST.");
}
}

View File

@ -17,84 +17,71 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.nn.layers.capsule;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.CapsuleLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
public class CapsuleLayerTest extends BaseDL4JTest {
@DisplayName("Capsule Layer Test")
class CapsuleLayerTest extends BaseDL4JTest {
@Override
public DataType getDataType(){
public DataType getDataType() {
return DataType.FLOAT;
}
@Test
public void testOutputType(){
@DisplayName("Test Output Type")
void testOutputType() {
CapsuleLayer layer = new CapsuleLayer.Builder(10, 16, 5).build();
InputType in1 = InputType.recurrent(5, 8);
assertEquals(InputType.recurrent(10, 16), layer.getOutputType(0, in1));
}
@Test
public void testInputType(){
@DisplayName("Test Input Type")
void testInputType() {
CapsuleLayer layer = new CapsuleLayer.Builder(10, 16, 5).build();
InputType in1 = InputType.recurrent(5, 8);
layer.setNIn(in1, true);
assertEquals(5, layer.getInputCapsules());
assertEquals(8, layer.getInputCapsuleDimensions());
}
@Test
public void testConfig(){
@DisplayName("Test Config")
void testConfig() {
CapsuleLayer layer1 = new CapsuleLayer.Builder(10, 16, 5).build();
assertEquals(10, layer1.getCapsules());
assertEquals(16, layer1.getCapsuleDimensions());
assertEquals(5, layer1.getRoutings());
assertFalse(layer1.isHasBias());
CapsuleLayer layer2 = new CapsuleLayer.Builder(10, 16, 5).hasBias(true).build();
assertTrue(layer2.isHasBias());
}
@Test
public void testLayer(){
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(123)
.list()
.layer(new CapsuleLayer.Builder(10, 16, 3).build())
.setInputType(InputType.recurrent(10, 8))
.build();
@DisplayName("Test Layer")
void testLayer() {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(123).list().layer(new CapsuleLayer.Builder(10, 16, 3).build()).setInputType(InputType.recurrent(10, 8)).build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
INDArray emptyFeatures = Nd4j.zeros(64, 10, 8);
long[] shape = model.output(emptyFeatures).shape();
assertArrayEquals(new long[]{64, 10, 16}, shape);
assertArrayEquals(new long[] { 64, 10, 16 }, shape);
}
}

Some files were not shown because too many files have changed in this diff Show More