commit
1930d99908
|
@ -98,7 +98,7 @@ public class TestGraphLocalExecution extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long getTimeoutMilliseconds() {
|
public long getTimeoutMilliseconds() {
|
||||||
return 90000L;
|
return 120_000L;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -156,7 +156,7 @@ public class TestGraphLocalExecution extends BaseDL4JTest {
|
||||||
.dataSource(ds, dsP)
|
.dataSource(ds, dsP)
|
||||||
.modelSaver(new FileModelSaver(modelSave))
|
.modelSaver(new FileModelSaver(modelSave))
|
||||||
.scoreFunction(new TestSetLossScoreFunction())
|
.scoreFunction(new TestSetLossScoreFunction())
|
||||||
.terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS),
|
.terminationConditions(new MaxTimeCondition(20, TimeUnit.SECONDS),
|
||||||
new MaxCandidatesCondition(3))
|
new MaxCandidatesCondition(3))
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
|
|
@ -87,7 +87,7 @@ public class TestGraphLocalExecutionGenetic extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long getTimeoutMilliseconds() {
|
public long getTimeoutMilliseconds() {
|
||||||
return 45000L;
|
return 120_000L;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -154,8 +154,8 @@ public class TestGraphLocalExecutionGenetic extends BaseDL4JTest {
|
||||||
.dataSource(ds, dsP)
|
.dataSource(ds, dsP)
|
||||||
.modelSaver(new FileModelSaver(modelSave))
|
.modelSaver(new FileModelSaver(modelSave))
|
||||||
.scoreFunction(new TestSetLossScoreFunction())
|
.scoreFunction(new TestSetLossScoreFunction())
|
||||||
.terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS),
|
.terminationConditions(new MaxTimeCondition(20, TimeUnit.SECONDS),
|
||||||
new MaxCandidatesCondition(10))
|
new MaxCandidatesCondition(3))
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
IOptimizationRunner runner = new LocalOptimizationRunner(configuration, new ComputationGraphTaskCreator(new ClassificationEvaluator()));
|
IOptimizationRunner runner = new LocalOptimizationRunner(configuration, new ComputationGraphTaskCreator(new ClassificationEvaluator()));
|
||||||
|
|
|
@ -18,6 +18,7 @@ package org.datavec.api.records.reader.impl;
|
||||||
|
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.io.IOUtils;
|
import org.apache.commons.io.IOUtils;
|
||||||
import org.apache.commons.io.LineIterator;
|
import org.apache.commons.io.LineIterator;
|
||||||
import org.datavec.api.conf.Configuration;
|
import org.datavec.api.conf.Configuration;
|
||||||
|
@ -43,6 +44,7 @@ import java.util.*;
|
||||||
*
|
*
|
||||||
* @author Adam Gibson
|
* @author Adam Gibson
|
||||||
*/
|
*/
|
||||||
|
@Slf4j
|
||||||
public class LineRecordReader extends BaseRecordReader {
|
public class LineRecordReader extends BaseRecordReader {
|
||||||
|
|
||||||
|
|
||||||
|
@ -58,6 +60,13 @@ public class LineRecordReader extends BaseRecordReader {
|
||||||
@Override
|
@Override
|
||||||
public void initialize(InputSplit split) throws IOException, InterruptedException {
|
public void initialize(InputSplit split) throws IOException, InterruptedException {
|
||||||
super.initialize(split);
|
super.initialize(split);
|
||||||
|
if(!(inputSplit instanceof StringSplit || inputSplit instanceof InputStreamInputSplit)){
|
||||||
|
final ArrayList<URI> uris = new ArrayList<>();
|
||||||
|
final Iterator<URI> uriIterator = inputSplit.locationsIterator();
|
||||||
|
while(uriIterator.hasNext()) uris.add(uriIterator.next());
|
||||||
|
|
||||||
|
this.locations = uris.toArray(new URI[0]);
|
||||||
|
}
|
||||||
this.iter = getIterator(0);
|
this.iter = getIterator(0);
|
||||||
this.initialized = true;
|
this.initialized = true;
|
||||||
}
|
}
|
||||||
|
@ -66,7 +75,6 @@ public class LineRecordReader extends BaseRecordReader {
|
||||||
public void initialize(Configuration conf, InputSplit split) throws IOException, InterruptedException {
|
public void initialize(Configuration conf, InputSplit split) throws IOException, InterruptedException {
|
||||||
this.conf = conf;
|
this.conf = conf;
|
||||||
initialize(split);
|
initialize(split);
|
||||||
this.initialized = true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -89,7 +97,7 @@ public class LineRecordReader extends BaseRecordReader {
|
||||||
iter = getIterator(splitIndex);
|
iter = getIterator(splitIndex);
|
||||||
onLocationOpen(locations[splitIndex]);
|
onLocationOpen(locations[splitIndex]);
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (iter.hasNext()) {
|
if (iter.hasNext()) {
|
||||||
|
@ -120,7 +128,7 @@ public class LineRecordReader extends BaseRecordReader {
|
||||||
iter = getIterator(splitIndex);
|
iter = getIterator(splitIndex);
|
||||||
onLocationOpen(locations[splitIndex]);
|
onLocationOpen(locations[splitIndex]);
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
}
|
}
|
||||||
|
|
||||||
return iter.hasNext();
|
return iter.hasNext();
|
||||||
|
@ -205,11 +213,6 @@ public class LineRecordReader extends BaseRecordReader {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
final ArrayList<URI> uris = new ArrayList<>();
|
|
||||||
final Iterator<URI> uriIterator = inputSplit.locationsIterator();
|
|
||||||
while(uriIterator.hasNext()) uris.add(uriIterator.next());
|
|
||||||
|
|
||||||
this.locations = uris.toArray(new URI[uris.size()]);
|
|
||||||
if (locations.length > 0) {
|
if (locations.length > 0) {
|
||||||
InputStream inputStream = streamCreatorFn.apply(locations[location]);
|
InputStream inputStream = streamCreatorFn.apply(locations[location]);
|
||||||
try {
|
try {
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.datavec.api.split;
|
package org.datavec.api.split;
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.datavec.api.util.files.UriFromPathIterator;
|
import org.datavec.api.util.files.UriFromPathIterator;
|
||||||
import org.datavec.api.writable.WritableType;
|
import org.datavec.api.writable.WritableType;
|
||||||
|
|
||||||
|
@ -34,6 +35,7 @@ import java.util.regex.Pattern;
|
||||||
* NumberedFileInputSplit utilizes String.format(), hence the requirement for "%d" to represent
|
* NumberedFileInputSplit utilizes String.format(), hence the requirement for "%d" to represent
|
||||||
* the integer index.
|
* the integer index.
|
||||||
*/
|
*/
|
||||||
|
@Slf4j
|
||||||
public class NumberedFileInputSplit implements InputSplit {
|
public class NumberedFileInputSplit implements InputSplit {
|
||||||
private final String baseString;
|
private final String baseString;
|
||||||
private final int minIdx;
|
private final int minIdx;
|
||||||
|
@ -93,7 +95,7 @@ public class NumberedFileInputSplit implements InputSplit {
|
||||||
try {
|
try {
|
||||||
writeFile.createNewFile();
|
writeFile.createNewFile();
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,7 @@ import java.net.URI;
|
||||||
import java.net.URISyntaxException;
|
import java.net.URISyntaxException;
|
||||||
import java.util.Iterator;
|
import java.util.Iterator;
|
||||||
import java.util.NoSuchElementException;
|
import java.util.NoSuchElementException;
|
||||||
|
import java.util.regex.Pattern;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A simple utility method to convert a {@code Iterator<String>} to an {@code Iterator<URI>}, where each
|
* A simple utility method to convert a {@code Iterator<String>} to an {@code Iterator<URI>}, where each
|
||||||
|
@ -32,6 +33,7 @@ import java.util.NoSuchElementException;
|
||||||
*/
|
*/
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
public class UriFromPathIterator implements Iterator<URI> {
|
public class UriFromPathIterator implements Iterator<URI> {
|
||||||
|
final Pattern schemaPattern = Pattern.compile("^.*?:/.*");
|
||||||
|
|
||||||
private final Iterator<String> paths;
|
private final Iterator<String> paths;
|
||||||
|
|
||||||
|
@ -42,16 +44,17 @@ public class UriFromPathIterator implements Iterator<URI> {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public URI next() {
|
public URI next() {
|
||||||
|
|
||||||
if (!hasNext()) {
|
if (!hasNext()) {
|
||||||
throw new NoSuchElementException("No next element");
|
throw new NoSuchElementException("No next element");
|
||||||
}
|
}
|
||||||
try {
|
try {
|
||||||
String s = paths.next();
|
String s = paths.next();
|
||||||
if(!s.matches(".*:/.*")){
|
if(schemaPattern.matcher(s).matches()){
|
||||||
|
return new URI(s);
|
||||||
|
} else {
|
||||||
//No scheme - assume file for backward compatibility
|
//No scheme - assume file for backward compatibility
|
||||||
return new File(s).toURI();
|
return new File(s).toURI();
|
||||||
} else {
|
|
||||||
return new URI(s);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} catch (URISyntaxException e) {
|
} catch (URISyntaxException e) {
|
||||||
|
|
|
@ -162,7 +162,6 @@ public class Text extends BinaryComparable implements WritableComparable<BinaryC
|
||||||
return -1; // not found
|
return -1; // not found
|
||||||
} catch (CharacterCodingException e) {
|
} catch (CharacterCodingException e) {
|
||||||
// can't get here
|
// can't get here
|
||||||
e.printStackTrace();
|
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.datavec.arrow.recordreader;
|
package org.datavec.arrow.recordreader;
|
||||||
|
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.datavec.api.conf.Configuration;
|
import org.datavec.api.conf.Configuration;
|
||||||
import org.datavec.api.records.Record;
|
import org.datavec.api.records.Record;
|
||||||
|
@ -50,6 +51,7 @@ import static org.datavec.arrow.ArrowConverter.readFromBytes;
|
||||||
* @author Adam Gibson
|
* @author Adam Gibson
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
|
@Slf4j
|
||||||
public class ArrowRecordReader implements RecordReader {
|
public class ArrowRecordReader implements RecordReader {
|
||||||
|
|
||||||
private InputSplit split;
|
private InputSplit split;
|
||||||
|
@ -132,7 +134,7 @@ public class ArrowRecordReader implements RecordReader {
|
||||||
currIdx++;
|
currIdx++;
|
||||||
this.currentPath = url;
|
this.currentPath = url;
|
||||||
}catch(Exception e) {
|
}catch(Exception e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -242,7 +244,7 @@ public class ArrowRecordReader implements RecordReader {
|
||||||
try {
|
try {
|
||||||
currentBatch.close();
|
currentBatch.close();
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.datavec.arrow;
|
package org.datavec.arrow;
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.apache.arrow.memory.BufferAllocator;
|
import org.apache.arrow.memory.BufferAllocator;
|
||||||
import org.apache.arrow.memory.RootAllocator;
|
import org.apache.arrow.memory.RootAllocator;
|
||||||
|
@ -56,7 +57,7 @@ import static junit.framework.TestCase.assertTrue;
|
||||||
import static org.junit.Assert.assertArrayEquals;
|
import static org.junit.Assert.assertArrayEquals;
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
import static org.junit.Assert.assertFalse;
|
import static org.junit.Assert.assertFalse;
|
||||||
|
@Slf4j
|
||||||
public class ArrowConverterTest extends BaseND4JTest {
|
public class ArrowConverterTest extends BaseND4JTest {
|
||||||
|
|
||||||
private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE);
|
private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE);
|
||||||
|
@ -343,7 +344,7 @@ public class ArrowConverterTest extends BaseND4JTest {
|
||||||
try(ArrowFileWriter arrowFileWriter = new ArrowFileWriter(schemaRoot1,null,newChannel(byteArrayOutputStream))) {
|
try(ArrowFileWriter arrowFileWriter = new ArrowFileWriter(schemaRoot1,null,newChannel(byteArrayOutputStream))) {
|
||||||
arrowFileWriter.writeBatch();
|
arrowFileWriter.writeBatch();
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
}
|
}
|
||||||
|
|
||||||
byte[] arr = byteArrayOutputStream.toByteArray();
|
byte[] arr = byteArrayOutputStream.toByteArray();
|
||||||
|
|
|
@ -61,7 +61,7 @@ public class Wave implements Serializable {
|
||||||
initWaveWithInputStream(inputStream);
|
initWaveWithInputStream(inputStream);
|
||||||
inputStream.close();
|
inputStream.close();
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
e.printStackTrace();
|
System.out.println(e.toString());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -96,7 +96,7 @@ public class Wave implements Serializable {
|
||||||
data = new byte[inputStream.available()];
|
data = new byte[inputStream.available()];
|
||||||
inputStream.read(data);
|
inputStream.read(data);
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
e.printStackTrace();
|
System.err.println(e.toString());
|
||||||
}
|
}
|
||||||
// end load data
|
// end load data
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -16,10 +16,12 @@
|
||||||
|
|
||||||
package org.datavec.audio;
|
package org.datavec.audio;
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
|
||||||
import java.io.FileOutputStream;
|
import java.io.FileOutputStream;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
public class WaveFileManager {
|
public class WaveFileManager {
|
||||||
|
|
||||||
private Wave wave;
|
private Wave wave;
|
||||||
|
@ -78,7 +80,7 @@ public class WaveFileManager {
|
||||||
fos.write(wave.getBytes());
|
fos.write(wave.getBytes());
|
||||||
fos.close();
|
fos.close();
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,8 @@
|
||||||
|
|
||||||
package org.datavec.audio;
|
package org.datavec.audio;
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.InputStream;
|
import java.io.InputStream;
|
||||||
|
|
||||||
|
@ -25,6 +27,7 @@ import java.io.InputStream;
|
||||||
*
|
*
|
||||||
* @author Jacquet Wong
|
* @author Jacquet Wong
|
||||||
*/
|
*/
|
||||||
|
@Slf4j
|
||||||
public class WaveHeader {
|
public class WaveHeader {
|
||||||
|
|
||||||
public static final String RIFF_HEADER = "RIFF";
|
public static final String RIFF_HEADER = "RIFF";
|
||||||
|
@ -109,7 +112,7 @@ public class WaveHeader {
|
||||||
// dis.close();
|
// dis.close();
|
||||||
|
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.datavec.audio.fingerprint;
|
package org.datavec.audio.fingerprint;
|
||||||
|
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.datavec.audio.Wave;
|
import org.datavec.audio.Wave;
|
||||||
import org.datavec.audio.WaveHeader;
|
import org.datavec.audio.WaveHeader;
|
||||||
import org.datavec.audio.dsp.Resampler;
|
import org.datavec.audio.dsp.Resampler;
|
||||||
|
@ -38,6 +39,7 @@ import java.util.List;
|
||||||
* @author jacquet
|
* @author jacquet
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
|
@Slf4j
|
||||||
public class FingerprintManager {
|
public class FingerprintManager {
|
||||||
|
|
||||||
private FingerprintProperties fingerprintProperties = FingerprintProperties.getInstance();
|
private FingerprintProperties fingerprintProperties = FingerprintProperties.getInstance();
|
||||||
|
@ -153,7 +155,7 @@ public class FingerprintManager {
|
||||||
fingerprint = getFingerprintFromInputStream(fis);
|
fingerprint = getFingerprintFromInputStream(fis);
|
||||||
fis.close();
|
fis.close();
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
}
|
}
|
||||||
return fingerprint;
|
return fingerprint;
|
||||||
}
|
}
|
||||||
|
@ -170,7 +172,7 @@ public class FingerprintManager {
|
||||||
fingerprint = new byte[inputStream.available()];
|
fingerprint = new byte[inputStream.available()];
|
||||||
inputStream.read(fingerprint);
|
inputStream.read(fingerprint);
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
}
|
}
|
||||||
return fingerprint;
|
return fingerprint;
|
||||||
}
|
}
|
||||||
|
@ -190,7 +192,7 @@ public class FingerprintManager {
|
||||||
fileOutputStream.write(fingerprint);
|
fileOutputStream.write(fingerprint);
|
||||||
fileOutputStream.close();
|
fileOutputStream.close();
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -81,7 +81,7 @@ public abstract class BaseImageLoader implements Serializable {
|
||||||
String fileName = file.toString();
|
String fileName = file.toString();
|
||||||
if (fileName.endsWith(".tgz") || fileName.endsWith(".tar.gz") || fileName.endsWith(".gz")
|
if (fileName.endsWith(".tgz") || fileName.endsWith(".tar.gz") || fileName.endsWith(".gz")
|
||||||
|| fileName.endsWith(".zip"))
|
|| fileName.endsWith(".zip"))
|
||||||
ArchiveUtils.unzipFileTo(file.getAbsolutePath(), fullDir.getAbsolutePath());
|
ArchiveUtils.unzipFileTo(file.getAbsolutePath(), fullDir.getAbsolutePath(), false);
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
throw new IllegalStateException("Unable to fetch images", e);
|
throw new IllegalStateException("Unable to fetch images", e);
|
||||||
}
|
}
|
||||||
|
|
|
@ -186,7 +186,7 @@ public class CifarLoader extends NativeImageLoader implements Serializable {
|
||||||
labels.add(line);
|
labels.add(line);
|
||||||
}
|
}
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -300,7 +300,7 @@ public class CifarLoader extends NativeImageLoader implements Serializable {
|
||||||
try {
|
try {
|
||||||
FileUtils.write(meanVarPath, uMean + "," + uStd + "," + vMean + "," + vStd);
|
FileUtils.write(meanVarPath, uMean + "," + uStd + "," + vMean + "," + vStd);
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
}
|
}
|
||||||
meanStdStored = true;
|
meanStdStored = true;
|
||||||
} else if (uMean == 0 && meanStdStored) {
|
} else if (uMean == 0 && meanStdStored) {
|
||||||
|
@ -312,7 +312,7 @@ public class CifarLoader extends NativeImageLoader implements Serializable {
|
||||||
vStd = Double.parseDouble(values[3]);
|
vStd = Double.parseDouble(values[3]);
|
||||||
|
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (int i = 0; i < result.numExamples(); i++) {
|
for (int i = 0; i < result.numExamples(); i++) {
|
||||||
|
@ -356,12 +356,12 @@ public class CifarLoader extends NativeImageLoader implements Serializable {
|
||||||
dataSets.add(new DataSet(asMatrix(matConversion.getSecond()), matConversion.getFirst()));
|
dataSets.add(new DataSet(asMatrix(matConversion.getSecond()), matConversion.getFirst()));
|
||||||
batchNumCount++;
|
batchNumCount++;
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
}
|
}
|
||||||
|
|
||||||
if(dataSets.size() == 0){
|
if(dataSets.size() == 0){
|
||||||
|
|
|
@ -235,7 +235,7 @@ public class LFWLoader extends BaseImageLoader implements Serializable {
|
||||||
InputSplit data = train ? inputSplit[0] : inputSplit[1];
|
InputSplit data = train ? inputSplit[0] : inputSplit[1];
|
||||||
recordReader.initialize(data);
|
recordReader.initialize(data);
|
||||||
} catch (IOException | InterruptedException e) {
|
} catch (IOException | InterruptedException e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
}
|
}
|
||||||
return recordReader;
|
return recordReader;
|
||||||
}
|
}
|
||||||
|
@ -250,7 +250,7 @@ public class LFWLoader extends BaseImageLoader implements Serializable {
|
||||||
InputSplit data = train ? inputSplit[0] : inputSplit[1];
|
InputSplit data = train ? inputSplit[0] : inputSplit[1];
|
||||||
recordReader.initialize(data);
|
recordReader.initialize(data);
|
||||||
} catch (IOException | InterruptedException e) {
|
} catch (IOException | InterruptedException e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
}
|
}
|
||||||
return recordReader;
|
return recordReader;
|
||||||
}
|
}
|
||||||
|
|
|
@ -225,7 +225,7 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
|
||||||
finishedInputStreamSplit = true;
|
finishedInputStreamSplit = true;
|
||||||
return Arrays.<Writable>asList(ndArrayWritable);
|
return Arrays.<Writable>asList(ndArrayWritable);
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (iter != null) {
|
if (iter != null) {
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.datavec.image.loader;
|
package org.datavec.image.loader;
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.apache.commons.io.IOUtils;
|
import org.apache.commons.io.IOUtils;
|
||||||
import org.bytedeco.javacpp.Loader;
|
import org.bytedeco.javacpp.Loader;
|
||||||
|
@ -53,6 +54,7 @@ import static org.junit.Assert.fail;
|
||||||
*
|
*
|
||||||
* @author saudet
|
* @author saudet
|
||||||
*/
|
*/
|
||||||
|
@Slf4j
|
||||||
public class TestNativeImageLoader {
|
public class TestNativeImageLoader {
|
||||||
static final long seed = 10;
|
static final long seed = 10;
|
||||||
static final Random rng = new Random(seed);
|
static final Random rng = new Random(seed);
|
||||||
|
@ -123,7 +125,7 @@ public class TestNativeImageLoader {
|
||||||
try {
|
try {
|
||||||
array6 = loader5.asMatrix(new ClassPathResource(path2MitosisFile).getFile().getAbsolutePath());
|
array6 = loader5.asMatrix(new ClassPathResource(path2MitosisFile).getFile().getAbsolutePath());
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
fail();
|
fail();
|
||||||
}
|
}
|
||||||
assertEquals(5, array6.rank());
|
assertEquals(5, array6.rank());
|
||||||
|
@ -156,7 +158,7 @@ public class TestNativeImageLoader {
|
||||||
try {
|
try {
|
||||||
array8 = loader7.asMatrix(new ClassPathResource(path2MitosisFile).getFile().getAbsolutePath());
|
array8 = loader7.asMatrix(new ClassPathResource(path2MitosisFile).getFile().getAbsolutePath());
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
}
|
}
|
||||||
assertEquals(5, array8.rank());
|
assertEquals(5, array8.rank());
|
||||||
assertEquals(pages2, array8.size(0));
|
assertEquals(pages2, array8.size(0));
|
||||||
|
@ -172,7 +174,7 @@ public class TestNativeImageLoader {
|
||||||
try {
|
try {
|
||||||
array9 = loader8.asMatrix(new ClassPathResource(braintiff).getFile().getAbsolutePath());
|
array9 = loader8.asMatrix(new ClassPathResource(braintiff).getFile().getAbsolutePath());
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
fail();
|
fail();
|
||||||
}
|
}
|
||||||
assertEquals(5, array9.rank());
|
assertEquals(5, array9.rank());
|
||||||
|
|
|
@ -66,7 +66,7 @@ public class UimaTokenizer implements Tokenizer {
|
||||||
|
|
||||||
|
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.datavec.local.transforms.functions;
|
package org.datavec.local.transforms.functions;
|
||||||
|
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.datavec.api.records.reader.SequenceRecordReader;
|
import org.datavec.api.records.reader.SequenceRecordReader;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.nd4j.linalg.function.Function;
|
import org.nd4j.linalg.function.Function;
|
||||||
|
@ -32,6 +33,7 @@ import java.util.List;
|
||||||
* sequence data into a {@code List<List<Writable>>}
|
* sequence data into a {@code List<List<Writable>>}
|
||||||
* @author Alex Black
|
* @author Alex Black
|
||||||
*/
|
*/
|
||||||
|
@Slf4j
|
||||||
public class SequenceRecordReaderFunction
|
public class SequenceRecordReaderFunction
|
||||||
implements Function<Pair<String, InputStream>, List<List<Writable>>> {
|
implements Function<Pair<String, InputStream>, List<List<Writable>>> {
|
||||||
protected SequenceRecordReader sequenceRecordReader;
|
protected SequenceRecordReader sequenceRecordReader;
|
||||||
|
@ -46,7 +48,7 @@ public class SequenceRecordReaderFunction
|
||||||
try (DataInputStream dis = (DataInputStream) value.getRight()) {
|
try (DataInputStream dis = (DataInputStream) value.getRight()) {
|
||||||
return sequenceRecordReader.sequenceRecord(uri, dis);
|
return sequenceRecordReader.sequenceRecord(uri, dis);
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
}
|
}
|
||||||
|
|
||||||
throw new IllegalStateException("Something went wrong");
|
throw new IllegalStateException("Something went wrong");
|
||||||
|
|
|
@ -20,6 +20,7 @@ package org.datavec.python;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.io.IOUtils;
|
import org.apache.commons.io.IOUtils;
|
||||||
|
import org.bytedeco.cpython.global.python;
|
||||||
import org.bytedeco.numpy.global.numpy;
|
import org.bytedeco.numpy.global.numpy;
|
||||||
import org.nd4j.linalg.api.concurrency.AffinityManager;
|
import org.nd4j.linalg.api.concurrency.AffinityManager;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -343,6 +344,19 @@ public class PythonExecutioner {
|
||||||
if (path == null) {
|
if (path == null) {
|
||||||
log.info("Setting python default path");
|
log.info("Setting python default path");
|
||||||
File[] packages = numpy.cachePackages();
|
File[] packages = numpy.cachePackages();
|
||||||
|
|
||||||
|
//// TODO: fix in javacpp
|
||||||
|
File sitePackagesWindows = new File(python.cachePackage(), "site-packages");
|
||||||
|
File[] packages2 = new File[packages.length + 1];
|
||||||
|
for (int i = 0;i < packages.length; i++){
|
||||||
|
//System.out.println(packages[i].getAbsolutePath());
|
||||||
|
packages2[i] = packages[i];
|
||||||
|
}
|
||||||
|
packages2[packages.length] = sitePackagesWindows;
|
||||||
|
//System.out.println(sitePackagesWindows.getAbsolutePath());
|
||||||
|
packages = packages2;
|
||||||
|
//////////
|
||||||
|
|
||||||
Py_SetPath(packages);
|
Py_SetPath(packages);
|
||||||
} else {
|
} else {
|
||||||
log.info("Setting python path " + path);
|
log.info("Setting python path " + path);
|
||||||
|
|
|
@ -0,0 +1,132 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
|
||||||
|
package org.datavec.python;
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.io.IOUtils;
|
||||||
|
import org.bytedeco.javacpp.Loader;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.nio.charset.StandardCharsets;
|
||||||
|
import java.util.Arrays;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
public class PythonProcess {
|
||||||
|
private static String pythonExecutable = Loader.load(org.bytedeco.cpython.python.class);
|
||||||
|
public static String runAndReturn(String... arguments)throws IOException, InterruptedException{
|
||||||
|
String[] allArgs = new String[arguments.length + 1];
|
||||||
|
for (int i = 0; i < arguments.length; i++){
|
||||||
|
allArgs[i + 1] = arguments[i];
|
||||||
|
}
|
||||||
|
allArgs[0] = pythonExecutable;
|
||||||
|
log.info("Executing command: " + Arrays.toString(allArgs));
|
||||||
|
ProcessBuilder pb = new ProcessBuilder(allArgs);
|
||||||
|
Process process = pb.start();
|
||||||
|
String out = IOUtils.toString(process.getInputStream(), StandardCharsets.UTF_8);
|
||||||
|
process.waitFor();
|
||||||
|
return out;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void run(String... arguments)throws IOException, InterruptedException{
|
||||||
|
String[] allArgs = new String[arguments.length + 1];
|
||||||
|
for (int i = 0; i < arguments.length; i++){
|
||||||
|
allArgs[i + 1] = arguments[i];
|
||||||
|
}
|
||||||
|
allArgs[0] = pythonExecutable;
|
||||||
|
log.info("Executing command: " + Arrays.toString(allArgs));
|
||||||
|
ProcessBuilder pb = new ProcessBuilder(allArgs);
|
||||||
|
pb.inheritIO().start().waitFor();
|
||||||
|
}
|
||||||
|
public static void pipInstall(String packageName) throws PythonException{
|
||||||
|
try{
|
||||||
|
run("-m", "pip", "install", packageName);
|
||||||
|
}catch(Exception e){
|
||||||
|
throw new PythonException("Error installing package " + packageName, e);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void pipInstall(String packageName, String version) throws PythonException{
|
||||||
|
pipInstall(packageName + "==" + version);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void pipUninstall(String packageName) throws PythonException{
|
||||||
|
try{
|
||||||
|
run("-m", "pip", "uninstall", packageName);
|
||||||
|
}catch(Exception e){
|
||||||
|
throw new PythonException("Error uninstalling package " + packageName, e);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
public static void pipInstallFromGit(String gitRepoUrl) throws PythonException{
|
||||||
|
if (!gitRepoUrl.contains("://")){
|
||||||
|
gitRepoUrl = "git://" + gitRepoUrl;
|
||||||
|
}
|
||||||
|
try{
|
||||||
|
run("-m", "pip", "install", "git+", gitRepoUrl);
|
||||||
|
}catch(Exception e){
|
||||||
|
throw new PythonException("Error installing package from " + gitRepoUrl, e);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public static String getPackageVersion(String packageName) throws PythonException{
|
||||||
|
String out;
|
||||||
|
try{
|
||||||
|
out = runAndReturn("-m", "pip", "show", packageName);
|
||||||
|
} catch (Exception e){
|
||||||
|
throw new PythonException("Error finding version for package " + packageName, e);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!out.contains("Version: ")){
|
||||||
|
throw new PythonException("Can't find package " + packageName);
|
||||||
|
}
|
||||||
|
String pkgVersion = out.split("Version: ")[1].split(System.lineSeparator())[0];
|
||||||
|
return pkgVersion;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static boolean isPackageInstalled(String packageName)throws PythonException{
|
||||||
|
try{
|
||||||
|
String out = runAndReturn("-m", "pip", "show", packageName);
|
||||||
|
return !out.isEmpty();
|
||||||
|
}catch (Exception e){
|
||||||
|
throw new PythonException("Error checking if package is installed: " +packageName, e);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void pipInstallFromRequirementsTxt(String path) throws PythonException{
|
||||||
|
try{
|
||||||
|
run("-m", "pip", "install","-r", path);
|
||||||
|
}catch (Exception e){
|
||||||
|
throw new PythonException("Error installing packages from " + path, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void pipInstallFromSetupScript(String path, boolean inplace) throws PythonException{
|
||||||
|
|
||||||
|
try{
|
||||||
|
run(path, inplace?"develop":"install");
|
||||||
|
}catch (Exception e){
|
||||||
|
throw new PythonException("Error installing package from " + path, e);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,144 @@
|
||||||
|
package org.datavec.python.keras;
|
||||||
|
|
||||||
|
import org.datavec.python.Python;
|
||||||
|
import org.datavec.python.PythonException;
|
||||||
|
import org.datavec.python.PythonObject;
|
||||||
|
import org.datavec.python.PythonProcess;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
|
public class Model {
|
||||||
|
|
||||||
|
private PythonObject pyModel;
|
||||||
|
|
||||||
|
|
||||||
|
private static PythonObject installAndImportTF() throws PythonException{
|
||||||
|
if (!PythonProcess.isPackageInstalled("tensorflow")){
|
||||||
|
PythonProcess.pipInstall("tensorflow");
|
||||||
|
}
|
||||||
|
return Python.importModule("tensorflow");
|
||||||
|
}
|
||||||
|
private static PythonObject getKerasModule() throws PythonException{
|
||||||
|
PythonObject tf = installAndImportTF();
|
||||||
|
PythonObject keras = tf.attr("keras");
|
||||||
|
tf.del();
|
||||||
|
return keras;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static PythonObject loadModel(String s) throws PythonException{
|
||||||
|
PythonObject models = getKerasModule().attr("models");
|
||||||
|
PythonObject loadModelF = models.attr("load_model");
|
||||||
|
PythonObject model = loadModelF.call(s);
|
||||||
|
models.del();
|
||||||
|
loadModelF.del();
|
||||||
|
return model;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Model(String path) throws PythonException{
|
||||||
|
pyModel = loadModel(path);
|
||||||
|
}
|
||||||
|
|
||||||
|
public INDArray[] predict(INDArray... inputs) throws PythonException{
|
||||||
|
PythonObject predictF = pyModel.attr("predict");
|
||||||
|
PythonObject inputList = new PythonObject(inputs);
|
||||||
|
PythonObject pyOut = predictF.call(inputList);
|
||||||
|
INDArray[] out;
|
||||||
|
if (Python.isinstance(pyOut, Python.listType())){
|
||||||
|
out = new INDArray[Python.len(pyOut).toInt()];
|
||||||
|
for(int i = 0; i < out.length; i++){
|
||||||
|
out[i] = pyOut.get(i).toNumpy().getNd4jArray();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
out = new INDArray[]{
|
||||||
|
pyOut.toNumpy().getNd4jArray()};
|
||||||
|
}
|
||||||
|
|
||||||
|
predictF.del();
|
||||||
|
inputList.del();
|
||||||
|
pyOut.del();
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int numInputs(){
|
||||||
|
PythonObject inputs = pyModel.attr("inputs");
|
||||||
|
PythonObject pyNumInputs = Python.len(inputs);
|
||||||
|
int ret = pyNumInputs.toInt();
|
||||||
|
inputs.del();
|
||||||
|
pyNumInputs.del();
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
public int numOutputs(){
|
||||||
|
PythonObject outputs = pyModel.attr("outputs");
|
||||||
|
PythonObject pyNumOutputs = Python.len(outputs);
|
||||||
|
int ret = pyNumOutputs.toInt();
|
||||||
|
outputs.del();
|
||||||
|
pyNumOutputs.del();
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
public long[][] inputShapes(){
|
||||||
|
long[][] ret = new long[numInputs()][];
|
||||||
|
for (int i = 0; i < ret.length; i++){
|
||||||
|
ret[i] = inputShapeAt(i);
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
public long[][] outputShapes(){
|
||||||
|
long[][] ret = new long[numOutputs()][];
|
||||||
|
for (int i = 0; i < ret.length; i++){
|
||||||
|
ret[i] = outputShapeAt(i);
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
public long[] inputShapeAt(int input){
|
||||||
|
PythonObject inputs = pyModel.attr("inputs");
|
||||||
|
PythonObject tensor = inputs.get(input);
|
||||||
|
PythonObject tensorShape = tensor.attr("shape");
|
||||||
|
PythonObject shapeList = Python.list(tensorShape);
|
||||||
|
PythonObject pyNdim = Python.len(shapeList);
|
||||||
|
int ndim = pyNdim.toInt();
|
||||||
|
long[] shape = new long[ndim];
|
||||||
|
for(int i = 0; i < shape.length; i++){
|
||||||
|
PythonObject pyDim = shapeList.get(i);
|
||||||
|
if (pyDim == null || !Python.isinstance(pyDim, Python.intType())){
|
||||||
|
shape[i] = -1;
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
shape[i] = pyDim.toLong();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pyNdim.del();
|
||||||
|
shapeList.del();
|
||||||
|
tensorShape.del();
|
||||||
|
tensor.del();
|
||||||
|
inputs.del();
|
||||||
|
return shape;
|
||||||
|
}
|
||||||
|
|
||||||
|
public long[] outputShapeAt(int output){
|
||||||
|
PythonObject inputs = pyModel.attr("outputs");
|
||||||
|
PythonObject tensor = inputs.get(output);
|
||||||
|
PythonObject tensorShape = tensor.attr("shape");
|
||||||
|
PythonObject shapeList = Python.list(tensorShape);
|
||||||
|
PythonObject pyNdim = Python.len(shapeList);
|
||||||
|
int ndim = pyNdim.toInt();
|
||||||
|
long[] shape = new long[ndim];
|
||||||
|
for(int i = 0; i < shape.length; i++){
|
||||||
|
PythonObject pyDim = shapeList.get(i);
|
||||||
|
if (pyDim == null || !Python.isinstance(pyDim, Python.intType())){
|
||||||
|
shape[i] = -1;
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
shape[i] = pyDim.toLong();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pyNdim.del();
|
||||||
|
shapeList.del();
|
||||||
|
tensorShape.del();
|
||||||
|
tensor.del();
|
||||||
|
inputs.del();
|
||||||
|
return shape;
|
||||||
|
}
|
||||||
|
}
|
|
@ -74,7 +74,6 @@ public class DataVecTransformClient implements DataVecTransformService {
|
||||||
|
|
||||||
} catch (UnirestException e) {
|
} catch (UnirestException e) {
|
||||||
log.error("Error in setCSVTransformProcess()", e);
|
log.error("Error in setCSVTransformProcess()", e);
|
||||||
e.printStackTrace();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -94,7 +93,6 @@ public class DataVecTransformClient implements DataVecTransformService {
|
||||||
return TransformProcess.fromJson(s);
|
return TransformProcess.fromJson(s);
|
||||||
} catch (UnirestException e) {
|
} catch (UnirestException e) {
|
||||||
log.error("Error in getCSVTransformProcess()",e);
|
log.error("Error in getCSVTransformProcess()",e);
|
||||||
e.printStackTrace();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return null;
|
return null;
|
||||||
|
@ -119,7 +117,6 @@ public class DataVecTransformClient implements DataVecTransformService {
|
||||||
return singleCsvRecord;
|
return singleCsvRecord;
|
||||||
} catch (UnirestException e) {
|
} catch (UnirestException e) {
|
||||||
log.error("Error in transformIncremental(SingleCSVRecord)",e);
|
log.error("Error in transformIncremental(SingleCSVRecord)",e);
|
||||||
e.printStackTrace();
|
|
||||||
}
|
}
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
@ -140,8 +137,7 @@ public class DataVecTransformClient implements DataVecTransformService {
|
||||||
.getBody();
|
.getBody();
|
||||||
return batchCSVRecord1;
|
return batchCSVRecord1;
|
||||||
} catch (UnirestException e) {
|
} catch (UnirestException e) {
|
||||||
log.error("Error in transform(BatchCSVRecord)", e);
|
log.error("",e);
|
||||||
e.printStackTrace();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return null;
|
return null;
|
||||||
|
@ -162,7 +158,6 @@ public class DataVecTransformClient implements DataVecTransformService {
|
||||||
return batchCSVRecord1;
|
return batchCSVRecord1;
|
||||||
} catch (UnirestException e) {
|
} catch (UnirestException e) {
|
||||||
log.error("Error in transform(BatchCSVRecord)", e);
|
log.error("Error in transform(BatchCSVRecord)", e);
|
||||||
e.printStackTrace();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return null;
|
return null;
|
||||||
|
@ -181,7 +176,6 @@ public class DataVecTransformClient implements DataVecTransformService {
|
||||||
return batchArray1;
|
return batchArray1;
|
||||||
} catch (UnirestException e) {
|
} catch (UnirestException e) {
|
||||||
log.error("Error in transformArray(BatchCSVRecord)",e);
|
log.error("Error in transformArray(BatchCSVRecord)",e);
|
||||||
e.printStackTrace();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return null;
|
return null;
|
||||||
|
@ -200,7 +194,6 @@ public class DataVecTransformClient implements DataVecTransformService {
|
||||||
return array;
|
return array;
|
||||||
} catch (UnirestException e) {
|
} catch (UnirestException e) {
|
||||||
log.error("Error in transformArrayIncremental(SingleCSVRecord)",e);
|
log.error("Error in transformArrayIncremental(SingleCSVRecord)",e);
|
||||||
e.printStackTrace();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return null;
|
return null;
|
||||||
|
@ -231,7 +224,6 @@ public class DataVecTransformClient implements DataVecTransformService {
|
||||||
return array;
|
return array;
|
||||||
} catch (UnirestException e) {
|
} catch (UnirestException e) {
|
||||||
log.error("Error in transformSequenceArrayIncremental",e);
|
log.error("Error in transformSequenceArrayIncremental",e);
|
||||||
e.printStackTrace();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return null;
|
return null;
|
||||||
|
@ -252,7 +244,6 @@ public class DataVecTransformClient implements DataVecTransformService {
|
||||||
return batchArray1;
|
return batchArray1;
|
||||||
} catch (UnirestException e) {
|
} catch (UnirestException e) {
|
||||||
log.error("Error in transformSequenceArray",e);
|
log.error("Error in transformSequenceArray",e);
|
||||||
e.printStackTrace();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return null;
|
return null;
|
||||||
|
@ -274,7 +265,6 @@ public class DataVecTransformClient implements DataVecTransformService {
|
||||||
return batchCSVRecord1;
|
return batchCSVRecord1;
|
||||||
} catch (UnirestException e) {
|
} catch (UnirestException e) {
|
||||||
log.error("Error in transformSequence");
|
log.error("Error in transformSequence");
|
||||||
e.printStackTrace();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return null;
|
return null;
|
||||||
|
@ -295,7 +285,6 @@ public class DataVecTransformClient implements DataVecTransformService {
|
||||||
return singleCsvRecord;
|
return singleCsvRecord;
|
||||||
} catch (UnirestException e) {
|
} catch (UnirestException e) {
|
||||||
log.error("Error in transformSequenceIncremental");
|
log.error("Error in transformSequenceIncremental");
|
||||||
e.printStackTrace();
|
|
||||||
}
|
}
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,6 +18,7 @@ package org.datavec.spark.transform;
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.apache.arrow.memory.BufferAllocator;
|
import org.apache.arrow.memory.BufferAllocator;
|
||||||
import org.apache.arrow.memory.RootAllocator;
|
import org.apache.arrow.memory.RootAllocator;
|
||||||
|
@ -53,7 +54,7 @@ import static org.datavec.local.transforms.LocalTransformExecutor.executeToSeque
|
||||||
* @author Adan Gibson
|
* @author Adan Gibson
|
||||||
*/
|
*/
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
|
@Slf4j
|
||||||
public class CSVSparkTransform {
|
public class CSVSparkTransform {
|
||||||
@Getter
|
@Getter
|
||||||
private TransformProcess transformProcess;
|
private TransformProcess transformProcess;
|
||||||
|
@ -252,7 +253,7 @@ public class CSVSparkTransform {
|
||||||
try {
|
try {
|
||||||
return new Base64NDArrayBody(Nd4jBase64.base64String(arr));
|
return new Base64NDArrayBody(Nd4jBase64.base64String(arr));
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
}
|
}
|
||||||
|
|
||||||
return null;
|
return null;
|
||||||
|
|
|
@ -88,7 +88,7 @@ public class ImageSparkTransformServer extends SparkTransformServer {
|
||||||
log.info("Transform process initialized");
|
log.info("Transform process initialized");
|
||||||
return ok(objectMapper.writeValueAsString(transform.getImageTransformProcess())).as(contentType);
|
return ok(objectMapper.writeValueAsString(transform.getImageTransformProcess())).as(contentType);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
return internalServerError();
|
return internalServerError();
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
@ -100,7 +100,7 @@ public class ImageSparkTransformServer extends SparkTransformServer {
|
||||||
log.info("Transform process initialized");
|
log.info("Transform process initialized");
|
||||||
return ok(objectMapper.writeValueAsString(transformProcess)).as(contentType);
|
return ok(objectMapper.writeValueAsString(transformProcess)).as(contentType);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
return internalServerError();
|
return internalServerError();
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
@ -112,7 +112,7 @@ public class ImageSparkTransformServer extends SparkTransformServer {
|
||||||
return badRequest();
|
return badRequest();
|
||||||
return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType);
|
return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
return internalServerError();
|
return internalServerError();
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
@ -130,7 +130,7 @@ public class ImageSparkTransformServer extends SparkTransformServer {
|
||||||
|
|
||||||
return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType);
|
return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
return internalServerError();
|
return internalServerError();
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
@ -142,7 +142,7 @@ public class ImageSparkTransformServer extends SparkTransformServer {
|
||||||
return badRequest();
|
return badRequest();
|
||||||
return ok(objectMapper.writeValueAsString(transformArray(batch))).as(contentType);
|
return ok(objectMapper.writeValueAsString(transformArray(batch))).as(contentType);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
return internalServerError();
|
return internalServerError();
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
@ -169,7 +169,7 @@ public class ImageSparkTransformServer extends SparkTransformServer {
|
||||||
|
|
||||||
return ok(objectMapper.writeValueAsString(transformArray(batch))).as(contentType);
|
return ok(objectMapper.writeValueAsString(transformArray(batch))).as(contentType);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
return internalServerError();
|
return internalServerError();
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.datavec.spark;
|
package org.datavec.spark;
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.spark.SparkConf;
|
import org.apache.spark.SparkConf;
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
import org.junit.After;
|
import org.junit.After;
|
||||||
|
@ -23,6 +24,7 @@ import org.junit.Before;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
public abstract class BaseSparkTest implements Serializable {
|
public abstract class BaseSparkTest implements Serializable {
|
||||||
protected static JavaSparkContext sc;
|
protected static JavaSparkContext sc;
|
||||||
|
|
||||||
|
@ -40,7 +42,7 @@ public abstract class BaseSparkTest implements Serializable {
|
||||||
try {
|
try {
|
||||||
Thread.sleep(100L);
|
Thread.sleep(100L);
|
||||||
} catch (InterruptedException e) {
|
} catch (InterruptedException e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
break;
|
break;
|
||||||
|
|
|
@ -68,7 +68,7 @@ public abstract class BaseDL4JTest {
|
||||||
* Override this method to set the default timeout for methods in the test class
|
* Override this method to set the default timeout for methods in the test class
|
||||||
*/
|
*/
|
||||||
public long getTimeoutMilliseconds(){
|
public long getTimeoutMilliseconds(){
|
||||||
return 30000;
|
return 60_000;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -43,7 +43,8 @@ import java.util.concurrent.atomic.AtomicLong;
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class RemoteUIStatsStorageRouter implements StatsStorageRouter, Serializable, Closeable {
|
public class RemoteUIStatsStorageRouter implements StatsStorageRouter, Serializable, Closeable {
|
||||||
|
private static final String ROUTE_IS_DOWN = "Info posted to RemoteUIStatsStorageRouter but router is shut down.";
|
||||||
|
private static final String MAX_WARNINGS_REACHED = "RemoteUIStatsStorageRouter: Reached max shutdown warnings. No further warnings will be produced.";
|
||||||
/**
|
/**
|
||||||
* Default path for posting data to the UI - i.e., http://localhost:9000/remoteReceive or similar
|
* Default path for posting data to the UI - i.e., http://localhost:9000/remoteReceive or similar
|
||||||
*/
|
*/
|
||||||
|
@ -163,10 +164,10 @@ public class RemoteUIStatsStorageRouter implements StatsStorageRouter, Serializa
|
||||||
if (shutdown.get()) {
|
if (shutdown.get()) {
|
||||||
long count = shutdownWarnCount.getAndIncrement();
|
long count = shutdownWarnCount.getAndIncrement();
|
||||||
if (count <= MAX_SHUTDOWN_WARN_COUNT) {
|
if (count <= MAX_SHUTDOWN_WARN_COUNT) {
|
||||||
log.warn("Info posted to RemoteUIStatsStorageRouter but router is shut down.");
|
log.warn(ROUTE_IS_DOWN);
|
||||||
}
|
}
|
||||||
if (count == MAX_SHUTDOWN_WARN_COUNT) {
|
if (count == MAX_SHUTDOWN_WARN_COUNT) {
|
||||||
log.warn("RemoteUIStatsStorageRouter: Reached max shutdown warnings. No further warnings will be produced.");
|
log.warn(MAX_WARNINGS_REACHED);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (StorageMetaData m : storageMetaData) {
|
for (StorageMetaData m : storageMetaData) {
|
||||||
|
@ -186,10 +187,10 @@ public class RemoteUIStatsStorageRouter implements StatsStorageRouter, Serializa
|
||||||
if (shutdown.get()) {
|
if (shutdown.get()) {
|
||||||
long count = shutdownWarnCount.getAndIncrement();
|
long count = shutdownWarnCount.getAndIncrement();
|
||||||
if (count <= MAX_SHUTDOWN_WARN_COUNT) {
|
if (count <= MAX_SHUTDOWN_WARN_COUNT) {
|
||||||
log.warn("Info posted to RemoteUIStatsStorageRouter but router is shut down.");
|
log.warn(ROUTE_IS_DOWN);
|
||||||
}
|
}
|
||||||
if (count == MAX_SHUTDOWN_WARN_COUNT) {
|
if (count == MAX_SHUTDOWN_WARN_COUNT) {
|
||||||
log.warn("RemoteUIStatsStorageRouter: Reached max shutdown warnings. No further warnings will be produced.");
|
log.warn(MAX_WARNINGS_REACHED);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (Persistable p : staticInfo) {
|
for (Persistable p : staticInfo) {
|
||||||
|
@ -209,10 +210,10 @@ public class RemoteUIStatsStorageRouter implements StatsStorageRouter, Serializa
|
||||||
if (shutdown.get()) {
|
if (shutdown.get()) {
|
||||||
long count = shutdownWarnCount.getAndIncrement();
|
long count = shutdownWarnCount.getAndIncrement();
|
||||||
if (count <= MAX_SHUTDOWN_WARN_COUNT) {
|
if (count <= MAX_SHUTDOWN_WARN_COUNT) {
|
||||||
log.warn("Info posted to RemoteUIStatsStorageRouter but router is shut down.");
|
log.warn(ROUTE_IS_DOWN);
|
||||||
}
|
}
|
||||||
if (count == MAX_SHUTDOWN_WARN_COUNT) {
|
if (count == MAX_SHUTDOWN_WARN_COUNT) {
|
||||||
log.warn("RemoteUIStatsStorageRouter: Reached max shutdown warnings. No further warnings will be produced.");
|
log.warn(MAX_WARNINGS_REACHED);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (Persistable p : updates) {
|
for (Persistable p : updates) {
|
||||||
|
|
|
@ -67,10 +67,7 @@ public class AsyncIterator<T extends Object> implements Iterator<T> {
|
||||||
nextElement = buffer.take();
|
nextElement = buffer.take();
|
||||||
|
|
||||||
// same on this run
|
// same on this run
|
||||||
if (nextElement == terminator)
|
return (nextElement != terminator);
|
||||||
return false;
|
|
||||||
|
|
||||||
return true;
|
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.error("Premature end of loop!");
|
log.error("Premature end of loop!");
|
||||||
return false;
|
return false;
|
||||||
|
|
|
@ -43,7 +43,7 @@ public class SystemInfoPrintListener implements TrainingListener {
|
||||||
private boolean printOnBackwardPass;
|
private boolean printOnBackwardPass;
|
||||||
private boolean printOnGradientCalculation;
|
private boolean printOnGradientCalculation;
|
||||||
|
|
||||||
|
private static final String SYSTEM_INFO = "System info on epoch end: ";
|
||||||
@Override
|
@Override
|
||||||
public void iterationDone(Model model, int iteration, int epoch) {
|
public void iterationDone(Model model, int iteration, int epoch) {
|
||||||
|
|
||||||
|
@ -65,7 +65,7 @@ public class SystemInfoPrintListener implements TrainingListener {
|
||||||
return;
|
return;
|
||||||
|
|
||||||
SystemInfo systemInfo = new SystemInfo();
|
SystemInfo systemInfo = new SystemInfo();
|
||||||
log.info("System info on epoch end: ");
|
log.info(SYSTEM_INFO);
|
||||||
log.info(systemInfo.toPrettyJSON());
|
log.info(systemInfo.toPrettyJSON());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -75,7 +75,7 @@ public class SystemInfoPrintListener implements TrainingListener {
|
||||||
return;
|
return;
|
||||||
|
|
||||||
SystemInfo systemInfo = new SystemInfo();
|
SystemInfo systemInfo = new SystemInfo();
|
||||||
log.info("System info on epoch end: ");
|
log.info(SYSTEM_INFO);
|
||||||
log.info(systemInfo.toPrettyJSON());
|
log.info(systemInfo.toPrettyJSON());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -85,7 +85,7 @@ public class SystemInfoPrintListener implements TrainingListener {
|
||||||
return;
|
return;
|
||||||
|
|
||||||
SystemInfo systemInfo = new SystemInfo();
|
SystemInfo systemInfo = new SystemInfo();
|
||||||
log.info("System info on epoch end: ");
|
log.info(SYSTEM_INFO);
|
||||||
log.info(systemInfo.toPrettyJSON());
|
log.info(systemInfo.toPrettyJSON());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -95,7 +95,7 @@ public class SystemInfoPrintListener implements TrainingListener {
|
||||||
return;
|
return;
|
||||||
|
|
||||||
SystemInfo systemInfo = new SystemInfo();
|
SystemInfo systemInfo = new SystemInfo();
|
||||||
log.info("System info on epoch end: ");
|
log.info(SYSTEM_INFO);
|
||||||
log.info(systemInfo.toPrettyJSON());
|
log.info(systemInfo.toPrettyJSON());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -104,7 +104,7 @@ public class SystemInfoPrintListener implements TrainingListener {
|
||||||
if(!printOnBackwardPass)
|
if(!printOnBackwardPass)
|
||||||
return;
|
return;
|
||||||
SystemInfo systemInfo = new SystemInfo();
|
SystemInfo systemInfo = new SystemInfo();
|
||||||
log.info("System info on epoch end: ");
|
log.info(SYSTEM_INFO);
|
||||||
log.info(systemInfo.toPrettyJSON());
|
log.info(systemInfo.toPrettyJSON());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.perf.listener;
|
package org.deeplearning4j.perf.listener;
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||||||
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
|
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
|
||||||
import oshi.json.SystemInfo;
|
import oshi.json.SystemInfo;
|
||||||
|
@ -36,6 +37,7 @@ import java.util.concurrent.TimeUnit;
|
||||||
*
|
*
|
||||||
* @author Adam Gibson
|
* @author Adam Gibson
|
||||||
*/
|
*/
|
||||||
|
@Slf4j
|
||||||
public class SystemPolling {
|
public class SystemPolling {
|
||||||
|
|
||||||
private ScheduledExecutorService scheduledExecutorService;
|
private ScheduledExecutorService scheduledExecutorService;
|
||||||
|
@ -66,7 +68,7 @@ public class SystemPolling {
|
||||||
try {
|
try {
|
||||||
objectMapper.writeValue(hardwareFile,hardwareMetric);
|
objectMapper.writeValue(hardwareFile,hardwareMetric);
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},0,pollEveryMillis, TimeUnit.MILLISECONDS);
|
},0,pollEveryMillis, TimeUnit.MILLISECONDS);
|
||||||
|
|
|
@ -27,7 +27,6 @@ import org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
|
||||||
import org.nd4j.linalg.dataset.api.preprocessor.Normalizer;
|
import org.nd4j.linalg.dataset.api.preprocessor.Normalizer;
|
||||||
|
|
||||||
import java.io.*;
|
import java.io.*;
|
||||||
import java.nio.file.Files;
|
|
||||||
import java.util.UUID;
|
import java.util.UUID;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -20,6 +20,7 @@ import org.apache.commons.compress.utils.IOUtils;
|
||||||
import org.deeplearning4j.nn.api.Layer;
|
import org.deeplearning4j.nn.api.Layer;
|
||||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||||
import org.deeplearning4j.nn.conf.layers.BaseLayer;
|
import org.deeplearning4j.nn.conf.layers.BaseLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer;
|
import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer;
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
|
@ -153,11 +154,22 @@ public class TestUtils {
|
||||||
return randomOneHotTimeSeries(minibatch, outSize, tsLength, new Random(rngSeed));
|
return randomOneHotTimeSeries(minibatch, outSize, tsLength, new Random(rngSeed));
|
||||||
}
|
}
|
||||||
|
|
||||||
public static INDArray randomOneHotTimeSeries(int minibatch, int outSize, int tsLength, Random rng){
|
public static INDArray randomOneHotTimeSeries(int minibatch, int outSize, int tsLength, Random rng) {
|
||||||
INDArray out = Nd4j.create(new int[]{minibatch, outSize, tsLength}, 'f');
|
return randomOneHotTimeSeries(RNNFormat.NCW, minibatch, outSize, tsLength, rng);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static INDArray randomOneHotTimeSeries(RNNFormat format, int minibatch, int outSize, int tsLength, Random rng){
|
||||||
|
boolean ncw = format == RNNFormat.NCW;
|
||||||
|
long[] shape = ncw ? new long[]{minibatch, outSize, tsLength} : new long[]{minibatch, tsLength, outSize};
|
||||||
|
char order = ncw ? 'f' : 'c';
|
||||||
|
INDArray out = Nd4j.create(DataType.FLOAT, shape, order);
|
||||||
for( int i=0; i<minibatch; i++ ){
|
for( int i=0; i<minibatch; i++ ){
|
||||||
for( int j=0; j<tsLength; j++ ){
|
for( int j=0; j<tsLength; j++ ){
|
||||||
out.putScalar(i, rng.nextInt(outSize), j, 1.0);
|
if(ncw){
|
||||||
|
out.putScalar(i, rng.nextInt(outSize), j, 1.0);
|
||||||
|
} else {
|
||||||
|
out.putScalar(i, j, rng.nextInt(outSize), 1.0);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return out;
|
return out;
|
||||||
|
|
|
@ -493,7 +493,7 @@ public class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
|
||||||
out.println();
|
out.println();
|
||||||
}
|
}
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
}
|
}
|
||||||
|
|
||||||
return new Pair<>(dArr,temp);
|
return new Pair<>(dArr,temp);
|
||||||
|
|
|
@ -78,10 +78,10 @@ public class EarlyTerminationDataSetIteratorTest extends BaseDL4JTest {
|
||||||
EarlyTerminationDataSetIterator earlyEndIter = new EarlyTerminationDataSetIterator(iter, terminateAfter);
|
EarlyTerminationDataSetIterator earlyEndIter = new EarlyTerminationDataSetIterator(iter, terminateAfter);
|
||||||
|
|
||||||
earlyEndIter.next(10);
|
earlyEndIter.next(10);
|
||||||
assertEquals(earlyEndIter.hasNext(), false);
|
assertEquals(false, earlyEndIter.hasNext());
|
||||||
|
|
||||||
earlyEndIter.reset();
|
earlyEndIter.reset();
|
||||||
assertEquals(earlyEndIter.hasNext(), true);
|
assertEquals(true, earlyEndIter.hasNext());
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -98,7 +98,7 @@ public class MultipleEpochsIteratorTest extends BaseDL4JTest {
|
||||||
while (multiIter.hasNext()) {
|
while (multiIter.hasNext()) {
|
||||||
DataSet path = multiIter.next(10);
|
DataSet path = multiIter.next(10);
|
||||||
assertNotNull(path);
|
assertNotNull(path);
|
||||||
assertEquals(path.numExamples(), 10, 0.0);
|
assertEquals(10, path.numExamples(), 0.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(epochs, multiIter.epochs);
|
assertEquals(epochs, multiIter.epochs);
|
||||||
|
|
|
@ -33,7 +33,7 @@ public class SamplingTest extends BaseDL4JTest {
|
||||||
DataSetIterator iter = new MnistDataSetIterator(10, 10);
|
DataSetIterator iter = new MnistDataSetIterator(10, 10);
|
||||||
//batch size and total
|
//batch size and total
|
||||||
DataSetIterator sampling = new SamplingDataSetIterator(iter.next(), 10, 10);
|
DataSetIterator sampling = new SamplingDataSetIterator(iter.next(), 10, 10);
|
||||||
assertEquals(sampling.next().numExamples(), 10);
|
assertEquals(10, sampling.next().numExamples());
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.exceptions;
|
package org.deeplearning4j.exceptions;
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.exception.DL4JException;
|
import org.deeplearning4j.exception.DL4JException;
|
||||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||||
|
@ -34,6 +35,7 @@ import static org.junit.Assert.fail;
|
||||||
/**
|
/**
|
||||||
* A set of tests to ensure that useful exceptions are thrown on invalid network configurations
|
* A set of tests to ensure that useful exceptions are thrown on invalid network configurations
|
||||||
*/
|
*/
|
||||||
|
@Slf4j
|
||||||
public class TestInvalidConfigurations extends BaseDL4JTest {
|
public class TestInvalidConfigurations extends BaseDL4JTest {
|
||||||
|
|
||||||
public static MultiLayerNetwork getDensePlusOutput(int nIn, int nOut) {
|
public static MultiLayerNetwork getDensePlusOutput(int nIn, int nOut) {
|
||||||
|
@ -78,7 +80,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
|
||||||
} catch (DL4JException e) {
|
} catch (DL4JException e) {
|
||||||
System.out.println("testDenseNin0(): " + e.getMessage());
|
System.out.println("testDenseNin0(): " + e.getMessage());
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
fail();
|
fail();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -96,7 +98,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
|
||||||
} catch (DL4JException e) {
|
} catch (DL4JException e) {
|
||||||
System.out.println("testDenseNout0(): " + e.getMessage());
|
System.out.println("testDenseNout0(): " + e.getMessage());
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
fail();
|
fail();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -109,7 +111,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
|
||||||
} catch (DL4JException e) {
|
} catch (DL4JException e) {
|
||||||
System.out.println("testOutputLayerNin0(): " + e.getMessage());
|
System.out.println("testOutputLayerNin0(): " + e.getMessage());
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
fail();
|
fail();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -122,7 +124,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
|
||||||
} catch (DL4JException e) {
|
} catch (DL4JException e) {
|
||||||
System.out.println("testRnnOutputLayerNin0(): " + e.getMessage());
|
System.out.println("testRnnOutputLayerNin0(): " + e.getMessage());
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
fail();
|
fail();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -135,7 +137,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
|
||||||
} catch (DL4JException e) {
|
} catch (DL4JException e) {
|
||||||
System.out.println("testLSTMNIn0(): " + e.getMessage());
|
System.out.println("testLSTMNIn0(): " + e.getMessage());
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
fail();
|
fail();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -153,7 +155,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
|
||||||
} catch (DL4JException e) {
|
} catch (DL4JException e) {
|
||||||
System.out.println("testLSTMNOut0(): " + e.getMessage());
|
System.out.println("testLSTMNOut0(): " + e.getMessage());
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
fail();
|
fail();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -166,7 +168,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
|
||||||
} catch (DL4JException e) {
|
} catch (DL4JException e) {
|
||||||
System.out.println("testConvolutionalNin0(): " + e.getMessage());
|
System.out.println("testConvolutionalNin0(): " + e.getMessage());
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
fail();
|
fail();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -185,7 +187,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
|
||||||
} catch (DL4JException e) {
|
} catch (DL4JException e) {
|
||||||
System.out.println("testConvolutionalNOut0(): " + e.getMessage());
|
System.out.println("testConvolutionalNOut0(): " + e.getMessage());
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
fail();
|
fail();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -216,7 +218,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
|
||||||
} catch (DL4JException e) {
|
} catch (DL4JException e) {
|
||||||
System.out.println("testCnnInvalidConfigPaddingStridesHeight(): " + e.getMessage());
|
System.out.println("testCnnInvalidConfigPaddingStridesHeight(): " + e.getMessage());
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
fail();
|
fail();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -245,7 +247,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
|
||||||
} catch (DL4JException e) {
|
} catch (DL4JException e) {
|
||||||
System.out.println("testCnnInvalidConfigOrInput_SmallerDataThanKernel(): " + e.getMessage());
|
System.out.println("testCnnInvalidConfigOrInput_SmallerDataThanKernel(): " + e.getMessage());
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
fail();
|
fail();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -277,7 +279,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
|
||||||
} catch (DL4JException e) {
|
} catch (DL4JException e) {
|
||||||
System.out.println("testCnnInvalidConfigOrInput_BadStrides(): " + e.getMessage());
|
System.out.println("testCnnInvalidConfigOrInput_BadStrides(): " + e.getMessage());
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
fail();
|
fail();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -318,7 +320,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
|
||||||
} catch (DL4JException e) {
|
} catch (DL4JException e) {
|
||||||
System.out.println("testCnnInvalidConfigPaddingStridesWidth(): " + e.getMessage());
|
System.out.println("testCnnInvalidConfigPaddingStridesWidth(): " + e.getMessage());
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
fail();
|
fail();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -347,7 +349,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
|
||||||
} catch (DL4JException e) {
|
} catch (DL4JException e) {
|
||||||
System.out.println("testCnnInvalidConfigPaddingStridesWidthSubsampling(): " + e.getMessage());
|
System.out.println("testCnnInvalidConfigPaddingStridesWidthSubsampling(): " + e.getMessage());
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
fail();
|
fail();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.exceptions;
|
package org.deeplearning4j.exceptions;
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.exception.DL4JException;
|
import org.deeplearning4j.exception.DL4JException;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
|
@ -36,6 +37,7 @@ import static org.junit.Assert.*;
|
||||||
/**
|
/**
|
||||||
* A set of tests to ensure that useful exceptions are thrown on invalid input
|
* A set of tests to ensure that useful exceptions are thrown on invalid input
|
||||||
*/
|
*/
|
||||||
|
@Slf4j
|
||||||
public class TestInvalidInput extends BaseDL4JTest {
|
public class TestInvalidInput extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -53,7 +55,7 @@ public class TestInvalidInput extends BaseDL4JTest {
|
||||||
} catch (DL4JException e) {
|
} catch (DL4JException e) {
|
||||||
System.out.println("testInputNinMismatchDense(): " + e.getMessage());
|
System.out.println("testInputNinMismatchDense(): " + e.getMessage());
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
fail("Expected DL4JException");
|
fail("Expected DL4JException");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -73,7 +75,7 @@ public class TestInvalidInput extends BaseDL4JTest {
|
||||||
} catch (DL4JException e) {
|
} catch (DL4JException e) {
|
||||||
System.out.println("testInputNinMismatchOutputLayer(): " + e.getMessage());
|
System.out.println("testInputNinMismatchOutputLayer(): " + e.getMessage());
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
fail("Expected DL4JException");
|
fail("Expected DL4JException");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -94,7 +96,7 @@ public class TestInvalidInput extends BaseDL4JTest {
|
||||||
//From loss function
|
//From loss function
|
||||||
System.out.println("testLabelsNOutMismatchOutputLayer(): " + e.getMessage());
|
System.out.println("testLabelsNOutMismatchOutputLayer(): " + e.getMessage());
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
fail("Expected DL4JException");
|
fail("Expected DL4JException");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -115,7 +117,7 @@ public class TestInvalidInput extends BaseDL4JTest {
|
||||||
//From loss function
|
//From loss function
|
||||||
System.out.println("testLabelsNOutMismatchRnnOutputLayer(): " + e.getMessage());
|
System.out.println("testLabelsNOutMismatchRnnOutputLayer(): " + e.getMessage());
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
fail("Expected DL4JException");
|
fail("Expected DL4JException");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -142,7 +144,7 @@ public class TestInvalidInput extends BaseDL4JTest {
|
||||||
} catch (DL4JException e) {
|
} catch (DL4JException e) {
|
||||||
System.out.println("testInputNinMismatchConvolutional(): " + e.getMessage());
|
System.out.println("testInputNinMismatchConvolutional(): " + e.getMessage());
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
fail("Expected DL4JException");
|
fail("Expected DL4JException");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -169,7 +171,7 @@ public class TestInvalidInput extends BaseDL4JTest {
|
||||||
} catch (DL4JException e) {
|
} catch (DL4JException e) {
|
||||||
System.out.println("testInputNinRank2Convolutional(): " + e.getMessage());
|
System.out.println("testInputNinRank2Convolutional(): " + e.getMessage());
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
fail("Expected DL4JException");
|
fail("Expected DL4JException");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -195,7 +197,7 @@ public class TestInvalidInput extends BaseDL4JTest {
|
||||||
} catch (DL4JException e) {
|
} catch (DL4JException e) {
|
||||||
System.out.println("testInputNinRank2Subsampling(): " + e.getMessage());
|
System.out.println("testInputNinRank2Subsampling(): " + e.getMessage());
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
fail("Expected DL4JException");
|
fail("Expected DL4JException");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -217,7 +219,7 @@ public class TestInvalidInput extends BaseDL4JTest {
|
||||||
} catch (DL4JException e) {
|
} catch (DL4JException e) {
|
||||||
System.out.println("testInputNinMismatchLSTM(): " + e.getMessage());
|
System.out.println("testInputNinMismatchLSTM(): " + e.getMessage());
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
fail("Expected DL4JException");
|
fail("Expected DL4JException");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -238,7 +240,7 @@ public class TestInvalidInput extends BaseDL4JTest {
|
||||||
} catch (DL4JException e) {
|
} catch (DL4JException e) {
|
||||||
System.out.println("testInputNinMismatchBidirectionalLSTM(): " + e.getMessage());
|
System.out.println("testInputNinMismatchBidirectionalLSTM(): " + e.getMessage());
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
fail("Expected DL4JException");
|
fail("Expected DL4JException");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -260,7 +262,7 @@ public class TestInvalidInput extends BaseDL4JTest {
|
||||||
} catch (DL4JException e) {
|
} catch (DL4JException e) {
|
||||||
System.out.println("testInputNinMismatchEmbeddingLayer(): " + e.getMessage());
|
System.out.println("testInputNinMismatchEmbeddingLayer(): " + e.getMessage());
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
fail("Expected DL4JException");
|
fail("Expected DL4JException");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -305,7 +307,7 @@ public class TestInvalidInput extends BaseDL4JTest {
|
||||||
net.rnnTimeStep(Nd4j.create(5, 5, 10));
|
net.rnnTimeStep(Nd4j.create(5, 5, 10));
|
||||||
fail("Expected Exception - " + layerType);
|
fail("Expected Exception - " + layerType);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
// e.printStackTrace();
|
log.error("",e);
|
||||||
String msg = e.getMessage();
|
String msg = e.getMessage();
|
||||||
assertTrue(msg, msg != null && msg.contains("rnn") && msg.contains("batch"));
|
assertTrue(msg, msg != null && msg.contains("rnn") && msg.contains("batch"));
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,6 +21,7 @@ import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.TestUtils;
|
import org.deeplearning4j.TestUtils;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
|
||||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||||
|
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
@ -34,6 +35,8 @@ import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
import org.junit.Ignore;
|
import org.junit.Ignore;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
import org.junit.runners.Parameterized;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -51,6 +54,7 @@ import static org.junit.Assert.*;
|
||||||
/**
|
/**
|
||||||
* Created by nyghtowl on 9/1/15.
|
* Created by nyghtowl on 9/1/15.
|
||||||
*/
|
*/
|
||||||
|
@RunWith(Parameterized.class)
|
||||||
public class CNNGradientCheckTest extends BaseDL4JTest {
|
public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
private static final boolean PRINT_RESULTS = true;
|
private static final boolean PRINT_RESULTS = true;
|
||||||
private static final boolean RETURN_ON_FIRST_FAILURE = false;
|
private static final boolean RETURN_ON_FIRST_FAILURE = false;
|
||||||
|
@ -62,6 +66,17 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
Nd4j.setDataType(DataType.DOUBLE);
|
Nd4j.setDataType(DataType.DOUBLE);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private CNN2DFormat format;
|
||||||
|
|
||||||
|
public CNNGradientCheckTest(CNN2DFormat format){
|
||||||
|
this.format = format;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Parameterized.Parameters(name = "{0}")
|
||||||
|
public static Object[] params(){
|
||||||
|
return CNN2DFormat.values();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long getTimeoutMilliseconds() {
|
public long getTimeoutMilliseconds() {
|
||||||
return 90000L;
|
return 90000L;
|
||||||
|
@ -69,6 +84,9 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testGradientCNNMLN() {
|
public void testGradientCNNMLN() {
|
||||||
|
if(this.format != CNN2DFormat.NCHW) //Only test NCHW due to flat input format...
|
||||||
|
return;
|
||||||
|
|
||||||
//Parameterized test, testing combinations of:
|
//Parameterized test, testing combinations of:
|
||||||
// (a) activation function
|
// (a) activation function
|
||||||
// (b) Whether to test at random initialization, or after some learning (i.e., 'characteristic mode of operation')
|
// (b) Whether to test at random initialization, or after some learning (i.e., 'characteristic mode of operation')
|
||||||
|
@ -144,6 +162,9 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testGradientCNNL1L2MLN() {
|
public void testGradientCNNL1L2MLN() {
|
||||||
|
if(this.format != CNN2DFormat.NCHW) //Only test NCHW due to flat input format...
|
||||||
|
return;
|
||||||
|
|
||||||
//Parameterized test, testing combinations of:
|
//Parameterized test, testing combinations of:
|
||||||
// (a) activation function
|
// (a) activation function
|
||||||
// (b) Whether to test at random initialization, or after some learning (i.e., 'characteristic mode of operation')
|
// (b) Whether to test at random initialization, or after some learning (i.e., 'characteristic mode of operation')
|
||||||
|
@ -311,10 +332,12 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
new SubsamplingLayer.PoolingType[]{SubsamplingLayer.PoolingType.MAX,
|
new SubsamplingLayer.PoolingType[]{SubsamplingLayer.PoolingType.MAX,
|
||||||
SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM};
|
SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM};
|
||||||
|
|
||||||
|
boolean nchw = format == CNN2DFormat.NCHW;
|
||||||
for (String afn : activations) {
|
for (String afn : activations) {
|
||||||
for (SubsamplingLayer.PoolingType poolingType : poolingTypes) {
|
for (SubsamplingLayer.PoolingType poolingType : poolingTypes) {
|
||||||
for (int minibatchSize : minibatchSizes) {
|
for (int minibatchSize : minibatchSizes) {
|
||||||
INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth);
|
long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth};
|
||||||
|
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
|
||||||
INDArray labels = Nd4j.zeros(4 * minibatchSize, nOut);
|
INDArray labels = Nd4j.zeros(4 * minibatchSize, nOut);
|
||||||
for (int i = 0; i < 4 * minibatchSize; i++) {
|
for (int i = 0; i < 4 * minibatchSize; i++) {
|
||||||
labels.putScalar(new int[]{i, i % nOut}, 1.0);
|
labels.putScalar(new int[]{i, i % nOut}, 1.0);
|
||||||
|
@ -330,13 +353,13 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
|
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
|
||||||
.activation(Activation.SOFTMAX)
|
.activation(Activation.SOFTMAX)
|
||||||
.nOut(nOut).build())
|
.nOut(nOut).build())
|
||||||
.setInputType(InputType.convolutionalFlat(height, width, inputDepth))
|
.setInputType(InputType.convolutional(height, width, inputDepth, format))
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
String msg = "PoolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn="
|
String msg = format + " - poolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn="
|
||||||
+ afn;
|
+ afn;
|
||||||
|
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
|
@ -377,8 +400,11 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
int[] padding = {0, 0};
|
int[] padding = {0, 0};
|
||||||
int size = 2;
|
int size = 2;
|
||||||
|
|
||||||
|
boolean nchw = format == CNN2DFormat.NCHW;
|
||||||
|
|
||||||
for (int minibatchSize : minibatchSizes) {
|
for (int minibatchSize : minibatchSizes) {
|
||||||
INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth);
|
long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth};
|
||||||
|
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
|
||||||
INDArray labels = TestUtils.randomOneHot(minibatchSize, nOut);
|
INDArray labels = TestUtils.randomOneHot(minibatchSize, nOut);
|
||||||
|
|
||||||
MultiLayerConfiguration conf =
|
MultiLayerConfiguration conf =
|
||||||
|
@ -393,8 +419,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
|
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
|
||||||
.activation(Activation.SOFTMAX).nIn(8 * 8 * 3)
|
.activation(Activation.SOFTMAX).nIn(8 * 8 * 3)
|
||||||
.nOut(4).build())
|
.nOut(4).build())
|
||||||
.setInputType(InputType.convolutionalFlat(height, width,
|
.setInputType(InputType.convolutional(height, width, inputDepth, format))
|
||||||
inputDepth))
|
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
|
@ -438,10 +463,13 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
new SubsamplingLayer.PoolingType[]{SubsamplingLayer.PoolingType.MAX,
|
new SubsamplingLayer.PoolingType[]{SubsamplingLayer.PoolingType.MAX,
|
||||||
SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM};
|
SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM};
|
||||||
|
|
||||||
|
boolean nchw = format == CNN2DFormat.NCHW;
|
||||||
|
|
||||||
for (Activation afn : activations) {
|
for (Activation afn : activations) {
|
||||||
for (SubsamplingLayer.PoolingType poolingType : poolingTypes) {
|
for (SubsamplingLayer.PoolingType poolingType : poolingTypes) {
|
||||||
for (int minibatchSize : minibatchSizes) {
|
for (int minibatchSize : minibatchSizes) {
|
||||||
INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth);
|
long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth};
|
||||||
|
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
|
||||||
INDArray labels = Nd4j.zeros(minibatchSize, nOut);
|
INDArray labels = Nd4j.zeros(minibatchSize, nOut);
|
||||||
for (int i = 0; i < minibatchSize; i++) {
|
for (int i = 0; i < minibatchSize; i++) {
|
||||||
labels.putScalar(new int[]{i, i % nOut}, 1.0);
|
labels.putScalar(new int[]{i, i % nOut}, 1.0);
|
||||||
|
@ -461,14 +489,13 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
|
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
|
||||||
.activation(Activation.SOFTMAX).nIn(3 * 3 * 3)
|
.activation(Activation.SOFTMAX).nIn(3 * 3 * 3)
|
||||||
.nOut(4).build())
|
.nOut(4).build())
|
||||||
.setInputType(InputType.convolutionalFlat(height, width,
|
.setInputType(InputType.convolutional(height, width, inputDepth, format))
|
||||||
inputDepth))
|
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
String msg = "PoolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn="
|
String msg = format + " - poolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn="
|
||||||
+ afn;
|
+ afn;
|
||||||
|
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
|
@ -508,10 +535,13 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
new SubsamplingLayer.PoolingType[]{SubsamplingLayer.PoolingType.MAX,
|
new SubsamplingLayer.PoolingType[]{SubsamplingLayer.PoolingType.MAX,
|
||||||
SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM};
|
SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM};
|
||||||
|
|
||||||
|
boolean nchw = format == CNN2DFormat.NCHW;
|
||||||
|
|
||||||
for (Activation afn : activations) {
|
for (Activation afn : activations) {
|
||||||
for (SubsamplingLayer.PoolingType poolingType : poolingTypes) {
|
for (SubsamplingLayer.PoolingType poolingType : poolingTypes) {
|
||||||
for (int minibatchSize : minibatchSizes) {
|
for (int minibatchSize : minibatchSizes) {
|
||||||
INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth);
|
long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth};
|
||||||
|
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
|
||||||
INDArray labels = Nd4j.zeros(minibatchSize, nOut);
|
INDArray labels = Nd4j.zeros(minibatchSize, nOut);
|
||||||
for (int i = 0; i < minibatchSize; i++) {
|
for (int i = 0; i < minibatchSize; i++) {
|
||||||
labels.putScalar(new int[]{i, i % nOut}, 1.0);
|
labels.putScalar(new int[]{i, i % nOut}, 1.0);
|
||||||
|
@ -533,8 +563,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
|
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
|
||||||
.activation(Activation.SOFTMAX).nIn(2 * 2 * 2)
|
.activation(Activation.SOFTMAX).nIn(2 * 2 * 2)
|
||||||
.nOut(4).build())
|
.nOut(4).build())
|
||||||
.setInputType(InputType.convolutionalFlat(height, width,
|
.setInputType(InputType.convolutional(height, width, inputDepth, format))
|
||||||
inputDepth))
|
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
|
@ -558,8 +587,6 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
@Test
|
@Test
|
||||||
public void testCnnLocallyConnected2D() {
|
public void testCnnLocallyConnected2D() {
|
||||||
int nOut = 3;
|
int nOut = 3;
|
||||||
|
|
||||||
int[] minibatchSizes = {2};
|
|
||||||
int width = 5;
|
int width = 5;
|
||||||
int height = 5;
|
int height = 5;
|
||||||
|
|
||||||
|
@ -569,11 +596,15 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
Activation[] activations = {Activation.SIGMOID, Activation.TANH, Activation.SOFTPLUS};
|
Activation[] activations = {Activation.SIGMOID, Activation.TANH, Activation.SOFTPLUS};
|
||||||
int[] minibatch = {2, 1, 3};
|
int[] minibatch = {2, 1, 3};
|
||||||
|
|
||||||
|
boolean nchw = format == CNN2DFormat.NCHW;
|
||||||
|
|
||||||
for( int i=0; i<inputDepths.length; i++ ){
|
for( int i=0; i<inputDepths.length; i++ ){
|
||||||
int inputDepth = inputDepths[i];
|
int inputDepth = inputDepths[i];
|
||||||
Activation afn = activations[i];
|
Activation afn = activations[i];
|
||||||
int minibatchSize = minibatch[i];
|
int minibatchSize = minibatch[i];
|
||||||
INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth);
|
|
||||||
|
long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth};
|
||||||
|
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
|
||||||
INDArray labels = TestUtils.randomOneHot(minibatchSize, nOut);
|
INDArray labels = TestUtils.randomOneHot(minibatchSize, nOut);
|
||||||
|
|
||||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(new NoOp())
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(new NoOp())
|
||||||
|
@ -590,7 +621,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
|
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
|
||||||
.activation(Activation.SOFTMAX).nIn(2 * 2 * 2).nOut(nOut)
|
.activation(Activation.SOFTMAX).nIn(2 * 2 * 2).nOut(nOut)
|
||||||
.build())
|
.build())
|
||||||
.setInputType(InputType.convolutionalFlat(height, width, inputDepth)).build();
|
.setInputType(InputType.convolutional(height, width, inputDepth, format)).build();
|
||||||
|
|
||||||
assertEquals(ConvolutionMode.Truncate,
|
assertEquals(ConvolutionMode.Truncate,
|
||||||
((ConvolutionLayer) conf.getConf(0).getLayer()).getConvolutionMode());
|
((ConvolutionLayer) conf.getConf(0).getLayer()).getConvolutionMode());
|
||||||
|
@ -626,11 +657,15 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
|
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
boolean nchw = format == CNN2DFormat.NCHW;
|
||||||
|
|
||||||
for (int inputDepth : inputDepths) {
|
for (int inputDepth : inputDepths) {
|
||||||
for (Activation afn : activations) {
|
for (Activation afn : activations) {
|
||||||
for (SubsamplingLayer.PoolingType poolingType : poolingTypes) {
|
for (SubsamplingLayer.PoolingType poolingType : poolingTypes) {
|
||||||
for (int minibatchSize : minibatchSizes) {
|
for (int minibatchSize : minibatchSizes) {
|
||||||
INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth);
|
long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth};
|
||||||
|
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
|
||||||
|
|
||||||
INDArray labels = Nd4j.zeros(minibatchSize, nOut);
|
INDArray labels = Nd4j.zeros(minibatchSize, nOut);
|
||||||
for (int i = 0; i < minibatchSize; i++) {
|
for (int i = 0; i < minibatchSize; i++) {
|
||||||
labels.putScalar(new int[]{i, i % nOut}, 1.0);
|
labels.putScalar(new int[]{i, i % nOut}, 1.0);
|
||||||
|
@ -649,7 +684,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
|
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
|
||||||
.activation(Activation.SOFTMAX).nIn(2 * 2 * 2).nOut(nOut)
|
.activation(Activation.SOFTMAX).nIn(2 * 2 * 2).nOut(nOut)
|
||||||
.build())
|
.build())
|
||||||
.setInputType(InputType.convolutionalFlat(height, width, inputDepth)).build();
|
.setInputType(InputType.convolutional(height, width, inputDepth, format)).build();
|
||||||
|
|
||||||
assertEquals(ConvolutionMode.Truncate,
|
assertEquals(ConvolutionMode.Truncate,
|
||||||
((ConvolutionLayer) conf.getConf(0).getLayer()).getConvolutionMode());
|
((ConvolutionLayer) conf.getConf(0).getLayer()).getConvolutionMode());
|
||||||
|
@ -691,13 +726,17 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
|
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
boolean nchw = format == CNN2DFormat.NCHW;
|
||||||
|
|
||||||
for( int i=0; i<minibatchSizes.length; i++ ){
|
for( int i=0; i<minibatchSizes.length; i++ ){
|
||||||
int inputDepth = inputDepths[i];
|
int inputDepth = inputDepths[i];
|
||||||
int minibatchSize = minibatchSizes[i];
|
int minibatchSize = minibatchSizes[i];
|
||||||
int height = heights[i];
|
int height = heights[i];
|
||||||
int k = kernelSizes[i];
|
int k = kernelSizes[i];
|
||||||
|
|
||||||
INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth);
|
long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth};
|
||||||
|
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
|
||||||
|
|
||||||
INDArray labels = TestUtils.randomOneHot(minibatchSize, nOut);
|
INDArray labels = TestUtils.randomOneHot(minibatchSize, nOut);
|
||||||
|
|
||||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
|
||||||
|
@ -713,7 +752,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
.stride(1, 1).padding(0, 0).build())
|
.stride(1, 1).padding(0, 0).build())
|
||||||
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
|
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
|
||||||
.activation(Activation.SOFTMAX).nOut(nOut).build())
|
.activation(Activation.SOFTMAX).nOut(nOut).build())
|
||||||
.setInputType(InputType.convolutionalFlat(height, width, inputDepth)).build();
|
.setInputType(InputType.convolutional(height, width, inputDepth, format)).build();
|
||||||
|
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
@ -748,13 +787,16 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
|
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
boolean nchw = format == CNN2DFormat.NCHW;
|
||||||
|
|
||||||
for (int inputDepth : inputDepths) {
|
for (int inputDepth : inputDepths) {
|
||||||
for (int minibatchSize : minibatchSizes) {
|
for (int minibatchSize : minibatchSizes) {
|
||||||
for (int stride : strides) {
|
for (int stride : strides) {
|
||||||
for (int k : kernelSizes) {
|
for (int k : kernelSizes) {
|
||||||
for (boolean convFirst : new boolean[]{true, false}) {
|
for (boolean convFirst : new boolean[]{true, false}) {
|
||||||
|
long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth};
|
||||||
|
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
|
||||||
|
|
||||||
INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth);
|
|
||||||
INDArray labels = Nd4j.zeros(minibatchSize, nOut);
|
INDArray labels = Nd4j.zeros(minibatchSize, nOut);
|
||||||
for (int i = 0; i < minibatchSize; i++) {
|
for (int i = 0; i < minibatchSize; i++) {
|
||||||
labels.putScalar(new int[]{i, i % nOut}, 1.0);
|
labels.putScalar(new int[]{i, i % nOut}, 1.0);
|
||||||
|
@ -775,7 +817,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
.layer(1, convFirst ? poolLayer : convLayer)
|
.layer(1, convFirst ? poolLayer : convLayer)
|
||||||
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
|
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
|
||||||
.activation(Activation.SOFTMAX).nOut(nOut).build())
|
.activation(Activation.SOFTMAX).nOut(nOut).build())
|
||||||
.setInputType(InputType.convolutionalFlat(height, width, inputDepth))
|
.setInputType(InputType.convolutional(height, width, inputDepth, format))
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
|
@ -822,11 +864,15 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
int[] inputDepths = {1, 3, 2};
|
int[] inputDepths = {1, 3, 2};
|
||||||
int[][] zeroPadLayer = new int[][]{{0, 0, 0, 0}, {1, 1, 0, 0}, {2, 2, 2, 2}};
|
int[][] zeroPadLayer = new int[][]{{0, 0, 0, 0}, {1, 1, 0, 0}, {2, 2, 2, 2}};
|
||||||
|
|
||||||
|
boolean nchw = format == CNN2DFormat.NCHW;
|
||||||
|
|
||||||
for( int i=0; i<minibatchSizes.length; i++ ){
|
for( int i=0; i<minibatchSizes.length; i++ ){
|
||||||
int minibatchSize = minibatchSizes[i];
|
int minibatchSize = minibatchSizes[i];
|
||||||
int inputDepth = inputDepths[i];
|
int inputDepth = inputDepths[i];
|
||||||
int[] zeroPad = zeroPadLayer[i];
|
int[] zeroPad = zeroPadLayer[i];
|
||||||
INDArray input = Nd4j.rand(DataType.DOUBLE, new int[]{minibatchSize, inputDepth, height, width});
|
|
||||||
|
long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth};
|
||||||
|
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
|
||||||
INDArray labels = TestUtils.randomOneHot(minibatchSize, nOut);
|
INDArray labels = TestUtils.randomOneHot(minibatchSize, nOut);
|
||||||
|
|
||||||
MultiLayerConfiguration conf =
|
MultiLayerConfiguration conf =
|
||||||
|
@ -840,7 +886,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
padding).nIn(3).nOut(3).build())//output: (6-2+0)/1+1 = 5
|
padding).nIn(3).nOut(3).build())//output: (6-2+0)/1+1 = 5
|
||||||
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
|
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
|
||||||
.activation(Activation.SOFTMAX).nOut(4).build())
|
.activation(Activation.SOFTMAX).nOut(4).build())
|
||||||
.setInputType(InputType.convolutional(height, width, inputDepth))
|
.setInputType(InputType.convolutional(height, width, inputDepth, format))
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
|
@ -849,8 +895,14 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
//Check zero padding activation shape
|
//Check zero padding activation shape
|
||||||
org.deeplearning4j.nn.layers.convolution.ZeroPaddingLayer zpl =
|
org.deeplearning4j.nn.layers.convolution.ZeroPaddingLayer zpl =
|
||||||
(org.deeplearning4j.nn.layers.convolution.ZeroPaddingLayer) net.getLayer(1);
|
(org.deeplearning4j.nn.layers.convolution.ZeroPaddingLayer) net.getLayer(1);
|
||||||
val expShape = new long[]{minibatchSize, inputDepth, height + zeroPad[0] + zeroPad[1],
|
long[] expShape;
|
||||||
width + zeroPad[2] + zeroPad[3]};
|
if(nchw){
|
||||||
|
expShape = new long[]{minibatchSize, inputDepth, height + zeroPad[0] + zeroPad[1],
|
||||||
|
width + zeroPad[2] + zeroPad[3]};
|
||||||
|
} else {
|
||||||
|
expShape = new long[]{minibatchSize, height + zeroPad[0] + zeroPad[1],
|
||||||
|
width + zeroPad[2] + zeroPad[3], inputDepth};
|
||||||
|
}
|
||||||
INDArray out = zpl.activate(input, false, LayerWorkspaceMgr.noWorkspaces());
|
INDArray out = zpl.activate(input, false, LayerWorkspaceMgr.noWorkspaces());
|
||||||
assertArrayEquals(expShape, out.shape());
|
assertArrayEquals(expShape, out.shape());
|
||||||
|
|
||||||
|
@ -888,6 +940,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
|
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
boolean nchw = format == CNN2DFormat.NCHW;
|
||||||
|
|
||||||
for (int i = 0; i < minibatchSizes.length; i++) {
|
for (int i = 0; i < minibatchSizes.length; i++) {
|
||||||
int minibatchSize = minibatchSizes[i];
|
int minibatchSize = minibatchSizes[i];
|
||||||
int k = kernelSizes[i];
|
int k = kernelSizes[i];
|
||||||
|
@ -900,7 +954,10 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
int w = d * width;
|
int w = d * width;
|
||||||
int h = d * height;
|
int h = d * height;
|
||||||
|
|
||||||
INDArray input = Nd4j.rand(minibatchSize, w * h * inputDepth);
|
long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, h, w} : new long[]{minibatchSize, h, w, inputDepth};
|
||||||
|
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
|
||||||
|
|
||||||
|
|
||||||
INDArray labels = Nd4j.zeros(minibatchSize, nOut);
|
INDArray labels = Nd4j.zeros(minibatchSize, nOut);
|
||||||
for (int j = 0; j < minibatchSize; j++) {
|
for (int j = 0; j < minibatchSize; j++) {
|
||||||
labels.putScalar(new int[]{j, j % nOut}, 1.0);
|
labels.putScalar(new int[]{j, j % nOut}, 1.0);
|
||||||
|
@ -920,7 +977,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
|
|
||||||
MultiLayerConfiguration conf = b.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
|
MultiLayerConfiguration conf = b.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
|
||||||
.activation(Activation.SOFTMAX).nOut(nOut).build())
|
.activation(Activation.SOFTMAX).nOut(nOut).build())
|
||||||
.setInputType(InputType.convolutionalFlat(h, w, inputDepth)).build();
|
.setInputType(InputType.convolutional(h, w, inputDepth, format)).build();
|
||||||
|
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
@ -945,8 +1002,6 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
@Test
|
@Test
|
||||||
public void testSeparableConv2D() {
|
public void testSeparableConv2D() {
|
||||||
int nOut = 2;
|
int nOut = 2;
|
||||||
|
|
||||||
int[] minibatchSizes = new int[]{1, 3};
|
|
||||||
int width = 6;
|
int width = 6;
|
||||||
int height = 6;
|
int height = 6;
|
||||||
int inputDepth = 3;
|
int inputDepth = 3;
|
||||||
|
@ -959,6 +1014,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
ConvolutionMode[] cms = new ConvolutionMode[]{Truncate, Truncate, Truncate, Truncate, Truncate};
|
ConvolutionMode[] cms = new ConvolutionMode[]{Truncate, Truncate, Truncate, Truncate, Truncate};
|
||||||
int[] mb = new int[]{1, 1, 1, 3, 3};
|
int[] mb = new int[]{1, 1, 1, 3, 3};
|
||||||
|
|
||||||
|
boolean nchw = format == CNN2DFormat.NCHW;
|
||||||
|
|
||||||
for (int t = 0; t < ks.length; t++) {
|
for (int t = 0; t < ks.length; t++) {
|
||||||
|
|
||||||
int k = ks[t];
|
int k = ks[t];
|
||||||
|
@ -971,7 +1028,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
int w = d * width;
|
int w = d * width;
|
||||||
int h = d * height;
|
int h = d * height;
|
||||||
|
|
||||||
INDArray input = Nd4j.rand(minibatchSize, w * h * inputDepth);
|
long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, h, w} : new long[]{minibatchSize, h, w, inputDepth};
|
||||||
|
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
|
||||||
INDArray labels = Nd4j.zeros(minibatchSize, nOut);
|
INDArray labels = Nd4j.zeros(minibatchSize, nOut);
|
||||||
for (int i = 0; i < minibatchSize; i++) {
|
for (int i = 0; i < minibatchSize; i++) {
|
||||||
labels.putScalar(new int[]{i, i % nOut}, 1.0);
|
labels.putScalar(new int[]{i, i % nOut}, 1.0);
|
||||||
|
@ -992,7 +1050,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
|
|
||||||
MultiLayerConfiguration conf = b.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
|
MultiLayerConfiguration conf = b.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
|
||||||
.activation(Activation.SOFTMAX).nOut(nOut).build())
|
.activation(Activation.SOFTMAX).nOut(nOut).build())
|
||||||
.setInputType(InputType.convolutionalFlat(h, w, inputDepth)).build();
|
.setInputType(InputType.convolutional(h, w, inputDepth, format)).build();
|
||||||
|
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
@ -1017,7 +1075,6 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
@Test
|
@Test
|
||||||
public void testCnnDilated() {
|
public void testCnnDilated() {
|
||||||
int nOut = 2;
|
int nOut = 2;
|
||||||
|
|
||||||
int minibatchSize = 2;
|
int minibatchSize = 2;
|
||||||
int width = 8;
|
int width = 8;
|
||||||
int height = 8;
|
int height = 8;
|
||||||
|
@ -1031,9 +1088,9 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
int[] ds = new int[]{2, 2, 3, 3, 2};
|
int[] ds = new int[]{2, 2, 3, 3, 2};
|
||||||
ConvolutionMode[] cms = new ConvolutionMode[]{Same, Truncate, Truncate, Same, Truncate};
|
ConvolutionMode[] cms = new ConvolutionMode[]{Same, Truncate, Truncate, Same, Truncate};
|
||||||
|
|
||||||
|
boolean nchw = format == CNN2DFormat.NCHW;
|
||||||
|
|
||||||
for (int t = 0; t < sub.length; t++) {
|
for (int t = 0; t < sub.length; t++) {
|
||||||
|
|
||||||
boolean subsampling = sub[t];
|
boolean subsampling = sub[t];
|
||||||
int s = stride[t];
|
int s = stride[t];
|
||||||
int k = kernel[t];
|
int k = kernel[t];
|
||||||
|
@ -1044,7 +1101,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
int w = d * width;
|
int w = d * width;
|
||||||
int h = d * height;
|
int h = d * height;
|
||||||
|
|
||||||
INDArray input = Nd4j.rand(minibatchSize, w * h * inputDepth);
|
long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, h, w} : new long[]{minibatchSize, h, w, inputDepth};
|
||||||
|
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
|
||||||
INDArray labels = Nd4j.zeros(minibatchSize, nOut);
|
INDArray labels = Nd4j.zeros(minibatchSize, nOut);
|
||||||
for (int i = 0; i < minibatchSize; i++) {
|
for (int i = 0; i < minibatchSize; i++) {
|
||||||
labels.putScalar(new int[]{i, i % nOut}, 1.0);
|
labels.putScalar(new int[]{i, i % nOut}, 1.0);
|
||||||
|
@ -1076,7 +1134,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
|
|
||||||
MultiLayerConfiguration conf = b.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
|
MultiLayerConfiguration conf = b.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
|
||||||
.activation(Activation.SOFTMAX).nOut(nOut).build())
|
.activation(Activation.SOFTMAX).nOut(nOut).build())
|
||||||
.setInputType(InputType.convolutionalFlat(h, w, inputDepth)).build();
|
.setInputType(InputType.convolutional(h, w, inputDepth, format)).build();
|
||||||
|
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
@ -1114,11 +1172,14 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
int[] inputDepths = {1, 2, 3, 2};
|
int[] inputDepths = {1, 2, 3, 2};
|
||||||
int[] minibatchSizes = {2, 1, 3, 2};
|
int[] minibatchSizes = {2, 1, 3, 2};
|
||||||
|
|
||||||
|
boolean nchw = format == CNN2DFormat.NCHW;
|
||||||
|
|
||||||
for (int i = 0; i < cropTestCases.length; i++) {
|
for (int i = 0; i < cropTestCases.length; i++) {
|
||||||
int inputDepth = inputDepths[i];
|
int inputDepth = inputDepths[i];
|
||||||
int minibatchSize = minibatchSizes[i];
|
int minibatchSize = minibatchSizes[i];
|
||||||
int[] crop = cropTestCases[i];
|
int[] crop = cropTestCases[i];
|
||||||
INDArray input = Nd4j.rand(new int[]{minibatchSize, inputDepth, height, width});
|
long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth};
|
||||||
|
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
|
||||||
INDArray labels = TestUtils.randomOneHot(minibatchSize, nOut);
|
INDArray labels = TestUtils.randomOneHot(minibatchSize, nOut);
|
||||||
|
|
||||||
MultiLayerConfiguration conf =
|
MultiLayerConfiguration conf =
|
||||||
|
@ -1134,7 +1195,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
.layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.AVG).kernelSize(3, 3).stride(3, 3).build())
|
.layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.AVG).kernelSize(3, 3).stride(3, 3).build())
|
||||||
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
|
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
|
||||||
.activation(Activation.SOFTMAX).nOut(nOut).build())
|
.activation(Activation.SOFTMAX).nOut(nOut).build())
|
||||||
.setInputType(InputType.convolutional(height, width, inputDepth))
|
.setInputType(InputType.convolutional(height, width, inputDepth, format))
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
|
@ -1143,12 +1204,18 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
//Check cropping activation shape
|
//Check cropping activation shape
|
||||||
org.deeplearning4j.nn.layers.convolution.Cropping2DLayer cl =
|
org.deeplearning4j.nn.layers.convolution.Cropping2DLayer cl =
|
||||||
(org.deeplearning4j.nn.layers.convolution.Cropping2DLayer) net.getLayer(1);
|
(org.deeplearning4j.nn.layers.convolution.Cropping2DLayer) net.getLayer(1);
|
||||||
val expShape = new long[]{minibatchSize, inputDepth, height - crop[0] - crop[1],
|
long[] expShape;
|
||||||
width - crop[2] - crop[3]};
|
if(nchw){
|
||||||
|
expShape = new long[]{minibatchSize, inputDepth, height - crop[0] - crop[1],
|
||||||
|
width - crop[2] - crop[3]};
|
||||||
|
} else {
|
||||||
|
expShape = new long[]{minibatchSize, height - crop[0] - crop[1],
|
||||||
|
width - crop[2] - crop[3], inputDepth};
|
||||||
|
}
|
||||||
INDArray out = cl.activate(input, false, LayerWorkspaceMgr.noWorkspaces());
|
INDArray out = cl.activate(input, false, LayerWorkspaceMgr.noWorkspaces());
|
||||||
assertArrayEquals(expShape, out.shape());
|
assertArrayEquals(expShape, out.shape());
|
||||||
|
|
||||||
String msg = "minibatch=" + minibatchSize + ", channels=" + inputDepth + ", zeroPad = "
|
String msg = format + " - minibatch=" + minibatchSize + ", channels=" + inputDepth + ", zeroPad = "
|
||||||
+ Arrays.toString(crop);
|
+ Arrays.toString(crop);
|
||||||
|
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
|
@ -1181,6 +1248,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
Truncate, Truncate, Truncate, Truncate, Truncate};
|
Truncate, Truncate, Truncate, Truncate, Truncate};
|
||||||
int[] mb = new int[]{1,1,1,3,3};
|
int[] mb = new int[]{1,1,1,3,3};
|
||||||
|
|
||||||
|
boolean nchw = format == CNN2DFormat.NCHW;
|
||||||
|
|
||||||
for( int t=0; t<ks.length; t++ ){
|
for( int t=0; t<ks.length; t++ ){
|
||||||
|
|
||||||
int k = ks[t];
|
int k = ks[t];
|
||||||
|
@ -1188,8 +1257,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
ConvolutionMode cm = cms[t];
|
ConvolutionMode cm = cms[t];
|
||||||
int minibatchSize = mb[t];
|
int minibatchSize = mb[t];
|
||||||
|
|
||||||
|
long[] inShape = nchw ? new long[]{minibatchSize, nIn, height, width} : new long[]{minibatchSize, height, width, nIn};
|
||||||
INDArray input = Nd4j.rand(minibatchSize, width * height * nIn);
|
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
|
||||||
INDArray labels = Nd4j.zeros(minibatchSize, nOut);
|
INDArray labels = Nd4j.zeros(minibatchSize, nOut);
|
||||||
for (int i = 0; i < minibatchSize; i++) {
|
for (int i = 0; i < minibatchSize; i++) {
|
||||||
labels.putScalar(new int[]{i, i % nOut}, 1.0);
|
labels.putScalar(new int[]{i, i % nOut}, 1.0);
|
||||||
|
@ -1211,7 +1280,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
|
|
||||||
MultiLayerConfiguration conf = b.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
|
MultiLayerConfiguration conf = b.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
|
||||||
.activation(Activation.SOFTMAX).nOut(nOut).build())
|
.activation(Activation.SOFTMAX).nOut(nOut).build())
|
||||||
.setInputType(InputType.convolutionalFlat(height, width, nIn)).build();
|
.setInputType(InputType.convolutional(height, width, nIn, format)).build();
|
||||||
|
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
|
@ -18,6 +18,7 @@ package org.deeplearning4j.gradientcheck;
|
||||||
|
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.TestUtils;
|
import org.deeplearning4j.TestUtils;
|
||||||
|
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
@ -115,55 +116,57 @@ public class GlobalPoolingGradientCheckTests extends BaseDL4JTest {
|
||||||
//Basic test of global pooling w/ CNN
|
//Basic test of global pooling w/ CNN
|
||||||
Nd4j.getRandom().setSeed(12345L);
|
Nd4j.getRandom().setSeed(12345L);
|
||||||
|
|
||||||
int inputDepth = 3;
|
for(boolean nchw : new boolean[]{true, false}) {
|
||||||
int inputH = 5;
|
|
||||||
int inputW = 4;
|
|
||||||
int layerDepth = 4;
|
|
||||||
int nOut = 2;
|
|
||||||
|
|
||||||
int[] minibatchSizes = new int[] {1, 3};
|
int inputDepth = 3;
|
||||||
PoolingType[] poolingTypes =
|
int inputH = 5;
|
||||||
new PoolingType[] {PoolingType.AVG, PoolingType.SUM, PoolingType.MAX, PoolingType.PNORM};
|
int inputW = 4;
|
||||||
|
int layerDepth = 4;
|
||||||
|
int nOut = 2;
|
||||||
|
|
||||||
for (int miniBatchSize : minibatchSizes) {
|
int[] minibatchSizes = new int[]{1, 3};
|
||||||
for (PoolingType pt : poolingTypes) {
|
PoolingType[] poolingTypes =
|
||||||
|
new PoolingType[]{PoolingType.AVG, PoolingType.SUM, PoolingType.MAX, PoolingType.PNORM};
|
||||||
|
|
||||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
for (int miniBatchSize : minibatchSizes) {
|
||||||
.dataType(DataType.DOUBLE)
|
for (PoolingType pt : poolingTypes) {
|
||||||
.updater(new NoOp())
|
|
||||||
.dist(new NormalDistribution(0, 1.0)).seed(12345L).list()
|
|
||||||
.layer(0, new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).nOut(layerDepth)
|
|
||||||
.build())
|
|
||||||
.layer(1, new GlobalPoolingLayer.Builder().poolingType(pt).build())
|
|
||||||
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
|
|
||||||
.activation(Activation.SOFTMAX).nOut(nOut).build())
|
|
||||||
|
|
||||||
.setInputType(InputType.convolutional(inputH, inputW, inputDepth)).build();
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
|
.dataType(DataType.DOUBLE)
|
||||||
|
.updater(new NoOp())
|
||||||
|
.dist(new NormalDistribution(0, 1.0)).seed(12345L).list()
|
||||||
|
.layer(0, new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).nOut(layerDepth)
|
||||||
|
.build())
|
||||||
|
.layer(1, new GlobalPoolingLayer.Builder().poolingType(pt).build())
|
||||||
|
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
|
||||||
|
.activation(Activation.SOFTMAX).nOut(nOut).build())
|
||||||
|
.setInputType(InputType.convolutional(inputH, inputW, inputDepth, nchw ? CNN2DFormat.NCHW : CNN2DFormat.NHWC)).build();
|
||||||
|
|
||||||
MultiLayerNetwork mln = new MultiLayerNetwork(conf);
|
MultiLayerNetwork mln = new MultiLayerNetwork(conf);
|
||||||
mln.init();
|
mln.init();
|
||||||
|
|
||||||
Random r = new Random(12345L);
|
Random r = new Random(12345L);
|
||||||
INDArray input = Nd4j.rand(new int[] {miniBatchSize, inputDepth, inputH, inputW}).subi(0.5);
|
long[] inShape = nchw ? new long[]{miniBatchSize, inputDepth, inputH, inputW} : new long[]{miniBatchSize, inputH, inputW, inputDepth};
|
||||||
|
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape).subi(0.5);
|
||||||
|
|
||||||
INDArray labels = Nd4j.zeros(miniBatchSize, nOut);
|
INDArray labels = Nd4j.zeros(miniBatchSize, nOut);
|
||||||
for (int i = 0; i < miniBatchSize; i++) {
|
for (int i = 0; i < miniBatchSize; i++) {
|
||||||
int idx = r.nextInt(nOut);
|
int idx = r.nextInt(nOut);
|
||||||
labels.putScalar(i, idx, 1.0);
|
labels.putScalar(i, idx, 1.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println(
|
System.out.println("testCnnGlobalPoolingBasicMultiLayer() - " + pt + ", minibatch = " + miniBatchSize + " - " + (nchw ? "NCHW" : "NHWC"));
|
||||||
"testCnnGlobalPoolingBasicMultiLayer() - " + pt + ", minibatch = " + miniBatchSize);
|
|
||||||
// for (int j = 0; j < mln.getnLayers(); j++)
|
// for (int j = 0; j < mln.getnLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
|
// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
|
||||||
|
}
|
||||||
|
|
||||||
|
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||||
|
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
|
||||||
|
|
||||||
|
assertTrue(gradOK);
|
||||||
|
TestUtils.testModelSerialization(mln);
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
|
||||||
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
|
|
||||||
|
|
||||||
assertTrue(gradOK);
|
|
||||||
TestUtils.testModelSerialization(mln);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,6 +24,7 @@ import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||||
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
||||||
import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
|
import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
|
||||||
import org.deeplearning4j.nn.conf.layers.*;
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
|
@ -560,75 +561,81 @@ public class GradientCheckTests extends BaseDL4JTest {
|
||||||
public void testEmbeddingSequenceLayer(){
|
public void testEmbeddingSequenceLayer(){
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
for(boolean maskArray : new boolean[]{false, true}){
|
for(RNNFormat seqOutputFormat : RNNFormat.values()) {
|
||||||
for(int inputRank : new int[]{2,3}) {
|
for (boolean maskArray : new boolean[]{false, true}) {
|
||||||
|
for (int inputRank : new int[]{2, 3}) {
|
||||||
|
|
||||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
.dataType(DataType.DOUBLE)
|
.dataType(DataType.DOUBLE)
|
||||||
.seed(12345)
|
.seed(12345)
|
||||||
.updater(new NoOp())
|
.updater(new NoOp())
|
||||||
.weightInit(new NormalDistribution(0, 1))
|
.weightInit(new NormalDistribution(0, 1))
|
||||||
.list()
|
.list()
|
||||||
.layer(new EmbeddingSequenceLayer.Builder()
|
.layer(new EmbeddingSequenceLayer.Builder()
|
||||||
.nIn(8)
|
.nIn(8)
|
||||||
.nOut(4)
|
.nOut(4)
|
||||||
.build())
|
.outputDataFormat(seqOutputFormat)
|
||||||
.layer(new RnnOutputLayer.Builder().nIn(4).nOut(3).activation(Activation.TANH)
|
.build())
|
||||||
.lossFunction(LossFunction.MSE).build())
|
.layer(new RnnOutputLayer.Builder().nIn(4).nOut(3).activation(Activation.TANH)
|
||||||
.build();
|
.dataFormat(seqOutputFormat)
|
||||||
|
.lossFunction(LossFunction.MSE).build())
|
||||||
|
.build();
|
||||||
|
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
INDArray in = Transforms.floor(Nd4j.rand(3, 6).muli(8)); //Integers 0 to 7 inclusive
|
boolean ncw = seqOutputFormat == RNNFormat.NCW;
|
||||||
INDArray label = Nd4j.rand(new int[]{3, 3, 6});
|
|
||||||
|
|
||||||
if(inputRank == 3){
|
INDArray in = Transforms.floor(Nd4j.rand(3, 6).muli(8)); //Integers 0 to 7 inclusive
|
||||||
//Reshape from [3,6] to [3,1,6]
|
INDArray label = Nd4j.rand(DataType.FLOAT, ncw ? new int[]{3, 3, 6} : new int[]{3,6,3});
|
||||||
in = in.reshape('c', 3, 1, 6);
|
|
||||||
}
|
|
||||||
|
|
||||||
INDArray fMask = null;
|
if (inputRank == 3) {
|
||||||
if (maskArray) {
|
//Reshape from [3,6] to [3,1,6]
|
||||||
fMask = Nd4j.create(new double[][]{{1, 1, 1, 1, 1, 1},
|
in = in.reshape('c', 3, 1, 6);
|
||||||
{1, 1, 0, 0, 0, 0},
|
|
||||||
{1, 0, 0, 0, 0, 0}});
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
String msg = "mask=" + maskArray + ", inputRank=" + inputRank;
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in)
|
|
||||||
.labels(label).inputMask(fMask));
|
|
||||||
assertTrue(msg, gradOK);
|
|
||||||
TestUtils.testModelSerialization(net);
|
|
||||||
|
|
||||||
|
|
||||||
//Also: if mask is present, double check that the masked steps don't impact score
|
|
||||||
if (maskArray) {
|
|
||||||
DataSet ds = new DataSet(in, label, fMask, null);
|
|
||||||
double score = net.score(ds);
|
|
||||||
if(inputRank == 2){
|
|
||||||
in.putScalar(1, 2, 0);
|
|
||||||
in.putScalar(2, 1, 0);
|
|
||||||
in.putScalar(2, 2, 0);
|
|
||||||
} else {
|
|
||||||
in.putScalar(1, 0, 2, 0);
|
|
||||||
in.putScalar(2, 0, 1, 0);
|
|
||||||
in.putScalar(2, 0, 2, 0);
|
|
||||||
}
|
}
|
||||||
double score2 = net.score(ds);
|
|
||||||
assertEquals(score, score2, 1e-6);
|
INDArray fMask = null;
|
||||||
if(inputRank == 2){
|
if (maskArray) {
|
||||||
in.putScalar(1, 2, 1);
|
fMask = Nd4j.create(new double[][]{{1, 1, 1, 1, 1, 1},
|
||||||
in.putScalar(2, 1, 1);
|
{1, 1, 0, 0, 0, 0},
|
||||||
in.putScalar(2, 2, 1);
|
{1, 0, 0, 0, 0, 0}});
|
||||||
} else {
|
|
||||||
in.putScalar(1, 0, 2, 1);
|
}
|
||||||
in.putScalar(2, 0, 1, 1);
|
|
||||||
in.putScalar(2, 0, 2, 1);
|
String msg = "mask=" + maskArray + ", inputRank=" + inputRank;
|
||||||
|
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in)
|
||||||
|
.labels(label).inputMask(fMask));
|
||||||
|
assertTrue(msg, gradOK);
|
||||||
|
TestUtils.testModelSerialization(net);
|
||||||
|
|
||||||
|
|
||||||
|
//Also: if mask is present, double check that the masked steps don't impact score
|
||||||
|
if (maskArray) {
|
||||||
|
DataSet ds = new DataSet(in, label, fMask, null);
|
||||||
|
double score = net.score(ds);
|
||||||
|
if (inputRank == 2) {
|
||||||
|
in.putScalar(1, 2, 0);
|
||||||
|
in.putScalar(2, 1, 0);
|
||||||
|
in.putScalar(2, 2, 0);
|
||||||
|
} else {
|
||||||
|
in.putScalar(1, 0, 2, 0);
|
||||||
|
in.putScalar(2, 0, 1, 0);
|
||||||
|
in.putScalar(2, 0, 2, 0);
|
||||||
|
}
|
||||||
|
double score2 = net.score(ds);
|
||||||
|
assertEquals(score, score2, 1e-6);
|
||||||
|
if (inputRank == 2) {
|
||||||
|
in.putScalar(1, 2, 1);
|
||||||
|
in.putScalar(2, 1, 1);
|
||||||
|
in.putScalar(2, 2, 1);
|
||||||
|
} else {
|
||||||
|
in.putScalar(1, 0, 2, 1);
|
||||||
|
in.putScalar(2, 0, 1, 1);
|
||||||
|
in.putScalar(2, 0, 2, 1);
|
||||||
|
}
|
||||||
|
double score3 = net.score(ds);
|
||||||
|
assertEquals(score, score3, 1e-6);
|
||||||
}
|
}
|
||||||
double score3 = net.score(ds);
|
|
||||||
assertEquals(score, score3, 1e-6);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,9 +21,7 @@ import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.TestUtils;
|
import org.deeplearning4j.TestUtils;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
|
||||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
import org.deeplearning4j.nn.conf.*;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
|
||||||
import org.deeplearning4j.nn.conf.distribution.GaussianDistribution;
|
import org.deeplearning4j.nn.conf.distribution.GaussianDistribution;
|
||||||
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
||||||
import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
|
import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
|
||||||
|
@ -341,104 +339,112 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
|
||||||
@Test
|
@Test
|
||||||
public void testCnnDepthMerge() {
|
public void testCnnDepthMerge() {
|
||||||
|
|
||||||
Nd4j.getRandom().setSeed(12345);
|
for(CNN2DFormat format : CNN2DFormat.values()) {
|
||||||
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
|
|
||||||
.dataType(DataType.DOUBLE)
|
|
||||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
|
||||||
.dist(new NormalDistribution(0, 0.1))
|
|
||||||
.updater(new NoOp()).graphBuilder().addInputs("input")
|
|
||||||
.addLayer("l1", new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).padding(0, 0)
|
|
||||||
.nIn(2).nOut(2).activation(Activation.TANH).build(), "input")
|
|
||||||
.addLayer("l2", new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).padding(0, 0)
|
|
||||||
.nIn(2).nOut(2).activation(Activation.TANH).build(), "input")
|
|
||||||
.addVertex("merge", new MergeVertex(), "l1", "l2")
|
|
||||||
.addLayer("outputLayer",
|
|
||||||
new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT)
|
|
||||||
.activation(Activation.SOFTMAX).nIn(5 * 5 * (2 + 2)).nOut(3)
|
|
||||||
.build(),
|
|
||||||
"merge")
|
|
||||||
.setOutputs("outputLayer")
|
|
||||||
.inputPreProcessor("outputLayer", new CnnToFeedForwardPreProcessor(5, 5, 4))
|
|
||||||
.build();
|
|
||||||
|
|
||||||
ComputationGraph graph = new ComputationGraph(conf);
|
String msg = "testCnnDepthMerge - " + format;
|
||||||
graph.init();
|
|
||||||
|
|
||||||
Random r = new Random(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
INDArray input = Nd4j.rand(new int[] {5, 2, 6, 6}); //Order: examples, channels, height, width
|
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
|
||||||
INDArray labels = Nd4j.zeros(5, 3);
|
.dataType(DataType.DOUBLE)
|
||||||
for (int i = 0; i < 5; i++)
|
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
||||||
labels.putScalar(new int[] {i, r.nextInt(3)}, 1.0);
|
.dist(new NormalDistribution(0, 0.1))
|
||||||
|
.updater(new NoOp()).graphBuilder().addInputs("input")
|
||||||
|
.addLayer("l1", new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).padding(0, 0)
|
||||||
|
.nIn(2).nOut(2).activation(Activation.TANH).build(), "input")
|
||||||
|
.addLayer("l2", new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).padding(0, 0)
|
||||||
|
.nIn(2).nOut(2).activation(Activation.TANH).build(), "input")
|
||||||
|
.addVertex("merge", new MergeVertex(), "l1", "l2")
|
||||||
|
.addLayer("outputLayer",
|
||||||
|
new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT)
|
||||||
|
.activation(Activation.SOFTMAX).nIn(5 * 5 * (2 + 2)).nOut(3)
|
||||||
|
.build(),
|
||||||
|
"merge")
|
||||||
|
.setOutputs("outputLayer")
|
||||||
|
.setInputTypes(InputType.convolutional(6, 6, 2, format))
|
||||||
|
.build();
|
||||||
|
|
||||||
if (PRINT_RESULTS) {
|
ComputationGraph graph = new ComputationGraph(conf);
|
||||||
System.out.println("testCnnDepthMerge()");
|
graph.init();
|
||||||
|
|
||||||
|
Random r = new Random(12345);
|
||||||
|
INDArray input = Nd4j.rand(DataType.DOUBLE, format == CNN2DFormat.NCHW ? new long[]{5,2,6,6} : new long[]{5,6,6,2});
|
||||||
|
INDArray labels = Nd4j.zeros(5, 3);
|
||||||
|
for (int i = 0; i < 5; i++)
|
||||||
|
labels.putScalar(new int[]{i, r.nextInt(3)}, 1.0);
|
||||||
|
|
||||||
|
if (PRINT_RESULTS) {
|
||||||
|
System.out.println(msg);
|
||||||
// for (int j = 0; j < graph.getNumLayers(); j++)
|
// for (int j = 0; j < graph.getNumLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
|
// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
|
||||||
|
}
|
||||||
|
|
||||||
|
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
|
||||||
|
.labels(new INDArray[]{labels}));
|
||||||
|
|
||||||
|
assertTrue(msg, gradOK);
|
||||||
|
TestUtils.testModelSerialization(graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
|
|
||||||
.labels(new INDArray[]{labels}));
|
|
||||||
|
|
||||||
String msg = "testCnnDepthMerge()";
|
|
||||||
assertTrue(msg, gradOK);
|
|
||||||
TestUtils.testModelSerialization(graph);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRNNWithMerging() {
|
public void testRNNWithMerging() {
|
||||||
|
|
||||||
Nd4j.getRandom().setSeed(12345);
|
for(RNNFormat format : RNNFormat.values()) {
|
||||||
ComputationGraphConfiguration conf =
|
|
||||||
new NeuralNetConfiguration.Builder().seed(12345)
|
|
||||||
.dataType(DataType.DOUBLE)
|
|
||||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
|
||||||
.dist(new UniformDistribution(0.2, 0.6))
|
|
||||||
.updater(new NoOp()).graphBuilder().addInputs("input")
|
|
||||||
.setOutputs("out")
|
|
||||||
.addLayer("lstm1",
|
|
||||||
new SimpleRnn.Builder().nIn(3).nOut(3)
|
|
||||||
.activation(Activation.TANH).build(),
|
|
||||||
"input")
|
|
||||||
.addLayer("lstm2",
|
|
||||||
new SimpleRnn.Builder().nIn(3).nOut(3)
|
|
||||||
.activation(Activation.TANH).build(),
|
|
||||||
"lstm1")
|
|
||||||
.addLayer("dense1",
|
|
||||||
new DenseLayer.Builder().nIn(3).nOut(3)
|
|
||||||
.activation(Activation.SIGMOID).build(),
|
|
||||||
"lstm1")
|
|
||||||
.addLayer("lstm3",
|
|
||||||
new SimpleRnn.Builder().nIn(3).nOut(3)
|
|
||||||
.activation(Activation.TANH).build(),
|
|
||||||
"dense1")
|
|
||||||
.addVertex("merge", new MergeVertex(), "lstm2", "lstm3")
|
|
||||||
.addLayer("out", new RnnOutputLayer.Builder().nIn(6).nOut(3)
|
|
||||||
.activation(Activation.SOFTMAX)
|
|
||||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build(),
|
|
||||||
"merge")
|
|
||||||
.inputPreProcessor("dense1", new RnnToFeedForwardPreProcessor())
|
|
||||||
.inputPreProcessor("lstm3", new FeedForwardToRnnPreProcessor())
|
|
||||||
.build();
|
|
||||||
|
|
||||||
ComputationGraph graph = new ComputationGraph(conf);
|
String msg = "testLSTMWithMerging - " + format;
|
||||||
graph.init();
|
|
||||||
|
|
||||||
Random r = new Random(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
INDArray input = Nd4j.rand(new int[] {2, 3, 4});
|
ComputationGraphConfiguration conf =
|
||||||
INDArray labels = TestUtils.randomOneHotTimeSeries(2, 3, 4);
|
new NeuralNetConfiguration.Builder().seed(12345)
|
||||||
|
.dataType(DataType.DOUBLE)
|
||||||
|
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
||||||
|
.dist(new UniformDistribution(0.2, 0.6))
|
||||||
|
.updater(new NoOp()).graphBuilder().addInputs("input")
|
||||||
|
.setOutputs("out")
|
||||||
|
.addLayer("lstm1",
|
||||||
|
new SimpleRnn.Builder().nIn(3).nOut(3)
|
||||||
|
.activation(Activation.TANH).build(),
|
||||||
|
"input")
|
||||||
|
.addLayer("lstm2",
|
||||||
|
new SimpleRnn.Builder().nIn(3).nOut(3)
|
||||||
|
.activation(Activation.TANH).build(),
|
||||||
|
"lstm1")
|
||||||
|
.addLayer("dense1",
|
||||||
|
new DenseLayer.Builder().nIn(3).nOut(3)
|
||||||
|
.activation(Activation.SIGMOID).build(),
|
||||||
|
"lstm1")
|
||||||
|
.addLayer("lstm3",
|
||||||
|
new SimpleRnn.Builder().nIn(3).nOut(3)
|
||||||
|
.activation(Activation.TANH).build(),
|
||||||
|
"dense1")
|
||||||
|
.addVertex("merge", new MergeVertex(), "lstm2", "lstm3")
|
||||||
|
.addLayer("out", new RnnOutputLayer.Builder().nIn(6).nOut(3)
|
||||||
|
.activation(Activation.SOFTMAX)
|
||||||
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build(),
|
||||||
|
"merge")
|
||||||
|
.setInputTypes(InputType.recurrent(4, format))
|
||||||
|
.build();
|
||||||
|
|
||||||
if (PRINT_RESULTS) {
|
ComputationGraph graph = new ComputationGraph(conf);
|
||||||
System.out.println("testLSTMWithMerging()");
|
graph.init();
|
||||||
|
|
||||||
|
Random r = new Random(12345);
|
||||||
|
INDArray input = Nd4j.rand(DataType.DOUBLE, format == RNNFormat.NCW ? new long[]{2, 3, 4} : new long[]{2,4,3});
|
||||||
|
INDArray labels = TestUtils.randomOneHotTimeSeries(format, 2, 3, 4, new Random(12345));
|
||||||
|
|
||||||
|
if (PRINT_RESULTS) {
|
||||||
|
System.out.println(msg);
|
||||||
// for (int j = 0; j < graph.getNumLayers(); j++)
|
// for (int j = 0; j < graph.getNumLayers(); j++)
|
||||||
// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
|
// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
|
||||||
|
}
|
||||||
|
|
||||||
|
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
|
||||||
|
.labels(new INDArray[]{labels}));
|
||||||
|
|
||||||
|
assertTrue(msg, gradOK);
|
||||||
|
TestUtils.testModelSerialization(graph);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
|
|
||||||
.labels(new INDArray[]{labels}));
|
|
||||||
|
|
||||||
String msg = "testLSTMWithMerging()";
|
|
||||||
assertTrue(msg, gradOK);
|
|
||||||
TestUtils.testModelSerialization(graph);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
|
@ -216,7 +216,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
|
||||||
gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||||
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
|
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
failed.add(testName + "\t" + "EXCEPTION");
|
failed.add(testName + "\t" + "EXCEPTION");
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -383,7 +383,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
|
||||||
gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||||
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
|
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
failed.add(testName + "\t" + "EXCEPTION");
|
failed.add(testName + "\t" + "EXCEPTION");
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -693,7 +693,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
|
||||||
gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||||
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
|
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
failed.add(testName + "\t" + "EXCEPTION");
|
failed.add(testName + "\t" + "EXCEPTION");
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
|
@ -338,7 +338,7 @@ public class RnnGradientChecks extends BaseDL4JTest {
|
||||||
.weightInit(WeightInit.XAVIER)
|
.weightInit(WeightInit.XAVIER)
|
||||||
.list()
|
.list()
|
||||||
.layer(new LSTM.Builder().nOut(layerSize).build())
|
.layer(new LSTM.Builder().nOut(layerSize).build())
|
||||||
.layer(new TimeDistributed(new DenseLayer.Builder().nOut(layerSize).activation(Activation.SOFTMAX).build(), 2))
|
.layer(new TimeDistributed(new DenseLayer.Builder().nOut(layerSize).activation(Activation.SOFTMAX).build()))
|
||||||
.layer(new RnnOutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX)
|
.layer(new RnnOutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX)
|
||||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||||
.setInputType(InputType.recurrent(nIn))
|
.setInputType(InputType.recurrent(nIn))
|
||||||
|
|
|
@ -21,6 +21,7 @@ import lombok.AllArgsConstructor;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.exception.DL4JInvalidConfigException;
|
import org.deeplearning4j.exception.DL4JInvalidConfigException;
|
||||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||||
|
@ -47,6 +48,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.Assert.*;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
public class ComputationGraphConfigurationTest extends BaseDL4JTest {
|
public class ComputationGraphConfigurationTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -150,7 +152,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
|
||||||
fail("No exception thrown for invalid configuration");
|
fail("No exception thrown for invalid configuration");
|
||||||
} catch (IllegalStateException e) {
|
} catch (IllegalStateException e) {
|
||||||
//OK - exception is good
|
//OK - exception is good
|
||||||
//e.printStackTrace();
|
log.info(e.toString());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use appendLayer on first layer
|
// Use appendLayer on first layer
|
||||||
|
@ -162,7 +164,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
|
||||||
fail("No exception thrown for invalid configuration");
|
fail("No exception thrown for invalid configuration");
|
||||||
} catch (IllegalStateException e) {
|
} catch (IllegalStateException e) {
|
||||||
//OK - exception is good
|
//OK - exception is good
|
||||||
//e.printStackTrace();
|
log.info(e.toString());
|
||||||
}
|
}
|
||||||
|
|
||||||
//Test no network inputs
|
//Test no network inputs
|
||||||
|
@ -174,7 +176,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
|
||||||
fail("No exception thrown for invalid configuration");
|
fail("No exception thrown for invalid configuration");
|
||||||
} catch (IllegalStateException e) {
|
} catch (IllegalStateException e) {
|
||||||
//OK - exception is good
|
//OK - exception is good
|
||||||
//e.printStackTrace();
|
log.info(e.toString());
|
||||||
}
|
}
|
||||||
|
|
||||||
//Test no network outputs
|
//Test no network outputs
|
||||||
|
@ -185,7 +187,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
|
||||||
fail("No exception thrown for invalid configuration");
|
fail("No exception thrown for invalid configuration");
|
||||||
} catch (IllegalStateException e) {
|
} catch (IllegalStateException e) {
|
||||||
//OK - exception is good
|
//OK - exception is good
|
||||||
//e.printStackTrace();
|
log.info(e.toString());
|
||||||
}
|
}
|
||||||
|
|
||||||
//Test: invalid input
|
//Test: invalid input
|
||||||
|
@ -197,7 +199,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
|
||||||
fail("No exception thrown for invalid configuration");
|
fail("No exception thrown for invalid configuration");
|
||||||
} catch (IllegalStateException e) {
|
} catch (IllegalStateException e) {
|
||||||
//OK - exception is good
|
//OK - exception is good
|
||||||
//e.printStackTrace();
|
log.info(e.toString());
|
||||||
}
|
}
|
||||||
|
|
||||||
//Test: graph with cycles
|
//Test: graph with cycles
|
||||||
|
@ -215,7 +217,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
|
||||||
fail("No exception thrown for invalid configuration");
|
fail("No exception thrown for invalid configuration");
|
||||||
} catch (IllegalStateException e) {
|
} catch (IllegalStateException e) {
|
||||||
//OK - exception is good
|
//OK - exception is good
|
||||||
//e.printStackTrace();
|
log.info(e.toString());
|
||||||
}
|
}
|
||||||
|
|
||||||
//Test: input != inputType count mismatch
|
//Test: input != inputType count mismatch
|
||||||
|
@ -241,7 +243,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
|
||||||
fail("No exception thrown for invalid configuration");
|
fail("No exception thrown for invalid configuration");
|
||||||
} catch (IllegalArgumentException e) {
|
} catch (IllegalArgumentException e) {
|
||||||
//OK - exception is good
|
//OK - exception is good
|
||||||
//e.printStackTrace();
|
log.info(e.toString());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf;
|
package org.deeplearning4j.nn.conf;
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.exception.DL4JInvalidConfigException;
|
import org.deeplearning4j.exception.DL4JInvalidConfigException;
|
||||||
import org.deeplearning4j.nn.api.Layer;
|
import org.deeplearning4j.nn.api.Layer;
|
||||||
|
@ -46,6 +47,7 @@ import static org.junit.Assert.*;
|
||||||
/**
|
/**
|
||||||
* Created by agibsonccc on 11/27/14.
|
* Created by agibsonccc on 11/27/14.
|
||||||
*/
|
*/
|
||||||
|
@Slf4j
|
||||||
public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
|
public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Rule
|
@Rule
|
||||||
|
@ -272,9 +274,9 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
|
||||||
fail("No exception thrown for invalid configuration");
|
fail("No exception thrown for invalid configuration");
|
||||||
} catch (IllegalStateException e) {
|
} catch (IllegalStateException e) {
|
||||||
//OK
|
//OK
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
} catch (Throwable e) {
|
} catch (Throwable e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
fail("Unexpected exception thrown for invalid config");
|
fail("Unexpected exception thrown for invalid config");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -288,9 +290,9 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
|
||||||
fail("No exception thrown for invalid configuration");
|
fail("No exception thrown for invalid configuration");
|
||||||
} catch (IllegalStateException e) {
|
} catch (IllegalStateException e) {
|
||||||
//OK
|
//OK
|
||||||
e.printStackTrace();
|
log.info(e.toString());
|
||||||
} catch (Throwable e) {
|
} catch (Throwable e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
fail("Unexpected exception thrown for invalid config");
|
fail("Unexpected exception thrown for invalid config");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -304,9 +306,9 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
|
||||||
fail("No exception thrown for invalid configuration");
|
fail("No exception thrown for invalid configuration");
|
||||||
} catch (IllegalStateException e) {
|
} catch (IllegalStateException e) {
|
||||||
//OK
|
//OK
|
||||||
e.printStackTrace();
|
log.info(e.toString());
|
||||||
} catch (Throwable e) {
|
} catch (Throwable e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
fail("Unexpected exception thrown for invalid config");
|
fail("Unexpected exception thrown for invalid config");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -96,8 +96,8 @@ public class TestConstraints extends BaseDL4JTest {
|
||||||
} else if (lc instanceof NonNegativeConstraint) {
|
} else if (lc instanceof NonNegativeConstraint) {
|
||||||
assertTrue(RW0.minNumber().doubleValue() >= 0.0);
|
assertTrue(RW0.minNumber().doubleValue() >= 0.0);
|
||||||
} else if (lc instanceof UnitNormConstraint) {
|
} else if (lc instanceof UnitNormConstraint) {
|
||||||
assertEquals(RW0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6);
|
assertEquals(1.0, RW0.norm2(1).minNumber().doubleValue(), 1e-6);
|
||||||
assertEquals(RW0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6);
|
assertEquals(1.0, RW0.norm2(1).maxNumber().doubleValue(), 1e-6);
|
||||||
}
|
}
|
||||||
|
|
||||||
TestUtils.testModelSerialization(net);
|
TestUtils.testModelSerialization(net);
|
||||||
|
@ -149,8 +149,8 @@ public class TestConstraints extends BaseDL4JTest {
|
||||||
} else if (lc instanceof NonNegativeConstraint) {
|
} else if (lc instanceof NonNegativeConstraint) {
|
||||||
assertTrue(b0.minNumber().doubleValue() >= 0.0);
|
assertTrue(b0.minNumber().doubleValue() >= 0.0);
|
||||||
} else if (lc instanceof UnitNormConstraint) {
|
} else if (lc instanceof UnitNormConstraint) {
|
||||||
assertEquals(b0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6);
|
assertEquals(1.0, b0.norm2(1).minNumber().doubleValue(), 1e-6);
|
||||||
assertEquals(b0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6);
|
assertEquals(1.0, b0.norm2(1).maxNumber().doubleValue(), 1e-6);
|
||||||
}
|
}
|
||||||
|
|
||||||
TestUtils.testModelSerialization(net);
|
TestUtils.testModelSerialization(net);
|
||||||
|
@ -201,8 +201,8 @@ public class TestConstraints extends BaseDL4JTest {
|
||||||
} else if (lc instanceof NonNegativeConstraint) {
|
} else if (lc instanceof NonNegativeConstraint) {
|
||||||
assertTrue(w0.minNumber().doubleValue() >= 0.0);
|
assertTrue(w0.minNumber().doubleValue() >= 0.0);
|
||||||
} else if (lc instanceof UnitNormConstraint) {
|
} else if (lc instanceof UnitNormConstraint) {
|
||||||
assertEquals(w0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6);
|
assertEquals(1.0, w0.norm2(1).minNumber().doubleValue(), 1e-6);
|
||||||
assertEquals(w0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6);
|
assertEquals(1.0, w0.norm2(1).maxNumber().doubleValue(), 1e-6);
|
||||||
}
|
}
|
||||||
|
|
||||||
TestUtils.testModelSerialization(net);
|
TestUtils.testModelSerialization(net);
|
||||||
|
@ -259,10 +259,10 @@ public class TestConstraints extends BaseDL4JTest {
|
||||||
assertTrue(w0.minNumber().doubleValue() >= 0.0);
|
assertTrue(w0.minNumber().doubleValue() >= 0.0);
|
||||||
assertTrue(b0.minNumber().doubleValue() >= 0.0);
|
assertTrue(b0.minNumber().doubleValue() >= 0.0);
|
||||||
} else if (lc instanceof UnitNormConstraint) {
|
} else if (lc instanceof UnitNormConstraint) {
|
||||||
assertEquals(w0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6);
|
assertEquals(1.0, w0.norm2(1).minNumber().doubleValue(), 1e-6);
|
||||||
assertEquals(w0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6);
|
assertEquals(1.0, w0.norm2(1).maxNumber().doubleValue(), 1e-6);
|
||||||
assertEquals(b0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6);
|
assertEquals(1.0, b0.norm2(1).minNumber().doubleValue(), 1e-6);
|
||||||
assertEquals(b0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6);
|
assertEquals(1.0, b0.norm2(1).maxNumber().doubleValue(), 1e-6);
|
||||||
}
|
}
|
||||||
|
|
||||||
TestUtils.testModelSerialization(net);
|
TestUtils.testModelSerialization(net);
|
||||||
|
@ -320,10 +320,10 @@ public class TestConstraints extends BaseDL4JTest {
|
||||||
assertTrue(w0.minNumber().doubleValue() >= 0.0);
|
assertTrue(w0.minNumber().doubleValue() >= 0.0);
|
||||||
assertTrue(b0.minNumber().doubleValue() >= 0.0);
|
assertTrue(b0.minNumber().doubleValue() >= 0.0);
|
||||||
} else if (lc instanceof UnitNormConstraint) {
|
} else if (lc instanceof UnitNormConstraint) {
|
||||||
assertEquals(w0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6);
|
assertEquals(1.0, w0.norm2(1).minNumber().doubleValue(), 1e-6);
|
||||||
assertEquals(w0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6);
|
assertEquals(1.0, w0.norm2(1).maxNumber().doubleValue(), 1e-6);
|
||||||
assertEquals(b0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6);
|
assertEquals(1.0, b0.norm2(1).minNumber().doubleValue(), 1e-6);
|
||||||
assertEquals(b0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6);
|
assertEquals(1.0, b0.norm2(1).maxNumber().doubleValue(), 1e-6);
|
||||||
}
|
}
|
||||||
|
|
||||||
TestUtils.testModelSerialization(net);
|
TestUtils.testModelSerialization(net);
|
||||||
|
@ -378,10 +378,10 @@ public class TestConstraints extends BaseDL4JTest {
|
||||||
} else if(lc instanceof NonNegativeConstraint ){
|
} else if(lc instanceof NonNegativeConstraint ){
|
||||||
assertTrue(w0.minNumber().doubleValue() >= 0.0 );
|
assertTrue(w0.minNumber().doubleValue() >= 0.0 );
|
||||||
} else if(lc instanceof UnitNormConstraint ){
|
} else if(lc instanceof UnitNormConstraint ){
|
||||||
assertEquals(w0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6 );
|
assertEquals(1.0, w0.norm2(1).minNumber().doubleValue(), 1e-6 );
|
||||||
assertEquals(w0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6 );
|
assertEquals(1.0, w0.norm2(1).maxNumber().doubleValue(), 1e-6 );
|
||||||
assertEquals(w1.norm2(1).minNumber().doubleValue(), 1.0, 1e-6 );
|
assertEquals(1.0, w1.norm2(1).minNumber().doubleValue(), 1e-6 );
|
||||||
assertEquals(w1.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6 );
|
assertEquals(1.0, w1.norm2(1).maxNumber().doubleValue(), 1e-6 );
|
||||||
}
|
}
|
||||||
|
|
||||||
TestUtils.testModelSerialization(net);
|
TestUtils.testModelSerialization(net);
|
||||||
|
|
|
@ -156,7 +156,7 @@ public class LayerBuilderTest extends BaseDL4JTest {
|
||||||
|
|
||||||
checkSerialization(glstm);
|
checkSerialization(glstm);
|
||||||
|
|
||||||
assertEquals(glstm.getForgetGateBiasInit(), 1.5, 0.0);
|
assertEquals(1.5, glstm.getForgetGateBiasInit(), 0.0);
|
||||||
assertEquals(glstm.nIn, numIn);
|
assertEquals(glstm.nIn, numIn);
|
||||||
assertEquals(glstm.nOut, numOut);
|
assertEquals(glstm.nOut, numOut);
|
||||||
assertTrue(glstm.getActivationFn() instanceof ActivationTanH);
|
assertTrue(glstm.getActivationFn() instanceof ActivationTanH);
|
||||||
|
|
|
@ -17,7 +17,9 @@
|
||||||
package org.deeplearning4j.nn.dtypes;
|
package org.deeplearning4j.nn.dtypes;
|
||||||
|
|
||||||
import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
|
import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
|
||||||
|
import org.deeplearning4j.nn.conf.preprocessor.*;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.layers.TFOpLayer;
|
import org.deeplearning4j.nn.modelimport.keras.layers.TFOpLayer;
|
||||||
|
import org.deeplearning4j.nn.modelimport.keras.preprocessors.TensorFlowCnnToFeedForwardPreProcessor;
|
||||||
import org.nd4j.shade.guava.collect.ImmutableSet;
|
import org.nd4j.shade.guava.collect.ImmutableSet;
|
||||||
import org.nd4j.shade.guava.reflect.ClassPath;
|
import org.nd4j.shade.guava.reflect.ClassPath;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
@ -51,16 +53,11 @@ import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
|
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
|
||||||
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
|
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
|
||||||
import org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer;
|
import org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer;
|
||||||
import org.deeplearning4j.nn.conf.preprocessor.CnnToRnnPreProcessor;
|
|
||||||
import org.deeplearning4j.nn.conf.preprocessor.ComposableInputPreProcessor;
|
|
||||||
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToCnn3DPreProcessor;
|
|
||||||
import org.deeplearning4j.nn.conf.preprocessor.RnnToCnnPreProcessor;
|
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.deeplearning4j.nn.layers.util.IdentityLayer;
|
import org.deeplearning4j.nn.layers.util.IdentityLayer;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor;
|
import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.PermutePreprocessor;
|
import org.deeplearning4j.nn.modelimport.keras.preprocessors.PermutePreprocessor;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
|
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.TensorFlowCnnToFeedForwardPreProcessor;
|
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
import org.deeplearning4j.nn.weights.WeightInitDistribution;
|
import org.deeplearning4j.nn.weights.WeightInitDistribution;
|
||||||
|
@ -97,7 +94,8 @@ public class DTypeTests extends BaseDL4JTest {
|
||||||
Pooling2D.class, //Alias for SubsamplingLayer
|
Pooling2D.class, //Alias for SubsamplingLayer
|
||||||
Convolution2D.class, //Alias for ConvolutionLayer
|
Convolution2D.class, //Alias for ConvolutionLayer
|
||||||
Pooling1D.class, //Alias for Subsampling1D
|
Pooling1D.class, //Alias for Subsampling1D
|
||||||
Convolution1D.class //Alias for Convolution1DLayer
|
Convolution1D.class, //Alias for Convolution1DLayer
|
||||||
|
TensorFlowCnnToFeedForwardPreProcessor.class //Deprecated
|
||||||
));
|
));
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -819,7 +817,7 @@ public class DTypeTests extends BaseDL4JTest {
|
||||||
.layer(new DenseLayer.Builder().nOut(5).build())
|
.layer(new DenseLayer.Builder().nOut(5).build())
|
||||||
.layer(new GravesBidirectionalLSTM.Builder().nIn(5).nOut(5).activation(Activation.TANH).build())
|
.layer(new GravesBidirectionalLSTM.Builder().nIn(5).nOut(5).activation(Activation.TANH).build())
|
||||||
.layer(new Bidirectional(new LSTM.Builder().nIn(5).nOut(5).activation(Activation.TANH).build()))
|
.layer(new Bidirectional(new LSTM.Builder().nIn(5).nOut(5).activation(Activation.TANH).build()))
|
||||||
.layer(new TimeDistributed(new DenseLayer.Builder().nIn(10).nOut(5).activation(Activation.TANH).build(), 2))
|
.layer(new TimeDistributed(new DenseLayer.Builder().nIn(10).nOut(5).activation(Activation.TANH).build()))
|
||||||
.layer(new SimpleRnn.Builder().nIn(5).nOut(5).build())
|
.layer(new SimpleRnn.Builder().nIn(5).nOut(5).build())
|
||||||
.layer(new MaskZeroLayer.Builder().underlying(new SimpleRnn.Builder().nIn(5).nOut(5).build()).maskValue(0.0).build())
|
.layer(new MaskZeroLayer.Builder().underlying(new SimpleRnn.Builder().nIn(5).nOut(5).build()).maskValue(0.0).build())
|
||||||
.layer(secondLast)
|
.layer(secondLast)
|
||||||
|
@ -1078,7 +1076,7 @@ public class DTypeTests extends BaseDL4JTest {
|
||||||
.addLayer("l", new DenseLayer.Builder().nOut(16).build(), "in")
|
.addLayer("l", new DenseLayer.Builder().nOut(16).build(), "in")
|
||||||
.addVertex("preproc", new PreprocessorVertex(new FeedForwardToCnn3DPreProcessor(2, 2, 2, 2, true)), "l")
|
.addVertex("preproc", new PreprocessorVertex(new FeedForwardToCnn3DPreProcessor(2, 2, 2, 2, true)), "l")
|
||||||
.addVertex("preproc2", new PreprocessorVertex(new PermutePreprocessor(0, 2, 3, 4, 1)), "preproc")
|
.addVertex("preproc2", new PreprocessorVertex(new PermutePreprocessor(0, 2, 3, 4, 1)), "preproc")
|
||||||
.addVertex("preproc3", new PreprocessorVertex(new ReshapePreprocessor(new long[]{2, 2, 2, 2}, new long[]{16})), "preproc2")
|
.addVertex("preproc3", new PreprocessorVertex(new ReshapePreprocessor(new long[]{2, 2, 2, 2}, new long[]{16}, false)), "preproc2")
|
||||||
.addLayer("out", new OutputLayer.Builder().nIn(16).nOut(10).build(), "preproc3")
|
.addLayer("out", new OutputLayer.Builder().nIn(16).nOut(10).build(), "preproc3")
|
||||||
.setInputTypes(InputType.feedForward(5))
|
.setInputTypes(InputType.feedForward(5))
|
||||||
.setOutputs("out");
|
.setOutputs("out");
|
||||||
|
@ -1150,7 +1148,7 @@ public class DTypeTests extends BaseDL4JTest {
|
||||||
case 7:
|
case 7:
|
||||||
b.addInputs("in")
|
b.addInputs("in")
|
||||||
.addLayer("1", new ConvolutionLayer.Builder().kernelSize(2, 2).nOut(5).convolutionMode(ConvolutionMode.Same).build(), "in")
|
.addLayer("1", new ConvolutionLayer.Builder().kernelSize(2, 2).nOut(5).convolutionMode(ConvolutionMode.Same).build(), "in")
|
||||||
.addVertex("2", new PreprocessorVertex(new TensorFlowCnnToFeedForwardPreProcessor(28, 28, 5)), "1")
|
.addVertex("2", new PreprocessorVertex(new CnnToFeedForwardPreProcessor(28, 28, 5)), "1")
|
||||||
.addLayer("out", new OutputLayer.Builder().nOut(10).build(), "2")
|
.addLayer("out", new OutputLayer.Builder().nOut(10).build(), "2")
|
||||||
.setOutputs("out")
|
.setOutputs("out")
|
||||||
.setInputTypes(InputType.convolutional(28, 28, 1));
|
.setInputTypes(InputType.convolutional(28, 28, 1));
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.graph;
|
package org.deeplearning4j.nn.graph;
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.datasets.iterator.IteratorDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.IteratorDataSetIterator;
|
||||||
|
@ -54,6 +55,7 @@ import java.util.Map;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.Assert.*;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
public class ComputationGraphTestRNN extends BaseDL4JTest {
|
public class ComputationGraphTestRNN extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -618,7 +620,7 @@ public class ComputationGraphTestRNN extends BaseDL4JTest {
|
||||||
.build();
|
.build();
|
||||||
fail("Exception expected");
|
fail("Exception expected");
|
||||||
} catch (IllegalStateException e){
|
} catch (IllegalStateException e){
|
||||||
// e.printStackTrace();
|
log.error("",e);
|
||||||
assertTrue(e.getMessage().contains("TBPTT") && e.getMessage().contains("validateTbpttConfig"));
|
assertTrue(e.getMessage().contains("TBPTT") && e.getMessage().contains("validateTbpttConfig"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1394,7 +1394,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
|
||||||
|
|
||||||
|
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
//e.printStackTrace();
|
log.error("",e);
|
||||||
if(allowDisconnected){
|
if(allowDisconnected){
|
||||||
fail("No exception expected");
|
fail("No exception expected");
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -416,8 +416,8 @@ public class TestVariableLengthTSCG extends BaseDL4JTest {
|
||||||
INDArray outRow2 = out2.get(NDArrayIndex.point(i), NDArrayIndex.all(),
|
INDArray outRow2 = out2.get(NDArrayIndex.point(i), NDArrayIndex.all(),
|
||||||
NDArrayIndex.point(j));
|
NDArrayIndex.point(j));
|
||||||
for (int k = 0; k < nOut; k++) {
|
for (int k = 0; k < nOut; k++) {
|
||||||
assertEquals(outRow.getDouble(k), 0.0, 0.0);
|
assertEquals(0.0, outRow.getDouble(k), 0.0);
|
||||||
assertEquals(outRow2.getDouble(k), 0.0, 0.0);
|
assertEquals(0.0, outRow2.getDouble(k), 0.0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -60,7 +60,7 @@ public class TestGraphNodes extends BaseDL4JTest {
|
||||||
@Test
|
@Test
|
||||||
public void testMergeNode() {
|
public void testMergeNode() {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType());
|
GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType(), 1);
|
||||||
|
|
||||||
INDArray first = Nd4j.linspace(0, 11, 12, Nd4j.dataType()).reshape(3, 4);
|
INDArray first = Nd4j.linspace(0, 11, 12, Nd4j.dataType()).reshape(3, 4);
|
||||||
INDArray second = Nd4j.linspace(0, 17, 18, Nd4j.dataType()).reshape(3, 6).addi(100);
|
INDArray second = Nd4j.linspace(0, 17, 18, Nd4j.dataType()).reshape(3, 6).addi(100);
|
||||||
|
@ -82,7 +82,7 @@ public class TestGraphNodes extends BaseDL4JTest {
|
||||||
public void testMergeNodeRNN() {
|
public void testMergeNodeRNN() {
|
||||||
|
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType());
|
GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType(), 1);
|
||||||
|
|
||||||
INDArray first = Nd4j.linspace(0, 59, 60, Nd4j.dataType()).reshape(3, 4, 5);
|
INDArray first = Nd4j.linspace(0, 59, 60, Nd4j.dataType()).reshape(3, 4, 5);
|
||||||
INDArray second = Nd4j.linspace(0, 89, 90, Nd4j.dataType()).reshape(3, 6, 5).addi(100);
|
INDArray second = Nd4j.linspace(0, 89, 90, Nd4j.dataType()).reshape(3, 6, 5).addi(100);
|
||||||
|
@ -103,7 +103,7 @@ public class TestGraphNodes extends BaseDL4JTest {
|
||||||
@Test
|
@Test
|
||||||
public void testCnnDepthMerge() {
|
public void testCnnDepthMerge() {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType());
|
GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType(), 1);
|
||||||
|
|
||||||
INDArray first = Nd4j.linspace(0, 3, 4, Nd4j.dataType()).reshape(1, 1, 2, 2);
|
INDArray first = Nd4j.linspace(0, 3, 4, Nd4j.dataType()).reshape(1, 1, 2, 2);
|
||||||
INDArray second = Nd4j.linspace(0, 3, 4, Nd4j.dataType()).reshape(1, 1, 2, 2).addi(10);
|
INDArray second = Nd4j.linspace(0, 3, 4, Nd4j.dataType()).reshape(1, 1, 2, 2).addi(10);
|
||||||
|
|
|
@ -0,0 +1,974 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.deeplearning4j.nn.layers.convolution;
|
||||||
|
|
||||||
|
import lombok.*;
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
|
import org.deeplearning4j.TestUtils;
|
||||||
|
import org.deeplearning4j.nn.api.MaskState;
|
||||||
|
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||||
|
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||||
|
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||||
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.CnnLossLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.ZeroPaddingLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D;
|
||||||
|
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
|
||||||
|
import org.deeplearning4j.nn.conf.preprocessor.ComposableInputPreProcessor;
|
||||||
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
|
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||||
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
import org.junit.runners.Parameterized;
|
||||||
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
|
@RunWith(Parameterized.class)
|
||||||
|
public class ConvDataFormatTests extends BaseDL4JTest {
|
||||||
|
|
||||||
|
private final DataType dataType;
|
||||||
|
|
||||||
|
public ConvDataFormatTests(DataType dataType){
|
||||||
|
this.dataType = dataType;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Parameterized.Parameters(name = "{0}")
|
||||||
|
public static Object[] params(){
|
||||||
|
return new DataType[]{DataType.FLOAT, DataType.DOUBLE};
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testConv2d() {
|
||||||
|
try {
|
||||||
|
for (boolean helpers : new boolean[]{false, true}) {
|
||||||
|
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
Nd4j.getEnvironment().allowHelpers(helpers);
|
||||||
|
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
|
||||||
|
System.out.println(" --- " + msg + " ---");
|
||||||
|
|
||||||
|
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
|
||||||
|
INDArray labels = TestUtils.randomOneHot(2, 10);
|
||||||
|
|
||||||
|
TestCase tc = TestCase.builder()
|
||||||
|
.msg(msg)
|
||||||
|
.net1(getConv2dNet(CNN2DFormat.NCHW, true, cm))
|
||||||
|
.net2(getConv2dNet(CNN2DFormat.NCHW, false, cm))
|
||||||
|
.net3(getConv2dNet(CNN2DFormat.NHWC, true, cm))
|
||||||
|
.net4(getConv2dNet(CNN2DFormat.NHWC, false, cm))
|
||||||
|
.inNCHW(inNCHW)
|
||||||
|
.labelsNCHW(labels)
|
||||||
|
.labelsNHWC(labels)
|
||||||
|
.testLayerIdx(1)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
testHelper(tc);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
Nd4j.getEnvironment().allowHelpers(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testSubsampling2d() {
|
||||||
|
try {
|
||||||
|
for (boolean helpers : new boolean[]{false, true}) {
|
||||||
|
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
Nd4j.getEnvironment().allowHelpers(helpers);
|
||||||
|
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
|
||||||
|
System.out.println(" --- " + msg + " ---");
|
||||||
|
|
||||||
|
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
|
||||||
|
INDArray labels = TestUtils.randomOneHot(2, 10);
|
||||||
|
|
||||||
|
TestCase tc = TestCase.builder()
|
||||||
|
.msg(msg)
|
||||||
|
.net1(getSubsampling2dNet(CNN2DFormat.NCHW, true, cm))
|
||||||
|
.net2(getSubsampling2dNet(CNN2DFormat.NCHW, false, cm))
|
||||||
|
.net3(getSubsampling2dNet(CNN2DFormat.NHWC, true, cm))
|
||||||
|
.net4(getSubsampling2dNet(CNN2DFormat.NHWC, false, cm))
|
||||||
|
.inNCHW(inNCHW)
|
||||||
|
.labelsNCHW(labels)
|
||||||
|
.labelsNHWC(labels)
|
||||||
|
.testLayerIdx(1)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
testHelper(tc);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
Nd4j.getEnvironment().allowHelpers(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testDepthwiseConv2d() {
|
||||||
|
try {
|
||||||
|
for (boolean helpers : new boolean[]{false, true}) {
|
||||||
|
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
Nd4j.getEnvironment().allowHelpers(helpers);
|
||||||
|
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
|
||||||
|
System.out.println(" --- " + msg + " ---");
|
||||||
|
|
||||||
|
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
|
||||||
|
INDArray labels = TestUtils.randomOneHot(2, 10);
|
||||||
|
|
||||||
|
TestCase tc = TestCase.builder()
|
||||||
|
.msg(msg)
|
||||||
|
.net1(getDepthwiseConv2dNet(CNN2DFormat.NCHW, true, cm))
|
||||||
|
.net2(getDepthwiseConv2dNet(CNN2DFormat.NCHW, false, cm))
|
||||||
|
.net3(getDepthwiseConv2dNet(CNN2DFormat.NHWC, true, cm))
|
||||||
|
.net4(getDepthwiseConv2dNet(CNN2DFormat.NHWC, false, cm))
|
||||||
|
.inNCHW(inNCHW)
|
||||||
|
.labelsNCHW(labels)
|
||||||
|
.labelsNHWC(labels)
|
||||||
|
.testLayerIdx(1)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
testHelper(tc);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
Nd4j.getEnvironment().allowHelpers(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testSeparableConv2d() {
|
||||||
|
try {
|
||||||
|
for (boolean helpers : new boolean[]{false, true}) {
|
||||||
|
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
Nd4j.getEnvironment().allowHelpers(helpers);
|
||||||
|
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
|
||||||
|
System.out.println(" --- " + msg + " ---");
|
||||||
|
|
||||||
|
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
|
||||||
|
INDArray labels = TestUtils.randomOneHot(2, 10);
|
||||||
|
|
||||||
|
TestCase tc = TestCase.builder()
|
||||||
|
.msg(msg)
|
||||||
|
.net1(getSeparableConv2dNet(CNN2DFormat.NCHW, true, cm))
|
||||||
|
.net2(getSeparableConv2dNet(CNN2DFormat.NCHW, false, cm))
|
||||||
|
.net3(getSeparableConv2dNet(CNN2DFormat.NHWC, true, cm))
|
||||||
|
.net4(getSeparableConv2dNet(CNN2DFormat.NHWC, false, cm))
|
||||||
|
.inNCHW(inNCHW)
|
||||||
|
.labelsNCHW(labels)
|
||||||
|
.labelsNHWC(labels)
|
||||||
|
.testLayerIdx(1)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
testHelper(tc);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
Nd4j.getEnvironment().allowHelpers(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testDeconv2d() {
|
||||||
|
try {
|
||||||
|
for (boolean helpers : new boolean[]{false, true}) {
|
||||||
|
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
Nd4j.getEnvironment().allowHelpers(helpers);
|
||||||
|
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
|
||||||
|
System.out.println(" --- " + msg + " ---");
|
||||||
|
|
||||||
|
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
|
||||||
|
INDArray labels = TestUtils.randomOneHot(2, 10);
|
||||||
|
|
||||||
|
TestCase tc = TestCase.builder()
|
||||||
|
.msg(msg)
|
||||||
|
.net1(getDeconv2DNet2dNet(CNN2DFormat.NCHW, true, cm))
|
||||||
|
.net2(getDeconv2DNet2dNet(CNN2DFormat.NCHW, false, cm))
|
||||||
|
.net3(getDeconv2DNet2dNet(CNN2DFormat.NHWC, true, cm))
|
||||||
|
.net4(getDeconv2DNet2dNet(CNN2DFormat.NHWC, false, cm))
|
||||||
|
.inNCHW(inNCHW)
|
||||||
|
.labelsNCHW(labels)
|
||||||
|
.labelsNHWC(labels)
|
||||||
|
.testLayerIdx(1)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
testHelper(tc);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
Nd4j.getEnvironment().allowHelpers(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testLRN() {
|
||||||
|
try {
|
||||||
|
for (boolean helpers : new boolean[]{false, true}) {
|
||||||
|
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
Nd4j.getEnvironment().allowHelpers(helpers);
|
||||||
|
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
|
||||||
|
System.out.println(" --- " + msg + " ---");
|
||||||
|
|
||||||
|
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
|
||||||
|
INDArray labels = TestUtils.randomOneHot(2, 10);
|
||||||
|
|
||||||
|
TestCase tc = TestCase.builder()
|
||||||
|
.msg(msg)
|
||||||
|
.net1(getLrnLayer(CNN2DFormat.NCHW, true, cm))
|
||||||
|
.net2(getLrnLayer(CNN2DFormat.NCHW, false, cm))
|
||||||
|
.net3(getLrnLayer(CNN2DFormat.NHWC, true, cm))
|
||||||
|
.net4(getLrnLayer(CNN2DFormat.NHWC, false, cm))
|
||||||
|
.inNCHW(inNCHW)
|
||||||
|
.labelsNCHW(labels)
|
||||||
|
.labelsNHWC(labels)
|
||||||
|
.testLayerIdx(1)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
testHelper(tc);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
Nd4j.getEnvironment().allowHelpers(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testZeroPaddingLayer(){
|
||||||
|
try {
|
||||||
|
for (boolean helpers : new boolean[]{false, true}) {
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
Nd4j.getEnvironment().allowHelpers(helpers);
|
||||||
|
String msg = helpers ? "With helpers" : "No helpers";
|
||||||
|
System.out.println(" --- " + msg + " ---");
|
||||||
|
|
||||||
|
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
|
||||||
|
INDArray labels = TestUtils.randomOneHot(2, 10);
|
||||||
|
|
||||||
|
TestCase tc = TestCase.builder()
|
||||||
|
.msg(msg)
|
||||||
|
.net1(getZeroPaddingNet(CNN2DFormat.NCHW, true))
|
||||||
|
.net2(getZeroPaddingNet(CNN2DFormat.NCHW, false))
|
||||||
|
.net3(getZeroPaddingNet(CNN2DFormat.NHWC, true))
|
||||||
|
.net4(getZeroPaddingNet(CNN2DFormat.NHWC, false))
|
||||||
|
.inNCHW(inNCHW)
|
||||||
|
.labelsNCHW(labels)
|
||||||
|
.labelsNHWC(labels)
|
||||||
|
.testLayerIdx(1)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
testHelper(tc);
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
Nd4j.getEnvironment().allowHelpers(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCropping2DLayer(){
|
||||||
|
try {
|
||||||
|
for (boolean helpers : new boolean[]{false, true}) {
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
Nd4j.getEnvironment().allowHelpers(helpers);
|
||||||
|
String msg = helpers ? "With helpers" : "No helpers";
|
||||||
|
System.out.println(" --- " + msg + " ---");
|
||||||
|
|
||||||
|
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
|
||||||
|
INDArray labels = TestUtils.randomOneHot(2, 10);
|
||||||
|
|
||||||
|
TestCase tc = TestCase.builder()
|
||||||
|
.msg(msg)
|
||||||
|
.net1(getCropping2dNet(CNN2DFormat.NCHW, true))
|
||||||
|
.net2(getCropping2dNet(CNN2DFormat.NCHW, false))
|
||||||
|
.net3(getCropping2dNet(CNN2DFormat.NHWC, true))
|
||||||
|
.net4(getCropping2dNet(CNN2DFormat.NHWC, false))
|
||||||
|
.inNCHW(inNCHW)
|
||||||
|
.labelsNCHW(labels)
|
||||||
|
.labelsNHWC(labels)
|
||||||
|
.testLayerIdx(1)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
testHelper(tc);
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
Nd4j.getEnvironment().allowHelpers(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testUpsampling2d(){
|
||||||
|
try {
|
||||||
|
for (boolean helpers : new boolean[]{false, true}) {
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
Nd4j.getEnvironment().allowHelpers(helpers);
|
||||||
|
String msg = helpers ? "With helpers" : "No helpers";
|
||||||
|
System.out.println(" --- " + msg + " ---");
|
||||||
|
|
||||||
|
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
|
||||||
|
INDArray labels = TestUtils.randomOneHot(2, 10);
|
||||||
|
|
||||||
|
TestCase tc = TestCase.builder()
|
||||||
|
.msg(msg)
|
||||||
|
.net1(getUpsamplingNet(CNN2DFormat.NCHW, true))
|
||||||
|
.net2(getUpsamplingNet(CNN2DFormat.NCHW, false))
|
||||||
|
.net3(getUpsamplingNet(CNN2DFormat.NHWC, true))
|
||||||
|
.net4(getUpsamplingNet(CNN2DFormat.NHWC, false))
|
||||||
|
.inNCHW(inNCHW)
|
||||||
|
.labelsNCHW(labels)
|
||||||
|
.labelsNHWC(labels)
|
||||||
|
.testLayerIdx(1)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
testHelper(tc);
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
Nd4j.getEnvironment().allowHelpers(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testBatchNormNet(){
|
||||||
|
try {
|
||||||
|
for(boolean useLogStd : new boolean[]{true, false}) {
|
||||||
|
for (boolean helpers : new boolean[]{false, true}) {
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
Nd4j.getEnvironment().allowHelpers(helpers);
|
||||||
|
String msg = (helpers ? "With helpers" : "No helpers") + " - " + (useLogStd ? "logstd" : "std");
|
||||||
|
System.out.println(" --- " + msg + " ---");
|
||||||
|
|
||||||
|
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
|
||||||
|
INDArray labels = TestUtils.randomOneHot(2, 10);
|
||||||
|
|
||||||
|
TestCase tc = TestCase.builder()
|
||||||
|
.msg(msg)
|
||||||
|
.net1(getBatchNormNet(useLogStd, CNN2DFormat.NCHW, true))
|
||||||
|
.net2(getBatchNormNet(useLogStd, CNN2DFormat.NCHW, false))
|
||||||
|
.net3(getBatchNormNet(useLogStd, CNN2DFormat.NHWC, true))
|
||||||
|
.net4(getBatchNormNet(useLogStd, CNN2DFormat.NHWC, false))
|
||||||
|
.inNCHW(inNCHW)
|
||||||
|
.labelsNCHW(labels)
|
||||||
|
.labelsNHWC(labels)
|
||||||
|
.testLayerIdx(1)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
testHelper(tc);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
Nd4j.getEnvironment().allowHelpers(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCnnLossLayer() {
|
||||||
|
try {
|
||||||
|
for (boolean helpers : new boolean[]{false, true}) {
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
Nd4j.getEnvironment().allowHelpers(helpers);
|
||||||
|
String msg = helpers ? "With helpers" : "No helpers";
|
||||||
|
System.out.println(" --- " + msg + " ---");
|
||||||
|
|
||||||
|
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
|
||||||
|
INDArray labelsNHWC = TestUtils.randomOneHot(this.dataType,2*6*6, 3);
|
||||||
|
labelsNHWC = labelsNHWC.reshape(2,6,6,3);
|
||||||
|
INDArray labelsNCHW = labelsNHWC.permute(0,3,1,2).dup();
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
TestCase tc = TestCase.builder()
|
||||||
|
.msg(msg)
|
||||||
|
.net1(getCnnLossNet(CNN2DFormat.NCHW, true, ConvolutionMode.Same))
|
||||||
|
.net2(getCnnLossNet(CNN2DFormat.NCHW, false, ConvolutionMode.Same))
|
||||||
|
.net3(getCnnLossNet(CNN2DFormat.NHWC, true, ConvolutionMode.Same))
|
||||||
|
.net4(getCnnLossNet(CNN2DFormat.NHWC, false, ConvolutionMode.Same))
|
||||||
|
.inNCHW(inNCHW)
|
||||||
|
.labelsNCHW(labelsNCHW)
|
||||||
|
.labelsNHWC(labelsNHWC)
|
||||||
|
.testLayerIdx(1)
|
||||||
|
.nhwcOutput(true)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
testHelper(tc);
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
Nd4j.getEnvironment().allowHelpers(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testSpaceToDepthNet(){
|
||||||
|
try {
|
||||||
|
for (boolean helpers : new boolean[]{false, true}) {
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
Nd4j.getEnvironment().allowHelpers(helpers);
|
||||||
|
String msg = helpers ? "With helpers" : "No helpers";
|
||||||
|
System.out.println(" --- " + msg + " ---");
|
||||||
|
|
||||||
|
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
|
||||||
|
INDArray labels = TestUtils.randomOneHot(2, 10);
|
||||||
|
|
||||||
|
TestCase tc = TestCase.builder()
|
||||||
|
.msg(msg)
|
||||||
|
.net1(getSpaceToDepthNet(CNN2DFormat.NCHW, true))
|
||||||
|
.net2(getSpaceToDepthNet(CNN2DFormat.NCHW, false))
|
||||||
|
.net3(getSpaceToDepthNet(CNN2DFormat.NHWC, true))
|
||||||
|
.net4(getSpaceToDepthNet(CNN2DFormat.NHWC, false))
|
||||||
|
.inNCHW(inNCHW)
|
||||||
|
.labelsNCHW(labels)
|
||||||
|
.labelsNHWC(labels)
|
||||||
|
.testLayerIdx(1)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
testHelper(tc);
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
Nd4j.getEnvironment().allowHelpers(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testSpaceToBatchNet(){
|
||||||
|
try {
|
||||||
|
for (boolean helpers : new boolean[]{false, true}) {
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
Nd4j.getEnvironment().allowHelpers(helpers);
|
||||||
|
String msg = helpers ? "With helpers" : "No helpers";
|
||||||
|
System.out.println(" --- " + msg + " ---");
|
||||||
|
|
||||||
|
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 16, 16);
|
||||||
|
INDArray labels = TestUtils.randomOneHot(8, 10);
|
||||||
|
|
||||||
|
TestCase tc = TestCase.builder()
|
||||||
|
.msg(msg)
|
||||||
|
.net1(getSpaceToBatchNet(CNN2DFormat.NCHW, true))
|
||||||
|
.net2(getSpaceToBatchNet(CNN2DFormat.NCHW, false))
|
||||||
|
.net3(getSpaceToBatchNet(CNN2DFormat.NHWC, true))
|
||||||
|
.net4(getSpaceToBatchNet(CNN2DFormat.NHWC, false))
|
||||||
|
.inNCHW(inNCHW)
|
||||||
|
.labelsNCHW(labels)
|
||||||
|
.labelsNHWC(labels)
|
||||||
|
.testLayerIdx(1)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
testHelper(tc);
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
Nd4j.getEnvironment().allowHelpers(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testLocallyConnected() {
|
||||||
|
try {
|
||||||
|
for (boolean helpers : new boolean[]{false, true}) {
|
||||||
|
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
Nd4j.getEnvironment().allowHelpers(helpers);
|
||||||
|
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
|
||||||
|
System.out.println(" --- " + msg + " ---");
|
||||||
|
|
||||||
|
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
|
||||||
|
INDArray labels = TestUtils.randomOneHot(2, 10);
|
||||||
|
|
||||||
|
TestCase tc = TestCase.builder()
|
||||||
|
.msg(msg)
|
||||||
|
.net1(getLocallyConnectedNet(CNN2DFormat.NCHW, true, cm))
|
||||||
|
.net2(getLocallyConnectedNet(CNN2DFormat.NCHW, false, cm))
|
||||||
|
.net3(getLocallyConnectedNet(CNN2DFormat.NHWC, true, cm))
|
||||||
|
.net4(getLocallyConnectedNet(CNN2DFormat.NHWC, false, cm))
|
||||||
|
.inNCHW(inNCHW)
|
||||||
|
.labelsNCHW(labels)
|
||||||
|
.labelsNHWC(labels)
|
||||||
|
.testLayerIdx(1)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
testHelper(tc);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
Nd4j.getEnvironment().allowHelpers(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testGlobalPooling() {
|
||||||
|
try {
|
||||||
|
for (boolean helpers : new boolean[]{false, true}) {
|
||||||
|
for (PoolingType pt : PoolingType.values()) {
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
Nd4j.getEnvironment().allowHelpers(helpers);
|
||||||
|
String msg = helpers ? "With helpers (" + pt + ")" : "No helpers (" + pt + ")";
|
||||||
|
System.out.println(" --- " + msg + " ---");
|
||||||
|
|
||||||
|
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
|
||||||
|
INDArray labels = TestUtils.randomOneHot(2, 10);
|
||||||
|
|
||||||
|
TestCase tc = TestCase.builder()
|
||||||
|
.msg(msg)
|
||||||
|
.net1(getGlobalPoolingNet(CNN2DFormat.NCHW, pt, true))
|
||||||
|
.net2(getGlobalPoolingNet(CNN2DFormat.NCHW, pt, false))
|
||||||
|
.net3(getGlobalPoolingNet(CNN2DFormat.NHWC, pt, true))
|
||||||
|
.net4(getGlobalPoolingNet(CNN2DFormat.NHWC, pt, false))
|
||||||
|
.inNCHW(inNCHW)
|
||||||
|
.labelsNCHW(labels)
|
||||||
|
.labelsNHWC(labels)
|
||||||
|
.testLayerIdx(1)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
testHelper(tc);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
Nd4j.getEnvironment().allowHelpers(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private MultiLayerNetwork getConv2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
|
||||||
|
if (setOnLayerAlso) {
|
||||||
|
return getNetWithLayer(new ConvolutionLayer.Builder()
|
||||||
|
.kernelSize(3, 3)
|
||||||
|
.stride(2, 2)
|
||||||
|
.activation(Activation.TANH)
|
||||||
|
.dataFormat(format)
|
||||||
|
.nOut(3)
|
||||||
|
.helperAllowFallback(false)
|
||||||
|
.build(), format, cm, null);
|
||||||
|
} else {
|
||||||
|
return getNetWithLayer(new ConvolutionLayer.Builder()
|
||||||
|
.kernelSize(3, 3)
|
||||||
|
.stride(2, 2)
|
||||||
|
.activation(Activation.TANH)
|
||||||
|
.nOut(3)
|
||||||
|
.helperAllowFallback(false)
|
||||||
|
.build(), format, cm, null);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private MultiLayerNetwork getSubsampling2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
|
||||||
|
if (setOnLayerAlso) {
|
||||||
|
return getNetWithLayer(new SubsamplingLayer.Builder()
|
||||||
|
.kernelSize(2, 2)
|
||||||
|
.stride(1, 1)
|
||||||
|
.dataFormat(format)
|
||||||
|
.helperAllowFallback(false)
|
||||||
|
.build(), format, cm, null);
|
||||||
|
} else {
|
||||||
|
return getNetWithLayer(new SubsamplingLayer.Builder()
|
||||||
|
.kernelSize(2, 2)
|
||||||
|
.stride(1, 1)
|
||||||
|
.helperAllowFallback(false)
|
||||||
|
.build(), format, cm, null);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private MultiLayerNetwork getSeparableConv2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
|
||||||
|
if (setOnLayerAlso) {
|
||||||
|
return getNetWithLayer(new SeparableConvolution2D.Builder()
|
||||||
|
.kernelSize(3, 3)
|
||||||
|
.stride(2, 2)
|
||||||
|
.activation(Activation.TANH)
|
||||||
|
.dataFormat(format)
|
||||||
|
.nOut(3)
|
||||||
|
.helperAllowFallback(false)
|
||||||
|
.build(), format, cm, null);
|
||||||
|
} else {
|
||||||
|
return getNetWithLayer(new SeparableConvolution2D.Builder()
|
||||||
|
.kernelSize(3, 3)
|
||||||
|
.stride(2, 2)
|
||||||
|
.activation(Activation.TANH)
|
||||||
|
.nOut(3)
|
||||||
|
.helperAllowFallback(false)
|
||||||
|
.build(), format, cm, null);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private MultiLayerNetwork getDepthwiseConv2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
|
||||||
|
if (setOnLayerAlso) {
|
||||||
|
return getNetWithLayer(new DepthwiseConvolution2D.Builder()
|
||||||
|
.depthMultiplier(2)
|
||||||
|
.kernelSize(3, 3)
|
||||||
|
.stride(2, 2)
|
||||||
|
.activation(Activation.TANH)
|
||||||
|
.dataFormat(format)
|
||||||
|
.nOut(3)
|
||||||
|
.helperAllowFallback(false)
|
||||||
|
.build(), format, cm, null);
|
||||||
|
} else {
|
||||||
|
return getNetWithLayer(new DepthwiseConvolution2D.Builder()
|
||||||
|
.depthMultiplier(2)
|
||||||
|
.kernelSize(3, 3)
|
||||||
|
.stride(2, 2)
|
||||||
|
.activation(Activation.TANH)
|
||||||
|
.nOut(3)
|
||||||
|
.helperAllowFallback(false)
|
||||||
|
.build(), format, cm, null);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private MultiLayerNetwork getLrnLayer(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
|
||||||
|
if (setOnLayerAlso) {
|
||||||
|
return getNetWithLayer(new LocalResponseNormalization.Builder()
|
||||||
|
.dataFormat(format)
|
||||||
|
.helperAllowFallback(false)
|
||||||
|
.build(), format, cm, null);
|
||||||
|
} else {
|
||||||
|
return getNetWithLayer(new LocalResponseNormalization.Builder()
|
||||||
|
.helperAllowFallback(false)
|
||||||
|
.build(), format, cm, null);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private MultiLayerNetwork getZeroPaddingNet(CNN2DFormat format, boolean setOnLayerAlso) {
|
||||||
|
if (setOnLayerAlso) {
|
||||||
|
return getNetWithLayer(new ZeroPaddingLayer.Builder(2,2)
|
||||||
|
.dataFormat(format).build(), format, ConvolutionMode.Same, null);
|
||||||
|
} else {
|
||||||
|
return getNetWithLayer(new ZeroPaddingLayer.Builder(2,2).build(),
|
||||||
|
format, ConvolutionMode.Same, null);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private MultiLayerNetwork getCropping2dNet(CNN2DFormat format, boolean setOnLayerAlso) {
|
||||||
|
if (setOnLayerAlso) {
|
||||||
|
return getNetWithLayer(new Cropping2D.Builder(2,2)
|
||||||
|
.dataFormat(format).build(), format, ConvolutionMode.Same, null);
|
||||||
|
} else {
|
||||||
|
return getNetWithLayer(new Cropping2D.Builder(2,2)
|
||||||
|
.build(), format, ConvolutionMode.Same, null);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private MultiLayerNetwork getUpsamplingNet(CNN2DFormat format, boolean setOnLayerAlso) {
|
||||||
|
if (setOnLayerAlso) {
|
||||||
|
return getNetWithLayer(new Upsampling2D.Builder(2)
|
||||||
|
.dataFormat(format).build(), format, ConvolutionMode.Same, null);
|
||||||
|
} else {
|
||||||
|
return getNetWithLayer(new Upsampling2D.Builder(2)
|
||||||
|
.build(), format, ConvolutionMode.Same, null);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private MultiLayerNetwork getDeconv2DNet2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
|
||||||
|
if (setOnLayerAlso) {
|
||||||
|
return getNetWithLayer(new Deconvolution2D.Builder().nOut(2)
|
||||||
|
.activation(Activation.TANH)
|
||||||
|
.kernelSize(2,2)
|
||||||
|
.stride(2,2)
|
||||||
|
.build(), format, cm, null);
|
||||||
|
} else {
|
||||||
|
return getNetWithLayer(new Deconvolution2D.Builder().nOut(2)
|
||||||
|
.activation(Activation.TANH)
|
||||||
|
.kernelSize(2,2)
|
||||||
|
.stride(2,2)
|
||||||
|
.build(), format, cm, null);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private MultiLayerNetwork getBatchNormNet(boolean logStdev, CNN2DFormat format, boolean setOnLayerAlso) {
|
||||||
|
if (setOnLayerAlso) {
|
||||||
|
return getNetWithLayer(new BatchNormalization.Builder()
|
||||||
|
.useLogStd(logStdev)
|
||||||
|
.dataFormat(format)
|
||||||
|
.helperAllowFallback(false)
|
||||||
|
.nOut(3).build(), format, ConvolutionMode.Same, null);
|
||||||
|
} else {
|
||||||
|
return getNetWithLayer(new BatchNormalization.Builder()
|
||||||
|
.useLogStd(logStdev)
|
||||||
|
.helperAllowFallback(false)
|
||||||
|
.nOut(3).build(), format, ConvolutionMode.Same, null);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private MultiLayerNetwork getSpaceToDepthNet(CNN2DFormat format, boolean setOnLayerAlso) {
|
||||||
|
if (setOnLayerAlso) {
|
||||||
|
return getNetWithLayer(new SpaceToDepthLayer.Builder()
|
||||||
|
.blocks(2)
|
||||||
|
.dataFormat(format)
|
||||||
|
.build(), format, ConvolutionMode.Same, null);
|
||||||
|
} else {
|
||||||
|
return getNetWithLayer(new SpaceToDepthLayer.Builder()
|
||||||
|
.blocks(2)
|
||||||
|
.build(), format, ConvolutionMode.Same, null);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private MultiLayerNetwork getSpaceToBatchNet(CNN2DFormat format, boolean setOnLayerAlso) {
|
||||||
|
if (setOnLayerAlso) {
|
||||||
|
return getNetWithLayer(new SpaceToBatchLayer.Builder()
|
||||||
|
.blocks(2, 2)
|
||||||
|
.dataFormat(format)
|
||||||
|
.build(), format, ConvolutionMode.Same, InputType.convolutional(16, 16, 3, format));
|
||||||
|
} else {
|
||||||
|
return getNetWithLayer(new SpaceToBatchLayer.Builder()
|
||||||
|
.blocks(2, 2)
|
||||||
|
.build(), format, ConvolutionMode.Same, InputType.convolutional(16, 16, 3, format));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private MultiLayerNetwork getLocallyConnectedNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
|
||||||
|
if (setOnLayerAlso) {
|
||||||
|
return getNetWithLayer(new LocallyConnected2D.Builder()
|
||||||
|
.kernelSize(3, 3)
|
||||||
|
.stride(2, 2)
|
||||||
|
.activation(Activation.TANH)
|
||||||
|
.dataFormat(format)
|
||||||
|
.nOut(3)
|
||||||
|
.build(), format, cm, null);
|
||||||
|
} else {
|
||||||
|
return getNetWithLayer(new LocallyConnected2D.Builder()
|
||||||
|
.kernelSize(3, 3)
|
||||||
|
.stride(2, 2)
|
||||||
|
.activation(Activation.TANH)
|
||||||
|
.nOut(3)
|
||||||
|
.build(), format, cm, null);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private MultiLayerNetwork getNetWithLayer(Layer layer, CNN2DFormat format, ConvolutionMode cm, InputType inputType) {
|
||||||
|
NeuralNetConfiguration.ListBuilder builder = new NeuralNetConfiguration.Builder()
|
||||||
|
.dataType(this.dataType)
|
||||||
|
.seed(12345)
|
||||||
|
.convolutionMode(cm)
|
||||||
|
.list()
|
||||||
|
.layer(new ConvolutionLayer.Builder()
|
||||||
|
.kernelSize(3, 3)
|
||||||
|
.stride(2, 2)
|
||||||
|
.activation(Activation.TANH)
|
||||||
|
.dataFormat(format)
|
||||||
|
.nOut(3)
|
||||||
|
.helperAllowFallback(false)
|
||||||
|
.build())
|
||||||
|
.layer(layer)
|
||||||
|
.layer(new OutputLayer.Builder().activation(Activation.SOFTMAX).nOut(10).build())
|
||||||
|
.setInputType(inputType != null ? inputType : InputType.convolutional(12, 12, 3, format));
|
||||||
|
|
||||||
|
if(format == CNN2DFormat.NHWC && !(layer instanceof GlobalPoolingLayer)){
|
||||||
|
//Add a preprocessor due to the differences in how NHWC and NCHW activations are flattened
|
||||||
|
//DL4J's flattening behaviour matches Keras (hence TF) for import compatibility
|
||||||
|
builder.inputPreProcessor(2, new ComposableInputPreProcessor(new NHWCToNCHWPreprocessor(), new CnnToFeedForwardPreProcessor()));
|
||||||
|
}
|
||||||
|
|
||||||
|
MultiLayerNetwork net = new MultiLayerNetwork(builder.build());
|
||||||
|
net.init();
|
||||||
|
return net;
|
||||||
|
}
|
||||||
|
|
||||||
|
private MultiLayerNetwork getGlobalPoolingNet(CNN2DFormat format, PoolingType pt, boolean setOnLayerAlso) {
|
||||||
|
if (setOnLayerAlso) {
|
||||||
|
return getNetWithLayer(new GlobalPoolingLayer.Builder(pt)
|
||||||
|
.poolingDimensions(format == CNN2DFormat.NCHW ? new int[]{2,3} : new int[]{1,2})
|
||||||
|
.build(), format, ConvolutionMode.Same, null);
|
||||||
|
} else {
|
||||||
|
return getNetWithLayer(new GlobalPoolingLayer.Builder(pt)
|
||||||
|
.build(), format, ConvolutionMode.Same, null);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private MultiLayerNetwork getCnnLossNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm){
|
||||||
|
NeuralNetConfiguration.ListBuilder builder = new NeuralNetConfiguration.Builder()
|
||||||
|
.seed(12345)
|
||||||
|
.convolutionMode(cm)
|
||||||
|
.list()
|
||||||
|
.layer(new ConvolutionLayer.Builder()
|
||||||
|
.kernelSize(3, 3)
|
||||||
|
.stride(2, 2)
|
||||||
|
.activation(Activation.TANH)
|
||||||
|
.dataFormat(format)
|
||||||
|
.nOut(3)
|
||||||
|
.helperAllowFallback(false)
|
||||||
|
.build());
|
||||||
|
if(setOnLayerAlso){
|
||||||
|
builder.layer(new CnnLossLayer.Builder().format(format).activation(Activation.SOFTMAX).build());
|
||||||
|
} else {
|
||||||
|
builder.layer(new CnnLossLayer.Builder().activation(Activation.SOFTMAX).build());
|
||||||
|
}
|
||||||
|
|
||||||
|
builder.setInputType(InputType.convolutional(12, 12, 3, format));
|
||||||
|
|
||||||
|
MultiLayerNetwork net = new MultiLayerNetwork(builder.build());
|
||||||
|
net.init();
|
||||||
|
return net;
|
||||||
|
}
|
||||||
|
|
||||||
|
@AllArgsConstructor
|
||||||
|
@Data
|
||||||
|
@NoArgsConstructor
|
||||||
|
@Builder
|
||||||
|
private static class TestCase {
|
||||||
|
private String msg;
|
||||||
|
private MultiLayerNetwork net1;
|
||||||
|
private MultiLayerNetwork net2;
|
||||||
|
private MultiLayerNetwork net3;
|
||||||
|
private MultiLayerNetwork net4;
|
||||||
|
private INDArray inNCHW;
|
||||||
|
private INDArray labelsNCHW;
|
||||||
|
private INDArray labelsNHWC;
|
||||||
|
private int testLayerIdx;
|
||||||
|
private boolean nhwcOutput;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void testHelper(TestCase tc) {
|
||||||
|
|
||||||
|
tc.net2.params().assign(tc.net1.params());
|
||||||
|
tc.net3.params().assign(tc.net1.params());
|
||||||
|
tc.net4.params().assign(tc.net1.params());
|
||||||
|
|
||||||
|
//Test forward pass:
|
||||||
|
INDArray inNCHW = tc.inNCHW;
|
||||||
|
INDArray inNHWC = tc.inNCHW.permute(0, 2, 3, 1).dup();
|
||||||
|
|
||||||
|
INDArray l0_1 = tc.net1.feedForward(inNCHW).get(tc.testLayerIdx + 1);
|
||||||
|
INDArray l0_2 = tc.net2.feedForward(inNCHW).get(tc.testLayerIdx + 1);
|
||||||
|
INDArray l0_3 = tc.net3.feedForward(inNHWC).get(tc.testLayerIdx + 1);
|
||||||
|
INDArray l0_4 = tc.net4.feedForward(inNHWC).get(tc.testLayerIdx + 1);
|
||||||
|
|
||||||
|
assertEquals(tc.msg, l0_1, l0_2);
|
||||||
|
if(l0_1.rank() == 4) {
|
||||||
|
assertEquals(tc.msg, l0_1, l0_3.permute(0, 3, 1, 2));
|
||||||
|
assertEquals(tc.msg, l0_1, l0_4.permute(0, 3, 1, 2));
|
||||||
|
} else {
|
||||||
|
assertEquals(tc.msg, l0_1, l0_3);
|
||||||
|
assertEquals(tc.msg, l0_1, l0_4);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
INDArray out1 = tc.net1.output(inNCHW);
|
||||||
|
INDArray out2 = tc.net2.output(inNCHW);
|
||||||
|
INDArray out3 = tc.net3.output(inNHWC);
|
||||||
|
INDArray out4 = tc.net4.output(inNHWC);
|
||||||
|
|
||||||
|
assertEquals(tc.msg, out1, out2);
|
||||||
|
if(!tc.nhwcOutput) {
|
||||||
|
assertEquals(tc.msg, out1, out3);
|
||||||
|
assertEquals(tc.msg, out1, out4);
|
||||||
|
} else {
|
||||||
|
assertEquals(tc.msg, out1, out3.permute(0,3,1,2)); //NHWC to NCHW
|
||||||
|
assertEquals(tc.msg, out1, out4.permute(0,3,1,2));
|
||||||
|
}
|
||||||
|
|
||||||
|
//Test backprop
|
||||||
|
Pair<Gradient, INDArray> p1 = tc.net1.calculateGradients(inNCHW, tc.labelsNCHW, null, null);
|
||||||
|
Pair<Gradient, INDArray> p2 = tc.net2.calculateGradients(inNCHW, tc.labelsNCHW, null, null);
|
||||||
|
Pair<Gradient, INDArray> p3 = tc.net3.calculateGradients(inNHWC, tc.labelsNHWC, null, null);
|
||||||
|
Pair<Gradient, INDArray> p4 = tc.net4.calculateGradients(inNHWC, tc.labelsNHWC, null, null);
|
||||||
|
|
||||||
|
//Inpput gradients
|
||||||
|
assertEquals(tc.msg, p1.getSecond(), p2.getSecond());
|
||||||
|
assertEquals(tc.msg, p1.getSecond(), p3.getSecond().permute(0,3,1,2)); //Input gradients for NHWC input are also in NHWC format
|
||||||
|
assertEquals(tc.msg, p1.getSecond(), p4.getSecond().permute(0,3,1,2));
|
||||||
|
|
||||||
|
List<String> diff12 = differentGrads(p1.getFirst(), p2.getFirst());
|
||||||
|
List<String> diff13 = differentGrads(p1.getFirst(), p3.getFirst());
|
||||||
|
List<String> diff14 = differentGrads(p1.getFirst(), p4.getFirst());
|
||||||
|
assertEquals(tc.msg + " " + diff12, 0, diff12.size());
|
||||||
|
assertEquals(tc.msg + " " + diff13, 0, diff13.size());
|
||||||
|
assertEquals(tc.msg + " " + diff14, 0, diff14.size());
|
||||||
|
|
||||||
|
assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p2.getFirst().gradientForVariable());
|
||||||
|
assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p3.getFirst().gradientForVariable());
|
||||||
|
assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p4.getFirst().gradientForVariable());
|
||||||
|
|
||||||
|
tc.net1.fit(inNCHW, tc.labelsNCHW);
|
||||||
|
tc.net2.fit(inNCHW, tc.labelsNCHW);
|
||||||
|
tc.net3.fit(inNHWC, tc.labelsNHWC);
|
||||||
|
tc.net4.fit(inNHWC, tc.labelsNHWC);
|
||||||
|
|
||||||
|
assertEquals(tc.msg, tc.net1.params(), tc.net2.params());
|
||||||
|
assertEquals(tc.msg, tc.net1.params(), tc.net3.params());
|
||||||
|
assertEquals(tc.msg, tc.net1.params(), tc.net4.params());
|
||||||
|
|
||||||
|
//Test serialization
|
||||||
|
MultiLayerNetwork net1a = TestUtils.testModelSerialization(tc.net1);
|
||||||
|
MultiLayerNetwork net2a = TestUtils.testModelSerialization(tc.net2);
|
||||||
|
MultiLayerNetwork net3a = TestUtils.testModelSerialization(tc.net3);
|
||||||
|
MultiLayerNetwork net4a = TestUtils.testModelSerialization(tc.net4);
|
||||||
|
|
||||||
|
out1 = tc.net1.output(inNCHW);
|
||||||
|
assertEquals(tc.msg, out1, net1a.output(inNCHW));
|
||||||
|
assertEquals(tc.msg, out1, net2a.output(inNCHW));
|
||||||
|
if(!tc.nhwcOutput) {
|
||||||
|
assertEquals(tc.msg, out1, net3a.output(inNHWC));
|
||||||
|
assertEquals(tc.msg, out1, net4a.output(inNHWC));
|
||||||
|
} else {
|
||||||
|
assertEquals(tc.msg, out1, net3a.output(inNHWC).permute(0,3,1,2)); //NHWC to NCHW
|
||||||
|
assertEquals(tc.msg, out1, net4a.output(inNHWC).permute(0,3,1,2));
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private static List<String> differentGrads(Gradient g1, Gradient g2){
|
||||||
|
List<String> differs = new ArrayList<>();
|
||||||
|
Map<String,INDArray> m1 = g1.gradientForVariable();
|
||||||
|
Map<String,INDArray> m2 = g2.gradientForVariable();
|
||||||
|
for(String s : m1.keySet()){
|
||||||
|
INDArray a1 = m1.get(s);
|
||||||
|
INDArray a2 = m2.get(s);
|
||||||
|
if(!a1.equals(a2)){
|
||||||
|
differs.add(s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return differs;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//Converts NHWC to NCHW activations
|
||||||
|
@EqualsAndHashCode
|
||||||
|
private static class NHWCToNCHWPreprocessor implements InputPreProcessor {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
|
||||||
|
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input.permute(0,3,1,2));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray backprop(INDArray output, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
|
||||||
|
return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, output.permute(0,2,3,1));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public InputPreProcessor clone() {
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public InputType getOutputType(InputType inputType) {
|
||||||
|
InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
|
||||||
|
return InputType.convolutional(c.getHeight(), c.getWidth(), c.getChannels(), CNN2DFormat.NCHW);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.layers.convolution;
|
package org.deeplearning4j.nn.layers.convolution;
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.exception.DL4JException;
|
import org.deeplearning4j.exception.DL4JException;
|
||||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||||
|
@ -45,6 +46,7 @@ import static org.junit.Assert.*;
|
||||||
/**
|
/**
|
||||||
* Created by Alex on 15/11/2016.
|
* Created by Alex on 15/11/2016.
|
||||||
*/
|
*/
|
||||||
|
@Slf4j
|
||||||
public class TestConvolutionModes extends BaseDL4JTest {
|
public class TestConvolutionModes extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -106,12 +108,12 @@ public class TestConvolutionModes extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
} catch (DL4JException e) {
|
} catch (DL4JException e) {
|
||||||
if (inSize == 9 || cm != ConvolutionMode.Strict) {
|
if (inSize == 9 || cm != ConvolutionMode.Strict) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
fail("Unexpected exception");
|
fail("Unexpected exception");
|
||||||
}
|
}
|
||||||
continue; //Expected exception
|
continue; //Expected exception
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
fail("Unexpected exception");
|
fail("Unexpected exception");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -184,12 +186,12 @@ public class TestConvolutionModes extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
} catch (DL4JException e) {
|
} catch (DL4JException e) {
|
||||||
if (inSize == 9 || cm != ConvolutionMode.Strict) {
|
if (inSize == 9 || cm != ConvolutionMode.Strict) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
fail("Unexpected exception");
|
fail("Unexpected exception");
|
||||||
}
|
}
|
||||||
continue; //Expected exception
|
continue; //Expected exception
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
fail("Unexpected exception");
|
fail("Unexpected exception");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -24,10 +24,7 @@ import org.deeplearning4j.earlystopping.saver.InMemoryModelSaver;
|
||||||
import org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculator;
|
import org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculator;
|
||||||
import org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition;
|
import org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition;
|
||||||
import org.deeplearning4j.earlystopping.trainer.EarlyStoppingGraphTrainer;
|
import org.deeplearning4j.earlystopping.trainer.EarlyStoppingGraphTrainer;
|
||||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
import org.deeplearning4j.nn.conf.*;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
|
||||||
import org.deeplearning4j.nn.conf.WorkspaceMode;
|
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM;
|
import org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM;
|
||||||
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
|
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
|
||||||
|
@ -45,6 +42,8 @@ import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
import org.deeplearning4j.util.ModelSerializer;
|
import org.deeplearning4j.util.ModelSerializer;
|
||||||
import org.deeplearning4j.util.TimeSeriesUtils;
|
import org.deeplearning4j.util.TimeSeriesUtils;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
import org.junit.runners.Parameterized;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -61,12 +60,22 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
import java.io.ByteArrayInputStream;
|
import java.io.ByteArrayInputStream;
|
||||||
import java.io.ByteArrayOutputStream;
|
import java.io.ByteArrayOutputStream;
|
||||||
|
|
||||||
|
import static org.deeplearning4j.nn.conf.RNNFormat.NCW;
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
|
@RunWith(Parameterized.class)
|
||||||
public class BidirectionalTest extends BaseDL4JTest {
|
public class BidirectionalTest extends BaseDL4JTest {
|
||||||
|
|
||||||
|
private RNNFormat rnnDataFormat;
|
||||||
|
|
||||||
|
public BidirectionalTest(RNNFormat rnnDataFormat){
|
||||||
|
this.rnnDataFormat = rnnDataFormat;
|
||||||
|
}
|
||||||
|
@Parameterized.Parameters
|
||||||
|
public static Object[] params(){
|
||||||
|
return RNNFormat.values();
|
||||||
|
}
|
||||||
@Test
|
@Test
|
||||||
public void compareImplementations(){
|
public void compareImplementations(){
|
||||||
for(WorkspaceMode wsm : WorkspaceMode.values()) {
|
for(WorkspaceMode wsm : WorkspaceMode.values()) {
|
||||||
|
@ -82,9 +91,9 @@ public class BidirectionalTest extends BaseDL4JTest {
|
||||||
.inferenceWorkspaceMode(wsm)
|
.inferenceWorkspaceMode(wsm)
|
||||||
.updater(new Adam())
|
.updater(new Adam())
|
||||||
.list()
|
.list()
|
||||||
.layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build()))
|
.layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()))
|
||||||
.layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build()))
|
.layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()))
|
||||||
.layer(new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE)
|
.layer(new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).dataFormat(rnnDataFormat)
|
||||||
.nIn(10).nOut(10).build())
|
.nIn(10).nOut(10).build())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
@ -95,9 +104,9 @@ public class BidirectionalTest extends BaseDL4JTest {
|
||||||
.inferenceWorkspaceMode(wsm)
|
.inferenceWorkspaceMode(wsm)
|
||||||
.updater(new Adam())
|
.updater(new Adam())
|
||||||
.list()
|
.list()
|
||||||
.layer(new GravesBidirectionalLSTM.Builder().nIn(10).nOut(10).build())
|
.layer(new GravesBidirectionalLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build())
|
||||||
.layer(new GravesBidirectionalLSTM.Builder().nIn(10).nOut(10).build())
|
.layer(new GravesBidirectionalLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build())
|
||||||
.layer(new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE)
|
.layer(new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).dataFormat(rnnDataFormat)
|
||||||
.nIn(10).nOut(10).build())
|
.nIn(10).nOut(10).build())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
@ -116,15 +125,24 @@ public class BidirectionalTest extends BaseDL4JTest {
|
||||||
|
|
||||||
net2.setParams(net1.params()); //Assuming exact same layout here...
|
net2.setParams(net1.params()); //Assuming exact same layout here...
|
||||||
|
|
||||||
INDArray in = Nd4j.rand(new int[]{3, 10, 5});
|
INDArray in;
|
||||||
|
if (rnnDataFormat == NCW){
|
||||||
|
in = Nd4j.rand(new int[]{3, 10, 5});
|
||||||
|
}else{
|
||||||
|
in = Nd4j.rand(new int[]{3, 5, 10});
|
||||||
|
}
|
||||||
|
|
||||||
INDArray out1 = net1.output(in);
|
INDArray out1 = net1.output(in);
|
||||||
INDArray out2 = net2.output(in);
|
INDArray out2 = net2.output(in);
|
||||||
|
|
||||||
assertEquals(out1, out2);
|
assertEquals(out1, out2);
|
||||||
|
|
||||||
INDArray labels = Nd4j.rand(new int[]{3, 10, 5});
|
INDArray labels;
|
||||||
|
if (rnnDataFormat == NCW){
|
||||||
|
labels = Nd4j.rand(new int[]{3, 10, 5});
|
||||||
|
}else{
|
||||||
|
labels = Nd4j.rand(new int[]{3, 5, 10});
|
||||||
|
}
|
||||||
net1.setInput(in);
|
net1.setInput(in);
|
||||||
net1.setLabels(labels);
|
net1.setLabels(labels);
|
||||||
|
|
||||||
|
@ -276,17 +294,22 @@ public class BidirectionalTest extends BaseDL4JTest {
|
||||||
.inferenceWorkspaceMode(wsm)
|
.inferenceWorkspaceMode(wsm)
|
||||||
.updater(new Adam())
|
.updater(new Adam())
|
||||||
.list()
|
.list()
|
||||||
.layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build()))
|
.layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()))
|
||||||
.layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build()))
|
.layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()))
|
||||||
.layer(new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE)
|
.layer(new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE)
|
||||||
.nIn(10).nOut(10).build())
|
.nIn(10).nOut(10).dataFormat(rnnDataFormat).build())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
MultiLayerNetwork net1 = new MultiLayerNetwork(conf1);
|
MultiLayerNetwork net1 = new MultiLayerNetwork(conf1);
|
||||||
net1.init();
|
net1.init();
|
||||||
|
|
||||||
INDArray in = Nd4j.rand(new int[]{3, 10, 5});
|
INDArray in;
|
||||||
INDArray labels = Nd4j.rand(new int[]{3, 10, 5});
|
INDArray labels;
|
||||||
|
|
||||||
|
long[] inshape = rnnDataFormat == NCW ? new long[]{3, 10, 5} : new long[]{3, 5, 10};
|
||||||
|
|
||||||
|
in = Nd4j.rand(inshape);
|
||||||
|
labels = Nd4j.rand(inshape);
|
||||||
|
|
||||||
net1.fit(in, labels);
|
net1.fit(in, labels);
|
||||||
|
|
||||||
|
@ -300,8 +323,8 @@ public class BidirectionalTest extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net2 = ModelSerializer.restoreMultiLayerNetwork(new ByteArrayInputStream(bytes), true);
|
MultiLayerNetwork net2 = ModelSerializer.restoreMultiLayerNetwork(new ByteArrayInputStream(bytes), true);
|
||||||
|
|
||||||
|
|
||||||
in = Nd4j.rand(new int[]{3, 10, 5});
|
in = Nd4j.rand(inshape);
|
||||||
labels = Nd4j.rand(new int[]{3, 10, 5});
|
labels = Nd4j.rand(inshape);
|
||||||
|
|
||||||
INDArray out1 = net1.output(in);
|
INDArray out1 = net1.output(in);
|
||||||
INDArray out2 = net2.output(in);
|
INDArray out2 = net2.output(in);
|
||||||
|
@ -338,18 +361,18 @@ public class BidirectionalTest extends BaseDL4JTest {
|
||||||
.updater(new Adam())
|
.updater(new Adam())
|
||||||
.graphBuilder()
|
.graphBuilder()
|
||||||
.addInputs("in")
|
.addInputs("in")
|
||||||
.layer("0", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build()), "in")
|
.layer("0", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()), "in")
|
||||||
.layer("1", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build()), "0")
|
.layer("1", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()), "0")
|
||||||
.layer("2", new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE)
|
.layer("2", new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).dataFormat(rnnDataFormat)
|
||||||
.nIn(10).nOut(10).build(), "1")
|
.nIn(10).nOut(10).build(), "1")
|
||||||
.setOutputs("2")
|
.setOutputs("2")
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
ComputationGraph net1 = new ComputationGraph(conf1);
|
ComputationGraph net1 = new ComputationGraph(conf1);
|
||||||
net1.init();
|
net1.init();
|
||||||
|
long[] inshape = (rnnDataFormat == NCW)? new long[]{3, 10, 5}: new long[]{3, 5, 10};
|
||||||
INDArray in = Nd4j.rand(new int[]{3, 10, 5});
|
INDArray in = Nd4j.rand(inshape);
|
||||||
INDArray labels = Nd4j.rand(new int[]{3, 10, 5});
|
INDArray labels = Nd4j.rand(inshape);
|
||||||
|
|
||||||
net1.fit(new DataSet(in, labels));
|
net1.fit(new DataSet(in, labels));
|
||||||
|
|
||||||
|
@ -363,8 +386,8 @@ public class BidirectionalTest extends BaseDL4JTest {
|
||||||
ComputationGraph net2 = ModelSerializer.restoreComputationGraph(new ByteArrayInputStream(bytes), true);
|
ComputationGraph net2 = ModelSerializer.restoreComputationGraph(new ByteArrayInputStream(bytes), true);
|
||||||
|
|
||||||
|
|
||||||
in = Nd4j.rand(new int[]{3, 10, 5});
|
in = Nd4j.rand(inshape);
|
||||||
labels = Nd4j.rand(new int[]{3, 10, 5});
|
labels = Nd4j.rand(inshape);
|
||||||
|
|
||||||
INDArray out1 = net1.outputSingle(in);
|
INDArray out1 = net1.outputSingle(in);
|
||||||
INDArray out2 = net2.outputSingle(in);
|
INDArray out2 = net2.outputSingle(in);
|
||||||
|
@ -394,8 +417,8 @@ public class BidirectionalTest extends BaseDL4JTest {
|
||||||
Bidirectional.Mode[] modes = new Bidirectional.Mode[]{Bidirectional.Mode.CONCAT, Bidirectional.Mode.ADD,
|
Bidirectional.Mode[] modes = new Bidirectional.Mode[]{Bidirectional.Mode.CONCAT, Bidirectional.Mode.ADD,
|
||||||
Bidirectional.Mode.AVERAGE, Bidirectional.Mode.MUL};
|
Bidirectional.Mode.AVERAGE, Bidirectional.Mode.MUL};
|
||||||
|
|
||||||
|
long[] inshape = rnnDataFormat == NCW ? new long[]{3, 10, 6} : new long[]{3, 6, 10};
|
||||||
INDArray in = Nd4j.rand(new int[]{3, 10, 6});
|
INDArray in = Nd4j.rand(inshape);
|
||||||
|
|
||||||
for (Bidirectional.Mode m : modes) {
|
for (Bidirectional.Mode m : modes) {
|
||||||
MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder()
|
MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder()
|
||||||
|
@ -406,7 +429,7 @@ public class BidirectionalTest extends BaseDL4JTest {
|
||||||
.inferenceWorkspaceMode(wsm)
|
.inferenceWorkspaceMode(wsm)
|
||||||
.updater(new Adam())
|
.updater(new Adam())
|
||||||
.list()
|
.list()
|
||||||
.layer(new Bidirectional(m, new SimpleRnn.Builder().nIn(10).nOut(10).build()))
|
.layer(new Bidirectional(m, new SimpleRnn.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()))
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
MultiLayerNetwork net1 = new MultiLayerNetwork(conf1);
|
MultiLayerNetwork net1 = new MultiLayerNetwork(conf1);
|
||||||
|
@ -418,7 +441,7 @@ public class BidirectionalTest extends BaseDL4JTest {
|
||||||
.weightInit(WeightInit.XAVIER)
|
.weightInit(WeightInit.XAVIER)
|
||||||
.updater(new Adam())
|
.updater(new Adam())
|
||||||
.list()
|
.list()
|
||||||
.layer(new SimpleRnn.Builder().nIn(10).nOut(10).build())
|
.layer(new SimpleRnn.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2.clone());
|
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2.clone());
|
||||||
|
@ -434,11 +457,10 @@ public class BidirectionalTest extends BaseDL4JTest {
|
||||||
net3.setParam("0_RW", net1.getParam("0_bRW"));
|
net3.setParam("0_RW", net1.getParam("0_bRW"));
|
||||||
net3.setParam("0_b", net1.getParam("0_bb"));
|
net3.setParam("0_b", net1.getParam("0_bb"));
|
||||||
|
|
||||||
INDArray inReverse = TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT);
|
INDArray inReverse = TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT, rnnDataFormat);
|
||||||
|
|
||||||
INDArray out1 = net1.output(in);
|
INDArray out1 = net1.output(in);
|
||||||
INDArray out2 = net2.output(in);
|
INDArray out2 = net2.output(in);
|
||||||
INDArray out3 = TimeSeriesUtils.reverseTimeSeries(net3.output(inReverse), LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT);
|
INDArray out3 = TimeSeriesUtils.reverseTimeSeries(net3.output(inReverse), LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT, rnnDataFormat);
|
||||||
|
|
||||||
INDArray outExp;
|
INDArray outExp;
|
||||||
switch (m) {
|
switch (m) {
|
||||||
|
@ -452,7 +474,7 @@ public class BidirectionalTest extends BaseDL4JTest {
|
||||||
outExp = out2.add(out3).muli(0.5);
|
outExp = out2.add(out3).muli(0.5);
|
||||||
break;
|
break;
|
||||||
case CONCAT:
|
case CONCAT:
|
||||||
outExp = Nd4j.concat(1, out2, out3);
|
outExp = Nd4j.concat((rnnDataFormat == NCW)?1:2, out2, out3);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
throw new RuntimeException();
|
throw new RuntimeException();
|
||||||
|
@ -464,25 +486,25 @@ public class BidirectionalTest extends BaseDL4JTest {
|
||||||
//Check gradients:
|
//Check gradients:
|
||||||
if (m == Bidirectional.Mode.ADD || m == Bidirectional.Mode.CONCAT) {
|
if (m == Bidirectional.Mode.ADD || m == Bidirectional.Mode.CONCAT) {
|
||||||
|
|
||||||
INDArray eps = Nd4j.rand(new int[]{3, 10, 6});
|
INDArray eps = Nd4j.rand(inshape);
|
||||||
|
|
||||||
INDArray eps1;
|
INDArray eps1;
|
||||||
if (m == Bidirectional.Mode.CONCAT) {
|
if (m == Bidirectional.Mode.CONCAT) {
|
||||||
eps1 = Nd4j.concat(1, eps, eps);
|
eps1 = Nd4j.concat((rnnDataFormat == NCW)?1:2, eps, eps);
|
||||||
} else {
|
} else {
|
||||||
eps1 = eps;
|
eps1 = eps;
|
||||||
}
|
}
|
||||||
|
|
||||||
net1.setInput(in);
|
net1.setInput(in);
|
||||||
net2.setInput(in);
|
net2.setInput(in);
|
||||||
net3.setInput(TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT));
|
net3.setInput(TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT, rnnDataFormat));
|
||||||
net1.feedForward(true, false);
|
net1.feedForward(true, false);
|
||||||
net2.feedForward(true, false);
|
net2.feedForward(true, false);
|
||||||
net3.feedForward(true, false);
|
net3.feedForward(true, false);
|
||||||
|
|
||||||
Pair<Gradient, INDArray> p1 = net1.backpropGradient(eps1, LayerWorkspaceMgr.noWorkspaces());
|
Pair<Gradient, INDArray> p1 = net1.backpropGradient(eps1, LayerWorkspaceMgr.noWorkspaces());
|
||||||
Pair<Gradient, INDArray> p2 = net2.backpropGradient(eps, LayerWorkspaceMgr.noWorkspaces());
|
Pair<Gradient, INDArray> p2 = net2.backpropGradient(eps, LayerWorkspaceMgr.noWorkspaces());
|
||||||
Pair<Gradient, INDArray> p3 = net3.backpropGradient(TimeSeriesUtils.reverseTimeSeries(eps, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT), LayerWorkspaceMgr.noWorkspaces());
|
Pair<Gradient, INDArray> p3 = net3.backpropGradient(TimeSeriesUtils.reverseTimeSeries(eps, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT, rnnDataFormat), LayerWorkspaceMgr.noWorkspaces());
|
||||||
Gradient g1 = p1.getFirst();
|
Gradient g1 = p1.getFirst();
|
||||||
Gradient g2 = p2.getFirst();
|
Gradient g2 = p2.getFirst();
|
||||||
Gradient g3 = p3.getFirst();
|
Gradient g3 = p3.getFirst();
|
||||||
|
@ -520,7 +542,9 @@ public class BidirectionalTest extends BaseDL4JTest {
|
||||||
Bidirectional.Mode.AVERAGE, Bidirectional.Mode.MUL};
|
Bidirectional.Mode.AVERAGE, Bidirectional.Mode.MUL};
|
||||||
|
|
||||||
|
|
||||||
INDArray in = Nd4j.rand(new int[]{3, 10, 6});
|
long[] inshape = rnnDataFormat == NCW ? new long[]{3, 10, 6} : new long[]{3, 6, 10};
|
||||||
|
INDArray in = Nd4j.rand(inshape);
|
||||||
|
|
||||||
|
|
||||||
for (Bidirectional.Mode m : modes) {
|
for (Bidirectional.Mode m : modes) {
|
||||||
ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder()
|
ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder()
|
||||||
|
@ -532,7 +556,7 @@ public class BidirectionalTest extends BaseDL4JTest {
|
||||||
.updater(new Adam())
|
.updater(new Adam())
|
||||||
.graphBuilder()
|
.graphBuilder()
|
||||||
.addInputs("in")
|
.addInputs("in")
|
||||||
.layer("0", new Bidirectional(m, new SimpleRnn.Builder().nIn(10).nOut(10).build()), "in")
|
.layer("0", new Bidirectional(m, new SimpleRnn.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()), "in")
|
||||||
.setOutputs("0")
|
.setOutputs("0")
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
@ -546,7 +570,7 @@ public class BidirectionalTest extends BaseDL4JTest {
|
||||||
.updater(new Adam())
|
.updater(new Adam())
|
||||||
.graphBuilder()
|
.graphBuilder()
|
||||||
.addInputs("in")
|
.addInputs("in")
|
||||||
.layer("0", new SimpleRnn.Builder().nIn(10).nOut(10).build(), "in")
|
.layer("0", new SimpleRnn.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build(), "in")
|
||||||
.setOutputs("0")
|
.setOutputs("0")
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
@ -566,9 +590,20 @@ public class BidirectionalTest extends BaseDL4JTest {
|
||||||
|
|
||||||
INDArray out1 = net1.outputSingle(in);
|
INDArray out1 = net1.outputSingle(in);
|
||||||
INDArray out2 = net2.outputSingle(in);
|
INDArray out2 = net2.outputSingle(in);
|
||||||
INDArray out3 = TimeSeriesUtils.reverseTimeSeries(net3.outputSingle(
|
INDArray out3;
|
||||||
TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT)),
|
INDArray inReverse;
|
||||||
LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT);
|
if (rnnDataFormat == RNNFormat.NWC){
|
||||||
|
inReverse = TimeSeriesUtils.reverseTimeSeries(in.permute(0, 2, 1), LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT).permute(0, 2, 1);
|
||||||
|
out3 = net3.outputSingle(inReverse);
|
||||||
|
out3 = TimeSeriesUtils.reverseTimeSeries(out3.permute(0, 2, 1), LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT).permute(0, 2, 1);
|
||||||
|
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
inReverse = TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT);
|
||||||
|
out3 = net3.outputSingle(inReverse);
|
||||||
|
out3 = TimeSeriesUtils.reverseTimeSeries(out3, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
INDArray outExp;
|
INDArray outExp;
|
||||||
switch (m) {
|
switch (m) {
|
||||||
|
@ -582,7 +617,9 @@ public class BidirectionalTest extends BaseDL4JTest {
|
||||||
outExp = out2.add(out3).muli(0.5);
|
outExp = out2.add(out3).muli(0.5);
|
||||||
break;
|
break;
|
||||||
case CONCAT:
|
case CONCAT:
|
||||||
outExp = Nd4j.concat(1, out2, out3);
|
System.out.println(out2.shapeInfoToString());
|
||||||
|
System.out.println(out3.shapeInfoToString());
|
||||||
|
outExp = Nd4j.concat((rnnDataFormat == NCW)?1:2, out2, out3);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
throw new RuntimeException();
|
throw new RuntimeException();
|
||||||
|
@ -594,22 +631,26 @@ public class BidirectionalTest extends BaseDL4JTest {
|
||||||
//Check gradients:
|
//Check gradients:
|
||||||
if (m == Bidirectional.Mode.ADD || m == Bidirectional.Mode.CONCAT) {
|
if (m == Bidirectional.Mode.ADD || m == Bidirectional.Mode.CONCAT) {
|
||||||
|
|
||||||
INDArray eps = Nd4j.rand(new int[]{3, 10, 6});
|
INDArray eps = Nd4j.rand(inshape);
|
||||||
|
|
||||||
INDArray eps1;
|
INDArray eps1;
|
||||||
if (m == Bidirectional.Mode.CONCAT) {
|
if (m == Bidirectional.Mode.CONCAT) {
|
||||||
eps1 = Nd4j.concat(1, eps, eps);
|
eps1 = Nd4j.concat((rnnDataFormat == NCW)?1:2, eps, eps);
|
||||||
} else {
|
} else {
|
||||||
eps1 = eps;
|
eps1 = eps;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
INDArray epsReversed = (rnnDataFormat == NCW)?
|
||||||
|
TimeSeriesUtils.reverseTimeSeries(eps, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT):
|
||||||
|
TimeSeriesUtils.reverseTimeSeries(eps.permute(0, 2, 1), LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT)
|
||||||
|
.permute(0, 2, 1);
|
||||||
net1.outputSingle(true, false, in);
|
net1.outputSingle(true, false, in);
|
||||||
net2.outputSingle(true, false, in);
|
net2.outputSingle(true, false, in);
|
||||||
net3.outputSingle(true, false, TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT));
|
net3.outputSingle(true, false, inReverse);
|
||||||
|
|
||||||
Gradient g1 = net1.backpropGradient(eps1);
|
Gradient g1 = net1.backpropGradient(eps1);
|
||||||
Gradient g2 = net2.backpropGradient(eps);
|
Gradient g2 = net2.backpropGradient(eps);
|
||||||
Gradient g3 = net3.backpropGradient(TimeSeriesUtils.reverseTimeSeries(eps, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT));
|
Gradient g3 = net3.backpropGradient(epsReversed);
|
||||||
|
|
||||||
for (boolean updates : new boolean[]{false, true}) {
|
for (boolean updates : new boolean[]{false, true}) {
|
||||||
if (updates) {
|
if (updates) {
|
||||||
|
|
|
@ -23,6 +23,7 @@ import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||||
import org.deeplearning4j.nn.conf.CacheMode;
|
import org.deeplearning4j.nn.conf.CacheMode;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||||
import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
|
import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
|
@ -31,6 +32,8 @@ import org.deeplearning4j.nn.params.GravesLSTMParamInitializer;
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
import org.junit.runners.Parameterized;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
|
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -42,10 +45,18 @@ import org.nd4j.linalg.primitives.Pair;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.Assert.*;
|
||||||
|
|
||||||
|
@RunWith(Parameterized.class)
|
||||||
public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
|
public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
|
||||||
private double score = 0.0;
|
private double score = 0.0;
|
||||||
|
private RNNFormat rnnDataFormat;
|
||||||
|
|
||||||
|
public GravesBidirectionalLSTMTest(RNNFormat rnnDataFormat){
|
||||||
|
this.rnnDataFormat = rnnDataFormat;
|
||||||
|
}
|
||||||
|
@Parameterized.Parameters
|
||||||
|
public static Object[] params(){
|
||||||
|
return RNNFormat.values();
|
||||||
|
}
|
||||||
@Test
|
@Test
|
||||||
public void testBidirectionalLSTMGravesForwardBasic() {
|
public void testBidirectionalLSTMGravesForwardBasic() {
|
||||||
//Very basic test of forward prop. of LSTM layer with a time series.
|
//Very basic test of forward prop. of LSTM layer with a time series.
|
||||||
|
@ -55,7 +66,7 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
|
||||||
|
|
||||||
final NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder()
|
final NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
.layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().nIn(nIn)
|
.layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().nIn(nIn)
|
||||||
.nOut(nHiddenUnits).activation(Activation.TANH).build())
|
.nOut(nHiddenUnits).dataFormat(rnnDataFormat).activation(Activation.TANH).build())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
val numParams = conf.getLayer().initializer().numParams(conf);
|
val numParams = conf.getLayer().initializer().numParams(conf);
|
||||||
|
@ -65,22 +76,41 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
|
||||||
|
|
||||||
//Data: has shape [miniBatchSize,nIn,timeSeriesLength];
|
//Data: has shape [miniBatchSize,nIn,timeSeriesLength];
|
||||||
//Output/activations has shape [miniBatchsize,nHiddenUnits,timeSeriesLength];
|
//Output/activations has shape [miniBatchsize,nHiddenUnits,timeSeriesLength];
|
||||||
|
if (rnnDataFormat == RNNFormat.NCW){
|
||||||
|
final INDArray dataSingleExampleTimeLength1 = Nd4j.ones(1, nIn, 1);
|
||||||
|
final INDArray activations1 = layer.activate(dataSingleExampleTimeLength1, false, LayerWorkspaceMgr.noWorkspaces());
|
||||||
|
assertArrayEquals(activations1.shape(), new long[] {1, nHiddenUnits, 1});
|
||||||
|
|
||||||
final INDArray dataSingleExampleTimeLength1 = Nd4j.ones(1, nIn, 1);
|
final INDArray dataMultiExampleLength1 = Nd4j.ones(10, nIn, 1);
|
||||||
final INDArray activations1 = layer.activate(dataSingleExampleTimeLength1, false, LayerWorkspaceMgr.noWorkspaces());
|
final INDArray activations2 = layer.activate(dataMultiExampleLength1, false, LayerWorkspaceMgr.noWorkspaces());
|
||||||
assertArrayEquals(activations1.shape(), new long[] {1, nHiddenUnits, 1});
|
assertArrayEquals(activations2.shape(), new long[] {10, nHiddenUnits, 1});
|
||||||
|
|
||||||
final INDArray dataMultiExampleLength1 = Nd4j.ones(10, nIn, 1);
|
final INDArray dataSingleExampleLength12 = Nd4j.ones(1, nIn, 12);
|
||||||
final INDArray activations2 = layer.activate(dataMultiExampleLength1, false, LayerWorkspaceMgr.noWorkspaces());
|
final INDArray activations3 = layer.activate(dataSingleExampleLength12, false, LayerWorkspaceMgr.noWorkspaces());
|
||||||
assertArrayEquals(activations2.shape(), new long[] {10, nHiddenUnits, 1});
|
assertArrayEquals(activations3.shape(), new long[] {1, nHiddenUnits, 12});
|
||||||
|
|
||||||
final INDArray dataSingleExampleLength12 = Nd4j.ones(1, nIn, 12);
|
final INDArray dataMultiExampleLength15 = Nd4j.ones(10, nIn, 15);
|
||||||
final INDArray activations3 = layer.activate(dataSingleExampleLength12, false, LayerWorkspaceMgr.noWorkspaces());
|
final INDArray activations4 = layer.activate(dataMultiExampleLength15, false, LayerWorkspaceMgr.noWorkspaces());
|
||||||
assertArrayEquals(activations3.shape(), new long[] {1, nHiddenUnits, 12});
|
assertArrayEquals(activations4.shape(), new long[] {10, nHiddenUnits, 15});
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
final INDArray dataSingleExampleTimeLength1 = Nd4j.ones(1, 1, nIn);
|
||||||
|
final INDArray activations1 = layer.activate(dataSingleExampleTimeLength1, false, LayerWorkspaceMgr.noWorkspaces());
|
||||||
|
assertArrayEquals(activations1.shape(), new long[] {1, 1, nHiddenUnits});
|
||||||
|
|
||||||
|
final INDArray dataMultiExampleLength1 = Nd4j.ones(10, 1, nIn);
|
||||||
|
final INDArray activations2 = layer.activate(dataMultiExampleLength1, false, LayerWorkspaceMgr.noWorkspaces());
|
||||||
|
assertArrayEquals(activations2.shape(), new long[] {10, 1, nHiddenUnits});
|
||||||
|
|
||||||
|
final INDArray dataSingleExampleLength12 = Nd4j.ones(1, 12, nIn);
|
||||||
|
final INDArray activations3 = layer.activate(dataSingleExampleLength12, false, LayerWorkspaceMgr.noWorkspaces());
|
||||||
|
assertArrayEquals(activations3.shape(), new long[] {1, 12, nHiddenUnits});
|
||||||
|
|
||||||
|
final INDArray dataMultiExampleLength15 = Nd4j.ones(10, 15, nIn);
|
||||||
|
final INDArray activations4 = layer.activate(dataMultiExampleLength15, false, LayerWorkspaceMgr.noWorkspaces());
|
||||||
|
assertArrayEquals(activations4.shape(), new long[] {10, 15, nHiddenUnits});
|
||||||
|
}
|
||||||
|
|
||||||
final INDArray dataMultiExampleLength15 = Nd4j.ones(10, nIn, 15);
|
|
||||||
final INDArray activations4 = layer.activate(dataMultiExampleLength15, false, LayerWorkspaceMgr.noWorkspaces());
|
|
||||||
assertArrayEquals(activations4.shape(), new long[] {10, nHiddenUnits, 15});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -94,14 +124,15 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
|
||||||
testGravesBackwardBasicHelper(13, 3, 17, 1, 1); //Edge case: both miniBatchSize = 1 and timeSeriesLength = 1
|
testGravesBackwardBasicHelper(13, 3, 17, 1, 1); //Edge case: both miniBatchSize = 1 and timeSeriesLength = 1
|
||||||
}
|
}
|
||||||
|
|
||||||
private static void testGravesBackwardBasicHelper(int nIn, int nOut, int lstmNHiddenUnits, int miniBatchSize,
|
private void testGravesBackwardBasicHelper(int nIn, int nOut, int lstmNHiddenUnits, int miniBatchSize,
|
||||||
int timeSeriesLength) {
|
int timeSeriesLength) {
|
||||||
|
|
||||||
INDArray inputData = Nd4j.ones(miniBatchSize, nIn, timeSeriesLength);
|
INDArray inputData = (rnnDataFormat == RNNFormat.NCW)?Nd4j.ones(miniBatchSize, nIn, timeSeriesLength):
|
||||||
|
Nd4j.ones(miniBatchSize, timeSeriesLength, nIn);
|
||||||
|
|
||||||
NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder()
|
NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
.layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().nIn(nIn)
|
.layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().nIn(nIn)
|
||||||
.nOut(lstmNHiddenUnits)
|
.nOut(lstmNHiddenUnits).dataFormat(rnnDataFormat)
|
||||||
.dist(new UniformDistribution(0, 1)).activation(Activation.TANH).build())
|
.dist(new UniformDistribution(0, 1)).activation(Activation.TANH).build())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
@ -114,7 +145,8 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
|
||||||
lstm.activate(inputData, false, LayerWorkspaceMgr.noWorkspaces());
|
lstm.activate(inputData, false, LayerWorkspaceMgr.noWorkspaces());
|
||||||
assertNotNull(lstm.input());
|
assertNotNull(lstm.input());
|
||||||
|
|
||||||
INDArray epsilon = Nd4j.ones(miniBatchSize, lstmNHiddenUnits, timeSeriesLength);
|
INDArray epsilon =(rnnDataFormat == RNNFormat.NCW)? Nd4j.ones(miniBatchSize, lstmNHiddenUnits, timeSeriesLength):
|
||||||
|
Nd4j.ones(miniBatchSize, timeSeriesLength, lstmNHiddenUnits);
|
||||||
|
|
||||||
Pair<Gradient, INDArray> out = lstm.backpropGradient(epsilon, LayerWorkspaceMgr.noWorkspaces());
|
Pair<Gradient, INDArray> out = lstm.backpropGradient(epsilon, LayerWorkspaceMgr.noWorkspaces());
|
||||||
Gradient outGradient = out.getFirst();
|
Gradient outGradient = out.getFirst();
|
||||||
|
@ -147,7 +179,11 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
|
||||||
assertArrayEquals(recurrentWeightGradientB.shape(), new long[] {lstmNHiddenUnits, 4 * lstmNHiddenUnits + 3});
|
assertArrayEquals(recurrentWeightGradientB.shape(), new long[] {lstmNHiddenUnits, 4 * lstmNHiddenUnits + 3});
|
||||||
|
|
||||||
assertNotNull(nextEpsilon);
|
assertNotNull(nextEpsilon);
|
||||||
assertArrayEquals(nextEpsilon.shape(), new long[] {miniBatchSize, nIn, timeSeriesLength});
|
if (rnnDataFormat == RNNFormat.NCW) {
|
||||||
|
assertArrayEquals(nextEpsilon.shape(), new long[]{miniBatchSize, nIn, timeSeriesLength});
|
||||||
|
}else{
|
||||||
|
assertArrayEquals(nextEpsilon.shape(), new long[]{miniBatchSize, timeSeriesLength, nIn });
|
||||||
|
}
|
||||||
|
|
||||||
//Check update:
|
//Check update:
|
||||||
for (String s : outGradient.gradientForVariable().keySet()) {
|
for (String s : outGradient.gradientForVariable().keySet()) {
|
||||||
|
@ -226,7 +262,7 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
|
||||||
|
|
||||||
final NeuralNetConfiguration confBidirectional = new NeuralNetConfiguration.Builder()
|
final NeuralNetConfiguration confBidirectional = new NeuralNetConfiguration.Builder()
|
||||||
.layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().nIn(nIn)
|
.layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().nIn(nIn)
|
||||||
.nOut(layerSize)
|
.nOut(layerSize).dataFormat(rnnDataFormat)
|
||||||
.dist(new UniformDistribution(-0.1, 0.1)).activation(Activation.TANH).build())
|
.dist(new UniformDistribution(-0.1, 0.1)).activation(Activation.TANH).build())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
@ -237,7 +273,8 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
|
||||||
.instantiate(confBidirectional, null, 0, params, true, params.dataType());
|
.instantiate(confBidirectional, null, 0, params, true, params.dataType());
|
||||||
|
|
||||||
|
|
||||||
final INDArray sig = Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength});
|
final INDArray sig = (rnnDataFormat == RNNFormat.NCW)?Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength}):
|
||||||
|
Nd4j.rand(new int[] {miniBatchSize, timeSeriesLength, nIn});
|
||||||
|
|
||||||
final INDArray act1 = bidirectionalLSTM.activate(sig, false, LayerWorkspaceMgr.noWorkspaces());
|
final INDArray act1 = bidirectionalLSTM.activate(sig, false, LayerWorkspaceMgr.noWorkspaces());
|
||||||
|
|
||||||
|
@ -265,13 +302,13 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
|
||||||
final NeuralNetConfiguration confBidirectional =
|
final NeuralNetConfiguration confBidirectional =
|
||||||
new NeuralNetConfiguration.Builder()
|
new NeuralNetConfiguration.Builder()
|
||||||
.layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder()
|
.layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder()
|
||||||
.nIn(nIn).nOut(layerSize)
|
.nIn(nIn).nOut(layerSize).dataFormat(rnnDataFormat)
|
||||||
.dist(new UniformDistribution(-0.1, 0.1))
|
.dist(new UniformDistribution(-0.1, 0.1))
|
||||||
.activation(Activation.TANH).updater(new NoOp()).build())
|
.activation(Activation.TANH).updater(new NoOp()).build())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
final NeuralNetConfiguration confForwards = new NeuralNetConfiguration.Builder()
|
final NeuralNetConfiguration confForwards = new NeuralNetConfiguration.Builder()
|
||||||
.layer(new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn).nOut(layerSize)
|
.layer(new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn).nOut(layerSize).dataFormat(rnnDataFormat)
|
||||||
.weightInit(WeightInit.ZERO).activation(Activation.TANH).build())
|
.weightInit(WeightInit.ZERO).activation(Activation.TANH).build())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
@ -290,9 +327,16 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
|
||||||
Nd4j.create(1, confForwards.getLayer().initializer().numParams(confForwards)));
|
Nd4j.create(1, confForwards.getLayer().initializer().numParams(confForwards)));
|
||||||
|
|
||||||
|
|
||||||
final INDArray sig = Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength});
|
final INDArray sig = (rnnDataFormat == RNNFormat.NCW)?Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength}):
|
||||||
|
Nd4j.rand(new int[] {miniBatchSize, timeSeriesLength, nIn});
|
||||||
final INDArray sigb = sig.dup();
|
final INDArray sigb = sig.dup();
|
||||||
reverseColumnsInPlace(sigb.slice(0));
|
|
||||||
|
if (rnnDataFormat == RNNFormat.NCW) {
|
||||||
|
reverseColumnsInPlace(sigb.slice(0));
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
reverseColumnsInPlace(sigb.slice(0).permute(1, 0));
|
||||||
|
}
|
||||||
|
|
||||||
final INDArray recurrentWeightsF = bidirectionalLSTM
|
final INDArray recurrentWeightsF = bidirectionalLSTM
|
||||||
.getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS);
|
.getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS);
|
||||||
|
@ -345,10 +389,14 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
|
||||||
|
|
||||||
assertArrayEquals(activation1.data().asFloat(), activation2.data().asFloat(), 1e-5f);
|
assertArrayEquals(activation1.data().asFloat(), activation2.data().asFloat(), 1e-5f);
|
||||||
|
|
||||||
final INDArray randSig = Nd4j.rand(new int[] {1, layerSize, timeSeriesLength});
|
final INDArray randSig = (rnnDataFormat == RNNFormat.NCW)?Nd4j.rand(new int[] {1, layerSize, timeSeriesLength}):
|
||||||
final INDArray randSigBackwards = randSig.dup();
|
Nd4j.rand(new int[] {1, timeSeriesLength, layerSize});
|
||||||
reverseColumnsInPlace(randSigBackwards.slice(0));
|
INDArray randSigBackwards = randSig.dup();
|
||||||
|
if (rnnDataFormat == RNNFormat.NCW){
|
||||||
|
reverseColumnsInPlace(randSigBackwards.slice(0));
|
||||||
|
}else{
|
||||||
|
reverseColumnsInPlace(randSigBackwards.slice(0).permute(1, 0));
|
||||||
|
}
|
||||||
|
|
||||||
final Pair<Gradient, INDArray> backprop1 = forwardsLSTM.backpropGradient(randSig, LayerWorkspaceMgr.noWorkspaces());
|
final Pair<Gradient, INDArray> backprop1 = forwardsLSTM.backpropGradient(randSig, LayerWorkspaceMgr.noWorkspaces());
|
||||||
final Pair<Gradient, INDArray> backprop2 = bidirectionalLSTM.backpropGradient(randSig, LayerWorkspaceMgr.noWorkspaces());
|
final Pair<Gradient, INDArray> backprop2 = bidirectionalLSTM.backpropGradient(randSig, LayerWorkspaceMgr.noWorkspaces());
|
||||||
|
@ -399,10 +447,16 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
|
||||||
final INDArray activation3 = bidirectionalLSTM.activate(sigb, false, LayerWorkspaceMgr.noWorkspaces()).slice(0);
|
final INDArray activation3 = bidirectionalLSTM.activate(sigb, false, LayerWorkspaceMgr.noWorkspaces()).slice(0);
|
||||||
|
|
||||||
final INDArray activation3Reverse = activation3.dup();
|
final INDArray activation3Reverse = activation3.dup();
|
||||||
reverseColumnsInPlace(activation3Reverse);
|
if (rnnDataFormat == RNNFormat.NCW){
|
||||||
|
reverseColumnsInPlace(activation3Reverse);
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
reverseColumnsInPlace(activation3Reverse.permute(1, 0));
|
||||||
|
}
|
||||||
|
|
||||||
assertEquals(activation3Reverse, activation1);
|
|
||||||
assertArrayEquals(activation3Reverse.shape(), activation1.shape());
|
assertArrayEquals(activation3Reverse.shape(), activation1.shape());
|
||||||
|
assertEquals(activation3Reverse, activation1);
|
||||||
|
|
||||||
|
|
||||||
//test backprop now
|
//test backprop now
|
||||||
final INDArray refBackGradientReccurrent =
|
final INDArray refBackGradientReccurrent =
|
||||||
|
@ -434,7 +488,12 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
|
||||||
final INDArray refEpsilon = backprop1.getSecond().dup();
|
final INDArray refEpsilon = backprop1.getSecond().dup();
|
||||||
final INDArray backEpsilon = backprop3.getSecond().dup();
|
final INDArray backEpsilon = backprop3.getSecond().dup();
|
||||||
|
|
||||||
reverseColumnsInPlace(refEpsilon.slice(0));
|
if (rnnDataFormat == RNNFormat.NCW) {
|
||||||
|
reverseColumnsInPlace(refEpsilon.slice(0));
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
reverseColumnsInPlace(refEpsilon.slice(0).permute(1, 0));
|
||||||
|
}
|
||||||
assertArrayEquals(backEpsilon.dup().data().asDouble(), refEpsilon.dup().data().asDouble(), 1e-6);
|
assertArrayEquals(backEpsilon.dup().data().asDouble(), refEpsilon.dup().data().asDouble(), 1e-6);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -477,10 +536,10 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
|
||||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
||||||
.seed(12345).list()
|
.seed(12345).list()
|
||||||
.layer(0, new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder()
|
.layer(0, new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder()
|
||||||
.gateActivationFunction(gateAfn).activation(Activation.TANH).nIn(2).nOut(2)
|
.gateActivationFunction(gateAfn).activation(Activation.TANH).nIn(2).nOut(2).dataFormat(rnnDataFormat)
|
||||||
.build())
|
.build())
|
||||||
.layer(1, new org.deeplearning4j.nn.conf.layers.RnnOutputLayer.Builder()
|
.layer(1, new org.deeplearning4j.nn.conf.layers.RnnOutputLayer.Builder()
|
||||||
.lossFunction(LossFunctions.LossFunction.MSE).nIn(2).nOut(2)
|
.lossFunction(LossFunctions.LossFunction.MSE).nIn(2).nOut(2).dataFormat(rnnDataFormat)
|
||||||
.activation(Activation.TANH).build())
|
.activation(Activation.TANH).build())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
@ -492,7 +551,10 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
|
||||||
|
|
||||||
INDArray in = Nd4j.rand(new int[] {3, 2, 5});
|
INDArray in = Nd4j.rand(new int[] {3, 2, 5});
|
||||||
INDArray labels = Nd4j.rand(new int[] {3, 2, 5});
|
INDArray labels = Nd4j.rand(new int[] {3, 2, 5});
|
||||||
|
if (rnnDataFormat == RNNFormat.NWC){
|
||||||
|
in = in.permute(0, 2, 1);
|
||||||
|
labels = labels.permute(0, 2, 1);
|
||||||
|
}
|
||||||
net.fit(in, labels);
|
net.fit(in, labels);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,11 +21,14 @@ import org.deeplearning4j.TestUtils;
|
||||||
import org.deeplearning4j.nn.api.Layer;
|
import org.deeplearning4j.nn.api.Layer;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||||
import org.deeplearning4j.nn.conf.layers.LSTM;
|
import org.deeplearning4j.nn.conf.layers.LSTM;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
import org.deeplearning4j.optimize.api.TrainingListener;
|
import org.deeplearning4j.optimize.api.TrainingListener;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
import org.junit.runners.Parameterized;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -36,9 +39,17 @@ import java.util.Collections;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
|
@RunWith(Parameterized.class)
|
||||||
public class MaskZeroLayerTest extends BaseDL4JTest {
|
public class MaskZeroLayerTest extends BaseDL4JTest {
|
||||||
|
private RNNFormat rnnDataFormat;
|
||||||
|
|
||||||
|
public MaskZeroLayerTest(RNNFormat rnnDataFormat){
|
||||||
|
this.rnnDataFormat = rnnDataFormat;
|
||||||
|
}
|
||||||
|
@Parameterized.Parameters
|
||||||
|
public static Object[] params(){
|
||||||
|
return RNNFormat.values();
|
||||||
|
}
|
||||||
@Test
|
@Test
|
||||||
public void activate() {
|
public void activate() {
|
||||||
|
|
||||||
|
@ -57,7 +68,7 @@ public class MaskZeroLayerTest extends BaseDL4JTest {
|
||||||
.activation(Activation.IDENTITY)
|
.activation(Activation.IDENTITY)
|
||||||
.gateActivationFunction(Activation.IDENTITY)
|
.gateActivationFunction(Activation.IDENTITY)
|
||||||
.nIn(2)
|
.nIn(2)
|
||||||
.nOut(1)
|
.nOut(1).dataFormat(rnnDataFormat)
|
||||||
.build();
|
.build();
|
||||||
NeuralNetConfiguration conf = new NeuralNetConfiguration();
|
NeuralNetConfiguration conf = new NeuralNetConfiguration();
|
||||||
conf.setLayer(underlying);
|
conf.setLayer(underlying);
|
||||||
|
@ -72,20 +83,25 @@ public class MaskZeroLayerTest extends BaseDL4JTest {
|
||||||
|
|
||||||
MaskZeroLayer l = new MaskZeroLayer(lstm, maskingValue);
|
MaskZeroLayer l = new MaskZeroLayer(lstm, maskingValue);
|
||||||
INDArray input = Nd4j.create(Arrays.asList(ex1, ex2), new int[]{2, 2, 3});
|
INDArray input = Nd4j.create(Arrays.asList(ex1, ex2), new int[]{2, 2, 3});
|
||||||
|
if (rnnDataFormat == RNNFormat.NWC){
|
||||||
|
input = input.permute(0, 2, 1);
|
||||||
|
}
|
||||||
//WHEN
|
//WHEN
|
||||||
INDArray out = l.activate(input, true, LayerWorkspaceMgr.noWorkspaces());
|
INDArray out = l.activate(input, true, LayerWorkspaceMgr.noWorkspaces());
|
||||||
|
if (rnnDataFormat == RNNFormat.NWC){
|
||||||
|
out = out.permute(0, 2,1);
|
||||||
|
}
|
||||||
//THEN output should only be incremented for the non-zero timesteps
|
//THEN output should only be incremented for the non-zero timesteps
|
||||||
INDArray firstExampleOutput = out.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.all());
|
INDArray firstExampleOutput = out.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.all());
|
||||||
INDArray secondExampleOutput = out.get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.all());
|
INDArray secondExampleOutput = out.get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.all());
|
||||||
|
|
||||||
assertEquals(firstExampleOutput.getDouble(0), 0.0, 1e-6);
|
assertEquals(0.0, firstExampleOutput.getDouble(0), 1e-6);
|
||||||
assertEquals(firstExampleOutput.getDouble(1), 1.0, 1e-6);
|
assertEquals(1.0, firstExampleOutput.getDouble(1), 1e-6);
|
||||||
assertEquals(firstExampleOutput.getDouble(2), 2.0, 1e-6);
|
assertEquals(2.0, firstExampleOutput.getDouble(2), 1e-6);
|
||||||
|
|
||||||
assertEquals(secondExampleOutput.getDouble(0), 0.0, 1e-6);
|
assertEquals(0.0, secondExampleOutput.getDouble(0), 1e-6);
|
||||||
assertEquals(secondExampleOutput.getDouble(1), 0.0, 1e-6);
|
assertEquals(0.0, secondExampleOutput.getDouble(1), 1e-6);
|
||||||
assertEquals(secondExampleOutput.getDouble(2), 1.0, 1e-6);
|
assertEquals(1.0, secondExampleOutput.getDouble(2), 1e-6);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -94,7 +110,7 @@ public class MaskZeroLayerTest extends BaseDL4JTest {
|
||||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
.list()
|
.list()
|
||||||
.layer(new org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer.Builder()
|
.layer(new org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer.Builder()
|
||||||
.setMaskValue(0.0).setUnderlying(new LSTM.Builder().nIn(4).nOut(5).build()).build())
|
.setMaskValue(0.0).setUnderlying(new LSTM.Builder().nIn(4).nOut(5).dataFormat(rnnDataFormat).build()).build())
|
||||||
.build();
|
.build();
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
|
@ -0,0 +1,394 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.nn.layers.recurrent;
|
||||||
|
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
|
import org.deeplearning4j.TestUtils;
|
||||||
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||||
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.LSTM;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
|
||||||
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
import org.junit.runners.Parameterized;
|
||||||
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
|
@RunWith(Parameterized.class)
|
||||||
|
@AllArgsConstructor
|
||||||
|
public class RnnDataFormatTests extends BaseDL4JTest {
|
||||||
|
|
||||||
|
private boolean helpers;
|
||||||
|
private boolean lastTimeStep;
|
||||||
|
private boolean maskZeros;
|
||||||
|
|
||||||
|
@Parameterized.Parameters(name = "helpers={0},lastTimeStep={1},maskZero={2}")
|
||||||
|
public static List params(){
|
||||||
|
List<Object[]> ret = new ArrayList<>();
|
||||||
|
for (boolean helpers: new boolean[]{true, false})
|
||||||
|
for (boolean lastTimeStep: new boolean[]{true, false})
|
||||||
|
for (boolean maskZero: new boolean[]{true, false})
|
||||||
|
ret.add(new Object[]{helpers, lastTimeStep, maskZero});
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testSimpleRnn() {
|
||||||
|
try {
|
||||||
|
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
Nd4j.getEnvironment().allowHelpers(helpers);
|
||||||
|
String msg = "Helpers: " + helpers + ", lastTimeStep: " + lastTimeStep + ", maskZeros: " + maskZeros;
|
||||||
|
System.out.println(" --- " + msg + " ---");
|
||||||
|
|
||||||
|
INDArray inNCW = Nd4j.rand(DataType.FLOAT, 2, 3, 12);
|
||||||
|
|
||||||
|
INDArray labelsNWC = (lastTimeStep) ?TestUtils.randomOneHot(2, 10): TestUtils.randomOneHot(2 * 12, 10).reshape(2, 12, 10);
|
||||||
|
|
||||||
|
TestCase tc = TestCase.builder()
|
||||||
|
.msg(msg)
|
||||||
|
.net1(getSimpleRnnNet(RNNFormat.NCW, true, lastTimeStep, maskZeros))
|
||||||
|
.net2(getSimpleRnnNet(RNNFormat.NCW, false, lastTimeStep, maskZeros))
|
||||||
|
.net3(getSimpleRnnNet(RNNFormat.NWC, true, lastTimeStep, maskZeros))
|
||||||
|
.net4(getSimpleRnnNet(RNNFormat.NWC, false, lastTimeStep, maskZeros))
|
||||||
|
.inNCW(inNCW)
|
||||||
|
.labelsNCW((lastTimeStep)? labelsNWC: labelsNWC.permute(0, 2, 1))
|
||||||
|
.labelsNWC(labelsNWC)
|
||||||
|
.testLayerIdx(1)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
TestCase.testHelper(tc);
|
||||||
|
|
||||||
|
|
||||||
|
} finally {
|
||||||
|
Nd4j.getEnvironment().allowHelpers(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testLSTM() {
|
||||||
|
try {
|
||||||
|
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
Nd4j.getEnvironment().allowHelpers(helpers);
|
||||||
|
String msg = "Helpers: " + helpers + ", lastTimeStep: " + lastTimeStep + ", maskZeros: " + maskZeros;
|
||||||
|
System.out.println(" --- " + msg + " ---");
|
||||||
|
|
||||||
|
INDArray inNCW = Nd4j.rand(DataType.FLOAT, 2, 3, 12);
|
||||||
|
|
||||||
|
INDArray labelsNWC = (lastTimeStep) ?TestUtils.randomOneHot(2, 10): TestUtils.randomOneHot(2 * 12, 10).reshape(2, 12, 10);
|
||||||
|
|
||||||
|
TestCase tc = TestCase.builder()
|
||||||
|
.msg(msg)
|
||||||
|
.net1(getLstmNet(RNNFormat.NCW, true, lastTimeStep, maskZeros))
|
||||||
|
.net2(getLstmNet(RNNFormat.NCW, false, lastTimeStep, maskZeros))
|
||||||
|
.net3(getLstmNet(RNNFormat.NWC, true, lastTimeStep, maskZeros))
|
||||||
|
.net4(getLstmNet(RNNFormat.NWC, false, lastTimeStep, maskZeros))
|
||||||
|
.inNCW(inNCW)
|
||||||
|
.labelsNCW((lastTimeStep)? labelsNWC: labelsNWC.permute(0, 2, 1))
|
||||||
|
.labelsNWC(labelsNWC)
|
||||||
|
.testLayerIdx(1)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
TestCase.testHelper(tc);
|
||||||
|
|
||||||
|
|
||||||
|
} finally {
|
||||||
|
Nd4j.getEnvironment().allowHelpers(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testGraveLSTM() {
|
||||||
|
try {
|
||||||
|
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
Nd4j.getEnvironment().allowHelpers(helpers);
|
||||||
|
String msg = "Helpers: " + helpers + ", lastTimeStep: " + lastTimeStep + ", maskZeros: " + maskZeros;
|
||||||
|
System.out.println(" --- " + msg + " ---");
|
||||||
|
|
||||||
|
INDArray inNCW = Nd4j.rand(DataType.FLOAT, 2, 3, 12);
|
||||||
|
|
||||||
|
INDArray labelsNWC = (lastTimeStep) ?TestUtils.randomOneHot(2, 10): TestUtils.randomOneHot(2 * 12, 10).reshape(2, 12, 10);
|
||||||
|
|
||||||
|
TestCase tc = TestCase.builder()
|
||||||
|
.msg(msg)
|
||||||
|
.net1(getGravesLstmNet(RNNFormat.NCW, true, lastTimeStep, maskZeros))
|
||||||
|
.net2(getGravesLstmNet(RNNFormat.NCW, false, lastTimeStep, maskZeros))
|
||||||
|
.net3(getGravesLstmNet(RNNFormat.NWC, true, lastTimeStep, maskZeros))
|
||||||
|
.net4(getGravesLstmNet(RNNFormat.NWC, false, lastTimeStep, maskZeros))
|
||||||
|
.inNCW(inNCW)
|
||||||
|
.labelsNCW((lastTimeStep)? labelsNWC: labelsNWC.permute(0, 2, 1))
|
||||||
|
.labelsNWC(labelsNWC)
|
||||||
|
.testLayerIdx(1)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
TestCase.testHelper(tc);
|
||||||
|
|
||||||
|
|
||||||
|
} finally {
|
||||||
|
Nd4j.getEnvironment().allowHelpers(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testGraveBiLSTM() {
|
||||||
|
try {
|
||||||
|
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
Nd4j.getEnvironment().allowHelpers(helpers);
|
||||||
|
String msg = "Helpers: " + helpers + ", lastTimeStep: " + lastTimeStep + ", maskZeros: " + maskZeros;
|
||||||
|
System.out.println(" --- " + msg + " ---");
|
||||||
|
|
||||||
|
INDArray inNCW = Nd4j.rand(DataType.FLOAT, 2, 3, 12);
|
||||||
|
|
||||||
|
INDArray labelsNWC = (lastTimeStep) ?TestUtils.randomOneHot(2, 10): TestUtils.randomOneHot(2 * 12, 10).reshape(2, 12, 10);
|
||||||
|
|
||||||
|
TestCase tc = TestCase.builder()
|
||||||
|
.msg(msg)
|
||||||
|
.net1(getGravesBidirectionalLstmNet(RNNFormat.NCW, true, lastTimeStep, maskZeros))
|
||||||
|
.net2(getGravesBidirectionalLstmNet(RNNFormat.NCW, false, lastTimeStep, maskZeros))
|
||||||
|
.net3(getGravesBidirectionalLstmNet(RNNFormat.NWC, true, lastTimeStep, maskZeros))
|
||||||
|
.net4(getGravesBidirectionalLstmNet(RNNFormat.NWC, false, lastTimeStep, maskZeros))
|
||||||
|
.inNCW(inNCW)
|
||||||
|
.labelsNCW((lastTimeStep)? labelsNWC: labelsNWC.permute(0, 2, 1))
|
||||||
|
.labelsNWC(labelsNWC)
|
||||||
|
.testLayerIdx(1)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
TestCase.testHelper(tc);
|
||||||
|
|
||||||
|
|
||||||
|
} finally {
|
||||||
|
Nd4j.getEnvironment().allowHelpers(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
private MultiLayerNetwork getGravesBidirectionalLstmNet(RNNFormat format, boolean setOnLayerAlso, boolean lastTimeStep, boolean maskZeros) {
|
||||||
|
if (setOnLayerAlso) {
|
||||||
|
return getNetWithLayer(new GravesBidirectionalLSTM.Builder().nOut(3)
|
||||||
|
.dataFormat(format).build(), format, lastTimeStep, maskZeros);
|
||||||
|
} else {
|
||||||
|
return getNetWithLayer(new GravesBidirectionalLSTM.Builder().nOut(3).build(), format, lastTimeStep, maskZeros);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
private MultiLayerNetwork getGravesLstmNet(RNNFormat format, boolean setOnLayerAlso, boolean lastTimeStep, boolean maskZeros) {
|
||||||
|
if (setOnLayerAlso) {
|
||||||
|
return getNetWithLayer(new GravesLSTM.Builder().nOut(3)
|
||||||
|
.dataFormat(format).build(), format, lastTimeStep, maskZeros);
|
||||||
|
} else {
|
||||||
|
return getNetWithLayer(new GravesLSTM.Builder().nOut(3).build(), format, lastTimeStep, maskZeros);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private MultiLayerNetwork getLstmNet(RNNFormat format, boolean setOnLayerAlso, boolean lastTimeStep, boolean maskZeros) {
|
||||||
|
if (setOnLayerAlso) {
|
||||||
|
return getNetWithLayer(new LSTM.Builder().nOut(3)
|
||||||
|
.dataFormat(format).build(), format, lastTimeStep, maskZeros);
|
||||||
|
} else {
|
||||||
|
return getNetWithLayer(new LSTM.Builder().nOut(3).build(), format, lastTimeStep, maskZeros);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private MultiLayerNetwork getSimpleRnnNet(RNNFormat format, boolean setOnLayerAlso, boolean lastTimeStep, boolean maskZeros) {
|
||||||
|
if (setOnLayerAlso) {
|
||||||
|
return getNetWithLayer(new SimpleRnn.Builder().nOut(3)
|
||||||
|
.dataFormat(format).build(), format, lastTimeStep, maskZeros);
|
||||||
|
} else {
|
||||||
|
return getNetWithLayer(new SimpleRnn.Builder().nOut(3).build(), format, lastTimeStep, maskZeros);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
private MultiLayerNetwork getNetWithLayer(Layer layer, RNNFormat format, boolean lastTimeStep, boolean maskZeros) {
|
||||||
|
if (maskZeros){
|
||||||
|
layer = new MaskZeroLayer.Builder().setMaskValue(0.).setUnderlying(layer).build();
|
||||||
|
}
|
||||||
|
if(lastTimeStep){
|
||||||
|
layer = new LastTimeStep(layer);
|
||||||
|
}
|
||||||
|
NeuralNetConfiguration.ListBuilder builder = new NeuralNetConfiguration.Builder()
|
||||||
|
.seed(12345)
|
||||||
|
.list()
|
||||||
|
.layer(new LSTM.Builder()
|
||||||
|
.nIn(3)
|
||||||
|
.activation(Activation.TANH)
|
||||||
|
.dataFormat(format)
|
||||||
|
.nOut(3)
|
||||||
|
.helperAllowFallback(false)
|
||||||
|
.build())
|
||||||
|
.layer(layer)
|
||||||
|
.layer(
|
||||||
|
(lastTimeStep)?new OutputLayer.Builder().activation(Activation.SOFTMAX).nOut(10).build():
|
||||||
|
new RnnOutputLayer.Builder().activation(Activation.SOFTMAX).nOut(10).dataFormat(format).build()
|
||||||
|
)
|
||||||
|
.setInputType(InputType.recurrent(3, 12, format));
|
||||||
|
|
||||||
|
MultiLayerNetwork net = new MultiLayerNetwork(builder.build());
|
||||||
|
net.init();
|
||||||
|
return net;
|
||||||
|
}
|
||||||
|
|
||||||
|
@AllArgsConstructor
|
||||||
|
@Data
|
||||||
|
@NoArgsConstructor
|
||||||
|
@Builder
|
||||||
|
private static class TestCase {
|
||||||
|
private String msg;
|
||||||
|
private MultiLayerNetwork net1;
|
||||||
|
private MultiLayerNetwork net2;
|
||||||
|
private MultiLayerNetwork net3;
|
||||||
|
private MultiLayerNetwork net4;
|
||||||
|
private INDArray inNCW;
|
||||||
|
private INDArray labelsNCW;
|
||||||
|
private INDArray labelsNWC;
|
||||||
|
private int testLayerIdx;
|
||||||
|
private boolean nwcOutput;
|
||||||
|
|
||||||
|
public static void testHelper(TestCase tc) {
|
||||||
|
|
||||||
|
tc.net2.params().assign(tc.net1.params());
|
||||||
|
tc.net3.params().assign(tc.net1.params());
|
||||||
|
tc.net4.params().assign(tc.net1.params());
|
||||||
|
|
||||||
|
INDArray inNCW = tc.inNCW;
|
||||||
|
INDArray inNWC = tc.inNCW.permute(0, 2, 1).dup();
|
||||||
|
|
||||||
|
INDArray l0_1 = tc.net1.feedForward(inNCW).get(tc.testLayerIdx + 1);
|
||||||
|
INDArray l0_2 = tc.net2.feedForward(inNCW).get(tc.testLayerIdx + 1);
|
||||||
|
INDArray l0_3 = tc.net3.feedForward(inNWC).get(tc.testLayerIdx + 1);
|
||||||
|
INDArray l0_4 = tc.net4.feedForward(inNWC).get(tc.testLayerIdx + 1);
|
||||||
|
|
||||||
|
boolean rank3Out = tc.labelsNCW.rank() == 3;
|
||||||
|
assertEquals(tc.msg, l0_1, l0_2);
|
||||||
|
if (rank3Out){
|
||||||
|
assertEquals(tc.msg, l0_1, l0_3.permute(0, 2, 1));
|
||||||
|
assertEquals(tc.msg, l0_1, l0_4.permute(0, 2, 1));
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
assertEquals(tc.msg, l0_1, l0_3);
|
||||||
|
assertEquals(tc.msg, l0_1, l0_4);
|
||||||
|
}
|
||||||
|
INDArray out1 = tc.net1.output(inNCW);
|
||||||
|
INDArray out2 = tc.net2.output(inNCW);
|
||||||
|
INDArray out3 = tc.net3.output(inNWC);
|
||||||
|
INDArray out4 = tc.net4.output(inNWC);
|
||||||
|
|
||||||
|
assertEquals(tc.msg, out1, out2);
|
||||||
|
if (rank3Out){
|
||||||
|
assertEquals(tc.msg, out1, out3.permute(0, 2, 1)); //NWC to NCW
|
||||||
|
assertEquals(tc.msg, out1, out4.permute(0, 2, 1));
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
assertEquals(tc.msg, out1, out3); //NWC to NCW
|
||||||
|
assertEquals(tc.msg, out1, out4);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//Test backprop
|
||||||
|
Pair<Gradient, INDArray> p1 = tc.net1.calculateGradients(inNCW, tc.labelsNCW, null, null);
|
||||||
|
Pair<Gradient, INDArray> p2 = tc.net2.calculateGradients(inNCW, tc.labelsNCW, null, null);
|
||||||
|
Pair<Gradient, INDArray> p3 = tc.net3.calculateGradients(inNWC, tc.labelsNWC, null, null);
|
||||||
|
Pair<Gradient, INDArray> p4 = tc.net4.calculateGradients(inNWC, tc.labelsNWC, null, null);
|
||||||
|
|
||||||
|
//Inpput gradients
|
||||||
|
assertEquals(tc.msg, p1.getSecond(), p2.getSecond());
|
||||||
|
|
||||||
|
assertEquals(tc.msg, p1.getSecond(), p3.getSecond().permute(0, 2, 1)); //Input gradients for NWC input are also in NWC format
|
||||||
|
assertEquals(tc.msg, p1.getSecond(), p4.getSecond().permute(0, 2, 1));
|
||||||
|
|
||||||
|
|
||||||
|
List<String> diff12 = differentGrads(p1.getFirst(), p2.getFirst());
|
||||||
|
List<String> diff13 = differentGrads(p1.getFirst(), p3.getFirst());
|
||||||
|
List<String> diff14 = differentGrads(p1.getFirst(), p4.getFirst());
|
||||||
|
assertEquals(tc.msg + " " + diff12, 0, diff12.size());
|
||||||
|
assertEquals(tc.msg + " " + diff13, 0, diff13.size());
|
||||||
|
assertEquals(tc.msg + " " + diff14, 0, diff14.size());
|
||||||
|
|
||||||
|
assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p2.getFirst().gradientForVariable());
|
||||||
|
assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p3.getFirst().gradientForVariable());
|
||||||
|
assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p4.getFirst().gradientForVariable());
|
||||||
|
|
||||||
|
tc.net1.fit(inNCW, tc.labelsNCW);
|
||||||
|
tc.net2.fit(inNCW, tc.labelsNCW);
|
||||||
|
tc.net3.fit(inNWC, tc.labelsNWC);
|
||||||
|
tc.net4.fit(inNWC, tc.labelsNWC);
|
||||||
|
|
||||||
|
assertEquals(tc.msg, tc.net1.params(), tc.net2.params());
|
||||||
|
assertEquals(tc.msg, tc.net1.params(), tc.net3.params());
|
||||||
|
assertEquals(tc.msg, tc.net1.params(), tc.net4.params());
|
||||||
|
|
||||||
|
//Test serialization
|
||||||
|
MultiLayerNetwork net1a = TestUtils.testModelSerialization(tc.net1);
|
||||||
|
MultiLayerNetwork net2a = TestUtils.testModelSerialization(tc.net2);
|
||||||
|
MultiLayerNetwork net3a = TestUtils.testModelSerialization(tc.net3);
|
||||||
|
MultiLayerNetwork net4a = TestUtils.testModelSerialization(tc.net4);
|
||||||
|
|
||||||
|
out1 = tc.net1.output(inNCW);
|
||||||
|
assertEquals(tc.msg, out1, net1a.output(inNCW));
|
||||||
|
assertEquals(tc.msg, out1, net2a.output(inNCW));
|
||||||
|
|
||||||
|
if (rank3Out) {
|
||||||
|
assertEquals(tc.msg, out1, net3a.output(inNWC).permute(0, 2, 1)); //NWC to NCW
|
||||||
|
assertEquals(tc.msg, out1, net4a.output(inNWC).permute(0, 2, 1));
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
assertEquals(tc.msg, out1, net3a.output(inNWC)); //NWC to NCW
|
||||||
|
assertEquals(tc.msg, out1, net4a.output(inNWC));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
private static List<String> differentGrads(Gradient g1, Gradient g2){
|
||||||
|
List<String> differs = new ArrayList<>();
|
||||||
|
Map<String,INDArray> m1 = g1.gradientForVariable();
|
||||||
|
Map<String,INDArray> m2 = g2.gradientForVariable();
|
||||||
|
for(String s : m1.keySet()){
|
||||||
|
INDArray a1 = m1.get(s);
|
||||||
|
INDArray a2 = m2.get(s);
|
||||||
|
if(!a1.equals(a2)){
|
||||||
|
differs.add(s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return differs;
|
||||||
|
}
|
||||||
|
}
|
|
@ -21,6 +21,7 @@ import org.deeplearning4j.TestUtils;
|
||||||
import org.deeplearning4j.nn.api.Layer;
|
import org.deeplearning4j.nn.api.Layer;
|
||||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.LSTM;
|
import org.deeplearning4j.nn.conf.layers.LSTM;
|
||||||
|
@ -29,6 +30,8 @@ import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
|
||||||
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
|
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
import org.junit.runners.Parameterized;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||||
|
@ -42,14 +45,25 @@ import static org.nd4j.linalg.activations.Activation.IDENTITY;
|
||||||
import static org.nd4j.linalg.activations.Activation.TANH;
|
import static org.nd4j.linalg.activations.Activation.TANH;
|
||||||
import static org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction.MSE;
|
import static org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction.MSE;
|
||||||
|
|
||||||
|
|
||||||
|
@RunWith(Parameterized.class)
|
||||||
public class TestLastTimeStepLayer extends BaseDL4JTest {
|
public class TestLastTimeStepLayer extends BaseDL4JTest {
|
||||||
|
private RNNFormat rnnDataFormat;
|
||||||
|
|
||||||
|
public TestLastTimeStepLayer(RNNFormat rnnDataFormat){
|
||||||
|
this.rnnDataFormat = rnnDataFormat;
|
||||||
|
}
|
||||||
|
@Parameterized.Parameters(name="{0}")
|
||||||
|
public static Object[] params(){
|
||||||
|
return RNNFormat.values();
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testLastTimeStepVertex() {
|
public void testLastTimeStepVertex() {
|
||||||
|
|
||||||
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in")
|
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in")
|
||||||
.addLayer("lastTS", new LastTimeStep(new SimpleRnn.Builder()
|
.addLayer("lastTS", new LastTimeStep(new SimpleRnn.Builder()
|
||||||
.nIn(5).nOut(6).build()), "in")
|
.nIn(5).nOut(6).dataFormat(rnnDataFormat).build()), "in")
|
||||||
.setOutputs("lastTS")
|
.setOutputs("lastTS")
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
@ -59,9 +73,22 @@ public class TestLastTimeStepLayer extends BaseDL4JTest {
|
||||||
//First: test without input mask array
|
//First: test without input mask array
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
Layer l = graph.getLayer("lastTS");
|
Layer l = graph.getLayer("lastTS");
|
||||||
INDArray in = Nd4j.rand(new int[]{3, 5, 6});
|
INDArray in;
|
||||||
|
if (rnnDataFormat == RNNFormat.NCW){
|
||||||
|
in = Nd4j.rand(3, 5, 6);
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
in = Nd4j.rand(3, 6, 5);
|
||||||
|
}
|
||||||
INDArray outUnderlying = ((LastTimeStepLayer)l).getUnderlying().activate(in, false, LayerWorkspaceMgr.noWorkspaces());
|
INDArray outUnderlying = ((LastTimeStepLayer)l).getUnderlying().activate(in, false, LayerWorkspaceMgr.noWorkspaces());
|
||||||
INDArray expOut = outUnderlying.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(5));
|
INDArray expOut;
|
||||||
|
if (rnnDataFormat == RNNFormat.NCW){
|
||||||
|
expOut = outUnderlying.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(5));
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
expOut = outUnderlying.get(NDArrayIndex.all(), NDArrayIndex.point(5), NDArrayIndex.all());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
//Forward pass:
|
//Forward pass:
|
||||||
|
@ -76,9 +103,17 @@ public class TestLastTimeStepLayer extends BaseDL4JTest {
|
||||||
graph.setLayerMaskArrays(new INDArray[]{inMask}, null);
|
graph.setLayerMaskArrays(new INDArray[]{inMask}, null);
|
||||||
|
|
||||||
expOut = Nd4j.zeros(3, 6);
|
expOut = Nd4j.zeros(3, 6);
|
||||||
expOut.putRow(0, outUnderlying.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(2)));
|
if (rnnDataFormat == RNNFormat.NCW){
|
||||||
expOut.putRow(1, outUnderlying.get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.point(3)));
|
expOut.putRow(0, outUnderlying.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(2)));
|
||||||
expOut.putRow(2, outUnderlying.get(NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.point(4)));
|
expOut.putRow(1, outUnderlying.get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.point(3)));
|
||||||
|
expOut.putRow(2, outUnderlying.get(NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.point(4)));
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
expOut.putRow(0, outUnderlying.get(NDArrayIndex.point(0), NDArrayIndex.point(2), NDArrayIndex.all()));
|
||||||
|
expOut.putRow(1, outUnderlying.get(NDArrayIndex.point(1), NDArrayIndex.point(3), NDArrayIndex.all()));
|
||||||
|
expOut.putRow(2, outUnderlying.get(NDArrayIndex.point(2), NDArrayIndex.point(4), NDArrayIndex.all()));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
outFwd = l.activate(in, false, LayerWorkspaceMgr.noWorkspaces());
|
outFwd = l.activate(in, false, LayerWorkspaceMgr.noWorkspaces());
|
||||||
assertEquals(expOut, outFwd);
|
assertEquals(expOut, outFwd);
|
||||||
|
@ -97,9 +132,9 @@ public class TestLastTimeStepLayer extends BaseDL4JTest {
|
||||||
.seed(1234)
|
.seed(1234)
|
||||||
.graphBuilder()
|
.graphBuilder()
|
||||||
.addInputs("in")
|
.addInputs("in")
|
||||||
.setInputTypes(InputType.recurrent(1))
|
.setInputTypes(InputType.recurrent(1, rnnDataFormat))
|
||||||
.addLayer("RNN", new LastTimeStep(new LSTM.Builder()
|
.addLayer("RNN", new LastTimeStep(new LSTM.Builder()
|
||||||
.nOut(10)
|
.nOut(10).dataFormat(rnnDataFormat)
|
||||||
.build()), "in")
|
.build()), "in")
|
||||||
.addLayer("dense", new DenseLayer.Builder()
|
.addLayer("dense", new DenseLayer.Builder()
|
||||||
.nOut(10)
|
.nOut(10)
|
||||||
|
@ -120,7 +155,9 @@ public class TestLastTimeStepLayer extends BaseDL4JTest {
|
||||||
INDArray fm2 = Nd4j.zeros(1,24);
|
INDArray fm2 = Nd4j.zeros(1,24);
|
||||||
INDArray fm3 = Nd4j.zeros(1,24);
|
INDArray fm3 = Nd4j.zeros(1,24);
|
||||||
fm3.get(NDArrayIndex.point(0), NDArrayIndex.interval(0,5)).assign(1);
|
fm3.get(NDArrayIndex.point(0), NDArrayIndex.interval(0,5)).assign(1);
|
||||||
|
if (rnnDataFormat == RNNFormat.NWC){
|
||||||
|
f = f.permute(0, 2, 1);
|
||||||
|
}
|
||||||
INDArray[] out1 = cg.output(false, new INDArray[]{f}, new INDArray[]{fm1});
|
INDArray[] out1 = cg.output(false, new INDArray[]{f}, new INDArray[]{fm1});
|
||||||
try {
|
try {
|
||||||
cg.output(false, new INDArray[]{f}, new INDArray[]{fm2});
|
cg.output(false, new INDArray[]{f}, new INDArray[]{fm2});
|
||||||
|
|
|
@ -20,6 +20,7 @@ import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.TestUtils;
|
import org.deeplearning4j.TestUtils;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||||
import org.deeplearning4j.nn.conf.dropout.TestDropout;
|
import org.deeplearning4j.nn.conf.dropout.TestDropout;
|
||||||
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
|
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
|
||||||
import org.deeplearning4j.nn.conf.layers.LSTM;
|
import org.deeplearning4j.nn.conf.layers.LSTM;
|
||||||
|
@ -31,6 +32,8 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
import org.junit.runners.Parameterized;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -41,13 +44,24 @@ import org.nd4j.linalg.primitives.Pair;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Random;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
import static org.junit.Assert.assertNotEquals;
|
import static org.junit.Assert.assertNotEquals;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
|
@RunWith(Parameterized.class)
|
||||||
public class TestRnnLayers extends BaseDL4JTest {
|
public class TestRnnLayers extends BaseDL4JTest {
|
||||||
|
|
||||||
|
private RNNFormat rnnDataFormat;
|
||||||
|
|
||||||
|
public TestRnnLayers(RNNFormat rnnDataFormat){
|
||||||
|
this.rnnDataFormat = rnnDataFormat;
|
||||||
|
}
|
||||||
|
@Parameterized.Parameters
|
||||||
|
public static Object[] params(){
|
||||||
|
return RNNFormat.values();
|
||||||
|
}
|
||||||
@Test
|
@Test
|
||||||
public void testTimeStepIs3Dimensional() {
|
public void testTimeStepIs3Dimensional() {
|
||||||
|
|
||||||
|
@ -58,8 +72,8 @@ public class TestRnnLayers extends BaseDL4JTest {
|
||||||
.updater(new NoOp())
|
.updater(new NoOp())
|
||||||
.weightInit(WeightInit.XAVIER)
|
.weightInit(WeightInit.XAVIER)
|
||||||
.list()
|
.list()
|
||||||
.layer(new SimpleRnn.Builder().nIn(nIn).nOut(3).build())
|
.layer(new SimpleRnn.Builder().nIn(nIn).nOut(3).dataFormat(rnnDataFormat).build())
|
||||||
.layer(new LSTM.Builder().nIn(3).nOut(5).build())
|
.layer(new LSTM.Builder().nIn(3).nOut(5).dataFormat(rnnDataFormat).build())
|
||||||
.layer(new RnnOutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX).build())
|
.layer(new RnnOutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX).build())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
@ -70,9 +84,9 @@ public class TestRnnLayers extends BaseDL4JTest {
|
||||||
org.deeplearning4j.nn.layers.recurrent.SimpleRnn simpleRnn =
|
org.deeplearning4j.nn.layers.recurrent.SimpleRnn simpleRnn =
|
||||||
(org.deeplearning4j.nn.layers.recurrent.SimpleRnn) net.getLayer(0);
|
(org.deeplearning4j.nn.layers.recurrent.SimpleRnn) net.getLayer(0);
|
||||||
|
|
||||||
INDArray rnnInput3d = Nd4j.create(10, 12, 1);
|
INDArray rnnInput3d = (rnnDataFormat==RNNFormat.NCW)?Nd4j.create(10,12, 1):Nd4j.create(10, 1, 12);
|
||||||
INDArray simpleOut = simpleRnn.rnnTimeStep(rnnInput3d, LayerWorkspaceMgr.noWorkspaces());
|
INDArray simpleOut = simpleRnn.rnnTimeStep(rnnInput3d, LayerWorkspaceMgr.noWorkspaces());
|
||||||
assertTrue(Arrays.equals(simpleOut.shape(), new long[] {10, 3, 1}));
|
assertTrue(Arrays.equals(simpleOut.shape(), (rnnDataFormat==RNNFormat.NCW)?new long[] {10, 3, 1}:new long[]{10, 1, 3}));
|
||||||
|
|
||||||
INDArray rnnInput2d = Nd4j.create(10, 12);
|
INDArray rnnInput2d = Nd4j.create(10, 12);
|
||||||
try {
|
try {
|
||||||
|
@ -84,9 +98,9 @@ public class TestRnnLayers extends BaseDL4JTest {
|
||||||
org.deeplearning4j.nn.layers.recurrent.LSTM lstm =
|
org.deeplearning4j.nn.layers.recurrent.LSTM lstm =
|
||||||
(org.deeplearning4j.nn.layers.recurrent.LSTM) net.getLayer(1);
|
(org.deeplearning4j.nn.layers.recurrent.LSTM) net.getLayer(1);
|
||||||
|
|
||||||
INDArray lstmInput3d = Nd4j.create(10, 3, 1);
|
INDArray lstmInput3d = (rnnDataFormat==RNNFormat.NCW)?Nd4j.create(10, 3, 1):Nd4j.create(10, 1, 3);
|
||||||
INDArray lstmOut = lstm.rnnTimeStep(lstmInput3d, LayerWorkspaceMgr.noWorkspaces());
|
INDArray lstmOut = lstm.rnnTimeStep(lstmInput3d, LayerWorkspaceMgr.noWorkspaces());
|
||||||
assertTrue(Arrays.equals(lstmOut.shape(), new long[] {10, 5, 1}));
|
assertTrue(Arrays.equals(lstmOut.shape(), (rnnDataFormat==RNNFormat.NCW)?new long[] {10, 5, 1}:new long[]{10, 1, 5}));
|
||||||
|
|
||||||
INDArray lstmInput2d = Nd4j.create(10, 3);
|
INDArray lstmInput2d = Nd4j.create(10, 3);
|
||||||
try {
|
try {
|
||||||
|
@ -112,19 +126,19 @@ public class TestRnnLayers extends BaseDL4JTest {
|
||||||
TestDropout.CustomDropout cd = new TestDropout.CustomDropout();
|
TestDropout.CustomDropout cd = new TestDropout.CustomDropout();
|
||||||
switch (s){
|
switch (s){
|
||||||
case "graves":
|
case "graves":
|
||||||
layer = new GravesLSTM.Builder().activation(Activation.TANH).nIn(10).nOut(10).build();
|
layer = new GravesLSTM.Builder().activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
|
||||||
layerD = new GravesLSTM.Builder().dropOut(0.5).activation(Activation.TANH).nIn(10).nOut(10).build();
|
layerD = new GravesLSTM.Builder().dropOut(0.5).activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
|
||||||
layerD2 = new GravesLSTM.Builder().dropOut(cd).activation(Activation.TANH).nIn(10).nOut(10).build();
|
layerD2 = new GravesLSTM.Builder().dropOut(cd).activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
|
||||||
break;
|
break;
|
||||||
case "lstm":
|
case "lstm":
|
||||||
layer = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().activation(Activation.TANH).nIn(10).nOut(10).build();
|
layer = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
|
||||||
layerD = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().dropOut(0.5).activation(Activation.TANH).nIn(10).nOut(10).build();
|
layerD = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().dropOut(0.5).activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
|
||||||
layerD2 = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().dropOut(cd).activation(Activation.TANH).nIn(10).nOut(10).build();
|
layerD2 = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().dropOut(cd).activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
|
||||||
break;
|
break;
|
||||||
case "simple":
|
case "simple":
|
||||||
layer = new SimpleRnn.Builder().activation(Activation.TANH).nIn(10).nOut(10).build();
|
layer = new SimpleRnn.Builder().activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
|
||||||
layerD = new SimpleRnn.Builder().dropOut(0.5).activation(Activation.TANH).nIn(10).nOut(10).build();
|
layerD = new SimpleRnn.Builder().dropOut(0.5).activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
|
||||||
layerD2 = new SimpleRnn.Builder().dropOut(cd).activation(Activation.TANH).nIn(10).nOut(10).build();
|
layerD2 = new SimpleRnn.Builder().dropOut(cd).activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
throw new RuntimeException(s);
|
throw new RuntimeException(s);
|
||||||
|
@ -134,21 +148,21 @@ public class TestRnnLayers extends BaseDL4JTest {
|
||||||
.seed(12345)
|
.seed(12345)
|
||||||
.list()
|
.list()
|
||||||
.layer(layer)
|
.layer(layer)
|
||||||
.layer(new RnnOutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).build())
|
.layer(new RnnOutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).dataFormat(rnnDataFormat).build())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
MultiLayerConfiguration confD = new NeuralNetConfiguration.Builder()
|
MultiLayerConfiguration confD = new NeuralNetConfiguration.Builder()
|
||||||
.seed(12345)
|
.seed(12345)
|
||||||
.list()
|
.list()
|
||||||
.layer(layerD)
|
.layer(layerD)
|
||||||
.layer(new RnnOutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).build())
|
.layer(new RnnOutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).dataFormat(rnnDataFormat).build())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
MultiLayerConfiguration confD2 = new NeuralNetConfiguration.Builder()
|
MultiLayerConfiguration confD2 = new NeuralNetConfiguration.Builder()
|
||||||
.seed(12345)
|
.seed(12345)
|
||||||
.list()
|
.list()
|
||||||
.layer(layerD2)
|
.layer(layerD2)
|
||||||
.layer(new RnnOutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).build())
|
.layer(new RnnOutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).dataFormat(rnnDataFormat).build())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
|
@ -178,7 +192,6 @@ public class TestRnnLayers extends BaseDL4JTest {
|
||||||
assertNotEquals(s, out2, out2D);
|
assertNotEquals(s, out2, out2D);
|
||||||
|
|
||||||
INDArray l = TestUtils.randomOneHotTimeSeries(3, 10, 10, 12345);
|
INDArray l = TestUtils.randomOneHotTimeSeries(3, 10, 10, 12345);
|
||||||
|
|
||||||
net.fit(f.dup(), l);
|
net.fit(f.dup(), l);
|
||||||
netD.fit(f.dup(), l);
|
netD.fit(f.dup(), l);
|
||||||
assertNotEquals(s, net.params(), netD.params());
|
assertNotEquals(s, net.params(), netD.params());
|
||||||
|
@ -205,14 +218,14 @@ public class TestRnnLayers extends BaseDL4JTest {
|
||||||
NeuralNetConfiguration.ListBuilder lb = new NeuralNetConfiguration.Builder()
|
NeuralNetConfiguration.ListBuilder lb = new NeuralNetConfiguration.Builder()
|
||||||
|
|
||||||
.list()
|
.list()
|
||||||
.layer(new SimpleRnn.Builder().nIn(5).nOut(5).build());
|
.layer(new SimpleRnn.Builder().nIn(5).nOut(5).dataFormat(rnnDataFormat).build());
|
||||||
|
|
||||||
switch (i){
|
switch (i){
|
||||||
case 0:
|
case 0:
|
||||||
lb.layer(new RnnOutputLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).nIn(5).nOut(5).build());
|
lb.layer(new RnnOutputLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).nIn(5).nOut(5).dataFormat(rnnDataFormat).build());
|
||||||
break;
|
break;
|
||||||
case 1:
|
case 1:
|
||||||
lb.layer(new RnnLossLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build());
|
lb.layer(new RnnLossLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).dataFormat(rnnDataFormat).build());
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
throw new RuntimeException();
|
throw new RuntimeException();
|
||||||
|
@ -223,14 +236,14 @@ public class TestRnnLayers extends BaseDL4JTest {
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
INDArray in = Nd4j.rand(DataType.FLOAT, 3, 5, 5);
|
INDArray in = Nd4j.rand(DataType.FLOAT, 3, 5, 5);
|
||||||
INDArray l = TestUtils.randomOneHotTimeSeries(3, 5, 10);
|
INDArray l = TestUtils.randomOneHotTimeSeries(rnnDataFormat, 3, 5, 10, new Random(12345));
|
||||||
|
|
||||||
try{
|
try{
|
||||||
net.fit(in,l);
|
net.fit(in,l);
|
||||||
} catch (Throwable t){
|
} catch (Throwable t){
|
||||||
String msg = t.getMessage();
|
String msg = t.getMessage();
|
||||||
if(msg == null)
|
if(msg == null)
|
||||||
t.printStackTrace();
|
t.printStackTrace();
|
||||||
|
System.out.println(i);
|
||||||
assertTrue(msg, msg != null && msg.contains("sequence length") && msg.contains("input") && msg.contains("label"));
|
assertTrue(msg, msg != null && msg.contains("sequence length") && msg.contains("input") && msg.contains("label"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -20,10 +20,13 @@ import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.TestUtils;
|
import org.deeplearning4j.TestUtils;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||||
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
|
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
import org.junit.runners.Parameterized;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -36,8 +39,18 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.all;
|
||||||
import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
|
import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
|
||||||
import static org.nd4j.linalg.indexing.NDArrayIndex.point;
|
import static org.nd4j.linalg.indexing.NDArrayIndex.point;
|
||||||
|
|
||||||
|
@RunWith(Parameterized.class)
|
||||||
public class TestSimpleRnn extends BaseDL4JTest {
|
public class TestSimpleRnn extends BaseDL4JTest {
|
||||||
|
|
||||||
|
private RNNFormat rnnDataFormat;
|
||||||
|
|
||||||
|
public TestSimpleRnn(RNNFormat rnnDataFormat){
|
||||||
|
this.rnnDataFormat = rnnDataFormat;
|
||||||
|
}
|
||||||
|
@Parameterized.Parameters
|
||||||
|
public static Object[] params(){
|
||||||
|
return RNNFormat.values();
|
||||||
|
}
|
||||||
@Test
|
@Test
|
||||||
public void testSimpleRnn(){
|
public void testSimpleRnn(){
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
@ -46,15 +59,21 @@ public class TestSimpleRnn extends BaseDL4JTest {
|
||||||
int nIn = 5;
|
int nIn = 5;
|
||||||
int layerSize = 6;
|
int layerSize = 6;
|
||||||
int tsLength = 7;
|
int tsLength = 7;
|
||||||
INDArray in = Nd4j.rand(DataType.FLOAT, new int[]{m, nIn, tsLength});
|
INDArray in;
|
||||||
// in.get(all(), all(), interval(1,tsLength)).assign(0);
|
if (rnnDataFormat == RNNFormat.NCW){
|
||||||
|
in = Nd4j.rand(DataType.FLOAT, m, nIn, tsLength);
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
in = Nd4j.rand(DataType.FLOAT, m, tsLength, nIn);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
.updater(new NoOp())
|
.updater(new NoOp())
|
||||||
.weightInit(WeightInit.XAVIER)
|
.weightInit(WeightInit.XAVIER)
|
||||||
.activation(Activation.TANH)
|
.activation(Activation.TANH)
|
||||||
.list()
|
.list()
|
||||||
.layer(new SimpleRnn.Builder().nIn(nIn).nOut(layerSize).build())
|
.layer(new SimpleRnn.Builder().nIn(nIn).nOut(layerSize).dataFormat(rnnDataFormat).build())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
|
@ -68,7 +87,13 @@ public class TestSimpleRnn extends BaseDL4JTest {
|
||||||
|
|
||||||
INDArray outLast = null;
|
INDArray outLast = null;
|
||||||
for( int i=0; i<tsLength; i++ ){
|
for( int i=0; i<tsLength; i++ ){
|
||||||
INDArray inCurrent = in.get(all(), all(), point(i));
|
INDArray inCurrent;
|
||||||
|
if (rnnDataFormat == RNNFormat.NCW){
|
||||||
|
inCurrent = in.get(all(), all(), point(i));
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
inCurrent = in.get(all(), point(i), all());
|
||||||
|
}
|
||||||
|
|
||||||
INDArray outExpCurrent = inCurrent.mmul(w);
|
INDArray outExpCurrent = inCurrent.mmul(w);
|
||||||
if(outLast != null){
|
if(outLast != null){
|
||||||
|
@ -79,7 +104,13 @@ public class TestSimpleRnn extends BaseDL4JTest {
|
||||||
|
|
||||||
Transforms.tanh(outExpCurrent, false);
|
Transforms.tanh(outExpCurrent, false);
|
||||||
|
|
||||||
INDArray outActCurrent = out.get(all(), all(), point(i));
|
INDArray outActCurrent;
|
||||||
|
if (rnnDataFormat == RNNFormat.NCW){
|
||||||
|
outActCurrent = out.get(all(), all(), point(i));
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
outActCurrent = out.get(all(), point(i), all());
|
||||||
|
}
|
||||||
assertEquals(String.valueOf(i), outExpCurrent, outActCurrent);
|
assertEquals(String.valueOf(i), outExpCurrent, outActCurrent);
|
||||||
|
|
||||||
outLast = outExpCurrent;
|
outLast = outExpCurrent;
|
||||||
|
@ -100,7 +131,7 @@ public class TestSimpleRnn extends BaseDL4JTest {
|
||||||
.weightInit(WeightInit.XAVIER)
|
.weightInit(WeightInit.XAVIER)
|
||||||
.activation(Activation.TANH)
|
.activation(Activation.TANH)
|
||||||
.list()
|
.list()
|
||||||
.layer(new SimpleRnn.Builder().nIn(nIn).nOut(layerSize)
|
.layer(new SimpleRnn.Builder().nIn(nIn).nOut(layerSize).dataFormat(rnnDataFormat)
|
||||||
.biasInit(100)
|
.biasInit(100)
|
||||||
.build())
|
.build())
|
||||||
.build();
|
.build();
|
||||||
|
|
|
@ -4,14 +4,21 @@ import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.TestUtils;
|
import org.deeplearning4j.TestUtils;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||||
import org.deeplearning4j.nn.conf.WorkspaceMode;
|
import org.deeplearning4j.nn.conf.WorkspaceMode;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.LSTM;
|
import org.deeplearning4j.nn.conf.layers.LSTM;
|
||||||
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
|
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
|
||||||
import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
|
import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
import org.junit.runners.Parameterized;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -22,8 +29,18 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
|
@RunWith(Parameterized.class)
|
||||||
public class TestTimeDistributed extends BaseDL4JTest {
|
public class TestTimeDistributed extends BaseDL4JTest {
|
||||||
|
|
||||||
|
private RNNFormat rnnDataFormat;
|
||||||
|
|
||||||
|
public TestTimeDistributed(RNNFormat rnnDataFormat){
|
||||||
|
this.rnnDataFormat = rnnDataFormat;
|
||||||
|
}
|
||||||
|
@Parameterized.Parameters
|
||||||
|
public static Object[] params(){
|
||||||
|
return RNNFormat.values();
|
||||||
|
}
|
||||||
@Test
|
@Test
|
||||||
public void testTimeDistributed(){
|
public void testTimeDistributed(){
|
||||||
for(WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.ENABLED, WorkspaceMode.NONE}) {
|
for(WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.ENABLED, WorkspaceMode.NONE}) {
|
||||||
|
@ -34,11 +51,11 @@ public class TestTimeDistributed extends BaseDL4JTest {
|
||||||
.seed(12345)
|
.seed(12345)
|
||||||
.updater(new Adam(0.1))
|
.updater(new Adam(0.1))
|
||||||
.list()
|
.list()
|
||||||
.layer(new LSTM.Builder().nIn(3).nOut(3).build())
|
.layer(new LSTM.Builder().nIn(3).nOut(3).dataFormat(rnnDataFormat).build())
|
||||||
.layer(new DenseLayer.Builder().nIn(3).nOut(3).activation(Activation.TANH).build())
|
.layer(new DenseLayer.Builder().nIn(3).nOut(3).activation(Activation.TANH).build())
|
||||||
.layer(new RnnOutputLayer.Builder().nIn(3).nOut(3).activation(Activation.SOFTMAX)
|
.layer(new RnnOutputLayer.Builder().nIn(3).nOut(3).activation(Activation.SOFTMAX).dataFormat(rnnDataFormat)
|
||||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||||
.setInputType(InputType.recurrent(3))
|
.setInputType(InputType.recurrent(3, rnnDataFormat))
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder()
|
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder()
|
||||||
|
@ -47,11 +64,11 @@ public class TestTimeDistributed extends BaseDL4JTest {
|
||||||
.seed(12345)
|
.seed(12345)
|
||||||
.updater(new Adam(0.1))
|
.updater(new Adam(0.1))
|
||||||
.list()
|
.list()
|
||||||
.layer(new LSTM.Builder().nIn(3).nOut(3).build())
|
.layer(new LSTM.Builder().nIn(3).nOut(3).dataFormat(rnnDataFormat).build())
|
||||||
.layer(new TimeDistributed(new DenseLayer.Builder().nIn(3).nOut(3).activation(Activation.TANH).build(), 2))
|
.layer(new TimeDistributed(new DenseLayer.Builder().nIn(3).nOut(3).activation(Activation.TANH).build(), rnnDataFormat))
|
||||||
.layer(new RnnOutputLayer.Builder().nIn(3).nOut(3).activation(Activation.SOFTMAX)
|
.layer(new RnnOutputLayer.Builder().nIn(3).nOut(3).activation(Activation.SOFTMAX).dataFormat(rnnDataFormat)
|
||||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||||
.setInputType(InputType.recurrent(3))
|
.setInputType(InputType.recurrent(3, rnnDataFormat))
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
MultiLayerNetwork net1 = new MultiLayerNetwork(conf1);
|
MultiLayerNetwork net1 = new MultiLayerNetwork(conf1);
|
||||||
|
@ -62,13 +79,21 @@ public class TestTimeDistributed extends BaseDL4JTest {
|
||||||
for( int mb : new int[]{1, 5}) {
|
for( int mb : new int[]{1, 5}) {
|
||||||
for(char inLabelOrder : new char[]{'c', 'f'}) {
|
for(char inLabelOrder : new char[]{'c', 'f'}) {
|
||||||
INDArray in = Nd4j.rand(DataType.FLOAT, mb, 3, 5).dup(inLabelOrder);
|
INDArray in = Nd4j.rand(DataType.FLOAT, mb, 3, 5).dup(inLabelOrder);
|
||||||
|
if (rnnDataFormat == RNNFormat.NWC){
|
||||||
|
in = in.permute(0, 2, 1);
|
||||||
|
}
|
||||||
INDArray out1 = net1.output(in);
|
INDArray out1 = net1.output(in);
|
||||||
INDArray out2 = net2.output(in);
|
INDArray out2 = net2.output(in);
|
||||||
|
|
||||||
assertEquals(out1, out2);
|
assertEquals(out1, out2);
|
||||||
|
|
||||||
INDArray labels = TestUtils.randomOneHotTimeSeries(mb, 3, 5).dup(inLabelOrder);
|
INDArray labels ;
|
||||||
|
if (rnnDataFormat == RNNFormat.NCW) {
|
||||||
|
labels = TestUtils.randomOneHotTimeSeries(mb, 3, 5).dup(inLabelOrder);
|
||||||
|
}else{
|
||||||
|
labels = TestUtils.randomOneHotTimeSeries(mb, 5, 3).dup(inLabelOrder);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
DataSet ds = new DataSet(in, labels);
|
DataSet ds = new DataSet(in, labels);
|
||||||
net1.fit(ds);
|
net1.fit(ds);
|
||||||
|
@ -85,4 +110,73 @@ public class TestTimeDistributed extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testTimeDistributedDense(){
|
||||||
|
|
||||||
|
for( int rnnType=0; rnnType<3; rnnType++ ) {
|
||||||
|
for( int ffType=0; ffType<3; ffType++ ) {
|
||||||
|
|
||||||
|
Layer l0, l2;
|
||||||
|
switch (rnnType) {
|
||||||
|
case 0:
|
||||||
|
l0 = new LSTM.Builder().nOut(5).build();
|
||||||
|
l2 = new LSTM.Builder().nOut(5).build();
|
||||||
|
break;
|
||||||
|
case 1:
|
||||||
|
l0 = new SimpleRnn.Builder().nOut(5).build();
|
||||||
|
l2 = new SimpleRnn.Builder().nOut(5).build();
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
l0 = new Bidirectional(new LSTM.Builder().nOut(5).build());
|
||||||
|
l2 = new Bidirectional(new LSTM.Builder().nOut(5).build());
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw new RuntimeException("Not implemented: " + rnnType);
|
||||||
|
}
|
||||||
|
|
||||||
|
Layer l1;
|
||||||
|
switch (ffType){
|
||||||
|
case 0:
|
||||||
|
l1 = new DenseLayer.Builder().nOut(5).build();
|
||||||
|
break;
|
||||||
|
case 1:
|
||||||
|
l1 = new VariationalAutoencoder.Builder().nOut(5).encoderLayerSizes(5).decoderLayerSizes(5).build();
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
l1 = new AutoEncoder.Builder().nOut(5).build();
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw new RuntimeException("Not implemented: " + ffType);
|
||||||
|
}
|
||||||
|
|
||||||
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
|
.activation(Activation.TANH)
|
||||||
|
.list()
|
||||||
|
.layer(l0)
|
||||||
|
.layer(l1)
|
||||||
|
.layer(l2)
|
||||||
|
.setInputType(InputType.recurrent(5, 9, rnnDataFormat))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
BaseRecurrentLayer l0a;
|
||||||
|
BaseRecurrentLayer l2a;
|
||||||
|
if (rnnType < 2) {
|
||||||
|
l0a = (BaseRecurrentLayer) l0;
|
||||||
|
l2a = (BaseRecurrentLayer) l2;
|
||||||
|
} else {
|
||||||
|
l0a = (BaseRecurrentLayer) ((Bidirectional) l0).getFwd();
|
||||||
|
l2a = (BaseRecurrentLayer) ((Bidirectional) l2).getFwd();
|
||||||
|
}
|
||||||
|
assertEquals(rnnDataFormat, l0a.getRnnDataFormat());
|
||||||
|
assertEquals(rnnDataFormat, l2a.getRnnDataFormat());
|
||||||
|
|
||||||
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
|
net.init();
|
||||||
|
INDArray in = Nd4j.rand(DataType.FLOAT, rnnDataFormat == RNNFormat.NCW ? new long[]{2, 5, 9} : new long[]{2, 9, 5} );
|
||||||
|
net.output(in);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -771,7 +771,7 @@ public class MultiLayerTestRNN extends BaseDL4JTest {
|
||||||
.build();
|
.build();
|
||||||
fail("Exception expected");
|
fail("Exception expected");
|
||||||
} catch (IllegalStateException e){
|
} catch (IllegalStateException e){
|
||||||
// e.printStackTrace();
|
log.info(e.toString());
|
||||||
assertTrue(e.getMessage().contains("TBPTT") && e.getMessage().contains("validateTbpttConfig"));
|
assertTrue(e.getMessage().contains("TBPTT") && e.getMessage().contains("validateTbpttConfig"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -408,8 +408,8 @@ public class TestVariableLengthTS extends BaseDL4JTest {
|
||||||
INDArray outRow2 = out2.get(NDArrayIndex.point(i), NDArrayIndex.all(),
|
INDArray outRow2 = out2.get(NDArrayIndex.point(i), NDArrayIndex.all(),
|
||||||
NDArrayIndex.point(j));
|
NDArrayIndex.point(j));
|
||||||
for (int k = 0; k < nOut; k++) {
|
for (int k = 0; k < nOut; k++) {
|
||||||
assertEquals(outRow.getDouble(k), 0.0, 0.0);
|
assertEquals(0.0, outRow.getDouble(k), 0.0);
|
||||||
assertEquals(outRow2.getDouble(k), 0.0, 0.0);
|
assertEquals(0.0, outRow2.getDouble(k), 0.0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -336,7 +336,7 @@ public class TestUpdaters extends BaseDL4JTest {
|
||||||
actualM[i] = Math.round(actualM[i] * 1e2) / 1e2;
|
actualM[i] = Math.round(actualM[i] * 1e2) / 1e2;
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals("Wrong weight gradient after first iteration's update", Arrays.equals(actualM, expectedM), true);
|
assertEquals("Wrong weight gradient after first iteration's update", Arrays.equals(expectedM, actualM), true);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -159,14 +159,14 @@ public class RegressionTest050 extends BaseDL4JTest {
|
||||||
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
|
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
|
||||||
assertArrayEquals(new int[] {1, 1}, l0.getStride());
|
assertArrayEquals(new int[] {1, 1}, l0.getStride());
|
||||||
assertArrayEquals(new int[] {0, 0}, l0.getPadding());
|
assertArrayEquals(new int[] {0, 0}, l0.getPadding());
|
||||||
assertEquals(l0.getConvolutionMode(), ConvolutionMode.Truncate); //Pre-0.7.0: no ConvolutionMode. Want to default to truncate here if not set
|
assertEquals(ConvolutionMode.Truncate, l0.getConvolutionMode()); //Pre-0.7.0: no ConvolutionMode. Want to default to truncate here if not set
|
||||||
|
|
||||||
SubsamplingLayer l1 = (SubsamplingLayer) conf.getConf(1).getLayer();
|
SubsamplingLayer l1 = (SubsamplingLayer) conf.getConf(1).getLayer();
|
||||||
assertArrayEquals(new int[] {2, 2}, l1.getKernelSize());
|
assertArrayEquals(new int[] {2, 2}, l1.getKernelSize());
|
||||||
assertArrayEquals(new int[] {1, 1}, l1.getStride());
|
assertArrayEquals(new int[] {1, 1}, l1.getStride());
|
||||||
assertArrayEquals(new int[] {0, 0}, l1.getPadding());
|
assertArrayEquals(new int[] {0, 0}, l1.getPadding());
|
||||||
assertEquals(PoolingType.MAX, l1.getPoolingType());
|
assertEquals(PoolingType.MAX, l1.getPoolingType());
|
||||||
assertEquals(l1.getConvolutionMode(), ConvolutionMode.Truncate); //Pre-0.7.0: no ConvolutionMode. Want to default to truncate here if not set
|
assertEquals(ConvolutionMode.Truncate, l1.getConvolutionMode()); //Pre-0.7.0: no ConvolutionMode. Want to default to truncate here if not set
|
||||||
|
|
||||||
OutputLayer l2 = (OutputLayer) conf.getConf(2).getLayer();
|
OutputLayer l2 = (OutputLayer) conf.getConf(2).getLayer();
|
||||||
assertEquals("sigmoid", l2.getActivationFn().toString());
|
assertEquals("sigmoid", l2.getActivationFn().toString());
|
||||||
|
|
|
@ -162,14 +162,14 @@ public class RegressionTest060 extends BaseDL4JTest {
|
||||||
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
|
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
|
||||||
assertArrayEquals(new int[] {1, 1}, l0.getStride());
|
assertArrayEquals(new int[] {1, 1}, l0.getStride());
|
||||||
assertArrayEquals(new int[] {0, 0}, l0.getPadding());
|
assertArrayEquals(new int[] {0, 0}, l0.getPadding());
|
||||||
assertEquals(l0.getConvolutionMode(), ConvolutionMode.Truncate); //Pre-0.7.0: no ConvolutionMode. Want to default to truncate here if not set
|
assertEquals(ConvolutionMode.Truncate, l0.getConvolutionMode()); //Pre-0.7.0: no ConvolutionMode. Want to default to truncate here if not set
|
||||||
|
|
||||||
SubsamplingLayer l1 = (SubsamplingLayer) conf.getConf(1).getLayer();
|
SubsamplingLayer l1 = (SubsamplingLayer) conf.getConf(1).getLayer();
|
||||||
assertArrayEquals(new int[] {2, 2}, l1.getKernelSize());
|
assertArrayEquals(new int[] {2, 2}, l1.getKernelSize());
|
||||||
assertArrayEquals(new int[] {1, 1}, l1.getStride());
|
assertArrayEquals(new int[] {1, 1}, l1.getStride());
|
||||||
assertArrayEquals(new int[] {0, 0}, l1.getPadding());
|
assertArrayEquals(new int[] {0, 0}, l1.getPadding());
|
||||||
assertEquals(PoolingType.MAX, l1.getPoolingType());
|
assertEquals(PoolingType.MAX, l1.getPoolingType());
|
||||||
assertEquals(l1.getConvolutionMode(), ConvolutionMode.Truncate); //Pre-0.7.0: no ConvolutionMode. Want to default to truncate here if not set
|
assertEquals(ConvolutionMode.Truncate, l1.getConvolutionMode()); //Pre-0.7.0: no ConvolutionMode. Want to default to truncate here if not set
|
||||||
|
|
||||||
OutputLayer l2 = (OutputLayer) conf.getConf(2).getLayer();
|
OutputLayer l2 = (OutputLayer) conf.getConf(2).getLayer();
|
||||||
assertEquals("sigmoid", l2.getActivationFn().toString());
|
assertEquals("sigmoid", l2.getActivationFn().toString());
|
||||||
|
|
|
@ -162,7 +162,7 @@ public class RegressionTest071 extends BaseDL4JTest {
|
||||||
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
|
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
|
||||||
assertArrayEquals(new int[] {1, 1}, l0.getStride());
|
assertArrayEquals(new int[] {1, 1}, l0.getStride());
|
||||||
assertArrayEquals(new int[] {0, 0}, l0.getPadding());
|
assertArrayEquals(new int[] {0, 0}, l0.getPadding());
|
||||||
assertEquals(l0.getConvolutionMode(), ConvolutionMode.Same);
|
assertEquals(ConvolutionMode.Same, l0.getConvolutionMode());
|
||||||
|
|
||||||
SubsamplingLayer l1 = (SubsamplingLayer) conf.getConf(1).getLayer();
|
SubsamplingLayer l1 = (SubsamplingLayer) conf.getConf(1).getLayer();
|
||||||
assertArrayEquals(new int[] {2, 2}, l1.getKernelSize());
|
assertArrayEquals(new int[] {2, 2}, l1.getKernelSize());
|
||||||
|
|
|
@ -176,14 +176,14 @@ public class RegressionTest080 extends BaseDL4JTest {
|
||||||
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
|
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
|
||||||
assertArrayEquals(new int[] {1, 1}, l0.getStride());
|
assertArrayEquals(new int[] {1, 1}, l0.getStride());
|
||||||
assertArrayEquals(new int[] {0, 0}, l0.getPadding());
|
assertArrayEquals(new int[] {0, 0}, l0.getPadding());
|
||||||
assertEquals(l0.getConvolutionMode(), ConvolutionMode.Same);
|
assertEquals(ConvolutionMode.Same, l0.getConvolutionMode());
|
||||||
|
|
||||||
SubsamplingLayer l1 = (SubsamplingLayer) conf.getConf(1).getLayer();
|
SubsamplingLayer l1 = (SubsamplingLayer) conf.getConf(1).getLayer();
|
||||||
assertArrayEquals(new int[] {2, 2}, l1.getKernelSize());
|
assertArrayEquals(new int[] {2, 2}, l1.getKernelSize());
|
||||||
assertArrayEquals(new int[] {1, 1}, l1.getStride());
|
assertArrayEquals(new int[] {1, 1}, l1.getStride());
|
||||||
assertArrayEquals(new int[] {0, 0}, l1.getPadding());
|
assertArrayEquals(new int[] {0, 0}, l1.getPadding());
|
||||||
assertEquals(PoolingType.MAX, l1.getPoolingType());
|
assertEquals(PoolingType.MAX, l1.getPoolingType());
|
||||||
assertEquals(l1.getConvolutionMode(), ConvolutionMode.Same);
|
assertEquals(ConvolutionMode.Same, l1.getConvolutionMode());
|
||||||
|
|
||||||
OutputLayer l2 = (OutputLayer) conf.getConf(2).getLayer();
|
OutputLayer l2 = (OutputLayer) conf.getConf(2).getLayer();
|
||||||
assertTrue(l2.getActivationFn() instanceof ActivationSigmoid);
|
assertTrue(l2.getActivationFn() instanceof ActivationSigmoid);
|
||||||
|
|
|
@ -178,8 +178,6 @@ public abstract class BaseCudnnHelper {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
protected static final int TENSOR_FORMAT = CUDNN_TENSOR_NCHW;
|
|
||||||
|
|
||||||
protected final DataType nd4jDataType;
|
protected final DataType nd4jDataType;
|
||||||
protected final int dataType;
|
protected final int dataType;
|
||||||
protected final int dataTypeSize;
|
protected final int dataTypeSize;
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -22,6 +23,7 @@ import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import com.jakewharton.byteunits.BinaryByteUnit;
|
import com.jakewharton.byteunits.BinaryByteUnit;
|
||||||
import org.bytedeco.javacpp.Pointer;
|
import org.bytedeco.javacpp.Pointer;
|
||||||
|
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||||
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer.AlgoMode;
|
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer.AlgoMode;
|
||||||
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer.BwdDataAlgo;
|
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer.BwdDataAlgo;
|
||||||
|
@ -86,7 +88,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
|
||||||
}
|
}
|
||||||
|
|
||||||
private cudnnTensorStruct srcTensorDesc = new cudnnTensorStruct(), dstTensorDesc = new cudnnTensorStruct(),
|
private cudnnTensorStruct srcTensorDesc = new cudnnTensorStruct(), dstTensorDesc = new cudnnTensorStruct(),
|
||||||
biasTensorDesc = new cudnnTensorStruct(), deltaTensorDesc = new cudnnTensorStruct();
|
biasTensorDesc = new cudnnTensorStruct(), deltaTensorDesc = new cudnnTensorStruct();
|
||||||
private cudnnFilterStruct filterDesc = new cudnnFilterStruct();
|
private cudnnFilterStruct filterDesc = new cudnnFilterStruct();
|
||||||
private cudnnConvolutionStruct convDesc = new cudnnConvolutionStruct();
|
private cudnnConvolutionStruct convDesc = new cudnnConvolutionStruct();
|
||||||
private cudnnActivationStruct activationDesc = new cudnnActivationStruct();
|
private cudnnActivationStruct activationDesc = new cudnnActivationStruct();
|
||||||
|
@ -138,7 +140,21 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
|
||||||
public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray weights, INDArray bias, INDArray delta, int[] kernel,
|
public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray weights, INDArray bias, INDArray delta, int[] kernel,
|
||||||
int[] strides, int[] pad, INDArray biasGradView, INDArray weightGradView, IActivation afn,
|
int[] strides, int[] pad, INDArray biasGradView, INDArray weightGradView, IActivation afn,
|
||||||
AlgoMode mode, BwdFilterAlgo bwdFilterAlgo, BwdDataAlgo bwdDataAlgo,
|
AlgoMode mode, BwdFilterAlgo bwdFilterAlgo, BwdDataAlgo bwdDataAlgo,
|
||||||
ConvolutionMode convolutionMode, int[] dilation, LayerWorkspaceMgr workspaceMgr) {
|
ConvolutionMode convolutionMode, int[] dilation, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) {
|
||||||
|
|
||||||
|
//AB 2020/04/21 - cuDNN does have NHWC support (with limitations) however I have been unable to get it working
|
||||||
|
// correctly on NHWC data, even after updating all descriptors, tensor format, etc.
|
||||||
|
//Therefore: all computation here is done in NCHW format only
|
||||||
|
//As of a future (next?) release we'll likely switch to C++ for cuDNN support
|
||||||
|
boolean origNHWC = false;
|
||||||
|
if(format == CNN2DFormat.NHWC){
|
||||||
|
input = input.permute(0,3,1,2); //NHWC to NCHW
|
||||||
|
delta = delta.permute(0,3,1,2);
|
||||||
|
origNHWC = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
int TENSOR_FORMAT = CUDNN_TENSOR_NCHW;
|
||||||
|
|
||||||
int code;
|
int code;
|
||||||
|
|
||||||
val miniBatch = input.size(0);
|
val miniBatch = input.size(0);
|
||||||
|
@ -147,7 +163,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
|
||||||
val kH = weights.size(2);
|
val kH = weights.size(2);
|
||||||
val kW = weights.size(3);
|
val kW = weights.size(3);
|
||||||
|
|
||||||
CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, null);
|
CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, null, CNN2DFormat.NCHW); //Note hardcoded NCHW due to above
|
||||||
input = args.getInput();
|
input = args.getInput();
|
||||||
val inH = input.size(2);
|
val inH = input.size(2);
|
||||||
val inW = input.size(3);
|
val inW = input.size(3);
|
||||||
|
@ -176,7 +192,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
|
||||||
(int) deltaStride[0], (int) deltaStride[1], (int) deltaStride[2], (int) deltaStride[3]);
|
(int) deltaStride[0], (int) deltaStride[1], (int) deltaStride[2], (int) deltaStride[3]);
|
||||||
checkCudnn(false, "cudnnSetTensor4dDescriptorEx", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
|
checkCudnn(false, "cudnnSetTensor4dDescriptorEx", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
|
||||||
code = cudnnSetConvolution2dDescriptor(cudnnContext.convDesc, pad[0], pad[1], strides[0], strides[1], dilation[0],
|
code = cudnnSetConvolution2dDescriptor(cudnnContext.convDesc, pad[0], pad[1], strides[0], strides[1], dilation[0],
|
||||||
dilation[1], CUDNN_CROSS_CORRELATION, dataType);
|
dilation[1], CUDNN_CROSS_CORRELATION, dataType);
|
||||||
checkCudnn(false, "cudnnSetConvolution2dDescriptor", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
|
checkCudnn(false, "cudnnSetConvolution2dDescriptor", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
|
||||||
code = cudnnSetFilter4dDescriptor(cudnnContext.filterDesc, dataType, TENSOR_FORMAT, (int) outDepth, (int) inDepth, (int) kH, (int) kW);
|
code = cudnnSetFilter4dDescriptor(cudnnContext.filterDesc, dataType, TENSOR_FORMAT, (int) outDepth, (int) inDepth, (int) kH, (int) kW);
|
||||||
checkCudnn(false, "cudnnSetFilter4dDescriptor", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
|
checkCudnn(false, "cudnnSetFilter4dDescriptor", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
|
||||||
|
@ -238,16 +254,16 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
code = cudnnGetConvolutionBackwardFilterAlgorithm(cudnnContext, cudnnContext.srcTensorDesc,
|
code = cudnnGetConvolutionBackwardFilterAlgorithm(cudnnContext, cudnnContext.srcTensorDesc,
|
||||||
cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.filterDesc,
|
cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.filterDesc,
|
||||||
mode == AlgoMode.NO_WORKSPACE ? CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE
|
mode == AlgoMode.NO_WORKSPACE ? CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE
|
||||||
: CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST,
|
: CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST,
|
||||||
0, algo1);
|
0, algo1);
|
||||||
checkCudnn(false, "cudnnGetConvolutionBackwardFilterAlgorithm", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
|
checkCudnn(false, "cudnnGetConvolutionBackwardFilterAlgorithm", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
|
||||||
code = cudnnGetConvolutionBackwardDataAlgorithm(cudnnContext, cudnnContext.filterDesc,
|
code = cudnnGetConvolutionBackwardDataAlgorithm(cudnnContext, cudnnContext.filterDesc,
|
||||||
cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.srcTensorDesc,
|
cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.srcTensorDesc,
|
||||||
mode == AlgoMode.NO_WORKSPACE ? CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE
|
mode == AlgoMode.NO_WORKSPACE ? CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE
|
||||||
: CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST,
|
: CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST,
|
||||||
0, algo2);
|
0, algo2);
|
||||||
checkCudnn(false, "cudnnGetConvolutionBackwardDataAlgorithm", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
|
checkCudnn(false, "cudnnGetConvolutionBackwardDataAlgorithm", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -263,7 +279,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
|
||||||
|
|
||||||
Allocator allocator = AtomicAllocator.getInstance();
|
Allocator allocator = AtomicAllocator.getInstance();
|
||||||
CudaContext context = allocator.getFlowController().prepareActionAllWrite(input, weights, weightGradView,
|
CudaContext context = allocator.getFlowController().prepareActionAllWrite(input, weights, weightGradView,
|
||||||
biasGradView, delta, epsNext);
|
biasGradView, delta, epsNext);
|
||||||
Pointer srcData = allocator.getPointer(input, context);
|
Pointer srcData = allocator.getPointer(input, context);
|
||||||
Pointer filterData = allocator.getPointer(weights, context);
|
Pointer filterData = allocator.getPointer(weights, context);
|
||||||
Pointer filterGradData = allocator.getPointer(weightGradView, context);
|
Pointer filterGradData = allocator.getPointer(weightGradView, context);
|
||||||
|
@ -279,14 +295,14 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
|
||||||
checkCudnn(false, "cudnnSetTensor4dDescriptorEx", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
|
checkCudnn(false, "cudnnSetTensor4dDescriptorEx", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
|
||||||
|
|
||||||
code = cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnnContext, cudnnContext.srcTensorDesc,
|
code = cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnnContext, cudnnContext.srcTensorDesc,
|
||||||
cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.filterDesc, algo1[0],
|
cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.filterDesc, algo1[0],
|
||||||
sizeInBytes);
|
sizeInBytes);
|
||||||
checkCudnn(false, "cudnnGetConvolutionBackwardFilterWorkspaceSize", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
|
checkCudnn(false, "cudnnGetConvolutionBackwardFilterWorkspaceSize", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
|
||||||
|
|
||||||
long sizeInBytes1 = sizeInBytes.get(0);
|
long sizeInBytes1 = sizeInBytes.get(0);
|
||||||
code = cudnnGetConvolutionBackwardDataWorkspaceSize(cudnnContext, cudnnContext.filterDesc,
|
code = cudnnGetConvolutionBackwardDataWorkspaceSize(cudnnContext, cudnnContext.filterDesc,
|
||||||
cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.dstTensorDesc, algo2[0],
|
cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.dstTensorDesc, algo2[0],
|
||||||
sizeInBytes);
|
sizeInBytes);
|
||||||
checkCudnn(false, "cudnnGetConvolutionBackwardDataWorkspaceSize", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
|
checkCudnn(false, "cudnnGetConvolutionBackwardDataWorkspaceSize", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
|
||||||
|
|
||||||
DataCache workSpace = workspaceMgr.getHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY);
|
DataCache workSpace = workspaceMgr.getHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY);
|
||||||
|
@ -313,21 +329,21 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
|
||||||
checkCudnn(false, "cudnnSetTensor4dDescriptor", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
|
checkCudnn(false, "cudnnSetTensor4dDescriptor", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
|
||||||
|
|
||||||
code = cudnnConvolutionBackwardBias(cudnnContext, alpha, cudnnContext.deltaTensorDesc, deltaData, beta,
|
code = cudnnConvolutionBackwardBias(cudnnContext, alpha, cudnnContext.deltaTensorDesc, deltaData, beta,
|
||||||
cudnnContext.biasTensorDesc, biasGradData);
|
cudnnContext.biasTensorDesc, biasGradData);
|
||||||
checkCudnn(false, "cudnnConvolutionBackwardBias", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
|
checkCudnn(false, "cudnnConvolutionBackwardBias", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
|
||||||
|
|
||||||
code = cudnnConvolutionBackwardFilter(cudnnContext, alpha, cudnnContext.srcTensorDesc, srcData,
|
code = cudnnConvolutionBackwardFilter(cudnnContext, alpha, cudnnContext.srcTensorDesc, srcData,
|
||||||
cudnnContext.deltaTensorDesc, deltaData, cudnnContext.convDesc, algo1[0], workSpace,
|
cudnnContext.deltaTensorDesc, deltaData, cudnnContext.convDesc, algo1[0], workSpace,
|
||||||
workSpace.capacity(), beta, cudnnContext.filterDesc, filterGradData);
|
workSpace.capacity(), beta, cudnnContext.filterDesc, filterGradData);
|
||||||
checkCudnn(false, "cudnnConvolutionBackwardFilter", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
|
checkCudnn(false, "cudnnConvolutionBackwardFilter", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
|
||||||
|
|
||||||
code = cudnnConvolutionBackwardData(cudnnContext, alpha, cudnnContext.filterDesc, filterData,
|
code = cudnnConvolutionBackwardData(cudnnContext, alpha, cudnnContext.filterDesc, filterData,
|
||||||
cudnnContext.deltaTensorDesc, deltaData, cudnnContext.convDesc, algo2[0], workSpace,
|
cudnnContext.deltaTensorDesc, deltaData, cudnnContext.convDesc, algo2[0], workSpace,
|
||||||
workSpace.capacity(), beta, cudnnContext.dstTensorDesc, dstData);
|
workSpace.capacity(), beta, cudnnContext.dstTensorDesc, dstData);
|
||||||
checkCudnn(false, "cudnnConvolutionBackwardData", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
|
checkCudnn(false, "cudnnConvolutionBackwardData", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
|
||||||
|
|
||||||
allocator.getFlowController().registerActionAllWrite(context, input, weights, weightGradView, biasGradView,
|
allocator.getFlowController().registerActionAllWrite(context, input, weights, weightGradView, biasGradView,
|
||||||
delta, epsNext);
|
delta, epsNext);
|
||||||
|
|
||||||
Gradient retGradient = new DefaultGradient();
|
Gradient retGradient = new DefaultGradient();
|
||||||
retGradient.setGradientFor(ConvolutionParamInitializer.BIAS_KEY, biasGradView);
|
retGradient.setGradientFor(ConvolutionParamInitializer.BIAS_KEY, biasGradView);
|
||||||
|
@ -344,12 +360,30 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
|
||||||
interval(0, epsNext.size(3) - (args.isManualPadRight() ? 1 : 0)));
|
interval(0, epsNext.size(3) - (args.isManualPadRight() ? 1 : 0)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if(origNHWC){
|
||||||
|
epsNext = epsNext.permute(0,2,3,1); //NCHW to NHWC
|
||||||
|
}
|
||||||
|
|
||||||
return new Pair<>(retGradient, epsNext);
|
return new Pair<>(retGradient, epsNext);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray preOutput(INDArray input, INDArray weights, INDArray bias, int[] kernel, int[] strides, int[] pad,
|
public INDArray preOutput(INDArray input, INDArray weights, INDArray bias, int[] kernel, int[] strides, int[] pad,
|
||||||
AlgoMode mode, FwdAlgo fwdAlgo, ConvolutionMode convolutionMode, int[] dilation, LayerWorkspaceMgr workspaceMgr) {
|
AlgoMode mode, FwdAlgo fwdAlgo, ConvolutionMode convolutionMode, int[] dilation, CNN2DFormat format,
|
||||||
|
LayerWorkspaceMgr workspaceMgr) {
|
||||||
|
|
||||||
|
//AB 2020/04/21 - cuDNN does have NHWC support (with limitations) however I have been unable to get it working
|
||||||
|
// correctly on NHWC data, even after updating all descriptors, tensor format, etc.
|
||||||
|
//Therefore: all computation here is done in NCHW format only
|
||||||
|
//As of a future (next?) release we'll likely switch to C++ for cuDNN support
|
||||||
|
boolean origNHWC = false;
|
||||||
|
if(format == CNN2DFormat.NHWC){
|
||||||
|
input = input.permute(0,3,1,2); //NHWC to NCHW
|
||||||
|
origNHWC = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
int TENSOR_FORMAT = CUDNN_TENSOR_NCHW;
|
||||||
|
|
||||||
int code;
|
int code;
|
||||||
|
|
||||||
val miniBatch = input.size(0);
|
val miniBatch = input.size(0);
|
||||||
|
@ -358,7 +392,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
|
||||||
val kH = weights.size(2);
|
val kH = weights.size(2);
|
||||||
val kW = weights.size(3);
|
val kW = weights.size(3);
|
||||||
|
|
||||||
CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, null);
|
CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, null, CNN2DFormat.NCHW); //Note hardcoded NCHW due to above
|
||||||
input = args.getInput();
|
input = args.getInput();
|
||||||
val inH = input.size(2);
|
val inH = input.size(2);
|
||||||
val inW = input.size(3);
|
val inW = input.size(3);
|
||||||
|
@ -378,7 +412,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
|
||||||
checkCudnn(true, "cudnnSetFilter4dDescriptor", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
|
checkCudnn(true, "cudnnSetFilter4dDescriptor", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
|
||||||
|
|
||||||
code = cudnnSetConvolution2dDescriptor(cudnnContext.convDesc, pad[0], pad[1], strides[0], strides[1], dilation[0],
|
code = cudnnSetConvolution2dDescriptor(cudnnContext.convDesc, pad[0], pad[1], strides[0], strides[1], dilation[0],
|
||||||
dilation[1], CUDNN_CROSS_CORRELATION, dataType);
|
dilation[1], CUDNN_CROSS_CORRELATION, dataType);
|
||||||
checkCudnn(true, "cudnnSetConvolution2dDescriptor", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
|
checkCudnn(true, "cudnnSetConvolution2dDescriptor", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
|
||||||
|
|
||||||
|
|
||||||
|
@ -460,8 +494,8 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
|
||||||
checkCudnn(true, "cudnnSetStream", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
|
checkCudnn(true, "cudnnSetStream", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
|
||||||
|
|
||||||
code = cudnnGetConvolutionForwardWorkspaceSize(cudnnContext, cudnnContext.srcTensorDesc,
|
code = cudnnGetConvolutionForwardWorkspaceSize(cudnnContext, cudnnContext.srcTensorDesc,
|
||||||
cudnnContext.filterDesc, cudnnContext.convDesc, cudnnContext.dstTensorDesc, algo[0],
|
cudnnContext.filterDesc, cudnnContext.convDesc, cudnnContext.dstTensorDesc, algo[0],
|
||||||
sizeInBytes);
|
sizeInBytes);
|
||||||
checkCudnn(true, "cudnnGetConvolutionForwardWorkspaceSize", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
|
checkCudnn(true, "cudnnGetConvolutionForwardWorkspaceSize", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
|
||||||
|
|
||||||
DataCache workSpace = workspaceMgr.getHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY);
|
DataCache workSpace = workspaceMgr.getHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY);
|
||||||
|
@ -482,8 +516,8 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
|
||||||
workspaceMgr.setHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY, workSpace);
|
workspaceMgr.setHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY, workSpace);
|
||||||
}
|
}
|
||||||
code = cudnnConvolutionForward(cudnnContext, alpha, cudnnContext.srcTensorDesc, srcData,
|
code = cudnnConvolutionForward(cudnnContext, alpha, cudnnContext.srcTensorDesc, srcData,
|
||||||
cudnnContext.filterDesc, filterData, cudnnContext.convDesc, algo[0], workSpace,
|
cudnnContext.filterDesc, filterData, cudnnContext.convDesc, algo[0], workSpace,
|
||||||
workSpace.capacity(), beta, cudnnContext.dstTensorDesc, dstData);
|
workSpace.capacity(), beta, cudnnContext.dstTensorDesc, dstData);
|
||||||
checkCudnn(true, "cudnnConvolutionForward", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
|
checkCudnn(true, "cudnnConvolutionForward", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
|
||||||
|
|
||||||
|
|
||||||
|
@ -491,7 +525,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
|
||||||
checkCudnn(true, "cudnnSetTensor4dDescriptor", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
|
checkCudnn(true, "cudnnSetTensor4dDescriptor", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
|
||||||
|
|
||||||
code = cudnnAddTensor(cudnnContext, alpha, cudnnContext.biasTensorDesc, biasData, alpha,
|
code = cudnnAddTensor(cudnnContext, alpha, cudnnContext.biasTensorDesc, biasData, alpha,
|
||||||
cudnnContext.dstTensorDesc, dstData);
|
cudnnContext.dstTensorDesc, dstData);
|
||||||
checkCudnn(true, "cudnnAddTensor", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
|
checkCudnn(true, "cudnnAddTensor", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
|
||||||
|
|
||||||
allocator.registerAction(context, z, input, weights, bias);
|
allocator.registerAction(context, z, input, weights, bias);
|
||||||
|
@ -499,6 +533,10 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
|
||||||
if (CudaEnvironment.getInstance().getConfiguration().isDebug())
|
if (CudaEnvironment.getInstance().getConfiguration().isDebug())
|
||||||
context.syncOldStream();
|
context.syncOldStream();
|
||||||
|
|
||||||
|
if(origNHWC){
|
||||||
|
z = z.permute(0,2,3,1); //NCHW to NHWC
|
||||||
|
}
|
||||||
|
|
||||||
return z;
|
return z;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -552,29 +590,29 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
|
||||||
break;
|
break;
|
||||||
case "sigmoid":
|
case "sigmoid":
|
||||||
checkCudnn(cudnnSetActivationDescriptor(cudnnContext.activationDesc, CUDNN_ACTIVATION_SIGMOID,
|
checkCudnn(cudnnSetActivationDescriptor(cudnnContext.activationDesc, CUDNN_ACTIVATION_SIGMOID,
|
||||||
CUDNN_PROPAGATE_NAN, 0));
|
CUDNN_PROPAGATE_NAN, 0));
|
||||||
checkCudnn(cudnnActivationForward(cudnnContext, cudnnContext.activationDesc, alpha,
|
checkCudnn(cudnnActivationForward(cudnnContext, cudnnContext.activationDesc, alpha,
|
||||||
cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData));
|
cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData));
|
||||||
break;
|
break;
|
||||||
case "relu":
|
case "relu":
|
||||||
checkCudnn(cudnnSetActivationDescriptor(cudnnContext.activationDesc, CUDNN_ACTIVATION_RELU,
|
checkCudnn(cudnnSetActivationDescriptor(cudnnContext.activationDesc, CUDNN_ACTIVATION_RELU,
|
||||||
CUDNN_PROPAGATE_NAN, 0));
|
CUDNN_PROPAGATE_NAN, 0));
|
||||||
checkCudnn(cudnnActivationForward(cudnnContext, cudnnContext.activationDesc, alpha,
|
checkCudnn(cudnnActivationForward(cudnnContext, cudnnContext.activationDesc, alpha,
|
||||||
cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData));
|
cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData));
|
||||||
break;
|
break;
|
||||||
case "tanh":
|
case "tanh":
|
||||||
checkCudnn(cudnnSetActivationDescriptor(cudnnContext.activationDesc, CUDNN_ACTIVATION_TANH,
|
checkCudnn(cudnnSetActivationDescriptor(cudnnContext.activationDesc, CUDNN_ACTIVATION_TANH,
|
||||||
CUDNN_PROPAGATE_NAN, 0));
|
CUDNN_PROPAGATE_NAN, 0));
|
||||||
checkCudnn(cudnnActivationForward(cudnnContext, cudnnContext.activationDesc, alpha,
|
checkCudnn(cudnnActivationForward(cudnnContext, cudnnContext.activationDesc, alpha,
|
||||||
cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData));
|
cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData));
|
||||||
break;
|
break;
|
||||||
case "softmax":
|
case "softmax":
|
||||||
checkCudnn(cudnnSoftmaxForward(cudnnContext, CUDNN_SOFTMAX_ACCURATE, CUDNN_SOFTMAX_MODE_CHANNEL, alpha,
|
checkCudnn(cudnnSoftmaxForward(cudnnContext, CUDNN_SOFTMAX_ACCURATE, CUDNN_SOFTMAX_MODE_CHANNEL, alpha,
|
||||||
cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData));
|
cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData));
|
||||||
break;
|
break;
|
||||||
case "logsoftmax":
|
case "logsoftmax":
|
||||||
checkCudnn(cudnnSoftmaxForward(cudnnContext, CUDNN_SOFTMAX_LOG, CUDNN_SOFTMAX_MODE_CHANNEL, alpha,
|
checkCudnn(cudnnSoftmaxForward(cudnnContext, CUDNN_SOFTMAX_LOG, CUDNN_SOFTMAX_MODE_CHANNEL, alpha,
|
||||||
cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData));
|
cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData));
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
activation = null;
|
activation = null;
|
||||||
|
@ -593,7 +631,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public static CudnnForwardArgs getCudnnForwardArgs(INDArray input, int[] kernel, int[] strides, int[] padding, int[] dilation,
|
public static CudnnForwardArgs getCudnnForwardArgs(INDArray input, int[] kernel, int[] strides, int[] padding, int[] dilation,
|
||||||
ConvolutionMode convolutionMode, PoolingType poolingType){
|
ConvolutionMode convolutionMode, PoolingType poolingType, CNN2DFormat format){
|
||||||
INDArray origInput = input;
|
INDArray origInput = input;
|
||||||
|
|
||||||
//Check if we need to dup the input: views, non-contiguous, etc. CuDNN also seems to have has issues if strides
|
//Check if we need to dup the input: views, non-contiguous, etc. CuDNN also seems to have has issues if strides
|
||||||
|
@ -602,16 +640,19 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
|
||||||
input = input.dup('c');
|
input = input.dup('c');
|
||||||
}
|
}
|
||||||
|
|
||||||
|
boolean nchw = format == CNN2DFormat.NCHW;
|
||||||
|
int hIdx = nchw ? 2 : 1;
|
||||||
|
int wIdx = nchw ? 3 : 2;
|
||||||
|
|
||||||
val inH = input.size(2);
|
val inH = input.size(hIdx);
|
||||||
val inW = input.size(3);
|
val inW = input.size(wIdx);
|
||||||
|
|
||||||
boolean manualPadBottom = false;
|
boolean manualPadBottom = false;
|
||||||
boolean manualPadRight = false;
|
boolean manualPadRight = false;
|
||||||
|
|
||||||
int[] outSize;
|
int[] outSize;
|
||||||
if (convolutionMode == ConvolutionMode.Same) {
|
if (convolutionMode == ConvolutionMode.Same) {
|
||||||
outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation); //Also performs validation
|
outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation, format); //Also performs validation
|
||||||
padding = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[] {(int) inH, (int) inW}, kernel, strides, dilation);
|
padding = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[] {(int) inH, (int) inW}, kernel, strides, dilation);
|
||||||
int[] padBottomRight = ConvolutionUtils.getSameModeBottomRightPadding(outSize, new int[] {(int) inH, (int) inW}, kernel, strides, dilation);
|
int[] padBottomRight = ConvolutionUtils.getSameModeBottomRightPadding(outSize, new int[] {(int) inH, (int) inW}, kernel, strides, dilation);
|
||||||
if(!Arrays.equals(padding, padBottomRight)){
|
if(!Arrays.equals(padding, padBottomRight)){
|
||||||
|
@ -626,9 +667,17 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
|
||||||
manualPadRight = (padding[1] != padBottomRight[1]);
|
manualPadRight = (padding[1] != padBottomRight[1]);
|
||||||
|
|
||||||
//NCHW format
|
//NCHW format
|
||||||
val newShape = new long[]{input.size(0), input.size(1),
|
long[] newShape;
|
||||||
input.size(2) + (manualPadBottom ? 1 : 0),
|
if(nchw){
|
||||||
input.size(3) + (manualPadRight ? 1 : 0)};
|
newShape = new long[]{input.size(0), input.size(1),
|
||||||
|
input.size(2) + (manualPadBottom ? 1 : 0),
|
||||||
|
input.size(3) + (manualPadRight ? 1 : 0)};
|
||||||
|
} else {
|
||||||
|
newShape = new long[]{input.size(0),
|
||||||
|
input.size(1) + (manualPadBottom ? 1 : 0),
|
||||||
|
input.size(2) + (manualPadRight ? 1 : 0),
|
||||||
|
input.size(3)};
|
||||||
|
}
|
||||||
INDArray newInput;
|
INDArray newInput;
|
||||||
if(poolingType == null || poolingType != PoolingType.MAX){
|
if(poolingType == null || poolingType != PoolingType.MAX){
|
||||||
newInput = Nd4j.create(input.dataType(), newShape);
|
newInput = Nd4j.create(input.dataType(), newShape);
|
||||||
|
@ -638,15 +687,22 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
|
||||||
// if the 'real' (non-padding) values are all < 0, we take the real value, not the padding value
|
// if the 'real' (non-padding) values are all < 0, we take the real value, not the padding value
|
||||||
newInput = Nd4j.valueArrayOf(newShape, Double.NEGATIVE_INFINITY, input.dataType());
|
newInput = Nd4j.valueArrayOf(newShape, Double.NEGATIVE_INFINITY, input.dataType());
|
||||||
}
|
}
|
||||||
newInput.put(new INDArrayIndex[]{all(), all(), interval(0,input.size(2)),
|
|
||||||
interval(0, input.size(3))}, input);
|
if(nchw){
|
||||||
|
newInput.put(new INDArrayIndex[]{all(), all(), interval(0,input.size(2)),
|
||||||
|
interval(0, input.size(3))}, input);
|
||||||
|
} else {
|
||||||
|
newInput.put(new INDArrayIndex[]{all(), interval(0,input.size(1)),
|
||||||
|
interval(0, input.size(2)), all()}, input);
|
||||||
|
}
|
||||||
|
|
||||||
input = newInput;
|
input = newInput;
|
||||||
//Now: we've manually applied the "extra" bottom/right padding only - if required. Consequently, we
|
//Now: we've manually applied the "extra" bottom/right padding only - if required. Consequently, we
|
||||||
// now have the same amount of padding required for top/bottom, and left/right - which we'll let
|
// now have the same amount of padding required for top/bottom, and left/right - which we'll let
|
||||||
// CuDNN handle
|
// CuDNN handle
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, padding, convolutionMode, dilation); //Also performs validation
|
outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, padding, convolutionMode, dilation, format); //Also performs validation
|
||||||
}
|
}
|
||||||
|
|
||||||
return new CudnnForwardArgs(manualPadBottom, manualPadRight, input, origInput, padding, outSize);
|
return new CudnnForwardArgs(manualPadBottom, manualPadRight, input, origInput, padding, outSize);
|
||||||
|
@ -670,4 +726,4 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
|
||||||
return Collections.emptyMap();
|
return Collections.emptyMap();
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
|
@ -19,6 +19,7 @@ package org.deeplearning4j.nn.layers.convolution.subsampling;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.bytedeco.javacpp.Pointer;
|
import org.bytedeco.javacpp.Pointer;
|
||||||
|
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||||
import org.deeplearning4j.nn.conf.layers.PoolingType;
|
import org.deeplearning4j.nn.conf.layers.PoolingType;
|
||||||
import org.deeplearning4j.nn.gradient.DefaultGradient;
|
import org.deeplearning4j.nn.gradient.DefaultGradient;
|
||||||
|
@ -114,23 +115,29 @@ public class CudnnSubsamplingHelper extends BaseCudnnHelper implements Subsampli
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, int[] kernel, int[] strides,
|
public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, int[] kernel, int[] strides,
|
||||||
int[] pad, PoolingType poolingType, ConvolutionMode convolutionMode, int[] dilation, LayerWorkspaceMgr workspaceMgr) {
|
int[] pad, PoolingType poolingType, ConvolutionMode convolutionMode,
|
||||||
|
int[] dilation, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) {
|
||||||
if(dilation[0] != 1 || dilation[1] != 1){
|
if(dilation[0] != 1 || dilation[1] != 1){
|
||||||
//CuDNN doesn't support dilated subsampling
|
//CuDNN doesn't support dilated subsampling
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
boolean nchw = format == CNN2DFormat.NCHW;
|
||||||
|
int chIdx = nchw ? 1 : 3;
|
||||||
|
int hIdx = nchw ? 2 : 1;
|
||||||
|
int wIdx = nchw ? 3 : 2;
|
||||||
|
|
||||||
//We require the output as one of the arguments for backprop here
|
//We require the output as one of the arguments for backprop here
|
||||||
//TODO we could add cache mode support here somehow...
|
//TODO we could add cache mode support here somehow...
|
||||||
INDArray reduced = activate(input, true, kernel, strides, pad, poolingType, convolutionMode, dilation, workspaceMgr);
|
INDArray reduced = activate(input, true, kernel, strides, pad, poolingType, convolutionMode, dilation, format, workspaceMgr);
|
||||||
|
|
||||||
val miniBatch = input.size(0);
|
val miniBatch = input.size(0);
|
||||||
val depth = input.size(1);
|
val depth = input.size(chIdx);
|
||||||
|
|
||||||
CudnnConvolutionHelper.CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, poolingType);
|
CudnnConvolutionHelper.CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, poolingType, format);
|
||||||
input = args.getInput();
|
input = args.getInput();
|
||||||
val inH = input.size(2);
|
val inH = input.size(hIdx);
|
||||||
val inW = input.size(3);
|
val inW = input.size(wIdx);
|
||||||
val srcStride = input.stride();
|
val srcStride = input.stride();
|
||||||
int[] outSize = args.getOutSize();
|
int[] outSize = args.getOutSize();
|
||||||
int outH = outSize[0];
|
int outH = outSize[0];
|
||||||
|
@ -160,23 +167,26 @@ public class CudnnSubsamplingHelper extends BaseCudnnHelper implements Subsampli
|
||||||
epsilon = epsilon.dup('c');
|
epsilon = epsilon.dup('c');
|
||||||
}
|
}
|
||||||
|
|
||||||
|
input = input.dup();
|
||||||
|
|
||||||
val deltaStride = epsilon.stride();
|
val deltaStride = epsilon.stride();
|
||||||
|
|
||||||
if (Nd4j.getExecutioner() instanceof GridExecutioner)
|
if (Nd4j.getExecutioner() instanceof GridExecutioner)
|
||||||
((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
|
((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
|
||||||
|
|
||||||
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW,
|
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW,
|
||||||
(int) srcStride[0], (int) srcStride[1], (int) srcStride[2], (int) srcStride[3]));
|
(int) srcStride[0], (int) srcStride[chIdx], (int) srcStride[hIdx], (int) srcStride[wIdx]));
|
||||||
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.deltaTensorDesc, dataType, (int) miniBatch, (int) depth, (int) outH, (int) outW,
|
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.deltaTensorDesc, dataType, (int) miniBatch, (int) depth, (int) outH, (int) outW,
|
||||||
(int) deltaStride[0], (int) deltaStride[1], (int) deltaStride[2], (int) deltaStride[3]));
|
(int) deltaStride[0], (int) deltaStride[chIdx], (int) deltaStride[hIdx], (int) deltaStride[wIdx]));
|
||||||
checkCudnn(cudnnSetPooling2dDescriptor(cudnnContext.poolingDesc, poolingMode, CUDNN_PROPAGATE_NAN, kernel[0],
|
checkCudnn(cudnnSetPooling2dDescriptor(cudnnContext.poolingDesc, poolingMode, CUDNN_PROPAGATE_NAN, kernel[0],
|
||||||
kernel[1], pad[0], pad[1], strides[0], strides[1]));
|
kernel[1], pad[0], pad[1], strides[0], strides[1]));
|
||||||
|
|
||||||
INDArray outEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), new long[] {(int) miniBatch, (int) depth, (int) inH, (int) inW}, 'c');
|
long[] outEpsShape = nchw ? new long[] {miniBatch, depth, inH, inW} : new long[] {miniBatch, inH, inW, depth};
|
||||||
|
INDArray outEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), outEpsShape, 'c');
|
||||||
|
|
||||||
val dstStride = outEpsilon.stride();
|
val dstStride = outEpsilon.stride();
|
||||||
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW,
|
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW,
|
||||||
(int) dstStride[0], (int) dstStride[1], (int) dstStride[2], (int) dstStride[3]));
|
(int) dstStride[0], (int) dstStride[chIdx], (int) dstStride[hIdx], (int) dstStride[wIdx]));
|
||||||
|
|
||||||
Allocator allocator = AtomicAllocator.getInstance();
|
Allocator allocator = AtomicAllocator.getInstance();
|
||||||
CudaContext context = allocator.getFlowController().prepareAction(input, epsilon, reduced, outEpsilon);
|
CudaContext context = allocator.getFlowController().prepareAction(input, epsilon, reduced, outEpsilon);
|
||||||
|
@ -198,9 +208,16 @@ public class CudnnSubsamplingHelper extends BaseCudnnHelper implements Subsampli
|
||||||
//Note that: if we had to manually pad for SAME mode, we have to 'undo' this manual padding for the epsilon
|
//Note that: if we had to manually pad for SAME mode, we have to 'undo' this manual padding for the epsilon
|
||||||
// we return. The returned epsilon (i.e., dL/dIn array) has to be the same shape as the *original* input.
|
// we return. The returned epsilon (i.e., dL/dIn array) has to be the same shape as the *original* input.
|
||||||
if(args.isManualPadBottom() || args.isManualPadRight()) {
|
if(args.isManualPadBottom() || args.isManualPadRight()) {
|
||||||
outEpsilon = outEpsilon.get(all(), all(),
|
if(nchw){
|
||||||
interval(0, outEpsilon.size(2) - (args.isManualPadBottom() ? 1 : 0)),
|
outEpsilon = outEpsilon.get(all(), all(),
|
||||||
interval(0, outEpsilon.size(3) - (args.isManualPadRight() ? 1 : 0)));
|
interval(0, outEpsilon.size(2) - (args.isManualPadBottom() ? 1 : 0)),
|
||||||
|
interval(0, outEpsilon.size(3) - (args.isManualPadRight() ? 1 : 0)));
|
||||||
|
} else {
|
||||||
|
outEpsilon = outEpsilon.get(all(),
|
||||||
|
interval(0, outEpsilon.size(1) - (args.isManualPadBottom() ? 1 : 0)),
|
||||||
|
interval(0, outEpsilon.size(2) - (args.isManualPadRight() ? 1 : 0)),
|
||||||
|
all());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return new Pair<>(retGradient, outEpsilon);
|
return new Pair<>(retGradient, outEpsilon);
|
||||||
|
@ -209,19 +226,24 @@ public class CudnnSubsamplingHelper extends BaseCudnnHelper implements Subsampli
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray activate(INDArray input, boolean training, int[] kernel, int[] strides, int[] pad,
|
public INDArray activate(INDArray input, boolean training, int[] kernel, int[] strides, int[] pad,
|
||||||
PoolingType poolingType, ConvolutionMode convolutionMode, int[] dilation, LayerWorkspaceMgr workspaceMgr) {
|
PoolingType poolingType, ConvolutionMode convolutionMode, int[] dilation, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) {
|
||||||
if(dilation[0] != 1 || dilation[1] != 1){
|
if(dilation[0] != 1 || dilation[1] != 1){
|
||||||
//CuDNN doesn't support dilated subsampling
|
//CuDNN doesn't support dilated subsampling
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
val miniBatch = input.size(0);
|
boolean nchw = format == CNN2DFormat.NCHW;
|
||||||
val inDepth = input.size(1);
|
int chIdx = nchw ? 1 : 3;
|
||||||
|
int hIdx = nchw ? 2 : 1;
|
||||||
|
int wIdx = nchw ? 3 : 2;
|
||||||
|
|
||||||
CudnnConvolutionHelper.CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, poolingType);
|
val miniBatch = input.size(0);
|
||||||
|
val inDepth = input.size(nchw ? 1 : 3);
|
||||||
|
|
||||||
|
CudnnConvolutionHelper.CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, poolingType, format);
|
||||||
input = args.getInput();
|
input = args.getInput();
|
||||||
val inH = input.size(2);
|
val inH = input.size(nchw ? 2 : 1);
|
||||||
val inW = input.size(3);
|
val inW = input.size(nchw ? 3 : 2);
|
||||||
val srcStride = input.stride();
|
val srcStride = input.stride();
|
||||||
val outSize = args.getOutSize();
|
val outSize = args.getOutSize();
|
||||||
int outH = outSize[0];
|
int outH = outSize[0];
|
||||||
|
@ -246,13 +268,14 @@ public class CudnnSubsamplingHelper extends BaseCudnnHelper implements Subsampli
|
||||||
checkCudnn(cudnnSetPooling2dDescriptor(cudnnContext.poolingDesc, poolingMode, CUDNN_PROPAGATE_NAN, kernel[0],
|
checkCudnn(cudnnSetPooling2dDescriptor(cudnnContext.poolingDesc, poolingMode, CUDNN_PROPAGATE_NAN, kernel[0],
|
||||||
kernel[1], pad[0], pad[1], strides[0], strides[1]));
|
kernel[1], pad[0], pad[1], strides[0], strides[1]));
|
||||||
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, (int) miniBatch, (int) inDepth, (int) inH, (int) inW,
|
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, (int) miniBatch, (int) inDepth, (int) inH, (int) inW,
|
||||||
(int) srcStride[0], (int) srcStride[1], (int) srcStride[2], (int) srcStride[3]));
|
(int) srcStride[0], (int) srcStride[chIdx], (int) srcStride[hIdx], (int) srcStride[wIdx]));
|
||||||
|
|
||||||
INDArray reduced = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), new long[] {(int) miniBatch, (int) inDepth, outH, outW}, 'c');
|
long[] outShape = nchw ? new long[] {miniBatch, inDepth, outH, outW} : new long[] {miniBatch, outH, outW, inDepth};
|
||||||
|
INDArray reduced = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), outShape, 'c');
|
||||||
|
|
||||||
val dstStride = reduced.stride();
|
val dstStride = reduced.stride();
|
||||||
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, (int) miniBatch, (int) inDepth, (int) outH, (int) outW,
|
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, (int) miniBatch, (int) inDepth, (int) outH, (int) outW,
|
||||||
(int) dstStride[0], (int) dstStride[1], (int) dstStride[2], (int) dstStride[3]));
|
(int) dstStride[0], (int) dstStride[chIdx], (int) dstStride[hIdx], (int) dstStride[wIdx]));
|
||||||
|
|
||||||
Allocator allocator = AtomicAllocator.getInstance();
|
Allocator allocator = AtomicAllocator.getInstance();
|
||||||
CudaContext context = allocator.getFlowController().prepareAction(input, reduced);
|
CudaContext context = allocator.getFlowController().prepareAction(input, reduced);
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.deeplearning4j.nn.layers.normalization;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.bytedeco.javacpp.Pointer;
|
import org.bytedeco.javacpp.Pointer;
|
||||||
|
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||||
import org.deeplearning4j.nn.gradient.DefaultGradient;
|
import org.deeplearning4j.nn.gradient.DefaultGradient;
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.nn.layers.BaseCudnnHelper;
|
import org.deeplearning4j.nn.layers.BaseCudnnHelper;
|
||||||
|
@ -124,12 +125,21 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, long[] shape, INDArray gamma, INDArray beta,
|
public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, long[] shape, INDArray gamma, INDArray beta,
|
||||||
INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr layerWorkspaceMgr) {
|
INDArray dGammaView, INDArray dBetaView, double eps, CNN2DFormat format, LayerWorkspaceMgr layerWorkspaceMgr) {
|
||||||
|
|
||||||
|
boolean nchw = format == CNN2DFormat.NCHW;
|
||||||
|
|
||||||
this.eps = eps;
|
this.eps = eps;
|
||||||
|
|
||||||
|
int cudnnTensorFormat = nchw ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC;
|
||||||
|
int chIdx = nchw ? 1 : 3;
|
||||||
|
int hIdx = nchw ? 2 : 1;
|
||||||
|
int wIdx = nchw ? 3 : 2;
|
||||||
|
|
||||||
val miniBatch = (int) input.size(0);
|
val miniBatch = (int) input.size(0);
|
||||||
val depth = (int) input.size(1);
|
val depth = (int) input.size(chIdx);
|
||||||
val inH = (int) input.size(2);
|
val inH = (int) input.size(hIdx);
|
||||||
val inW = (int) input.size(3);
|
val inW = (int) input.size(wIdx);
|
||||||
|
|
||||||
final boolean isHalf = (input.dataType() == DataType.HALF);
|
final boolean isHalf = (input.dataType() == DataType.HALF);
|
||||||
INDArray gammaOrig = null;
|
INDArray gammaOrig = null;
|
||||||
|
@ -164,16 +174,17 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba
|
||||||
((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
|
((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
|
||||||
|
|
||||||
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW,
|
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW,
|
||||||
(int) srcStride[0], (int) srcStride[1], (int) srcStride[2], (int) srcStride[3]));
|
(int) srcStride[0], (int) srcStride[chIdx], (int) srcStride[hIdx], (int) srcStride[wIdx]));
|
||||||
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.deltaTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW,
|
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.deltaTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW,
|
||||||
(int) deltaStride[0], (int) deltaStride[1], (int) deltaStride[2], (int) deltaStride[3]));
|
(int) deltaStride[0], (int) deltaStride[chIdx], (int) deltaStride[hIdx], (int) deltaStride[wIdx]));
|
||||||
|
|
||||||
INDArray nextEpsilon = layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), new long[] {miniBatch, depth, inH, inW}, 'c');
|
long[] nextEpsShape = nchw ? new long[] {miniBatch, depth, inH, inW} : new long[] {miniBatch, inH, inW, depth};
|
||||||
|
INDArray nextEpsilon = layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), nextEpsShape, 'c');
|
||||||
val dstStride = ArrayUtil.toInts(nextEpsilon.stride());
|
val dstStride = ArrayUtil.toInts(nextEpsilon.stride());
|
||||||
|
|
||||||
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, depth, inH, inW,
|
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, depth, inH, inW,
|
||||||
dstStride[0], dstStride[1], dstStride[2], dstStride[3]));
|
dstStride[0], dstStride[chIdx], dstStride[hIdx], dstStride[wIdx]));
|
||||||
checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, TENSOR_FORMAT, toCudnnDataType(gamma.data().dataType()), (int)shape[0],
|
checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, cudnnTensorFormat, toCudnnDataType(gamma.data().dataType()), (int)shape[0],
|
||||||
(int)shape[1], shape.length > 2 ? (int)shape[2] : 1, shape.length > 3 ? (int)shape[3] : 1));
|
(int)shape[1], shape.length > 2 ? (int)shape[2] : 1, shape.length > 3 ? (int)shape[3] : 1));
|
||||||
|
|
||||||
Allocator allocator = AtomicAllocator.getInstance();
|
Allocator allocator = AtomicAllocator.getInstance();
|
||||||
|
@ -215,9 +226,15 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray preOutput(INDArray x, boolean training, long[] shape, INDArray gamma, INDArray beta, INDArray mean,
|
public INDArray preOutput(INDArray x, boolean training, long[] shape, INDArray gamma, INDArray beta, INDArray mean,
|
||||||
INDArray var, double decay, double eps, LayerWorkspaceMgr workspaceMgr) {
|
INDArray var, double decay, double eps, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) {
|
||||||
|
boolean nchw = format == CNN2DFormat.NCHW;
|
||||||
|
int cudnnTensorFormat = nchw ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC;
|
||||||
|
int chIdx = nchw ? 1 : 3;
|
||||||
|
int hIdx = nchw ? 2 : 1;
|
||||||
|
int wIdx = nchw ? 3 : 2;
|
||||||
|
|
||||||
this.eps = eps;
|
this.eps = eps;
|
||||||
final boolean isHalf = (x.dataType() == DataType.HALF);
|
final boolean isHalf = (x.dataType() == DataType.FLOAT16);
|
||||||
INDArray origGamma = gamma;
|
INDArray origGamma = gamma;
|
||||||
INDArray origBeta = beta;
|
INDArray origBeta = beta;
|
||||||
INDArray origMean = mean;
|
INDArray origMean = mean;
|
||||||
|
@ -238,21 +255,22 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba
|
||||||
decay = 0.0; //From cudnn docs: runningMean = newMean*factor + runningMean*(1-factor). -> 0 = "in-place modification of running mean disabled"
|
decay = 0.0; //From cudnn docs: runningMean = newMean*factor + runningMean*(1-factor). -> 0 = "in-place modification of running mean disabled"
|
||||||
|
|
||||||
val miniBatch = (int) x.size(0);
|
val miniBatch = (int) x.size(0);
|
||||||
val inDepth = (int) x.size(1);
|
val inDepth = (int) x.size(chIdx);
|
||||||
val inH = (int) x.size(2);
|
val inH = (int) x.size(hIdx);
|
||||||
val inW = (int) x.size(3);
|
val inW = (int) x.size(wIdx);
|
||||||
|
|
||||||
val srcStride = ArrayUtil.toInts(x.stride());
|
val srcStride = ArrayUtil.toInts(x.stride());
|
||||||
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, miniBatch, inDepth, inH, inW,
|
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, miniBatch, inDepth, inH, inW,
|
||||||
srcStride[0], srcStride[1], srcStride[2], srcStride[3]));
|
srcStride[0], srcStride[chIdx], srcStride[hIdx], srcStride[wIdx]));
|
||||||
|
|
||||||
INDArray activations = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, x.dataType(), new long[] {miniBatch, inDepth, inH, inW}, 'c');
|
long[] actShape = nchw ? new long[] {miniBatch, inDepth, inH, inW} : new long[] {miniBatch, inH, inW, inDepth};
|
||||||
|
INDArray activations = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, x.dataType(), actShape, 'c');
|
||||||
|
|
||||||
val dstStride = ArrayUtil.toInts(activations.stride());
|
val dstStride = ArrayUtil.toInts(activations.stride());
|
||||||
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, inDepth, inH, inW,
|
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, inDepth, inH, inW,
|
||||||
dstStride[0], dstStride[1], dstStride[2], dstStride[3]));
|
dstStride[0], dstStride[chIdx], dstStride[hIdx], dstStride[wIdx]));
|
||||||
|
|
||||||
checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, TENSOR_FORMAT, toCudnnDataType(mean.data().dataType()), (int)shape[0],
|
checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, cudnnTensorFormat, toCudnnDataType(mean.data().dataType()), (int)shape[0],
|
||||||
(int)shape[1], shape.length > 2 ? (int)shape[2] : 1, shape.length > 3 ? (int)shape[3] : 1));
|
(int)shape[1], shape.length > 2 ? (int)shape[2] : 1, shape.length > 3 ? (int)shape[3] : 1));
|
||||||
|
|
||||||
Allocator allocator = AtomicAllocator.getInstance();
|
Allocator allocator = AtomicAllocator.getInstance();
|
||||||
|
|
|
@ -16,74 +16,131 @@
|
||||||
|
|
||||||
package org.deeplearning4j;
|
package org.deeplearning4j;
|
||||||
|
|
||||||
|
import org.apache.commons.compress.utils.IOUtils;
|
||||||
|
import org.deeplearning4j.nn.api.Layer;
|
||||||
|
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.BaseLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer;
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
|
import org.deeplearning4j.nn.layers.convolution.ConvolutionLayer;
|
||||||
|
import org.deeplearning4j.nn.layers.convolution.subsampling.SubsamplingLayer;
|
||||||
|
import org.deeplearning4j.nn.layers.normalization.BatchNormalization;
|
||||||
|
import org.deeplearning4j.nn.layers.normalization.LocalResponseNormalization;
|
||||||
|
import org.deeplearning4j.nn.layers.recurrent.LSTM;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.util.ModelSerializer;
|
import org.deeplearning4j.util.ModelSerializer;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
|
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.learning.regularization.L1Regularization;
|
||||||
|
import org.nd4j.linalg.learning.regularization.L2Regularization;
|
||||||
|
import org.nd4j.linalg.learning.regularization.Regularization;
|
||||||
|
import org.nd4j.linalg.learning.regularization.WeightDecay;
|
||||||
|
|
||||||
import java.io.ByteArrayInputStream;
|
import java.io.*;
|
||||||
import java.io.ByteArrayOutputStream;
|
import java.lang.reflect.Field;
|
||||||
import java.io.IOException;
|
import java.util.List;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.junit.Assert.assertNotNull;
|
||||||
|
|
||||||
public class TestUtils {
|
public class TestUtils {
|
||||||
|
|
||||||
public static MultiLayerNetwork testModelSerialization(MultiLayerNetwork net){
|
public static MultiLayerNetwork testModelSerialization(MultiLayerNetwork net){
|
||||||
|
|
||||||
|
MultiLayerNetwork restored;
|
||||||
try {
|
try {
|
||||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
||||||
ModelSerializer.writeModel(net, baos, true);
|
ModelSerializer.writeModel(net, baos, true);
|
||||||
byte[] bytes = baos.toByteArray();
|
byte[] bytes = baos.toByteArray();
|
||||||
|
|
||||||
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
|
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
|
||||||
MultiLayerNetwork restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);
|
restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);
|
||||||
|
|
||||||
assertEquals(net.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations());
|
assertEquals(net.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations());
|
||||||
assertEquals(net.params(), restored.params());
|
assertEquals(net.params(), restored.params());
|
||||||
|
|
||||||
return restored;
|
|
||||||
} catch (IOException e){
|
} catch (IOException e){
|
||||||
//Should never happen
|
//Should never happen
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//Also check the MultiLayerConfiguration is serializable (required by Spark etc)
|
||||||
|
MultiLayerConfiguration conf = net.getLayerWiseConfigurations();
|
||||||
|
serializeDeserializeJava(conf);
|
||||||
|
|
||||||
|
return restored;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static ComputationGraph testModelSerialization(ComputationGraph net){
|
public static ComputationGraph testModelSerialization(ComputationGraph net){
|
||||||
|
ComputationGraph restored;
|
||||||
try {
|
try {
|
||||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
||||||
ModelSerializer.writeModel(net, baos, true);
|
ModelSerializer.writeModel(net, baos, true);
|
||||||
byte[] bytes = baos.toByteArray();
|
byte[] bytes = baos.toByteArray();
|
||||||
|
|
||||||
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
|
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
|
||||||
ComputationGraph restored = ModelSerializer.restoreComputationGraph(bais, true);
|
restored = ModelSerializer.restoreComputationGraph(bais, true);
|
||||||
|
|
||||||
assertEquals(net.getConfiguration(), restored.getConfiguration());
|
assertEquals(net.getConfiguration(), restored.getConfiguration());
|
||||||
assertEquals(net.params(), restored.params());
|
assertEquals(net.params(), restored.params());
|
||||||
|
|
||||||
return restored;
|
|
||||||
} catch (IOException e){
|
} catch (IOException e){
|
||||||
//Should never happen
|
//Should never happen
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//Also check the ComputationGraphConfiguration is serializable (required by Spark etc)
|
||||||
|
ComputationGraphConfiguration conf = net.getConfiguration();
|
||||||
|
serializeDeserializeJava(conf);
|
||||||
|
|
||||||
|
return restored;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static INDArray randomOneHot(int examples, int nOut){
|
private static <T> T serializeDeserializeJava(T object){
|
||||||
|
byte[] bytes;
|
||||||
|
try(ByteArrayOutputStream baos = new ByteArrayOutputStream(); ObjectOutputStream oos = new ObjectOutputStream(baos)){
|
||||||
|
oos.writeObject(object);
|
||||||
|
oos.close();
|
||||||
|
bytes = baos.toByteArray();
|
||||||
|
} catch (IOException e){
|
||||||
|
//Should never happen
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
|
||||||
|
T out;
|
||||||
|
try(ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(bytes))){
|
||||||
|
out = (T)ois.readObject();
|
||||||
|
} catch (IOException | ClassNotFoundException e){
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
|
||||||
|
assertEquals(object, out);
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static INDArray randomOneHot(long examples, long nOut){
|
||||||
return randomOneHot(examples, nOut, new Random(12345));
|
return randomOneHot(examples, nOut, new Random(12345));
|
||||||
}
|
}
|
||||||
|
|
||||||
public static INDArray randomOneHot(int examples, int nOut, long rngSeed){
|
public static INDArray randomOneHot(DataType dataType, long examples, long nOut){
|
||||||
|
return randomOneHot(dataType, examples, nOut, new Random(12345));
|
||||||
|
}
|
||||||
|
|
||||||
|
public static INDArray randomOneHot(long examples, long nOut, long rngSeed){
|
||||||
return randomOneHot(examples, nOut, new Random(rngSeed));
|
return randomOneHot(examples, nOut, new Random(rngSeed));
|
||||||
}
|
}
|
||||||
|
|
||||||
public static INDArray randomOneHot(int examples, int nOut, Random rng){
|
public static INDArray randomOneHot(long examples, long nOut, Random rng) {
|
||||||
INDArray arr = Nd4j.create(examples, nOut);
|
return randomOneHot(Nd4j.defaultFloatingPointType(), examples,nOut, rng);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static INDArray randomOneHot(DataType dataType, long examples, long nOut, Random rng){
|
||||||
|
INDArray arr = Nd4j.create(dataType, examples, nOut);
|
||||||
for( int i=0; i<examples; i++ ){
|
for( int i=0; i<examples; i++ ){
|
||||||
arr.putScalar(i, rng.nextInt(nOut), 1.0);
|
arr.putScalar(i, rng.nextInt((int) nOut), 1.0);
|
||||||
}
|
}
|
||||||
return arr;
|
return arr;
|
||||||
}
|
}
|
||||||
|
@ -115,4 +172,143 @@ public class TestUtils {
|
||||||
Nd4j.getExecutioner().exec(new BernoulliDistribution(ret, p));
|
Nd4j.getExecutioner().exec(new BernoulliDistribution(ret, p));
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static void writeStreamToFile(File out, InputStream is) throws IOException {
|
||||||
|
byte[] b = IOUtils.toByteArray(is);
|
||||||
|
try (OutputStream os = new BufferedOutputStream(new FileOutputStream(out))) {
|
||||||
|
os.write(b);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static L1Regularization getL1Reg(List<Regularization> l){
|
||||||
|
for(Regularization r : l){
|
||||||
|
if(r instanceof L1Regularization){
|
||||||
|
return (L1Regularization) r;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static L2Regularization getL2Reg(BaseLayer baseLayer){
|
||||||
|
return getL2Reg(baseLayer.getRegularization());
|
||||||
|
}
|
||||||
|
|
||||||
|
public static L2Regularization getL2Reg(List<Regularization> l){
|
||||||
|
for(Regularization r : l){
|
||||||
|
if(r instanceof L2Regularization){
|
||||||
|
return (L2Regularization) r;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static WeightDecay getWeightDecayReg(BaseLayer bl){
|
||||||
|
return getWeightDecayReg(bl.getRegularization());
|
||||||
|
}
|
||||||
|
|
||||||
|
public static WeightDecay getWeightDecayReg(List<Regularization> l){
|
||||||
|
for(Regularization r : l){
|
||||||
|
if(r instanceof WeightDecay){
|
||||||
|
return (WeightDecay) r;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static double getL1(BaseLayer layer) {
|
||||||
|
List<Regularization> l = layer.getRegularization();
|
||||||
|
return getL1(l);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static double getL1(List<Regularization> l){
|
||||||
|
L1Regularization l1Reg = null;
|
||||||
|
for(Regularization reg : l){
|
||||||
|
if(reg instanceof L1Regularization)
|
||||||
|
l1Reg = (L1Regularization) reg;
|
||||||
|
}
|
||||||
|
assertNotNull(l1Reg);
|
||||||
|
return l1Reg.getL1().valueAt(0,0);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static double getL2(BaseLayer layer) {
|
||||||
|
List<Regularization> l = layer.getRegularization();
|
||||||
|
return getL2(l);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static double getL2(List<Regularization> l){
|
||||||
|
L2Regularization l2Reg = null;
|
||||||
|
for(Regularization reg : l){
|
||||||
|
if(reg instanceof L2Regularization)
|
||||||
|
l2Reg = (L2Regularization) reg;
|
||||||
|
}
|
||||||
|
assertNotNull(l2Reg);
|
||||||
|
return l2Reg.getL2().valueAt(0,0);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static double getL1(AbstractSameDiffLayer layer){
|
||||||
|
return getL1(layer.getRegularization());
|
||||||
|
}
|
||||||
|
|
||||||
|
public static double getL2(AbstractSameDiffLayer layer){
|
||||||
|
return getL2(layer.getRegularization());
|
||||||
|
}
|
||||||
|
|
||||||
|
public static double getWeightDecay(BaseLayer layer) {
|
||||||
|
return getWeightDecayReg(layer.getRegularization()).getCoeff().valueAt(0,0);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void removeHelper(Layer layer) throws Exception {
|
||||||
|
removeHelpers(new Layer[]{layer});
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void removeHelpers(Layer[] layers) throws Exception {
|
||||||
|
for(Layer l : layers){
|
||||||
|
|
||||||
|
if(l instanceof ConvolutionLayer){
|
||||||
|
Field f1 = ConvolutionLayer.class.getDeclaredField("helper");
|
||||||
|
f1.setAccessible(true);
|
||||||
|
f1.set(l, null);
|
||||||
|
} else if(l instanceof SubsamplingLayer){
|
||||||
|
Field f2 = SubsamplingLayer.class.getDeclaredField("helper");
|
||||||
|
f2.setAccessible(true);
|
||||||
|
f2.set(l, null);
|
||||||
|
} else if(l instanceof BatchNormalization) {
|
||||||
|
Field f3 = BatchNormalization.class.getDeclaredField("helper");
|
||||||
|
f3.setAccessible(true);
|
||||||
|
f3.set(l, null);
|
||||||
|
} else if(l instanceof LSTM){
|
||||||
|
Field f4 = LSTM.class.getDeclaredField("helper");
|
||||||
|
f4.setAccessible(true);
|
||||||
|
f4.set(l, null);
|
||||||
|
} else if(l instanceof LocalResponseNormalization){
|
||||||
|
Field f5 = LocalResponseNormalization.class.getDeclaredField("helper");
|
||||||
|
f5.setAccessible(true);
|
||||||
|
f5.set(l, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if(l.getHelper() != null){
|
||||||
|
throw new IllegalStateException("Did not remove helper for layer: " + l.getClass().getSimpleName());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void assertHelperPresent(Layer layer){
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void assertHelpersPresent(Layer[] layers) throws Exception {
|
||||||
|
for(Layer l : layers){
|
||||||
|
//Don't use instanceof here - there are sub conv subclasses
|
||||||
|
if(l.getClass() == ConvolutionLayer.class || l instanceof SubsamplingLayer || l instanceof BatchNormalization || l instanceof LSTM){
|
||||||
|
Preconditions.checkNotNull(l.getHelper(), l.conf().getLayer().getLayerName());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void assertHelpersAbsent(Layer[] layers) throws Exception {
|
||||||
|
for(Layer l : layers){
|
||||||
|
Preconditions.checkState(l.getHelper() == null, l.conf().getLayer().getLayerName());
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -212,7 +212,7 @@ public class TestConvolution extends BaseDL4JTest {
|
||||||
ComputationGraph model = KerasModelImport.importKerasModelAndWeights( fExtracted.getAbsolutePath(), new int[]{inSize, inSize, 3}, false);
|
ComputationGraph model = KerasModelImport.importKerasModelAndWeights( fExtracted.getAbsolutePath(), new int[]{inSize, inSize, 3}, false);
|
||||||
model = model.convertDataType(DataType.DOUBLE);
|
model = model.convertDataType(DataType.DOUBLE);
|
||||||
|
|
||||||
INDArray in = Nd4j.rand(DataType.DOUBLE, new int[]{1, 3, inSize, inSize});
|
INDArray in = Nd4j.rand(DataType.DOUBLE, new int[]{1, inSize, inSize, 3}); //Keras import model -> NHWC
|
||||||
|
|
||||||
CuDNNTestUtils.assertHelpersPresent(model.getLayers());
|
CuDNNTestUtils.assertHelpersPresent(model.getLayers());
|
||||||
Map<String,INDArray> withCudnn = model.feedForward(in, false);
|
Map<String,INDArray> withCudnn = model.feedForward(in, false);
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.datasets.fetchers;
|
package org.deeplearning4j.datasets.fetchers;
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
import org.apache.commons.io.FilenameUtils;
|
import org.apache.commons.io.FilenameUtils;
|
||||||
import org.deeplearning4j.base.EmnistFetcher;
|
import org.deeplearning4j.base.EmnistFetcher;
|
||||||
|
@ -36,6 +37,7 @@ import java.util.Random;
|
||||||
* @author Alex Black
|
* @author Alex Black
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
|
@Slf4j
|
||||||
public class EmnistDataFetcher extends MnistDataFetcher implements DataSetFetcher {
|
public class EmnistDataFetcher extends MnistDataFetcher implements DataSetFetcher {
|
||||||
|
|
||||||
protected EmnistFetcher fetcher;
|
protected EmnistFetcher fetcher;
|
||||||
|
@ -64,7 +66,7 @@ public class EmnistDataFetcher extends MnistDataFetcher implements DataSetFetche
|
||||||
try {
|
try {
|
||||||
man = new MnistManager(images, labels, totalExamples);
|
man = new MnistManager(images, labels, totalExamples);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
FileUtils.deleteDirectory(new File(EMNIST_ROOT));
|
FileUtils.deleteDirectory(new File(EMNIST_ROOT));
|
||||||
new EmnistFetcher(dataSet).downloadAndUntar();
|
new EmnistFetcher(dataSet).downloadAndUntar();
|
||||||
man = new MnistManager(images, labels, totalExamples);
|
man = new MnistManager(images, labels, totalExamples);
|
||||||
|
|
|
@ -687,9 +687,7 @@ public class BarnesHutTsne implements Model {
|
||||||
* @throws IOException
|
* @throws IOException
|
||||||
*/
|
*/
|
||||||
public void saveAsFile(List<String> labels, String path) throws IOException {
|
public void saveAsFile(List<String> labels, String path) throws IOException {
|
||||||
BufferedWriter write = null;
|
try (BufferedWriter write = new BufferedWriter(new FileWriter(new File(path)))) {
|
||||||
try {
|
|
||||||
write = new BufferedWriter(new FileWriter(new File(path)));
|
|
||||||
for (int i = 0; i < Y.rows(); i++) {
|
for (int i = 0; i < Y.rows(); i++) {
|
||||||
if (i >= labels.size())
|
if (i >= labels.size())
|
||||||
break;
|
break;
|
||||||
|
@ -711,17 +709,11 @@ public class BarnesHutTsne implements Model {
|
||||||
|
|
||||||
}
|
}
|
||||||
write.flush();
|
write.flush();
|
||||||
write.close();
|
|
||||||
} finally {
|
|
||||||
if (write != null)
|
|
||||||
write.close();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public void saveAsFile(String path) throws IOException {
|
public void saveAsFile(String path) throws IOException {
|
||||||
BufferedWriter write = null;
|
try (BufferedWriter write = new BufferedWriter(new FileWriter(new File(path)))) {
|
||||||
try {
|
|
||||||
write = new BufferedWriter(new FileWriter(new File(path)));
|
|
||||||
for (int i = 0; i < Y.rows(); i++) {
|
for (int i = 0; i < Y.rows(); i++) {
|
||||||
StringBuilder sb = new StringBuilder();
|
StringBuilder sb = new StringBuilder();
|
||||||
INDArray wordVector = Y.getRow(i);
|
INDArray wordVector = Y.getRow(i);
|
||||||
|
@ -734,10 +726,6 @@ public class BarnesHutTsne implements Model {
|
||||||
write.write(sb.toString());
|
write.write(sb.toString());
|
||||||
}
|
}
|
||||||
write.flush();
|
write.flush();
|
||||||
write.close();
|
|
||||||
} finally {
|
|
||||||
if (write != null)
|
|
||||||
write.close();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -113,6 +113,12 @@
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.datavec</groupId>
|
||||||
|
<artifactId>datavec-python</artifactId>
|
||||||
|
<version>${datavec.version}</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
|
|
|
@ -60,7 +60,7 @@ public class Hdf5Archive implements Closeable {
|
||||||
/* This is necessary for the call to the BytePointer constructor below. */
|
/* This is necessary for the call to the BytePointer constructor below. */
|
||||||
Loader.load(org.bytedeco.hdf5.global.hdf5.class);
|
Loader.load(org.bytedeco.hdf5.global.hdf5.class);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
e.printStackTrace();
|
log.error("",e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -19,6 +19,9 @@ package org.deeplearning4j.nn.modelimport.keras.layers;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.lang3.ArrayUtils;
|
||||||
|
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||||
|
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.Convolution3D;
|
import org.deeplearning4j.nn.conf.layers.Convolution3D;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
||||||
|
@ -121,27 +124,29 @@ public class KerasInput extends KerasLayer {
|
||||||
InputType myInputType;
|
InputType myInputType;
|
||||||
switch (this.inputShape.length) {
|
switch (this.inputShape.length) {
|
||||||
case 1:
|
case 1:
|
||||||
myInputType = new InputType.InputTypeFeedForward(this.inputShape[0]);
|
myInputType = new InputType.InputTypeFeedForward(this.inputShape[0], null);
|
||||||
break;
|
break;
|
||||||
case 2:
|
case 2:
|
||||||
if(this.dimOrder != null) {
|
if(this.dimOrder != null) {
|
||||||
|
System.out.println("Dim order: " + this.dimOrder);
|
||||||
|
System.out.println("Input shape: " + ArrayUtils.toString(this.inputShape));
|
||||||
switch (this.dimOrder) {
|
switch (this.dimOrder) {
|
||||||
case TENSORFLOW: //NWC == channels_last
|
case TENSORFLOW: //NWC == channels_last
|
||||||
myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0]);
|
myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0], RNNFormat.NWC);
|
||||||
break;
|
break;
|
||||||
case THEANO: //NCW == channels_first
|
case THEANO: //NCW == channels_first
|
||||||
myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]);
|
myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1], RNNFormat.NCW);
|
||||||
break;
|
break;
|
||||||
case NONE:
|
case NONE:
|
||||||
//Assume RNN in [mb, seqLen, size] format
|
//Assume RNN in [mb, seqLen, size] format
|
||||||
myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]);
|
myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0], RNNFormat.NWC);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
throw new IllegalStateException("Unknown/not supported dimension ordering: " + this.dimOrder);
|
throw new IllegalStateException("Unknown/not supported dimension ordering: " + this.dimOrder);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
//Assume RNN in [mb, seqLen, size] format
|
//Assume RNN in [mb, seqLen, size] format
|
||||||
myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]);
|
myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0], RNNFormat.NWC);
|
||||||
}
|
}
|
||||||
|
|
||||||
break;
|
break;
|
||||||
|
@ -150,17 +155,17 @@ public class KerasInput extends KerasLayer {
|
||||||
case TENSORFLOW:
|
case TENSORFLOW:
|
||||||
/* TensorFlow convolutional input: # rows, # cols, # channels */
|
/* TensorFlow convolutional input: # rows, # cols, # channels */
|
||||||
myInputType = new InputType.InputTypeConvolutional(this.inputShape[0], this.inputShape[1],
|
myInputType = new InputType.InputTypeConvolutional(this.inputShape[0], this.inputShape[1],
|
||||||
this.inputShape[2]);
|
this.inputShape[2], CNN2DFormat.NHWC);
|
||||||
break;
|
break;
|
||||||
case THEANO:
|
case THEANO:
|
||||||
/* Theano convolutional input: # channels, # rows, # cols */
|
/* Theano convolutional input: # channels, # rows, # cols */
|
||||||
myInputType = new InputType.InputTypeConvolutional(this.inputShape[1], this.inputShape[2],
|
myInputType = new InputType.InputTypeConvolutional(this.inputShape[1], this.inputShape[2],
|
||||||
this.inputShape[0]);
|
this.inputShape[0], CNN2DFormat.NCHW);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
this.dimOrder = DimOrder.THEANO;
|
this.dimOrder = DimOrder.THEANO;
|
||||||
myInputType = new InputType.InputTypeConvolutional(this.inputShape[1], this.inputShape[2],
|
myInputType = new InputType.InputTypeConvolutional(this.inputShape[1], this.inputShape[2],
|
||||||
this.inputShape[0]);
|
this.inputShape[0], CNN2DFormat.NCHW);
|
||||||
log.warn("Couldn't determine dim ordering / data format from model file. Older Keras " +
|
log.warn("Couldn't determine dim ordering / data format from model file. Older Keras " +
|
||||||
"versions may come without specified backend, in which case we assume the model was " +
|
"versions may come without specified backend, in which case we assume the model was " +
|
||||||
"built with theano." );
|
"built with theano." );
|
||||||
|
|
|
@ -20,6 +20,7 @@ import org.deeplearning4j.nn.api.ParamInitializer;
|
||||||
import org.deeplearning4j.nn.conf.GradientNormalization;
|
import org.deeplearning4j.nn.conf.GradientNormalization;
|
||||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.Layer;
|
import org.deeplearning4j.nn.conf.layers.Layer;
|
||||||
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
|
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
|
||||||
|
@ -65,6 +66,9 @@ public class TFOpLayer extends Layer {
|
||||||
long[] shape = inputType.getShape(true);
|
long[] shape = inputType.getShape(true);
|
||||||
TFOpLayerImpl tempLayer = new TFOpLayerImpl(nodeDef, constants, null, null);
|
TFOpLayerImpl tempLayer = new TFOpLayerImpl(nodeDef, constants, null, null);
|
||||||
long[] outputShape = tempLayer.getOutputShape(shape);
|
long[] outputShape = tempLayer.getOutputShape(shape);
|
||||||
|
if (outputShape.length == 3){
|
||||||
|
return InputType.recurrent(outputShape[2], outputShape[1], RNNFormat.NWC);
|
||||||
|
}
|
||||||
return InputType.inferInputType(Nd4j.create(outputShape));
|
return InputType.inferInputType(Nd4j.create(outputShape));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -125,17 +125,9 @@ public class TFOpLayerImpl extends AbstractLayer<TFOpLayer> {
|
||||||
}
|
}
|
||||||
|
|
||||||
private INDArray runGraph(INDArray input){
|
private INDArray runGraph(INDArray input){
|
||||||
if (input.rank() == 3){
|
|
||||||
// TODO make this a preprocessor
|
|
||||||
input = input.permute(0, 2, 1);
|
|
||||||
}
|
|
||||||
Map<String, INDArray> inputMap = new HashMap<>();
|
Map<String, INDArray> inputMap = new HashMap<>();
|
||||||
inputMap.put(inputNames.get(0), input);
|
inputMap.put(inputNames.get(0), input);
|
||||||
INDArray out = graphRunnerService.run(inputMap).values().toArray(new INDArray[0])[0];
|
INDArray out = graphRunnerService.run(inputMap).values().toArray(new INDArray[0])[0];
|
||||||
if (out.rank() == 3){
|
|
||||||
out = out.permute(0, 2, 1); // TODO post-processing?
|
|
||||||
}
|
|
||||||
|
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,7 @@ import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.ArrayUtils;
|
import org.apache.commons.lang3.ArrayUtils;
|
||||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||||
|
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
|
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
|
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
|
||||||
|
@ -94,7 +95,6 @@ public class KerasConvolution1D extends KerasConvolution {
|
||||||
|
|
||||||
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||||
enforceTrainingConfig, conf, kerasMajorVersion);
|
enforceTrainingConfig, conf, kerasMajorVersion);
|
||||||
|
|
||||||
Convolution1DLayer.Builder builder = new Convolution1DLayer.Builder().name(this.layerName)
|
Convolution1DLayer.Builder builder = new Convolution1DLayer.Builder().name(this.layerName)
|
||||||
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
||||||
.activation(getIActivationFromConfig(layerConfig, conf))
|
.activation(getIActivationFromConfig(layerConfig, conf))
|
||||||
|
@ -103,7 +103,7 @@ public class KerasConvolution1D extends KerasConvolution {
|
||||||
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
|
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
|
||||||
.kernelSize(getKernelSizeFromConfig(layerConfig, 1, conf, kerasMajorVersion)[0])
|
.kernelSize(getKernelSizeFromConfig(layerConfig, 1, conf, kerasMajorVersion)[0])
|
||||||
.hasBias(hasBias)
|
.hasBias(hasBias)
|
||||||
.stride(getStrideFromConfig(layerConfig, 1, conf)[0]);
|
.stride(getStrideFromConfig(layerConfig, 1, conf)[0]).rnnDataFormat(dimOrder == DimOrder.TENSORFLOW? RNNFormat.NWC: RNNFormat.NCW);
|
||||||
int[] padding = getPaddingFromBorderModeConfig(layerConfig, 1, conf, kerasMajorVersion);
|
int[] padding = getPaddingFromBorderModeConfig(layerConfig, 1, conf, kerasMajorVersion);
|
||||||
if (hasBias)
|
if (hasBias)
|
||||||
builder.biasInit(0.0);
|
builder.biasInit(0.0);
|
||||||
|
@ -160,8 +160,8 @@ public class KerasConvolution1D extends KerasConvolution {
|
||||||
public InputPreProcessor getInputPreprocessor(InputType... inputType) throws InvalidKerasConfigurationException {
|
public InputPreProcessor getInputPreprocessor(InputType... inputType) throws InvalidKerasConfigurationException {
|
||||||
if (inputType.length > 1)
|
if (inputType.length > 1)
|
||||||
throw new InvalidKerasConfigurationException(
|
throw new InvalidKerasConfigurationException(
|
||||||
"Keras LSTM layer accepts only one input (received " + inputType.length + ")");
|
"Keras Conv1D layer accepts only one input (received " + inputType.length + ")");
|
||||||
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], layerName);
|
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], RNNFormat.NCW,layerName);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -20,6 +20,7 @@ import lombok.Data;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||||
|
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
||||||
|
@ -27,6 +28,7 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurat
|
||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
|
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
|
||||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||||
|
import oshi.jna.platform.windows.PowrProf;
|
||||||
|
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
|
@ -93,6 +95,7 @@ public class KerasConvolution2D extends KerasConvolution {
|
||||||
LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig(
|
LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig(
|
||||||
layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion);
|
layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion);
|
||||||
|
|
||||||
|
System.out.println("----" + dimOrder);
|
||||||
ConvolutionLayer.Builder builder = new ConvolutionLayer.Builder().name(this.layerName)
|
ConvolutionLayer.Builder builder = new ConvolutionLayer.Builder().name(this.layerName)
|
||||||
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
||||||
.activation(getIActivationFromConfig(layerConfig, conf))
|
.activation(getIActivationFromConfig(layerConfig, conf))
|
||||||
|
@ -101,7 +104,8 @@ public class KerasConvolution2D extends KerasConvolution {
|
||||||
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
|
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
|
||||||
.kernelSize(getKernelSizeFromConfig(layerConfig, 2, conf, kerasMajorVersion))
|
.kernelSize(getKernelSizeFromConfig(layerConfig, 2, conf, kerasMajorVersion))
|
||||||
.hasBias(hasBias)
|
.hasBias(hasBias)
|
||||||
.stride(getStrideFromConfig(layerConfig, 2, conf));
|
.stride(getStrideFromConfig(layerConfig, 2, conf))
|
||||||
|
.dataFormat((dimOrder==DimOrder.TENSORFLOW)? CNN2DFormat.NHWC:CNN2DFormat.NCHW);
|
||||||
int[] padding = getPaddingFromBorderModeConfig(layerConfig, 2, conf, kerasMajorVersion);
|
int[] padding = getPaddingFromBorderModeConfig(layerConfig, 2, conf, kerasMajorVersion);
|
||||||
if (hasBias)
|
if (hasBias)
|
||||||
builder.biasInit(0.0);
|
builder.biasInit(0.0);
|
||||||
|
|
|
@ -360,8 +360,19 @@ public class KerasConvolutionUtils {
|
||||||
}
|
}
|
||||||
|
|
||||||
} else if (dimension == 1) {
|
} else if (dimension == 1) {
|
||||||
int paddingInt = (int) innerConfig.get(layerField);
|
Object paddingObj = innerConfig.get(layerField);
|
||||||
padding = new int[]{paddingInt, paddingInt};
|
if (paddingObj instanceof List){
|
||||||
|
List<Integer> paddingList = (List)paddingObj;
|
||||||
|
padding = new int[]{
|
||||||
|
paddingList.get(0),
|
||||||
|
paddingList.get(1)
|
||||||
|
};
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
int paddingInt = (int) innerConfig.get(layerField);
|
||||||
|
padding = new int[]{paddingInt, paddingInt};
|
||||||
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
throw new UnsupportedKerasConfigurationException(
|
throw new UnsupportedKerasConfigurationException(
|
||||||
"Keras padding layer not supported");
|
"Keras padding layer not supported");
|
||||||
|
|
|
@ -18,6 +18,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.core;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
|
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType.InputTypeConvolutional;
|
import org.deeplearning4j.nn.conf.inputs.InputType.InputTypeConvolutional;
|
||||||
|
@ -27,7 +28,6 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurat
|
||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor;
|
import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
|
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.TensorFlowCnnToFeedForwardPreProcessor;
|
|
||||||
|
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
|
@ -93,11 +93,10 @@ public class KerasFlatten extends KerasLayer {
|
||||||
switch (this.getDimOrder()) {
|
switch (this.getDimOrder()) {
|
||||||
case NONE:
|
case NONE:
|
||||||
case THEANO:
|
case THEANO:
|
||||||
preprocessor = new CnnToFeedForwardPreProcessor(it.getHeight(), it.getWidth(), it.getChannels());
|
preprocessor = new CnnToFeedForwardPreProcessor(it.getHeight(), it.getWidth(), it.getChannels(), CNN2DFormat.NCHW);
|
||||||
break;
|
break;
|
||||||
case TENSORFLOW:
|
case TENSORFLOW:
|
||||||
preprocessor = new TensorFlowCnnToFeedForwardPreProcessor(it.getHeight(), it.getWidth(),
|
preprocessor = new CnnToFeedForwardPreProcessor(it.getHeight(), it.getWidth(), it.getChannels(), CNN2DFormat.NHWC);
|
||||||
it.getChannels());
|
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
throw new InvalidKerasConfigurationException("Unknown Keras backend " + this.getDimOrder());
|
throw new InvalidKerasConfigurationException("Unknown Keras backend " + this.getDimOrder());
|
||||||
|
@ -111,7 +110,7 @@ public class KerasFlatten extends KerasLayer {
|
||||||
// to RNN type. Otherwise we add this trivial preprocessor (since there's nothing to flatten).
|
// to RNN type. Otherwise we add this trivial preprocessor (since there's nothing to flatten).
|
||||||
InputType.InputTypeFeedForward it = (InputType.InputTypeFeedForward) inputType[0];
|
InputType.InputTypeFeedForward it = (InputType.InputTypeFeedForward) inputType[0];
|
||||||
val inputShape = new long[]{it.getSize()};
|
val inputShape = new long[]{it.getSize()};
|
||||||
preprocessor = new ReshapePreprocessor(inputShape, inputShape, false);
|
preprocessor = new ReshapePreprocessor(inputShape, inputShape, false, null);
|
||||||
}
|
}
|
||||||
return preprocessor;
|
return preprocessor;
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.deeplearning4j.nn.modelimport.keras.layers.core;
|
package org.deeplearning4j.nn.modelimport.keras.layers.core;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.misc.RepeatVector;
|
import org.deeplearning4j.nn.conf.layers.misc.RepeatVector;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
||||||
|
@ -60,6 +61,7 @@ public class KerasRepeatVector extends KerasLayer {
|
||||||
super(layerConfig, enforceTrainingConfig);
|
super(layerConfig, enforceTrainingConfig);
|
||||||
|
|
||||||
this.layer = new RepeatVector.Builder().repetitionFactor(getRepeatMultiplier(layerConfig, conf))
|
this.layer = new RepeatVector.Builder().repetitionFactor(getRepeatMultiplier(layerConfig, conf))
|
||||||
|
.dataFormat(RNNFormat.NWC)
|
||||||
.name(this.layerName).build();
|
.name(this.layerName).build();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.core;
|
||||||
|
|
||||||
|
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
|
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
||||||
|
@ -111,11 +112,9 @@ public class KerasReshape extends KerasLayer {
|
||||||
} else {
|
} else {
|
||||||
targetShape = new long[]{targetShape[1], targetShape[0], targetShape[2]};
|
targetShape = new long[]{targetShape[1], targetShape[0], targetShape[2]};
|
||||||
}
|
}
|
||||||
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false);
|
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false, CNN2DFormat.NCHW);
|
||||||
} else { // (dimOrder == DimOrder.TENSORFLOW || dimOrder == DimOrder.NONE && kerasMajorVersion == 2)
|
} else { // (dimOrder == DimOrder.TENSORFLOW || dimOrder == DimOrder.NONE && kerasMajorVersion == 2)
|
||||||
if (inputShape[0] != targetShape[0])
|
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false, CNN2DFormat.NHWC);
|
||||||
targetShape = new long[]{targetShape[2], targetShape[0], targetShape[1]};
|
|
||||||
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} else if (inputType[0] instanceof InputType.InputTypeConvolutional3D) {
|
} else if (inputType[0] instanceof InputType.InputTypeConvolutional3D) {
|
||||||
|
@ -128,23 +127,23 @@ public class KerasReshape extends KerasLayer {
|
||||||
} else {
|
} else {
|
||||||
targetShape = new long[] { targetShape[2], targetShape[1], targetShape[0], targetShape[3] };
|
targetShape = new long[] { targetShape[2], targetShape[1], targetShape[0], targetShape[3] };
|
||||||
}
|
}
|
||||||
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false);
|
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false, null);
|
||||||
} else {
|
} else {
|
||||||
if (inputShape[0] != targetShape[0])
|
if (inputShape[0] != targetShape[0])
|
||||||
targetShape = new long[] { targetShape[3], targetShape[0], targetShape[1], targetShape[2] };
|
targetShape = new long[] { targetShape[3], targetShape[0], targetShape[1], targetShape[2] };
|
||||||
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false);
|
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false, null);
|
||||||
}
|
}
|
||||||
} else if (inputType[0] instanceof InputType.InputTypeRecurrent) {
|
} else if (inputType[0] instanceof InputType.InputTypeRecurrent) {
|
||||||
InputType.InputTypeRecurrent it = (InputType.InputTypeRecurrent) inputType[0];
|
InputType.InputTypeRecurrent it = (InputType.InputTypeRecurrent) inputType[0];
|
||||||
val inputShape = new long[]{it.getSize(), it.getTimeSeriesLength()};
|
val inputShape = new long[]{it.getSize(), it.getTimeSeriesLength()};
|
||||||
preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false);
|
preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false, null);
|
||||||
} else if (inputType[0] instanceof InputType.InputTypeFeedForward) {
|
} else if (inputType[0] instanceof InputType.InputTypeFeedForward) {
|
||||||
InputType.InputTypeFeedForward it = (InputType.InputTypeFeedForward) inputType[0];
|
InputType.InputTypeFeedForward it = (InputType.InputTypeFeedForward) inputType[0];
|
||||||
val inputShape = new long[]{it.getSize()};
|
val inputShape = new long[]{it.getSize()};
|
||||||
if (targetShape.length == 3) {
|
if (targetShape.length == 3) {
|
||||||
targetShape = targetShapeForDimOrder(inputShape, targetShape);
|
targetShape = targetShapeForDimOrder(inputShape, targetShape);
|
||||||
}
|
}
|
||||||
preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false);
|
preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false, null);
|
||||||
}
|
}
|
||||||
return preprocessor;
|
return preprocessor;
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,6 +21,7 @@ import lombok.EqualsAndHashCode;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||||
|
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer;
|
import org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
||||||
|
@ -121,6 +122,7 @@ public class KerasEmbedding extends KerasLayer {
|
||||||
.biasInit(0.0)
|
.biasInit(0.0)
|
||||||
.l1(this.weightL1Regularization)
|
.l1(this.weightL1Regularization)
|
||||||
.l2(this.weightL2Regularization)
|
.l2(this.weightL2Regularization)
|
||||||
|
.outputDataFormat(RNNFormat.NWC)
|
||||||
.hasBias(false);
|
.hasBias(false);
|
||||||
if (embeddingConstraint != null)
|
if (embeddingConstraint != null)
|
||||||
builder.constrainWeights(embeddingConstraint);
|
builder.constrainWeights(embeddingConstraint);
|
||||||
|
|
|
@ -22,11 +22,9 @@ import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||||
|
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
|
|
||||||
import org.deeplearning4j.nn.conf.layers.LSTM;
|
|
||||||
import org.deeplearning4j.nn.conf.layers.Layer;
|
|
||||||
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
|
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
|
||||||
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
|
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
|
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
|
||||||
|
@ -37,6 +35,7 @@ import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
|
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
|
||||||
import org.deeplearning4j.nn.params.LSTMParamInitializer;
|
import org.deeplearning4j.nn.params.LSTMParamInitializer;
|
||||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||||
|
import org.deeplearning4j.util.TimeSeriesUtils;
|
||||||
import org.nd4j.linalg.activations.IActivation;
|
import org.nd4j.linalg.activations.IActivation;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -187,7 +186,7 @@ public class KerasLSTM extends KerasLayer {
|
||||||
.weightInitRecurrent(recurrentInit)
|
.weightInitRecurrent(recurrentInit)
|
||||||
.biasInit(0.0) // TODO: this is incorrect
|
.biasInit(0.0) // TODO: this is incorrect
|
||||||
.l1(this.weightL1Regularization)
|
.l1(this.weightL1Regularization)
|
||||||
.l2(this.weightL2Regularization);
|
.l2(this.weightL2Regularization).dataFormat(RNNFormat.NWC);
|
||||||
Integer nIn = KerasLayerUtils.getNInFromInputDim(layerConfig, conf);
|
Integer nIn = KerasLayerUtils.getNInFromInputDim(layerConfig, conf);
|
||||||
if(nIn != null)
|
if(nIn != null)
|
||||||
builder.setNIn(nIn);
|
builder.setNIn(nIn);
|
||||||
|
@ -266,7 +265,8 @@ public class KerasLSTM extends KerasLayer {
|
||||||
throw new InvalidKerasConfigurationException("Keras LSTM layer accepts only one single input" +
|
throw new InvalidKerasConfigurationException("Keras LSTM layer accepts only one single input" +
|
||||||
"or three (input to LSTM and two states tensors, but " +
|
"or three (input to LSTM and two states tensors, but " +
|
||||||
"received " + inputType.length + ".");
|
"received " + inputType.length + ".");
|
||||||
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], layerName);
|
RNNFormat f = TimeSeriesUtils.getFormatFromRnnLayer(layer);
|
||||||
|
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], f,layerName);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -21,7 +21,9 @@ import lombok.EqualsAndHashCode;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||||
|
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
|
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
|
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
|
||||||
import org.deeplearning4j.nn.conf.layers.Layer;
|
import org.deeplearning4j.nn.conf.layers.Layer;
|
||||||
|
@ -36,6 +38,7 @@ import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
|
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
|
||||||
import org.deeplearning4j.nn.params.SimpleRnnParamInitializer;
|
import org.deeplearning4j.nn.params.SimpleRnnParamInitializer;
|
||||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||||
|
import org.deeplearning4j.util.TimeSeriesUtils;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
|
||||||
|
@ -155,7 +158,7 @@ public class KerasSimpleRnn extends KerasLayer {
|
||||||
.weightInitRecurrent(recurrentInit)
|
.weightInitRecurrent(recurrentInit)
|
||||||
.biasInit(0.0)
|
.biasInit(0.0)
|
||||||
.l1(this.weightL1Regularization)
|
.l1(this.weightL1Regularization)
|
||||||
.l2(this.weightL2Regularization);
|
.l2(this.weightL2Regularization).dataFormat(RNNFormat.NWC);
|
||||||
Integer nIn = KerasLayerUtils.getNInFromInputDim(layerConfig, conf);
|
Integer nIn = KerasLayerUtils.getNInFromInputDim(layerConfig, conf);
|
||||||
if(nIn != null)
|
if(nIn != null)
|
||||||
builder.setNIn(nIn);
|
builder.setNIn(nIn);
|
||||||
|
@ -227,7 +230,8 @@ public class KerasSimpleRnn extends KerasLayer {
|
||||||
throw new InvalidKerasConfigurationException(
|
throw new InvalidKerasConfigurationException(
|
||||||
"Keras SimpleRnn layer accepts only one input (received " + inputType.length + ")");
|
"Keras SimpleRnn layer accepts only one input (received " + inputType.length + ")");
|
||||||
|
|
||||||
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], layerName);
|
RNNFormat f = TimeSeriesUtils.getFormatFromRnnLayer(layer);
|
||||||
|
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], f, layerName);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -147,7 +147,7 @@ public class KerasBidirectional extends KerasLayer {
|
||||||
break;
|
break;
|
||||||
case "SimpleRNN":
|
case "SimpleRNN":
|
||||||
kerasRnnlayer = new KerasSimpleRnn(innerRnnConfig, enforceTrainingConfig, previousLayers);
|
kerasRnnlayer = new KerasSimpleRnn(innerRnnConfig, enforceTrainingConfig, previousLayers);
|
||||||
SimpleRnn rnnLayer = (SimpleRnn) ((KerasSimpleRnn) kerasRnnlayer).getSimpleRnnLayer();
|
Layer rnnLayer = ((KerasSimpleRnn) kerasRnnlayer).getSimpleRnnLayer();
|
||||||
this.layer = new Bidirectional(mode, rnnLayer);
|
this.layer = new Bidirectional(mode, rnnLayer);
|
||||||
layer.setLayerName(layerName);
|
layer.setLayerName(layerName);
|
||||||
break;
|
break;
|
||||||
|
@ -218,7 +218,7 @@ public class KerasBidirectional extends KerasLayer {
|
||||||
if (inputType.length > 1)
|
if (inputType.length > 1)
|
||||||
throw new InvalidKerasConfigurationException(
|
throw new InvalidKerasConfigurationException(
|
||||||
"Keras Bidirectional layer accepts only one input (received " + inputType.length + ")");
|
"Keras Bidirectional layer accepts only one input (received " + inputType.length + ")");
|
||||||
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], layerName);
|
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], ((Bidirectional)layer).getRNNDataFormat(), layerName);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -21,6 +21,9 @@ import lombok.Data;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
|
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||||
|
import org.deeplearning4j.nn.conf.DataFormat;
|
||||||
|
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
|
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
|
||||||
import org.deeplearning4j.nn.conf.preprocessor.BaseInputPreProcessor;
|
import org.deeplearning4j.nn.conf.preprocessor.BaseInputPreProcessor;
|
||||||
|
@ -54,25 +57,30 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
|
||||||
private final long[] inputShape;
|
private final long[] inputShape;
|
||||||
private final long[] targetShape;
|
private final long[] targetShape;
|
||||||
private boolean hasMiniBatchDimension;
|
private boolean hasMiniBatchDimension;
|
||||||
|
private DataFormat format;
|
||||||
/**
|
|
||||||
* @deprecated Use constructor {@link #ReshapePreprocessor(long[], long[], boolean)}
|
|
||||||
*/
|
|
||||||
@Deprecated
|
|
||||||
public ReshapePreprocessor(long[] inputShape, long[] targetShape) {
|
|
||||||
this(inputShape, targetShape, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param inputShape Input shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension
|
* @param inputShape Input shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension
|
||||||
* @param targetShape Target shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension
|
* @param targetShape Target shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension
|
||||||
* @param hasMiniBatchDimension If true: shapes should be of the form [minibatch, x, y, ...]; if false: shapes should be of form [x, y, ...]
|
* @param hasMiniBatchDimension If true: shapes should be of the form [minibatch, x, y, ...]; if false: shapes should be of form [x, y, ...]
|
||||||
*/
|
*/
|
||||||
|
public ReshapePreprocessor(long[] inputShape, long[] targetShape, boolean hasMiniBatchDimension) {
|
||||||
|
this(inputShape, targetShape, hasMiniBatchDimension, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param inputShape Input shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension
|
||||||
|
* @param targetShape Target shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension
|
||||||
|
* @param hasMiniBatchDimension If true: shapes should be of the form [minibatch, x, y, ...]; if false: shapes should be of form [x, y, ...]
|
||||||
|
* @param dataFormat May be null. If non-null:
|
||||||
|
*/
|
||||||
public ReshapePreprocessor(@JsonProperty("inputShape") long[] inputShape, @JsonProperty("targetShape") long[] targetShape,
|
public ReshapePreprocessor(@JsonProperty("inputShape") long[] inputShape, @JsonProperty("targetShape") long[] targetShape,
|
||||||
@JsonProperty("hasMiniBatchDimension") boolean hasMiniBatchDimension) {
|
@JsonProperty("hasMiniBatchDimension") boolean hasMiniBatchDimension,
|
||||||
|
@JsonProperty("dataFormat") DataFormat dataFormat) {
|
||||||
this.inputShape = inputShape;
|
this.inputShape = inputShape;
|
||||||
this.targetShape = targetShape;
|
this.targetShape = targetShape;
|
||||||
this.hasMiniBatchDimension = hasMiniBatchDimension;
|
this.hasMiniBatchDimension = hasMiniBatchDimension;
|
||||||
|
this.format = dataFormat;
|
||||||
}
|
}
|
||||||
|
|
||||||
private long[] getShape(long[] originalShape, long minibatch) {
|
private long[] getShape(long[] originalShape, long minibatch) {
|
||||||
|
@ -140,13 +148,26 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
|
||||||
ret = InputType.feedForward(shape[1]);
|
ret = InputType.feedForward(shape[1]);
|
||||||
break;
|
break;
|
||||||
case 3:
|
case 3:
|
||||||
ret = InputType.recurrent(shape[2], shape[1]);
|
RNNFormat format = RNNFormat.NCW;
|
||||||
|
if(this.format != null && this.format instanceof RNNFormat)
|
||||||
|
format = (RNNFormat)this.format;
|
||||||
|
|
||||||
|
ret = InputType.recurrent(shape[2], shape[1], format);
|
||||||
break;
|
break;
|
||||||
case 4:
|
case 4:
|
||||||
if (inputShape.length == 1 || inputType.getType() == InputType.Type.RNN) {
|
if (inputShape.length == 1 || inputType.getType() == InputType.Type.RNN) {
|
||||||
ret = InputType.convolutional(shape[1], shape[2], shape[3]);
|
ret = InputType.convolutional(shape[1], shape[2], shape[3]);
|
||||||
} else {
|
} else {
|
||||||
ret = InputType.convolutional(shape[2], shape[3], shape[1]);
|
|
||||||
|
CNN2DFormat cnnFormat = CNN2DFormat.NCHW;
|
||||||
|
if (this.format != null && this.format instanceof CNN2DFormat)
|
||||||
|
cnnFormat = (CNN2DFormat) this.format;
|
||||||
|
|
||||||
|
if (cnnFormat == CNN2DFormat.NCHW) {
|
||||||
|
ret = InputType.convolutional(shape[2], shape[3], shape[1], cnnFormat);
|
||||||
|
} else {
|
||||||
|
ret = InputType.convolutional(shape[1], shape[2], shape[3], cnnFormat);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
|
|
|
@ -27,26 +27,25 @@ import org.nd4j.shade.jackson.annotation.JsonCreator;
|
||||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Specialized CnnToFeedForwardInputPreProcessor for use with
|
* @deprecated Exists only for backward compatibility of older pretrained models. Should not be used.
|
||||||
* Convolutional layers imported from Keras using the TensorFlow
|
* Use {@link CnnToFeedForwardPreProcessor} for all new models instead.
|
||||||
* backend.
|
|
||||||
*
|
|
||||||
* @author dave@skymind.io
|
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j @Deprecated
|
||||||
public class TensorFlowCnnToFeedForwardPreProcessor extends CnnToFeedForwardPreProcessor {
|
public class TensorFlowCnnToFeedForwardPreProcessor extends CnnToFeedForwardPreProcessor {
|
||||||
|
|
||||||
@JsonCreator
|
@JsonCreator @Deprecated
|
||||||
public TensorFlowCnnToFeedForwardPreProcessor(@JsonProperty("inputHeight") long inputHeight,
|
public TensorFlowCnnToFeedForwardPreProcessor(@JsonProperty("inputHeight") long inputHeight,
|
||||||
@JsonProperty("inputWidth") long inputWidth,
|
@JsonProperty("inputWidth") long inputWidth,
|
||||||
@JsonProperty("numChannels") long numChannels) {
|
@JsonProperty("numChannels") long numChannels) {
|
||||||
super(inputHeight, inputWidth, numChannels);
|
super(inputHeight, inputWidth, numChannels);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
public TensorFlowCnnToFeedForwardPreProcessor(long inputHeight, long inputWidth) {
|
public TensorFlowCnnToFeedForwardPreProcessor(long inputHeight, long inputWidth) {
|
||||||
super(inputHeight, inputWidth);
|
super(inputHeight, inputWidth);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
public TensorFlowCnnToFeedForwardPreProcessor() {
|
public TensorFlowCnnToFeedForwardPreProcessor() {
|
||||||
super();
|
super();
|
||||||
}
|
}
|
||||||
|
@ -81,4 +80,4 @@ public class TensorFlowCnnToFeedForwardPreProcessor extends CnnToFeedForwardPreP
|
||||||
public TensorFlowCnnToFeedForwardPreProcessor clone() {
|
public TensorFlowCnnToFeedForwardPreProcessor clone() {
|
||||||
return (TensorFlowCnnToFeedForwardPreProcessor) super.clone();
|
return (TensorFlowCnnToFeedForwardPreProcessor) super.clone();
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -31,6 +31,7 @@ import org.deeplearning4j.nn.modelimport.keras.layers.advanced.activations.*;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.*;
|
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.*;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.layers.core.*;
|
import org.deeplearning4j.nn.modelimport.keras.layers.core.*;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.layers.embeddings.KerasEmbedding;
|
import org.deeplearning4j.nn.modelimport.keras.layers.embeddings.KerasEmbedding;
|
||||||
|
import org.deeplearning4j.nn.modelimport.keras.layers.local.KerasLocallyConnected1D;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.layers.noise.KerasAlphaDropout;
|
import org.deeplearning4j.nn.modelimport.keras.layers.noise.KerasAlphaDropout;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.layers.noise.KerasGaussianDropout;
|
import org.deeplearning4j.nn.modelimport.keras.layers.noise.KerasGaussianDropout;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.layers.noise.KerasGaussianNoise;
|
import org.deeplearning4j.nn.modelimport.keras.layers.noise.KerasGaussianNoise;
|
||||||
|
@ -319,6 +320,8 @@ public class KerasLayerUtils {
|
||||||
layer = new KerasELU(layerConfig, enforceTrainingConfig);
|
layer = new KerasELU(layerConfig, enforceTrainingConfig);
|
||||||
} else if(layerClassName.equals(conf.getLAYER_CLASS_NAME_SOFTMAX())){
|
} else if(layerClassName.equals(conf.getLAYER_CLASS_NAME_SOFTMAX())){
|
||||||
layer = new KerasSoftmax(layerConfig, enforceTrainingConfig);
|
layer = new KerasSoftmax(layerConfig, enforceTrainingConfig);
|
||||||
|
} else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_LOCALLY_CONNECTED_1D())){
|
||||||
|
layer = new KerasLocallyConnected1D(layerConfig, enforceTrainingConfig);
|
||||||
} else if (conf instanceof Keras2LayerConfiguration){
|
} else if (conf instanceof Keras2LayerConfiguration){
|
||||||
Keras2LayerConfiguration k2conf = (Keras2LayerConfiguration)conf;
|
Keras2LayerConfiguration k2conf = (Keras2LayerConfiguration)conf;
|
||||||
if (layerClassName.equals(k2conf.getTENSORFLOW_OP_LAYER())){
|
if (layerClassName.equals(k2conf.getTENSORFLOW_OP_LAYER())){
|
||||||
|
|
|
@ -1,50 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2020 Konduit K.K.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.nn.modelimport.keras;
|
|
||||||
|
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
|
||||||
import org.junit.Assert;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
import org.nd4j.resources.Resources;
|
|
||||||
|
|
||||||
import java.io.File;
|
|
||||||
import java.util.Arrays;
|
|
||||||
|
|
||||||
public class TFKerasTests extends BaseDL4JTest{
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testModelWithTFOp1() throws Exception{
|
|
||||||
File f = Resources.asFile("modelimport/keras/tfkeras/reshape.h5");
|
|
||||||
ComputationGraph graph = KerasModelImport.importKerasModelAndWeights(f.getAbsolutePath());
|
|
||||||
INDArray out = graph.outputSingle(Nd4j.zeros(12, 2, 3));
|
|
||||||
Assert.assertArrayEquals(new long[]{12, 3}, out.shape());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testModelWithTFOp2() throws Exception{
|
|
||||||
File f = Resources.asFile("modelimport/keras/tfkeras/permute.h5");
|
|
||||||
ComputationGraph graph = KerasModelImport.importKerasModelAndWeights(f.getAbsolutePath());
|
|
||||||
INDArray out = graph.outputSingle(Nd4j.zeros(12, 2, 3));
|
|
||||||
// dl4j's feedforward doesn't support 3D output, so batch and time axes gets squashed
|
|
||||||
long[] expectedShape = new long[]{12 * 2, 5};
|
|
||||||
Assert.assertArrayEquals(expectedShape, out.shape());
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -0,0 +1,147 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.nn.modelimport.keras;
|
||||||
|
|
||||||
|
import org.apache.commons.io.FileUtils;
|
||||||
|
import org.datavec.python.keras.Model;
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.Rule;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.junit.rules.TemporaryFolder;
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
import org.junit.runners.Parameterized;
|
||||||
|
import org.nd4j.common.tests.ResourceUtils;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.concurrency.AffinityManager;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.resources.Resources;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
|
||||||
|
@RunWith(Parameterized.class)
|
||||||
|
public class TestTFKerasModelImport extends BaseDL4JTest{
|
||||||
|
|
||||||
|
@Rule
|
||||||
|
public TemporaryFolder testDir = new TemporaryFolder();
|
||||||
|
|
||||||
|
private String modelFile;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getTimeoutMilliseconds(){
|
||||||
|
return 300000;
|
||||||
|
} // installing TF will take a while
|
||||||
|
|
||||||
|
|
||||||
|
@Parameterized.Parameters(name = "file={0}")
|
||||||
|
public static Object[] params() throws Exception {
|
||||||
|
List<String> paths = ResourceUtils.listClassPathFiles("modelimport/keras/tfkeras", true, false);
|
||||||
|
return paths.toArray(new String[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
public TestTFKerasModelImport(String modelFile){
|
||||||
|
this.modelFile = modelFile;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testModelImport() throws Exception{
|
||||||
|
testModelImportWithData(modelFile);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void testModelImportWithData(String path) throws Exception{
|
||||||
|
System.out.println(path);
|
||||||
|
// TODO multi input/output
|
||||||
|
INDArray inputArray;
|
||||||
|
INDArray expectedOutputArray;
|
||||||
|
File f = Resources.asFile(path); //May in in JAR that HDF5 can't read from
|
||||||
|
File modelFile = new File(testDir.getRoot(), f.getName());
|
||||||
|
FileUtils.copyFile(f, modelFile);
|
||||||
|
|
||||||
|
synchronized (Hdf5Archive.LOCK_OBJECT){
|
||||||
|
Hdf5Archive hdf5Archive = new Hdf5Archive(modelFile.getAbsolutePath());
|
||||||
|
List<String> rootGroups = hdf5Archive.getGroups();
|
||||||
|
if (rootGroups.contains("data")){
|
||||||
|
String inputName = hdf5Archive.readAttributeAsString("input_names", "data");
|
||||||
|
String outputName = hdf5Archive.readAttributeAsString("output_names", "data");
|
||||||
|
inputArray = hdf5Archive.readDataSet(inputName, "data");
|
||||||
|
expectedOutputArray = hdf5Archive.readDataSet(outputName, "data");
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
hdf5Archive.close();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
hdf5Archive.close();
|
||||||
|
}
|
||||||
|
INDArray outputArray;
|
||||||
|
|
||||||
|
ComputationGraph dl4jModel = KerasModelImport.importKerasModelAndWeights(path);
|
||||||
|
outputArray = dl4jModel.outputSingle(inputArray);
|
||||||
|
|
||||||
|
expectedOutputArray = expectedOutputArray.castTo(DataType.FLOAT);
|
||||||
|
outputArray = outputArray.castTo(DataType.FLOAT);
|
||||||
|
if (path.contains("misc_")){
|
||||||
|
//shape relaxation
|
||||||
|
expectedOutputArray = expectedOutputArray.reshape( -1);
|
||||||
|
outputArray = outputArray.reshape(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
System.out.println(outputArray.toString());
|
||||||
|
System.out.println(expectedOutputArray.toString());
|
||||||
|
Assert.assertArrayEquals(expectedOutputArray.shape(), outputArray.shape());
|
||||||
|
Assert.assertTrue(expectedOutputArray.equalsWithEps(outputArray, 1e-3));
|
||||||
|
}
|
||||||
|
|
||||||
|
private void testModelImportWithKeras(String path) throws Exception{
|
||||||
|
Model kerasModel = new Model(path);
|
||||||
|
ComputationGraph dl4jModel = KerasModelImport.importKerasModelAndWeights(path);
|
||||||
|
Assert.assertEquals(kerasModel.numInputs(), dl4jModel.getNumInputArrays());
|
||||||
|
Assert.assertEquals(kerasModel.numOutputs(), dl4jModel.getNumOutputArrays());
|
||||||
|
INDArray[] kerasInputArrays = new INDArray[kerasModel.numInputs()];
|
||||||
|
INDArray[] dl4jInputArrays = new INDArray[kerasModel.numInputs()];
|
||||||
|
|
||||||
|
for (int i = 0; i < kerasInputArrays.length; i ++) {
|
||||||
|
long[] shape = kerasModel.inputShapeAt(i);
|
||||||
|
for (int j = 0; j < shape.length; j++) {
|
||||||
|
if (shape[j] < 0) {
|
||||||
|
shape[j] = 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
kerasInputArrays[i] = Nd4j.rand(shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
INDArray[] kerasOut = kerasModel.predict(kerasInputArrays);
|
||||||
|
INDArray[] dl4jOut = dl4jModel.output(dl4jInputArrays);
|
||||||
|
|
||||||
|
Assert.assertEquals(kerasOut.length, dl4jOut.length);
|
||||||
|
|
||||||
|
for (int i = 0; i < kerasOut.length; i++){
|
||||||
|
INDArray kerasOutArr = kerasOut[i];
|
||||||
|
kerasOutArr = kerasOutArr.reshape(1, -1);// bit of relaxation on shape
|
||||||
|
kerasOutArr= kerasOutArr.castTo(DataType.DOUBLE);
|
||||||
|
Nd4j.getAffinityManager().ensureLocation(dl4jOut[i], AffinityManager.Location.HOST);
|
||||||
|
INDArray dl4jOutArr = dl4jOut[i].reshape(1, -1);
|
||||||
|
System.out.println(kerasOutArr.shapeInfoToString());
|
||||||
|
System.out.println(dl4jOutArr.shapeInfoToString());
|
||||||
|
Assert.assertEquals(kerasOutArr, dl4jOutArr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue