Merge pull request #8892 from KonduitAI/master

Merge latest development work
master
Alex Black 2020-04-28 20:38:56 +10:00 committed by GitHub
commit 1930d99908
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
475 changed files with 10073 additions and 2958 deletions

View File

@ -98,7 +98,7 @@ public class TestGraphLocalExecution extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 90000L;
return 120_000L;
}
@Test
@ -156,7 +156,7 @@ public class TestGraphLocalExecution extends BaseDL4JTest {
.dataSource(ds, dsP)
.modelSaver(new FileModelSaver(modelSave))
.scoreFunction(new TestSetLossScoreFunction())
.terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS),
.terminationConditions(new MaxTimeCondition(20, TimeUnit.SECONDS),
new MaxCandidatesCondition(3))
.build();

View File

@ -87,7 +87,7 @@ public class TestGraphLocalExecutionGenetic extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 45000L;
return 120_000L;
}
@Test
@ -154,8 +154,8 @@ public class TestGraphLocalExecutionGenetic extends BaseDL4JTest {
.dataSource(ds, dsP)
.modelSaver(new FileModelSaver(modelSave))
.scoreFunction(new TestSetLossScoreFunction())
.terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS),
new MaxCandidatesCondition(10))
.terminationConditions(new MaxTimeCondition(20, TimeUnit.SECONDS),
new MaxCandidatesCondition(3))
.build();
IOptimizationRunner runner = new LocalOptimizationRunner(configuration, new ComputationGraphTaskCreator(new ClassificationEvaluator()));

View File

