Fixing issues from Sonar report (#391)

* Fixing issues from Sonar report

* Proper logger of exceptions

* Coding style fixes

* Use dup parameter

* Cleanup, minor issues

* Cuda compilation fixed and some minor fixes
master
Alexander Stoyakin 2020-04-23 01:36:49 +03:00 committed by GitHub
parent 2a488efb1b
commit ccb216a3ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
178 changed files with 376 additions and 422 deletions

View File

@ -18,6 +18,7 @@ package org.datavec.api.records.reader.impl;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import org.apache.commons.io.LineIterator;
import org.datavec.api.conf.Configuration;
@ -43,6 +44,7 @@ import java.util.*;
*
* @author Adam Gibson
*/
@Slf4j
public class LineRecordReader extends BaseRecordReader {
@ -89,7 +91,7 @@ public class LineRecordReader extends BaseRecordReader {
iter = getIterator(splitIndex);
onLocationOpen(locations[splitIndex]);
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
if (iter.hasNext()) {
@ -120,7 +122,7 @@ public class LineRecordReader extends BaseRecordReader {
iter = getIterator(splitIndex);
onLocationOpen(locations[splitIndex]);
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
return iter.hasNext();

View File

@ -16,6 +16,7 @@
package org.datavec.api.split;
import lombok.extern.slf4j.Slf4j;
import org.datavec.api.util.files.UriFromPathIterator;
import org.datavec.api.writable.WritableType;
@ -34,6 +35,7 @@ import java.util.regex.Pattern;
* NumberedFileInputSplit utilizes String.format(), hence the requirement for "%d" to represent
* the integer index.
*/
@Slf4j
public class NumberedFileInputSplit implements InputSplit {
private final String baseString;
private final int minIdx;
@ -93,7 +95,7 @@ public class NumberedFileInputSplit implements InputSplit {
try {
writeFile.createNewFile();
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}

View File

@ -162,7 +162,6 @@ public class Text extends BinaryComparable implements WritableComparable<BinaryC
return -1; // not found
} catch (CharacterCodingException e) {
// can't get here
e.printStackTrace();
return -1;
}
}

View File

@ -17,6 +17,7 @@
package org.datavec.arrow.recordreader;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.datavec.api.conf.Configuration;
import org.datavec.api.records.Record;
@ -50,6 +51,7 @@ import static org.datavec.arrow.ArrowConverter.readFromBytes;
* @author Adam Gibson
*
*/
@Slf4j
public class ArrowRecordReader implements RecordReader {
private InputSplit split;
@ -132,7 +134,7 @@ public class ArrowRecordReader implements RecordReader {
currIdx++;
this.currentPath = url;
}catch(Exception e) {
e.printStackTrace();
log.error("",e);
}
}
@ -242,7 +244,7 @@ public class ArrowRecordReader implements RecordReader {
try {
currentBatch.close();
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
}
}

View File

@ -16,6 +16,7 @@
package org.datavec.arrow;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
@ -56,7 +57,7 @@ import static junit.framework.TestCase.assertTrue;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
@Slf4j
public class ArrowConverterTest extends BaseND4JTest {
private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE);
@ -343,7 +344,7 @@ public class ArrowConverterTest extends BaseND4JTest {
try(ArrowFileWriter arrowFileWriter = new ArrowFileWriter(schemaRoot1,null,newChannel(byteArrayOutputStream))) {
arrowFileWriter.writeBatch();
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
byte[] arr = byteArrayOutputStream.toByteArray();

View File

@ -61,7 +61,7 @@ public class Wave implements Serializable {
initWaveWithInputStream(inputStream);
inputStream.close();
} catch (IOException e) {
e.printStackTrace();
System.out.println(e.toString());
}
}
@ -96,7 +96,7 @@ public class Wave implements Serializable {
data = new byte[inputStream.available()];
inputStream.read(data);
} catch (IOException e) {
e.printStackTrace();
System.err.println(e.toString());
}
// end load data
} else {

View File

@ -16,10 +16,12 @@
package org.datavec.audio;
import lombok.extern.slf4j.Slf4j;
import java.io.FileOutputStream;
import java.io.IOException;
@Slf4j
public class WaveFileManager {
private Wave wave;
@ -78,7 +80,7 @@ public class WaveFileManager {
fos.write(wave.getBytes());
fos.close();
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
}

View File

@ -16,6 +16,8 @@
package org.datavec.audio;
import lombok.extern.slf4j.Slf4j;
import java.io.IOException;
import java.io.InputStream;
@ -25,6 +27,7 @@ import java.io.InputStream;
*
* @author Jacquet Wong
*/
@Slf4j
public class WaveHeader {
public static final String RIFF_HEADER = "RIFF";
@ -109,7 +112,7 @@ public class WaveHeader {
// dis.close();
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
return false;
}

View File

@ -17,6 +17,7 @@
package org.datavec.audio.fingerprint;
import lombok.extern.slf4j.Slf4j;
import org.datavec.audio.Wave;
import org.datavec.audio.WaveHeader;
import org.datavec.audio.dsp.Resampler;
@ -38,6 +39,7 @@ import java.util.List;
* @author jacquet
*
*/
@Slf4j
public class FingerprintManager {
private FingerprintProperties fingerprintProperties = FingerprintProperties.getInstance();
@ -153,7 +155,7 @@ public class FingerprintManager {
fingerprint = getFingerprintFromInputStream(fis);
fis.close();
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
return fingerprint;
}
@ -170,7 +172,7 @@ public class FingerprintManager {
fingerprint = new byte[inputStream.available()];
inputStream.read(fingerprint);
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
return fingerprint;
}
@ -190,7 +192,7 @@ public class FingerprintManager {
fileOutputStream.write(fingerprint);
fileOutputStream.close();
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
}

View File

@ -186,7 +186,7 @@ public class CifarLoader extends NativeImageLoader implements Serializable {
labels.add(line);
}
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
}
@ -300,7 +300,7 @@ public class CifarLoader extends NativeImageLoader implements Serializable {
try {
FileUtils.write(meanVarPath, uMean + "," + uStd + "," + vMean + "," + vStd);
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
meanStdStored = true;
} else if (uMean == 0 && meanStdStored) {
@ -312,7 +312,7 @@ public class CifarLoader extends NativeImageLoader implements Serializable {
vStd = Double.parseDouble(values[3]);
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
}
for (int i = 0; i < result.numExamples(); i++) {
@ -356,12 +356,12 @@ public class CifarLoader extends NativeImageLoader implements Serializable {
dataSets.add(new DataSet(asMatrix(matConversion.getSecond()), matConversion.getFirst()));
batchNumCount++;
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
break;
}
}
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
if(dataSets.size() == 0){

View File

@ -235,7 +235,7 @@ public class LFWLoader extends BaseImageLoader implements Serializable {
InputSplit data = train ? inputSplit[0] : inputSplit[1];
recordReader.initialize(data);
} catch (IOException | InterruptedException e) {
e.printStackTrace();
log.error("",e);
}
return recordReader;
}
@ -250,7 +250,7 @@ public class LFWLoader extends BaseImageLoader implements Serializable {
InputSplit data = train ? inputSplit[0] : inputSplit[1];
recordReader.initialize(data);
} catch (IOException | InterruptedException e) {
e.printStackTrace();
log.error("",e);
}
return recordReader;
}

View File

@ -225,7 +225,7 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
finishedInputStreamSplit = true;
return Arrays.<Writable>asList(ndArrayWritable);
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
}
if (iter != null) {

View File

@ -16,6 +16,7 @@
package org.datavec.image.loader;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.apache.commons.io.IOUtils;
import org.bytedeco.javacpp.Loader;
@ -53,6 +54,7 @@ import static org.junit.Assert.fail;
*
* @author saudet
*/
@Slf4j
public class TestNativeImageLoader {
static final long seed = 10;
static final Random rng = new Random(seed);
@ -123,7 +125,7 @@ public class TestNativeImageLoader {
try {
array6 = loader5.asMatrix(new ClassPathResource(path2MitosisFile).getFile().getAbsolutePath());
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
fail();
}
assertEquals(5, array6.rank());
@ -156,7 +158,7 @@ public class TestNativeImageLoader {
try {
array8 = loader7.asMatrix(new ClassPathResource(path2MitosisFile).getFile().getAbsolutePath());
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
assertEquals(5, array8.rank());
assertEquals(pages2, array8.size(0));
@ -172,7 +174,7 @@ public class TestNativeImageLoader {
try {
array9 = loader8.asMatrix(new ClassPathResource(braintiff).getFile().getAbsolutePath());
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
fail();
}
assertEquals(5, array9.rank());

View File

@ -66,7 +66,7 @@ public class UimaTokenizer implements Tokenizer {
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
throw new RuntimeException(e);
}

View File

@ -17,6 +17,7 @@
package org.datavec.local.transforms.functions;
import lombok.extern.slf4j.Slf4j;
import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.writable.Writable;
import org.nd4j.linalg.function.Function;
@ -32,6 +33,7 @@ import java.util.List;
* sequence data into a {@code List<List<Writable>>}
* @author Alex Black
*/
@Slf4j
public class SequenceRecordReaderFunction
implements Function<Pair<String, InputStream>, List<List<Writable>>> {
protected SequenceRecordReader sequenceRecordReader;
@ -46,7 +48,7 @@ public class SequenceRecordReaderFunction
try (DataInputStream dis = (DataInputStream) value.getRight()) {
return sequenceRecordReader.sequenceRecord(uri, dis);
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
throw new IllegalStateException("Something went wrong");

View File

@ -74,7 +74,6 @@ public class DataVecTransformClient implements DataVecTransformService {
} catch (UnirestException e) {
log.error("Error in setCSVTransformProcess()", e);
e.printStackTrace();
}
}
@ -94,7 +93,6 @@ public class DataVecTransformClient implements DataVecTransformService {
return TransformProcess.fromJson(s);
} catch (UnirestException e) {
log.error("Error in getCSVTransformProcess()",e);
e.printStackTrace();
}
return null;
@ -119,7 +117,6 @@ public class DataVecTransformClient implements DataVecTransformService {
return singleCsvRecord;
} catch (UnirestException e) {
log.error("Error in transformIncremental(SingleCSVRecord)",e);
e.printStackTrace();
}
return null;
}
@ -140,8 +137,7 @@ public class DataVecTransformClient implements DataVecTransformService {
.getBody();
return batchCSVRecord1;
} catch (UnirestException e) {
log.error("Error in transform(BatchCSVRecord)", e);
e.printStackTrace();
log.error("",e);
}
return null;
@ -162,7 +158,6 @@ public class DataVecTransformClient implements DataVecTransformService {
return batchCSVRecord1;
} catch (UnirestException e) {
log.error("Error in transform(BatchCSVRecord)", e);
e.printStackTrace();
}
return null;
@ -181,7 +176,6 @@ public class DataVecTransformClient implements DataVecTransformService {
return batchArray1;
} catch (UnirestException e) {
log.error("Error in transformArray(BatchCSVRecord)",e);
e.printStackTrace();
}
return null;
@ -200,7 +194,6 @@ public class DataVecTransformClient implements DataVecTransformService {
return array;
} catch (UnirestException e) {
log.error("Error in transformArrayIncremental(SingleCSVRecord)",e);
e.printStackTrace();
}
return null;
@ -231,7 +224,6 @@ public class DataVecTransformClient implements DataVecTransformService {
return array;
} catch (UnirestException e) {
log.error("Error in transformSequenceArrayIncremental",e);
e.printStackTrace();
}
return null;
@ -252,7 +244,6 @@ public class DataVecTransformClient implements DataVecTransformService {
return batchArray1;
} catch (UnirestException e) {
log.error("Error in transformSequenceArray",e);
e.printStackTrace();
}
return null;
@ -274,7 +265,6 @@ public class DataVecTransformClient implements DataVecTransformService {
return batchCSVRecord1;
} catch (UnirestException e) {
log.error("Error in transformSequence");
e.printStackTrace();
}
return null;
@ -295,7 +285,6 @@ public class DataVecTransformClient implements DataVecTransformService {
return singleCsvRecord;
} catch (UnirestException e) {
log.error("Error in transformSequenceIncremental");
e.printStackTrace();
}
return null;
}

View File

@ -18,6 +18,7 @@ package org.datavec.spark.transform;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
@ -53,7 +54,7 @@ import static org.datavec.local.transforms.LocalTransformExecutor.executeToSeque
* @author Adan Gibson
*/
@AllArgsConstructor
@Slf4j
public class CSVSparkTransform {
@Getter
private TransformProcess transformProcess;
@ -252,7 +253,7 @@ public class CSVSparkTransform {
try {
return new Base64NDArrayBody(Nd4jBase64.base64String(arr));
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
return null;

View File

@ -88,7 +88,7 @@ public class ImageSparkTransformServer extends SparkTransformServer {
log.info("Transform process initialized");
return ok(objectMapper.writeValueAsString(transform.getImageTransformProcess())).as(contentType);
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
return internalServerError();
}
});
@ -100,7 +100,7 @@ public class ImageSparkTransformServer extends SparkTransformServer {
log.info("Transform process initialized");
return ok(objectMapper.writeValueAsString(transformProcess)).as(contentType);
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
return internalServerError();
}
});
@ -112,7 +112,7 @@ public class ImageSparkTransformServer extends SparkTransformServer {
return badRequest();
return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType);
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
return internalServerError();
}
});
@ -130,7 +130,7 @@ public class ImageSparkTransformServer extends SparkTransformServer {
return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType);
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
return internalServerError();
}
});
@ -142,7 +142,7 @@ public class ImageSparkTransformServer extends SparkTransformServer {
return badRequest();
return ok(objectMapper.writeValueAsString(transformArray(batch))).as(contentType);
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
return internalServerError();
}
});
@ -169,7 +169,7 @@ public class ImageSparkTransformServer extends SparkTransformServer {
return ok(objectMapper.writeValueAsString(transformArray(batch))).as(contentType);
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
return internalServerError();
}
});

View File

@ -16,6 +16,7 @@
package org.datavec.spark;
import lombok.extern.slf4j.Slf4j;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.junit.After;
@ -23,6 +24,7 @@ import org.junit.Before;
import java.io.Serializable;
@Slf4j
public abstract class BaseSparkTest implements Serializable {
protected static JavaSparkContext sc;
@ -40,7 +42,7 @@ public abstract class BaseSparkTest implements Serializable {
try {
Thread.sleep(100L);
} catch (InterruptedException e) {
e.printStackTrace();
log.error("",e);
}
} else {
break;

View File

@ -43,7 +43,8 @@ import java.util.concurrent.atomic.AtomicLong;
*/
@Slf4j
public class RemoteUIStatsStorageRouter implements StatsStorageRouter, Serializable, Closeable {
private static final String ROUTE_IS_DOWN = "Info posted to RemoteUIStatsStorageRouter but router is shut down.";
private static final String MAX_WARNINGS_REACHED = "RemoteUIStatsStorageRouter: Reached max shutdown warnings. No further warnings will be produced.";
/**
* Default path for posting data to the UI - i.e., http://localhost:9000/remoteReceive or similar
*/
@ -163,10 +164,10 @@ public class RemoteUIStatsStorageRouter implements StatsStorageRouter, Serializa
if (shutdown.get()) {
long count = shutdownWarnCount.getAndIncrement();
if (count <= MAX_SHUTDOWN_WARN_COUNT) {
log.warn("Info posted to RemoteUIStatsStorageRouter but router is shut down.");
log.warn(ROUTE_IS_DOWN);
}
if (count == MAX_SHUTDOWN_WARN_COUNT) {
log.warn("RemoteUIStatsStorageRouter: Reached max shutdown warnings. No further warnings will be produced.");
log.warn(MAX_WARNINGS_REACHED);
}
} else {
for (StorageMetaData m : storageMetaData) {
@ -186,10 +187,10 @@ public class RemoteUIStatsStorageRouter implements StatsStorageRouter, Serializa
if (shutdown.get()) {
long count = shutdownWarnCount.getAndIncrement();
if (count <= MAX_SHUTDOWN_WARN_COUNT) {
log.warn("Info posted to RemoteUIStatsStorageRouter but router is shut down.");
log.warn(ROUTE_IS_DOWN);
}
if (count == MAX_SHUTDOWN_WARN_COUNT) {
log.warn("RemoteUIStatsStorageRouter: Reached max shutdown warnings. No further warnings will be produced.");
log.warn(MAX_WARNINGS_REACHED);
}
} else {
for (Persistable p : staticInfo) {
@ -209,10 +210,10 @@ public class RemoteUIStatsStorageRouter implements StatsStorageRouter, Serializa
if (shutdown.get()) {
long count = shutdownWarnCount.getAndIncrement();
if (count <= MAX_SHUTDOWN_WARN_COUNT) {
log.warn("Info posted to RemoteUIStatsStorageRouter but router is shut down.");
log.warn(ROUTE_IS_DOWN);
}
if (count == MAX_SHUTDOWN_WARN_COUNT) {
log.warn("RemoteUIStatsStorageRouter: Reached max shutdown warnings. No further warnings will be produced.");
log.warn(MAX_WARNINGS_REACHED);
}
} else {
for (Persistable p : updates) {

View File

@ -67,10 +67,7 @@ public class AsyncIterator<T extends Object> implements Iterator<T> {
nextElement = buffer.take();
// same on this run
if (nextElement == terminator)
return false;
return true;
return (nextElement != terminator);
} catch (Exception e) {
log.error("Premature end of loop!");
return false;

View File

@ -43,7 +43,7 @@ public class SystemInfoPrintListener implements TrainingListener {
private boolean printOnBackwardPass;
private boolean printOnGradientCalculation;
private static final String SYSTEM_INFO = "System info on epoch end: ";
@Override
public void iterationDone(Model model, int iteration, int epoch) {
@ -65,7 +65,7 @@ public class SystemInfoPrintListener implements TrainingListener {
return;
SystemInfo systemInfo = new SystemInfo();
log.info("System info on epoch end: ");
log.info(SYSTEM_INFO);
log.info(systemInfo.toPrettyJSON());
}
@ -75,7 +75,7 @@ public class SystemInfoPrintListener implements TrainingListener {
return;
SystemInfo systemInfo = new SystemInfo();
log.info("System info on epoch end: ");
log.info(SYSTEM_INFO);
log.info(systemInfo.toPrettyJSON());
}
@ -85,7 +85,7 @@ public class SystemInfoPrintListener implements TrainingListener {
return;
SystemInfo systemInfo = new SystemInfo();
log.info("System info on epoch end: ");
log.info(SYSTEM_INFO);
log.info(systemInfo.toPrettyJSON());
}
@ -95,7 +95,7 @@ public class SystemInfoPrintListener implements TrainingListener {
return;
SystemInfo systemInfo = new SystemInfo();
log.info("System info on epoch end: ");
log.info(SYSTEM_INFO);
log.info(systemInfo.toPrettyJSON());
}
@ -104,7 +104,7 @@ public class SystemInfoPrintListener implements TrainingListener {
if(!printOnBackwardPass)
return;
SystemInfo systemInfo = new SystemInfo();
log.info("System info on epoch end: ");
log.info(SYSTEM_INFO);
log.info(systemInfo.toPrettyJSON());
}
}

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.perf.listener;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
import oshi.json.SystemInfo;
@ -36,6 +37,7 @@ import java.util.concurrent.TimeUnit;
*
* @author Adam Gibson
*/
@Slf4j
public class SystemPolling {
private ScheduledExecutorService scheduledExecutorService;
@ -66,7 +68,7 @@ public class SystemPolling {
try {
objectMapper.writeValue(hardwareFile,hardwareMetric);
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
}
},0,pollEveryMillis, TimeUnit.MILLISECONDS);

View File

@ -27,7 +27,6 @@ import org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
import org.nd4j.linalg.dataset.api.preprocessor.Normalizer;
import java.io.*;
import java.nio.file.Files;
import java.util.UUID;
/**

View File

@ -493,7 +493,7 @@ public class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
out.println();
}
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
return new Pair<>(dArr,temp);

View File

@ -78,10 +78,10 @@ public class EarlyTerminationDataSetIteratorTest extends BaseDL4JTest {
EarlyTerminationDataSetIterator earlyEndIter = new EarlyTerminationDataSetIterator(iter, terminateAfter);
earlyEndIter.next(10);
assertEquals(earlyEndIter.hasNext(), false);
assertEquals(false, earlyEndIter.hasNext());
earlyEndIter.reset();
assertEquals(earlyEndIter.hasNext(), true);
assertEquals(true, earlyEndIter.hasNext());
}

View File

@ -98,7 +98,7 @@ public class MultipleEpochsIteratorTest extends BaseDL4JTest {
while (multiIter.hasNext()) {
DataSet path = multiIter.next(10);
assertNotNull(path);
assertEquals(path.numExamples(), 10, 0.0);
assertEquals(10, path.numExamples(), 0.0);
}
assertEquals(epochs, multiIter.epochs);

View File

@ -33,7 +33,7 @@ public class SamplingTest extends BaseDL4JTest {
DataSetIterator iter = new MnistDataSetIterator(10, 10);
//batch size and total
DataSetIterator sampling = new SamplingDataSetIterator(iter.next(), 10, 10);
assertEquals(sampling.next().numExamples(), 10);
assertEquals(10, sampling.next().numExamples());
}
}

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.exceptions;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.exception.DL4JException;
import org.deeplearning4j.nn.conf.ConvolutionMode;
@ -34,6 +35,7 @@ import static org.junit.Assert.fail;
/**
* A set of tests to ensure that useful exceptions are thrown on invalid network configurations
*/
@Slf4j
public class TestInvalidConfigurations extends BaseDL4JTest {
public static MultiLayerNetwork getDensePlusOutput(int nIn, int nOut) {
@ -78,7 +80,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testDenseNin0(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail();
}
}
@ -96,7 +98,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testDenseNout0(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail();
}
}
@ -109,7 +111,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testOutputLayerNin0(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail();
}
}
@ -122,7 +124,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testRnnOutputLayerNin0(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail();
}
}
@ -135,7 +137,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testLSTMNIn0(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail();
}
}
@ -153,7 +155,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testLSTMNOut0(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail();
}
}
@ -166,7 +168,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testConvolutionalNin0(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail();
}
}
@ -185,7 +187,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testConvolutionalNOut0(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail();
}
}
@ -216,7 +218,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testCnnInvalidConfigPaddingStridesHeight(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail();
}
}
@ -245,7 +247,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testCnnInvalidConfigOrInput_SmallerDataThanKernel(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail();
}
}
@ -277,7 +279,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testCnnInvalidConfigOrInput_BadStrides(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail();
}
}
@ -318,7 +320,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testCnnInvalidConfigPaddingStridesWidth(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail();
}
}
@ -347,7 +349,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testCnnInvalidConfigPaddingStridesWidthSubsampling(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail();
}
}

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.exceptions;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.exception.DL4JException;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
@ -36,6 +37,7 @@ import static org.junit.Assert.*;
/**
* A set of tests to ensure that useful exceptions are thrown on invalid input
*/
@Slf4j
public class TestInvalidInput extends BaseDL4JTest {
@Test
@ -53,7 +55,7 @@ public class TestInvalidInput extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testInputNinMismatchDense(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail("Expected DL4JException");
}
}
@ -73,7 +75,7 @@ public class TestInvalidInput extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testInputNinMismatchOutputLayer(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail("Expected DL4JException");
}
}
@ -94,7 +96,7 @@ public class TestInvalidInput extends BaseDL4JTest {
//From loss function
System.out.println("testLabelsNOutMismatchOutputLayer(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail("Expected DL4JException");
}
}
@ -115,7 +117,7 @@ public class TestInvalidInput extends BaseDL4JTest {
//From loss function
System.out.println("testLabelsNOutMismatchRnnOutputLayer(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail("Expected DL4JException");
}
}
@ -142,7 +144,7 @@ public class TestInvalidInput extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testInputNinMismatchConvolutional(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail("Expected DL4JException");
}
}
@ -169,7 +171,7 @@ public class TestInvalidInput extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testInputNinRank2Convolutional(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail("Expected DL4JException");
}
}
@ -195,7 +197,7 @@ public class TestInvalidInput extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testInputNinRank2Subsampling(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail("Expected DL4JException");
}
}
@ -217,7 +219,7 @@ public class TestInvalidInput extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testInputNinMismatchLSTM(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail("Expected DL4JException");
}
}
@ -238,7 +240,7 @@ public class TestInvalidInput extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testInputNinMismatchBidirectionalLSTM(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail("Expected DL4JException");
}
@ -260,7 +262,7 @@ public class TestInvalidInput extends BaseDL4JTest {
} catch (DL4JException e) {
System.out.println("testInputNinMismatchEmbeddingLayer(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail("Expected DL4JException");
}
}
@ -305,7 +307,7 @@ public class TestInvalidInput extends BaseDL4JTest {
net.rnnTimeStep(Nd4j.create(5, 5, 10));
fail("Expected Exception - " + layerType);
} catch (Exception e) {
// e.printStackTrace();
log.error("",e);
String msg = e.getMessage();
assertTrue(msg, msg != null && msg.contains("rnn") && msg.contains("batch"));
}

