Fixing issues from Sonar report (#391)

* Fixing issues from Sonar report

* Proper logger of exceptions

* Coding style fixes

* Use dup parameter

* Cleanup, minor issues

* Cuda compilation fixed and some minor fixes
master
Alexander Stoyakin 2020-04-23 01:36:49 +03:00 committed by GitHub
parent 2a488efb1b
commit ccb216a3ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
178 changed files with 376 additions and 422 deletions

View File

@ -18,6 +18,7 @@ package org.datavec.api.records.reader.impl;
import lombok.Getter; import lombok.Getter;
import lombok.Setter; import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils; import org.apache.commons.io.IOUtils;
import org.apache.commons.io.LineIterator; import org.apache.commons.io.LineIterator;
import org.datavec.api.conf.Configuration; import org.datavec.api.conf.Configuration;
@ -43,6 +44,7 @@ import java.util.*;
* *
* @author Adam Gibson * @author Adam Gibson
*/ */
@Slf4j
public class LineRecordReader extends BaseRecordReader { public class LineRecordReader extends BaseRecordReader {
@ -89,7 +91,7 @@ public class LineRecordReader extends BaseRecordReader {
iter = getIterator(splitIndex); iter = getIterator(splitIndex);
onLocationOpen(locations[splitIndex]); onLocationOpen(locations[splitIndex]);
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("",e);
} }
if (iter.hasNext()) { if (iter.hasNext()) {
@ -120,7 +122,7 @@ public class LineRecordReader extends BaseRecordReader {
iter = getIterator(splitIndex); iter = getIterator(splitIndex);
onLocationOpen(locations[splitIndex]); onLocationOpen(locations[splitIndex]);
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("",e);
} }
return iter.hasNext(); return iter.hasNext();

View File

@ -16,6 +16,7 @@
package org.datavec.api.split; package org.datavec.api.split;
import lombok.extern.slf4j.Slf4j;
import org.datavec.api.util.files.UriFromPathIterator; import org.datavec.api.util.files.UriFromPathIterator;
import org.datavec.api.writable.WritableType; import org.datavec.api.writable.WritableType;
@ -34,6 +35,7 @@ import java.util.regex.Pattern;
* NumberedFileInputSplit utilizes String.format(), hence the requirement for "%d" to represent * NumberedFileInputSplit utilizes String.format(), hence the requirement for "%d" to represent
* the integer index. * the integer index.
*/ */
@Slf4j
public class NumberedFileInputSplit implements InputSplit { public class NumberedFileInputSplit implements InputSplit {
private final String baseString; private final String baseString;
private final int minIdx; private final int minIdx;
@ -93,7 +95,7 @@ public class NumberedFileInputSplit implements InputSplit {
try { try {
writeFile.createNewFile(); writeFile.createNewFile();
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("",e);
} }

View File

@ -162,7 +162,6 @@ public class Text extends BinaryComparable implements WritableComparable<BinaryC
return -1; // not found return -1; // not found
} catch (CharacterCodingException e) { } catch (CharacterCodingException e) {
// can't get here // can't get here
e.printStackTrace();
return -1; return -1;
} }
} }

View File

@ -17,6 +17,7 @@
package org.datavec.arrow.recordreader; package org.datavec.arrow.recordreader;
import lombok.Getter; import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.datavec.api.conf.Configuration; import org.datavec.api.conf.Configuration;
import org.datavec.api.records.Record; import org.datavec.api.records.Record;
@ -50,6 +51,7 @@ import static org.datavec.arrow.ArrowConverter.readFromBytes;
* @author Adam Gibson * @author Adam Gibson
* *
*/ */
@Slf4j
public class ArrowRecordReader implements RecordReader { public class ArrowRecordReader implements RecordReader {
private InputSplit split; private InputSplit split;
@ -132,7 +134,7 @@ public class ArrowRecordReader implements RecordReader {
currIdx++; currIdx++;
this.currentPath = url; this.currentPath = url;
}catch(Exception e) { }catch(Exception e) {
e.printStackTrace(); log.error("",e);
} }
} }
@ -242,7 +244,7 @@ public class ArrowRecordReader implements RecordReader {
try { try {
currentBatch.close(); currentBatch.close();
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("",e);
} }
} }
} }

View File

@ -16,6 +16,7 @@
package org.datavec.arrow; package org.datavec.arrow;
import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.memory.RootAllocator;
@ -56,7 +57,7 @@ import static junit.framework.TestCase.assertTrue;
import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertFalse;
@Slf4j
public class ArrowConverterTest extends BaseND4JTest { public class ArrowConverterTest extends BaseND4JTest {
private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE); private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE);
@ -343,7 +344,7 @@ public class ArrowConverterTest extends BaseND4JTest {
try(ArrowFileWriter arrowFileWriter = new ArrowFileWriter(schemaRoot1,null,newChannel(byteArrayOutputStream))) { try(ArrowFileWriter arrowFileWriter = new ArrowFileWriter(schemaRoot1,null,newChannel(byteArrayOutputStream))) {
arrowFileWriter.writeBatch(); arrowFileWriter.writeBatch();
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("",e);
} }
byte[] arr = byteArrayOutputStream.toByteArray(); byte[] arr = byteArrayOutputStream.toByteArray();

View File

@ -61,7 +61,7 @@ public class Wave implements Serializable {
initWaveWithInputStream(inputStream); initWaveWithInputStream(inputStream);
inputStream.close(); inputStream.close();
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); System.out.println(e.toString());
} }
} }
@ -96,7 +96,7 @@ public class Wave implements Serializable {
data = new byte[inputStream.available()]; data = new byte[inputStream.available()];
inputStream.read(data); inputStream.read(data);
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); System.err.println(e.toString());
} }
// end load data // end load data
} else { } else {

View File

@ -16,10 +16,12 @@
package org.datavec.audio; package org.datavec.audio;
import lombok.extern.slf4j.Slf4j;
import java.io.FileOutputStream; import java.io.FileOutputStream;
import java.io.IOException; import java.io.IOException;
@Slf4j
public class WaveFileManager { public class WaveFileManager {
private Wave wave; private Wave wave;
@ -78,7 +80,7 @@ public class WaveFileManager {
fos.write(wave.getBytes()); fos.write(wave.getBytes());
fos.close(); fos.close();
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("",e);
} }
} }

View File

@ -16,6 +16,8 @@
package org.datavec.audio; package org.datavec.audio;
import lombok.extern.slf4j.Slf4j;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
@ -25,6 +27,7 @@ import java.io.InputStream;
* *
* @author Jacquet Wong * @author Jacquet Wong
*/ */
@Slf4j
public class WaveHeader { public class WaveHeader {
public static final String RIFF_HEADER = "RIFF"; public static final String RIFF_HEADER = "RIFF";
@ -109,7 +112,7 @@ public class WaveHeader {
// dis.close(); // dis.close();
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("",e);
return false; return false;
} }

View File

@ -17,6 +17,7 @@
package org.datavec.audio.fingerprint; package org.datavec.audio.fingerprint;
import lombok.extern.slf4j.Slf4j;
import org.datavec.audio.Wave; import org.datavec.audio.Wave;
import org.datavec.audio.WaveHeader; import org.datavec.audio.WaveHeader;
import org.datavec.audio.dsp.Resampler; import org.datavec.audio.dsp.Resampler;
@ -38,6 +39,7 @@ import java.util.List;
* @author jacquet * @author jacquet
* *
*/ */
@Slf4j
public class FingerprintManager { public class FingerprintManager {
private FingerprintProperties fingerprintProperties = FingerprintProperties.getInstance(); private FingerprintProperties fingerprintProperties = FingerprintProperties.getInstance();
@ -153,7 +155,7 @@ public class FingerprintManager {
fingerprint = getFingerprintFromInputStream(fis); fingerprint = getFingerprintFromInputStream(fis);
fis.close(); fis.close();
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("",e);
} }
return fingerprint; return fingerprint;
} }
@ -170,7 +172,7 @@ public class FingerprintManager {
fingerprint = new byte[inputStream.available()]; fingerprint = new byte[inputStream.available()];
inputStream.read(fingerprint); inputStream.read(fingerprint);
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("",e);
} }
return fingerprint; return fingerprint;
} }
@ -190,7 +192,7 @@ public class FingerprintManager {
fileOutputStream.write(fingerprint); fileOutputStream.write(fingerprint);
fileOutputStream.close(); fileOutputStream.close();
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("",e);
} }
} }

View File

@ -186,7 +186,7 @@ public class CifarLoader extends NativeImageLoader implements Serializable {
labels.add(line); labels.add(line);
} }
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("",e);
} }
} }
@ -300,7 +300,7 @@ public class CifarLoader extends NativeImageLoader implements Serializable {
try { try {
FileUtils.write(meanVarPath, uMean + "," + uStd + "," + vMean + "," + vStd); FileUtils.write(meanVarPath, uMean + "," + uStd + "," + vMean + "," + vStd);
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("",e);
} }
meanStdStored = true; meanStdStored = true;
} else if (uMean == 0 && meanStdStored) { } else if (uMean == 0 && meanStdStored) {
@ -312,7 +312,7 @@ public class CifarLoader extends NativeImageLoader implements Serializable {
vStd = Double.parseDouble(values[3]); vStd = Double.parseDouble(values[3]);
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("",e);
} }
} }
for (int i = 0; i < result.numExamples(); i++) { for (int i = 0; i < result.numExamples(); i++) {
@ -356,12 +356,12 @@ public class CifarLoader extends NativeImageLoader implements Serializable {
dataSets.add(new DataSet(asMatrix(matConversion.getSecond()), matConversion.getFirst())); dataSets.add(new DataSet(asMatrix(matConversion.getSecond()), matConversion.getFirst()));
batchNumCount++; batchNumCount++;
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
break; break;
} }
} }
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("",e);
} }
if(dataSets.size() == 0){ if(dataSets.size() == 0){

View File

@ -235,7 +235,7 @@ public class LFWLoader extends BaseImageLoader implements Serializable {
InputSplit data = train ? inputSplit[0] : inputSplit[1]; InputSplit data = train ? inputSplit[0] : inputSplit[1];
recordReader.initialize(data); recordReader.initialize(data);
} catch (IOException | InterruptedException e) { } catch (IOException | InterruptedException e) {
e.printStackTrace(); log.error("",e);
} }
return recordReader; return recordReader;
} }
@ -250,7 +250,7 @@ public class LFWLoader extends BaseImageLoader implements Serializable {
InputSplit data = train ? inputSplit[0] : inputSplit[1]; InputSplit data = train ? inputSplit[0] : inputSplit[1];
recordReader.initialize(data); recordReader.initialize(data);
} catch (IOException | InterruptedException e) { } catch (IOException | InterruptedException e) {
e.printStackTrace(); log.error("",e);
} }
return recordReader; return recordReader;
} }

View File

