Merge pull request #8892 from KonduitAI/master

Merge latest development work
master
Alex Black 2020-04-28 20:38:56 +10:00 committed by GitHub
commit 1930d99908
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
475 changed files with 10073 additions and 2958 deletions

View File

@ -98,7 +98,7 @@ public class TestGraphLocalExecution extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 90000L;
return 120_000L;
}
@Test
@ -156,7 +156,7 @@ public class TestGraphLocalExecution extends BaseDL4JTest {
.dataSource(ds, dsP)
.modelSaver(new FileModelSaver(modelSave))
.scoreFunction(new TestSetLossScoreFunction())
.terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS),
.terminationConditions(new MaxTimeCondition(20, TimeUnit.SECONDS),
new MaxCandidatesCondition(3))
.build();

View File

@ -87,7 +87,7 @@ public class TestGraphLocalExecutionGenetic extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 45000L;
return 120_000L;
}
@Test
@ -154,8 +154,8 @@ public class TestGraphLocalExecutionGenetic extends BaseDL4JTest {
.dataSource(ds, dsP)
.modelSaver(new FileModelSaver(modelSave))
.scoreFunction(new TestSetLossScoreFunction())
.terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS),
new MaxCandidatesCondition(10))
.terminationConditions(new MaxTimeCondition(20, TimeUnit.SECONDS),
new MaxCandidatesCondition(3))
.build();
IOptimizationRunner runner = new LocalOptimizationRunner(configuration, new ComputationGraphTaskCreator(new ClassificationEvaluator()));

View File