View File

@ -216,7 +216,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
failed.add(testName + "\t" + "EXCEPTION");
continue;
}
@ -383,7 +383,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
failed.add(testName + "\t" + "EXCEPTION");
continue;
}
@ -693,7 +693,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
failed.add(testName + "\t" + "EXCEPTION");
continue;
}

View File

@ -21,6 +21,7 @@ import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.exception.DL4JInvalidConfigException;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
@ -47,6 +48,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
import static org.junit.Assert.*;
@Slf4j
public class ComputationGraphConfigurationTest extends BaseDL4JTest {
@Test
@ -150,7 +152,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) {
//OK - exception is good
//e.printStackTrace();
log.info(e.toString());
}
// Use appendLayer on first layer
@ -162,7 +164,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) {
//OK - exception is good
//e.printStackTrace();
log.info(e.toString());
}
//Test no network inputs
@ -174,7 +176,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) {
//OK - exception is good
//e.printStackTrace();
log.info(e.toString());
}
//Test no network outputs
@ -185,7 +187,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) {
//OK - exception is good
//e.printStackTrace();
log.info(e.toString());
}
//Test: invalid input
@ -197,7 +199,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) {
//OK - exception is good
//e.printStackTrace();
log.info(e.toString());
}
//Test: graph with cycles
@ -215,7 +217,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) {
//OK - exception is good
//e.printStackTrace();
log.info(e.toString());
}
//Test: input != inputType count mismatch
@ -241,7 +243,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
fail("No exception thrown for invalid configuration");
} catch (IllegalArgumentException e) {
//OK - exception is good
//e.printStackTrace();
log.info(e.toString());
}
}

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.nn.conf;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.exception.DL4JInvalidConfigException;
import org.deeplearning4j.nn.api.Layer;
@ -46,6 +47,7 @@ import static org.junit.Assert.*;
/**
* Created by agibsonccc on 11/27/14.
*/
@Slf4j
public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
@Rule
@ -272,9 +274,9 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) {
//OK
e.printStackTrace();
log.error("",e);
} catch (Throwable e) {
e.printStackTrace();
log.error("",e);
fail("Unexpected exception thrown for invalid config");
}
@ -288,9 +290,9 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) {
//OK
e.printStackTrace();
log.info(e.toString());
} catch (Throwable e) {
e.printStackTrace();
log.error("",e);
fail("Unexpected exception thrown for invalid config");
}
@ -304,9 +306,9 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) {
//OK
e.printStackTrace();
log.info(e.toString());
} catch (Throwable e) {
e.printStackTrace();
log.error("",e);
fail("Unexpected exception thrown for invalid config");
}
}