@ -225,7 +225,7 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
finishedInputStreamSplit = true; finishedInputStreamSplit = true;
return Arrays.<Writable>asList(ndArrayWritable); return Arrays.<Writable>asList(ndArrayWritable);
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("",e);
} }
} }
if (iter != null) { if (iter != null) {

View File

@ -16,6 +16,7 @@
package org.datavec.image.loader; package org.datavec.image.loader;
import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.apache.commons.io.IOUtils; import org.apache.commons.io.IOUtils;
import org.bytedeco.javacpp.Loader; import org.bytedeco.javacpp.Loader;
@ -53,6 +54,7 @@ import static org.junit.Assert.fail;
* *
* @author saudet * @author saudet
*/ */
@Slf4j
public class TestNativeImageLoader { public class TestNativeImageLoader {
static final long seed = 10; static final long seed = 10;
static final Random rng = new Random(seed); static final Random rng = new Random(seed);
@ -123,7 +125,7 @@ public class TestNativeImageLoader {
try { try {
array6 = loader5.asMatrix(new ClassPathResource(path2MitosisFile).getFile().getAbsolutePath()); array6 = loader5.asMatrix(new ClassPathResource(path2MitosisFile).getFile().getAbsolutePath());
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("",e);
fail(); fail();
} }
assertEquals(5, array6.rank()); assertEquals(5, array6.rank());
@ -156,7 +158,7 @@ public class TestNativeImageLoader {
try { try {
array8 = loader7.asMatrix(new ClassPathResource(path2MitosisFile).getFile().getAbsolutePath()); array8 = loader7.asMatrix(new ClassPathResource(path2MitosisFile).getFile().getAbsolutePath());
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("",e);
} }
assertEquals(5, array8.rank()); assertEquals(5, array8.rank());
assertEquals(pages2, array8.size(0)); assertEquals(pages2, array8.size(0));
@ -172,7 +174,7 @@ public class TestNativeImageLoader {
try { try {
array9 = loader8.asMatrix(new ClassPathResource(braintiff).getFile().getAbsolutePath()); array9 = loader8.asMatrix(new ClassPathResource(braintiff).getFile().getAbsolutePath());
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("",e);
fail(); fail();
} }
assertEquals(5, array9.rank()); assertEquals(5, array9.rank());

View File

@ -66,7 +66,7 @@ public class UimaTokenizer implements Tokenizer {
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
throw new RuntimeException(e); throw new RuntimeException(e);
} }

View File

@ -17,6 +17,7 @@
package org.datavec.local.transforms.functions; package org.datavec.local.transforms.functions;
import lombok.extern.slf4j.Slf4j;
import org.datavec.api.records.reader.SequenceRecordReader; import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.nd4j.linalg.function.Function; import org.nd4j.linalg.function.Function;
@ -32,6 +33,7 @@ import java.util.List;
* sequence data into a {@code List<List<Writable>>} * sequence data into a {@code List<List<Writable>>}
* @author Alex Black * @author Alex Black
*/ */
@Slf4j
public class SequenceRecordReaderFunction public class SequenceRecordReaderFunction
implements Function<Pair<String, InputStream>, List<List<Writable>>> { implements Function<Pair<String, InputStream>, List<List<Writable>>> {
protected SequenceRecordReader sequenceRecordReader; protected SequenceRecordReader sequenceRecordReader;
@ -46,7 +48,7 @@ public class SequenceRecordReaderFunction
try (DataInputStream dis = (DataInputStream) value.getRight()) { try (DataInputStream dis = (DataInputStream) value.getRight()) {
return sequenceRecordReader.sequenceRecord(uri, dis); return sequenceRecordReader.sequenceRecord(uri, dis);
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("",e);
} }
throw new IllegalStateException("Something went wrong"); throw new IllegalStateException("Something went wrong");

View File

@ -74,7 +74,6 @@ public class DataVecTransformClient implements DataVecTransformService {
} catch (UnirestException e) { } catch (UnirestException e) {
log.error("Error in setCSVTransformProcess()", e); log.error("Error in setCSVTransformProcess()", e);
e.printStackTrace();
} }
} }
@ -94,7 +93,6 @@ public class DataVecTransformClient implements DataVecTransformService {
return TransformProcess.fromJson(s); return TransformProcess.fromJson(s);
} catch (UnirestException e) { } catch (UnirestException e) {
log.error("Error in getCSVTransformProcess()",e); log.error("Error in getCSVTransformProcess()",e);
e.printStackTrace();
} }
return null; return null;
@ -119,7 +117,6 @@ public class DataVecTransformClient implements DataVecTransformService {
return singleCsvRecord; return singleCsvRecord;
} catch (UnirestException e) { } catch (UnirestException e) {
log.error("Error in transformIncremental(SingleCSVRecord)",e); log.error("Error in transformIncremental(SingleCSVRecord)",e);
e.printStackTrace();
} }
return null; return null;
} }
@ -140,8 +137,7 @@ public class DataVecTransformClient implements DataVecTransformService {
.getBody(); .getBody();
return batchCSVRecord1; return batchCSVRecord1;
} catch (UnirestException e) { } catch (UnirestException e) {
log.error("Error in transform(BatchCSVRecord)", e); log.error("",e);
e.printStackTrace();
} }
return null; return null;
@ -162,7 +158,6 @@ public class DataVecTransformClient implements DataVecTransformService {
return batchCSVRecord1; return batchCSVRecord1;
} catch (UnirestException e) { } catch (UnirestException e) {
log.error("Error in transform(BatchCSVRecord)", e); log.error("Error in transform(BatchCSVRecord)", e);
e.printStackTrace();
} }
return null; return null;
@ -181,7 +176,6 @@ public class DataVecTransformClient implements DataVecTransformService {
return batchArray1; return batchArray1;
} catch (UnirestException e) { } catch (UnirestException e) {
log.error("Error in transformArray(BatchCSVRecord)",e); log.error("Error in transformArray(BatchCSVRecord)",e);
e.printStackTrace();
} }
return null; return null;
@ -200,7 +194,6 @@ public class DataVecTransformClient implements DataVecTransformService {
return array; return array;
} catch (UnirestException e) { } catch (UnirestException e) {
log.error("Error in transformArrayIncremental(SingleCSVRecord)",e); log.error("Error in transformArrayIncremental(SingleCSVRecord)",e);
e.printStackTrace();
} }
return null; return null;
@ -231,7 +224,6 @@ public class DataVecTransformClient implements DataVecTransformService {
return array; return array;
} catch (UnirestException e) { } catch (UnirestException e) {
log.error("Error in transformSequenceArrayIncremental",e); log.error("Error in transformSequenceArrayIncremental",e);
e.printStackTrace();
} }
return null; return null;
@ -252,7 +244,6 @@ public class DataVecTransformClient implements DataVecTransformService {
return batchArray1; return batchArray1;
} catch (UnirestException e) { } catch (UnirestException e) {
log.error("Error in transformSequenceArray",e); log.error("Error in transformSequenceArray",e);
e.printStackTrace();
} }
return null; return null;
@ -274,7 +265,6 @@ public class DataVecTransformClient implements DataVecTransformService {
return batchCSVRecord1; return batchCSVRecord1;
} catch (UnirestException e) { } catch (UnirestException e) {
log.error("Error in transformSequence"); log.error("Error in transformSequence");
e.printStackTrace();
} }
return null; return null;
@ -295,7 +285,6 @@ public class DataVecTransformClient implements DataVecTransformService {
return singleCsvRecord; return singleCsvRecord;
} catch (UnirestException e) { } catch (UnirestException e) {
log.error("Error in transformSequenceIncremental"); log.error("Error in transformSequenceIncremental");
e.printStackTrace();
} }
return null; return null;
} }

View File

@ -18,6 +18,7 @@ package org.datavec.spark.transform;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Getter; import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.memory.RootAllocator;
@ -53,7 +54,7 @@ import static org.datavec.local.transforms.LocalTransformExecutor.executeToSeque
* @author Adan Gibson * @author Adan Gibson
*/ */
@AllArgsConstructor @AllArgsConstructor
@Slf4j
public class CSVSparkTransform { public class CSVSparkTransform {
@Getter @Getter
private TransformProcess transformProcess; private TransformProcess transformProcess;
@ -252,7 +253,7 @@ public class CSVSparkTransform {
try { try {
return new Base64NDArrayBody(Nd4jBase64.base64String(arr)); return new Base64NDArrayBody(Nd4jBase64.base64String(arr));
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("",e);
} }
return null; return null;

View File

@ -88,7 +88,7 @@ public class ImageSparkTransformServer extends SparkTransformServer {
log.info("Transform process initialized"); log.info("Transform process initialized");
return ok(objectMapper.writeValueAsString(transform.getImageTransformProcess())).as(contentType); return ok(objectMapper.writeValueAsString(transform.getImageTransformProcess())).as(contentType);
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
return internalServerError(); return internalServerError();
} }
}); });
@ -100,7 +100,7 @@ public class ImageSparkTransformServer extends SparkTransformServer {
log.info("Transform process initialized"); log.info("Transform process initialized");
return ok(objectMapper.writeValueAsString(transformProcess)).as(contentType); return ok(objectMapper.writeValueAsString(transformProcess)).as(contentType);
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
return internalServerError(); return internalServerError();
} }
}); });
@ -112,7 +112,7 @@ public class ImageSparkTransformServer extends SparkTransformServer {
return badRequest(); return badRequest();
return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType); return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType);
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
return internalServerError(); return internalServerError();
} }
}); });
@ -130,7 +130,7 @@ public class ImageSparkTransformServer extends SparkTransformServer {
return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType); return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType);
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
return internalServerError(); return internalServerError();
} }
}); });
@ -142,7 +142,7 @@ public class ImageSparkTransformServer extends SparkTransformServer {
return badRequest(); return badRequest();
return ok(objectMapper.writeValueAsString(transformArray(batch))).as(contentType); return ok(objectMapper.writeValueAsString(transformArray(batch))).as(contentType);
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
return internalServerError(); return internalServerError();
} }
}); });
@ -169,7 +169,7 @@ public class ImageSparkTransformServer extends SparkTransformServer {
return ok(objectMapper.writeValueAsString(transformArray(batch))).as(contentType); return ok(objectMapper.writeValueAsString(transformArray(batch))).as(contentType);
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
return internalServerError(); return internalServerError();
} }
}); });

View File