@ -18,6 +18,7 @@ package org.datavec.api.records.reader.impl;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import org.apache.commons.io.LineIterator;
import org.datavec.api.conf.Configuration;
@ -43,6 +44,7 @@ import java.util.*;
*
* @author Adam Gibson
*/
@Slf4j
public class LineRecordReader extends BaseRecordReader {
@ -58,6 +60,13 @@ public class LineRecordReader extends BaseRecordReader {
@Override
public void initialize(InputSplit split) throws IOException, InterruptedException {
super.initialize(split);
if(!(inputSplit instanceof StringSplit || inputSplit instanceof InputStreamInputSplit)){
final ArrayList<URI> uris = new ArrayList<>();
final Iterator<URI> uriIterator = inputSplit.locationsIterator();
while(uriIterator.hasNext()) uris.add(uriIterator.next());
this.locations = uris.toArray(new URI[0]);
}
this.iter = getIterator(0);
this.initialized = true;
}
@ -66,7 +75,6 @@ public class LineRecordReader extends BaseRecordReader {
public void initialize(Configuration conf, InputSplit split) throws IOException, InterruptedException {
this.conf = conf;
initialize(split);
this.initialized = true;
}
@Override
@ -89,7 +97,7 @@ public class LineRecordReader extends BaseRecordReader {
iter = getIterator(splitIndex);
onLocationOpen(locations[splitIndex]);
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
if (iter.hasNext()) {
@ -120,7 +128,7 @@ public class LineRecordReader extends BaseRecordReader {
iter = getIterator(splitIndex);
onLocationOpen(locations[splitIndex]);
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
return iter.hasNext();
@ -205,11 +213,6 @@ public class LineRecordReader extends BaseRecordReader {
}
}
} else {
final ArrayList<URI> uris = new ArrayList<>();
final Iterator<URI> uriIterator = inputSplit.locationsIterator();
while(uriIterator.hasNext()) uris.add(uriIterator.next());
this.locations = uris.toArray(new URI[uris.size()]);
if (locations.length > 0) {
InputStream inputStream = streamCreatorFn.apply(locations[location]);
try {

View File

@ -16,6 +16,7 @@
package org.datavec.api.split;
import lombok.extern.slf4j.Slf4j;
import org.datavec.api.util.files.UriFromPathIterator;
import org.datavec.api.writable.WritableType;
@ -34,6 +35,7 @@ import java.util.regex.Pattern;
* NumberedFileInputSplit utilizes String.format(), hence the requirement for "%d" to represent
* the integer index.
*/
@Slf4j
public class NumberedFileInputSplit implements InputSplit {
private final String baseString;
private final int minIdx;
@ -93,7 +95,7 @@ public class NumberedFileInputSplit implements InputSplit {
try {
writeFile.createNewFile();
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}

View File

@ -23,6 +23,7 @@ import java.net.URI;
import java.net.URISyntaxException;
import java.util.Iterator;
import java.util.NoSuchElementException;
import java.util.regex.Pattern;
/**
* A simple utility method to convert a {@code Iterator<String>} to an {@code Iterator<URI>}, where each
@ -32,6 +33,7 @@ import java.util.NoSuchElementException;
*/
@AllArgsConstructor
public class UriFromPathIterator implements Iterator<URI> {
final Pattern schemaPattern = Pattern.compile("^.*?:/.*");
private final Iterator<String> paths;
@ -42,16 +44,17 @@ public class UriFromPathIterator implements Iterator<URI> {
@Override
public URI next() {
if (!hasNext()) {
throw new NoSuchElementException("No next element");
}
try {
String s = paths.next();
if(!s.matches(".*:/.*")){
if(schemaPattern.matcher(s).matches()){
return new URI(s);
} else {
//No scheme - assume file for backward compatibility
return new File(s).toURI();
} else {
return new URI(s);
}
} catch (URISyntaxException e) {

View File

@ -162,7 +162,6 @@ public class Text extends BinaryComparable implements WritableComparable<BinaryC
return -1; // not found
} catch (CharacterCodingException e) {
// can't get here
e.printStackTrace();
return -1;
}
}

View File

@ -17,6 +17,7 @@
package org.datavec.arrow.recordreader;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.datavec.api.conf.Configuration;
import org.datavec.api.records.Record;
@ -50,6 +51,7 @@ import static org.datavec.arrow.ArrowConverter.readFromBytes;
* @author Adam Gibson
*
*/
@Slf4j
public class ArrowRecordReader implements RecordReader {
private InputSplit split;
@ -132,7 +134,7 @@ public class ArrowRecordReader implements RecordReader {
currIdx++;
this.currentPath = url;
}catch(Exception e) {
e.printStackTrace();
log.error("",e);
}
}
@ -242,7 +244,7 @@ public class ArrowRecordReader implements RecordReader {
try {
currentBatch.close();
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
}
}

View File

@ -16,6 +16,7 @@
package org.datavec.arrow;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
@ -56,7 +57,7 @@ import static junit.framework.TestCase.assertTrue;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
@Slf4j
public class ArrowConverterTest extends BaseND4JTest {
private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE);
@ -343,7 +344,7 @@ public class ArrowConverterTest extends BaseND4JTest {
try(ArrowFileWriter arrowFileWriter = new ArrowFileWriter(schemaRoot1,null,newChannel(byteArrayOutputStream))) {
arrowFileWriter.writeBatch();
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
byte[] arr = byteArrayOutputStream.toByteArray();

View File

@ -61,7 +61,7 @@ public class Wave implements Serializable {
initWaveWithInputStream(inputStream);
inputStream.close();
} catch (IOException e) {
e.printStackTrace();
System.out.println(e.toString());
}
}
@ -96,7 +96,7 @@ public class Wave implements Serializable {
data = new byte[inputStream.available()];
inputStream.read(data);
} catch (IOException e) {
e.printStackTrace();
System.err.println(e.toString());
}
// end load data
} else {

View File

@ -16,10 +16,12 @@
package org.datavec.audio;
import lombok.extern.slf4j.Slf4j;
import java.io.FileOutputStream;
import java.io.IOException;
@Slf4j
public class WaveFileManager {
private Wave wave;
@ -78,7 +80,7 @@ public class WaveFileManager {
fos.write(wave.getBytes());
fos.close();
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
}

View File

@ -16,6 +16,8 @@
package org.datavec.audio;
import lombok.extern.slf4j.Slf4j;
import java.io.IOException;
import java.io.InputStream;
@ -25,6 +27,7 @@ import java.io.InputStream;
*
* @author Jacquet Wong
*/
@Slf4j
public class WaveHeader {
public static final String RIFF_HEADER = "RIFF";
@ -109,7 +112,7 @@ public class WaveHeader {
// dis.close();
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
return false;
}

View File

@ -17,6 +17,7 @@
package org.datavec.audio.fingerprint;
import lombok.extern.slf4j.Slf4j;
import org.datavec.audio.Wave;
import org.datavec.audio.WaveHeader;
import org.datavec.audio.dsp.Resampler;
@ -38,6 +39,7 @@ import java.util.List;
* @author jacquet
*
*/
@Slf4j
public class FingerprintManager {
private FingerprintProperties fingerprintProperties = FingerprintProperties.getInstance();
@ -153,7 +155,7 @@ public class FingerprintManager {
fingerprint = getFingerprintFromInputStream(fis);
fis.close();
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
return fingerprint;
}
@ -170,7 +172,7 @@ public class FingerprintManager {
fingerprint = new byte[inputStream.available()];
inputStream.read(fingerprint);
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
return fingerprint;
}
@ -190,7 +192,7 @@ public class FingerprintManager {
fileOutputStream.write(fingerprint);
fileOutputStream.close();
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
}

View File

@ -81,7 +81,7 @@ public abstract class BaseImageLoader implements Serializable {
String fileName = file.toString();
if (fileName.endsWith(".tgz") || fileName.endsWith(".tar.gz") || fileName.endsWith(".gz")
|| fileName.endsWith(".zip"))
ArchiveUtils.unzipFileTo(file.getAbsolutePath(), fullDir.getAbsolutePath());
ArchiveUtils.unzipFileTo(file.getAbsolutePath(), fullDir.getAbsolutePath(), false);
} catch (IOException e) {
throw new IllegalStateException("Unable to fetch images", e);
}

View File

@ -186,7 +186,7 @@ public class CifarLoader extends NativeImageLoader implements Serializable {
labels.add(line);
}
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
}
@ -300,7 +300,7 @@ public class CifarLoader extends NativeImageLoader implements Serializable {
try {
FileUtils.write(meanVarPath, uMean + "," + uStd + "," + vMean + "," + vStd);
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
meanStdStored = true;
} else if (uMean == 0 && meanStdStored) {
@ -312,7 +312,7 @@ public class CifarLoader extends NativeImageLoader implements Serializable {
vStd = Double.parseDouble(values[3]);
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
}
for (int i = 0; i < result.numExamples(); i++) {
@ -356,12 +356,12 @@ public class CifarLoader extends NativeImageLoader implements Serializable {
dataSets.add(new DataSet(asMatrix(matConversion.getSecond()), matConversion.getFirst()));
batchNumCount++;
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
break;
}
}
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
if(dataSets.size() == 0){

View File

@ -235,7 +235,7 @@ public class LFWLoader extends BaseImageLoader implements Serializable {
InputSplit data = train ? inputSplit[0] : inputSplit[1];
recordReader.initialize(data);
} catch (IOException | InterruptedException e) {
e.printStackTrace();
log.error("",e);
}
return recordReader;
}
@ -250,7 +250,7 @@ public class LFWLoader extends BaseImageLoader implements Serializable {
InputSplit data = train ? inputSplit[0] : inputSplit[1];
recordReader.initialize(data);
} catch (IOException | InterruptedException e) {
e.printStackTrace();
log.error("",e);
}
return recordReader;
}

View File

@ -225,7 +225,7 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
finishedInputStreamSplit = true;
return Arrays.<Writable>asList(ndArrayWritable);
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
}
if (iter != null) {

View File

@ -16,6 +16,7 @@
package org.datavec.image.loader;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.apache.commons.io.IOUtils;
import org.bytedeco.javacpp.Loader;
@ -53,6 +54,7 @@ import static org.junit.Assert.fail;
*
* @author saudet
*/
@Slf4j
public class TestNativeImageLoader {
static final long seed = 10;
static final Random rng = new Random(seed);
@ -123,7 +125,7 @@ public class TestNativeImageLoader {
try {
array6 = loader5.asMatrix(new ClassPathResource(path2MitosisFile).getFile().getAbsolutePath());
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
fail();
}
assertEquals(5, array6.rank());
@ -156,7 +158,7 @@ public class TestNativeImageLoader {
try {
array8 = loader7.asMatrix(new ClassPathResource(path2MitosisFile).getFile().getAbsolutePath());
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
assertEquals(5, array8.rank());
assertEquals(pages2, array8.size(0));
@ -172,7 +174,7 @@ public class TestNativeImageLoader {
try {
array9 = loader8.asMatrix(new ClassPathResource(braintiff).getFile().getAbsolutePath());
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
fail();
}
assertEquals(5, array9.rank());

View File

@ -66,7 +66,7 @@ public class UimaTokenizer implements Tokenizer {
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
throw new RuntimeException(e);
}

View File

@ -17,6 +17,7 @@
package org.datavec.local.transforms.functions;
import lombok.extern.slf4j.Slf4j;
import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.writable.Writable;
import org.nd4j.linalg.function.Function;
@ -32,6 +33,7 @@ import java.util.List;
* sequence data into a {@code List<List<Writable>>}
* @author Alex Black
*/
@Slf4j
public class SequenceRecordReaderFunction
implements Function<Pair<String, InputStream>, List<List<Writable>>> {
protected SequenceRecordReader sequenceRecordReader;
@ -46,7 +48,7 @@ public class SequenceRecordReaderFunction
try (DataInputStream dis = (DataInputStream) value.getRight()) {
return sequenceRecordReader.sequenceRecord(uri, dis);
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
throw new IllegalStateException("Something went wrong");

View File

@ -20,6 +20,7 @@ package org.datavec.python;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import org.bytedeco.cpython.global.python;
import org.bytedeco.numpy.global.numpy;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -343,6 +344,19 @@ public class PythonExecutioner {
if (path == null) {
log.info("Setting python default path");
File[] packages = numpy.cachePackages();
//// TODO: fix in javacpp
File sitePackagesWindows = new File(python.cachePackage(), "site-packages");
File[] packages2 = new File[packages.length + 1];
for (int i = 0;i < packages.length; i++){
//System.out.println(packages[i].getAbsolutePath());
packages2[i] = packages[i];
}
packages2[packages.length] = sitePackagesWindows;
//System.out.println(sitePackagesWindows.getAbsolutePath());
packages = packages2;
//////////
Py_SetPath(packages);
} else {
log.info("Setting python path " + path);

View File

@ -0,0 +1,132 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.python;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import org.bytedeco.javacpp.Loader;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
@Slf4j
public class PythonProcess {
private static String pythonExecutable = Loader.load(org.bytedeco.cpython.python.class);
public static String runAndReturn(String... arguments)throws IOException, InterruptedException{
String[] allArgs = new String[arguments.length + 1];
for (int i = 0; i < arguments.length; i++){
allArgs[i + 1] = arguments[i];
}
allArgs[0] = pythonExecutable;
log.info("Executing command: " + Arrays.toString(allArgs));
ProcessBuilder pb = new ProcessBuilder(allArgs);
Process process = pb.start();
String out = IOUtils.toString(process.getInputStream(), StandardCharsets.UTF_8);
process.waitFor();
return out;
}
public static void run(String... arguments)throws IOException, InterruptedException{
String[] allArgs = new String[arguments.length + 1];
for (int i = 0; i < arguments.length; i++){
allArgs[i + 1] = arguments[i];
}
allArgs[0] = pythonExecutable;
log.info("Executing command: " + Arrays.toString(allArgs));
ProcessBuilder pb = new ProcessBuilder(allArgs);
pb.inheritIO().start().waitFor();
}
public static void pipInstall(String packageName) throws PythonException{
try{
run("-m", "pip", "install", packageName);
}catch(Exception e){
throw new PythonException("Error installing package " + packageName, e);
}
}
public static void pipInstall(String packageName, String version) throws PythonException{
pipInstall(packageName + "==" + version);
}
public static void pipUninstall(String packageName) throws PythonException{
try{
run("-m", "pip", "uninstall", packageName);
}catch(Exception e){
throw new PythonException("Error uninstalling package " + packageName, e);
}
}
public static void pipInstallFromGit(String gitRepoUrl) throws PythonException{
if (!gitRepoUrl.contains("://")){
gitRepoUrl = "git://" + gitRepoUrl;
}
try{
run("-m", "pip", "install", "git+", gitRepoUrl);
}catch(Exception e){
throw new PythonException("Error installing package from " + gitRepoUrl, e);
}
}
public static String getPackageVersion(String packageName) throws PythonException{
String out;
try{
out = runAndReturn("-m", "pip", "show", packageName);
} catch (Exception e){
throw new PythonException("Error finding version for package " + packageName, e);
}
if (!out.contains("Version: ")){
throw new PythonException("Can't find package " + packageName);
}
String pkgVersion = out.split("Version: ")[1].split(System.lineSeparator())[0];
return pkgVersion;
}
public static boolean isPackageInstalled(String packageName)throws PythonException{
try{
String out = runAndReturn("-m", "pip", "show", packageName);
return !out.isEmpty();
}catch (Exception e){
throw new PythonException("Error checking if package is installed: " +packageName, e);
}
}
public static void pipInstallFromRequirementsTxt(String path) throws PythonException{
try{
run("-m", "pip", "install","-r", path);
}catch (Exception e){
throw new PythonException("Error installing packages from " + path, e);
}
}
public static void pipInstallFromSetupScript(String path, boolean inplace) throws PythonException{
try{
run(path, inplace?"develop":"install");
}catch (Exception e){
throw new PythonException("Error installing package from " + path, e);
}
}
}

View File

@ -0,0 +1,144 @@
package org.datavec.python.keras;
import org.datavec.python.Python;
import org.datavec.python.PythonException;
import org.datavec.python.PythonObject;
import org.datavec.python.PythonProcess;
import org.nd4j.linalg.api.ndarray.INDArray;
public class Model {
private PythonObject pyModel;
private static PythonObject installAndImportTF() throws PythonException{
if (!PythonProcess.isPackageInstalled("tensorflow")){
PythonProcess.pipInstall("tensorflow");
}
return Python.importModule("tensorflow");
}
private static PythonObject getKerasModule() throws PythonException{
PythonObject tf = installAndImportTF();
PythonObject keras = tf.attr("keras");
tf.del();
return keras;
}
private static PythonObject loadModel(String s) throws PythonException{
PythonObject models = getKerasModule().attr("models");
PythonObject loadModelF = models.attr("load_model");
PythonObject model = loadModelF.call(s);
models.del();
loadModelF.del();
return model;
}
public Model(String path) throws PythonException{
pyModel = loadModel(path);
}
public INDArray[] predict(INDArray... inputs) throws PythonException{
PythonObject predictF = pyModel.attr("predict");
PythonObject inputList = new PythonObject(inputs);
PythonObject pyOut = predictF.call(inputList);
INDArray[] out;
if (Python.isinstance(pyOut, Python.listType())){
out = new INDArray[Python.len(pyOut).toInt()];
for(int i = 0; i < out.length; i++){
out[i] = pyOut.get(i).toNumpy().getNd4jArray();
}
}
else{
out = new INDArray[]{
pyOut.toNumpy().getNd4jArray()};
}
predictF.del();
inputList.del();
pyOut.del();
return out;
}
public int numInputs(){
PythonObject inputs = pyModel.attr("inputs");
PythonObject pyNumInputs = Python.len(inputs);
int ret = pyNumInputs.toInt();
inputs.del();
pyNumInputs.del();
return ret;
}
public int numOutputs(){
PythonObject outputs = pyModel.attr("outputs");
PythonObject pyNumOutputs = Python.len(outputs);
int ret = pyNumOutputs.toInt();
outputs.del();
pyNumOutputs.del();
return ret;
}
public long[][] inputShapes(){
long[][] ret = new long[numInputs()][];
for (int i = 0; i < ret.length; i++){
ret[i] = inputShapeAt(i);
}
return ret;
}
public long[][] outputShapes(){
long[][] ret = new long[numOutputs()][];
for (int i = 0; i < ret.length; i++){
ret[i] = outputShapeAt(i);
}
return ret;
}
public long[] inputShapeAt(int input){
PythonObject inputs = pyModel.attr("inputs");
PythonObject tensor = inputs.get(input);
PythonObject tensorShape = tensor.attr("shape");
PythonObject shapeList = Python.list(tensorShape);
PythonObject pyNdim = Python.len(shapeList);
int ndim = pyNdim.toInt();
long[] shape = new long[ndim];
for(int i = 0; i < shape.length; i++){
PythonObject pyDim = shapeList.get(i);
if (pyDim == null || !Python.isinstance(pyDim, Python.intType())){
shape[i] = -1;
}
else{
shape[i] = pyDim.toLong();
}
}
pyNdim.del();
shapeList.del();
tensorShape.del();
tensor.del();
inputs.del();
return shape;
}
public long[] outputShapeAt(int output){
PythonObject inputs = pyModel.attr("outputs");
PythonObject tensor = inputs.get(output);
PythonObject tensorShape = tensor.attr("shape");
PythonObject shapeList = Python.list(tensorShape);
PythonObject pyNdim = Python.len(shapeList);
int ndim = pyNdim.toInt();
long[] shape = new long[ndim];
for(int i = 0; i < shape.length; i++){
PythonObject pyDim = shapeList.get(i);
if (pyDim == null || !Python.isinstance(pyDim, Python.intType())){
shape[i] = -1;
}
else{
shape[i] = pyDim.toLong();
}
}
pyNdim.del();
shapeList.del();
tensorShape.del();
tensor.del();
inputs.del();
return shape;
}
}

View File

@ -74,7 +74,6 @@ public class DataVecTransformClient implements DataVecTransformService {
} catch (UnirestException e) {
log.error("Error in setCSVTransformProcess()", e);
e.printStackTrace();
}
}
@ -94,7 +93,6 @@ public class DataVecTransformClient implements DataVecTransformService {
return TransformProcess.fromJson(s);
} catch (UnirestException e) {
log.error("Error in getCSVTransformProcess()",e);
e.printStackTrace();
}
return null;
@ -119,7 +117,6 @@ public class DataVecTransformClient implements DataVecTransformService {
return singleCsvRecord;
} catch (UnirestException e) {
log.error("Error in transformIncremental(SingleCSVRecord)",e);
e.printStackTrace();
}
return null;
}
@ -140,8 +137,7 @@ public class DataVecTransformClient implements DataVecTransformService {
.getBody();
return batchCSVRecord1;
} catch (UnirestException e) {
log.error("Error in transform(BatchCSVRecord)", e);
e.printStackTrace();
log.error("",e);
}
return null;
@ -162,7 +158,6 @@ public class DataVecTransformClient implements DataVecTransformService {
return batchCSVRecord1;
} catch (UnirestException e) {
log.error("Error in transform(BatchCSVRecord)", e);
e.printStackTrace();
}
return null;
@ -181,7 +176,6 @@ public class DataVecTransformClient implements DataVecTransformService {
return batchArray1;
} catch (UnirestException e) {
log.error("Error in transformArray(BatchCSVRecord)",e);
e.printStackTrace();
}
return null;
@ -200,7 +194,6 @@ public class DataVecTransformClient implements DataVecTransformService {
return array;
} catch (UnirestException e) {
log.error("Error in transformArrayIncremental(SingleCSVRecord)",e);
e.printStackTrace();
}
return null;
@ -231,7 +224,6 @@ public class DataVecTransformClient implements DataVecTransformService {
return array;
} catch (UnirestException e) {
log.error("Error in transformSequenceArrayIncremental",e);
e.printStackTrace();
}
return null;
@ -252,7 +244,6 @@ public class DataVecTransformClient implements DataVecTransformService {
return batchArray1;
} catch (UnirestException e) {
log.error("Error in transformSequenceArray",e);
e.printStackTrace();
}
return null;
@ -274,7 +265,6 @@ public class DataVecTransformClient implements DataVecTransformService {
return batchCSVRecord1;
} catch (UnirestException e) {
log.error("Error in transformSequence");
e.printStackTrace();
}
return null;
@ -295,7 +285,6 @@ public class DataVecTransformClient implements DataVecTransformService {
return singleCsvRecord;
} catch (UnirestException e) {
log.error("Error in transformSequenceIncremental");
e.printStackTrace();
}
return null;
}

View File

@ -18,6 +18,7 @@ package org.datavec.spark.transform;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
@ -53,7 +54,7 @@ import static org.datavec.local.transforms.LocalTransformExecutor.executeToSeque
* @author Adan Gibson
*/
@AllArgsConstructor
@Slf4j
public class CSVSparkTransform {
@Getter
private TransformProcess transformProcess;
@ -252,7 +253,7 @@ public class CSVSparkTransform {
try {
return new Base64NDArrayBody(Nd4jBase64.base64String(arr));
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
return null;

View File

@ -88,7 +88,7 @@ public class ImageSparkTransformServer extends SparkTransformServer {
log.info("Transform process initialized");
return ok(objectMapper.writeValueAsString(transform.getImageTransformProcess())).as(contentType);
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
return internalServerError();
}
});
@ -100,7 +100,7 @@ public class ImageSparkTransformServer extends SparkTransformServer {
log.info("Transform process initialized");
return ok(objectMapper.writeValueAsString(transformProcess)).as(contentType);
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
return internalServerError();
}
});
@ -112,7 +112,7 @@ public class ImageSparkTransformServer extends SparkTransformServer {
return badRequest();
return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType);
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
return internalServerError();
}
});
@ -130,7 +130,7 @@ public class ImageSparkTransformServer extends SparkTransformServer {
return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType);
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
return internalServerError();
}
});
@ -142,7 +142,7 @@ public class ImageSparkTransformServer extends SparkTransformServer {
return badRequest();
return ok(objectMapper.writeValueAsString(transformArray(batch))).as(contentType);
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
return internalServerError();
}
});
@ -169,7 +169,7 @@ public class ImageSparkTransformServer extends SparkTransformServer {
return ok(objectMapper.writeValueAsString(transformArray(batch))).as(contentType);
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
return internalServerError();
}
});

View File

@ -16,6 +16,7 @@
package org.datavec.spark;
import lombok.extern.slf4j.Slf4j;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.junit.After;
@ -23,6 +24,7 @@ import org.junit.Before;
import java.io.Serializable;
@Slf4j
public abstract class BaseSparkTest implements Serializable {
protected static JavaSparkContext sc;
@ -40,7 +42,7 @@ public abstract class BaseSparkTest implements Serializable {
try {
Thread.sleep(100L);
} catch (InterruptedException e) {
e.printStackTrace();
log.error("",e);
}
} else {
break;

View File

@ -68,7 +68,7 @@ public abstract class BaseDL4JTest {
* Override this method to set the default timeout for methods in the test class
*/
public long getTimeoutMilliseconds(){
return 30000;
return 60_000;
}
/**

View File

@ -43,7 +43,8 @@ import java.util.concurrent.atomic.AtomicLong;
*/
@Slf4j
public class RemoteUIStatsStorageRouter implements StatsStorageRouter, Serializable, Closeable {
private static final String ROUTE_IS_DOWN = "Info posted to RemoteUIStatsStorageRouter but router is shut down.";
private static final String MAX_WARNINGS_REACHED = "RemoteUIStatsStorageRouter: Reached max shutdown warnings. No further warnings will be produced.";
/**
* Default path for posting data to the UI - i.e., http://localhost:9000/remoteReceive or similar
*/
@ -163,10 +164,10 @@ public class RemoteUIStatsStorageRouter implements StatsStorageRouter, Serializa
if (shutdown.get()) {
long count = shutdownWarnCount.getAndIncrement();
if (count <= MAX_SHUTDOWN_WARN_COUNT) {
log.warn("Info posted to RemoteUIStatsStorageRouter but router is shut down.");
log.warn(ROUTE_IS_DOWN);
}
if (count == MAX_SHUTDOWN_WARN_COUNT) {
log.warn("RemoteUIStatsStorageRouter: Reached max shutdown warnings. No further warnings will be produced.");
log.warn(MAX_WARNINGS_REACHED);
}
} else {
for (StorageMetaData m : storageMetaData) {
@ -186,10 +187,10 @@ public class RemoteUIStatsStorageRouter implements StatsStorageRouter, Serializa
if (shutdown.get()) {
long count = shutdownWarnCount.getAndIncrement();
if (count <= MAX_SHUTDOWN_WARN_COUNT) {
log.warn("Info posted to RemoteUIStatsStorageRouter but router is shut down.");
log.warn(ROUTE_IS_DOWN);
}
if (count == MAX_SHUTDOWN_WARN_COUNT) {
log.warn("RemoteUIStatsStorageRouter: Reached max shutdown warnings. No further warnings will be produced.");
log.warn(MAX_WARNINGS_REACHED);
}
} else {
for (Persistable p : staticInfo) {
@ -209,10 +210,10 @@ public class RemoteUIStatsStorageRouter implements StatsStorageRouter, Serializa
if (shutdown.get()) {
long count = shutdownWarnCount.getAndIncrement();
if (count <= MAX_SHUTDOWN_WARN_COUNT) {
log.warn("Info posted to RemoteUIStatsStorageRouter but router is shut down.");
log.warn(ROUTE_IS_DOWN);
}
if (count == MAX_SHUTDOWN_WARN_COUNT) {
log.warn("RemoteUIStatsStorageRouter: Reached max shutdown warnings. No further warnings will be produced.");
log.warn(MAX_WARNINGS_REACHED);
}
} else {
for (Persistable p : updates) {

View File

@ -67,10 +67,7 @@ public class AsyncIterator<T extends Object> implements Iterator<T> {
nextElement = buffer.take();
// same on this run
if (nextElement == terminator)
return false;
return true;
return (nextElement != terminator);
} catch (Exception e) {
log.error("Premature end of loop!");
return false;

View File

@ -43,7 +43,7 @@ public class SystemInfoPrintListener implements TrainingListener {
private boolean printOnBackwardPass;
private boolean printOnGradientCalculation;
private static final String SYSTEM_INFO = "System info on epoch end: ";
@Override
public void iterationDone(Model model, int iteration, int epoch) {
@ -65,7 +65,7 @@ public class SystemInfoPrintListener implements TrainingListener {
return;
SystemInfo systemInfo = new SystemInfo();
log.info("System info on epoch end: ");
log.info(SYSTEM_INFO);
log.info(systemInfo.toPrettyJSON());
}
@ -75,7 +75,7 @@ public class SystemInfoPrintListener implements TrainingListener {
return;
SystemInfo systemInfo = new SystemInfo();
log.info("System info on epoch end: ");
log.info(SYSTEM_INFO);
log.info(systemInfo.toPrettyJSON());
}
@ -85,7 +85,7 @@ public class SystemInfoPrintListener implements TrainingListener {
return;
SystemInfo systemInfo = new SystemInfo();
log.info("System info on epoch end: ");
log.info(SYSTEM_INFO);
log.info(systemInfo.toPrettyJSON());
}
@ -95,7 +95,7 @@ public class SystemInfoPrintListener implements TrainingListener {
return;
SystemInfo systemInfo = new SystemInfo();
log.info("System info on epoch end: ");
log.info(SYSTEM_INFO);
log.info(systemInfo.toPrettyJSON());
}
@ -104,7 +104,7 @@ public class SystemInfoPrintListener implements TrainingListener {
if(!printOnBackwardPass)
return;
SystemInfo systemInfo = new SystemInfo();
log.info("System info on epoch end: ");
log.info(SYSTEM_INFO);
log.info(systemInfo.toPrettyJSON());
}
}

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.perf.listener;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
import oshi.json.SystemInfo;
@ -36,6 +37,7 @@ import java.util.concurrent.TimeUnit;
*
* @author Adam Gibson
*/
@Slf4j
public class SystemPolling {
private ScheduledExecutorService scheduledExecutorService;
@ -66,7 +68,7 @@ public class SystemPolling {
try {
objectMapper.writeValue(hardwareFile,hardwareMetric);
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
}
},0,pollEveryMillis, TimeUnit.MILLISECONDS);

View File

@ -27,7 +27,6 @@ import org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
import org.nd4j.linalg.dataset.api.preprocessor.Normalizer;
import java.io.*;
import java.nio.file.Files;
import java.util.UUID;
/**

View File

@ -20,6 +20,7 @@ import org.apache.commons.compress.utils.IOUtils;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
@ -153,11 +154,22 @@ public class TestUtils {
return randomOneHotTimeSeries(minibatch, outSize, tsLength, new Random(rngSeed));
}
public static INDArray randomOneHotTimeSeries(int minibatch, int outSize, int tsLength, Random rng){
INDArray out = Nd4j.create(new int[]{minibatch, outSize, tsLength}, 'f');
public static INDArray randomOneHotTimeSeries(int minibatch, int outSize, int tsLength, Random rng) {
return randomOneHotTimeSeries(RNNFormat.NCW, minibatch, outSize, tsLength, rng);
}
public static INDArray randomOneHotTimeSeries(RNNFormat format, int minibatch, int outSize, int tsLength, Random rng){
boolean ncw = format == RNNFormat.NCW;
long[] shape = ncw ? new long[]{minibatch, outSize, tsLength} : new long[]{minibatch, tsLength, outSize};
char order = ncw ? 'f' : 'c';
INDArray out = Nd4j.create(DataType.FLOAT, shape, order);
for( int i=0; i<minibatch; i++ ){
for( int j=0; j<tsLength; j++ ){
if(ncw){
out.putScalar(i, rng.nextInt(outSize), j, 1.0);
} else {
out.putScalar(i, j, rng.nextInt(outSize), 1.0);
}
}
}
return out;

View File

@ -493,7 +493,7 @@ public class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
out.println();
}
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
return new Pair<>(dArr,temp);

View File

@ -78,10 +78,10 @@ public class EarlyTerminationDataSetIteratorTest extends BaseDL4JTest {
EarlyTerminationDataSetIterator earlyEndIter = new EarlyTerminationDataSetIterator(iter, terminateAfter);
earlyEndIter.next(10);
assertEquals(earlyEndIter.hasNext(), false);
assertEquals(false, earlyEndIter.hasNext());
earlyEndIter.reset();
assertEquals(earlyEndIter.hasNext(), true);
assertEquals(true, earlyEndIter.hasNext());
}

View File

@ -98,7 +98,7 @@ public class MultipleEpochsIteratorTest extends BaseDL4JTest {
while (multiIter.hasNext()) {
DataSet path = multiIter.next(10);
assertNotNull(path);
assertEquals(path.numExamples(), 10, 0.0);
assertEquals(10, path.numExamples(), 0.0);
}
assertEquals(epochs, multiIter.epochs);

View File

@ -33,7 +33,7 @@ public class SamplingTest extends BaseDL4JTest {
DataSetIterator iter = new MnistDataSetIterator(10, 10);
//batch size and total
DataSetIterator sampling = new SamplingDataSetIterator(iter.next(), 10, 10);
assertEquals(sampling.next().numExamples(), 10);
assertEquals(10, sampling.next().numExamples());
}
}

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.exceptions;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.exception.DL4JException;
import org.deeplearning4j.nn.conf.ConvolutionMode;
@ -34,6 +35,7 @@ import static org.junit.Assert.fail;
/**
* A set of tests to ensure that useful exceptions are thrown on invalid network configurations
*/
@Slf4j
public class TestInvalidConfigurations extends BaseDL4JTest {
public static MultiLayerNetwork getDensePlusOutput(int nIn, int nOut) {
@ -78,7 +80,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testDenseNin0(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail();
}
}
@ -96,7 +98,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testDenseNout0(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail();
}
}
@ -109,7 +111,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testOutputLayerNin0(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail();
}
}
@ -122,7 +124,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testRnnOutputLayerNin0(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail();
}
}
@ -135,7 +137,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testLSTMNIn0(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail();
}
}
@ -153,7 +155,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testLSTMNOut0(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail();
}
}
@ -166,7 +168,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testConvolutionalNin0(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail();
}
}
@ -185,7 +187,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testConvolutionalNOut0(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail();
}
}
@ -216,7 +218,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testCnnInvalidConfigPaddingStridesHeight(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail();
}
}
@ -245,7 +247,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testCnnInvalidConfigOrInput_SmallerDataThanKernel(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail();
}
}
@ -277,7 +279,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testCnnInvalidConfigOrInput_BadStrides(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail();
}
}
@ -318,7 +320,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testCnnInvalidConfigPaddingStridesWidth(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail();
}
}
@ -347,7 +349,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testCnnInvalidConfigPaddingStridesWidthSubsampling(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail();
}
}

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.exceptions;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.exception.DL4JException;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
@ -36,6 +37,7 @@ import static org.junit.Assert.*;
/**
* A set of tests to ensure that useful exceptions are thrown on invalid input
*/
@Slf4j
public class TestInvalidInput extends BaseDL4JTest {
@Test
@ -53,7 +55,7 @@ public class TestInvalidInput extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testInputNinMismatchDense(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail("Expected DL4JException");
}
}
@ -73,7 +75,7 @@ public class TestInvalidInput extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testInputNinMismatchOutputLayer(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail("Expected DL4JException");
}
}
@ -94,7 +96,7 @@ public class TestInvalidInput extends BaseDL4JTest {
//From loss function
System.out.println("testLabelsNOutMismatchOutputLayer(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail("Expected DL4JException");
}
}
@ -115,7 +117,7 @@ public class TestInvalidInput extends BaseDL4JTest {
//From loss function
System.out.println("testLabelsNOutMismatchRnnOutputLayer(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail("Expected DL4JException");
}
}
@ -142,7 +144,7 @@ public class TestInvalidInput extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testInputNinMismatchConvolutional(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail("Expected DL4JException");
}
}
@ -169,7 +171,7 @@ public class TestInvalidInput extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testInputNinRank2Convolutional(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail("Expected DL4JException");
}
}
@ -195,7 +197,7 @@ public class TestInvalidInput extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testInputNinRank2Subsampling(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail("Expected DL4JException");
}
}
@ -217,7 +219,7 @@ public class TestInvalidInput extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testInputNinMismatchLSTM(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail("Expected DL4JException");
}
}
@ -238,7 +240,7 @@ public class TestInvalidInput extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testInputNinMismatchBidirectionalLSTM(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail("Expected DL4JException");
}
@ -260,7 +262,7 @@ public class TestInvalidInput extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testInputNinMismatchEmbeddingLayer(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail("Expected DL4JException");
}
}
@ -305,7 +307,7 @@ public class TestInvalidInput extends BaseDL4JTest {
net.rnnTimeStep(Nd4j.create(5, 5, 10));
fail("Expected Exception - " + layerType);
} catch (Exception e) {
// e.printStackTrace();
log.error("",e);
String msg = e.getMessage();
assertTrue(msg, msg != null && msg.contains("rnn") && msg.contains("batch"));
}

View File

@ -21,6 +21,7 @@ import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils;
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
@ -34,6 +35,8 @@ import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -51,6 +54,7 @@ import static org.junit.Assert.*;
/**
* Created by nyghtowl on 9/1/15.
*/
@RunWith(Parameterized.class)
public class CNNGradientCheckTest extends BaseDL4JTest {
private static final boolean PRINT_RESULTS = true;
private static final boolean RETURN_ON_FIRST_FAILURE = false;
@ -62,6 +66,17 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
Nd4j.setDataType(DataType.DOUBLE);
}
private CNN2DFormat format;
public CNNGradientCheckTest(CNN2DFormat format){
this.format = format;
}
@Parameterized.Parameters(name = "{0}")
public static Object[] params(){
return CNN2DFormat.values();
}
@Override
public long getTimeoutMilliseconds() {
return 90000L;
@ -69,6 +84,9 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
@Test
public void testGradientCNNMLN() {
if(this.format != CNN2DFormat.NCHW) //Only test NCHW due to flat input format...
return;
//Parameterized test, testing combinations of:
// (a) activation function
// (b) Whether to test at random initialization, or after some learning (i.e., 'characteristic mode of operation')
@ -144,6 +162,9 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
@Test
public void testGradientCNNL1L2MLN() {
if(this.format != CNN2DFormat.NCHW) //Only test NCHW due to flat input format...
return;
//Parameterized test, testing combinations of:
// (a) activation function
// (b) Whether to test at random initialization, or after some learning (i.e., 'characteristic mode of operation')
@ -311,10 +332,12 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
new SubsamplingLayer.PoolingType[]{SubsamplingLayer.PoolingType.MAX,
SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM};
boolean nchw = format == CNN2DFormat.NCHW;
for (String afn : activations) {
for (SubsamplingLayer.PoolingType poolingType : poolingTypes) {
for (int minibatchSize : minibatchSizes) {
INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth);
long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth};
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
INDArray labels = Nd4j.zeros(4 * minibatchSize, nOut);
for (int i = 0; i < 4 * minibatchSize; i++) {
labels.putScalar(new int[]{i, i % nOut}, 1.0);
@ -330,13 +353,13 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX)
.nOut(nOut).build())
.setInputType(InputType.convolutionalFlat(height, width, inputDepth))
.setInputType(InputType.convolutional(height, width, inputDepth, format))
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
String msg = "PoolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn="
String msg = format + " - poolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn="
+ afn;
if (PRINT_RESULTS) {
@ -377,8 +400,11 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
int[] padding = {0, 0};
int size = 2;
boolean nchw = format == CNN2DFormat.NCHW;
for (int minibatchSize : minibatchSizes) {
INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth);
long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth};
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
INDArray labels = TestUtils.randomOneHot(minibatchSize, nOut);
MultiLayerConfiguration conf =
@ -393,8 +419,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(8 * 8 * 3)
.nOut(4).build())
.setInputType(InputType.convolutionalFlat(height, width,
inputDepth))
.setInputType(InputType.convolutional(height, width, inputDepth, format))
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
@ -438,10 +463,13 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
new SubsamplingLayer.PoolingType[]{SubsamplingLayer.PoolingType.MAX,
SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM};
boolean nchw = format == CNN2DFormat.NCHW;
for (Activation afn : activations) {
for (SubsamplingLayer.PoolingType poolingType : poolingTypes) {
for (int minibatchSize : minibatchSizes) {
INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth);
long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth};
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
INDArray labels = Nd4j.zeros(minibatchSize, nOut);
for (int i = 0; i < minibatchSize; i++) {
labels.putScalar(new int[]{i, i % nOut}, 1.0);
@ -461,14 +489,13 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(3 * 3 * 3)
.nOut(4).build())
.setInputType(InputType.convolutionalFlat(height, width,
inputDepth))
.setInputType(InputType.convolutional(height, width, inputDepth, format))
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
String msg = "PoolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn="
String msg = format + " - poolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn="
+ afn;
if (PRINT_RESULTS) {
@ -508,10 +535,13 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
new SubsamplingLayer.PoolingType[]{SubsamplingLayer.PoolingType.MAX,
SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM};
boolean nchw = format == CNN2DFormat.NCHW;
for (Activation afn : activations) {
for (SubsamplingLayer.PoolingType poolingType : poolingTypes) {
for (int minibatchSize : minibatchSizes) {
INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth);
long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth};
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
INDArray labels = Nd4j.zeros(minibatchSize, nOut);
for (int i = 0; i < minibatchSize; i++) {
labels.putScalar(new int[]{i, i % nOut}, 1.0);
@ -533,8 +563,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(2 * 2 * 2)
.nOut(4).build())
.setInputType(InputType.convolutionalFlat(height, width,
inputDepth))
.setInputType(InputType.convolutional(height, width, inputDepth, format))
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
@ -558,8 +587,6 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
@Test
public void testCnnLocallyConnected2D() {
int nOut = 3;
int[] minibatchSizes = {2};
int width = 5;
int height = 5;
@ -569,11 +596,15 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
Activation[] activations = {Activation.SIGMOID, Activation.TANH, Activation.SOFTPLUS};
int[] minibatch = {2, 1, 3};
boolean nchw = format == CNN2DFormat.NCHW;
for( int i=0; i<inputDepths.length; i++ ){
int inputDepth = inputDepths[i];
Activation afn = activations[i];
int minibatchSize = minibatch[i];
INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth);
long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth};
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
INDArray labels = TestUtils.randomOneHot(minibatchSize, nOut);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(new NoOp())
@ -590,7 +621,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(2 * 2 * 2).nOut(nOut)
.build())
.setInputType(InputType.convolutionalFlat(height, width, inputDepth)).build();
.setInputType(InputType.convolutional(height, width, inputDepth, format)).build();
assertEquals(ConvolutionMode.Truncate,
((ConvolutionLayer) conf.getConf(0).getLayer()).getConvolutionMode());
@ -626,11 +657,15 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
Nd4j.getRandom().setSeed(12345);
boolean nchw = format == CNN2DFormat.NCHW;
for (int inputDepth : inputDepths) {
for (Activation afn : activations) {
for (SubsamplingLayer.PoolingType poolingType : poolingTypes) {
for (int minibatchSize : minibatchSizes) {
INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth);
long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth};
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
INDArray labels = Nd4j.zeros(minibatchSize, nOut);
for (int i = 0; i < minibatchSize; i++) {
labels.putScalar(new int[]{i, i % nOut}, 1.0);
@ -649,7 +684,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(2 * 2 * 2).nOut(nOut)
.build())
.setInputType(InputType.convolutionalFlat(height, width, inputDepth)).build();
.setInputType(InputType.convolutional(height, width, inputDepth, format)).build();
assertEquals(ConvolutionMode.Truncate,
((ConvolutionLayer) conf.getConf(0).getLayer()).getConvolutionMode());
@ -691,13 +726,17 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
Nd4j.getRandom().setSeed(12345);
boolean nchw = format == CNN2DFormat.NCHW;
for( int i=0; i<minibatchSizes.length; i++ ){
int inputDepth = inputDepths[i];
int minibatchSize = minibatchSizes[i];
int height = heights[i];
int k = kernelSizes[i];
INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth);
long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth};
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
INDArray labels = TestUtils.randomOneHot(minibatchSize, nOut);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
@ -713,7 +752,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
.stride(1, 1).padding(0, 0).build())
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(nOut).build())
.setInputType(InputType.convolutionalFlat(height, width, inputDepth)).build();
.setInputType(InputType.convolutional(height, width, inputDepth, format)).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
@ -748,13 +787,16 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
Nd4j.getRandom().setSeed(12345);
boolean nchw = format == CNN2DFormat.NCHW;
for (int inputDepth : inputDepths) {
for (int minibatchSize : minibatchSizes) {
for (int stride : strides) {
for (int k : kernelSizes) {
for (boolean convFirst : new boolean[]{true, false}) {
long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth};
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth);
INDArray labels = Nd4j.zeros(minibatchSize, nOut);
for (int i = 0; i < minibatchSize; i++) {
labels.putScalar(new int[]{i, i % nOut}, 1.0);
@ -775,7 +817,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
.layer(1, convFirst ? poolLayer : convLayer)
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(nOut).build())
.setInputType(InputType.convolutionalFlat(height, width, inputDepth))
.setInputType(InputType.convolutional(height, width, inputDepth, format))
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
@ -822,11 +864,15 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
int[] inputDepths = {1, 3, 2};
int[][] zeroPadLayer = new int[][]{{0, 0, 0, 0}, {1, 1, 0, 0}, {2, 2, 2, 2}};
boolean nchw = format == CNN2DFormat.NCHW;
for( int i=0; i<minibatchSizes.length; i++ ){
int minibatchSize = minibatchSizes[i];
int inputDepth = inputDepths[i];
int[] zeroPad = zeroPadLayer[i];
INDArray input = Nd4j.rand(DataType.DOUBLE, new int[]{minibatchSize, inputDepth, height, width});
long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth};
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
INDArray labels = TestUtils.randomOneHot(minibatchSize, nOut);
MultiLayerConfiguration conf =
@ -840,7 +886,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
padding).nIn(3).nOut(3).build())//output: (6-2+0)/1+1 = 5
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(4).build())
.setInputType(InputType.convolutional(height, width, inputDepth))
.setInputType(InputType.convolutional(height, width, inputDepth, format))
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
@ -849,8 +895,14 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
//Check zero padding activation shape
org.deeplearning4j.nn.layers.convolution.ZeroPaddingLayer zpl =
(org.deeplearning4j.nn.layers.convolution.ZeroPaddingLayer) net.getLayer(1);
val expShape = new long[]{minibatchSize, inputDepth, height + zeroPad[0] + zeroPad[1],
long[] expShape;
if(nchw){
expShape = new long[]{minibatchSize, inputDepth, height + zeroPad[0] + zeroPad[1],
width + zeroPad[2] + zeroPad[3]};
} else {
expShape = new long[]{minibatchSize, height + zeroPad[0] + zeroPad[1],
width + zeroPad[2] + zeroPad[3], inputDepth};
}
INDArray out = zpl.activate(input, false, LayerWorkspaceMgr.noWorkspaces());
assertArrayEquals(expShape, out.shape());
@ -888,6 +940,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
Nd4j.getRandom().setSeed(12345);
boolean nchw = format == CNN2DFormat.NCHW;
for (int i = 0; i < minibatchSizes.length; i++) {
int minibatchSize = minibatchSizes[i];
int k = kernelSizes[i];
@ -900,7 +954,10 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
int w = d * width;
int h = d * height;
INDArray input = Nd4j.rand(minibatchSize, w * h * inputDepth);
long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, h, w} : new long[]{minibatchSize, h, w, inputDepth};
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
INDArray labels = Nd4j.zeros(minibatchSize, nOut);
for (int j = 0; j < minibatchSize; j++) {
labels.putScalar(new int[]{j, j % nOut}, 1.0);
@ -920,7 +977,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
MultiLayerConfiguration conf = b.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(nOut).build())
.setInputType(InputType.convolutionalFlat(h, w, inputDepth)).build();
.setInputType(InputType.convolutional(h, w, inputDepth, format)).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
@ -945,8 +1002,6 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
@Test
public void testSeparableConv2D() {
int nOut = 2;
int[] minibatchSizes = new int[]{1, 3};
int width = 6;
int height = 6;
int inputDepth = 3;
@ -959,6 +1014,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
ConvolutionMode[] cms = new ConvolutionMode[]{Truncate, Truncate, Truncate, Truncate, Truncate};
int[] mb = new int[]{1, 1, 1, 3, 3};
boolean nchw = format == CNN2DFormat.NCHW;
for (int t = 0; t < ks.length; t++) {
int k = ks[t];
@ -971,7 +1028,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
int w = d * width;
int h = d * height;
INDArray input = Nd4j.rand(minibatchSize, w * h * inputDepth);
long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, h, w} : new long[]{minibatchSize, h, w, inputDepth};
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
INDArray labels = Nd4j.zeros(minibatchSize, nOut);
for (int i = 0; i < minibatchSize; i++) {
labels.putScalar(new int[]{i, i % nOut}, 1.0);
@ -992,7 +1050,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
MultiLayerConfiguration conf = b.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(nOut).build())
.setInputType(InputType.convolutionalFlat(h, w, inputDepth)).build();
.setInputType(InputType.convolutional(h, w, inputDepth, format)).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
@ -1017,7 +1075,6 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
@Test
public void testCnnDilated() {
int nOut = 2;
int minibatchSize = 2;
int width = 8;
int height = 8;
@ -1031,9 +1088,9 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
int[] ds = new int[]{2, 2, 3, 3, 2};
ConvolutionMode[] cms = new ConvolutionMode[]{Same, Truncate, Truncate, Same, Truncate};
boolean nchw = format == CNN2DFormat.NCHW;
for (int t = 0; t < sub.length; t++) {
boolean subsampling = sub[t];
int s = stride[t];
int k = kernel[t];
@ -1044,7 +1101,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
int w = d * width;
int h = d * height;
INDArray input = Nd4j.rand(minibatchSize, w * h * inputDepth);
long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, h, w} : new long[]{minibatchSize, h, w, inputDepth};
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
INDArray labels = Nd4j.zeros(minibatchSize, nOut);
for (int i = 0; i < minibatchSize; i++) {
labels.putScalar(new int[]{i, i % nOut}, 1.0);
@ -1076,7 +1134,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
MultiLayerConfiguration conf = b.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(nOut).build())
.setInputType(InputType.convolutionalFlat(h, w, inputDepth)).build();
.setInputType(InputType.convolutional(h, w, inputDepth, format)).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
@ -1114,11 +1172,14 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
int[] inputDepths = {1, 2, 3, 2};
int[] minibatchSizes = {2, 1, 3, 2};
boolean nchw = format == CNN2DFormat.NCHW;
for (int i = 0; i < cropTestCases.length; i++) {
int inputDepth = inputDepths[i];
int minibatchSize = minibatchSizes[i];
int[] crop = cropTestCases[i];
INDArray input = Nd4j.rand(new int[]{minibatchSize, inputDepth, height, width});
long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth};
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
INDArray labels = TestUtils.randomOneHot(minibatchSize, nOut);
MultiLayerConfiguration conf =
@ -1134,7 +1195,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
.layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.AVG).kernelSize(3, 3).stride(3, 3).build())
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(nOut).build())
.setInputType(InputType.convolutional(height, width, inputDepth))
.setInputType(InputType.convolutional(height, width, inputDepth, format))
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
@ -1143,12 +1204,18 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
//Check cropping activation shape
org.deeplearning4j.nn.layers.convolution.Cropping2DLayer cl =
(org.deeplearning4j.nn.layers.convolution.Cropping2DLayer) net.getLayer(1);
val expShape = new long[]{minibatchSize, inputDepth, height - crop[0] - crop[1],
long[] expShape;
if(nchw){
expShape = new long[]{minibatchSize, inputDepth, height - crop[0] - crop[1],
width - crop[2] - crop[3]};
} else {
expShape = new long[]{minibatchSize, height - crop[0] - crop[1],
width - crop[2] - crop[3], inputDepth};
}
INDArray out = cl.activate(input, false, LayerWorkspaceMgr.noWorkspaces());
assertArrayEquals(expShape, out.shape());
String msg = "minibatch=" + minibatchSize + ", channels=" + inputDepth + ", zeroPad = "
String msg = format + " - minibatch=" + minibatchSize + ", channels=" + inputDepth + ", zeroPad = "
+ Arrays.toString(crop);
if (PRINT_RESULTS) {
@ -1181,6 +1248,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
Truncate, Truncate, Truncate, Truncate, Truncate};
int[] mb = new int[]{1,1,1,3,3};
boolean nchw = format == CNN2DFormat.NCHW;
for( int t=0; t<ks.length; t++ ){
int k = ks[t];
@ -1188,8 +1257,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
ConvolutionMode cm = cms[t];
int minibatchSize = mb[t];
INDArray input = Nd4j.rand(minibatchSize, width * height * nIn);
long[] inShape = nchw ? new long[]{minibatchSize, nIn, height, width} : new long[]{minibatchSize, height, width, nIn};
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape);
INDArray labels = Nd4j.zeros(minibatchSize, nOut);
for (int i = 0; i < minibatchSize; i++) {
labels.putScalar(new int[]{i, i % nOut}, 1.0);
@ -1211,7 +1280,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
MultiLayerConfiguration conf = b.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(nOut).build())
.setInputType(InputType.convolutionalFlat(height, width, nIn)).build();
.setInputType(InputType.convolutional(height, width, nIn, format)).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();

View File

@ -18,6 +18,7 @@ package org.deeplearning4j.gradientcheck;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
@ -115,15 +116,17 @@ public class GlobalPoolingGradientCheckTests extends BaseDL4JTest {
//Basic test of global pooling w/ CNN
Nd4j.getRandom().setSeed(12345L);
for(boolean nchw : new boolean[]{true, false}) {
int inputDepth = 3;
int inputH = 5;
int inputW = 4;
int layerDepth = 4;
int nOut = 2;
int[] minibatchSizes = new int[] {1, 3};
int[] minibatchSizes = new int[]{1, 3};
PoolingType[] poolingTypes =
new PoolingType[] {PoolingType.AVG, PoolingType.SUM, PoolingType.MAX, PoolingType.PNORM};
new PoolingType[]{PoolingType.AVG, PoolingType.SUM, PoolingType.MAX, PoolingType.PNORM};
for (int miniBatchSize : minibatchSizes) {
for (PoolingType pt : poolingTypes) {
@ -137,14 +140,14 @@ public class GlobalPoolingGradientCheckTests extends BaseDL4JTest {
.layer(1, new GlobalPoolingLayer.Builder().poolingType(pt).build())
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(nOut).build())
.setInputType(InputType.convolutional(inputH, inputW, inputDepth)).build();
.setInputType(InputType.convolutional(inputH, inputW, inputDepth, nchw ? CNN2DFormat.NCHW : CNN2DFormat.NHWC)).build();
MultiLayerNetwork mln = new MultiLayerNetwork(conf);
mln.init();
Random r = new Random(12345L);
INDArray input = Nd4j.rand(new int[] {miniBatchSize, inputDepth, inputH, inputW}).subi(0.5);
long[] inShape = nchw ? new long[]{miniBatchSize, inputDepth, inputH, inputW} : new long[]{miniBatchSize, inputH, inputW, inputDepth};
INDArray input = Nd4j.rand(DataType.DOUBLE, inShape).subi(0.5);
INDArray labels = Nd4j.zeros(miniBatchSize, nOut);
for (int i = 0; i < miniBatchSize; i++) {
@ -153,8 +156,7 @@ public class GlobalPoolingGradientCheckTests extends BaseDL4JTest {
}
if (PRINT_RESULTS) {
System.out.println(
"testCnnGlobalPoolingBasicMultiLayer() - " + pt + ", minibatch = " + miniBatchSize);
System.out.println("testCnnGlobalPoolingBasicMultiLayer() - " + pt + ", minibatch = " + miniBatchSize + " - " + (nchw ? "NCHW" : "NHWC"));
// for (int j = 0; j < mln.getnLayers(); j++)
// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
}
@ -167,6 +169,7 @@ public class GlobalPoolingGradientCheckTests extends BaseDL4JTest {
}
}
}
}
@Test
public void testLSTMWithMasking() {

View File

@ -24,6 +24,7 @@ import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
import org.deeplearning4j.nn.conf.layers.*;
@ -560,8 +561,9 @@ public class GradientCheckTests extends BaseDL4JTest {
public void testEmbeddingSequenceLayer(){
Nd4j.getRandom().setSeed(12345);
for(boolean maskArray : new boolean[]{false, true}){
for(int inputRank : new int[]{2,3}) {
for(RNNFormat seqOutputFormat : RNNFormat.values()) {
for (boolean maskArray : new boolean[]{false, true}) {
for (int inputRank : new int[]{2, 3}) {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
@ -572,18 +574,22 @@ public class GradientCheckTests extends BaseDL4JTest {
.layer(new EmbeddingSequenceLayer.Builder()
.nIn(8)
.nOut(4)
.outputDataFormat(seqOutputFormat)
.build())
.layer(new RnnOutputLayer.Builder().nIn(4).nOut(3).activation(Activation.TANH)
.dataFormat(seqOutputFormat)
.lossFunction(LossFunction.MSE).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
INDArray in = Transforms.floor(Nd4j.rand(3, 6).muli(8)); //Integers 0 to 7 inclusive
INDArray label = Nd4j.rand(new int[]{3, 3, 6});
boolean ncw = seqOutputFormat == RNNFormat.NCW;
if(inputRank == 3){
INDArray in = Transforms.floor(Nd4j.rand(3, 6).muli(8)); //Integers 0 to 7 inclusive
INDArray label = Nd4j.rand(DataType.FLOAT, ncw ? new int[]{3, 3, 6} : new int[]{3,6,3});
if (inputRank == 3) {
//Reshape from [3,6] to [3,1,6]
in = in.reshape('c', 3, 1, 6);
}
@ -607,7 +613,7 @@ public class GradientCheckTests extends BaseDL4JTest {
if (maskArray) {
DataSet ds = new DataSet(in, label, fMask, null);
double score = net.score(ds);
if(inputRank == 2){
if (inputRank == 2) {
in.putScalar(1, 2, 0);
in.putScalar(2, 1, 0);
in.putScalar(2, 2, 0);
@ -618,7 +624,7 @@ public class GradientCheckTests extends BaseDL4JTest {
}
double score2 = net.score(ds);
assertEquals(score, score2, 1e-6);
if(inputRank == 2){
if (inputRank == 2) {
in.putScalar(1, 2, 1);
in.putScalar(2, 1, 1);
in.putScalar(2, 2, 1);
@ -633,6 +639,7 @@ public class GradientCheckTests extends BaseDL4JTest {
}
}
}
}
@Test

View File

@ -21,9 +21,7 @@ import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils;
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.conf.distribution.GaussianDistribution;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
@ -341,6 +339,10 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
@Test
public void testCnnDepthMerge() {
for(CNN2DFormat format : CNN2DFormat.values()) {
String msg = "testCnnDepthMerge - " + format;
Nd4j.getRandom().setSeed(12345);
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
.dataType(DataType.DOUBLE)
@ -358,20 +360,20 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
.build(),
"merge")
.setOutputs("outputLayer")
.inputPreProcessor("outputLayer", new CnnToFeedForwardPreProcessor(5, 5, 4))
.setInputTypes(InputType.convolutional(6, 6, 2, format))
.build();
ComputationGraph graph = new ComputationGraph(conf);
graph.init();
Random r = new Random(12345);
INDArray input = Nd4j.rand(new int[] {5, 2, 6, 6}); //Order: examples, channels, height, width
INDArray input = Nd4j.rand(DataType.DOUBLE, format == CNN2DFormat.NCHW ? new long[]{5,2,6,6} : new long[]{5,6,6,2});
INDArray labels = Nd4j.zeros(5, 3);
for (int i = 0; i < 5; i++)
labels.putScalar(new int[] {i, r.nextInt(3)}, 1.0);
labels.putScalar(new int[]{i, r.nextInt(3)}, 1.0);
if (PRINT_RESULTS) {
System.out.println("testCnnDepthMerge()");
System.out.println(msg);
// for (int j = 0; j < graph.getNumLayers(); j++)
// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
}
@ -379,14 +381,18 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
.labels(new INDArray[]{labels}));
String msg = "testCnnDepthMerge()";
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(graph);
}
}
@Test
public void testRNNWithMerging() {
for(RNNFormat format : RNNFormat.values()) {
String msg = "testLSTMWithMerging - " + format;
Nd4j.getRandom().setSeed(12345);
ComputationGraphConfiguration conf =
new NeuralNetConfiguration.Builder().seed(12345)
@ -416,19 +422,18 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
.activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build(),
"merge")
.inputPreProcessor("dense1", new RnnToFeedForwardPreProcessor())
.inputPreProcessor("lstm3", new FeedForwardToRnnPreProcessor())
.setInputTypes(InputType.recurrent(4, format))
.build();
ComputationGraph graph = new ComputationGraph(conf);
graph.init();
Random r = new Random(12345);
INDArray input = Nd4j.rand(new int[] {2, 3, 4});
INDArray labels = TestUtils.randomOneHotTimeSeries(2, 3, 4);
INDArray input = Nd4j.rand(DataType.DOUBLE, format == RNNFormat.NCW ? new long[]{2, 3, 4} : new long[]{2,4,3});
INDArray labels = TestUtils.randomOneHotTimeSeries(format, 2, 3, 4, new Random(12345));
if (PRINT_RESULTS) {
System.out.println("testLSTMWithMerging()");
System.out.println(msg);
// for (int j = 0; j < graph.getNumLayers(); j++)
// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
}
@ -436,9 +441,10 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
.labels(new INDArray[]{labels}));
String msg = "testLSTMWithMerging()";
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(graph);
}
}
@Test

View File

@ -216,7 +216,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
failed.add(testName + "\t" + "EXCEPTION");
continue;
}
@ -383,7 +383,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
failed.add(testName + "\t" + "EXCEPTION");
continue;
}
@ -693,7 +693,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
failed.add(testName + "\t" + "EXCEPTION");
continue;
}

View File

@ -338,7 +338,7 @@ public class RnnGradientChecks extends BaseDL4JTest {
.weightInit(WeightInit.XAVIER)
.list()
.layer(new LSTM.Builder().nOut(layerSize).build())
.layer(new TimeDistributed(new DenseLayer.Builder().nOut(layerSize).activation(Activation.SOFTMAX).build(), 2))
.layer(new TimeDistributed(new DenseLayer.Builder().nOut(layerSize).activation(Activation.SOFTMAX).build()))
.layer(new RnnOutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.setInputType(InputType.recurrent(nIn))

View File

@ -21,6 +21,7 @@ import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.exception.DL4JInvalidConfigException;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
@ -47,6 +48,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
import static org.junit.Assert.*;
@Slf4j
public class ComputationGraphConfigurationTest extends BaseDL4JTest {
@Test
@ -150,7 +152,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) {
//OK - exception is good
//e.printStackTrace();
log.info(e.toString());
}
// Use appendLayer on first layer
@ -162,7 +164,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) {
//OK - exception is good
//e.printStackTrace();
log.info(e.toString());
}
//Test no network inputs
@ -174,7 +176,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) {
//OK - exception is good
//e.printStackTrace();
log.info(e.toString());
}
//Test no network outputs
@ -185,7 +187,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) {
//OK - exception is good
//e.printStackTrace();
log.info(e.toString());
}
//Test: invalid input
@ -197,7 +199,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) {
//OK - exception is good
//e.printStackTrace();
log.info(e.toString());
}
//Test: graph with cycles
@ -215,7 +217,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) {
//OK - exception is good
//e.printStackTrace();
log.info(e.toString());
}
//Test: input != inputType count mismatch
@ -241,7 +243,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
fail("No exception thrown for invalid configuration");
} catch (IllegalArgumentException e) {
//OK - exception is good
//e.printStackTrace();
log.info(e.toString());
}
}

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.nn.conf;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.exception.DL4JInvalidConfigException;
import org.deeplearning4j.nn.api.Layer;
@ -46,6 +47,7 @@ import static org.junit.Assert.*;
/**
* Created by agibsonccc on 11/27/14.
*/
@Slf4j
public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
@Rule
@ -272,9 +274,9 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) {
//OK
e.printStackTrace();
log.error("",e);
} catch (Throwable e) {
e.printStackTrace();
log.error("",e);
fail("Unexpected exception thrown for invalid config");
}
@ -288,9 +290,9 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) {
//OK
e.printStackTrace();
log.info(e.toString());
} catch (Throwable e) {
e.printStackTrace();
log.error("",e);
fail("Unexpected exception thrown for invalid config");
}
@ -304,9 +306,9 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) {
//OK
e.printStackTrace();
log.info(e.toString());
} catch (Throwable e) {
e.printStackTrace();
log.error("",e);
fail("Unexpected exception thrown for invalid config");
}
}

View File

@ -96,8 +96,8 @@ public class TestConstraints extends BaseDL4JTest {
} else if (lc instanceof NonNegativeConstraint) {
assertTrue(RW0.minNumber().doubleValue() >= 0.0);
} else if (lc instanceof UnitNormConstraint) {
assertEquals(RW0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6);
assertEquals(RW0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6);
assertEquals(1.0, RW0.norm2(1).minNumber().doubleValue(), 1e-6);
assertEquals(1.0, RW0.norm2(1).maxNumber().doubleValue(), 1e-6);
}
TestUtils.testModelSerialization(net);
@ -149,8 +149,8 @@ public class TestConstraints extends BaseDL4JTest {
} else if (lc instanceof NonNegativeConstraint) {
assertTrue(b0.minNumber().doubleValue() >= 0.0);
} else if (lc instanceof UnitNormConstraint) {
assertEquals(b0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6);
assertEquals(b0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6);
assertEquals(1.0, b0.norm2(1).minNumber().doubleValue(), 1e-6);
assertEquals(1.0, b0.norm2(1).maxNumber().doubleValue(), 1e-6);
}
TestUtils.testModelSerialization(net);
@ -201,8 +201,8 @@ public class TestConstraints extends BaseDL4JTest {
} else if (lc instanceof NonNegativeConstraint) {
assertTrue(w0.minNumber().doubleValue() >= 0.0);
} else if (lc instanceof UnitNormConstraint) {
assertEquals(w0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6);
assertEquals(w0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6);
assertEquals(1.0, w0.norm2(1).minNumber().doubleValue(), 1e-6);
assertEquals(1.0, w0.norm2(1).maxNumber().doubleValue(), 1e-6);
}
TestUtils.testModelSerialization(net);
@ -259,10 +259,10 @@ public class TestConstraints extends BaseDL4JTest {
assertTrue(w0.minNumber().doubleValue() >= 0.0);
assertTrue(b0.minNumber().doubleValue() >= 0.0);
} else if (lc instanceof UnitNormConstraint) {
assertEquals(w0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6);
assertEquals(w0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6);
assertEquals(b0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6);
assertEquals(b0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6);
assertEquals(1.0, w0.norm2(1).minNumber().doubleValue(), 1e-6);
assertEquals(1.0, w0.norm2(1).maxNumber().doubleValue(), 1e-6);
assertEquals(1.0, b0.norm2(1).minNumber().doubleValue(), 1e-6);
assertEquals(1.0, b0.norm2(1).maxNumber().doubleValue(), 1e-6);
}
TestUtils.testModelSerialization(net);
@ -320,10 +320,10 @@ public class TestConstraints extends BaseDL4JTest {
assertTrue(w0.minNumber().doubleValue() >= 0.0);
assertTrue(b0.minNumber().doubleValue() >= 0.0);
} else if (lc instanceof UnitNormConstraint) {
assertEquals(w0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6);
assertEquals(w0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6);
assertEquals(b0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6);
assertEquals(b0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6);
assertEquals(1.0, w0.norm2(1).minNumber().doubleValue(), 1e-6);
assertEquals(1.0, w0.norm2(1).maxNumber().doubleValue(), 1e-6);
assertEquals(1.0, b0.norm2(1).minNumber().doubleValue(), 1e-6);
assertEquals(1.0, b0.norm2(1).maxNumber().doubleValue(), 1e-6);
}
TestUtils.testModelSerialization(net);
@ -378,10 +378,10 @@ public class TestConstraints extends BaseDL4JTest {
} else if(lc instanceof NonNegativeConstraint ){
assertTrue(w0.minNumber().doubleValue() >= 0.0 );
} else if(lc instanceof UnitNormConstraint ){
assertEquals(w0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6 );
assertEquals(w0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6 );
assertEquals(w1.norm2(1).minNumber().doubleValue(), 1.0, 1e-6 );
assertEquals(w1.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6 );
assertEquals(1.0, w0.norm2(1).minNumber().doubleValue(), 1e-6 );
assertEquals(1.0, w0.norm2(1).maxNumber().doubleValue(), 1e-6 );
assertEquals(1.0, w1.norm2(1).minNumber().doubleValue(), 1e-6 );
assertEquals(1.0, w1.norm2(1).maxNumber().doubleValue(), 1e-6 );
}
TestUtils.testModelSerialization(net);

View File

@ -156,7 +156,7 @@ public class LayerBuilderTest extends BaseDL4JTest {
checkSerialization(glstm);
assertEquals(glstm.getForgetGateBiasInit(), 1.5, 0.0);
assertEquals(1.5, glstm.getForgetGateBiasInit(), 0.0);
assertEquals(glstm.nIn, numIn);
assertEquals(glstm.nOut, numOut);
assertTrue(glstm.getActivationFn() instanceof ActivationTanH);

View File

@ -17,7 +17,9 @@
package org.deeplearning4j.nn.dtypes;
import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
import org.deeplearning4j.nn.conf.preprocessor.*;
import org.deeplearning4j.nn.modelimport.keras.layers.TFOpLayer;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.TensorFlowCnnToFeedForwardPreProcessor;
import org.nd4j.shade.guava.collect.ImmutableSet;
import org.nd4j.shade.guava.reflect.ClassPath;
import lombok.extern.slf4j.Slf4j;
@ -51,16 +53,11 @@ import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
import org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer;
import org.deeplearning4j.nn.conf.preprocessor.CnnToRnnPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.ComposableInputPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToCnn3DPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.RnnToCnnPreProcessor;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.util.IdentityLayer;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.PermutePreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.TensorFlowCnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.WeightInitDistribution;
@ -97,7 +94,8 @@ public class DTypeTests extends BaseDL4JTest {
Pooling2D.class, //Alias for SubsamplingLayer
Convolution2D.class, //Alias for ConvolutionLayer
Pooling1D.class, //Alias for Subsampling1D
Convolution1D.class //Alias for Convolution1DLayer
Convolution1D.class, //Alias for Convolution1DLayer
TensorFlowCnnToFeedForwardPreProcessor.class //Deprecated
));
@Override
@ -819,7 +817,7 @@ public class DTypeTests extends BaseDL4JTest {
.layer(new DenseLayer.Builder().nOut(5).build())
.layer(new GravesBidirectionalLSTM.Builder().nIn(5).nOut(5).activation(Activation.TANH).build())
.layer(new Bidirectional(new LSTM.Builder().nIn(5).nOut(5).activation(Activation.TANH).build()))
.layer(new TimeDistributed(new DenseLayer.Builder().nIn(10).nOut(5).activation(Activation.TANH).build(), 2))
.layer(new TimeDistributed(new DenseLayer.Builder().nIn(10).nOut(5).activation(Activation.TANH).build()))
.layer(new SimpleRnn.Builder().nIn(5).nOut(5).build())
.layer(new MaskZeroLayer.Builder().underlying(new SimpleRnn.Builder().nIn(5).nOut(5).build()).maskValue(0.0).build())
.layer(secondLast)
@ -1078,7 +1076,7 @@ public class DTypeTests extends BaseDL4JTest {
.addLayer("l", new DenseLayer.Builder().nOut(16).build(), "in")
.addVertex("preproc", new PreprocessorVertex(new FeedForwardToCnn3DPreProcessor(2, 2, 2, 2, true)), "l")
.addVertex("preproc2", new PreprocessorVertex(new PermutePreprocessor(0, 2, 3, 4, 1)), "preproc")
.addVertex("preproc3", new PreprocessorVertex(new ReshapePreprocessor(new long[]{2, 2, 2, 2}, new long[]{16})), "preproc2")
.addVertex("preproc3", new PreprocessorVertex(new ReshapePreprocessor(new long[]{2, 2, 2, 2}, new long[]{16}, false)), "preproc2")
.addLayer("out", new OutputLayer.Builder().nIn(16).nOut(10).build(), "preproc3")
.setInputTypes(InputType.feedForward(5))
.setOutputs("out");
@ -1150,7 +1148,7 @@ public class DTypeTests extends BaseDL4JTest {
case 7:
b.addInputs("in")
.addLayer("1", new ConvolutionLayer.Builder().kernelSize(2, 2).nOut(5).convolutionMode(ConvolutionMode.Same).build(), "in")
.addVertex("2", new PreprocessorVertex(new TensorFlowCnnToFeedForwardPreProcessor(28, 28, 5)), "1")
.addVertex("2", new PreprocessorVertex(new CnnToFeedForwardPreProcessor(28, 28, 5)), "1")
.addLayer("out", new OutputLayer.Builder().nOut(10).build(), "2")
.setOutputs("out")
.setInputTypes(InputType.convolutional(28, 28, 1));

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.nn.graph;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.IteratorDataSetIterator;
@ -54,6 +55,7 @@ import java.util.Map;
import static org.junit.Assert.*;
@Slf4j
public class ComputationGraphTestRNN extends BaseDL4JTest {
@Test
@ -618,7 +620,7 @@ public class ComputationGraphTestRNN extends BaseDL4JTest {
.build();
fail("Exception expected");
} catch (IllegalStateException e){
// e.printStackTrace();
log.error("",e);
assertTrue(e.getMessage().contains("TBPTT") && e.getMessage().contains("validateTbpttConfig"));
}
}

View File

@ -1394,7 +1394,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
} catch (Exception e) {
//e.printStackTrace();
log.error("",e);
if(allowDisconnected){
fail("No exception expected");
} else {

View File

@ -416,8 +416,8 @@ public class TestVariableLengthTSCG extends BaseDL4JTest {
INDArray outRow2 = out2.get(NDArrayIndex.point(i), NDArrayIndex.all(),
NDArrayIndex.point(j));
for (int k = 0; k < nOut; k++) {
assertEquals(outRow.getDouble(k), 0.0, 0.0);
assertEquals(outRow2.getDouble(k), 0.0, 0.0);
assertEquals(0.0, outRow.getDouble(k), 0.0);
assertEquals(0.0, outRow2.getDouble(k), 0.0);
}
}
}

View File

@ -60,7 +60,7 @@ public class TestGraphNodes extends BaseDL4JTest {
@Test
public void testMergeNode() {
Nd4j.getRandom().setSeed(12345);
GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType());
GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType(), 1);
INDArray first = Nd4j.linspace(0, 11, 12, Nd4j.dataType()).reshape(3, 4);
INDArray second = Nd4j.linspace(0, 17, 18, Nd4j.dataType()).reshape(3, 6).addi(100);
@ -82,7 +82,7 @@ public class TestGraphNodes extends BaseDL4JTest {
public void testMergeNodeRNN() {
Nd4j.getRandom().setSeed(12345);
GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType());
GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType(), 1);
INDArray first = Nd4j.linspace(0, 59, 60, Nd4j.dataType()).reshape(3, 4, 5);
INDArray second = Nd4j.linspace(0, 89, 90, Nd4j.dataType()).reshape(3, 6, 5).addi(100);
@ -103,7 +103,7 @@ public class TestGraphNodes extends BaseDL4JTest {
@Test
public void testCnnDepthMerge() {
Nd4j.getRandom().setSeed(12345);
GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType());
GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType(), 1);
INDArray first = Nd4j.linspace(0, 3, 4, Nd4j.dataType()).reshape(1, 1, 2, 2);
INDArray second = Nd4j.linspace(0, 3, 4, Nd4j.dataType()).reshape(1, 1, 2, 2).addi(10);

View File

@ -0,0 +1,974 @@
/* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.nn.layers.convolution;
import lombok.*;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.CnnLossLayer;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.ZeroPaddingLayer;
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D;
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.ComposableInputPreProcessor;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import static org.junit.Assert.assertEquals;
@RunWith(Parameterized.class)
public class ConvDataFormatTests extends BaseDL4JTest {
private final DataType dataType;
public ConvDataFormatTests(DataType dataType){
this.dataType = dataType;
}
@Parameterized.Parameters(name = "{0}")
public static Object[] params(){
return new DataType[]{DataType.FLOAT, DataType.DOUBLE};
}
@Test
public void testConv2d() {
try {
for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getConv2dNet(CNN2DFormat.NCHW, true, cm))
.net2(getConv2dNet(CNN2DFormat.NCHW, false, cm))
.net3(getConv2dNet(CNN2DFormat.NHWC, true, cm))
.net4(getConv2dNet(CNN2DFormat.NHWC, false, cm))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
.testLayerIdx(1)
.build();
testHelper(tc);
}
}
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testSubsampling2d() {
try {
for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getSubsampling2dNet(CNN2DFormat.NCHW, true, cm))
.net2(getSubsampling2dNet(CNN2DFormat.NCHW, false, cm))
.net3(getSubsampling2dNet(CNN2DFormat.NHWC, true, cm))
.net4(getSubsampling2dNet(CNN2DFormat.NHWC, false, cm))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
.testLayerIdx(1)
.build();
testHelper(tc);
}
}
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testDepthwiseConv2d() {
try {
for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getDepthwiseConv2dNet(CNN2DFormat.NCHW, true, cm))
.net2(getDepthwiseConv2dNet(CNN2DFormat.NCHW, false, cm))
.net3(getDepthwiseConv2dNet(CNN2DFormat.NHWC, true, cm))
.net4(getDepthwiseConv2dNet(CNN2DFormat.NHWC, false, cm))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
.testLayerIdx(1)
.build();
testHelper(tc);
}
}
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testSeparableConv2d() {
try {
for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getSeparableConv2dNet(CNN2DFormat.NCHW, true, cm))
.net2(getSeparableConv2dNet(CNN2DFormat.NCHW, false, cm))
.net3(getSeparableConv2dNet(CNN2DFormat.NHWC, true, cm))
.net4(getSeparableConv2dNet(CNN2DFormat.NHWC, false, cm))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
.testLayerIdx(1)
.build();
testHelper(tc);
}
}
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testDeconv2d() {
try {
for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getDeconv2DNet2dNet(CNN2DFormat.NCHW, true, cm))
.net2(getDeconv2DNet2dNet(CNN2DFormat.NCHW, false, cm))
.net3(getDeconv2DNet2dNet(CNN2DFormat.NHWC, true, cm))
.net4(getDeconv2DNet2dNet(CNN2DFormat.NHWC, false, cm))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
.testLayerIdx(1)
.build();
testHelper(tc);
}
}
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testLRN() {
try {
for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getLrnLayer(CNN2DFormat.NCHW, true, cm))
.net2(getLrnLayer(CNN2DFormat.NCHW, false, cm))
.net3(getLrnLayer(CNN2DFormat.NHWC, true, cm))
.net4(getLrnLayer(CNN2DFormat.NHWC, false, cm))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
.testLayerIdx(1)
.build();
testHelper(tc);
}
}
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testZeroPaddingLayer(){
try {
for (boolean helpers : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = helpers ? "With helpers" : "No helpers";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getZeroPaddingNet(CNN2DFormat.NCHW, true))
.net2(getZeroPaddingNet(CNN2DFormat.NCHW, false))
.net3(getZeroPaddingNet(CNN2DFormat.NHWC, true))
.net4(getZeroPaddingNet(CNN2DFormat.NHWC, false))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
.testLayerIdx(1)
.build();
testHelper(tc);
}
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testCropping2DLayer(){
try {
for (boolean helpers : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = helpers ? "With helpers" : "No helpers";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getCropping2dNet(CNN2DFormat.NCHW, true))
.net2(getCropping2dNet(CNN2DFormat.NCHW, false))
.net3(getCropping2dNet(CNN2DFormat.NHWC, true))
.net4(getCropping2dNet(CNN2DFormat.NHWC, false))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
.testLayerIdx(1)
.build();
testHelper(tc);
}
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testUpsampling2d(){
try {
for (boolean helpers : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = helpers ? "With helpers" : "No helpers";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getUpsamplingNet(CNN2DFormat.NCHW, true))
.net2(getUpsamplingNet(CNN2DFormat.NCHW, false))
.net3(getUpsamplingNet(CNN2DFormat.NHWC, true))
.net4(getUpsamplingNet(CNN2DFormat.NHWC, false))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
.testLayerIdx(1)
.build();
testHelper(tc);
}
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testBatchNormNet(){
try {
for(boolean useLogStd : new boolean[]{true, false}) {
for (boolean helpers : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = (helpers ? "With helpers" : "No helpers") + " - " + (useLogStd ? "logstd" : "std");
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getBatchNormNet(useLogStd, CNN2DFormat.NCHW, true))
.net2(getBatchNormNet(useLogStd, CNN2DFormat.NCHW, false))
.net3(getBatchNormNet(useLogStd, CNN2DFormat.NHWC, true))
.net4(getBatchNormNet(useLogStd, CNN2DFormat.NHWC, false))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
.testLayerIdx(1)
.build();
testHelper(tc);
}
}
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testCnnLossLayer() {
try {
for (boolean helpers : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = helpers ? "With helpers" : "No helpers";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray labelsNHWC = TestUtils.randomOneHot(this.dataType,2*6*6, 3);
labelsNHWC = labelsNHWC.reshape(2,6,6,3);
INDArray labelsNCHW = labelsNHWC.permute(0,3,1,2).dup();
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getCnnLossNet(CNN2DFormat.NCHW, true, ConvolutionMode.Same))
.net2(getCnnLossNet(CNN2DFormat.NCHW, false, ConvolutionMode.Same))
.net3(getCnnLossNet(CNN2DFormat.NHWC, true, ConvolutionMode.Same))
.net4(getCnnLossNet(CNN2DFormat.NHWC, false, ConvolutionMode.Same))
.inNCHW(inNCHW)
.labelsNCHW(labelsNCHW)
.labelsNHWC(labelsNHWC)
.testLayerIdx(1)
.nhwcOutput(true)
.build();
testHelper(tc);
}
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testSpaceToDepthNet(){
try {
for (boolean helpers : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = helpers ? "With helpers" : "No helpers";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getSpaceToDepthNet(CNN2DFormat.NCHW, true))
.net2(getSpaceToDepthNet(CNN2DFormat.NCHW, false))
.net3(getSpaceToDepthNet(CNN2DFormat.NHWC, true))
.net4(getSpaceToDepthNet(CNN2DFormat.NHWC, false))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
.testLayerIdx(1)
.build();
testHelper(tc);
}
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testSpaceToBatchNet(){
try {
for (boolean helpers : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = helpers ? "With helpers" : "No helpers";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 16, 16);
INDArray labels = TestUtils.randomOneHot(8, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getSpaceToBatchNet(CNN2DFormat.NCHW, true))
.net2(getSpaceToBatchNet(CNN2DFormat.NCHW, false))
.net3(getSpaceToBatchNet(CNN2DFormat.NHWC, true))
.net4(getSpaceToBatchNet(CNN2DFormat.NHWC, false))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
.testLayerIdx(1)
.build();
testHelper(tc);
}
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testLocallyConnected() {
try {
for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getLocallyConnectedNet(CNN2DFormat.NCHW, true, cm))
.net2(getLocallyConnectedNet(CNN2DFormat.NCHW, false, cm))
.net3(getLocallyConnectedNet(CNN2DFormat.NHWC, true, cm))
.net4(getLocallyConnectedNet(CNN2DFormat.NHWC, false, cm))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
.testLayerIdx(1)
.build();
testHelper(tc);
}
}
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testGlobalPooling() {
try {
for (boolean helpers : new boolean[]{false, true}) {
for (PoolingType pt : PoolingType.values()) {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = helpers ? "With helpers (" + pt + ")" : "No helpers (" + pt + ")";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getGlobalPoolingNet(CNN2DFormat.NCHW, pt, true))
.net2(getGlobalPoolingNet(CNN2DFormat.NCHW, pt, false))
.net3(getGlobalPoolingNet(CNN2DFormat.NHWC, pt, true))
.net4(getGlobalPoolingNet(CNN2DFormat.NHWC, pt, false))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
.testLayerIdx(1)
.build();
testHelper(tc);
}
}
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
private MultiLayerNetwork getConv2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) {
return getNetWithLayer(new ConvolutionLayer.Builder()
.kernelSize(3, 3)
.stride(2, 2)
.activation(Activation.TANH)
.dataFormat(format)
.nOut(3)
.helperAllowFallback(false)
.build(), format, cm, null);
} else {
return getNetWithLayer(new ConvolutionLayer.Builder()
.kernelSize(3, 3)
.stride(2, 2)
.activation(Activation.TANH)
.nOut(3)
.helperAllowFallback(false)
.build(), format, cm, null);
}
}
private MultiLayerNetwork getSubsampling2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) {
return getNetWithLayer(new SubsamplingLayer.Builder()
.kernelSize(2, 2)
.stride(1, 1)
.dataFormat(format)
.helperAllowFallback(false)
.build(), format, cm, null);
} else {
return getNetWithLayer(new SubsamplingLayer.Builder()
.kernelSize(2, 2)
.stride(1, 1)
.helperAllowFallback(false)
.build(), format, cm, null);
}
}
private MultiLayerNetwork getSeparableConv2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) {
return getNetWithLayer(new SeparableConvolution2D.Builder()
.kernelSize(3, 3)
.stride(2, 2)
.activation(Activation.TANH)
.dataFormat(format)
.nOut(3)
.helperAllowFallback(false)
.build(), format, cm, null);
} else {
return getNetWithLayer(new SeparableConvolution2D.Builder()
.kernelSize(3, 3)
.stride(2, 2)
.activation(Activation.TANH)
.nOut(3)
.helperAllowFallback(false)
.build(), format, cm, null);
}
}
private MultiLayerNetwork getDepthwiseConv2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) {
return getNetWithLayer(new DepthwiseConvolution2D.Builder()
.depthMultiplier(2)
.kernelSize(3, 3)
.stride(2, 2)
.activation(Activation.TANH)
.dataFormat(format)
.nOut(3)
.helperAllowFallback(false)
.build(), format, cm, null);
} else {
return getNetWithLayer(new DepthwiseConvolution2D.Builder()
.depthMultiplier(2)
.kernelSize(3, 3)
.stride(2, 2)
.activation(Activation.TANH)
.nOut(3)
.helperAllowFallback(false)
.build(), format, cm, null);
}
}
private MultiLayerNetwork getLrnLayer(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) {
return getNetWithLayer(new LocalResponseNormalization.Builder()
.dataFormat(format)
.helperAllowFallback(false)
.build(), format, cm, null);
} else {
return getNetWithLayer(new LocalResponseNormalization.Builder()
.helperAllowFallback(false)
.build(), format, cm, null);
}
}
private MultiLayerNetwork getZeroPaddingNet(CNN2DFormat format, boolean setOnLayerAlso) {
if (setOnLayerAlso) {
return getNetWithLayer(new ZeroPaddingLayer.Builder(2,2)
.dataFormat(format).build(), format, ConvolutionMode.Same, null);
} else {
return getNetWithLayer(new ZeroPaddingLayer.Builder(2,2).build(),
format, ConvolutionMode.Same, null);
}
}
private MultiLayerNetwork getCropping2dNet(CNN2DFormat format, boolean setOnLayerAlso) {
if (setOnLayerAlso) {
return getNetWithLayer(new Cropping2D.Builder(2,2)
.dataFormat(format).build(), format, ConvolutionMode.Same, null);
} else {
return getNetWithLayer(new Cropping2D.Builder(2,2)
.build(), format, ConvolutionMode.Same, null);
}
}
private MultiLayerNetwork getUpsamplingNet(CNN2DFormat format, boolean setOnLayerAlso) {
if (setOnLayerAlso) {
return getNetWithLayer(new Upsampling2D.Builder(2)
.dataFormat(format).build(), format, ConvolutionMode.Same, null);
} else {
return getNetWithLayer(new Upsampling2D.Builder(2)
.build(), format, ConvolutionMode.Same, null);
}
}
private MultiLayerNetwork getDeconv2DNet2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) {
return getNetWithLayer(new Deconvolution2D.Builder().nOut(2)
.activation(Activation.TANH)
.kernelSize(2,2)
.stride(2,2)
.build(), format, cm, null);
} else {
return getNetWithLayer(new Deconvolution2D.Builder().nOut(2)
.activation(Activation.TANH)
.kernelSize(2,2)
.stride(2,2)
.build(), format, cm, null);
}
}
private MultiLayerNetwork getBatchNormNet(boolean logStdev, CNN2DFormat format, boolean setOnLayerAlso) {
if (setOnLayerAlso) {
return getNetWithLayer(new BatchNormalization.Builder()
.useLogStd(logStdev)
.dataFormat(format)
.helperAllowFallback(false)
.nOut(3).build(), format, ConvolutionMode.Same, null);
} else {
return getNetWithLayer(new BatchNormalization.Builder()
.useLogStd(logStdev)
.helperAllowFallback(false)
.nOut(3).build(), format, ConvolutionMode.Same, null);
}
}
private MultiLayerNetwork getSpaceToDepthNet(CNN2DFormat format, boolean setOnLayerAlso) {
if (setOnLayerAlso) {
return getNetWithLayer(new SpaceToDepthLayer.Builder()
.blocks(2)
.dataFormat(format)
.build(), format, ConvolutionMode.Same, null);
} else {
return getNetWithLayer(new SpaceToDepthLayer.Builder()
.blocks(2)
.build(), format, ConvolutionMode.Same, null);
}
}
private MultiLayerNetwork getSpaceToBatchNet(CNN2DFormat format, boolean setOnLayerAlso) {
if (setOnLayerAlso) {
return getNetWithLayer(new SpaceToBatchLayer.Builder()
.blocks(2, 2)
.dataFormat(format)
.build(), format, ConvolutionMode.Same, InputType.convolutional(16, 16, 3, format));
} else {
return getNetWithLayer(new SpaceToBatchLayer.Builder()
.blocks(2, 2)
.build(), format, ConvolutionMode.Same, InputType.convolutional(16, 16, 3, format));
}
}
private MultiLayerNetwork getLocallyConnectedNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) {
return getNetWithLayer(new LocallyConnected2D.Builder()
.kernelSize(3, 3)
.stride(2, 2)
.activation(Activation.TANH)
.dataFormat(format)
.nOut(3)
.build(), format, cm, null);
} else {
return getNetWithLayer(new LocallyConnected2D.Builder()
.kernelSize(3, 3)
.stride(2, 2)
.activation(Activation.TANH)
.nOut(3)
.build(), format, cm, null);
}
}
private MultiLayerNetwork getNetWithLayer(Layer layer, CNN2DFormat format, ConvolutionMode cm, InputType inputType) {
NeuralNetConfiguration.ListBuilder builder = new NeuralNetConfiguration.Builder()
.dataType(this.dataType)
.seed(12345)
.convolutionMode(cm)
.list()
.layer(new ConvolutionLayer.Builder()
.kernelSize(3, 3)
.stride(2, 2)
.activation(Activation.TANH)
.dataFormat(format)
.nOut(3)
.helperAllowFallback(false)
.build())
.layer(layer)
.layer(new OutputLayer.Builder().activation(Activation.SOFTMAX).nOut(10).build())
.setInputType(inputType != null ? inputType : InputType.convolutional(12, 12, 3, format));
if(format == CNN2DFormat.NHWC && !(layer instanceof GlobalPoolingLayer)){
//Add a preprocessor due to the differences in how NHWC and NCHW activations are flattened
//DL4J's flattening behaviour matches Keras (hence TF) for import compatibility
builder.inputPreProcessor(2, new ComposableInputPreProcessor(new NHWCToNCHWPreprocessor(), new CnnToFeedForwardPreProcessor()));
}
MultiLayerNetwork net = new MultiLayerNetwork(builder.build());
net.init();
return net;
}
private MultiLayerNetwork getGlobalPoolingNet(CNN2DFormat format, PoolingType pt, boolean setOnLayerAlso) {
if (setOnLayerAlso) {
return getNetWithLayer(new GlobalPoolingLayer.Builder(pt)
.poolingDimensions(format == CNN2DFormat.NCHW ? new int[]{2,3} : new int[]{1,2})
.build(), format, ConvolutionMode.Same, null);
} else {
return getNetWithLayer(new GlobalPoolingLayer.Builder(pt)
.build(), format, ConvolutionMode.Same, null);
}
}
private MultiLayerNetwork getCnnLossNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm){
NeuralNetConfiguration.ListBuilder builder = new NeuralNetConfiguration.Builder()
.seed(12345)
.convolutionMode(cm)
.list()
.layer(new ConvolutionLayer.Builder()
.kernelSize(3, 3)
.stride(2, 2)
.activation(Activation.TANH)
.dataFormat(format)
.nOut(3)
.helperAllowFallback(false)
.build());
if(setOnLayerAlso){
builder.layer(new CnnLossLayer.Builder().format(format).activation(Activation.SOFTMAX).build());
} else {
builder.layer(new CnnLossLayer.Builder().activation(Activation.SOFTMAX).build());
}
builder.setInputType(InputType.convolutional(12, 12, 3, format));
MultiLayerNetwork net = new MultiLayerNetwork(builder.build());
net.init();
return net;
}
@AllArgsConstructor
@Data
@NoArgsConstructor
@Builder
private static class TestCase {
private String msg;
private MultiLayerNetwork net1;
private MultiLayerNetwork net2;
private MultiLayerNetwork net3;
private MultiLayerNetwork net4;
private INDArray inNCHW;
private INDArray labelsNCHW;
private INDArray labelsNHWC;
private int testLayerIdx;
private boolean nhwcOutput;
}
public static void testHelper(TestCase tc) {
tc.net2.params().assign(tc.net1.params());
tc.net3.params().assign(tc.net1.params());
tc.net4.params().assign(tc.net1.params());
//Test forward pass:
INDArray inNCHW = tc.inNCHW;
INDArray inNHWC = tc.inNCHW.permute(0, 2, 3, 1).dup();
INDArray l0_1 = tc.net1.feedForward(inNCHW).get(tc.testLayerIdx + 1);
INDArray l0_2 = tc.net2.feedForward(inNCHW).get(tc.testLayerIdx + 1);
INDArray l0_3 = tc.net3.feedForward(inNHWC).get(tc.testLayerIdx + 1);
INDArray l0_4 = tc.net4.feedForward(inNHWC).get(tc.testLayerIdx + 1);
assertEquals(tc.msg, l0_1, l0_2);
if(l0_1.rank() == 4) {
assertEquals(tc.msg, l0_1, l0_3.permute(0, 3, 1, 2));
assertEquals(tc.msg, l0_1, l0_4.permute(0, 3, 1, 2));
} else {
assertEquals(tc.msg, l0_1, l0_3);
assertEquals(tc.msg, l0_1, l0_4);
}
INDArray out1 = tc.net1.output(inNCHW);
INDArray out2 = tc.net2.output(inNCHW);
INDArray out3 = tc.net3.output(inNHWC);
INDArray out4 = tc.net4.output(inNHWC);
assertEquals(tc.msg, out1, out2);
if(!tc.nhwcOutput) {
assertEquals(tc.msg, out1, out3);
assertEquals(tc.msg, out1, out4);
} else {
assertEquals(tc.msg, out1, out3.permute(0,3,1,2)); //NHWC to NCHW
assertEquals(tc.msg, out1, out4.permute(0,3,1,2));
}
//Test backprop
Pair<Gradient, INDArray> p1 = tc.net1.calculateGradients(inNCHW, tc.labelsNCHW, null, null);
Pair<Gradient, INDArray> p2 = tc.net2.calculateGradients(inNCHW, tc.labelsNCHW, null, null);
Pair<Gradient, INDArray> p3 = tc.net3.calculateGradients(inNHWC, tc.labelsNHWC, null, null);
Pair<Gradient, INDArray> p4 = tc.net4.calculateGradients(inNHWC, tc.labelsNHWC, null, null);
//Inpput gradients
assertEquals(tc.msg, p1.getSecond(), p2.getSecond());
assertEquals(tc.msg, p1.getSecond(), p3.getSecond().permute(0,3,1,2)); //Input gradients for NHWC input are also in NHWC format
assertEquals(tc.msg, p1.getSecond(), p4.getSecond().permute(0,3,1,2));
List<String> diff12 = differentGrads(p1.getFirst(), p2.getFirst());
List<String> diff13 = differentGrads(p1.getFirst(), p3.getFirst());
List<String> diff14 = differentGrads(p1.getFirst(), p4.getFirst());
assertEquals(tc.msg + " " + diff12, 0, diff12.size());
assertEquals(tc.msg + " " + diff13, 0, diff13.size());
assertEquals(tc.msg + " " + diff14, 0, diff14.size());
assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p2.getFirst().gradientForVariable());
assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p3.getFirst().gradientForVariable());
assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p4.getFirst().gradientForVariable());
tc.net1.fit(inNCHW, tc.labelsNCHW);
tc.net2.fit(inNCHW, tc.labelsNCHW);
tc.net3.fit(inNHWC, tc.labelsNHWC);
tc.net4.fit(inNHWC, tc.labelsNHWC);
assertEquals(tc.msg, tc.net1.params(), tc.net2.params());
assertEquals(tc.msg, tc.net1.params(), tc.net3.params());
assertEquals(tc.msg, tc.net1.params(), tc.net4.params());
//Test serialization
MultiLayerNetwork net1a = TestUtils.testModelSerialization(tc.net1);
MultiLayerNetwork net2a = TestUtils.testModelSerialization(tc.net2);
MultiLayerNetwork net3a = TestUtils.testModelSerialization(tc.net3);
MultiLayerNetwork net4a = TestUtils.testModelSerialization(tc.net4);
out1 = tc.net1.output(inNCHW);
assertEquals(tc.msg, out1, net1a.output(inNCHW));
assertEquals(tc.msg, out1, net2a.output(inNCHW));
if(!tc.nhwcOutput) {
assertEquals(tc.msg, out1, net3a.output(inNHWC));
assertEquals(tc.msg, out1, net4a.output(inNHWC));
} else {
assertEquals(tc.msg, out1, net3a.output(inNHWC).permute(0,3,1,2)); //NHWC to NCHW
assertEquals(tc.msg, out1, net4a.output(inNHWC).permute(0,3,1,2));
}
}
private static List<String> differentGrads(Gradient g1, Gradient g2){
List<String> differs = new ArrayList<>();
Map<String,INDArray> m1 = g1.gradientForVariable();
Map<String,INDArray> m2 = g2.gradientForVariable();
for(String s : m1.keySet()){
INDArray a1 = m1.get(s);
INDArray a2 = m2.get(s);
if(!a1.equals(a2)){
differs.add(s);
}
}
return differs;
}
//Converts NHWC to NCHW activations
@EqualsAndHashCode
private static class NHWCToNCHWPreprocessor implements InputPreProcessor {
@Override
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input.permute(0,3,1,2));
}
@Override
public INDArray backprop(INDArray output, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, output.permute(0,2,3,1));
}
@Override
public InputPreProcessor clone() {
return this;
}
@Override
public InputType getOutputType(InputType inputType) {
InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
return InputType.convolutional(c.getHeight(), c.getWidth(), c.getChannels(), CNN2DFormat.NCHW);
}
@Override
public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize) {
return null;
}
}
}

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.nn.layers.convolution;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.exception.DL4JException;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
@ -45,6 +46,7 @@ import static org.junit.Assert.*;
/**
* Created by Alex on 15/11/2016.
*/
@Slf4j
public class TestConvolutionModes extends BaseDL4JTest {
@Test
@ -106,12 +108,12 @@ public class TestConvolutionModes extends BaseDL4JTest {
}
} catch (DL4JException e) {
if (inSize == 9 || cm != ConvolutionMode.Strict) {
e.printStackTrace();
log.error("",e);
fail("Unexpected exception");
}
continue; //Expected exception
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail("Unexpected exception");
}
@ -184,12 +186,12 @@ public class TestConvolutionModes extends BaseDL4JTest {
}
} catch (DL4JException e) {
if (inSize == 9 || cm != ConvolutionMode.Strict) {
e.printStackTrace();
log.error("",e);
fail("Unexpected exception");
}
continue; //Expected exception
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail("Unexpected exception");
}

View File

@ -24,10 +24,7 @@ import org.deeplearning4j.earlystopping.saver.InMemoryModelSaver;
import org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculator;
import org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition;
import org.deeplearning4j.earlystopping.trainer.EarlyStoppingGraphTrainer;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
@ -45,6 +42,8 @@ import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.util.ModelSerializer;
import org.deeplearning4j.util.TimeSeriesUtils;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -61,12 +60,22 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import static org.deeplearning4j.nn.conf.RNNFormat.NCW;
import static org.junit.Assert.assertEquals;
@Slf4j
@RunWith(Parameterized.class)
public class BidirectionalTest extends BaseDL4JTest {
private RNNFormat rnnDataFormat;
public BidirectionalTest(RNNFormat rnnDataFormat){
this.rnnDataFormat = rnnDataFormat;
}
@Parameterized.Parameters
public static Object[] params(){
return RNNFormat.values();
}
@Test
public void compareImplementations(){
for(WorkspaceMode wsm : WorkspaceMode.values()) {
@ -82,9 +91,9 @@ public class BidirectionalTest extends BaseDL4JTest {
.inferenceWorkspaceMode(wsm)
.updater(new Adam())
.list()
.layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build()))
.layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build()))
.layer(new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE)
.layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()))
.layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()))
.layer(new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).dataFormat(rnnDataFormat)
.nIn(10).nOut(10).build())
.build();
@ -95,9 +104,9 @@ public class BidirectionalTest extends BaseDL4JTest {
.inferenceWorkspaceMode(wsm)
.updater(new Adam())
.list()
.layer(new GravesBidirectionalLSTM.Builder().nIn(10).nOut(10).build())
.layer(new GravesBidirectionalLSTM.Builder().nIn(10).nOut(10).build())
.layer(new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE)
.layer(new GravesBidirectionalLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build())
.layer(new GravesBidirectionalLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build())
.layer(new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).dataFormat(rnnDataFormat)
.nIn(10).nOut(10).build())
.build();
@ -116,15 +125,24 @@ public class BidirectionalTest extends BaseDL4JTest {
net2.setParams(net1.params()); //Assuming exact same layout here...
INDArray in = Nd4j.rand(new int[]{3, 10, 5});
INDArray in;
if (rnnDataFormat == NCW){
in = Nd4j.rand(new int[]{3, 10, 5});
}else{
in = Nd4j.rand(new int[]{3, 5, 10});
}
INDArray out1 = net1.output(in);
INDArray out2 = net2.output(in);
assertEquals(out1, out2);
INDArray labels = Nd4j.rand(new int[]{3, 10, 5});
INDArray labels;
if (rnnDataFormat == NCW){
labels = Nd4j.rand(new int[]{3, 10, 5});
}else{
labels = Nd4j.rand(new int[]{3, 5, 10});
}
net1.setInput(in);
net1.setLabels(labels);
@ -276,17 +294,22 @@ public class BidirectionalTest extends BaseDL4JTest {
.inferenceWorkspaceMode(wsm)
.updater(new Adam())
.list()
.layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build()))
.layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build()))
.layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()))
.layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()))
.layer(new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE)
.nIn(10).nOut(10).build())
.nIn(10).nOut(10).dataFormat(rnnDataFormat).build())
.build();
MultiLayerNetwork net1 = new MultiLayerNetwork(conf1);
net1.init();
INDArray in = Nd4j.rand(new int[]{3, 10, 5});
INDArray labels = Nd4j.rand(new int[]{3, 10, 5});
INDArray in;
INDArray labels;
long[] inshape = rnnDataFormat == NCW ? new long[]{3, 10, 5} : new long[]{3, 5, 10};
in = Nd4j.rand(inshape);
labels = Nd4j.rand(inshape);
net1.fit(in, labels);
@ -300,8 +323,8 @@ public class BidirectionalTest extends BaseDL4JTest {
MultiLayerNetwork net2 = ModelSerializer.restoreMultiLayerNetwork(new ByteArrayInputStream(bytes), true);
in = Nd4j.rand(new int[]{3, 10, 5});
labels = Nd4j.rand(new int[]{3, 10, 5});
in = Nd4j.rand(inshape);
labels = Nd4j.rand(inshape);
INDArray out1 = net1.output(in);
INDArray out2 = net2.output(in);
@ -338,18 +361,18 @@ public class BidirectionalTest extends BaseDL4JTest {
.updater(new Adam())
.graphBuilder()
.addInputs("in")
.layer("0", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build()), "in")
.layer("1", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build()), "0")
.layer("2", new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE)
.layer("0", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()), "in")
.layer("1", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()), "0")
.layer("2", new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).dataFormat(rnnDataFormat)
.nIn(10).nOut(10).build(), "1")
.setOutputs("2")
.build();
ComputationGraph net1 = new ComputationGraph(conf1);
net1.init();
INDArray in = Nd4j.rand(new int[]{3, 10, 5});
INDArray labels = Nd4j.rand(new int[]{3, 10, 5});
long[] inshape = (rnnDataFormat == NCW)? new long[]{3, 10, 5}: new long[]{3, 5, 10};
INDArray in = Nd4j.rand(inshape);
INDArray labels = Nd4j.rand(inshape);
net1.fit(new DataSet(in, labels));
@ -363,8 +386,8 @@ public class BidirectionalTest extends BaseDL4JTest {
ComputationGraph net2 = ModelSerializer.restoreComputationGraph(new ByteArrayInputStream(bytes), true);
in = Nd4j.rand(new int[]{3, 10, 5});
labels = Nd4j.rand(new int[]{3, 10, 5});
in = Nd4j.rand(inshape);
labels = Nd4j.rand(inshape);
INDArray out1 = net1.outputSingle(in);
INDArray out2 = net2.outputSingle(in);
@ -394,8 +417,8 @@ public class BidirectionalTest extends BaseDL4JTest {
Bidirectional.Mode[] modes = new Bidirectional.Mode[]{Bidirectional.Mode.CONCAT, Bidirectional.Mode.ADD,
Bidirectional.Mode.AVERAGE, Bidirectional.Mode.MUL};
INDArray in = Nd4j.rand(new int[]{3, 10, 6});
long[] inshape = rnnDataFormat == NCW ? new long[]{3, 10, 6} : new long[]{3, 6, 10};
INDArray in = Nd4j.rand(inshape);
for (Bidirectional.Mode m : modes) {
MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder()
@ -406,7 +429,7 @@ public class BidirectionalTest extends BaseDL4JTest {
.inferenceWorkspaceMode(wsm)
.updater(new Adam())
.list()
.layer(new Bidirectional(m, new SimpleRnn.Builder().nIn(10).nOut(10).build()))
.layer(new Bidirectional(m, new SimpleRnn.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()))
.build();
MultiLayerNetwork net1 = new MultiLayerNetwork(conf1);
@ -418,7 +441,7 @@ public class BidirectionalTest extends BaseDL4JTest {
.weightInit(WeightInit.XAVIER)
.updater(new Adam())
.list()
.layer(new SimpleRnn.Builder().nIn(10).nOut(10).build())
.layer(new SimpleRnn.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build())
.build();
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2.clone());
@ -434,11 +457,10 @@ public class BidirectionalTest extends BaseDL4JTest {
net3.setParam("0_RW", net1.getParam("0_bRW"));
net3.setParam("0_b", net1.getParam("0_bb"));
INDArray inReverse = TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT);
INDArray inReverse = TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT, rnnDataFormat);
INDArray out1 = net1.output(in);
INDArray out2 = net2.output(in);
INDArray out3 = TimeSeriesUtils.reverseTimeSeries(net3.output(inReverse), LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT);
INDArray out3 = TimeSeriesUtils.reverseTimeSeries(net3.output(inReverse), LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT, rnnDataFormat);
INDArray outExp;
switch (m) {
@ -452,7 +474,7 @@ public class BidirectionalTest extends BaseDL4JTest {
outExp = out2.add(out3).muli(0.5);
break;
case CONCAT:
outExp = Nd4j.concat(1, out2, out3);
outExp = Nd4j.concat((rnnDataFormat == NCW)?1:2, out2, out3);
break;
default:
throw new RuntimeException();
@ -464,25 +486,25 @@ public class BidirectionalTest extends BaseDL4JTest {
//Check gradients:
if (m == Bidirectional.Mode.ADD || m == Bidirectional.Mode.CONCAT) {
INDArray eps = Nd4j.rand(new int[]{3, 10, 6});
INDArray eps = Nd4j.rand(inshape);
INDArray eps1;
if (m == Bidirectional.Mode.CONCAT) {
eps1 = Nd4j.concat(1, eps, eps);
eps1 = Nd4j.concat((rnnDataFormat == NCW)?1:2, eps, eps);
} else {
eps1 = eps;
}
net1.setInput(in);
net2.setInput(in);
net3.setInput(TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT));
net3.setInput(TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT, rnnDataFormat));
net1.feedForward(true, false);
net2.feedForward(true, false);
net3.feedForward(true, false);
Pair<Gradient, INDArray> p1 = net1.backpropGradient(eps1, LayerWorkspaceMgr.noWorkspaces());
Pair<Gradient, INDArray> p2 = net2.backpropGradient(eps, LayerWorkspaceMgr.noWorkspaces());
Pair<Gradient, INDArray> p3 = net3.backpropGradient(TimeSeriesUtils.reverseTimeSeries(eps, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT), LayerWorkspaceMgr.noWorkspaces());
Pair<Gradient, INDArray> p3 = net3.backpropGradient(TimeSeriesUtils.reverseTimeSeries(eps, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT, rnnDataFormat), LayerWorkspaceMgr.noWorkspaces());
Gradient g1 = p1.getFirst();
Gradient g2 = p2.getFirst();
Gradient g3 = p3.getFirst();
@ -520,7 +542,9 @@ public class BidirectionalTest extends BaseDL4JTest {
Bidirectional.Mode.AVERAGE, Bidirectional.Mode.MUL};
INDArray in = Nd4j.rand(new int[]{3, 10, 6});
long[] inshape = rnnDataFormat == NCW ? new long[]{3, 10, 6} : new long[]{3, 6, 10};
INDArray in = Nd4j.rand(inshape);
for (Bidirectional.Mode m : modes) {
ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder()
@ -532,7 +556,7 @@ public class BidirectionalTest extends BaseDL4JTest {
.updater(new Adam())
.graphBuilder()
.addInputs("in")
.layer("0", new Bidirectional(m, new SimpleRnn.Builder().nIn(10).nOut(10).build()), "in")
.layer("0", new Bidirectional(m, new SimpleRnn.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()), "in")
.setOutputs("0")
.build();
@ -546,7 +570,7 @@ public class BidirectionalTest extends BaseDL4JTest {
.updater(new Adam())
.graphBuilder()
.addInputs("in")
.layer("0", new SimpleRnn.Builder().nIn(10).nOut(10).build(), "in")
.layer("0", new SimpleRnn.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build(), "in")
.setOutputs("0")
.build();
@ -566,9 +590,20 @@ public class BidirectionalTest extends BaseDL4JTest {
INDArray out1 = net1.outputSingle(in);
INDArray out2 = net2.outputSingle(in);
INDArray out3 = TimeSeriesUtils.reverseTimeSeries(net3.outputSingle(
TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT)),
LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT);
INDArray out3;
INDArray inReverse;
if (rnnDataFormat == RNNFormat.NWC){
inReverse = TimeSeriesUtils.reverseTimeSeries(in.permute(0, 2, 1), LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT).permute(0, 2, 1);
out3 = net3.outputSingle(inReverse);
out3 = TimeSeriesUtils.reverseTimeSeries(out3.permute(0, 2, 1), LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT).permute(0, 2, 1);
}
else{
inReverse = TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT);
out3 = net3.outputSingle(inReverse);
out3 = TimeSeriesUtils.reverseTimeSeries(out3, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT);
}
INDArray outExp;
switch (m) {
@ -582,7 +617,9 @@ public class BidirectionalTest extends BaseDL4JTest {
outExp = out2.add(out3).muli(0.5);
break;
case CONCAT:
outExp = Nd4j.concat(1, out2, out3);
System.out.println(out2.shapeInfoToString());
System.out.println(out3.shapeInfoToString());
outExp = Nd4j.concat((rnnDataFormat == NCW)?1:2, out2, out3);
break;
default:
throw new RuntimeException();
@ -594,22 +631,26 @@ public class BidirectionalTest extends BaseDL4JTest {
//Check gradients:
if (m == Bidirectional.Mode.ADD || m == Bidirectional.Mode.CONCAT) {
INDArray eps = Nd4j.rand(new int[]{3, 10, 6});
INDArray eps = Nd4j.rand(inshape);
INDArray eps1;
if (m == Bidirectional.Mode.CONCAT) {
eps1 = Nd4j.concat(1, eps, eps);
eps1 = Nd4j.concat((rnnDataFormat == NCW)?1:2, eps, eps);
} else {
eps1 = eps;
}
INDArray epsReversed = (rnnDataFormat == NCW)?
TimeSeriesUtils.reverseTimeSeries(eps, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT):
TimeSeriesUtils.reverseTimeSeries(eps.permute(0, 2, 1), LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT)
.permute(0, 2, 1);
net1.outputSingle(true, false, in);
net2.outputSingle(true, false, in);
net3.outputSingle(true, false, TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT));
net3.outputSingle(true, false, inReverse);
Gradient g1 = net1.backpropGradient(eps1);
Gradient g2 = net2.backpropGradient(eps);
Gradient g3 = net3.backpropGradient(TimeSeriesUtils.reverseTimeSeries(eps, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT));
Gradient g3 = net3.backpropGradient(epsReversed);
for (boolean updates : new boolean[]{false, true}) {
if (updates) {

View File

@ -23,6 +23,7 @@ import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
@ -31,6 +32,8 @@ import org.deeplearning4j.nn.params.GravesLSTMParamInitializer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -42,10 +45,18 @@ import org.nd4j.linalg.primitives.Pair;
import static org.junit.Assert.*;
@RunWith(Parameterized.class)
public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
private double score = 0.0;
private RNNFormat rnnDataFormat;
public GravesBidirectionalLSTMTest(RNNFormat rnnDataFormat){
this.rnnDataFormat = rnnDataFormat;
}
@Parameterized.Parameters
public static Object[] params(){
return RNNFormat.values();
}
@Test
public void testBidirectionalLSTMGravesForwardBasic() {
//Very basic test of forward prop. of LSTM layer with a time series.
@ -55,7 +66,7 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
final NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder()
.layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().nIn(nIn)
.nOut(nHiddenUnits).activation(Activation.TANH).build())
.nOut(nHiddenUnits).dataFormat(rnnDataFormat).activation(Activation.TANH).build())
.build();
val numParams = conf.getLayer().initializer().numParams(conf);
@ -65,7 +76,7 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
//Data: has shape [miniBatchSize,nIn,timeSeriesLength];
//Output/activations has shape [miniBatchsize,nHiddenUnits,timeSeriesLength];
if (rnnDataFormat == RNNFormat.NCW){
final INDArray dataSingleExampleTimeLength1 = Nd4j.ones(1, nIn, 1);
final INDArray activations1 = layer.activate(dataSingleExampleTimeLength1, false, LayerWorkspaceMgr.noWorkspaces());
assertArrayEquals(activations1.shape(), new long[] {1, nHiddenUnits, 1});
@ -82,6 +93,25 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
final INDArray activations4 = layer.activate(dataMultiExampleLength15, false, LayerWorkspaceMgr.noWorkspaces());
assertArrayEquals(activations4.shape(), new long[] {10, nHiddenUnits, 15});
}
else{
final INDArray dataSingleExampleTimeLength1 = Nd4j.ones(1, 1, nIn);
final INDArray activations1 = layer.activate(dataSingleExampleTimeLength1, false, LayerWorkspaceMgr.noWorkspaces());
assertArrayEquals(activations1.shape(), new long[] {1, 1, nHiddenUnits});
final INDArray dataMultiExampleLength1 = Nd4j.ones(10, 1, nIn);
final INDArray activations2 = layer.activate(dataMultiExampleLength1, false, LayerWorkspaceMgr.noWorkspaces());
assertArrayEquals(activations2.shape(), new long[] {10, 1, nHiddenUnits});
final INDArray dataSingleExampleLength12 = Nd4j.ones(1, 12, nIn);
final INDArray activations3 = layer.activate(dataSingleExampleLength12, false, LayerWorkspaceMgr.noWorkspaces());
assertArrayEquals(activations3.shape(), new long[] {1, 12, nHiddenUnits});
final INDArray dataMultiExampleLength15 = Nd4j.ones(10, 15, nIn);
final INDArray activations4 = layer.activate(dataMultiExampleLength15, false, LayerWorkspaceMgr.noWorkspaces());
assertArrayEquals(activations4.shape(), new long[] {10, 15, nHiddenUnits});
}
}
@Test
public void testBidirectionalLSTMGravesBackwardBasic() {
@ -94,14 +124,15 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
testGravesBackwardBasicHelper(13, 3, 17, 1, 1); //Edge case: both miniBatchSize = 1 and timeSeriesLength = 1
}
private static void testGravesBackwardBasicHelper(int nIn, int nOut, int lstmNHiddenUnits, int miniBatchSize,
private void testGravesBackwardBasicHelper(int nIn, int nOut, int lstmNHiddenUnits, int miniBatchSize,
int timeSeriesLength) {
INDArray inputData = Nd4j.ones(miniBatchSize, nIn, timeSeriesLength);
INDArray inputData = (rnnDataFormat == RNNFormat.NCW)?Nd4j.ones(miniBatchSize, nIn, timeSeriesLength):
Nd4j.ones(miniBatchSize, timeSeriesLength, nIn);
NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder()
.layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().nIn(nIn)
.nOut(lstmNHiddenUnits)
.nOut(lstmNHiddenUnits).dataFormat(rnnDataFormat)
.dist(new UniformDistribution(0, 1)).activation(Activation.TANH).build())
.build();
@ -114,7 +145,8 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
lstm.activate(inputData, false, LayerWorkspaceMgr.noWorkspaces());
assertNotNull(lstm.input());
INDArray epsilon = Nd4j.ones(miniBatchSize, lstmNHiddenUnits, timeSeriesLength);
INDArray epsilon =(rnnDataFormat == RNNFormat.NCW)? Nd4j.ones(miniBatchSize, lstmNHiddenUnits, timeSeriesLength):
Nd4j.ones(miniBatchSize, timeSeriesLength, lstmNHiddenUnits);
Pair<Gradient, INDArray> out = lstm.backpropGradient(epsilon, LayerWorkspaceMgr.noWorkspaces());
Gradient outGradient = out.getFirst();
@ -147,7 +179,11 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
assertArrayEquals(recurrentWeightGradientB.shape(), new long[] {lstmNHiddenUnits, 4 * lstmNHiddenUnits + 3});
assertNotNull(nextEpsilon);
assertArrayEquals(nextEpsilon.shape(), new long[] {miniBatchSize, nIn, timeSeriesLength});
if (rnnDataFormat == RNNFormat.NCW) {
assertArrayEquals(nextEpsilon.shape(), new long[]{miniBatchSize, nIn, timeSeriesLength});
}else{
assertArrayEquals(nextEpsilon.shape(), new long[]{miniBatchSize, timeSeriesLength, nIn });
}
//Check update:
for (String s : outGradient.gradientForVariable().keySet()) {
@ -226,7 +262,7 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
final NeuralNetConfiguration confBidirectional = new NeuralNetConfiguration.Builder()
.layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().nIn(nIn)
.nOut(layerSize)
.nOut(layerSize).dataFormat(rnnDataFormat)
.dist(new UniformDistribution(-0.1, 0.1)).activation(Activation.TANH).build())
.build();
@ -237,7 +273,8 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
.instantiate(confBidirectional, null, 0, params, true, params.dataType());
final INDArray sig = Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength});
final INDArray sig = (rnnDataFormat == RNNFormat.NCW)?Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength}):
Nd4j.rand(new int[] {miniBatchSize, timeSeriesLength, nIn});
final INDArray act1 = bidirectionalLSTM.activate(sig, false, LayerWorkspaceMgr.noWorkspaces());
@ -265,13 +302,13 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
final NeuralNetConfiguration confBidirectional =
new NeuralNetConfiguration.Builder()
.layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder()
.nIn(nIn).nOut(layerSize)
.nIn(nIn).nOut(layerSize).dataFormat(rnnDataFormat)
.dist(new UniformDistribution(-0.1, 0.1))
.activation(Activation.TANH).updater(new NoOp()).build())
.build();
final NeuralNetConfiguration confForwards = new NeuralNetConfiguration.Builder()
.layer(new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn).nOut(layerSize)
.layer(new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn).nOut(layerSize).dataFormat(rnnDataFormat)
.weightInit(WeightInit.ZERO).activation(Activation.TANH).build())
.build();
@ -290,9 +327,16 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
Nd4j.create(1, confForwards.getLayer().initializer().numParams(confForwards)));
final INDArray sig = Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength});
final INDArray sig = (rnnDataFormat == RNNFormat.NCW)?Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength}):
Nd4j.rand(new int[] {miniBatchSize, timeSeriesLength, nIn});
final INDArray sigb = sig.dup();
if (rnnDataFormat == RNNFormat.NCW) {
reverseColumnsInPlace(sigb.slice(0));
}
else{
reverseColumnsInPlace(sigb.slice(0).permute(1, 0));
}
final INDArray recurrentWeightsF = bidirectionalLSTM
.getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS);
@ -345,10 +389,14 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
assertArrayEquals(activation1.data().asFloat(), activation2.data().asFloat(), 1e-5f);
final INDArray randSig = Nd4j.rand(new int[] {1, layerSize, timeSeriesLength});
final INDArray randSigBackwards = randSig.dup();
final INDArray randSig = (rnnDataFormat == RNNFormat.NCW)?Nd4j.rand(new int[] {1, layerSize, timeSeriesLength}):
Nd4j.rand(new int[] {1, timeSeriesLength, layerSize});
INDArray randSigBackwards = randSig.dup();
if (rnnDataFormat == RNNFormat.NCW){
reverseColumnsInPlace(randSigBackwards.slice(0));
}else{
reverseColumnsInPlace(randSigBackwards.slice(0).permute(1, 0));
}
final Pair<Gradient, INDArray> backprop1 = forwardsLSTM.backpropGradient(randSig, LayerWorkspaceMgr.noWorkspaces());
final Pair<Gradient, INDArray> backprop2 = bidirectionalLSTM.backpropGradient(randSig, LayerWorkspaceMgr.noWorkspaces());
@ -399,10 +447,16 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
final INDArray activation3 = bidirectionalLSTM.activate(sigb, false, LayerWorkspaceMgr.noWorkspaces()).slice(0);
final INDArray activation3Reverse = activation3.dup();
if (rnnDataFormat == RNNFormat.NCW){
reverseColumnsInPlace(activation3Reverse);
}
else{
reverseColumnsInPlace(activation3Reverse.permute(1, 0));
}
assertEquals(activation3Reverse, activation1);
assertArrayEquals(activation3Reverse.shape(), activation1.shape());
assertEquals(activation3Reverse, activation1);
//test backprop now
final INDArray refBackGradientReccurrent =
@ -434,7 +488,12 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
final INDArray refEpsilon = backprop1.getSecond().dup();
final INDArray backEpsilon = backprop3.getSecond().dup();
if (rnnDataFormat == RNNFormat.NCW) {
reverseColumnsInPlace(refEpsilon.slice(0));
}
else{
reverseColumnsInPlace(refEpsilon.slice(0).permute(1, 0));
}
assertArrayEquals(backEpsilon.dup().data().asDouble(), refEpsilon.dup().data().asDouble(), 1e-6);
}
@ -477,10 +536,10 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.seed(12345).list()
.layer(0, new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder()
.gateActivationFunction(gateAfn).activation(Activation.TANH).nIn(2).nOut(2)
.gateActivationFunction(gateAfn).activation(Activation.TANH).nIn(2).nOut(2).dataFormat(rnnDataFormat)
.build())
.layer(1, new org.deeplearning4j.nn.conf.layers.RnnOutputLayer.Builder()
.lossFunction(LossFunctions.LossFunction.MSE).nIn(2).nOut(2)
.lossFunction(LossFunctions.LossFunction.MSE).nIn(2).nOut(2).dataFormat(rnnDataFormat)
.activation(Activation.TANH).build())
.build();
@ -492,7 +551,10 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest {
INDArray in = Nd4j.rand(new int[] {3, 2, 5});
INDArray labels = Nd4j.rand(new int[] {3, 2, 5});
if (rnnDataFormat == RNNFormat.NWC){
in = in.permute(0, 2, 1);
labels = labels.permute(0, 2, 1);
}
net.fit(in, labels);
}
}