View File

@ -96,8 +96,8 @@ public class TestConstraints extends BaseDL4JTest {
} else if (lc instanceof NonNegativeConstraint) {
assertTrue(RW0.minNumber().doubleValue() >= 0.0);
} else if (lc instanceof UnitNormConstraint) {
assertEquals(RW0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6);
assertEquals(RW0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6);
assertEquals(1.0, RW0.norm2(1).minNumber().doubleValue(), 1e-6);
assertEquals(1.0, RW0.norm2(1).maxNumber().doubleValue(), 1e-6);
}
TestUtils.testModelSerialization(net);
@ -149,8 +149,8 @@ public class TestConstraints extends BaseDL4JTest {
} else if (lc instanceof NonNegativeConstraint) {
assertTrue(b0.minNumber().doubleValue() >= 0.0);
} else if (lc instanceof UnitNormConstraint) {
assertEquals(b0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6);
assertEquals(b0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6);
assertEquals(1.0, b0.norm2(1).minNumber().doubleValue(), 1e-6);
assertEquals(1.0, b0.norm2(1).maxNumber().doubleValue(), 1e-6);
}
TestUtils.testModelSerialization(net);
@ -201,8 +201,8 @@ public class TestConstraints extends BaseDL4JTest {
} else if (lc instanceof NonNegativeConstraint) {
assertTrue(w0.minNumber().doubleValue() >= 0.0);
} else if (lc instanceof UnitNormConstraint) {
assertEquals(w0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6);
assertEquals(w0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6);
assertEquals(1.0, w0.norm2(1).minNumber().doubleValue(), 1e-6);
assertEquals(1.0, w0.norm2(1).maxNumber().doubleValue(), 1e-6);
}
TestUtils.testModelSerialization(net);
@ -259,10 +259,10 @@ public class TestConstraints extends BaseDL4JTest {
assertTrue(w0.minNumber().doubleValue() >= 0.0);
assertTrue(b0.minNumber().doubleValue() >= 0.0);
} else if (lc instanceof UnitNormConstraint) {
assertEquals(w0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6);
assertEquals(w0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6);
assertEquals(b0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6);
assertEquals(b0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6);
assertEquals(1.0, w0.norm2(1).minNumber().doubleValue(), 1e-6);
assertEquals(1.0, w0.norm2(1).maxNumber().doubleValue(), 1e-6);
assertEquals(1.0, b0.norm2(1).minNumber().doubleValue(), 1e-6);
assertEquals(1.0, b0.norm2(1).maxNumber().doubleValue(), 1e-6);
}
TestUtils.testModelSerialization(net);
@ -320,10 +320,10 @@ public class TestConstraints extends BaseDL4JTest {
assertTrue(w0.minNumber().doubleValue() >= 0.0);
assertTrue(b0.minNumber().doubleValue() >= 0.0);
} else if (lc instanceof UnitNormConstraint) {
assertEquals(w0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6);
assertEquals(w0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6);
assertEquals(b0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6);
assertEquals(b0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6);
assertEquals(1.0, w0.norm2(1).minNumber().doubleValue(), 1e-6);
assertEquals(1.0, w0.norm2(1).maxNumber().doubleValue(), 1e-6);
assertEquals(1.0, b0.norm2(1).minNumber().doubleValue(), 1e-6);
assertEquals(1.0, b0.norm2(1).maxNumber().doubleValue(), 1e-6);
}
TestUtils.testModelSerialization(net);
@ -378,10 +378,10 @@ public class TestConstraints extends BaseDL4JTest {
} else if(lc instanceof NonNegativeConstraint ){
assertTrue(w0.minNumber().doubleValue() >= 0.0 );
} else if(lc instanceof UnitNormConstraint ){
assertEquals(w0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6 );
assertEquals(w0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6 );
assertEquals(w1.norm2(1).minNumber().doubleValue(), 1.0, 1e-6 );
assertEquals(w1.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6 );
assertEquals(1.0, w0.norm2(1).minNumber().doubleValue(), 1e-6 );
assertEquals(1.0, w0.norm2(1).maxNumber().doubleValue(), 1e-6 );
assertEquals(1.0, w1.norm2(1).minNumber().doubleValue(), 1e-6 );
assertEquals(1.0, w1.norm2(1).maxNumber().doubleValue(), 1e-6 );
}
TestUtils.testModelSerialization(net);