@ -16,6 +16,7 @@
package org.datavec.spark; package org.datavec.spark;
import lombok.extern.slf4j.Slf4j;
import org.apache.spark.SparkConf; import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.JavaSparkContext;
import org.junit.After; import org.junit.After;
@ -23,6 +24,7 @@ import org.junit.Before;
import java.io.Serializable; import java.io.Serializable;
@Slf4j
public abstract class BaseSparkTest implements Serializable { public abstract class BaseSparkTest implements Serializable {
protected static JavaSparkContext sc; protected static JavaSparkContext sc;
@ -40,7 +42,7 @@ public abstract class BaseSparkTest implements Serializable {
try { try {
Thread.sleep(100L); Thread.sleep(100L);
} catch (InterruptedException e) { } catch (InterruptedException e) {
e.printStackTrace(); log.error("",e);
} }
} else { } else {
break; break;

View File

@ -43,7 +43,8 @@ import java.util.concurrent.atomic.AtomicLong;
*/ */
@Slf4j @Slf4j
public class RemoteUIStatsStorageRouter implements StatsStorageRouter, Serializable, Closeable { public class RemoteUIStatsStorageRouter implements StatsStorageRouter, Serializable, Closeable {
private static final String ROUTE_IS_DOWN = "Info posted to RemoteUIStatsStorageRouter but router is shut down.";
private static final String MAX_WARNINGS_REACHED = "RemoteUIStatsStorageRouter: Reached max shutdown warnings. No further warnings will be produced.";
/** /**
* Default path for posting data to the UI - i.e., http://localhost:9000/remoteReceive or similar * Default path for posting data to the UI - i.e., http://localhost:9000/remoteReceive or similar
*/ */
@ -163,10 +164,10 @@ public class RemoteUIStatsStorageRouter implements StatsStorageRouter, Serializa
if (shutdown.get()) { if (shutdown.get()) {
long count = shutdownWarnCount.getAndIncrement(); long count = shutdownWarnCount.getAndIncrement();
if (count <= MAX_SHUTDOWN_WARN_COUNT) { if (count <= MAX_SHUTDOWN_WARN_COUNT) {
log.warn("Info posted to RemoteUIStatsStorageRouter but router is shut down."); log.warn(ROUTE_IS_DOWN);
} }
if (count == MAX_SHUTDOWN_WARN_COUNT) { if (count == MAX_SHUTDOWN_WARN_COUNT) {
log.warn("RemoteUIStatsStorageRouter: Reached max shutdown warnings. No further warnings will be produced."); log.warn(MAX_WARNINGS_REACHED);
} }
} else { } else {
for (StorageMetaData m : storageMetaData) { for (StorageMetaData m : storageMetaData) {
@ -186,10 +187,10 @@ public class RemoteUIStatsStorageRouter implements StatsStorageRouter, Serializa
if (shutdown.get()) { if (shutdown.get()) {
long count = shutdownWarnCount.getAndIncrement(); long count = shutdownWarnCount.getAndIncrement();
if (count <= MAX_SHUTDOWN_WARN_COUNT) { if (count <= MAX_SHUTDOWN_WARN_COUNT) {
log.warn("Info posted to RemoteUIStatsStorageRouter but router is shut down."); log.warn(ROUTE_IS_DOWN);
} }
if (count == MAX_SHUTDOWN_WARN_COUNT) { if (count == MAX_SHUTDOWN_WARN_COUNT) {
log.warn("RemoteUIStatsStorageRouter: Reached max shutdown warnings. No further warnings will be produced."); log.warn(MAX_WARNINGS_REACHED);
} }
} else { } else {
for (Persistable p : staticInfo) { for (Persistable p : staticInfo) {
@ -209,10 +210,10 @@ public class RemoteUIStatsStorageRouter implements StatsStorageRouter, Serializa
if (shutdown.get()) { if (shutdown.get()) {
long count = shutdownWarnCount.getAndIncrement(); long count = shutdownWarnCount.getAndIncrement();
if (count <= MAX_SHUTDOWN_WARN_COUNT) { if (count <= MAX_SHUTDOWN_WARN_COUNT) {
log.warn("Info posted to RemoteUIStatsStorageRouter but router is shut down."); log.warn(ROUTE_IS_DOWN);
} }
if (count == MAX_SHUTDOWN_WARN_COUNT) { if (count == MAX_SHUTDOWN_WARN_COUNT) {
log.warn("RemoteUIStatsStorageRouter: Reached max shutdown warnings. No further warnings will be produced."); log.warn(MAX_WARNINGS_REACHED);
} }
} else { } else {
for (Persistable p : updates) { for (Persistable p : updates) {

View File

@ -67,10 +67,7 @@ public class AsyncIterator<T extends Object> implements Iterator<T> {
nextElement = buffer.take(); nextElement = buffer.take();
// same on this run // same on this run
if (nextElement == terminator) return (nextElement != terminator);
return false;
return true;
} catch (Exception e) { } catch (Exception e) {
log.error("Premature end of loop!"); log.error("Premature end of loop!");
return false; return false;

View File

@ -43,7 +43,7 @@ public class SystemInfoPrintListener implements TrainingListener {
private boolean printOnBackwardPass; private boolean printOnBackwardPass;
private boolean printOnGradientCalculation; private boolean printOnGradientCalculation;
private static final String SYSTEM_INFO = "System info on epoch end: ";
@Override @Override
public void iterationDone(Model model, int iteration, int epoch) { public void iterationDone(Model model, int iteration, int epoch) {
@ -65,7 +65,7 @@ public class SystemInfoPrintListener implements TrainingListener {
return; return;
SystemInfo systemInfo = new SystemInfo(); SystemInfo systemInfo = new SystemInfo();
log.info("System info on epoch end: "); log.info(SYSTEM_INFO);
log.info(systemInfo.toPrettyJSON()); log.info(systemInfo.toPrettyJSON());
} }
@ -75,7 +75,7 @@ public class SystemInfoPrintListener implements TrainingListener {
return; return;
SystemInfo systemInfo = new SystemInfo(); SystemInfo systemInfo = new SystemInfo();
log.info("System info on epoch end: "); log.info(SYSTEM_INFO);
log.info(systemInfo.toPrettyJSON()); log.info(systemInfo.toPrettyJSON());
} }
@ -85,7 +85,7 @@ public class SystemInfoPrintListener implements TrainingListener {
return; return;
SystemInfo systemInfo = new SystemInfo(); SystemInfo systemInfo = new SystemInfo();
log.info("System info on epoch end: "); log.info(SYSTEM_INFO);
log.info(systemInfo.toPrettyJSON()); log.info(systemInfo.toPrettyJSON());
} }
@ -95,7 +95,7 @@ public class SystemInfoPrintListener implements TrainingListener {
return; return;
SystemInfo systemInfo = new SystemInfo(); SystemInfo systemInfo = new SystemInfo();
log.info("System info on epoch end: "); log.info(SYSTEM_INFO);
log.info(systemInfo.toPrettyJSON()); log.info(systemInfo.toPrettyJSON());
} }
@ -104,7 +104,7 @@ public class SystemInfoPrintListener implements TrainingListener {
if(!printOnBackwardPass) if(!printOnBackwardPass)
return; return;
SystemInfo systemInfo = new SystemInfo(); SystemInfo systemInfo = new SystemInfo();
log.info("System info on epoch end: "); log.info(SYSTEM_INFO);
log.info(systemInfo.toPrettyJSON()); log.info(systemInfo.toPrettyJSON());
} }
} }

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.perf.listener; package org.deeplearning4j.perf.listener;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.shade.jackson.databind.ObjectMapper; import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory; import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
import oshi.json.SystemInfo; import oshi.json.SystemInfo;
@ -36,6 +37,7 @@ import java.util.concurrent.TimeUnit;
* *
* @author Adam Gibson * @author Adam Gibson
*/ */
@Slf4j
public class SystemPolling { public class SystemPolling {
private ScheduledExecutorService scheduledExecutorService; private ScheduledExecutorService scheduledExecutorService;
@ -66,7 +68,7 @@ public class SystemPolling {
try { try {
objectMapper.writeValue(hardwareFile,hardwareMetric); objectMapper.writeValue(hardwareFile,hardwareMetric);
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("",e);
} }
} }
},0,pollEveryMillis, TimeUnit.MILLISECONDS); },0,pollEveryMillis, TimeUnit.MILLISECONDS);

View File

@ -27,7 +27,6 @@ import org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
import org.nd4j.linalg.dataset.api.preprocessor.Normalizer; import org.nd4j.linalg.dataset.api.preprocessor.Normalizer;
import java.io.*; import java.io.*;
import java.nio.file.Files;
import java.util.UUID; import java.util.UUID;
/** /**

View File

@ -493,7 +493,7 @@ public class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
out.println(); out.println();
} }
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("",e);
} }
return new Pair<>(dArr,temp); return new Pair<>(dArr,temp);

View File

@ -78,10 +78,10 @@ public class EarlyTerminationDataSetIteratorTest extends BaseDL4JTest {
EarlyTerminationDataSetIterator earlyEndIter = new EarlyTerminationDataSetIterator(iter, terminateAfter); EarlyTerminationDataSetIterator earlyEndIter = new EarlyTerminationDataSetIterator(iter, terminateAfter);
earlyEndIter.next(10); earlyEndIter.next(10);
assertEquals(earlyEndIter.hasNext(), false); assertEquals(false, earlyEndIter.hasNext());
earlyEndIter.reset(); earlyEndIter.reset();
assertEquals(earlyEndIter.hasNext(), true); assertEquals(true, earlyEndIter.hasNext());
} }

View File

@ -98,7 +98,7 @@ public class MultipleEpochsIteratorTest extends BaseDL4JTest {
while (multiIter.hasNext()) { while (multiIter.hasNext()) {
DataSet path = multiIter.next(10); DataSet path = multiIter.next(10);
assertNotNull(path); assertNotNull(path);
assertEquals(path.numExamples(), 10, 0.0); assertEquals(10, path.numExamples(), 0.0);
} }
assertEquals(epochs, multiIter.epochs); assertEquals(epochs, multiIter.epochs);

View File

@ -33,7 +33,7 @@ public class SamplingTest extends BaseDL4JTest {
DataSetIterator iter = new MnistDataSetIterator(10, 10); DataSetIterator iter = new MnistDataSetIterator(10, 10);
//batch size and total //batch size and total
DataSetIterator sampling = new SamplingDataSetIterator(iter.next(), 10, 10); DataSetIterator sampling = new SamplingDataSetIterator(iter.next(), 10, 10);
assertEquals(sampling.next().numExamples(), 10); assertEquals(10, sampling.next().numExamples());
} }
} }

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.exceptions; package org.deeplearning4j.exceptions;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.exception.DL4JException; import org.deeplearning4j.exception.DL4JException;
import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.ConvolutionMode;
@ -34,6 +35,7 @@ import static org.junit.Assert.fail;
/** /**
* A set of tests to ensure that useful exceptions are thrown on invalid network configurations * A set of tests to ensure that useful exceptions are thrown on invalid network configurations
*/ */
@Slf4j
public class TestInvalidConfigurations extends BaseDL4JTest { public class TestInvalidConfigurations extends BaseDL4JTest {
public static MultiLayerNetwork getDensePlusOutput(int nIn, int nOut) { public static MultiLayerNetwork getDensePlusOutput(int nIn, int nOut) {
@ -78,7 +80,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) { } catch (DL4JException e) {
System.out.println("testDenseNin0(): " + e.getMessage()); System.out.println("testDenseNin0(): " + e.getMessage());
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail(); fail();
} }
} }
@ -96,7 +98,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) { } catch (DL4JException e) {
System.out.println("testDenseNout0(): " + e.getMessage()); System.out.println("testDenseNout0(): " + e.getMessage());
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail(); fail();
} }
} }
@ -109,7 +111,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) { } catch (DL4JException e) {
System.out.println("testOutputLayerNin0(): " + e.getMessage()); System.out.println("testOutputLayerNin0(): " + e.getMessage());
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail(); fail();
} }
} }
@ -122,7 +124,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) { } catch (DL4JException e) {
System.out.println("testRnnOutputLayerNin0(): " + e.getMessage()); System.out.println("testRnnOutputLayerNin0(): " + e.getMessage());
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail(); fail();
} }
} }
@ -135,7 +137,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) { } catch (DL4JException e) {
System.out.println("testLSTMNIn0(): " + e.getMessage()); System.out.println("testLSTMNIn0(): " + e.getMessage());
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail(); fail();
} }
} }
@ -153,7 +155,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) { } catch (DL4JException e) {
System.out.println("testLSTMNOut0(): " + e.getMessage()); System.out.println("testLSTMNOut0(): " + e.getMessage());
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail(); fail();
} }
} }
@ -166,7 +168,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) { } catch (DL4JException e) {
System.out.println("testConvolutionalNin0(): " + e.getMessage()); System.out.println("testConvolutionalNin0(): " + e.getMessage());
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail(); fail();
} }
} }
@ -185,7 +187,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) { } catch (DL4JException e) {
System.out.println("testConvolutionalNOut0(): " + e.getMessage()); System.out.println("testConvolutionalNOut0(): " + e.getMessage());
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail(); fail();
} }
} }
@ -216,7 +218,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) { } catch (DL4JException e) {
System.out.println("testCnnInvalidConfigPaddingStridesHeight(): " + e.getMessage()); System.out.println("testCnnInvalidConfigPaddingStridesHeight(): " + e.getMessage());
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail(); fail();
} }
} }
@ -245,7 +247,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) { } catch (DL4JException e) {
System.out.println("testCnnInvalidConfigOrInput_SmallerDataThanKernel(): " + e.getMessage()); System.out.println("testCnnInvalidConfigOrInput_SmallerDataThanKernel(): " + e.getMessage());
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail(); fail();
} }
} }
@ -277,7 +279,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) { } catch (DL4JException e) {
System.out.println("testCnnInvalidConfigOrInput_BadStrides(): " + e.getMessage()); System.out.println("testCnnInvalidConfigOrInput_BadStrides(): " + e.getMessage());
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail(); fail();
} }
} }
@ -318,7 +320,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) { } catch (DL4JException e) {
System.out.println("testCnnInvalidConfigPaddingStridesWidth(): " + e.getMessage()); System.out.println("testCnnInvalidConfigPaddingStridesWidth(): " + e.getMessage());
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail(); fail();
} }
} }
@ -347,7 +349,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) { } catch (DL4JException e) {
System.out.println("testCnnInvalidConfigPaddingStridesWidthSubsampling(): " + e.getMessage()); System.out.println("testCnnInvalidConfigPaddingStridesWidthSubsampling(): " + e.getMessage());
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail(); fail();
} }
} }

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.exceptions; package org.deeplearning4j.exceptions;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.exception.DL4JException; import org.deeplearning4j.exception.DL4JException;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
@ -36,6 +37,7 @@ import static org.junit.Assert.*;
/** /**
* A set of tests to ensure that useful exceptions are thrown on invalid input * A set of tests to ensure that useful exceptions are thrown on invalid input
*/ */
@Slf4j
public class TestInvalidInput extends BaseDL4JTest { public class TestInvalidInput extends BaseDL4JTest {
@Test @Test
@ -53,7 +55,7 @@ public class TestInvalidInput extends BaseDL4JTest {
} catch (DL4JException e) { } catch (DL4JException e) {
System.out.println("testInputNinMismatchDense(): " + e.getMessage()); System.out.println("testInputNinMismatchDense(): " + e.getMessage());
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail("Expected DL4JException"); fail("Expected DL4JException");
} }
} }
@ -73,7 +75,7 @@ public class TestInvalidInput extends BaseDL4JTest {
} catch (DL4JException e) { } catch (DL4JException e) {
System.out.println("testInputNinMismatchOutputLayer(): " + e.getMessage()); System.out.println("testInputNinMismatchOutputLayer(): " + e.getMessage());
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail("Expected DL4JException"); fail("Expected DL4JException");
} }
} }
@ -94,7 +96,7 @@ public class TestInvalidInput extends BaseDL4JTest {
//From loss function //From loss function
System.out.println("testLabelsNOutMismatchOutputLayer(): " + e.getMessage()); System.out.println("testLabelsNOutMismatchOutputLayer(): " + e.getMessage());
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail("Expected DL4JException"); fail("Expected DL4JException");
} }
} }
@ -115,7 +117,7 @@ public class TestInvalidInput extends BaseDL4JTest {
//From loss function //From loss function
System.out.println("testLabelsNOutMismatchRnnOutputLayer(): " + e.getMessage()); System.out.println("testLabelsNOutMismatchRnnOutputLayer(): " + e.getMessage());
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail("Expected DL4JException"); fail("Expected DL4JException");
} }
} }
@ -142,7 +144,7 @@ public class TestInvalidInput extends BaseDL4JTest {
} catch (DL4JException e) { } catch (DL4JException e) {
System.out.println("testInputNinMismatchConvolutional(): " + e.getMessage()); System.out.println("testInputNinMismatchConvolutional(): " + e.getMessage());
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail("Expected DL4JException"); fail("Expected DL4JException");
} }
} }
@ -169,7 +171,7 @@ public class TestInvalidInput extends BaseDL4JTest {
} catch (DL4JException e) { } catch (DL4JException e) {
System.out.println("testInputNinRank2Convolutional(): " + e.getMessage()); System.out.println("testInputNinRank2Convolutional(): " + e.getMessage());
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail("Expected DL4JException"); fail("Expected DL4JException");
} }
} }
@ -195,7 +197,7 @@ public class TestInvalidInput extends BaseDL4JTest {
} catch (DL4JException e) { } catch (DL4JException e) {
System.out.println("testInputNinRank2Subsampling(): " + e.getMessage()); System.out.println("testInputNinRank2Subsampling(): " + e.getMessage());
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail("Expected DL4JException"); fail("Expected DL4JException");
} }
} }
@ -217,7 +219,7 @@ public class TestInvalidInput extends BaseDL4JTest {
} catch (DL4JException e) { } catch (DL4JException e) {
System.out.println("testInputNinMismatchLSTM(): " + e.getMessage()); System.out.println("testInputNinMismatchLSTM(): " + e.getMessage());
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail("Expected DL4JException"); fail("Expected DL4JException");
} }
} }
@ -238,7 +240,7 @@ public class TestInvalidInput extends BaseDL4JTest {
} catch (DL4JException e) { } catch (DL4JException e) {
System.out.println("testInputNinMismatchBidirectionalLSTM(): " + e.getMessage()); System.out.println("testInputNinMismatchBidirectionalLSTM(): " + e.getMessage());
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail("Expected DL4JException"); fail("Expected DL4JException");
} }
@ -260,7 +262,7 @@ public class TestInvalidInput extends BaseDL4JTest {
} catch (DL4JException e) { } catch (DL4JException e) {
System.out.println("testInputNinMismatchEmbeddingLayer(): " + e.getMessage()); System.out.println("testInputNinMismatchEmbeddingLayer(): " + e.getMessage());
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail("Expected DL4JException"); fail("Expected DL4JException");
} }
} }
@ -305,7 +307,7 @@ public class TestInvalidInput extends BaseDL4JTest {
net.rnnTimeStep(Nd4j.create(5, 5, 10)); net.rnnTimeStep(Nd4j.create(5, 5, 10));
fail("Expected Exception - " + layerType); fail("Expected Exception - " + layerType);
} catch (Exception e) { } catch (Exception e) {
// e.printStackTrace(); log.error("",e);
String msg = e.getMessage(); String msg = e.getMessage();
assertTrue(msg, msg != null && msg.contains("rnn") && msg.contains("batch")); assertTrue(msg, msg != null && msg.contains("rnn") && msg.contains("batch"));
} }