View File

@ -21,11 +21,14 @@ import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -36,9 +39,17 @@ import java.util.Collections;
import static org.junit.Assert.assertEquals;
@RunWith(Parameterized.class)
public class MaskZeroLayerTest extends BaseDL4JTest {
private RNNFormat rnnDataFormat;
public MaskZeroLayerTest(RNNFormat rnnDataFormat){
this.rnnDataFormat = rnnDataFormat;
}
@Parameterized.Parameters
public static Object[] params(){
return RNNFormat.values();
}
@Test
public void activate() {
@ -57,7 +68,7 @@ public class MaskZeroLayerTest extends BaseDL4JTest {
.activation(Activation.IDENTITY)
.gateActivationFunction(Activation.IDENTITY)
.nIn(2)
.nOut(1)
.nOut(1).dataFormat(rnnDataFormat)
.build();
NeuralNetConfiguration conf = new NeuralNetConfiguration();
conf.setLayer(underlying);
@ -72,20 +83,25 @@ public class MaskZeroLayerTest extends BaseDL4JTest {
MaskZeroLayer l = new MaskZeroLayer(lstm, maskingValue);
INDArray input = Nd4j.create(Arrays.asList(ex1, ex2), new int[]{2, 2, 3});
if (rnnDataFormat == RNNFormat.NWC){
input = input.permute(0, 2, 1);
}
//WHEN
INDArray out = l.activate(input, true, LayerWorkspaceMgr.noWorkspaces());
if (rnnDataFormat == RNNFormat.NWC){
out = out.permute(0, 2,1);
}
//THEN output should only be incremented for the non-zero timesteps
INDArray firstExampleOutput = out.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.all());
INDArray secondExampleOutput = out.get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.all());
assertEquals(firstExampleOutput.getDouble(0), 0.0, 1e-6);
assertEquals(firstExampleOutput.getDouble(1), 1.0, 1e-6);
assertEquals(firstExampleOutput.getDouble(2), 2.0, 1e-6);
assertEquals(0.0, firstExampleOutput.getDouble(0), 1e-6);
assertEquals(1.0, firstExampleOutput.getDouble(1), 1e-6);
assertEquals(2.0, firstExampleOutput.getDouble(2), 1e-6);
assertEquals(secondExampleOutput.getDouble(0), 0.0, 1e-6);
assertEquals(secondExampleOutput.getDouble(1), 0.0, 1e-6);
assertEquals(secondExampleOutput.getDouble(2), 1.0, 1e-6);
assertEquals(0.0, secondExampleOutput.getDouble(0), 1e-6);
assertEquals(0.0, secondExampleOutput.getDouble(1), 1e-6);
assertEquals(1.0, secondExampleOutput.getDouble(2), 1e-6);
}
@ -94,7 +110,7 @@ public class MaskZeroLayerTest extends BaseDL4JTest {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.list()
.layer(new org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer.Builder()
.setMaskValue(0.0).setUnderlying(new LSTM.Builder().nIn(4).nOut(5).build()).build())
.setMaskValue(0.0).setUnderlying(new LSTM.Builder().nIn(4).nOut(5).dataFormat(rnnDataFormat).build()).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();

View File

@ -0,0 +1,394 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.nn.layers.recurrent;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import static org.junit.Assert.assertEquals;
@RunWith(Parameterized.class)
@AllArgsConstructor
public class RnnDataFormatTests extends BaseDL4JTest {
private boolean helpers;
private boolean lastTimeStep;
private boolean maskZeros;
@Parameterized.Parameters(name = "helpers={0},lastTimeStep={1},maskZero={2}")
public static List params(){
List<Object[]> ret = new ArrayList<>();
for (boolean helpers: new boolean[]{true, false})
for (boolean lastTimeStep: new boolean[]{true, false})
for (boolean maskZero: new boolean[]{true, false})
ret.add(new Object[]{helpers, lastTimeStep, maskZero});
return ret;
}
@Test
public void testSimpleRnn() {
try {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = "Helpers: " + helpers + ", lastTimeStep: " + lastTimeStep + ", maskZeros: " + maskZeros;
System.out.println(" --- " + msg + " ---");
INDArray inNCW = Nd4j.rand(DataType.FLOAT, 2, 3, 12);
INDArray labelsNWC = (lastTimeStep) ?TestUtils.randomOneHot(2, 10): TestUtils.randomOneHot(2 * 12, 10).reshape(2, 12, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getSimpleRnnNet(RNNFormat.NCW, true, lastTimeStep, maskZeros))
.net2(getSimpleRnnNet(RNNFormat.NCW, false, lastTimeStep, maskZeros))
.net3(getSimpleRnnNet(RNNFormat.NWC, true, lastTimeStep, maskZeros))
.net4(getSimpleRnnNet(RNNFormat.NWC, false, lastTimeStep, maskZeros))
.inNCW(inNCW)
.labelsNCW((lastTimeStep)? labelsNWC: labelsNWC.permute(0, 2, 1))
.labelsNWC(labelsNWC)
.testLayerIdx(1)
.build();
TestCase.testHelper(tc);
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testLSTM() {
try {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = "Helpers: " + helpers + ", lastTimeStep: " + lastTimeStep + ", maskZeros: " + maskZeros;
System.out.println(" --- " + msg + " ---");
INDArray inNCW = Nd4j.rand(DataType.FLOAT, 2, 3, 12);
INDArray labelsNWC = (lastTimeStep) ?TestUtils.randomOneHot(2, 10): TestUtils.randomOneHot(2 * 12, 10).reshape(2, 12, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getLstmNet(RNNFormat.NCW, true, lastTimeStep, maskZeros))
.net2(getLstmNet(RNNFormat.NCW, false, lastTimeStep, maskZeros))
.net3(getLstmNet(RNNFormat.NWC, true, lastTimeStep, maskZeros))
.net4(getLstmNet(RNNFormat.NWC, false, lastTimeStep, maskZeros))
.inNCW(inNCW)
.labelsNCW((lastTimeStep)? labelsNWC: labelsNWC.permute(0, 2, 1))
.labelsNWC(labelsNWC)
.testLayerIdx(1)
.build();
TestCase.testHelper(tc);
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testGraveLSTM() {
try {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = "Helpers: " + helpers + ", lastTimeStep: " + lastTimeStep + ", maskZeros: " + maskZeros;
System.out.println(" --- " + msg + " ---");
INDArray inNCW = Nd4j.rand(DataType.FLOAT, 2, 3, 12);
INDArray labelsNWC = (lastTimeStep) ?TestUtils.randomOneHot(2, 10): TestUtils.randomOneHot(2 * 12, 10).reshape(2, 12, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getGravesLstmNet(RNNFormat.NCW, true, lastTimeStep, maskZeros))
.net2(getGravesLstmNet(RNNFormat.NCW, false, lastTimeStep, maskZeros))
.net3(getGravesLstmNet(RNNFormat.NWC, true, lastTimeStep, maskZeros))
.net4(getGravesLstmNet(RNNFormat.NWC, false, lastTimeStep, maskZeros))
.inNCW(inNCW)
.labelsNCW((lastTimeStep)? labelsNWC: labelsNWC.permute(0, 2, 1))
.labelsNWC(labelsNWC)
.testLayerIdx(1)
.build();
TestCase.testHelper(tc);
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
@Test
public void testGraveBiLSTM() {
try {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = "Helpers: " + helpers + ", lastTimeStep: " + lastTimeStep + ", maskZeros: " + maskZeros;
System.out.println(" --- " + msg + " ---");
INDArray inNCW = Nd4j.rand(DataType.FLOAT, 2, 3, 12);
INDArray labelsNWC = (lastTimeStep) ?TestUtils.randomOneHot(2, 10): TestUtils.randomOneHot(2 * 12, 10).reshape(2, 12, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getGravesBidirectionalLstmNet(RNNFormat.NCW, true, lastTimeStep, maskZeros))
.net2(getGravesBidirectionalLstmNet(RNNFormat.NCW, false, lastTimeStep, maskZeros))
.net3(getGravesBidirectionalLstmNet(RNNFormat.NWC, true, lastTimeStep, maskZeros))
.net4(getGravesBidirectionalLstmNet(RNNFormat.NWC, false, lastTimeStep, maskZeros))
.inNCW(inNCW)
.labelsNCW((lastTimeStep)? labelsNWC: labelsNWC.permute(0, 2, 1))
.labelsNWC(labelsNWC)
.testLayerIdx(1)
.build();
TestCase.testHelper(tc);
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
private MultiLayerNetwork getGravesBidirectionalLstmNet(RNNFormat format, boolean setOnLayerAlso, boolean lastTimeStep, boolean maskZeros) {
if (setOnLayerAlso) {
return getNetWithLayer(new GravesBidirectionalLSTM.Builder().nOut(3)
.dataFormat(format).build(), format, lastTimeStep, maskZeros);
} else {
return getNetWithLayer(new GravesBidirectionalLSTM.Builder().nOut(3).build(), format, lastTimeStep, maskZeros);
}
}
private MultiLayerNetwork getGravesLstmNet(RNNFormat format, boolean setOnLayerAlso, boolean lastTimeStep, boolean maskZeros) {
if (setOnLayerAlso) {
return getNetWithLayer(new GravesLSTM.Builder().nOut(3)
.dataFormat(format).build(), format, lastTimeStep, maskZeros);
} else {
return getNetWithLayer(new GravesLSTM.Builder().nOut(3).build(), format, lastTimeStep, maskZeros);
}
}
private MultiLayerNetwork getLstmNet(RNNFormat format, boolean setOnLayerAlso, boolean lastTimeStep, boolean maskZeros) {
if (setOnLayerAlso) {
return getNetWithLayer(new LSTM.Builder().nOut(3)
.dataFormat(format).build(), format, lastTimeStep, maskZeros);
} else {
return getNetWithLayer(new LSTM.Builder().nOut(3).build(), format, lastTimeStep, maskZeros);
}
}
private MultiLayerNetwork getSimpleRnnNet(RNNFormat format, boolean setOnLayerAlso, boolean lastTimeStep, boolean maskZeros) {
if (setOnLayerAlso) {
return getNetWithLayer(new SimpleRnn.Builder().nOut(3)
.dataFormat(format).build(), format, lastTimeStep, maskZeros);
} else {
return getNetWithLayer(new SimpleRnn.Builder().nOut(3).build(), format, lastTimeStep, maskZeros);
}
}
private MultiLayerNetwork getNetWithLayer(Layer layer, RNNFormat format, boolean lastTimeStep, boolean maskZeros) {
if (maskZeros){
layer = new MaskZeroLayer.Builder().setMaskValue(0.).setUnderlying(layer).build();
}
if(lastTimeStep){
layer = new LastTimeStep(layer);
}
NeuralNetConfiguration.ListBuilder builder = new NeuralNetConfiguration.Builder()
.seed(12345)
.list()
.layer(new LSTM.Builder()
.nIn(3)
.activation(Activation.TANH)
.dataFormat(format)
.nOut(3)
.helperAllowFallback(false)
.build())
.layer(layer)
.layer(
(lastTimeStep)?new OutputLayer.Builder().activation(Activation.SOFTMAX).nOut(10).build():
new RnnOutputLayer.Builder().activation(Activation.SOFTMAX).nOut(10).dataFormat(format).build()
)
.setInputType(InputType.recurrent(3, 12, format));
MultiLayerNetwork net = new MultiLayerNetwork(builder.build());
net.init();
return net;
}
@AllArgsConstructor
@Data
@NoArgsConstructor
@Builder
private static class TestCase {
private String msg;
private MultiLayerNetwork net1;
private MultiLayerNetwork net2;
private MultiLayerNetwork net3;
private MultiLayerNetwork net4;
private INDArray inNCW;
private INDArray labelsNCW;
private INDArray labelsNWC;
private int testLayerIdx;
private boolean nwcOutput;
public static void testHelper(TestCase tc) {
tc.net2.params().assign(tc.net1.params());
tc.net3.params().assign(tc.net1.params());
tc.net4.params().assign(tc.net1.params());
INDArray inNCW = tc.inNCW;
INDArray inNWC = tc.inNCW.permute(0, 2, 1).dup();
INDArray l0_1 = tc.net1.feedForward(inNCW).get(tc.testLayerIdx + 1);
INDArray l0_2 = tc.net2.feedForward(inNCW).get(tc.testLayerIdx + 1);
INDArray l0_3 = tc.net3.feedForward(inNWC).get(tc.testLayerIdx + 1);
INDArray l0_4 = tc.net4.feedForward(inNWC).get(tc.testLayerIdx + 1);
boolean rank3Out = tc.labelsNCW.rank() == 3;
assertEquals(tc.msg, l0_1, l0_2);
if (rank3Out){
assertEquals(tc.msg, l0_1, l0_3.permute(0, 2, 1));
assertEquals(tc.msg, l0_1, l0_4.permute(0, 2, 1));
}
else{
assertEquals(tc.msg, l0_1, l0_3);
assertEquals(tc.msg, l0_1, l0_4);
}
INDArray out1 = tc.net1.output(inNCW);
INDArray out2 = tc.net2.output(inNCW);
INDArray out3 = tc.net3.output(inNWC);
INDArray out4 = tc.net4.output(inNWC);
assertEquals(tc.msg, out1, out2);
if (rank3Out){
assertEquals(tc.msg, out1, out3.permute(0, 2, 1)); //NWC to NCW
assertEquals(tc.msg, out1, out4.permute(0, 2, 1));
}
else{
assertEquals(tc.msg, out1, out3); //NWC to NCW
assertEquals(tc.msg, out1, out4);
}
//Test backprop
Pair<Gradient, INDArray> p1 = tc.net1.calculateGradients(inNCW, tc.labelsNCW, null, null);
Pair<Gradient, INDArray> p2 = tc.net2.calculateGradients(inNCW, tc.labelsNCW, null, null);
Pair<Gradient, INDArray> p3 = tc.net3.calculateGradients(inNWC, tc.labelsNWC, null, null);
Pair<Gradient, INDArray> p4 = tc.net4.calculateGradients(inNWC, tc.labelsNWC, null, null);
//Inpput gradients
assertEquals(tc.msg, p1.getSecond(), p2.getSecond());
assertEquals(tc.msg, p1.getSecond(), p3.getSecond().permute(0, 2, 1)); //Input gradients for NWC input are also in NWC format
assertEquals(tc.msg, p1.getSecond(), p4.getSecond().permute(0, 2, 1));
List<String> diff12 = differentGrads(p1.getFirst(), p2.getFirst());
List<String> diff13 = differentGrads(p1.getFirst(), p3.getFirst());
List<String> diff14 = differentGrads(p1.getFirst(), p4.getFirst());
assertEquals(tc.msg + " " + diff12, 0, diff12.size());
assertEquals(tc.msg + " " + diff13, 0, diff13.size());
assertEquals(tc.msg + " " + diff14, 0, diff14.size());
assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p2.getFirst().gradientForVariable());
assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p3.getFirst().gradientForVariable());
assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p4.getFirst().gradientForVariable());
tc.net1.fit(inNCW, tc.labelsNCW);
tc.net2.fit(inNCW, tc.labelsNCW);
tc.net3.fit(inNWC, tc.labelsNWC);
tc.net4.fit(inNWC, tc.labelsNWC);
assertEquals(tc.msg, tc.net1.params(), tc.net2.params());
assertEquals(tc.msg, tc.net1.params(), tc.net3.params());
assertEquals(tc.msg, tc.net1.params(), tc.net4.params());
//Test serialization
MultiLayerNetwork net1a = TestUtils.testModelSerialization(tc.net1);
MultiLayerNetwork net2a = TestUtils.testModelSerialization(tc.net2);
MultiLayerNetwork net3a = TestUtils.testModelSerialization(tc.net3);
MultiLayerNetwork net4a = TestUtils.testModelSerialization(tc.net4);
out1 = tc.net1.output(inNCW);
assertEquals(tc.msg, out1, net1a.output(inNCW));
assertEquals(tc.msg, out1, net2a.output(inNCW));
if (rank3Out) {
assertEquals(tc.msg, out1, net3a.output(inNWC).permute(0, 2, 1)); //NWC to NCW
assertEquals(tc.msg, out1, net4a.output(inNWC).permute(0, 2, 1));
}
else{
assertEquals(tc.msg, out1, net3a.output(inNWC)); //NWC to NCW
assertEquals(tc.msg, out1, net4a.output(inNWC));
}
}
}
private static List<String> differentGrads(Gradient g1, Gradient g2){
List<String> differs = new ArrayList<>();
Map<String,INDArray> m1 = g1.gradientForVariable();
Map<String,INDArray> m2 = g2.gradientForVariable();
for(String s : m1.keySet()){
INDArray a1 = m1.get(s);
INDArray a2 = m2.get(s);
if(!a1.equals(a2)){
differs.add(s);
}
}
return differs;
}
}

View File

@ -21,6 +21,7 @@ import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.LSTM;
@ -29,6 +30,8 @@ import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
@ -42,14 +45,25 @@ import static org.nd4j.linalg.activations.Activation.IDENTITY;
import static org.nd4j.linalg.activations.Activation.TANH;
import static org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction.MSE;
@RunWith(Parameterized.class)
public class TestLastTimeStepLayer extends BaseDL4JTest {
private RNNFormat rnnDataFormat;
public TestLastTimeStepLayer(RNNFormat rnnDataFormat){
this.rnnDataFormat = rnnDataFormat;
}
@Parameterized.Parameters(name="{0}")
public static Object[] params(){
return RNNFormat.values();
}
@Test
public void testLastTimeStepVertex() {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in")
.addLayer("lastTS", new LastTimeStep(new SimpleRnn.Builder()
.nIn(5).nOut(6).build()), "in")
.nIn(5).nOut(6).dataFormat(rnnDataFormat).build()), "in")
.setOutputs("lastTS")
.build();
@ -59,9 +73,22 @@ public class TestLastTimeStepLayer extends BaseDL4JTest {
//First: test without input mask array
Nd4j.getRandom().setSeed(12345);
Layer l = graph.getLayer("lastTS");
INDArray in = Nd4j.rand(new int[]{3, 5, 6});
INDArray in;
if (rnnDataFormat == RNNFormat.NCW){
in = Nd4j.rand(3, 5, 6);
}
else{
in = Nd4j.rand(3, 6, 5);
}
INDArray outUnderlying = ((LastTimeStepLayer)l).getUnderlying().activate(in, false, LayerWorkspaceMgr.noWorkspaces());
INDArray expOut = outUnderlying.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(5));
INDArray expOut;
if (rnnDataFormat == RNNFormat.NCW){
expOut = outUnderlying.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(5));
}
else{
expOut = outUnderlying.get(NDArrayIndex.all(), NDArrayIndex.point(5), NDArrayIndex.all());
}
//Forward pass:
@ -76,9 +103,17 @@ public class TestLastTimeStepLayer extends BaseDL4JTest {
graph.setLayerMaskArrays(new INDArray[]{inMask}, null);
expOut = Nd4j.zeros(3, 6);
if (rnnDataFormat == RNNFormat.NCW){
expOut.putRow(0, outUnderlying.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(2)));
expOut.putRow(1, outUnderlying.get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.point(3)));
expOut.putRow(2, outUnderlying.get(NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.point(4)));
}
else{
expOut.putRow(0, outUnderlying.get(NDArrayIndex.point(0), NDArrayIndex.point(2), NDArrayIndex.all()));
expOut.putRow(1, outUnderlying.get(NDArrayIndex.point(1), NDArrayIndex.point(3), NDArrayIndex.all()));
expOut.putRow(2, outUnderlying.get(NDArrayIndex.point(2), NDArrayIndex.point(4), NDArrayIndex.all()));
}
outFwd = l.activate(in, false, LayerWorkspaceMgr.noWorkspaces());
assertEquals(expOut, outFwd);
@ -97,9 +132,9 @@ public class TestLastTimeStepLayer extends BaseDL4JTest {
.seed(1234)
.graphBuilder()
.addInputs("in")
.setInputTypes(InputType.recurrent(1))
.setInputTypes(InputType.recurrent(1, rnnDataFormat))
.addLayer("RNN", new LastTimeStep(new LSTM.Builder()
.nOut(10)
.nOut(10).dataFormat(rnnDataFormat)
.build()), "in")
.addLayer("dense", new DenseLayer.Builder()
.nOut(10)
@ -120,7 +155,9 @@ public class TestLastTimeStepLayer extends BaseDL4JTest {
INDArray fm2 = Nd4j.zeros(1,24);
INDArray fm3 = Nd4j.zeros(1,24);
fm3.get(NDArrayIndex.point(0), NDArrayIndex.interval(0,5)).assign(1);
if (rnnDataFormat == RNNFormat.NWC){
f = f.permute(0, 2, 1);
}
INDArray[] out1 = cg.output(false, new INDArray[]{f}, new INDArray[]{fm1});
try {
cg.output(false, new INDArray[]{f}, new INDArray[]{fm2});

View File

@ -20,6 +20,7 @@ import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.dropout.TestDropout;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.conf.layers.LSTM;
@ -31,6 +32,8 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -41,13 +44,24 @@ import org.nd4j.linalg.primitives.Pair;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertTrue;
@RunWith(Parameterized.class)
public class TestRnnLayers extends BaseDL4JTest {
private RNNFormat rnnDataFormat;
public TestRnnLayers(RNNFormat rnnDataFormat){
this.rnnDataFormat = rnnDataFormat;
}
@Parameterized.Parameters
public static Object[] params(){
return RNNFormat.values();
}
@Test
public void testTimeStepIs3Dimensional() {
@ -58,8 +72,8 @@ public class TestRnnLayers extends BaseDL4JTest {
.updater(new NoOp())
.weightInit(WeightInit.XAVIER)
.list()
.layer(new SimpleRnn.Builder().nIn(nIn).nOut(3).build())
.layer(new LSTM.Builder().nIn(3).nOut(5).build())
.layer(new SimpleRnn.Builder().nIn(nIn).nOut(3).dataFormat(rnnDataFormat).build())
.layer(new LSTM.Builder().nIn(3).nOut(5).dataFormat(rnnDataFormat).build())
.layer(new RnnOutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX).build())
.build();
@ -70,9 +84,9 @@ public class TestRnnLayers extends BaseDL4JTest {
org.deeplearning4j.nn.layers.recurrent.SimpleRnn simpleRnn =
(org.deeplearning4j.nn.layers.recurrent.SimpleRnn) net.getLayer(0);
INDArray rnnInput3d = Nd4j.create(10, 12, 1);
INDArray rnnInput3d = (rnnDataFormat==RNNFormat.NCW)?Nd4j.create(10,12, 1):Nd4j.create(10, 1, 12);
INDArray simpleOut = simpleRnn.rnnTimeStep(rnnInput3d, LayerWorkspaceMgr.noWorkspaces());
assertTrue(Arrays.equals(simpleOut.shape(), new long[] {10, 3, 1}));
assertTrue(Arrays.equals(simpleOut.shape(), (rnnDataFormat==RNNFormat.NCW)?new long[] {10, 3, 1}:new long[]{10, 1, 3}));
INDArray rnnInput2d = Nd4j.create(10, 12);
try {
@ -84,9 +98,9 @@ public class TestRnnLayers extends BaseDL4JTest {
org.deeplearning4j.nn.layers.recurrent.LSTM lstm =
(org.deeplearning4j.nn.layers.recurrent.LSTM) net.getLayer(1);
INDArray lstmInput3d = Nd4j.create(10, 3, 1);
INDArray lstmInput3d = (rnnDataFormat==RNNFormat.NCW)?Nd4j.create(10, 3, 1):Nd4j.create(10, 1, 3);
INDArray lstmOut = lstm.rnnTimeStep(lstmInput3d, LayerWorkspaceMgr.noWorkspaces());
assertTrue(Arrays.equals(lstmOut.shape(), new long[] {10, 5, 1}));
assertTrue(Arrays.equals(lstmOut.shape(), (rnnDataFormat==RNNFormat.NCW)?new long[] {10, 5, 1}:new long[]{10, 1, 5}));
INDArray lstmInput2d = Nd4j.create(10, 3);
try {
@ -112,19 +126,19 @@ public class TestRnnLayers extends BaseDL4JTest {
TestDropout.CustomDropout cd = new TestDropout.CustomDropout();
switch (s){
case "graves":
layer = new GravesLSTM.Builder().activation(Activation.TANH).nIn(10).nOut(10).build();
layerD = new GravesLSTM.Builder().dropOut(0.5).activation(Activation.TANH).nIn(10).nOut(10).build();
layerD2 = new GravesLSTM.Builder().dropOut(cd).activation(Activation.TANH).nIn(10).nOut(10).build();
layer = new GravesLSTM.Builder().activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
layerD = new GravesLSTM.Builder().dropOut(0.5).activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
layerD2 = new GravesLSTM.Builder().dropOut(cd).activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
break;
case "lstm":
layer = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().activation(Activation.TANH).nIn(10).nOut(10).build();
layerD = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().dropOut(0.5).activation(Activation.TANH).nIn(10).nOut(10).build();
layerD2 = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().dropOut(cd).activation(Activation.TANH).nIn(10).nOut(10).build();
layer = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
layerD = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().dropOut(0.5).activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
layerD2 = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().dropOut(cd).activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
break;
case "simple":
layer = new SimpleRnn.Builder().activation(Activation.TANH).nIn(10).nOut(10).build();
layerD = new SimpleRnn.Builder().dropOut(0.5).activation(Activation.TANH).nIn(10).nOut(10).build();
layerD2 = new SimpleRnn.Builder().dropOut(cd).activation(Activation.TANH).nIn(10).nOut(10).build();
layer = new SimpleRnn.Builder().activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
layerD = new SimpleRnn.Builder().dropOut(0.5).activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
layerD2 = new SimpleRnn.Builder().dropOut(cd).activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build();
break;
default:
throw new RuntimeException(s);
@ -134,21 +148,21 @@ public class TestRnnLayers extends BaseDL4JTest {
.seed(12345)
.list()
.layer(layer)
.layer(new RnnOutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).build())
.layer(new RnnOutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).dataFormat(rnnDataFormat).build())
.build();
MultiLayerConfiguration confD = new NeuralNetConfiguration.Builder()
.seed(12345)
.list()
.layer(layerD)
.layer(new RnnOutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).build())
.layer(new RnnOutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).dataFormat(rnnDataFormat).build())
.build();
MultiLayerConfiguration confD2 = new NeuralNetConfiguration.Builder()
.seed(12345)
.list()
.layer(layerD2)
.layer(new RnnOutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).build())
.layer(new RnnOutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).dataFormat(rnnDataFormat).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
@ -178,7 +192,6 @@ public class TestRnnLayers extends BaseDL4JTest {
assertNotEquals(s, out2, out2D);
INDArray l = TestUtils.randomOneHotTimeSeries(3, 10, 10, 12345);
net.fit(f.dup(), l);
netD.fit(f.dup(), l);
assertNotEquals(s, net.params(), netD.params());
@ -205,14 +218,14 @@ public class TestRnnLayers extends BaseDL4JTest {
NeuralNetConfiguration.ListBuilder lb = new NeuralNetConfiguration.Builder()
.list()
.layer(new SimpleRnn.Builder().nIn(5).nOut(5).build());
.layer(new SimpleRnn.Builder().nIn(5).nOut(5).dataFormat(rnnDataFormat).build());
switch (i){
case 0:
lb.layer(new RnnOutputLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).nIn(5).nOut(5).build());
lb.layer(new RnnOutputLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).nIn(5).nOut(5).dataFormat(rnnDataFormat).build());
break;
case 1:
lb.layer(new RnnLossLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build());
lb.layer(new RnnLossLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).dataFormat(rnnDataFormat).build());
break;
default:
throw new RuntimeException();
@ -223,14 +236,14 @@ public class TestRnnLayers extends BaseDL4JTest {
net.init();
INDArray in = Nd4j.rand(DataType.FLOAT, 3, 5, 5);
INDArray l = TestUtils.randomOneHotTimeSeries(3, 5, 10);
INDArray l = TestUtils.randomOneHotTimeSeries(rnnDataFormat, 3, 5, 10, new Random(12345));
try{
net.fit(in,l);
} catch (Throwable t){
String msg = t.getMessage();
if(msg == null)
t.printStackTrace();
System.out.println(i);
assertTrue(msg, msg != null && msg.contains("sequence length") && msg.contains("input") && msg.contains("label"));
}

View File

@ -20,10 +20,13 @@ import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -36,8 +39,18 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.all;
import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
import static org.nd4j.linalg.indexing.NDArrayIndex.point;
@RunWith(Parameterized.class)
public class TestSimpleRnn extends BaseDL4JTest {
private RNNFormat rnnDataFormat;
public TestSimpleRnn(RNNFormat rnnDataFormat){
this.rnnDataFormat = rnnDataFormat;
}
@Parameterized.Parameters
public static Object[] params(){
return RNNFormat.values();
}
@Test
public void testSimpleRnn(){
Nd4j.getRandom().setSeed(12345);
@ -46,15 +59,21 @@ public class TestSimpleRnn extends BaseDL4JTest {
int nIn = 5;
int layerSize = 6;
int tsLength = 7;
INDArray in = Nd4j.rand(DataType.FLOAT, new int[]{m, nIn, tsLength});
// in.get(all(), all(), interval(1,tsLength)).assign(0);
INDArray in;
if (rnnDataFormat == RNNFormat.NCW){
in = Nd4j.rand(DataType.FLOAT, m, nIn, tsLength);
}
else{
in = Nd4j.rand(DataType.FLOAT, m, tsLength, nIn);
}
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.updater(new NoOp())
.weightInit(WeightInit.XAVIER)
.activation(Activation.TANH)
.list()
.layer(new SimpleRnn.Builder().nIn(nIn).nOut(layerSize).build())
.layer(new SimpleRnn.Builder().nIn(nIn).nOut(layerSize).dataFormat(rnnDataFormat).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
@ -68,7 +87,13 @@ public class TestSimpleRnn extends BaseDL4JTest {
INDArray outLast = null;
for( int i=0; i<tsLength; i++ ){
INDArray inCurrent = in.get(all(), all(), point(i));
INDArray inCurrent;
if (rnnDataFormat == RNNFormat.NCW){
inCurrent = in.get(all(), all(), point(i));
}
else{
inCurrent = in.get(all(), point(i), all());
}
INDArray outExpCurrent = inCurrent.mmul(w);
if(outLast != null){
@ -79,7 +104,13 @@ public class TestSimpleRnn extends BaseDL4JTest {
Transforms.tanh(outExpCurrent, false);
INDArray outActCurrent = out.get(all(), all(), point(i));
INDArray outActCurrent;
if (rnnDataFormat == RNNFormat.NCW){
outActCurrent = out.get(all(), all(), point(i));
}
else{
outActCurrent = out.get(all(), point(i), all());
}
assertEquals(String.valueOf(i), outExpCurrent, outActCurrent);
outLast = outExpCurrent;
@ -100,7 +131,7 @@ public class TestSimpleRnn extends BaseDL4JTest {
.weightInit(WeightInit.XAVIER)
.activation(Activation.TANH)
.list()
.layer(new SimpleRnn.Builder().nIn(nIn).nOut(layerSize)
.layer(new SimpleRnn.Builder().nIn(nIn).nOut(layerSize).dataFormat(rnnDataFormat)
.biasInit(100)
.build())
.build();

View File

@ -4,14 +4,21 @@ import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -22,8 +29,18 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
import static org.junit.Assert.assertEquals;
@RunWith(Parameterized.class)
public class TestTimeDistributed extends BaseDL4JTest {
private RNNFormat rnnDataFormat;
public TestTimeDistributed(RNNFormat rnnDataFormat){
this.rnnDataFormat = rnnDataFormat;
}
@Parameterized.Parameters
public static Object[] params(){
return RNNFormat.values();
}
@Test
public void testTimeDistributed(){
for(WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.ENABLED, WorkspaceMode.NONE}) {
@ -34,11 +51,11 @@ public class TestTimeDistributed extends BaseDL4JTest {
.seed(12345)
.updater(new Adam(0.1))
.list()
.layer(new LSTM.Builder().nIn(3).nOut(3).build())
.layer(new LSTM.Builder().nIn(3).nOut(3).dataFormat(rnnDataFormat).build())
.layer(new DenseLayer.Builder().nIn(3).nOut(3).activation(Activation.TANH).build())
.layer(new RnnOutputLayer.Builder().nIn(3).nOut(3).activation(Activation.SOFTMAX)
.layer(new RnnOutputLayer.Builder().nIn(3).nOut(3).activation(Activation.SOFTMAX).dataFormat(rnnDataFormat)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.setInputType(InputType.recurrent(3))
.setInputType(InputType.recurrent(3, rnnDataFormat))
.build();
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder()
@ -47,11 +64,11 @@ public class TestTimeDistributed extends BaseDL4JTest {
.seed(12345)
.updater(new Adam(0.1))
.list()
.layer(new LSTM.Builder().nIn(3).nOut(3).build())
.layer(new TimeDistributed(new DenseLayer.Builder().nIn(3).nOut(3).activation(Activation.TANH).build(), 2))
.layer(new RnnOutputLayer.Builder().nIn(3).nOut(3).activation(Activation.SOFTMAX)
.layer(new LSTM.Builder().nIn(3).nOut(3).dataFormat(rnnDataFormat).build())
.layer(new TimeDistributed(new DenseLayer.Builder().nIn(3).nOut(3).activation(Activation.TANH).build(), rnnDataFormat))
.layer(new RnnOutputLayer.Builder().nIn(3).nOut(3).activation(Activation.SOFTMAX).dataFormat(rnnDataFormat)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.setInputType(InputType.recurrent(3))
.setInputType(InputType.recurrent(3, rnnDataFormat))
.build();
MultiLayerNetwork net1 = new MultiLayerNetwork(conf1);
@ -62,13 +79,21 @@ public class TestTimeDistributed extends BaseDL4JTest {
for( int mb : new int[]{1, 5}) {
for(char inLabelOrder : new char[]{'c', 'f'}) {
INDArray in = Nd4j.rand(DataType.FLOAT, mb, 3, 5).dup(inLabelOrder);
if (rnnDataFormat == RNNFormat.NWC){
in = in.permute(0, 2, 1);
}
INDArray out1 = net1.output(in);
INDArray out2 = net2.output(in);
assertEquals(out1, out2);
INDArray labels = TestUtils.randomOneHotTimeSeries(mb, 3, 5).dup(inLabelOrder);
INDArray labels ;
if (rnnDataFormat == RNNFormat.NCW) {
labels = TestUtils.randomOneHotTimeSeries(mb, 3, 5).dup(inLabelOrder);
}else{
labels = TestUtils.randomOneHotTimeSeries(mb, 5, 3).dup(inLabelOrder);
}
DataSet ds = new DataSet(in, labels);
net1.fit(ds);
@ -85,4 +110,73 @@ public class TestTimeDistributed extends BaseDL4JTest {
}
}
}
@Test
public void testTimeDistributedDense(){
for( int rnnType=0; rnnType<3; rnnType++ ) {
for( int ffType=0; ffType<3; ffType++ ) {
Layer l0, l2;
switch (rnnType) {
case 0:
l0 = new LSTM.Builder().nOut(5).build();
l2 = new LSTM.Builder().nOut(5).build();
break;
case 1:
l0 = new SimpleRnn.Builder().nOut(5).build();
l2 = new SimpleRnn.Builder().nOut(5).build();
break;
case 2:
l0 = new Bidirectional(new LSTM.Builder().nOut(5).build());
l2 = new Bidirectional(new LSTM.Builder().nOut(5).build());
break;
default:
throw new RuntimeException("Not implemented: " + rnnType);
}
Layer l1;
switch (ffType){
case 0:
l1 = new DenseLayer.Builder().nOut(5).build();
break;
case 1:
l1 = new VariationalAutoencoder.Builder().nOut(5).encoderLayerSizes(5).decoderLayerSizes(5).build();
break;
case 2:
l1 = new AutoEncoder.Builder().nOut(5).build();
break;
default:
throw new RuntimeException("Not implemented: " + ffType);
}
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.activation(Activation.TANH)
.list()
.layer(l0)
.layer(l1)
.layer(l2)
.setInputType(InputType.recurrent(5, 9, rnnDataFormat))
.build();
BaseRecurrentLayer l0a;
BaseRecurrentLayer l2a;
if (rnnType < 2) {
l0a = (BaseRecurrentLayer) l0;
l2a = (BaseRecurrentLayer) l2;
} else {
l0a = (BaseRecurrentLayer) ((Bidirectional) l0).getFwd();
l2a = (BaseRecurrentLayer) ((Bidirectional) l2).getFwd();
}
assertEquals(rnnDataFormat, l0a.getRnnDataFormat());
assertEquals(rnnDataFormat, l2a.getRnnDataFormat());
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
INDArray in = Nd4j.rand(DataType.FLOAT, rnnDataFormat == RNNFormat.NCW ? new long[]{2, 5, 9} : new long[]{2, 9, 5} );
net.output(in);
}
}
}
}

View File

@ -771,7 +771,7 @@ public class MultiLayerTestRNN extends BaseDL4JTest {
.build();
fail("Exception expected");
} catch (IllegalStateException e){
// e.printStackTrace();
log.info(e.toString());
assertTrue(e.getMessage().contains("TBPTT") && e.getMessage().contains("validateTbpttConfig"));
}
}

View File

@ -408,8 +408,8 @@ public class TestVariableLengthTS extends BaseDL4JTest {
INDArray outRow2 = out2.get(NDArrayIndex.point(i), NDArrayIndex.all(),
NDArrayIndex.point(j));
for (int k = 0; k < nOut; k++) {
assertEquals(outRow.getDouble(k), 0.0, 0.0);
assertEquals(outRow2.getDouble(k), 0.0, 0.0);
assertEquals(0.0, outRow.getDouble(k), 0.0);
assertEquals(0.0, outRow2.getDouble(k), 0.0);
}
}
}

View File

@ -336,7 +336,7 @@ public class TestUpdaters extends BaseDL4JTest {
actualM[i] = Math.round(actualM[i] * 1e2) / 1e2;
}
assertEquals("Wrong weight gradient after first iteration's update", Arrays.equals(actualM, expectedM), true);
assertEquals("Wrong weight gradient after first iteration's update", Arrays.equals(expectedM, actualM), true);
}

View File

@ -159,14 +159,14 @@ public class RegressionTest050 extends BaseDL4JTest {
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
assertArrayEquals(new int[] {1, 1}, l0.getStride());
assertArrayEquals(new int[] {0, 0}, l0.getPadding());
assertEquals(l0.getConvolutionMode(), ConvolutionMode.Truncate); //Pre-0.7.0: no ConvolutionMode. Want to default to truncate here if not set
assertEquals(ConvolutionMode.Truncate, l0.getConvolutionMode()); //Pre-0.7.0: no ConvolutionMode. Want to default to truncate here if not set
SubsamplingLayer l1 = (SubsamplingLayer) conf.getConf(1).getLayer();
assertArrayEquals(new int[] {2, 2}, l1.getKernelSize());
assertArrayEquals(new int[] {1, 1}, l1.getStride());
assertArrayEquals(new int[] {0, 0}, l1.getPadding());
assertEquals(PoolingType.MAX, l1.getPoolingType());
assertEquals(l1.getConvolutionMode(), ConvolutionMode.Truncate); //Pre-0.7.0: no ConvolutionMode. Want to default to truncate here if not set
assertEquals(ConvolutionMode.Truncate, l1.getConvolutionMode()); //Pre-0.7.0: no ConvolutionMode. Want to default to truncate here if not set
OutputLayer l2 = (OutputLayer) conf.getConf(2).getLayer();
assertEquals("sigmoid", l2.getActivationFn().toString());

View File

@ -162,14 +162,14 @@ public class RegressionTest060 extends BaseDL4JTest {
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
assertArrayEquals(new int[] {1, 1}, l0.getStride());
assertArrayEquals(new int[] {0, 0}, l0.getPadding());
assertEquals(l0.getConvolutionMode(), ConvolutionMode.Truncate); //Pre-0.7.0: no ConvolutionMode. Want to default to truncate here if not set
assertEquals(ConvolutionMode.Truncate, l0.getConvolutionMode()); //Pre-0.7.0: no ConvolutionMode. Want to default to truncate here if not set
SubsamplingLayer l1 = (SubsamplingLayer) conf.getConf(1).getLayer();
assertArrayEquals(new int[] {2, 2}, l1.getKernelSize());
assertArrayEquals(new int[] {1, 1}, l1.getStride());
assertArrayEquals(new int[] {0, 0}, l1.getPadding());
assertEquals(PoolingType.MAX, l1.getPoolingType());
assertEquals(l1.getConvolutionMode(), ConvolutionMode.Truncate); //Pre-0.7.0: no ConvolutionMode. Want to default to truncate here if not set
assertEquals(ConvolutionMode.Truncate, l1.getConvolutionMode()); //Pre-0.7.0: no ConvolutionMode. Want to default to truncate here if not set
OutputLayer l2 = (OutputLayer) conf.getConf(2).getLayer();
assertEquals("sigmoid", l2.getActivationFn().toString());

View File

@ -162,7 +162,7 @@ public class RegressionTest071 extends BaseDL4JTest {
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
assertArrayEquals(new int[] {1, 1}, l0.getStride());
assertArrayEquals(new int[] {0, 0}, l0.getPadding());
assertEquals(l0.getConvolutionMode(), ConvolutionMode.Same);
assertEquals(ConvolutionMode.Same, l0.getConvolutionMode());
SubsamplingLayer l1 = (SubsamplingLayer) conf.getConf(1).getLayer();
assertArrayEquals(new int[] {2, 2}, l1.getKernelSize());

View File

@ -176,14 +176,14 @@ public class RegressionTest080 extends BaseDL4JTest {
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
assertArrayEquals(new int[] {1, 1}, l0.getStride());
assertArrayEquals(new int[] {0, 0}, l0.getPadding());
assertEquals(l0.getConvolutionMode(), ConvolutionMode.Same);
assertEquals(ConvolutionMode.Same, l0.getConvolutionMode());
SubsamplingLayer l1 = (SubsamplingLayer) conf.getConf(1).getLayer();
assertArrayEquals(new int[] {2, 2}, l1.getKernelSize());
assertArrayEquals(new int[] {1, 1}, l1.getStride());
assertArrayEquals(new int[] {0, 0}, l1.getPadding());
assertEquals(PoolingType.MAX, l1.getPoolingType());
assertEquals(l1.getConvolutionMode(), ConvolutionMode.Same);
assertEquals(ConvolutionMode.Same, l1.getConvolutionMode());
OutputLayer l2 = (OutputLayer) conf.getConf(2).getLayer();
assertTrue(l2.getActivationFn() instanceof ActivationSigmoid);

View File

@ -178,8 +178,6 @@ public abstract class BaseCudnnHelper {
}
}
protected static final int TENSOR_FORMAT = CUDNN_TENSOR_NCHW;
protected final DataType nd4jDataType;
protected final int dataType;
protected final int dataTypeSize;

View File

@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -22,6 +23,7 @@ import lombok.extern.slf4j.Slf4j;
import lombok.val;
import com.jakewharton.byteunits.BinaryByteUnit;
import org.bytedeco.javacpp.Pointer;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer.AlgoMode;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer.BwdDataAlgo;
@ -138,7 +140,21 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray weights, INDArray bias, INDArray delta, int[] kernel,
int[] strides, int[] pad, INDArray biasGradView, INDArray weightGradView, IActivation afn,
AlgoMode mode, BwdFilterAlgo bwdFilterAlgo, BwdDataAlgo bwdDataAlgo,
ConvolutionMode convolutionMode, int[] dilation, LayerWorkspaceMgr workspaceMgr) {
ConvolutionMode convolutionMode, int[] dilation, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) {
//AB 2020/04/21 - cuDNN does have NHWC support (with limitations) however I have been unable to get it working
// correctly on NHWC data, even after updating all descriptors, tensor format, etc.
//Therefore: all computation here is done in NCHW format only
//As of a future (next?) release we'll likely switch to C++ for cuDNN support
boolean origNHWC = false;
if(format == CNN2DFormat.NHWC){
input = input.permute(0,3,1,2); //NHWC to NCHW
delta = delta.permute(0,3,1,2);
origNHWC = true;
}
int TENSOR_FORMAT = CUDNN_TENSOR_NCHW;
int code;
val miniBatch = input.size(0);
@ -147,7 +163,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
val kH = weights.size(2);
val kW = weights.size(3);
CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, null);
CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, null, CNN2DFormat.NCHW); //Note hardcoded NCHW due to above
input = args.getInput();
val inH = input.size(2);
val inW = input.size(3);
@ -344,12 +360,30 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
interval(0, epsNext.size(3) - (args.isManualPadRight() ? 1 : 0)));
}
if(origNHWC){
epsNext = epsNext.permute(0,2,3,1); //NCHW to NHWC
}
return new Pair<>(retGradient, epsNext);
}
@Override
public INDArray preOutput(INDArray input, INDArray weights, INDArray bias, int[] kernel, int[] strides, int[] pad,
AlgoMode mode, FwdAlgo fwdAlgo, ConvolutionMode convolutionMode, int[] dilation, LayerWorkspaceMgr workspaceMgr) {
AlgoMode mode, FwdAlgo fwdAlgo, ConvolutionMode convolutionMode, int[] dilation, CNN2DFormat format,
LayerWorkspaceMgr workspaceMgr) {
//AB 2020/04/21 - cuDNN does have NHWC support (with limitations) however I have been unable to get it working
// correctly on NHWC data, even after updating all descriptors, tensor format, etc.
//Therefore: all computation here is done in NCHW format only
//As of a future (next?) release we'll likely switch to C++ for cuDNN support
boolean origNHWC = false;
if(format == CNN2DFormat.NHWC){
input = input.permute(0,3,1,2); //NHWC to NCHW
origNHWC = true;
}
int TENSOR_FORMAT = CUDNN_TENSOR_NCHW;
int code;
val miniBatch = input.size(0);
@ -358,7 +392,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
val kH = weights.size(2);
val kW = weights.size(3);
CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, null);
CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, null, CNN2DFormat.NCHW); //Note hardcoded NCHW due to above
input = args.getInput();
val inH = input.size(2);
val inW = input.size(3);
@ -499,6 +533,10 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
if (CudaEnvironment.getInstance().getConfiguration().isDebug())
context.syncOldStream();
if(origNHWC){
z = z.permute(0,2,3,1); //NCHW to NHWC
}
return z;
}
@ -593,7 +631,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
* @return
*/
public static CudnnForwardArgs getCudnnForwardArgs(INDArray input, int[] kernel, int[] strides, int[] padding, int[] dilation,
ConvolutionMode convolutionMode, PoolingType poolingType){
ConvolutionMode convolutionMode, PoolingType poolingType, CNN2DFormat format){
INDArray origInput = input;
//Check if we need to dup the input: views, non-contiguous, etc. CuDNN also seems to have has issues if strides
@ -602,16 +640,19 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
input = input.dup('c');
}
boolean nchw = format == CNN2DFormat.NCHW;
int hIdx = nchw ? 2 : 1;
int wIdx = nchw ? 3 : 2;
val inH = input.size(2);
val inW = input.size(3);
val inH = input.size(hIdx);
val inW = input.size(wIdx);
boolean manualPadBottom = false;
boolean manualPadRight = false;
int[] outSize;
if (convolutionMode == ConvolutionMode.Same) {
outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation); //Also performs validation
outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation, format); //Also performs validation
padding = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[] {(int) inH, (int) inW}, kernel, strides, dilation);
int[] padBottomRight = ConvolutionUtils.getSameModeBottomRightPadding(outSize, new int[] {(int) inH, (int) inW}, kernel, strides, dilation);
if(!Arrays.equals(padding, padBottomRight)){
@ -626,9 +667,17 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
manualPadRight = (padding[1] != padBottomRight[1]);
//NCHW format
val newShape = new long[]{input.size(0), input.size(1),
long[] newShape;
if(nchw){
newShape = new long[]{input.size(0), input.size(1),
input.size(2) + (manualPadBottom ? 1 : 0),
input.size(3) + (manualPadRight ? 1 : 0)};
} else {
newShape = new long[]{input.size(0),
input.size(1) + (manualPadBottom ? 1 : 0),
input.size(2) + (manualPadRight ? 1 : 0),
input.size(3)};
}
INDArray newInput;
if(poolingType == null || poolingType != PoolingType.MAX){
newInput = Nd4j.create(input.dataType(), newShape);
@ -638,15 +687,22 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
// if the 'real' (non-padding) values are all < 0, we take the real value, not the padding value
newInput = Nd4j.valueArrayOf(newShape, Double.NEGATIVE_INFINITY, input.dataType());
}
if(nchw){
newInput.put(new INDArrayIndex[]{all(), all(), interval(0,input.size(2)),
interval(0, input.size(3))}, input);
} else {
newInput.put(new INDArrayIndex[]{all(), interval(0,input.size(1)),
interval(0, input.size(2)), all()}, input);
}
input = newInput;
//Now: we've manually applied the "extra" bottom/right padding only - if required. Consequently, we
// now have the same amount of padding required for top/bottom, and left/right - which we'll let
// CuDNN handle
}
} else {
outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, padding, convolutionMode, dilation); //Also performs validation
outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, padding, convolutionMode, dilation, format); //Also performs validation
}
return new CudnnForwardArgs(manualPadBottom, manualPadRight, input, origInput, padding, outSize);

View File

@ -19,6 +19,7 @@ package org.deeplearning4j.nn.layers.convolution.subsampling;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.bytedeco.javacpp.Pointer;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.layers.PoolingType;
import org.deeplearning4j.nn.gradient.DefaultGradient;
@ -114,23 +115,29 @@ public class CudnnSubsamplingHelper extends BaseCudnnHelper implements Subsampli
@Override
public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, int[] kernel, int[] strides,
int[] pad, PoolingType poolingType, ConvolutionMode convolutionMode, int[] dilation, LayerWorkspaceMgr workspaceMgr) {
int[] pad, PoolingType poolingType, ConvolutionMode convolutionMode,
int[] dilation, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) {
if(dilation[0] != 1 || dilation[1] != 1){
//CuDNN doesn't support dilated subsampling
return null;
}
boolean nchw = format == CNN2DFormat.NCHW;
int chIdx = nchw ? 1 : 3;
int hIdx = nchw ? 2 : 1;
int wIdx = nchw ? 3 : 2;
//We require the output as one of the arguments for backprop here
//TODO we could add cache mode support here somehow...
INDArray reduced = activate(input, true, kernel, strides, pad, poolingType, convolutionMode, dilation, workspaceMgr);
INDArray reduced = activate(input, true, kernel, strides, pad, poolingType, convolutionMode, dilation, format, workspaceMgr);
val miniBatch = input.size(0);
val depth = input.size(1);
val depth = input.size(chIdx);
CudnnConvolutionHelper.CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, poolingType);
CudnnConvolutionHelper.CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, poolingType, format);
input = args.getInput();
val inH = input.size(2);
val inW = input.size(3);
val inH = input.size(hIdx);
val inW = input.size(wIdx);
val srcStride = input.stride();
int[] outSize = args.getOutSize();
int outH = outSize[0];
@ -160,23 +167,26 @@ public class CudnnSubsamplingHelper extends BaseCudnnHelper implements Subsampli
epsilon = epsilon.dup('c');
}
input = input.dup();
val deltaStride = epsilon.stride();
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW,
(int) srcStride[0], (int) srcStride[1], (int) srcStride[2], (int) srcStride[3]));
(int) srcStride[0], (int) srcStride[chIdx], (int) srcStride[hIdx], (int) srcStride[wIdx]));
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.deltaTensorDesc, dataType, (int) miniBatch, (int) depth, (int) outH, (int) outW,
(int) deltaStride[0], (int) deltaStride[1], (int) deltaStride[2], (int) deltaStride[3]));
(int) deltaStride[0], (int) deltaStride[chIdx], (int) deltaStride[hIdx], (int) deltaStride[wIdx]));
checkCudnn(cudnnSetPooling2dDescriptor(cudnnContext.poolingDesc, poolingMode, CUDNN_PROPAGATE_NAN, kernel[0],
kernel[1], pad[0], pad[1], strides[0], strides[1]));
INDArray outEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), new long[] {(int) miniBatch, (int) depth, (int) inH, (int) inW}, 'c');
long[] outEpsShape = nchw ? new long[] {miniBatch, depth, inH, inW} : new long[] {miniBatch, inH, inW, depth};
INDArray outEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), outEpsShape, 'c');
val dstStride = outEpsilon.stride();
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW,
(int) dstStride[0], (int) dstStride[1], (int) dstStride[2], (int) dstStride[3]));
(int) dstStride[0], (int) dstStride[chIdx], (int) dstStride[hIdx], (int) dstStride[wIdx]));
Allocator allocator = AtomicAllocator.getInstance();
CudaContext context = allocator.getFlowController().prepareAction(input, epsilon, reduced, outEpsilon);
@ -198,9 +208,16 @@ public class CudnnSubsamplingHelper extends BaseCudnnHelper implements Subsampli
//Note that: if we had to manually pad for SAME mode, we have to 'undo' this manual padding for the epsilon
// we return. The returned epsilon (i.e., dL/dIn array) has to be the same shape as the *original* input.
if(args.isManualPadBottom() || args.isManualPadRight()) {
if(nchw){
outEpsilon = outEpsilon.get(all(), all(),
interval(0, outEpsilon.size(2) - (args.isManualPadBottom() ? 1 : 0)),
interval(0, outEpsilon.size(3) - (args.isManualPadRight() ? 1 : 0)));
} else {
outEpsilon = outEpsilon.get(all(),
interval(0, outEpsilon.size(1) - (args.isManualPadBottom() ? 1 : 0)),
interval(0, outEpsilon.size(2) - (args.isManualPadRight() ? 1 : 0)),
all());
}
}
return new Pair<>(retGradient, outEpsilon);
@ -209,19 +226,24 @@ public class CudnnSubsamplingHelper extends BaseCudnnHelper implements Subsampli
@Override
public INDArray activate(INDArray input, boolean training, int[] kernel, int[] strides, int[] pad,
PoolingType poolingType, ConvolutionMode convolutionMode, int[] dilation, LayerWorkspaceMgr workspaceMgr) {
PoolingType poolingType, ConvolutionMode convolutionMode, int[] dilation, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) {
if(dilation[0] != 1 || dilation[1] != 1){
//CuDNN doesn't support dilated subsampling
return null;
}
val miniBatch = input.size(0);
val inDepth = input.size(1);
boolean nchw = format == CNN2DFormat.NCHW;
int chIdx = nchw ? 1 : 3;
int hIdx = nchw ? 2 : 1;
int wIdx = nchw ? 3 : 2;
CudnnConvolutionHelper.CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, poolingType);
val miniBatch = input.size(0);
val inDepth = input.size(nchw ? 1 : 3);
CudnnConvolutionHelper.CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, poolingType, format);
input = args.getInput();
val inH = input.size(2);
val inW = input.size(3);
val inH = input.size(nchw ? 2 : 1);
val inW = input.size(nchw ? 3 : 2);
val srcStride = input.stride();
val outSize = args.getOutSize();
int outH = outSize[0];
@ -246,13 +268,14 @@ public class CudnnSubsamplingHelper extends BaseCudnnHelper implements Subsampli
checkCudnn(cudnnSetPooling2dDescriptor(cudnnContext.poolingDesc, poolingMode, CUDNN_PROPAGATE_NAN, kernel[0],
kernel[1], pad[0], pad[1], strides[0], strides[1]));
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, (int) miniBatch, (int) inDepth, (int) inH, (int) inW,
(int) srcStride[0], (int) srcStride[1], (int) srcStride[2], (int) srcStride[3]));
(int) srcStride[0], (int) srcStride[chIdx], (int) srcStride[hIdx], (int) srcStride[wIdx]));
INDArray reduced = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), new long[] {(int) miniBatch, (int) inDepth, outH, outW}, 'c');
long[] outShape = nchw ? new long[] {miniBatch, inDepth, outH, outW} : new long[] {miniBatch, outH, outW, inDepth};
INDArray reduced = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), outShape, 'c');
val dstStride = reduced.stride();
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, (int) miniBatch, (int) inDepth, (int) outH, (int) outW,
(int) dstStride[0], (int) dstStride[1], (int) dstStride[2], (int) dstStride[3]));
(int) dstStride[0], (int) dstStride[chIdx], (int) dstStride[hIdx], (int) dstStride[wIdx]));
Allocator allocator = AtomicAllocator.getInstance();
CudaContext context = allocator.getFlowController().prepareAction(input, reduced);