View File

@ -156,7 +156,7 @@ public class LayerBuilderTest extends BaseDL4JTest {
checkSerialization(glstm);
assertEquals(glstm.getForgetGateBiasInit(), 1.5, 0.0);
assertEquals(1.5, glstm.getForgetGateBiasInit(), 0.0);
assertEquals(glstm.nIn, numIn);
assertEquals(glstm.nOut, numOut);
assertTrue(glstm.getActivationFn() instanceof ActivationTanH);

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.nn.graph;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.IteratorDataSetIterator;
@ -54,6 +55,7 @@ import java.util.Map;
import static org.junit.Assert.*;
@Slf4j
public class ComputationGraphTestRNN extends BaseDL4JTest {
@Test
@ -618,7 +620,7 @@ public class ComputationGraphTestRNN extends BaseDL4JTest {
.build();
fail("Exception expected");
} catch (IllegalStateException e){
// e.printStackTrace();
log.error("",e);
assertTrue(e.getMessage().contains("TBPTT") && e.getMessage().contains("validateTbpttConfig"));
}
}

View File

@ -1394,7 +1394,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
} catch (Exception e) {
//e.printStackTrace();
log.error("",e);
if(allowDisconnected){
fail("No exception expected");
} else {

View File

@ -416,8 +416,8 @@ public class TestVariableLengthTSCG extends BaseDL4JTest {
INDArray outRow2 = out2.get(NDArrayIndex.point(i), NDArrayIndex.all(),
NDArrayIndex.point(j));
for (int k = 0; k < nOut; k++) {
assertEquals(outRow.getDouble(k), 0.0, 0.0);
assertEquals(outRow2.getDouble(k), 0.0, 0.0);
assertEquals(0.0, outRow.getDouble(k), 0.0);
assertEquals(0.0, outRow2.getDouble(k), 0.0);
}
}
}

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.nn.layers.convolution;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.exception.DL4JException;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
@ -45,6 +46,7 @@ import static org.junit.Assert.*;
/**
* Created by Alex on 15/11/2016.
*/
@Slf4j
public class TestConvolutionModes extends BaseDL4JTest {
@Test
@ -106,12 +108,12 @@ public class TestConvolutionModes extends BaseDL4JTest {
}
} catch (DL4JException e) {
if (inSize == 9 || cm != ConvolutionMode.Strict) {
e.printStackTrace();
log.error("",e);
fail("Unexpected exception");
}
continue; //Expected exception
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail("Unexpected exception");
}
@ -184,12 +186,12 @@ public class TestConvolutionModes extends BaseDL4JTest {
}
} catch (DL4JException e) {
if (inSize == 9 || cm != ConvolutionMode.Strict) {
e.printStackTrace();
log.error("",e);
fail("Unexpected exception");
}
continue; //Expected exception
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail("Unexpected exception");
}