View File

@ -216,7 +216,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
failed.add(testName + "\t" + "EXCEPTION"); failed.add(testName + "\t" + "EXCEPTION");
continue; continue;
} }
@ -383,7 +383,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
failed.add(testName + "\t" + "EXCEPTION"); failed.add(testName + "\t" + "EXCEPTION");
continue; continue;
} }
@ -693,7 +693,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
failed.add(testName + "\t" + "EXCEPTION"); failed.add(testName + "\t" + "EXCEPTION");
continue; continue;
} }

View File

@ -21,6 +21,7 @@ import lombok.AllArgsConstructor;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.exception.DL4JInvalidConfigException; import org.deeplearning4j.exception.DL4JInvalidConfigException;
import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.api.OptimizationAlgorithm;
@ -47,6 +48,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
import static org.junit.Assert.*; import static org.junit.Assert.*;
@Slf4j
public class ComputationGraphConfigurationTest extends BaseDL4JTest { public class ComputationGraphConfigurationTest extends BaseDL4JTest {
@Test @Test
@ -150,7 +152,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
fail("No exception thrown for invalid configuration"); fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) { } catch (IllegalStateException e) {
//OK - exception is good //OK - exception is good
//e.printStackTrace(); log.info(e.toString());
} }
// Use appendLayer on first layer // Use appendLayer on first layer
@ -162,7 +164,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
fail("No exception thrown for invalid configuration"); fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) { } catch (IllegalStateException e) {
//OK - exception is good //OK - exception is good
//e.printStackTrace(); log.info(e.toString());
} }
//Test no network inputs //Test no network inputs
@ -174,7 +176,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
fail("No exception thrown for invalid configuration"); fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) { } catch (IllegalStateException e) {
//OK - exception is good //OK - exception is good
//e.printStackTrace(); log.info(e.toString());
} }
//Test no network outputs //Test no network outputs
@ -185,7 +187,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
fail("No exception thrown for invalid configuration"); fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) { } catch (IllegalStateException e) {
//OK - exception is good //OK - exception is good
//e.printStackTrace(); log.info(e.toString());
} }
//Test: invalid input //Test: invalid input
@ -197,7 +199,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
fail("No exception thrown for invalid configuration"); fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) { } catch (IllegalStateException e) {
//OK - exception is good //OK - exception is good
//e.printStackTrace(); log.info(e.toString());
} }
//Test: graph with cycles //Test: graph with cycles
@ -215,7 +217,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
fail("No exception thrown for invalid configuration"); fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) { } catch (IllegalStateException e) {
//OK - exception is good //OK - exception is good
//e.printStackTrace(); log.info(e.toString());
} }
//Test: input != inputType count mismatch //Test: input != inputType count mismatch
@ -241,7 +243,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
fail("No exception thrown for invalid configuration"); fail("No exception thrown for invalid configuration");
} catch (IllegalArgumentException e) { } catch (IllegalArgumentException e) {
//OK - exception is good //OK - exception is good
//e.printStackTrace(); log.info(e.toString());
} }
} }

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.nn.conf; package org.deeplearning4j.nn.conf;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.exception.DL4JInvalidConfigException; import org.deeplearning4j.exception.DL4JInvalidConfigException;
import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Layer;
@ -46,6 +47,7 @@ import static org.junit.Assert.*;
/** /**
* Created by agibsonccc on 11/27/14. * Created by agibsonccc on 11/27/14.
*/ */
@Slf4j
public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest { public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
@Rule @Rule
@ -272,9 +274,9 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
fail("No exception thrown for invalid configuration"); fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) { } catch (IllegalStateException e) {
//OK //OK
e.printStackTrace(); log.error("",e);
} catch (Throwable e) { } catch (Throwable e) {
e.printStackTrace(); log.error("",e);
fail("Unexpected exception thrown for invalid config"); fail("Unexpected exception thrown for invalid config");
} }
@ -288,9 +290,9 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
fail("No exception thrown for invalid configuration"); fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) { } catch (IllegalStateException e) {
//OK //OK
e.printStackTrace(); log.info(e.toString());
} catch (Throwable e) { } catch (Throwable e) {
e.printStackTrace(); log.error("",e);
fail("Unexpected exception thrown for invalid config"); fail("Unexpected exception thrown for invalid config");
} }
@ -304,9 +306,9 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
fail("No exception thrown for invalid configuration"); fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) { } catch (IllegalStateException e) {
//OK //OK
e.printStackTrace(); log.info(e.toString());
} catch (Throwable e) { } catch (Throwable e) {
e.printStackTrace(); log.error("",e);
fail("Unexpected exception thrown for invalid config"); fail("Unexpected exception thrown for invalid config");
} }
} }

View File

@ -96,8 +96,8 @@ public class TestConstraints extends BaseDL4JTest {
} else if (lc instanceof NonNegativeConstraint) { } else if (lc instanceof NonNegativeConstraint) {
assertTrue(RW0.minNumber().doubleValue() >= 0.0); assertTrue(RW0.minNumber().doubleValue() >= 0.0);
} else if (lc instanceof UnitNormConstraint) { } else if (lc instanceof UnitNormConstraint) {
assertEquals(RW0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6); assertEquals(1.0, RW0.norm2(1).minNumber().doubleValue(), 1e-6);
assertEquals(RW0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6); assertEquals(1.0, RW0.norm2(1).maxNumber().doubleValue(), 1e-6);
} }
TestUtils.testModelSerialization(net); TestUtils.testModelSerialization(net);
@ -149,8 +149,8 @@ public class TestConstraints extends BaseDL4JTest {
} else if (lc instanceof NonNegativeConstraint) { } else if (lc instanceof NonNegativeConstraint) {
assertTrue(b0.minNumber().doubleValue() >= 0.0); assertTrue(b0.minNumber().doubleValue() >= 0.0);
} else if (lc instanceof UnitNormConstraint) { } else if (lc instanceof UnitNormConstraint) {
assertEquals(b0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6); assertEquals(1.0, b0.norm2(1).minNumber().doubleValue(), 1e-6);
assertEquals(b0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6); assertEquals(1.0, b0.norm2(1).maxNumber().doubleValue(), 1e-6);
} }
TestUtils.testModelSerialization(net); TestUtils.testModelSerialization(net);
@ -201,8 +201,8 @@ public class TestConstraints extends BaseDL4JTest {
} else if (lc instanceof NonNegativeConstraint) { } else if (lc instanceof NonNegativeConstraint) {
assertTrue(w0.minNumber().doubleValue() >= 0.0); assertTrue(w0.minNumber().doubleValue() >= 0.0);
} else if (lc instanceof UnitNormConstraint) { } else if (lc instanceof UnitNormConstraint) {
assertEquals(w0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6); assertEquals(1.0, w0.norm2(1).minNumber().doubleValue(), 1e-6);
assertEquals(w0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6); assertEquals(1.0, w0.norm2(1).maxNumber().doubleValue(), 1e-6);
} }
TestUtils.testModelSerialization(net); TestUtils.testModelSerialization(net);
@ -259,10 +259,10 @@ public class TestConstraints extends BaseDL4JTest {
assertTrue(w0.minNumber().doubleValue() >= 0.0); assertTrue(w0.minNumber().doubleValue() >= 0.0);
assertTrue(b0.minNumber().doubleValue() >= 0.0); assertTrue(b0.minNumber().doubleValue() >= 0.0);
} else if (lc instanceof UnitNormConstraint) { } else if (lc instanceof UnitNormConstraint) {
assertEquals(w0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6); assertEquals(1.0, w0.norm2(1).minNumber().doubleValue(), 1e-6);
assertEquals(w0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6); assertEquals(1.0, w0.norm2(1).maxNumber().doubleValue(), 1e-6);
assertEquals(b0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6); assertEquals(1.0, b0.norm2(1).minNumber().doubleValue(), 1e-6);
assertEquals(b0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6); assertEquals(1.0, b0.norm2(1).maxNumber().doubleValue(), 1e-6);
} }
TestUtils.testModelSerialization(net); TestUtils.testModelSerialization(net);
@ -320,10 +320,10 @@ public class TestConstraints extends BaseDL4JTest {
assertTrue(w0.minNumber().doubleValue() >= 0.0); assertTrue(w0.minNumber().doubleValue() >= 0.0);
assertTrue(b0.minNumber().doubleValue() >= 0.0); assertTrue(b0.minNumber().doubleValue() >= 0.0);
} else if (lc instanceof UnitNormConstraint) { } else if (lc instanceof UnitNormConstraint) {
assertEquals(w0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6); assertEquals(1.0, w0.norm2(1).minNumber().doubleValue(), 1e-6);
assertEquals(w0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6); assertEquals(1.0, w0.norm2(1).maxNumber().doubleValue(), 1e-6);
assertEquals(b0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6); assertEquals(1.0, b0.norm2(1).minNumber().doubleValue(), 1e-6);
assertEquals(b0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6); assertEquals(1.0, b0.norm2(1).maxNumber().doubleValue(), 1e-6);
} }
TestUtils.testModelSerialization(net); TestUtils.testModelSerialization(net);
@ -378,10 +378,10 @@ public class TestConstraints extends BaseDL4JTest {
} else if(lc instanceof NonNegativeConstraint ){ } else if(lc instanceof NonNegativeConstraint ){
assertTrue(w0.minNumber().doubleValue() >= 0.0 ); assertTrue(w0.minNumber().doubleValue() >= 0.0 );
} else if(lc instanceof UnitNormConstraint ){ } else if(lc instanceof UnitNormConstraint ){
assertEquals(w0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6 ); assertEquals(1.0, w0.norm2(1).minNumber().doubleValue(), 1e-6 );
assertEquals(w0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6 ); assertEquals(1.0, w0.norm2(1).maxNumber().doubleValue(), 1e-6 );
assertEquals(w1.norm2(1).minNumber().doubleValue(), 1.0, 1e-6 ); assertEquals(1.0, w1.norm2(1).minNumber().doubleValue(), 1e-6 );
assertEquals(w1.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6 ); assertEquals(1.0, w1.norm2(1).maxNumber().doubleValue(), 1e-6 );
} }
TestUtils.testModelSerialization(net); TestUtils.testModelSerialization(net);

View File

@ -156,7 +156,7 @@ public class LayerBuilderTest extends BaseDL4JTest {
checkSerialization(glstm); checkSerialization(glstm);
assertEquals(glstm.getForgetGateBiasInit(), 1.5, 0.0); assertEquals(1.5, glstm.getForgetGateBiasInit(), 0.0);
assertEquals(glstm.nIn, numIn); assertEquals(glstm.nIn, numIn);
assertEquals(glstm.nOut, numOut); assertEquals(glstm.nOut, numOut);
assertTrue(glstm.getActivationFn() instanceof ActivationTanH); assertTrue(glstm.getActivationFn() instanceof ActivationTanH);

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.nn.graph; package org.deeplearning4j.nn.graph;
import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.IteratorDataSetIterator; import org.deeplearning4j.datasets.iterator.IteratorDataSetIterator;
@ -54,6 +55,7 @@ import java.util.Map;
import static org.junit.Assert.*; import static org.junit.Assert.*;
@Slf4j
public class ComputationGraphTestRNN extends BaseDL4JTest { public class ComputationGraphTestRNN extends BaseDL4JTest {
@Test @Test
@ -618,7 +620,7 @@ public class ComputationGraphTestRNN extends BaseDL4JTest {
.build(); .build();
fail("Exception expected"); fail("Exception expected");
} catch (IllegalStateException e){ } catch (IllegalStateException e){
// e.printStackTrace(); log.error("",e);
assertTrue(e.getMessage().contains("TBPTT") && e.getMessage().contains("validateTbpttConfig")); assertTrue(e.getMessage().contains("TBPTT") && e.getMessage().contains("validateTbpttConfig"));
} }
} }

View File