View File

@ -19,6 +19,7 @@ package org.deeplearning4j.nn.layers.normalization;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.bytedeco.javacpp.Pointer;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseCudnnHelper;
@ -124,12 +125,21 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba
@Override
public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, long[] shape, INDArray gamma, INDArray beta,
INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr layerWorkspaceMgr) {
INDArray dGammaView, INDArray dBetaView, double eps, CNN2DFormat format, LayerWorkspaceMgr layerWorkspaceMgr) {
boolean nchw = format == CNN2DFormat.NCHW;
this.eps = eps;
int cudnnTensorFormat = nchw ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC;
int chIdx = nchw ? 1 : 3;
int hIdx = nchw ? 2 : 1;
int wIdx = nchw ? 3 : 2;
val miniBatch = (int) input.size(0);
val depth = (int) input.size(1);
val inH = (int) input.size(2);
val inW = (int) input.size(3);
val depth = (int) input.size(chIdx);
val inH = (int) input.size(hIdx);
val inW = (int) input.size(wIdx);
final boolean isHalf = (input.dataType() == DataType.HALF);
INDArray gammaOrig = null;
@ -164,16 +174,17 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba
((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW,
(int) srcStride[0], (int) srcStride[1], (int) srcStride[2], (int) srcStride[3]));
(int) srcStride[0], (int) srcStride[chIdx], (int) srcStride[hIdx], (int) srcStride[wIdx]));
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.deltaTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW,
(int) deltaStride[0], (int) deltaStride[1], (int) deltaStride[2], (int) deltaStride[3]));
(int) deltaStride[0], (int) deltaStride[chIdx], (int) deltaStride[hIdx], (int) deltaStride[wIdx]));
INDArray nextEpsilon = layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), new long[] {miniBatch, depth, inH, inW}, 'c');
long[] nextEpsShape = nchw ? new long[] {miniBatch, depth, inH, inW} : new long[] {miniBatch, inH, inW, depth};
INDArray nextEpsilon = layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), nextEpsShape, 'c');
val dstStride = ArrayUtil.toInts(nextEpsilon.stride());
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, depth, inH, inW,
dstStride[0], dstStride[1], dstStride[2], dstStride[3]));
checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, TENSOR_FORMAT, toCudnnDataType(gamma.data().dataType()), (int)shape[0],
dstStride[0], dstStride[chIdx], dstStride[hIdx], dstStride[wIdx]));
checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, cudnnTensorFormat, toCudnnDataType(gamma.data().dataType()), (int)shape[0],
(int)shape[1], shape.length > 2 ? (int)shape[2] : 1, shape.length > 3 ? (int)shape[3] : 1));
Allocator allocator = AtomicAllocator.getInstance();
@ -215,9 +226,15 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba
@Override
public INDArray preOutput(INDArray x, boolean training, long[] shape, INDArray gamma, INDArray beta, INDArray mean,
INDArray var, double decay, double eps, LayerWorkspaceMgr workspaceMgr) {
INDArray var, double decay, double eps, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) {
boolean nchw = format == CNN2DFormat.NCHW;
int cudnnTensorFormat = nchw ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC;
int chIdx = nchw ? 1 : 3;
int hIdx = nchw ? 2 : 1;
int wIdx = nchw ? 3 : 2;
this.eps = eps;
final boolean isHalf = (x.dataType() == DataType.HALF);
final boolean isHalf = (x.dataType() == DataType.FLOAT16);
INDArray origGamma = gamma;
INDArray origBeta = beta;
INDArray origMean = mean;
@ -238,21 +255,22 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba
decay = 0.0; //From cudnn docs: runningMean = newMean*factor + runningMean*(1-factor). -> 0 = "in-place modification of running mean disabled"
val miniBatch = (int) x.size(0);
val inDepth = (int) x.size(1);
val inH = (int) x.size(2);
val inW = (int) x.size(3);
val inDepth = (int) x.size(chIdx);
val inH = (int) x.size(hIdx);
val inW = (int) x.size(wIdx);
val srcStride = ArrayUtil.toInts(x.stride());
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, miniBatch, inDepth, inH, inW,
srcStride[0], srcStride[1], srcStride[2], srcStride[3]));
srcStride[0], srcStride[chIdx], srcStride[hIdx], srcStride[wIdx]));
INDArray activations = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, x.dataType(), new long[] {miniBatch, inDepth, inH, inW}, 'c');
long[] actShape = nchw ? new long[] {miniBatch, inDepth, inH, inW} : new long[] {miniBatch, inH, inW, inDepth};
INDArray activations = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, x.dataType(), actShape, 'c');
val dstStride = ArrayUtil.toInts(activations.stride());
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, inDepth, inH, inW,
dstStride[0], dstStride[1], dstStride[2], dstStride[3]));
dstStride[0], dstStride[chIdx], dstStride[hIdx], dstStride[wIdx]));
checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, TENSOR_FORMAT, toCudnnDataType(mean.data().dataType()), (int)shape[0],
checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, cudnnTensorFormat, toCudnnDataType(mean.data().dataType()), (int)shape[0],
(int)shape[1], shape.length > 2 ? (int)shape[2] : 1, shape.length > 3 ? (int)shape[3] : 1));
Allocator allocator = AtomicAllocator.getInstance();