View File

@ -79,13 +79,13 @@ public class MaskZeroLayerTest extends BaseDL4JTest {
INDArray firstExampleOutput = out.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.all());
INDArray secondExampleOutput = out.get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.all());
assertEquals(firstExampleOutput.getDouble(0), 0.0, 1e-6);
assertEquals(firstExampleOutput.getDouble(1), 1.0, 1e-6);
assertEquals(firstExampleOutput.getDouble(2), 2.0, 1e-6);
assertEquals(0.0, firstExampleOutput.getDouble(0), 1e-6);
assertEquals(1.0, firstExampleOutput.getDouble(1), 1e-6);
assertEquals(2.0, firstExampleOutput.getDouble(2), 1e-6);
assertEquals(secondExampleOutput.getDouble(0), 0.0, 1e-6);
assertEquals(secondExampleOutput.getDouble(1), 0.0, 1e-6);
assertEquals(secondExampleOutput.getDouble(2), 1.0, 1e-6);
assertEquals(0.0, secondExampleOutput.getDouble(0), 1e-6);
assertEquals(0.0, secondExampleOutput.getDouble(1), 1e-6);
assertEquals(1.0, secondExampleOutput.getDouble(2), 1e-6);
}

View File

@ -771,7 +771,7 @@ public class MultiLayerTestRNN extends BaseDL4JTest {
.build();
fail("Exception expected");
} catch (IllegalStateException e){
// e.printStackTrace();
log.info(e.toString());
assertTrue(e.getMessage().contains("TBPTT") && e.getMessage().contains("validateTbpttConfig"));
}
}

View File

@ -408,8 +408,8 @@ public class TestVariableLengthTS extends BaseDL4JTest {
INDArray outRow2 = out2.get(NDArrayIndex.point(i), NDArrayIndex.all(),
NDArrayIndex.point(j));
for (int k = 0; k < nOut; k++) {
assertEquals(outRow.getDouble(k), 0.0, 0.0);
assertEquals(outRow2.getDouble(k), 0.0, 0.0);
assertEquals(0.0, outRow.getDouble(k), 0.0);
assertEquals(0.0, outRow2.getDouble(k), 0.0);
}
}
}

View File

@ -336,7 +336,7 @@ public class TestUpdaters extends BaseDL4JTest {
actualM[i] = Math.round(actualM[i] * 1e2) / 1e2;
}
assertEquals("Wrong weight gradient after first iteration's update", Arrays.equals(actualM, expectedM), true);
assertEquals("Wrong weight gradient after first iteration's update", Arrays.equals(expectedM, actualM), true);
}

View File