@ -1394,7 +1394,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
} catch (Exception e) { } catch (Exception e) {
//e.printStackTrace(); log.error("",e);
if(allowDisconnected){ if(allowDisconnected){
fail("No exception expected"); fail("No exception expected");
} else { } else {

View File

@ -416,8 +416,8 @@ public class TestVariableLengthTSCG extends BaseDL4JTest {
INDArray outRow2 = out2.get(NDArrayIndex.point(i), NDArrayIndex.all(), INDArray outRow2 = out2.get(NDArrayIndex.point(i), NDArrayIndex.all(),
NDArrayIndex.point(j)); NDArrayIndex.point(j));
for (int k = 0; k < nOut; k++) { for (int k = 0; k < nOut; k++) {
assertEquals(outRow.getDouble(k), 0.0, 0.0); assertEquals(0.0, outRow.getDouble(k), 0.0);
assertEquals(outRow2.getDouble(k), 0.0, 0.0); assertEquals(0.0, outRow2.getDouble(k), 0.0);
} }
} }
} }

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.nn.layers.convolution; package org.deeplearning4j.nn.layers.convolution;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.exception.DL4JException; import org.deeplearning4j.exception.DL4JException;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
@ -45,6 +46,7 @@ import static org.junit.Assert.*;
/** /**
* Created by Alex on 15/11/2016. * Created by Alex on 15/11/2016.
*/ */
@Slf4j
public class TestConvolutionModes extends BaseDL4JTest { public class TestConvolutionModes extends BaseDL4JTest {
@Test @Test
@ -106,12 +108,12 @@ public class TestConvolutionModes extends BaseDL4JTest {
} }
} catch (DL4JException e) { } catch (DL4JException e) {
if (inSize == 9 || cm != ConvolutionMode.Strict) { if (inSize == 9 || cm != ConvolutionMode.Strict) {
e.printStackTrace(); log.error("",e);
fail("Unexpected exception"); fail("Unexpected exception");
} }
continue; //Expected exception continue; //Expected exception
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail("Unexpected exception"); fail("Unexpected exception");
} }
@ -184,12 +186,12 @@ public class TestConvolutionModes extends BaseDL4JTest {
} }
} catch (DL4JException e) { } catch (DL4JException e) {
if (inSize == 9 || cm != ConvolutionMode.Strict) { if (inSize == 9 || cm != ConvolutionMode.Strict) {
e.printStackTrace(); log.error("",e);
fail("Unexpected exception"); fail("Unexpected exception");
} }
continue; //Expected exception continue; //Expected exception
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail("Unexpected exception"); fail("Unexpected exception");
} }

View File

@ -79,13 +79,13 @@ public class MaskZeroLayerTest extends BaseDL4JTest {
INDArray firstExampleOutput = out.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.all()); INDArray firstExampleOutput = out.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.all());
INDArray secondExampleOutput = out.get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.all()); INDArray secondExampleOutput = out.get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.all());
assertEquals(firstExampleOutput.getDouble(0), 0.0, 1e-6); assertEquals(0.0, firstExampleOutput.getDouble(0), 1e-6);
assertEquals(firstExampleOutput.getDouble(1), 1.0, 1e-6); assertEquals(1.0, firstExampleOutput.getDouble(1), 1e-6);
assertEquals(firstExampleOutput.getDouble(2), 2.0, 1e-6); assertEquals(2.0, firstExampleOutput.getDouble(2), 1e-6);
assertEquals(secondExampleOutput.getDouble(0), 0.0, 1e-6); assertEquals(0.0, secondExampleOutput.getDouble(0), 1e-6);
assertEquals(secondExampleOutput.getDouble(1), 0.0, 1e-6); assertEquals(0.0, secondExampleOutput.getDouble(1), 1e-6);
assertEquals(secondExampleOutput.getDouble(2), 1.0, 1e-6); assertEquals(1.0, secondExampleOutput.getDouble(2), 1e-6);
} }

View File

@ -771,7 +771,7 @@ public class MultiLayerTestRNN extends BaseDL4JTest {
.build(); .build();
fail("Exception expected"); fail("Exception expected");
} catch (IllegalStateException e){ } catch (IllegalStateException e){
// e.printStackTrace(); log.info(e.toString());
assertTrue(e.getMessage().contains("TBPTT") && e.getMessage().contains("validateTbpttConfig")); assertTrue(e.getMessage().contains("TBPTT") && e.getMessage().contains("validateTbpttConfig"));
} }
} }

View File

@ -408,8 +408,8 @@ public class TestVariableLengthTS extends BaseDL4JTest {
INDArray outRow2 = out2.get(NDArrayIndex.point(i), NDArrayIndex.all(), INDArray outRow2 = out2.get(NDArrayIndex.point(i), NDArrayIndex.all(),
NDArrayIndex.point(j)); NDArrayIndex.point(j));
for (int k = 0; k < nOut; k++) { for (int k = 0; k < nOut; k++) {
assertEquals(outRow.getDouble(k), 0.0, 0.0); assertEquals(0.0, outRow.getDouble(k), 0.0);
assertEquals(outRow2.getDouble(k), 0.0, 0.0); assertEquals(0.0, outRow2.getDouble(k), 0.0);
} }
} }
} }

View File

@ -336,7 +336,7 @@ public class TestUpdaters extends BaseDL4JTest {
actualM[i] = Math.round(actualM[i] * 1e2) / 1e2; actualM[i] = Math.round(actualM[i] * 1e2) / 1e2;
} }
assertEquals("Wrong weight gradient after first iteration's update", Arrays.equals(actualM, expectedM), true); assertEquals("Wrong weight gradient after first iteration's update", Arrays.equals(expectedM, actualM), true);
} }

View File

@ -159,14 +159,14 @@ public class RegressionTest050 extends BaseDL4JTest {
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize()); assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
assertArrayEquals(new int[] {1, 1}, l0.getStride()); assertArrayEquals(new int[] {1, 1}, l0.getStride());
assertArrayEquals(new int[] {0, 0}, l0.getPadding()); assertArrayEquals(new int[] {0, 0}, l0.getPadding());
assertEquals(l0.getConvolutionMode(), ConvolutionMode.Truncate); //Pre-0.7.0: no ConvolutionMode. Want to default to truncate here if not set assertEquals(ConvolutionMode.Truncate, l0.getConvolutionMode()); //Pre-0.7.0: no ConvolutionMode. Want to default to truncate here if not set
SubsamplingLayer l1 = (SubsamplingLayer) conf.getConf(1).getLayer(); SubsamplingLayer l1 = (SubsamplingLayer) conf.getConf(1).getLayer();
assertArrayEquals(new int[] {2, 2}, l1.getKernelSize()); assertArrayEquals(new int[] {2, 2}, l1.getKernelSize());
assertArrayEquals(new int[] {1, 1}, l1.getStride()); assertArrayEquals(new int[] {1, 1}, l1.getStride());
assertArrayEquals(new int[] {0, 0}, l1.getPadding()); assertArrayEquals(new int[] {0, 0}, l1.getPadding());
assertEquals(PoolingType.MAX, l1.getPoolingType()); assertEquals(PoolingType.MAX, l1.getPoolingType());
assertEquals(l1.getConvolutionMode(), ConvolutionMode.Truncate); //Pre-0.7.0: no ConvolutionMode. Want to default to truncate here if not set assertEquals(ConvolutionMode.Truncate, l1.getConvolutionMode()); //Pre-0.7.0: no ConvolutionMode. Want to default to truncate here if not set
OutputLayer l2 = (OutputLayer) conf.getConf(2).getLayer(); OutputLayer l2 = (OutputLayer) conf.getConf(2).getLayer();
assertEquals("sigmoid", l2.getActivationFn().toString()); assertEquals("sigmoid", l2.getActivationFn().toString());

View File

@ -162,14 +162,14 @@ public class RegressionTest060 extends BaseDL4JTest {
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize()); assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
assertArrayEquals(new int[] {1, 1}, l0.getStride()); assertArrayEquals(new int[] {1, 1}, l0.getStride());
assertArrayEquals(new int[] {0, 0}, l0.getPadding()); assertArrayEquals(new int[] {0, 0}, l0.getPadding());
assertEquals(l0.getConvolutionMode(), ConvolutionMode.Truncate); //Pre-0.7.0: no ConvolutionMode. Want to default to truncate here if not set assertEquals(ConvolutionMode.Truncate, l0.getConvolutionMode()); //Pre-0.7.0: no ConvolutionMode. Want to default to truncate here if not set
SubsamplingLayer l1 = (SubsamplingLayer) conf.getConf(1).getLayer(); SubsamplingLayer l1 = (SubsamplingLayer) conf.getConf(1).getLayer();
assertArrayEquals(new int[] {2, 2}, l1.getKernelSize()); assertArrayEquals(new int[] {2, 2}, l1.getKernelSize());
assertArrayEquals(new int[] {1, 1}, l1.getStride()); assertArrayEquals(new int[] {1, 1}, l1.getStride());
assertArrayEquals(new int[] {0, 0}, l1.getPadding()); assertArrayEquals(new int[] {0, 0}, l1.getPadding());
assertEquals(PoolingType.MAX, l1.getPoolingType()); assertEquals(PoolingType.MAX, l1.getPoolingType());
assertEquals(l1.getConvolutionMode(), ConvolutionMode.Truncate); //Pre-0.7.0: no ConvolutionMode. Want to default to truncate here if not set assertEquals(ConvolutionMode.Truncate, l1.getConvolutionMode()); //Pre-0.7.0: no ConvolutionMode. Want to default to truncate here if not set
OutputLayer l2 = (OutputLayer) conf.getConf(2).getLayer(); OutputLayer l2 = (OutputLayer) conf.getConf(2).getLayer();
assertEquals("sigmoid", l2.getActivationFn().toString()); assertEquals("sigmoid", l2.getActivationFn().toString());

View File

@ -162,7 +162,7 @@ public class RegressionTest071 extends BaseDL4JTest {
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize()); assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
assertArrayEquals(new int[] {1, 1}, l0.getStride()); assertArrayEquals(new int[] {1, 1}, l0.getStride());
assertArrayEquals(new int[] {0, 0}, l0.getPadding()); assertArrayEquals(new int[] {0, 0}, l0.getPadding());
assertEquals(l0.getConvolutionMode(), ConvolutionMode.Same); assertEquals(ConvolutionMode.Same, l0.getConvolutionMode());
SubsamplingLayer l1 = (SubsamplingLayer) conf.getConf(1).getLayer(); SubsamplingLayer l1 = (SubsamplingLayer) conf.getConf(1).getLayer();
assertArrayEquals(new int[] {2, 2}, l1.getKernelSize()); assertArrayEquals(new int[] {2, 2}, l1.getKernelSize());

View File