View File

@ -16,74 +16,131 @@
package org.deeplearning4j;
import org.apache.commons.compress.utils.IOUtils;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.convolution.ConvolutionLayer;
import org.deeplearning4j.nn.layers.convolution.subsampling.SubsamplingLayer;
import org.deeplearning4j.nn.layers.normalization.BatchNormalization;
import org.deeplearning4j.nn.layers.normalization.LocalResponseNormalization;
import org.deeplearning4j.nn.layers.recurrent.LSTM;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.regularization.L1Regularization;
import org.nd4j.linalg.learning.regularization.L2Regularization;
import org.nd4j.linalg.learning.regularization.Regularization;
import org.nd4j.linalg.learning.regularization.WeightDecay;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.*;
import java.lang.reflect.Field;
import java.util.List;
import java.util.Random;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
public class TestUtils {
public static MultiLayerNetwork testModelSerialization(MultiLayerNetwork net){
MultiLayerNetwork restored;
try {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ModelSerializer.writeModel(net, baos, true);
byte[] bytes = baos.toByteArray();
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
MultiLayerNetwork restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);
restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);
assertEquals(net.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations());
assertEquals(net.params(), restored.params());
return restored;
} catch (IOException e){
//Should never happen
throw new RuntimeException(e);
}
//Also check the MultiLayerConfiguration is serializable (required by Spark etc)
MultiLayerConfiguration conf = net.getLayerWiseConfigurations();
serializeDeserializeJava(conf);
return restored;
}
public static ComputationGraph testModelSerialization(ComputationGraph net){
ComputationGraph restored;
try {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ModelSerializer.writeModel(net, baos, true);
byte[] bytes = baos.toByteArray();
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
ComputationGraph restored = ModelSerializer.restoreComputationGraph(bais, true);
restored = ModelSerializer.restoreComputationGraph(bais, true);
assertEquals(net.getConfiguration(), restored.getConfiguration());
assertEquals(net.params(), restored.params());
return restored;
} catch (IOException e){
//Should never happen
throw new RuntimeException(e);
}
//Also check the ComputationGraphConfiguration is serializable (required by Spark etc)
ComputationGraphConfiguration conf = net.getConfiguration();
serializeDeserializeJava(conf);
return restored;
}
public static INDArray randomOneHot(int examples, int nOut){
private static <T> T serializeDeserializeJava(T object){
byte[] bytes;
try(ByteArrayOutputStream baos = new ByteArrayOutputStream(); ObjectOutputStream oos = new ObjectOutputStream(baos)){
oos.writeObject(object);
oos.close();
bytes = baos.toByteArray();
} catch (IOException e){
//Should never happen
throw new RuntimeException(e);
}
T out;
try(ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(bytes))){
out = (T)ois.readObject();
} catch (IOException | ClassNotFoundException e){
throw new RuntimeException(e);
}
assertEquals(object, out);
return out;
}
public static INDArray randomOneHot(long examples, long nOut){
return randomOneHot(examples, nOut, new Random(12345));
}
public static INDArray randomOneHot(int examples, int nOut, long rngSeed){
public static INDArray randomOneHot(DataType dataType, long examples, long nOut){
return randomOneHot(dataType, examples, nOut, new Random(12345));
}
public static INDArray randomOneHot(long examples, long nOut, long rngSeed){
return randomOneHot(examples, nOut, new Random(rngSeed));
}
public static INDArray randomOneHot(int examples, int nOut, Random rng){
INDArray arr = Nd4j.create(examples, nOut);
public static INDArray randomOneHot(long examples, long nOut, Random rng) {
return randomOneHot(Nd4j.defaultFloatingPointType(), examples,nOut, rng);
}
public static INDArray randomOneHot(DataType dataType, long examples, long nOut, Random rng){
INDArray arr = Nd4j.create(dataType, examples, nOut);
for( int i=0; i<examples; i++ ){
arr.putScalar(i, rng.nextInt(nOut), 1.0);
arr.putScalar(i, rng.nextInt((int) nOut), 1.0);
}
return arr;
}
@ -115,4 +172,143 @@ public class TestUtils {
Nd4j.getExecutioner().exec(new BernoulliDistribution(ret, p));
return ret;
}
public static void writeStreamToFile(File out, InputStream is) throws IOException {
byte[] b = IOUtils.toByteArray(is);
try (OutputStream os = new BufferedOutputStream(new FileOutputStream(out))) {
os.write(b);
}
}
public static L1Regularization getL1Reg(List<Regularization> l){
for(Regularization r : l){
if(r instanceof L1Regularization){
return (L1Regularization) r;
}
}
return null;
}
public static L2Regularization getL2Reg(BaseLayer baseLayer){
return getL2Reg(baseLayer.getRegularization());
}
public static L2Regularization getL2Reg(List<Regularization> l){
for(Regularization r : l){
if(r instanceof L2Regularization){
return (L2Regularization) r;
}
}
return null;
}
public static WeightDecay getWeightDecayReg(BaseLayer bl){
return getWeightDecayReg(bl.getRegularization());
}
public static WeightDecay getWeightDecayReg(List<Regularization> l){
for(Regularization r : l){
if(r instanceof WeightDecay){
return (WeightDecay) r;
}
}
return null;
}
public static double getL1(BaseLayer layer) {
List<Regularization> l = layer.getRegularization();
return getL1(l);
}
public static double getL1(List<Regularization> l){
L1Regularization l1Reg = null;
for(Regularization reg : l){
if(reg instanceof L1Regularization)
l1Reg = (L1Regularization) reg;
}
assertNotNull(l1Reg);
return l1Reg.getL1().valueAt(0,0);
}
public static double getL2(BaseLayer layer) {
List<Regularization> l = layer.getRegularization();
return getL2(l);
}
public static double getL2(List<Regularization> l){
L2Regularization l2Reg = null;
for(Regularization reg : l){
if(reg instanceof L2Regularization)
l2Reg = (L2Regularization) reg;
}
assertNotNull(l2Reg);
return l2Reg.getL2().valueAt(0,0);
}
public static double getL1(AbstractSameDiffLayer layer){
return getL1(layer.getRegularization());
}
public static double getL2(AbstractSameDiffLayer layer){
return getL2(layer.getRegularization());
}
public static double getWeightDecay(BaseLayer layer) {
return getWeightDecayReg(layer.getRegularization()).getCoeff().valueAt(0,0);
}
public static void removeHelper(Layer layer) throws Exception {
removeHelpers(new Layer[]{layer});
}
public static void removeHelpers(Layer[] layers) throws Exception {
for(Layer l : layers){
if(l instanceof ConvolutionLayer){
Field f1 = ConvolutionLayer.class.getDeclaredField("helper");
f1.setAccessible(true);
f1.set(l, null);
} else if(l instanceof SubsamplingLayer){
Field f2 = SubsamplingLayer.class.getDeclaredField("helper");
f2.setAccessible(true);
f2.set(l, null);
} else if(l instanceof BatchNormalization) {
Field f3 = BatchNormalization.class.getDeclaredField("helper");
f3.setAccessible(true);
f3.set(l, null);
} else if(l instanceof LSTM){
Field f4 = LSTM.class.getDeclaredField("helper");
f4.setAccessible(true);
f4.set(l, null);
} else if(l instanceof LocalResponseNormalization){
Field f5 = LocalResponseNormalization.class.getDeclaredField("helper");
f5.setAccessible(true);
f5.set(l, null);
}
if(l.getHelper() != null){
throw new IllegalStateException("Did not remove helper for layer: " + l.getClass().getSimpleName());
}
}
}
public static void assertHelperPresent(Layer layer){
}
public static void assertHelpersPresent(Layer[] layers) throws Exception {
for(Layer l : layers){
//Don't use instanceof here - there are sub conv subclasses
if(l.getClass() == ConvolutionLayer.class || l instanceof SubsamplingLayer || l instanceof BatchNormalization || l instanceof LSTM){
Preconditions.checkNotNull(l.getHelper(), l.conf().getLayer().getLayerName());
}
}
}
public static void assertHelpersAbsent(Layer[] layers) throws Exception {
for(Layer l : layers){
Preconditions.checkState(l.getHelper() == null, l.conf().getLayer().getLayerName());
}
}
}

View File

@ -212,7 +212,7 @@ public class TestConvolution extends BaseDL4JTest {
ComputationGraph model = KerasModelImport.importKerasModelAndWeights( fExtracted.getAbsolutePath(), new int[]{inSize, inSize, 3}, false);
model = model.convertDataType(DataType.DOUBLE);
INDArray in = Nd4j.rand(DataType.DOUBLE, new int[]{1, 3, inSize, inSize});
INDArray in = Nd4j.rand(DataType.DOUBLE, new int[]{1, inSize, inSize, 3}); //Keras import model -> NHWC
CuDNNTestUtils.assertHelpersPresent(model.getLayers());
Map<String,INDArray> withCudnn = model.feedForward(in, false);

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.datasets.fetchers;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.deeplearning4j.base.EmnistFetcher;
@ -36,6 +37,7 @@ import java.util.Random;
* @author Alex Black
*
*/
@Slf4j
public class EmnistDataFetcher extends MnistDataFetcher implements DataSetFetcher {
protected EmnistFetcher fetcher;
@ -64,7 +66,7 @@ public class EmnistDataFetcher extends MnistDataFetcher implements DataSetFetche
try {
man = new MnistManager(images, labels, totalExamples);
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
FileUtils.deleteDirectory(new File(EMNIST_ROOT));
new EmnistFetcher(dataSet).downloadAndUntar();
man = new MnistManager(images, labels, totalExamples);

View File

@ -687,9 +687,7 @@ public class BarnesHutTsne implements Model {
* @throws IOException
*/
public void saveAsFile(List<String> labels, String path) throws IOException {
BufferedWriter write = null;
try {
write = new BufferedWriter(new FileWriter(new File(path)));
try (BufferedWriter write = new BufferedWriter(new FileWriter(new File(path)))) {
for (int i = 0; i < Y.rows(); i++) {
if (i >= labels.size())
break;
@ -711,17 +709,11 @@ public class BarnesHutTsne implements Model {
}
write.flush();
write.close();
} finally {
if (write != null)
write.close();
}
}
public void saveAsFile(String path) throws IOException {
BufferedWriter write = null;
try {
write = new BufferedWriter(new FileWriter(new File(path)));
try (BufferedWriter write = new BufferedWriter(new FileWriter(new File(path)))) {
for (int i = 0; i < Y.rows(); i++) {
StringBuilder sb = new StringBuilder();
INDArray wordVector = Y.getRow(i);
@ -734,10 +726,6 @@ public class BarnesHutTsne implements Model {
write.write(sb.toString());
}
write.flush();
write.close();
} finally {
if (write != null)
write.close();
}
}
/**

View File

@ -113,6 +113,12 @@
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-python</artifactId>
<version>${datavec.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
<profiles>

View File

@ -60,7 +60,7 @@ public class Hdf5Archive implements Closeable {
/* This is necessary for the call to the BytePointer constructor below. */
Loader.load(org.bytedeco.hdf5.global.hdf5.class);
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
}
}

View File

@ -19,6 +19,9 @@ package org.deeplearning4j.nn.modelimport.keras.layers;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Convolution3D;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
@ -121,27 +124,29 @@ public class KerasInput extends KerasLayer {
InputType myInputType;
switch (this.inputShape.length) {
case 1:
myInputType = new InputType.InputTypeFeedForward(this.inputShape[0]);
myInputType = new InputType.InputTypeFeedForward(this.inputShape[0], null);
break;
case 2:
if(this.dimOrder != null) {
System.out.println("Dim order: " + this.dimOrder);
System.out.println("Input shape: " + ArrayUtils.toString(this.inputShape));
switch (this.dimOrder) {
case TENSORFLOW: //NWC == channels_last
myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0]);
myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0], RNNFormat.NWC);
break;
case THEANO: //NCW == channels_first
myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]);
myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1], RNNFormat.NCW);
break;
case NONE:
//Assume RNN in [mb, seqLen, size] format
myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]);
myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0], RNNFormat.NWC);
break;
default:
throw new IllegalStateException("Unknown/not supported dimension ordering: " + this.dimOrder);
}
} else {
//Assume RNN in [mb, seqLen, size] format
myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]);
myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0], RNNFormat.NWC);
}
break;
@ -150,17 +155,17 @@ public class KerasInput extends KerasLayer {
case TENSORFLOW:
/* TensorFlow convolutional input: # rows, # cols, # channels */
myInputType = new InputType.InputTypeConvolutional(this.inputShape[0], this.inputShape[1],
this.inputShape[2]);
this.inputShape[2], CNN2DFormat.NHWC);
break;
case THEANO:
/* Theano convolutional input: # channels, # rows, # cols */
myInputType = new InputType.InputTypeConvolutional(this.inputShape[1], this.inputShape[2],
this.inputShape[0]);
this.inputShape[0], CNN2DFormat.NCHW);
break;
default:
this.dimOrder = DimOrder.THEANO;
myInputType = new InputType.InputTypeConvolutional(this.inputShape[1], this.inputShape[2],
this.inputShape[0]);
this.inputShape[0], CNN2DFormat.NCHW);
log.warn("Couldn't determine dim ordering / data format from model file. Older Keras " +
"versions may come without specified backend, in which case we assume the model was " +
"built with theano." );