@ -18,6 +18,7 @@ package org.datavec.api.records.reader.impl;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import org.apache.commons.io.LineIterator;
import org.datavec.api.conf.Configuration;
@ -43,6 +44,7 @@ import java.util.*;
*
* @author Adam Gibson
*/
@Slf4j
public class LineRecordReader extends BaseRecordReader {
@ -58,6 +60,13 @@ public class LineRecordReader extends BaseRecordReader {
@Override
public void initialize(InputSplit split) throws IOException, InterruptedException {
super.initialize(split);
if(!(inputSplit instanceof StringSplit || inputSplit instanceof InputStreamInputSplit)){
final ArrayList<URI> uris = new ArrayList<>();
final Iterator<URI> uriIterator = inputSplit.locationsIterator();
while(uriIterator.hasNext()) uris.add(uriIterator.next());
this.locations = uris.toArray(new URI[0]);
}
this.iter = getIterator(0);
this.initialized = true;
}
@ -66,7 +75,6 @@ public class LineRecordReader extends BaseRecordReader {
public void initialize(Configuration conf, InputSplit split) throws IOException, InterruptedException {
this.conf = conf;
initialize(split);
this.initialized = true;
}
@Override
@ -89,7 +97,7 @@ public class LineRecordReader extends BaseRecordReader {
iter = getIterator(splitIndex);
onLocationOpen(locations[splitIndex]);
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
if (iter.hasNext()) {
@ -120,7 +128,7 @@ public class LineRecordReader extends BaseRecordReader {
iter = getIterator(splitIndex);
onLocationOpen(locations[splitIndex]);
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
return iter.hasNext();
@ -205,11 +213,6 @@ public class LineRecordReader extends BaseRecordReader {
}
}
} else {
final ArrayList<URI> uris = new ArrayList<>();
final Iterator<URI> uriIterator = inputSplit.locationsIterator();
while(uriIterator.hasNext()) uris.add(uriIterator.next());
this.locations = uris.toArray(new URI[uris.size()]);
if (locations.length > 0) {
InputStream inputStream = streamCreatorFn.apply(locations[location]);
try {

View File

@ -16,6 +16,7 @@
package org.datavec.api.split;
import lombok.extern.slf4j.Slf4j;
import org.datavec.api.util.files.UriFromPathIterator;
import org.datavec.api.writable.WritableType;
@ -34,6 +35,7 @@ import java.util.regex.Pattern;
* NumberedFileInputSplit utilizes String.format(), hence the requirement for "%d" to represent
* the integer index.
*/
@Slf4j
public class NumberedFileInputSplit implements InputSplit {
private final String baseString;
private final int minIdx;
@ -93,7 +95,7 @@ public class NumberedFileInputSplit implements InputSplit {
try {
writeFile.createNewFile();
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}

View File

@ -23,6 +23,7 @@ import java.net.URI;
import java.net.URISyntaxException;
import java.util.Iterator;
import java.util.NoSuchElementException;
import java.util.regex.Pattern;
/**
* A simple utility method to convert a {@code Iterator<String>} to an {@code Iterator<URI>}, where each
@ -32,6 +33,7 @@ import java.util.NoSuchElementException;
*/
@AllArgsConstructor
public class UriFromPathIterator implements Iterator<URI> {
final Pattern schemaPattern = Pattern.compile("^.*?:/.*");
private final Iterator<String> paths;
@ -42,16 +44,17 @@ public class UriFromPathIterator implements Iterator<URI> {
@Override
public URI next() {
if (!hasNext()) {
throw new NoSuchElementException("No next element");
}
try {
String s = paths.next();
if(!s.matches(".*:/.*")){
if(schemaPattern.matcher(s).matches()){
return new URI(s);
} else {
//No scheme - assume file for backward compatibility
return new File(s).toURI();
} else {
return new URI(s);
}
} catch (URISyntaxException e) {

View File

@ -162,7 +162,6 @@ public class Text extends BinaryComparable implements WritableComparable<BinaryC
return -1; // not found
} catch (CharacterCodingException e) {
// can't get here
e.printStackTrace();
return -1;
}
}

View File

@ -17,6 +17,7 @@
package org.datavec.arrow.recordreader;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.datavec.api.conf.Configuration;
import org.datavec.api.records.Record;
@ -50,6 +51,7 @@ import static org.datavec.arrow.ArrowConverter.readFromBytes;
* @author Adam Gibson
*
*/
@Slf4j
public class ArrowRecordReader implements RecordReader {
private InputSplit split;
@ -132,7 +134,7 @@ public class ArrowRecordReader implements RecordReader {
currIdx++;
this.currentPath = url;
}catch(Exception e) {
e.printStackTrace();
log.error("",e);
}
}
@ -242,7 +244,7 @@ public class ArrowRecordReader implements RecordReader {
try {
currentBatch.close();
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
}
}

View File

@ -16,6 +16,7 @@
package org.datavec.arrow;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
@ -56,7 +57,7 @@ import static junit.framework.TestCase.assertTrue;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
@Slf4j
public class ArrowConverterTest extends BaseND4JTest {
private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE);
@ -343,7 +344,7 @@ public class ArrowConverterTest extends BaseND4JTest {
try(ArrowFileWriter arrowFileWriter = new ArrowFileWriter(schemaRoot1,null,newChannel(byteArrayOutputStream))) {
arrowFileWriter.writeBatch();
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
byte[] arr = byteArrayOutputStream.toByteArray();

View File

@ -61,7 +61,7 @@ public class Wave implements Serializable {
initWaveWithInputStream(inputStream);
inputStream.close();
} catch (IOException e) {
e.printStackTrace();
System.out.println(e.toString());
}
}
@ -96,7 +96,7 @@ public class Wave implements Serializable {
data = new byte[inputStream.available()];
inputStream.read(data);
} catch (IOException e) {
e.printStackTrace();
System.err.println(e.toString());
}
// end load data
} else {

View File

@ -16,10 +16,12 @@
package org.datavec.audio;
import lombok.extern.slf4j.Slf4j;
import java.io.FileOutputStream;
import java.io.IOException;
@Slf4j
public class WaveFileManager {
private Wave wave;
@ -78,7 +80,7 @@ public class WaveFileManager {
fos.write(wave.getBytes());
fos.close();
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
}

View File

@ -16,6 +16,8 @@
package org.datavec.audio;
import lombok.extern.slf4j.Slf4j;
import java.io.IOException;
import java.io.InputStream;
@ -25,6 +27,7 @@ import java.io.InputStream;
*
* @author Jacquet Wong
*/
@Slf4j
public class WaveHeader {
public static final String RIFF_HEADER = "RIFF";
@ -109,7 +112,7 @@ public class WaveHeader {
// dis.close();
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
return false;
}

View File

@ -17,6 +17,7 @@
package org.datavec.audio.fingerprint;
import lombok.extern.slf4j.Slf4j;
import org.datavec.audio.Wave;
import org.datavec.audio.WaveHeader;
import org.datavec.audio.dsp.Resampler;
@ -38,6 +39,7 @@ import java.util.List;
* @author jacquet
*
*/
@Slf4j
public class FingerprintManager {
private FingerprintProperties fingerprintProperties = FingerprintProperties.getInstance();
@ -153,7 +155,7 @@ public class FingerprintManager {
fingerprint = getFingerprintFromInputStream(fis);
fis.close();
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
return fingerprint;
}
@ -170,7 +172,7 @@ public class FingerprintManager {
fingerprint = new byte[inputStream.available()];
inputStream.read(fingerprint);
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
return fingerprint;
}
@ -190,7 +192,7 @@ public class FingerprintManager {
fileOutputStream.write(fingerprint);
fileOutputStream.close();
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
}

View File

@ -81,7 +81,7 @@ public abstract class BaseImageLoader implements Serializable {
String fileName = file.toString();
if (fileName.endsWith(".tgz") || fileName.endsWith(".tar.gz") || fileName.endsWith(".gz")
|| fileName.endsWith(".zip"))
ArchiveUtils.unzipFileTo(file.getAbsolutePath(), fullDir.getAbsolutePath());
ArchiveUtils.unzipFileTo(file.getAbsolutePath(), fullDir.getAbsolutePath(), false);
} catch (IOException e) {
throw new IllegalStateException("Unable to fetch images", e);
}

View File

@ -186,7 +186,7 @@ public class CifarLoader extends NativeImageLoader implements Serializable {
labels.add(line);
}
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
}
@ -300,7 +300,7 @@ public class CifarLoader extends NativeImageLoader implements Serializable {
try {
FileUtils.write(meanVarPath, uMean + "," + uStd + "," + vMean + "," + vStd);
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
meanStdStored = true;
} else if (uMean == 0 && meanStdStored) {
@ -312,7 +312,7 @@ public class CifarLoader extends NativeImageLoader implements Serializable {
vStd = Double.parseDouble(values[3]);
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
}
for (int i = 0; i < result.numExamples(); i++) {
@ -356,12 +356,12 @@ public class CifarLoader extends NativeImageLoader implements Serializable {
dataSets.add(new DataSet(asMatrix(matConversion.getSecond()), matConversion.getFirst()));
batchNumCount++;
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
break;
}
}
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
if(dataSets.size() == 0){

View File

@ -235,7 +235,7 @@ public class LFWLoader extends BaseImageLoader implements Serializable {
InputSplit data = train ? inputSplit[0] : inputSplit[1];
recordReader.initialize(data);
} catch (IOException | InterruptedException e) {
e.printStackTrace();
log.error("",e);
}
return recordReader;
}
@ -250,7 +250,7 @@ public class LFWLoader extends BaseImageLoader implements Serializable {
InputSplit data = train ? inputSplit[0] : inputSplit[1];
recordReader.initialize(data);
} catch (IOException | InterruptedException e) {
e.printStackTrace();
log.error("",e);
}
return recordReader;
}

View File

@ -225,7 +225,7 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
finishedInputStreamSplit = true;
return Arrays.<Writable>asList(ndArrayWritable);
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
}
if (iter != null) {

View File

@ -16,6 +16,7 @@
package org.datavec.image.loader;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.apache.commons.io.IOUtils;
import org.bytedeco.javacpp.Loader;
@ -53,6 +54,7 @@ import static org.junit.Assert.fail;
*
* @author saudet
*/
@Slf4j
public class TestNativeImageLoader {
static final long seed = 10;
static final Random rng = new Random(seed);
@ -123,7 +125,7 @@ public class TestNativeImageLoader {
try {
array6 = loader5.asMatrix(new ClassPathResource(path2MitosisFile).getFile().getAbsolutePath());
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
fail();
}
assertEquals(5, array6.rank());
@ -156,7 +158,7 @@ public class TestNativeImageLoader {
try {
array8 = loader7.asMatrix(new ClassPathResource(path2MitosisFile).getFile().getAbsolutePath());
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
assertEquals(5, array8.rank());
assertEquals(pages2, array8.size(0));
@ -172,7 +174,7 @@ public class TestNativeImageLoader {
try {
array9 = loader8.asMatrix(new ClassPathResource(braintiff).getFile().getAbsolutePath());
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
fail();
}
assertEquals(5, array9.rank());

View File

@ -66,7 +66,7 @@ public class UimaTokenizer implements Tokenizer {
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
throw new RuntimeException(e);
}

View File

@ -17,6 +17,7 @@
package org.datavec.local.transforms.functions;
import lombok.extern.slf4j.Slf4j;
import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.writable.Writable;
import org.nd4j.linalg.function.Function;
@ -32,6 +33,7 @@ import java.util.List;
* sequence data into a {@code List<List<Writable>>}
* @author Alex Black
*/
@Slf4j
public class SequenceRecordReaderFunction
implements Function<Pair<String, InputStream>, List<List<Writable>>> {
protected SequenceRecordReader sequenceRecordReader;
@ -46,7 +48,7 @@ public class SequenceRecordReaderFunction
try (DataInputStream dis = (DataInputStream) value.getRight()) {
return sequenceRecordReader.sequenceRecord(uri, dis);
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
throw new IllegalStateException("Something went wrong");

View File

@ -20,6 +20,7 @@ package org.datavec.python;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import org.bytedeco.cpython.global.python;
import org.bytedeco.numpy.global.numpy;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -343,6 +344,19 @@ public class PythonExecutioner {
if (path == null) {
log.info("Setting python default path");
File[] packages = numpy.cachePackages();
//// TODO: fix in javacpp
File sitePackagesWindows = new File(python.cachePackage(), "site-packages");
File[] packages2 = new File[packages.length + 1];
for (int i = 0;i < packages.length; i++){
//System.out.println(packages[i].getAbsolutePath());
packages2[i] = packages[i];
}
packages2[packages.length] = sitePackagesWindows;
//System.out.println(sitePackagesWindows.getAbsolutePath());
packages = packages2;
//////////
Py_SetPath(packages);
} else {
log.info("Setting python path " + path);

View File

@ -0,0 +1,132 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.python;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import org.bytedeco.javacpp.Loader;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
@Slf4j
public class PythonProcess {
private static String pythonExecutable = Loader.load(org.bytedeco.cpython.python.class);
public static String runAndReturn(String... arguments)throws IOException, InterruptedException{
String[] allArgs = new String[arguments.length + 1];
for (int i = 0; i < arguments.length; i++){
allArgs[i + 1] = arguments[i];
}
allArgs[0] = pythonExecutable;
log.info("Executing command: " + Arrays.toString(allArgs));
ProcessBuilder pb = new ProcessBuilder(allArgs);
Process process = pb.start();
String out = IOUtils.toString(process.getInputStream(), StandardCharsets.UTF_8);
process.waitFor();
return out;
}
public static void run(String... arguments)throws IOException, InterruptedException{
String[] allArgs = new String[arguments.length + 1];
for (int i = 0; i < arguments.length; i++){
allArgs[i + 1] = arguments[i];
}
allArgs[0] = pythonExecutable;
log.info("Executing command: " + Arrays.toString(allArgs));
ProcessBuilder pb = new ProcessBuilder(allArgs);
pb.inheritIO().start().waitFor();
}
public static void pipInstall(String packageName) throws PythonException{
try{
run("-m", "pip", "install", packageName);
}catch(Exception e){
throw new PythonException("Error installing package " + packageName, e);
}
}
public static void pipInstall(String packageName, String version) throws PythonException{
pipInstall(packageName + "==" + version);
}
public static void pipUninstall(String packageName) throws PythonException{
try{
run("-m", "pip", "uninstall", packageName);
}catch(Exception e){
throw new PythonException("Error uninstalling package " + packageName, e);
}
}
public static void pipInstallFromGit(String gitRepoUrl) throws PythonException{
if (!gitRepoUrl.contains("://")){
gitRepoUrl = "git://" + gitRepoUrl;
}
try{
run("-m", "pip", "install", "git+", gitRepoUrl);
}catch(Exception e){
throw new PythonException("Error installing package from " + gitRepoUrl, e);
}
}
public static String getPackageVersion(String packageName) throws PythonException{
String out;
try{
out = runAndReturn("-m", "pip", "show", packageName);
} catch (Exception e){
throw new PythonException("Error finding version for package " + packageName, e);
}
if (!out.contains("Version: ")){
throw new PythonException("Can't find package " + packageName);
}
String pkgVersion = out.split("Version: ")[1].split(System.lineSeparator())[0];
return pkgVersion;
}
public static boolean isPackageInstalled(String packageName)throws PythonException{
try{
String out = runAndReturn("-m", "pip", "show", packageName);
return !out.isEmpty();
}catch (Exception e){
throw new PythonException("Error checking if package is installed: " +packageName, e);
}
}
public static void pipInstallFromRequirementsTxt(String path) throws PythonException{
try{
run("-m", "pip", "install","-r", path);
}catch (Exception e){
throw new PythonException("Error installing packages from " + path, e);
}
}
public static void pipInstallFromSetupScript(String path, boolean inplace) throws PythonException{
try{
run(path, inplace?"develop":"install");
}catch (Exception e){
throw new PythonException("Error installing package from " + path, e);
}
}
}

View File

@ -0,0 +1,144 @@
package org.datavec.python.keras;
import org.datavec.python.Python;
import org.datavec.python.PythonException;
import org.datavec.python.PythonObject;
import org.datavec.python.PythonProcess;
import org.nd4j.linalg.api.ndarray.INDArray;
public class Model {
private PythonObject pyModel;
private static PythonObject installAndImportTF() throws PythonException{
if (!PythonProcess.isPackageInstalled("tensorflow")){
PythonProcess.pipInstall("tensorflow");
}
return Python.importModule("tensorflow");
}
private static PythonObject getKerasModule() throws PythonException{
PythonObject tf = installAndImportTF();
PythonObject keras = tf.attr("keras");
tf.del();
return keras;
}
private static PythonObject loadModel(String s) throws PythonException{
PythonObject models = getKerasModule().attr("models");
PythonObject loadModelF = models.attr("load_model");
PythonObject model = loadModelF.call(s);
models.del();
loadModelF.del();
return model;
}
public Model(String path) throws PythonException{
pyModel = loadModel(path);
}
public INDArray[] predict(INDArray... inputs) throws PythonException{
PythonObject predictF = pyModel.attr("predict");
PythonObject inputList = new PythonObject(inputs);
PythonObject pyOut = predictF.call(inputList);
INDArray[] out;
if (Python.isinstance(pyOut, Python.listType())){
out = new INDArray[Python.len(pyOut).toInt()];
for(int i = 0; i < out.length; i++){
out[i] = pyOut.get(i).toNumpy().getNd4jArray();
}
}
else{
out = new INDArray[]{
pyOut.toNumpy().getNd4jArray()};
}
predictF.del();
inputList.del();
pyOut.del();
return out;
}
public int numInputs(){
PythonObject inputs = pyModel.attr("inputs");
PythonObject pyNumInputs = Python.len(inputs);
int ret = pyNumInputs.toInt();
inputs.del();
pyNumInputs.del();
return ret;
}
public int numOutputs(){
PythonObject outputs = pyModel.attr("outputs");
PythonObject pyNumOutputs = Python.len(outputs);
int ret = pyNumOutputs.toInt();
outputs.del();
pyNumOutputs.del();
return ret;
}
public long[][] inputShapes(){
long[][] ret = new long[numInputs()][];
for (int i = 0; i < ret.length; i++){
ret[i] = inputShapeAt(i);
}
return ret;
}
public long[][] outputShapes(){
long[][] ret = new long[numOutputs()][];
for (int i = 0; i < ret.length; i++){
ret[i] = outputShapeAt(i);
}
return ret;
}
public long[] inputShapeAt(int input){
PythonObject inputs = pyModel.attr("inputs");
PythonObject tensor = inputs.get(input);
PythonObject tensorShape = tensor.attr("shape");
PythonObject shapeList = Python.list(tensorShape);
PythonObject pyNdim = Python.len(shapeList);
int ndim = pyNdim.toInt();
long[] shape = new long[ndim];
for(int i = 0; i < shape.length; i++){
PythonObject pyDim = shapeList.get(i);
if (pyDim == null || !Python.isinstance(pyDim, Python.intType())){
shape[i] = -1;
}
else{
shape[i] = pyDim.toLong();
}
}
pyNdim.del();
shapeList.del();
tensorShape.del();
tensor.del();
inputs.del();
return shape;
}
public long[] outputShapeAt(int output){
PythonObject inputs = pyModel.attr("outputs");
PythonObject tensor = inputs.get(output);
PythonObject tensorShape = tensor.attr("shape");
PythonObject shapeList = Python.list(tensorShape);
PythonObject pyNdim = Python.len(shapeList);
int ndim = pyNdim.toInt();
long[] shape = new long[ndim];
for(int i = 0; i < shape.length; i++){
PythonObject pyDim = shapeList.get(i);
if (pyDim == null || !Python.isinstance(pyDim, Python.intType())){
shape[i] = -1;
}
else{
shape[i] = pyDim.toLong();
}
}
pyNdim.del();
shapeList.del();
tensorShape.del();
tensor.del();
inputs.del();
return shape;
}
}

View File

@ -74,7 +74,6 @@ public class DataVecTransformClient implements DataVecTransformService {
} catch (UnirestException e) {
log.error("Error in setCSVTransformProcess()", e);
e.printStackTrace();
}
}
@ -94,7 +93,6 @@ public class DataVecTransformClient implements DataVecTransformService {
return TransformProcess.fromJson(s);
} catch (UnirestException e) {
log.error("Error in getCSVTransformProcess()",e);
e.printStackTrace();
}
return null;
@ -119,7 +117,6 @@ public class DataVecTransformClient implements DataVecTransformService {
return singleCsvRecord;
} catch (UnirestException e) {
log.error("Error in transformIncremental(SingleCSVRecord)",e);
e.printStackTrace();
}
return null;
}
@ -140,8 +137,7 @@ public class DataVecTransformClient implements DataVecTransformService {
.getBody();
return batchCSVRecord1;
} catch (UnirestException e) {
log.error("Error in transform(BatchCSVRecord)", e);
e.printStackTrace();
log.error("",e);
}
return null;
@ -162,7 +158,6 @@ public class DataVecTransformClient implements DataVecTransformService {
return batchCSVRecord1;
} catch (UnirestException e) {
log.error("Error in transform(BatchCSVRecord)", e);
e.printStackTrace();
}
return null;
@ -181,7 +176,6 @@ public class DataVecTransformClient implements DataVecTransformService {
return batchArray1;
} catch (UnirestException e) {
log.error("Error in transformArray(BatchCSVRecord)",e);
e.printStackTrace();
}
return null;
@ -200,7 +194,6 @@ public class DataVecTransformClient implements DataVecTransformService {
return array;
} catch (UnirestException e) {
log.error("Error in transformArrayIncremental(SingleCSVRecord)",e);
e.printStackTrace();
}
return null;
@ -231,7 +224,6 @@ public class DataVecTransformClient implements DataVecTransformService {
return array;
} catch (UnirestException e) {
log.error("Error in transformSequenceArrayIncremental",e);
e.printStackTrace();
}
return null;
@ -252,7 +244,6 @@ public class DataVecTransformClient implements DataVecTransformService {
return batchArray1;
} catch (UnirestException e) {
log.error("Error in transformSequenceArray",e);
e.printStackTrace();
}
return null;
@ -274,7 +265,6 @@ public class DataVecTransformClient implements DataVecTransformService {
return batchCSVRecord1;
} catch (UnirestException e) {
log.error("Error in transformSequence");
e.printStackTrace();
}
return null;
@ -295,7 +285,6 @@ public class DataVecTransformClient implements DataVecTransformService {
return singleCsvRecord;
} catch (UnirestException e) {
log.error("Error in transformSequenceIncremental");
e.printStackTrace();
}
return null;
}

View File

@ -18,6 +18,7 @@ package org.datavec.spark.transform;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
@ -53,7 +54,7 @@ import static org.datavec.local.transforms.LocalTransformExecutor.executeToSeque
* @author Adan Gibson
*/
@AllArgsConstructor
@Slf4j
public class CSVSparkTransform {
@Getter
private TransformProcess transformProcess;
@ -252,7 +253,7 @@ public class CSVSparkTransform {
try {
return new Base64NDArrayBody(Nd4jBase64.base64String(arr));
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
return null;

View File

@ -88,7 +88,7 @@ public class ImageSparkTransformServer extends SparkTransformServer {
log.info("Transform process initialized");
return ok(objectMapper.writeValueAsString(transform.getImageTransformProcess())).as(contentType);
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
return internalServerError();
}
});
@ -100,7 +100,7 @@ public class ImageSparkTransformServer extends SparkTransformServer {
log.info("Transform process initialized");
return ok(objectMapper.writeValueAsString(transformProcess)).as(contentType);
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
return internalServerError();
}
});
@ -112,7 +112,7 @@ public class ImageSparkTransformServer extends SparkTransformServer {
return badRequest();
return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType);
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
return internalServerError();
}
});
@ -130,7 +130,7 @@ public class ImageSparkTransformServer extends SparkTransformServer {
return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType);
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
return internalServerError();
}
});
@ -142,7 +142,7 @@ public class ImageSparkTransformServer extends SparkTransformServer {
return badRequest();
return ok(objectMapper.writeValueAsString(transformArray(batch))).as(contentType);
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
return internalServerError();
}
});
@ -169,7 +169,7 @@ public class ImageSparkTransformServer extends SparkTransformServer {
return ok(objectMapper.writeValueAsString(transformArray(batch))).as(contentType);
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
return internalServerError();
}
});

View File

@ -16,6 +16,7 @@
package org.datavec.spark;
import lombok.extern.slf4j.Slf4j;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.junit.After;
@ -23,6 +24,7 @@ import org.junit.Before;
import java.io.Serializable;
@Slf4j
public abstract class BaseSparkTest implements Serializable {
protected static JavaSparkContext sc;
@ -40,7 +42,7 @@ public abstract class BaseSparkTest implements Serializable {
try {
Thread.sleep(100L);
} catch (InterruptedException e) {
e.printStackTrace();
log.error("",e);
}
} else {
break;

View File

@ -68,7 +68,7 @@ public abstract class BaseDL4JTest {
* Override this method to set the default timeout for methods in the test class
*/
public long getTimeoutMilliseconds(){
return 30000;
return 60_000;
}
/**

View File

@ -43,7 +43,8 @@ import java.util.concurrent.atomic.AtomicLong;
*/
@Slf4j
public class RemoteUIStatsStorageRouter implements StatsStorageRouter, Serializable, Closeable {
private static final String ROUTE_IS_DOWN = "Info posted to RemoteUIStatsStorageRouter but router is shut down.";
private static final String MAX_WARNINGS_REACHED = "RemoteUIStatsStorageRouter: Reached max shutdown warnings. No further warnings will be produced.";
/**
* Default path for posting data to the UI - i.e., http://localhost:9000/remoteReceive or similar
*/
@ -163,10 +164,10 @@ public class RemoteUIStatsStorageRouter implements StatsStorageRouter, Serializa
if (shutdown.get()) {
long count = shutdownWarnCount.getAndIncrement();
if (count <= MAX_SHUTDOWN_WARN_COUNT) {
log.warn("Info posted to RemoteUIStatsStorageRouter but router is shut down.");
log.warn(ROUTE_IS_DOWN);
}
if (count == MAX_SHUTDOWN_WARN_COUNT) {
log.warn("RemoteUIStatsStorageRouter: Reached max shutdown warnings. No further warnings will be produced.");
log.warn(MAX_WARNINGS_REACHED);
}
} else {
for (StorageMetaData m : storageMetaData) {
@ -186,10 +187,10 @@ public class RemoteUIStatsStorageRouter implements StatsStorageRouter, Serializa
if (shutdown.get()) {
long count = shutdownWarnCount.getAndIncrement();
if (count <= MAX_SHUTDOWN_WARN_COUNT) {
log.warn("Info posted to RemoteUIStatsStorageRouter but router is shut down.");
log.warn(ROUTE_IS_DOWN);
}
if (count == MAX_SHUTDOWN_WARN_COUNT) {
log.warn("RemoteUIStatsStorageRouter: Reached max shutdown warnings. No further warnings will be produced.");
log.warn(MAX_WARNINGS_REACHED);
}
} else {
for (Persistable p : staticInfo) {
@ -209,10 +210,10 @@ public class RemoteUIStatsStorageRouter implements StatsStorageRouter, Serializa
if (shutdown.get()) {
long count = shutdownWarnCount.getAndIncrement();
if (count <= MAX_SHUTDOWN_WARN_COUNT) {
log.warn("Info posted to RemoteUIStatsStorageRouter but router is shut down.");
log.warn(ROUTE_IS_DOWN);
}
if (count == MAX_SHUTDOWN_WARN_COUNT) {
log.warn("RemoteUIStatsStorageRouter: Reached max shutdown warnings. No further warnings will be produced.");
log.warn(MAX_WARNINGS_REACHED);
}
} else {
for (Persistable p : updates) {

View File

@ -67,10 +67,7 @@ public class AsyncIterator<T extends Object> implements Iterator<T> {
nextElement = buffer.take();
// same on this run
if (nextElement == terminator)
return false;
return true;
return (nextElement != terminator);
} catch (Exception e) {
log.error("Premature end of loop!");
return false;

View File

@ -43,7 +43,7 @@ public class SystemInfoPrintListener implements TrainingListener {
private boolean printOnBackwardPass;
private boolean printOnGradientCalculation;
private static final String SYSTEM_INFO = "System info on epoch end: ";
@Override
public void iterationDone(Model model, int iteration, int epoch) {
@ -65,7 +65,7 @@ public class SystemInfoPrintListener implements TrainingListener {
return;
SystemInfo systemInfo = new SystemInfo();
log.info("System info on epoch end: ");
log.info(SYSTEM_INFO);
log.info(systemInfo.toPrettyJSON());
}
@ -75,7 +75,7 @@ public class SystemInfoPrintListener implements TrainingListener {
return;
SystemInfo systemInfo = new SystemInfo();
log.info("System info on epoch end: ");
log.info(SYSTEM_INFO);
log.info(systemInfo.toPrettyJSON());
}
@ -85,7 +85,7 @@ public class SystemInfoPrintListener implements TrainingListener {
return;
SystemInfo systemInfo = new SystemInfo();
log.info("System info on epoch end: ");
log.info(SYSTEM_INFO);
log.info(systemInfo.toPrettyJSON());
}
@ -95,7 +95,7 @@ public class SystemInfoPrintListener implements TrainingListener {
return;
SystemInfo systemInfo = new SystemInfo();
log.info("System info on epoch end: ");
log.info(SYSTEM_INFO);
log.info(systemInfo.toPrettyJSON());
}
@ -104,7 +104,7 @@ public class SystemInfoPrintListener implements TrainingListener {
if(!printOnBackwardPass)
return;
SystemInfo systemInfo = new SystemInfo();
log.info("System info on epoch end: ");
log.info(SYSTEM_INFO);
log.info(systemInfo.toPrettyJSON());
}
}

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.perf.listener;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
import oshi.json.SystemInfo;
@ -36,6 +37,7 @@ import java.util.concurrent.TimeUnit;
*
* @author Adam Gibson
*/
@Slf4j
public class SystemPolling {
private ScheduledExecutorService scheduledExecutorService;
@ -66,7 +68,7 @@ public class SystemPolling {
try {
objectMapper.writeValue(hardwareFile,hardwareMetric);
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
}
},0,pollEveryMillis, TimeUnit.MILLISECONDS);

View File

@ -27,7 +27,6 @@ import org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
import org.nd4j.linalg.dataset.api.preprocessor.Normalizer;
import java.io.*;
import java.nio.file.Files;
import java.util.UUID;
/**

View File

@ -20,6 +20,7 @@ import org.apache.commons.compress.utils.IOUtils;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
@ -153,11 +154,22 @@ public class TestUtils {
return randomOneHotTimeSeries(minibatch, outSize, tsLength, new Random(rngSeed));
}
public static INDArray randomOneHotTimeSeries(int minibatch, int outSize, int tsLength, Random rng){
INDArray out = Nd4j.create(new int[]{minibatch, outSize, tsLength}, 'f');
public static INDArray randomOneHotTimeSeries(int minibatch, int outSize, int tsLength, Random rng) {
return randomOneHotTimeSeries(RNNFormat.NCW, minibatch, outSize, tsLength, rng);
}
public static INDArray randomOneHotTimeSeries(RNNFormat format, int minibatch, int outSize, int tsLength, Random rng){
boolean ncw = format == RNNFormat.NCW;
long[] shape = ncw ? new long[]{minibatch, outSize, tsLength} : new long[]{minibatch, tsLength, outSize};
char order = ncw ? 'f' : 'c';
INDArray out = Nd4j.create(DataType.FLOAT, shape, order);
for( int i=0; i<minibatch; i++ ){
for( int j=0; j<tsLength; j++ ){
out.putScalar(i, rng.nextInt(outSize), j, 1.0);
if(ncw){
out.putScalar(i, rng.nextInt(outSize), j, 1.0);
} else {
out.putScalar(i, j, rng.nextInt(outSize), 1.0);
}
}
}
return out;

View File

@ -493,7 +493,7 @@ public class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
out.println();
}
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
return new Pair<>(dArr,temp);

View File

@ -78,10 +78,10 @@ public class EarlyTerminationDataSetIteratorTest extends BaseDL4JTest {
EarlyTerminationDataSetIterator earlyEndIter = new EarlyTerminationDataSetIterator(iter, terminateAfter);
earlyEndIter.next(10);
assertEquals(earlyEndIter.hasNext(), false);
assertEquals(false, earlyEndIter.hasNext());
earlyEndIter.reset();
assertEquals(earlyEndIter.hasNext(), true);
assertEquals(true, earlyEndIter.hasNext());
}

View File

@ -98,7 +98,7 @@ public class MultipleEpochsIteratorTest extends BaseDL4JTest {
while (multiIter.hasNext()) {
DataSet path = multiIter.next(10);
assertNotNull(path);
assertEquals(path.numExamples(), 10, 0.0);
assertEquals(10, path.numExamples(), 0.0);
}
assertEquals(epochs, multiIter.epochs);

View File

@ -33,7 +33,7 @@ public class SamplingTest extends BaseDL4JTest {
DataSetIterator iter = new MnistDataSetIterator(10, 10);
//batch size and total
DataSetIterator sampling = new SamplingDataSetIterator(iter.next(), 10, 10);
assertEquals(sampling.next().numExamples(), 10);
assertEquals(10, sampling.next().numExamples());
}
}

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.exceptions;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.exception.DL4JException;
import org.deeplearning4j.nn.conf.ConvolutionMode;
@ -34,6 +35,7 @@ import static org.junit.Assert.fail;
/**
* A set of tests to ensure that useful exceptions are thrown on invalid network configurations
*/
@Slf4j
public class TestInvalidConfigurations extends BaseDL4JTest {
public static MultiLayerNetwork getDensePlusOutput(int nIn, int nOut) {
@ -78,7 +80,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testDenseNin0(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail();
}
}
@ -96,7 +98,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testDenseNout0(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail();
}
}
@ -109,7 +111,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testOutputLayerNin0(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail();
}
}
@ -122,7 +124,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testRnnOutputLayerNin0(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail();
}
}
@ -135,7 +137,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testLSTMNIn0(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail();
}
}
@ -153,7 +155,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testLSTMNOut0(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail();
}
}
@ -166,7 +168,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testConvolutionalNin0(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail();
}
}
@ -185,7 +187,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testConvolutionalNOut0(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail();
}
}
@ -216,7 +218,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testCnnInvalidConfigPaddingStridesHeight(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail();
}
}
@ -245,7 +247,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testCnnInvalidConfigOrInput_SmallerDataThanKernel(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail();
}
}
@ -277,7 +279,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testCnnInvalidConfigOrInput_BadStrides(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail();
}
}
@ -318,7 +320,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testCnnInvalidConfigPaddingStridesWidth(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail();
}
}
@ -347,7 +349,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testCnnInvalidConfigPaddingStridesWidthSubsampling(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail();
}
}

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.exceptions;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.exception.DL4JException;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
@ -36,6 +37,7 @@ import static org.junit.Assert.*;
/**
* A set of tests to ensure that useful exceptions are thrown on invalid input
*/
@Slf4j
public class TestInvalidInput extends BaseDL4JTest {
@Test
@ -53,7 +55,7 @@ public class TestInvalidInput extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testInputNinMismatchDense(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail("Expected DL4JException");
}
}
@ -73,7 +75,7 @@ public class TestInvalidInput extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testInputNinMismatchOutputLayer(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail("Expected DL4JException");
}
}
@ -94,7 +96,7 @@ public class TestInvalidInput extends BaseDL4JTest {
//From loss function
System.out.println("testLabelsNOutMismatchOutputLayer(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail("Expected DL4JException");
}
}
@ -115,7 +117,7 @@ public class TestInvalidInput extends BaseDL4JTest {
//From loss function
System.out.println("testLabelsNOutMismatchRnnOutputLayer(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail("Expected DL4JException");
}
}
@ -142,7 +144,7 @@ public class TestInvalidInput extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testInputNinMismatchConvolutional(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail("Expected DL4JException");
}
}
@ -169,7 +171,7 @@ public class TestInvalidInput extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testInputNinRank2Convolutional(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail("Expected DL4JException");
}
}
@ -195,7 +197,7 @@ public class TestInvalidInput extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testInputNinRank2Subsampling(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail("Expected DL4JException");
}
}
@ -217,7 +219,7 @@ public class TestInvalidInput extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testInputNinMismatchLSTM(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail("Expected DL4JException");
}
}
@ -238,7 +240,7 @@ public class TestInvalidInput extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testInputNinMismatchBidirectionalLSTM(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail("Expected DL4JException");
}
@ -260,7 +262,7 @@ public class TestInvalidInput extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testInputNinMismatchEmbeddingLayer(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail("Expected DL4JException");
}
}
@ -305,7 +307,7 @@ public class TestInvalidInput extends BaseDL4JTest {
net.rnnTimeStep(Nd4j.create(5, 5, 10));
fail("Expected Exception - " + layerType);
} catch (Exception e) {
// e.printStackTrace();
log.error("",e);
String msg = e.getMessage();
assertTrue(msg, msg != null && msg.contains("rnn") && msg.contains("batch"));
}

View File

@ -21,6 +21,7 @@ import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils;
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
@ -34,6 +35,8 @@ import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -51,6 +54,7 @@ import static org.junit.Assert.*;
/**
* Created by nyghtowl on 9/1/15.
*/
@RunWith(Parameterized.class)
public class CNNGradientCheckTest extends BaseDL4JTest {
private static final boolean PRINT_RESULTS = true;
private static final boolean RETURN_ON_FIRST_FAILURE = false;
@ -62,6 +66,17 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
Nd4j.setDataType(DataType.DOUBLE);
}
private CNN2DFormat format;
public CNNGradientCheckTest(CNN2DFormat format){
this.format = format;
}
@Parameterized.Parameters(name = "{0}")
public static Object[] params(){
return CNN2DFormat.values();
}
@Override
public long getTimeoutMilliseconds() {
return 90000L;
@ -69,6 +84,9 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
@Test
public void testGradientCNNMLN() {
if(this.format != CNN2DFormat.NCHW) //Only test NCHW due to flat input format...
return;
//Parameterized test, testing combinations of:
// (a) activation function
// (b) Whether to test at random initialization, or after some learning (i.e., 'characteristic mode of operation')
@ -144,6 +162,9 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
@Test
public void testGradientCNNL1L2MLN() {
if(this.format != CNN2DFormat.NCHW) //Only test NCHW due to flat input format...
return;
//Parameterized test, testing combinations of:
// (a) activation function
// (b) Whether to test at random initialization, or after some learning (i.e., 'characteristic mode of operation')
@ -311,10 +332,12 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
new SubsamplingLayer.PoolingType[]{SubsamplingLayer.PoolingType.MAX,
SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM};
boolean nchw = format == CNN2DFormat.NCHW;
for (String afn : activations) {
for (SubsamplingLayer.PoolingType poolingType : poolingTypes) {
for (int minibatchSize : minibatchSizes) {
INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth);
long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth};
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
INDArray labels = Nd4j.zeros(4 * minibatchSize, nOut);
for (int i = 0; i < 4 * minibatchSize; i++) {
labels.putScalar(new int[]{i, i % nOut}, 1.0);
@ -330,13 +353,13 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX)
.nOut(nOut).build())
.setInputType(InputType.convolutionalFlat(height, width, inputDepth))
.setInputType(InputType.convolutional(height, width, inputDepth, format))
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
String msg = "PoolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn="
String msg = format + " - poolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn="
+ afn;
if (PRINT_RESULTS) {
@ -377,8 +400,11 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
int[] padding = {0, 0};
int size = 2;
boolean nchw = format == CNN2DFormat.NCHW;
for (int minibatchSize : minibatchSizes) {
INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth);
long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth};
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
INDArray labels = TestUtils.randomOneHot(minibatchSize, nOut);
MultiLayerConfiguration conf =
@ -393,8 +419,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(8 * 8 * 3)
.nOut(4).build())
.setInputType(InputType.convolutionalFlat(height, width,
inputDepth))
.setInputType(InputType.convolutional(height, width, inputDepth, format))
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
@ -438,10 +463,13 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
new SubsamplingLayer.PoolingType[]{SubsamplingLayer.PoolingType.MAX,
SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM};
boolean nchw = format == CNN2DFormat.NCHW;
for (Activation afn : activations) {
for (SubsamplingLayer.PoolingType poolingType : poolingTypes) {
for (int minibatchSize : minibatchSizes) {
INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth);
long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth};
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
INDArray labels = Nd4j.zeros(minibatchSize, nOut);
for (int i = 0; i < minibatchSize; i++) {
labels.putScalar(new int[]{i, i % nOut}, 1.0);
@ -461,14 +489,13 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(3 * 3 * 3)
.nOut(4).build())
.setInputType(InputType.convolutionalFlat(height, width,
inputDepth))
.setInputType(InputType.convolutional(height, width, inputDepth, format))
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
String msg = "PoolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn="
String msg = format + " - poolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn="
+ afn;
if (PRINT_RESULTS) {
@ -508,10 +535,13 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
new SubsamplingLayer.PoolingType[]{SubsamplingLayer.PoolingType.MAX,
SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM};
boolean nchw = format == CNN2DFormat.NCHW;
for (Activation afn : activations) {
for (SubsamplingLayer.PoolingType poolingType : poolingTypes) {
for (int minibatchSize : minibatchSizes) {
INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth);
long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth};
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
INDArray labels = Nd4j.zeros(minibatchSize, nOut);
for (int i = 0; i < minibatchSize; i++) {
labels.putScalar(new int[]{i, i % nOut}, 1.0);
@ -533,8 +563,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(2 * 2 * 2)
.nOut(4).build())
.setInputType(InputType.convolutionalFlat(height, width,
inputDepth))
.setInputType(InputType.convolutional(height, width, inputDepth, format))
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
@ -558,8 +587,6 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
@Test
public void testCnnLocallyConnected2D() {
int nOut = 3;
int[] minibatchSizes = {2};
int width = 5;
int height = 5;
@ -569,11 +596,15 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
Activation[] activations = {Activation.SIGMOID, Activation.TANH, Activation.SOFTPLUS};
int[] minibatch = {2, 1, 3};
boolean nchw = format == CNN2DFormat.NCHW;
for( int i=0; i<inputDepths.length; i++ ){
int inputDepth = inputDepths[i];
Activation afn = activations[i];
int minibatchSize = minibatch[i];
INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth);
long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth};
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
INDArray labels = TestUtils.randomOneHot(minibatchSize, nOut);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(new NoOp())
@ -590,7 +621,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(2 * 2 * 2).nOut(nOut)
.build())
.setInputType(InputType.convolutionalFlat(height, width, inputDepth)).build();
.setInputType(InputType.convolutional(height, width, inputDepth, format)).build();
assertEquals(ConvolutionMode.Truncate,
((ConvolutionLayer) conf.getConf(0).getLayer()).getConvolutionMode());
@ -626,11 +657,15 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
Nd4j.getRandom().setSeed(12345);
boolean nchw = format == CNN2DFormat.NCHW;
for (int inputDepth : inputDepths) {
for (Activation afn : activations) {
for (SubsamplingLayer.PoolingType poolingType : poolingTypes) {
for (int minibatchSize : minibatchSizes) {
INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth);
long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth};
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
INDArray labels = Nd4j.zeros(minibatchSize, nOut);
for (int i = 0; i < minibatchSize; i++) {
labels.putScalar(new int[]{i, i % nOut}, 1.0);
@ -649,7 +684,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(2 * 2 * 2).nOut(nOut)
.build())
.setInputType(InputType.convolutionalFlat(height, width, inputDepth)).build();
.setInputType(InputType.convolutional(height, width, inputDepth, format)).build();
assertEquals(ConvolutionMode.Truncate,
((ConvolutionLayer) conf.getConf(0).getLayer()).getConvolutionMode());
@ -691,13 +726,17 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
Nd4j.getRandom().setSeed(12345);
boolean nchw = format == CNN2DFormat.NCHW;
for( int i=0; i<minibatchSizes.length; i++ ){
int inputDepth = inputDepths[i];
int minibatchSize = minibatchSizes[i];
int height = heights[i];
int k = kernelSizes[i];
INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth);
long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth};
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
INDArray labels = TestUtils.randomOneHot(minibatchSize, nOut);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
@ -713,7 +752,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
.stride(1, 1).padding(0, 0).build())
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(nOut).build())
.setInputType(InputType.convolutionalFlat(height, width, inputDepth)).build();
.setInputType(InputType.convolutional(height, width, inputDepth, format)).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
@ -748,13 +787,16 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
Nd4j.getRandom().setSeed(12345);
boolean nchw = format == CNN2DFormat.NCHW;
for (int inputDepth : inputDepths) {
for (int minibatchSize : minibatchSizes) {
for (int stride : strides) {
for (int k : kernelSizes) {
for (boolean convFirst : new boolean[]{true, false}) {
long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth};
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth);
INDArray labels = Nd4j.zeros(minibatchSize, nOut);
for (int i = 0; i < minibatchSize; i++) {
labels.putScalar(new int[]{i, i % nOut}, 1.0);
@ -775,7 +817,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
.layer(1, convFirst ? poolLayer : convLayer)
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(nOut).build())
.setInputType(InputType.convolutionalFlat(height, width, inputDepth))
.setInputType(InputType.convolutional(height, width, inputDepth, format))
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
@ -822,11 +864,15 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
int[] inputDepths = {1, 3, 2};
int[][] zeroPadLayer = new int[][]{{0, 0, 0, 0}, {1, 1, 0, 0}, {2, 2, 2, 2}};
boolean nchw = format == CNN2DFormat.NCHW;
for( int i=0; i<minibatchSizes.length; i++ ){
int minibatchSize = minibatchSizes[i];
int inputDepth = inputDepths[i];
int[] zeroPad = zeroPadLayer[i];
INDArray input = Nd4j.rand(DataType.DOUBLE, new int[]{minibatchSize, inputDepth, height, width});
long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth};
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
INDArray labels = TestUtils.randomOneHot(minibatchSize, nOut);
MultiLayerConfiguration conf =
@ -840,7 +886,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
padding).nIn(3).nOut(3).build())//output: (6-2+0)/1+1 = 5
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(4).build())
.setInputType(InputType.convolutional(height, width, inputDepth))
.setInputType(InputType.convolutional(height, width, inputDepth, format))
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
@ -849,8 +895,14 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
//Check zero padding activation shape
org.deeplearning4j.nn.layers.convolution.ZeroPaddingLayer zpl =
(org.deeplearning4j.nn.layers.convolution.ZeroPaddingLayer) net.getLayer(1);
val expShape = new long[]{minibatchSize, inputDepth, height + zeroPad[0] + zeroPad[1],
width + zeroPad[2] + zeroPad[3]};
long[] expShape;
if(nchw){
expShape = new long[]{minibatchSize, inputDepth, height + zeroPad[0] + zeroPad[1],
width + zeroPad[2] + zeroPad[3]};
} else {
expShape = new long[]{minibatchSize, height + zeroPad[0] + zeroPad[1],
width + zeroPad[2] + zeroPad[3], inputDepth};
}
INDArray out = zpl.activate(input, false, LayerWorkspaceMgr.noWorkspaces());
assertArrayEquals(expShape, out.shape());
@ -888,6 +940,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
Nd4j.getRandom().setSeed(12345);
boolean nchw = format == CNN2DFormat.NCHW;
for (int i = 0; i < minibatchSizes.length; i++) {
int minibatchSize = minibatchSizes[i];
int k = kernelSizes[i];
@ -900,7 +954,10 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
int w = d * width;
int h = d * height;
INDArray input = Nd4j.rand(minibatchSize, w * h * inputDepth);
long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, h, w} : new long[]{minibatchSize, h, w, inputDepth};
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
INDArray labels = Nd4j.zeros(minibatchSize, nOut);
for (int j = 0; j < minibatchSize; j++) {
labels.putScalar(new int[]{j, j % nOut}, 1.0);
@ -920,7 +977,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
MultiLayerConfiguration conf = b.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(nOut).build())
.setInputType(InputType.convolutionalFlat(h, w, inputDepth)).build();
.setInputType(InputType.convolutional(h, w, inputDepth, format)).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
@ -945,8 +1002,6 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
@Test
public void testSeparableConv2D() {
int nOut = 2;
int[] minibatchSizes = new int[]{1, 3};
int width = 6;
int height = 6;
int inputDepth = 3;
@ -959,6 +1014,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
ConvolutionMode[] cms = new ConvolutionMode[]{Truncate, Truncate, Truncate, Truncate, Truncate};
int[] mb = new int[]{1, 1, 1, 3, 3};
boolean nchw = format == CNN2DFormat.NCHW;
for (int t = 0; t < ks.length; t++) {
int k = ks[t];
@ -971,7 +1028,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
int w = d * width;
int h = d * height;
INDArray input = Nd4j.rand(minibatchSize, w * h * inputDepth);
long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, h, w} : new long[]{minibatchSize, h, w, inputDepth};
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
INDArray labels = Nd4j.zeros(minibatchSize, nOut);
for (int i = 0; i < minibatchSize; i++) {
labels.putScalar(new int[]{i, i % nOut}, 1.0);
@ -992,7 +1050,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
MultiLayerConfiguration conf = b.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(nOut).build())
.setInputType(InputType.convolutionalFlat(h, w, inputDepth)).build();
.setInputType(InputType.convolutional(h, w, inputDepth, format)).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
@ -1017,7 +1075,6 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
@Test
public void testCnnDilated() {
int nOut = 2;
int minibatchSize = 2;
int width = 8;
int height = 8;
@ -1031,9 +1088,9 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
int[] ds = new int[]{2, 2, 3, 3, 2};
ConvolutionMode[] cms = new ConvolutionMode[]{Same, Truncate, Truncate, Same, Truncate};
boolean nchw = format == CNN2DFormat.NCHW;
for (int t = 0; t < sub.length; t++) {
boolean subsampling = sub[t];
int s = stride[t];
int k = kernel[t];
@ -1044,7 +1101,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
int w = d * width;
int h = d * height;
INDArray input = Nd4j.rand(minibatchSize, w * h * inputDepth);
long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, h, w} : new long[]{minibatchSize, h, w, inputDepth};
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
INDArray labels = Nd4j.zeros(minibatchSize, nOut);
for (int i = 0; i < minibatchSize; i++) {
labels.putScalar(new int[]{i, i % nOut}, 1.0);
@ -1076,7 +1134,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
MultiLayerConfiguration conf = b.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(nOut).build())
.setInputType(InputType.convolutionalFlat(h, w, inputDepth)).build();
.setInputType(InputType.convolutional(h, w, inputDepth, format)).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
@ -1114,11 +1172,14 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
int[] inputDepths = {1, 2, 3, 2};
int[] minibatchSizes = {2, 1, 3, 2};
boolean nchw = format == CNN2DFormat.NCHW;
for (int i = 0; i < cropTestCases.length; i++) {
int inputDepth = inputDepths[i];
int minibatchSize = minibatchSizes[i];
int[] crop = cropTestCases[i];
INDArray input = Nd4j.rand(new int[]{minibatchSize, inputDepth, height, width});
long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth};
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
INDArray labels = TestUtils.randomOneHot(minibatchSize, nOut);
MultiLayerConfiguration conf =
@ -1134,7 +1195,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
.layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.AVG).kernelSize(3, 3).stride(3, 3).build())
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(nOut).build())
.setInputType(InputType.convolutional(height, width, inputDepth))
.setInputType(InputType.convolutional(height, width, inputDepth, format))
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
@ -1143,12 +1204,18 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
//Check cropping activation shape
org.deeplearning4j.nn.layers.convolution.Cropping2DLayer cl =
(org.deeplearning4j.nn.layers.convolution.Cropping2DLayer) net.getLayer(1);
val expShape = new long[]{minibatchSize, inputDepth, height - crop[0] - crop[1],
width - crop[2] - crop[3]};
long[] expShape;
if(nchw){
expShape = new long[]{minibatchSize, inputDepth, height - crop[0] - crop[1],
width - crop[2] - crop[3]};
} else {
expShape = new long[]{minibatchSize, height - crop[0] - crop[1],
width - crop[2] - crop[3], inputDepth};
}
INDArray out = cl.activate(input, false, LayerWorkspaceMgr.noWorkspaces());
assertArrayEquals(expShape, out.shape());
String msg = "minibatch=" + minibatchSize + ", channels=" + inputDepth + ", zeroPad = "
String msg = format + " - minibatch=" + minibatchSize + ", channels=" + inputDepth + ", zeroPad = "
+ Arrays.toString(crop);
if (PRINT_RESULTS) {
@ -1181,6 +1248,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
Truncate, Truncate, Truncate, Truncate, Truncate};
int[] mb = new int[]{1,1,1,3,3};
boolean nchw = format == CNN2DFormat.NCHW;
for( int t=0; t<ks.length; t++ ){
int k = ks[t];
@ -1188,8 +1257,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
ConvolutionMode cm = cms[t];
int minibatchSize = mb[t];
INDArray input = Nd4j.rand(minibatchSize, width * height * nIn);
long[] inShape = nchw ? new long[]{minibatchSize, nIn, height, width} : new long[]{minibatchSize, height, width, nIn};
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
INDArray labels = Nd4j.zeros(minibatchSize, nOut);
for (int i = 0; i < minibatchSize; i++) {
labels.putScalar(new int[]{i, i % nOut}, 1.0);
@ -1211,7 +1280,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
MultiLayerConfiguration conf = b.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(nOut).build())
.setInputType(InputType.convolutionalFlat(height, width, nIn)).build();
.setInputType(InputType.convolutional(height, width, nIn, format)).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();

View File

@ -18,6 +18,7 @@ package org.deeplearning4j.gradientcheck;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
@ -115,55 +116,57 @@ public class GlobalPoolingGradientCheckTests extends BaseDL4JTest {
//Basic test of global pooling w/ CNN
Nd4j.getRandom().setSeed(12345L);
int inputDepth = 3;
int inputH = 5;
int inputW = 4;
int layerDepth = 4;
int nOut = 2;
for(boolean nchw : new boolean[]{true, false}) {
int[] minibatchSizes = new int[] {1, 3};
PoolingType[] poolingTypes =
new PoolingType[] {PoolingType.AVG, PoolingType.SUM, PoolingType.MAX, PoolingType.PNORM};
int inputDepth = 3;
int inputH = 5;
int inputW = 4;
int layerDepth = 4;
int nOut = 2;
for (int miniBatchSize : minibatchSizes) {
for (PoolingType pt : poolingTypes) {
int[] minibatchSizes = new int[]{1, 3};
PoolingType[] poolingTypes =
new PoolingType[]{PoolingType.AVG, PoolingType.SUM, PoolingType.MAX, PoolingType.PNORM};
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.updater(new NoOp())
.dist(new NormalDistribution(0, 1.0)).seed(12345L).list()
.layer(0, new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).nOut(layerDepth)
.build())
.layer(1, new GlobalPoolingLayer.Builder().poolingType(pt).build())
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(nOut).build())
for (int miniBatchSize : minibatchSizes) {
for (PoolingType pt : poolingTypes) {
.setInputType(InputType.convolutional(inputH, inputW, inputDepth)).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.updater(new NoOp())
.dist(new NormalDistribution(0, 1.0)).seed(12345L).list()
.layer(0, new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).nOut(layerDepth)
.build())
.layer(1, new GlobalPoolingLayer.Builder().poolingType(pt).build())
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(nOut).build())
.setInputType(InputType.convolutional(inputH, inputW, inputDepth, nchw ? CNN2DFormat.NCHW : CNN2DFormat.NHWC)).build();
MultiLayerNetwork mln = new MultiLayerNetwork(conf);
mln.init();
MultiLayerNetwork mln = new MultiLayerNetwork(conf);
mln.init();
Random r = new Random(12345L);
INDArray input = Nd4j.rand(new int[] {miniBatchSize, inputDepth, inputH, inputW}).subi(0.5);
Random r = new Random(12345L);
long[] inShape = nchw ? new long[]{miniBatchSize, inputDepth, inputH, inputW} : new long[]{miniBatchSize, inputH, inputW, inputDepth};
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape).subi(0.5);
INDArray labels = Nd4j.zeros(miniBatchSize, nOut);
for (int i = 0; i < miniBatchSize; i++) {
int idx = r.nextInt(nOut);
labels.putScalar(i, idx, 1.0);
}
INDArray labels = Nd4j.zeros(miniBatchSize, nOut);
for (int i = 0; i < miniBatchSize; i++) {
int idx = r.nextInt(nOut);
labels.putScalar(i, idx, 1.0);
}
if (PRINT_RESULTS) {
System.out.println(
"testCnnGlobalPoolingBasicMultiLayer() - " + pt + ", minibatch = " + miniBatchSize);
if (PRINT_RESULTS) {
System.out.println("testCnnGlobalPoolingBasicMultiLayer() - " + pt + ", minibatch = " + miniBatchSize + " - " + (nchw ? "NCHW" : "NHWC"));
// for (int j = 0; j < mln.getnLayers(); j++)
// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
}
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
assertTrue(gradOK);
TestUtils.testModelSerialization(mln);
}
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
assertTrue(gradOK);
TestUtils.testModelSerialization(mln);
}
}
}

View File

@ -24,6 +24,7 @@ import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
import org.deeplearning4j.nn.conf.layers.*;
@ -560,75 +561,81 @@ public class GradientCheckTests extends BaseDL4JTest {
public void testEmbeddingSequenceLayer(){
Nd4j.getRandom().setSeed(12345);
for(boolean maskArray : new boolean[]{false, true}){
for(int inputRank : new int[]{2,3}) {
for(RNNFormat seqOutputFormat : RNNFormat.values()) {
for (boolean maskArray : new boolean[]{false, true}) {
for (int inputRank : new int[]{2, 3}) {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.seed(12345)
.updater(new NoOp())
.weightInit(new NormalDistribution(0, 1))
.list()
.layer(new EmbeddingSequenceLayer.Builder()
.nIn(8)
.nOut(4)
.build())
.layer(new RnnOutputLayer.Builder().nIn(4).nOut(3).activation(Activation.TANH)
.lossFunction(LossFunction.MSE).build())
.build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.seed(12345)
.updater(new NoOp())
.weightInit(new NormalDistribution(0, 1))
.list()
.layer(new EmbeddingSequenceLayer.Builder()
.nIn(8)
.nOut(4)
.outputDataFormat(seqOutputFormat)
.build())
.layer(new RnnOutputLayer.Builder().nIn(4).nOut(3).activation(Activation.TANH)
.dataFormat(seqOutputFormat)
.lossFunction(LossFunction.MSE).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
INDArray in = Transforms.floor(Nd4j.rand(3, 6).muli(8)); //Integers 0 to 7 inclusive
INDArray label = Nd4j.rand(new int[]{3, 3, 6});
boolean ncw = seqOutputFormat == RNNFormat.NCW;
if(inputRank == 3){
//Reshape from [3,6] to [3,1,6]
in = in.reshape('c', 3, 1, 6);
}
INDArray in = Transforms.floor(Nd4j.rand(3, 6).muli(8)); //Integers 0 to 7 inclusive
INDArray label = Nd4j.rand(DataType.FLOAT, ncw ? new int[]{3, 3, 6} : new int[]{3,6,3});
INDArray fMask = null;
if (maskArray) {
fMask = Nd4j.create(new double[][]{{1, 1, 1, 1, 1, 1},
{1, 1, 0, 0, 0, 0},
{1, 0, 0, 0, 0, 0}});
}
String msg = "mask=" + maskArray + ", inputRank=" + inputRank;
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in)
.labels(label).inputMask(fMask));
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(net);
//Also: if mask is present, double check that the masked steps don't impact score
if (maskArray) {
DataSet ds = new DataSet(in, label, fMask, null);
double score = net.score(ds);
if(inputRank == 2){
in.putScalar(1, 2, 0);
in.putScalar(2, 1, 0);
in.putScalar(2, 2, 0);
} else {
in.putScalar(1, 0, 2, 0);
in.putScalar(2, 0, 1, 0);
in.putScalar(2, 0, 2, 0);
if (inputRank == 3) {
//Reshape from [3,6] to [3,1,6]
in = in.reshape('c', 3, 1, 6);
}
double score2 = net.score(ds);
assertEquals(score, score2, 1e-6);
if(inputRank == 2){
in.putScalar(1, 2, 1);
in.putScalar(2, 1, 1);
in.putScalar(2, 2, 1);
} else {
in.putScalar(1, 0, 2, 1);
in.putScalar(2, 0, 1, 1);
in.putScalar(2, 0, 2, 1);
INDArray fMask = null;
if (maskArray) {
fMask = Nd4j.create(new double[][]{{1, 1, 1, 1, 1, 1},
{1, 1, 0, 0, 0, 0},
{1, 0, 0, 0, 0, 0}});
}
String msg = "mask=" + maskArray + ", inputRank=" + inputRank;
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in)
.labels(label).inputMask(fMask));
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(net);
//Also: if mask is present, double check that the masked steps don't impact score
if (maskArray) {
DataSet ds = new DataSet(in, label, fMask, null);
double score = net.score(ds);
if (inputRank == 2) {
in.putScalar(1, 2, 0);
in.putScalar(2, 1, 0);
in.putScalar(2, 2, 0);
} else {
in.putScalar(1, 0, 2, 0);
in.putScalar(2, 0, 1, 0);
in.putScalar(2, 0, 2, 0);
}
double score2 = net.score(ds);
assertEquals(score, score2, 1e-6);
if (inputRank == 2) {
in.putScalar(1, 2, 1);
in.putScalar(2, 1, 1);
in.putScalar(2, 2, 1);
} else {
in.putScalar(1, 0, 2, 1);
in.putScalar(2, 0, 1, 1);
in.putScalar(2, 0, 2, 1);
}
double score3 = net.score(ds);
assertEquals(score, score3, 1e-6);
}
double score3 = net.score(ds);
assertEquals(score, score3, 1e-6);
}
}
}

View File

@ -21,9 +21,7 @@ import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils;
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.conf.distribution.GaussianDistribution;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
@ -341,104 +339,112 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
@Test
public void testCnnDepthMerge() {
Nd4j.getRandom().setSeed(12345);
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
.dataType(DataType.DOUBLE)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.dist(new NormalDistribution(0, 0.1))
.updater(new NoOp()).graphBuilder().addInputs("input")
.addLayer("l1", new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).padding(0, 0)
.nIn(2).nOut(2).activation(Activation.TANH).build(), "input")
.addLayer("l2", new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).padding(0, 0)
.nIn(2).nOut(2).activation(Activation.TANH).build(), "input")
.addVertex("merge", new MergeVertex(), "l1", "l2")
.addLayer("outputLayer",
new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(5 * 5 * (2 + 2)).nOut(3)
.build(),
"merge")
.setOutputs("outputLayer")
.inputPreProcessor("outputLayer", new CnnToFeedForwardPreProcessor(5, 5, 4))
.build();
for(CNN2DFormat format : CNN2DFormat.values()) {
ComputationGraph graph = new ComputationGraph(conf);
graph.init();
String msg = "testCnnDepthMerge - " + format;
Random r = new Random(12345);
INDArray input = Nd4j.rand(new int[] {5, 2, 6, 6}); //Order: examples, channels, height, width
INDArray labels = Nd4j.zeros(5, 3);
for (int i = 0; i < 5; i++)
labels.putScalar(new int[] {i, r.nextInt(3)}, 1.0);
Nd4j.getRandom().setSeed(12345);
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
.dataType(DataType.DOUBLE)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.dist(new NormalDistribution(0, 0.1))
.updater(new NoOp()).graphBuilder().addInputs("input")
.addLayer("l1", new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).padding(0, 0)
.nIn(2).nOut(2).activation(Activation.TANH).build(), "input")
.addLayer("l2", new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).padding(0, 0)
.nIn(2).nOut(2).activation(Activation.TANH).build(), "input")
.addVertex("merge", new MergeVertex(), "l1", "l2")
.addLayer("outputLayer",
new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(5 * 5 * (2 + 2)).nOut(3)
.build(),
"merge")
.setOutputs("outputLayer")
.setInputTypes(InputType.convolutional(6, 6, 2, format))
.build();
if (PRINT_RESULTS) {
System.out.println("testCnnDepthMerge()");
ComputationGraph graph = new ComputationGraph(conf);
graph.init();
Random r = new Random(12345);
INDArray input = Nd4j.rand(DataType.DOUBLE, format == CNN2DFormat.NCHW ? new long[]{5,2,6,6} : new long[]{5,6,6,2});
INDArray labels = Nd4j.zeros(5, 3);
for (int i = 0; i < 5; i++)
labels.putScalar(new int[]{i, r.nextInt(3)}, 1.0);
if (PRINT_RESULTS) {
System.out.println(msg);
// for (int j = 0; j < graph.getNumLayers(); j++)
// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
}
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
.labels(new INDArray[]{labels}));
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(graph);
}
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
.labels(new INDArray[]{labels}));
String msg = "testCnnDepthMerge()";
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(graph);
}
@Test
public void testRNNWithMerging() {
Nd4j.getRandom().setSeed(12345);
ComputationGraphConfiguration conf =
new NeuralNetConfiguration.Builder().seed(12345)
.dataType(DataType.DOUBLE)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.dist(new UniformDistribution(0.2, 0.6))
.updater(new NoOp()).graphBuilder().addInputs("input")
.setOutputs("out")
.addLayer("lstm1",
new SimpleRnn.Builder().nIn(3).nOut(3)
.activation(Activation.TANH).build(),
"input")
.addLayer("lstm2",
new SimpleRnn.Builder().nIn(3).nOut(3)
.activation(Activation.TANH).build(),
"lstm1")
.addLayer("dense1",
new DenseLayer.Builder().nIn(3).nOut(3)
.activation(Activation.SIGMOID).build(),
"lstm1")
.addLayer("lstm3",
new SimpleRnn.Builder().nIn(3).nOut(3)
.activation(Activation.TANH).build(),
"dense1")
.addVertex("merge", new MergeVertex(), "lstm2", "lstm3")
.addLayer("out", new RnnOutputLayer.Builder().nIn(6).nOut(3)
.activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build(),
"merge")
.inputPreProcessor("dense1", new RnnToFeedForwardPreProcessor())
.inputPreProcessor("lstm3", new FeedForwardToRnnPreProcessor())
.build();
for(RNNFormat format : RNNFormat.values()) {
ComputationGraph graph = new ComputationGraph(conf);
graph.init();
String msg = "testLSTMWithMerging - " + format;
Random r = new Random(12345);
INDArray input = Nd4j.rand(new int[] {2, 3, 4});
INDArray labels = TestUtils.randomOneHotTimeSeries(2, 3, 4);
Nd4j.getRandom().setSeed(12345);
ComputationGraphConfiguration conf =
new NeuralNetConfiguration.Builder().seed(12345)
.dataType(DataType.DOUBLE)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.dist(new UniformDistribution(0.2, 0.6))
.updater(new NoOp()).graphBuilder().addInputs("input")
.setOutputs("out")
.addLayer("lstm1",
new SimpleRnn.Builder().nIn(3).nOut(3)
.activation(Activation.TANH).build(),
"input")
.addLayer("lstm2",
new SimpleRnn.Builder().nIn(3).nOut(3)
.activation(Activation.TANH).build(),
"lstm1")
.addLayer("dense1",
new DenseLayer.Builder().nIn(3).nOut(3)
.activation(Activation.SIGMOID).build(),
"lstm1")
.addLayer("lstm3",
new SimpleRnn.Builder().nIn(3).nOut(3)
.activation(Activation.TANH).build(),
"dense1")
.addVertex("merge", new MergeVertex(), "lstm2", "lstm3")
.addLayer("out", new RnnOutputLayer.Builder().nIn(6).nOut(3)
.activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build(),
"merge")
.setInputTypes(InputType.recurrent(4, format))
.build();
if (PRINT_RESULTS) {
System.out.println("testLSTMWithMerging()");
ComputationGraph graph = new ComputationGraph(conf);
graph.init();
Random r = new Random(12345);
INDArray input = Nd4j.rand(DataType.DOUBLE, format == RNNFormat.NCW ? new long[]{2, 3, 4} : new long[]{2,4,3});
INDArray labels = TestUtils.randomOneHotTimeSeries(format, 2, 3, 4, new Random(12345));
if (PRINT_RESULTS) {
System.out.println(msg);
// for (int j = 0; j < graph.getNumLayers(); j++)
// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
}
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
.labels(new INDArray[]{labels}));
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(graph);
}
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
.labels(new INDArray[]{labels}));
String msg = "testLSTMWithMerging()";
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(graph);
}
@Test

View File

@ -216,7 +216,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
failed.add(testName + "\t" + "EXCEPTION");
continue;
}
@ -383,7 +383,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
failed.add(testName + "\t" + "EXCEPTION");
continue;
}
@ -693,7 +693,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
failed.add(testName + "\t" + "EXCEPTION");
continue;
}

View File

@ -338,7 +338,7 @@ public class RnnGradientChecks extends BaseDL4JTest {
.weightInit(WeightInit.XAVIER)
.list()
.layer(new LSTM.Builder().nOut(layerSize).build())
.layer(new TimeDistributed(new DenseLayer.Builder().nOut(layerSize).activation(Activation.SOFTMAX).build(), 2))
.layer(new TimeDistributed(new DenseLayer.Builder().nOut(layerSize).activation(Activation.SOFTMAX).build()))
.layer(new RnnOutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.setInputType(InputType.recurrent(nIn))

View File

@ -21,6 +21,7 @@ import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.exception.DL4JInvalidConfigException;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
@ -47,6 +48,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
import static org.junit.Assert.*;
@Slf4j
public class ComputationGraphConfigurationTest extends BaseDL4JTest {
@Test
@ -150,7 +152,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) {
//OK - exception is good
//e.printStackTrace();
log.info(e.toString());
}
// Use appendLayer on first layer
@ -162,7 +164,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) {
//OK - exception is good
//e.printStackTrace();
log.info(e.toString());
}
//Test no network inputs
@ -174,7 +176,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) {
//OK - exception is good
//e.printStackTrace();
log.info(e.toString());
}
//Test no network outputs
@ -185,7 +187,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) {
//OK - exception is good
//e.printStackTrace();
log.info(e.toString());
}
//Test: invalid input
@ -197,7 +199,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) {
//OK - exception is good
//e.printStackTrace();
log.info(e.toString());
}
//Test: graph with cycles
@ -215,7 +217,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) {
//OK - exception is good
//e.printStackTrace();
log.info(e.toString());
}
//Test: input != inputType count mismatch
@ -241,7 +243,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
fail("No exception thrown for invalid configuration");
} catch (IllegalArgumentException e) {
//OK - exception is good
//e.printStackTrace();
log.info(e.toString());
}
}

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.nn.conf;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.exception.DL4JInvalidConfigException;
import org.deeplearning4j.nn.api.Layer;
@ -46,6 +47,7 @@ import static org.junit.Assert.*;
/**
* Created by agibsonccc on 11/27/14.
*/
@Slf4j
public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
@Rule
@ -272,9 +274,9 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) {
//OK
e.printStackTrace();
log.error("",e);
} catch (Throwable e) {
e.printStackTrace();
log.error("",e);
fail("Unexpected exception thrown for invalid config");
}
@ -288,9 +290,9 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) {
//OK
e.printStackTrace();
log.info(e.toString());
} catch (Throwable e) {
e.printStackTrace();
log.error("",e);
fail("Unexpected exception thrown for invalid config");
}
@ -304,9 +306,9 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) {
//OK
e.printStackTrace();
log.info(e.toString());
} catch (Throwable e) {
e.printStackTrace();
log.error("",e);
fail("Unexpected exception thrown for invalid config");
}
}