@ -176,14 +176,14 @@ public class RegressionTest080 extends BaseDL4JTest {
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize()); assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
assertArrayEquals(new int[] {1, 1}, l0.getStride()); assertArrayEquals(new int[] {1, 1}, l0.getStride());
assertArrayEquals(new int[] {0, 0}, l0.getPadding()); assertArrayEquals(new int[] {0, 0}, l0.getPadding());
assertEquals(l0.getConvolutionMode(), ConvolutionMode.Same); assertEquals(ConvolutionMode.Same, l0.getConvolutionMode());
SubsamplingLayer l1 = (SubsamplingLayer) conf.getConf(1).getLayer(); SubsamplingLayer l1 = (SubsamplingLayer) conf.getConf(1).getLayer();
assertArrayEquals(new int[] {2, 2}, l1.getKernelSize()); assertArrayEquals(new int[] {2, 2}, l1.getKernelSize());
assertArrayEquals(new int[] {1, 1}, l1.getStride()); assertArrayEquals(new int[] {1, 1}, l1.getStride());
assertArrayEquals(new int[] {0, 0}, l1.getPadding()); assertArrayEquals(new int[] {0, 0}, l1.getPadding());
assertEquals(PoolingType.MAX, l1.getPoolingType()); assertEquals(PoolingType.MAX, l1.getPoolingType());
assertEquals(l1.getConvolutionMode(), ConvolutionMode.Same); assertEquals(ConvolutionMode.Same, l1.getConvolutionMode());
OutputLayer l2 = (OutputLayer) conf.getConf(2).getLayer(); OutputLayer l2 = (OutputLayer) conf.getConf(2).getLayer();
assertTrue(l2.getActivationFn() instanceof ActivationSigmoid); assertTrue(l2.getActivationFn() instanceof ActivationSigmoid);

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.datasets.fetchers; package org.deeplearning4j.datasets.fetchers;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils; import org.apache.commons.io.FilenameUtils;
import org.deeplearning4j.base.EmnistFetcher; import org.deeplearning4j.base.EmnistFetcher;
@ -36,6 +37,7 @@ import java.util.Random;
* @author Alex Black * @author Alex Black
* *
*/ */
@Slf4j
public class EmnistDataFetcher extends MnistDataFetcher implements DataSetFetcher { public class EmnistDataFetcher extends MnistDataFetcher implements DataSetFetcher {
protected EmnistFetcher fetcher; protected EmnistFetcher fetcher;
@ -64,7 +66,7 @@ public class EmnistDataFetcher extends MnistDataFetcher implements DataSetFetche
try { try {
man = new MnistManager(images, labels, totalExamples); man = new MnistManager(images, labels, totalExamples);
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
FileUtils.deleteDirectory(new File(EMNIST_ROOT)); FileUtils.deleteDirectory(new File(EMNIST_ROOT));
new EmnistFetcher(dataSet).downloadAndUntar(); new EmnistFetcher(dataSet).downloadAndUntar();
man = new MnistManager(images, labels, totalExamples); man = new MnistManager(images, labels, totalExamples);

View File

@ -687,9 +687,7 @@ public class BarnesHutTsne implements Model {
* @throws IOException * @throws IOException
*/ */
public void saveAsFile(List<String> labels, String path) throws IOException { public void saveAsFile(List<String> labels, String path) throws IOException {
BufferedWriter write = null; try (BufferedWriter write = new BufferedWriter(new FileWriter(new File(path)))) {
try {
write = new BufferedWriter(new FileWriter(new File(path)));
for (int i = 0; i < Y.rows(); i++) { for (int i = 0; i < Y.rows(); i++) {
if (i >= labels.size()) if (i >= labels.size())
break; break;
@ -711,17 +709,11 @@ public class BarnesHutTsne implements Model {
} }
write.flush(); write.flush();
write.close();
} finally {
if (write != null)
write.close();
} }
} }
public void saveAsFile(String path) throws IOException { public void saveAsFile(String path) throws IOException {
BufferedWriter write = null; try (BufferedWriter write = new BufferedWriter(new FileWriter(new File(path)))) {
try {
write = new BufferedWriter(new FileWriter(new File(path)));
for (int i = 0; i < Y.rows(); i++) { for (int i = 0; i < Y.rows(); i++) {
StringBuilder sb = new StringBuilder(); StringBuilder sb = new StringBuilder();
INDArray wordVector = Y.getRow(i); INDArray wordVector = Y.getRow(i);
@ -734,10 +726,6 @@ public class BarnesHutTsne implements Model {
write.write(sb.toString()); write.write(sb.toString());
} }
write.flush(); write.flush();
write.close();
} finally {
if (write != null)
write.close();
} }
} }
/** /**

View File

@ -60,7 +60,7 @@ public class Hdf5Archive implements Closeable {
/* This is necessary for the call to the BytePointer constructor below. */ /* This is necessary for the call to the BytePointer constructor below. */
Loader.load(org.bytedeco.hdf5.global.hdf5.class); Loader.load(org.bytedeco.hdf5.global.hdf5.class);
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
} }
} }

View File

@ -69,7 +69,7 @@ public class KerasModelImportTest extends BaseDL4JTest {
network = KerasModelImport.importKerasSequentialModelAndWeights(Resources.asFile(modelJsonFilename).getAbsolutePath(), network = KerasModelImport.importKerasSequentialModelAndWeights(Resources.asFile(modelJsonFilename).getAbsolutePath(),
Resources.asFile(modelWeightFilename).getAbsolutePath(), false); Resources.asFile(modelWeightFilename).getAbsolutePath(), false);
} catch (IOException | InvalidKerasConfigurationException | UnsupportedKerasConfigurationException e) { } catch (IOException | InvalidKerasConfigurationException | UnsupportedKerasConfigurationException e) {
e.printStackTrace(); log.error("",e);
} }
return network; return network;
@ -80,7 +80,7 @@ public class KerasModelImportTest extends BaseDL4JTest {
try { try {
model = KerasModelImport.importKerasSequentialModelAndWeights(Resources.asFile(modelFilename).getAbsolutePath()); model = KerasModelImport.importKerasSequentialModelAndWeights(Resources.asFile(modelFilename).getAbsolutePath());
} catch (IOException | InvalidKerasConfigurationException | UnsupportedKerasConfigurationException e) { } catch (IOException | InvalidKerasConfigurationException | UnsupportedKerasConfigurationException e) {
e.printStackTrace(); log.error("",e);
} }
return model; return model;

View File

@ -210,7 +210,6 @@ public class NearestNeighborsServer extends AbstractVerticle {
return; return;
} catch (Throwable e) { } catch (Throwable e) {
log.error("Error in POST /knn",e); log.error("Error in POST /knn",e);
e.printStackTrace();
rc.response().setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code()) rc.response().setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code())
.end("Error parsing request - " + e.getMessage()); .end("Error parsing request - " + e.getMessage());
return; return;
@ -265,7 +264,6 @@ public class NearestNeighborsServer extends AbstractVerticle {
.end(j); .end(j);
} catch (Throwable e) { } catch (Throwable e) {
log.error("Error in POST /knnnew",e); log.error("Error in POST /knnnew",e);
e.printStackTrace();
rc.response().setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code()) rc.response().setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code())
.end("Error parsing request - " + e.getMessage()); .end("Error parsing request - " + e.getMessage());
return; return;

View File

@ -75,7 +75,6 @@ public class StopRecognition implements Recognition {
try { try {
regexList.add(Pattern.compile(regex)); regexList.add(Pattern.compile(regex));
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace();
LOG.error("regex err : " + regex, e); LOG.error("regex err : " + regex, e);
} }
} }

View File

@ -20,10 +20,12 @@ import com.atilika.kuromoji.dict.CharacterDefinitions;
import com.atilika.kuromoji.dict.ConnectionCosts; import com.atilika.kuromoji.dict.ConnectionCosts;
import com.atilika.kuromoji.dict.UnknownDictionary; import com.atilika.kuromoji.dict.UnknownDictionary;
import com.atilika.kuromoji.trie.DoubleArrayTrie; import com.atilika.kuromoji.trie.DoubleArrayTrie;
import lombok.extern.slf4j.Slf4j;
import java.io.*; import java.io.*;
import java.util.List; import java.util.List;
@Slf4j
public abstract class DictionaryCompilerBase { public abstract class DictionaryCompilerBase {
public void build(String inputDirname, String outputDirname, String encoding, boolean compactTries) public void build(String inputDirname, String outputDirname, String encoding, boolean compactTries)
@ -66,7 +68,7 @@ public abstract class DictionaryCompilerBase {
} }
} }
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
} }
ProgressLog.end(); ProgressLog.end();

View File

@ -33,7 +33,6 @@ public class FileResourceResolver implements ResourceResolver {
try { try {
KuromojiBinFilesFetcher.downloadAndUntar(); KuromojiBinFilesFetcher.downloadAndUntar();
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace();
log.error("IOException : ", e); log.error("IOException : ", e);
} }
} }

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.text.corpora.sentiwordnet; package org.deeplearning4j.text.corpora.sentiwordnet;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.shade.guava.collect.Sets; import org.nd4j.shade.guava.collect.Sets;
import org.apache.uima.analysis_engine.AnalysisEngine; import org.apache.uima.analysis_engine.AnalysisEngine;
import org.apache.uima.cas.CAS; import org.apache.uima.cas.CAS;
@ -37,6 +38,7 @@ import java.util.*;
* @author Adam Gibson * @author Adam Gibson
* *
*/ */
@Slf4j
public class SWN3 implements Serializable { public class SWN3 implements Serializable {
/** /**
* *
@ -120,7 +122,7 @@ public class SWN3 implements Serializable {
try { try {
csv.close(); csv.close();
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("",e);
} }
} }
} }

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.text.tokenization.tokenizer; package org.deeplearning4j.text.tokenization.tokenizer;
import lombok.extern.slf4j.Slf4j;
import org.apache.uima.cas.CAS; import org.apache.uima.cas.CAS;
import org.apache.uima.fit.util.JCasUtil; import org.apache.uima.fit.util.JCasUtil;
import org.cleartk.token.type.Token; import org.cleartk.token.type.Token;
@ -32,6 +33,7 @@ import java.util.List;
* @author Adam Gibson * @author Adam Gibson
* *
*/ */
@Slf4j
public class UimaTokenizer implements Tokenizer { public class UimaTokenizer implements Tokenizer {
private List<String> tokens; private List<String> tokens;
@ -66,7 +68,7 @@ public class UimaTokenizer implements Tokenizer {
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
throw new RuntimeException(e); throw new RuntimeException(e);
} }

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.models.word2vec; package org.deeplearning4j.models.word2vec;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils; import org.apache.commons.io.IOUtils;
import org.apache.commons.io.LineIterator; import org.apache.commons.io.LineIterator;
import org.deeplearning4j.text.sentenceiterator.CollectionSentenceIterator; import org.deeplearning4j.text.sentenceiterator.CollectionSentenceIterator;
@ -64,6 +65,7 @@ import static org.junit.Assert.*;
/** /**
* @author jeffreytang * @author jeffreytang
*/ */
@Slf4j
public class Word2VecTests extends BaseDL4JTest { public class Word2VecTests extends BaseDL4JTest {
private static final Logger log = LoggerFactory.getLogger(Word2VecTests.class); private static final Logger log = LoggerFactory.getLogger(Word2VecTests.class);
@ -621,7 +623,7 @@ public class Word2VecTests extends BaseDL4JTest {
unserialized = Word2Vec.fromJson(json); unserialized = Word2Vec.fromJson(json);
} }
catch (Exception e) { catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail(); fail();
} }

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.models.sequencevectors.listeners; package org.deeplearning4j.models.sequencevectors.listeners;
import lombok.NonNull; import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.models.sequencevectors.SequenceVectors; import org.deeplearning4j.models.sequencevectors.SequenceVectors;
import org.deeplearning4j.models.sequencevectors.enums.ListenerEvent; import org.deeplearning4j.models.sequencevectors.enums.ListenerEvent;
import org.deeplearning4j.models.sequencevectors.interfaces.VectorsListener; import org.deeplearning4j.models.sequencevectors.interfaces.VectorsListener;
@ -34,6 +35,7 @@ import java.util.concurrent.Semaphore;
* *
* @author raver119@gmail.com * @author raver119@gmail.com
*/ */
@Slf4j
public class SerializingListener<T extends SequenceElement> implements VectorsListener<T> { public class SerializingListener<T extends SequenceElement> implements VectorsListener<T> {
private File targetFolder = new File("./"); private File targetFolder = new File("./");
private String modelPrefix = "Model_"; private String modelPrefix = "Model_";
@ -96,7 +98,7 @@ public class SerializingListener<T extends SequenceElement> implements VectorsLi
} }
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
} finally { } finally {
locker.release(); locker.release();
} }

View File

@ -147,7 +147,7 @@ public class ParallelTransformerIterator extends BasicTransformerIterator {
try { try {
buffer.put(futureSequence); buffer.put(futureSequence);
} catch (InterruptedException e) { } catch (InterruptedException e) {
e.printStackTrace(); log.error("",e);
} }
} }
/* else /* else

View File

@ -21,6 +21,7 @@ import lombok.NonNull;
import org.deeplearning4j.parallelism.AsyncIterator; import org.deeplearning4j.parallelism.AsyncIterator;
import java.util.Iterator; import java.util.Iterator;
import java.util.NoSuchElementException;
/** /**
* @author raver119@gmail.com * @author raver119@gmail.com
@ -77,7 +78,7 @@ public class AsyncLabelAwareIterator implements LabelAwareIterator, Iterator<Lab
} }
@Override @Override
public LabelledDocument next() { public LabelledDocument next() throws NoSuchElementException {
return nextDocument(); return nextDocument();
} }
} }

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.text.sentenceiterator; package org.deeplearning4j.text.sentenceiterator;
import lombok.NonNull; import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import java.io.*; import java.io.*;
import java.util.Iterator; import java.util.Iterator;
@ -29,6 +30,7 @@ import java.util.Iterator;
* *
* @author raver119@gmail.com * @author raver119@gmail.com
*/ */
@Slf4j
public class BasicLineIterator implements SentenceIterator, Iterable<String> { public class BasicLineIterator implements SentenceIterator, Iterable<String> {
private BufferedReader reader; private BufferedReader reader;
@ -113,7 +115,7 @@ public class BasicLineIterator implements SentenceIterator, Iterable<String> {
reader.close(); reader.close();
} catch (Exception e) { } catch (Exception e) {
// do nothing here // do nothing here
e.printStackTrace(); log.error("",e);
} }
super.finalize(); super.finalize();
} }

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.text.sentenceiterator; package org.deeplearning4j.text.sentenceiterator;
import lombok.NonNull; import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.text.documentiterator.DocumentIterator; import org.deeplearning4j.text.documentiterator.DocumentIterator;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@ -35,6 +36,7 @@ import java.util.concurrent.atomic.AtomicBoolean;
* *
* @author raver119@gmail.com * @author raver119@gmail.com
*/ */
@Slf4j
public class StreamLineIterator implements SentenceIterator { public class StreamLineIterator implements SentenceIterator {
private DocumentIterator iterator; private DocumentIterator iterator;
private int linesToFetch; private int linesToFetch;
@ -64,7 +66,7 @@ public class StreamLineIterator implements SentenceIterator {
currentReader = null; currentReader = null;
} }
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("",e);
throw new RuntimeException(e); throw new RuntimeException(e);
} }
} }
@ -145,7 +147,7 @@ public class StreamLineIterator implements SentenceIterator {
try { try {
this.onlyStream.reset(); this.onlyStream.reset();
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("",e);
throw new RuntimeException(e); throw new RuntimeException(e);
} }
} }

View File

@ -183,7 +183,7 @@ public class FastTextTest extends BaseDL4JTest {
fastText.loadIterator(); fastText.loadIterator();
} catch (IOException e) { } catch (IOException e) {
log.error(e.toString()); log.error("",e);
} }
} }

View File