View File

@ -20,6 +20,7 @@ import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
@ -65,6 +66,9 @@ public class TFOpLayer extends Layer {
long[] shape = inputType.getShape(true);
TFOpLayerImpl tempLayer = new TFOpLayerImpl(nodeDef, constants, null, null);
long[] outputShape = tempLayer.getOutputShape(shape);
if (outputShape.length == 3){
return InputType.recurrent(outputShape[2], outputShape[1], RNNFormat.NWC);
}
return InputType.inferInputType(Nd4j.create(outputShape));
}

View File

@ -125,17 +125,9 @@ public class TFOpLayerImpl extends AbstractLayer<TFOpLayer> {
}
private INDArray runGraph(INDArray input){
if (input.rank() == 3){
// TODO make this a preprocessor
input = input.permute(0, 2, 1);
}
Map<String, INDArray> inputMap = new HashMap<>();
inputMap.put(inputNames.get(0), input);
INDArray out = graphRunnerService.run(inputMap).values().toArray(new INDArray[0])[0];
if (out.rank() == 3){
out = out.permute(0, 2, 1); // TODO post-processing?
}
return out;
}

View File

@ -22,6 +22,7 @@ import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
@ -94,7 +95,6 @@ public class KerasConvolution1D extends KerasConvolution {
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
enforceTrainingConfig, conf, kerasMajorVersion);
Convolution1DLayer.Builder builder = new Convolution1DLayer.Builder().name(this.layerName)
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
.activation(getIActivationFromConfig(layerConfig, conf))
@ -103,7 +103,7 @@ public class KerasConvolution1D extends KerasConvolution {
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
.kernelSize(getKernelSizeFromConfig(layerConfig, 1, conf, kerasMajorVersion)[0])
.hasBias(hasBias)
.stride(getStrideFromConfig(layerConfig, 1, conf)[0]);
.stride(getStrideFromConfig(layerConfig, 1, conf)[0]).rnnDataFormat(dimOrder == DimOrder.TENSORFLOW? RNNFormat.NWC: RNNFormat.NCW);
int[] padding = getPaddingFromBorderModeConfig(layerConfig, 1, conf, kerasMajorVersion);
if (hasBias)
builder.biasInit(0.0);
@ -160,8 +160,8 @@ public class KerasConvolution1D extends KerasConvolution {
public InputPreProcessor getInputPreprocessor(InputType... inputType) throws InvalidKerasConfigurationException {
if (inputType.length > 1)
throw new InvalidKerasConfigurationException(
"Keras LSTM layer accepts only one input (received " + inputType.length + ")");
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], layerName);
"Keras Conv1D layer accepts only one input (received " + inputType.length + ")");
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], RNNFormat.NCW,layerName);
}