View File

@ -96,8 +96,8 @@ public class TestConstraints extends BaseDL4JTest {
} else if (lc instanceof NonNegativeConstraint) {
assertTrue(RW0.minNumber().doubleValue() >= 0.0);
} else if (lc instanceof UnitNormConstraint) {
assertEquals(RW0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6);
assertEquals(RW0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6);
assertEquals(1.0, RW0.norm2(1).minNumber().doubleValue(), 1e-6);
assertEquals(1.0, RW0.norm2(1).maxNumber().doubleValue(), 1e-6);
}
TestUtils.testModelSerialization(net);
@ -149,8 +149,8 @@ public class TestConstraints extends BaseDL4JTest {
} else if (lc instanceof NonNegativeConstraint) {
assertTrue(b0.minNumber().doubleValue() >= 0.0);
} else if (lc instanceof UnitNormConstraint) {
assertEquals(b0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6);
assertEquals(b0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6);
assertEquals(1.0, b0.norm2(1).minNumber().doubleValue(), 1e-6);
assertEquals(1.0, b0.norm2(1).maxNumber().doubleValue(), 1e-6);
}
TestUtils.testModelSerialization(net);
@ -201,8 +201,8 @@ public class TestConstraints extends BaseDL4JTest {
} else if (lc instanceof NonNegativeConstraint) {
assertTrue(w0.minNumber().doubleValue() >= 0.0);
} else if (lc instanceof UnitNormConstraint) {
assertEquals(w0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6);
assertEquals(w0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6);
assertEquals(1.0, w0.norm2(1).minNumber().doubleValue(), 1e-6);
assertEquals(1.0, w0.norm2(1).maxNumber().doubleValue(), 1e-6);
}
TestUtils.testModelSerialization(net);
@ -259,10 +259,10 @@ public class TestConstraints extends BaseDL4JTest {
assertTrue(w0.minNumber().doubleValue() >= 0.0);
assertTrue(b0.minNumber().doubleValue() >= 0.0);
} else if (lc instanceof UnitNormConstraint) {
assertEquals(w0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6);
assertEquals(w0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6);
assertEquals(b0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6);
assertEquals(b0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6);
assertEquals(1.0, w0.norm2(1).minNumber().doubleValue(), 1e-6);
assertEquals(1.0, w0.norm2(1).maxNumber().doubleValue(), 1e-6);
assertEquals(1.0, b0.norm2(1).minNumber().doubleValue(), 1e-6);
assertEquals(1.0, b0.norm2(1).maxNumber().doubleValue(), 1e-6);
}
TestUtils.testModelSerialization(net);
@ -320,10 +320,10 @@ public class TestConstraints extends BaseDL4JTest {
assertTrue(w0.minNumber().doubleValue() >= 0.0);
assertTrue(b0.minNumber().doubleValue() >= 0.0);
} else if (lc instanceof UnitNormConstraint) {
assertEquals(w0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6);
assertEquals(w0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6);
assertEquals(b0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6);
assertEquals(b0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6);
assertEquals(1.0, w0.norm2(1).minNumber().doubleValue(), 1e-6);
assertEquals(1.0, w0.norm2(1).maxNumber().doubleValue(), 1e-6);
assertEquals(1.0, b0.norm2(1).minNumber().doubleValue(), 1e-6);
assertEquals(1.0, b0.norm2(1).maxNumber().doubleValue(), 1e-6);
}
TestUtils.testModelSerialization(net);
@ -378,10 +378,10 @@ public class TestConstraints extends BaseDL4JTest {
} else if(lc instanceof NonNegativeConstraint ){
assertTrue(w0.minNumber().doubleValue() >= 0.0 );
} else if(lc instanceof UnitNormConstraint ){
assertEquals(w0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6 );
assertEquals(w0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6 );
assertEquals(w1.norm2(1).minNumber().doubleValue(), 1.0, 1e-6 );
assertEquals(w1.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6 );
assertEquals(1.0, w0.norm2(1).minNumber().doubleValue(), 1e-6 );
assertEquals(1.0, w0.norm2(1).maxNumber().doubleValue(), 1e-6 );
assertEquals(1.0, w1.norm2(1).minNumber().doubleValue(), 1e-6 );
assertEquals(1.0, w1.norm2(1).maxNumber().doubleValue(), 1e-6 );
}
TestUtils.testModelSerialization(net);

View File