@ -1164,7 +1164,7 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
unserialized = ParagraphVectors.fromJson(json); unserialized = ParagraphVectors.fromJson(json);
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail(); fail();
} }

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.models.sequencevectors.serialization; package org.deeplearning4j.models.sequencevectors.serialization;
import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.apache.commons.lang.StringUtils; import org.apache.commons.lang.StringUtils;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
@ -48,6 +49,7 @@ import java.util.Collections;
import static org.junit.Assert.*; import static org.junit.Assert.*;
@Slf4j
public class WordVectorSerializerTest extends BaseDL4JTest { public class WordVectorSerializerTest extends BaseDL4JTest {
private AbstractCache<VocabWord> cache; private AbstractCache<VocabWord> cache;
@ -97,7 +99,7 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
byte[] bytesResult = baos.toByteArray(); byte[] bytesResult = baos.toByteArray();
deser = WordVectorSerializer.readSequenceVectors(new ByteArrayInputStream(bytesResult), true); deser = WordVectorSerializer.readSequenceVectors(new ByteArrayInputStream(bytesResult), true);
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail(); fail();
} }
@ -175,7 +177,7 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
byte[] bytesResult = baos.toByteArray(); byte[] bytesResult = baos.toByteArray();
deser = WordVectorSerializer.readWord2Vec(new ByteArrayInputStream(bytesResult), true); deser = WordVectorSerializer.readWord2Vec(new ByteArrayInputStream(bytesResult), true);
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail(); fail();
} }
@ -223,7 +225,7 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
byte[] bytesResult = baos.toByteArray(); byte[] bytesResult = baos.toByteArray();
deser = WordVectorSerializer.readWord2Vec(new ByteArrayInputStream(bytesResult), true); deser = WordVectorSerializer.readWord2Vec(new ByteArrayInputStream(bytesResult), true);
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail(); fail();
} }
@ -268,7 +270,7 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
ByteArrayOutputStream baos = new ByteArrayOutputStream(); ByteArrayOutputStream baos = new ByteArrayOutputStream();
deser = WordVectorSerializer.readLookupTable(file); deser = WordVectorSerializer.readLookupTable(file);
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail(); fail();
} }
assertEquals(lookupTable.getVocab().totalWordOccurrences(), ((InMemoryLookupTable<VocabWord>)deser).getVocab().totalWordOccurrences()); assertEquals(lookupTable.getVocab().totalWordOccurrences(), ((InMemoryLookupTable<VocabWord>)deser).getVocab().totalWordOccurrences());
@ -306,7 +308,7 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
ByteArrayOutputStream baos = new ByteArrayOutputStream(); ByteArrayOutputStream baos = new ByteArrayOutputStream();
deser = WordVectorSerializer.readWordVectors(new File(dir, "some.data")); deser = WordVectorSerializer.readWordVectors(new File(dir, "some.data"));
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail(); fail();
} }

View File

@ -140,7 +140,7 @@ public class AbstractCacheTest extends BaseDL4JTest {
unserialized = AbstractCache.fromJson(json); unserialized = AbstractCache.fromJson(json);
} }
catch (Exception e) { catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail(); fail();
} }
assertEquals(cache.totalWordOccurrences(),unserialized.totalWordOccurrences()); assertEquals(cache.totalWordOccurrences(),unserialized.totalWordOccurrences());
@ -175,7 +175,7 @@ public class AbstractCacheTest extends BaseDL4JTest {
unserialized = AbstractCache.fromJson(json); unserialized = AbstractCache.fromJson(json);
} }
catch (Exception e) { catch (Exception e) {
e.printStackTrace(); log.error("",e);
fail(); fail();
} }
assertEquals(cache.totalWordOccurrences(),unserialized.totalWordOccurrences()); assertEquals(cache.totalWordOccurrences(),unserialized.totalWordOccurrences());

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.nn.layers; package org.deeplearning4j.nn.layers;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.exception.DL4JInvalidInputException; import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
@ -43,6 +44,7 @@ import java.util.*;
* A layer with parameters * A layer with parameters
* @author Adam Gibson * @author Adam Gibson
*/ */
@Slf4j
public abstract class BaseLayer<LayerConfT extends org.deeplearning4j.nn.conf.layers.BaseLayer> public abstract class BaseLayer<LayerConfT extends org.deeplearning4j.nn.conf.layers.BaseLayer>
extends AbstractLayer<LayerConfT> { extends AbstractLayer<LayerConfT> {
@ -371,7 +373,7 @@ public abstract class BaseLayer<LayerConfT extends org.deeplearning4j.nn.conf.la
} }
layer.setParamTable(linkedTable); layer.setParamTable(linkedTable);
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
} }
return layer; return layer;

View File

@ -1,5 +1,6 @@
package org.deeplearning4j.remote; package org.deeplearning4j.remote;
import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.datavec.image.loader.Java2DNativeImageLoader; import org.datavec.image.loader.Java2DNativeImageLoader;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
@ -32,6 +33,7 @@ import java.util.concurrent.TimeUnit;
import static org.deeplearning4j.parallelism.inference.InferenceMode.SEQUENTIAL; import static org.deeplearning4j.parallelism.inference.InferenceMode.SEQUENTIAL;
import static org.junit.Assert.*; import static org.junit.Assert.*;
@Slf4j
public class BinaryModelServerTest extends BaseDL4JTest { public class BinaryModelServerTest extends BaseDL4JTest {
private final int PORT = 18080; private final int PORT = 18080;
@ -120,7 +122,7 @@ public class BinaryModelServerTest extends BaseDL4JTest {
assertEquals(new Integer(1), result); assertEquals(new Integer(1), result);
} catch (Exception e){ } catch (Exception e){
e.printStackTrace(); log.error("",e);
throw e; throw e;
} finally { } finally {
server.stop(); server.stop();
@ -189,7 +191,7 @@ public class BinaryModelServerTest extends BaseDL4JTest {
assertEquals(new Integer(1), results[2].get()); assertEquals(new Integer(1), results[2].get());
} catch (Exception e){ } catch (Exception e){
e.printStackTrace(); log.error("",e);
throw e; throw e;
} finally { } finally {
server.stop(); server.stop();
@ -244,7 +246,7 @@ public class BinaryModelServerTest extends BaseDL4JTest {
assertEquals(28, result.getWidth()); assertEquals(28, result.getWidth());
} catch (Exception e){ } catch (Exception e){
e.printStackTrace(); log.error("",e);
throw e; throw e;
} finally { } finally {
server.stop(); server.stop();

View File

@ -585,7 +585,7 @@ public class JsonModelServerTest extends BaseDL4JTest {
assertEquals(exp.argMax().getInt(0), out); assertEquals(exp.argMax().getInt(0), out);
} }
} catch (Exception e){ } catch (Exception e){
e.printStackTrace(); log.error("",e);
throw e; throw e;
} finally { } finally {
server.stop(); server.stop();
@ -640,7 +640,7 @@ public class JsonModelServerTest extends BaseDL4JTest {
server.start(); server.start();
//client.predict(new float[]{0.0f, 1.0f, 2.0f}); //client.predict(new float[]{0.0f, 1.0f, 2.0f});
} catch (Exception e){ } catch (Exception e){
e.printStackTrace(); log.error("",e);
throw e; throw e;
} finally { } finally {
server.stop(); server.stop();
@ -700,7 +700,7 @@ public class JsonModelServerTest extends BaseDL4JTest {
val result = client.predict(new float[]{0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}); val result = client.predict(new float[]{0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f});
assertNotNull(result); assertNotNull(result);
} catch (Exception e){ } catch (Exception e){
e.printStackTrace(); log.error("",e);
throw e; throw e;
} finally { } finally {
server.stop(); server.stop();

View File

@ -660,7 +660,7 @@ public class ParallelInferenceTest extends BaseDL4JTest {
//OK //OK
System.out.println("Expected exception: " + e.getMessage()); System.out.println("Expected exception: " + e.getMessage());
} catch (Exception e){ } catch (Exception e){
e.printStackTrace(); log.error("",e);
fail("Expected other exception type"); fail("Expected other exception type");
} }
@ -903,7 +903,7 @@ public class ParallelInferenceTest extends BaseDL4JTest {
int idx = t.getRight(); int idx = t.getRight();
act[idx] = inf.output(t.getFirst(), t.getSecond()); act[idx] = inf.output(t.getFirst(), t.getSecond());
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
failedCount.incrementAndGet(); failedCount.incrementAndGet();
} }
} }
@ -955,7 +955,7 @@ public class ParallelInferenceTest extends BaseDL4JTest {
act[j] = inf.output(in.get(j), inMask); act[j] = inf.output(in.get(j), inMask);
counter.incrementAndGet(); counter.incrementAndGet();
} catch (Exception e){ } catch (Exception e){
e.printStackTrace(); log.error("",e);
failedCount.incrementAndGet(); failedCount.incrementAndGet();
} }
} }

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.spark.text.functions; package org.deeplearning4j.spark.text.functions;
import lombok.extern.slf4j.Slf4j;
import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.Function;
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess; import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess;
import org.deeplearning4j.text.tokenization.tokenizerfactory.NGramTokenizerFactory; import org.deeplearning4j.text.tokenization.tokenizerfactory.NGramTokenizerFactory;
@ -29,6 +30,7 @@ import java.util.List;
* @author Adam Gibson * @author Adam Gibson
*/ */
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@Slf4j
public class TokenizerFunction implements Function<String, List<String>> { public class TokenizerFunction implements Function<String, List<String>> {
private String tokenizerFactoryClazz; private String tokenizerFactoryClazz;
private String tokenizerPreprocessorClazz; private String tokenizerPreprocessorClazz;
@ -69,7 +71,7 @@ public class TokenizerFunction implements Function<String, List<String>> {
tokenizerFactory = new NGramTokenizerFactory(tokenizerFactory, nGrams, nGrams); tokenizerFactory = new NGramTokenizerFactory(tokenizerFactory, nGrams, nGrams);
} }
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
} }
return tokenizerFactory; return tokenizerFactory;
} }

View File

@ -120,7 +120,7 @@ public class UpdatesConsumer implements UpdatesHandler {
//log.info("Putting update to the queue, current size: [{}]", updatesBuffer.size()); //log.info("Putting update to the queue, current size: [{}]", updatesBuffer.size());
updatesBuffer.put(array); updatesBuffer.put(array);
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
throw new RuntimeException(e); throw new RuntimeException(e);
} }
} else if (params != null && stepFunction != null) { } else if (params != null && stepFunction != null) {

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.spark.datavec; package org.deeplearning4j.spark.datavec;
import lombok.extern.slf4j.Slf4j;
import org.apache.hadoop.io.BytesWritable; import org.apache.hadoop.io.BytesWritable;
import org.apache.hadoop.io.Text; import org.apache.hadoop.io.Text;
import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.api.java.function.PairFunction;
@ -35,6 +36,7 @@ import java.util.List;
/** /**
*/ */
@Slf4j
public class DataVecByteDataSetFunction implements PairFunction<Tuple2<Text, BytesWritable>, Double, DataSet> { public class DataVecByteDataSetFunction implements PairFunction<Tuple2<Text, BytesWritable>, Double, DataSet> {
private int labelIndex = 0; private int labelIndex = 0;
@ -104,7 +106,7 @@ public class DataVecByteDataSetFunction implements PairFunction<Tuple2<Text, Byt
featureVector = Nd4j.create(lenFeatureVector); featureVector = Nd4j.create(lenFeatureVector);
} }
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("",e);
} }
List<INDArray> inputs = new ArrayList<>(); List<INDArray> inputs = new ArrayList<>();

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.spark.datavec; package org.deeplearning4j.spark.datavec;
import lombok.extern.slf4j.Slf4j;
import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.Function;
import org.datavec.api.io.WritableConverter; import org.datavec.api.io.WritableConverter;
import org.datavec.api.io.converters.WritableConverterException; import org.datavec.api.io.converters.WritableConverterException;
@ -36,6 +37,7 @@ import java.util.List;
* Analogous to {@link RecordReaderDataSetIterator}, but in the context of Spark. * Analogous to {@link RecordReaderDataSetIterator}, but in the context of Spark.
* @author Alex Black * @author Alex Black
*/ */
@Slf4j
public class DataVecDataSetFunction implements Function<List<Writable>, DataSet>, Serializable { public class DataVecDataSetFunction implements Function<List<Writable>, DataSet>, Serializable {
private final int labelIndex; private final int labelIndex;
@ -129,7 +131,8 @@ public class DataVecDataSetFunction implements Function<List<Writable>, DataSet>
try { try {
current = converter.convert(current); current = converter.convert(current);
} catch (WritableConverterException e) { } catch (WritableConverterException e) {
e.printStackTrace();
log.error("",e);
} }
} }
if (regression) { if (regression) {

View File

@ -33,7 +33,7 @@ public class BalancedPartitionerTest {
// the 10 first elements should go in the 1st partition // the 10 first elements should go in the 1st partition
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
int p = bp.getPartition(i); int p = bp.getPartition(i);
assertEquals("Found wrong partition output " + p + ", not 0", p, 0); assertEquals("Found wrong partition output " + p + ", not 0", 0, p);
} }
} }
@ -43,7 +43,7 @@ public class BalancedPartitionerTest {
// the 10 first elements should go in the 1st partition // the 10 first elements should go in the 1st partition
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
int p = bp.getPartition(i); int p = bp.getPartition(i);
assertEquals("Found wrong partition output " + p + ", not 0", p, 0); assertEquals("Found wrong partition output " + p + ", not 0", 0, p);
} }
} }
@ -56,7 +56,7 @@ public class BalancedPartitionerTest {
countPerPartition[p] += 1; countPerPartition[p] += 1;
} }
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
assertEquals(countPerPartition[i], 10); assertEquals(10, countPerPartition[i]);
} }
} }
@ -70,9 +70,9 @@ public class BalancedPartitionerTest {
} }
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
if (i < 7) if (i < 7)
assertEquals(countPerPartition[i], 10 + 1); assertEquals(10 + 1, countPerPartition[i]);
else else
assertEquals(countPerPartition[i], 10); assertEquals(10, countPerPartition[i]);
} }
} }