View File

@ -20,6 +20,7 @@ import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
@ -27,6 +28,7 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurat
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
import org.deeplearning4j.nn.weights.IWeightInit;
import oshi.jna.platform.windows.PowrProf;
import java.util.Map;
@ -93,6 +95,7 @@ public class KerasConvolution2D extends KerasConvolution {
LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig(
layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion);
System.out.println("----" + dimOrder);
ConvolutionLayer.Builder builder = new ConvolutionLayer.Builder().name(this.layerName)
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
.activation(getIActivationFromConfig(layerConfig, conf))
@ -101,7 +104,8 @@ public class KerasConvolution2D extends KerasConvolution {
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
.kernelSize(getKernelSizeFromConfig(layerConfig, 2, conf, kerasMajorVersion))
.hasBias(hasBias)
.stride(getStrideFromConfig(layerConfig, 2, conf));
.stride(getStrideFromConfig(layerConfig, 2, conf))
.dataFormat((dimOrder==DimOrder.TENSORFLOW)? CNN2DFormat.NHWC:CNN2DFormat.NCHW);
int[] padding = getPaddingFromBorderModeConfig(layerConfig, 2, conf, kerasMajorVersion);
if (hasBias)
builder.biasInit(0.0);

View File

@ -360,8 +360,19 @@ public class KerasConvolutionUtils {
}
} else if (dimension == 1) {
Object paddingObj = innerConfig.get(layerField);
if (paddingObj instanceof List){
List<Integer> paddingList = (List)paddingObj;
padding = new int[]{
paddingList.get(0),
paddingList.get(1)
};
}
else{
int paddingInt = (int) innerConfig.get(layerField);
padding = new int[]{paddingInt, paddingInt};
}
} else {
throw new UnsupportedKerasConfigurationException(
"Keras padding layer not supported");

View File

@ -18,6 +18,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.core;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InputType.InputTypeConvolutional;
@ -27,7 +28,6 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurat
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.TensorFlowCnnToFeedForwardPreProcessor;
import java.util.Map;
@ -93,11 +93,10 @@ public class KerasFlatten extends KerasLayer {
switch (this.getDimOrder()) {
case NONE:
case THEANO:
preprocessor = new CnnToFeedForwardPreProcessor(it.getHeight(), it.getWidth(), it.getChannels());
preprocessor = new CnnToFeedForwardPreProcessor(it.getHeight(), it.getWidth(), it.getChannels(), CNN2DFormat.NCHW);
break;
case TENSORFLOW:
preprocessor = new TensorFlowCnnToFeedForwardPreProcessor(it.getHeight(), it.getWidth(),
it.getChannels());
preprocessor = new CnnToFeedForwardPreProcessor(it.getHeight(), it.getWidth(), it.getChannels(), CNN2DFormat.NHWC);
break;
default:
throw new InvalidKerasConfigurationException("Unknown Keras backend " + this.getDimOrder());
@ -111,7 +110,7 @@ public class KerasFlatten extends KerasLayer {
// to RNN type. Otherwise we add this trivial preprocessor (since there's nothing to flatten).
InputType.InputTypeFeedForward it = (InputType.InputTypeFeedForward) inputType[0];
val inputShape = new long[]{it.getSize()};
preprocessor = new ReshapePreprocessor(inputShape, inputShape, false);
preprocessor = new ReshapePreprocessor(inputShape, inputShape, false, null);
}
return preprocessor;
}

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.nn.modelimport.keras.layers.core;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.misc.RepeatVector;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
@ -60,6 +61,7 @@ public class KerasRepeatVector extends KerasLayer {
super(layerConfig, enforceTrainingConfig);
this.layer = new RepeatVector.Builder().repetitionFactor(getRepeatMultiplier(layerConfig, conf))
.dataFormat(RNNFormat.NWC)
.name(this.layerName).build();
}

View File

@ -18,6 +18,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.core;
import lombok.val;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
@ -111,11 +112,9 @@ public class KerasReshape extends KerasLayer {
} else {
targetShape = new long[]{targetShape[1], targetShape[0], targetShape[2]};
}
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false);
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false, CNN2DFormat.NCHW);
} else { // (dimOrder == DimOrder.TENSORFLOW || dimOrder == DimOrder.NONE && kerasMajorVersion == 2)
if (inputShape[0] != targetShape[0])
targetShape = new long[]{targetShape[2], targetShape[0], targetShape[1]};
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false);
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false, CNN2DFormat.NHWC);
}
} else if (inputType[0] instanceof InputType.InputTypeConvolutional3D) {
@ -128,23 +127,23 @@ public class KerasReshape extends KerasLayer {
} else {
targetShape = new long[] { targetShape[2], targetShape[1], targetShape[0], targetShape[3] };
}
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false);
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false, null);
} else {
if (inputShape[0] != targetShape[0])
targetShape = new long[] { targetShape[3], targetShape[0], targetShape[1], targetShape[2] };
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false);
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false, null);
}
} else if (inputType[0] instanceof InputType.InputTypeRecurrent) {
InputType.InputTypeRecurrent it = (InputType.InputTypeRecurrent) inputType[0];
val inputShape = new long[]{it.getSize(), it.getTimeSeriesLength()};
preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false);
preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false, null);
} else if (inputType[0] instanceof InputType.InputTypeFeedForward) {
InputType.InputTypeFeedForward it = (InputType.InputTypeFeedForward) inputType[0];
val inputShape = new long[]{it.getSize()};
if (targetShape.length == 3) {
targetShape = targetShapeForDimOrder(inputShape, targetShape);
}
preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false);
preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false, null);
}
return preprocessor;
}