@ -156,7 +156,7 @@ public class LayerBuilderTest extends BaseDL4JTest {
checkSerialization(glstm);
assertEquals(glstm.getForgetGateBiasInit(), 1.5, 0.0);
assertEquals(1.5, glstm.getForgetGateBiasInit(), 0.0);
assertEquals(glstm.nIn, numIn);
assertEquals(glstm.nOut, numOut);
assertTrue(glstm.getActivationFn() instanceof ActivationTanH);

View File

@ -17,7 +17,9 @@
package org.deeplearning4j.nn.dtypes;
import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
import org.deeplearning4j.nn.conf.preprocessor.*;
import org.deeplearning4j.nn.modelimport.keras.layers.TFOpLayer;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.TensorFlowCnnToFeedForwardPreProcessor;
import org.nd4j.shade.guava.collect.ImmutableSet;
import org.nd4j.shade.guava.reflect.ClassPath;
import lombok.extern.slf4j.Slf4j;
@ -51,16 +53,11 @@ import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
import org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer;
import org.deeplearning4j.nn.conf.preprocessor.CnnToRnnPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.ComposableInputPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToCnn3DPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.RnnToCnnPreProcessor;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.util.IdentityLayer;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.PermutePreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.TensorFlowCnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.WeightInitDistribution;
@ -97,7 +94,8 @@ public class DTypeTests extends BaseDL4JTest {
Pooling2D.class, //Alias for SubsamplingLayer
Convolution2D.class, //Alias for ConvolutionLayer
Pooling1D.class, //Alias for Subsampling1D
Convolution1D.class //Alias for Convolution1DLayer
Convolution1D.class, //Alias for Convolution1DLayer
TensorFlowCnnToFeedForwardPreProcessor.class //Deprecated
));
@Override
@ -819,7 +817,7 @@ public class DTypeTests extends BaseDL4JTest {
.layer(new DenseLayer.Builder().nOut(5).build())
.layer(new GravesBidirectionalLSTM.Builder().nIn(5).nOut(5).activation(Activation.TANH).build())
.layer(new Bidirectional(new LSTM.Builder().nIn(5).nOut(5).activation(Activation.TANH).build()))
.layer(new TimeDistributed(new DenseLayer.Builder().nIn(10).nOut(5).activation(Activation.TANH).build(), 2))
.layer(new TimeDistributed(new DenseLayer.Builder().nIn(10).nOut(5).activation(Activation.TANH).build()))
.layer(new SimpleRnn.Builder().nIn(5).nOut(5).build())
.layer(new MaskZeroLayer.Builder().underlying(new SimpleRnn.Builder().nIn(5).nOut(5).build()).maskValue(0.0).build())
.layer(secondLast)
@ -1078,7 +1076,7 @@ public class DTypeTests extends BaseDL4JTest {
.addLayer("l", new DenseLayer.Builder().nOut(16).build(), "in")
.addVertex("preproc", new PreprocessorVertex(new FeedForwardToCnn3DPreProcessor(2, 2, 2, 2, true)), "l")
.addVertex("preproc2", new PreprocessorVertex(new PermutePreprocessor(0, 2, 3, 4, 1)), "preproc")
.addVertex("preproc3", new PreprocessorVertex(new ReshapePreprocessor(new long[]{2, 2, 2, 2}, new long[]{16})), "preproc2")
.addVertex("preproc3", new PreprocessorVertex(new ReshapePreprocessor(new long[]{2, 2, 2, 2}, new long[]{16}, false)), "preproc2")
.addLayer("out", new OutputLayer.Builder().nIn(16).nOut(10).build(), "preproc3")
.setInputTypes(InputType.feedForward(5))
.setOutputs("out");
@ -1150,7 +1148,7 @@ public class DTypeTests extends BaseDL4JTest {
case 7:
b.addInputs("in")
.addLayer("1", new ConvolutionLayer.Builder().kernelSize(2, 2).nOut(5).convolutionMode(ConvolutionMode.Same).build(), "in")
.addVertex("2", new PreprocessorVertex(new TensorFlowCnnToFeedForwardPreProcessor(28, 28, 5)), "1")
.addVertex("2", new PreprocessorVertex(new CnnToFeedForwardPreProcessor(28, 28, 5)), "1")
.addLayer("out", new OutputLayer.Builder().nOut(10).build(), "2")
.setOutputs("out")
.setInputTypes(InputType.convolutional(28, 28, 1));

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.nn.graph;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.IteratorDataSetIterator;
@ -54,6 +55,7 @@ import java.util.Map;
import static org.junit.Assert.*;
@Slf4j
public class ComputationGraphTestRNN extends BaseDL4JTest {
@Test
@ -618,7 +620,7 @@ public class ComputationGraphTestRNN extends BaseDL4JTest {
.build();
fail("Exception expected");
} catch (IllegalStateException e){
// e.printStackTrace();
log.error("",e);
assertTrue(e.getMessage().contains("TBPTT") && e.getMessage().contains("validateTbpttConfig"));
}
}

View File

@ -1394,7 +1394,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
} catch (Exception e) {
//e.printStackTrace();
log.error("",e);
if(allowDisconnected){
fail("No exception expected");
} else {

View File

@ -416,8 +416,8 @@ public class TestVariableLengthTSCG extends BaseDL4JTest {
INDArray outRow2 = out2.get(NDArrayIndex.point(i), NDArrayIndex.all(),
NDArrayIndex.point(j));
for (int k = 0; k < nOut; k++) {
assertEquals(outRow.getDouble(k), 0.0, 0.0);
assertEquals(outRow2.getDouble(k), 0.0, 0.0);
assertEquals(0.0, outRow.getDouble(k), 0.0);
assertEquals(0.0, outRow2.getDouble(k), 0.0);
}
}
}

View File

@ -60,7 +60,7 @@ public class TestGraphNodes extends BaseDL4JTest {
@Test
public void testMergeNode() {
Nd4j.getRandom().setSeed(12345);
GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType());
GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType(), 1);
INDArray first = Nd4j.linspace(0, 11, 12, Nd4j.dataType()).reshape(3, 4);
INDArray second = Nd4j.linspace(0, 17, 18, Nd4j.dataType()).reshape(3, 6).addi(100);
@ -82,7 +82,7 @@ public class TestGraphNodes extends BaseDL4JTest {
public void testMergeNodeRNN() {
Nd4j.getRandom().setSeed(12345);
GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType());
GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType(), 1);
INDArray first = Nd4j.linspace(0, 59, 60, Nd4j.dataType()).reshape(3, 4, 5);
INDArray second = Nd4j.linspace(0, 89, 90, Nd4j.dataType()).reshape(3, 6, 5).addi(100);
@ -103,7 +103,7 @@ public class TestGraphNodes extends BaseDL4JTest {
@Test
public void testCnnDepthMerge() {
Nd4j.getRandom().setSeed(12345);
GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType());
GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType(), 1);
INDArray first = Nd4j.linspace(0, 3, 4, Nd4j.dataType()).reshape(1, 1, 2, 2);
INDArray second = Nd4j.linspace(0, 3, 4, Nd4j.dataType()).reshape(1, 1, 2, 2).addi(10);

View File

@ -0,0 +1,974 @@
/* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.nn.layers.convolution;
import lombok.*;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.CnnLossLayer;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.ZeroPaddingLayer;
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D;
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.ComposableInputPreProcessor;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import static org.junit.Assert.assertEquals;
@RunWith(Parameterized.class)
public class ConvDataFormatTests extends BaseDL4JTest {
private final DataType dataType;
public ConvDataFormatTests(DataType dataType){
this.dataType = dataType;
}
@Parameterized.Parameters(name = "{0}")
public static Object[] params(){
return new DataType[]{DataType.FLOAT, DataType.DOUBLE};
}
@Test
public void testConv2d() {
try {
for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getConv2dNet(CNN2DFormat.NCHW, true, cm))
.net2(getConv2dNet(CNN2DFormat.NCHW, false, cm))
.net3(getConv2dNet(CNN2DFormat.NHWC, true, cm))
.net4(getConv2dNet(CNN2DFormat.NHWC, false, cm))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
.testLayerIdx(1)
.build();
testHelper(tc);
}
}
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testSubsampling2d() {
try {
for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getSubsampling2dNet(CNN2DFormat.NCHW, true, cm))
.net2(getSubsampling2dNet(CNN2DFormat.NCHW, false, cm))
.net3(getSubsampling2dNet(CNN2DFormat.NHWC, true, cm))
.net4(getSubsampling2dNet(CNN2DFormat.NHWC, false, cm))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
.testLayerIdx(1)
.build();
testHelper(tc);
}
}
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testDepthwiseConv2d() {
try {
for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getDepthwiseConv2dNet(CNN2DFormat.NCHW, true, cm))
.net2(getDepthwiseConv2dNet(CNN2DFormat.NCHW, false, cm))
.net3(getDepthwiseConv2dNet(CNN2DFormat.NHWC, true, cm))
.net4(getDepthwiseConv2dNet(CNN2DFormat.NHWC, false, cm))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
.testLayerIdx(1)
.build();
testHelper(tc);
}
}
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testSeparableConv2d() {
try {
for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getSeparableConv2dNet(CNN2DFormat.NCHW, true, cm))
.net2(getSeparableConv2dNet(CNN2DFormat.NCHW, false, cm))
.net3(getSeparableConv2dNet(CNN2DFormat.NHWC, true, cm))
.net4(getSeparableConv2dNet(CNN2DFormat.NHWC, false, cm))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
.testLayerIdx(1)
.build();
testHelper(tc);
}
}
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testDeconv2d() {
try {
for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getDeconv2DNet2dNet(CNN2DFormat.NCHW, true, cm))
.net2(getDeconv2DNet2dNet(CNN2DFormat.NCHW, false, cm))
.net3(getDeconv2DNet2dNet(CNN2DFormat.NHWC, true, cm))
.net4(getDeconv2DNet2dNet(CNN2DFormat.NHWC, false, cm))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
.testLayerIdx(1)
.build();
testHelper(tc);
}
}
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testLRN() {
try {
for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getLrnLayer(CNN2DFormat.NCHW, true, cm))
.net2(getLrnLayer(CNN2DFormat.NCHW, false, cm))
.net3(getLrnLayer(CNN2DFormat.NHWC, true, cm))
.net4(getLrnLayer(CNN2DFormat.NHWC, false, cm))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
.testLayerIdx(1)
.build();
testHelper(tc);
}
}
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testZeroPaddingLayer(){
try {
for (boolean helpers : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = helpers ? "With helpers" : "No helpers";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getZeroPaddingNet(CNN2DFormat.NCHW, true))
.net2(getZeroPaddingNet(CNN2DFormat.NCHW, false))
.net3(getZeroPaddingNet(CNN2DFormat.NHWC, true))
.net4(getZeroPaddingNet(CNN2DFormat.NHWC, false))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
.testLayerIdx(1)
.build();
testHelper(tc);
}
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testCropping2DLayer(){
try {
for (boolean helpers : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = helpers ? "With helpers" : "No helpers";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getCropping2dNet(CNN2DFormat.NCHW, true))
.net2(getCropping2dNet(CNN2DFormat.NCHW, false))
.net3(getCropping2dNet(CNN2DFormat.NHWC, true))
.net4(getCropping2dNet(CNN2DFormat.NHWC, false))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
.testLayerIdx(1)
.build();
testHelper(tc);
}
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testUpsampling2d(){
try {
for (boolean helpers : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = helpers ? "With helpers" : "No helpers";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getUpsamplingNet(CNN2DFormat.NCHW, true))
.net2(getUpsamplingNet(CNN2DFormat.NCHW, false))
.net3(getUpsamplingNet(CNN2DFormat.NHWC, true))
.net4(getUpsamplingNet(CNN2DFormat.NHWC, false))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
.testLayerIdx(1)
.build();
testHelper(tc);
}
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testBatchNormNet(){
try {
for(boolean useLogStd : new boolean[]{true, false}) {
for (boolean helpers : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = (helpers ? "With helpers" : "No helpers") + " - " + (useLogStd ? "logstd" : "std");
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getBatchNormNet(useLogStd, CNN2DFormat.NCHW, true))
.net2(getBatchNormNet(useLogStd, CNN2DFormat.NCHW, false))
.net3(getBatchNormNet(useLogStd, CNN2DFormat.NHWC, true))
.net4(getBatchNormNet(useLogStd, CNN2DFormat.NHWC, false))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
.testLayerIdx(1)
.build();
testHelper(tc);
}
}
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testCnnLossLayer() {
try {
for (boolean helpers : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = helpers ? "With helpers" : "No helpers";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray labelsNHWC = TestUtils.randomOneHot(this.dataType,2*6*6, 3);
labelsNHWC = labelsNHWC.reshape(2,6,6,3);
INDArray labelsNCHW = labelsNHWC.permute(0,3,1,2).dup();
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getCnnLossNet(CNN2DFormat.NCHW, true, ConvolutionMode.Same))
.net2(getCnnLossNet(CNN2DFormat.NCHW, false, ConvolutionMode.Same))
.net3(getCnnLossNet(CNN2DFormat.NHWC, true, ConvolutionMode.Same))
.net4(getCnnLossNet(CNN2DFormat.NHWC, false, ConvolutionMode.Same))
.inNCHW(inNCHW)
.labelsNCHW(labelsNCHW)
.labelsNHWC(labelsNHWC)
.testLayerIdx(1)
.nhwcOutput(true)
.build();
testHelper(tc);
}
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testSpaceToDepthNet(){
try {
for (boolean helpers : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = helpers ? "With helpers" : "No helpers";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getSpaceToDepthNet(CNN2DFormat.NCHW, true))
.net2(getSpaceToDepthNet(CNN2DFormat.NCHW, false))
.net3(getSpaceToDepthNet(CNN2DFormat.NHWC, true))
.net4(getSpaceToDepthNet(CNN2DFormat.NHWC, false))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
.testLayerIdx(1)
.build();
testHelper(tc);
}
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testSpaceToBatchNet(){
try {
for (boolean helpers : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = helpers ? "With helpers" : "No helpers";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 16, 16);
INDArray labels = TestUtils.randomOneHot(8, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getSpaceToBatchNet(CNN2DFormat.NCHW, true))
.net2(getSpaceToBatchNet(CNN2DFormat.NCHW, false))
.net3(getSpaceToBatchNet(CNN2DFormat.NHWC, true))
.net4(getSpaceToBatchNet(CNN2DFormat.NHWC, false))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
.testLayerIdx(1)
.build();
testHelper(tc);
}
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testLocallyConnected() {
try {
for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getLocallyConnectedNet(CNN2DFormat.NCHW, true, cm))
.net2(getLocallyConnectedNet(CNN2DFormat.NCHW, false, cm))
.net3(getLocallyConnectedNet(CNN2DFormat.NHWC, true, cm))
.net4(getLocallyConnectedNet(CNN2DFormat.NHWC, false, cm))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
.testLayerIdx(1)
.build();
testHelper(tc);
}
}
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testGlobalPooling() {
try {
for (boolean helpers : new boolean[]{false, true}) {
for (PoolingType pt : PoolingType.values()) {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = helpers ? "With helpers (" + pt + ")" : "No helpers (" + pt + ")";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getGlobalPoolingNet(CNN2DFormat.NCHW, pt, true))
.net2(getGlobalPoolingNet(CNN2DFormat.NCHW, pt, false))
.net3(getGlobalPoolingNet(CNN2DFormat.NHWC, pt, true))
.net4(getGlobalPoolingNet(CNN2DFormat.NHWC, pt, false))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
.testLayerIdx(1)
.build();
testHelper(tc);
}
}
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
private MultiLayerNetwork getConv2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) {
return getNetWithLayer(new ConvolutionLayer.Builder()
.kernelSize(3, 3)
.stride(2, 2)
.activation(Activation.TANH)
.dataFormat(format)
.nOut(3)
.helperAllowFallback(false)
.build(), format, cm, null);
} else {
return getNetWithLayer(new ConvolutionLayer.Builder()
.kernelSize(3, 3)
.stride(2, 2)
.activation(Activation.TANH)
.nOut(3)
.helperAllowFallback(false)
.build(), format, cm, null);
}
}
private MultiLayerNetwork getSubsampling2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) {
return getNetWithLayer(new SubsamplingLayer.Builder()
.kernelSize(2, 2)
.stride(1, 1)
.dataFormat(format)
.helperAllowFallback(false)
.build(), format, cm, null);
} else {
return getNetWithLayer(new SubsamplingLayer.Builder()
.kernelSize(2, 2)
.stride(1, 1)
.helperAllowFallback(false)
.build(), format, cm, null);
}
}
private MultiLayerNetwork getSeparableConv2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) {
return getNetWithLayer(new SeparableConvolution2D.Builder()
.kernelSize(3, 3)
.stride(2, 2)
.activation(Activation.TANH)
.dataFormat(format)
.nOut(3)
.helperAllowFallback(false)
.build(), format, cm, null);
} else {
return getNetWithLayer(new SeparableConvolution2D.Builder()
.kernelSize(3, 3)
.stride(2, 2)
.activation(Activation.TANH)
.nOut(3)
.helperAllowFallback(false)
.build(), format, cm, null);
}
}
private MultiLayerNetwork getDepthwiseConv2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) {
return getNetWithLayer(new DepthwiseConvolution2D.Builder()
.depthMultiplier(2)
.kernelSize(3, 3)
.stride(2, 2)
.activation(Activation.TANH)
.dataFormat(format)
.nOut(3)
.helperAllowFallback(false)
.build(), format, cm, null);
} else {
return getNetWithLayer(new DepthwiseConvolution2D.Builder()
.depthMultiplier(2)
.kernelSize(3, 3)
.stride(2, 2)
.activation(Activation.TANH)
.nOut(3)
.helperAllowFallback(false)
.build(), format, cm, null);
}
}
private MultiLayerNetwork getLrnLayer(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) {
return getNetWithLayer(new LocalResponseNormalization.Builder()
.dataFormat(format)
.helperAllowFallback(false)
.build(), format, cm, null);
} else {
return getNetWithLayer(new LocalResponseNormalization.Builder()
.helperAllowFallback(false)
.build(), format, cm, null);
}
}
private MultiLayerNetwork getZeroPaddingNet(CNN2DFormat format, boolean setOnLayerAlso) {
if (setOnLayerAlso) {
return getNetWithLayer(new ZeroPaddingLayer.Builder(2,2)
.dataFormat(format).build(), format, ConvolutionMode.Same, null);
} else {
return getNetWithLayer(new ZeroPaddingLayer.Builder(2,2).build(),
format, ConvolutionMode.Same, null);
}
}
private MultiLayerNetwork getCropping2dNet(CNN2DFormat format, boolean setOnLayerAlso) {
if (setOnLayerAlso) {
return getNetWithLayer(new Cropping2D.Builder(2,2)
.dataFormat(format).build(), format, ConvolutionMode.Same, null);
} else {
return getNetWithLayer(new Cropping2D.Builder(2,2)
.build(), format, ConvolutionMode.Same, null);
}
}
private MultiLayerNetwork getUpsamplingNet(CNN2DFormat format, boolean setOnLayerAlso) {
if (setOnLayerAlso) {
return getNetWithLayer(new Upsampling2D.Builder(2)
.dataFormat(format).build(), format, ConvolutionMode.Same, null);
} else {
return getNetWithLayer(new Upsampling2D.Builder(2)
.build(), format, ConvolutionMode.Same, null);
}
}
private MultiLayerNetwork getDeconv2DNet2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) {
return getNetWithLayer(new Deconvolution2D.Builder().nOut(2)
.activation(Activation.TANH)
.kernelSize(2,2)
.stride(2,2)
.build(), format, cm, null);
} else {
return getNetWithLayer(new Deconvolution2D.Builder().nOut(2)
.activation(Activation.TANH)
.kernelSize(2,2)
.stride(2,2)
.build(), format, cm, null);
}
}
private MultiLayerNetwork getBatchNormNet(boolean logStdev, CNN2DFormat format, boolean setOnLayerAlso) {
if (setOnLayerAlso) {
return getNetWithLayer(new BatchNormalization.Builder()
.useLogStd(logStdev)
.dataFormat(format)
.helperAllowFallback(false)
.nOut(3).build(), format, ConvolutionMode.Same, null);
} else {
return getNetWithLayer(new BatchNormalization.Builder()
.useLogStd(logStdev)
.helperAllowFallback(false)
.nOut(3).build(), format, ConvolutionMode.Same, null);
}
}
private MultiLayerNetwork getSpaceToDepthNet(CNN2DFormat format, boolean setOnLayerAlso) {
if (setOnLayerAlso) {
return getNetWithLayer(new SpaceToDepthLayer.Builder()
.blocks(2)
.dataFormat(format)
.build(), format, ConvolutionMode.Same, null);
} else {
return getNetWithLayer(new SpaceToDepthLayer.Builder()
.blocks(2)
.build(), format, ConvolutionMode.Same, null);
}
}
private MultiLayerNetwork getSpaceToBatchNet(CNN2DFormat format, boolean setOnLayerAlso) {
if (setOnLayerAlso) {
return getNetWithLayer(new SpaceToBatchLayer.Builder()
.blocks(2, 2)
.dataFormat(format)
.build(), format, ConvolutionMode.Same, InputType.convolutional(16, 16, 3, format));
} else {
return getNetWithLayer(new SpaceToBatchLayer.Builder()
.blocks(2, 2)
.build(), format, ConvolutionMode.Same, InputType.convolutional(16, 16, 3, format));
}
}
private MultiLayerNetwork getLocallyConnectedNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) {
return getNetWithLayer(new LocallyConnected2D.Builder()
.kernelSize(3, 3)
.stride(2, 2)
.activation(Activation.TANH)
.dataFormat(format)
.nOut(3)
.build(), format, cm, null);
} else {
return getNetWithLayer(new LocallyConnected2D.Builder()
.kernelSize(3, 3)
.stride(2, 2)
.activation(Activation.TANH)
.nOut(3)
.build(), format, cm, null);
}
}
private MultiLayerNetwork getNetWithLayer(Layer layer, CNN2DFormat format, ConvolutionMode cm, InputType inputType) {
NeuralNetConfiguration.ListBuilder builder = new NeuralNetConfiguration.Builder()
.dataType(this.dataType)
.seed(12345)
.convolutionMode(cm)
.list()
.layer(new ConvolutionLayer.Builder()
.kernelSize(3, 3)
.stride(2, 2)
.activation(Activation.TANH)
.dataFormat(format)
.nOut(3)
.helperAllowFallback(false)
.build())
.layer(layer)
.layer(new OutputLayer.Builder().activation(Activation.SOFTMAX).nOut(10).build())
.setInputType(inputType != null ? inputType : InputType.convolutional(12, 12, 3, format));
if(format == CNN2DFormat.NHWC && !(layer instanceof GlobalPoolingLayer)){
//Add a preprocessor due to the differences in how NHWC and NCHW activations are flattened
//DL4J's flattening behaviour matches Keras (hence TF) for import compatibility
builder.inputPreProcessor(2, new ComposableInputPreProcessor(new NHWCToNCHWPreprocessor(), new CnnToFeedForwardPreProcessor()));
}
MultiLayerNetwork net = new MultiLayerNetwork(builder.build());
net.init();
return net;
}
private MultiLayerNetwork getGlobalPoolingNet(CNN2DFormat format, PoolingType pt, boolean setOnLayerAlso) {
if (setOnLayerAlso) {
return getNetWithLayer(new GlobalPoolingLayer.Builder(pt)
.poolingDimensions(format == CNN2DFormat.NCHW ? new int[]{2,3} : new int[]{1,2})
.build(), format, ConvolutionMode.Same, null);
} else {
return getNetWithLayer(new GlobalPoolingLayer.Builder(pt)
.build(), format, ConvolutionMode.Same, null);
}
}
private MultiLayerNetwork getCnnLossNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm){
NeuralNetConfiguration.ListBuilder builder = new NeuralNetConfiguration.Builder()
.seed(12345)
.convolutionMode(cm)
.list()
.layer(new ConvolutionLayer.Builder()
.kernelSize(3, 3)
.stride(2, 2)
.activation(Activation.TANH)
.dataFormat(format)
.nOut(3)
.helperAllowFallback(false)
.build());
if(setOnLayerAlso){
builder.layer(new CnnLossLayer.Builder().format(format).activation(Activation.SOFTMAX).build());
} else {
builder.layer(new CnnLossLayer.Builder().activation(Activation.SOFTMAX).build());
}
builder.setInputType(InputType.convolutional(12, 12, 3, format));
MultiLayerNetwork net = new MultiLayerNetwork(builder.build());
net.init();
return net;
}
@AllArgsConstructor
@Data
@NoArgsConstructor
@Builder
private static class TestCase {
private String msg;
private MultiLayerNetwork net1;
private MultiLayerNetwork net2;
private MultiLayerNetwork net3;
private MultiLayerNetwork net4;
private INDArray inNCHW;
private INDArray labelsNCHW;
private INDArray labelsNHWC;
private int testLayerIdx;
private boolean nhwcOutput;
}
public static void testHelper(TestCase tc) {
tc.net2.params().assign(tc.net1.params());
tc.net3.params().assign(tc.net1.params());
tc.net4.params().assign(tc.net1.params());
//Test forward pass:
INDArray inNCHW = tc.inNCHW;
INDArray inNHWC = tc.inNCHW.permute(0, 2, 3, 1).dup();
INDArray l0_1 = tc.net1.feedForward(inNCHW).get(tc.testLayerIdx + 1);
INDArray l0_2 = tc.net2.feedForward(inNCHW).get(tc.testLayerIdx + 1);
INDArray l0_3 = tc.net3.feedForward(inNHWC).get(tc.testLayerIdx + 1);
INDArray l0_4 = tc.net4.feedForward(inNHWC).get(tc.testLayerIdx + 1);
assertEquals(tc.msg, l0_1, l0_2);
if(l0_1.rank() == 4) {
assertEquals(tc.msg, l0_1, l0_3.permute(0, 3, 1, 2));
assertEquals(tc.msg, l0_1, l0_4.permute(0, 3, 1, 2));
} else {
assertEquals(tc.msg, l0_1, l0_3);
assertEquals(tc.msg, l0_1, l0_4);
}
INDArray out1 = tc.net1.output(inNCHW);
INDArray out2 = tc.net2.output(inNCHW);
INDArray out3 = tc.net3.output(inNHWC);
INDArray out4 = tc.net4.output(inNHWC);
assertEquals(tc.msg, out1, out2);
if(!tc.nhwcOutput) {
assertEquals(tc.msg, out1, out3);
assertEquals(tc.msg, out1, out4);
} else {
assertEquals(tc.msg, out1, out3.permute(0,3,1,2)); //NHWC to NCHW
assertEquals(tc.msg, out1, out4.permute(0,3,1,2));
}
//Test backprop
Pair<Gradient, INDArray> p1 = tc.net1.calculateGradients(inNCHW, tc.labelsNCHW, null, null);
Pair<Gradient, INDArray> p2 = tc.net2.calculateGradients(inNCHW, tc.labelsNCHW, null, null);
Pair<Gradient, INDArray> p3 = tc.net3.calculateGradients(inNHWC, tc.labelsNHWC, null, null);
Pair<Gradient, INDArray> p4 = tc.net4.calculateGradients(inNHWC, tc.labelsNHWC, null, null);
//Inpput gradients
assertEquals(tc.msg, p1.getSecond(), p2.getSecond());
assertEquals(tc.msg, p1.getSecond(), p3.getSecond().permute(0,3,1,2)); //Input gradients for NHWC input are also in NHWC format
assertEquals(tc.msg, p1.getSecond(), p4.getSecond().permute(0,3,1,2));
List<String> diff12 = differentGrads(p1.getFirst(), p2.getFirst());
List<String> diff13 = differentGrads(p1.getFirst(), p3.getFirst());
List<String> diff14 = differentGrads(p1.getFirst(), p4.getFirst());
assertEquals(tc.msg + " " + diff12, 0, diff12.size());
assertEquals(tc.msg + " " + diff13, 0, diff13.size());
assertEquals(tc.msg + " " + diff14, 0, diff14.size());
assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p2.getFirst().gradientForVariable());
assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p3.getFirst().gradientForVariable());
assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p4.getFirst().gradientForVariable());
tc.net1.fit(inNCHW, tc.labelsNCHW);
tc.net2.fit(inNCHW, tc.labelsNCHW);
tc.net3.fit(inNHWC, tc.labelsNHWC);
tc.net4.fit(inNHWC, tc.labelsNHWC);
assertEquals(tc.msg, tc.net1.params(), tc.net2.params());
assertEquals(tc.msg, tc.net1.params(), tc.net3.params());
assertEquals(tc.msg, tc.net1.params(), tc.net4.params());
//Test serialization
MultiLayerNetwork net1a = TestUtils.testModelSerialization(tc.net1);
MultiLayerNetwork net2a = TestUtils.testModelSerialization(tc.net2);
MultiLayerNetwork net3a = TestUtils.testModelSerialization(tc.net3);
MultiLayerNetwork net4a = TestUtils.testModelSerialization(tc.net4);
out1 = tc.net1.output(inNCHW);
assertEquals(tc.msg, out1, net1a.output(inNCHW));
assertEquals(tc.msg, out1, net2a.output(inNCHW));
if(!tc.nhwcOutput) {
assertEquals(tc.msg, out1, net3a.output(inNHWC));
assertEquals(tc.msg, out1, net4a.output(inNHWC));
} else {
assertEquals(tc.msg, out1, net3a.output(inNHWC).permute(0,3,1,2)); //NHWC to NCHW
assertEquals(tc.msg, out1, net4a.output(inNHWC).permute(0,3,1,2));
}
}
private static List<String> differentGrads(Gradient g1, Gradient g2){
List<String> differs = new ArrayList<>();
Map<String,INDArray> m1 = g1.gradientForVariable();
Map<String,INDArray> m2 = g2.gradientForVariable();
for(String s : m1.keySet()){
INDArray a1 = m1.get(s);
INDArray a2 = m2.get(s);
if(!a1.equals(a2)){
differs.add(s);
}
}
return differs;
}
//Converts NHWC to NCHW activations
@EqualsAndHashCode
private static class NHWCToNCHWPreprocessor implements InputPreProcessor {
@Override
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input.permute(0,3,1,2));
}
@Override
public INDArray backprop(INDArray output, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, output.permute(0,2,3,1));
}
@Override
public InputPreProcessor clone() {
return this;
}
@Override
public InputType getOutputType(InputType inputType) {
InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
return InputType.convolutional(c.getHeight(), c.getWidth(), c.getChannels(), CNN2DFormat.NCHW);
}
@Override
public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize) {
return null;
}
}
}

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.nn.layers.convolution;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.exception.DL4JException;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
@ -45,6 +46,7 @@ import static org.junit.Assert.*;
/**
* Created by Alex on 15/11/2016.
*/
@Slf4j
public class TestConvolutionModes extends BaseDL4JTest {
@Test
@ -106,12 +108,12 @@ public class TestConvolutionModes extends BaseDL4JTest {
}
} catch (DL4JException e) {
if (inSize == 9 || cm != ConvolutionMode.Strict) {
e.printStackTrace();
log.error("",e);
fail("Unexpected exception");
}
continue; //Expected exception
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail("Unexpected exception");
}
@ -184,12 +186,12 @@ public class TestConvolutionModes extends BaseDL4JTest {
}
} catch (DL4JException e) {
if (inSize == 9 || cm != ConvolutionMode.Strict) {
e.printStackTrace();
log.error("",e);
fail("Unexpected exception");
}
continue; //Expected exception
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail("Unexpected exception");
}

View File

@ -24,10 +24,7 @@ import org.deeplearning4j.earlystopping.saver.InMemoryModelSaver;
import org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculator;
import org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition;
import org.deeplearning4j.earlystopping.trainer.EarlyStoppingGraphTrainer;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
@ -45,6 +42,8 @@ import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.util.ModelSerializer;
import org.deeplearning4j.util.TimeSeriesUtils;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -61,12 +60,22 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import static org.deeplearning4j.nn.conf.RNNFormat.NCW;
import static org.junit.Assert.assertEquals;
@Slf4j
@RunWith(Parameterized.class)
public class BidirectionalTest extends BaseDL4JTest {
private RNNFormat rnnDataFormat;
public BidirectionalTest(RNNFormat rnnDataFormat){
this.rnnDataFormat = rnnDataFormat;
}
@Parameterized.Parameters
public static Object[] params(){
return RNNFormat.values();
}
@Test
public void compareImplementations(){
for(WorkspaceMode wsm : WorkspaceMode.values()) {
@ -82,9 +91,9 @@ public class BidirectionalTest extends BaseDL4JTest {
.inferenceWorkspaceMode(wsm)
.updater(new Adam())
.list()
.layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build()))
.layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build()))
.layer(new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE)
.layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()))
.layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()))
.layer(new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).dataFormat(rnnDataFormat)
.nIn(10).nOut(10).build())
.build();
@ -95,9 +104,9 @@ public class BidirectionalTest extends BaseDL4JTest {
.inferenceWorkspaceMode(wsm)
.updater(new Adam())
.list()
.layer(new GravesBidirectionalLSTM.Builder().nIn(10).nOut(10).build())
.layer(new GravesBidirectionalLSTM.Builder().nIn(10).nOut(10).build())
.layer(new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE)
.layer(new GravesBidirectionalLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build())
.layer(new GravesBidirectionalLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build())
.layer(new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).dataFormat(rnnDataFormat)
.nIn(10).nOut(10).build())
.build();
@ -116,15 +125,24 @@ public class BidirectionalTest extends BaseDL4JTest {
net2.setParams(net1.params()); //Assuming exact same layout here...
INDArray in = Nd4j.rand(new int[]{3, 10, 5});
INDArray in;
if (rnnDataFormat == NCW){
in = Nd4j.rand(new int[]{3, 10, 5});
}else{
in = Nd4j.rand(new int[]{3, 5, 10});
}
INDArray out1 = net1.output(in);
INDArray out2 = net2.output(in);
assertEquals(out1, out2);
INDArray labels = Nd4j.rand(new int[]{3, 10, 5});
INDArray labels;
if (rnnDataFormat == NCW){
labels = Nd4j.rand(new int[]{3, 10, 5});
}else{
labels = Nd4j.rand(new int[]{3, 5, 10});
}
net1.setInput(in);
net1.setLabels(labels);
@ -276,17 +294,22 @@ public class BidirectionalTest extends BaseDL4JTest {
.inferenceWorkspaceMode(wsm)
.updater(new Adam())
.list()
.layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build()))
.layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build()))
.layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()))
.layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()))
.layer(new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE)
.nIn(10).nOut(10).build())
.nIn(10).nOut(10).dataFormat(rnnDataFormat).build())
.build();
MultiLayerNetwork net1 = new MultiLayerNetwork(conf1);
net1.init();
INDArray in = Nd4j.rand(new int[]{3, 10, 5});
INDArray labels = Nd4j.rand(new int[]{3, 10, 5});
INDArray in;
INDArray labels;
long[] inshape = rnnDataFormat == NCW ? new long[]{3, 10, 5} : new long[]{3, 5, 10};
in = Nd4j.rand(inshape);
labels = Nd4j.rand(inshape);
net1.fit(in, labels);
@ -300,8 +323,8 @@ public class BidirectionalTest extends BaseDL4JTest {
MultiLayerNetwork net2 = ModelSerializer.restoreMultiLayerNetwork(new ByteArrayInputStream(bytes), true);
in = Nd4j.rand(new int[]{3, 10, 5});
labels = Nd4j.rand(new int[]{3, 10, 5});
in = Nd4j.rand(inshape);
labels = Nd4j.rand(inshape);
INDArray out1 = net1.output(in);
INDArray out2 = net2.output(in);
@ -338,18 +361,18 @@ public class BidirectionalTest extends BaseDL4JTest {
.updater(new Adam())
.graphBuilder()
.addInputs("in")
.layer("0", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build()), "in")
.layer("1", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build()), "0")
.layer("2", new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE)
.layer("0", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()), "in")
.layer("1", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()), "0")
.layer("2", new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).dataFormat(rnnDataFormat)
.nIn(10).nOut(10).build(), "1")
.setOutputs("2")
.build();
ComputationGraph net1 = new ComputationGraph(conf1);
net1.init();
INDArray in = Nd4j.rand(new int[]{3, 10, 5});
INDArray labels = Nd4j.rand(new int[]{3, 10, 5});
long[] inshape = (rnnDataFormat == NCW)? new long[]{3, 10, 5}: new long[]{3, 5, 10};
INDArray in = Nd4j.rand(inshape);
INDArray labels = Nd4j.rand(inshape);
net1.fit(new DataSet(in, labels));
@ -363,8 +386,8 @@ public class BidirectionalTest extends BaseDL4JTest {
ComputationGraph net2 = ModelSerializer.restoreComputationGraph(new ByteArrayInputStream(bytes), true);
in = Nd4j.rand(new int[]{3, 10, 5});
labels = Nd4j.rand(new int[]{3, 10, 5});
in = Nd4j.rand(inshape);
labels = Nd4j.rand(inshape);
INDArray out1 = net1.outputSingle(in);
INDArray out2 = net2.outputSingle(in);
@ -394,8 +417,8 @@ public class BidirectionalTest extends BaseDL4JTest {
Bidirectional.Mode[] modes = new Bidirectional.Mode[]{Bidirectional.Mode.CONCAT, Bidirectional.Mode.ADD,
Bidirectional.Mode.AVERAGE, Bidirectional.Mode.MUL};
INDArray in = Nd4j.rand(new int[]{3, 10, 6});
long[] inshape = rnnDataFormat == NCW ? new long[]{3, 10, 6} : new long[]{3, 6, 10};
INDArray in = Nd4j.rand(inshape);
for (Bidirectional.Mode m : modes) {
MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder()
@ -406,7 +429,7 @@ public class BidirectionalTest extends BaseDL4JTest {
.inferenceWorkspaceMode(wsm)
.updater(new Adam())
.list()
.layer(new Bidirectional(m, new SimpleRnn.Builder().nIn(10).nOut(10).build()))
.layer(new Bidirectional(m, new SimpleRnn.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()))
.build();
MultiLayerNetwork net1 = new MultiLayerNetwork(conf1);
@ -418,7 +441,7 @@ public class BidirectionalTest extends BaseDL4JTest {
.weightInit(WeightInit.XAVIER)
.updater(new Adam())
.list()
.layer(new SimpleRnn.Builder().nIn(10).nOut(10).build())
.layer(new SimpleRnn.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build())
.build();
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2.clone());
@ -434,11 +457,10 @@ public class BidirectionalTest extends BaseDL4JTest {
net3.setParam("0_RW", net1.getParam("0_bRW"));
net3.setParam("0_b", net1.getParam("0_bb"));
INDArray inReverse = TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT);
INDArray inReverse = TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT, rnnDataFormat);
INDArray out1 = net1.output(in);
INDArray out2 = net2.output(in);
INDArray out3 = TimeSeriesUtils.reverseTimeSeries(net3.output(inReverse), LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT);
INDArray out3 = TimeSeriesUtils.reverseTimeSeries(net3.output(inReverse), LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT, rnnDataFormat);
INDArray outExp;
switch (m) {
@ -452,7 +474,7 @@ public class BidirectionalTest extends BaseDL4JTest {
outExp = out2.add(out3).muli(0.5);
break;
case CONCAT:
outExp = Nd4j.concat(1, out2, out3);
outExp = Nd4j.concat((rnnDataFormat == NCW)?1:2, out2, out3);
break;
default:
throw new RuntimeException();
@ -464,25 +486,25 @@ public class BidirectionalTest extends BaseDL4JTest {
//Check gradients:
if (m == Bidirectional.Mode.ADD || m == Bidirectional.Mode.CONCAT) {
INDArray eps = Nd4j.rand(new int[]{3, 10, 6});
INDArray eps = Nd4j.rand(inshape);
INDArray eps1;
if (m == Bidirectional.Mode.CONCAT) {
eps1 = Nd4j.concat(1, eps, eps);
eps1 = Nd4j.concat((rnnDataFormat == NCW)?1:2, eps, eps);
} else {
eps1 = eps;
}
net1.setInput(in);
net2.setInput(in);
net3.setInput(TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT));
net3.setInput(TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT, rnnDataFormat));
net1.feedForward(true, false);
net2.feedForward(true, false);
net3.feedForward(true, false);
Pair<Gradient, INDArray> p1 = net1.backpropGradient(eps1, LayerWorkspaceMgr.noWorkspaces());
Pair<Gradient, INDArray> p2 = net2.backpropGradient(eps, LayerWorkspaceMgr.noWorkspaces());
Pair<Gradient, INDArray> p3 = net3.backpropGradient(TimeSeriesUtils.reverseTimeSeries(eps, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT), LayerWorkspaceMgr.noWorkspaces());
Pair<Gradient, INDArray> p3 = net3.backpropGradient(TimeSeriesUtils.reverseTimeSeries(eps, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT, rnnDataFormat), LayerWorkspaceMgr.noWorkspaces());
Gradient g1 = p1.getFirst();
Gradient g2 = p2.getFirst();
Gradient g3 = p3.getFirst();
@ -520,7 +542,9 @@ public class BidirectionalTest extends BaseDL4JTest {
Bidirectional.Mode.AVERAGE, Bidirectional.Mode.MUL};
INDArray in = Nd4j.rand(new int[]{3, 10, 6});
long[] inshape = rnnDataFormat == NCW ? new long[]{3, 10, 6} : new long[]{3, 6, 10};
INDArray in = Nd4j.rand(inshape);
for (Bidirectional.Mode m : modes) {
ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder()
@ -532,7 +556,7 @@ public class BidirectionalTest extends BaseDL4JTest {
.updater(new Adam())
.graphBuilder()
.addInputs("in")
.layer("0", new Bidirectional(m, new SimpleRnn.Builder().nIn(10).nOut(10).build()), "in")
.layer("0", new Bidirectional(m, new SimpleRnn.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()), "in")
.setOutputs("0")
.build();
@ -546,7 +570,7 @@ public class BidirectionalTest extends BaseDL4JTest {
.updater(new Adam())
.graphBuilder()
.addInputs("in")
.layer("0", new SimpleRnn.Builder().nIn(10).nOut(10).build(), "in")
.layer("0", new SimpleRnn.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build(), "in")
.setOutputs("0")
.build();
@ -566,9 +590,20 @@ public class BidirectionalTest extends BaseDL4JTest {
INDArray out1 = net1.outputSingle(in);
INDArray out2 = net2.outputSingle(in);
INDArray out3 = TimeSeriesUtils.reverseTimeSeries(net3.outputSingle(
TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT)),
LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT);
INDArray out3;
INDArray inReverse;
if (rnnDataFormat == RNNFormat.NWC){
inReverse = TimeSeriesUtils.reverseTimeSeries(in.permute(0, 2, 1), LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT).permute(0, 2, 1);
out3 = net3.outputSingle(inReverse);
out3 = TimeSeriesUtils.reverseTimeSeries(out3.permute(0, 2, 1), LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT).permute(0, 2, 1);
}
else{
inReverse = TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT);
out3 = net3.outputSingle(inReverse);
out3 = TimeSeriesUtils.reverseTimeSeries(out3, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT);
}
INDArray outExp;
switch (m) {
@ -582,7 +617,9 @@ public class BidirectionalTest extends BaseDL4JTest {
outExp = out2.add(out3).muli(0.5);
break;
case CONCAT:
outExp = Nd4j.concat(1, out2, out3);
System.out.println(out2.shapeInfoToString());
System.out.println(out3.shapeInfoToString());
outExp = Nd4j.concat((rnnDataFormat == NCW)?1:2, out2, out3);
break;
default:
throw new RuntimeException();
@ -594,22 +631,26 @@ public class BidirectionalTest extends BaseDL4JTest {
//Check gradients:
if (m == Bidirectional.Mode.ADD || m == Bidirectional.Mode.CONCAT) {
INDArray eps = Nd4j.rand(new int[]{3, 10, 6});
INDArray eps = Nd4j.rand(inshape);
INDArray eps1;
if (m == Bidirectional.Mode.CONCAT) {
eps1 = Nd4j.concat(1, eps, eps);
eps1 = Nd4j.concat((rnnDataFormat == NCW)?1:2, eps, eps);
} else {
eps1 = eps;
}
INDArray epsReversed = (rnnDataFormat == NCW)?
TimeSeriesUtils.reverseTimeSeries(eps, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT):
TimeSeriesUtils.reverseTimeSeries(eps.permute(0, 2, 1), LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT)
.permute(0, 2, 1);
net1.outputSingle(true, false, in);
net2.outputSingle(true, false, in);
net3.outputSingle(true, false, TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT));
net3.outputSingle(true, false, inReverse);
Gradient g1 = net1.backpropGradient(eps1);
Gradient g2 = net2.backpropGradient(eps);
Gradient g3 = net3.backpropGradient(TimeSeriesUtils.reverseTimeSeries(eps, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT));
Gradient g3 = net3.backpropGradient(epsReversed);
for (boolean updates : new boolean[]{false, true}) {
if (updates) {

View File

@ -23,6 +23,7 @@ import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
@ -31,6 +32,8 @@ import org.deeplearning4j.nn.params.GravesLSTMParamInitializer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -42,10 +45,18 @@ import org.nd4j.linalg.primitives.Pair;
import static org.junit.Assert.*;
@RunWith(Parameterized.class)
public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
private double score = 0.0;
private RNNFormat rnnDataFormat;
public GravesBidirectionalLSTMTest(RNNFormat rnnDataFormat){
this.rnnDataFormat = rnnDataFormat;
}
@Parameterized.Parameters
public static Object[] params(){
return RNNFormat.values();
}
@Test
public void testBidirectionalLSTMGravesForwardBasic() {
//Very basic test of forward prop. of LSTM layer with a time series.
@ -55,7 +66,7 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
final NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder()
.layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().nIn(nIn)
.nOut(nHiddenUnits).activation(Activation.TANH).build())
.nOut(nHiddenUnits).dataFormat(rnnDataFormat).activation(Activation.TANH).build())
.build();
val numParams = conf.getLayer().initializer().numParams(conf);
@ -65,22 +76,41 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
//Data: has shape [miniBatchSize,nIn,timeSeriesLength];
//Output/activations has shape [miniBatchsize,nHiddenUnits,timeSeriesLength];
if (rnnDataFormat == RNNFormat.NCW){
final INDArray dataSingleExampleTimeLength1 = Nd4j.ones(1, nIn, 1);
final INDArray activations1 = layer.activate(dataSingleExampleTimeLength1, false, LayerWorkspaceMgr.noWorkspaces());
assertArrayEquals(activations1.shape(), new long[] {1, nHiddenUnits, 1});
final INDArray dataSingleExampleTimeLength1 = Nd4j.ones(1, nIn, 1);
final INDArray activations1 = layer.activate(dataSingleExampleTimeLength1, false, LayerWorkspaceMgr.noWorkspaces());
assertArrayEquals(activations1.shape(), new long[] {1, nHiddenUnits, 1});
final INDArray dataMultiExampleLength1 = Nd4j.ones(10, nIn, 1);
final INDArray activations2 = layer.activate(dataMultiExampleLength1, false, LayerWorkspaceMgr.noWorkspaces());
assertArrayEquals(activations2.shape(), new long[] {10, nHiddenUnits, 1});
final INDArray dataMultiExampleLength1 = Nd4j.ones(10, nIn, 1);
final INDArray activations2 = layer.activate(dataMultiExampleLength1, false, LayerWorkspaceMgr.noWorkspaces());
assertArrayEquals(activations2.shape(), new long[] {10, nHiddenUnits, 1});
final INDArray dataSingleExampleLength12 = Nd4j.ones(1, nIn, 12);
final INDArray activations3 = layer.activate(dataSingleExampleLength12, false, LayerWorkspaceMgr.noWorkspaces());
assertArrayEquals(activations3.shape(), new long[] {1, nHiddenUnits, 12});
final INDArray dataSingleExampleLength12 = Nd4j.ones(1, nIn, 12);
final INDArray activations3 = layer.activate(dataSingleExampleLength12, false, LayerWorkspaceMgr.noWorkspaces());
assertArrayEquals(activations3.shape(), new long[] {1, nHiddenUnits, 12});
final INDArray dataMultiExampleLength15 = Nd4j.ones(10, nIn, 15);
final INDArray activations4 = layer.activate(dataMultiExampleLength15, false, LayerWorkspaceMgr.noWorkspaces());
assertArrayEquals(activations4.shape(), new long[] {10, nHiddenUnits, 15});
}
else{
final INDArray dataSingleExampleTimeLength1 = Nd4j.ones(1, 1, nIn);
final INDArray activations1 = layer.activate(dataSingleExampleTimeLength1, false, LayerWorkspaceMgr.noWorkspaces());
assertArrayEquals(activations1.shape(), new long[] {1, 1, nHiddenUnits});
final INDArray dataMultiExampleLength1 = Nd4j.ones(10, 1, nIn);
final INDArray activations2 = layer.activate(dataMultiExampleLength1, false, LayerWorkspaceMgr.noWorkspaces());
assertArrayEquals(activations2.shape(), new long[] {10, 1, nHiddenUnits});
final INDArray dataSingleExampleLength12 = Nd4j.ones(1, 12, nIn);
final INDArray activations3 = layer.activate(dataSingleExampleLength12, false, LayerWorkspaceMgr.noWorkspaces());
assertArrayEquals(activations3.shape(), new long[] {1, 12, nHiddenUnits});
final INDArray dataMultiExampleLength15 = Nd4j.ones(10, 15, nIn);
final INDArray activations4 = layer.activate(dataMultiExampleLength15, false, LayerWorkspaceMgr.noWorkspaces());
assertArrayEquals(activations4.shape(), new long[] {10, 15, nHiddenUnits});
}
final INDArray dataMultiExampleLength15 = Nd4j.ones(10, nIn, 15);
final INDArray activations4 = layer.activate(dataMultiExampleLength15, false, LayerWorkspaceMgr.noWorkspaces());
assertArrayEquals(activations4.shape(), new long[] {10, nHiddenUnits, 15});
}
@Test
@ -94,14 +124,15 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
testGravesBackwardBasicHelper(13, 3, 17, 1, 1); //Edge case: both miniBatchSize = 1 and timeSeriesLength = 1
}
private static void testGravesBackwardBasicHelper(int nIn, int nOut, int lstmNHiddenUnits, int miniBatchSize,
private void testGravesBackwardBasicHelper(int nIn, int nOut, int lstmNHiddenUnits, int miniBatchSize,
int timeSeriesLength) {
INDArray inputData = Nd4j.ones(miniBatchSize, nIn, timeSeriesLength);
INDArray inputData = (rnnDataFormat == RNNFormat.NCW)?Nd4j.ones(miniBatchSize, nIn, timeSeriesLength):
Nd4j.ones(miniBatchSize, timeSeriesLength, nIn);
NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder()
.layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().nIn(nIn)
.nOut(lstmNHiddenUnits)
.nOut(lstmNHiddenUnits).dataFormat(rnnDataFormat)
.dist(new UniformDistribution(0, 1)).activation(Activation.TANH).build())
.build();
@ -114,7 +145,8 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
lstm.activate(inputData, false, LayerWorkspaceMgr.noWorkspaces());
assertNotNull(lstm.input());
INDArray epsilon = Nd4j.ones(miniBatchSize, lstmNHiddenUnits, timeSeriesLength);
INDArray epsilon =(rnnDataFormat == RNNFormat.NCW)? Nd4j.ones(miniBatchSize, lstmNHiddenUnits, timeSeriesLength):
Nd4j.ones(miniBatchSize, timeSeriesLength, lstmNHiddenUnits);
Pair<Gradient, INDArray> out = lstm.backpropGradient(epsilon, LayerWorkspaceMgr.noWorkspaces());
Gradient outGradient = out.getFirst();
@ -147,7 +179,11 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
assertArrayEquals(recurrentWeightGradientB.shape(), new long[] {lstmNHiddenUnits, 4 * lstmNHiddenUnits + 3});
assertNotNull(nextEpsilon);
assertArrayEquals(nextEpsilon.shape(), new long[] {miniBatchSize, nIn, timeSeriesLength});
if (rnnDataFormat == RNNFormat.NCW) {
assertArrayEquals(nextEpsilon.shape(), new long[]{miniBatchSize, nIn, timeSeriesLength});
}else{
assertArrayEquals(nextEpsilon.shape(), new long[]{miniBatchSize, timeSeriesLength, nIn });
}
//Check update:
for (String s : outGradient.gradientForVariable().keySet()) {
@ -226,7 +262,7 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
final NeuralNetConfiguration confBidirectional = new NeuralNetConfiguration.Builder()
.layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().nIn(nIn)
.nOut(layerSize)
.nOut(layerSize).dataFormat(rnnDataFormat)
.dist(new UniformDistribution(-0.1, 0.1)).activation(Activation.TANH).build())
.build();
@ -237,7 +273,8 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
.instantiate(confBidirectional, null, 0, params, true, params.dataType());
final INDArray sig = Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength});
final INDArray sig = (rnnDataFormat == RNNFormat.NCW)?Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength}):
Nd4j.rand(new int[] {miniBatchSize, timeSeriesLength, nIn});
final INDArray act1 = bidirectionalLSTM.activate(sig, false, LayerWorkspaceMgr.noWorkspaces());
@ -265,13 +302,13 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
final NeuralNetConfiguration confBidirectional =
new NeuralNetConfiguration.Builder()
.layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder()
.nIn(nIn).nOut(layerSize)
.nIn(nIn).nOut(layerSize).dataFormat(rnnDataFormat)
.dist(new UniformDistribution(-0.1, 0.1))
.activation(Activation.TANH).updater(new NoOp()).build())
.build();
final NeuralNetConfiguration confForwards = new NeuralNetConfiguration.Builder()
.layer(new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn).nOut(layerSize)
.layer(new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn).nOut(layerSize).dataFormat(rnnDataFormat)
.weightInit(WeightInit.ZERO).activation(Activation.TANH).build())
.build();
@ -290,9 +327,16 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
Nd4j.create(1, confForwards.getLayer().initializer().numParams(confForwards)));
final INDArray sig = Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength});
final INDArray sig = (rnnDataFormat == RNNFormat.NCW)?Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength}):
Nd4j.rand(new int[] {miniBatchSize, timeSeriesLength, nIn});
final INDArray sigb = sig.dup();
reverseColumnsInPlace(sigb.slice(0));
if (rnnDataFormat == RNNFormat.NCW) {
reverseColumnsInPlace(sigb.slice(0));
}
else{
reverseColumnsInPlace(sigb.slice(0).permute(1, 0));
}
final INDArray recurrentWeightsF = bidirectionalLSTM
.getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS);
@ -345,10 +389,14 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
assertArrayEquals(activation1.data().asFloat(), activation2.data().asFloat(), 1e-5f);
final INDArray randSig = Nd4j.rand(new int[] {1, layerSize, timeSeriesLength});
final INDArray randSigBackwards = randSig.dup();
reverseColumnsInPlace(randSigBackwards.slice(0));
final INDArray randSig = (rnnDataFormat == RNNFormat.NCW)?Nd4j.rand(new int[] {1, layerSize, timeSeriesLength}):
Nd4j.rand(new int[] {1, timeSeriesLength, layerSize});
INDArray randSigBackwards = randSig.dup();
if (rnnDataFormat == RNNFormat.NCW){
reverseColumnsInPlace(randSigBackwards.slice(0));
}else{
reverseColumnsInPlace(randSigBackwards.slice(0).permute(1, 0));
}
final Pair<Gradient, INDArray> backprop1 = forwardsLSTM.backpropGradient(randSig, LayerWorkspaceMgr.noWorkspaces());
final Pair<Gradient, INDArray> backprop2 = bidirectionalLSTM.backpropGradient(randSig, LayerWorkspaceMgr.noWorkspaces());
@ -399,10 +447,16 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
final INDArray activation3 = bidirectionalLSTM.activate(sigb, false, LayerWorkspaceMgr.noWorkspaces()).slice(0);
final INDArray activation3Reverse = activation3.dup();
reverseColumnsInPlace(activation3Reverse);
if (rnnDataFormat == RNNFormat.NCW){
reverseColumnsInPlace(activation3Reverse);
}
else{
reverseColumnsInPlace(activation3Reverse.permute(1, 0));
}
assertEquals(activation3Reverse, activation1);
assertArrayEquals(activation3Reverse.shape(), activation1.shape());
assertEquals(activation3Reverse, activation1);
//test backprop now
final INDArray refBackGradientReccurrent =
@ -434,7 +488,12 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
final INDArray refEpsilon = backprop1.getSecond().dup();
final INDArray backEpsilon = backprop3.getSecond().dup();
reverseColumnsInPlace(refEpsilon.slice(0));
if (rnnDataFormat == RNNFormat.NCW) {
reverseColumnsInPlace(refEpsilon.slice(0));
}
else{
reverseColumnsInPlace(refEpsilon.slice(0).permute(1, 0));
}
assertArrayEquals(backEpsilon.dup().data().asDouble(), refEpsilon.dup().data().asDouble(), 1e-6);
}
@ -477,10 +536,10 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.seed(12345).list()
.layer(0, new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder()
.gateActivationFunction(gateAfn).activation(Activation.TANH).nIn(2).nOut(2)
.gateActivationFunction(gateAfn).activation(Activation.TANH).nIn(2).nOut(2).dataFormat(rnnDataFormat)
.build())
.layer(1, new org.deeplearning4j.nn.conf.layers.RnnOutputLayer.Builder()
.lossFunction(LossFunctions.LossFunction.MSE).nIn(2).nOut(2)
.lossFunction(LossFunctions.LossFunction.MSE).nIn(2).nOut(2).dataFormat(rnnDataFormat)
.activation(Activation.TANH).build())
.build();
@ -492,7 +551,10 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
INDArray in = Nd4j.rand(new int[] {3, 2, 5});
INDArray labels = Nd4j.rand(new int[] {3, 2, 5});
if (rnnDataFormat == RNNFormat.NWC){
in = in.permute(0, 2, 1);
labels = labels.permute(0, 2, 1);
}
net.fit(in, labels);
}
}