View File

@ -385,7 +385,7 @@ public abstract class BaseStatsListener implements RoutingIterationListener {
gpuMaxBytes[i] = nativeOps.getDeviceTotalMemory(0); gpuMaxBytes[i] = nativeOps.getDeviceTotalMemory(0);
gpuCurrentBytes[i] = gpuMaxBytes[i] - nativeOps.getDeviceFreeMemory(0); gpuCurrentBytes[i] = gpuMaxBytes[i] - nativeOps.getDeviceFreeMemory(0);
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
} }
} }
} }

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.ui.weights; package org.deeplearning4j.ui.weights;
import lombok.NonNull; import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.datavec.image.loader.ImageLoader; import org.datavec.image.loader.ImageLoader;
import org.deeplearning4j.api.storage.Persistable; import org.deeplearning4j.api.storage.Persistable;
@ -49,6 +50,7 @@ import java.util.List;
/** /**
* @author raver119@gmail.com * @author raver119@gmail.com
*/ */
@Slf4j
public class ConvolutionalIterationListener extends BaseTrainingListener { public class ConvolutionalIterationListener extends BaseTrainingListener {
private enum Orientation { private enum Orientation {
@ -661,7 +663,7 @@ public class ConvolutionalIterationListener extends BaseTrainingListener {
try { try {
ImageIO.write(renderImageGrayscale(array), "png", file); ImageIO.write(renderImageGrayscale(array), "png", file);
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("",e);
} }
} }
@ -670,7 +672,7 @@ public class ConvolutionalIterationListener extends BaseTrainingListener {
try { try {
ImageIO.write(image, "png", file); ImageIO.write(image, "png", file);
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("",e);
} }
} }

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.ui; package org.deeplearning4j.ui;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils; import org.apache.commons.io.IOUtils;
import org.datavec.image.loader.LFWLoader; import org.datavec.image.loader.LFWLoader;
import org.deeplearning4j.datasets.iterator.impl.LFWDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.LFWDataSetIterator;
@ -82,6 +83,7 @@ import static org.junit.Assert.fail;
* @author raver119@gmail.com * @author raver119@gmail.com
*/ */
@Ignore @Ignore
@Slf4j
public class ManualTests { public class ManualTests {
private static Logger log = LoggerFactory.getLogger(ManualTests.class); private static Logger log = LoggerFactory.getLogger(ManualTests.class);
@ -258,7 +260,7 @@ public class ManualTests {
try { try {
ImageIO.write(imageToRender, "png", file); ImageIO.write(imageToRender, "png", file);
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); log.error("",e);
} }
} }

View File

@ -926,7 +926,7 @@ public class TrainModule implements UIModule {
NeuralNetConfiguration.mapper().readValue(config, NeuralNetConfiguration.class); NeuralNetConfiguration.mapper().readValue(config, NeuralNetConfiguration.class);
return new Triple<>(null, null, layer); return new Triple<>(null, null, layer);
} catch (Exception e) { } catch (Exception e) {
e.printStackTrace(); log.error("",e);
} }
} }
return null; return null;

View File

@ -18,6 +18,7 @@
package org.deeplearning4j.ui; package org.deeplearning4j.ui;
import io.netty.handler.codec.http.HttpResponseStatus; import io.netty.handler.codec.http.HttpResponseStatus;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.api.storage.StatsStorage; import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
@ -53,6 +54,7 @@ import static org.junit.Assert.*;
* @author Tamas Fenyvesi * @author Tamas Fenyvesi
*/ */
@Ignore @Ignore
@Slf4j
public class TestVertxUIMultiSession extends BaseDL4JTest { public class TestVertxUIMultiSession extends BaseDL4JTest {
@Before @Before
public void setUp() throws Exception { public void setUp() throws Exception {
@ -121,7 +123,7 @@ public class TestVertxUIMultiSession extends BaseDL4JTest {
assertEquals(HttpResponseStatus.OK.code(), conn.getResponseCode()); assertEquals(HttpResponseStatus.OK.code(), conn.getResponseCode());
assertTrue(uIServer.isAttached(ss)); assertTrue(uIServer.isAttached(ss));
} catch (InterruptedException | IOException e) { } catch (InterruptedException | IOException e) {
e.printStackTrace(); log.error("",e);
fail(e.getMessage()); fail(e.getMessage());
} finally { } finally {
uIServer.detach(ss); uIServer.detach(ss);
@ -294,7 +296,7 @@ public class TestVertxUIMultiSession extends BaseDL4JTest {
" after " + autoDetachTimeoutMillis + " ms "); " after " + autoDetachTimeoutMillis + " ms ");
Thread.sleep(autoDetachTimeoutMillis); Thread.sleep(autoDetachTimeoutMillis);
} catch (InterruptedException e) { } catch (InterruptedException e) {
e.printStackTrace(); log.error("",e);
} finally { } finally {
System.out.println("Auto-detaching StatsStorage (session ID: " + sessionId + ") after " + System.out.println("Auto-detaching StatsStorage (session ID: " + sessionId + ") after " +
autoDetachTimeoutMillis + " ms."); autoDetachTimeoutMillis + " ms.");

View File

@ -21,7 +21,6 @@ import lombok.Builder;
import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.*; import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.conf.distribution.GaussianDistribution;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.*;

View File

@ -1045,7 +1045,7 @@ public class IntegrationTestRunner {
act[j] = inf.output(in.get(j).getFirst(), inMask); act[j] = inf.output(in.get(j).getFirst(), inMask);
counter.incrementAndGet(); counter.incrementAndGet();
} catch (Exception e){ } catch (Exception e){
e.printStackTrace(); log.error("",e);
failedCount.incrementAndGet(); failedCount.incrementAndGet();
} }
} }

View File

@ -189,7 +189,7 @@ public abstract class DifferentialFunction {
try { try {
return property.get(this); return property.get(this);
} catch (IllegalAccessException e) { } catch (IllegalAccessException e) {
e.printStackTrace(); log.error("",e);
} }
return null; return null;

View File

@ -80,7 +80,7 @@ public class Loss {
public static Loss sum(List<Loss> losses) { public static Loss sum(List<Loss> losses) {
if (losses.size() == 0) if (losses.isEmpty())
return new Loss(Collections.<String>emptyList(), new double[0]); return new Loss(Collections.<String>emptyList(), new double[0]);
double[] lossValues = new double[losses.get(0).losses.length]; double[] lossValues = new double[losses.get(0).losses.length];

View File

@ -17,10 +17,7 @@
package org.nd4j.autodiff.listeners.impl; package org.nd4j.autodiff.listeners.impl;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Set;
import lombok.Getter; import lombok.Getter;
import lombok.Setter; import lombok.Setter;
@ -34,9 +31,6 @@ import org.nd4j.autodiff.listeners.records.LossCurve;
import org.nd4j.autodiff.listeners.Operation; import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.TrainingConfig; import org.nd4j.autodiff.samediff.TrainingConfig;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
/** /**
* HistoryListener is mainly used internally to collect information such as the loss curve and evaluations, * HistoryListener is mainly used internally to collect information such as the loss curve and evaluations,

View File

@ -154,14 +154,12 @@ public abstract class AbstractDependencyTracker<T, D> {
} }
} }
if (allSatisfied) { if (allSatisfied && !this.allSatisfied.contains(t)) {
if (!this.allSatisfied.contains(t)) {
this.allSatisfied.add(t); this.allSatisfied.add(t);
this.allSatisfiedQueue.add(t); this.allSatisfiedQueue.add(t);
} }
} }
} }
}
} else { } else {
satisfiedDependencies.remove(x); satisfiedDependencies.remove(x);
@ -278,25 +276,25 @@ public abstract class AbstractDependencyTracker<T, D> {
protected boolean isAllSatisfied(@NonNull T y) { protected boolean isAllSatisfied(@NonNull T y) {
Set<D> set1 = dependencies.get(y); Set<D> set1 = dependencies.get(y);
boolean allSatisfied = true; boolean retVal = true;
if (set1 != null) { if (set1 != null) {
for (D d : set1) { for (D d : set1) {
allSatisfied = isSatisfied(d); retVal = isSatisfied(d);
if (!allSatisfied) if (!retVal)
break; break;
} }
} }
if (allSatisfied) { if (retVal) {
Set<Pair<D, D>> set2 = orDependencies.get(y); Set<Pair<D, D>> set2 = orDependencies.get(y);
if (set2 != null) { if (set2 != null) {
for (Pair<D, D> p : set2) { for (Pair<D, D> p : set2) {
allSatisfied = isSatisfied(p.getFirst()) || isSatisfied(p.getSecond()); retVal = isSatisfied(p.getFirst()) || isSatisfied(p.getSecond());
if (!allSatisfied) if (!retVal)
break; break;
} }
} }
} }
return allSatisfied; return retVal;
} }

View File

@ -16,8 +16,6 @@
package org.nd4j.evaluation.custom; package org.nd4j.evaluation.custom;
import org.nd4j.shade.guava.collect.Lists;
import java.util.List; import java.util.List;
/** /**

View File

@ -21,10 +21,8 @@ import lombok.Getter;
import lombok.val; import lombok.val;
import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SeluBp;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Mish; import org.nd4j.linalg.api.ops.impl.transforms.strict.Mish;
import org.nd4j.linalg.api.ops.impl.transforms.strict.MishDerivative; import org.nd4j.linalg.api.ops.impl.transforms.strict.MishDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.strict.SELU;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;

View File

@ -19,7 +19,6 @@ import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;

View File

@ -24,7 +24,6 @@ import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;

View File

@ -21,7 +21,6 @@ import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseIndexAccumulation; import org.nd4j.linalg.api.ops.BaseIndexAccumulation;
import org.nd4j.linalg.factory.Nd4j;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;

View File

@ -21,7 +21,6 @@ import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseIndexAccumulation; import org.nd4j.linalg.api.ops.BaseIndexAccumulation;
import org.nd4j.linalg.factory.Nd4j;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;

View File

@ -68,7 +68,7 @@ public class AvgPooling3D extends Pooling3D {
return config.toProperties(); return config.toProperties();
} }
@Override
public String getPoolingPrefix() { public String getPoolingPrefix() {
return "avg"; return "avg";
} }

View File

@ -82,7 +82,7 @@ public class MaxPoolWithArgmax extends DynamicCustomOp {
@Override @Override
public Map<String, Object> propertiesForFunction() { public Map<String, Object> propertiesForFunction() {
if(config == null && iArguments.size() > 0){ if(config == null && !iArguments.isEmpty()){
//Perhaps loaded from FlatBuffers - hence we have IArgs but not Config object //Perhaps loaded from FlatBuffers - hence we have IArgs but not Config object
config = Pooling2DConfig.builder() config = Pooling2DConfig.builder()
.kH(iArguments.get(0)) .kH(iArguments.get(0))

View File

@ -85,7 +85,7 @@ public class MaxPooling2D extends DynamicCustomOp {
@Override @Override
public Map<String, Object> propertiesForFunction() { public Map<String, Object> propertiesForFunction() {
if(config == null && iArguments.size() > 0){ if(config == null && !iArguments.isEmpty()){
//Perhaps loaded from FlatBuffers - hence we have IArgs but not Config object //Perhaps loaded from FlatBuffers - hence we have IArgs but not Config object
config = Pooling2DConfig.builder() config = Pooling2DConfig.builder()
.kH(iArguments.get(0)) .kH(iArguments.get(0))

View File

@ -75,7 +75,7 @@ public class MaxPooling3D extends Pooling3D {
return config.toProperties(); return config.toProperties();
} }
@Override
public String getPoolingPrefix() { public String getPoolingPrefix() {
return "max"; return "max";
} }

View File

@ -18,11 +18,14 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution.config;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.Map; import java.util.Map;
import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.exception.ND4JIllegalStateException;
import java.lang.reflect.Field; import java.lang.reflect.Field;
@Slf4j
public abstract class BaseConvolutionConfig { public abstract class BaseConvolutionConfig {
public abstract Map<String, Object> toProperties(); public abstract Map<String, Object> toProperties();
@ -61,7 +64,7 @@ public abstract class BaseConvolutionConfig {
try { try {
target.set(this, value); target.set(this, value);
} catch (IllegalAccessException e) { } catch (IllegalAccessException e) {
e.printStackTrace(); log.error("",e);
} }
} }

View File

@ -24,12 +24,10 @@ import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.convolution.DepthwiseConv2DBp;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig; import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMLayerWeights; import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMLayerWeights;
import org.nd4j.shade.guava.primitives.Booleans; import org.nd4j.shade.guava.primitives.Booleans;
import javax.xml.crypto.Data;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;

View File

@ -19,8 +19,6 @@ import lombok.AllArgsConstructor;
import lombok.Builder; import lombok.Builder;
import lombok.Data; import lombok.Data;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.Map; import java.util.Map;

Some files were not shown because too many files have changed in this diff Show More