@ -159,14 +159,14 @@ public class RegressionTest050 extends BaseDL4JTest {
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
assertArrayEquals(new int[] {1, 1}, l0.getStride());
assertArrayEquals(new int[] {0, 0}, l0.getPadding());
assertEquals(l0.getConvolutionMode(), ConvolutionMode.Truncate); //Pre-0.7.0: no ConvolutionMode. Want to default to truncate here if not set
assertEquals(ConvolutionMode.Truncate, l0.getConvolutionMode()); //Pre-0.7.0: no ConvolutionMode. Want to default to truncate here if not set
SubsamplingLayer l1 = (SubsamplingLayer) conf.getConf(1).getLayer();
assertArrayEquals(new int[] {2, 2}, l1.getKernelSize());
assertArrayEquals(new int[] {1, 1}, l1.getStride());
assertArrayEquals(new int[] {0, 0}, l1.getPadding());
assertEquals(PoolingType.MAX, l1.getPoolingType());
assertEquals(l1.getConvolutionMode(), ConvolutionMode.Truncate); //Pre-0.7.0: no ConvolutionMode. Want to default to truncate here if not set
assertEquals(ConvolutionMode.Truncate, l1.getConvolutionMode()); //Pre-0.7.0: no ConvolutionMode. Want to default to truncate here if not set
OutputLayer l2 = (OutputLayer) conf.getConf(2).getLayer();
assertEquals("sigmoid", l2.getActivationFn().toString());

View File

@ -162,14 +162,14 @@ public class RegressionTest060 extends BaseDL4JTest {
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
assertArrayEquals(new int[] {1, 1}, l0.getStride());
assertArrayEquals(new int[] {0, 0}, l0.getPadding());
assertEquals(l0.getConvolutionMode(), ConvolutionMode.Truncate); //Pre-0.7.0: no ConvolutionMode. Want to default to truncate here if not set
assertEquals(ConvolutionMode.Truncate, l0.getConvolutionMode()); //Pre-0.7.0: no ConvolutionMode. Want to default to truncate here if not set
SubsamplingLayer l1 = (SubsamplingLayer) conf.getConf(1).getLayer();
assertArrayEquals(new int[] {2, 2}, l1.getKernelSize());
assertArrayEquals(new int[] {1, 1}, l1.getStride());
assertArrayEquals(new int[] {0, 0}, l1.getPadding());
assertEquals(PoolingType.MAX, l1.getPoolingType());
assertEquals(l1.getConvolutionMode(), ConvolutionMode.Truncate); //Pre-0.7.0: no ConvolutionMode. Want to default to truncate here if not set
assertEquals(ConvolutionMode.Truncate, l1.getConvolutionMode()); //Pre-0.7.0: no ConvolutionMode. Want to default to truncate here if not set
OutputLayer l2 = (OutputLayer) conf.getConf(2).getLayer();
assertEquals("sigmoid", l2.getActivationFn().toString());

View File

@ -162,7 +162,7 @@ public class RegressionTest071 extends BaseDL4JTest {
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
assertArrayEquals(new int[] {1, 1}, l0.getStride());
assertArrayEquals(new int[] {0, 0}, l0.getPadding());
assertEquals(l0.getConvolutionMode(), ConvolutionMode.Same);
assertEquals(ConvolutionMode.Same, l0.getConvolutionMode());
SubsamplingLayer l1 = (SubsamplingLayer) conf.getConf(1).getLayer();
assertArrayEquals(new int[] {2, 2}, l1.getKernelSize());

View File

@ -176,14 +176,14 @@ public class RegressionTest080 extends BaseDL4JTest {
assertArrayEquals(new int[] {2, 2}, l0.getKernelSize());
assertArrayEquals(new int[] {1, 1}, l0.getStride());
assertArrayEquals(new int[] {0, 0}, l0.getPadding());
assertEquals(l0.getConvolutionMode(), ConvolutionMode.Same);
assertEquals(ConvolutionMode.Same, l0.getConvolutionMode());
SubsamplingLayer l1 = (SubsamplingLayer) conf.getConf(1).getLayer();
assertArrayEquals(new int[] {2, 2}, l1.getKernelSize());
assertArrayEquals(new int[] {1, 1}, l1.getStride());
assertArrayEquals(new int[] {0, 0}, l1.getPadding());
assertEquals(PoolingType.MAX, l1.getPoolingType());
assertEquals(l1.getConvolutionMode(), ConvolutionMode.Same);
assertEquals(ConvolutionMode.Same, l1.getConvolutionMode());
OutputLayer l2 = (OutputLayer) conf.getConf(2).getLayer();
assertTrue(l2.getActivationFn() instanceof ActivationSigmoid);

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.datasets.fetchers;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.deeplearning4j.base.EmnistFetcher;
@ -36,6 +37,7 @@ import java.util.Random;
* @author Alex Black
*
*/
@Slf4j
public class EmnistDataFetcher extends MnistDataFetcher implements DataSetFetcher {
protected EmnistFetcher fetcher;
@ -64,7 +66,7 @@ public class EmnistDataFetcher extends MnistDataFetcher implements DataSetFetche
try {
man = new MnistManager(images, labels, totalExamples);
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
FileUtils.deleteDirectory(new File(EMNIST_ROOT));
new EmnistFetcher(dataSet).downloadAndUntar();
man = new MnistManager(images, labels, totalExamples);

View File

@ -687,9 +687,7 @@ public class BarnesHutTsne implements Model {
* @throws IOException
*/
public void saveAsFile(List<String> labels, String path) throws IOException {
BufferedWriter write = null;
try {
write = new BufferedWriter(new FileWriter(new File(path)));
try (BufferedWriter write = new BufferedWriter(new FileWriter(new File(path)))) {
for (int i = 0; i < Y.rows(); i++) {
if (i >= labels.size())
break;
@ -711,17 +709,11 @@ public class BarnesHutTsne implements Model {
}
write.flush();
write.close();
} finally {
if (write != null)
write.close();
}
}
public void saveAsFile(String path) throws IOException {
BufferedWriter write = null;
try {
write = new BufferedWriter(new FileWriter(new File(path)));
try (BufferedWriter write = new BufferedWriter(new FileWriter(new File(path)))) {
for (int i = 0; i < Y.rows(); i++) {
StringBuilder sb = new StringBuilder();
INDArray wordVector = Y.getRow(i);
@ -734,10 +726,6 @@ public class BarnesHutTsne implements Model {
write.write(sb.toString());
}
write.flush();
write.close();
} finally {
if (write != null)
write.close();
}
}
/**

View File

@ -60,7 +60,7 @@ public class Hdf5Archive implements Closeable {
/* This is necessary for the call to the BytePointer constructor below. */
Loader.load(org.bytedeco.hdf5.global.hdf5.class);
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
}
}

View File

@ -69,7 +69,7 @@ public class KerasModelImportTest extends BaseDL4JTest {
network = KerasModelImport.importKerasSequentialModelAndWeights(Resources.asFile(modelJsonFilename).getAbsolutePath(),
Resources.asFile(modelWeightFilename).getAbsolutePath(), false);
} catch (IOException | InvalidKerasConfigurationException | UnsupportedKerasConfigurationException e) {
e.printStackTrace();
log.error("",e);
}
return network;
@ -80,7 +80,7 @@ public class KerasModelImportTest extends BaseDL4JTest {
try {
model = KerasModelImport.importKerasSequentialModelAndWeights(Resources.asFile(modelFilename).getAbsolutePath());
} catch (IOException | InvalidKerasConfigurationException | UnsupportedKerasConfigurationException e) {
e.printStackTrace();
log.error("",e);
}
return model;

View File

@ -210,7 +210,6 @@ public class NearestNeighborsServer extends AbstractVerticle {
return;
} catch (Throwable e) {
log.error("Error in POST /knn",e);
e.printStackTrace();
rc.response().setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code())
.end("Error parsing request - " + e.getMessage());
return;
@ -265,7 +264,6 @@ public class NearestNeighborsServer extends AbstractVerticle {
.end(j);
} catch (Throwable e) {
log.error("Error in POST /knnnew",e);
e.printStackTrace();
rc.response().setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code())
.end("Error parsing request - " + e.getMessage());
return;

View File

@ -75,7 +75,6 @@ public class StopRecognition implements Recognition {
try {
regexList.add(Pattern.compile(regex));
} catch (Exception e) {
e.printStackTrace();
LOG.error("regex err : " + regex, e);
}
}

View File

@ -20,10 +20,12 @@ import com.atilika.kuromoji.dict.CharacterDefinitions;
import com.atilika.kuromoji.dict.ConnectionCosts;
import com.atilika.kuromoji.dict.UnknownDictionary;
import com.atilika.kuromoji.trie.DoubleArrayTrie;
import lombok.extern.slf4j.Slf4j;
import java.io.*;
import java.util.List;
@Slf4j
public abstract class DictionaryCompilerBase {
public void build(String inputDirname, String outputDirname, String encoding, boolean compactTries)
@ -66,7 +68,7 @@ public abstract class DictionaryCompilerBase {
}
}
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
}
ProgressLog.end();

View File

@ -33,7 +33,6 @@ public class FileResourceResolver implements ResourceResolver {
try {
KuromojiBinFilesFetcher.downloadAndUntar();
} catch (IOException e) {
e.printStackTrace();
log.error("IOException : ", e);
}
}

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.text.corpora.sentiwordnet;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.shade.guava.collect.Sets;
import org.apache.uima.analysis_engine.AnalysisEngine;
import org.apache.uima.cas.CAS;
@ -37,6 +38,7 @@ import java.util.*;
* @author Adam Gibson
*
*/
@Slf4j
public class SWN3 implements Serializable {
/**
*
@ -120,7 +122,7 @@ public class SWN3 implements Serializable {
try {
csv.close();
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
}
}

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.text.tokenization.tokenizer;
import lombok.extern.slf4j.Slf4j;
import org.apache.uima.cas.CAS;
import org.apache.uima.fit.util.JCasUtil;
import org.cleartk.token.type.Token;
@ -32,6 +33,7 @@ import java.util.List;
* @author Adam Gibson
*
*/
@Slf4j
public class UimaTokenizer implements Tokenizer {
private List<String> tokens;
@ -66,7 +68,7 @@ public class UimaTokenizer implements Tokenizer {
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
throw new RuntimeException(e);
}

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.models.word2vec;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import org.apache.commons.io.LineIterator;
import org.deeplearning4j.text.sentenceiterator.CollectionSentenceIterator;
@ -64,6 +65,7 @@ import static org.junit.Assert.*;
/**
* @author jeffreytang
*/
@Slf4j
public class Word2VecTests extends BaseDL4JTest {
private static final Logger log = LoggerFactory.getLogger(Word2VecTests.class);
@ -621,7 +623,7 @@ public class Word2VecTests extends BaseDL4JTest {
unserialized = Word2Vec.fromJson(json);
}
catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail();
}

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.models.sequencevectors.listeners;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.models.sequencevectors.SequenceVectors;
import org.deeplearning4j.models.sequencevectors.enums.ListenerEvent;
import org.deeplearning4j.models.sequencevectors.interfaces.VectorsListener;
@ -34,6 +35,7 @@ import java.util.concurrent.Semaphore;
*
* @author raver119@gmail.com
*/
@Slf4j
public class SerializingListener<T extends SequenceElement> implements VectorsListener<T> {
private File targetFolder = new File("./");
private String modelPrefix = "Model_";
@ -96,7 +98,7 @@ public class SerializingListener<T extends SequenceElement> implements VectorsLi
}
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
} finally {
locker.release();
}

View File

@ -147,7 +147,7 @@ public class ParallelTransformerIterator extends BasicTransformerIterator {
try {
buffer.put(futureSequence);
} catch (InterruptedException e) {
e.printStackTrace();
log.error("",e);
}
}
/* else

View File

@ -21,6 +21,7 @@ import lombok.NonNull;
import org.deeplearning4j.parallelism.AsyncIterator;
import java.util.Iterator;
import java.util.NoSuchElementException;
/**
* @author raver119@gmail.com
@ -77,7 +78,7 @@ public class AsyncLabelAwareIterator implements LabelAwareIterator, Iterator<Lab
}
@Override
public LabelledDocument next() {
public LabelledDocument next() throws NoSuchElementException {
return nextDocument();
}
}

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.text.sentenceiterator;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import java.io.*;
import java.util.Iterator;
@ -29,6 +30,7 @@ import java.util.Iterator;
*
* @author raver119@gmail.com
*/
@Slf4j
public class BasicLineIterator implements SentenceIterator, Iterable<String> {
private BufferedReader reader;
@ -113,7 +115,7 @@ public class BasicLineIterator implements SentenceIterator, Iterable<String> {
reader.close();
} catch (Exception e) {
// do nothing here
e.printStackTrace();
log.error("",e);
}
super.finalize();
}

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.text.sentenceiterator;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.text.documentiterator.DocumentIterator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -35,6 +36,7 @@ import java.util.concurrent.atomic.AtomicBoolean;
*
* @author raver119@gmail.com
*/
@Slf4j
public class StreamLineIterator implements SentenceIterator {
private DocumentIterator iterator;
private int linesToFetch;
@ -64,7 +66,7 @@ public class StreamLineIterator implements SentenceIterator {
currentReader = null;
}
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
throw new RuntimeException(e);
}
}
@ -145,7 +147,7 @@ public class StreamLineIterator implements SentenceIterator {
try {
this.onlyStream.reset();
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
throw new RuntimeException(e);
}
}

View File

@ -183,7 +183,7 @@ public class FastTextTest extends BaseDL4JTest {
fastText.loadIterator();
} catch (IOException e) {
log.error(e.toString());
log.error("",e);
}
}

View File

@ -1164,7 +1164,7 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
unserialized = ParagraphVectors.fromJson(json);
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail();
}

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.models.sequencevectors.serialization;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.apache.commons.lang.StringUtils;
import org.deeplearning4j.BaseDL4JTest;
@ -48,6 +49,7 @@ import java.util.Collections;
import static org.junit.Assert.*;
@Slf4j
public class WordVectorSerializerTest extends BaseDL4JTest {
private AbstractCache<VocabWord> cache;
@ -97,7 +99,7 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
byte[] bytesResult = baos.toByteArray();
deser = WordVectorSerializer.readSequenceVectors(new ByteArrayInputStream(bytesResult), true);
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail();
}
@ -175,7 +177,7 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
byte[] bytesResult = baos.toByteArray();
deser = WordVectorSerializer.readWord2Vec(new ByteArrayInputStream(bytesResult), true);
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail();
}
@ -223,7 +225,7 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
byte[] bytesResult = baos.toByteArray();
deser = WordVectorSerializer.readWord2Vec(new ByteArrayInputStream(bytesResult), true);
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail();
}
@ -268,7 +270,7 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
deser = WordVectorSerializer.readLookupTable(file);
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail();
}
assertEquals(lookupTable.getVocab().totalWordOccurrences(), ((InMemoryLookupTable<VocabWord>)deser).getVocab().totalWordOccurrences());
@ -306,7 +308,7 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
deser = WordVectorSerializer.readWordVectors(new File(dir, "some.data"));
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail();
}

View File

@ -140,7 +140,7 @@ public class AbstractCacheTest extends BaseDL4JTest {
unserialized = AbstractCache.fromJson(json);
}
catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail();
}
assertEquals(cache.totalWordOccurrences(),unserialized.totalWordOccurrences());
@ -175,7 +175,7 @@ public class AbstractCacheTest extends BaseDL4JTest {
unserialized = AbstractCache.fromJson(json);
}
catch (Exception e) {
e.printStackTrace();
log.error("",e);
fail();
}
assertEquals(cache.totalWordOccurrences(),unserialized.totalWordOccurrences());

View File

@ -63,7 +63,7 @@ public abstract class AbstractLayer<LayerConfT extends org.deeplearning4j.nn.con
public AbstractLayer(NeuralNetConfiguration conf, DataType dataType) {
this.conf = conf;
if (conf != null)
cacheMode = conf.getCacheMode();
cacheMode = conf.getCacheMode();
this.dataType = dataType;
}

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.nn.layers;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
@ -43,6 +44,7 @@ import java.util.*;
* A layer with parameters
* @author Adam Gibson
*/
@Slf4j
public abstract class BaseLayer<LayerConfT extends org.deeplearning4j.nn.conf.layers.BaseLayer>
extends AbstractLayer<LayerConfT> {
@ -371,7 +373,7 @@ public abstract class BaseLayer<LayerConfT extends org.deeplearning4j.nn.conf.la
}
layer.setParamTable(linkedTable);
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
}
return layer;

View File

@ -1,5 +1,6 @@
package org.deeplearning4j.remote;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.datavec.image.loader.Java2DNativeImageLoader;
import org.deeplearning4j.BaseDL4JTest;
@ -32,6 +33,7 @@ import java.util.concurrent.TimeUnit;
import static org.deeplearning4j.parallelism.inference.InferenceMode.SEQUENTIAL;
import static org.junit.Assert.*;
@Slf4j
public class BinaryModelServerTest extends BaseDL4JTest {
private final int PORT = 18080;
@ -120,7 +122,7 @@ public class BinaryModelServerTest extends BaseDL4JTest {
assertEquals(new Integer(1), result);
} catch (Exception e){
e.printStackTrace();
log.error("",e);
throw e;
} finally {
server.stop();
@ -189,7 +191,7 @@ public class BinaryModelServerTest extends BaseDL4JTest {
assertEquals(new Integer(1), results[2].get());
} catch (Exception e){
e.printStackTrace();
log.error("",e);
throw e;
} finally {
server.stop();
@ -244,7 +246,7 @@ public class BinaryModelServerTest extends BaseDL4JTest {
assertEquals(28, result.getWidth());
} catch (Exception e){
e.printStackTrace();
log.error("",e);
throw e;
} finally {
server.stop();

View File

@ -585,7 +585,7 @@ public class JsonModelServerTest extends BaseDL4JTest {
assertEquals(exp.argMax().getInt(0), out);
}
} catch (Exception e){
e.printStackTrace();
log.error("",e);
throw e;
} finally {
server.stop();
@ -640,7 +640,7 @@ public class JsonModelServerTest extends BaseDL4JTest {
server.start();
//client.predict(new float[]{0.0f, 1.0f, 2.0f});
} catch (Exception e){
e.printStackTrace();
log.error("",e);
throw e;
} finally {
server.stop();
@ -700,7 +700,7 @@ public class JsonModelServerTest extends BaseDL4JTest {
val result = client.predict(new float[]{0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f});
assertNotNull(result);
} catch (Exception e){
e.printStackTrace();
log.error("",e);
throw e;
} finally {
server.stop();

View File

@ -660,7 +660,7 @@ public class ParallelInferenceTest extends BaseDL4JTest {
//OK
System.out.println("Expected exception: " + e.getMessage());
} catch (Exception e){
e.printStackTrace();
log.error("",e);
fail("Expected other exception type");
}
@ -903,7 +903,7 @@ public class ParallelInferenceTest extends BaseDL4JTest {
int idx = t.getRight();
act[idx] = inf.output(t.getFirst(), t.getSecond());
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
failedCount.incrementAndGet();
}
}
@ -955,7 +955,7 @@ public class ParallelInferenceTest extends BaseDL4JTest {
act[j] = inf.output(in.get(j), inMask);
counter.incrementAndGet();
} catch (Exception e){
e.printStackTrace();
log.error("",e);
failedCount.incrementAndGet();
}
}

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.spark.text.functions;
import lombok.extern.slf4j.Slf4j;
import org.apache.spark.api.java.function.Function;
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess;
import org.deeplearning4j.text.tokenization.tokenizerfactory.NGramTokenizerFactory;
@ -29,6 +30,7 @@ import java.util.List;
* @author Adam Gibson
*/
@SuppressWarnings("unchecked")
@Slf4j
public class TokenizerFunction implements Function<String, List<String>> {
private String tokenizerFactoryClazz;
private String tokenizerPreprocessorClazz;
@ -69,7 +71,7 @@ public class TokenizerFunction implements Function<String, List<String>> {
tokenizerFactory = new NGramTokenizerFactory(tokenizerFactory, nGrams, nGrams);
}
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
}
return tokenizerFactory;
}

View File

@ -120,7 +120,7 @@ public class UpdatesConsumer implements UpdatesHandler {
//log.info("Putting update to the queue, current size: [{}]", updatesBuffer.size());
updatesBuffer.put(array);
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
throw new RuntimeException(e);
}
} else if (params != null && stepFunction != null) {

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.spark.datavec;
import lombok.extern.slf4j.Slf4j;
import org.apache.hadoop.io.BytesWritable;
import org.apache.hadoop.io.Text;
import org.apache.spark.api.java.function.PairFunction;
@ -35,6 +36,7 @@ import java.util.List;
/**
*/
@Slf4j
public class DataVecByteDataSetFunction implements PairFunction<Tuple2<Text, BytesWritable>, Double, DataSet> {
private int labelIndex = 0;
@ -104,7 +106,7 @@ public class DataVecByteDataSetFunction implements PairFunction<Tuple2<Text, Byt
featureVector = Nd4j.create(lenFeatureVector);
}
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
List<INDArray> inputs = new ArrayList<>();

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.spark.datavec;
import lombok.extern.slf4j.Slf4j;
import org.apache.spark.api.java.function.Function;
import org.datavec.api.io.WritableConverter;
import org.datavec.api.io.converters.WritableConverterException;
@ -36,6 +37,7 @@ import java.util.List;
* Analogous to {@link RecordReaderDataSetIterator}, but in the context of Spark.
* @author Alex Black
*/
@Slf4j
public class DataVecDataSetFunction implements Function<List<Writable>, DataSet>, Serializable {
private final int labelIndex;
@ -129,7 +131,8 @@ public class DataVecDataSetFunction implements Function<List<Writable>, DataSet>
try {
current = converter.convert(current);
} catch (WritableConverterException e) {
e.printStackTrace();
log.error("",e);
}
}
if (regression) {

View File

@ -33,7 +33,7 @@ public class BalancedPartitionerTest {
// the 10 first elements should go in the 1st partition
for (int i = 0; i < 10; i++) {
int p = bp.getPartition(i);
assertEquals("Found wrong partition output " + p + ", not 0", p, 0);
assertEquals("Found wrong partition output " + p + ", not 0", 0, p);
}
}
@ -43,7 +43,7 @@ public class BalancedPartitionerTest {
// the 10 first elements should go in the 1st partition
for (int i = 0; i < 10; i++) {
int p = bp.getPartition(i);
assertEquals("Found wrong partition output " + p + ", not 0", p, 0);
assertEquals("Found wrong partition output " + p + ", not 0", 0, p);
}
}
@ -56,7 +56,7 @@ public class BalancedPartitionerTest {
countPerPartition[p] += 1;
}
for (int i = 0; i < 10; i++) {
assertEquals(countPerPartition[i], 10);
assertEquals(10, countPerPartition[i]);
}
}
@ -70,9 +70,9 @@ public class BalancedPartitionerTest {
}
for (int i = 0; i < 10; i++) {
if (i < 7)
assertEquals(countPerPartition[i], 10 + 1);
assertEquals(10 + 1, countPerPartition[i]);
else
assertEquals(countPerPartition[i], 10);
assertEquals(10, countPerPartition[i]);
}
}

View File

@ -385,7 +385,7 @@ public abstract class BaseStatsListener implements RoutingIterationListener {
gpuMaxBytes[i] = nativeOps.getDeviceTotalMemory(0);
gpuCurrentBytes[i] = gpuMaxBytes[i] - nativeOps.getDeviceFreeMemory(0);
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
}
}
}

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.ui.weights;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.datavec.image.loader.ImageLoader;
import org.deeplearning4j.api.storage.Persistable;
@ -49,6 +50,7 @@ import java.util.List;
/**
* @author raver119@gmail.com
*/
@Slf4j
public class ConvolutionalIterationListener extends BaseTrainingListener {
private enum Orientation {
@ -661,7 +663,7 @@ public class ConvolutionalIterationListener extends BaseTrainingListener {
try {
ImageIO.write(renderImageGrayscale(array), "png", file);
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
}
@ -670,7 +672,7 @@ public class ConvolutionalIterationListener extends BaseTrainingListener {
try {
ImageIO.write(image, "png", file);
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
}

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.ui;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import org.datavec.image.loader.LFWLoader;
import org.deeplearning4j.datasets.iterator.impl.LFWDataSetIterator;
@ -82,6 +83,7 @@ import static org.junit.Assert.fail;
* @author raver119@gmail.com
*/
@Ignore
@Slf4j
public class ManualTests {
private static Logger log = LoggerFactory.getLogger(ManualTests.class);
@ -258,7 +260,7 @@ public class ManualTests {
try {
ImageIO.write(imageToRender, "png", file);
} catch (IOException e) {
e.printStackTrace();
log.error("",e);
}
}

View File

@ -926,7 +926,7 @@ public class TrainModule implements UIModule {
NeuralNetConfiguration.mapper().readValue(config, NeuralNetConfiguration.class);
return new Triple<>(null, null, layer);
} catch (Exception e) {
e.printStackTrace();
log.error("",e);
}
}
return null;

View File

@ -18,6 +18,7 @@
package org.deeplearning4j.ui;
import io.netty.handler.codec.http.HttpResponseStatus;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
@ -53,6 +54,7 @@ import static org.junit.Assert.*;
* @author Tamas Fenyvesi
*/
@Ignore
@Slf4j
public class TestVertxUIMultiSession extends BaseDL4JTest {
@Before
public void setUp() throws Exception {
@ -121,7 +123,7 @@ public class TestVertxUIMultiSession extends BaseDL4JTest {
assertEquals(HttpResponseStatus.OK.code(), conn.getResponseCode());
assertTrue(uIServer.isAttached(ss));
} catch (InterruptedException | IOException e) {
e.printStackTrace();
log.error("",e);
fail(e.getMessage());
} finally {
uIServer.detach(ss);
@ -294,7 +296,7 @@ public class TestVertxUIMultiSession extends BaseDL4JTest {
" after " + autoDetachTimeoutMillis + " ms ");
Thread.sleep(autoDetachTimeoutMillis);
} catch (InterruptedException e) {
e.printStackTrace();
log.error("",e);
} finally {
System.out.println("Auto-detaching StatsStorage (session ID: " + sessionId + ") after " +
autoDetachTimeoutMillis + " ms.");

View File

@ -21,7 +21,6 @@ import lombok.Builder;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.conf.distribution.GaussianDistribution;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;

View File

@ -1045,7 +1045,7 @@ public class IntegrationTestRunner {
act[j] = inf.output(in.get(j).getFirst(), inMask);
counter.incrementAndGet();
} catch (Exception e){
e.printStackTrace();
log.error("",e);
failedCount.incrementAndGet();
}
}

View File

@ -189,7 +189,7 @@ public abstract class DifferentialFunction {
try {
return property.get(this);
} catch (IllegalAccessException e) {
e.printStackTrace();
log.error("",e);
}
return null;

View File

@ -80,7 +80,7 @@ public class Loss {
public static Loss sum(List<Loss> losses) {
if (losses.size() == 0)
if (losses.isEmpty())
return new Loss(Collections.<String>emptyList(), new double[0]);
double[] lossValues = new double[losses.get(0).losses.length];

View File

@ -17,10 +17,7 @@
package org.nd4j.autodiff.listeners.impl;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import lombok.Getter;
import lombok.Setter;
@ -34,9 +31,6 @@ import org.nd4j.autodiff.listeners.records.LossCurve;
import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.TrainingConfig;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
/**
* HistoryListener is mainly used internally to collect information such as the loss curve and evaluations,

View File

@ -154,11 +154,9 @@ public abstract class AbstractDependencyTracker<T, D> {
}
}
if (allSatisfied) {
if (!this.allSatisfied.contains(t)) {
this.allSatisfied.add(t);
this.allSatisfiedQueue.add(t);
}
if (allSatisfied && !this.allSatisfied.contains(t)) {
this.allSatisfied.add(t);
this.allSatisfiedQueue.add(t);
}
}
}
@ -278,25 +276,25 @@ public abstract class AbstractDependencyTracker<T, D> {
protected boolean isAllSatisfied(@NonNull T y) {
Set<D> set1 = dependencies.get(y);
boolean allSatisfied = true;
boolean retVal = true;
if (set1 != null) {
for (D d : set1) {
allSatisfied = isSatisfied(d);
if (!allSatisfied)
retVal = isSatisfied(d);
if (!retVal)
break;
}
}
if (allSatisfied) {
if (retVal) {
Set<Pair<D, D>> set2 = orDependencies.get(y);
if (set2 != null) {
for (Pair<D, D> p : set2) {
allSatisfied = isSatisfied(p.getFirst()) || isSatisfied(p.getSecond());
if (!allSatisfied)
retVal = isSatisfied(p.getFirst()) || isSatisfied(p.getSecond());
if (!retVal)
break;
}
}
}
return allSatisfied;
return retVal;
}

View File

@ -16,8 +16,6 @@
package org.nd4j.evaluation.custom;
import org.nd4j.shade.guava.collect.Lists;
import java.util.List;
/**

View File

@ -21,10 +21,8 @@ import lombok.Getter;
import lombok.val;
import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SeluBp;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Mish;
import org.nd4j.linalg.api.ops.impl.transforms.strict.MishDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.strict.SELU;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;

View File

@ -19,7 +19,6 @@ import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;

View File

@ -24,7 +24,6 @@ import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

View File

@ -21,7 +21,6 @@ import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseIndexAccumulation;
import org.nd4j.linalg.factory.Nd4j;
import java.util.Collections;
import java.util.List;

View File

@ -21,7 +21,6 @@ import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseIndexAccumulation;
import org.nd4j.linalg.factory.Nd4j;
import java.util.Collections;
import java.util.List;

View File

@ -68,7 +68,7 @@ public class AvgPooling3D extends Pooling3D {
return config.toProperties();
}
@Override
public String getPoolingPrefix() {
return "avg";
}

View File

@ -82,7 +82,7 @@ public class MaxPoolWithArgmax extends DynamicCustomOp {
@Override
public Map<String, Object> propertiesForFunction() {
if(config == null && iArguments.size() > 0){
if(config == null && !iArguments.isEmpty()){
//Perhaps loaded from FlatBuffers - hence we have IArgs but not Config object
config = Pooling2DConfig.builder()
.kH(iArguments.get(0))

View File

@ -85,7 +85,7 @@ public class MaxPooling2D extends DynamicCustomOp {
@Override
public Map<String, Object> propertiesForFunction() {
if(config == null && iArguments.size() > 0){
if(config == null && !iArguments.isEmpty()){
//Perhaps loaded from FlatBuffers - hence we have IArgs but not Config object
config = Pooling2DConfig.builder()
.kH(iArguments.get(0))

View File

@ -75,7 +75,7 @@ public class MaxPooling3D extends Pooling3D {
return config.toProperties();
}
@Override
public String getPoolingPrefix() {
return "max";
}

View File

@ -18,11 +18,14 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution.config;
import java.util.LinkedHashMap;
import java.util.Map;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import java.lang.reflect.Field;
@Slf4j
public abstract class BaseConvolutionConfig {
public abstract Map<String, Object> toProperties();
@ -61,7 +64,7 @@ public abstract class BaseConvolutionConfig {
try {
target.set(this, value);
} catch (IllegalAccessException e) {
e.printStackTrace();
log.error("",e);
}
}

View File

@ -24,12 +24,10 @@ import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.convolution.DepthwiseConv2DBp;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMLayerWeights;
import org.nd4j.shade.guava.primitives.Booleans;
import javax.xml.crypto.Data;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

Some files were not shown because too many files have changed in this diff Show More