View File

@ -21,11 +21,14 @@ import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -36,9 +39,17 @@ import java.util.Collections;
import static org.junit.Assert.assertEquals;
@RunWith(Parameterized.class)
public class MaskZeroLayerTest extends BaseDL4JTest {
private RNNFormat rnnDataFormat;
public MaskZeroLayerTest(RNNFormat rnnDataFormat){
this.rnnDataFormat = rnnDataFormat;
}
@Parameterized.Parameters
public static Object[] params(){
return RNNFormat.values();
}
@Test
public void activate() {
@ -57,7 +68,7 @@ public class MaskZeroLayerTest extends BaseDL4JTest {
.activation(Activation.IDENTITY)
.gateActivationFunction(Activation.IDENTITY)
.nIn(2)
.nOut(1)
.nOut(1).dataFormat(rnnDataFormat)
.build();
NeuralNetConfiguration conf = new NeuralNetConfiguration();
conf.setLayer(underlying);
@ -72,20 +83,25 @@ public class MaskZeroLayerTest extends BaseDL4JTest {
MaskZeroLayer l = new MaskZeroLayer(lstm, maskingValue);
INDArray input = Nd4j.create(Arrays.asList(ex1, ex2), new int[]{2, 2, 3});
if (rnnDataFormat == RNNFormat.NWC){
input = input.permute(0, 2, 1);
}
//WHEN
INDArray out = l.activate(input, true, LayerWorkspaceMgr.noWorkspaces());
if (rnnDataFormat == RNNFormat.NWC){
out = out.permute(0, 2,1);
}
//THEN output should only be incremented for the non-zero timesteps
INDArray firstExampleOutput = out.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.all());
INDArray secondExampleOutput = out.get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.all());
assertEquals(firstExampleOutput.getDouble(0), 0.0, 1e-6);
assertEquals(firstExampleOutput.getDouble(1), 1.0, 1e-6);
assertEquals(firstExampleOutput.getDouble(2), 2.0, 1e-6);
assertEquals(0.0, firstExampleOutput.getDouble(0), 1e-6);
assertEquals(1.0, firstExampleOutput.getDouble(1), 1e-6);
assertEquals(2.0, firstExampleOutput.getDouble(2), 1e-6);
assertEquals(secondExampleOutput.getDouble(0), 0.0, 1e-6);
assertEquals(secondExampleOutput.getDouble(1), 0.0, 1e-6);
assertEquals(secondExampleOutput.getDouble(2), 1.0, 1e-6);
assertEquals(0.0, secondExampleOutput.getDouble(0), 1e-6);
assertEquals(0.0, secondExampleOutput.getDouble(1), 1e-6);
assertEquals(1.0, secondExampleOutput.getDouble(2), 1e-6);
}
@ -94,7 +110,7 @@ public class MaskZeroLayerTest extends BaseDL4JTest {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.list()
.layer(new org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer.Builder()
.setMaskValue(0.0).setUnderlying(new LSTM.Builder().nIn(4).nOut(5).build()).build())
.setMaskValue(0.0).setUnderlying(new LSTM.Builder().nIn(4).nOut(5).dataFormat(rnnDataFormat).build()).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();

View File

@ -0,0 +1,394 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.nn.layers.recurrent;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import static org.junit.Assert.assertEquals;
@RunWith(Parameterized.class)
@AllArgsConstructor
public class RnnDataFormatTests extends BaseDL4JTest {
private boolean helpers;
private boolean lastTimeStep;
private boolean maskZeros;
@Parameterized.Parameters(name = "helpers={0},lastTimeStep={1},maskZero={2}")
public static List params(){
List<Object[]> ret = new ArrayList<>();
for (boolean helpers: new boolean[]{true, false})
for (boolean lastTimeStep: new boolean[]{true, false})
for (boolean maskZero: new boolean[]{true, false})
ret.add(new Object[]{helpers, lastTimeStep, maskZero});
return ret;
}
@Test
public void testSimpleRnn() {
try {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = "Helpers: " + helpers + ", lastTimeStep: " + lastTimeStep + ", maskZeros: " + maskZeros;
System.out.println(" --- " + msg + " ---");
INDArray inNCW = Nd4j.rand(DataType.FLOAT, 2, 3, 12);
INDArray labelsNWC = (lastTimeStep) ?TestUtils.randomOneHot(2, 10): TestUtils.randomOneHot(2 * 12, 10).reshape(2, 12, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getSimpleRnnNet(RNNFormat.NCW, true, lastTimeStep, maskZeros))
.net2(getSimpleRnnNet(RNNFormat.NCW, false, lastTimeStep, maskZeros))
.net3(getSimpleRnnNet(RNNFormat.NWC, true, lastTimeStep, maskZeros))
.net4(getSimpleRnnNet(RNNFormat.NWC, false, lastTimeStep, maskZeros))
.inNCW(inNCW)
.labelsNCW((lastTimeStep)? labelsNWC: labelsNWC.permute(0, 2, 1))
.labelsNWC(labelsNWC)
.testLayerIdx(1)
.build();
TestCase.testHelper(tc);
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testLSTM() {
try {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = "Helpers: " + helpers + ", lastTimeStep: " + lastTimeStep + ", maskZeros: " + maskZeros;
System.out.println(" --- " + msg + " ---");
INDArray inNCW = Nd4j.rand(DataType.FLOAT, 2, 3, 12);
INDArray labelsNWC = (lastTimeStep) ?TestUtils.randomOneHot(2, 10): TestUtils.randomOneHot(2 * 12, 10).reshape(2, 12, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getLstmNet(RNNFormat.NCW, true, lastTimeStep, maskZeros))
.net2(getLstmNet(RNNFormat.NCW, false, lastTimeStep, maskZeros))
.net3(getLstmNet(RNNFormat.NWC, true, lastTimeStep, maskZeros))
.net4(getLstmNet(RNNFormat.NWC, false, lastTimeStep, maskZeros))
.inNCW(inNCW)
.labelsNCW((lastTimeStep)? labelsNWC: labelsNWC.permute(0, 2, 1))
.labelsNWC(labelsNWC)
.testLayerIdx(1)
.build();
TestCase.testHelper(tc);
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testGraveLSTM() {
try {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = "Helpers: " + helpers + ", lastTimeStep: " + lastTimeStep + ", maskZeros: " + maskZeros;
System.out.println(" --- " + msg + " ---");
INDArray inNCW = Nd4j.rand(DataType.FLOAT, 2, 3, 12);
INDArray labelsNWC = (lastTimeStep) ?TestUtils.randomOneHot(2, 10): TestUtils.randomOneHot(2 * 12, 10).reshape(2, 12, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getGravesLstmNet(RNNFormat.NCW, true, lastTimeStep, maskZeros))
.net2(getGravesLstmNet(RNNFormat.NCW, false, lastTimeStep, maskZeros))
.net3(getGravesLstmNet(RNNFormat.NWC, true, lastTimeStep, maskZeros))
.net4(getGravesLstmNet(RNNFormat.NWC, false, lastTimeStep, maskZeros))
.inNCW(inNCW)
.labelsNCW((lastTimeStep)? labelsNWC: labelsNWC.permute(0, 2, 1))
.labelsNWC(labelsNWC)
.testLayerIdx(1)
.build();
TestCase.testHelper(tc);
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testGraveBiLSTM() {
try {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = "Helpers: " + helpers + ", lastTimeStep: " + lastTimeStep + ", maskZeros: " + maskZeros;
System.out.println(" --- " + msg + " ---");
INDArray inNCW = Nd4j.rand(DataType.FLOAT, 2, 3, 12);
INDArray labelsNWC = (lastTimeStep) ?TestUtils.randomOneHot(2, 10): TestUtils.randomOneHot(2 * 12, 10).reshape(2, 12, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getGravesBidirectionalLstmNet(RNNFormat.NCW, true, lastTimeStep, maskZeros))
.net2(getGravesBidirectionalLstmNet(RNNFormat.NCW, false, lastTimeStep, maskZeros))
.net3(getGravesBidirectionalLstmNet(RNNFormat.NWC, true, lastTimeStep, maskZeros))
.net4(getGravesBidirectionalLstmNet(RNNFormat.NWC, false, lastTimeStep, maskZeros))
.inNCW(inNCW)
.labelsNCW((lastTimeStep)? labelsNWC: labelsNWC.permute(0, 2, 1))
.labelsNWC(labelsNWC)
.testLayerIdx(1)
.build();
TestCase.testHelper(tc);
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
private MultiLayerNetwork getGravesBidirectionalLstmNet(RNNFormat format, boolean setOnLayerAlso, boolean lastTimeStep, boolean maskZeros) {
if (setOnLayerAlso) {
return getNetWithLayer(new GravesBidirectionalLSTM.Builder().nOut(3)
.dataFormat(format).build(), format, lastTimeStep, maskZeros);
} else {
return getNetWithLayer(new GravesBidirectionalLSTM.Builder().nOut(3).build(), format, lastTimeStep, maskZeros);
}
}
private MultiLayerNetwork getGravesLstmNet(RNNFormat format, boolean setOnLayerAlso, boolean lastTimeStep, boolean maskZeros) {
if (setOnLayerAlso) {
return getNetWithLayer(new GravesLSTM.Builder().nOut(3)
.dataFormat(format).build(), format, lastTimeStep, maskZeros);
} else {
return getNetWithLayer(new GravesLSTM.Builder().nOut(3).build(), format, lastTimeStep, maskZeros);
}
}
private MultiLayerNetwork getLstmNet(RNNFormat format, boolean setOnLayerAlso, boolean lastTimeStep, boolean maskZeros) {
if (setOnLayerAlso) {
return getNetWithLayer(new LSTM.Builder().nOut(3)
.dataFormat(format).build(), format, lastTimeStep, maskZeros);
} else {
return getNetWithLayer(new LSTM.Builder().nOut(3).build(), format, lastTimeStep, maskZeros);
}
}
private MultiLayerNetwork getSimpleRnnNet(RNNFormat format, boolean setOnLayerAlso, boolean lastTimeStep, boolean maskZeros) {
if (setOnLayerAlso) {
return getNetWithLayer(new SimpleRnn.Builder().nOut(3)
.dataFormat(format).build(), format, lastTimeStep, maskZeros);
} else {
return getNetWithLayer(new SimpleRnn.Builder().nOut(3).build(), format, lastTimeStep, maskZeros);
}
}
private MultiLayerNetwork getNetWithLayer(Layer layer, RNNFormat format, boolean lastTimeStep, boolean maskZeros) {
if (maskZeros){
layer = new MaskZeroLayer.Builder().setMaskValue(0.).setUnderlying(layer).build();
}
if(lastTimeStep){
layer = new LastTimeStep(layer);
}
NeuralNetConfiguration.ListBuilder builder = new NeuralNetConfiguration.Builder()
.seed(12345)
.list()
.layer(new LSTM.Builder()
.nIn(3)
.activation(Activation.TANH)
.dataFormat(format)
.nOut(3)
.helperAllowFallback(false)
.build())
.layer(layer)
.layer(
(lastTimeStep)?new OutputLayer.Builder().activation(Activation.SOFTMAX).nOut(10).build():
new RnnOutputLayer.Builder().activation(Activation.SOFTMAX).nOut(10).dataFormat(format).build()
)
.setInputType(InputType.recurrent(3, 12, format));
MultiLayerNetwork net = new MultiLayerNetwork(builder.build());
net.init();
return net;
}
@AllArgsConstructor
@Data
@NoArgsConstructor
@Builder
private static class TestCase {
private String msg;
private MultiLayerNetwork net1;
private MultiLayerNetwork net2;
private MultiLayerNetwork net3;
private MultiLayerNetwork net4;
private INDArray inNCW;
private INDArray labelsNCW;
private INDArray labelsNWC;
private int testLayerIdx;
private boolean nwcOutput;
public static void testHelper(TestCase tc) {
tc.net2.params().assign(tc.net1.params());
tc.net3.params().assign(tc.net1.params());
tc.net4.params().assign(tc.net1.params());
INDArray inNCW = tc.inNCW;
INDArray inNWC = tc.inNCW.permute(0, 2, 1).dup();
INDArray l0_1 = tc.net1.feedForward(inNCW).get(tc.testLayerIdx + 1);
INDArray l0_2 = tc.net2.feedForward(inNCW).get(tc.testLayerIdx + 1);
INDArray l0_3 = tc.net3.feedForward(inNWC).get(tc.testLayerIdx + 1);
INDArray l0_4 = tc.net4.feedForward(inNWC).get(tc.testLayerIdx + 1);
boolean rank3Out = tc.labelsNCW.rank() == 3;
assertEquals(tc.msg, l0_1, l0_2);
if (rank3Out){
assertEquals(tc.msg, l0_1, l0_3.permute(0, 2, 1));
assertEquals(tc.msg, l0_1, l0_4.permute(0, 2, 1));
}
else{
assertEquals(tc.msg, l0_1, l0_3);
assertEquals(tc.msg, l0_1, l0_4);
}
INDArray out1 = tc.net1.output(inNCW);
INDArray out2 = tc.net2.output(inNCW);
INDArray out3 = tc.net3.output(inNWC);
INDArray out4 = tc.net4.output(inNWC);
assertEquals(tc.msg, out1, out2);
if (rank3Out){
assertEquals(tc.msg, out1, out3.permute(0, 2, 1)); //NWC to NCW
assertEquals(tc.msg, out1, out4.permute(0, 2, 1));
}
else{
assertEquals(tc.msg, out1, out3); //NWC to NCW
assertEquals(tc.msg, out1, out4);
}
//Test backprop
Pair<Gradient, INDArray> p1 = tc.net1.calculateGradients(inNCW, tc.labelsNCW, null, null);
Pair<Gradient, INDArray> p2 = tc.net2.calculateGradients(inNCW, tc.labelsNCW, null, null);
Pair<Gradient, INDArray> p3 = tc.net3.calculateGradients(inNWC, tc.labelsNWC, null, null);
Pair<Gradient, INDArray> p4 = tc.net4.calculateGradients(inNWC, tc.labelsNWC, null, null);
//Inpput gradients
assertEquals(tc.msg, p1.getSecond(), p2.getSecond());
assertEquals(tc.msg, p1.getSecond(), p3.getSecond().permute(0, 2, 1)); //Input gradients for NWC input are also in NWC format
assertEquals(tc.msg, p1.getSecond(), p4.getSecond().permute(0, 2, 1));
List<String> diff12 = differentGrads(p1.getFirst(), p2.getFirst());
List<String> diff13 = differentGrads(p1.getFirst(), p3.getFirst());
List<String> diff14 = differentGrads(p1.getFirst(), p4.getFirst());
assertEquals(tc.msg + " " + diff12, 0, diff12.size());
assertEquals(tc.msg + " " + diff13, 0, diff13.size());
assertEquals(tc.msg + " " + diff14, 0, diff14.size());
assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p2.getFirst().gradientForVariable());
assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p3.getFirst().gradientForVariable());
assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p4.getFirst().gradientForVariable());
tc.net1.fit(inNCW, tc.labelsNCW);
tc.net2.fit(inNCW, tc.labelsNCW);
tc.net3.fit(inNWC, tc.labelsNWC);
tc.net4.fit(inNWC, tc.labelsNWC);
assertEquals(tc.msg, tc.net1.params(), tc.net2.params());
assertEquals(tc.msg, tc.net1.params(), tc.net3.params());
assertEquals(tc.msg, tc.net1.params(), tc.net4.params());
//Test serialization
MultiLayerNetwork net1a = TestUtils.testModelSerialization(tc.net1);
MultiLayerNetwork net2a = TestUtils.testModelSerialization(tc.net2);
MultiLayerNetwork net3a = TestUtils.testModelSerialization(tc.net3);
MultiLayerNetwork net4a = TestUtils.testModelSerialization(tc.net4);
out1 = tc.net1.output(inNCW);
assertEquals(tc.msg, out1, net1a.output(inNCW));
assertEquals(tc.msg, out1, net2a.output(inNCW));
if (rank3Out) {
assertEquals(tc.msg, out1, net3a.output(inNWC).permute(0, 2, 1)); //NWC to NCW
assertEquals(tc.msg, out1, net4a.output(inNWC).permute(0, 2, 1));
}
else{
assertEquals(tc.msg, out1, net3a.output(inNWC)); //NWC to NCW
assertEquals(tc.msg, out1, net4a.output(inNWC));
}
}
}
private static List<String> differentGrads(Gradient g1, Gradient g2){
List<String> differs = new ArrayList<>();
Map<String,INDArray> m1 = g1.gradientForVariable();
Map<String,INDArray> m2 = g2.gradientForVariable();
for(String s : m1.keySet()){
INDArray a1 = m1.get(s);
INDArray a2 = m2.get(s);
if(!a1.equals(a2)){
differs.add(s);
}
}
return differs;
}
}

View File

@ -21,6 +21,7 @@ import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.LSTM;
@ -29,6 +30,8 @@ import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
@ -42,14 +45,25 @@ import static org.nd4j.linalg.activations.Activation.IDENTITY;
import static org.nd4j.linalg.activations.Activation.TANH;
import static org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction.MSE;
@RunWith(Parameterized.class)
public class TestLastTimeStepLayer extends BaseDL4JTest {
private RNNFormat rnnDataFormat;
public TestLastTimeStepLayer(RNNFormat rnnDataFormat){
this.rnnDataFormat = rnnDataFormat;
}
@Parameterized.Parameters(name="{0}")
public static Object[] params(){
return RNNFormat.values();
}
@Test
public void testLastTimeStepVertex() {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in")
.addLayer("lastTS", new LastTimeStep(new SimpleRnn.Builder()
.nIn(5).nOut(6).build()), "in")
.nIn(5).nOut(6).dataFormat(rnnDataFormat).build()), "in")
.setOutputs("lastTS")
.build();
@ -59,9 +73,22 @@ public class TestLastTimeStepLayer extends BaseDL4JTest {
//First: test without input mask array
Nd4j.getRandom().setSeed(12345);
Layer l = graph.getLayer("lastTS");
INDArray in = Nd4j.rand(new int[]{3, 5, 6});
INDArray in;
if (rnnDataFormat == RNNFormat.NCW){
in = Nd4j.rand(3, 5, 6);
}
else{
in = Nd4j.rand(3, 6, 5);
}
INDArray outUnderlying = ((LastTimeStepLayer)l).getUnderlying().activate(in, false, LayerWorkspaceMgr.noWorkspaces());
INDArray expOut = outUnderlying.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(5));
INDArray expOut;
if (rnnDataFormat == RNNFormat.NCW){
expOut = outUnderlying.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(5));
}
else{
expOut = outUnderlying.get(NDArrayIndex.all(), NDArrayIndex.point(5), NDArrayIndex.all());
}
//Forward pass:
@ -76,9 +103,17 @@ public class TestLastTimeStepLayer extends BaseDL4JTest {
graph.setLayerMaskArrays(new INDArray[]{inMask}, null);
expOut = Nd4j.zeros(3, 6);
expOut.putRow(0, outUnderlying.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(2)));
expOut.putRow(1, outUnderlying.get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.point(3)));
expOut.putRow(2, outUnderlying.get(NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.point(4)));
if (rnnDataFormat == RNNFormat.NCW){
expOut.putRow(0, outUnderlying.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(2)));
expOut.putRow(1, outUnderlying.get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.point(3)));
expOut.putRow(2, outUnderlying.get(NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.point(4)));
}
else{
expOut.putRow(0, outUnderlying.get(NDArrayIndex.point(0), NDArrayIndex.point(2), NDArrayIndex.all()));
expOut.putRow(1, outUnderlying.get(NDArrayIndex.point(1), NDArrayIndex.point(3), NDArrayIndex.all()));
expOut.putRow(2, outUnderlying.get(NDArrayIndex.point(2), NDArrayIndex.point(4), NDArrayIndex.all()));
}
outFwd = l.activate(in, false, LayerWorkspaceMgr.noWorkspaces());
assertEquals(expOut, outFwd);
@ -97,9 +132,9 @@ public class TestLastTimeStepLayer extends BaseDL4JTest {
.seed(1234)
.graphBuilder()
.addInputs("in")
.setInputTypes(InputType.recurrent(1))
.setInputTypes(InputType.recurrent(1, rnnDataFormat))
.addLayer("RNN", new LastTimeStep(new LSTM.Builder()
.nOut(10)
.nOut(10).dataFormat(rnnDataFormat)
.build()), "in")
.addLayer("dense", new DenseLayer.Builder()
.nOut(10)
@ -120,7 +155,9 @@ public class TestLastTimeStepLayer extends BaseDL4JTest {
INDArray fm2 = Nd4j.zeros(1,24);
INDArray fm3 = Nd4j.zeros(1,24);
fm3.get(NDArrayIndex.point(0), NDArrayIndex.interval(0,5)).assign(1);
if (rnnDataFormat == RNNFormat.NWC){
f = f.permute(0, 2, 1);
}
INDArray[] out1 = cg.output(false, new INDArray[]{f}, new INDArray[]{fm1});
try {
cg.output(false, new INDArray[]{f}, new INDArray[]{fm2});

View File

@ -20,6 +20,7 @@ import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.dropout.TestDropout;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.conf.layers.LSTM;
@ -31,6 +32,8 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -41,13 +44,24 @@ import org.nd4j.linalg.primitives.Pair;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertTrue;
@RunWith(Parameterized.class)
public class TestRnnLayers extends BaseDL4JTest {
private RNNFormat rnnDataFormat;
public TestRnnLayers(RNNFormat rnnDataFormat){
this.rnnDataFormat = rnnDataFormat;
}
@Parameterized.Parameters
public static Object[] params(){
return RNNFormat.values();
}
@Test
public void testTimeStepIs3Dimensional() {
@ -58,8 +72,8 @@ public class TestRnnLayers extends BaseDL4JTest {
.updater(new NoOp())
.weightInit(WeightInit.XAVIER)
.list()
.layer(new SimpleRnn.Builder().nIn(nIn).nOut(3).build())
.layer(new LSTM.Builder().nIn(3).nOut(5).build())
.layer(new SimpleRnn.Builder().nIn(nIn).nOut(3).dataFormat(rnnDataFormat).build())
.layer(new LSTM.Builder().nIn(3).nOut(5).dataFormat(rnnDataFormat).build())
.layer(new RnnOutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX).build())
.build();
@ -70,9 +84,9 @@ public class TestRnnLayers extends BaseDL4JTest {
org.deeplearning4j.nn.layers.recurrent.SimpleRnn simpleRnn =
(org.deeplearning4j.nn.layers.recurrent.SimpleRnn) net.getLayer(0);
INDArray rnnInput3d = Nd4j.create(10, 12, 1);
INDArray rnnInput3d = (rnnDataFormat==RNNFormat.NCW)?Nd4j.create(10,12, 1):Nd4j.create(10, 1, 12);
INDArray simpleOut = simpleRnn.rnnTimeStep(rnnInput3d, LayerWorkspaceMgr.noWorkspaces());
assertTrue(Arrays.equals(simpleOut.shape(), new long[] {10, 3, 1}));
assertTrue(Arrays.equals(simpleOut.shape(), (rnnDataFormat==RNNFormat.NCW)?new long[] {10, 3, 1}:new long[]{10, 1, 3}));
INDArray rnnInput2d = Nd4j.create(10, 12);
try {
@ -84,9 +98,9 @@ public class TestRnnLayers extends BaseDL4JTest {
org.deeplearning4j.nn.layers.recurrent.LSTM lstm =
(org.deeplearning4j.nn.layers.recurrent.LSTM) net.getLayer(1);
INDArray lstmInput3d = Nd4j.create(10, 3, 1);
INDArray lstmInput3d = (rnnDataFormat==RNNFormat.NCW)?Nd4j.create(10, 3, 1):Nd4j.create(10, 1, 3);
INDArray lstmOut = lstm.rnnTimeStep(lstmInput3d, LayerWorkspaceMgr.noWorkspaces());
assertTrue(Arrays.equals(lstmOut.shape(), new long[] {10, 5, 1}));
assertTrue(Arrays.equals(lstmOut.shape(), (rnnDataFormat==RNNFormat.NCW)?new long[] {10, 5, 1}:new long[]{10, 1, 5}));
INDArray lstmInput2d = Nd4j.create(10, 3);
try {
@ -112,19 +126,19 @@ public class TestRnnLayers extends BaseDL4JTest {
TestDropout.CustomDropout cd = new TestDropout.CustomDropout();
switch (s){
case "graves":
layer = new GravesLSTM.Builder().activation(Activation.TANH).nIn(10).nOut(10).build();
layerD = new GravesLSTM.Builder().dropOut(0.5).activation(Activation.TANH).nIn(10).nOut(10).build();
layerD2 = new GravesLSTM.Builder().dropOut(cd).activation(Activation.TANH).nIn(10).nOut(10).build();
layer = new GravesLSTM.Builder().activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
layerD = new GravesLSTM.Builder().dropOut(0.5).activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
layerD2 = new GravesLSTM.Builder().dropOut(cd).activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
break;
case "lstm":
layer = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().activation(Activation.TANH).nIn(10).nOut(10).build();
layerD = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().dropOut(0.5).activation(Activation.TANH).nIn(10).nOut(10).build();
layerD2 = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().dropOut(cd).activation(Activation.TANH).nIn(10).nOut(10).build();
layer = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
layerD = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().dropOut(0.5).activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
layerD2 = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().dropOut(cd).activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
break;
case "simple":
layer = new SimpleRnn.Builder().activation(Activation.TANH).nIn(10).nOut(10).build();
layerD = new SimpleRnn.Builder().dropOut(0.5).activation(Activation.TANH).nIn(10).nOut(10).build();
layerD2 = new SimpleRnn.Builder().dropOut(cd).activation(Activation.TANH).nIn(10).nOut(10).build();
layer = new SimpleRnn.Builder().activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
layerD = new SimpleRnn.Builder().dropOut(0.5).activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
layerD2 = new SimpleRnn.Builder().dropOut(cd).activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
break;
default:
throw new RuntimeException(s);
@ -134,21 +148,21 @@ public class TestRnnLayers extends BaseDL4JTest {
.seed(12345)
.list()
.layer(layer)
.layer(new RnnOutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).build())
.layer(new RnnOutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).dataFormat(rnnDataFormat).build())
.build();
MultiLayerConfiguration confD = new NeuralNetConfiguration.Builder()
.seed(12345)
.list()
.layer(layerD)
.layer(new RnnOutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).build())
.layer(new RnnOutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).dataFormat(rnnDataFormat).build())
.build();
MultiLayerConfiguration confD2 = new NeuralNetConfiguration.Builder()
.seed(12345)
.list()
.layer(layerD2)
.layer(new RnnOutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).build())
.layer(new RnnOutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).dataFormat(rnnDataFormat).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
@ -178,7 +192,6 @@ public class TestRnnLayers extends BaseDL4JTest {
assertNotEquals(s, out2, out2D);
INDArray l = TestUtils.randomOneHotTimeSeries(3, 10, 10, 12345);
net.fit(f.dup(), l);
netD.fit(f.dup(), l);
assertNotEquals(s, net.params(), netD.params());
@ -205,14 +218,14 @@ public class TestRnnLayers extends BaseDL4JTest {
NeuralNetConfiguration.ListBuilder lb = new NeuralNetConfiguration.Builder()
.list()
.layer(new SimpleRnn.Builder().nIn(5).nOut(5).build());
.layer(new SimpleRnn.Builder().nIn(5).nOut(5).dataFormat(rnnDataFormat).build());
switch (i){
case 0:
lb.layer(new RnnOutputLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).nIn(5).nOut(5).build());
lb.layer(new RnnOutputLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).nIn(5).nOut(5).dataFormat(rnnDataFormat).build());
break;
case 1:
lb.layer(new RnnLossLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build());
lb.layer(new RnnLossLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).dataFormat(rnnDataFormat).build());
break;
default:
throw new RuntimeException();
@ -223,14 +236,14 @@ public class TestRnnLayers extends BaseDL4JTest {
net.init();
INDArray in = Nd4j.rand(DataType.FLOAT, 3, 5, 5);
INDArray l = TestUtils.randomOneHotTimeSeries(3, 5, 10);
INDArray l = TestUtils.randomOneHotTimeSeries(rnnDataFormat, 3, 5, 10, new Random(12345));
try{
net.fit(in,l);
} catch (Throwable t){
String msg = t.getMessage();
if(msg == null)
t.printStackTrace();
System.out.println(i);
assertTrue(msg, msg != null && msg.contains("sequence length") && msg.contains("input") && msg.contains("label"));
}

View File

@ -20,10 +20,13 @@ import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -36,8 +39,18 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.all;
import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
import static org.nd4j.linalg.indexing.NDArrayIndex.point;
@RunWith(Parameterized.class)
public class TestSimpleRnn extends BaseDL4JTest {
private RNNFormat rnnDataFormat;
public TestSimpleRnn(RNNFormat rnnDataFormat){
this.rnnDataFormat = rnnDataFormat;
}
@Parameterized.Parameters
public static Object[] params(){
return RNNFormat.values();
}
@Test
public void testSimpleRnn(){
Nd4j.getRandom().setSeed(12345);
@ -46,15 +59,21 @@ public class TestSimpleRnn extends BaseDL4JTest {
int nIn = 5;
int layerSize = 6;
int tsLength = 7;
INDArray in = Nd4j.rand(DataType.FLOAT, new int[]{m, nIn, tsLength});
// in.get(all(), all(), interval(1,tsLength)).assign(0);
INDArray in;
if (rnnDataFormat == RNNFormat.NCW){
in = Nd4j.rand(DataType.FLOAT, m, nIn, tsLength);
}
else{
in = Nd4j.rand(DataType.FLOAT, m, tsLength, nIn);
}
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.updater(new NoOp())
.weightInit(WeightInit.XAVIER)
.activation(Activation.TANH)
.list()
.layer(new SimpleRnn.Builder().nIn(nIn).nOut(layerSize).build())
.layer(new SimpleRnn.Builder().nIn(nIn).nOut(layerSize).dataFormat(rnnDataFormat).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
@ -68,7 +87,13 @@ public class TestSimpleRnn extends BaseDL4JTest {
INDArray outLast = null;
for( int i=0; i<tsLength; i++ ){
INDArray inCurrent = in.get(all(), all(), point(i));
INDArray inCurrent;
if (rnnDataFormat == RNNFormat.NCW){
inCurrent = in.get(all(), all(), point(i));
}
else{
inCurrent = in.get(all(), point(i), all());
}
INDArray outExpCurrent = inCurrent.mmul(w);
if(outLast != null){
@ -79,7 +104,13 @@ public class TestSimpleRnn extends BaseDL4JTest {
Transforms.tanh(outExpCurrent, false);
INDArray outActCurrent = out.get(all(), all(), point(i));
INDArray outActCurrent;
if (rnnDataFormat == RNNFormat.NCW){
outActCurrent = out.get(all(), all(), point(i));
}
else{
outActCurrent = out.get(all(), point(i), all());
}
assertEquals(String.valueOf(i), outExpCurrent, outActCurrent);
outLast = outExpCurrent;
@ -100,7 +131,7 @@ public class TestSimpleRnn extends BaseDL4JTest {
.weightInit(WeightInit.XAVIER)
.activation(Activation.TANH)
.list()
.layer(new SimpleRnn.Builder().nIn(nIn).nOut(layerSize)
.layer(new SimpleRnn.Builder().nIn(nIn).nOut(layerSize).dataFormat(rnnDataFormat)
.biasInit(100)
.build())
.build();

View File

@ -4,14 +4,21 @@ import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -22,8 +29,18 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
import static org.junit.Assert.assertEquals;
@RunWith(Parameterized.class)
public class TestTimeDistributed extends BaseDL4JTest {
private RNNFormat rnnDataFormat;
public TestTimeDistributed(RNNFormat rnnDataFormat){
this.rnnDataFormat = rnnDataFormat;
}
@Parameterized.Parameters
public static Object[] params(){
return RNNFormat.values();
}
@Test
public void testTimeDistributed(){
for(WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.ENABLED, WorkspaceMode.NONE}) {
@ -34,11 +51,11 @@ public class TestTimeDistributed extends BaseDL4JTest {
.seed(12345)
.updater(new Adam(0.1))
.list()
.layer(new LSTM.Builder().nIn(3).nOut(3).build())
.layer(new LSTM.Builder().nIn(3).nOut(3).dataFormat(rnnDataFormat).build())
.layer(new DenseLayer.Builder().nIn(3).nOut(3).activation(Activation.TANH).build())
.layer(new RnnOutputLayer.Builder().nIn(3).nOut(3).activation(Activation.SOFTMAX)
.layer(new RnnOutputLayer.Builder().nIn(3).nOut(3).activation(Activation.SOFTMAX).dataFormat(rnnDataFormat)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.setInputType(InputType.recurrent(3))
.setInputType(InputType.recurrent(3, rnnDataFormat))
.build();
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder()
@ -47,11 +64,11 @@ public class TestTimeDistributed extends BaseDL4JTest {
.seed(12345)
.updater(new Adam(0.1))
.list()
.layer(new LSTM.Builder().nIn(3).nOut(3).build())
.layer(new TimeDistributed(new DenseLayer.Builder().nIn(3).nOut(3).activation(Activation.TANH).build(), 2))
.layer(new RnnOutputLayer.Builder().nIn(3).nOut(3).activation(Activation.SOFTMAX)
.layer(new LSTM.Builder().nIn(3).nOut(3).dataFormat(rnnDataFormat).build())
.layer(new TimeDistributed(new DenseLayer.Builder().nIn(3).nOut(3).activation(Activation.TANH).build(), rnnDataFormat))
.layer(new RnnOutputLayer.Builder().nIn(3).nOut(3).activation(Activation.SOFTMAX).dataFormat(rnnDataFormat)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.setInputType(InputType.recurrent(3))
.setInputType(InputType.recurrent(3, rnnDataFormat))
.build();
MultiLayerNetwork net1 = new MultiLayerNetwork(conf1);
@ -62,13 +79,21 @@ public class TestTimeDistributed extends BaseDL4JTest {
for( int mb : new int[]{1, 5}) {
for(char inLabelOrder : new char[]{'c', 'f'}) {
INDArray in = Nd4j.rand(DataType.FLOAT, mb, 3, 5).dup(inLabelOrder);
if (rnnDataFormat == RNNFormat.NWC){
in = in.permute(0, 2, 1);
}
INDArray out1 = net1.output(in);
INDArray out2 = net2.output(in);
assertEquals(out1, out2);
INDArray labels = TestUtils.randomOneHotTimeSeries(mb, 3, 5).dup(inLabelOrder);
INDArray labels ;
if (rnnDataFormat == RNNFormat.NCW) {
labels = TestUtils.randomOneHotTimeSeries(mb, 3, 5).dup(inLabelOrder);
}else{
labels = TestUtils.randomOneHotTimeSeries(mb, 5, 3).dup(inLabelOrder);
}
DataSet ds = new DataSet(in, labels);
net1.fit(ds);
@ -85,4 +110,73 @@ public class TestTimeDistributed extends BaseDL4JTest {
}
}
}
@Test
public void testTimeDistributedDense(){
for( int rnnType=0; rnnType<3; rnnType++ ) {
for( int ffType=0; ffType<3; ffType++ ) {
Layer l0, l2;
switch (rnnType) {
case 0:
l0 = new LSTM.Builder().nOut(5).build();
l2 = new LSTM.Builder().nOut(5).build();
break;
case 1:
l0 = new SimpleRnn.Builder().nOut(5).build();
l2 = new SimpleRnn.Builder().nOut(5).build();
break;
case 2:
l0 = new Bidirectional(new LSTM.Builder().nOut(5).build());
l2 = new Bidirectional(new LSTM.Builder().nOut(5).build());
break;
default:
throw new RuntimeException("Not implemented: " + rnnType);
}
Layer l1;
switch (ffType){
case 0:
l1 = new DenseLayer.Builder().nOut(5).build();
break;
case 1:
l1 = new VariationalAutoencoder.Builder().nOut(5).encoderLayerSizes(5).decoderLayerSizes(5).build();
break;
case 2:
l1 = new AutoEncoder.Builder().nOut(5).build();
break;
default:
throw new RuntimeException("Not implemented: " + ffType);
}
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.activation(Activation.TANH)
.list()
.layer(l0)
.layer(l1)
.layer(l2)
.setInputType(InputType.recurrent(5, 9, rnnDataFormat))
.build();
BaseRecurrentLayer l0a;
BaseRecurrentLayer l2a;
if (rnnType < 2) {
l0a = (BaseRecurrentLayer) l0;
l2a = (BaseRecurrentLayer) l2;
} else {
l0a = (BaseRecurrentLayer) ((Bidirectional) l0).getFwd();
l2a = (BaseRecurrentLayer) ((Bidirectional) l2).getFwd();
}
assertEquals(rnnDataFormat, l0a.getRnnDataFormat());
assertEquals(rnnDataFormat, l2a.getRnnDataFormat());
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
INDArray in = Nd4j.rand(DataType.FLOAT, rnnDataFormat == RNNFormat.NCW ? new long[]{2, 5, 9} : new long[]{2, 9, 5} );
net.output(in);
}
}
}
}

View File

@ -771,7 +771,7 @@ public class MultiLayerTestRNN extends BaseDL4JTest {
.build();
fail("Exception expected");
} catch (IllegalStateException e){
// e.printStackTrace();
log.info(e.toString());
assertTrue(e.getMessage().contains("TBPTT") && e.getMessage().contains("validateTbpttConfig"));
}
}

View File

@ -408,8 +408,8 @@ public class TestVariableLengthTS extends BaseDL4JTest {
INDArray outRow2 = out2.get(NDArrayIndex.point(i), NDArrayIndex.all(),
NDArrayIndex.point(j));
for (int k = 0; k < nOut; k++) {
assertEquals(outRow.getDouble(k), 0.0, 0.0);
assertEquals(outRow2.getDouble(k), 0.0, 0.0);
assertEquals(0.0, outRow.getDouble(k), 0.0);
assertEquals(0.0, outRow2.getDouble(k), 0.0);
}
}
}

View File

@ -336,7 +336,7 @@ public class TestUpdaters extends BaseDL4JTest {
actualM[i] = Math.round(actualM[i] * 1e2) / 1e2;
}
assertEquals("Wrong weight gradient after first iteration's update", Arrays.equals(actualM, expectedM), true);
assertEquals("Wrong weight gradient after first iteration's update", Arrays.equals(expectedM, actualM), true);
}

View File

@ -159,14 +159,14 @@ public class RegressionTest050 extends BaseDL4JTest {
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
assertArrayEquals(new int[] {1, 1}, l0.getStride());
assertArrayEquals(new int[] {0, 0}, l0.getPadding());
assertEquals(l0.getConvolutionMode(), ConvolutionMode.Truncate); //Pre-0.7.0: no ConvolutionMode. Want to default to truncate here if not set
assertEquals(ConvolutionMode.Truncate, l0.getConvolutionMode()); //Pre-0.7.0: no ConvolutionMode. Want to default to truncate here if not set
SubsamplingLayer l1 = (SubsamplingLayer) conf.getConf(1).getLayer();
assertArrayEquals(new int[] {2, 2}, l1.getKernelSize());
assertArrayEquals(new int[] {1, 1}, l1.getStride());
assertArrayEquals(new int[] {0, 0}, l1.getPadding());
assertEquals(PoolingType.MAX, l1.getPoolingType());
assertEquals(l1.getConvolutionMode(), ConvolutionMode.Truncate); //Pre-0.7.0: no ConvolutionMode. Want to default to truncate here if not set
assertEquals(ConvolutionMode.Truncate, l1.getConvolutionMode()); //Pre-0.7.0: no ConvolutionMode. Want to default to truncate here if not set
OutputLayer l2 = (OutputLayer) conf.getConf(2).getLayer();
assertEquals("sigmoid", l2.getActivationFn().toString());

View File

@ -162,14 +162,14 @@ public class RegressionTest060 extends BaseDL4JTest {
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
assertArrayEquals(new int[] {1, 1}, l0.getStride());
assertArrayEquals(new int[] {0, 0}, l0.getPadding());
assertEquals(l0.getConvolutionMode(), ConvolutionMode.Truncate); //Pre-0.7.0: no ConvolutionMode. Want to default to truncate here if not set
assertEquals(ConvolutionMode.Truncate, l0.getConvolutionMode()); //Pre-0.7.0: no ConvolutionMode. Want to default to truncate here if not set
SubsamplingLayer l1 = (SubsamplingLayer) conf.getConf(1).getLayer();
assertArrayEquals(new int[] {2, 2}, l1.getKernelSize());
assertArrayEquals(new int[] {1, 1}, l1.getStride());
assertArrayEquals(new int[] {0, 0}, l1.getPadding());
assertEquals(PoolingType.MAX, l1.getPoolingType());
assertEquals(l1.getConvolutionMode(), ConvolutionMode.Truncate); //Pre-0.7.0: no ConvolutionMode. Want to default to truncate here if not set
assertEquals(ConvolutionMode.Truncate, l1.getConvolutionMode()); //Pre-0.7.0: no ConvolutionMode. Want to default to truncate here if not set
OutputLayer l2 = (OutputLayer) conf.getConf(2).getLayer();
assertEquals("sigmoid", l2.getActivationFn().toString());

View File

@ -162,7 +162,7 @@ public class RegressionTest071 extends BaseDL4JTest {
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
assertArrayEquals(new int[] {1, 1}, l0.getStride());
assertArrayEquals(new int[] {0, 0}, l0.getPadding());
assertEquals(l0.getConvolutionMode(), ConvolutionMode.Same);
assertEquals(ConvolutionMode.Same, l0.getConvolutionMode());
SubsamplingLayer l1 = (SubsamplingLayer) conf.getConf(1).getLayer();
assertArrayEquals(new int[] {2, 2}, l1.getKernelSize());

View File

@ -176,14 +176,14 @@ public class RegressionTest080 extends BaseDL4JTest {
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
assertArrayEquals(new int[] {1, 1}, l0.getStride());
assertArrayEquals(new int[] {0, 0}, l0.getPadding());
assertEquals(l0.getConvolutionMode(), ConvolutionMode.Same);
assertEquals(ConvolutionMode.Same, l0.getConvolutionMode());
SubsamplingLayer l1 = (SubsamplingLayer) conf.getConf(1).getLayer();
assertArrayEquals(new int[] {2, 2}, l1.getKernelSize());
assertArrayEquals(new int[] {1, 1}, l1.getStride());
assertArrayEquals(new int[] {0, 0}, l1.getPadding());
assertEquals(PoolingType.MAX, l1.getPoolingType());
assertEquals(l1.getConvolutionMode(), ConvolutionMode.Same);
assertEquals(ConvolutionMode.Same, l1.getConvolutionMode());
OutputLayer l2 = (OutputLayer) conf.getConf(2).getLayer();
assertTrue(l2.getActivationFn() instanceof ActivationSigmoid);

View File

@ -178,8 +178,6 @@ public abstract class BaseCudnnHelper {
}
}
protected static final int TENSOR_FORMAT = CUDNN_TENSOR_NCHW;
protected final DataType nd4jDataType;
protected final int dataType;
protected final int dataTypeSize;

View File

@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -22,6 +23,7 @@ import lombok.extern.slf4j.Slf4j;
import lombok.val;
import com.jakewharton.byteunits.BinaryByteUnit;
import org.bytedeco.javacpp.Pointer;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer.AlgoMode;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer.BwdDataAlgo;
@ -86,7 +88,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
}
private cudnnTensorStruct srcTensorDesc = new cudnnTensorStruct(), dstTensorDesc = new cudnnTensorStruct(),
biasTensorDesc = new cudnnTensorStruct(), deltaTensorDesc = new cudnnTensorStruct();
biasTensorDesc = new cudnnTensorStruct(), deltaTensorDesc = new cudnnTensorStruct();
private cudnnFilterStruct filterDesc = new cudnnFilterStruct();
private cudnnConvolutionStruct convDesc = new cudnnConvolutionStruct();
private cudnnActivationStruct activationDesc = new cudnnActivationStruct();
@ -138,7 +140,21 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray weights, INDArray bias, INDArray delta, int[] kernel,
int[] strides, int[] pad, INDArray biasGradView, INDArray weightGradView, IActivation afn,
AlgoMode mode, BwdFilterAlgo bwdFilterAlgo, BwdDataAlgo bwdDataAlgo,
ConvolutionMode convolutionMode, int[] dilation, LayerWorkspaceMgr workspaceMgr) {
ConvolutionMode convolutionMode, int[] dilation, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) {
//AB 2020/04/21 - cuDNN does have NHWC support (with limitations) however I have been unable to get it working
// correctly on NHWC data, even after updating all descriptors, tensor format, etc.
//Therefore: all computation here is done in NCHW format only
//As of a future (next?) release we'll likely switch to C++ for cuDNN support
boolean origNHWC = false;
if(format == CNN2DFormat.NHWC){
input = input.permute(0,3,1,2); //NHWC to NCHW
delta = delta.permute(0,3,1,2);
origNHWC = true;
}
int TENSOR_FORMAT = CUDNN_TENSOR_NCHW;
int code;
val miniBatch = input.size(0);
@ -147,7 +163,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
val kH = weights.size(2);
val kW = weights.size(3);
CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, null);
CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, null, CNN2DFormat.NCHW); //Note hardcoded NCHW due to above
input = args.getInput();
val inH = input.size(2);
val inW = input.size(3);
@ -176,7 +192,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
(int) deltaStride[0], (int) deltaStride[1], (int) deltaStride[2], (int) deltaStride[3]);
checkCudnn(false, "cudnnSetTensor4dDescriptorEx", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
code = cudnnSetConvolution2dDescriptor(cudnnContext.convDesc, pad[0], pad[1], strides[0], strides[1], dilation[0],
dilation[1], CUDNN_CROSS_CORRELATION, dataType);
dilation[1], CUDNN_CROSS_CORRELATION, dataType);
checkCudnn(false, "cudnnSetConvolution2dDescriptor", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
code = cudnnSetFilter4dDescriptor(cudnnContext.filterDesc, dataType, TENSOR_FORMAT, (int) outDepth, (int) inDepth, (int) kH, (int) kW);
checkCudnn(false, "cudnnSetFilter4dDescriptor", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
@ -238,16 +254,16 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
}
} else {
code = cudnnGetConvolutionBackwardFilterAlgorithm(cudnnContext, cudnnContext.srcTensorDesc,
cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.filterDesc,
mode == AlgoMode.NO_WORKSPACE ? CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE
: CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST,
0, algo1);
cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.filterDesc,
mode == AlgoMode.NO_WORKSPACE ? CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE
: CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST,
0, algo1);
checkCudnn(false, "cudnnGetConvolutionBackwardFilterAlgorithm", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
code = cudnnGetConvolutionBackwardDataAlgorithm(cudnnContext, cudnnContext.filterDesc,
cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.srcTensorDesc,
mode == AlgoMode.NO_WORKSPACE ? CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE
: CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST,
0, algo2);
cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.srcTensorDesc,
mode == AlgoMode.NO_WORKSPACE ? CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE
: CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST,
0, algo2);
checkCudnn(false, "cudnnGetConvolutionBackwardDataAlgorithm", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
}
@ -263,7 +279,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
Allocator allocator = AtomicAllocator.getInstance();
CudaContext context = allocator.getFlowController().prepareActionAllWrite(input, weights, weightGradView,
biasGradView, delta, epsNext);
biasGradView, delta, epsNext);
Pointer srcData = allocator.getPointer(input, context);
Pointer filterData = allocator.getPointer(weights, context);
Pointer filterGradData = allocator.getPointer(weightGradView, context);
@ -279,14 +295,14 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
checkCudnn(false, "cudnnSetTensor4dDescriptorEx", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
code = cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnnContext, cudnnContext.srcTensorDesc,
cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.filterDesc, algo1[0],
sizeInBytes);
cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.filterDesc, algo1[0],
sizeInBytes);
checkCudnn(false, "cudnnGetConvolutionBackwardFilterWorkspaceSize", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
long sizeInBytes1 = sizeInBytes.get(0);
code = cudnnGetConvolutionBackwardDataWorkspaceSize(cudnnContext, cudnnContext.filterDesc,
cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.dstTensorDesc, algo2[0],
sizeInBytes);
cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.dstTensorDesc, algo2[0],
sizeInBytes);
checkCudnn(false, "cudnnGetConvolutionBackwardDataWorkspaceSize", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
DataCache workSpace = workspaceMgr.getHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY);
@ -313,21 +329,21 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
checkCudnn(false, "cudnnSetTensor4dDescriptor", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
code = cudnnConvolutionBackwardBias(cudnnContext, alpha, cudnnContext.deltaTensorDesc, deltaData, beta,
cudnnContext.biasTensorDesc, biasGradData);
cudnnContext.biasTensorDesc, biasGradData);
checkCudnn(false, "cudnnConvolutionBackwardBias", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
code = cudnnConvolutionBackwardFilter(cudnnContext, alpha, cudnnContext.srcTensorDesc, srcData,
cudnnContext.deltaTensorDesc, deltaData, cudnnContext.convDesc, algo1[0], workSpace,
workSpace.capacity(), beta, cudnnContext.filterDesc, filterGradData);
cudnnContext.deltaTensorDesc, deltaData, cudnnContext.convDesc, algo1[0], workSpace,
workSpace.capacity(), beta, cudnnContext.filterDesc, filterGradData);
checkCudnn(false, "cudnnConvolutionBackwardFilter", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
code = cudnnConvolutionBackwardData(cudnnContext, alpha, cudnnContext.filterDesc, filterData,
cudnnContext.deltaTensorDesc, deltaData, cudnnContext.convDesc, algo2[0], workSpace,
workSpace.capacity(), beta, cudnnContext.dstTensorDesc, dstData);
cudnnContext.deltaTensorDesc, deltaData, cudnnContext.convDesc, algo2[0], workSpace,
workSpace.capacity(), beta, cudnnContext.dstTensorDesc, dstData);
checkCudnn(false, "cudnnConvolutionBackwardData", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
allocator.getFlowController().registerActionAllWrite(context, input, weights, weightGradView, biasGradView,
delta, epsNext);
delta, epsNext);
Gradient retGradient = new DefaultGradient();
retGradient.setGradientFor(ConvolutionParamInitializer.BIAS_KEY, biasGradView);
@ -344,12 +360,30 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
interval(0, epsNext.size(3) - (args.isManualPadRight() ? 1 : 0)));
}
if(origNHWC){
epsNext = epsNext.permute(0,2,3,1); //NCHW to NHWC
}
return new Pair<>(retGradient, epsNext);
}
@Override
public INDArray preOutput(INDArray input, INDArray weights, INDArray bias, int[] kernel, int[] strides, int[] pad,
AlgoMode mode, FwdAlgo fwdAlgo, ConvolutionMode convolutionMode, int[] dilation, LayerWorkspaceMgr workspaceMgr) {
AlgoMode mode, FwdAlgo fwdAlgo, ConvolutionMode convolutionMode, int[] dilation, CNN2DFormat format,
LayerWorkspaceMgr workspaceMgr) {
//AB 2020/04/21 - cuDNN does have NHWC support (with limitations) however I have been unable to get it working
// correctly on NHWC data, even after updating all descriptors, tensor format, etc.
//Therefore: all computation here is done in NCHW format only
//As of a future (next?) release we'll likely switch to C++ for cuDNN support
boolean origNHWC = false;
if(format == CNN2DFormat.NHWC){
input = input.permute(0,3,1,2); //NHWC to NCHW
origNHWC = true;
}
int TENSOR_FORMAT = CUDNN_TENSOR_NCHW;
int code;
val miniBatch = input.size(0);
@ -358,7 +392,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
val kH = weights.size(2);
val kW = weights.size(3);
CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, null);
CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, null, CNN2DFormat.NCHW); //Note hardcoded NCHW due to above
input = args.getInput();
val inH = input.size(2);
val inW = input.size(3);
@ -378,7 +412,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
checkCudnn(true, "cudnnSetFilter4dDescriptor", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
code = cudnnSetConvolution2dDescriptor(cudnnContext.convDesc, pad[0], pad[1], strides[0], strides[1], dilation[0],
dilation[1], CUDNN_CROSS_CORRELATION, dataType);
dilation[1], CUDNN_CROSS_CORRELATION, dataType);
checkCudnn(true, "cudnnSetConvolution2dDescriptor", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
@ -460,8 +494,8 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
checkCudnn(true, "cudnnSetStream", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
code = cudnnGetConvolutionForwardWorkspaceSize(cudnnContext, cudnnContext.srcTensorDesc,
cudnnContext.filterDesc, cudnnContext.convDesc, cudnnContext.dstTensorDesc, algo[0],
sizeInBytes);
cudnnContext.filterDesc, cudnnContext.convDesc, cudnnContext.dstTensorDesc, algo[0],
sizeInBytes);
checkCudnn(true, "cudnnGetConvolutionForwardWorkspaceSize", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
DataCache workSpace = workspaceMgr.getHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY);
@ -482,8 +516,8 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
workspaceMgr.setHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY, workSpace);
}
code = cudnnConvolutionForward(cudnnContext, alpha, cudnnContext.srcTensorDesc, srcData,
cudnnContext.filterDesc, filterData, cudnnContext.convDesc, algo[0], workSpace,
workSpace.capacity(), beta, cudnnContext.dstTensorDesc, dstData);
cudnnContext.filterDesc, filterData, cudnnContext.convDesc, algo[0], workSpace,
workSpace.capacity(), beta, cudnnContext.dstTensorDesc, dstData);
checkCudnn(true, "cudnnConvolutionForward", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
@ -491,7 +525,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
checkCudnn(true, "cudnnSetTensor4dDescriptor", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
code = cudnnAddTensor(cudnnContext, alpha, cudnnContext.biasTensorDesc, biasData, alpha,
cudnnContext.dstTensorDesc, dstData);
cudnnContext.dstTensorDesc, dstData);
checkCudnn(true, "cudnnAddTensor", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation);
allocator.registerAction(context, z, input, weights, bias);
@ -499,6 +533,10 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
if (CudaEnvironment.getInstance().getConfiguration().isDebug())
context.syncOldStream();
if(origNHWC){
z = z.permute(0,2,3,1); //NCHW to NHWC
}
return z;
}
@ -552,29 +590,29 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
break;
case "sigmoid":
checkCudnn(cudnnSetActivationDescriptor(cudnnContext.activationDesc, CUDNN_ACTIVATION_SIGMOID,
CUDNN_PROPAGATE_NAN, 0));
CUDNN_PROPAGATE_NAN, 0));
checkCudnn(cudnnActivationForward(cudnnContext, cudnnContext.activationDesc, alpha,
cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData));
cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData));
break;
case "relu":
checkCudnn(cudnnSetActivationDescriptor(cudnnContext.activationDesc, CUDNN_ACTIVATION_RELU,
CUDNN_PROPAGATE_NAN, 0));
CUDNN_PROPAGATE_NAN, 0));
checkCudnn(cudnnActivationForward(cudnnContext, cudnnContext.activationDesc, alpha,
cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData));
cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData));
break;
case "tanh":
checkCudnn(cudnnSetActivationDescriptor(cudnnContext.activationDesc, CUDNN_ACTIVATION_TANH,
CUDNN_PROPAGATE_NAN, 0));
CUDNN_PROPAGATE_NAN, 0));
checkCudnn(cudnnActivationForward(cudnnContext, cudnnContext.activationDesc, alpha,
cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData));
cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData));
break;
case "softmax":
checkCudnn(cudnnSoftmaxForward(cudnnContext, CUDNN_SOFTMAX_ACCURATE, CUDNN_SOFTMAX_MODE_CHANNEL, alpha,
cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData));
cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData));
break;
case "logsoftmax":
checkCudnn(cudnnSoftmaxForward(cudnnContext, CUDNN_SOFTMAX_LOG, CUDNN_SOFTMAX_MODE_CHANNEL, alpha,
cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData));
cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData));
break;
default:
activation = null;
@ -593,7 +631,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
* @return
*/
public static CudnnForwardArgs getCudnnForwardArgs(INDArray input, int[] kernel, int[] strides, int[] padding, int[] dilation,
ConvolutionMode convolutionMode, PoolingType poolingType){
ConvolutionMode convolutionMode, PoolingType poolingType, CNN2DFormat format){
INDArray origInput = input;
//Check if we need to dup the input: views, non-contiguous, etc. CuDNN also seems to have has issues if strides
@ -602,16 +640,19 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
input = input.dup('c');
}
boolean nchw = format == CNN2DFormat.NCHW;
int hIdx = nchw ? 2 : 1;
int wIdx = nchw ? 3 : 2;
val inH = input.size(2);
val inW = input.size(3);
val inH = input.size(hIdx);
val inW = input.size(wIdx);
boolean manualPadBottom = false;
boolean manualPadRight = false;
int[] outSize;
if (convolutionMode == ConvolutionMode.Same) {
outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation); //Also performs validation
outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation, format); //Also performs validation
padding = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[] {(int) inH, (int) inW}, kernel, strides, dilation);
int[] padBottomRight = ConvolutionUtils.getSameModeBottomRightPadding(outSize, new int[] {(int) inH, (int) inW}, kernel, strides, dilation);
if(!Arrays.equals(padding, padBottomRight)){
@ -626,9 +667,17 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
manualPadRight = (padding[1] != padBottomRight[1]);
//NCHW format
val newShape = new long[]{input.size(0), input.size(1),
input.size(2) + (manualPadBottom ? 1 : 0),
input.size(3) + (manualPadRight ? 1 : 0)};
long[] newShape;
if(nchw){
newShape = new long[]{input.size(0), input.size(1),
input.size(2) + (manualPadBottom ? 1 : 0),
input.size(3) + (manualPadRight ? 1 : 0)};
} else {
newShape = new long[]{input.size(0),
input.size(1) + (manualPadBottom ? 1 : 0),
input.size(2) + (manualPadRight ? 1 : 0),
input.size(3)};
}
INDArray newInput;
if(poolingType == null || poolingType != PoolingType.MAX){
newInput = Nd4j.create(input.dataType(), newShape);
@ -638,15 +687,22 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
// if the 'real' (non-padding) values are all < 0, we take the real value, not the padding value
newInput = Nd4j.valueArrayOf(newShape, Double.NEGATIVE_INFINITY, input.dataType());
}
newInput.put(new INDArrayIndex[]{all(), all(), interval(0,input.size(2)),
interval(0, input.size(3))}, input);
if(nchw){
newInput.put(new INDArrayIndex[]{all(), all(), interval(0,input.size(2)),
interval(0, input.size(3))}, input);
} else {
newInput.put(new INDArrayIndex[]{all(), interval(0,input.size(1)),
interval(0, input.size(2)), all()}, input);
}
input = newInput;
//Now: we've manually applied the "extra" bottom/right padding only - if required. Consequently, we
// now have the same amount of padding required for top/bottom, and left/right - which we'll let
// CuDNN handle
}
} else {
outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, padding, convolutionMode, dilation); //Also performs validation
outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, padding, convolutionMode, dilation, format); //Also performs validation
}
return new CudnnForwardArgs(manualPadBottom, manualPadRight, input, origInput, padding, outSize);
@ -670,4 +726,4 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
return Collections.emptyMap();
}
}
}