View File

@ -21,6 +21,7 @@ import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
@ -121,6 +122,7 @@ public class KerasEmbedding extends KerasLayer {
.biasInit(0.0)
.l1(this.weightL1Regularization)
.l2(this.weightL2Regularization)
.outputDataFormat(RNNFormat.NWC)
.hasBias(false);
if (embeddingConstraint != null)
builder.constrainWeights(embeddingConstraint);

View File

@ -22,11 +22,9 @@ import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
@ -37,6 +35,7 @@ import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
import org.deeplearning4j.nn.params.LSTMParamInitializer;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.util.TimeSeriesUtils;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -187,7 +186,7 @@ public class KerasLSTM extends KerasLayer {
.weightInitRecurrent(recurrentInit)
.biasInit(0.0) // TODO: this is incorrect
.l1(this.weightL1Regularization)
.l2(this.weightL2Regularization);
.l2(this.weightL2Regularization).dataFormat(RNNFormat.NWC);
Integer nIn = KerasLayerUtils.getNInFromInputDim(layerConfig, conf);
if(nIn != null)
builder.setNIn(nIn);
@ -266,7 +265,8 @@ public class KerasLSTM extends KerasLayer {
throw new InvalidKerasConfigurationException("Keras LSTM layer accepts only one single input" +
"or three (input to LSTM and two states tensors, but " +
"received " + inputType.length + ".");
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], layerName);
RNNFormat f = TimeSeriesUtils.getFormatFromRnnLayer(layer);
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], f,layerName);
}
/**

View File

@ -21,7 +21,9 @@ import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
import org.deeplearning4j.nn.conf.layers.Layer;
@ -36,6 +38,7 @@ import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
import org.deeplearning4j.nn.params.SimpleRnnParamInitializer;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.util.TimeSeriesUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.primitives.Pair;
@ -155,7 +158,7 @@ public class KerasSimpleRnn extends KerasLayer {
.weightInitRecurrent(recurrentInit)
.biasInit(0.0)
.l1(this.weightL1Regularization)
.l2(this.weightL2Regularization);
.l2(this.weightL2Regularization).dataFormat(RNNFormat.NWC);
Integer nIn = KerasLayerUtils.getNInFromInputDim(layerConfig, conf);
if(nIn != null)
builder.setNIn(nIn);
@ -227,7 +230,8 @@ public class KerasSimpleRnn extends KerasLayer {
throw new InvalidKerasConfigurationException(
"Keras SimpleRnn layer accepts only one input (received " + inputType.length + ")");
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], layerName);
RNNFormat f = TimeSeriesUtils.getFormatFromRnnLayer(layer);
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], f, layerName);
}
/**

View File

@ -147,7 +147,7 @@ public class KerasBidirectional extends KerasLayer {
break;
case "SimpleRNN":
kerasRnnlayer = new KerasSimpleRnn(innerRnnConfig, enforceTrainingConfig, previousLayers);
SimpleRnn rnnLayer = (SimpleRnn) ((KerasSimpleRnn) kerasRnnlayer).getSimpleRnnLayer();
Layer rnnLayer = ((KerasSimpleRnn) kerasRnnlayer).getSimpleRnnLayer();
this.layer = new Bidirectional(mode, rnnLayer);
layer.setLayerName(layerName);
break;
@ -218,7 +218,7 @@ public class KerasBidirectional extends KerasLayer {
if (inputType.length > 1)
throw new InvalidKerasConfigurationException(
"Keras Bidirectional layer accepts only one input (received " + inputType.length + ")");
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], layerName);
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], ((Bidirectional)layer).getRNNDataFormat(), layerName);
}
/**

View File

@ -21,6 +21,9 @@ import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.DataFormat;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
import org.deeplearning4j.nn.conf.preprocessor.BaseInputPreProcessor;
@ -54,25 +57,30 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
private final long[] inputShape;
private final long[] targetShape;
private boolean hasMiniBatchDimension;
/**
* @deprecated Use constructor {@link #ReshapePreprocessor(long[], long[], boolean)}
*/
@Deprecated
public ReshapePreprocessor(long[] inputShape, long[] targetShape) {
this(inputShape, targetShape, false);
}
private DataFormat format;
/**
* @param inputShape Input shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension
* @param targetShape Target shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension
* @param hasMiniBatchDimension If true: shapes should be of the form [minibatch, x, y, ...]; if false: shapes should be of form [x, y, ...]
*/
public ReshapePreprocessor(long[] inputShape, long[] targetShape, boolean hasMiniBatchDimension) {
this(inputShape, targetShape, hasMiniBatchDimension, null);
}
/**
* @param inputShape Input shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension
* @param targetShape Target shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension
* @param hasMiniBatchDimension If true: shapes should be of the form [minibatch, x, y, ...]; if false: shapes should be of form [x, y, ...]
* @param dataFormat May be null. If non-null:
*/
public ReshapePreprocessor(@JsonProperty("inputShape") long[] inputShape, @JsonProperty("targetShape") long[] targetShape,
@JsonProperty("hasMiniBatchDimension") boolean hasMiniBatchDimension) {
@JsonProperty("hasMiniBatchDimension") boolean hasMiniBatchDimension,
@JsonProperty("dataFormat") DataFormat dataFormat) {
this.inputShape = inputShape;
this.targetShape = targetShape;
this.hasMiniBatchDimension = hasMiniBatchDimension;
this.format = dataFormat;
}
private long[] getShape(long[] originalShape, long minibatch) {
@ -140,13 +148,26 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
ret = InputType.feedForward(shape[1]);
break;
case 3:
ret = InputType.recurrent(shape[2], shape[1]);
RNNFormat format = RNNFormat.NCW;
if(this.format != null && this.format instanceof RNNFormat)
format = (RNNFormat)this.format;
ret = InputType.recurrent(shape[2], shape[1], format);
break;
case 4:
if (inputShape.length == 1 || inputType.getType() == InputType.Type.RNN) {
ret = InputType.convolutional(shape[1], shape[2], shape[3]);
} else {
ret = InputType.convolutional(shape[2], shape[3], shape[1]);
CNN2DFormat cnnFormat = CNN2DFormat.NCHW;
if (this.format != null && this.format instanceof CNN2DFormat)
cnnFormat = (CNN2DFormat) this.format;
if (cnnFormat == CNN2DFormat.NCHW) {
ret = InputType.convolutional(shape[2], shape[3], shape[1], cnnFormat);
} else {
ret = InputType.convolutional(shape[1], shape[2], shape[3], cnnFormat);
}
}
break;
default:

View File

@ -27,26 +27,25 @@ import org.nd4j.shade.jackson.annotation.JsonCreator;
import org.nd4j.shade.jackson.annotation.JsonProperty;
/**
* Specialized CnnToFeedForwardInputPreProcessor for use with
* Convolutional layers imported from Keras using the TensorFlow
* backend.
*
* @author dave@skymind.io
* @deprecated Exists only for backward compatibility of older pretrained models. Should not be used.
* Use {@link CnnToFeedForwardPreProcessor} for all new models instead.
*/
@Slf4j
@Slf4j @Deprecated
public class TensorFlowCnnToFeedForwardPreProcessor extends CnnToFeedForwardPreProcessor {
@JsonCreator
@JsonCreator @Deprecated
public TensorFlowCnnToFeedForwardPreProcessor(@JsonProperty("inputHeight") long inputHeight,
@JsonProperty("inputWidth") long inputWidth,
@JsonProperty("numChannels") long numChannels) {
super(inputHeight, inputWidth, numChannels);
}
@Deprecated
public TensorFlowCnnToFeedForwardPreProcessor(long inputHeight, long inputWidth) {
super(inputHeight, inputWidth);
}
@Deprecated
public TensorFlowCnnToFeedForwardPreProcessor() {
super();
}

View File

@ -31,6 +31,7 @@ import org.deeplearning4j.nn.modelimport.keras.layers.advanced.activations.*;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.*;
import org.deeplearning4j.nn.modelimport.keras.layers.core.*;
import org.deeplearning4j.nn.modelimport.keras.layers.embeddings.KerasEmbedding;
import org.deeplearning4j.nn.modelimport.keras.layers.local.KerasLocallyConnected1D;
import org.deeplearning4j.nn.modelimport.keras.layers.noise.KerasAlphaDropout;
import org.deeplearning4j.nn.modelimport.keras.layers.noise.KerasGaussianDropout;
import org.deeplearning4j.nn.modelimport.keras.layers.noise.KerasGaussianNoise;
@ -319,6 +320,8 @@ public class KerasLayerUtils {
layer = new KerasELU(layerConfig, enforceTrainingConfig);
} else if(layerClassName.equals(conf.getLAYER_CLASS_NAME_SOFTMAX())){
layer = new KerasSoftmax(layerConfig, enforceTrainingConfig);
} else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_LOCALLY_CONNECTED_1D())){
layer = new KerasLocallyConnected1D(layerConfig, enforceTrainingConfig);
} else if (conf instanceof Keras2LayerConfiguration){
Keras2LayerConfiguration k2conf = (Keras2LayerConfiguration)conf;
if (layerClassName.equals(k2conf.getTENSORFLOW_OP_LAYER())){

View File

@ -1,50 +0,0 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.nn.modelimport.keras;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.junit.Assert;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.resources.Resources;
import java.io.File;
import java.util.Arrays;
public class TFKerasTests extends BaseDL4JTest{
@Test
public void testModelWithTFOp1() throws Exception{
File f = Resources.asFile("modelimport/keras/tfkeras/reshape.h5");
ComputationGraph graph = KerasModelImport.importKerasModelAndWeights(f.getAbsolutePath());
INDArray out = graph.outputSingle(Nd4j.zeros(12, 2, 3));
Assert.assertArrayEquals(new long[]{12, 3}, out.shape());
}
@Test
public void testModelWithTFOp2() throws Exception{
File f = Resources.asFile("modelimport/keras/tfkeras/permute.h5");
ComputationGraph graph = KerasModelImport.importKerasModelAndWeights(f.getAbsolutePath());
INDArray out = graph.outputSingle(Nd4j.zeros(12, 2, 3));
// dl4j's feedforward doesn't support 3D output, so batch and time axes gets squashed
long[] expectedShape = new long[]{12 * 2, 5};
Assert.assertArrayEquals(expectedShape, out.shape());
}
}

View File

@ -0,0 +1,147 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.nn.modelimport.keras;
import org.apache.commons.io.FileUtils;
import org.datavec.python.keras.Model;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.common.tests.ResourceUtils;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.resources.Resources;
import java.io.File;
import java.util.List;
@RunWith(Parameterized.class)
public class TestTFKerasModelImport extends BaseDL4JTest{
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
private String modelFile;
@Override
public long getTimeoutMilliseconds(){
return 300000;
} // installing TF will take a while
@Parameterized.Parameters(name = "file={0}")
public static Object[] params() throws Exception {
List<String> paths = ResourceUtils.listClassPathFiles("modelimport/keras/tfkeras", true, false);
return paths.toArray(new String[0]);
}
public TestTFKerasModelImport(String modelFile){
this.modelFile = modelFile;
}
@Test
public void testModelImport() throws Exception{
testModelImportWithData(modelFile);
}
private void testModelImportWithData(String path) throws Exception{
System.out.println(path);
// TODO multi input/output
INDArray inputArray;
INDArray expectedOutputArray;
File f = Resources.asFile(path); //May in in JAR that HDF5 can't read from
File modelFile = new File(testDir.getRoot(), f.getName());
FileUtils.copyFile(f, modelFile);
synchronized (Hdf5Archive.LOCK_OBJECT){
Hdf5Archive hdf5Archive = new Hdf5Archive(modelFile.getAbsolutePath());
List<String> rootGroups = hdf5Archive.getGroups();
if (rootGroups.contains("data")){
String inputName = hdf5Archive.readAttributeAsString("input_names", "data");
String outputName = hdf5Archive.readAttributeAsString("output_names", "data");
inputArray = hdf5Archive.readDataSet(inputName, "data");
expectedOutputArray = hdf5Archive.readDataSet(outputName, "data");
}
else{
hdf5Archive.close();
return;
}
hdf5Archive.close();
}
INDArray outputArray;
ComputationGraph dl4jModel = KerasModelImport.importKerasModelAndWeights(path);
outputArray = dl4jModel.outputSingle(inputArray);
expectedOutputArray = expectedOutputArray.castTo(DataType.FLOAT);
outputArray = outputArray.castTo(DataType.FLOAT);
if (path.contains("misc_")){
//shape relaxation
expectedOutputArray = expectedOutputArray.reshape( -1);
outputArray = outputArray.reshape(-1);
}
System.out.println(outputArray.toString());
System.out.println(expectedOutputArray.toString());
Assert.assertArrayEquals(expectedOutputArray.shape(), outputArray.shape());
Assert.assertTrue(expectedOutputArray.equalsWithEps(outputArray, 1e-3));
}
private void testModelImportWithKeras(String path) throws Exception{
Model kerasModel = new Model(path);
ComputationGraph dl4jModel = KerasModelImport.importKerasModelAndWeights(path);
Assert.assertEquals(kerasModel.numInputs(), dl4jModel.getNumInputArrays());
Assert.assertEquals(kerasModel.numOutputs(), dl4jModel.getNumOutputArrays());
INDArray[] kerasInputArrays = new INDArray[kerasModel.numInputs()];
INDArray[] dl4jInputArrays = new INDArray[kerasModel.numInputs()];
for (int i = 0; i < kerasInputArrays.length; i ++) {
long[] shape = kerasModel.inputShapeAt(i);
for (int j = 0; j < shape.length; j++) {
if (shape[j] < 0) {
shape[j] = 1;
}
}
kerasInputArrays[i] = Nd4j.rand(shape);
}
INDArray[] kerasOut = kerasModel.predict(kerasInputArrays);
INDArray[] dl4jOut = dl4jModel.output(dl4jInputArrays);
Assert.assertEquals(kerasOut.length, dl4jOut.length);
for (int i = 0; i < kerasOut.length; i++){
INDArray kerasOutArr = kerasOut[i];
kerasOutArr = kerasOutArr.reshape(1, -1);// bit of relaxation on shape
kerasOutArr= kerasOutArr.castTo(DataType.DOUBLE);
Nd4j.getAffinityManager().ensureLocation(dl4jOut[i], AffinityManager.Location.HOST);
INDArray dl4jOutArr = dl4jOut[i].reshape(1, -1);
System.out.println(kerasOutArr.shapeInfoToString());
System.out.println(dl4jOutArr.shapeInfoToString());
Assert.assertEquals(kerasOutArr, dl4jOutArr);
}
}
}

Some files were not shown because too many files have changed in this diff Show More