Merge pull request #8892 from KonduitAI/master

Merge latest development work
master
Alex Black 2020-04-28 20:38:56 +10:00 committed by GitHub
commit 1930d99908
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
475 changed files with 10073 additions and 2958 deletions

View File

@ -98,7 +98,7 @@ public class TestGraphLocalExecution extends BaseDL4JTest {
@Override @Override
public long getTimeoutMilliseconds() { public long getTimeoutMilliseconds() {
return 90000L; return 120_000L;
} }
@Test @Test
@ -156,7 +156,7 @@ public class TestGraphLocalExecution extends BaseDL4JTest {
.dataSource(ds, dsP) .dataSource(ds, dsP)
.modelSaver(new FileModelSaver(modelSave)) .modelSaver(new FileModelSaver(modelSave))
.scoreFunction(new TestSetLossScoreFunction()) .scoreFunction(new TestSetLossScoreFunction())
.terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS), .terminationConditions(new MaxTimeCondition(20, TimeUnit.SECONDS),
new MaxCandidatesCondition(3)) new MaxCandidatesCondition(3))
.build(); .build();

View File

@ -87,7 +87,7 @@ public class TestGraphLocalExecutionGenetic extends BaseDL4JTest {
@Override @Override
public long getTimeoutMilliseconds() { public long getTimeoutMilliseconds() {
return 45000L; return 120_000L;
} }
@Test @Test
@ -154,8 +154,8 @@ public class TestGraphLocalExecutionGenetic extends BaseDL4JTest {
.dataSource(ds, dsP) .dataSource(ds, dsP)
.modelSaver(new FileModelSaver(modelSave)) .modelSaver(new FileModelSaver(modelSave))
.scoreFunction(new TestSetLossScoreFunction()) .scoreFunction(new TestSetLossScoreFunction())
.terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS), .terminationConditions(new MaxTimeCondition(20, TimeUnit.SECONDS),
new MaxCandidatesCondition(10)) new MaxCandidatesCondition(3))
.build(); .build();
IOptimizationRunner runner = new LocalOptimizationRunner(configuration, new ComputationGraphTaskCreator(new ClassificationEvaluator())); IOptimizationRunner runner = new LocalOptimizationRunner(configuration, new ComputationGraphTaskCreator(new ClassificationEvaluator()));

View File

@ -18,6 +18,7 @@ package org.datavec.api.records.reader.impl;
import lombok.Getter; import lombok.Getter;
import lombok.Setter; import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils; import org.apache.commons.io.IOUtils;
import org.apache.commons.io.LineIterator; import org.apache.commons.io.LineIterator;
import org.datavec.api.conf.Configuration; import org.datavec.api.conf.Configuration;
@ -43,6 +44,7 @@ import java.util.*;
* *
* @author Adam Gibson * @author Adam Gibson
*/ */
@Slf4j
public class LineRecordReader extends BaseRecordReader { public class LineRecordReader extends BaseRecordReader {
@ -58,6 +60,13 @@ public class LineRecordReader extends BaseRecordReader {
@Override @Override
public void initialize(InputSplit split) throws IOException, InterruptedException { public void initialize(InputSplit split) throws IOException, InterruptedException {
super.initialize(split); super.initialize(split);
if(!(inputSplit instanceof StringSplit || inputSplit instanceof InputStreamInputSplit)){
final ArrayList<URI> uris = new ArrayList<>();
final Iterator<URI> uriIterator = inputSplit.locationsIterator();
while(uriIterator.hasNext()) uris.add(uriIterator.next());
this.locations = uris.toArray(new URI[0]);
}
this.iter = getIterator(0); this.iter = getIterator(0);
this.initialized = true; this.initialized = true;
} }
@ -66,7 +75,6 @@ public class LineRecordReader extends BaseRecordReader {
public void initialize(Configuration conf, InputSplit split) throws IOException, InterruptedException { public void initialize(Configuration conf, InputSplit split) throws IOException, InterruptedException {
this.conf = conf; this.conf = conf;
initialize(split); initialize(split);
this.initialized = true;
} }
@Override @Override
@ -89,7 +97,7 @@ public class LineRecordReader extends BaseRecordReader {
iter = getIterator(splitIndex); iter = getIterator(splitIndex);
onLocationOpen(locations[splitIndex]); onLocationOpen(locations[splitIndex]);
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("",e);
} }
if (iter.hasNext()) { if (iter.hasNext()) {
@ -120,7 +128,7 @@ public class LineRecordReader extends BaseRecordReader {
iter = getIterator(splitIndex); iter = getIterator(splitIndex);
onLocationOpen(locations[splitIndex]); onLocationOpen(locations[splitIndex]);
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("",e);
} }
return iter.hasNext(); return iter.hasNext();
@ -205,11 +213,6 @@ public class LineRecordReader extends BaseRecordReader {
} }
} }
} else { } else {
final ArrayList<URI> uris = new ArrayList<>();
final Iterator<URI> uriIterator = inputSplit.locationsIterator();
while(uriIterator.hasNext()) uris.add(uriIterator.next());
this.locations = uris.toArray(new URI[uris.size()]);
if (locations.length > 0) { if (locations.length > 0) {
InputStream inputStream = streamCreatorFn.apply(locations[location]); InputStream inputStream = streamCreatorFn.apply(locations[location]);
try { try {

View File

@ -16,6 +16,7 @@
package org.datavec.api.split; package org.datavec.api.split;
import lombok.extern.slf4j.Slf4j;
import org.datavec.api.util.files.UriFromPathIterator; import org.datavec.api.util.files.UriFromPathIterator;
import org.datavec.api.writable.WritableType; import org.datavec.api.writable.WritableType;
@ -34,6 +35,7 @@ import java.util.regex.Pattern;
* NumberedFileInputSplit utilizes String.format(), hence the requirement for "%d" to represent * NumberedFileInputSplit utilizes String.format(), hence the requirement for "%d" to represent
* the integer index. * the integer index.
*/ */
@Slf4j
public class NumberedFileInputSplit implements InputSplit { public class NumberedFileInputSplit implements InputSplit {
private final String baseString; private final String baseString;
private final int minIdx; private final int minIdx;
@ -93,7 +95,7 @@ public class NumberedFileInputSplit implements InputSplit {
try { try {
writeFile.createNewFile(); writeFile.createNewFile();
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("",e);
} }

View File

@ -23,6 +23,7 @@ import java.net.URI;
import java.net.URISyntaxException; import java.net.URISyntaxException;
import java.util.Iterator; import java.util.Iterator;
import java.util.NoSuchElementException; import java.util.NoSuchElementException;
import java.util.regex.Pattern;
/** /**
* A simple utility method to convert a {@code Iterator<String>} to an {@code Iterator<URI>}, where each * A simple utility method to convert a {@code Iterator<String>} to an {@code Iterator<URI>}, where each
@ -32,6 +33,7 @@ import java.util.NoSuchElementException;
*/ */
@AllArgsConstructor @AllArgsConstructor
public class UriFromPathIterator implements Iterator<URI> { public class UriFromPathIterator implements Iterator<URI> {
final Pattern schemaPattern = Pattern.compile("^.*?:/.*");
private final Iterator<String> paths; private final Iterator<String> paths;
@ -42,16 +44,17 @@ public class UriFromPathIterator implements Iterator<URI> {
@Override @Override
public URI next() { public URI next() {
if (!hasNext()) { if (!hasNext()) {
throw new NoSuchElementException("No next element"); throw new NoSuchElementException("No next element");
} }
try { try {
String s = paths.next(); String s = paths.next();
if(!s.matches(".*:/.*")){ if(schemaPattern.matcher(s).matches()){
return new URI(s);
} else {
//No scheme - assume file for backward compatibility //No scheme - assume file for backward compatibility
return new File(s).toURI(); return new File(s).toURI();
} else {
return new URI(s);
} }
} catch (URISyntaxException e) { } catch (URISyntaxException e) {

View File

@ -162,7 +162,6 @@ public class Text extends BinaryComparable implements WritableComparable<BinaryC
return -1; // not found return -1; // not found
} catch (CharacterCodingException e) { } catch (CharacterCodingException e) {
// can't get here // can't get here
e.printStackTrace();
return -1; return -1;
} }
} }

View File

@ -17,6 +17,7 @@
package org.datavec.arrow.recordreader; package org.datavec.arrow.recordreader;
import lombok.Getter; import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.datavec.api.conf.Configuration; import org.datavec.api.conf.Configuration;
import org.datavec.api.records.Record; import org.datavec.api.records.Record;
@ -50,6 +51,7 @@ import static org.datavec.arrow.ArrowConverter.readFromBytes;
* @author Adam Gibson * @author Adam Gibson
* *
*/ */
@Slf4j
public class ArrowRecordReader implements RecordReader { public class ArrowRecordReader implements RecordReader {
private InputSplit split; private InputSplit split;
@ -132,7 +134,7 @@ public class ArrowRecordReader implements RecordReader {
currIdx++; currIdx++;
this.currentPath = url; this.currentPath = url;
}catch(Exception e) { }catch(Exception e) {
e.printStackTrace(); log.error("",e);
} }
} }
@ -242,7 +244,7 @@ public class ArrowRecordReader implements RecordReader {
try { try {
currentBatch.close(); currentBatch.close();
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("",e);
} }
} }
} }

View File

@ -16,6 +16,7 @@
package org.datavec.arrow; package org.datavec.arrow;
import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.memory.RootAllocator;
@ -56,7 +57,7 @@ import static junit.framework.TestCase.assertTrue;
import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertFalse;
@Slf4j
public class ArrowConverterTest extends BaseND4JTest { public class ArrowConverterTest extends BaseND4JTest {
private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE); private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE);
@ -343,7 +344,7 @@ public class ArrowConverterTest extends BaseND4JTest {
try(ArrowFileWriter arrowFileWriter = new ArrowFileWriter(schemaRoot1,null,newChannel(byteArrayOutputStream))) { try(ArrowFileWriter arrowFileWriter = new ArrowFileWriter(schemaRoot1,null,newChannel(byteArrayOutputStream))) {
arrowFileWriter.writeBatch(); arrowFileWriter.writeBatch();
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("",e);
} }
byte[] arr = byteArrayOutputStream.toByteArray(); byte[] arr = byteArrayOutputStream.toByteArray();

View File

@ -61,7 +61,7 @@ public class Wave implements Serializable {
initWaveWithInputStream(inputStream); initWaveWithInputStream(inputStream);
inputStream.close(); inputStream.close();
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); System.out.println(e.toString());
} }
} }
@ -96,7 +96,7 @@ public class Wave implements Serializable {
data = new byte[inputStream.available()]; data = new byte[inputStream.available()];
inputStream.read(data); inputStream.read(data);
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); System.err.println(e.toString());
} }
// end load data // end load data
} else { } else {

View File

@ -16,10 +16,12 @@
package org.datavec.audio; package org.datavec.audio;
import lombok.extern.slf4j.Slf4j;
import java.io.FileOutputStream; import java.io.FileOutputStream;
import java.io.IOException; import java.io.IOException;
@Slf4j
public class WaveFileManager { public class WaveFileManager {
private Wave wave; private Wave wave;
@ -78,7 +80,7 @@ public class WaveFileManager {
fos.write(wave.getBytes()); fos.write(wave.getBytes());
fos.close(); fos.close();
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("",e);
} }
} }

View File

@ -16,6 +16,8 @@
package org.datavec.audio; package org.datavec.audio;
import lombok.extern.slf4j.Slf4j;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
@ -25,6 +27,7 @@ import java.io.InputStream;
* *
* @author Jacquet Wong * @author Jacquet Wong
*/ */
@Slf4j
public class WaveHeader { public class WaveHeader {
public static final String RIFF_HEADER = "RIFF"; public static final String RIFF_HEADER = "RIFF";
@ -109,7 +112,7 @@ public class WaveHeader {
// dis.close(); // dis.close();
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("",e);
return false; return false;
} }

View File

@ -17,6 +17,7 @@
package org.datavec.audio.fingerprint; package org.datavec.audio.fingerprint;
import lombok.extern.slf4j.Slf4j;
import org.datavec.audio.Wave; import org.datavec.audio.Wave;
import org.datavec.audio.WaveHeader; import org.datavec.audio.WaveHeader;
import org.datavec.audio.dsp.Resampler; import org.datavec.audio.dsp.Resampler;
@ -38,6 +39,7 @@ import java.util.List;
* @author jacquet * @author jacquet
* *
*/ */
@Slf4j
public class FingerprintManager { public class FingerprintManager {
private FingerprintProperties fingerprintProperties = FingerprintProperties.getInstance(); private FingerprintProperties fingerprintProperties = FingerprintProperties.getInstance();
@ -153,7 +155,7 @@ public class FingerprintManager {
fingerprint = getFingerprintFromInputStream(fis); fingerprint = getFingerprintFromInputStream(fis);
fis.close(); fis.close();
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("",e);
} }
return fingerprint; return fingerprint;
} }
@ -170,7 +172,7 @@ public class FingerprintManager {
fingerprint = new byte[inputStream.available()]; fingerprint = new byte[inputStream.available()];
inputStream.read(fingerprint); inputStream.read(fingerprint);
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("",e);
} }
return fingerprint; return fingerprint;
} }
@ -190,7 +192,7 @@ public class FingerprintManager {
fileOutputStream.write(fingerprint); fileOutputStream.write(fingerprint);
fileOutputStream.close(); fileOutputStream.close();
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("",e);
} }
} }

View File

@ -81,7 +81,7 @@ public abstract class BaseImageLoader implements Serializable {
String fileName = file.toString(); String fileName = file.toString();
if (fileName.endsWith(".tgz") || fileName.endsWith(".tar.gz") || fileName.endsWith(".gz") if (fileName.endsWith(".tgz") || fileName.endsWith(".tar.gz") || fileName.endsWith(".gz")
|| fileName.endsWith(".zip")) || fileName.endsWith(".zip"))
ArchiveUtils.unzipFileTo(file.getAbsolutePath(), fullDir.getAbsolutePath()); ArchiveUtils.unzipFileTo(file.getAbsolutePath(), fullDir.getAbsolutePath(), false);
} catch (IOException e) { } catch (IOException e) {
throw new IllegalStateException("Unable to fetch images", e); throw new IllegalStateException("Unable to fetch images", e);
} }

View File

@ -186,7 +186,7 @@ public class CifarLoader extends NativeImageLoader implements Serializable {
labels.add(line); labels.add(line);
} }
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("",e);
} }
} }
@ -300,7 +300,7 @@ public class CifarLoader extends NativeImageLoader implements Serializable {
try { try {
FileUtils.write(meanVarPath, uMean + "," + uStd + "," + vMean + "," + vStd); FileUtils.write(meanVarPath, uMean + "," + uStd + "," + vMean + "," + vStd);
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("",e);
} }
meanStdStored = true; meanStdStored = true;
} else if (uMean == 0 && meanStdStored) { } else if (uMean == 0 && meanStdStored) {
@ -312,7 +312,7 @@ public class CifarLoader extends NativeImageLoader implements Serializable {
vStd = Double.parseDouble(values[3]); vStd = Double.parseDouble(values[3]);
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("",e);
} }
} }
for (int i = 0; i < result.numExamples(); i++) { for (int i = 0; i < result.numExamples(); i++) {
@ -356,12 +356,12 @@ public class CifarLoader extends NativeImageLoader implements Serializable {
dataSets.add(new DataSet(asMatrix(matConversion.getSecond()), matConversion.getFirst())); dataSets.add(new DataSet(asMatrix(matConversion.getSecond()), matConversion.getFirst()));
batchNumCount++; batchNumCount++;
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
break; break;
} }
} }
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("",e);
} }
if(dataSets.size() == 0){ if(dataSets.size() == 0){

View File

@ -235,7 +235,7 @@ public class LFWLoader extends BaseImageLoader implements Serializable {
InputSplit data = train ? inputSplit[0] : inputSplit[1]; InputSplit data = train ? inputSplit[0] : inputSplit[1];
recordReader.initialize(data); recordReader.initialize(data);
} catch (IOException | InterruptedException e) { } catch (IOException | InterruptedException e) {
e.printStackTrace(); log.error("",e);
} }
return recordReader; return recordReader;
} }
@ -250,7 +250,7 @@ public class LFWLoader extends BaseImageLoader implements Serializable {
InputSplit data = train ? inputSplit[0] : inputSplit[1]; InputSplit data = train ? inputSplit[0] : inputSplit[1];
recordReader.initialize(data); recordReader.initialize(data);
} catch (IOException | InterruptedException e) { } catch (IOException | InterruptedException e) {
e.printStackTrace(); log.error("",e);
} }
return recordReader; return recordReader;
} }

View File

@ -225,7 +225,7 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
finishedInputStreamSplit = true; finishedInputStreamSplit = true;
return Arrays.<Writable>asList(ndArrayWritable); return Arrays.<Writable>asList(ndArrayWritable);
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("",e);
} }
} }
if (iter != null) { if (iter != null) {

View File

@ -16,6 +16,7 @@
package org.datavec.image.loader; package org.datavec.image.loader;
import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.apache.commons.io.IOUtils; import org.apache.commons.io.IOUtils;
import org.bytedeco.javacpp.Loader; import org.bytedeco.javacpp.Loader;
@ -53,6 +54,7 @@ import static org.junit.Assert.fail;
* *
* @author saudet * @author saudet
*/ */
@Slf4j
public class TestNativeImageLoader { public class TestNativeImageLoader {
static final long seed = 10; static final long seed = 10;
static final Random rng = new Random(seed); static final Random rng = new Random(seed);
@ -123,7 +125,7 @@ public class TestNativeImageLoader {
try { try {
array6 = loader5.asMatrix(new ClassPathResource(path2MitosisFile).getFile().getAbsolutePath()); array6 = loader5.asMatrix(new ClassPathResource(path2MitosisFile).getFile().getAbsolutePath());
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("",e);
fail(); fail();
} }
assertEquals(5, array6.rank()); assertEquals(5, array6.rank());
@ -156,7 +158,7 @@ public class TestNativeImageLoader {
try { try {
array8 = loader7.asMatrix(new ClassPathResource(path2MitosisFile).getFile().getAbsolutePath()); array8 = loader7.asMatrix(new ClassPathResource(path2MitosisFile).getFile().getAbsolutePath());
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("",e);
} }
assertEquals(5, array8.rank()); assertEquals(5, array8.rank());
assertEquals(pages2, array8.size(0)); assertEquals(pages2, array8.size(0));
@ -172,7 +174,7 @@ public class TestNativeImageLoader {
try { try {
array9 = loader8.asMatrix(new ClassPathResource(braintiff).getFile().getAbsolutePath()); array9 = loader8.asMatrix(new ClassPathResource(braintiff).getFile().getAbsolutePath());
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("",e);
fail(); fail();
} }
assertEquals(5, array9.rank()); assertEquals(5, array9.rank());

View File

@ -66,7 +66,7 @@ public class UimaTokenizer implements Tokenizer {
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
throw new RuntimeException(e); throw new RuntimeException(e);
} }

View File

@ -17,6 +17,7 @@
package org.datavec.local.transforms.functions; package org.datavec.local.transforms.functions;
import lombok.extern.slf4j.Slf4j;
import org.datavec.api.records.reader.SequenceRecordReader; import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.nd4j.linalg.function.Function; import org.nd4j.linalg.function.Function;
@ -32,6 +33,7 @@ import java.util.List;
* sequence data into a {@code List<List<Writable>>} * sequence data into a {@code List<List<Writable>>}
* @author Alex Black * @author Alex Black
*/ */
@Slf4j
public class SequenceRecordReaderFunction public class SequenceRecordReaderFunction
implements Function<Pair<String, InputStream>, List<List<Writable>>> { implements Function<Pair<String, InputStream>, List<List<Writable>>> {
protected SequenceRecordReader sequenceRecordReader; protected SequenceRecordReader sequenceRecordReader;
@ -46,7 +48,7 @@ public class SequenceRecordReaderFunction
try (DataInputStream dis = (DataInputStream) value.getRight()) { try (DataInputStream dis = (DataInputStream) value.getRight()) {
return sequenceRecordReader.sequenceRecord(uri, dis); return sequenceRecordReader.sequenceRecord(uri, dis);
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("",e);
} }
throw new IllegalStateException("Something went wrong"); throw new IllegalStateException("Something went wrong");

View File

@ -20,6 +20,7 @@ package org.datavec.python;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils; import org.apache.commons.io.IOUtils;
import org.bytedeco.cpython.global.python;
import org.bytedeco.numpy.global.numpy; import org.bytedeco.numpy.global.numpy;
import org.nd4j.linalg.api.concurrency.AffinityManager; import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -343,6 +344,19 @@ public class PythonExecutioner {
if (path == null) { if (path == null) {
log.info("Setting python default path"); log.info("Setting python default path");
File[] packages = numpy.cachePackages(); File[] packages = numpy.cachePackages();
//// TODO: fix in javacpp
File sitePackagesWindows = new File(python.cachePackage(), "site-packages");
File[] packages2 = new File[packages.length + 1];
for (int i = 0;i < packages.length; i++){
//System.out.println(packages[i].getAbsolutePath());
packages2[i] = packages[i];
}
packages2[packages.length] = sitePackagesWindows;
//System.out.println(sitePackagesWindows.getAbsolutePath());
packages = packages2;
//////////
Py_SetPath(packages); Py_SetPath(packages);
} else { } else {
log.info("Setting python path " + path); log.info("Setting python path " + path);

View File

@ -0,0 +1,132 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.python;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import org.bytedeco.javacpp.Loader;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
@Slf4j
public class PythonProcess {
private static String pythonExecutable = Loader.load(org.bytedeco.cpython.python.class);
public static String runAndReturn(String... arguments)throws IOException, InterruptedException{
String[] allArgs = new String[arguments.length + 1];
for (int i = 0; i < arguments.length; i++){
allArgs[i + 1] = arguments[i];
}
allArgs[0] = pythonExecutable;
log.info("Executing command: " + Arrays.toString(allArgs));
ProcessBuilder pb = new ProcessBuilder(allArgs);
Process process = pb.start();
String out = IOUtils.toString(process.getInputStream(), StandardCharsets.UTF_8);
process.waitFor();
return out;
}
public static void run(String... arguments)throws IOException, InterruptedException{
String[] allArgs = new String[arguments.length + 1];
for (int i = 0; i < arguments.length; i++){
allArgs[i + 1] = arguments[i];
}
allArgs[0] = pythonExecutable;
log.info("Executing command: " + Arrays.toString(allArgs));
ProcessBuilder pb = new ProcessBuilder(allArgs);
pb.inheritIO().start().waitFor();
}
public static void pipInstall(String packageName) throws PythonException{
try{
run("-m", "pip", "install", packageName);
}catch(Exception e){
throw new PythonException("Error installing package " + packageName, e);
}
}
public static void pipInstall(String packageName, String version) throws PythonException{
pipInstall(packageName + "==" + version);
}
public static void pipUninstall(String packageName) throws PythonException{
try{
run("-m", "pip", "uninstall", packageName);
}catch(Exception e){
throw new PythonException("Error uninstalling package " + packageName, e);
}
}
public static void pipInstallFromGit(String gitRepoUrl) throws PythonException{
if (!gitRepoUrl.contains("://")){
gitRepoUrl = "git://" + gitRepoUrl;
}
try{
run("-m", "pip", "install", "git+", gitRepoUrl);
}catch(Exception e){
throw new PythonException("Error installing package from " + gitRepoUrl, e);
}
}
public static String getPackageVersion(String packageName) throws PythonException{
String out;
try{
out = runAndReturn("-m", "pip", "show", packageName);
} catch (Exception e){
throw new PythonException("Error finding version for package " + packageName, e);
}
if (!out.contains("Version: ")){
throw new PythonException("Can't find package " + packageName);
}
String pkgVersion = out.split("Version: ")[1].split(System.lineSeparator())[0];
return pkgVersion;
}
public static boolean isPackageInstalled(String packageName)throws PythonException{
try{
String out = runAndReturn("-m", "pip", "show", packageName);
return !out.isEmpty();
}catch (Exception e){
throw new PythonException("Error checking if package is installed: " +packageName, e);
}
}
public static void pipInstallFromRequirementsTxt(String path) throws PythonException{
try{
run("-m", "pip", "install","-r", path);
}catch (Exception e){
throw new PythonException("Error installing packages from " + path, e);
}
}
public static void pipInstallFromSetupScript(String path, boolean inplace) throws PythonException{
try{
run(path, inplace?"develop":"install");
}catch (Exception e){
throw new PythonException("Error installing package from " + path, e);
}
}
}

View File

@ -0,0 +1,144 @@
package org.datavec.python.keras;
import org.datavec.python.Python;
import org.datavec.python.PythonException;
import org.datavec.python.PythonObject;
import org.datavec.python.PythonProcess;
import org.nd4j.linalg.api.ndarray.INDArray;
public class Model {
private PythonObject pyModel;
private static PythonObject installAndImportTF() throws PythonException{
if (!PythonProcess.isPackageInstalled("tensorflow")){
PythonProcess.pipInstall("tensorflow");
}
return Python.importModule("tensorflow");
}
private static PythonObject getKerasModule() throws PythonException{
PythonObject tf = installAndImportTF();
PythonObject keras = tf.attr("keras");
tf.del();
return keras;
}
private static PythonObject loadModel(String s) throws PythonException{
PythonObject models = getKerasModule().attr("models");
PythonObject loadModelF = models.attr("load_model");
PythonObject model = loadModelF.call(s);
models.del();
loadModelF.del();
return model;
}
public Model(String path) throws PythonException{
pyModel = loadModel(path);
}
public INDArray[] predict(INDArray... inputs) throws PythonException{
PythonObject predictF = pyModel.attr("predict");
PythonObject inputList = new PythonObject(inputs);
PythonObject pyOut = predictF.call(inputList);
INDArray[] out;
if (Python.isinstance(pyOut, Python.listType())){
out = new INDArray[Python.len(pyOut).toInt()];
for(int i = 0; i < out.length; i++){
out[i] = pyOut.get(i).toNumpy().getNd4jArray();
}
}
else{
out = new INDArray[]{
pyOut.toNumpy().getNd4jArray()};
}
predictF.del();
inputList.del();
pyOut.del();
return out;
}
public int numInputs(){
PythonObject inputs = pyModel.attr("inputs");
PythonObject pyNumInputs = Python.len(inputs);
int ret = pyNumInputs.toInt();
inputs.del();
pyNumInputs.del();
return ret;
}
public int numOutputs(){
PythonObject outputs = pyModel.attr("outputs");
PythonObject pyNumOutputs = Python.len(outputs);
int ret = pyNumOutputs.toInt();
outputs.del();
pyNumOutputs.del();
return ret;
}
public long[][] inputShapes(){
long[][] ret = new long[numInputs()][];
for (int i = 0; i < ret.length; i++){
ret[i] = inputShapeAt(i);
}
return ret;
}
public long[][] outputShapes(){
long[][] ret = new long[numOutputs()][];
for (int i = 0; i < ret.length; i++){
ret[i] = outputShapeAt(i);
}
return ret;
}
public long[] inputShapeAt(int input){
PythonObject inputs = pyModel.attr("inputs");
PythonObject tensor = inputs.get(input);
PythonObject tensorShape = tensor.attr("shape");
PythonObject shapeList = Python.list(tensorShape);
PythonObject pyNdim = Python.len(shapeList);
int ndim = pyNdim.toInt();
long[] shape = new long[ndim];
for(int i = 0; i < shape.length; i++){
PythonObject pyDim = shapeList.get(i);
if (pyDim == null || !Python.isinstance(pyDim, Python.intType())){
shape[i] = -1;
}
else{
shape[i] = pyDim.toLong();
}
}
pyNdim.del();
shapeList.del();
tensorShape.del();
tensor.del();
inputs.del();
return shape;
}
public long[] outputShapeAt(int output){
PythonObject inputs = pyModel.attr("outputs");
PythonObject tensor = inputs.get(output);
PythonObject tensorShape = tensor.attr("shape");
PythonObject shapeList = Python.list(tensorShape);
PythonObject pyNdim = Python.len(shapeList);
int ndim = pyNdim.toInt();
long[] shape = new long[ndim];
for(int i = 0; i < shape.length; i++){
PythonObject pyDim = shapeList.get(i);
if (pyDim == null || !Python.isinstance(pyDim, Python.intType())){
shape[i] = -1;
}
else{
shape[i] = pyDim.toLong();
}
}
pyNdim.del();
shapeList.del();
tensorShape.del();
tensor.del();
inputs.del();
return shape;
}
}

View File

@ -74,7 +74,6 @@ public class DataVecTransformClient implements DataVecTransformService {
} catch (UnirestException e) { } catch (UnirestException e) {
log.error("Error in setCSVTransformProcess()", e); log.error("Error in setCSVTransformProcess()", e);
e.printStackTrace();
} }
} }
@ -94,7 +93,6 @@ public class DataVecTransformClient implements DataVecTransformService {
return TransformProcess.fromJson(s); return TransformProcess.fromJson(s);
} catch (UnirestException e) { } catch (UnirestException e) {
log.error("Error in getCSVTransformProcess()",e); log.error("Error in getCSVTransformProcess()",e);
e.printStackTrace();
} }
return null; return null;
@ -119,7 +117,6 @@ public class DataVecTransformClient implements DataVecTransformService {
return singleCsvRecord; return singleCsvRecord;
} catch (UnirestException e) { } catch (UnirestException e) {
log.error("Error in transformIncremental(SingleCSVRecord)",e); log.error("Error in transformIncremental(SingleCSVRecord)",e);
e.printStackTrace();
} }
return null; return null;
} }
@ -140,8 +137,7 @@ public class DataVecTransformClient implements DataVecTransformService {
.getBody(); .getBody();
return batchCSVRecord1; return batchCSVRecord1;
} catch (UnirestException e) { } catch (UnirestException e) {
log.error("Error in transform(BatchCSVRecord)", e); log.error("",e);
e.printStackTrace();
} }
return null; return null;
@ -162,7 +158,6 @@ public class DataVecTransformClient implements DataVecTransformService {
return batchCSVRecord1; return batchCSVRecord1;
} catch (UnirestException e) { } catch (UnirestException e) {
log.error("Error in transform(BatchCSVRecord)", e); log.error("Error in transform(BatchCSVRecord)", e);
e.printStackTrace();
} }
return null; return null;
@ -181,7 +176,6 @@ public class DataVecTransformClient implements DataVecTransformService {
return batchArray1; return batchArray1;
} catch (UnirestException e) { } catch (UnirestException e) {
log.error("Error in transformArray(BatchCSVRecord)",e); log.error("Error in transformArray(BatchCSVRecord)",e);
e.printStackTrace();
} }
return null; return null;
@ -200,7 +194,6 @@ public class DataVecTransformClient implements DataVecTransformService {
return array; return array;
} catch (UnirestException e) { } catch (UnirestException e) {
log.error("Error in transformArrayIncremental(SingleCSVRecord)",e); log.error("Error in transformArrayIncremental(SingleCSVRecord)",e);
e.printStackTrace();
} }
return null; return null;
@ -231,7 +224,6 @@ public class DataVecTransformClient implements DataVecTransformService {
return array; return array;
} catch (UnirestException e) { } catch (UnirestException e) {
log.error("Error in transformSequenceArrayIncremental",e); log.error("Error in transformSequenceArrayIncremental",e);
e.printStackTrace();
} }
return null; return null;
@ -252,7 +244,6 @@ public class DataVecTransformClient implements DataVecTransformService {
return batchArray1; return batchArray1;
} catch (UnirestException e) { } catch (UnirestException e) {
log.error("Error in transformSequenceArray",e); log.error("Error in transformSequenceArray",e);
e.printStackTrace();
} }
return null; return null;
@ -274,7 +265,6 @@ public class DataVecTransformClient implements DataVecTransformService {
return batchCSVRecord1; return batchCSVRecord1;
} catch (UnirestException e) { } catch (UnirestException e) {
log.error("Error in transformSequence"); log.error("Error in transformSequence");
e.printStackTrace();
} }
return null; return null;
@ -295,7 +285,6 @@ public class DataVecTransformClient implements DataVecTransformService {
return singleCsvRecord; return singleCsvRecord;
} catch (UnirestException e) { } catch (UnirestException e) {
log.error("Error in transformSequenceIncremental"); log.error("Error in transformSequenceIncremental");
e.printStackTrace();
} }
return null; return null;
} }

View File

@ -18,6 +18,7 @@ package org.datavec.spark.transform;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Getter; import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.memory.RootAllocator;
@ -53,7 +54,7 @@ import static org.datavec.local.transforms.LocalTransformExecutor.executeToSeque
* @author Adan Gibson * @author Adan Gibson
*/ */
@AllArgsConstructor @AllArgsConstructor
@Slf4j
public class CSVSparkTransform { public class CSVSparkTransform {
@Getter @Getter
private TransformProcess transformProcess; private TransformProcess transformProcess;
@ -252,7 +253,7 @@ public class CSVSparkTransform {
try { try {
return new Base64NDArrayBody(Nd4jBase64.base64String(arr)); return new Base64NDArrayBody(Nd4jBase64.base64String(arr));
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("",e);
} }
return null; return null;

View File

@ -88,7 +88,7 @@ public class ImageSparkTransformServer extends SparkTransformServer {
log.info("Transform process initialized"); log.info("Transform process initialized");
return ok(objectMapper.writeValueAsString(transform.getImageTransformProcess())).as(contentType); return ok(objectMapper.writeValueAsString(transform.getImageTransformProcess())).as(contentType);
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
return internalServerError(); return internalServerError();
} }
}); });
@ -100,7 +100,7 @@ public class ImageSparkTransformServer extends SparkTransformServer {
log.info("Transform process initialized"); log.info("Transform process initialized");
return ok(objectMapper.writeValueAsString(transformProcess)).as(contentType); return ok(objectMapper.writeValueAsString(transformProcess)).as(contentType);
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
return internalServerError(); return internalServerError();
} }
}); });
@ -112,7 +112,7 @@ public class ImageSparkTransformServer extends SparkTransformServer {
return badRequest(); return badRequest();
return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType); return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType);
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
return internalServerError(); return internalServerError();
} }
}); });
@ -130,7 +130,7 @@ public class ImageSparkTransformServer extends SparkTransformServer {
return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType); return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType);
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
return internalServerError(); return internalServerError();
} }
}); });
@ -142,7 +142,7 @@ public class ImageSparkTransformServer extends SparkTransformServer {
return badRequest(); return badRequest();
return ok(objectMapper.writeValueAsString(transformArray(batch))).as(contentType); return ok(objectMapper.writeValueAsString(transformArray(batch))).as(contentType);
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
return internalServerError(); return internalServerError();
} }
}); });
@ -169,7 +169,7 @@ public class ImageSparkTransformServer extends SparkTransformServer {
return ok(objectMapper.writeValueAsString(transformArray(batch))).as(contentType); return ok(objectMapper.writeValueAsString(transformArray(batch))).as(contentType);
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
return internalServerError(); return internalServerError();
} }
}); });

View File

@ -16,6 +16,7 @@
package org.datavec.spark; package org.datavec.spark;
import lombok.extern.slf4j.Slf4j;
import org.apache.spark.SparkConf; import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.JavaSparkContext;
import org.junit.After; import org.junit.After;
@ -23,6 +24,7 @@ import org.junit.Before;
import java.io.Serializable; import java.io.Serializable;
@Slf4j
public abstract class BaseSparkTest implements Serializable { public abstract class BaseSparkTest implements Serializable {
protected static JavaSparkContext sc; protected static JavaSparkContext sc;
@ -40,7 +42,7 @@ public abstract class BaseSparkTest implements Serializable {
try { try {
Thread.sleep(100L); Thread.sleep(100L);
} catch (InterruptedException e) { } catch (InterruptedException e) {
e.printStackTrace(); log.error("",e);
} }
} else { } else {
break; break;

View File

@ -68,7 +68,7 @@ public abstract class BaseDL4JTest {
* Override this method to set the default timeout for methods in the test class * Override this method to set the default timeout for methods in the test class
*/ */
public long getTimeoutMilliseconds(){ public long getTimeoutMilliseconds(){
return 30000; return 60_000;
} }
/** /**

View File

@ -43,7 +43,8 @@ import java.util.concurrent.atomic.AtomicLong;
*/ */
@Slf4j @Slf4j
public class RemoteUIStatsStorageRouter implements StatsStorageRouter, Serializable, Closeable { public class RemoteUIStatsStorageRouter implements StatsStorageRouter, Serializable, Closeable {
private static final String ROUTE_IS_DOWN = "Info posted to RemoteUIStatsStorageRouter but router is shut down.";
private static final String MAX_WARNINGS_REACHED = "RemoteUIStatsStorageRouter: Reached max shutdown warnings. No further warnings will be produced.";
/** /**
* Default path for posting data to the UI - i.e., http://localhost:9000/remoteReceive or similar * Default path for posting data to the UI - i.e., http://localhost:9000/remoteReceive or similar
*/ */
@ -163,10 +164,10 @@ public class RemoteUIStatsStorageRouter implements StatsStorageRouter, Serializa
if (shutdown.get()) { if (shutdown.get()) {
long count = shutdownWarnCount.getAndIncrement(); long count = shutdownWarnCount.getAndIncrement();
if (count <= MAX_SHUTDOWN_WARN_COUNT) { if (count <= MAX_SHUTDOWN_WARN_COUNT) {
log.warn("Info posted to RemoteUIStatsStorageRouter but router is shut down."); log.warn(ROUTE_IS_DOWN);
} }
if (count == MAX_SHUTDOWN_WARN_COUNT) { if (count == MAX_SHUTDOWN_WARN_COUNT) {
log.warn("RemoteUIStatsStorageRouter: Reached max shutdown warnings. No further warnings will be produced."); log.warn(MAX_WARNINGS_REACHED);
} }
} else { } else {
for (StorageMetaData m : storageMetaData) { for (StorageMetaData m : storageMetaData) {
@ -186,10 +187,10 @@ public class RemoteUIStatsStorageRouter implements StatsStorageRouter, Serializa
if (shutdown.get()) { if (shutdown.get()) {
long count = shutdownWarnCount.getAndIncrement(); long count = shutdownWarnCount.getAndIncrement();
if (count <= MAX_SHUTDOWN_WARN_COUNT) { if (count <= MAX_SHUTDOWN_WARN_COUNT) {
log.warn("Info posted to RemoteUIStatsStorageRouter but router is shut down."); log.warn(ROUTE_IS_DOWN);
} }
if (count == MAX_SHUTDOWN_WARN_COUNT) { if (count == MAX_SHUTDOWN_WARN_COUNT) {
log.warn("RemoteUIStatsStorageRouter: Reached max shutdown warnings. No further warnings will be produced."); log.warn(MAX_WARNINGS_REACHED);
} }
} else { } else {
for (Persistable p : staticInfo) { for (Persistable p : staticInfo) {
@ -209,10 +210,10 @@ public class RemoteUIStatsStorageRouter implements StatsStorageRouter, Serializa
if (shutdown.get()) { if (shutdown.get()) {
long count = shutdownWarnCount.getAndIncrement(); long count = shutdownWarnCount.getAndIncrement();
if (count <= MAX_SHUTDOWN_WARN_COUNT) { if (count <= MAX_SHUTDOWN_WARN_COUNT) {
log.warn("Info posted to RemoteUIStatsStorageRouter but router is shut down."); log.warn(ROUTE_IS_DOWN);
} }
if (count == MAX_SHUTDOWN_WARN_COUNT) { if (count == MAX_SHUTDOWN_WARN_COUNT) {
log.warn("RemoteUIStatsStorageRouter: Reached max shutdown warnings. No further warnings will be produced."); log.warn(MAX_WARNINGS_REACHED);
} }
} else { } else {
for (Persistable p : updates) { for (Persistable p : updates) {

View File

@ -67,10 +67,7 @@ public class AsyncIterator<T extends Object> implements Iterator<T> {
nextElement = buffer.take(); nextElement = buffer.take();
// same on this run // same on this run
if (nextElement == terminator) return (nextElement != terminator);
return false;
return true;
} catch (Exception e) { } catch (Exception e) {
log.error("Premature end of loop!"); log.error("Premature end of loop!");
return false; return false;

View File

@ -43,7 +43,7 @@ public class SystemInfoPrintListener implements TrainingListener {
private boolean printOnBackwardPass; private boolean printOnBackwardPass;
private boolean printOnGradientCalculation; private boolean printOnGradientCalculation;
private static final String SYSTEM_INFO = "System info on epoch end: ";
@Override @Override
public void iterationDone(Model model, int iteration, int epoch) { public void iterationDone(Model model, int iteration, int epoch) {
@ -65,7 +65,7 @@ public class SystemInfoPrintListener implements TrainingListener {
return; return;
SystemInfo systemInfo = new SystemInfo(); SystemInfo systemInfo = new SystemInfo();
log.info("System info on epoch end: "); log.info(SYSTEM_INFO);
log.info(systemInfo.toPrettyJSON()); log.info(systemInfo.toPrettyJSON());
} }
@ -75,7 +75,7 @@ public class SystemInfoPrintListener implements TrainingListener {
return; return;
SystemInfo systemInfo = new SystemInfo(); SystemInfo systemInfo = new SystemInfo();
log.info("System info on epoch end: "); log.info(SYSTEM_INFO);
log.info(systemInfo.toPrettyJSON()); log.info(systemInfo.toPrettyJSON());
} }
@ -85,7 +85,7 @@ public class SystemInfoPrintListener implements TrainingListener {
return; return;
SystemInfo systemInfo = new SystemInfo(); SystemInfo systemInfo = new SystemInfo();
log.info("System info on epoch end: "); log.info(SYSTEM_INFO);
log.info(systemInfo.toPrettyJSON()); log.info(systemInfo.toPrettyJSON());
} }
@ -95,7 +95,7 @@ public class SystemInfoPrintListener implements TrainingListener {
return; return;
SystemInfo systemInfo = new SystemInfo(); SystemInfo systemInfo = new SystemInfo();
log.info("System info on epoch end: "); log.info(SYSTEM_INFO);
log.info(systemInfo.toPrettyJSON()); log.info(systemInfo.toPrettyJSON());
} }
@ -104,7 +104,7 @@ public class SystemInfoPrintListener implements TrainingListener {
if(!printOnBackwardPass) if(!printOnBackwardPass)
return; return;
SystemInfo systemInfo = new SystemInfo(); SystemInfo systemInfo = new SystemInfo();
log.info("System info on epoch end: "); log.info(SYSTEM_INFO);
log.info(systemInfo.toPrettyJSON()); log.info(systemInfo.toPrettyJSON());
} }
} }

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.perf.listener; package org.deeplearning4j.perf.listener;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.shade.jackson.databind.ObjectMapper; import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory; import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
import oshi.json.SystemInfo; import oshi.json.SystemInfo;
@ -36,6 +37,7 @@ import java.util.concurrent.TimeUnit;
* *
* @author Adam Gibson * @author Adam Gibson
*/ */
@Slf4j
public class SystemPolling { public class SystemPolling {
private ScheduledExecutorService scheduledExecutorService; private ScheduledExecutorService scheduledExecutorService;
@ -66,7 +68,7 @@ public class SystemPolling {
try { try {
objectMapper.writeValue(hardwareFile,hardwareMetric); objectMapper.writeValue(hardwareFile,hardwareMetric);
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("",e);
} }
} }
},0,pollEveryMillis, TimeUnit.MILLISECONDS); },0,pollEveryMillis, TimeUnit.MILLISECONDS);

View File

@ -27,7 +27,6 @@ import org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
import org.nd4j.linalg.dataset.api.preprocessor.Normalizer; import org.nd4j.linalg.dataset.api.preprocessor.Normalizer;
import java.io.*; import java.io.*;
import java.nio.file.Files;
import java.util.UUID; import java.util.UUID;
/** /**

View File

@ -20,6 +20,7 @@ import org.apache.commons.compress.utils.IOUtils;
import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.layers.BaseLayer; import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer; import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;
@ -153,11 +154,22 @@ public class TestUtils {
return randomOneHotTimeSeries(minibatch, outSize, tsLength, new Random(rngSeed)); return randomOneHotTimeSeries(minibatch, outSize, tsLength, new Random(rngSeed));
} }
public static INDArray randomOneHotTimeSeries(int minibatch, int outSize, int tsLength, Random rng){ public static INDArray randomOneHotTimeSeries(int minibatch, int outSize, int tsLength, Random rng) {
INDArray out = Nd4j.create(new int[]{minibatch, outSize, tsLength}, 'f'); return randomOneHotTimeSeries(RNNFormat.NCW, minibatch, outSize, tsLength, rng);
}
public static INDArray randomOneHotTimeSeries(RNNFormat format, int minibatch, int outSize, int tsLength, Random rng){
boolean ncw = format == RNNFormat.NCW;
long[] shape = ncw ? new long[]{minibatch, outSize, tsLength} : new long[]{minibatch, tsLength, outSize};
char order = ncw ? 'f' : 'c';
INDArray out = Nd4j.create(DataType.FLOAT, shape, order);
for( int i=0; i<minibatch; i++ ){ for( int i=0; i<minibatch; i++ ){
for( int j=0; j<tsLength; j++ ){ for( int j=0; j<tsLength; j++ ){
out.putScalar(i, rng.nextInt(outSize), j, 1.0); if(ncw){
out.putScalar(i, rng.nextInt(outSize), j, 1.0);
} else {
out.putScalar(i, j, rng.nextInt(outSize), 1.0);
}
} }
} }
return out; return out;

View File

@ -493,7 +493,7 @@ public class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
out.println(); out.println();
} }
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("",e);
} }
return new Pair<>(dArr,temp); return new Pair<>(dArr,temp);

View File

@ -78,10 +78,10 @@ public class EarlyTerminationDataSetIteratorTest extends BaseDL4JTest {
EarlyTerminationDataSetIterator earlyEndIter = new EarlyTerminationDataSetIterator(iter, terminateAfter); EarlyTerminationDataSetIterator earlyEndIter = new EarlyTerminationDataSetIterator(iter, terminateAfter);
earlyEndIter.next(10); earlyEndIter.next(10);
assertEquals(earlyEndIter.hasNext(), false); assertEquals(false, earlyEndIter.hasNext());
earlyEndIter.reset(); earlyEndIter.reset();
assertEquals(earlyEndIter.hasNext(), true); assertEquals(true, earlyEndIter.hasNext());
} }

View File

@ -98,7 +98,7 @@ public class MultipleEpochsIteratorTest extends BaseDL4JTest {
while (multiIter.hasNext()) { while (multiIter.hasNext()) {
DataSet path = multiIter.next(10); DataSet path = multiIter.next(10);
assertNotNull(path); assertNotNull(path);
assertEquals(path.numExamples(), 10, 0.0); assertEquals(10, path.numExamples(), 0.0);
} }
assertEquals(epochs, multiIter.epochs); assertEquals(epochs, multiIter.epochs);

View File

@ -33,7 +33,7 @@ public class SamplingTest extends BaseDL4JTest {
DataSetIterator iter = new MnistDataSetIterator(10, 10); DataSetIterator iter = new MnistDataSetIterator(10, 10);
//batch size and total //batch size and total
DataSetIterator sampling = new SamplingDataSetIterator(iter.next(), 10, 10); DataSetIterator sampling = new SamplingDataSetIterator(iter.next(), 10, 10);
assertEquals(sampling.next().numExamples(), 10); assertEquals(10, sampling.next().numExamples());
} }
} }

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.exceptions; package org.deeplearning4j.exceptions;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.exception.DL4JException; import org.deeplearning4j.exception.DL4JException;
import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.ConvolutionMode;
@ -34,6 +35,7 @@ import static org.junit.Assert.fail;
/** /**
* A set of tests to ensure that useful exceptions are thrown on invalid network configurations * A set of tests to ensure that useful exceptions are thrown on invalid network configurations
*/ */
@Slf4j
public class TestInvalidConfigurations extends BaseDL4JTest { public class TestInvalidConfigurations extends BaseDL4JTest {
public static MultiLayerNetwork getDensePlusOutput(int nIn, int nOut) { public static MultiLayerNetwork getDensePlusOutput(int nIn, int nOut) {
@ -78,7 +80,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) { } catch (DL4JException e) {
System.out.println("testDenseNin0(): " + e.getMessage()); System.out.println("testDenseNin0(): " + e.getMessage());
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail(); fail();
} }
} }
@ -96,7 +98,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) { } catch (DL4JException e) {
System.out.println("testDenseNout0(): " + e.getMessage()); System.out.println("testDenseNout0(): " + e.getMessage());
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail(); fail();
} }
} }
@ -109,7 +111,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) { } catch (DL4JException e) {
System.out.println("testOutputLayerNin0(): " + e.getMessage()); System.out.println("testOutputLayerNin0(): " + e.getMessage());
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail(); fail();
} }
} }
@ -122,7 +124,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) { } catch (DL4JException e) {
System.out.println("testRnnOutputLayerNin0(): " + e.getMessage()); System.out.println("testRnnOutputLayerNin0(): " + e.getMessage());
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail(); fail();
} }
} }
@ -135,7 +137,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) { } catch (DL4JException e) {
System.out.println("testLSTMNIn0(): " + e.getMessage()); System.out.println("testLSTMNIn0(): " + e.getMessage());
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail(); fail();
} }
} }
@ -153,7 +155,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) { } catch (DL4JException e) {
System.out.println("testLSTMNOut0(): " + e.getMessage()); System.out.println("testLSTMNOut0(): " + e.getMessage());
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail(); fail();
} }
} }
@ -166,7 +168,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) { } catch (DL4JException e) {
System.out.println("testConvolutionalNin0(): " + e.getMessage()); System.out.println("testConvolutionalNin0(): " + e.getMessage());
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail(); fail();
} }
} }
@ -185,7 +187,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) { } catch (DL4JException e) {
System.out.println("testConvolutionalNOut0(): " + e.getMessage()); System.out.println("testConvolutionalNOut0(): " + e.getMessage());
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail(); fail();
} }
} }
@ -216,7 +218,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) { } catch (DL4JException e) {
System.out.println("testCnnInvalidConfigPaddingStridesHeight(): " + e.getMessage()); System.out.println("testCnnInvalidConfigPaddingStridesHeight(): " + e.getMessage());
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail(); fail();
} }
} }
@ -245,7 +247,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) { } catch (DL4JException e) {
System.out.println("testCnnInvalidConfigOrInput_SmallerDataThanKernel(): " + e.getMessage()); System.out.println("testCnnInvalidConfigOrInput_SmallerDataThanKernel(): " + e.getMessage());
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail(); fail();
} }
} }
@ -277,7 +279,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) { } catch (DL4JException e) {
System.out.println("testCnnInvalidConfigOrInput_BadStrides(): " + e.getMessage()); System.out.println("testCnnInvalidConfigOrInput_BadStrides(): " + e.getMessage());
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail(); fail();
} }
} }
@ -318,7 +320,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) { } catch (DL4JException e) {
System.out.println("testCnnInvalidConfigPaddingStridesWidth(): " + e.getMessage()); System.out.println("testCnnInvalidConfigPaddingStridesWidth(): " + e.getMessage());
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail(); fail();
} }
} }
@ -347,7 +349,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) { } catch (DL4JException e) {
System.out.println("testCnnInvalidConfigPaddingStridesWidthSubsampling(): " + e.getMessage()); System.out.println("testCnnInvalidConfigPaddingStridesWidthSubsampling(): " + e.getMessage());
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail(); fail();
} }
} }

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.exceptions; package org.deeplearning4j.exceptions;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.exception.DL4JException; import org.deeplearning4j.exception.DL4JException;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
@ -36,6 +37,7 @@ import static org.junit.Assert.*;
/** /**
* A set of tests to ensure that useful exceptions are thrown on invalid input * A set of tests to ensure that useful exceptions are thrown on invalid input
*/ */
@Slf4j
public class TestInvalidInput extends BaseDL4JTest { public class TestInvalidInput extends BaseDL4JTest {
@Test @Test
@ -53,7 +55,7 @@ public class TestInvalidInput extends BaseDL4JTest {
} catch (DL4JException e) { } catch (DL4JException e) {
System.out.println("testInputNinMismatchDense(): " + e.getMessage()); System.out.println("testInputNinMismatchDense(): " + e.getMessage());
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail("Expected DL4JException"); fail("Expected DL4JException");
} }
} }
@ -73,7 +75,7 @@ public class TestInvalidInput extends BaseDL4JTest {
} catch (DL4JException e) { } catch (DL4JException e) {
System.out.println("testInputNinMismatchOutputLayer(): " + e.getMessage()); System.out.println("testInputNinMismatchOutputLayer(): " + e.getMessage());
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail("Expected DL4JException"); fail("Expected DL4JException");
} }
} }
@ -94,7 +96,7 @@ public class TestInvalidInput extends BaseDL4JTest {
//From loss function //From loss function
System.out.println("testLabelsNOutMismatchOutputLayer(): " + e.getMessage()); System.out.println("testLabelsNOutMismatchOutputLayer(): " + e.getMessage());
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail("Expected DL4JException"); fail("Expected DL4JException");
} }
} }
@ -115,7 +117,7 @@ public class TestInvalidInput extends BaseDL4JTest {
//From loss function //From loss function
System.out.println("testLabelsNOutMismatchRnnOutputLayer(): " + e.getMessage()); System.out.println("testLabelsNOutMismatchRnnOutputLayer(): " + e.getMessage());
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail("Expected DL4JException"); fail("Expected DL4JException");
} }
} }
@ -142,7 +144,7 @@ public class TestInvalidInput extends BaseDL4JTest {
} catch (DL4JException e) { } catch (DL4JException e) {
System.out.println("testInputNinMismatchConvolutional(): " + e.getMessage()); System.out.println("testInputNinMismatchConvolutional(): " + e.getMessage());
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail("Expected DL4JException"); fail("Expected DL4JException");
} }
} }
@ -169,7 +171,7 @@ public class TestInvalidInput extends BaseDL4JTest {
} catch (DL4JException e) { } catch (DL4JException e) {
System.out.println("testInputNinRank2Convolutional(): " + e.getMessage()); System.out.println("testInputNinRank2Convolutional(): " + e.getMessage());
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail("Expected DL4JException"); fail("Expected DL4JException");
} }
} }
@ -195,7 +197,7 @@ public class TestInvalidInput extends BaseDL4JTest {
} catch (DL4JException e) { } catch (DL4JException e) {
System.out.println("testInputNinRank2Subsampling(): " + e.getMessage()); System.out.println("testInputNinRank2Subsampling(): " + e.getMessage());
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail("Expected DL4JException"); fail("Expected DL4JException");
} }
} }
@ -217,7 +219,7 @@ public class TestInvalidInput extends BaseDL4JTest {
} catch (DL4JException e) { } catch (DL4JException e) {
System.out.println("testInputNinMismatchLSTM(): " + e.getMessage()); System.out.println("testInputNinMismatchLSTM(): " + e.getMessage());
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail("Expected DL4JException"); fail("Expected DL4JException");
} }
} }
@ -238,7 +240,7 @@ public class TestInvalidInput extends BaseDL4JTest {
} catch (DL4JException e) { } catch (DL4JException e) {
System.out.println("testInputNinMismatchBidirectionalLSTM(): " + e.getMessage()); System.out.println("testInputNinMismatchBidirectionalLSTM(): " + e.getMessage());
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail("Expected DL4JException"); fail("Expected DL4JException");
} }
@ -260,7 +262,7 @@ public class TestInvalidInput extends BaseDL4JTest {
} catch (DL4JException e) { } catch (DL4JException e) {
System.out.println("testInputNinMismatchEmbeddingLayer(): " + e.getMessage()); System.out.println("testInputNinMismatchEmbeddingLayer(): " + e.getMessage());
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail("Expected DL4JException"); fail("Expected DL4JException");
} }
} }
@ -305,7 +307,7 @@ public class TestInvalidInput extends BaseDL4JTest {
net.rnnTimeStep(Nd4j.create(5, 5, 10)); net.rnnTimeStep(Nd4j.create(5, 5, 10));
fail("Expected Exception - " + layerType); fail("Expected Exception - " + layerType);
} catch (Exception e) { } catch (Exception e) {
// e.printStackTrace(); log.error("",e);
String msg = e.getMessage(); String msg = e.getMessage();
assertTrue(msg, msg != null && msg.contains("rnn") && msg.contains("batch")); assertTrue(msg, msg != null && msg.contains("rnn") && msg.contains("batch"));
} }

View File

@ -21,6 +21,7 @@ import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils; import org.deeplearning4j.TestUtils;
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
@ -34,6 +35,8 @@ import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.junit.Ignore; import org.junit.Ignore;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -51,6 +54,7 @@ import static org.junit.Assert.*;
/** /**
* Created by nyghtowl on 9/1/15. * Created by nyghtowl on 9/1/15.
*/ */
@RunWith(Parameterized.class)
public class CNNGradientCheckTest extends BaseDL4JTest { public class CNNGradientCheckTest extends BaseDL4JTest {
private static final boolean PRINT_RESULTS = true; private static final boolean PRINT_RESULTS = true;
private static final boolean RETURN_ON_FIRST_FAILURE = false; private static final boolean RETURN_ON_FIRST_FAILURE = false;
@ -62,6 +66,17 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
Nd4j.setDataType(DataType.DOUBLE); Nd4j.setDataType(DataType.DOUBLE);
} }
private CNN2DFormat format;
public CNNGradientCheckTest(CNN2DFormat format){
this.format = format;
}
@Parameterized.Parameters(name = "{0}")
public static Object[] params(){
return CNN2DFormat.values();
}
@Override @Override
public long getTimeoutMilliseconds() { public long getTimeoutMilliseconds() {
return 90000L; return 90000L;
@ -69,6 +84,9 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
@Test @Test
public void testGradientCNNMLN() { public void testGradientCNNMLN() {
if(this.format != CNN2DFormat.NCHW) //Only test NCHW due to flat input format...
return;
//Parameterized test, testing combinations of: //Parameterized test, testing combinations of:
// (a) activation function // (a) activation function
// (b) Whether to test at random initialization, or after some learning (i.e., 'characteristic mode of operation') // (b) Whether to test at random initialization, or after some learning (i.e., 'characteristic mode of operation')
@ -144,6 +162,9 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
@Test @Test
public void testGradientCNNL1L2MLN() { public void testGradientCNNL1L2MLN() {
if(this.format != CNN2DFormat.NCHW) //Only test NCHW due to flat input format...
return;
//Parameterized test, testing combinations of: //Parameterized test, testing combinations of:
// (a) activation function // (a) activation function
// (b) Whether to test at random initialization, or after some learning (i.e., 'characteristic mode of operation') // (b) Whether to test at random initialization, or after some learning (i.e., 'characteristic mode of operation')
@ -311,10 +332,12 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
new SubsamplingLayer.PoolingType[]{SubsamplingLayer.PoolingType.MAX, new SubsamplingLayer.PoolingType[]{SubsamplingLayer.PoolingType.MAX,
SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM}; SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM};
boolean nchw = format == CNN2DFormat.NCHW;
for (String afn : activations) { for (String afn : activations) {
for (SubsamplingLayer.PoolingType poolingType : poolingTypes) { for (SubsamplingLayer.PoolingType poolingType : poolingTypes) {
for (int minibatchSize : minibatchSizes) { for (int minibatchSize : minibatchSizes) {
INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth); long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth};
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
INDArray labels = Nd4j.zeros(4 * minibatchSize, nOut); INDArray labels = Nd4j.zeros(4 * minibatchSize, nOut);
for (int i = 0; i < 4 * minibatchSize; i++) { for (int i = 0; i < 4 * minibatchSize; i++) {
labels.putScalar(new int[]{i, i % nOut}, 1.0); labels.putScalar(new int[]{i, i % nOut}, 1.0);
@ -330,13 +353,13 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX) .activation(Activation.SOFTMAX)
.nOut(nOut).build()) .nOut(nOut).build())
.setInputType(InputType.convolutionalFlat(height, width, inputDepth)) .setInputType(InputType.convolutional(height, width, inputDepth, format))
.build(); .build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
String msg = "PoolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn=" String msg = format + " - poolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn="
+ afn; + afn;
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
@ -377,8 +400,11 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
int[] padding = {0, 0}; int[] padding = {0, 0};
int size = 2; int size = 2;
boolean nchw = format == CNN2DFormat.NCHW;
for (int minibatchSize : minibatchSizes) { for (int minibatchSize : minibatchSizes) {
INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth); long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth};
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
INDArray labels = TestUtils.randomOneHot(minibatchSize, nOut); INDArray labels = TestUtils.randomOneHot(minibatchSize, nOut);
MultiLayerConfiguration conf = MultiLayerConfiguration conf =
@ -393,8 +419,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(8 * 8 * 3) .activation(Activation.SOFTMAX).nIn(8 * 8 * 3)
.nOut(4).build()) .nOut(4).build())
.setInputType(InputType.convolutionalFlat(height, width, .setInputType(InputType.convolutional(height, width, inputDepth, format))
inputDepth))
.build(); .build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
@ -438,10 +463,13 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
new SubsamplingLayer.PoolingType[]{SubsamplingLayer.PoolingType.MAX, new SubsamplingLayer.PoolingType[]{SubsamplingLayer.PoolingType.MAX,
SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM}; SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM};
boolean nchw = format == CNN2DFormat.NCHW;
for (Activation afn : activations) { for (Activation afn : activations) {
for (SubsamplingLayer.PoolingType poolingType : poolingTypes) { for (SubsamplingLayer.PoolingType poolingType : poolingTypes) {
for (int minibatchSize : minibatchSizes) { for (int minibatchSize : minibatchSizes) {
INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth); long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth};
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
INDArray labels = Nd4j.zeros(minibatchSize, nOut); INDArray labels = Nd4j.zeros(minibatchSize, nOut);
for (int i = 0; i < minibatchSize; i++) { for (int i = 0; i < minibatchSize; i++) {
labels.putScalar(new int[]{i, i % nOut}, 1.0); labels.putScalar(new int[]{i, i % nOut}, 1.0);
@ -461,14 +489,13 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(3 * 3 * 3) .activation(Activation.SOFTMAX).nIn(3 * 3 * 3)
.nOut(4).build()) .nOut(4).build())
.setInputType(InputType.convolutionalFlat(height, width, .setInputType(InputType.convolutional(height, width, inputDepth, format))
inputDepth))
.build(); .build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
String msg = "PoolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn=" String msg = format + " - poolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn="
+ afn; + afn;
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
@ -508,10 +535,13 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
new SubsamplingLayer.PoolingType[]{SubsamplingLayer.PoolingType.MAX, new SubsamplingLayer.PoolingType[]{SubsamplingLayer.PoolingType.MAX,
SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM}; SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM};
boolean nchw = format == CNN2DFormat.NCHW;
for (Activation afn : activations) { for (Activation afn : activations) {
for (SubsamplingLayer.PoolingType poolingType : poolingTypes) { for (SubsamplingLayer.PoolingType poolingType : poolingTypes) {
for (int minibatchSize : minibatchSizes) { for (int minibatchSize : minibatchSizes) {
INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth); long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth};
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
INDArray labels = Nd4j.zeros(minibatchSize, nOut); INDArray labels = Nd4j.zeros(minibatchSize, nOut);
for (int i = 0; i < minibatchSize; i++) { for (int i = 0; i < minibatchSize; i++) {
labels.putScalar(new int[]{i, i % nOut}, 1.0); labels.putScalar(new int[]{i, i % nOut}, 1.0);
@ -533,8 +563,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(2 * 2 * 2) .activation(Activation.SOFTMAX).nIn(2 * 2 * 2)
.nOut(4).build()) .nOut(4).build())
.setInputType(InputType.convolutionalFlat(height, width, .setInputType(InputType.convolutional(height, width, inputDepth, format))
inputDepth))
.build(); .build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
@ -558,8 +587,6 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
@Test @Test
public void testCnnLocallyConnected2D() { public void testCnnLocallyConnected2D() {
int nOut = 3; int nOut = 3;
int[] minibatchSizes = {2};
int width = 5; int width = 5;
int height = 5; int height = 5;
@ -569,11 +596,15 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
Activation[] activations = {Activation.SIGMOID, Activation.TANH, Activation.SOFTPLUS}; Activation[] activations = {Activation.SIGMOID, Activation.TANH, Activation.SOFTPLUS};
int[] minibatch = {2, 1, 3}; int[] minibatch = {2, 1, 3};
boolean nchw = format == CNN2DFormat.NCHW;
for( int i=0; i<inputDepths.length; i++ ){ for( int i=0; i<inputDepths.length; i++ ){
int inputDepth = inputDepths[i]; int inputDepth = inputDepths[i];
Activation afn = activations[i]; Activation afn = activations[i];
int minibatchSize = minibatch[i]; int minibatchSize = minibatch[i];
INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth);
long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth};
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
INDArray labels = TestUtils.randomOneHot(minibatchSize, nOut); INDArray labels = TestUtils.randomOneHot(minibatchSize, nOut);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(new NoOp()) MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(new NoOp())
@ -590,7 +621,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(2 * 2 * 2).nOut(nOut) .activation(Activation.SOFTMAX).nIn(2 * 2 * 2).nOut(nOut)
.build()) .build())
.setInputType(InputType.convolutionalFlat(height, width, inputDepth)).build(); .setInputType(InputType.convolutional(height, width, inputDepth, format)).build();
assertEquals(ConvolutionMode.Truncate, assertEquals(ConvolutionMode.Truncate,
((ConvolutionLayer) conf.getConf(0).getLayer()).getConvolutionMode()); ((ConvolutionLayer) conf.getConf(0).getLayer()).getConvolutionMode());
@ -626,11 +657,15 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
boolean nchw = format == CNN2DFormat.NCHW;
for (int inputDepth : inputDepths) { for (int inputDepth : inputDepths) {
for (Activation afn : activations) { for (Activation afn : activations) {
for (SubsamplingLayer.PoolingType poolingType : poolingTypes) { for (SubsamplingLayer.PoolingType poolingType : poolingTypes) {
for (int minibatchSize : minibatchSizes) { for (int minibatchSize : minibatchSizes) {
INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth); long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth};
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
INDArray labels = Nd4j.zeros(minibatchSize, nOut); INDArray labels = Nd4j.zeros(minibatchSize, nOut);
for (int i = 0; i < minibatchSize; i++) { for (int i = 0; i < minibatchSize; i++) {
labels.putScalar(new int[]{i, i % nOut}, 1.0); labels.putScalar(new int[]{i, i % nOut}, 1.0);
@ -649,7 +684,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(2 * 2 * 2).nOut(nOut) .activation(Activation.SOFTMAX).nIn(2 * 2 * 2).nOut(nOut)
.build()) .build())
.setInputType(InputType.convolutionalFlat(height, width, inputDepth)).build(); .setInputType(InputType.convolutional(height, width, inputDepth, format)).build();
assertEquals(ConvolutionMode.Truncate, assertEquals(ConvolutionMode.Truncate,
((ConvolutionLayer) conf.getConf(0).getLayer()).getConvolutionMode()); ((ConvolutionLayer) conf.getConf(0).getLayer()).getConvolutionMode());
@ -691,13 +726,17 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
boolean nchw = format == CNN2DFormat.NCHW;
for( int i=0; i<minibatchSizes.length; i++ ){ for( int i=0; i<minibatchSizes.length; i++ ){
int inputDepth = inputDepths[i]; int inputDepth = inputDepths[i];
int minibatchSize = minibatchSizes[i]; int minibatchSize = minibatchSizes[i];
int height = heights[i]; int height = heights[i];
int k = kernelSizes[i]; int k = kernelSizes[i];
INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth); long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth};
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
INDArray labels = TestUtils.randomOneHot(minibatchSize, nOut); INDArray labels = TestUtils.randomOneHot(minibatchSize, nOut);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
@ -713,7 +752,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
.stride(1, 1).padding(0, 0).build()) .stride(1, 1).padding(0, 0).build())
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(nOut).build()) .activation(Activation.SOFTMAX).nOut(nOut).build())
.setInputType(InputType.convolutionalFlat(height, width, inputDepth)).build(); .setInputType(InputType.convolutional(height, width, inputDepth, format)).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
@ -748,13 +787,16 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
boolean nchw = format == CNN2DFormat.NCHW;
for (int inputDepth : inputDepths) { for (int inputDepth : inputDepths) {
for (int minibatchSize : minibatchSizes) { for (int minibatchSize : minibatchSizes) {
for (int stride : strides) { for (int stride : strides) {
for (int k : kernelSizes) { for (int k : kernelSizes) {
for (boolean convFirst : new boolean[]{true, false}) { for (boolean convFirst : new boolean[]{true, false}) {
long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth};
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth);
INDArray labels = Nd4j.zeros(minibatchSize, nOut); INDArray labels = Nd4j.zeros(minibatchSize, nOut);
for (int i = 0; i < minibatchSize; i++) { for (int i = 0; i < minibatchSize; i++) {
labels.putScalar(new int[]{i, i % nOut}, 1.0); labels.putScalar(new int[]{i, i % nOut}, 1.0);
@ -775,7 +817,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
.layer(1, convFirst ? poolLayer : convLayer) .layer(1, convFirst ? poolLayer : convLayer)
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(nOut).build()) .activation(Activation.SOFTMAX).nOut(nOut).build())
.setInputType(InputType.convolutionalFlat(height, width, inputDepth)) .setInputType(InputType.convolutional(height, width, inputDepth, format))
.build(); .build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
@ -822,11 +864,15 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
int[] inputDepths = {1, 3, 2}; int[] inputDepths = {1, 3, 2};
int[][] zeroPadLayer = new int[][]{{0, 0, 0, 0}, {1, 1, 0, 0}, {2, 2, 2, 2}}; int[][] zeroPadLayer = new int[][]{{0, 0, 0, 0}, {1, 1, 0, 0}, {2, 2, 2, 2}};
boolean nchw = format == CNN2DFormat.NCHW;
for( int i=0; i<minibatchSizes.length; i++ ){ for( int i=0; i<minibatchSizes.length; i++ ){
int minibatchSize = minibatchSizes[i]; int minibatchSize = minibatchSizes[i];
int inputDepth = inputDepths[i]; int inputDepth = inputDepths[i];
int[] zeroPad = zeroPadLayer[i]; int[] zeroPad = zeroPadLayer[i];
INDArray input = Nd4j.rand(DataType.DOUBLE, new int[]{minibatchSize, inputDepth, height, width});
long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth};
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
INDArray labels = TestUtils.randomOneHot(minibatchSize, nOut); INDArray labels = TestUtils.randomOneHot(minibatchSize, nOut);
MultiLayerConfiguration conf = MultiLayerConfiguration conf =
@ -840,7 +886,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
padding).nIn(3).nOut(3).build())//output: (6-2+0)/1+1 = 5 padding).nIn(3).nOut(3).build())//output: (6-2+0)/1+1 = 5
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(4).build()) .activation(Activation.SOFTMAX).nOut(4).build())
.setInputType(InputType.convolutional(height, width, inputDepth)) .setInputType(InputType.convolutional(height, width, inputDepth, format))
.build(); .build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
@ -849,8 +895,14 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
//Check zero padding activation shape //Check zero padding activation shape
org.deeplearning4j.nn.layers.convolution.ZeroPaddingLayer zpl = org.deeplearning4j.nn.layers.convolution.ZeroPaddingLayer zpl =
(org.deeplearning4j.nn.layers.convolution.ZeroPaddingLayer) net.getLayer(1); (org.deeplearning4j.nn.layers.convolution.ZeroPaddingLayer) net.getLayer(1);
val expShape = new long[]{minibatchSize, inputDepth, height + zeroPad[0] + zeroPad[1], long[] expShape;
width + zeroPad[2] + zeroPad[3]}; if(nchw){
expShape = new long[]{minibatchSize, inputDepth, height + zeroPad[0] + zeroPad[1],
width + zeroPad[2] + zeroPad[3]};
} else {
expShape = new long[]{minibatchSize, height + zeroPad[0] + zeroPad[1],
width + zeroPad[2] + zeroPad[3], inputDepth};
}
INDArray out = zpl.activate(input, false, LayerWorkspaceMgr.noWorkspaces()); INDArray out = zpl.activate(input, false, LayerWorkspaceMgr.noWorkspaces());
assertArrayEquals(expShape, out.shape()); assertArrayEquals(expShape, out.shape());
@ -888,6 +940,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
boolean nchw = format == CNN2DFormat.NCHW;
for (int i = 0; i < minibatchSizes.length; i++) { for (int i = 0; i < minibatchSizes.length; i++) {
int minibatchSize = minibatchSizes[i]; int minibatchSize = minibatchSizes[i];
int k = kernelSizes[i]; int k = kernelSizes[i];
@ -900,7 +954,10 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
int w = d * width; int w = d * width;
int h = d * height; int h = d * height;
INDArray input = Nd4j.rand(minibatchSize, w * h * inputDepth); long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, h, w} : new long[]{minibatchSize, h, w, inputDepth};
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
INDArray labels = Nd4j.zeros(minibatchSize, nOut); INDArray labels = Nd4j.zeros(minibatchSize, nOut);
for (int j = 0; j < minibatchSize; j++) { for (int j = 0; j < minibatchSize; j++) {
labels.putScalar(new int[]{j, j % nOut}, 1.0); labels.putScalar(new int[]{j, j % nOut}, 1.0);
@ -920,7 +977,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
MultiLayerConfiguration conf = b.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) MultiLayerConfiguration conf = b.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(nOut).build()) .activation(Activation.SOFTMAX).nOut(nOut).build())
.setInputType(InputType.convolutionalFlat(h, w, inputDepth)).build(); .setInputType(InputType.convolutional(h, w, inputDepth, format)).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
@ -945,8 +1002,6 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
@Test @Test
public void testSeparableConv2D() { public void testSeparableConv2D() {
int nOut = 2; int nOut = 2;
int[] minibatchSizes = new int[]{1, 3};
int width = 6; int width = 6;
int height = 6; int height = 6;
int inputDepth = 3; int inputDepth = 3;
@ -959,6 +1014,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
ConvolutionMode[] cms = new ConvolutionMode[]{Truncate, Truncate, Truncate, Truncate, Truncate}; ConvolutionMode[] cms = new ConvolutionMode[]{Truncate, Truncate, Truncate, Truncate, Truncate};
int[] mb = new int[]{1, 1, 1, 3, 3}; int[] mb = new int[]{1, 1, 1, 3, 3};
boolean nchw = format == CNN2DFormat.NCHW;
for (int t = 0; t < ks.length; t++) { for (int t = 0; t < ks.length; t++) {
int k = ks[t]; int k = ks[t];
@ -971,7 +1028,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
int w = d * width; int w = d * width;
int h = d * height; int h = d * height;
INDArray input = Nd4j.rand(minibatchSize, w * h * inputDepth); long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, h, w} : new long[]{minibatchSize, h, w, inputDepth};
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
INDArray labels = Nd4j.zeros(minibatchSize, nOut); INDArray labels = Nd4j.zeros(minibatchSize, nOut);
for (int i = 0; i < minibatchSize; i++) { for (int i = 0; i < minibatchSize; i++) {
labels.putScalar(new int[]{i, i % nOut}, 1.0); labels.putScalar(new int[]{i, i % nOut}, 1.0);
@ -992,7 +1050,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
MultiLayerConfiguration conf = b.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) MultiLayerConfiguration conf = b.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(nOut).build()) .activation(Activation.SOFTMAX).nOut(nOut).build())
.setInputType(InputType.convolutionalFlat(h, w, inputDepth)).build(); .setInputType(InputType.convolutional(h, w, inputDepth, format)).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
@ -1017,7 +1075,6 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
@Test @Test
public void testCnnDilated() { public void testCnnDilated() {
int nOut = 2; int nOut = 2;
int minibatchSize = 2; int minibatchSize = 2;
int width = 8; int width = 8;
int height = 8; int height = 8;
@ -1031,9 +1088,9 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
int[] ds = new int[]{2, 2, 3, 3, 2}; int[] ds = new int[]{2, 2, 3, 3, 2};
ConvolutionMode[] cms = new ConvolutionMode[]{Same, Truncate, Truncate, Same, Truncate}; ConvolutionMode[] cms = new ConvolutionMode[]{Same, Truncate, Truncate, Same, Truncate};
boolean nchw = format == CNN2DFormat.NCHW;
for (int t = 0; t < sub.length; t++) { for (int t = 0; t < sub.length; t++) {
boolean subsampling = sub[t]; boolean subsampling = sub[t];
int s = stride[t]; int s = stride[t];
int k = kernel[t]; int k = kernel[t];
@ -1044,7 +1101,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
int w = d * width; int w = d * width;
int h = d * height; int h = d * height;
INDArray input = Nd4j.rand(minibatchSize, w * h * inputDepth); long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, h, w} : new long[]{minibatchSize, h, w, inputDepth};
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
INDArray labels = Nd4j.zeros(minibatchSize, nOut); INDArray labels = Nd4j.zeros(minibatchSize, nOut);
for (int i = 0; i < minibatchSize; i++) { for (int i = 0; i < minibatchSize; i++) {
labels.putScalar(new int[]{i, i % nOut}, 1.0); labels.putScalar(new int[]{i, i % nOut}, 1.0);
@ -1076,7 +1134,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
MultiLayerConfiguration conf = b.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) MultiLayerConfiguration conf = b.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(nOut).build()) .activation(Activation.SOFTMAX).nOut(nOut).build())
.setInputType(InputType.convolutionalFlat(h, w, inputDepth)).build(); .setInputType(InputType.convolutional(h, w, inputDepth, format)).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
@ -1114,11 +1172,14 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
int[] inputDepths = {1, 2, 3, 2}; int[] inputDepths = {1, 2, 3, 2};
int[] minibatchSizes = {2, 1, 3, 2}; int[] minibatchSizes = {2, 1, 3, 2};
boolean nchw = format == CNN2DFormat.NCHW;
for (int i = 0; i < cropTestCases.length; i++) { for (int i = 0; i < cropTestCases.length; i++) {
int inputDepth = inputDepths[i]; int inputDepth = inputDepths[i];
int minibatchSize = minibatchSizes[i]; int minibatchSize = minibatchSizes[i];
int[] crop = cropTestCases[i]; int[] crop = cropTestCases[i];
INDArray input = Nd4j.rand(new int[]{minibatchSize, inputDepth, height, width}); long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth};
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
INDArray labels = TestUtils.randomOneHot(minibatchSize, nOut); INDArray labels = TestUtils.randomOneHot(minibatchSize, nOut);
MultiLayerConfiguration conf = MultiLayerConfiguration conf =
@ -1134,7 +1195,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
.layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.AVG).kernelSize(3, 3).stride(3, 3).build()) .layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.AVG).kernelSize(3, 3).stride(3, 3).build())
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(nOut).build()) .activation(Activation.SOFTMAX).nOut(nOut).build())
.setInputType(InputType.convolutional(height, width, inputDepth)) .setInputType(InputType.convolutional(height, width, inputDepth, format))
.build(); .build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
@ -1143,12 +1204,18 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
//Check cropping activation shape //Check cropping activation shape
org.deeplearning4j.nn.layers.convolution.Cropping2DLayer cl = org.deeplearning4j.nn.layers.convolution.Cropping2DLayer cl =
(org.deeplearning4j.nn.layers.convolution.Cropping2DLayer) net.getLayer(1); (org.deeplearning4j.nn.layers.convolution.Cropping2DLayer) net.getLayer(1);
val expShape = new long[]{minibatchSize, inputDepth, height - crop[0] - crop[1], long[] expShape;
width - crop[2] - crop[3]}; if(nchw){
expShape = new long[]{minibatchSize, inputDepth, height - crop[0] - crop[1],
width - crop[2] - crop[3]};
} else {
expShape = new long[]{minibatchSize, height - crop[0] - crop[1],
width - crop[2] - crop[3], inputDepth};
}
INDArray out = cl.activate(input, false, LayerWorkspaceMgr.noWorkspaces()); INDArray out = cl.activate(input, false, LayerWorkspaceMgr.noWorkspaces());
assertArrayEquals(expShape, out.shape()); assertArrayEquals(expShape, out.shape());
String msg = "minibatch=" + minibatchSize + ", channels=" + inputDepth + ", zeroPad = " String msg = format + " - minibatch=" + minibatchSize + ", channels=" + inputDepth + ", zeroPad = "
+ Arrays.toString(crop); + Arrays.toString(crop);
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
@ -1181,6 +1248,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
Truncate, Truncate, Truncate, Truncate, Truncate}; Truncate, Truncate, Truncate, Truncate, Truncate};
int[] mb = new int[]{1,1,1,3,3}; int[] mb = new int[]{1,1,1,3,3};
boolean nchw = format == CNN2DFormat.NCHW;
for( int t=0; t<ks.length; t++ ){ for( int t=0; t<ks.length; t++ ){
int k = ks[t]; int k = ks[t];
@ -1188,8 +1257,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
ConvolutionMode cm = cms[t]; ConvolutionMode cm = cms[t];
int minibatchSize = mb[t]; int minibatchSize = mb[t];
long[] inShape = nchw ? new long[]{minibatchSize, nIn, height, width} : new long[]{minibatchSize, height, width, nIn};
INDArray input = Nd4j.rand(minibatchSize, width * height * nIn); INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
INDArray labels = Nd4j.zeros(minibatchSize, nOut); INDArray labels = Nd4j.zeros(minibatchSize, nOut);
for (int i = 0; i < minibatchSize; i++) { for (int i = 0; i < minibatchSize; i++) {
labels.putScalar(new int[]{i, i % nOut}, 1.0); labels.putScalar(new int[]{i, i % nOut}, 1.0);
@ -1211,7 +1280,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
MultiLayerConfiguration conf = b.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) MultiLayerConfiguration conf = b.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(nOut).build()) .activation(Activation.SOFTMAX).nOut(nOut).build())
.setInputType(InputType.convolutionalFlat(height, width, nIn)).build(); .setInputType(InputType.convolutional(height, width, nIn, format)).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();

View File

@ -18,6 +18,7 @@ package org.deeplearning4j.gradientcheck;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils; import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
@ -115,55 +116,57 @@ public class GlobalPoolingGradientCheckTests extends BaseDL4JTest {
//Basic test of global pooling w/ CNN //Basic test of global pooling w/ CNN
Nd4j.getRandom().setSeed(12345L); Nd4j.getRandom().setSeed(12345L);
int inputDepth = 3; for(boolean nchw : new boolean[]{true, false}) {
int inputH = 5;
int inputW = 4;
int layerDepth = 4;
int nOut = 2;
int[] minibatchSizes = new int[] {1, 3}; int inputDepth = 3;
PoolingType[] poolingTypes = int inputH = 5;
new PoolingType[] {PoolingType.AVG, PoolingType.SUM, PoolingType.MAX, PoolingType.PNORM}; int inputW = 4;
int layerDepth = 4;
int nOut = 2;
for (int miniBatchSize : minibatchSizes) { int[] minibatchSizes = new int[]{1, 3};
for (PoolingType pt : poolingTypes) { PoolingType[] poolingTypes =
new PoolingType[]{PoolingType.AVG, PoolingType.SUM, PoolingType.MAX, PoolingType.PNORM};
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() for (int miniBatchSize : minibatchSizes) {
.dataType(DataType.DOUBLE) for (PoolingType pt : poolingTypes) {
.updater(new NoOp())
.dist(new NormalDistribution(0, 1.0)).seed(12345L).list()
.layer(0, new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).nOut(layerDepth)
.build())
.layer(1, new GlobalPoolingLayer.Builder().poolingType(pt).build())
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(nOut).build())
.setInputType(InputType.convolutional(inputH, inputW, inputDepth)).build(); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.updater(new NoOp())
.dist(new NormalDistribution(0, 1.0)).seed(12345L).list()
.layer(0, new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).nOut(layerDepth)
.build())
.layer(1, new GlobalPoolingLayer.Builder().poolingType(pt).build())
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(nOut).build())
.setInputType(InputType.convolutional(inputH, inputW, inputDepth, nchw ? CNN2DFormat.NCHW : CNN2DFormat.NHWC)).build();
MultiLayerNetwork mln = new MultiLayerNetwork(conf); MultiLayerNetwork mln = new MultiLayerNetwork(conf);
mln.init(); mln.init();
Random r = new Random(12345L); Random r = new Random(12345L);
INDArray input = Nd4j.rand(new int[] {miniBatchSize, inputDepth, inputH, inputW}).subi(0.5); long[] inShape = nchw ? new long[]{miniBatchSize, inputDepth, inputH, inputW} : new long[]{miniBatchSize, inputH, inputW, inputDepth};
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape).subi(0.5);
INDArray labels = Nd4j.zeros(miniBatchSize, nOut); INDArray labels = Nd4j.zeros(miniBatchSize, nOut);
for (int i = 0; i < miniBatchSize; i++) { for (int i = 0; i < miniBatchSize; i++) {
int idx = r.nextInt(nOut); int idx = r.nextInt(nOut);
labels.putScalar(i, idx, 1.0); labels.putScalar(i, idx, 1.0);
} }
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println( System.out.println("testCnnGlobalPoolingBasicMultiLayer() - " + pt + ", minibatch = " + miniBatchSize + " - " + (nchw ? "NCHW" : "NHWC"));
"testCnnGlobalPoolingBasicMultiLayer() - " + pt + ", minibatch = " + miniBatchSize);
// for (int j = 0; j < mln.getnLayers(); j++) // for (int j = 0; j < mln.getnLayers(); j++)
// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
}
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
assertTrue(gradOK);
TestUtils.testModelSerialization(mln);
} }
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
assertTrue(gradOK);
TestUtils.testModelSerialization(mln);
} }
} }
} }

View File

@ -24,6 +24,7 @@ import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.distribution.UniformDistribution; import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.*;
@ -560,75 +561,81 @@ public class GradientCheckTests extends BaseDL4JTest {
public void testEmbeddingSequenceLayer(){ public void testEmbeddingSequenceLayer(){
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
for(boolean maskArray : new boolean[]{false, true}){ for(RNNFormat seqOutputFormat : RNNFormat.values()) {
for(int inputRank : new int[]{2,3}) { for (boolean maskArray : new boolean[]{false, true}) {
for (int inputRank : new int[]{2, 3}) {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE) .dataType(DataType.DOUBLE)
.seed(12345) .seed(12345)
.updater(new NoOp()) .updater(new NoOp())
.weightInit(new NormalDistribution(0, 1)) .weightInit(new NormalDistribution(0, 1))
.list() .list()
.layer(new EmbeddingSequenceLayer.Builder() .layer(new EmbeddingSequenceLayer.Builder()
.nIn(8) .nIn(8)
.nOut(4) .nOut(4)
.build()) .outputDataFormat(seqOutputFormat)
.layer(new RnnOutputLayer.Builder().nIn(4).nOut(3).activation(Activation.TANH) .build())
.lossFunction(LossFunction.MSE).build()) .layer(new RnnOutputLayer.Builder().nIn(4).nOut(3).activation(Activation.TANH)
.build(); .dataFormat(seqOutputFormat)
.lossFunction(LossFunction.MSE).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
INDArray in = Transforms.floor(Nd4j.rand(3, 6).muli(8)); //Integers 0 to 7 inclusive boolean ncw = seqOutputFormat == RNNFormat.NCW;
INDArray label = Nd4j.rand(new int[]{3, 3, 6});
if(inputRank == 3){ INDArray in = Transforms.floor(Nd4j.rand(3, 6).muli(8)); //Integers 0 to 7 inclusive
//Reshape from [3,6] to [3,1,6] INDArray label = Nd4j.rand(DataType.FLOAT, ncw ? new int[]{3, 3, 6} : new int[]{3,6,3});
in = in.reshape('c', 3, 1, 6);
}
INDArray fMask = null; if (inputRank == 3) {
if (maskArray) { //Reshape from [3,6] to [3,1,6]
fMask = Nd4j.create(new double[][]{{1, 1, 1, 1, 1, 1}, in = in.reshape('c', 3, 1, 6);
{1, 1, 0, 0, 0, 0},
{1, 0, 0, 0, 0, 0}});
}
String msg = "mask=" + maskArray + ", inputRank=" + inputRank;
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in)
.labels(label).inputMask(fMask));
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(net);
//Also: if mask is present, double check that the masked steps don't impact score
if (maskArray) {
DataSet ds = new DataSet(in, label, fMask, null);
double score = net.score(ds);
if(inputRank == 2){
in.putScalar(1, 2, 0);
in.putScalar(2, 1, 0);
in.putScalar(2, 2, 0);
} else {
in.putScalar(1, 0, 2, 0);
in.putScalar(2, 0, 1, 0);
in.putScalar(2, 0, 2, 0);
} }
double score2 = net.score(ds);
assertEquals(score, score2, 1e-6); INDArray fMask = null;
if(inputRank == 2){ if (maskArray) {
in.putScalar(1, 2, 1); fMask = Nd4j.create(new double[][]{{1, 1, 1, 1, 1, 1},
in.putScalar(2, 1, 1); {1, 1, 0, 0, 0, 0},
in.putScalar(2, 2, 1); {1, 0, 0, 0, 0, 0}});
} else {
in.putScalar(1, 0, 2, 1); }
in.putScalar(2, 0, 1, 1);
in.putScalar(2, 0, 2, 1); String msg = "mask=" + maskArray + ", inputRank=" + inputRank;
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in)
.labels(label).inputMask(fMask));
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(net);
//Also: if mask is present, double check that the masked steps don't impact score
if (maskArray) {
DataSet ds = new DataSet(in, label, fMask, null);
double score = net.score(ds);
if (inputRank == 2) {
in.putScalar(1, 2, 0);
in.putScalar(2, 1, 0);
in.putScalar(2, 2, 0);
} else {
in.putScalar(1, 0, 2, 0);
in.putScalar(2, 0, 1, 0);
in.putScalar(2, 0, 2, 0);
}
double score2 = net.score(ds);
assertEquals(score, score2, 1e-6);
if (inputRank == 2) {
in.putScalar(1, 2, 1);
in.putScalar(2, 1, 1);
in.putScalar(2, 2, 1);
} else {
in.putScalar(1, 0, 2, 1);
in.putScalar(2, 0, 1, 1);
in.putScalar(2, 0, 2, 1);
}
double score3 = net.score(ds);
assertEquals(score, score3, 1e-6);
} }
double score3 = net.score(ds);
assertEquals(score, score3, 1e-6);
} }
} }
} }

View File

@ -21,9 +21,7 @@ import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils; import org.deeplearning4j.TestUtils;
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.distribution.GaussianDistribution; import org.deeplearning4j.nn.conf.distribution.GaussianDistribution;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.distribution.UniformDistribution; import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
@ -341,104 +339,112 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
@Test @Test
public void testCnnDepthMerge() { public void testCnnDepthMerge() {
Nd4j.getRandom().setSeed(12345); for(CNN2DFormat format : CNN2DFormat.values()) {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
.dataType(DataType.DOUBLE)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.dist(new NormalDistribution(0, 0.1))
.updater(new NoOp()).graphBuilder().addInputs("input")
.addLayer("l1", new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).padding(0, 0)
.nIn(2).nOut(2).activation(Activation.TANH).build(), "input")
.addLayer("l2", new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).padding(0, 0)
.nIn(2).nOut(2).activation(Activation.TANH).build(), "input")
.addVertex("merge", new MergeVertex(), "l1", "l2")
.addLayer("outputLayer",
new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(5 * 5 * (2 + 2)).nOut(3)
.build(),
"merge")
.setOutputs("outputLayer")
.inputPreProcessor("outputLayer", new CnnToFeedForwardPreProcessor(5, 5, 4))
.build();
ComputationGraph graph = new ComputationGraph(conf); String msg = "testCnnDepthMerge - " + format;
graph.init();
Random r = new Random(12345); Nd4j.getRandom().setSeed(12345);
INDArray input = Nd4j.rand(new int[] {5, 2, 6, 6}); //Order: examples, channels, height, width ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
INDArray labels = Nd4j.zeros(5, 3); .dataType(DataType.DOUBLE)
for (int i = 0; i < 5; i++) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
labels.putScalar(new int[] {i, r.nextInt(3)}, 1.0); .dist(new NormalDistribution(0, 0.1))
.updater(new NoOp()).graphBuilder().addInputs("input")
.addLayer("l1", new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).padding(0, 0)
.nIn(2).nOut(2).activation(Activation.TANH).build(), "input")
.addLayer("l2", new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).padding(0, 0)
.nIn(2).nOut(2).activation(Activation.TANH).build(), "input")
.addVertex("merge", new MergeVertex(), "l1", "l2")
.addLayer("outputLayer",
new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(5 * 5 * (2 + 2)).nOut(3)
.build(),
"merge")
.setOutputs("outputLayer")
.setInputTypes(InputType.convolutional(6, 6, 2, format))
.build();
if (PRINT_RESULTS) { ComputationGraph graph = new ComputationGraph(conf);
System.out.println("testCnnDepthMerge()"); graph.init();
Random r = new Random(12345);
INDArray input = Nd4j.rand(DataType.DOUBLE, format == CNN2DFormat.NCHW ? new long[]{5,2,6,6} : new long[]{5,6,6,2});
INDArray labels = Nd4j.zeros(5, 3);
for (int i = 0; i < 5; i++)
labels.putScalar(new int[]{i, r.nextInt(3)}, 1.0);
if (PRINT_RESULTS) {
System.out.println(msg);
// for (int j = 0; j < graph.getNumLayers(); j++) // for (int j = 0; j < graph.getNumLayers(); j++)
// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
}
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
.labels(new INDArray[]{labels}));
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(graph);
} }
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
.labels(new INDArray[]{labels}));
String msg = "testCnnDepthMerge()";
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(graph);
} }
@Test @Test
public void testRNNWithMerging() { public void testRNNWithMerging() {
Nd4j.getRandom().setSeed(12345); for(RNNFormat format : RNNFormat.values()) {
ComputationGraphConfiguration conf =
new NeuralNetConfiguration.Builder().seed(12345)
.dataType(DataType.DOUBLE)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.dist(new UniformDistribution(0.2, 0.6))
.updater(new NoOp()).graphBuilder().addInputs("input")
.setOutputs("out")
.addLayer("lstm1",
new SimpleRnn.Builder().nIn(3).nOut(3)
.activation(Activation.TANH).build(),
"input")
.addLayer("lstm2",
new SimpleRnn.Builder().nIn(3).nOut(3)
.activation(Activation.TANH).build(),
"lstm1")
.addLayer("dense1",
new DenseLayer.Builder().nIn(3).nOut(3)
.activation(Activation.SIGMOID).build(),
"lstm1")
.addLayer("lstm3",
new SimpleRnn.Builder().nIn(3).nOut(3)
.activation(Activation.TANH).build(),
"dense1")
.addVertex("merge", new MergeVertex(), "lstm2", "lstm3")
.addLayer("out", new RnnOutputLayer.Builder().nIn(6).nOut(3)
.activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build(),
"merge")
.inputPreProcessor("dense1", new RnnToFeedForwardPreProcessor())
.inputPreProcessor("lstm3", new FeedForwardToRnnPreProcessor())
.build();
ComputationGraph graph = new ComputationGraph(conf); String msg = "testLSTMWithMerging - " + format;
graph.init();
Random r = new Random(12345); Nd4j.getRandom().setSeed(12345);
INDArray input = Nd4j.rand(new int[] {2, 3, 4}); ComputationGraphConfiguration conf =
INDArray labels = TestUtils.randomOneHotTimeSeries(2, 3, 4); new NeuralNetConfiguration.Builder().seed(12345)
.dataType(DataType.DOUBLE)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.dist(new UniformDistribution(0.2, 0.6))
.updater(new NoOp()).graphBuilder().addInputs("input")
.setOutputs("out")
.addLayer("lstm1",
new SimpleRnn.Builder().nIn(3).nOut(3)
.activation(Activation.TANH).build(),
"input")
.addLayer("lstm2",
new SimpleRnn.Builder().nIn(3).nOut(3)
.activation(Activation.TANH).build(),
"lstm1")
.addLayer("dense1",
new DenseLayer.Builder().nIn(3).nOut(3)
.activation(Activation.SIGMOID).build(),
"lstm1")
.addLayer("lstm3",
new SimpleRnn.Builder().nIn(3).nOut(3)
.activation(Activation.TANH).build(),
"dense1")
.addVertex("merge", new MergeVertex(), "lstm2", "lstm3")
.addLayer("out", new RnnOutputLayer.Builder().nIn(6).nOut(3)
.activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build(),
"merge")
.setInputTypes(InputType.recurrent(4, format))
.build();
if (PRINT_RESULTS) { ComputationGraph graph = new ComputationGraph(conf);
System.out.println("testLSTMWithMerging()"); graph.init();
Random r = new Random(12345);
INDArray input = Nd4j.rand(DataType.DOUBLE, format == RNNFormat.NCW ? new long[]{2, 3, 4} : new long[]{2,4,3});
INDArray labels = TestUtils.randomOneHotTimeSeries(format, 2, 3, 4, new Random(12345));
if (PRINT_RESULTS) {
System.out.println(msg);
// for (int j = 0; j < graph.getNumLayers(); j++) // for (int j = 0; j < graph.getNumLayers(); j++)
// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
}
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
.labels(new INDArray[]{labels}));
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(graph);
} }
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
.labels(new INDArray[]{labels}));
String msg = "testLSTMWithMerging()";
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(graph);
} }
@Test @Test

View File

@ -216,7 +216,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
failed.add(testName + "\t" + "EXCEPTION"); failed.add(testName + "\t" + "EXCEPTION");
continue; continue;
} }
@ -383,7 +383,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
failed.add(testName + "\t" + "EXCEPTION"); failed.add(testName + "\t" + "EXCEPTION");
continue; continue;
} }
@ -693,7 +693,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
failed.add(testName + "\t" + "EXCEPTION"); failed.add(testName + "\t" + "EXCEPTION");
continue; continue;
} }

View File

@ -338,7 +338,7 @@ public class RnnGradientChecks extends BaseDL4JTest {
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
.list() .list()
.layer(new LSTM.Builder().nOut(layerSize).build()) .layer(new LSTM.Builder().nOut(layerSize).build())
.layer(new TimeDistributed(new DenseLayer.Builder().nOut(layerSize).activation(Activation.SOFTMAX).build(), 2)) .layer(new TimeDistributed(new DenseLayer.Builder().nOut(layerSize).activation(Activation.SOFTMAX).build()))
.layer(new RnnOutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX) .layer(new RnnOutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build()) .lossFunction(LossFunctions.LossFunction.MCXENT).build())
.setInputType(InputType.recurrent(nIn)) .setInputType(InputType.recurrent(nIn))

View File

@ -21,6 +21,7 @@ import lombok.AllArgsConstructor;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.exception.DL4JInvalidConfigException; import org.deeplearning4j.exception.DL4JInvalidConfigException;
import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.api.OptimizationAlgorithm;
@ -47,6 +48,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
import static org.junit.Assert.*; import static org.junit.Assert.*;
@Slf4j
public class ComputationGraphConfigurationTest extends BaseDL4JTest { public class ComputationGraphConfigurationTest extends BaseDL4JTest {
@Test @Test
@ -150,7 +152,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
fail("No exception thrown for invalid configuration"); fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) { } catch (IllegalStateException e) {
//OK - exception is good //OK - exception is good
//e.printStackTrace(); log.info(e.toString());
} }
// Use appendLayer on first layer // Use appendLayer on first layer
@ -162,7 +164,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
fail("No exception thrown for invalid configuration"); fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) { } catch (IllegalStateException e) {
//OK - exception is good //OK - exception is good
//e.printStackTrace(); log.info(e.toString());
} }
//Test no network inputs //Test no network inputs
@ -174,7 +176,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
fail("No exception thrown for invalid configuration"); fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) { } catch (IllegalStateException e) {
//OK - exception is good //OK - exception is good
//e.printStackTrace(); log.info(e.toString());
} }
//Test no network outputs //Test no network outputs
@ -185,7 +187,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
fail("No exception thrown for invalid configuration"); fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) { } catch (IllegalStateException e) {
//OK - exception is good //OK - exception is good
//e.printStackTrace(); log.info(e.toString());
} }
//Test: invalid input //Test: invalid input
@ -197,7 +199,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
fail("No exception thrown for invalid configuration"); fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) { } catch (IllegalStateException e) {
//OK - exception is good //OK - exception is good
//e.printStackTrace(); log.info(e.toString());
} }
//Test: graph with cycles //Test: graph with cycles
@ -215,7 +217,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
fail("No exception thrown for invalid configuration"); fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) { } catch (IllegalStateException e) {
//OK - exception is good //OK - exception is good
//e.printStackTrace(); log.info(e.toString());
} }
//Test: input != inputType count mismatch //Test: input != inputType count mismatch
@ -241,7 +243,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
fail("No exception thrown for invalid configuration"); fail("No exception thrown for invalid configuration");
} catch (IllegalArgumentException e) { } catch (IllegalArgumentException e) {
//OK - exception is good //OK - exception is good
//e.printStackTrace(); log.info(e.toString());
} }
} }

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.nn.conf; package org.deeplearning4j.nn.conf;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.exception.DL4JInvalidConfigException; import org.deeplearning4j.exception.DL4JInvalidConfigException;
import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Layer;
@ -46,6 +47,7 @@ import static org.junit.Assert.*;
/** /**
* Created by agibsonccc on 11/27/14. * Created by agibsonccc on 11/27/14.
*/ */
@Slf4j
public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest { public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
@Rule @Rule
@ -272,9 +274,9 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
fail("No exception thrown for invalid configuration"); fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) { } catch (IllegalStateException e) {
//OK //OK
e.printStackTrace(); log.error("",e);
} catch (Throwable e) { } catch (Throwable e) {
e.printStackTrace(); log.error("",e);
fail("Unexpected exception thrown for invalid config"); fail("Unexpected exception thrown for invalid config");
} }
@ -288,9 +290,9 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
fail("No exception thrown for invalid configuration"); fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) { } catch (IllegalStateException e) {
//OK //OK
e.printStackTrace(); log.info(e.toString());
} catch (Throwable e) { } catch (Throwable e) {
e.printStackTrace(); log.error("",e);
fail("Unexpected exception thrown for invalid config"); fail("Unexpected exception thrown for invalid config");
} }
@ -304,9 +306,9 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
fail("No exception thrown for invalid configuration"); fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) { } catch (IllegalStateException e) {
//OK //OK
e.printStackTrace(); log.info(e.toString());
} catch (Throwable e) { } catch (Throwable e) {
e.printStackTrace(); log.error("",e);
fail("Unexpected exception thrown for invalid config"); fail("Unexpected exception thrown for invalid config");
} }
} }

View File

@ -96,8 +96,8 @@ public class TestConstraints extends BaseDL4JTest {
} else if (lc instanceof NonNegativeConstraint) { } else if (lc instanceof NonNegativeConstraint) {
assertTrue(RW0.minNumber().doubleValue() >= 0.0); assertTrue(RW0.minNumber().doubleValue() >= 0.0);
} else if (lc instanceof UnitNormConstraint) { } else if (lc instanceof UnitNormConstraint) {
assertEquals(RW0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6); assertEquals(1.0, RW0.norm2(1).minNumber().doubleValue(), 1e-6);
assertEquals(RW0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6); assertEquals(1.0, RW0.norm2(1).maxNumber().doubleValue(), 1e-6);
} }
TestUtils.testModelSerialization(net); TestUtils.testModelSerialization(net);
@ -149,8 +149,8 @@ public class TestConstraints extends BaseDL4JTest {
} else if (lc instanceof NonNegativeConstraint) { } else if (lc instanceof NonNegativeConstraint) {
assertTrue(b0.minNumber().doubleValue() >= 0.0); assertTrue(b0.minNumber().doubleValue() >= 0.0);
} else if (lc instanceof UnitNormConstraint) { } else if (lc instanceof UnitNormConstraint) {
assertEquals(b0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6); assertEquals(1.0, b0.norm2(1).minNumber().doubleValue(), 1e-6);
assertEquals(b0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6); assertEquals(1.0, b0.norm2(1).maxNumber().doubleValue(), 1e-6);
} }
TestUtils.testModelSerialization(net); TestUtils.testModelSerialization(net);
@ -201,8 +201,8 @@ public class TestConstraints extends BaseDL4JTest {
} else if (lc instanceof NonNegativeConstraint) { } else if (lc instanceof NonNegativeConstraint) {
assertTrue(w0.minNumber().doubleValue() >= 0.0); assertTrue(w0.minNumber().doubleValue() >= 0.0);
} else if (lc instanceof UnitNormConstraint) { } else if (lc instanceof UnitNormConstraint) {
assertEquals(w0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6); assertEquals(1.0, w0.norm2(1).minNumber().doubleValue(), 1e-6);
assertEquals(w0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6); assertEquals(1.0, w0.norm2(1).maxNumber().doubleValue(), 1e-6);
} }
TestUtils.testModelSerialization(net); TestUtils.testModelSerialization(net);
@ -259,10 +259,10 @@ public class TestConstraints extends BaseDL4JTest {
assertTrue(w0.minNumber().doubleValue() >= 0.0); assertTrue(w0.minNumber().doubleValue() >= 0.0);
assertTrue(b0.minNumber().doubleValue() >= 0.0); assertTrue(b0.minNumber().doubleValue() >= 0.0);
} else if (lc instanceof UnitNormConstraint) { } else if (lc instanceof UnitNormConstraint) {
assertEquals(w0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6); assertEquals(1.0, w0.norm2(1).minNumber().doubleValue(), 1e-6);
assertEquals(w0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6); assertEquals(1.0, w0.norm2(1).maxNumber().doubleValue(), 1e-6);
assertEquals(b0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6); assertEquals(1.0, b0.norm2(1).minNumber().doubleValue(), 1e-6);
assertEquals(b0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6); assertEquals(1.0, b0.norm2(1).maxNumber().doubleValue(), 1e-6);
} }
TestUtils.testModelSerialization(net); TestUtils.testModelSerialization(net);
@ -320,10 +320,10 @@ public class TestConstraints extends BaseDL4JTest {
assertTrue(w0.minNumber().doubleValue() >= 0.0); assertTrue(w0.minNumber().doubleValue() >= 0.0);
assertTrue(b0.minNumber().doubleValue() >= 0.0); assertTrue(b0.minNumber().doubleValue() >= 0.0);
} else if (lc instanceof UnitNormConstraint) { } else if (lc instanceof UnitNormConstraint) {
assertEquals(w0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6); assertEquals(1.0, w0.norm2(1).minNumber().doubleValue(), 1e-6);
assertEquals(w0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6); assertEquals(1.0, w0.norm2(1).maxNumber().doubleValue(), 1e-6);
assertEquals(b0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6); assertEquals(1.0, b0.norm2(1).minNumber().doubleValue(), 1e-6);
assertEquals(b0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6); assertEquals(1.0, b0.norm2(1).maxNumber().doubleValue(), 1e-6);
} }
TestUtils.testModelSerialization(net); TestUtils.testModelSerialization(net);
@ -378,10 +378,10 @@ public class TestConstraints extends BaseDL4JTest {
} else if(lc instanceof NonNegativeConstraint ){ } else if(lc instanceof NonNegativeConstraint ){
assertTrue(w0.minNumber().doubleValue() >= 0.0 ); assertTrue(w0.minNumber().doubleValue() >= 0.0 );
} else if(lc instanceof UnitNormConstraint ){ } else if(lc instanceof UnitNormConstraint ){
assertEquals(w0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6 ); assertEquals(1.0, w0.norm2(1).minNumber().doubleValue(), 1e-6 );
assertEquals(w0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6 ); assertEquals(1.0, w0.norm2(1).maxNumber().doubleValue(), 1e-6 );
assertEquals(w1.norm2(1).minNumber().doubleValue(), 1.0, 1e-6 ); assertEquals(1.0, w1.norm2(1).minNumber().doubleValue(), 1e-6 );
assertEquals(w1.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6 ); assertEquals(1.0, w1.norm2(1).maxNumber().doubleValue(), 1e-6 );
} }
TestUtils.testModelSerialization(net); TestUtils.testModelSerialization(net);

View File

@ -156,7 +156,7 @@ public class LayerBuilderTest extends BaseDL4JTest {
checkSerialization(glstm); checkSerialization(glstm);
assertEquals(glstm.getForgetGateBiasInit(), 1.5, 0.0); assertEquals(1.5, glstm.getForgetGateBiasInit(), 0.0);
assertEquals(glstm.nIn, numIn); assertEquals(glstm.nIn, numIn);
assertEquals(glstm.nOut, numOut); assertEquals(glstm.nOut, numOut);
assertTrue(glstm.getActivationFn() instanceof ActivationTanH); assertTrue(glstm.getActivationFn() instanceof ActivationTanH);

View File

@ -17,7 +17,9 @@
package org.deeplearning4j.nn.dtypes; package org.deeplearning4j.nn.dtypes;
import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed; import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
import org.deeplearning4j.nn.conf.preprocessor.*;
import org.deeplearning4j.nn.modelimport.keras.layers.TFOpLayer; import org.deeplearning4j.nn.modelimport.keras.layers.TFOpLayer;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.TensorFlowCnnToFeedForwardPreProcessor;
import org.nd4j.shade.guava.collect.ImmutableSet; import org.nd4j.shade.guava.collect.ImmutableSet;
import org.nd4j.shade.guava.reflect.ClassPath; import org.nd4j.shade.guava.reflect.ClassPath;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@ -51,16 +53,11 @@ import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder; import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer; import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
import org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer; import org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer;
import org.deeplearning4j.nn.conf.preprocessor.CnnToRnnPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.ComposableInputPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToCnn3DPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.RnnToCnnPreProcessor;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.util.IdentityLayer; import org.deeplearning4j.nn.layers.util.IdentityLayer;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor; import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.PermutePreprocessor; import org.deeplearning4j.nn.modelimport.keras.preprocessors.PermutePreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor; import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.TensorFlowCnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.WeightInitDistribution; import org.deeplearning4j.nn.weights.WeightInitDistribution;
@ -97,7 +94,8 @@ public class DTypeTests extends BaseDL4JTest {
Pooling2D.class, //Alias for SubsamplingLayer Pooling2D.class, //Alias for SubsamplingLayer
Convolution2D.class, //Alias for ConvolutionLayer Convolution2D.class, //Alias for ConvolutionLayer
Pooling1D.class, //Alias for Subsampling1D Pooling1D.class, //Alias for Subsampling1D
Convolution1D.class //Alias for Convolution1DLayer Convolution1D.class, //Alias for Convolution1DLayer
TensorFlowCnnToFeedForwardPreProcessor.class //Deprecated
)); ));
@Override @Override
@ -819,7 +817,7 @@ public class DTypeTests extends BaseDL4JTest {
.layer(new DenseLayer.Builder().nOut(5).build()) .layer(new DenseLayer.Builder().nOut(5).build())
.layer(new GravesBidirectionalLSTM.Builder().nIn(5).nOut(5).activation(Activation.TANH).build()) .layer(new GravesBidirectionalLSTM.Builder().nIn(5).nOut(5).activation(Activation.TANH).build())
.layer(new Bidirectional(new LSTM.Builder().nIn(5).nOut(5).activation(Activation.TANH).build())) .layer(new Bidirectional(new LSTM.Builder().nIn(5).nOut(5).activation(Activation.TANH).build()))
.layer(new TimeDistributed(new DenseLayer.Builder().nIn(10).nOut(5).activation(Activation.TANH).build(), 2)) .layer(new TimeDistributed(new DenseLayer.Builder().nIn(10).nOut(5).activation(Activation.TANH).build()))
.layer(new SimpleRnn.Builder().nIn(5).nOut(5).build()) .layer(new SimpleRnn.Builder().nIn(5).nOut(5).build())
.layer(new MaskZeroLayer.Builder().underlying(new SimpleRnn.Builder().nIn(5).nOut(5).build()).maskValue(0.0).build()) .layer(new MaskZeroLayer.Builder().underlying(new SimpleRnn.Builder().nIn(5).nOut(5).build()).maskValue(0.0).build())
.layer(secondLast) .layer(secondLast)
@ -1078,7 +1076,7 @@ public class DTypeTests extends BaseDL4JTest {
.addLayer("l", new DenseLayer.Builder().nOut(16).build(), "in") .addLayer("l", new DenseLayer.Builder().nOut(16).build(), "in")
.addVertex("preproc", new PreprocessorVertex(new FeedForwardToCnn3DPreProcessor(2, 2, 2, 2, true)), "l") .addVertex("preproc", new PreprocessorVertex(new FeedForwardToCnn3DPreProcessor(2, 2, 2, 2, true)), "l")
.addVertex("preproc2", new PreprocessorVertex(new PermutePreprocessor(0, 2, 3, 4, 1)), "preproc") .addVertex("preproc2", new PreprocessorVertex(new PermutePreprocessor(0, 2, 3, 4, 1)), "preproc")
.addVertex("preproc3", new PreprocessorVertex(new ReshapePreprocessor(new long[]{2, 2, 2, 2}, new long[]{16})), "preproc2") .addVertex("preproc3", new PreprocessorVertex(new ReshapePreprocessor(new long[]{2, 2, 2, 2}, new long[]{16}, false)), "preproc2")
.addLayer("out", new OutputLayer.Builder().nIn(16).nOut(10).build(), "preproc3") .addLayer("out", new OutputLayer.Builder().nIn(16).nOut(10).build(), "preproc3")
.setInputTypes(InputType.feedForward(5)) .setInputTypes(InputType.feedForward(5))
.setOutputs("out"); .setOutputs("out");
@ -1150,7 +1148,7 @@ public class DTypeTests extends BaseDL4JTest {
case 7: case 7:
b.addInputs("in") b.addInputs("in")
.addLayer("1", new ConvolutionLayer.Builder().kernelSize(2, 2).nOut(5).convolutionMode(ConvolutionMode.Same).build(), "in") .addLayer("1", new ConvolutionLayer.Builder().kernelSize(2, 2).nOut(5).convolutionMode(ConvolutionMode.Same).build(), "in")
.addVertex("2", new PreprocessorVertex(new TensorFlowCnnToFeedForwardPreProcessor(28, 28, 5)), "1") .addVertex("2", new PreprocessorVertex(new CnnToFeedForwardPreProcessor(28, 28, 5)), "1")
.addLayer("out", new OutputLayer.Builder().nOut(10).build(), "2") .addLayer("out", new OutputLayer.Builder().nOut(10).build(), "2")
.setOutputs("out") .setOutputs("out")
.setInputTypes(InputType.convolutional(28, 28, 1)); .setInputTypes(InputType.convolutional(28, 28, 1));

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.nn.graph; package org.deeplearning4j.nn.graph;
import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.IteratorDataSetIterator; import org.deeplearning4j.datasets.iterator.IteratorDataSetIterator;
@ -54,6 +55,7 @@ import java.util.Map;
import static org.junit.Assert.*; import static org.junit.Assert.*;
@Slf4j
public class ComputationGraphTestRNN extends BaseDL4JTest { public class ComputationGraphTestRNN extends BaseDL4JTest {
@Test @Test
@ -618,7 +620,7 @@ public class ComputationGraphTestRNN extends BaseDL4JTest {
.build(); .build();
fail("Exception expected"); fail("Exception expected");
} catch (IllegalStateException e){ } catch (IllegalStateException e){
// e.printStackTrace(); log.error("",e);
assertTrue(e.getMessage().contains("TBPTT") && e.getMessage().contains("validateTbpttConfig")); assertTrue(e.getMessage().contains("TBPTT") && e.getMessage().contains("validateTbpttConfig"));
} }
} }

View File

@ -1394,7 +1394,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
} catch (Exception e) { } catch (Exception e) {
//e.printStackTrace(); log.error("",e);
if(allowDisconnected){ if(allowDisconnected){
fail("No exception expected"); fail("No exception expected");
} else { } else {

View File

@ -416,8 +416,8 @@ public class TestVariableLengthTSCG extends BaseDL4JTest {
INDArray outRow2 = out2.get(NDArrayIndex.point(i), NDArrayIndex.all(), INDArray outRow2 = out2.get(NDArrayIndex.point(i), NDArrayIndex.all(),
NDArrayIndex.point(j)); NDArrayIndex.point(j));
for (int k = 0; k < nOut; k++) { for (int k = 0; k < nOut; k++) {
assertEquals(outRow.getDouble(k), 0.0, 0.0); assertEquals(0.0, outRow.getDouble(k), 0.0);
assertEquals(outRow2.getDouble(k), 0.0, 0.0); assertEquals(0.0, outRow2.getDouble(k), 0.0);
} }
} }
} }

View File

@ -60,7 +60,7 @@ public class TestGraphNodes extends BaseDL4JTest {
@Test @Test
public void testMergeNode() { public void testMergeNode() {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType()); GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType(), 1);
INDArray first = Nd4j.linspace(0, 11, 12, Nd4j.dataType()).reshape(3, 4); INDArray first = Nd4j.linspace(0, 11, 12, Nd4j.dataType()).reshape(3, 4);
INDArray second = Nd4j.linspace(0, 17, 18, Nd4j.dataType()).reshape(3, 6).addi(100); INDArray second = Nd4j.linspace(0, 17, 18, Nd4j.dataType()).reshape(3, 6).addi(100);
@ -82,7 +82,7 @@ public class TestGraphNodes extends BaseDL4JTest {
public void testMergeNodeRNN() { public void testMergeNodeRNN() {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType()); GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType(), 1);
INDArray first = Nd4j.linspace(0, 59, 60, Nd4j.dataType()).reshape(3, 4, 5); INDArray first = Nd4j.linspace(0, 59, 60, Nd4j.dataType()).reshape(3, 4, 5);
INDArray second = Nd4j.linspace(0, 89, 90, Nd4j.dataType()).reshape(3, 6, 5).addi(100); INDArray second = Nd4j.linspace(0, 89, 90, Nd4j.dataType()).reshape(3, 6, 5).addi(100);
@ -103,7 +103,7 @@ public class TestGraphNodes extends BaseDL4JTest {
@Test @Test
public void testCnnDepthMerge() { public void testCnnDepthMerge() {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType()); GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType(), 1);
INDArray first = Nd4j.linspace(0, 3, 4, Nd4j.dataType()).reshape(1, 1, 2, 2); INDArray first = Nd4j.linspace(0, 3, 4, Nd4j.dataType()).reshape(1, 1, 2, 2);
INDArray second = Nd4j.linspace(0, 3, 4, Nd4j.dataType()).reshape(1, 1, 2, 2).addi(10); INDArray second = Nd4j.linspace(0, 3, 4, Nd4j.dataType()).reshape(1, 1, 2, 2).addi(10);

View File

@ -0,0 +1,974 @@
/* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.nn.layers.convolution;
import lombok.*;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.CnnLossLayer;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.ZeroPaddingLayer;
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D;
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.ComposableInputPreProcessor;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import static org.junit.Assert.assertEquals;
@RunWith(Parameterized.class)
public class ConvDataFormatTests extends BaseDL4JTest {
private final DataType dataType;
public ConvDataFormatTests(DataType dataType){
this.dataType = dataType;
}
@Parameterized.Parameters(name = "{0}")
public static Object[] params(){
return new DataType[]{DataType.FLOAT, DataType.DOUBLE};
}
@Test
public void testConv2d() {
try {
for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getConv2dNet(CNN2DFormat.NCHW, true, cm))
.net2(getConv2dNet(CNN2DFormat.NCHW, false, cm))
.net3(getConv2dNet(CNN2DFormat.NHWC, true, cm))
.net4(getConv2dNet(CNN2DFormat.NHWC, false, cm))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
.testLayerIdx(1)
.build();
testHelper(tc);
}
}
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testSubsampling2d() {
try {
for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getSubsampling2dNet(CNN2DFormat.NCHW, true, cm))
.net2(getSubsampling2dNet(CNN2DFormat.NCHW, false, cm))
.net3(getSubsampling2dNet(CNN2DFormat.NHWC, true, cm))
.net4(getSubsampling2dNet(CNN2DFormat.NHWC, false, cm))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
.testLayerIdx(1)
.build();
testHelper(tc);
}
}
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testDepthwiseConv2d() {
try {
for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getDepthwiseConv2dNet(CNN2DFormat.NCHW, true, cm))
.net2(getDepthwiseConv2dNet(CNN2DFormat.NCHW, false, cm))
.net3(getDepthwiseConv2dNet(CNN2DFormat.NHWC, true, cm))
.net4(getDepthwiseConv2dNet(CNN2DFormat.NHWC, false, cm))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
.testLayerIdx(1)
.build();
testHelper(tc);
}
}
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testSeparableConv2d() {
try {
for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getSeparableConv2dNet(CNN2DFormat.NCHW, true, cm))
.net2(getSeparableConv2dNet(CNN2DFormat.NCHW, false, cm))
.net3(getSeparableConv2dNet(CNN2DFormat.NHWC, true, cm))
.net4(getSeparableConv2dNet(CNN2DFormat.NHWC, false, cm))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
.testLayerIdx(1)
.build();
testHelper(tc);
}
}
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testDeconv2d() {
try {
for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getDeconv2DNet2dNet(CNN2DFormat.NCHW, true, cm))
.net2(getDeconv2DNet2dNet(CNN2DFormat.NCHW, false, cm))
.net3(getDeconv2DNet2dNet(CNN2DFormat.NHWC, true, cm))
.net4(getDeconv2DNet2dNet(CNN2DFormat.NHWC, false, cm))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
.testLayerIdx(1)
.build();
testHelper(tc);
}
}
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testLRN() {
try {
for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getLrnLayer(CNN2DFormat.NCHW, true, cm))
.net2(getLrnLayer(CNN2DFormat.NCHW, false, cm))
.net3(getLrnLayer(CNN2DFormat.NHWC, true, cm))
.net4(getLrnLayer(CNN2DFormat.NHWC, false, cm))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
.testLayerIdx(1)
.build();
testHelper(tc);
}
}
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testZeroPaddingLayer(){
try {
for (boolean helpers : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = helpers ? "With helpers" : "No helpers";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getZeroPaddingNet(CNN2DFormat.NCHW, true))
.net2(getZeroPaddingNet(CNN2DFormat.NCHW, false))
.net3(getZeroPaddingNet(CNN2DFormat.NHWC, true))
.net4(getZeroPaddingNet(CNN2DFormat.NHWC, false))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
.testLayerIdx(1)
.build();
testHelper(tc);
}
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testCropping2DLayer(){
try {
for (boolean helpers : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = helpers ? "With helpers" : "No helpers";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getCropping2dNet(CNN2DFormat.NCHW, true))
.net2(getCropping2dNet(CNN2DFormat.NCHW, false))
.net3(getCropping2dNet(CNN2DFormat.NHWC, true))
.net4(getCropping2dNet(CNN2DFormat.NHWC, false))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
.testLayerIdx(1)
.build();
testHelper(tc);
}
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testUpsampling2d(){
try {
for (boolean helpers : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = helpers ? "With helpers" : "No helpers";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getUpsamplingNet(CNN2DFormat.NCHW, true))
.net2(getUpsamplingNet(CNN2DFormat.NCHW, false))
.net3(getUpsamplingNet(CNN2DFormat.NHWC, true))
.net4(getUpsamplingNet(CNN2DFormat.NHWC, false))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
.testLayerIdx(1)
.build();
testHelper(tc);
}
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testBatchNormNet(){
try {
for(boolean useLogStd : new boolean[]{true, false}) {
for (boolean helpers : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = (helpers ? "With helpers" : "No helpers") + " - " + (useLogStd ? "logstd" : "std");
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getBatchNormNet(useLogStd, CNN2DFormat.NCHW, true))
.net2(getBatchNormNet(useLogStd, CNN2DFormat.NCHW, false))
.net3(getBatchNormNet(useLogStd, CNN2DFormat.NHWC, true))
.net4(getBatchNormNet(useLogStd, CNN2DFormat.NHWC, false))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
.testLayerIdx(1)
.build();
testHelper(tc);
}
}
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testCnnLossLayer() {
try {
for (boolean helpers : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = helpers ? "With helpers" : "No helpers";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray labelsNHWC = TestUtils.randomOneHot(this.dataType,2*6*6, 3);
labelsNHWC = labelsNHWC.reshape(2,6,6,3);
INDArray labelsNCHW = labelsNHWC.permute(0,3,1,2).dup();
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getCnnLossNet(CNN2DFormat.NCHW, true, ConvolutionMode.Same))
.net2(getCnnLossNet(CNN2DFormat.NCHW, false, ConvolutionMode.Same))
.net3(getCnnLossNet(CNN2DFormat.NHWC, true, ConvolutionMode.Same))
.net4(getCnnLossNet(CNN2DFormat.NHWC, false, ConvolutionMode.Same))
.inNCHW(inNCHW)
.labelsNCHW(labelsNCHW)
.labelsNHWC(labelsNHWC)
.testLayerIdx(1)
.nhwcOutput(true)
.build();
testHelper(tc);
}
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testSpaceToDepthNet(){
try {
for (boolean helpers : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = helpers ? "With helpers" : "No helpers";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getSpaceToDepthNet(CNN2DFormat.NCHW, true))
.net2(getSpaceToDepthNet(CNN2DFormat.NCHW, false))
.net3(getSpaceToDepthNet(CNN2DFormat.NHWC, true))
.net4(getSpaceToDepthNet(CNN2DFormat.NHWC, false))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
.testLayerIdx(1)
.build();
testHelper(tc);
}
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testSpaceToBatchNet(){
try {
for (boolean helpers : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = helpers ? "With helpers" : "No helpers";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 16, 16);
INDArray labels = TestUtils.randomOneHot(8, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getSpaceToBatchNet(CNN2DFormat.NCHW, true))
.net2(getSpaceToBatchNet(CNN2DFormat.NCHW, false))
.net3(getSpaceToBatchNet(CNN2DFormat.NHWC, true))
.net4(getSpaceToBatchNet(CNN2DFormat.NHWC, false))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
.testLayerIdx(1)
.build();
testHelper(tc);
}
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testLocallyConnected() {
try {
for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getLocallyConnectedNet(CNN2DFormat.NCHW, true, cm))
.net2(getLocallyConnectedNet(CNN2DFormat.NCHW, false, cm))
.net3(getLocallyConnectedNet(CNN2DFormat.NHWC, true, cm))
.net4(getLocallyConnectedNet(CNN2DFormat.NHWC, false, cm))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
.testLayerIdx(1)
.build();
testHelper(tc);
}
}
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testGlobalPooling() {
try {
for (boolean helpers : new boolean[]{false, true}) {
for (PoolingType pt : PoolingType.values()) {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = helpers ? "With helpers (" + pt + ")" : "No helpers (" + pt + ")";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getGlobalPoolingNet(CNN2DFormat.NCHW, pt, true))
.net2(getGlobalPoolingNet(CNN2DFormat.NCHW, pt, false))
.net3(getGlobalPoolingNet(CNN2DFormat.NHWC, pt, true))
.net4(getGlobalPoolingNet(CNN2DFormat.NHWC, pt, false))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
.testLayerIdx(1)
.build();
testHelper(tc);
}
}
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
private MultiLayerNetwork getConv2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) {
return getNetWithLayer(new ConvolutionLayer.Builder()
.kernelSize(3, 3)
.stride(2, 2)
.activation(Activation.TANH)
.dataFormat(format)
.nOut(3)
.helperAllowFallback(false)
.build(), format, cm, null);
} else {
return getNetWithLayer(new ConvolutionLayer.Builder()
.kernelSize(3, 3)
.stride(2, 2)
.activation(Activation.TANH)
.nOut(3)
.helperAllowFallback(false)
.build(), format, cm, null);
}
}
private MultiLayerNetwork getSubsampling2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) {
return getNetWithLayer(new SubsamplingLayer.Builder()
.kernelSize(2, 2)
.stride(1, 1)
.dataFormat(format)
.helperAllowFallback(false)
.build(), format, cm, null);
} else {
return getNetWithLayer(new SubsamplingLayer.Builder()
.kernelSize(2, 2)
.stride(1, 1)
.helperAllowFallback(false)
.build(), format, cm, null);
}
}
private MultiLayerNetwork getSeparableConv2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) {
return getNetWithLayer(new SeparableConvolution2D.Builder()
.kernelSize(3, 3)
.stride(2, 2)
.activation(Activation.TANH)
.dataFormat(format)
.nOut(3)
.helperAllowFallback(false)
.build(), format, cm, null);
} else {
return getNetWithLayer(new SeparableConvolution2D.Builder()
.kernelSize(3, 3)
.stride(2, 2)
.activation(Activation.TANH)
.nOut(3)
.helperAllowFallback(false)
.build(), format, cm, null);
}
}
private MultiLayerNetwork getDepthwiseConv2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) {
return getNetWithLayer(new DepthwiseConvolution2D.Builder()
.depthMultiplier(2)
.kernelSize(3, 3)
.stride(2, 2)
.activation(Activation.TANH)
.dataFormat(format)
.nOut(3)
.helperAllowFallback(false)
.build(), format, cm, null);
} else {
return getNetWithLayer(new DepthwiseConvolution2D.Builder()
.depthMultiplier(2)
.kernelSize(3, 3)
.stride(2, 2)
.activation(Activation.TANH)
.nOut(3)
.helperAllowFallback(false)
.build(), format, cm, null);
}
}
private MultiLayerNetwork getLrnLayer(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) {
return getNetWithLayer(new LocalResponseNormalization.Builder()
.dataFormat(format)
.helperAllowFallback(false)
.build(), format, cm, null);
} else {
return getNetWithLayer(new LocalResponseNormalization.Builder()
.helperAllowFallback(false)
.build(), format, cm, null);
}
}
private MultiLayerNetwork getZeroPaddingNet(CNN2DFormat format, boolean setOnLayerAlso) {
if (setOnLayerAlso) {
return getNetWithLayer(new ZeroPaddingLayer.Builder(2,2)
.dataFormat(format).build(), format, ConvolutionMode.Same, null);
} else {
return getNetWithLayer(new ZeroPaddingLayer.Builder(2,2).build(),
format, ConvolutionMode.Same, null);
}
}
private MultiLayerNetwork getCropping2dNet(CNN2DFormat format, boolean setOnLayerAlso) {
if (setOnLayerAlso) {
return getNetWithLayer(new Cropping2D.Builder(2,2)
.dataFormat(format).build(), format, ConvolutionMode.Same, null);
} else {
return getNetWithLayer(new Cropping2D.Builder(2,2)
.build(), format, ConvolutionMode.Same, null);
}
}
private MultiLayerNetwork getUpsamplingNet(CNN2DFormat format, boolean setOnLayerAlso) {
if (setOnLayerAlso) {
return getNetWithLayer(new Upsampling2D.Builder(2)
.dataFormat(format).build(), format, ConvolutionMode.Same, null);
} else {
return getNetWithLayer(new Upsampling2D.Builder(2)
.build(), format, ConvolutionMode.Same, null);
}
}
private MultiLayerNetwork getDeconv2DNet2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) {
return getNetWithLayer(new Deconvolution2D.Builder().nOut(2)
.activation(Activation.TANH)
.kernelSize(2,2)
.stride(2,2)
.build(), format, cm, null);
} else {
return getNetWithLayer(new Deconvolution2D.Builder().nOut(2)
.activation(Activation.TANH)
.kernelSize(2,2)
.stride(2,2)
.build(), format, cm, null);
}
}
private MultiLayerNetwork getBatchNormNet(boolean logStdev, CNN2DFormat format, boolean setOnLayerAlso) {
if (setOnLayerAlso) {
return getNetWithLayer(new BatchNormalization.Builder()
.useLogStd(logStdev)
.dataFormat(format)
.helperAllowFallback(false)
.nOut(3).build(), format, ConvolutionMode.Same, null);
} else {
return getNetWithLayer(new BatchNormalization.Builder()
.useLogStd(logStdev)
.helperAllowFallback(false)
.nOut(3).build(), format, ConvolutionMode.Same, null);
}
}
private MultiLayerNetwork getSpaceToDepthNet(CNN2DFormat format, boolean setOnLayerAlso) {
if (setOnLayerAlso) {
return getNetWithLayer(new SpaceToDepthLayer.Builder()
.blocks(2)
.dataFormat(format)
.build(), format, ConvolutionMode.Same, null);
} else {
return getNetWithLayer(new SpaceToDepthLayer.Builder()
.blocks(2)
.build(), format, ConvolutionMode.Same, null);
}
}
private MultiLayerNetwork getSpaceToBatchNet(CNN2DFormat format, boolean setOnLayerAlso) {
if (setOnLayerAlso) {
return getNetWithLayer(new SpaceToBatchLayer.Builder()
.blocks(2, 2)
.dataFormat(format)
.build(), format, ConvolutionMode.Same, InputType.convolutional(16, 16, 3, format));
} else {
return getNetWithLayer(new SpaceToBatchLayer.Builder()
.blocks(2, 2)
.build(), format, ConvolutionMode.Same, InputType.convolutional(16, 16, 3, format));
}
}
private MultiLayerNetwork getLocallyConnectedNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) {
return getNetWithLayer(new LocallyConnected2D.Builder()
.kernelSize(3, 3)
.stride(2, 2)
.activation(Activation.TANH)
.dataFormat(format)
.nOut(3)
.build(), format, cm, null);
} else {
return getNetWithLayer(new LocallyConnected2D.Builder()
.kernelSize(3, 3)
.stride(2, 2)
.activation(Activation.TANH)
.nOut(3)
.build(), format, cm, null);
}
}
private MultiLayerNetwork getNetWithLayer(Layer layer, CNN2DFormat format, ConvolutionMode cm, InputType inputType) {
NeuralNetConfiguration.ListBuilder builder = new NeuralNetConfiguration.Builder()
.dataType(this.dataType)
.seed(12345)
.convolutionMode(cm)
.list()
.layer(new ConvolutionLayer.Builder()
.kernelSize(3, 3)
.stride(2, 2)
.activation(Activation.TANH)
.dataFormat(format)
.nOut(3)
.helperAllowFallback(false)
.build())
.layer(layer)
.layer(new OutputLayer.Builder().activation(Activation.SOFTMAX).nOut(10).build())
.setInputType(inputType != null ? inputType : InputType.convolutional(12, 12, 3, format));
if(format == CNN2DFormat.NHWC && !(layer instanceof GlobalPoolingLayer)){
//Add a preprocessor due to the differences in how NHWC and NCHW activations are flattened
//DL4J's flattening behaviour matches Keras (hence TF) for import compatibility
builder.inputPreProcessor(2, new ComposableInputPreProcessor(new NHWCToNCHWPreprocessor(), new CnnToFeedForwardPreProcessor()));
}
MultiLayerNetwork net = new MultiLayerNetwork(builder.build());
net.init();
return net;
}
private MultiLayerNetwork getGlobalPoolingNet(CNN2DFormat format, PoolingType pt, boolean setOnLayerAlso) {
if (setOnLayerAlso) {
return getNetWithLayer(new GlobalPoolingLayer.Builder(pt)
.poolingDimensions(format == CNN2DFormat.NCHW ? new int[]{2,3} : new int[]{1,2})
.build(), format, ConvolutionMode.Same, null);
} else {
return getNetWithLayer(new GlobalPoolingLayer.Builder(pt)
.build(), format, ConvolutionMode.Same, null);
}
}
private MultiLayerNetwork getCnnLossNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm){
NeuralNetConfiguration.ListBuilder builder = new NeuralNetConfiguration.Builder()
.seed(12345)
.convolutionMode(cm)
.list()
.layer(new ConvolutionLayer.Builder()
.kernelSize(3, 3)
.stride(2, 2)
.activation(Activation.TANH)
.dataFormat(format)
.nOut(3)
.helperAllowFallback(false)
.build());
if(setOnLayerAlso){
builder.layer(new CnnLossLayer.Builder().format(format).activation(Activation.SOFTMAX).build());
} else {
builder.layer(new CnnLossLayer.Builder().activation(Activation.SOFTMAX).build());
}
builder.setInputType(InputType.convolutional(12, 12, 3, format));
MultiLayerNetwork net = new MultiLayerNetwork(builder.build());
net.init();
return net;
}
@AllArgsConstructor
@Data
@NoArgsConstructor
@Builder
private static class TestCase {
private String msg;
private MultiLayerNetwork net1;
private MultiLayerNetwork net2;
private MultiLayerNetwork net3;
private MultiLayerNetwork net4;
private INDArray inNCHW;
private INDArray labelsNCHW;
private INDArray labelsNHWC;
private int testLayerIdx;
private boolean nhwcOutput;
}
public static void testHelper(TestCase tc) {
tc.net2.params().assign(tc.net1.params());
tc.net3.params().assign(tc.net1.params());
tc.net4.params().assign(tc.net1.params());
//Test forward pass:
INDArray inNCHW = tc.inNCHW;
INDArray inNHWC = tc.inNCHW.permute(0, 2, 3, 1).dup();
INDArray l0_1 = tc.net1.feedForward(inNCHW).get(tc.testLayerIdx + 1);
INDArray l0_2 = tc.net2.feedForward(inNCHW).get(tc.testLayerIdx + 1);
INDArray l0_3 = tc.net3.feedForward(inNHWC).get(tc.testLayerIdx + 1);
INDArray l0_4 = tc.net4.feedForward(inNHWC).get(tc.testLayerIdx + 1);
assertEquals(tc.msg, l0_1, l0_2);
if(l0_1.rank() == 4) {
assertEquals(tc.msg, l0_1, l0_3.permute(0, 3, 1, 2));
assertEquals(tc.msg, l0_1, l0_4.permute(0, 3, 1, 2));
} else {
assertEquals(tc.msg, l0_1, l0_3);
assertEquals(tc.msg, l0_1, l0_4);
}
INDArray out1 = tc.net1.output(inNCHW);
INDArray out2 = tc.net2.output(inNCHW);
INDArray out3 = tc.net3.output(inNHWC);
INDArray out4 = tc.net4.output(inNHWC);
assertEquals(tc.msg, out1, out2);
if(!tc.nhwcOutput) {
assertEquals(tc.msg, out1, out3);
assertEquals(tc.msg, out1, out4);
} else {
assertEquals(tc.msg, out1, out3.permute(0,3,1,2)); //NHWC to NCHW
assertEquals(tc.msg, out1, out4.permute(0,3,1,2));
}
//Test backprop
Pair<Gradient, INDArray> p1 = tc.net1.calculateGradients(inNCHW, tc.labelsNCHW, null, null);
Pair<Gradient, INDArray> p2 = tc.net2.calculateGradients(inNCHW, tc.labelsNCHW, null, null);
Pair<Gradient, INDArray> p3 = tc.net3.calculateGradients(inNHWC, tc.labelsNHWC, null, null);
Pair<Gradient, INDArray> p4 = tc.net4.calculateGradients(inNHWC, tc.labelsNHWC, null, null);
//Inpput gradients
assertEquals(tc.msg, p1.getSecond(), p2.getSecond());
assertEquals(tc.msg, p1.getSecond(), p3.getSecond().permute(0,3,1,2)); //Input gradients for NHWC input are also in NHWC format
assertEquals(tc.msg, p1.getSecond(), p4.getSecond().permute(0,3,1,2));
List<String> diff12 = differentGrads(p1.getFirst(), p2.getFirst());
List<String> diff13 = differentGrads(p1.getFirst(), p3.getFirst());
List<String> diff14 = differentGrads(p1.getFirst(), p4.getFirst());
assertEquals(tc.msg + " " + diff12, 0, diff12.size());
assertEquals(tc.msg + " " + diff13, 0, diff13.size());
assertEquals(tc.msg + " " + diff14, 0, diff14.size());
assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p2.getFirst().gradientForVariable());
assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p3.getFirst().gradientForVariable());
assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p4.getFirst().gradientForVariable());
tc.net1.fit(inNCHW, tc.labelsNCHW);
tc.net2.fit(inNCHW, tc.labelsNCHW);
tc.net3.fit(inNHWC, tc.labelsNHWC);
tc.net4.fit(inNHWC, tc.labelsNHWC);
assertEquals(tc.msg, tc.net1.params(), tc.net2.params());
assertEquals(tc.msg, tc.net1.params(), tc.net3.params());
assertEquals(tc.msg, tc.net1.params(), tc.net4.params());
//Test serialization
MultiLayerNetwork net1a = TestUtils.testModelSerialization(tc.net1);
MultiLayerNetwork net2a = TestUtils.testModelSerialization(tc.net2);
MultiLayerNetwork net3a = TestUtils.testModelSerialization(tc.net3);
MultiLayerNetwork net4a = TestUtils.testModelSerialization(tc.net4);
out1 = tc.net1.output(inNCHW);
assertEquals(tc.msg, out1, net1a.output(inNCHW));
assertEquals(tc.msg, out1, net2a.output(inNCHW));
if(!tc.nhwcOutput) {
assertEquals(tc.msg, out1, net3a.output(inNHWC));
assertEquals(tc.msg, out1, net4a.output(inNHWC));
} else {
assertEquals(tc.msg, out1, net3a.output(inNHWC).permute(0,3,1,2)); //NHWC to NCHW
assertEquals(tc.msg, out1, net4a.output(inNHWC).permute(0,3,1,2));
}
}
private static List<String> differentGrads(Gradient g1, Gradient g2){
List<String> differs = new ArrayList<>();
Map<String,INDArray> m1 = g1.gradientForVariable();
Map<String,INDArray> m2 = g2.gradientForVariable();
for(String s : m1.keySet()){
INDArray a1 = m1.get(s);
INDArray a2 = m2.get(s);
if(!a1.equals(a2)){
differs.add(s);
}
}
return differs;
}
//Converts NHWC to NCHW activations
@EqualsAndHashCode
private static class NHWCToNCHWPreprocessor implements InputPreProcessor {
@Override
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input.permute(0,3,1,2));
}
@Override
public INDArray backprop(INDArray output, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, output.permute(0,2,3,1));
}
@Override
public InputPreProcessor clone() {
return this;
}
@Override
public InputType getOutputType(InputType inputType) {
InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
return InputType.convolutional(c.getHeight(), c.getWidth(), c.getChannels(), CNN2DFormat.NCHW);
}
@Override
public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize) {
return null;
}
}
}

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.nn.layers.convolution; package org.deeplearning4j.nn.layers.convolution;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.exception.DL4JException; import org.deeplearning4j.exception.DL4JException;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
@ -45,6 +46,7 @@ import static org.junit.Assert.*;
/** /**
* Created by Alex on 15/11/2016. * Created by Alex on 15/11/2016.
*/ */
@Slf4j
public class TestConvolutionModes extends BaseDL4JTest { public class TestConvolutionModes extends BaseDL4JTest {
@Test @Test
@ -106,12 +108,12 @@ public class TestConvolutionModes extends BaseDL4JTest {
} }
} catch (DL4JException e) { } catch (DL4JException e) {
if (inSize == 9 || cm != ConvolutionMode.Strict) { if (inSize == 9 || cm != ConvolutionMode.Strict) {
e.printStackTrace(); log.error("",e);
fail("Unexpected exception"); fail("Unexpected exception");
} }
continue; //Expected exception continue; //Expected exception
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail("Unexpected exception"); fail("Unexpected exception");
} }
@ -184,12 +186,12 @@ public class TestConvolutionModes extends BaseDL4JTest {
} }
} catch (DL4JException e) { } catch (DL4JException e) {
if (inSize == 9 || cm != ConvolutionMode.Strict) { if (inSize == 9 || cm != ConvolutionMode.Strict) {
e.printStackTrace(); log.error("",e);
fail("Unexpected exception"); fail("Unexpected exception");
} }
continue; //Expected exception continue; //Expected exception
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail("Unexpected exception"); fail("Unexpected exception");
} }

View File

@ -24,10 +24,7 @@ import org.deeplearning4j.earlystopping.saver.InMemoryModelSaver;
import org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculator; import org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculator;
import org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition; import org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition;
import org.deeplearning4j.earlystopping.trainer.EarlyStoppingGraphTrainer; import org.deeplearning4j.earlystopping.trainer.EarlyStoppingGraphTrainer;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM; import org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM;
import org.deeplearning4j.nn.conf.layers.GravesLSTM; import org.deeplearning4j.nn.conf.layers.GravesLSTM;
@ -45,6 +42,8 @@ import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.util.ModelSerializer; import org.deeplearning4j.util.ModelSerializer;
import org.deeplearning4j.util.TimeSeriesUtils; import org.deeplearning4j.util.TimeSeriesUtils;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -61,12 +60,22 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream; import java.io.ByteArrayOutputStream;
import static org.deeplearning4j.nn.conf.RNNFormat.NCW;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
@Slf4j @Slf4j
@RunWith(Parameterized.class)
public class BidirectionalTest extends BaseDL4JTest { public class BidirectionalTest extends BaseDL4JTest {
private RNNFormat rnnDataFormat;
public BidirectionalTest(RNNFormat rnnDataFormat){
this.rnnDataFormat = rnnDataFormat;
}
@Parameterized.Parameters
public static Object[] params(){
return RNNFormat.values();
}
@Test @Test
public void compareImplementations(){ public void compareImplementations(){
for(WorkspaceMode wsm : WorkspaceMode.values()) { for(WorkspaceMode wsm : WorkspaceMode.values()) {
@ -82,9 +91,9 @@ public class BidirectionalTest extends BaseDL4JTest {
.inferenceWorkspaceMode(wsm) .inferenceWorkspaceMode(wsm)
.updater(new Adam()) .updater(new Adam())
.list() .list()
.layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build())) .layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()))
.layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build())) .layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()))
.layer(new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE) .layer(new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).dataFormat(rnnDataFormat)
.nIn(10).nOut(10).build()) .nIn(10).nOut(10).build())
.build(); .build();
@ -95,9 +104,9 @@ public class BidirectionalTest extends BaseDL4JTest {
.inferenceWorkspaceMode(wsm) .inferenceWorkspaceMode(wsm)
.updater(new Adam()) .updater(new Adam())
.list() .list()
.layer(new GravesBidirectionalLSTM.Builder().nIn(10).nOut(10).build()) .layer(new GravesBidirectionalLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build())
.layer(new GravesBidirectionalLSTM.Builder().nIn(10).nOut(10).build()) .layer(new GravesBidirectionalLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build())
.layer(new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE) .layer(new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).dataFormat(rnnDataFormat)
.nIn(10).nOut(10).build()) .nIn(10).nOut(10).build())
.build(); .build();
@ -116,15 +125,24 @@ public class BidirectionalTest extends BaseDL4JTest {
net2.setParams(net1.params()); //Assuming exact same layout here... net2.setParams(net1.params()); //Assuming exact same layout here...
INDArray in = Nd4j.rand(new int[]{3, 10, 5}); INDArray in;
if (rnnDataFormat == NCW){
in = Nd4j.rand(new int[]{3, 10, 5});
}else{
in = Nd4j.rand(new int[]{3, 5, 10});
}
INDArray out1 = net1.output(in); INDArray out1 = net1.output(in);
INDArray out2 = net2.output(in); INDArray out2 = net2.output(in);
assertEquals(out1, out2); assertEquals(out1, out2);
INDArray labels = Nd4j.rand(new int[]{3, 10, 5}); INDArray labels;
if (rnnDataFormat == NCW){
labels = Nd4j.rand(new int[]{3, 10, 5});
}else{
labels = Nd4j.rand(new int[]{3, 5, 10});
}
net1.setInput(in); net1.setInput(in);
net1.setLabels(labels); net1.setLabels(labels);
@ -276,17 +294,22 @@ public class BidirectionalTest extends BaseDL4JTest {
.inferenceWorkspaceMode(wsm) .inferenceWorkspaceMode(wsm)
.updater(new Adam()) .updater(new Adam())
.list() .list()
.layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build())) .layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()))
.layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build())) .layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()))
.layer(new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE) .layer(new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE)
.nIn(10).nOut(10).build()) .nIn(10).nOut(10).dataFormat(rnnDataFormat).build())
.build(); .build();
MultiLayerNetwork net1 = new MultiLayerNetwork(conf1); MultiLayerNetwork net1 = new MultiLayerNetwork(conf1);
net1.init(); net1.init();
INDArray in = Nd4j.rand(new int[]{3, 10, 5}); INDArray in;
INDArray labels = Nd4j.rand(new int[]{3, 10, 5}); INDArray labels;
long[] inshape = rnnDataFormat == NCW ? new long[]{3, 10, 5} : new long[]{3, 5, 10};
in = Nd4j.rand(inshape);
labels = Nd4j.rand(inshape);
net1.fit(in, labels); net1.fit(in, labels);
@ -300,8 +323,8 @@ public class BidirectionalTest extends BaseDL4JTest {
MultiLayerNetwork net2 = ModelSerializer.restoreMultiLayerNetwork(new ByteArrayInputStream(bytes), true); MultiLayerNetwork net2 = ModelSerializer.restoreMultiLayerNetwork(new ByteArrayInputStream(bytes), true);
in = Nd4j.rand(new int[]{3, 10, 5}); in = Nd4j.rand(inshape);
labels = Nd4j.rand(new int[]{3, 10, 5}); labels = Nd4j.rand(inshape);
INDArray out1 = net1.output(in); INDArray out1 = net1.output(in);
INDArray out2 = net2.output(in); INDArray out2 = net2.output(in);
@ -338,18 +361,18 @@ public class BidirectionalTest extends BaseDL4JTest {
.updater(new Adam()) .updater(new Adam())
.graphBuilder() .graphBuilder()
.addInputs("in") .addInputs("in")
.layer("0", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build()), "in") .layer("0", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()), "in")
.layer("1", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build()), "0") .layer("1", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()), "0")
.layer("2", new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE) .layer("2", new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).dataFormat(rnnDataFormat)
.nIn(10).nOut(10).build(), "1") .nIn(10).nOut(10).build(), "1")
.setOutputs("2") .setOutputs("2")
.build(); .build();
ComputationGraph net1 = new ComputationGraph(conf1); ComputationGraph net1 = new ComputationGraph(conf1);
net1.init(); net1.init();
long[] inshape = (rnnDataFormat == NCW)? new long[]{3, 10, 5}: new long[]{3, 5, 10};
INDArray in = Nd4j.rand(new int[]{3, 10, 5}); INDArray in = Nd4j.rand(inshape);
INDArray labels = Nd4j.rand(new int[]{3, 10, 5}); INDArray labels = Nd4j.rand(inshape);
net1.fit(new DataSet(in, labels)); net1.fit(new DataSet(in, labels));
@ -363,8 +386,8 @@ public class BidirectionalTest extends BaseDL4JTest {
ComputationGraph net2 = ModelSerializer.restoreComputationGraph(new ByteArrayInputStream(bytes), true); ComputationGraph net2 = ModelSerializer.restoreComputationGraph(new ByteArrayInputStream(bytes), true);
in = Nd4j.rand(new int[]{3, 10, 5}); in = Nd4j.rand(inshape);
labels = Nd4j.rand(new int[]{3, 10, 5}); labels = Nd4j.rand(inshape);
INDArray out1 = net1.outputSingle(in); INDArray out1 = net1.outputSingle(in);
INDArray out2 = net2.outputSingle(in); INDArray out2 = net2.outputSingle(in);
@ -394,8 +417,8 @@ public class BidirectionalTest extends BaseDL4JTest {
Bidirectional.Mode[] modes = new Bidirectional.Mode[]{Bidirectional.Mode.CONCAT, Bidirectional.Mode.ADD, Bidirectional.Mode[] modes = new Bidirectional.Mode[]{Bidirectional.Mode.CONCAT, Bidirectional.Mode.ADD,
Bidirectional.Mode.AVERAGE, Bidirectional.Mode.MUL}; Bidirectional.Mode.AVERAGE, Bidirectional.Mode.MUL};
long[] inshape = rnnDataFormat == NCW ? new long[]{3, 10, 6} : new long[]{3, 6, 10};
INDArray in = Nd4j.rand(new int[]{3, 10, 6}); INDArray in = Nd4j.rand(inshape);
for (Bidirectional.Mode m : modes) { for (Bidirectional.Mode m : modes) {
MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder() MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder()
@ -406,7 +429,7 @@ public class BidirectionalTest extends BaseDL4JTest {
.inferenceWorkspaceMode(wsm) .inferenceWorkspaceMode(wsm)
.updater(new Adam()) .updater(new Adam())
.list() .list()
.layer(new Bidirectional(m, new SimpleRnn.Builder().nIn(10).nOut(10).build())) .layer(new Bidirectional(m, new SimpleRnn.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()))
.build(); .build();
MultiLayerNetwork net1 = new MultiLayerNetwork(conf1); MultiLayerNetwork net1 = new MultiLayerNetwork(conf1);
@ -418,7 +441,7 @@ public class BidirectionalTest extends BaseDL4JTest {
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
.updater(new Adam()) .updater(new Adam())
.list() .list()
.layer(new SimpleRnn.Builder().nIn(10).nOut(10).build()) .layer(new SimpleRnn.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build())
.build(); .build();
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2.clone()); MultiLayerNetwork net2 = new MultiLayerNetwork(conf2.clone());
@ -434,11 +457,10 @@ public class BidirectionalTest extends BaseDL4JTest {
net3.setParam("0_RW", net1.getParam("0_bRW")); net3.setParam("0_RW", net1.getParam("0_bRW"));
net3.setParam("0_b", net1.getParam("0_bb")); net3.setParam("0_b", net1.getParam("0_bb"));
INDArray inReverse = TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT); INDArray inReverse = TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT, rnnDataFormat);
INDArray out1 = net1.output(in); INDArray out1 = net1.output(in);
INDArray out2 = net2.output(in); INDArray out2 = net2.output(in);
INDArray out3 = TimeSeriesUtils.reverseTimeSeries(net3.output(inReverse), LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT); INDArray out3 = TimeSeriesUtils.reverseTimeSeries(net3.output(inReverse), LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT, rnnDataFormat);
INDArray outExp; INDArray outExp;
switch (m) { switch (m) {
@ -452,7 +474,7 @@ public class BidirectionalTest extends BaseDL4JTest {
outExp = out2.add(out3).muli(0.5); outExp = out2.add(out3).muli(0.5);
break; break;
case CONCAT: case CONCAT:
outExp = Nd4j.concat(1, out2, out3); outExp = Nd4j.concat((rnnDataFormat == NCW)?1:2, out2, out3);
break; break;
default: default:
throw new RuntimeException(); throw new RuntimeException();
@ -464,25 +486,25 @@ public class BidirectionalTest extends BaseDL4JTest {
//Check gradients: //Check gradients:
if (m == Bidirectional.Mode.ADD || m == Bidirectional.Mode.CONCAT) { if (m == Bidirectional.Mode.ADD || m == Bidirectional.Mode.CONCAT) {
INDArray eps = Nd4j.rand(new int[]{3, 10, 6}); INDArray eps = Nd4j.rand(inshape);
INDArray eps1; INDArray eps1;
if (m == Bidirectional.Mode.CONCAT) { if (m == Bidirectional.Mode.CONCAT) {
eps1 = Nd4j.concat(1, eps, eps); eps1 = Nd4j.concat((rnnDataFormat == NCW)?1:2, eps, eps);
} else { } else {
eps1 = eps; eps1 = eps;
} }
net1.setInput(in); net1.setInput(in);
net2.setInput(in); net2.setInput(in);
net3.setInput(TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT)); net3.setInput(TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT, rnnDataFormat));
net1.feedForward(true, false); net1.feedForward(true, false);
net2.feedForward(true, false); net2.feedForward(true, false);
net3.feedForward(true, false); net3.feedForward(true, false);
Pair<Gradient, INDArray> p1 = net1.backpropGradient(eps1, LayerWorkspaceMgr.noWorkspaces()); Pair<Gradient, INDArray> p1 = net1.backpropGradient(eps1, LayerWorkspaceMgr.noWorkspaces());
Pair<Gradient, INDArray> p2 = net2.backpropGradient(eps, LayerWorkspaceMgr.noWorkspaces()); Pair<Gradient, INDArray> p2 = net2.backpropGradient(eps, LayerWorkspaceMgr.noWorkspaces());
Pair<Gradient, INDArray> p3 = net3.backpropGradient(TimeSeriesUtils.reverseTimeSeries(eps, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT), LayerWorkspaceMgr.noWorkspaces()); Pair<Gradient, INDArray> p3 = net3.backpropGradient(TimeSeriesUtils.reverseTimeSeries(eps, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT, rnnDataFormat), LayerWorkspaceMgr.noWorkspaces());
Gradient g1 = p1.getFirst(); Gradient g1 = p1.getFirst();
Gradient g2 = p2.getFirst(); Gradient g2 = p2.getFirst();
Gradient g3 = p3.getFirst(); Gradient g3 = p3.getFirst();
@ -520,7 +542,9 @@ public class BidirectionalTest extends BaseDL4JTest {
Bidirectional.Mode.AVERAGE, Bidirectional.Mode.MUL}; Bidirectional.Mode.AVERAGE, Bidirectional.Mode.MUL};
INDArray in = Nd4j.rand(new int[]{3, 10, 6}); long[] inshape = rnnDataFormat == NCW ? new long[]{3, 10, 6} : new long[]{3, 6, 10};
INDArray in = Nd4j.rand(inshape);
for (Bidirectional.Mode m : modes) { for (Bidirectional.Mode m : modes) {
ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder() ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder()
@ -532,7 +556,7 @@ public class BidirectionalTest extends BaseDL4JTest {
.updater(new Adam()) .updater(new Adam())
.graphBuilder() .graphBuilder()
.addInputs("in") .addInputs("in")
.layer("0", new Bidirectional(m, new SimpleRnn.Builder().nIn(10).nOut(10).build()), "in") .layer("0", new Bidirectional(m, new SimpleRnn.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()), "in")
.setOutputs("0") .setOutputs("0")
.build(); .build();
@ -546,7 +570,7 @@ public class BidirectionalTest extends BaseDL4JTest {
.updater(new Adam()) .updater(new Adam())
.graphBuilder() .graphBuilder()
.addInputs("in") .addInputs("in")
.layer("0", new SimpleRnn.Builder().nIn(10).nOut(10).build(), "in") .layer("0", new SimpleRnn.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build(), "in")
.setOutputs("0") .setOutputs("0")
.build(); .build();
@ -566,9 +590,20 @@ public class BidirectionalTest extends BaseDL4JTest {
INDArray out1 = net1.outputSingle(in); INDArray out1 = net1.outputSingle(in);
INDArray out2 = net2.outputSingle(in); INDArray out2 = net2.outputSingle(in);
INDArray out3 = TimeSeriesUtils.reverseTimeSeries(net3.outputSingle( INDArray out3;
TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT)), INDArray inReverse;
LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT); if (rnnDataFormat == RNNFormat.NWC){
inReverse = TimeSeriesUtils.reverseTimeSeries(in.permute(0, 2, 1), LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT).permute(0, 2, 1);
out3 = net3.outputSingle(inReverse);
out3 = TimeSeriesUtils.reverseTimeSeries(out3.permute(0, 2, 1), LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT).permute(0, 2, 1);
}
else{
inReverse = TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT);
out3 = net3.outputSingle(inReverse);
out3 = TimeSeriesUtils.reverseTimeSeries(out3, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT);
}
INDArray outExp; INDArray outExp;
switch (m) { switch (m) {
@ -582,7 +617,9 @@ public class BidirectionalTest extends BaseDL4JTest {
outExp = out2.add(out3).muli(0.5); outExp = out2.add(out3).muli(0.5);
break; break;
case CONCAT: case CONCAT:
outExp = Nd4j.concat(1, out2, out3); System.out.println(out2.shapeInfoToString());
System.out.println(out3.shapeInfoToString());
outExp = Nd4j.concat((rnnDataFormat == NCW)?1:2, out2, out3);
break; break;
default: default:
throw new RuntimeException(); throw new RuntimeException();
@ -594,22 +631,26 @@ public class BidirectionalTest extends BaseDL4JTest {
//Check gradients: //Check gradients:
if (m == Bidirectional.Mode.ADD || m == Bidirectional.Mode.CONCAT) { if (m == Bidirectional.Mode.ADD || m == Bidirectional.Mode.CONCAT) {
INDArray eps = Nd4j.rand(new int[]{3, 10, 6}); INDArray eps = Nd4j.rand(inshape);
INDArray eps1; INDArray eps1;
if (m == Bidirectional.Mode.CONCAT) { if (m == Bidirectional.Mode.CONCAT) {
eps1 = Nd4j.concat(1, eps, eps); eps1 = Nd4j.concat((rnnDataFormat == NCW)?1:2, eps, eps);
} else { } else {
eps1 = eps; eps1 = eps;
} }
INDArray epsReversed = (rnnDataFormat == NCW)?
TimeSeriesUtils.reverseTimeSeries(eps, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT):
TimeSeriesUtils.reverseTimeSeries(eps.permute(0, 2, 1), LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT)
.permute(0, 2, 1);
net1.outputSingle(true, false, in); net1.outputSingle(true, false, in);
net2.outputSingle(true, false, in); net2.outputSingle(true, false, in);
net3.outputSingle(true, false, TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT)); net3.outputSingle(true, false, inReverse);
Gradient g1 = net1.backpropGradient(eps1); Gradient g1 = net1.backpropGradient(eps1);
Gradient g2 = net2.backpropGradient(eps); Gradient g2 = net2.backpropGradient(eps);
Gradient g3 = net3.backpropGradient(TimeSeriesUtils.reverseTimeSeries(eps, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT)); Gradient g3 = net3.backpropGradient(epsReversed);
for (boolean updates : new boolean[]{false, true}) { for (boolean updates : new boolean[]{false, true}) {
if (updates) { if (updates) {

View File

@ -23,6 +23,7 @@ import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.CacheMode; import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.distribution.UniformDistribution; import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
@ -31,6 +32,8 @@ import org.deeplearning4j.nn.params.GravesLSTMParamInitializer;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.impl.ActivationSigmoid; import org.nd4j.linalg.activations.impl.ActivationSigmoid;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -42,10 +45,18 @@ import org.nd4j.linalg.primitives.Pair;
import static org.junit.Assert.*; import static org.junit.Assert.*;
@RunWith(Parameterized.class)
public class GravesBidirectionalLSTMTest extends BaseDL4JTest { public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
private double score = 0.0; private double score = 0.0;
private RNNFormat rnnDataFormat;
public GravesBidirectionalLSTMTest(RNNFormat rnnDataFormat){
this.rnnDataFormat = rnnDataFormat;
}
@Parameterized.Parameters
public static Object[] params(){
return RNNFormat.values();
}
@Test @Test
public void testBidirectionalLSTMGravesForwardBasic() { public void testBidirectionalLSTMGravesForwardBasic() {
//Very basic test of forward prop. of LSTM layer with a time series. //Very basic test of forward prop. of LSTM layer with a time series.
@ -55,7 +66,7 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
final NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder() final NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder()
.layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().nIn(nIn) .layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().nIn(nIn)
.nOut(nHiddenUnits).activation(Activation.TANH).build()) .nOut(nHiddenUnits).dataFormat(rnnDataFormat).activation(Activation.TANH).build())
.build(); .build();
val numParams = conf.getLayer().initializer().numParams(conf); val numParams = conf.getLayer().initializer().numParams(conf);
@ -65,22 +76,41 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
//Data: has shape [miniBatchSize,nIn,timeSeriesLength]; //Data: has shape [miniBatchSize,nIn,timeSeriesLength];
//Output/activations has shape [miniBatchsize,nHiddenUnits,timeSeriesLength]; //Output/activations has shape [miniBatchsize,nHiddenUnits,timeSeriesLength];
if (rnnDataFormat == RNNFormat.NCW){
final INDArray dataSingleExampleTimeLength1 = Nd4j.ones(1, nIn, 1);
final INDArray activations1 = layer.activate(dataSingleExampleTimeLength1, false, LayerWorkspaceMgr.noWorkspaces());
assertArrayEquals(activations1.shape(), new long[] {1, nHiddenUnits, 1});
final INDArray dataSingleExampleTimeLength1 = Nd4j.ones(1, nIn, 1); final INDArray dataMultiExampleLength1 = Nd4j.ones(10, nIn, 1);
final INDArray activations1 = layer.activate(dataSingleExampleTimeLength1, false, LayerWorkspaceMgr.noWorkspaces()); final INDArray activations2 = layer.activate(dataMultiExampleLength1, false, LayerWorkspaceMgr.noWorkspaces());
assertArrayEquals(activations1.shape(), new long[] {1, nHiddenUnits, 1}); assertArrayEquals(activations2.shape(), new long[] {10, nHiddenUnits, 1});
final INDArray dataMultiExampleLength1 = Nd4j.ones(10, nIn, 1); final INDArray dataSingleExampleLength12 = Nd4j.ones(1, nIn, 12);
final INDArray activations2 = layer.activate(dataMultiExampleLength1, false, LayerWorkspaceMgr.noWorkspaces()); final INDArray activations3 = layer.activate(dataSingleExampleLength12, false, LayerWorkspaceMgr.noWorkspaces());
assertArrayEquals(activations2.shape(), new long[] {10, nHiddenUnits, 1}); assertArrayEquals(activations3.shape(), new long[] {1, nHiddenUnits, 12});
final INDArray dataSingleExampleLength12 = Nd4j.ones(1, nIn, 12); final INDArray dataMultiExampleLength15 = Nd4j.ones(10, nIn, 15);
final INDArray activations3 = layer.activate(dataSingleExampleLength12, false, LayerWorkspaceMgr.noWorkspaces()); final INDArray activations4 = layer.activate(dataMultiExampleLength15, false, LayerWorkspaceMgr.noWorkspaces());
assertArrayEquals(activations3.shape(), new long[] {1, nHiddenUnits, 12}); assertArrayEquals(activations4.shape(), new long[] {10, nHiddenUnits, 15});
}
else{
final INDArray dataSingleExampleTimeLength1 = Nd4j.ones(1, 1, nIn);
final INDArray activations1 = layer.activate(dataSingleExampleTimeLength1, false, LayerWorkspaceMgr.noWorkspaces());
assertArrayEquals(activations1.shape(), new long[] {1, 1, nHiddenUnits});
final INDArray dataMultiExampleLength1 = Nd4j.ones(10, 1, nIn);
final INDArray activations2 = layer.activate(dataMultiExampleLength1, false, LayerWorkspaceMgr.noWorkspaces());
assertArrayEquals(activations2.shape(), new long[] {10, 1, nHiddenUnits});
final INDArray dataSingleExampleLength12 = Nd4j.ones(1, 12, nIn);
final INDArray activations3 = layer.activate(dataSingleExampleLength12, false, LayerWorkspaceMgr.noWorkspaces());
assertArrayEquals(activations3.shape(), new long[] {1, 12, nHiddenUnits});
final INDArray dataMultiExampleLength15 = Nd4j.ones(10, 15, nIn);
final INDArray activations4 = layer.activate(dataMultiExampleLength15, false, LayerWorkspaceMgr.noWorkspaces());
assertArrayEquals(activations4.shape(), new long[] {10, 15, nHiddenUnits});
}
final INDArray dataMultiExampleLength15 = Nd4j.ones(10, nIn, 15);
final INDArray activations4 = layer.activate(dataMultiExampleLength15, false, LayerWorkspaceMgr.noWorkspaces());
assertArrayEquals(activations4.shape(), new long[] {10, nHiddenUnits, 15});
} }
@Test @Test
@ -94,14 +124,15 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
testGravesBackwardBasicHelper(13, 3, 17, 1, 1); //Edge case: both miniBatchSize = 1 and timeSeriesLength = 1 testGravesBackwardBasicHelper(13, 3, 17, 1, 1); //Edge case: both miniBatchSize = 1 and timeSeriesLength = 1
} }
private static void testGravesBackwardBasicHelper(int nIn, int nOut, int lstmNHiddenUnits, int miniBatchSize, private void testGravesBackwardBasicHelper(int nIn, int nOut, int lstmNHiddenUnits, int miniBatchSize,
int timeSeriesLength) { int timeSeriesLength) {
INDArray inputData = Nd4j.ones(miniBatchSize, nIn, timeSeriesLength); INDArray inputData = (rnnDataFormat == RNNFormat.NCW)?Nd4j.ones(miniBatchSize, nIn, timeSeriesLength):
Nd4j.ones(miniBatchSize, timeSeriesLength, nIn);
NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder() NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder()
.layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().nIn(nIn) .layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().nIn(nIn)
.nOut(lstmNHiddenUnits) .nOut(lstmNHiddenUnits).dataFormat(rnnDataFormat)
.dist(new UniformDistribution(0, 1)).activation(Activation.TANH).build()) .dist(new UniformDistribution(0, 1)).activation(Activation.TANH).build())
.build(); .build();
@ -114,7 +145,8 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
lstm.activate(inputData, false, LayerWorkspaceMgr.noWorkspaces()); lstm.activate(inputData, false, LayerWorkspaceMgr.noWorkspaces());
assertNotNull(lstm.input()); assertNotNull(lstm.input());
INDArray epsilon = Nd4j.ones(miniBatchSize, lstmNHiddenUnits, timeSeriesLength); INDArray epsilon =(rnnDataFormat == RNNFormat.NCW)? Nd4j.ones(miniBatchSize, lstmNHiddenUnits, timeSeriesLength):
Nd4j.ones(miniBatchSize, timeSeriesLength, lstmNHiddenUnits);
Pair<Gradient, INDArray> out = lstm.backpropGradient(epsilon, LayerWorkspaceMgr.noWorkspaces()); Pair<Gradient, INDArray> out = lstm.backpropGradient(epsilon, LayerWorkspaceMgr.noWorkspaces());
Gradient outGradient = out.getFirst(); Gradient outGradient = out.getFirst();
@ -147,7 +179,11 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
assertArrayEquals(recurrentWeightGradientB.shape(), new long[] {lstmNHiddenUnits, 4 * lstmNHiddenUnits + 3}); assertArrayEquals(recurrentWeightGradientB.shape(), new long[] {lstmNHiddenUnits, 4 * lstmNHiddenUnits + 3});
assertNotNull(nextEpsilon); assertNotNull(nextEpsilon);
assertArrayEquals(nextEpsilon.shape(), new long[] {miniBatchSize, nIn, timeSeriesLength}); if (rnnDataFormat == RNNFormat.NCW) {
assertArrayEquals(nextEpsilon.shape(), new long[]{miniBatchSize, nIn, timeSeriesLength});
}else{
assertArrayEquals(nextEpsilon.shape(), new long[]{miniBatchSize, timeSeriesLength, nIn });
}
//Check update: //Check update:
for (String s : outGradient.gradientForVariable().keySet()) { for (String s : outGradient.gradientForVariable().keySet()) {
@ -226,7 +262,7 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
final NeuralNetConfiguration confBidirectional = new NeuralNetConfiguration.Builder() final NeuralNetConfiguration confBidirectional = new NeuralNetConfiguration.Builder()
.layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().nIn(nIn) .layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().nIn(nIn)
.nOut(layerSize) .nOut(layerSize).dataFormat(rnnDataFormat)
.dist(new UniformDistribution(-0.1, 0.1)).activation(Activation.TANH).build()) .dist(new UniformDistribution(-0.1, 0.1)).activation(Activation.TANH).build())
.build(); .build();
@ -237,7 +273,8 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
.instantiate(confBidirectional, null, 0, params, true, params.dataType()); .instantiate(confBidirectional, null, 0, params, true, params.dataType());
final INDArray sig = Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength}); final INDArray sig = (rnnDataFormat == RNNFormat.NCW)?Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength}):
Nd4j.rand(new int[] {miniBatchSize, timeSeriesLength, nIn});
final INDArray act1 = bidirectionalLSTM.activate(sig, false, LayerWorkspaceMgr.noWorkspaces()); final INDArray act1 = bidirectionalLSTM.activate(sig, false, LayerWorkspaceMgr.noWorkspaces());
@ -265,13 +302,13 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
final NeuralNetConfiguration confBidirectional = final NeuralNetConfiguration confBidirectional =
new NeuralNetConfiguration.Builder() new NeuralNetConfiguration.Builder()
.layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder() .layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder()
.nIn(nIn).nOut(layerSize) .nIn(nIn).nOut(layerSize).dataFormat(rnnDataFormat)
.dist(new UniformDistribution(-0.1, 0.1)) .dist(new UniformDistribution(-0.1, 0.1))
.activation(Activation.TANH).updater(new NoOp()).build()) .activation(Activation.TANH).updater(new NoOp()).build())
.build(); .build();
final NeuralNetConfiguration confForwards = new NeuralNetConfiguration.Builder() final NeuralNetConfiguration confForwards = new NeuralNetConfiguration.Builder()
.layer(new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn).nOut(layerSize) .layer(new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn).nOut(layerSize).dataFormat(rnnDataFormat)
.weightInit(WeightInit.ZERO).activation(Activation.TANH).build()) .weightInit(WeightInit.ZERO).activation(Activation.TANH).build())
.build(); .build();
@ -290,9 +327,16 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
Nd4j.create(1, confForwards.getLayer().initializer().numParams(confForwards))); Nd4j.create(1, confForwards.getLayer().initializer().numParams(confForwards)));
final INDArray sig = Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength}); final INDArray sig = (rnnDataFormat == RNNFormat.NCW)?Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength}):
Nd4j.rand(new int[] {miniBatchSize, timeSeriesLength, nIn});
final INDArray sigb = sig.dup(); final INDArray sigb = sig.dup();
reverseColumnsInPlace(sigb.slice(0));
if (rnnDataFormat == RNNFormat.NCW) {
reverseColumnsInPlace(sigb.slice(0));
}
else{
reverseColumnsInPlace(sigb.slice(0).permute(1, 0));
}
final INDArray recurrentWeightsF = bidirectionalLSTM final INDArray recurrentWeightsF = bidirectionalLSTM
.getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS); .getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS);
@ -345,10 +389,14 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
assertArrayEquals(activation1.data().asFloat(), activation2.data().asFloat(), 1e-5f); assertArrayEquals(activation1.data().asFloat(), activation2.data().asFloat(), 1e-5f);
final INDArray randSig = Nd4j.rand(new int[] {1, layerSize, timeSeriesLength}); final INDArray randSig = (rnnDataFormat == RNNFormat.NCW)?Nd4j.rand(new int[] {1, layerSize, timeSeriesLength}):
final INDArray randSigBackwards = randSig.dup(); Nd4j.rand(new int[] {1, timeSeriesLength, layerSize});
reverseColumnsInPlace(randSigBackwards.slice(0)); INDArray randSigBackwards = randSig.dup();
if (rnnDataFormat == RNNFormat.NCW){
reverseColumnsInPlace(randSigBackwards.slice(0));
}else{
reverseColumnsInPlace(randSigBackwards.slice(0).permute(1, 0));
}
final Pair<Gradient, INDArray> backprop1 = forwardsLSTM.backpropGradient(randSig, LayerWorkspaceMgr.noWorkspaces()); final Pair<Gradient, INDArray> backprop1 = forwardsLSTM.backpropGradient(randSig, LayerWorkspaceMgr.noWorkspaces());
final Pair<Gradient, INDArray> backprop2 = bidirectionalLSTM.backpropGradient(randSig, LayerWorkspaceMgr.noWorkspaces()); final Pair<Gradient, INDArray> backprop2 = bidirectionalLSTM.backpropGradient(randSig, LayerWorkspaceMgr.noWorkspaces());
@ -399,10 +447,16 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
final INDArray activation3 = bidirectionalLSTM.activate(sigb, false, LayerWorkspaceMgr.noWorkspaces()).slice(0); final INDArray activation3 = bidirectionalLSTM.activate(sigb, false, LayerWorkspaceMgr.noWorkspaces()).slice(0);
final INDArray activation3Reverse = activation3.dup(); final INDArray activation3Reverse = activation3.dup();
reverseColumnsInPlace(activation3Reverse); if (rnnDataFormat == RNNFormat.NCW){
reverseColumnsInPlace(activation3Reverse);
}
else{
reverseColumnsInPlace(activation3Reverse.permute(1, 0));
}
assertEquals(activation3Reverse, activation1);
assertArrayEquals(activation3Reverse.shape(), activation1.shape()); assertArrayEquals(activation3Reverse.shape(), activation1.shape());
assertEquals(activation3Reverse, activation1);
//test backprop now //test backprop now
final INDArray refBackGradientReccurrent = final INDArray refBackGradientReccurrent =
@ -434,7 +488,12 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
final INDArray refEpsilon = backprop1.getSecond().dup(); final INDArray refEpsilon = backprop1.getSecond().dup();
final INDArray backEpsilon = backprop3.getSecond().dup(); final INDArray backEpsilon = backprop3.getSecond().dup();
reverseColumnsInPlace(refEpsilon.slice(0)); if (rnnDataFormat == RNNFormat.NCW) {
reverseColumnsInPlace(refEpsilon.slice(0));
}
else{
reverseColumnsInPlace(refEpsilon.slice(0).permute(1, 0));
}
assertArrayEquals(backEpsilon.dup().data().asDouble(), refEpsilon.dup().data().asDouble(), 1e-6); assertArrayEquals(backEpsilon.dup().data().asDouble(), refEpsilon.dup().data().asDouble(), 1e-6);
} }
@ -477,10 +536,10 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.seed(12345).list() .seed(12345).list()
.layer(0, new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder() .layer(0, new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder()
.gateActivationFunction(gateAfn).activation(Activation.TANH).nIn(2).nOut(2) .gateActivationFunction(gateAfn).activation(Activation.TANH).nIn(2).nOut(2).dataFormat(rnnDataFormat)
.build()) .build())
.layer(1, new org.deeplearning4j.nn.conf.layers.RnnOutputLayer.Builder() .layer(1, new org.deeplearning4j.nn.conf.layers.RnnOutputLayer.Builder()
.lossFunction(LossFunctions.LossFunction.MSE).nIn(2).nOut(2) .lossFunction(LossFunctions.LossFunction.MSE).nIn(2).nOut(2).dataFormat(rnnDataFormat)
.activation(Activation.TANH).build()) .activation(Activation.TANH).build())
.build(); .build();
@ -492,7 +551,10 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
INDArray in = Nd4j.rand(new int[] {3, 2, 5}); INDArray in = Nd4j.rand(new int[] {3, 2, 5});
INDArray labels = Nd4j.rand(new int[] {3, 2, 5}); INDArray labels = Nd4j.rand(new int[] {3, 2, 5});
if (rnnDataFormat == RNNFormat.NWC){
in = in.permute(0, 2, 1);
labels = labels.permute(0, 2, 1);
}
net.fit(in, labels); net.fit(in, labels);
} }
} }

View File

@ -21,11 +21,14 @@ import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.layers.LSTM; import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.optimize.api.TrainingListener;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -36,9 +39,17 @@ import java.util.Collections;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
@RunWith(Parameterized.class)
public class MaskZeroLayerTest extends BaseDL4JTest { public class MaskZeroLayerTest extends BaseDL4JTest {
private RNNFormat rnnDataFormat;
public MaskZeroLayerTest(RNNFormat rnnDataFormat){
this.rnnDataFormat = rnnDataFormat;
}
@Parameterized.Parameters
public static Object[] params(){
return RNNFormat.values();
}
@Test @Test
public void activate() { public void activate() {
@ -57,7 +68,7 @@ public class MaskZeroLayerTest extends BaseDL4JTest {
.activation(Activation.IDENTITY) .activation(Activation.IDENTITY)
.gateActivationFunction(Activation.IDENTITY) .gateActivationFunction(Activation.IDENTITY)
.nIn(2) .nIn(2)
.nOut(1) .nOut(1).dataFormat(rnnDataFormat)
.build(); .build();
NeuralNetConfiguration conf = new NeuralNetConfiguration(); NeuralNetConfiguration conf = new NeuralNetConfiguration();
conf.setLayer(underlying); conf.setLayer(underlying);
@ -72,20 +83,25 @@ public class MaskZeroLayerTest extends BaseDL4JTest {
MaskZeroLayer l = new MaskZeroLayer(lstm, maskingValue); MaskZeroLayer l = new MaskZeroLayer(lstm, maskingValue);
INDArray input = Nd4j.create(Arrays.asList(ex1, ex2), new int[]{2, 2, 3}); INDArray input = Nd4j.create(Arrays.asList(ex1, ex2), new int[]{2, 2, 3});
if (rnnDataFormat == RNNFormat.NWC){
input = input.permute(0, 2, 1);
}
//WHEN //WHEN
INDArray out = l.activate(input, true, LayerWorkspaceMgr.noWorkspaces()); INDArray out = l.activate(input, true, LayerWorkspaceMgr.noWorkspaces());
if (rnnDataFormat == RNNFormat.NWC){
out = out.permute(0, 2,1);
}
//THEN output should only be incremented for the non-zero timesteps //THEN output should only be incremented for the non-zero timesteps
INDArray firstExampleOutput = out.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.all()); INDArray firstExampleOutput = out.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.all());
INDArray secondExampleOutput = out.get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.all()); INDArray secondExampleOutput = out.get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.all());
assertEquals(firstExampleOutput.getDouble(0), 0.0, 1e-6); assertEquals(0.0, firstExampleOutput.getDouble(0), 1e-6);
assertEquals(firstExampleOutput.getDouble(1), 1.0, 1e-6); assertEquals(1.0, firstExampleOutput.getDouble(1), 1e-6);
assertEquals(firstExampleOutput.getDouble(2), 2.0, 1e-6); assertEquals(2.0, firstExampleOutput.getDouble(2), 1e-6);
assertEquals(secondExampleOutput.getDouble(0), 0.0, 1e-6); assertEquals(0.0, secondExampleOutput.getDouble(0), 1e-6);
assertEquals(secondExampleOutput.getDouble(1), 0.0, 1e-6); assertEquals(0.0, secondExampleOutput.getDouble(1), 1e-6);
assertEquals(secondExampleOutput.getDouble(2), 1.0, 1e-6); assertEquals(1.0, secondExampleOutput.getDouble(2), 1e-6);
} }
@ -94,7 +110,7 @@ public class MaskZeroLayerTest extends BaseDL4JTest {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.list() .list()
.layer(new org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer.Builder() .layer(new org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer.Builder()
.setMaskValue(0.0).setUnderlying(new LSTM.Builder().nIn(4).nOut(5).build()).build()) .setMaskValue(0.0).setUnderlying(new LSTM.Builder().nIn(4).nOut(5).dataFormat(rnnDataFormat).build()).build())
.build(); .build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();

View File

@ -0,0 +1,394 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.nn.layers.recurrent;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import static org.junit.Assert.assertEquals;
@RunWith(Parameterized.class)
@AllArgsConstructor
public class RnnDataFormatTests extends BaseDL4JTest {
private boolean helpers;
private boolean lastTimeStep;
private boolean maskZeros;
@Parameterized.Parameters(name = "helpers={0},lastTimeStep={1},maskZero={2}")
public static List params(){
List<Object[]> ret = new ArrayList<>();
for (boolean helpers: new boolean[]{true, false})
for (boolean lastTimeStep: new boolean[]{true, false})
for (boolean maskZero: new boolean[]{true, false})
ret.add(new Object[]{helpers, lastTimeStep, maskZero});
return ret;
}
@Test
public void testSimpleRnn() {
try {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = "Helpers: " + helpers + ", lastTimeStep: " + lastTimeStep + ", maskZeros: " + maskZeros;
System.out.println(" --- " + msg + " ---");
INDArray inNCW = Nd4j.rand(DataType.FLOAT, 2, 3, 12);
INDArray labelsNWC = (lastTimeStep) ?TestUtils.randomOneHot(2, 10): TestUtils.randomOneHot(2 * 12, 10).reshape(2, 12, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getSimpleRnnNet(RNNFormat.NCW, true, lastTimeStep, maskZeros))
.net2(getSimpleRnnNet(RNNFormat.NCW, false, lastTimeStep, maskZeros))
.net3(getSimpleRnnNet(RNNFormat.NWC, true, lastTimeStep, maskZeros))
.net4(getSimpleRnnNet(RNNFormat.NWC, false, lastTimeStep, maskZeros))
.inNCW(inNCW)
.labelsNCW((lastTimeStep)? labelsNWC: labelsNWC.permute(0, 2, 1))
.labelsNWC(labelsNWC)
.testLayerIdx(1)
.build();
TestCase.testHelper(tc);
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testLSTM() {
try {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = "Helpers: " + helpers + ", lastTimeStep: " + lastTimeStep + ", maskZeros: " + maskZeros;
System.out.println(" --- " + msg + " ---");
INDArray inNCW = Nd4j.rand(DataType.FLOAT, 2, 3, 12);
INDArray labelsNWC = (lastTimeStep) ?TestUtils.randomOneHot(2, 10): TestUtils.randomOneHot(2 * 12, 10).reshape(2, 12, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getLstmNet(RNNFormat.NCW, true, lastTimeStep, maskZeros))
.net2(getLstmNet(RNNFormat.NCW, false, lastTimeStep, maskZeros))
.net3(getLstmNet(RNNFormat.NWC, true, lastTimeStep, maskZeros))
.net4(getLstmNet(RNNFormat.NWC, false, lastTimeStep, maskZeros))
.inNCW(inNCW)
.labelsNCW((lastTimeStep)? labelsNWC: labelsNWC.permute(0, 2, 1))
.labelsNWC(labelsNWC)
.testLayerIdx(1)
.build();
TestCase.testHelper(tc);
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testGraveLSTM() {
try {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = "Helpers: " + helpers + ", lastTimeStep: " + lastTimeStep + ", maskZeros: " + maskZeros;
System.out.println(" --- " + msg + " ---");
INDArray inNCW = Nd4j.rand(DataType.FLOAT, 2, 3, 12);
INDArray labelsNWC = (lastTimeStep) ?TestUtils.randomOneHot(2, 10): TestUtils.randomOneHot(2 * 12, 10).reshape(2, 12, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getGravesLstmNet(RNNFormat.NCW, true, lastTimeStep, maskZeros))
.net2(getGravesLstmNet(RNNFormat.NCW, false, lastTimeStep, maskZeros))
.net3(getGravesLstmNet(RNNFormat.NWC, true, lastTimeStep, maskZeros))
.net4(getGravesLstmNet(RNNFormat.NWC, false, lastTimeStep, maskZeros))
.inNCW(inNCW)
.labelsNCW((lastTimeStep)? labelsNWC: labelsNWC.permute(0, 2, 1))
.labelsNWC(labelsNWC)
.testLayerIdx(1)
.build();
TestCase.testHelper(tc);
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testGraveBiLSTM() {
try {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = "Helpers: " + helpers + ", lastTimeStep: " + lastTimeStep + ", maskZeros: " + maskZeros;
System.out.println(" --- " + msg + " ---");
INDArray inNCW = Nd4j.rand(DataType.FLOAT, 2, 3, 12);
INDArray labelsNWC = (lastTimeStep) ?TestUtils.randomOneHot(2, 10): TestUtils.randomOneHot(2 * 12, 10).reshape(2, 12, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getGravesBidirectionalLstmNet(RNNFormat.NCW, true, lastTimeStep, maskZeros))
.net2(getGravesBidirectionalLstmNet(RNNFormat.NCW, false, lastTimeStep, maskZeros))
.net3(getGravesBidirectionalLstmNet(RNNFormat.NWC, true, lastTimeStep, maskZeros))
.net4(getGravesBidirectionalLstmNet(RNNFormat.NWC, false, lastTimeStep, maskZeros))
.inNCW(inNCW)
.labelsNCW((lastTimeStep)? labelsNWC: labelsNWC.permute(0, 2, 1))
.labelsNWC(labelsNWC)
.testLayerIdx(1)
.build();
TestCase.testHelper(tc);
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
private MultiLayerNetwork getGravesBidirectionalLstmNet(RNNFormat format, boolean setOnLayerAlso, boolean lastTimeStep, boolean maskZeros) {
if (setOnLayerAlso) {
return getNetWithLayer(new GravesBidirectionalLSTM.Builder().nOut(3)
.dataFormat(format).build(), format, lastTimeStep, maskZeros);
} else {
return getNetWithLayer(new GravesBidirectionalLSTM.Builder().nOut(3).build(), format, lastTimeStep, maskZeros);
}
}
private MultiLayerNetwork getGravesLstmNet(RNNFormat format, boolean setOnLayerAlso, boolean lastTimeStep, boolean maskZeros) {
if (setOnLayerAlso) {
return getNetWithLayer(new GravesLSTM.Builder().nOut(3)
.dataFormat(format).build(), format, lastTimeStep, maskZeros);
} else {
return getNetWithLayer(new GravesLSTM.Builder().nOut(3).build(), format, lastTimeStep, maskZeros);
}
}
private MultiLayerNetwork getLstmNet(RNNFormat format, boolean setOnLayerAlso, boolean lastTimeStep, boolean maskZeros) {
if (setOnLayerAlso) {
return getNetWithLayer(new LSTM.Builder().nOut(3)
.dataFormat(format).build(), format, lastTimeStep, maskZeros);
} else {
return getNetWithLayer(new LSTM.Builder().nOut(3).build(), format, lastTimeStep, maskZeros);
}
}
private MultiLayerNetwork getSimpleRnnNet(RNNFormat format, boolean setOnLayerAlso, boolean lastTimeStep, boolean maskZeros) {
if (setOnLayerAlso) {
return getNetWithLayer(new SimpleRnn.Builder().nOut(3)
.dataFormat(format).build(), format, lastTimeStep, maskZeros);
} else {
return getNetWithLayer(new SimpleRnn.Builder().nOut(3).build(), format, lastTimeStep, maskZeros);
}
}
private MultiLayerNetwork getNetWithLayer(Layer layer, RNNFormat format, boolean lastTimeStep, boolean maskZeros) {
if (maskZeros){
layer = new MaskZeroLayer.Builder().setMaskValue(0.).setUnderlying(layer).build();
}
if(lastTimeStep){
layer = new LastTimeStep(layer);
}
NeuralNetConfiguration.ListBuilder builder = new NeuralNetConfiguration.Builder()
.seed(12345)
.list()
.layer(new LSTM.Builder()
.nIn(3)
.activation(Activation.TANH)
.dataFormat(format)
.nOut(3)
.helperAllowFallback(false)
.build())
.layer(layer)
.layer(
(lastTimeStep)?new OutputLayer.Builder().activation(Activation.SOFTMAX).nOut(10).build():
new RnnOutputLayer.Builder().activation(Activation.SOFTMAX).nOut(10).dataFormat(format).build()
)
.setInputType(InputType.recurrent(3, 12, format));
MultiLayerNetwork net = new MultiLayerNetwork(builder.build());
net.init();
return net;
}
@AllArgsConstructor
@Data
@NoArgsConstructor
@Builder
private static class TestCase {
private String msg;
private MultiLayerNetwork net1;
private MultiLayerNetwork net2;
private MultiLayerNetwork net3;
private MultiLayerNetwork net4;
private INDArray inNCW;
private INDArray labelsNCW;
private INDArray labelsNWC;
private int testLayerIdx;
private boolean nwcOutput;
public static void testHelper(TestCase tc) {
tc.net2.params().assign(tc.net1.params());
tc.net3.params().assign(tc.net1.params());
tc.net4.params().assign(tc.net1.params());
INDArray inNCW = tc.inNCW;
INDArray inNWC = tc.inNCW.permute(0, 2, 1).dup();
INDArray l0_1 = tc.net1.feedForward(inNCW).get(tc.testLayerIdx + 1);
INDArray l0_2 = tc.net2.feedForward(inNCW).get(tc.testLayerIdx + 1);
INDArray l0_3 = tc.net3.feedForward(inNWC).get(tc.testLayerIdx + 1);
INDArray l0_4 = tc.net4.feedForward(inNWC).get(tc.testLayerIdx + 1);
boolean rank3Out = tc.labelsNCW.rank() == 3;
assertEquals(tc.msg, l0_1, l0_2);
if (rank3Out){
assertEquals(tc.msg, l0_1, l0_3.permute(0, 2, 1));
assertEquals(tc.msg, l0_1, l0_4.permute(0, 2, 1));
}
else{
assertEquals(tc.msg, l0_1, l0_3);
assertEquals(tc.msg, l0_1, l0_4);
}
INDArray out1 = tc.net1.output(inNCW);
INDArray out2 = tc.net2.output(inNCW);
INDArray out3 = tc.net3.output(inNWC);
INDArray out4 = tc.net4.output(inNWC);
assertEquals(tc.msg, out1, out2);
if (rank3Out){
assertEquals(tc.msg, out1, out3.permute(0, 2, 1)); //NWC to NCW
assertEquals(tc.msg, out1, out4.permute(0, 2, 1));
}
else{
assertEquals(tc.msg, out1, out3); //NWC to NCW
assertEquals(tc.msg, out1, out4);
}
//Test backprop
Pair<Gradient, INDArray> p1 = tc.net1.calculateGradients(inNCW, tc.labelsNCW, null, null);
Pair<Gradient, INDArray> p2 = tc.net2.calculateGradients(inNCW, tc.labelsNCW, null, null);
Pair<Gradient, INDArray> p3 = tc.net3.calculateGradients(inNWC, tc.labelsNWC, null, null);
Pair<Gradient, INDArray> p4 = tc.net4.calculateGradients(inNWC, tc.labelsNWC, null, null);
//Inpput gradients
assertEquals(tc.msg, p1.getSecond(), p2.getSecond());
assertEquals(tc.msg, p1.getSecond(), p3.getSecond().permute(0, 2, 1)); //Input gradients for NWC input are also in NWC format
assertEquals(tc.msg, p1.getSecond(), p4.getSecond().permute(0, 2, 1));
List<String> diff12 = differentGrads(p1.getFirst(), p2.getFirst());
List<String> diff13 = differentGrads(p1.getFirst(), p3.getFirst());
List<String> diff14 = differentGrads(p1.getFirst(), p4.getFirst());
assertEquals(tc.msg + " " + diff12, 0, diff12.size());
assertEquals(tc.msg + " " + diff13, 0, diff13.size());
assertEquals(tc.msg + " " + diff14, 0, diff14.size());
assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p2.getFirst().gradientForVariable());
assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p3.getFirst().gradientForVariable());
assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p4.getFirst().gradientForVariable());
tc.net1.fit(inNCW, tc.labelsNCW);
tc.net2.fit(inNCW, tc.labelsNCW);
tc.net3.fit(inNWC, tc.labelsNWC);
tc.net4.fit(inNWC, tc.labelsNWC);
assertEquals(tc.msg, tc.net1.params(), tc.net2.params());
assertEquals(tc.msg, tc.net1.params(), tc.net3.params());
assertEquals(tc.msg, tc.net1.params(), tc.net4.params());
//Test serialization
MultiLayerNetwork net1a = TestUtils.testModelSerialization(tc.net1);
MultiLayerNetwork net2a = TestUtils.testModelSerialization(tc.net2);
MultiLayerNetwork net3a = TestUtils.testModelSerialization(tc.net3);
MultiLayerNetwork net4a = TestUtils.testModelSerialization(tc.net4);
out1 = tc.net1.output(inNCW);
assertEquals(tc.msg, out1, net1a.output(inNCW));
assertEquals(tc.msg, out1, net2a.output(inNCW));
if (rank3Out) {
assertEquals(tc.msg, out1, net3a.output(inNWC).permute(0, 2, 1)); //NWC to NCW
assertEquals(tc.msg, out1, net4a.output(inNWC).permute(0, 2, 1));
}
else{
assertEquals(tc.msg, out1, net3a.output(inNWC)); //NWC to NCW
assertEquals(tc.msg, out1, net4a.output(inNWC));
}
}
}
private static List<String> differentGrads(Gradient g1, Gradient g2){
List<String> differs = new ArrayList<>();
Map<String,INDArray> m1 = g1.gradientForVariable();
Map<String,INDArray> m2 = g2.gradientForVariable();
for(String s : m1.keySet()){
INDArray a1 = m1.get(s);
INDArray a2 = m2.get(s);
if(!a1.equals(a2)){
differs.add(s);
}
}
return differs;
}
}

View File

@ -21,6 +21,7 @@ import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.LSTM; import org.deeplearning4j.nn.conf.layers.LSTM;
@ -29,6 +30,8 @@ import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex;
@ -42,14 +45,25 @@ import static org.nd4j.linalg.activations.Activation.IDENTITY;
import static org.nd4j.linalg.activations.Activation.TANH; import static org.nd4j.linalg.activations.Activation.TANH;
import static org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction.MSE; import static org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction.MSE;
@RunWith(Parameterized.class)
public class TestLastTimeStepLayer extends BaseDL4JTest { public class TestLastTimeStepLayer extends BaseDL4JTest {
private RNNFormat rnnDataFormat;
public TestLastTimeStepLayer(RNNFormat rnnDataFormat){
this.rnnDataFormat = rnnDataFormat;
}
@Parameterized.Parameters(name="{0}")
public static Object[] params(){
return RNNFormat.values();
}
@Test @Test
public void testLastTimeStepVertex() { public void testLastTimeStepVertex() {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in") ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in")
.addLayer("lastTS", new LastTimeStep(new SimpleRnn.Builder() .addLayer("lastTS", new LastTimeStep(new SimpleRnn.Builder()
.nIn(5).nOut(6).build()), "in") .nIn(5).nOut(6).dataFormat(rnnDataFormat).build()), "in")
.setOutputs("lastTS") .setOutputs("lastTS")
.build(); .build();
@ -59,9 +73,22 @@ public class TestLastTimeStepLayer extends BaseDL4JTest {
//First: test without input mask array //First: test without input mask array
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
Layer l = graph.getLayer("lastTS"); Layer l = graph.getLayer("lastTS");
INDArray in = Nd4j.rand(new int[]{3, 5, 6}); INDArray in;
if (rnnDataFormat == RNNFormat.NCW){
in = Nd4j.rand(3, 5, 6);
}
else{
in = Nd4j.rand(3, 6, 5);
}
INDArray outUnderlying = ((LastTimeStepLayer)l).getUnderlying().activate(in, false, LayerWorkspaceMgr.noWorkspaces()); INDArray outUnderlying = ((LastTimeStepLayer)l).getUnderlying().activate(in, false, LayerWorkspaceMgr.noWorkspaces());
INDArray expOut = outUnderlying.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(5)); INDArray expOut;
if (rnnDataFormat == RNNFormat.NCW){
expOut = outUnderlying.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(5));
}
else{
expOut = outUnderlying.get(NDArrayIndex.all(), NDArrayIndex.point(5), NDArrayIndex.all());
}
//Forward pass: //Forward pass:
@ -76,9 +103,17 @@ public class TestLastTimeStepLayer extends BaseDL4JTest {
graph.setLayerMaskArrays(new INDArray[]{inMask}, null); graph.setLayerMaskArrays(new INDArray[]{inMask}, null);
expOut = Nd4j.zeros(3, 6); expOut = Nd4j.zeros(3, 6);
expOut.putRow(0, outUnderlying.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(2))); if (rnnDataFormat == RNNFormat.NCW){
expOut.putRow(1, outUnderlying.get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.point(3))); expOut.putRow(0, outUnderlying.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(2)));
expOut.putRow(2, outUnderlying.get(NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.point(4))); expOut.putRow(1, outUnderlying.get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.point(3)));
expOut.putRow(2, outUnderlying.get(NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.point(4)));
}
else{
expOut.putRow(0, outUnderlying.get(NDArrayIndex.point(0), NDArrayIndex.point(2), NDArrayIndex.all()));
expOut.putRow(1, outUnderlying.get(NDArrayIndex.point(1), NDArrayIndex.point(3), NDArrayIndex.all()));
expOut.putRow(2, outUnderlying.get(NDArrayIndex.point(2), NDArrayIndex.point(4), NDArrayIndex.all()));
}
outFwd = l.activate(in, false, LayerWorkspaceMgr.noWorkspaces()); outFwd = l.activate(in, false, LayerWorkspaceMgr.noWorkspaces());
assertEquals(expOut, outFwd); assertEquals(expOut, outFwd);
@ -97,9 +132,9 @@ public class TestLastTimeStepLayer extends BaseDL4JTest {
.seed(1234) .seed(1234)
.graphBuilder() .graphBuilder()
.addInputs("in") .addInputs("in")
.setInputTypes(InputType.recurrent(1)) .setInputTypes(InputType.recurrent(1, rnnDataFormat))
.addLayer("RNN", new LastTimeStep(new LSTM.Builder() .addLayer("RNN", new LastTimeStep(new LSTM.Builder()
.nOut(10) .nOut(10).dataFormat(rnnDataFormat)
.build()), "in") .build()), "in")
.addLayer("dense", new DenseLayer.Builder() .addLayer("dense", new DenseLayer.Builder()
.nOut(10) .nOut(10)
@ -120,7 +155,9 @@ public class TestLastTimeStepLayer extends BaseDL4JTest {
INDArray fm2 = Nd4j.zeros(1,24); INDArray fm2 = Nd4j.zeros(1,24);
INDArray fm3 = Nd4j.zeros(1,24); INDArray fm3 = Nd4j.zeros(1,24);
fm3.get(NDArrayIndex.point(0), NDArrayIndex.interval(0,5)).assign(1); fm3.get(NDArrayIndex.point(0), NDArrayIndex.interval(0,5)).assign(1);
if (rnnDataFormat == RNNFormat.NWC){
f = f.permute(0, 2, 1);
}
INDArray[] out1 = cg.output(false, new INDArray[]{f}, new INDArray[]{fm1}); INDArray[] out1 = cg.output(false, new INDArray[]{f}, new INDArray[]{fm1});
try { try {
cg.output(false, new INDArray[]{f}, new INDArray[]{fm2}); cg.output(false, new INDArray[]{f}, new INDArray[]{fm2});

View File

@ -20,6 +20,7 @@ import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils; import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.dropout.TestDropout; import org.deeplearning4j.nn.conf.dropout.TestDropout;
import org.deeplearning4j.nn.conf.layers.GravesLSTM; import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.conf.layers.LSTM; import org.deeplearning4j.nn.conf.layers.LSTM;
@ -31,6 +32,8 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -41,13 +44,24 @@ import org.nd4j.linalg.primitives.Pair;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.Random;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
@RunWith(Parameterized.class)
public class TestRnnLayers extends BaseDL4JTest { public class TestRnnLayers extends BaseDL4JTest {
private RNNFormat rnnDataFormat;
public TestRnnLayers(RNNFormat rnnDataFormat){
this.rnnDataFormat = rnnDataFormat;
}
@Parameterized.Parameters
public static Object[] params(){
return RNNFormat.values();
}
@Test @Test
public void testTimeStepIs3Dimensional() { public void testTimeStepIs3Dimensional() {
@ -58,8 +72,8 @@ public class TestRnnLayers extends BaseDL4JTest {
.updater(new NoOp()) .updater(new NoOp())
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
.list() .list()
.layer(new SimpleRnn.Builder().nIn(nIn).nOut(3).build()) .layer(new SimpleRnn.Builder().nIn(nIn).nOut(3).dataFormat(rnnDataFormat).build())
.layer(new LSTM.Builder().nIn(3).nOut(5).build()) .layer(new LSTM.Builder().nIn(3).nOut(5).dataFormat(rnnDataFormat).build())
.layer(new RnnOutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX).build()) .layer(new RnnOutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX).build())
.build(); .build();
@ -70,9 +84,9 @@ public class TestRnnLayers extends BaseDL4JTest {
org.deeplearning4j.nn.layers.recurrent.SimpleRnn simpleRnn = org.deeplearning4j.nn.layers.recurrent.SimpleRnn simpleRnn =
(org.deeplearning4j.nn.layers.recurrent.SimpleRnn) net.getLayer(0); (org.deeplearning4j.nn.layers.recurrent.SimpleRnn) net.getLayer(0);
INDArray rnnInput3d = Nd4j.create(10, 12, 1); INDArray rnnInput3d = (rnnDataFormat==RNNFormat.NCW)?Nd4j.create(10,12, 1):Nd4j.create(10, 1, 12);
INDArray simpleOut = simpleRnn.rnnTimeStep(rnnInput3d, LayerWorkspaceMgr.noWorkspaces()); INDArray simpleOut = simpleRnn.rnnTimeStep(rnnInput3d, LayerWorkspaceMgr.noWorkspaces());
assertTrue(Arrays.equals(simpleOut.shape(), new long[] {10, 3, 1})); assertTrue(Arrays.equals(simpleOut.shape(), (rnnDataFormat==RNNFormat.NCW)?new long[] {10, 3, 1}:new long[]{10, 1, 3}));
INDArray rnnInput2d = Nd4j.create(10, 12); INDArray rnnInput2d = Nd4j.create(10, 12);
try { try {
@ -84,9 +98,9 @@ public class TestRnnLayers extends BaseDL4JTest {
org.deeplearning4j.nn.layers.recurrent.LSTM lstm = org.deeplearning4j.nn.layers.recurrent.LSTM lstm =
(org.deeplearning4j.nn.layers.recurrent.LSTM) net.getLayer(1); (org.deeplearning4j.nn.layers.recurrent.LSTM) net.getLayer(1);
INDArray lstmInput3d = Nd4j.create(10, 3, 1); INDArray lstmInput3d = (rnnDataFormat==RNNFormat.NCW)?Nd4j.create(10, 3, 1):Nd4j.create(10, 1, 3);
INDArray lstmOut = lstm.rnnTimeStep(lstmInput3d, LayerWorkspaceMgr.noWorkspaces()); INDArray lstmOut = lstm.rnnTimeStep(lstmInput3d, LayerWorkspaceMgr.noWorkspaces());
assertTrue(Arrays.equals(lstmOut.shape(), new long[] {10, 5, 1})); assertTrue(Arrays.equals(lstmOut.shape(), (rnnDataFormat==RNNFormat.NCW)?new long[] {10, 5, 1}:new long[]{10, 1, 5}));
INDArray lstmInput2d = Nd4j.create(10, 3); INDArray lstmInput2d = Nd4j.create(10, 3);
try { try {
@ -112,19 +126,19 @@ public class TestRnnLayers extends BaseDL4JTest {
TestDropout.CustomDropout cd = new TestDropout.CustomDropout(); TestDropout.CustomDropout cd = new TestDropout.CustomDropout();
switch (s){ switch (s){
case "graves": case "graves":
layer = new GravesLSTM.Builder().activation(Activation.TANH).nIn(10).nOut(10).build(); layer = new GravesLSTM.Builder().activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
layerD = new GravesLSTM.Builder().dropOut(0.5).activation(Activation.TANH).nIn(10).nOut(10).build(); layerD = new GravesLSTM.Builder().dropOut(0.5).activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
layerD2 = new GravesLSTM.Builder().dropOut(cd).activation(Activation.TANH).nIn(10).nOut(10).build(); layerD2 = new GravesLSTM.Builder().dropOut(cd).activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
break; break;
case "lstm": case "lstm":
layer = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().activation(Activation.TANH).nIn(10).nOut(10).build(); layer = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
layerD = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().dropOut(0.5).activation(Activation.TANH).nIn(10).nOut(10).build(); layerD = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().dropOut(0.5).activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
layerD2 = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().dropOut(cd).activation(Activation.TANH).nIn(10).nOut(10).build(); layerD2 = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().dropOut(cd).activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
break; break;
case "simple": case "simple":
layer = new SimpleRnn.Builder().activation(Activation.TANH).nIn(10).nOut(10).build(); layer = new SimpleRnn.Builder().activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
layerD = new SimpleRnn.Builder().dropOut(0.5).activation(Activation.TANH).nIn(10).nOut(10).build(); layerD = new SimpleRnn.Builder().dropOut(0.5).activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
layerD2 = new SimpleRnn.Builder().dropOut(cd).activation(Activation.TANH).nIn(10).nOut(10).build(); layerD2 = new SimpleRnn.Builder().dropOut(cd).activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
break; break;
default: default:
throw new RuntimeException(s); throw new RuntimeException(s);
@ -134,21 +148,21 @@ public class TestRnnLayers extends BaseDL4JTest {
.seed(12345) .seed(12345)
.list() .list()
.layer(layer) .layer(layer)
.layer(new RnnOutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).build()) .layer(new RnnOutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).dataFormat(rnnDataFormat).build())
.build(); .build();
MultiLayerConfiguration confD = new NeuralNetConfiguration.Builder() MultiLayerConfiguration confD = new NeuralNetConfiguration.Builder()
.seed(12345) .seed(12345)
.list() .list()
.layer(layerD) .layer(layerD)
.layer(new RnnOutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).build()) .layer(new RnnOutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).dataFormat(rnnDataFormat).build())
.build(); .build();
MultiLayerConfiguration confD2 = new NeuralNetConfiguration.Builder() MultiLayerConfiguration confD2 = new NeuralNetConfiguration.Builder()
.seed(12345) .seed(12345)
.list() .list()
.layer(layerD2) .layer(layerD2)
.layer(new RnnOutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).build()) .layer(new RnnOutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).dataFormat(rnnDataFormat).build())
.build(); .build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
@ -178,7 +192,6 @@ public class TestRnnLayers extends BaseDL4JTest {
assertNotEquals(s, out2, out2D); assertNotEquals(s, out2, out2D);
INDArray l = TestUtils.randomOneHotTimeSeries(3, 10, 10, 12345); INDArray l = TestUtils.randomOneHotTimeSeries(3, 10, 10, 12345);
net.fit(f.dup(), l); net.fit(f.dup(), l);
netD.fit(f.dup(), l); netD.fit(f.dup(), l);
assertNotEquals(s, net.params(), netD.params()); assertNotEquals(s, net.params(), netD.params());
@ -205,14 +218,14 @@ public class TestRnnLayers extends BaseDL4JTest {
NeuralNetConfiguration.ListBuilder lb = new NeuralNetConfiguration.Builder() NeuralNetConfiguration.ListBuilder lb = new NeuralNetConfiguration.Builder()
.list() .list()
.layer(new SimpleRnn.Builder().nIn(5).nOut(5).build()); .layer(new SimpleRnn.Builder().nIn(5).nOut(5).dataFormat(rnnDataFormat).build());
switch (i){ switch (i){
case 0: case 0:
lb.layer(new RnnOutputLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).nIn(5).nOut(5).build()); lb.layer(new RnnOutputLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).nIn(5).nOut(5).dataFormat(rnnDataFormat).build());
break; break;
case 1: case 1:
lb.layer(new RnnLossLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()); lb.layer(new RnnLossLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).dataFormat(rnnDataFormat).build());
break; break;
default: default:
throw new RuntimeException(); throw new RuntimeException();
@ -223,14 +236,14 @@ public class TestRnnLayers extends BaseDL4JTest {
net.init(); net.init();
INDArray in = Nd4j.rand(DataType.FLOAT, 3, 5, 5); INDArray in = Nd4j.rand(DataType.FLOAT, 3, 5, 5);
INDArray l = TestUtils.randomOneHotTimeSeries(3, 5, 10); INDArray l = TestUtils.randomOneHotTimeSeries(rnnDataFormat, 3, 5, 10, new Random(12345));
try{ try{
net.fit(in,l); net.fit(in,l);
} catch (Throwable t){ } catch (Throwable t){
String msg = t.getMessage(); String msg = t.getMessage();
if(msg == null) if(msg == null)
t.printStackTrace(); t.printStackTrace();
System.out.println(i);
assertTrue(msg, msg != null && msg.contains("sequence length") && msg.contains("input") && msg.contains("label")); assertTrue(msg, msg != null && msg.contains("sequence length") && msg.contains("input") && msg.contains("label"));
} }

View File

@ -20,10 +20,13 @@ import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils; import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -36,8 +39,18 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.all;
import static org.nd4j.linalg.indexing.NDArrayIndex.interval; import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
import static org.nd4j.linalg.indexing.NDArrayIndex.point; import static org.nd4j.linalg.indexing.NDArrayIndex.point;
@RunWith(Parameterized.class)
public class TestSimpleRnn extends BaseDL4JTest { public class TestSimpleRnn extends BaseDL4JTest {
private RNNFormat rnnDataFormat;
public TestSimpleRnn(RNNFormat rnnDataFormat){
this.rnnDataFormat = rnnDataFormat;
}
@Parameterized.Parameters
public static Object[] params(){
return RNNFormat.values();
}
@Test @Test
public void testSimpleRnn(){ public void testSimpleRnn(){
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -46,15 +59,21 @@ public class TestSimpleRnn extends BaseDL4JTest {
int nIn = 5; int nIn = 5;
int layerSize = 6; int layerSize = 6;
int tsLength = 7; int tsLength = 7;
INDArray in = Nd4j.rand(DataType.FLOAT, new int[]{m, nIn, tsLength}); INDArray in;
// in.get(all(), all(), interval(1,tsLength)).assign(0); if (rnnDataFormat == RNNFormat.NCW){
in = Nd4j.rand(DataType.FLOAT, m, nIn, tsLength);
}
else{
in = Nd4j.rand(DataType.FLOAT, m, tsLength, nIn);
}
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.updater(new NoOp()) .updater(new NoOp())
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
.activation(Activation.TANH) .activation(Activation.TANH)
.list() .list()
.layer(new SimpleRnn.Builder().nIn(nIn).nOut(layerSize).build()) .layer(new SimpleRnn.Builder().nIn(nIn).nOut(layerSize).dataFormat(rnnDataFormat).build())
.build(); .build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
@ -68,7 +87,13 @@ public class TestSimpleRnn extends BaseDL4JTest {
INDArray outLast = null; INDArray outLast = null;
for( int i=0; i<tsLength; i++ ){ for( int i=0; i<tsLength; i++ ){
INDArray inCurrent = in.get(all(), all(), point(i)); INDArray inCurrent;
if (rnnDataFormat == RNNFormat.NCW){
inCurrent = in.get(all(), all(), point(i));
}
else{
inCurrent = in.get(all(), point(i), all());
}
INDArray outExpCurrent = inCurrent.mmul(w); INDArray outExpCurrent = inCurrent.mmul(w);
if(outLast != null){ if(outLast != null){
@ -79,7 +104,13 @@ public class TestSimpleRnn extends BaseDL4JTest {
Transforms.tanh(outExpCurrent, false); Transforms.tanh(outExpCurrent, false);
INDArray outActCurrent = out.get(all(), all(), point(i)); INDArray outActCurrent;
if (rnnDataFormat == RNNFormat.NCW){
outActCurrent = out.get(all(), all(), point(i));
}
else{
outActCurrent = out.get(all(), point(i), all());
}
assertEquals(String.valueOf(i), outExpCurrent, outActCurrent); assertEquals(String.valueOf(i), outExpCurrent, outActCurrent);
outLast = outExpCurrent; outLast = outExpCurrent;
@ -100,7 +131,7 @@ public class TestSimpleRnn extends BaseDL4JTest {
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
.activation(Activation.TANH) .activation(Activation.TANH)
.list() .list()
.layer(new SimpleRnn.Builder().nIn(nIn).nOut(layerSize) .layer(new SimpleRnn.Builder().nIn(nIn).nOut(layerSize).dataFormat(rnnDataFormat)
.biasInit(100) .biasInit(100)
.build()) .build())
.build(); .build();

View File

@ -4,14 +4,21 @@ import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils; import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.WorkspaceMode; import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer;
import org.deeplearning4j.nn.conf.layers.LSTM; import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed; import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -22,8 +29,18 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
@RunWith(Parameterized.class)
public class TestTimeDistributed extends BaseDL4JTest { public class TestTimeDistributed extends BaseDL4JTest {
private RNNFormat rnnDataFormat;
public TestTimeDistributed(RNNFormat rnnDataFormat){
this.rnnDataFormat = rnnDataFormat;
}
@Parameterized.Parameters
public static Object[] params(){
return RNNFormat.values();
}
@Test @Test
public void testTimeDistributed(){ public void testTimeDistributed(){
for(WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.ENABLED, WorkspaceMode.NONE}) { for(WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.ENABLED, WorkspaceMode.NONE}) {
@ -34,11 +51,11 @@ public class TestTimeDistributed extends BaseDL4JTest {
.seed(12345) .seed(12345)
.updater(new Adam(0.1)) .updater(new Adam(0.1))
.list() .list()
.layer(new LSTM.Builder().nIn(3).nOut(3).build()) .layer(new LSTM.Builder().nIn(3).nOut(3).dataFormat(rnnDataFormat).build())
.layer(new DenseLayer.Builder().nIn(3).nOut(3).activation(Activation.TANH).build()) .layer(new DenseLayer.Builder().nIn(3).nOut(3).activation(Activation.TANH).build())
.layer(new RnnOutputLayer.Builder().nIn(3).nOut(3).activation(Activation.SOFTMAX) .layer(new RnnOutputLayer.Builder().nIn(3).nOut(3).activation(Activation.SOFTMAX).dataFormat(rnnDataFormat)
.lossFunction(LossFunctions.LossFunction.MCXENT).build()) .lossFunction(LossFunctions.LossFunction.MCXENT).build())
.setInputType(InputType.recurrent(3)) .setInputType(InputType.recurrent(3, rnnDataFormat))
.build(); .build();
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder() MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder()
@ -47,11 +64,11 @@ public class TestTimeDistributed extends BaseDL4JTest {
.seed(12345) .seed(12345)
.updater(new Adam(0.1)) .updater(new Adam(0.1))
.list() .list()
.layer(new LSTM.Builder().nIn(3).nOut(3).build()) .layer(new LSTM.Builder().nIn(3).nOut(3).dataFormat(rnnDataFormat).build())
.layer(new TimeDistributed(new DenseLayer.Builder().nIn(3).nOut(3).activation(Activation.TANH).build(), 2)) .layer(new TimeDistributed(new DenseLayer.Builder().nIn(3).nOut(3).activation(Activation.TANH).build(), rnnDataFormat))
.layer(new RnnOutputLayer.Builder().nIn(3).nOut(3).activation(Activation.SOFTMAX) .layer(new RnnOutputLayer.Builder().nIn(3).nOut(3).activation(Activation.SOFTMAX).dataFormat(rnnDataFormat)
.lossFunction(LossFunctions.LossFunction.MCXENT).build()) .lossFunction(LossFunctions.LossFunction.MCXENT).build())
.setInputType(InputType.recurrent(3)) .setInputType(InputType.recurrent(3, rnnDataFormat))
.build(); .build();
MultiLayerNetwork net1 = new MultiLayerNetwork(conf1); MultiLayerNetwork net1 = new MultiLayerNetwork(conf1);
@ -62,13 +79,21 @@ public class TestTimeDistributed extends BaseDL4JTest {
for( int mb : new int[]{1, 5}) { for( int mb : new int[]{1, 5}) {
for(char inLabelOrder : new char[]{'c', 'f'}) { for(char inLabelOrder : new char[]{'c', 'f'}) {
INDArray in = Nd4j.rand(DataType.FLOAT, mb, 3, 5).dup(inLabelOrder); INDArray in = Nd4j.rand(DataType.FLOAT, mb, 3, 5).dup(inLabelOrder);
if (rnnDataFormat == RNNFormat.NWC){
in = in.permute(0, 2, 1);
}
INDArray out1 = net1.output(in); INDArray out1 = net1.output(in);
INDArray out2 = net2.output(in); INDArray out2 = net2.output(in);
assertEquals(out1, out2); assertEquals(out1, out2);
INDArray labels = TestUtils.randomOneHotTimeSeries(mb, 3, 5).dup(inLabelOrder); INDArray labels ;
if (rnnDataFormat == RNNFormat.NCW) {
labels = TestUtils.randomOneHotTimeSeries(mb, 3, 5).dup(inLabelOrder);
}else{
labels = TestUtils.randomOneHotTimeSeries(mb, 5, 3).dup(inLabelOrder);
}
DataSet ds = new DataSet(in, labels); DataSet ds = new DataSet(in, labels);
net1.fit(ds); net1.fit(ds);
@ -85,4 +110,73 @@ public class TestTimeDistributed extends BaseDL4JTest {
} }
} }
} }
@Test
public void testTimeDistributedDense(){
for( int rnnType=0; rnnType<3; rnnType++ ) {
for( int ffType=0; ffType<3; ffType++ ) {
Layer l0, l2;
switch (rnnType) {
case 0:
l0 = new LSTM.Builder().nOut(5).build();
l2 = new LSTM.Builder().nOut(5).build();
break;
case 1:
l0 = new SimpleRnn.Builder().nOut(5).build();
l2 = new SimpleRnn.Builder().nOut(5).build();
break;
case 2:
l0 = new Bidirectional(new LSTM.Builder().nOut(5).build());
l2 = new Bidirectional(new LSTM.Builder().nOut(5).build());
break;
default:
throw new RuntimeException("Not implemented: " + rnnType);
}
Layer l1;
switch (ffType){
case 0:
l1 = new DenseLayer.Builder().nOut(5).build();
break;
case 1:
l1 = new VariationalAutoencoder.Builder().nOut(5).encoderLayerSizes(5).decoderLayerSizes(5).build();
break;
case 2:
l1 = new AutoEncoder.Builder().nOut(5).build();
break;
default:
throw new RuntimeException("Not implemented: " + ffType);
}
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.activation(Activation.TANH)
.list()
.layer(l0)
.layer(l1)
.layer(l2)
.setInputType(InputType.recurrent(5, 9, rnnDataFormat))
.build();
BaseRecurrentLayer l0a;
BaseRecurrentLayer l2a;
if (rnnType < 2) {
l0a = (BaseRecurrentLayer) l0;
l2a = (BaseRecurrentLayer) l2;
} else {
l0a = (BaseRecurrentLayer) ((Bidirectional) l0).getFwd();
l2a = (BaseRecurrentLayer) ((Bidirectional) l2).getFwd();
}
assertEquals(rnnDataFormat, l0a.getRnnDataFormat());
assertEquals(rnnDataFormat, l2a.getRnnDataFormat());
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
INDArray in = Nd4j.rand(DataType.FLOAT, rnnDataFormat == RNNFormat.NCW ? new long[]{2, 5, 9} : new long[]{2, 9, 5} );
net.output(in);
}
}
}
} }

View File

@ -771,7 +771,7 @@ public class MultiLayerTestRNN extends BaseDL4JTest {
.build(); .build();
fail("Exception expected"); fail("Exception expected");
} catch (IllegalStateException e){ } catch (IllegalStateException e){
// e.printStackTrace(); log.info(e.toString());
assertTrue(e.getMessage().contains("TBPTT") && e.getMessage().contains("validateTbpttConfig")); assertTrue(e.getMessage().contains("TBPTT") && e.getMessage().contains("validateTbpttConfig"));
} }
} }

View File

@ -408,8 +408,8 @@ public class TestVariableLengthTS extends BaseDL4JTest {
INDArray outRow2 = out2.get(NDArrayIndex.point(i), NDArrayIndex.all(), INDArray outRow2 = out2.get(NDArrayIndex.point(i), NDArrayIndex.all(),
NDArrayIndex.point(j)); NDArrayIndex.point(j));
for (int k = 0; k < nOut; k++) { for (int k = 0; k < nOut; k++) {
assertEquals(outRow.getDouble(k), 0.0, 0.0); assertEquals(0.0, outRow.getDouble(k), 0.0);
assertEquals(outRow2.getDouble(k), 0.0, 0.0); assertEquals(0.0, outRow2.getDouble(k), 0.0);
} }
} }
} }

View File

@ -336,7 +336,7 @@ public class TestUpdaters extends BaseDL4JTest {
actualM[i] = Math.round(actualM[i] * 1e2) / 1e2; actualM[i] = Math.round(actualM[i] * 1e2) / 1e2;
} }
assertEquals("Wrong weight gradient after first iteration's update", Arrays.equals(actualM, expectedM), true); assertEquals("Wrong weight gradient after first iteration's update", Arrays.equals(expectedM, actualM), true);
} }

View File

@ -159,14 +159,14 @@ public class RegressionTest050 extends BaseDL4JTest {
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize()); assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
assertArrayEquals(new int[] {1, 1}, l0.getStride()); assertArrayEquals(new int[] {1, 1}, l0.getStride());
assertArrayEquals(new int[] {0, 0}, l0.getPadding()); assertArrayEquals(new int[] {0, 0}, l0.getPadding());
assertEquals(l0.getConvolutionMode(), ConvolutionMode.Truncate); //Pre-0.7.0: no ConvolutionMode. Want to default to truncate here if not set assertEquals(ConvolutionMode.Truncate, l0.getConvolutionMode()); //Pre-0.7.0: no ConvolutionMode. Want to default to truncate here if not set
SubsamplingLayer l1 = (SubsamplingLayer) conf.getConf(1).getLayer(); SubsamplingLayer l1 = (SubsamplingLayer) conf.getConf(1).getLayer();
assertArrayEquals(new int[] {2, 2}, l1.getKernelSize()); assertArrayEquals(new int[] {2, 2}, l1.getKernelSize());
assertArrayEquals(new int[] {1, 1}, l1.getStride()); assertArrayEquals(new int[] {1, 1}, l1.getStride());
assertArrayEquals(new int[] {0, 0}, l1.getPadding()); assertArrayEquals(new int[] {0, 0}, l1.getPadding());
assertEquals(PoolingType.MAX, l1.getPoolingType()); assertEquals(PoolingType.MAX, l1.getPoolingType());
assertEquals(l1.getConvolutionMode(), ConvolutionMode.Truncate); //Pre-0.7.0: no ConvolutionMode. Want to default to truncate here if not set assertEquals(ConvolutionMode.Truncate, l1.getConvolutionMode()); //Pre-0.7.0: no ConvolutionMode. Want to default to truncate here if not set
OutputLayer l2 = (OutputLayer) conf.getConf(2).getLayer(); OutputLayer l2 = (OutputLayer) conf.getConf(2).getLayer();
assertEquals("sigmoid", l2.getActivationFn().toString()); assertEquals("sigmoid", l2.getActivationFn().toString());

View File

@ -162,14 +162,14 @@ public class RegressionTest060 extends BaseDL4JTest {
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize()); assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
assertArrayEquals(new int[] {1, 1}, l0.getStride()); assertArrayEquals(new int[] {1, 1}, l0.getStride());
assertArrayEquals(new int[] {0, 0}, l0.getPadding()); assertArrayEquals(new int[] {0, 0}, l0.getPadding());
assertEquals(l0.getConvolutionMode(), ConvolutionMode.Truncate); //Pre-0.7.0: no ConvolutionMode. Want to default to truncate here if not set assertEquals(ConvolutionMode.Truncate, l0.getConvolutionMode()); //Pre-0.7.0: no ConvolutionMode. Want to default to truncate here if not set
SubsamplingLayer l1 = (SubsamplingLayer) conf.getConf(1).getLayer(); SubsamplingLayer l1 = (SubsamplingLayer) conf.getConf(1).getLayer();
assertArrayEquals(new int[] {2, 2}, l1.getKernelSize()); assertArrayEquals(new int[] {2, 2}, l1.getKernelSize());
assertArrayEquals(new int[] {1, 1}, l1.getStride()); assertArrayEquals(new int[] {1, 1}, l1.getStride());
assertArrayEquals(new int[] {0, 0}, l1.getPadding()); assertArrayEquals(new int[] {0, 0}, l1.getPadding());
assertEquals(PoolingType.MAX, l1.getPoolingType()); assertEquals(PoolingType.MAX, l1.getPoolingType());
assertEquals(l1.getConvolutionMode(), ConvolutionMode.Truncate); //Pre-0.7.0: no ConvolutionMode. Want to default to truncate here if not set assertEquals(ConvolutionMode.Truncate, l1.getConvolutionMode()); //Pre-0.7.0: no ConvolutionMode. Want to default to truncate here if not set
OutputLayer l2 = (OutputLayer) conf.getConf(2).getLayer(); OutputLayer l2 = (OutputLayer) conf.getConf(2).getLayer();
assertEquals("sigmoid", l2.getActivationFn().toString()); assertEquals("sigmoid", l2.getActivationFn().toString());

View File

@ -162,7 +162,7 @@ public class RegressionTest071 extends BaseDL4JTest {
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize()); assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
assertArrayEquals(new int[] {1, 1}, l0.getStride()); assertArrayEquals(new int[] {1, 1}, l0.getStride());
assertArrayEquals(new int[] {0, 0}, l0.getPadding()); assertArrayEquals(new int[] {0, 0}, l0.getPadding());
assertEquals(l0.getConvolutionMode(), ConvolutionMode.Same); assertEquals(ConvolutionMode.Same, l0.getConvolutionMode());
SubsamplingLayer l1 = (SubsamplingLayer) conf.getConf(1).getLayer(); SubsamplingLayer l1 = (SubsamplingLayer) conf.getConf(1).getLayer();
assertArrayEquals(new int[] {2, 2}, l1.getKernelSize()); assertArrayEquals(new int[] {2, 2}, l1.getKernelSize());

View File

@ -176,14 +176,14 @@ public class RegressionTest080 extends BaseDL4JTest {
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize()); assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
assertArrayEquals(new int[] {1, 1}, l0.getStride()); assertArrayEquals(new int[] {1, 1}, l0.getStride());
assertArrayEquals(new int[] {0, 0}, l0.getPadding()); assertArrayEquals(new int[] {0, 0}, l0.getPadding());
assertEquals(l0.getConvolutionMode(), ConvolutionMode.Same); assertEquals(ConvolutionMode.Same, l0.getConvolutionMode());
SubsamplingLayer l1 = (SubsamplingLayer) conf.getConf(1).getLayer(); SubsamplingLayer l1 = (SubsamplingLayer) conf.getConf(1).getLayer();
assertArrayEquals(new int[] {2, 2}, l1.getKernelSize()); assertArrayEquals(new int[] {2, 2}, l1.getKernelSize());
assertArrayEquals(new int[] {1, 1}, l1.getStride()); assertArrayEquals(new int[] {1, 1}, l1.getStride());
assertArrayEquals(new int[] {0, 0}, l1.getPadding()); assertArrayEquals(new int[] {0, 0}, l1.getPadding());
assertEquals(PoolingType.MAX, l1.getPoolingType()); assertEquals(PoolingType.MAX, l1.getPoolingType());
assertEquals(l1.getConvolutionMode(), ConvolutionMode.Same); assertEquals(ConvolutionMode.Same, l1.getConvolutionMode());
OutputLayer l2 = (OutputLayer) conf.getConf(2).getLayer(); OutputLayer l2 = (OutputLayer) conf.getConf(2).getLayer();
assertTrue(l2.getActivationFn() instanceof ActivationSigmoid); assertTrue(l2.getActivationFn() instanceof ActivationSigmoid);

View File

@ -178,8 +178,6 @@ public abstract class BaseCudnnHelper {
} }
} }
protected static final int TENSOR_FORMAT = CUDNN_TENSOR_NCHW;
protected final DataType nd4jDataType; protected final DataType nd4jDataType;
protected final int dataType; protected final int dataType;
protected final int dataTypeSize; protected final int dataTypeSize;

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -22,6 +23,7 @@ import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import com.jakewharton.byteunits.BinaryByteUnit; import com.jakewharton.byteunits.BinaryByteUnit;
import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.Pointer;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer.AlgoMode; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer.AlgoMode;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer.BwdDataAlgo; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer.BwdDataAlgo;
@ -86,7 +88,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
} }
private cudnnTensorStruct srcTensorDesc = new cudnnTensorStruct(), dstTensorDesc = new cudnnTensorStruct(), private cudnnTensorStruct srcTensorDesc = new cudnnTensorStruct(), dstTensorDesc = new cudnnTensorStruct(),
biasTensorDesc = new cudnnTensorStruct(), deltaTensorDesc = new cudnnTensorStruct(); biasTensorDesc = new cudnnTensorStruct(), deltaTensorDesc = new cudnnTensorStruct();
private cudnnFilterStruct filterDesc = new cudnnFilterStruct(); private cudnnFilterStruct filterDesc = new cudnnFilterStruct();
private cudnnConvolutionStruct convDesc = new cudnnConvolutionStruct(); private cudnnConvolutionStruct convDesc = new cudnnConvolutionStruct();
private cudnnActivationStruct activationDesc = new cudnnActivationStruct(); private cudnnActivationStruct activationDesc = new cudnnActivationStruct();
@ -138,7 +140,21 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray weights, INDArray bias, INDArray delta, int[] kernel, public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray weights, INDArray bias, INDArray delta, int[] kernel,
int[] strides, int[] pad, INDArray biasGradView, INDArray weightGradView, IActivation afn, int[] strides, int[] pad, INDArray biasGradView, INDArray weightGradView, IActivation afn,
AlgoMode mode, BwdFilterAlgo bwdFilterAlgo, BwdDataAlgo bwdDataAlgo, AlgoMode mode, BwdFilterAlgo bwdFilterAlgo, BwdDataAlgo bwdDataAlgo,
ConvolutionMode convolutionMode, int[] dilation, LayerWorkspaceMgr workspaceMgr) { ConvolutionMode convolutionMode, int[] dilation, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) {
//AB 2020/04/21 - cuDNN does have NHWC support (with limitations) however I have been unable to get it working
// correctly on NHWC data, even after updating all descriptors, tensor format, etc.
//Therefore: all computation here is done in NCHW format only
//As of a future (next?) release we'll likely switch to C++ for cuDNN support
boolean origNHWC = false;
if(format == CNN2DFormat.NHWC){
input = input.permute(0,3,1,2); //NHWC to NCHW
delta = delta.permute(0,3,1,2);
origNHWC = true;
}
int TENSOR_FORMAT = CUDNN_TENSOR_NCHW;
int code; int code;
val miniBatch = input.size(0); val miniBatch = input.size(0);
@ -147,7 +163,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
val kH = weights.size(2); val kH = weights.size(2);
val kW = weights.size(3); val kW = weights.size(3);
CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, null); CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, null, CNN2DFormat.NCHW); //Note hardcoded NCHW due to above
input = args.getInput(); input = args.getInput();
val inH = input.size(2); val inH = input.size(2);
val inW = input.size(3); val inW = input.size(3);
@ -176,7 +192,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
(int) deltaStride[0], (int) deltaStride[1], (int) deltaStride[2], (int) deltaStride[3]); (int) deltaStride[0], (int) deltaStride[1], (int) deltaStride[2], (int) deltaStride[3]);
checkCudnn(false, "cudnnSetTensor4dDescriptorEx", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation); checkCudnn(false, "cudnnSetTensor4dDescriptorEx", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
code = cudnnSetConvolution2dDescriptor(cudnnContext.convDesc, pad[0], pad[1], strides[0], strides[1], dilation[0], code = cudnnSetConvolution2dDescriptor(cudnnContext.convDesc, pad[0], pad[1], strides[0], strides[1], dilation[0],
dilation[1], CUDNN_CROSS_CORRELATION, dataType); dilation[1], CUDNN_CROSS_CORRELATION, dataType);
checkCudnn(false, "cudnnSetConvolution2dDescriptor", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation); checkCudnn(false, "cudnnSetConvolution2dDescriptor", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
code = cudnnSetFilter4dDescriptor(cudnnContext.filterDesc, dataType, TENSOR_FORMAT, (int) outDepth, (int) inDepth, (int) kH, (int) kW); code = cudnnSetFilter4dDescriptor(cudnnContext.filterDesc, dataType, TENSOR_FORMAT, (int) outDepth, (int) inDepth, (int) kH, (int) kW);
checkCudnn(false, "cudnnSetFilter4dDescriptor", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation); checkCudnn(false, "cudnnSetFilter4dDescriptor", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
@ -238,16 +254,16 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
} }
} else { } else {
code = cudnnGetConvolutionBackwardFilterAlgorithm(cudnnContext, cudnnContext.srcTensorDesc, code = cudnnGetConvolutionBackwardFilterAlgorithm(cudnnContext, cudnnContext.srcTensorDesc,
cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.filterDesc, cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.filterDesc,
mode == AlgoMode.NO_WORKSPACE ? CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE mode == AlgoMode.NO_WORKSPACE ? CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE
: CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST, : CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST,
0, algo1); 0, algo1);
checkCudnn(false, "cudnnGetConvolutionBackwardFilterAlgorithm", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation); checkCudnn(false, "cudnnGetConvolutionBackwardFilterAlgorithm", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
code = cudnnGetConvolutionBackwardDataAlgorithm(cudnnContext, cudnnContext.filterDesc, code = cudnnGetConvolutionBackwardDataAlgorithm(cudnnContext, cudnnContext.filterDesc,
cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.srcTensorDesc, cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.srcTensorDesc,
mode == AlgoMode.NO_WORKSPACE ? CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE mode == AlgoMode.NO_WORKSPACE ? CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE
: CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST, : CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST,
0, algo2); 0, algo2);
checkCudnn(false, "cudnnGetConvolutionBackwardDataAlgorithm", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation); checkCudnn(false, "cudnnGetConvolutionBackwardDataAlgorithm", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
} }
@ -263,7 +279,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
Allocator allocator = AtomicAllocator.getInstance(); Allocator allocator = AtomicAllocator.getInstance();
CudaContext context = allocator.getFlowController().prepareActionAllWrite(input, weights, weightGradView, CudaContext context = allocator.getFlowController().prepareActionAllWrite(input, weights, weightGradView,
biasGradView, delta, epsNext); biasGradView, delta, epsNext);
Pointer srcData = allocator.getPointer(input, context); Pointer srcData = allocator.getPointer(input, context);
Pointer filterData = allocator.getPointer(weights, context); Pointer filterData = allocator.getPointer(weights, context);
Pointer filterGradData = allocator.getPointer(weightGradView, context); Pointer filterGradData = allocator.getPointer(weightGradView, context);
@ -279,14 +295,14 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
checkCudnn(false, "cudnnSetTensor4dDescriptorEx", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation); checkCudnn(false, "cudnnSetTensor4dDescriptorEx", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
code = cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnnContext, cudnnContext.srcTensorDesc, code = cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnnContext, cudnnContext.srcTensorDesc,
cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.filterDesc, algo1[0], cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.filterDesc, algo1[0],
sizeInBytes); sizeInBytes);
checkCudnn(false, "cudnnGetConvolutionBackwardFilterWorkspaceSize", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation); checkCudnn(false, "cudnnGetConvolutionBackwardFilterWorkspaceSize", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
long sizeInBytes1 = sizeInBytes.get(0); long sizeInBytes1 = sizeInBytes.get(0);
code = cudnnGetConvolutionBackwardDataWorkspaceSize(cudnnContext, cudnnContext.filterDesc, code = cudnnGetConvolutionBackwardDataWorkspaceSize(cudnnContext, cudnnContext.filterDesc,
cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.dstTensorDesc, algo2[0], cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.dstTensorDesc, algo2[0],
sizeInBytes); sizeInBytes);
checkCudnn(false, "cudnnGetConvolutionBackwardDataWorkspaceSize", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation); checkCudnn(false, "cudnnGetConvolutionBackwardDataWorkspaceSize", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
DataCache workSpace = workspaceMgr.getHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY); DataCache workSpace = workspaceMgr.getHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY);
@ -313,21 +329,21 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
checkCudnn(false, "cudnnSetTensor4dDescriptor", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation); checkCudnn(false, "cudnnSetTensor4dDescriptor", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
code = cudnnConvolutionBackwardBias(cudnnContext, alpha, cudnnContext.deltaTensorDesc, deltaData, beta, code = cudnnConvolutionBackwardBias(cudnnContext, alpha, cudnnContext.deltaTensorDesc, deltaData, beta,
cudnnContext.biasTensorDesc, biasGradData); cudnnContext.biasTensorDesc, biasGradData);
checkCudnn(false, "cudnnConvolutionBackwardBias", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation); checkCudnn(false, "cudnnConvolutionBackwardBias", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
code = cudnnConvolutionBackwardFilter(cudnnContext, alpha, cudnnContext.srcTensorDesc, srcData, code = cudnnConvolutionBackwardFilter(cudnnContext, alpha, cudnnContext.srcTensorDesc, srcData,
cudnnContext.deltaTensorDesc, deltaData, cudnnContext.convDesc, algo1[0], workSpace, cudnnContext.deltaTensorDesc, deltaData, cudnnContext.convDesc, algo1[0], workSpace,
workSpace.capacity(), beta, cudnnContext.filterDesc, filterGradData); workSpace.capacity(), beta, cudnnContext.filterDesc, filterGradData);
checkCudnn(false, "cudnnConvolutionBackwardFilter", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation); checkCudnn(false, "cudnnConvolutionBackwardFilter", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
code = cudnnConvolutionBackwardData(cudnnContext, alpha, cudnnContext.filterDesc, filterData, code = cudnnConvolutionBackwardData(cudnnContext, alpha, cudnnContext.filterDesc, filterData,
cudnnContext.deltaTensorDesc, deltaData, cudnnContext.convDesc, algo2[0], workSpace, cudnnContext.deltaTensorDesc, deltaData, cudnnContext.convDesc, algo2[0], workSpace,
workSpace.capacity(), beta, cudnnContext.dstTensorDesc, dstData); workSpace.capacity(), beta, cudnnContext.dstTensorDesc, dstData);
checkCudnn(false, "cudnnConvolutionBackwardData", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation); checkCudnn(false, "cudnnConvolutionBackwardData", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
allocator.getFlowController().registerActionAllWrite(context, input, weights, weightGradView, biasGradView, allocator.getFlowController().registerActionAllWrite(context, input, weights, weightGradView, biasGradView,
delta, epsNext); delta, epsNext);
Gradient retGradient = new DefaultGradient(); Gradient retGradient = new DefaultGradient();
retGradient.setGradientFor(ConvolutionParamInitializer.BIAS_KEY, biasGradView); retGradient.setGradientFor(ConvolutionParamInitializer.BIAS_KEY, biasGradView);
@ -344,12 +360,30 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
interval(0, epsNext.size(3) - (args.isManualPadRight() ? 1 : 0))); interval(0, epsNext.size(3) - (args.isManualPadRight() ? 1 : 0)));
} }
if(origNHWC){
epsNext = epsNext.permute(0,2,3,1); //NCHW to NHWC
}
return new Pair<>(retGradient, epsNext); return new Pair<>(retGradient, epsNext);
} }
@Override @Override
public INDArray preOutput(INDArray input, INDArray weights, INDArray bias, int[] kernel, int[] strides, int[] pad, public INDArray preOutput(INDArray input, INDArray weights, INDArray bias, int[] kernel, int[] strides, int[] pad,
AlgoMode mode, FwdAlgo fwdAlgo, ConvolutionMode convolutionMode, int[] dilation, LayerWorkspaceMgr workspaceMgr) { AlgoMode mode, FwdAlgo fwdAlgo, ConvolutionMode convolutionMode, int[] dilation, CNN2DFormat format,
LayerWorkspaceMgr workspaceMgr) {
//AB 2020/04/21 - cuDNN does have NHWC support (with limitations) however I have been unable to get it working
// correctly on NHWC data, even after updating all descriptors, tensor format, etc.
//Therefore: all computation here is done in NCHW format only
//As of a future (next?) release we'll likely switch to C++ for cuDNN support
boolean origNHWC = false;
if(format == CNN2DFormat.NHWC){
input = input.permute(0,3,1,2); //NHWC to NCHW
origNHWC = true;
}
int TENSOR_FORMAT = CUDNN_TENSOR_NCHW;
int code; int code;
val miniBatch = input.size(0); val miniBatch = input.size(0);
@ -358,7 +392,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
val kH = weights.size(2); val kH = weights.size(2);
val kW = weights.size(3); val kW = weights.size(3);
CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, null); CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, null, CNN2DFormat.NCHW); //Note hardcoded NCHW due to above
input = args.getInput(); input = args.getInput();
val inH = input.size(2); val inH = input.size(2);
val inW = input.size(3); val inW = input.size(3);
@ -378,7 +412,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
checkCudnn(true, "cudnnSetFilter4dDescriptor", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation); checkCudnn(true, "cudnnSetFilter4dDescriptor", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
code = cudnnSetConvolution2dDescriptor(cudnnContext.convDesc, pad[0], pad[1], strides[0], strides[1], dilation[0], code = cudnnSetConvolution2dDescriptor(cudnnContext.convDesc, pad[0], pad[1], strides[0], strides[1], dilation[0],
dilation[1], CUDNN_CROSS_CORRELATION, dataType); dilation[1], CUDNN_CROSS_CORRELATION, dataType);
checkCudnn(true, "cudnnSetConvolution2dDescriptor", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation); checkCudnn(true, "cudnnSetConvolution2dDescriptor", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
@ -460,8 +494,8 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
checkCudnn(true, "cudnnSetStream", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation); checkCudnn(true, "cudnnSetStream", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
code = cudnnGetConvolutionForwardWorkspaceSize(cudnnContext, cudnnContext.srcTensorDesc, code = cudnnGetConvolutionForwardWorkspaceSize(cudnnContext, cudnnContext.srcTensorDesc,
cudnnContext.filterDesc, cudnnContext.convDesc, cudnnContext.dstTensorDesc, algo[0], cudnnContext.filterDesc, cudnnContext.convDesc, cudnnContext.dstTensorDesc, algo[0],
sizeInBytes); sizeInBytes);
checkCudnn(true, "cudnnGetConvolutionForwardWorkspaceSize", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation); checkCudnn(true, "cudnnGetConvolutionForwardWorkspaceSize", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
DataCache workSpace = workspaceMgr.getHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY); DataCache workSpace = workspaceMgr.getHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY);
@ -482,8 +516,8 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
workspaceMgr.setHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY, workSpace); workspaceMgr.setHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY, workSpace);
} }
code = cudnnConvolutionForward(cudnnContext, alpha, cudnnContext.srcTensorDesc, srcData, code = cudnnConvolutionForward(cudnnContext, alpha, cudnnContext.srcTensorDesc, srcData,
cudnnContext.filterDesc, filterData, cudnnContext.convDesc, algo[0], workSpace, cudnnContext.filterDesc, filterData, cudnnContext.convDesc, algo[0], workSpace,
workSpace.capacity(), beta, cudnnContext.dstTensorDesc, dstData); workSpace.capacity(), beta, cudnnContext.dstTensorDesc, dstData);
checkCudnn(true, "cudnnConvolutionForward", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation); checkCudnn(true, "cudnnConvolutionForward", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
@ -491,7 +525,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
checkCudnn(true, "cudnnSetTensor4dDescriptor", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation); checkCudnn(true, "cudnnSetTensor4dDescriptor", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
code = cudnnAddTensor(cudnnContext, alpha, cudnnContext.biasTensorDesc, biasData, alpha, code = cudnnAddTensor(cudnnContext, alpha, cudnnContext.biasTensorDesc, biasData, alpha,
cudnnContext.dstTensorDesc, dstData); cudnnContext.dstTensorDesc, dstData);
checkCudnn(true, "cudnnAddTensor", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation); checkCudnn(true, "cudnnAddTensor", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
allocator.registerAction(context, z, input, weights, bias); allocator.registerAction(context, z, input, weights, bias);
@ -499,6 +533,10 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
if (CudaEnvironment.getInstance().getConfiguration().isDebug()) if (CudaEnvironment.getInstance().getConfiguration().isDebug())
context.syncOldStream(); context.syncOldStream();
if(origNHWC){
z = z.permute(0,2,3,1); //NCHW to NHWC
}
return z; return z;
} }
@ -552,29 +590,29 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
break; break;
case "sigmoid": case "sigmoid":
checkCudnn(cudnnSetActivationDescriptor(cudnnContext.activationDesc, CUDNN_ACTIVATION_SIGMOID, checkCudnn(cudnnSetActivationDescriptor(cudnnContext.activationDesc, CUDNN_ACTIVATION_SIGMOID,
CUDNN_PROPAGATE_NAN, 0)); CUDNN_PROPAGATE_NAN, 0));
checkCudnn(cudnnActivationForward(cudnnContext, cudnnContext.activationDesc, alpha, checkCudnn(cudnnActivationForward(cudnnContext, cudnnContext.activationDesc, alpha,
cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData)); cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData));
break; break;
case "relu": case "relu":
checkCudnn(cudnnSetActivationDescriptor(cudnnContext.activationDesc, CUDNN_ACTIVATION_RELU, checkCudnn(cudnnSetActivationDescriptor(cudnnContext.activationDesc, CUDNN_ACTIVATION_RELU,
CUDNN_PROPAGATE_NAN, 0)); CUDNN_PROPAGATE_NAN, 0));
checkCudnn(cudnnActivationForward(cudnnContext, cudnnContext.activationDesc, alpha, checkCudnn(cudnnActivationForward(cudnnContext, cudnnContext.activationDesc, alpha,
cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData)); cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData));
break; break;
case "tanh": case "tanh":
checkCudnn(cudnnSetActivationDescriptor(cudnnContext.activationDesc, CUDNN_ACTIVATION_TANH, checkCudnn(cudnnSetActivationDescriptor(cudnnContext.activationDesc, CUDNN_ACTIVATION_TANH,
CUDNN_PROPAGATE_NAN, 0)); CUDNN_PROPAGATE_NAN, 0));
checkCudnn(cudnnActivationForward(cudnnContext, cudnnContext.activationDesc, alpha, checkCudnn(cudnnActivationForward(cudnnContext, cudnnContext.activationDesc, alpha,
cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData)); cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData));
break; break;
case "softmax": case "softmax":
checkCudnn(cudnnSoftmaxForward(cudnnContext, CUDNN_SOFTMAX_ACCURATE, CUDNN_SOFTMAX_MODE_CHANNEL, alpha, checkCudnn(cudnnSoftmaxForward(cudnnContext, CUDNN_SOFTMAX_ACCURATE, CUDNN_SOFTMAX_MODE_CHANNEL, alpha,
cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData)); cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData));
break; break;
case "logsoftmax": case "logsoftmax":
checkCudnn(cudnnSoftmaxForward(cudnnContext, CUDNN_SOFTMAX_LOG, CUDNN_SOFTMAX_MODE_CHANNEL, alpha, checkCudnn(cudnnSoftmaxForward(cudnnContext, CUDNN_SOFTMAX_LOG, CUDNN_SOFTMAX_MODE_CHANNEL, alpha,
cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData)); cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData));
break; break;
default: default:
activation = null; activation = null;
@ -593,7 +631,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
* @return * @return
*/ */
public static CudnnForwardArgs getCudnnForwardArgs(INDArray input, int[] kernel, int[] strides, int[] padding, int[] dilation, public static CudnnForwardArgs getCudnnForwardArgs(INDArray input, int[] kernel, int[] strides, int[] padding, int[] dilation,
ConvolutionMode convolutionMode, PoolingType poolingType){ ConvolutionMode convolutionMode, PoolingType poolingType, CNN2DFormat format){
INDArray origInput = input; INDArray origInput = input;
//Check if we need to dup the input: views, non-contiguous, etc. CuDNN also seems to have has issues if strides //Check if we need to dup the input: views, non-contiguous, etc. CuDNN also seems to have has issues if strides
@ -602,16 +640,19 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
input = input.dup('c'); input = input.dup('c');
} }
boolean nchw = format == CNN2DFormat.NCHW;
int hIdx = nchw ? 2 : 1;
int wIdx = nchw ? 3 : 2;
val inH = input.size(2); val inH = input.size(hIdx);
val inW = input.size(3); val inW = input.size(wIdx);
boolean manualPadBottom = false; boolean manualPadBottom = false;
boolean manualPadRight = false; boolean manualPadRight = false;
int[] outSize; int[] outSize;
if (convolutionMode == ConvolutionMode.Same) { if (convolutionMode == ConvolutionMode.Same) {
outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation); //Also performs validation outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation, format); //Also performs validation
padding = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[] {(int) inH, (int) inW}, kernel, strides, dilation); padding = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[] {(int) inH, (int) inW}, kernel, strides, dilation);
int[] padBottomRight = ConvolutionUtils.getSameModeBottomRightPadding(outSize, new int[] {(int) inH, (int) inW}, kernel, strides, dilation); int[] padBottomRight = ConvolutionUtils.getSameModeBottomRightPadding(outSize, new int[] {(int) inH, (int) inW}, kernel, strides, dilation);
if(!Arrays.equals(padding, padBottomRight)){ if(!Arrays.equals(padding, padBottomRight)){
@ -626,9 +667,17 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
manualPadRight = (padding[1] != padBottomRight[1]); manualPadRight = (padding[1] != padBottomRight[1]);
//NCHW format //NCHW format
val newShape = new long[]{input.size(0), input.size(1), long[] newShape;
input.size(2) + (manualPadBottom ? 1 : 0), if(nchw){
input.size(3) + (manualPadRight ? 1 : 0)}; newShape = new long[]{input.size(0), input.size(1),
input.size(2) + (manualPadBottom ? 1 : 0),
input.size(3) + (manualPadRight ? 1 : 0)};
} else {
newShape = new long[]{input.size(0),
input.size(1) + (manualPadBottom ? 1 : 0),
input.size(2) + (manualPadRight ? 1 : 0),
input.size(3)};
}
INDArray newInput; INDArray newInput;
if(poolingType == null || poolingType != PoolingType.MAX){ if(poolingType == null || poolingType != PoolingType.MAX){
newInput = Nd4j.create(input.dataType(), newShape); newInput = Nd4j.create(input.dataType(), newShape);
@ -638,15 +687,22 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
// if the 'real' (non-padding) values are all < 0, we take the real value, not the padding value // if the 'real' (non-padding) values are all < 0, we take the real value, not the padding value
newInput = Nd4j.valueArrayOf(newShape, Double.NEGATIVE_INFINITY, input.dataType()); newInput = Nd4j.valueArrayOf(newShape, Double.NEGATIVE_INFINITY, input.dataType());
} }
newInput.put(new INDArrayIndex[]{all(), all(), interval(0,input.size(2)),
interval(0, input.size(3))}, input); if(nchw){
newInput.put(new INDArrayIndex[]{all(), all(), interval(0,input.size(2)),
interval(0, input.size(3))}, input);
} else {
newInput.put(new INDArrayIndex[]{all(), interval(0,input.size(1)),
interval(0, input.size(2)), all()}, input);
}
input = newInput; input = newInput;
//Now: we've manually applied the "extra" bottom/right padding only - if required. Consequently, we //Now: we've manually applied the "extra" bottom/right padding only - if required. Consequently, we
// now have the same amount of padding required for top/bottom, and left/right - which we'll let // now have the same amount of padding required for top/bottom, and left/right - which we'll let
// CuDNN handle // CuDNN handle
} }
} else { } else {
outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, padding, convolutionMode, dilation); //Also performs validation outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, padding, convolutionMode, dilation, format); //Also performs validation
} }
return new CudnnForwardArgs(manualPadBottom, manualPadRight, input, origInput, padding, outSize); return new CudnnForwardArgs(manualPadBottom, manualPadRight, input, origInput, padding, outSize);

View File

@ -19,6 +19,7 @@ package org.deeplearning4j.nn.layers.convolution.subsampling;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.Pointer;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.layers.PoolingType; import org.deeplearning4j.nn.conf.layers.PoolingType;
import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.DefaultGradient;
@ -114,23 +115,29 @@ public class CudnnSubsamplingHelper extends BaseCudnnHelper implements Subsampli
@Override @Override
public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, int[] kernel, int[] strides, public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, int[] kernel, int[] strides,
int[] pad, PoolingType poolingType, ConvolutionMode convolutionMode, int[] dilation, LayerWorkspaceMgr workspaceMgr) { int[] pad, PoolingType poolingType, ConvolutionMode convolutionMode,
int[] dilation, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) {
if(dilation[0] != 1 || dilation[1] != 1){ if(dilation[0] != 1 || dilation[1] != 1){
//CuDNN doesn't support dilated subsampling //CuDNN doesn't support dilated subsampling
return null; return null;
} }
boolean nchw = format == CNN2DFormat.NCHW;
int chIdx = nchw ? 1 : 3;
int hIdx = nchw ? 2 : 1;
int wIdx = nchw ? 3 : 2;
//We require the output as one of the arguments for backprop here //We require the output as one of the arguments for backprop here
//TODO we could add cache mode support here somehow... //TODO we could add cache mode support here somehow...
INDArray reduced = activate(input, true, kernel, strides, pad, poolingType, convolutionMode, dilation, workspaceMgr); INDArray reduced = activate(input, true, kernel, strides, pad, poolingType, convolutionMode, dilation, format, workspaceMgr);
val miniBatch = input.size(0); val miniBatch = input.size(0);
val depth = input.size(1); val depth = input.size(chIdx);
CudnnConvolutionHelper.CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, poolingType); CudnnConvolutionHelper.CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, poolingType, format);
input = args.getInput(); input = args.getInput();
val inH = input.size(2); val inH = input.size(hIdx);
val inW = input.size(3); val inW = input.size(wIdx);
val srcStride = input.stride(); val srcStride = input.stride();
int[] outSize = args.getOutSize(); int[] outSize = args.getOutSize();
int outH = outSize[0]; int outH = outSize[0];
@ -160,23 +167,26 @@ public class CudnnSubsamplingHelper extends BaseCudnnHelper implements Subsampli
epsilon = epsilon.dup('c'); epsilon = epsilon.dup('c');
} }
input = input.dup();
val deltaStride = epsilon.stride(); val deltaStride = epsilon.stride();
if (Nd4j.getExecutioner() instanceof GridExecutioner) if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); ((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW, checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW,
(int) srcStride[0], (int) srcStride[1], (int) srcStride[2], (int) srcStride[3])); (int) srcStride[0], (int) srcStride[chIdx], (int) srcStride[hIdx], (int) srcStride[wIdx]));
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.deltaTensorDesc, dataType, (int) miniBatch, (int) depth, (int) outH, (int) outW, checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.deltaTensorDesc, dataType, (int) miniBatch, (int) depth, (int) outH, (int) outW,
(int) deltaStride[0], (int) deltaStride[1], (int) deltaStride[2], (int) deltaStride[3])); (int) deltaStride[0], (int) deltaStride[chIdx], (int) deltaStride[hIdx], (int) deltaStride[wIdx]));
checkCudnn(cudnnSetPooling2dDescriptor(cudnnContext.poolingDesc, poolingMode, CUDNN_PROPAGATE_NAN, kernel[0], checkCudnn(cudnnSetPooling2dDescriptor(cudnnContext.poolingDesc, poolingMode, CUDNN_PROPAGATE_NAN, kernel[0],
kernel[1], pad[0], pad[1], strides[0], strides[1])); kernel[1], pad[0], pad[1], strides[0], strides[1]));
INDArray outEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), new long[] {(int) miniBatch, (int) depth, (int) inH, (int) inW}, 'c'); long[] outEpsShape = nchw ? new long[] {miniBatch, depth, inH, inW} : new long[] {miniBatch, inH, inW, depth};
INDArray outEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), outEpsShape, 'c');
val dstStride = outEpsilon.stride(); val dstStride = outEpsilon.stride();
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW, checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW,
(int) dstStride[0], (int) dstStride[1], (int) dstStride[2], (int) dstStride[3])); (int) dstStride[0], (int) dstStride[chIdx], (int) dstStride[hIdx], (int) dstStride[wIdx]));
Allocator allocator = AtomicAllocator.getInstance(); Allocator allocator = AtomicAllocator.getInstance();
CudaContext context = allocator.getFlowController().prepareAction(input, epsilon, reduced, outEpsilon); CudaContext context = allocator.getFlowController().prepareAction(input, epsilon, reduced, outEpsilon);
@ -198,9 +208,16 @@ public class CudnnSubsamplingHelper extends BaseCudnnHelper implements Subsampli
//Note that: if we had to manually pad for SAME mode, we have to 'undo' this manual padding for the epsilon //Note that: if we had to manually pad for SAME mode, we have to 'undo' this manual padding for the epsilon
// we return. The returned epsilon (i.e., dL/dIn array) has to be the same shape as the *original* input. // we return. The returned epsilon (i.e., dL/dIn array) has to be the same shape as the *original* input.
if(args.isManualPadBottom() || args.isManualPadRight()) { if(args.isManualPadBottom() || args.isManualPadRight()) {
outEpsilon = outEpsilon.get(all(), all(), if(nchw){
interval(0, outEpsilon.size(2) - (args.isManualPadBottom() ? 1 : 0)), outEpsilon = outEpsilon.get(all(), all(),
interval(0, outEpsilon.size(3) - (args.isManualPadRight() ? 1 : 0))); interval(0, outEpsilon.size(2) - (args.isManualPadBottom() ? 1 : 0)),
interval(0, outEpsilon.size(3) - (args.isManualPadRight() ? 1 : 0)));
} else {
outEpsilon = outEpsilon.get(all(),
interval(0, outEpsilon.size(1) - (args.isManualPadBottom() ? 1 : 0)),
interval(0, outEpsilon.size(2) - (args.isManualPadRight() ? 1 : 0)),
all());
}
} }
return new Pair<>(retGradient, outEpsilon); return new Pair<>(retGradient, outEpsilon);
@ -209,19 +226,24 @@ public class CudnnSubsamplingHelper extends BaseCudnnHelper implements Subsampli
@Override @Override
public INDArray activate(INDArray input, boolean training, int[] kernel, int[] strides, int[] pad, public INDArray activate(INDArray input, boolean training, int[] kernel, int[] strides, int[] pad,
PoolingType poolingType, ConvolutionMode convolutionMode, int[] dilation, LayerWorkspaceMgr workspaceMgr) { PoolingType poolingType, ConvolutionMode convolutionMode, int[] dilation, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) {
if(dilation[0] != 1 || dilation[1] != 1){ if(dilation[0] != 1 || dilation[1] != 1){
//CuDNN doesn't support dilated subsampling //CuDNN doesn't support dilated subsampling
return null; return null;
} }
val miniBatch = input.size(0); boolean nchw = format == CNN2DFormat.NCHW;
val inDepth = input.size(1); int chIdx = nchw ? 1 : 3;
int hIdx = nchw ? 2 : 1;
int wIdx = nchw ? 3 : 2;
CudnnConvolutionHelper.CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, poolingType); val miniBatch = input.size(0);
val inDepth = input.size(nchw ? 1 : 3);
CudnnConvolutionHelper.CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, poolingType, format);
input = args.getInput(); input = args.getInput();
val inH = input.size(2); val inH = input.size(nchw ? 2 : 1);
val inW = input.size(3); val inW = input.size(nchw ? 3 : 2);
val srcStride = input.stride(); val srcStride = input.stride();
val outSize = args.getOutSize(); val outSize = args.getOutSize();
int outH = outSize[0]; int outH = outSize[0];
@ -246,13 +268,14 @@ public class CudnnSubsamplingHelper extends BaseCudnnHelper implements Subsampli
checkCudnn(cudnnSetPooling2dDescriptor(cudnnContext.poolingDesc, poolingMode, CUDNN_PROPAGATE_NAN, kernel[0], checkCudnn(cudnnSetPooling2dDescriptor(cudnnContext.poolingDesc, poolingMode, CUDNN_PROPAGATE_NAN, kernel[0],
kernel[1], pad[0], pad[1], strides[0], strides[1])); kernel[1], pad[0], pad[1], strides[0], strides[1]));
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, (int) miniBatch, (int) inDepth, (int) inH, (int) inW, checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, (int) miniBatch, (int) inDepth, (int) inH, (int) inW,
(int) srcStride[0], (int) srcStride[1], (int) srcStride[2], (int) srcStride[3])); (int) srcStride[0], (int) srcStride[chIdx], (int) srcStride[hIdx], (int) srcStride[wIdx]));
INDArray reduced = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), new long[] {(int) miniBatch, (int) inDepth, outH, outW}, 'c'); long[] outShape = nchw ? new long[] {miniBatch, inDepth, outH, outW} : new long[] {miniBatch, outH, outW, inDepth};
INDArray reduced = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), outShape, 'c');
val dstStride = reduced.stride(); val dstStride = reduced.stride();
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, (int) miniBatch, (int) inDepth, (int) outH, (int) outW, checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, (int) miniBatch, (int) inDepth, (int) outH, (int) outW,
(int) dstStride[0], (int) dstStride[1], (int) dstStride[2], (int) dstStride[3])); (int) dstStride[0], (int) dstStride[chIdx], (int) dstStride[hIdx], (int) dstStride[wIdx]));
Allocator allocator = AtomicAllocator.getInstance(); Allocator allocator = AtomicAllocator.getInstance();
CudaContext context = allocator.getFlowController().prepareAction(input, reduced); CudaContext context = allocator.getFlowController().prepareAction(input, reduced);

View File

@ -19,6 +19,7 @@ package org.deeplearning4j.nn.layers.normalization;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.Pointer;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseCudnnHelper; import org.deeplearning4j.nn.layers.BaseCudnnHelper;
@ -124,12 +125,21 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba
@Override @Override
public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, long[] shape, INDArray gamma, INDArray beta, public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, long[] shape, INDArray gamma, INDArray beta,
INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr layerWorkspaceMgr) { INDArray dGammaView, INDArray dBetaView, double eps, CNN2DFormat format, LayerWorkspaceMgr layerWorkspaceMgr) {
boolean nchw = format == CNN2DFormat.NCHW;
this.eps = eps; this.eps = eps;
int cudnnTensorFormat = nchw ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC;
int chIdx = nchw ? 1 : 3;
int hIdx = nchw ? 2 : 1;
int wIdx = nchw ? 3 : 2;
val miniBatch = (int) input.size(0); val miniBatch = (int) input.size(0);
val depth = (int) input.size(1); val depth = (int) input.size(chIdx);
val inH = (int) input.size(2); val inH = (int) input.size(hIdx);
val inW = (int) input.size(3); val inW = (int) input.size(wIdx);
final boolean isHalf = (input.dataType() == DataType.HALF); final boolean isHalf = (input.dataType() == DataType.HALF);
INDArray gammaOrig = null; INDArray gammaOrig = null;
@ -164,16 +174,17 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba
((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); ((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW, checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW,
(int) srcStride[0], (int) srcStride[1], (int) srcStride[2], (int) srcStride[3])); (int) srcStride[0], (int) srcStride[chIdx], (int) srcStride[hIdx], (int) srcStride[wIdx]));
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.deltaTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW, checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.deltaTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW,
(int) deltaStride[0], (int) deltaStride[1], (int) deltaStride[2], (int) deltaStride[3])); (int) deltaStride[0], (int) deltaStride[chIdx], (int) deltaStride[hIdx], (int) deltaStride[wIdx]));
INDArray nextEpsilon = layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), new long[] {miniBatch, depth, inH, inW}, 'c'); long[] nextEpsShape = nchw ? new long[] {miniBatch, depth, inH, inW} : new long[] {miniBatch, inH, inW, depth};
INDArray nextEpsilon = layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), nextEpsShape, 'c');
val dstStride = ArrayUtil.toInts(nextEpsilon.stride()); val dstStride = ArrayUtil.toInts(nextEpsilon.stride());
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, depth, inH, inW, checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, depth, inH, inW,
dstStride[0], dstStride[1], dstStride[2], dstStride[3])); dstStride[0], dstStride[chIdx], dstStride[hIdx], dstStride[wIdx]));
checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, TENSOR_FORMAT, toCudnnDataType(gamma.data().dataType()), (int)shape[0], checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, cudnnTensorFormat, toCudnnDataType(gamma.data().dataType()), (int)shape[0],
(int)shape[1], shape.length > 2 ? (int)shape[2] : 1, shape.length > 3 ? (int)shape[3] : 1)); (int)shape[1], shape.length > 2 ? (int)shape[2] : 1, shape.length > 3 ? (int)shape[3] : 1));
Allocator allocator = AtomicAllocator.getInstance(); Allocator allocator = AtomicAllocator.getInstance();
@ -215,9 +226,15 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba
@Override @Override
public INDArray preOutput(INDArray x, boolean training, long[] shape, INDArray gamma, INDArray beta, INDArray mean, public INDArray preOutput(INDArray x, boolean training, long[] shape, INDArray gamma, INDArray beta, INDArray mean,
INDArray var, double decay, double eps, LayerWorkspaceMgr workspaceMgr) { INDArray var, double decay, double eps, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) {
boolean nchw = format == CNN2DFormat.NCHW;
int cudnnTensorFormat = nchw ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC;
int chIdx = nchw ? 1 : 3;
int hIdx = nchw ? 2 : 1;
int wIdx = nchw ? 3 : 2;
this.eps = eps; this.eps = eps;
final boolean isHalf = (x.dataType() == DataType.HALF); final boolean isHalf = (x.dataType() == DataType.FLOAT16);
INDArray origGamma = gamma; INDArray origGamma = gamma;
INDArray origBeta = beta; INDArray origBeta = beta;
INDArray origMean = mean; INDArray origMean = mean;
@ -238,21 +255,22 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba
decay = 0.0; //From cudnn docs: runningMean = newMean*factor + runningMean*(1-factor). -> 0 = "in-place modification of running mean disabled" decay = 0.0; //From cudnn docs: runningMean = newMean*factor + runningMean*(1-factor). -> 0 = "in-place modification of running mean disabled"
val miniBatch = (int) x.size(0); val miniBatch = (int) x.size(0);
val inDepth = (int) x.size(1); val inDepth = (int) x.size(chIdx);
val inH = (int) x.size(2); val inH = (int) x.size(hIdx);
val inW = (int) x.size(3); val inW = (int) x.size(wIdx);
val srcStride = ArrayUtil.toInts(x.stride()); val srcStride = ArrayUtil.toInts(x.stride());
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, miniBatch, inDepth, inH, inW, checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, miniBatch, inDepth, inH, inW,
srcStride[0], srcStride[1], srcStride[2], srcStride[3])); srcStride[0], srcStride[chIdx], srcStride[hIdx], srcStride[wIdx]));
INDArray activations = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, x.dataType(), new long[] {miniBatch, inDepth, inH, inW}, 'c'); long[] actShape = nchw ? new long[] {miniBatch, inDepth, inH, inW} : new long[] {miniBatch, inH, inW, inDepth};
INDArray activations = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, x.dataType(), actShape, 'c');
val dstStride = ArrayUtil.toInts(activations.stride()); val dstStride = ArrayUtil.toInts(activations.stride());
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, inDepth, inH, inW, checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, inDepth, inH, inW,
dstStride[0], dstStride[1], dstStride[2], dstStride[3])); dstStride[0], dstStride[chIdx], dstStride[hIdx], dstStride[wIdx]));
checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, TENSOR_FORMAT, toCudnnDataType(mean.data().dataType()), (int)shape[0], checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, cudnnTensorFormat, toCudnnDataType(mean.data().dataType()), (int)shape[0],
(int)shape[1], shape.length > 2 ? (int)shape[2] : 1, shape.length > 3 ? (int)shape[3] : 1)); (int)shape[1], shape.length > 2 ? (int)shape[2] : 1, shape.length > 3 ? (int)shape[3] : 1));
Allocator allocator = AtomicAllocator.getInstance(); Allocator allocator = AtomicAllocator.getInstance();

View File

@ -16,74 +16,131 @@
package org.deeplearning4j; package org.deeplearning4j;
import org.apache.commons.compress.utils.IOUtils;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.convolution.ConvolutionLayer;
import org.deeplearning4j.nn.layers.convolution.subsampling.SubsamplingLayer;
import org.deeplearning4j.nn.layers.normalization.BatchNormalization;
import org.deeplearning4j.nn.layers.normalization.LocalResponseNormalization;
import org.deeplearning4j.nn.layers.recurrent.LSTM;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.util.ModelSerializer; import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.regularization.L1Regularization;
import org.nd4j.linalg.learning.regularization.L2Regularization;
import org.nd4j.linalg.learning.regularization.Regularization;
import org.nd4j.linalg.learning.regularization.WeightDecay;
import java.io.ByteArrayInputStream; import java.io.*;
import java.io.ByteArrayOutputStream; import java.lang.reflect.Field;
import java.io.IOException; import java.util.List;
import java.util.Random; import java.util.Random;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
public class TestUtils { public class TestUtils {
public static MultiLayerNetwork testModelSerialization(MultiLayerNetwork net){ public static MultiLayerNetwork testModelSerialization(MultiLayerNetwork net){
MultiLayerNetwork restored;
try { try {
ByteArrayOutputStream baos = new ByteArrayOutputStream(); ByteArrayOutputStream baos = new ByteArrayOutputStream();
ModelSerializer.writeModel(net, baos, true); ModelSerializer.writeModel(net, baos, true);
byte[] bytes = baos.toByteArray(); byte[] bytes = baos.toByteArray();
ByteArrayInputStream bais = new ByteArrayInputStream(bytes); ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
MultiLayerNetwork restored = ModelSerializer.restoreMultiLayerNetwork(bais, true); restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);
assertEquals(net.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations()); assertEquals(net.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations());
assertEquals(net.params(), restored.params()); assertEquals(net.params(), restored.params());
return restored;
} catch (IOException e){ } catch (IOException e){
//Should never happen //Should never happen
throw new RuntimeException(e); throw new RuntimeException(e);
} }
//Also check the MultiLayerConfiguration is serializable (required by Spark etc)
MultiLayerConfiguration conf = net.getLayerWiseConfigurations();
serializeDeserializeJava(conf);
return restored;
} }
public static ComputationGraph testModelSerialization(ComputationGraph net){ public static ComputationGraph testModelSerialization(ComputationGraph net){
ComputationGraph restored;
try { try {
ByteArrayOutputStream baos = new ByteArrayOutputStream(); ByteArrayOutputStream baos = new ByteArrayOutputStream();
ModelSerializer.writeModel(net, baos, true); ModelSerializer.writeModel(net, baos, true);
byte[] bytes = baos.toByteArray(); byte[] bytes = baos.toByteArray();
ByteArrayInputStream bais = new ByteArrayInputStream(bytes); ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
ComputationGraph restored = ModelSerializer.restoreComputationGraph(bais, true); restored = ModelSerializer.restoreComputationGraph(bais, true);
assertEquals(net.getConfiguration(), restored.getConfiguration()); assertEquals(net.getConfiguration(), restored.getConfiguration());
assertEquals(net.params(), restored.params()); assertEquals(net.params(), restored.params());
return restored;
} catch (IOException e){ } catch (IOException e){
//Should never happen //Should never happen
throw new RuntimeException(e); throw new RuntimeException(e);
} }
//Also check the ComputationGraphConfiguration is serializable (required by Spark etc)
ComputationGraphConfiguration conf = net.getConfiguration();
serializeDeserializeJava(conf);
return restored;
} }
public static INDArray randomOneHot(int examples, int nOut){ private static <T> T serializeDeserializeJava(T object){
byte[] bytes;
try(ByteArrayOutputStream baos = new ByteArrayOutputStream(); ObjectOutputStream oos = new ObjectOutputStream(baos)){
oos.writeObject(object);
oos.close();
bytes = baos.toByteArray();
} catch (IOException e){
//Should never happen
throw new RuntimeException(e);
}
T out;
try(ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(bytes))){
out = (T)ois.readObject();
} catch (IOException | ClassNotFoundException e){
throw new RuntimeException(e);
}
assertEquals(object, out);
return out;
}
public static INDArray randomOneHot(long examples, long nOut){
return randomOneHot(examples, nOut, new Random(12345)); return randomOneHot(examples, nOut, new Random(12345));
} }
public static INDArray randomOneHot(int examples, int nOut, long rngSeed){ public static INDArray randomOneHot(DataType dataType, long examples, long nOut){
return randomOneHot(dataType, examples, nOut, new Random(12345));
}
public static INDArray randomOneHot(long examples, long nOut, long rngSeed){
return randomOneHot(examples, nOut, new Random(rngSeed)); return randomOneHot(examples, nOut, new Random(rngSeed));
} }
public static INDArray randomOneHot(int examples, int nOut, Random rng){ public static INDArray randomOneHot(long examples, long nOut, Random rng) {
INDArray arr = Nd4j.create(examples, nOut); return randomOneHot(Nd4j.defaultFloatingPointType(), examples,nOut, rng);
}
public static INDArray randomOneHot(DataType dataType, long examples, long nOut, Random rng){
INDArray arr = Nd4j.create(dataType, examples, nOut);
for( int i=0; i<examples; i++ ){ for( int i=0; i<examples; i++ ){
arr.putScalar(i, rng.nextInt(nOut), 1.0); arr.putScalar(i, rng.nextInt((int) nOut), 1.0);
} }
return arr; return arr;
} }
@ -115,4 +172,143 @@ public class TestUtils {
Nd4j.getExecutioner().exec(new BernoulliDistribution(ret, p)); Nd4j.getExecutioner().exec(new BernoulliDistribution(ret, p));
return ret; return ret;
} }
public static void writeStreamToFile(File out, InputStream is) throws IOException {
byte[] b = IOUtils.toByteArray(is);
try (OutputStream os = new BufferedOutputStream(new FileOutputStream(out))) {
os.write(b);
}
}
public static L1Regularization getL1Reg(List<Regularization> l){
for(Regularization r : l){
if(r instanceof L1Regularization){
return (L1Regularization) r;
}
}
return null;
}
public static L2Regularization getL2Reg(BaseLayer baseLayer){
return getL2Reg(baseLayer.getRegularization());
}
public static L2Regularization getL2Reg(List<Regularization> l){
for(Regularization r : l){
if(r instanceof L2Regularization){
return (L2Regularization) r;
}
}
return null;
}
public static WeightDecay getWeightDecayReg(BaseLayer bl){
return getWeightDecayReg(bl.getRegularization());
}
public static WeightDecay getWeightDecayReg(List<Regularization> l){
for(Regularization r : l){
if(r instanceof WeightDecay){
return (WeightDecay) r;
}
}
return null;
}
public static double getL1(BaseLayer layer) {
List<Regularization> l = layer.getRegularization();
return getL1(l);
}
public static double getL1(List<Regularization> l){
L1Regularization l1Reg = null;
for(Regularization reg : l){
if(reg instanceof L1Regularization)
l1Reg = (L1Regularization) reg;
}
assertNotNull(l1Reg);
return l1Reg.getL1().valueAt(0,0);
}
public static double getL2(BaseLayer layer) {
List<Regularization> l = layer.getRegularization();
return getL2(l);
}
public static double getL2(List<Regularization> l){
L2Regularization l2Reg = null;
for(Regularization reg : l){
if(reg instanceof L2Regularization)
l2Reg = (L2Regularization) reg;
}
assertNotNull(l2Reg);
return l2Reg.getL2().valueAt(0,0);
}
public static double getL1(AbstractSameDiffLayer layer){
return getL1(layer.getRegularization());
}
public static double getL2(AbstractSameDiffLayer layer){
return getL2(layer.getRegularization());
}
public static double getWeightDecay(BaseLayer layer) {
return getWeightDecayReg(layer.getRegularization()).getCoeff().valueAt(0,0);
}
public static void removeHelper(Layer layer) throws Exception {
removeHelpers(new Layer[]{layer});
}
public static void removeHelpers(Layer[] layers) throws Exception {
for(Layer l : layers){
if(l instanceof ConvolutionLayer){
Field f1 = ConvolutionLayer.class.getDeclaredField("helper");
f1.setAccessible(true);
f1.set(l, null);
} else if(l instanceof SubsamplingLayer){
Field f2 = SubsamplingLayer.class.getDeclaredField("helper");
f2.setAccessible(true);
f2.set(l, null);
} else if(l instanceof BatchNormalization) {
Field f3 = BatchNormalization.class.getDeclaredField("helper");
f3.setAccessible(true);
f3.set(l, null);
} else if(l instanceof LSTM){
Field f4 = LSTM.class.getDeclaredField("helper");
f4.setAccessible(true);
f4.set(l, null);
} else if(l instanceof LocalResponseNormalization){
Field f5 = LocalResponseNormalization.class.getDeclaredField("helper");
f5.setAccessible(true);
f5.set(l, null);
}
if(l.getHelper() != null){
throw new IllegalStateException("Did not remove helper for layer: " + l.getClass().getSimpleName());
}
}
}
public static void assertHelperPresent(Layer layer){
}
public static void assertHelpersPresent(Layer[] layers) throws Exception {
for(Layer l : layers){
//Don't use instanceof here - there are sub conv subclasses
if(l.getClass() == ConvolutionLayer.class || l instanceof SubsamplingLayer || l instanceof BatchNormalization || l instanceof LSTM){
Preconditions.checkNotNull(l.getHelper(), l.conf().getLayer().getLayerName());
}
}
}
public static void assertHelpersAbsent(Layer[] layers) throws Exception {
for(Layer l : layers){
Preconditions.checkState(l.getHelper() == null, l.conf().getLayer().getLayerName());
}
}
} }

View File

@ -212,7 +212,7 @@ public class TestConvolution extends BaseDL4JTest {
ComputationGraph model = KerasModelImport.importKerasModelAndWeights( fExtracted.getAbsolutePath(), new int[]{inSize, inSize, 3}, false); ComputationGraph model = KerasModelImport.importKerasModelAndWeights( fExtracted.getAbsolutePath(), new int[]{inSize, inSize, 3}, false);
model = model.convertDataType(DataType.DOUBLE); model = model.convertDataType(DataType.DOUBLE);
INDArray in = Nd4j.rand(DataType.DOUBLE, new int[]{1, 3, inSize, inSize}); INDArray in = Nd4j.rand(DataType.DOUBLE, new int[]{1, inSize, inSize, 3}); //Keras import model -> NHWC
CuDNNTestUtils.assertHelpersPresent(model.getLayers()); CuDNNTestUtils.assertHelpersPresent(model.getLayers());
Map<String,INDArray> withCudnn = model.feedForward(in, false); Map<String,INDArray> withCudnn = model.feedForward(in, false);

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.datasets.fetchers; package org.deeplearning4j.datasets.fetchers;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils; import org.apache.commons.io.FilenameUtils;
import org.deeplearning4j.base.EmnistFetcher; import org.deeplearning4j.base.EmnistFetcher;
@ -36,6 +37,7 @@ import java.util.Random;
* @author Alex Black * @author Alex Black
* *
*/ */
@Slf4j
public class EmnistDataFetcher extends MnistDataFetcher implements DataSetFetcher { public class EmnistDataFetcher extends MnistDataFetcher implements DataSetFetcher {
protected EmnistFetcher fetcher; protected EmnistFetcher fetcher;
@ -64,7 +66,7 @@ public class EmnistDataFetcher extends MnistDataFetcher implements DataSetFetche
try { try {
man = new MnistManager(images, labels, totalExamples); man = new MnistManager(images, labels, totalExamples);
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
FileUtils.deleteDirectory(new File(EMNIST_ROOT)); FileUtils.deleteDirectory(new File(EMNIST_ROOT));
new EmnistFetcher(dataSet).downloadAndUntar(); new EmnistFetcher(dataSet).downloadAndUntar();
man = new MnistManager(images, labels, totalExamples); man = new MnistManager(images, labels, totalExamples);

View File

@ -687,9 +687,7 @@ public class BarnesHutTsne implements Model {
* @throws IOException * @throws IOException
*/ */
public void saveAsFile(List<String> labels, String path) throws IOException { public void saveAsFile(List<String> labels, String path) throws IOException {
BufferedWriter write = null; try (BufferedWriter write = new BufferedWriter(new FileWriter(new File(path)))) {
try {
write = new BufferedWriter(new FileWriter(new File(path)));
for (int i = 0; i < Y.rows(); i++) { for (int i = 0; i < Y.rows(); i++) {
if (i >= labels.size()) if (i >= labels.size())
break; break;
@ -711,17 +709,11 @@ public class BarnesHutTsne implements Model {
} }
write.flush(); write.flush();
write.close();
} finally {
if (write != null)
write.close();
} }
} }
public void saveAsFile(String path) throws IOException { public void saveAsFile(String path) throws IOException {
BufferedWriter write = null; try (BufferedWriter write = new BufferedWriter(new FileWriter(new File(path)))) {
try {
write = new BufferedWriter(new FileWriter(new File(path)));
for (int i = 0; i < Y.rows(); i++) { for (int i = 0; i < Y.rows(); i++) {
StringBuilder sb = new StringBuilder(); StringBuilder sb = new StringBuilder();
INDArray wordVector = Y.getRow(i); INDArray wordVector = Y.getRow(i);
@ -734,10 +726,6 @@ public class BarnesHutTsne implements Model {
write.write(sb.toString()); write.write(sb.toString());
} }
write.flush(); write.flush();
write.close();
} finally {
if (write != null)
write.close();
} }
} }
/** /**

View File

@ -113,6 +113,12 @@
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-python</artifactId>
<version>${datavec.version}</version>
<scope>test</scope>
</dependency>
</dependencies> </dependencies>
<profiles> <profiles>

View File

@ -60,7 +60,7 @@ public class Hdf5Archive implements Closeable {
/* This is necessary for the call to the BytePointer constructor below. */ /* This is necessary for the call to the BytePointer constructor below. */
Loader.load(org.bytedeco.hdf5.global.hdf5.class); Loader.load(org.bytedeco.hdf5.global.hdf5.class);
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
} }
} }

View File

@ -19,6 +19,9 @@ package org.deeplearning4j.nn.modelimport.keras.layers;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Convolution3D; import org.deeplearning4j.nn.conf.layers.Convolution3D;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
@ -121,27 +124,29 @@ public class KerasInput extends KerasLayer {
InputType myInputType; InputType myInputType;
switch (this.inputShape.length) { switch (this.inputShape.length) {
case 1: case 1:
myInputType = new InputType.InputTypeFeedForward(this.inputShape[0]); myInputType = new InputType.InputTypeFeedForward(this.inputShape[0], null);
break; break;
case 2: case 2:
if(this.dimOrder != null) { if(this.dimOrder != null) {
System.out.println("Dim order: " + this.dimOrder);
System.out.println("Input shape: " + ArrayUtils.toString(this.inputShape));
switch (this.dimOrder) { switch (this.dimOrder) {
case TENSORFLOW: //NWC == channels_last case TENSORFLOW: //NWC == channels_last
myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0]); myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0], RNNFormat.NWC);
break; break;
case THEANO: //NCW == channels_first case THEANO: //NCW == channels_first
myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]); myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1], RNNFormat.NCW);
break; break;
case NONE: case NONE:
//Assume RNN in [mb, seqLen, size] format //Assume RNN in [mb, seqLen, size] format
myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]); myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0], RNNFormat.NWC);
break; break;
default: default:
throw new IllegalStateException("Unknown/not supported dimension ordering: " + this.dimOrder); throw new IllegalStateException("Unknown/not supported dimension ordering: " + this.dimOrder);
} }
} else { } else {
//Assume RNN in [mb, seqLen, size] format //Assume RNN in [mb, seqLen, size] format
myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]); myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0], RNNFormat.NWC);
} }
break; break;
@ -150,17 +155,17 @@ public class KerasInput extends KerasLayer {
case TENSORFLOW: case TENSORFLOW:
/* TensorFlow convolutional input: # rows, # cols, # channels */ /* TensorFlow convolutional input: # rows, # cols, # channels */
myInputType = new InputType.InputTypeConvolutional(this.inputShape[0], this.inputShape[1], myInputType = new InputType.InputTypeConvolutional(this.inputShape[0], this.inputShape[1],
this.inputShape[2]); this.inputShape[2], CNN2DFormat.NHWC);
break; break;
case THEANO: case THEANO:
/* Theano convolutional input: # channels, # rows, # cols */ /* Theano convolutional input: # channels, # rows, # cols */
myInputType = new InputType.InputTypeConvolutional(this.inputShape[1], this.inputShape[2], myInputType = new InputType.InputTypeConvolutional(this.inputShape[1], this.inputShape[2],
this.inputShape[0]); this.inputShape[0], CNN2DFormat.NCHW);
break; break;
default: default:
this.dimOrder = DimOrder.THEANO; this.dimOrder = DimOrder.THEANO;
myInputType = new InputType.InputTypeConvolutional(this.inputShape[1], this.inputShape[2], myInputType = new InputType.InputTypeConvolutional(this.inputShape[1], this.inputShape[2],
this.inputShape[0]); this.inputShape[0], CNN2DFormat.NCHW);
log.warn("Couldn't determine dim ordering / data format from model file. Older Keras " + log.warn("Couldn't determine dim ordering / data format from model file. Older Keras " +
"versions may come without specified backend, in which case we assume the model was " + "versions may come without specified backend, in which case we assume the model was " +
"built with theano." ); "built with theano." );

View File

@ -20,6 +20,7 @@ import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Layer; import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
@ -65,6 +66,9 @@ public class TFOpLayer extends Layer {
long[] shape = inputType.getShape(true); long[] shape = inputType.getShape(true);
TFOpLayerImpl tempLayer = new TFOpLayerImpl(nodeDef, constants, null, null); TFOpLayerImpl tempLayer = new TFOpLayerImpl(nodeDef, constants, null, null);
long[] outputShape = tempLayer.getOutputShape(shape); long[] outputShape = tempLayer.getOutputShape(shape);
if (outputShape.length == 3){
return InputType.recurrent(outputShape[2], outputShape[1], RNNFormat.NWC);
}
return InputType.inferInputType(Nd4j.create(outputShape)); return InputType.inferInputType(Nd4j.create(outputShape));
} }

View File

@ -125,17 +125,9 @@ public class TFOpLayerImpl extends AbstractLayer<TFOpLayer> {
} }
private INDArray runGraph(INDArray input){ private INDArray runGraph(INDArray input){
if (input.rank() == 3){
// TODO make this a preprocessor
input = input.permute(0, 2, 1);
}
Map<String, INDArray> inputMap = new HashMap<>(); Map<String, INDArray> inputMap = new HashMap<>();
inputMap.put(inputNames.get(0), input); inputMap.put(inputNames.get(0), input);
INDArray out = graphRunnerService.run(inputMap).values().toArray(new INDArray[0])[0]; INDArray out = graphRunnerService.run(inputMap).values().toArray(new INDArray[0])[0];
if (out.rank() == 3){
out = out.permute(0, 2, 1); // TODO post-processing?
}
return out; return out;
} }

View File

@ -22,6 +22,7 @@ import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer; import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
import org.deeplearning4j.nn.conf.layers.InputTypeUtil; import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
@ -94,7 +95,6 @@ public class KerasConvolution1D extends KerasConvolution {
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
enforceTrainingConfig, conf, kerasMajorVersion); enforceTrainingConfig, conf, kerasMajorVersion);
Convolution1DLayer.Builder builder = new Convolution1DLayer.Builder().name(this.layerName) Convolution1DLayer.Builder builder = new Convolution1DLayer.Builder().name(this.layerName)
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout) .nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
.activation(getIActivationFromConfig(layerConfig, conf)) .activation(getIActivationFromConfig(layerConfig, conf))
@ -103,7 +103,7 @@ public class KerasConvolution1D extends KerasConvolution {
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf)) .convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
.kernelSize(getKernelSizeFromConfig(layerConfig, 1, conf, kerasMajorVersion)[0]) .kernelSize(getKernelSizeFromConfig(layerConfig, 1, conf, kerasMajorVersion)[0])
.hasBias(hasBias) .hasBias(hasBias)
.stride(getStrideFromConfig(layerConfig, 1, conf)[0]); .stride(getStrideFromConfig(layerConfig, 1, conf)[0]).rnnDataFormat(dimOrder == DimOrder.TENSORFLOW? RNNFormat.NWC: RNNFormat.NCW);
int[] padding = getPaddingFromBorderModeConfig(layerConfig, 1, conf, kerasMajorVersion); int[] padding = getPaddingFromBorderModeConfig(layerConfig, 1, conf, kerasMajorVersion);
if (hasBias) if (hasBias)
builder.biasInit(0.0); builder.biasInit(0.0);
@ -160,8 +160,8 @@ public class KerasConvolution1D extends KerasConvolution {
public InputPreProcessor getInputPreprocessor(InputType... inputType) throws InvalidKerasConfigurationException { public InputPreProcessor getInputPreprocessor(InputType... inputType) throws InvalidKerasConfigurationException {
if (inputType.length > 1) if (inputType.length > 1)
throw new InvalidKerasConfigurationException( throw new InvalidKerasConfigurationException(
"Keras LSTM layer accepts only one input (received " + inputType.length + ")"); "Keras Conv1D layer accepts only one input (received " + inputType.length + ")");
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], layerName); return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], RNNFormat.NCW,layerName);
} }

View File

@ -20,6 +20,7 @@ import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
@ -27,6 +28,7 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurat
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils; import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.IWeightInit;
import oshi.jna.platform.windows.PowrProf;
import java.util.Map; import java.util.Map;
@ -93,6 +95,7 @@ public class KerasConvolution2D extends KerasConvolution {
LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig( LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig(
layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion); layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion);
System.out.println("----" + dimOrder);
ConvolutionLayer.Builder builder = new ConvolutionLayer.Builder().name(this.layerName) ConvolutionLayer.Builder builder = new ConvolutionLayer.Builder().name(this.layerName)
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout) .nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
.activation(getIActivationFromConfig(layerConfig, conf)) .activation(getIActivationFromConfig(layerConfig, conf))
@ -101,7 +104,8 @@ public class KerasConvolution2D extends KerasConvolution {
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf)) .convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
.kernelSize(getKernelSizeFromConfig(layerConfig, 2, conf, kerasMajorVersion)) .kernelSize(getKernelSizeFromConfig(layerConfig, 2, conf, kerasMajorVersion))
.hasBias(hasBias) .hasBias(hasBias)
.stride(getStrideFromConfig(layerConfig, 2, conf)); .stride(getStrideFromConfig(layerConfig, 2, conf))
.dataFormat((dimOrder==DimOrder.TENSORFLOW)? CNN2DFormat.NHWC:CNN2DFormat.NCHW);
int[] padding = getPaddingFromBorderModeConfig(layerConfig, 2, conf, kerasMajorVersion); int[] padding = getPaddingFromBorderModeConfig(layerConfig, 2, conf, kerasMajorVersion);
if (hasBias) if (hasBias)
builder.biasInit(0.0); builder.biasInit(0.0);

View File

@ -360,8 +360,19 @@ public class KerasConvolutionUtils {
} }
} else if (dimension == 1) { } else if (dimension == 1) {
int paddingInt = (int) innerConfig.get(layerField); Object paddingObj = innerConfig.get(layerField);
padding = new int[]{paddingInt, paddingInt}; if (paddingObj instanceof List){
List<Integer> paddingList = (List)paddingObj;
padding = new int[]{
paddingList.get(0),
paddingList.get(1)
};
}
else{
int paddingInt = (int) innerConfig.get(layerField);
padding = new int[]{paddingInt, paddingInt};
}
} else { } else {
throw new UnsupportedKerasConfigurationException( throw new UnsupportedKerasConfigurationException(
"Keras padding layer not supported"); "Keras padding layer not supported");

View File

@ -18,6 +18,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.core;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InputType.InputTypeConvolutional; import org.deeplearning4j.nn.conf.inputs.InputType.InputTypeConvolutional;
@ -27,7 +28,6 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurat
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor; import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor; import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.TensorFlowCnnToFeedForwardPreProcessor;
import java.util.Map; import java.util.Map;
@ -93,11 +93,10 @@ public class KerasFlatten extends KerasLayer {
switch (this.getDimOrder()) { switch (this.getDimOrder()) {
case NONE: case NONE:
case THEANO: case THEANO:
preprocessor = new CnnToFeedForwardPreProcessor(it.getHeight(), it.getWidth(), it.getChannels()); preprocessor = new CnnToFeedForwardPreProcessor(it.getHeight(), it.getWidth(), it.getChannels(), CNN2DFormat.NCHW);
break; break;
case TENSORFLOW: case TENSORFLOW:
preprocessor = new TensorFlowCnnToFeedForwardPreProcessor(it.getHeight(), it.getWidth(), preprocessor = new CnnToFeedForwardPreProcessor(it.getHeight(), it.getWidth(), it.getChannels(), CNN2DFormat.NHWC);
it.getChannels());
break; break;
default: default:
throw new InvalidKerasConfigurationException("Unknown Keras backend " + this.getDimOrder()); throw new InvalidKerasConfigurationException("Unknown Keras backend " + this.getDimOrder());
@ -111,7 +110,7 @@ public class KerasFlatten extends KerasLayer {
// to RNN type. Otherwise we add this trivial preprocessor (since there's nothing to flatten). // to RNN type. Otherwise we add this trivial preprocessor (since there's nothing to flatten).
InputType.InputTypeFeedForward it = (InputType.InputTypeFeedForward) inputType[0]; InputType.InputTypeFeedForward it = (InputType.InputTypeFeedForward) inputType[0];
val inputShape = new long[]{it.getSize()}; val inputShape = new long[]{it.getSize()};
preprocessor = new ReshapePreprocessor(inputShape, inputShape, false); preprocessor = new ReshapePreprocessor(inputShape, inputShape, false, null);
} }
return preprocessor; return preprocessor;
} }

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.nn.modelimport.keras.layers.core; package org.deeplearning4j.nn.modelimport.keras.layers.core;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.misc.RepeatVector; import org.deeplearning4j.nn.conf.layers.misc.RepeatVector;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
@ -60,6 +61,7 @@ public class KerasRepeatVector extends KerasLayer {
super(layerConfig, enforceTrainingConfig); super(layerConfig, enforceTrainingConfig);
this.layer = new RepeatVector.Builder().repetitionFactor(getRepeatMultiplier(layerConfig, conf)) this.layer = new RepeatVector.Builder().repetitionFactor(getRepeatMultiplier(layerConfig, conf))
.dataFormat(RNNFormat.NWC)
.name(this.layerName).build(); .name(this.layerName).build();
} }

View File

@ -18,6 +18,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.core;
import lombok.val; import lombok.val;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
@ -111,11 +112,9 @@ public class KerasReshape extends KerasLayer {
} else { } else {
targetShape = new long[]{targetShape[1], targetShape[0], targetShape[2]}; targetShape = new long[]{targetShape[1], targetShape[0], targetShape[2]};
} }
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false); preprocessor = new ReshapePreprocessor(inputShape, targetShape, false, CNN2DFormat.NCHW);
} else { // (dimOrder == DimOrder.TENSORFLOW || dimOrder == DimOrder.NONE && kerasMajorVersion == 2) } else { // (dimOrder == DimOrder.TENSORFLOW || dimOrder == DimOrder.NONE && kerasMajorVersion == 2)
if (inputShape[0] != targetShape[0]) preprocessor = new ReshapePreprocessor(inputShape, targetShape, false, CNN2DFormat.NHWC);
targetShape = new long[]{targetShape[2], targetShape[0], targetShape[1]};
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false);
} }
} else if (inputType[0] instanceof InputType.InputTypeConvolutional3D) { } else if (inputType[0] instanceof InputType.InputTypeConvolutional3D) {
@ -128,23 +127,23 @@ public class KerasReshape extends KerasLayer {
} else { } else {
targetShape = new long[] { targetShape[2], targetShape[1], targetShape[0], targetShape[3] }; targetShape = new long[] { targetShape[2], targetShape[1], targetShape[0], targetShape[3] };
} }
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false); preprocessor = new ReshapePreprocessor(inputShape, targetShape, false, null);
} else { } else {
if (inputShape[0] != targetShape[0]) if (inputShape[0] != targetShape[0])
targetShape = new long[] { targetShape[3], targetShape[0], targetShape[1], targetShape[2] }; targetShape = new long[] { targetShape[3], targetShape[0], targetShape[1], targetShape[2] };
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false); preprocessor = new ReshapePreprocessor(inputShape, targetShape, false, null);
} }
} else if (inputType[0] instanceof InputType.InputTypeRecurrent) { } else if (inputType[0] instanceof InputType.InputTypeRecurrent) {
InputType.InputTypeRecurrent it = (InputType.InputTypeRecurrent) inputType[0]; InputType.InputTypeRecurrent it = (InputType.InputTypeRecurrent) inputType[0];
val inputShape = new long[]{it.getSize(), it.getTimeSeriesLength()}; val inputShape = new long[]{it.getSize(), it.getTimeSeriesLength()};
preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false); preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false, null);
} else if (inputType[0] instanceof InputType.InputTypeFeedForward) { } else if (inputType[0] instanceof InputType.InputTypeFeedForward) {
InputType.InputTypeFeedForward it = (InputType.InputTypeFeedForward) inputType[0]; InputType.InputTypeFeedForward it = (InputType.InputTypeFeedForward) inputType[0];
val inputShape = new long[]{it.getSize()}; val inputShape = new long[]{it.getSize()};
if (targetShape.length == 3) { if (targetShape.length == 3) {
targetShape = targetShapeForDimOrder(inputShape, targetShape); targetShape = targetShapeForDimOrder(inputShape, targetShape);
} }
preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false); preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false, null);
} }
return preprocessor; return preprocessor;
} }

View File

@ -21,6 +21,7 @@ import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer; import org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
@ -121,6 +122,7 @@ public class KerasEmbedding extends KerasLayer {
.biasInit(0.0) .biasInit(0.0)
.l1(this.weightL1Regularization) .l1(this.weightL1Regularization)
.l2(this.weightL2Regularization) .l2(this.weightL2Regularization)
.outputDataFormat(RNNFormat.NWC)
.hasBias(false); .hasBias(false);
if (embeddingConstraint != null) if (embeddingConstraint != null)
builder.constrainWeights(embeddingConstraint); builder.constrainWeights(embeddingConstraint);

View File

@ -22,11 +22,9 @@ import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep; import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer; import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer; import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
@ -37,6 +35,7 @@ import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils; import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
import org.deeplearning4j.nn.params.LSTMParamInitializer; import org.deeplearning4j.nn.params.LSTMParamInitializer;
import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.util.TimeSeriesUtils;
import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -187,7 +186,7 @@ public class KerasLSTM extends KerasLayer {
.weightInitRecurrent(recurrentInit) .weightInitRecurrent(recurrentInit)
.biasInit(0.0) // TODO: this is incorrect .biasInit(0.0) // TODO: this is incorrect
.l1(this.weightL1Regularization) .l1(this.weightL1Regularization)
.l2(this.weightL2Regularization); .l2(this.weightL2Regularization).dataFormat(RNNFormat.NWC);
Integer nIn = KerasLayerUtils.getNInFromInputDim(layerConfig, conf); Integer nIn = KerasLayerUtils.getNInFromInputDim(layerConfig, conf);
if(nIn != null) if(nIn != null)
builder.setNIn(nIn); builder.setNIn(nIn);
@ -266,7 +265,8 @@ public class KerasLSTM extends KerasLayer {
throw new InvalidKerasConfigurationException("Keras LSTM layer accepts only one single input" + throw new InvalidKerasConfigurationException("Keras LSTM layer accepts only one single input" +
"or three (input to LSTM and two states tensors, but " + "or three (input to LSTM and two states tensors, but " +
"received " + inputType.length + "."); "received " + inputType.length + ".");
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], layerName); RNNFormat f = TimeSeriesUtils.getFormatFromRnnLayer(layer);
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], f,layerName);
} }
/** /**

View File

@ -21,7 +21,9 @@ import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.InputTypeUtil; import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
import org.deeplearning4j.nn.conf.layers.Layer; import org.deeplearning4j.nn.conf.layers.Layer;
@ -36,6 +38,7 @@ import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils; import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
import org.deeplearning4j.nn.params.SimpleRnnParamInitializer; import org.deeplearning4j.nn.params.SimpleRnnParamInitializer;
import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.util.TimeSeriesUtils;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
@ -155,7 +158,7 @@ public class KerasSimpleRnn extends KerasLayer {
.weightInitRecurrent(recurrentInit) .weightInitRecurrent(recurrentInit)
.biasInit(0.0) .biasInit(0.0)
.l1(this.weightL1Regularization) .l1(this.weightL1Regularization)
.l2(this.weightL2Regularization); .l2(this.weightL2Regularization).dataFormat(RNNFormat.NWC);
Integer nIn = KerasLayerUtils.getNInFromInputDim(layerConfig, conf); Integer nIn = KerasLayerUtils.getNInFromInputDim(layerConfig, conf);
if(nIn != null) if(nIn != null)
builder.setNIn(nIn); builder.setNIn(nIn);
@ -227,7 +230,8 @@ public class KerasSimpleRnn extends KerasLayer {
throw new InvalidKerasConfigurationException( throw new InvalidKerasConfigurationException(
"Keras SimpleRnn layer accepts only one input (received " + inputType.length + ")"); "Keras SimpleRnn layer accepts only one input (received " + inputType.length + ")");
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], layerName); RNNFormat f = TimeSeriesUtils.getFormatFromRnnLayer(layer);
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], f, layerName);
} }
/** /**

View File

@ -147,7 +147,7 @@ public class KerasBidirectional extends KerasLayer {
break; break;
case "SimpleRNN": case "SimpleRNN":
kerasRnnlayer = new KerasSimpleRnn(innerRnnConfig, enforceTrainingConfig, previousLayers); kerasRnnlayer = new KerasSimpleRnn(innerRnnConfig, enforceTrainingConfig, previousLayers);
SimpleRnn rnnLayer = (SimpleRnn) ((KerasSimpleRnn) kerasRnnlayer).getSimpleRnnLayer(); Layer rnnLayer = ((KerasSimpleRnn) kerasRnnlayer).getSimpleRnnLayer();
this.layer = new Bidirectional(mode, rnnLayer); this.layer = new Bidirectional(mode, rnnLayer);
layer.setLayerName(layerName); layer.setLayerName(layerName);
break; break;
@ -218,7 +218,7 @@ public class KerasBidirectional extends KerasLayer {
if (inputType.length > 1) if (inputType.length > 1)
throw new InvalidKerasConfigurationException( throw new InvalidKerasConfigurationException(
"Keras Bidirectional layer accepts only one input (received " + inputType.length + ")"); "Keras Bidirectional layer accepts only one input (received " + inputType.length + ")");
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], layerName); return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], ((Bidirectional)layer).getRNNDataFormat(), layerName);
} }
/** /**

View File

@ -21,6 +21,9 @@ import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.DataFormat;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException; import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
import org.deeplearning4j.nn.conf.preprocessor.BaseInputPreProcessor; import org.deeplearning4j.nn.conf.preprocessor.BaseInputPreProcessor;
@ -54,25 +57,30 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
private final long[] inputShape; private final long[] inputShape;
private final long[] targetShape; private final long[] targetShape;
private boolean hasMiniBatchDimension; private boolean hasMiniBatchDimension;
private DataFormat format;
/**
* @deprecated Use constructor {@link #ReshapePreprocessor(long[], long[], boolean)}
*/
@Deprecated
public ReshapePreprocessor(long[] inputShape, long[] targetShape) {
this(inputShape, targetShape, false);
}
/** /**
* @param inputShape Input shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension * @param inputShape Input shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension
* @param targetShape Target shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension * @param targetShape Target shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension
* @param hasMiniBatchDimension If true: shapes should be of the form [minibatch, x, y, ...]; if false: shapes should be of form [x, y, ...] * @param hasMiniBatchDimension If true: shapes should be of the form [minibatch, x, y, ...]; if false: shapes should be of form [x, y, ...]
*/ */
public ReshapePreprocessor(long[] inputShape, long[] targetShape, boolean hasMiniBatchDimension) {
this(inputShape, targetShape, hasMiniBatchDimension, null);
}
/**
* @param inputShape Input shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension
* @param targetShape Target shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension
* @param hasMiniBatchDimension If true: shapes should be of the form [minibatch, x, y, ...]; if false: shapes should be of form [x, y, ...]
* @param dataFormat May be null. If non-null:
*/
public ReshapePreprocessor(@JsonProperty("inputShape") long[] inputShape, @JsonProperty("targetShape") long[] targetShape, public ReshapePreprocessor(@JsonProperty("inputShape") long[] inputShape, @JsonProperty("targetShape") long[] targetShape,
@JsonProperty("hasMiniBatchDimension") boolean hasMiniBatchDimension) { @JsonProperty("hasMiniBatchDimension") boolean hasMiniBatchDimension,
@JsonProperty("dataFormat") DataFormat dataFormat) {
this.inputShape = inputShape; this.inputShape = inputShape;
this.targetShape = targetShape; this.targetShape = targetShape;
this.hasMiniBatchDimension = hasMiniBatchDimension; this.hasMiniBatchDimension = hasMiniBatchDimension;
this.format = dataFormat;
} }
private long[] getShape(long[] originalShape, long minibatch) { private long[] getShape(long[] originalShape, long minibatch) {
@ -140,13 +148,26 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
ret = InputType.feedForward(shape[1]); ret = InputType.feedForward(shape[1]);
break; break;
case 3: case 3:
ret = InputType.recurrent(shape[2], shape[1]); RNNFormat format = RNNFormat.NCW;
if(this.format != null && this.format instanceof RNNFormat)
format = (RNNFormat)this.format;
ret = InputType.recurrent(shape[2], shape[1], format);
break; break;
case 4: case 4:
if (inputShape.length == 1 || inputType.getType() == InputType.Type.RNN) { if (inputShape.length == 1 || inputType.getType() == InputType.Type.RNN) {
ret = InputType.convolutional(shape[1], shape[2], shape[3]); ret = InputType.convolutional(shape[1], shape[2], shape[3]);
} else { } else {
ret = InputType.convolutional(shape[2], shape[3], shape[1]);
CNN2DFormat cnnFormat = CNN2DFormat.NCHW;
if (this.format != null && this.format instanceof CNN2DFormat)
cnnFormat = (CNN2DFormat) this.format;
if (cnnFormat == CNN2DFormat.NCHW) {
ret = InputType.convolutional(shape[2], shape[3], shape[1], cnnFormat);
} else {
ret = InputType.convolutional(shape[1], shape[2], shape[3], cnnFormat);
}
} }
break; break;
default: default:

View File

@ -27,26 +27,25 @@ import org.nd4j.shade.jackson.annotation.JsonCreator;
import org.nd4j.shade.jackson.annotation.JsonProperty; import org.nd4j.shade.jackson.annotation.JsonProperty;
/** /**
* Specialized CnnToFeedForwardInputPreProcessor for use with * @deprecated Exists only for backward compatibility of older pretrained models. Should not be used.
* Convolutional layers imported from Keras using the TensorFlow * Use {@link CnnToFeedForwardPreProcessor} for all new models instead.
* backend.
*
* @author dave@skymind.io
*/ */
@Slf4j @Slf4j @Deprecated
public class TensorFlowCnnToFeedForwardPreProcessor extends CnnToFeedForwardPreProcessor { public class TensorFlowCnnToFeedForwardPreProcessor extends CnnToFeedForwardPreProcessor {
@JsonCreator @JsonCreator @Deprecated
public TensorFlowCnnToFeedForwardPreProcessor(@JsonProperty("inputHeight") long inputHeight, public TensorFlowCnnToFeedForwardPreProcessor(@JsonProperty("inputHeight") long inputHeight,
@JsonProperty("inputWidth") long inputWidth, @JsonProperty("inputWidth") long inputWidth,
@JsonProperty("numChannels") long numChannels) { @JsonProperty("numChannels") long numChannels) {
super(inputHeight, inputWidth, numChannels); super(inputHeight, inputWidth, numChannels);
} }
@Deprecated
public TensorFlowCnnToFeedForwardPreProcessor(long inputHeight, long inputWidth) { public TensorFlowCnnToFeedForwardPreProcessor(long inputHeight, long inputWidth) {
super(inputHeight, inputWidth); super(inputHeight, inputWidth);
} }
@Deprecated
public TensorFlowCnnToFeedForwardPreProcessor() { public TensorFlowCnnToFeedForwardPreProcessor() {
super(); super();
} }

View File

@ -31,6 +31,7 @@ import org.deeplearning4j.nn.modelimport.keras.layers.advanced.activations.*;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.*; import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.*;
import org.deeplearning4j.nn.modelimport.keras.layers.core.*; import org.deeplearning4j.nn.modelimport.keras.layers.core.*;
import org.deeplearning4j.nn.modelimport.keras.layers.embeddings.KerasEmbedding; import org.deeplearning4j.nn.modelimport.keras.layers.embeddings.KerasEmbedding;
import org.deeplearning4j.nn.modelimport.keras.layers.local.KerasLocallyConnected1D;
import org.deeplearning4j.nn.modelimport.keras.layers.noise.KerasAlphaDropout; import org.deeplearning4j.nn.modelimport.keras.layers.noise.KerasAlphaDropout;
import org.deeplearning4j.nn.modelimport.keras.layers.noise.KerasGaussianDropout; import org.deeplearning4j.nn.modelimport.keras.layers.noise.KerasGaussianDropout;
import org.deeplearning4j.nn.modelimport.keras.layers.noise.KerasGaussianNoise; import org.deeplearning4j.nn.modelimport.keras.layers.noise.KerasGaussianNoise;
@ -319,6 +320,8 @@ public class KerasLayerUtils {
layer = new KerasELU(layerConfig, enforceTrainingConfig); layer = new KerasELU(layerConfig, enforceTrainingConfig);
} else if(layerClassName.equals(conf.getLAYER_CLASS_NAME_SOFTMAX())){ } else if(layerClassName.equals(conf.getLAYER_CLASS_NAME_SOFTMAX())){
layer = new KerasSoftmax(layerConfig, enforceTrainingConfig); layer = new KerasSoftmax(layerConfig, enforceTrainingConfig);
} else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_LOCALLY_CONNECTED_1D())){
layer = new KerasLocallyConnected1D(layerConfig, enforceTrainingConfig);
} else if (conf instanceof Keras2LayerConfiguration){ } else if (conf instanceof Keras2LayerConfiguration){
Keras2LayerConfiguration k2conf = (Keras2LayerConfiguration)conf; Keras2LayerConfiguration k2conf = (Keras2LayerConfiguration)conf;
if (layerClassName.equals(k2conf.getTENSORFLOW_OP_LAYER())){ if (layerClassName.equals(k2conf.getTENSORFLOW_OP_LAYER())){

View File

@ -1,50 +0,0 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.nn.modelimport.keras;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.junit.Assert;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.resources.Resources;
import java.io.File;
import java.util.Arrays;
public class TFKerasTests extends BaseDL4JTest{
@Test
public void testModelWithTFOp1() throws Exception{
File f = Resources.asFile("modelimport/keras/tfkeras/reshape.h5");
ComputationGraph graph = KerasModelImport.importKerasModelAndWeights(f.getAbsolutePath());
INDArray out = graph.outputSingle(Nd4j.zeros(12, 2, 3));
Assert.assertArrayEquals(new long[]{12, 3}, out.shape());
}
@Test
public void testModelWithTFOp2() throws Exception{
File f = Resources.asFile("modelimport/keras/tfkeras/permute.h5");
ComputationGraph graph = KerasModelImport.importKerasModelAndWeights(f.getAbsolutePath());
INDArray out = graph.outputSingle(Nd4j.zeros(12, 2, 3));
// dl4j's feedforward doesn't support 3D output, so batch and time axes gets squashed
long[] expectedShape = new long[]{12 * 2, 5};
Assert.assertArrayEquals(expectedShape, out.shape());
}
}

View File

@ -0,0 +1,147 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.nn.modelimport.keras;
import org.apache.commons.io.FileUtils;
import org.datavec.python.keras.Model;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.common.tests.ResourceUtils;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.resources.Resources;
import java.io.File;
import java.util.List;
@RunWith(Parameterized.class)
public class TestTFKerasModelImport extends BaseDL4JTest{
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
private String modelFile;
@Override
public long getTimeoutMilliseconds(){
return 300000;
} // installing TF will take a while
@Parameterized.Parameters(name = "file={0}")
public static Object[] params() throws Exception {
List<String> paths = ResourceUtils.listClassPathFiles("modelimport/keras/tfkeras", true, false);
return paths.toArray(new String[0]);
}
public TestTFKerasModelImport(String modelFile){
this.modelFile = modelFile;
}
@Test
public void testModelImport() throws Exception{
testModelImportWithData(modelFile);
}
private void testModelImportWithData(String path) throws Exception{
System.out.println(path);
// TODO multi input/output
INDArray inputArray;
INDArray expectedOutputArray;
File f = Resources.asFile(path); //May in in JAR that HDF5 can't read from
File modelFile = new File(testDir.getRoot(), f.getName());
FileUtils.copyFile(f, modelFile);
synchronized (Hdf5Archive.LOCK_OBJECT){
Hdf5Archive hdf5Archive = new Hdf5Archive(modelFile.getAbsolutePath());
List<String> rootGroups = hdf5Archive.getGroups();
if (rootGroups.contains("data")){
String inputName = hdf5Archive.readAttributeAsString("input_names", "data");
String outputName = hdf5Archive.readAttributeAsString("output_names", "data");
inputArray = hdf5Archive.readDataSet(inputName, "data");
expectedOutputArray = hdf5Archive.readDataSet(outputName, "data");
}
else{
hdf5Archive.close();
return;
}
hdf5Archive.close();
}
INDArray outputArray;
ComputationGraph dl4jModel = KerasModelImport.importKerasModelAndWeights(path);
outputArray = dl4jModel.outputSingle(inputArray);
expectedOutputArray = expectedOutputArray.castTo(DataType.FLOAT);
outputArray = outputArray.castTo(DataType.FLOAT);
if (path.contains("misc_")){
//shape relaxation
expectedOutputArray = expectedOutputArray.reshape( -1);
outputArray = outputArray.reshape(-1);
}
System.out.println(outputArray.toString());
System.out.println(expectedOutputArray.toString());
Assert.assertArrayEquals(expectedOutputArray.shape(), outputArray.shape());
Assert.assertTrue(expectedOutputArray.equalsWithEps(outputArray, 1e-3));
}
private void testModelImportWithKeras(String path) throws Exception{
Model kerasModel = new Model(path);
ComputationGraph dl4jModel = KerasModelImport.importKerasModelAndWeights(path);
Assert.assertEquals(kerasModel.numInputs(), dl4jModel.getNumInputArrays());
Assert.assertEquals(kerasModel.numOutputs(), dl4jModel.getNumOutputArrays());
INDArray[] kerasInputArrays = new INDArray[kerasModel.numInputs()];
INDArray[] dl4jInputArrays = new INDArray[kerasModel.numInputs()];
for (int i = 0; i < kerasInputArrays.length; i ++) {
long[] shape = kerasModel.inputShapeAt(i);
for (int j = 0; j < shape.length; j++) {
if (shape[j] < 0) {
shape[j] = 1;
}
}
kerasInputArrays[i] = Nd4j.rand(shape);
}
INDArray[] kerasOut = kerasModel.predict(kerasInputArrays);
INDArray[] dl4jOut = dl4jModel.output(dl4jInputArrays);
Assert.assertEquals(kerasOut.length, dl4jOut.length);
for (int i = 0; i < kerasOut.length; i++){
INDArray kerasOutArr = kerasOut[i];
kerasOutArr = kerasOutArr.reshape(1, -1);// bit of relaxation on shape
kerasOutArr= kerasOutArr.castTo(DataType.DOUBLE);
Nd4j.getAffinityManager().ensureLocation(dl4jOut[i], AffinityManager.Location.HOST);
INDArray dl4jOutArr = dl4jOut[i].reshape(1, -1);
System.out.println(kerasOutArr.shapeInfoToString());
System.out.println(dl4jOutArr.shapeInfoToString());
Assert.assertEquals(kerasOutArr, dl4jOutArr);
}
}
}

Some files were not shown because too many files have changed in this diff Show More