View File

@ -19,6 +19,7 @@ package org.deeplearning4j.nn.layers.convolution.subsampling;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.bytedeco.javacpp.Pointer;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.layers.PoolingType;
import org.deeplearning4j.nn.gradient.DefaultGradient;
@ -114,23 +115,29 @@ public class CudnnSubsamplingHelper extends BaseCudnnHelper implements Subsampli
@Override
public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, int[] kernel, int[] strides,
int[] pad, PoolingType poolingType, ConvolutionMode convolutionMode, int[] dilation, LayerWorkspaceMgr workspaceMgr) {
int[] pad, PoolingType poolingType, ConvolutionMode convolutionMode,
int[] dilation, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) {
if(dilation[0] != 1 || dilation[1] != 1){
//CuDNN doesn't support dilated subsampling
return null;
}
boolean nchw = format == CNN2DFormat.NCHW;
int chIdx = nchw ? 1 : 3;
int hIdx = nchw ? 2 : 1;
int wIdx = nchw ? 3 : 2;
//We require the output as one of the arguments for backprop here
//TODO we could add cache mode support here somehow...
INDArray reduced = activate(input, true, kernel, strides, pad, poolingType, convolutionMode, dilation, workspaceMgr);
INDArray reduced = activate(input, true, kernel, strides, pad, poolingType, convolutionMode, dilation, format, workspaceMgr);
val miniBatch = input.size(0);
val depth = input.size(1);
val depth = input.size(chIdx);
CudnnConvolutionHelper.CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, poolingType);
CudnnConvolutionHelper.CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, poolingType, format);
input = args.getInput();
val inH = input.size(2);
val inW = input.size(3);
val inH = input.size(hIdx);
val inW = input.size(wIdx);
val srcStride = input.stride();
int[] outSize = args.getOutSize();
int outH = outSize[0];
@ -160,23 +167,26 @@ public class CudnnSubsamplingHelper extends BaseCudnnHelper implements Subsampli
epsilon = epsilon.dup('c');
}
input = input.dup();
val deltaStride = epsilon.stride();
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW,
(int) srcStride[0], (int) srcStride[1], (int) srcStride[2], (int) srcStride[3]));
(int) srcStride[0], (int) srcStride[chIdx], (int) srcStride[hIdx], (int) srcStride[wIdx]));
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.deltaTensorDesc, dataType, (int) miniBatch, (int) depth, (int) outH, (int) outW,
(int) deltaStride[0], (int) deltaStride[1], (int) deltaStride[2], (int) deltaStride[3]));
(int) deltaStride[0], (int) deltaStride[chIdx], (int) deltaStride[hIdx], (int) deltaStride[wIdx]));
checkCudnn(cudnnSetPooling2dDescriptor(cudnnContext.poolingDesc, poolingMode, CUDNN_PROPAGATE_NAN, kernel[0],
kernel[1], pad[0], pad[1], strides[0], strides[1]));
INDArray outEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), new long[] {(int) miniBatch, (int) depth, (int) inH, (int) inW}, 'c');
long[] outEpsShape = nchw ? new long[] {miniBatch, depth, inH, inW} : new long[] {miniBatch, inH, inW, depth};
INDArray outEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), outEpsShape, 'c');
val dstStride = outEpsilon.stride();
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW,
(int) dstStride[0], (int) dstStride[1], (int) dstStride[2], (int) dstStride[3]));
(int) dstStride[0], (int) dstStride[chIdx], (int) dstStride[hIdx], (int) dstStride[wIdx]));
Allocator allocator = AtomicAllocator.getInstance();
CudaContext context = allocator.getFlowController().prepareAction(input, epsilon, reduced, outEpsilon);
@ -198,9 +208,16 @@ public class CudnnSubsamplingHelper extends BaseCudnnHelper implements Subsampli
//Note that: if we had to manually pad for SAME mode, we have to 'undo' this manual padding for the epsilon
// we return. The returned epsilon (i.e., dL/dIn array) has to be the same shape as the *original* input.
if(args.isManualPadBottom() || args.isManualPadRight()) {
outEpsilon = outEpsilon.get(all(), all(),
interval(0, outEpsilon.size(2) - (args.isManualPadBottom() ? 1 : 0)),
interval(0, outEpsilon.size(3) - (args.isManualPadRight() ? 1 : 0)));
if(nchw){
outEpsilon = outEpsilon.get(all(), all(),
interval(0, outEpsilon.size(2) - (args.isManualPadBottom() ? 1 : 0)),
interval(0, outEpsilon.size(3) - (args.isManualPadRight() ? 1 : 0)));
} else {
outEpsilon = outEpsilon.get(all(),
interval(0, outEpsilon.size(1) - (args.isManualPadBottom() ? 1 : 0)),
interval(0, outEpsilon.size(2) - (args.isManualPadRight() ? 1 : 0)),
all());
}
}
return new Pair<>(retGradient, outEpsilon);
@ -209,19 +226,24 @@ public class CudnnSubsamplingHelper extends BaseCudnnHelper implements Subsampli
@Override
public INDArray activate(INDArray input, boolean training, int[] kernel, int[] strides, int[] pad,
PoolingType poolingType, ConvolutionMode convolutionMode, int[] dilation, LayerWorkspaceMgr workspaceMgr) {
PoolingType poolingType, ConvolutionMode convolutionMode, int[] dilation, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) {
if(dilation[0] != 1 || dilation[1] != 1){
//CuDNN doesn't support dilated subsampling
return null;
}
val miniBatch = input.size(0);
val inDepth = input.size(1);
boolean nchw = format == CNN2DFormat.NCHW;
int chIdx = nchw ? 1 : 3;
int hIdx = nchw ? 2 : 1;
int wIdx = nchw ? 3 : 2;
CudnnConvolutionHelper.CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, poolingType);
val miniBatch = input.size(0);
val inDepth = input.size(nchw ? 1 : 3);
CudnnConvolutionHelper.CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, poolingType, format);
input = args.getInput();
val inH = input.size(2);
val inW = input.size(3);
val inH = input.size(nchw ? 2 : 1);
val inW = input.size(nchw ? 3 : 2);
val srcStride = input.stride();
val outSize = args.getOutSize();
int outH = outSize[0];
@ -246,13 +268,14 @@ public class CudnnSubsamplingHelper extends BaseCudnnHelper implements Subsampli
checkCudnn(cudnnSetPooling2dDescriptor(cudnnContext.poolingDesc, poolingMode, CUDNN_PROPAGATE_NAN, kernel[0],
kernel[1], pad[0], pad[1], strides[0], strides[1]));
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, (int) miniBatch, (int) inDepth, (int) inH, (int) inW,
(int) srcStride[0], (int) srcStride[1], (int) srcStride[2], (int) srcStride[3]));
(int) srcStride[0], (int) srcStride[chIdx], (int) srcStride[hIdx], (int) srcStride[wIdx]));
INDArray reduced = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), new long[] {(int) miniBatch, (int) inDepth, outH, outW}, 'c');
long[] outShape = nchw ? new long[] {miniBatch, inDepth, outH, outW} : new long[] {miniBatch, outH, outW, inDepth};
INDArray reduced = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), outShape, 'c');
val dstStride = reduced.stride();
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, (int) miniBatch, (int) inDepth, (int) outH, (int) outW,
(int) dstStride[0], (int) dstStride[1], (int) dstStride[2], (int) dstStride[3]));
(int) dstStride[0], (int) dstStride[chIdx], (int) dstStride[hIdx], (int) dstStride[wIdx]));
Allocator allocator = AtomicAllocator.getInstance();
CudaContext context = allocator.getFlowController().prepareAction(input, reduced);

View File

@ -19,6 +19,7 @@ package org.deeplearning4j.nn.layers.normalization;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.bytedeco.javacpp.Pointer;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseCudnnHelper;
@ -124,12 +125,21 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba
@Override
public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, long[] shape, INDArray gamma, INDArray beta,
INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr layerWorkspaceMgr) {
INDArray dGammaView, INDArray dBetaView, double eps, CNN2DFormat format, LayerWorkspaceMgr layerWorkspaceMgr) {
boolean nchw = format == CNN2DFormat.NCHW;
this.eps = eps;
int cudnnTensorFormat = nchw ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC;
int chIdx = nchw ? 1 : 3;
int hIdx = nchw ? 2 : 1;
int wIdx = nchw ? 3 : 2;
val miniBatch = (int) input.size(0);
val depth = (int) input.size(1);
val inH = (int) input.size(2);
val inW = (int) input.size(3);
val depth = (int) input.size(chIdx);
val inH = (int) input.size(hIdx);
val inW = (int) input.size(wIdx);
final boolean isHalf = (input.dataType() == DataType.HALF);
INDArray gammaOrig = null;
@ -164,16 +174,17 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba
((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW,
(int) srcStride[0], (int) srcStride[1], (int) srcStride[2], (int) srcStride[3]));
(int) srcStride[0], (int) srcStride[chIdx], (int) srcStride[hIdx], (int) srcStride[wIdx]));
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.deltaTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW,
(int) deltaStride[0], (int) deltaStride[1], (int) deltaStride[2], (int) deltaStride[3]));
(int) deltaStride[0], (int) deltaStride[chIdx], (int) deltaStride[hIdx], (int) deltaStride[wIdx]));
INDArray nextEpsilon = layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), new long[] {miniBatch, depth, inH, inW}, 'c');
long[] nextEpsShape = nchw ? new long[] {miniBatch, depth, inH, inW} : new long[] {miniBatch, inH, inW, depth};
INDArray nextEpsilon = layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), nextEpsShape, 'c');
val dstStride = ArrayUtil.toInts(nextEpsilon.stride());
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, depth, inH, inW,
dstStride[0], dstStride[1], dstStride[2], dstStride[3]));
checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, TENSOR_FORMAT, toCudnnDataType(gamma.data().dataType()), (int)shape[0],
dstStride[0], dstStride[chIdx], dstStride[hIdx], dstStride[wIdx]));
checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, cudnnTensorFormat, toCudnnDataType(gamma.data().dataType()), (int)shape[0],
(int)shape[1], shape.length > 2 ? (int)shape[2] : 1, shape.length > 3 ? (int)shape[3] : 1));
Allocator allocator = AtomicAllocator.getInstance();
@ -215,9 +226,15 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba
@Override
public INDArray preOutput(INDArray x, boolean training, long[] shape, INDArray gamma, INDArray beta, INDArray mean,
INDArray var, double decay, double eps, LayerWorkspaceMgr workspaceMgr) {
INDArray var, double decay, double eps, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) {
boolean nchw = format == CNN2DFormat.NCHW;
int cudnnTensorFormat = nchw ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC;
int chIdx = nchw ? 1 : 3;
int hIdx = nchw ? 2 : 1;
int wIdx = nchw ? 3 : 2;
this.eps = eps;
final boolean isHalf = (x.dataType() == DataType.HALF);
final boolean isHalf = (x.dataType() == DataType.FLOAT16);
INDArray origGamma = gamma;
INDArray origBeta = beta;
INDArray origMean = mean;
@ -238,21 +255,22 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba
decay = 0.0; //From cudnn docs: runningMean = newMean*factor + runningMean*(1-factor). -> 0 = "in-place modification of running mean disabled"
val miniBatch = (int) x.size(0);
val inDepth = (int) x.size(1);
val inH = (int) x.size(2);
val inW = (int) x.size(3);
val inDepth = (int) x.size(chIdx);
val inH = (int) x.size(hIdx);
val inW = (int) x.size(wIdx);
val srcStride = ArrayUtil.toInts(x.stride());
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, miniBatch, inDepth, inH, inW,
srcStride[0], srcStride[1], srcStride[2], srcStride[3]));
srcStride[0], srcStride[chIdx], srcStride[hIdx], srcStride[wIdx]));
INDArray activations = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, x.dataType(), new long[] {miniBatch, inDepth, inH, inW}, 'c');
long[] actShape = nchw ? new long[] {miniBatch, inDepth, inH, inW} : new long[] {miniBatch, inH, inW, inDepth};
INDArray activations = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, x.dataType(), actShape, 'c');
val dstStride = ArrayUtil.toInts(activations.stride());
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, inDepth, inH, inW,
dstStride[0], dstStride[1], dstStride[2], dstStride[3]));
dstStride[0], dstStride[chIdx], dstStride[hIdx], dstStride[wIdx]));
checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, TENSOR_FORMAT, toCudnnDataType(mean.data().dataType()), (int)shape[0],
checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, cudnnTensorFormat, toCudnnDataType(mean.data().dataType()), (int)shape[0],
(int)shape[1], shape.length > 2 ? (int)shape[2] : 1, shape.length > 3 ? (int)shape[3] : 1));
Allocator allocator = AtomicAllocator.getInstance();

View File

@ -16,74 +16,131 @@
package org.deeplearning4j;
import org.apache.commons.compress.utils.IOUtils;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.convolution.ConvolutionLayer;
import org.deeplearning4j.nn.layers.convolution.subsampling.SubsamplingLayer;
import org.deeplearning4j.nn.layers.normalization.BatchNormalization;
import org.deeplearning4j.nn.layers.normalization.LocalResponseNormalization;
import org.deeplearning4j.nn.layers.recurrent.LSTM;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.regularization.L1Regularization;
import org.nd4j.linalg.learning.regularization.L2Regularization;
import org.nd4j.linalg.learning.regularization.Regularization;
import org.nd4j.linalg.learning.regularization.WeightDecay;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.*;
import java.lang.reflect.Field;
import java.util.List;
import java.util.Random;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
public class TestUtils {
public static MultiLayerNetwork testModelSerialization(MultiLayerNetwork net){
MultiLayerNetwork restored;
try {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ModelSerializer.writeModel(net, baos, true);
byte[] bytes = baos.toByteArray();
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
MultiLayerNetwork restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);
restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);
assertEquals(net.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations());
assertEquals(net.params(), restored.params());
return restored;
} catch (IOException e){
//Should never happen
throw new RuntimeException(e);
}
//Also check the MultiLayerConfiguration is serializable (required by Spark etc)
MultiLayerConfiguration conf = net.getLayerWiseConfigurations();
serializeDeserializeJava(conf);
return restored;
}
public static ComputationGraph testModelSerialization(ComputationGraph net){
ComputationGraph restored;
try {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ModelSerializer.writeModel(net, baos, true);
byte[] bytes = baos.toByteArray();
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
ComputationGraph restored = ModelSerializer.restoreComputationGraph(bais, true);
restored = ModelSerializer.restoreComputationGraph(bais, true);
assertEquals(net.getConfiguration(), restored.getConfiguration());
assertEquals(net.params(), restored.params());
return restored;
} catch (IOException e){
//Should never happen
throw new RuntimeException(e);
}
//Also check the ComputationGraphConfiguration is serializable (required by Spark etc)
ComputationGraphConfiguration conf = net.getConfiguration();
serializeDeserializeJava(conf);
return restored;
}
public static INDArray randomOneHot(int examples, int nOut){
private static <T> T serializeDeserializeJava(T object){
byte[] bytes;
try(ByteArrayOutputStream baos = new ByteArrayOutputStream(); ObjectOutputStream oos = new ObjectOutputStream(baos)){
oos.writeObject(object);
oos.close();
bytes = baos.toByteArray();
} catch (IOException e){
//Should never happen
throw new RuntimeException(e);
}
T out;
try(ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(bytes))){
out = (T)ois.readObject();
} catch (IOException | ClassNotFoundException e){
throw new RuntimeException(e);
}
assertEquals(object, out);
return out;
}
public static INDArray randomOneHot(long examples, long nOut){
return randomOneHot(examples, nOut, new Random(12345));
}
public static INDArray randomOneHot(int examples, int nOut, long rngSeed){
public static INDArray randomOneHot(DataType dataType, long examples, long nOut){
return randomOneHot(dataType, examples, nOut, new Random(12345));
}
public static INDArray randomOneHot(long examples, long nOut, long rngSeed){
return randomOneHot(examples, nOut, new Random(rngSeed));
}
public static INDArray randomOneHot(int examples, int nOut, Random rng){
INDArray arr = Nd4j.create(examples, nOut);
public static INDArray randomOneHot(long examples, long nOut, Random rng) {
return randomOneHot(Nd4j.defaultFloatingPointType(), examples,nOut, rng);
}
public static INDArray randomOneHot(DataType dataType, long examples, long nOut, Random rng){
INDArray arr = Nd4j.create(dataType, examples, nOut);
for( int i=0; i<examples; i++ ){
arr.putScalar(i, rng.nextInt(nOut), 1.0);
arr.putScalar(i, rng.nextInt((int) nOut), 1.0);
}
return arr;
}
@ -115,4 +172,143 @@ public class TestUtils {
Nd4j.getExecutioner().exec(new BernoulliDistribution(ret, p));
return ret;
}
public static void writeStreamToFile(File out, InputStream is) throws IOException {
byte[] b = IOUtils.toByteArray(is);
try (OutputStream os = new BufferedOutputStream(new FileOutputStream(out))) {
os.write(b);
}
}
public static L1Regularization getL1Reg(List<Regularization> l){
for(Regularization r : l){
if(r instanceof L1Regularization){
return (L1Regularization) r;
}
}
return null;
}
public static L2Regularization getL2Reg(BaseLayer baseLayer){
return getL2Reg(baseLayer.getRegularization());
}
public static L2Regularization getL2Reg(List<Regularization> l){
for(Regularization r : l){
if(r instanceof L2Regularization){
return (L2Regularization) r;
}
}
return null;
}
public static WeightDecay getWeightDecayReg(BaseLayer bl){
return getWeightDecayReg(bl.getRegularization());
}
public static WeightDecay getWeightDecayReg(List<Regularization> l){
for(Regularization r : l){
if(r instanceof WeightDecay){
return (WeightDecay) r;
}
}
return null;
}
public static double getL1(BaseLayer layer) {
List<Regularization> l = layer.getRegularization();
return getL1(l);
}
public static double getL1(List<Regularization> l){
L1Regularization l1Reg = null;
for(Regularization reg : l){
if(reg instanceof L1Regularization)
l1Reg = (L1Regularization) reg;
}
assertNotNull(l1Reg);
return l1Reg.getL1().valueAt(0,0);
}
public static double getL2(BaseLayer layer) {
List<Regularization> l = layer.getRegularization();
return getL2(l);
}
public static double getL2(List<Regularization> l){
L2Regularization l2Reg = null;
for(Regularization reg : l){
if(reg instanceof L2Regularization)
l2Reg = (L2Regularization) reg;
}
assertNotNull(l2Reg);
return l2Reg.getL2().valueAt(0,0);
}
public static double getL1(AbstractSameDiffLayer layer){
return getL1(layer.getRegularization());
}
public static double getL2(AbstractSameDiffLayer layer){
return getL2(layer.getRegularization());
}
public static double getWeightDecay(BaseLayer layer) {
return getWeightDecayReg(layer.getRegularization()).getCoeff().valueAt(0,0);
}
public static void removeHelper(Layer layer) throws Exception {
removeHelpers(new Layer[]{layer});
}
public static void removeHelpers(Layer[] layers) throws Exception {
for(Layer l : layers){
if(l instanceof ConvolutionLayer){
Field f1 = ConvolutionLayer.class.getDeclaredField("helper");
f1.setAccessible(true);
f1.set(l, null);
} else if(l instanceof SubsamplingLayer){
Field f2 = SubsamplingLayer.class.getDeclaredField("helper");
f2.setAccessible(true);
f2.set(l, null);
} else if(l instanceof BatchNormalization) {
Field f3 = BatchNormalization.class.getDeclaredField("helper");
f3.setAccessible(true);
f3.set(l, null);
} else if(l instanceof LSTM){
Field f4 = LSTM.class.getDeclaredField("helper");
f4.setAccessible(true);
f4.set(l, null);
} else if(l instanceof LocalResponseNormalization){
Field f5 = LocalResponseNormalization.class.getDeclaredField("helper");
f5.setAccessible(true);
f5.set(l, null);
}
if(l.getHelper() != null){
throw new IllegalStateException("Did not remove helper for layer: " + l.getClass().getSimpleName());
}
}
}
public static void assertHelperPresent(Layer layer){
}
public static void assertHelpersPresent(Layer[] layers) throws Exception {
for(Layer l : layers){
//Don't use instanceof here - there are sub conv subclasses
if(l.getClass() == ConvolutionLayer.class || l instanceof SubsamplingLayer || l instanceof BatchNormalization || l instanceof LSTM){
Preconditions.checkNotNull(l.getHelper(), l.conf().getLayer().getLayerName());
}
}
}
public static void assertHelpersAbsent(Layer[] layers) throws Exception {
for(Layer l : layers){
Preconditions.checkState(l.getHelper() == null, l.conf().getLayer().getLayerName());
}
}
}

View File

@ -212,7 +212,7 @@ public class TestConvolution extends BaseDL4JTest {
ComputationGraph model = KerasModelImport.importKerasModelAndWeights( fExtracted.getAbsolutePath(), new int[]{inSize, inSize, 3}, false);
model = model.convertDataType(DataType.DOUBLE);
INDArray in = Nd4j.rand(DataType.DOUBLE, new int[]{1, 3, inSize, inSize});
INDArray in = Nd4j.rand(DataType.DOUBLE, new int[]{1, inSize, inSize, 3}); //Keras import model -> NHWC
CuDNNTestUtils.assertHelpersPresent(model.getLayers());
Map<String,INDArray> withCudnn = model.feedForward(in, false);

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.datasets.fetchers;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.deeplearning4j.base.EmnistFetcher;
@ -36,6 +37,7 @@ import java.util.Random;
* @author Alex Black
*
*/
@Slf4j
public class EmnistDataFetcher extends MnistDataFetcher implements DataSetFetcher {
protected EmnistFetcher fetcher;
@ -64,7 +66,7 @@ public class EmnistDataFetcher extends MnistDataFetcher implements DataSetFetche
try {
man = new MnistManager(images, labels, totalExamples);
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
FileUtils.deleteDirectory(new File(EMNIST_ROOT));
new EmnistFetcher(dataSet).downloadAndUntar();
man = new MnistManager(images, labels, totalExamples);

View File

@ -687,9 +687,7 @@ public class BarnesHutTsne implements Model {
* @throws IOException
*/
public void saveAsFile(List<String> labels, String path) throws IOException {
BufferedWriter write = null;
try {
write = new BufferedWriter(new FileWriter(new File(path)));
try (BufferedWriter write = new BufferedWriter(new FileWriter(new File(path)))) {
for (int i = 0; i < Y.rows(); i++) {
if (i >= labels.size())
break;
@ -711,17 +709,11 @@ public class BarnesHutTsne implements Model {
}
write.flush();
write.close();
} finally {
if (write != null)
write.close();
}
}
public void saveAsFile(String path) throws IOException {
BufferedWriter write = null;
try {
write = new BufferedWriter(new FileWriter(new File(path)));
try (BufferedWriter write = new BufferedWriter(new FileWriter(new File(path)))) {
for (int i = 0; i < Y.rows(); i++) {
StringBuilder sb = new StringBuilder();
INDArray wordVector = Y.getRow(i);
@ -734,10 +726,6 @@ public class BarnesHutTsne implements Model {
write.write(sb.toString());
}
write.flush();
write.close();
} finally {
if (write != null)
write.close();
}
}
/**

View File

@ -113,6 +113,12 @@
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-python</artifactId>
<version>${datavec.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
<profiles>

View File

@ -60,7 +60,7 @@ public class Hdf5Archive implements Closeable {
/* This is necessary for the call to the BytePointer constructor below. */
Loader.load(org.bytedeco.hdf5.global.hdf5.class);
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
}
}

View File

@ -19,6 +19,9 @@ package org.deeplearning4j.nn.modelimport.keras.layers;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Convolution3D;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
@ -121,27 +124,29 @@ public class KerasInput extends KerasLayer {
InputType myInputType;
switch (this.inputShape.length) {
case 1:
myInputType = new InputType.InputTypeFeedForward(this.inputShape[0]);
myInputType = new InputType.InputTypeFeedForward(this.inputShape[0], null);
break;
case 2:
if(this.dimOrder != null) {
System.out.println("Dim order: " + this.dimOrder);
System.out.println("Input shape: " + ArrayUtils.toString(this.inputShape));
switch (this.dimOrder) {
case TENSORFLOW: //NWC == channels_last
myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0]);
myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0], RNNFormat.NWC);
break;
case THEANO: //NCW == channels_first
myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]);
myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1], RNNFormat.NCW);
break;
case NONE:
//Assume RNN in [mb, seqLen, size] format
myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]);
myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0], RNNFormat.NWC);
break;
default:
throw new IllegalStateException("Unknown/not supported dimension ordering: " + this.dimOrder);
}
} else {
//Assume RNN in [mb, seqLen, size] format
myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]);
myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0], RNNFormat.NWC);
}
break;
@ -150,17 +155,17 @@ public class KerasInput extends KerasLayer {
case TENSORFLOW:
/* TensorFlow convolutional input: # rows, # cols, # channels */
myInputType = new InputType.InputTypeConvolutional(this.inputShape[0], this.inputShape[1],
this.inputShape[2]);
this.inputShape[2], CNN2DFormat.NHWC);
break;
case THEANO:
/* Theano convolutional input: # channels, # rows, # cols */
myInputType = new InputType.InputTypeConvolutional(this.inputShape[1], this.inputShape[2],
this.inputShape[0]);
this.inputShape[0], CNN2DFormat.NCHW);
break;
default:
this.dimOrder = DimOrder.THEANO;
myInputType = new InputType.InputTypeConvolutional(this.inputShape[1], this.inputShape[2],
this.inputShape[0]);
this.inputShape[0], CNN2DFormat.NCHW);
log.warn("Couldn't determine dim ordering / data format from model file. Older Keras " +
"versions may come without specified backend, in which case we assume the model was " +
"built with theano." );

View File

@ -20,6 +20,7 @@ import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
@ -65,6 +66,9 @@ public class TFOpLayer extends Layer {
long[] shape = inputType.getShape(true);
TFOpLayerImpl tempLayer = new TFOpLayerImpl(nodeDef, constants, null, null);
long[] outputShape = tempLayer.getOutputShape(shape);
if (outputShape.length == 3){
return InputType.recurrent(outputShape[2], outputShape[1], RNNFormat.NWC);
}
return InputType.inferInputType(Nd4j.create(outputShape));
}

View File

@ -125,17 +125,9 @@ public class TFOpLayerImpl extends AbstractLayer<TFOpLayer> {
}
private INDArray runGraph(INDArray input){
if (input.rank() == 3){
// TODO make this a preprocessor
input = input.permute(0, 2, 1);
}
Map<String, INDArray> inputMap = new HashMap<>();
inputMap.put(inputNames.get(0), input);
INDArray out = graphRunnerService.run(inputMap).values().toArray(new INDArray[0])[0];
if (out.rank() == 3){
out = out.permute(0, 2, 1); // TODO post-processing?
}
return out;
}

View File

@ -22,6 +22,7 @@ import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
@ -94,7 +95,6 @@ public class KerasConvolution1D extends KerasConvolution {
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
enforceTrainingConfig, conf, kerasMajorVersion);
Convolution1DLayer.Builder builder = new Convolution1DLayer.Builder().name(this.layerName)
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
.activation(getIActivationFromConfig(layerConfig, conf))
@ -103,7 +103,7 @@ public class KerasConvolution1D extends KerasConvolution {
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
.kernelSize(getKernelSizeFromConfig(layerConfig, 1, conf, kerasMajorVersion)[0])
.hasBias(hasBias)
.stride(getStrideFromConfig(layerConfig, 1, conf)[0]);
.stride(getStrideFromConfig(layerConfig, 1, conf)[0]).rnnDataFormat(dimOrder == DimOrder.TENSORFLOW? RNNFormat.NWC: RNNFormat.NCW);
int[] padding = getPaddingFromBorderModeConfig(layerConfig, 1, conf, kerasMajorVersion);
if (hasBias)
builder.biasInit(0.0);
@ -160,8 +160,8 @@ public class KerasConvolution1D extends KerasConvolution {
public InputPreProcessor getInputPreprocessor(InputType... inputType) throws InvalidKerasConfigurationException {
if (inputType.length > 1)
throw new InvalidKerasConfigurationException(
"Keras LSTM layer accepts only one input (received " + inputType.length + ")");
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], layerName);
"Keras Conv1D layer accepts only one input (received " + inputType.length + ")");
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], RNNFormat.NCW,layerName);
}

View File

@ -20,6 +20,7 @@ import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
@ -27,6 +28,7 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurat
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
import org.deeplearning4j.nn.weights.IWeightInit;
import oshi.jna.platform.windows.PowrProf;
import java.util.Map;
@ -93,6 +95,7 @@ public class KerasConvolution2D extends KerasConvolution {
LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig(
layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion);
System.out.println("----" + dimOrder);
ConvolutionLayer.Builder builder = new ConvolutionLayer.Builder().name(this.layerName)
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
.activation(getIActivationFromConfig(layerConfig, conf))
@ -101,7 +104,8 @@ public class KerasConvolution2D extends KerasConvolution {
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
.kernelSize(getKernelSizeFromConfig(layerConfig, 2, conf, kerasMajorVersion))
.hasBias(hasBias)
.stride(getStrideFromConfig(layerConfig, 2, conf));
.stride(getStrideFromConfig(layerConfig, 2, conf))
.dataFormat((dimOrder==DimOrder.TENSORFLOW)? CNN2DFormat.NHWC:CNN2DFormat.NCHW);
int[] padding = getPaddingFromBorderModeConfig(layerConfig, 2, conf, kerasMajorVersion);
if (hasBias)
builder.biasInit(0.0);

View File

@ -360,8 +360,19 @@ public class KerasConvolutionUtils {
}
} else if (dimension == 1) {
int paddingInt = (int) innerConfig.get(layerField);
padding = new int[]{paddingInt, paddingInt};
Object paddingObj = innerConfig.get(layerField);
if (paddingObj instanceof List){
List<Integer> paddingList = (List)paddingObj;
padding = new int[]{
paddingList.get(0),
paddingList.get(1)
};
}
else{
int paddingInt = (int) innerConfig.get(layerField);
padding = new int[]{paddingInt, paddingInt};
}
} else {
throw new UnsupportedKerasConfigurationException(
"Keras padding layer not supported");

View File

@ -18,6 +18,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.core;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InputType.InputTypeConvolutional;
@ -27,7 +28,6 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurat
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.TensorFlowCnnToFeedForwardPreProcessor;
import java.util.Map;
@ -93,11 +93,10 @@ public class KerasFlatten extends KerasLayer {
switch (this.getDimOrder()) {
case NONE:
case THEANO:
preprocessor = new CnnToFeedForwardPreProcessor(it.getHeight(), it.getWidth(), it.getChannels());
preprocessor = new CnnToFeedForwardPreProcessor(it.getHeight(), it.getWidth(), it.getChannels(), CNN2DFormat.NCHW);
break;
case TENSORFLOW:
preprocessor = new TensorFlowCnnToFeedForwardPreProcessor(it.getHeight(), it.getWidth(),
it.getChannels());
preprocessor = new CnnToFeedForwardPreProcessor(it.getHeight(), it.getWidth(), it.getChannels(), CNN2DFormat.NHWC);
break;
default:
throw new InvalidKerasConfigurationException("Unknown Keras backend " + this.getDimOrder());
@ -111,7 +110,7 @@ public class KerasFlatten extends KerasLayer {
// to RNN type. Otherwise we add this trivial preprocessor (since there's nothing to flatten).
InputType.InputTypeFeedForward it = (InputType.InputTypeFeedForward) inputType[0];
val inputShape = new long[]{it.getSize()};
preprocessor = new ReshapePreprocessor(inputShape, inputShape, false);
preprocessor = new ReshapePreprocessor(inputShape, inputShape, false, null);
}
return preprocessor;
}

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.nn.modelimport.keras.layers.core;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.misc.RepeatVector;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
@ -60,6 +61,7 @@ public class KerasRepeatVector extends KerasLayer {
super(layerConfig, enforceTrainingConfig);
this.layer = new RepeatVector.Builder().repetitionFactor(getRepeatMultiplier(layerConfig, conf))
.dataFormat(RNNFormat.NWC)
.name(this.layerName).build();
}

View File

@ -18,6 +18,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.core;
import lombok.val;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
@ -111,11 +112,9 @@ public class KerasReshape extends KerasLayer {
} else {
targetShape = new long[]{targetShape[1], targetShape[0], targetShape[2]};
}
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false);
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false, CNN2DFormat.NCHW);
} else { // (dimOrder == DimOrder.TENSORFLOW || dimOrder == DimOrder.NONE && kerasMajorVersion == 2)
if (inputShape[0] != targetShape[0])
targetShape = new long[]{targetShape[2], targetShape[0], targetShape[1]};
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false);
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false, CNN2DFormat.NHWC);
}
} else if (inputType[0] instanceof InputType.InputTypeConvolutional3D) {
@ -128,23 +127,23 @@ public class KerasReshape extends KerasLayer {
} else {
targetShape = new long[] { targetShape[2], targetShape[1], targetShape[0], targetShape[3] };
}
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false);
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false, null);
} else {
if (inputShape[0] != targetShape[0])
targetShape = new long[] { targetShape[3], targetShape[0], targetShape[1], targetShape[2] };
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false);
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false, null);
}
} else if (inputType[0] instanceof InputType.InputTypeRecurrent) {
InputType.InputTypeRecurrent it = (InputType.InputTypeRecurrent) inputType[0];
val inputShape = new long[]{it.getSize(), it.getTimeSeriesLength()};
preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false);
preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false, null);
} else if (inputType[0] instanceof InputType.InputTypeFeedForward) {
InputType.InputTypeFeedForward it = (InputType.InputTypeFeedForward) inputType[0];
val inputShape = new long[]{it.getSize()};
if (targetShape.length == 3) {
targetShape = targetShapeForDimOrder(inputShape, targetShape);
}
preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false);
preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false, null);
}
return preprocessor;
}

View File

@ -21,6 +21,7 @@ import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
@ -121,6 +122,7 @@ public class KerasEmbedding extends KerasLayer {
.biasInit(0.0)
.l1(this.weightL1Regularization)
.l2(this.weightL2Regularization)
.outputDataFormat(RNNFormat.NWC)
.hasBias(false);
if (embeddingConstraint != null)
builder.constrainWeights(embeddingConstraint);

View File

@ -22,11 +22,9 @@ import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
@ -37,6 +35,7 @@ import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
import org.deeplearning4j.nn.params.LSTMParamInitializer;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.util.TimeSeriesUtils;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -187,7 +186,7 @@ public class KerasLSTM extends KerasLayer {
.weightInitRecurrent(recurrentInit)
.biasInit(0.0) // TODO: this is incorrect
.l1(this.weightL1Regularization)
.l2(this.weightL2Regularization);
.l2(this.weightL2Regularization).dataFormat(RNNFormat.NWC);
Integer nIn = KerasLayerUtils.getNInFromInputDim(layerConfig, conf);
if(nIn != null)
builder.setNIn(nIn);
@ -266,7 +265,8 @@ public class KerasLSTM extends KerasLayer {
throw new InvalidKerasConfigurationException("Keras LSTM layer accepts only one single input" +
"or three (input to LSTM and two states tensors, but " +
"received " + inputType.length + ".");
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], layerName);
RNNFormat f = TimeSeriesUtils.getFormatFromRnnLayer(layer);
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], f,layerName);
}
/**

View File

@ -21,7 +21,9 @@ import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
import org.deeplearning4j.nn.conf.layers.Layer;
@ -36,6 +38,7 @@ import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
import org.deeplearning4j.nn.params.SimpleRnnParamInitializer;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.util.TimeSeriesUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.primitives.Pair;
@ -155,7 +158,7 @@ public class KerasSimpleRnn extends KerasLayer {
.weightInitRecurrent(recurrentInit)
.biasInit(0.0)
.l1(this.weightL1Regularization)
.l2(this.weightL2Regularization);
.l2(this.weightL2Regularization).dataFormat(RNNFormat.NWC);
Integer nIn = KerasLayerUtils.getNInFromInputDim(layerConfig, conf);
if(nIn != null)
builder.setNIn(nIn);
@ -227,7 +230,8 @@ public class KerasSimpleRnn extends KerasLayer {
throw new InvalidKerasConfigurationException(
"Keras SimpleRnn layer accepts only one input (received " + inputType.length + ")");
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], layerName);
RNNFormat f = TimeSeriesUtils.getFormatFromRnnLayer(layer);
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], f, layerName);
}
/**

View File

@ -147,7 +147,7 @@ public class KerasBidirectional extends KerasLayer {
break;
case "SimpleRNN":
kerasRnnlayer = new KerasSimpleRnn(innerRnnConfig, enforceTrainingConfig, previousLayers);
SimpleRnn rnnLayer = (SimpleRnn) ((KerasSimpleRnn) kerasRnnlayer).getSimpleRnnLayer();
Layer rnnLayer = ((KerasSimpleRnn) kerasRnnlayer).getSimpleRnnLayer();
this.layer = new Bidirectional(mode, rnnLayer);
layer.setLayerName(layerName);
break;
@ -218,7 +218,7 @@ public class KerasBidirectional extends KerasLayer {
if (inputType.length > 1)
throw new InvalidKerasConfigurationException(
"Keras Bidirectional layer accepts only one input (received " + inputType.length + ")");
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], layerName);
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], ((Bidirectional)layer).getRNNDataFormat(), layerName);
}
/**

View File

@ -21,6 +21,9 @@ import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.DataFormat;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
import org.deeplearning4j.nn.conf.preprocessor.BaseInputPreProcessor;
@ -54,25 +57,30 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
private final long[] inputShape;
private final long[] targetShape;
private boolean hasMiniBatchDimension;
/**
* @deprecated Use constructor {@link #ReshapePreprocessor(long[], long[], boolean)}
*/
@Deprecated
public ReshapePreprocessor(long[] inputShape, long[] targetShape) {
this(inputShape, targetShape, false);
}
private DataFormat format;
/**
* @param inputShape Input shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension
* @param targetShape Target shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension
* @param hasMiniBatchDimension If true: shapes should be of the form [minibatch, x, y, ...]; if false: shapes should be of form [x, y, ...]
*/
public ReshapePreprocessor(long[] inputShape, long[] targetShape, boolean hasMiniBatchDimension) {
this(inputShape, targetShape, hasMiniBatchDimension, null);
}
/**
* @param inputShape Input shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension
* @param targetShape Target shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension
* @param hasMiniBatchDimension If true: shapes should be of the form [minibatch, x, y, ...]; if false: shapes should be of form [x, y, ...]
* @param dataFormat May be null. If non-null:
*/
public ReshapePreprocessor(@JsonProperty("inputShape") long[] inputShape, @JsonProperty("targetShape") long[] targetShape,
@JsonProperty("hasMiniBatchDimension") boolean hasMiniBatchDimension) {
@JsonProperty("hasMiniBatchDimension") boolean hasMiniBatchDimension,
@JsonProperty("dataFormat") DataFormat dataFormat) {
this.inputShape = inputShape;
this.targetShape = targetShape;
this.hasMiniBatchDimension = hasMiniBatchDimension;
this.format = dataFormat;
}
private long[] getShape(long[] originalShape, long minibatch) {
@ -140,13 +148,26 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
ret = InputType.feedForward(shape[1]);
break;
case 3:
ret = InputType.recurrent(shape[2], shape[1]);
RNNFormat format = RNNFormat.NCW;
if(this.format != null && this.format instanceof RNNFormat)
format = (RNNFormat)this.format;
ret = InputType.recurrent(shape[2], shape[1], format);
break;
case 4:
if (inputShape.length == 1 || inputType.getType() == InputType.Type.RNN) {
ret = InputType.convolutional(shape[1], shape[2], shape[3]);
} else {
ret = InputType.convolutional(shape[2], shape[3], shape[1]);
CNN2DFormat cnnFormat = CNN2DFormat.NCHW;
if (this.format != null && this.format instanceof CNN2DFormat)
cnnFormat = (CNN2DFormat) this.format;
if (cnnFormat == CNN2DFormat.NCHW) {
ret = InputType.convolutional(shape[2], shape[3], shape[1], cnnFormat);
} else {
ret = InputType.convolutional(shape[1], shape[2], shape[3], cnnFormat);
}
}
break;
default:

View File

@ -27,26 +27,25 @@ import org.nd4j.shade.jackson.annotation.JsonCreator;
import org.nd4j.shade.jackson.annotation.JsonProperty;
/**
* Specialized CnnToFeedForwardInputPreProcessor for use with
* Convolutional layers imported from Keras using the TensorFlow
* backend.
*
* @author dave@skymind.io
* @deprecated Exists only for backward compatibility of older pretrained models. Should not be used.
* Use {@link CnnToFeedForwardPreProcessor} for all new models instead.
*/
@Slf4j
@Slf4j @Deprecated
public class TensorFlowCnnToFeedForwardPreProcessor extends CnnToFeedForwardPreProcessor {
@JsonCreator
@JsonCreator @Deprecated
public TensorFlowCnnToFeedForwardPreProcessor(@JsonProperty("inputHeight") long inputHeight,
@JsonProperty("inputWidth") long inputWidth,
@JsonProperty("numChannels") long numChannels) {
super(inputHeight, inputWidth, numChannels);
}
@Deprecated
public TensorFlowCnnToFeedForwardPreProcessor(long inputHeight, long inputWidth) {
super(inputHeight, inputWidth);
}
@Deprecated
public TensorFlowCnnToFeedForwardPreProcessor() {
super();
}
@ -81,4 +80,4 @@ public class TensorFlowCnnToFeedForwardPreProcessor extends CnnToFeedForwardPreP
public TensorFlowCnnToFeedForwardPreProcessor clone() {
return (TensorFlowCnnToFeedForwardPreProcessor) super.clone();
}
}
}

View File

@ -31,6 +31,7 @@ import org.deeplearning4j.nn.modelimport.keras.layers.advanced.activations.*;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.*;
import org.deeplearning4j.nn.modelimport.keras.layers.core.*;
import org.deeplearning4j.nn.modelimport.keras.layers.embeddings.KerasEmbedding;
import org.deeplearning4j.nn.modelimport.keras.layers.local.KerasLocallyConnected1D;
import org.deeplearning4j.nn.modelimport.keras.layers.noise.KerasAlphaDropout;
import org.deeplearning4j.nn.modelimport.keras.layers.noise.KerasGaussianDropout;
import org.deeplearning4j.nn.modelimport.keras.layers.noise.KerasGaussianNoise;
@ -319,6 +320,8 @@ public class KerasLayerUtils {
layer = new KerasELU(layerConfig, enforceTrainingConfig);
} else if(layerClassName.equals(conf.getLAYER_CLASS_NAME_SOFTMAX())){
layer = new KerasSoftmax(layerConfig, enforceTrainingConfig);
} else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_LOCALLY_CONNECTED_1D())){
layer = new KerasLocallyConnected1D(layerConfig, enforceTrainingConfig);
} else if (conf instanceof Keras2LayerConfiguration){
Keras2LayerConfiguration k2conf = (Keras2LayerConfiguration)conf;
if (layerClassName.equals(k2conf.getTENSORFLOW_OP_LAYER())){

View File

@ -1,50 +0,0 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.nn.modelimport.keras;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.junit.Assert;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.resources.Resources;
import java.io.File;
import java.util.Arrays;
public class TFKerasTests extends BaseDL4JTest{
@Test
public void testModelWithTFOp1() throws Exception{
File f = Resources.asFile("modelimport/keras/tfkeras/reshape.h5");
ComputationGraph graph = KerasModelImport.importKerasModelAndWeights(f.getAbsolutePath());
INDArray out = graph.outputSingle(Nd4j.zeros(12, 2, 3));
Assert.assertArrayEquals(new long[]{12, 3}, out.shape());
}
@Test
public void testModelWithTFOp2() throws Exception{
File f = Resources.asFile("modelimport/keras/tfkeras/permute.h5");
ComputationGraph graph = KerasModelImport.importKerasModelAndWeights(f.getAbsolutePath());
INDArray out = graph.outputSingle(Nd4j.zeros(12, 2, 3));
// dl4j's feedforward doesn't support 3D output, so batch and time axes gets squashed
long[] expectedShape = new long[]{12 * 2, 5};
Assert.assertArrayEquals(expectedShape, out.shape());
}
}

View File

@ -0,0 +1,147 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.nn.modelimport.keras;
import org.apache.commons.io.FileUtils;
import org.datavec.python.keras.Model;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.common.tests.ResourceUtils;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.resources.Resources;
import java.io.File;
import java.util.List;
@RunWith(Parameterized.class)
public class TestTFKerasModelImport extends BaseDL4JTest{
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
private String modelFile;
@Override
public long getTimeoutMilliseconds(){
return 300000;
} // installing TF will take a while
@Parameterized.Parameters(name = "file={0}")
public static Object[] params() throws Exception {
List<String> paths = ResourceUtils.listClassPathFiles("modelimport/keras/tfkeras", true, false);
return paths.toArray(new String[0]);
}
public TestTFKerasModelImport(String modelFile){
this.modelFile = modelFile;
}
@Test
public void testModelImport() throws Exception{
testModelImportWithData(modelFile);
}
private void testModelImportWithData(String path) throws Exception{
System.out.println(path);
// TODO multi input/output
INDArray inputArray;
INDArray expectedOutputArray;
File f = Resources.asFile(path); //May in in JAR that HDF5 can't read from
File modelFile = new File(testDir.getRoot(), f.getName());
FileUtils.copyFile(f, modelFile);
synchronized (Hdf5Archive.LOCK_OBJECT){
Hdf5Archive hdf5Archive = new Hdf5Archive(modelFile.getAbsolutePath());
List<String> rootGroups = hdf5Archive.getGroups();
if (rootGroups.contains("data")){
String inputName = hdf5Archive.readAttributeAsString("input_names", "data");
String outputName = hdf5Archive.readAttributeAsString("output_names", "data");
inputArray = hdf5Archive.readDataSet(inputName, "data");
expectedOutputArray = hdf5Archive.readDataSet(outputName, "data");
}
else{
hdf5Archive.close();
return;
}
hdf5Archive.close();
}
INDArray outputArray;
ComputationGraph dl4jModel = KerasModelImport.importKerasModelAndWeights(path);
outputArray = dl4jModel.outputSingle(inputArray);
expectedOutputArray = expectedOutputArray.castTo(DataType.FLOAT);
outputArray = outputArray.castTo(DataType.FLOAT);
if (path.contains("misc_")){
//shape relaxation
expectedOutputArray = expectedOutputArray.reshape( -1);
outputArray = outputArray.reshape(-1);
}
System.out.println(outputArray.toString());
System.out.println(expectedOutputArray.toString());
Assert.assertArrayEquals(expectedOutputArray.shape(), outputArray.shape());
Assert.assertTrue(expectedOutputArray.equalsWithEps(outputArray, 1e-3));
}
private void testModelImportWithKeras(String path) throws Exception{
Model kerasModel = new Model(path);
ComputationGraph dl4jModel = KerasModelImport.importKerasModelAndWeights(path);
Assert.assertEquals(kerasModel.numInputs(), dl4jModel.getNumInputArrays());
Assert.assertEquals(kerasModel.numOutputs(), dl4jModel.getNumOutputArrays());
INDArray[] kerasInputArrays = new INDArray[kerasModel.numInputs()];
INDArray[] dl4jInputArrays = new INDArray[kerasModel.numInputs()];
for (int i = 0; i < kerasInputArrays.length; i ++) {
long[] shape = kerasModel.inputShapeAt(i);
for (int j = 0; j < shape.length; j++) {
if (shape[j] < 0) {
shape[j] = 1;
}
}
kerasInputArrays[i] = Nd4j.rand(shape);
}
INDArray[] kerasOut = kerasModel.predict(kerasInputArrays);
INDArray[] dl4jOut = dl4jModel.output(dl4jInputArrays);
Assert.assertEquals(kerasOut.length, dl4jOut.length);
for (int i = 0; i < kerasOut.length; i++){
INDArray kerasOutArr = kerasOut[i];
kerasOutArr = kerasOutArr.reshape(1, -1);// bit of relaxation on shape
kerasOutArr= kerasOutArr.castTo(DataType.DOUBLE);
Nd4j.getAffinityManager().ensureLocation(dl4jOut[i], AffinityManager.Location.HOST);
INDArray dl4jOutArr = dl4jOut[i].reshape(1, -1);
System.out.println(kerasOutArr.shapeInfoToString());
System.out.println(dl4jOutArr.shapeInfoToString());
Assert.assertEquals(kerasOutArr, dl4jOutArr);
}
}
}

Some files were not shown because too many files have changed in this diff Show More