Various fixes (#143)

* #8568 ArrayUtil optimization

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #6171 Keras ReLU and ELU support

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Keras softmax layer import

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #8549 Webjars dependency management

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix for TF import names ':0' suffix issue / NPE

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* BiasAdd: fix default data format for TF import

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Update zoo test ignores

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #8509 SameDiff Listener API - provide frame + iteration

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #8520 ND4J Environment

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Deconv3d

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Deconv3d fixes + gradient check

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Conv3d fixes + deconv3d DType test

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix issue with deconv3d gradinet check weight init

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #8579 Fix BaseCudaDataBuffer constructor fix for UINT16

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* DataType.isNumerical() returns false for BOOL type

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #8504 Reduce Spark log spam for tests

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Clean up DL4J gradient check test spam

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* More Gradient check spam reduction

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* SameDiff test spam reduction

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fixes for FlatBuffers mapping

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* SameDiff log spam cleanup

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Tests should extend BaseNd4jTest

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Remove debug line in c++ op

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* ND4J test spam cleanup

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* DL4J test spam reduction

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* More Dl4J and datavec test spam cleanup

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix for bad conv3d test

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Additional test

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Embedding layers: don't inherit global default activation function

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Trigger CI

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Consolidate all BaseDL4JTest classes to single class used everywhere; make timeout configurable per class

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Test fixes and timeout increases

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Timeouts and PReLU fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Restore libnd4j build threads arg for CUDA build

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Increase timeouts on a few tests to avoid spurious failures on some CI machines

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* More timeout fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* More test timeout fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Tweak timeout for one more test

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Final tweaks

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* One more ignore

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2020-01-04 13:45:07 +11:00 committed by GitHub
parent ac0d249f07
commit 29104083cc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
375 changed files with 3944 additions and 4078 deletions

View File

@ -311,7 +311,11 @@ public class CSVRecordReaderTest {
rr.reset(); rr.reset();
fail("Expected exception"); fail("Expected exception");
} catch (Exception e){ } catch (Exception e){
e.printStackTrace(); String msg = e.getMessage();
String msg2 = e.getCause().getMessage();
assertTrue(msg, msg.contains("Error during LineRecordReader reset"));
assertTrue(msg2, msg2.contains("Reset not supported from streams"));
// e.printStackTrace();
} }
} }

View File

@ -55,8 +55,7 @@ public class LineReaderTest {
@Test @Test
public void testLineReader() throws Exception { public void testLineReader() throws Exception {
String tempDir = System.getProperty("java.io.tmpdir"); File tmpdir = testDir.newFolder();
File tmpdir = new File(tempDir, "tmpdir-testLineReader");
if (tmpdir.exists()) if (tmpdir.exists())
tmpdir.delete(); tmpdir.delete();
tmpdir.mkdir(); tmpdir.mkdir();
@ -84,12 +83,6 @@ public class LineReaderTest {
} }
assertEquals(9, count); assertEquals(9, count);
try {
FileUtils.deleteDirectory(tmpdir);
} catch (Exception e) {
e.printStackTrace();
}
} }
@Test @Test
@ -145,13 +138,6 @@ public class LineReaderTest {
assertEquals(2, subset.size()); assertEquals(2, subset.size());
assertEquals(out3.get(4), subset.get(0)); assertEquals(out3.get(4), subset.get(0));
assertEquals(out3.get(7), subset.get(1)); assertEquals(out3.get(7), subset.get(1));
try {
FileUtils.deleteDirectory(tmpdir);
} catch (Exception e) {
e.printStackTrace();
}
} }
@Test @Test
@ -177,11 +163,5 @@ public class LineReaderTest {
} }
assertEquals(9, count); assertEquals(9, count);
try {
FileUtils.deleteDirectory(tmpdir);
} catch (Exception e) {
e.printStackTrace();
}
} }
} }

View File

@ -66,9 +66,9 @@ public class JsonYamlTest {
String asJson = itp.toJson(); String asJson = itp.toJson();
String asYaml = itp.toYaml(); String asYaml = itp.toYaml();
System.out.println(asJson); // System.out.println(asJson);
System.out.println("\n\n\n"); // System.out.println("\n\n\n");
System.out.println(asYaml); // System.out.println(asYaml);
ImageWritable img = TestImageTransform.makeRandomImage(0, 0, 3); ImageWritable img = TestImageTransform.makeRandomImage(0, 0, 3);
ImageWritable imgJson = new ImageWritable(img.getFrame().clone()); ImageWritable imgJson = new ImageWritable(img.getFrame().clone());

View File

@ -60,7 +60,7 @@ public class CSVSparkTransformTest {
Base64NDArrayBody body = csvSparkTransform.toArray(new SingleCSVRecord(values)); Base64NDArrayBody body = csvSparkTransform.toArray(new SingleCSVRecord(values));
INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray()); INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray());
assertTrue(fromBase64.isVector()); assertTrue(fromBase64.isVector());
System.out.println("Base 64ed array " + fromBase64); // System.out.println("Base 64ed array " + fromBase64);
} }
@Test @Test
@ -125,7 +125,7 @@ public class CSVSparkTransformTest {
SequenceBatchCSVRecord transformed = csvSparkTransform.transformSequence(sequenceBatchCSVRecord); SequenceBatchCSVRecord transformed = csvSparkTransform.transformSequence(sequenceBatchCSVRecord);
assertNotNull(transformed.getRecords()); assertNotNull(transformed.getRecords());
System.out.println(transformed); // System.out.println(transformed);
} }
@ -153,7 +153,8 @@ public class CSVSparkTransformTest {
new SingleCSVRecord(data2))); new SingleCSVRecord(data2)));
final CSVSparkTransform transform = new CSVSparkTransform(transformProcess); final CSVSparkTransform transform = new CSVSparkTransform(transformProcess);
System.out.println(transform.transformSequenceIncremental(batchCsvRecord)); // System.out.println(transform.transformSequenceIncremental(batchCsvRecord));
transform.transformSequenceIncremental(batchCsvRecord);
assertEquals(3,Nd4jBase64.fromBase64(transform.transformSequenceArrayIncremental(batchCsvRecord).getNdarray()).rank()); assertEquals(3,Nd4jBase64.fromBase64(transform.transformSequenceArrayIncremental(batchCsvRecord).getNdarray()).rank());
} }

View File

@ -54,7 +54,7 @@ public class ImageSparkTransformTest {
Base64NDArrayBody body = imgSparkTransform.toArray(imgRecord); Base64NDArrayBody body = imgSparkTransform.toArray(imgRecord);
INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray()); INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray());
System.out.println("Base 64ed array " + fromBase64); // System.out.println("Base 64ed array " + fromBase64);
assertEquals(1, fromBase64.size(0)); assertEquals(1, fromBase64.size(0));
} }
@ -78,7 +78,7 @@ public class ImageSparkTransformTest {
Base64NDArrayBody body = imgSparkTransform.toArray(batch); Base64NDArrayBody body = imgSparkTransform.toArray(batch);
INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray()); INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray());
System.out.println("Base 64ed array " + fromBase64); // System.out.println("Base 64ed array " + fromBase64);
assertEquals(3, fromBase64.size(0)); assertEquals(3, fromBase64.size(0));
} }
} }

View File

@ -120,7 +120,7 @@ public class ImageSparkTransformServerTest {
INDArray batchResult = getNDArray(jsonNodeBatch); INDArray batchResult = getNDArray(jsonNodeBatch);
assertEquals(3, batchResult.size(0)); assertEquals(3, batchResult.size(0));
System.out.println(array); // System.out.println(array);
} }
@Test @Test
@ -136,7 +136,7 @@ public class ImageSparkTransformServerTest {
INDArray batchResult = getNDArray(jsonNode); INDArray batchResult = getNDArray(jsonNode);
assertEquals(3, batchResult.size(0)); assertEquals(3, batchResult.size(0));
System.out.println(batchResult); // System.out.println(batchResult);
} }
@Test @Test
@ -153,7 +153,7 @@ public class ImageSparkTransformServerTest {
INDArray result = getNDArray(jsonNode); INDArray result = getNDArray(jsonNode);
assertEquals(1, result.size(0)); assertEquals(1, result.size(0));
System.out.println(result); // System.out.println(result);
} }
public INDArray getNDArray(JsonNode node) throws IOException { public INDArray getNDArray(JsonNode node) throws IOException {

View File

@ -72,7 +72,9 @@ public class TestAnalysis extends BaseSparkTest {
DataAnalysis da = AnalyzeSpark.analyze(schema, rdd); DataAnalysis da = AnalyzeSpark.analyze(schema, rdd);
String daString = da.toString(); String daString = da.toString();
System.out.println(da); // System.out.println(da);
da.toJson();
da.toString();
List<ColumnAnalysis> ca = da.getColumnAnalysis(); List<ColumnAnalysis> ca = da.getColumnAnalysis();
assertEquals(5, ca.size()); assertEquals(5, ca.size());
@ -151,7 +153,7 @@ public class TestAnalysis extends BaseSparkTest {
assertEquals(1, countD[countD.length - 1]); assertEquals(1, countD[countD.length - 1]);
File f = Files.createTempFile("datavec_spark_analysis_UITest", ".html").toFile(); File f = Files.createTempFile("datavec_spark_analysis_UITest", ".html").toFile();
System.out.println(f.getAbsolutePath()); // System.out.println(f.getAbsolutePath());
f.deleteOnExit(); f.deleteOnExit();
HtmlAnalysis.createHtmlAnalysisFile(da, f); HtmlAnalysis.createHtmlAnalysisFile(da, f);
} }
@ -210,7 +212,7 @@ public class TestAnalysis extends BaseSparkTest {
for( int i=1; i<10; i++ ){ for( int i=1; i<10; i++ ){
counter.merge(counters.get(i)); counter.merge(counters.get(i));
sparkCounter.merge(sparkCounters.get(i)); sparkCounter.merge(sparkCounters.get(i));
System.out.println(); // System.out.println();
} }
assertEquals(sc1.sampleStdev(), counter.getStddev(false), 1e-6); assertEquals(sc1.sampleStdev(), counter.getStddev(false), 1e-6);
assertEquals(sparkCounter.sampleStdev(), counter.getStddev(false), 1e-6); assertEquals(sparkCounter.sampleStdev(), counter.getStddev(false), 1e-6);
@ -356,7 +358,9 @@ public class TestAnalysis extends BaseSparkTest {
JavaRDD<List<Writable>> rdd = sc.parallelize(data); JavaRDD<List<Writable>> rdd = sc.parallelize(data);
DataAnalysis da = AnalyzeSpark.analyze(s, rdd); DataAnalysis da = AnalyzeSpark.analyze(s, rdd);
System.out.println(da); // System.out.println(da);
da.toString();
da.toJson();
} }
} }

View File

@ -0,0 +1,68 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~ Copyright (c) 2020 Konduit K.K.
~
~ This program and the accompanying materials are made available under the
~ terms of the Apache License, Version 2.0 which is available at
~ https://www.apache.org/licenses/LICENSE-2.0.
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
~ License for the specific language governing permissions and limitations
~ under the License.
~
~ SPDX-License-Identifier: Apache-2.0
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<artifactId>deeplearning4j-parent</artifactId>
<groupId>org.deeplearning4j</groupId>
<version>1.0.0-SNAPSHOT</version>
</parent>
<modelVersion>4.0.0</modelVersion>
<artifactId>deeplearning4j-common-tests</artifactId>
<dependencies>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-api</artifactId>
<version>${project.version}</version>
</dependency>
</dependencies>
<profiles>
<profile>
<id>test-nd4j-native</id>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
</profile>
<profile>
<id>test-nd4j-cuda-10.2</id>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-cuda-10.2</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
</profile>
</profiles>
</project>

View File

@ -1,5 +1,6 @@
/******************************************************************************* /* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019-2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -23,7 +24,7 @@ import org.junit.Before;
import org.junit.Rule; import org.junit.Rule;
import org.junit.rules.TestName; import org.junit.rules.TestName;
import org.junit.rules.Timeout; import org.junit.rules.Timeout;
import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.config.ND4JSystemProperties;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
@ -31,24 +32,28 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.profiler.ProfilerConfig; import org.nd4j.linalg.profiler.ProfilerConfig;
import java.lang.management.ManagementFactory; import java.lang.management.ManagementFactory;
import java.lang.management.ThreadMXBean;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Properties; import java.util.Properties;
import static org.junit.Assert.assertNull;
@Slf4j @Slf4j
public class BaseDL4JTest { public abstract class BaseDL4JTest {
@Rule @Rule
public TestName name = new TestName(); public TestName name = new TestName();
@Rule @Rule
public Timeout timeout = Timeout.seconds(30); public Timeout timeout = Timeout.millis(getTimeoutMilliseconds());
protected long startTime; protected long startTime;
protected int threadCountBefore; protected int threadCountBefore;
/**
* Override this method to set the default timeout for methods in the test class
*/
public long getTimeoutMilliseconds(){
return 30000;
}
/** /**
* Override this to set the profiling mode for the tests defined in the child class * Override this to set the profiling mode for the tests defined in the child class
*/ */
@ -70,6 +75,9 @@ public class BaseDL4JTest {
@Before @Before
public void beforeTest(){ public void beforeTest(){
log.info("{}.{}", getClass().getSimpleName(), name.getMethodName()); log.info("{}.{}", getClass().getSimpleName(), name.getMethodName());
//Suppress ND4J initialization - don't need this logged for every test...
System.setProperty(ND4JSystemProperties.LOG_INITIALIZATION, "false");
System.setProperty(ND4JSystemProperties.ND4J_IGNORE_AVX, "true");
Nd4j.getExecutioner().setProfilingMode(getProfilingMode()); Nd4j.getExecutioner().setProfilingMode(getProfilingMode());
Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build()); Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build());
Nd4j.setDefaultDataTypes(getDataType(), getDefaultFPDataType()); Nd4j.setDefaultDataTypes(getDataType(), getDefaultFPDataType());

View File

@ -95,6 +95,12 @@
<artifactId>junit</artifactId> <artifactId>junit</artifactId>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-common-tests</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
<dependency> <dependency>
<groupId>org.nd4j</groupId> <groupId>org.nd4j</groupId>

View File

@ -75,7 +75,6 @@ public class TestUtils {
} }
public static ComputationGraph testModelSerialization(ComputationGraph net){ public static ComputationGraph testModelSerialization(ComputationGraph net){
ComputationGraph restored; ComputationGraph restored;
try { try {
ByteArrayOutputStream baos = new ByteArrayOutputStream(); ByteArrayOutputStream baos = new ByteArrayOutputStream();

View File

@ -1006,7 +1006,9 @@ public class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
for (RecordMetaData m : meta) { for (RecordMetaData m : meta) {
Record r = csv.loadFromMetaData(m); Record r = csv.loadFromMetaData(m);
INDArray row = ds.getFeatures().getRow(i); INDArray row = ds.getFeatures().getRow(i);
System.out.println(m.getLocation() + "\t" + r.getRecord() + "\t" + row); if(i <= 3) {
System.out.println(m.getLocation() + "\t" + r.getRecord() + "\t" + row);
}
for (int j = 0; j < 4; j++) { for (int j = 0; j < 4; j++) {
double exp = r.getRecord().get(j).toDouble(); double exp = r.getRecord().get(j).toDouble();

View File

@ -183,7 +183,7 @@ public class AsyncDataSetIteratorTest extends BaseDL4JTest {
} }
adsi.reset(); adsi.reset();
log.info("Epoch {} finished...", e); // log.info("Epoch {} finished...", e);
} }
} }
@ -215,7 +215,7 @@ public class AsyncDataSetIteratorTest extends BaseDL4JTest {
} }
adsi.reset(); adsi.reset();
log.info("Epoch {} finished...", e); // log.info("Epoch {} finished...", e);
} }
} }
} }

View File

@ -57,6 +57,11 @@ import static org.junit.Assert.*;
public class DataSetIteratorTest extends BaseDL4JTest { public class DataSetIteratorTest extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 90000;
}
@Test @Test
public void testBatchSizeOfOneIris() throws Exception { public void testBatchSizeOfOneIris() throws Exception {
//Test for (a) iterators returning correct number of examples, and //Test for (a) iterators returning correct number of examples, and
@ -190,7 +195,7 @@ public class DataSetIteratorTest extends BaseDL4JTest {
INDArray output = model.output(dataTest.getFeatures()); INDArray output = model.output(dataTest.getFeatures());
Evaluation eval = new Evaluation(outputNum); Evaluation eval = new Evaluation(outputNum);
eval.eval(dataTest.getLabels(), output); eval.eval(dataTest.getLabels(), output);
System.out.println(eval.stats()); // System.out.println(eval.stats());
} }
@Test @Test
@ -257,7 +262,7 @@ public class DataSetIteratorTest extends BaseDL4JTest {
INDArray output = model.output(testDS.getFeatures()); INDArray output = model.output(testDS.getFeatures());
eval.eval(testDS.getLabels(), output); eval.eval(testDS.getLabels(), output);
} }
System.out.println(eval.stats(true)); // System.out.println(eval.stats(true));
listener.exportScores(System.out); listener.exportScores(System.out);
} }

View File

@ -68,8 +68,8 @@ public class VariableMultiTimeseriesGenerator implements MultiDataSetIterator {
int localMaxima = isFirst && firstMaxima > 0 ? firstMaxima int localMaxima = isFirst && firstMaxima > 0 ? firstMaxima
: minTS == maxTS ? minTS : rng.nextInt(maxTS - minTS) + minTS; : minTS == maxTS ? minTS : rng.nextInt(maxTS - minTS) + minTS;
if (isFirst) // if (isFirst)
log.info("Local maxima: {}", localMaxima); // log.info("Local maxima: {}", localMaxima);
isFirst = false; isFirst = false;

View File

@ -69,8 +69,8 @@ public class VariableTimeseriesGenerator implements DataSetIterator {
int localMaxima = isFirst && firstMaxima > 0 ? firstMaxima int localMaxima = isFirst && firstMaxima > 0 ? firstMaxima
: minTS == maxTS ? minTS : rng.nextInt(maxTS - minTS) + minTS; : minTS == maxTS ? minTS : rng.nextInt(maxTS - minTS) + minTS;
if (isFirst) // if (isFirst)
log.info("Local maxima: {}", localMaxima); // log.info("Local maxima: {}", localMaxima);
isFirst = false; isFirst = false;

View File

@ -54,7 +54,7 @@ public class EvalJsonTest extends BaseDL4JTest {
@Test @Test
public void testSerde() { public void testSerde() {
boolean print = true; boolean print = false;
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
Evaluation evaluation = new Evaluation(); Evaluation evaluation = new Evaluation();
@ -105,7 +105,7 @@ public class EvalJsonTest extends BaseDL4JTest {
@Test @Test
public void testSerdeExactRoc() { public void testSerdeExactRoc() {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
boolean print = true; boolean print = false;
ROC roc = new ROC(0); ROC roc = new ROC(0);
ROCBinary roc2 = new ROCBinary(0); ROCBinary roc2 = new ROCBinary(0);

View File

@ -131,11 +131,15 @@ public class EvalTest extends BaseDL4JTest {
org.nd4j.evaluation.classification.Evaluation evalViaMethod = model.evaluate(new ListDataSetIterator<>(Collections.singletonList(test))); org.nd4j.evaluation.classification.Evaluation evalViaMethod = model.evaluate(new ListDataSetIterator<>(Collections.singletonList(test)));
checkEvaluationEquality(eval, evalViaMethod); checkEvaluationEquality(eval, evalViaMethod);
System.out.println(eval.getConfusionMatrix().toString()); // System.out.println(eval.getConfusionMatrix().toString());
System.out.println(eval.getConfusionMatrix().toCSV()); // System.out.println(eval.getConfusionMatrix().toCSV());
System.out.println(eval.getConfusionMatrix().toHTML()); // System.out.println(eval.getConfusionMatrix().toHTML());
// System.out.println(eval.confusionToString());
System.out.println(eval.confusionToString()); eval.getConfusionMatrix().toString();
eval.getConfusionMatrix().toCSV();
eval.getConfusionMatrix().toHTML();
eval.confusionToString();
} }
private static void assertMapEquals(Map<Integer, Integer> first, Map<Integer, Integer> second) { private static void assertMapEquals(Map<Integer, Integer> first, Map<Integer, Integer> second) {
@ -205,9 +209,10 @@ public class EvalTest extends BaseDL4JTest {
e.eval(ds.getLabels(), out, meta); //*** New - evaluate and also store metadata *** e.eval(ds.getLabels(), out, meta); //*** New - evaluate and also store metadata ***
} }
System.out.println(e.stats()); // System.out.println(e.stats());
e.stats();
System.out.println("\n\n*** Prediction Errors: ***"); // System.out.println("\n\n*** Prediction Errors: ***");
List<org.nd4j.evaluation.meta.Prediction> errors = e.getPredictionErrors(); //*** New - get list of prediction errors from evaluation *** List<org.nd4j.evaluation.meta.Prediction> errors = e.getPredictionErrors(); //*** New - get list of prediction errors from evaluation ***
List<RecordMetaData> metaForErrors = new ArrayList<>(); List<RecordMetaData> metaForErrors = new ArrayList<>();
@ -219,10 +224,11 @@ public class EvalTest extends BaseDL4JTest {
int count = 0; int count = 0;
for (org.nd4j.evaluation.meta.Prediction t : errors) { for (org.nd4j.evaluation.meta.Prediction t : errors) {
System.out.println(t + "\t\tRaw Data: " String s = t + "\t\tRaw Data: "
+ csv.loadFromMetaData((RecordMetaData) t.getRecordMetaData()).getRecord() //*** New - load subset of data from MetaData object (usually batched for efficiency) *** + csv.loadFromMetaData((RecordMetaData) t.getRecordMetaData()).getRecord() //*** New - load subset of data from MetaData object (usually batched for efficiency) ***
+ "\tNormalized: " + ds.getFeatures().getRow(count) + "\tLabels: " + "\tNormalized: " + ds.getFeatures().getRow(count) + "\tLabels: "
+ ds.getLabels().getRow(count) + "\tNetwork predictions: " + output.getRow(count)); + ds.getLabels().getRow(count) + "\tNetwork predictions: " + output.getRow(count);
// System.out.println(s);
count++; count++;
} }
@ -322,9 +328,9 @@ public class EvalTest extends BaseDL4JTest {
List<DataSet> l = Arrays.asList(new DataSet(in1, out1, null, lMask1), new DataSet(in2, out2, null, lMask2)); List<DataSet> l = Arrays.asList(new DataSet(in1, out1, null, lMask1), new DataSet(in2, out2, null, lMask2));
DataSetIterator iter = new ExistingDataSetIterator(l); DataSetIterator iter = new ExistingDataSetIterator(l);
System.out.println("Net 1 eval"); // System.out.println("Net 1 eval");
org.nd4j.evaluation.IEvaluation[] e1 = net1.doEvaluation(iter, new org.nd4j.evaluation.classification.Evaluation(), new org.nd4j.evaluation.classification.ROCMultiClass(), new org.nd4j.evaluation.regression.RegressionEvaluation()); org.nd4j.evaluation.IEvaluation[] e1 = net1.doEvaluation(iter, new org.nd4j.evaluation.classification.Evaluation(), new org.nd4j.evaluation.classification.ROCMultiClass(), new org.nd4j.evaluation.regression.RegressionEvaluation());
System.out.println("Net 2 eval"); // System.out.println("Net 2 eval");
org.nd4j.evaluation.IEvaluation[] e2 = net2.doEvaluation(iter, new org.nd4j.evaluation.classification.Evaluation(), new org.nd4j.evaluation.classification.ROCMultiClass(), new org.nd4j.evaluation.regression.RegressionEvaluation()); org.nd4j.evaluation.IEvaluation[] e2 = net2.doEvaluation(iter, new org.nd4j.evaluation.classification.Evaluation(), new org.nd4j.evaluation.classification.ROCMultiClass(), new org.nd4j.evaluation.regression.RegressionEvaluation());
assertEquals(e1[0], e2[0]); assertEquals(e1[0], e2[0]);
@ -403,9 +409,9 @@ public class EvalTest extends BaseDL4JTest {
List<DataSet> l = Arrays.asList(new DataSet(in1, out1), new DataSet(in2, out2)); List<DataSet> l = Arrays.asList(new DataSet(in1, out1), new DataSet(in2, out2));
DataSetIterator iter = new ExistingDataSetIterator(l); DataSetIterator iter = new ExistingDataSetIterator(l);
System.out.println("Eval net 1"); // System.out.println("Eval net 1");
org.nd4j.evaluation.IEvaluation[] e1 = net1.doEvaluation(iter, new org.nd4j.evaluation.classification.Evaluation(), new org.nd4j.evaluation.classification.ROCMultiClass(), new org.nd4j.evaluation.regression.RegressionEvaluation()); org.nd4j.evaluation.IEvaluation[] e1 = net1.doEvaluation(iter, new org.nd4j.evaluation.classification.Evaluation(), new org.nd4j.evaluation.classification.ROCMultiClass(), new org.nd4j.evaluation.regression.RegressionEvaluation());
System.out.println("Eval net 2"); // System.out.println("Eval net 2");
org.nd4j.evaluation.IEvaluation[] e2 = net2.doEvaluation(iter, new org.nd4j.evaluation.classification.Evaluation(), new org.nd4j.evaluation.classification.ROCMultiClass(), new org.nd4j.evaluation.regression.RegressionEvaluation()); org.nd4j.evaluation.IEvaluation[] e2 = net2.doEvaluation(iter, new org.nd4j.evaluation.classification.Evaluation(), new org.nd4j.evaluation.classification.ROCMultiClass(), new org.nd4j.evaluation.regression.RegressionEvaluation());
assertEquals(e1[0], e2[0]); assertEquals(e1[0], e2[0]);

View File

@ -117,7 +117,7 @@ public class EvaluationToolsTests extends BaseDL4JTest {
String str = EvaluationTools.rocChartToHtml(roc, Arrays.asList("setosa", "versicolor", "virginica")); String str = EvaluationTools.rocChartToHtml(roc, Arrays.asList("setosa", "versicolor", "virginica"));
System.out.println(str); // System.out.println(str);
} }
} }

View File

@ -46,12 +46,6 @@ public class AttentionLayerTest extends BaseDL4JTest {
@Rule @Rule
public ExpectedException exceptionRule = ExpectedException.none(); public ExpectedException exceptionRule = ExpectedException.none();
private static final boolean PRINT_RESULTS = true;
private static final boolean RETURN_ON_FIRST_FAILURE = false;
private static final double DEFAULT_EPS = 1e-6;
private static final double DEFAULT_MAX_REL_ERROR = 1e-3;
private static final double DEFAULT_MIN_ABS_ERROR = 1e-8;
@Test @Test
public void testSelfAttentionLayer() { public void testSelfAttentionLayer() {
int nIn = 3; int nIn = 3;
@ -104,8 +98,8 @@ public class AttentionLayerTest extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in)
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, in, labels, inMask, null, true, 100); .labels(labels).inputMask(inMask).subset(true).maxPerParam(100));
assertTrue(name, gradOK); assertTrue(name, gradOK);
} }
} }
@ -165,8 +159,8 @@ public class AttentionLayerTest extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in)
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, in, labels, inMask, null, true, 100); .labels(labels).inputMask(inMask).subset(true).maxPerParam(100));
assertTrue(name, gradOK); assertTrue(name, gradOK);
} }
} }
@ -226,8 +220,8 @@ public class AttentionLayerTest extends BaseDL4JTest {
String name = "testLearnedSelfAttentionLayer() - mb=" + mb + ", tsLength = " + tsLength + ", maskType=" + maskType + ", projectInput = " + projectInput; String name = "testLearnedSelfAttentionLayer() - mb=" + mb + ", tsLength = " + tsLength + ", maskType=" + maskType + ", projectInput = " + projectInput;
System.out.println("Starting test: " + name); System.out.println("Starting test: " + name);
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in)
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, in, labels, inMask, null, true, 100); .labels(labels).inputMask(inMask).subset(true).maxPerParam(100));
assertTrue(name, gradOK); assertTrue(name, gradOK);
} }
} }
@ -320,8 +314,8 @@ public class AttentionLayerTest extends BaseDL4JTest {
net.init(); net.init();
//System.out.println("Original"); //System.out.println("Original");
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in)
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, in, labels, inMask, null, true, 100, null); .labels(labels).inputMask(inMask).subset(true).maxPerParam(100));
assertTrue(name, gradOK); assertTrue(name, gradOK);
} }
} }
@ -383,8 +377,8 @@ public class AttentionLayerTest extends BaseDL4JTest {
ComputationGraph net = new ComputationGraph(graph); ComputationGraph net = new ComputationGraph(graph);
net.init(); net.init();
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[]{in})
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[]{in}, new INDArray[]{labels}, inMask != null ? new INDArray[]{inMask} : null, null); .labels(new INDArray[]{labels}).inputMask(inMask != null ? new INDArray[]{inMask} : null).subset(true).maxPerParam(100));
assertTrue(name, gradOK); assertTrue(name, gradOK);
} }
} }
@ -445,9 +439,8 @@ public class AttentionLayerTest extends BaseDL4JTest {
ComputationGraph net = new ComputationGraph(graph); ComputationGraph net = new ComputationGraph(graph);
net.init(); net.init();
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[]{in})
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[]{in}, .labels(new INDArray[]{labels}).inputMask(inMask != null ? new INDArray[]{inMask} : null));
new INDArray[]{labels}, inMask != null ? new INDArray[]{inMask} : null, null);
assertTrue(name, gradOK); assertTrue(name, gradOK);
} }
} }

View File

@ -56,11 +56,6 @@ import static org.junit.Assert.assertTrue;
* *
*/ */
public class BNGradientCheckTest extends BaseDL4JTest { public class BNGradientCheckTest extends BaseDL4JTest {
private static final boolean PRINT_RESULTS = true;
private static final boolean RETURN_ON_FIRST_FAILURE = false;
private static final double DEFAULT_EPS = 1e-5;
private static final double DEFAULT_MAX_REL_ERROR = 1e-5;
private static final double DEFAULT_MIN_ABS_ERROR = 1e-9;
static { static {
Nd4j.setDataType(DataType.DOUBLE); Nd4j.setDataType(DataType.DOUBLE);
@ -93,17 +88,15 @@ public class BNGradientCheckTest extends BaseDL4JTest {
MultiLayerNetwork mln = new MultiLayerNetwork(builder.build()); MultiLayerNetwork mln = new MultiLayerNetwork(builder.build());
mln.init(); mln.init();
if (PRINT_RESULTS) { // for (int j = 0; j < mln.getnLayers(); j++)
for (int j = 0; j < mln.getnLayers(); j++) // System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
}
//Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc //Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc
//i.e., runningMean = decay * runningMean + (1-decay) * batchMean //i.e., runningMean = decay * runningMean + (1-decay) * batchMean
//However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter" //However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter"
Set<String> excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "1_log10stdev")); Set<String> excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "1_log10stdev"));
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input)
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, excludeParams); .labels(labels).excludeParams(excludeParams));
assertTrue(gradOK); assertTrue(gradOK);
TestUtils.testModelSerialization(mln); TestUtils.testModelSerialization(mln);
@ -140,17 +133,15 @@ public class BNGradientCheckTest extends BaseDL4JTest {
MultiLayerNetwork mln = new MultiLayerNetwork(builder.build()); MultiLayerNetwork mln = new MultiLayerNetwork(builder.build());
mln.init(); mln.init();
if (PRINT_RESULTS) { // for (int j = 0; j < mln.getnLayers(); j++)
for (int j = 0; j < mln.getnLayers(); j++) // System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
}
//Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc //Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc
//i.e., runningMean = decay * runningMean + (1-decay) * batchMean //i.e., runningMean = decay * runningMean + (1-decay) * batchMean
//However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter" //However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter"
Set<String> excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "1_log10stdev")); Set<String> excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "1_log10stdev"));
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input)
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, excludeParams); .labels(labels).excludeParams(excludeParams));
assertTrue(gradOK); assertTrue(gradOK);
TestUtils.testModelSerialization(mln); TestUtils.testModelSerialization(mln);
@ -220,7 +211,7 @@ public class BNGradientCheckTest extends BaseDL4JTest {
String name = new Object() { String name = new Object() {
}.getClass().getEnclosingMethod().getName(); }.getClass().getEnclosingMethod().getName();
System.out.println("Num params: " + mln.numParams()); // System.out.println("Num params: " + mln.numParams());
if (doLearningFirst) { if (doLearningFirst) {
//Run a number of iterations of learning //Run a number of iterations of learning
@ -241,20 +232,18 @@ public class BNGradientCheckTest extends BaseDL4JTest {
assertTrue(msg, scoreAfter < 0.9 * scoreBefore); assertTrue(msg, scoreAfter < 0.9 * scoreBefore);
} }
if (PRINT_RESULTS) { System.out.println(name + " - activationFn=" + afn + ", lossFn=" + lf
System.out.println(name + " - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", doLearningFirst="
+ ", outputActivation=" + outputActivation + ", doLearningFirst=" + doLearningFirst + ", l1=" + l1vals[j] + ", l2=" + l2vals[j]);
+ doLearningFirst + ", l1=" + l1vals[j] + ", l2=" + l2vals[j]); // for (int k = 0; k < mln.getnLayers(); k++)
for (int k = 0; k < mln.getnLayers(); k++) // System.out.println("Layer " + k + " # params: " + mln.getLayer(k).numParams());
System.out.println("Layer " + k + " # params: " + mln.getLayer(k).numParams());
}
//Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc //Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc
//i.e., runningMean = decay * runningMean + (1-decay) * batchMean //i.e., runningMean = decay * runningMean + (1-decay) * batchMean
//However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter" //However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter"
Set<String> excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "3_mean", "3_var", "1_log10stdev", "3_log10stdev")); Set<String> excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "3_mean", "3_var", "1_log10stdev", "3_log10stdev"));
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input)
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, null, null, true, 25, excludeParams); //Most params are in output layer, only these should be skipped with this threshold .labels(labels).excludeParams(excludeParams).subset(true).maxPerParam(25)); //Most params are in output layer, only these should be skipped with this threshold
assertTrue(gradOK); assertTrue(gradOK);
TestUtils.testModelSerialization(mln); TestUtils.testModelSerialization(mln);
@ -347,20 +336,18 @@ public class BNGradientCheckTest extends BaseDL4JTest {
assertTrue(msg, scoreAfter < 0.8 * scoreBefore); assertTrue(msg, scoreAfter < 0.8 * scoreBefore);
} }
if (PRINT_RESULTS) { System.out.println(name + " - activationFn=" + afn + ", lossFn=" + lf
System.out.println(name + " - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", doLearningFirst="
+ ", outputActivation=" + outputActivation + ", doLearningFirst=" + doLearningFirst + ", l1=" + l1vals[j] + ", l2=" + l2vals[j]);
+ doLearningFirst + ", l1=" + l1vals[j] + ", l2=" + l2vals[j]); // for (int k = 0; k < mln.getnLayers(); k++)
for (int k = 0; k < mln.getnLayers(); k++) // System.out.println("Layer " + k + " # params: " + mln.getLayer(k).numParams());
System.out.println("Layer " + k + " # params: " + mln.getLayer(k).numParams());
}
//Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc //Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc
//i.e., runningMean = decay * runningMean + (1-decay) * batchMean //i.e., runningMean = decay * runningMean + (1-decay) * batchMean
//However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter" //However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter"
Set<String> excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "3_mean", "3_var", "1_log10stdev", "3_log10stdev")); Set<String> excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "3_mean", "3_var", "1_log10stdev", "3_log10stdev"));
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input)
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, excludeParams); .labels(labels).excludeParams(excludeParams));
assertTrue(gradOK); assertTrue(gradOK);
TestUtils.testModelSerialization(mln); TestUtils.testModelSerialization(mln);
@ -396,17 +383,15 @@ public class BNGradientCheckTest extends BaseDL4JTest {
MultiLayerNetwork mln = new MultiLayerNetwork(builder.build()); MultiLayerNetwork mln = new MultiLayerNetwork(builder.build());
mln.init(); mln.init();
if (PRINT_RESULTS) { // for (int j = 0; j < mln.getnLayers(); j++)
for (int j = 0; j < mln.getnLayers(); j++) // System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
}
//Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc //Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc
//i.e., runningMean = decay * runningMean + (1-decay) * batchMean //i.e., runningMean = decay * runningMean + (1-decay) * batchMean
//However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter" //However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter"
Set<String> excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "1_log10stdev")); Set<String> excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "1_log10stdev"));
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input)
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, excludeParams); .labels(labels).excludeParams(excludeParams));
assertTrue(gradOK); assertTrue(gradOK);
TestUtils.testModelSerialization(mln); TestUtils.testModelSerialization(mln);
@ -443,17 +428,15 @@ public class BNGradientCheckTest extends BaseDL4JTest {
MultiLayerNetwork mln = new MultiLayerNetwork(builder.build()); MultiLayerNetwork mln = new MultiLayerNetwork(builder.build());
mln.init(); mln.init();
if (PRINT_RESULTS) { // for (int j = 0; j < mln.getnLayers(); j++)
for (int j = 0; j < mln.getnLayers(); j++) // System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
}
//Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc //Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc
//i.e., runningMean = decay * runningMean + (1-decay) * batchMean //i.e., runningMean = decay * runningMean + (1-decay) * batchMean
//However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter" //However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter"
Set<String> excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "1_log10stdev")); Set<String> excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "1_log10stdev"));
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input)
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, excludeParams); .labels(labels).excludeParams(excludeParams));
assertTrue(gradOK); assertTrue(gradOK);
TestUtils.testModelSerialization(mln); TestUtils.testModelSerialization(mln);
@ -496,9 +479,8 @@ public class BNGradientCheckTest extends BaseDL4JTest {
//i.e., runningMean = decay * runningMean + (1-decay) * batchMean //i.e., runningMean = decay * runningMean + (1-decay) * batchMean
//However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter" //However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter"
Set<String> excludeParams = new HashSet<>(Arrays.asList("bn_mean", "bn_var")); Set<String> excludeParams = new HashSet<>(Arrays.asList("bn_mean", "bn_var"));
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[]{input})
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[]{input}, .labels(new INDArray[]{labels}).excludeParams(excludeParams));
new INDArray[]{labels}, null, null, excludeParams);
assertTrue(gradOK); assertTrue(gradOK);
TestUtils.testModelSerialization(net); TestUtils.testModelSerialization(net);
@ -585,21 +567,18 @@ public class BNGradientCheckTest extends BaseDL4JTest {
assertTrue(msg, scoreAfter < 0.9 * scoreBefore); assertTrue(msg, scoreAfter < 0.9 * scoreBefore);
} }
if (PRINT_RESULTS) { System.out.println(name + " - activationFn=" + afn + ", lossFn=" + lf
System.out.println(name + " - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", doLearningFirst="
+ ", outputActivation=" + outputActivation + ", doLearningFirst=" + doLearningFirst + ", l1=" + l1vals[j] + ", l2=" + l2vals[j]);
+ doLearningFirst + ", l1=" + l1vals[j] + ", l2=" + l2vals[j]); // for (int k = 0; k < net.getNumLayers(); k++)
for (int k = 0; k < net.getNumLayers(); k++) // System.out.println("Layer " + k + " # params: " + net.getLayer(k).numParams());
System.out.println("Layer " + k + " # params: " + net.getLayer(k).numParams());
}
//Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc //Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc
//i.e., runningMean = decay * runningMean + (1-decay) * batchMean //i.e., runningMean = decay * runningMean + (1-decay) * batchMean
//However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter" //However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter"
Set<String> excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "3_mean", "3_var", "1_log10stdev", "3_log10stdev")); Set<String> excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "3_mean", "3_var", "1_log10stdev", "3_log10stdev"));
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[]{input})
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, .labels(new INDArray[]{labels}).excludeParams(excludeParams));
new INDArray[]{input}, new INDArray[]{labels}, null, null, excludeParams);
assertTrue(gradOK); assertTrue(gradOK);
TestUtils.testModelSerialization(net); TestUtils.testModelSerialization(net);

View File

@ -108,8 +108,8 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println(msg); System.out.println(msg);
for (int j = 0; j < net.getnLayers(); j++) // for (int j = 0; j < net.getnLayers(); j++)
System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
@ -188,8 +188,8 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println(msg); System.out.println(msg);
for (int j = 0; j < net.getnLayers(); j++) // for (int j = 0; j < net.getnLayers(); j++)
System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
@ -272,8 +272,8 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println(msg); System.out.println(msg);
for (int j = 0; j < net.getnLayers(); j++) // for (int j = 0; j < net.getnLayers(); j++)
System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
@ -349,8 +349,8 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println(msg); System.out.println(msg);
for (int j = 0; j < net.getnLayers(); j++) // for (int j = 0; j < net.getnLayers(); j++)
System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
@ -414,8 +414,8 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
INDArray label = TestUtils.randomOneHot(2, finalNOut); INDArray label = TestUtils.randomOneHot(2, finalNOut);
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(f)
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, f, label, fm, null); .labels(label).inputMask(fm));
assertTrue(s, gradOK); assertTrue(s, gradOK);
TestUtils.testModelSerialization(net); TestUtils.testModelSerialization(net);
@ -509,8 +509,8 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
INDArray label = TestUtils.randomOneHotTimeSeries(2, finalNOut, (int)outSize2); INDArray label = TestUtils.randomOneHotTimeSeries(2, finalNOut, (int)outSize2);
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(f)
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, f, label, fm, null); .labels(label).inputMask(fm));
assertTrue(s, gradOK); assertTrue(s, gradOK);
TestUtils.testModelSerialization(net); TestUtils.testModelSerialization(net);

View File

@ -144,14 +144,13 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest {
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
log.info(msg); log.info(msg);
for (int j = 0; j < net.getnLayers(); j++) { // for (int j = 0; j < net.getnLayers(); j++) {
log.info("Layer " + j + " # params: " + net.getLayer(j).numParams()); // log.info("Layer " + j + " # params: " + net.getLayer(j).numParams());
} // }
} }
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input)
DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, .labels(labels).subset(true).maxPerParam(128));
RETURN_ON_FIRST_FAILURE, input, labels, null, null, true, 128);
assertTrue(msg, gradOK); assertTrue(msg, gradOK);
@ -248,14 +247,13 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest {
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
log.info(msg); log.info(msg);
for (int j = 0; j < net.getnLayers(); j++) { // for (int j = 0; j < net.getnLayers(); j++) {
log.info("Layer " + j + " # params: " + net.getLayer(j).numParams()); // log.info("Layer " + j + " # params: " + net.getLayer(j).numParams());
} // }
} }
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input)
DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, .labels(labels).subset(true).maxPerParam(512));
RETURN_ON_FIRST_FAILURE, input, labels, null, null, true, 512);
assertTrue(msg, gradOK); assertTrue(msg, gradOK);
@ -431,9 +429,9 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest {
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
log.info(msg); log.info(msg);
for (int j = 0; j < net.getnLayers(); j++) { // for (int j = 0; j < net.getnLayers(); j++) {
log.info("Layer " + j + " # params: " + net.getLayer(j).numParams()); // log.info("Layer " + j + " # params: " + net.getLayer(j).numParams());
} // }
} }
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS,
@ -530,9 +528,9 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest {
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
log.info(msg); log.info(msg);
for (int j = 0; j < net.getnLayers(); j++) { // for (int j = 0; j < net.getnLayers(); j++) {
log.info("Layer " + j + " # params: " + net.getLayer(j).numParams()); // log.info("Layer " + j + " # params: " + net.getLayer(j).numParams());
} // }
} }
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS,
@ -547,4 +545,95 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest {
} }
} }
} }
@Test
public void testDeconv3d() {
Nd4j.getRandom().setSeed(12345);
// Note: we checked this with a variety of parameters, but it takes a lot of time.
int[] depths = {8, 8, 9};
int[] heights = {8, 9, 9};
int[] widths = {8, 8, 9};
int[][] kernels = {{2, 2, 2}, {3, 3, 3}, {2, 3, 2}};
int[][] strides = {{1, 1, 1}, {1, 1, 1}, {2, 2, 2}};
Activation[] activations = {Activation.SIGMOID, Activation.TANH, Activation.IDENTITY};
ConvolutionMode[] modes = {ConvolutionMode.Truncate, ConvolutionMode.Same, ConvolutionMode.Same};
int[] mbs = {1, 3, 2};
Convolution3D.DataFormat[] dataFormats = new Convolution3D.DataFormat[]{Convolution3D.DataFormat.NCDHW, Convolution3D.DataFormat.NDHWC, Convolution3D.DataFormat.NCDHW};
int convNIn = 2;
int finalNOut = 2;
int[] deconvOut = {2, 3, 4};
for (int i = 0; i < activations.length; i++) {
Activation afn = activations[i];
int miniBatchSize = mbs[i];
int depth = depths[i];
int height = heights[i];
int width = widths[i];
ConvolutionMode mode = modes[i];
int[] kernel = kernels[i];
int[] stride = strides[i];
Convolution3D.DataFormat df = dataFormats[i];
int dOut = deconvOut[i];
INDArray input;
if (df == Convolution3D.DataFormat.NDHWC) {
input = Nd4j.rand(new int[]{miniBatchSize, depth, height, width, convNIn});
} else {
input = Nd4j.rand(new int[]{miniBatchSize, convNIn, depth, height, width});
}
INDArray labels = Nd4j.zeros(miniBatchSize, finalNOut);
for (int j = 0; j < miniBatchSize; j++) {
labels.putScalar(new int[]{j, j % finalNOut}, 1.0);
}
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.updater(new NoOp())
.weightInit(new NormalDistribution(0, 0.1))
.list()
.layer(0, new Convolution3D.Builder().activation(afn).kernelSize(kernel)
.stride(stride).nIn(convNIn).nOut(dOut).hasBias(false)
.convolutionMode(mode).dataFormat(df)
.build())
.layer(1, new Deconvolution3D.Builder().activation(afn).kernelSize(kernel)
.stride(stride).nOut(dOut).hasBias(false)
.convolutionMode(mode).dataFormat(df)
.build())
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(finalNOut).build())
.setInputType(InputType.convolutional3D(df, depth, height, width, convNIn)).build();
String json = conf.toJson();
MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json);
assertEquals(conf, c2);
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
String msg = "DataFormat = " + df + ", minibatch size = " + miniBatchSize + ", activationFn=" + afn
+ ", kernel = " + Arrays.toString(kernel) + ", stride = "
+ Arrays.toString(stride) + ", mode = " + mode.toString()
+ ", input depth " + depth + ", input height " + height
+ ", input width " + width;
if (PRINT_RESULTS) {
log.info(msg);
// for (int j = 0; j < net.getnLayers(); j++) {
// log.info("Layer " + j + " # params: " + net.getLayer(j).numParams());
// }
}
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input)
.labels(labels).subset(true).maxPerParam(128));
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(net);
}
}
} }

View File

@ -122,8 +122,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println(name + " - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" System.out.println(name + " - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation="
+ outputActivation + ", doLearningFirst=" + doLearningFirst); + outputActivation + ", doLearningFirst=" + doLearningFirst);
for (int j = 0; j < mln.getnLayers(); j++) // for (int j = 0; j < mln.getnLayers(); j++)
System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
@ -213,8 +213,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
System.out.println(testName + "- activationFn=" + afn + ", lossFn=" + lf System.out.println(testName + "- activationFn=" + afn + ", lossFn=" + lf
+ ", outputActivation=" + outputActivation + ", doLearningFirst=" + ", outputActivation=" + outputActivation + ", doLearningFirst="
+ doLearningFirst); + doLearningFirst);
for (int j = 0; j < mln.getnLayers(); j++) // for (int j = 0; j < mln.getnLayers(); j++)
System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
@ -275,8 +275,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println(msg); System.out.println(msg);
for (int j = 0; j < net.getnLayers(); j++) // for (int j = 0; j < net.getnLayers(); j++)
System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
@ -336,8 +336,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println(msg); System.out.println(msg);
for (int j = 0; j < net.getnLayers(); j++) // for (int j = 0; j < net.getnLayers(); j++)
System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
@ -346,8 +346,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
//Also check compgraph: //Also check compgraph:
ComputationGraph cg = net.toComputationGraph(); ComputationGraph cg = net.toComputationGraph();
gradOK = GradientCheckUtil.checkGradients(cg, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(cg).inputs(new INDArray[]{input})
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[]{input}, new INDArray[]{labels}); .labels(new INDArray[]{labels}));
assertTrue(msg + " - compgraph", gradOK); assertTrue(msg + " - compgraph", gradOK);
TestUtils.testModelSerialization(net); TestUtils.testModelSerialization(net);
@ -399,8 +399,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println(msg); System.out.println(msg);
for (int j = 0; j < net.getnLayers(); j++) // for (int j = 0; j < net.getnLayers(); j++)
System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
@ -468,8 +468,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println(msg); System.out.println(msg);
for (int j = 0; j < net.getnLayers(); j++) // for (int j = 0; j < net.getnLayers(); j++)
System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
@ -793,9 +793,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
+ convFirst; + convFirst;
System.out.println(msg); System.out.println(msg);
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input)
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, .labels(labels).subset(true).maxPerParam(128));
labels, null, null, true, 128);
assertTrue(msg, gradOK); assertTrue(msg, gradOK);
@ -863,8 +862,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println(msg); System.out.println(msg);
for (int j = 0; j < net.getnLayers(); j++) // for (int j = 0; j < net.getnLayers(); j++)
System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
@ -937,8 +936,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
+ k + ", s=" + s + ", d=" + d + ", cm=" + cm; + k + ", s=" + s + ", d=" + d + ", cm=" + cm;
System.out.println(msg); System.out.println(msg);
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input)
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, null, null, true, 100); .labels(labels).subset(true).maxPerParam(100));
assertTrue(msg, gradOK); assertTrue(msg, gradOK);
@ -1009,8 +1008,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
+ k + ", s=" + s + ", d=" + d + ", cm=" + cm; + k + ", s=" + s + ", d=" + d + ", cm=" + cm;
System.out.println(msg); System.out.println(msg);
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input)
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, null, null, true, 50); //Most params are in output layer .labels(labels).subset(true).maxPerParam(50)); //Most params are in output layer
assertTrue(msg, gradOK); assertTrue(msg, gradOK);
@ -1160,12 +1159,12 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println(msg); System.out.println(msg);
for (int j = 0; j < net.getnLayers(); j++) // for (int j = 0; j < net.getnLayers(); j++)
System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input)
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, null, null, true, 160); .labels(labels).subset(true).maxPerParam(160));
assertTrue(msg, gradOK); assertTrue(msg, gradOK);
@ -1235,8 +1234,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
+ k + ", nIn=" + nIn + ", depthMul=" + depthMultiplier + ", s=" + s + ", cm=" + cm; + k + ", nIn=" + nIn + ", depthMul=" + depthMultiplier + ", s=" + s + ", cm=" + cm;
System.out.println(msg); System.out.println(msg);
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input)
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, null, null, true, 256); .labels(labels).subset(true).maxPerParam(256));
assertTrue(msg, gradOK); assertTrue(msg, gradOK);

View File

@ -110,10 +110,8 @@ public class CapsnetGradientCheckTest extends BaseDL4JTest {
" capsules with " + capsuleDim + " dimensions and " + routing + " routings"; " capsules with " + capsuleDim + " dimensions and " + routing + " routings";
System.out.println(msg); System.out.println(msg);
boolean gradOK = GradientCheckUtil boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input)
.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, .labels(labels).subset(true).maxPerParam(100));
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input,
labels, null, null, true, 100);
assertTrue(msg, gradOK); assertTrue(msg, gradOK);

View File

@ -36,6 +36,7 @@ import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.function.Consumer;
import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
@ -171,10 +172,15 @@ public class DropoutGradientCheck extends BaseDL4JTest {
INDArray[] in = new INDArray[]{Nd4j.rand(mb, 5)}; INDArray[] in = new INDArray[]{Nd4j.rand(mb, 5)};
INDArray[] l = new INDArray[]{TestUtils.randomOneHot(mb, 5)}; INDArray[] l = new INDArray[]{TestUtils.randomOneHot(mb, 5)};
boolean ok = GradientCheckUtil.checkGradients(cg, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(cg).inputs(in)
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, in, l, null, null, null, 12345); .labels(l).callEachIter(new Consumer<ComputationGraph>() {
@Override
public void accept(ComputationGraph net) {
Nd4j.getRandom().setSeed(12345);
}
}));
assertTrue(ok); assertTrue(gradOK);
} }
} }

View File

@ -92,8 +92,8 @@ public class GlobalPoolingGradientCheckTests extends BaseDL4JTest {
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println("testLSTMGlobalPoolingBasicMultiLayer() - " + pt + ", minibatch = " System.out.println("testLSTMGlobalPoolingBasicMultiLayer() - " + pt + ", minibatch = "
+ miniBatchSize); + miniBatchSize);
for (int j = 0; j < mln.getnLayers(); j++) // for (int j = 0; j < mln.getnLayers(); j++)
System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
@ -150,8 +150,8 @@ public class GlobalPoolingGradientCheckTests extends BaseDL4JTest {
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println( System.out.println(
"testCnnGlobalPoolingBasicMultiLayer() - " + pt + ", minibatch = " + miniBatchSize); "testCnnGlobalPoolingBasicMultiLayer() - " + pt + ", minibatch = " + miniBatchSize);
for (int j = 0; j < mln.getnLayers(); j++) // for (int j = 0; j < mln.getnLayers(); j++)
System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
@ -209,12 +209,12 @@ public class GlobalPoolingGradientCheckTests extends BaseDL4JTest {
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println("testLSTMGlobalPoolingBasicMultiLayer() - " + pt + ", minibatch = " + miniBatchSize); System.out.println("testLSTMGlobalPoolingBasicMultiLayer() - " + pt + ", minibatch = " + miniBatchSize);
for (int j = 0; j < mln.getnLayers(); j++) // for (int j = 0; j < mln.getnLayers(); j++)
System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input)
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, featuresMask, null); .labels(labels).inputMask(featuresMask));
assertTrue(gradOK); assertTrue(gradOK);
TestUtils.testModelSerialization(mln); TestUtils.testModelSerialization(mln);
@ -292,12 +292,12 @@ public class GlobalPoolingGradientCheckTests extends BaseDL4JTest {
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println("testCnnGlobalPoolingBasicMultiLayer() - " + pt + ", minibatch = " System.out.println("testCnnGlobalPoolingBasicMultiLayer() - " + pt + ", minibatch = "
+ miniBatchSize); + miniBatchSize);
for (int j = 0; j < mln.getnLayers(); j++) // for (int j = 0; j < mln.getnLayers(); j++)
System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input)
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, inputMask, null); .labels(labels).inputMask(inputMask));
assertTrue(gradOK); assertTrue(gradOK);
TestUtils.testModelSerialization(mln); TestUtils.testModelSerialization(mln);

View File

@ -120,8 +120,8 @@ public class GradientCheckTests extends BaseDL4JTest {
System.out.println("testMinibatchApplication() - activationFn=" + afn + ", lossFn=" System.out.println("testMinibatchApplication() - activationFn=" + afn + ", lossFn="
+ lf + ", outputActivation=" + outputActivation + ", doLearningFirst=" + lf + ", outputActivation=" + outputActivation + ", doLearningFirst="
+ doLearningFirst); + doLearningFirst);
for (int j = 0; j < mln.getnLayers(); j++) // for (int j = 0; j < mln.getnLayers(); j++)
System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
@ -200,8 +200,8 @@ public class GradientCheckTests extends BaseDL4JTest {
System.out.println("testGradientMLP2LayerIrisSimpleRandom() - activationFn=" + afn + ", lossFn=" System.out.println("testGradientMLP2LayerIrisSimpleRandom() - activationFn=" + afn + ", lossFn="
+ lf + ", outputActivation=" + outputActivation + ", doLearningFirst=" + lf + ", outputActivation=" + outputActivation + ", doLearningFirst="
+ doLearningFirst); + doLearningFirst);
for (int j = 0; j < mln.getnLayers(); j++) // for (int j = 0; j < mln.getnLayers(); j++)
System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
@ -294,8 +294,8 @@ public class GradientCheckTests extends BaseDL4JTest {
System.out.println("testGradientMLP2LayerIrisSimpleRandom() - activationFn=" + afn System.out.println("testGradientMLP2LayerIrisSimpleRandom() - activationFn=" + afn
+ ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", lossFn=" + lf + ", outputActivation=" + outputActivation
+ ", doLearningFirst=" + doLearningFirst + ", l2=" + l2 + ", l1=" + l1); + ", doLearningFirst=" + doLearningFirst + ", l2=" + l2 + ", l1=" + l1);
for (int j = 0; j < mln.getnLayers(); j++) // for (int j = 0; j < mln.getnLayers(); j++)
System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
@ -339,8 +339,8 @@ public class GradientCheckTests extends BaseDL4JTest {
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println("testEmbeddingLayerSimple"); System.out.println("testEmbeddingLayerSimple");
for (int j = 0; j < mln.getnLayers(); j++) // for (int j = 0; j < mln.getnLayers(); j++)
System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
@ -379,8 +379,8 @@ public class GradientCheckTests extends BaseDL4JTest {
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println("testEmbeddingLayerSimple"); System.out.println("testEmbeddingLayerSimple");
for (int j = 0; j < mln.getnLayers(); j++) // for (int j = 0; j < mln.getnLayers(); j++)
System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
@ -469,8 +469,8 @@ public class GradientCheckTests extends BaseDL4JTest {
+ doLearningFirst + ", l2=" + l2 + ", l1=" + l1; + doLearningFirst + ", l2=" + l2 + ", l1=" + l1;
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println(msg); System.out.println(msg);
for (int j = 0; j < mln.getnLayers(); j++) // for (int j = 0; j < mln.getnLayers(); j++)
System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
@ -539,8 +539,8 @@ public class GradientCheckTests extends BaseDL4JTest {
// expectation in case linear regression(with only element wise multiplication layer): large weight for the fourth weight // expectation in case linear regression(with only element wise multiplication layer): large weight for the fourth weight
log.info("params after learning: " + netGraph.getLayer(1).paramTable()); log.info("params after learning: " + netGraph.getLayer(1).paramTable());
boolean gradOK = checkGradients(netGraph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(netGraph).inputs(new INDArray[]{features})
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[]{features}, new INDArray[]{labels}); .labels(new INDArray[]{labels}));
msg = "elementWiseMultiplicationLayerTest() - activationFn=" + "ID" + ", lossFn=" + "Cos-sim" msg = "elementWiseMultiplicationLayerTest() - activationFn=" + "ID" + ", lossFn=" + "Cos-sim"
+ ", outputActivation=" + "Id" + ", doLearningFirst=" + "true"; + ", outputActivation=" + "Id" + ", doLearningFirst=" + "true";
@ -592,8 +592,8 @@ public class GradientCheckTests extends BaseDL4JTest {
} }
String msg = "mask=" + maskArray + ", inputRank=" + inputRank; String msg = "mask=" + maskArray + ", inputRank=" + inputRank;
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in)
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, in, label, fMask, null); .labels(label).inputMask(fMask));
assertTrue(msg, gradOK); assertTrue(msg, gradOK);
TestUtils.testModelSerialization(net); TestUtils.testModelSerialization(net);
@ -767,8 +767,8 @@ public class GradientCheckTests extends BaseDL4JTest {
System.out.println("testGradientMLP2LayerIrisSimpleRandom() - activationFn=" + afn + ", lossFn=" System.out.println("testGradientMLP2LayerIrisSimpleRandom() - activationFn=" + afn + ", lossFn="
+ lf + ", outputActivation=" + outputActivation + ", doLearningFirst=" + lf + ", outputActivation=" + outputActivation + ", doLearningFirst="
+ doLearningFirst + ", layerNorm=" + layerNorm); + doLearningFirst + ", layerNorm=" + layerNorm);
for (int j = 0; j < mln.getnLayers(); j++) // for (int j = 0; j < mln.getnLayers(); j++)
System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,

View File

@ -103,13 +103,12 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println("testBasicIris()"); System.out.println("testBasicIris()");
for (int j = 0; j < graph.getNumLayers(); j++) // for (int j = 0; j < graph.getNumLayers(); j++)
System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[] {input}, .labels(new INDArray[]{labels}));
new INDArray[] {labels});
String msg = "testBasicIris()"; String msg = "testBasicIris()";
assertTrue(msg, gradOK); assertTrue(msg, gradOK);
@ -155,13 +154,12 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println("testBasicIrisWithMerging()"); System.out.println("testBasicIrisWithMerging()");
for (int j = 0; j < graph.getNumLayers(); j++) // for (int j = 0; j < graph.getNumLayers(); j++)
System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[] {input}, .labels(new INDArray[]{labels}));
new INDArray[] {labels});
String msg = "testBasicIrisWithMerging()"; String msg = "testBasicIrisWithMerging()";
assertTrue(msg, gradOK); assertTrue(msg, gradOK);
@ -213,13 +211,12 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println("testBasicIrisWithElementWiseVertex(op=" + op + ")"); System.out.println("testBasicIrisWithElementWiseVertex(op=" + op + ")");
for (int j = 0; j < graph.getNumLayers(); j++) // for (int j = 0; j < graph.getNumLayers(); j++)
System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[] {input}, .labels(new INDArray[]{labels}));
new INDArray[] {labels});
String msg = "testBasicIrisWithElementWiseVertex(op=" + op + ")"; String msg = "testBasicIrisWithElementWiseVertex(op=" + op + ")";
assertTrue(msg, gradOK); assertTrue(msg, gradOK);
@ -274,13 +271,12 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println("testBasicIrisWithElementWiseVertex(op=" + op + ")"); System.out.println("testBasicIrisWithElementWiseVertex(op=" + op + ")");
for (int j = 0; j < graph.getNumLayers(); j++) // for (int j = 0; j < graph.getNumLayers(); j++)
System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[] {input}, .labels(new INDArray[]{labels}));
new INDArray[] {labels});
String msg = "testBasicIrisWithElementWiseVertex(op=" + op + ")"; String msg = "testBasicIrisWithElementWiseVertex(op=" + op + ")";
assertTrue(msg, gradOK); assertTrue(msg, gradOK);
@ -328,9 +324,8 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
graph.fit(new DataSet(in, labels)); graph.fit(new DataSet(in, labels));
boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{in})
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[]{in}, .labels(new INDArray[]{labels}));
new INDArray[]{labels});
assertTrue(msg, gradOK); assertTrue(msg, gradOK);
TestUtils.testModelSerialization(graph); TestUtils.testModelSerialization(graph);
} }
@ -372,13 +367,12 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println("testCnnDepthMerge()"); System.out.println("testCnnDepthMerge()");
for (int j = 0; j < graph.getNumLayers(); j++) // for (int j = 0; j < graph.getNumLayers(); j++)
System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[] {input}, .labels(new INDArray[]{labels}));
new INDArray[] {labels});
String msg = "testCnnDepthMerge()"; String msg = "testCnnDepthMerge()";
assertTrue(msg, gradOK); assertTrue(msg, gradOK);
@ -430,13 +424,12 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println("testLSTMWithMerging()"); System.out.println("testLSTMWithMerging()");
for (int j = 0; j < graph.getNumLayers(); j++) // for (int j = 0; j < graph.getNumLayers(); j++)
System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[] {input}, .labels(new INDArray[]{labels}));
new INDArray[] {labels});
String msg = "testLSTMWithMerging()"; String msg = "testLSTMWithMerging()";
assertTrue(msg, gradOK); assertTrue(msg, gradOK);
@ -466,13 +459,12 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println("testLSTMWithSubset()"); System.out.println("testLSTMWithSubset()");
for (int j = 0; j < graph.getNumLayers(); j++) // for (int j = 0; j < graph.getNumLayers(); j++)
System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[] {input}, .labels(new INDArray[]{labels}));
new INDArray[] {labels});
String msg = "testLSTMWithSubset()"; String msg = "testLSTMWithSubset()";
assertTrue(msg, gradOK); assertTrue(msg, gradOK);
@ -504,26 +496,24 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println("testLSTMWithLastTimeStepVertex()"); System.out.println("testLSTMWithLastTimeStepVertex()");
for (int j = 0; j < graph.getNumLayers(); j++) // for (int j = 0; j < graph.getNumLayers(); j++)
System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
} }
//First: test with no input mask array //First: test with no input mask array
boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[] {input}, .labels(new INDArray[]{labels}));
new INDArray[] {labels});
String msg = "testLSTMWithLastTimeStepVertex()"; String msg = "testLSTMWithLastTimeStepVertex()";
assertTrue(msg, gradOK); assertTrue(msg, gradOK);
//Second: test with input mask arrays. //Second: test with input mask arrays.
INDArray inMask = Nd4j.zeros(3, 5); INDArray inMask = Nd4j.zeros(3, 4);
inMask.putRow(0, Nd4j.create(new double[] {1, 1, 1, 0, 0})); inMask.putRow(0, Nd4j.create(new double[] {1, 1, 0, 0}));
inMask.putRow(1, Nd4j.create(new double[] {1, 1, 1, 1, 0})); inMask.putRow(1, Nd4j.create(new double[] {1, 1, 1, 0}));
inMask.putRow(2, Nd4j.create(new double[] {1, 1, 1, 1, 1})); inMask.putRow(2, Nd4j.create(new double[] {1, 1, 1, 1}));
graph.setLayerMaskArrays(new INDArray[] {inMask}, null); gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, .labels(new INDArray[]{labels}).inputMask(new INDArray[]{inMask}));
PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[] {input}, new INDArray[] {labels});
assertTrue(msg, gradOK); assertTrue(msg, gradOK);
TestUtils.testModelSerialization(graph); TestUtils.testModelSerialization(graph);
@ -566,13 +556,12 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println("testLSTMWithDuplicateToTimeSeries()"); System.out.println("testLSTMWithDuplicateToTimeSeries()");
for (int j = 0; j < graph.getNumLayers(); j++) // for (int j = 0; j < graph.getNumLayers(); j++)
System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input1, input2})
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[] {input1, input2}, .labels(new INDArray[]{labels}));
new INDArray[] {labels});
String msg = "testLSTMWithDuplicateToTimeSeries()"; String msg = "testLSTMWithDuplicateToTimeSeries()";
assertTrue(msg, gradOK); assertTrue(msg, gradOK);
@ -615,13 +604,12 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println("testLSTMWithReverseTimeSeriesVertex()"); System.out.println("testLSTMWithReverseTimeSeriesVertex()");
for (int j = 0; j < graph.getNumLayers(); j++) // for (int j = 0; j < graph.getNumLayers(); j++)
System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[] {input}, .labels(new INDArray[]{labels}));
new INDArray[] {labels});
String msg = "testLSTMWithDuplicateToTimeSeries()"; String msg = "testLSTMWithDuplicateToTimeSeries()";
assertTrue(msg, gradOK); assertTrue(msg, gradOK);
@ -632,8 +620,8 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
inMask.putRow(1, Nd4j.create(new double[] {1, 1, 0, 1, 0})); inMask.putRow(1, Nd4j.create(new double[] {1, 1, 0, 1, 0}));
inMask.putRow(2, Nd4j.create(new double[] {1, 1, 1, 1, 1})); inMask.putRow(2, Nd4j.create(new double[] {1, 1, 1, 1, 1}));
graph.setLayerMaskArrays(new INDArray[] {inMask}, null); graph.setLayerMaskArrays(new INDArray[] {inMask}, null);
gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[] {input}, new INDArray[] {labels}); .labels(new INDArray[]{labels}));
assertTrue(msg, gradOK); assertTrue(msg, gradOK);
TestUtils.testModelSerialization(graph); TestUtils.testModelSerialization(graph);
@ -671,13 +659,12 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
String msg = "testMultipleInputsLayer() - minibatchSize = " + mb; String msg = "testMultipleInputsLayer() - minibatchSize = " + mb;
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println(msg); System.out.println(msg);
for (int j = 0; j < graph.getNumLayers(); j++) // for (int j = 0; j < graph.getNumLayers(); j++)
System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(inputs)
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, inputs, .labels(new INDArray[]{out}));
new INDArray[] {out});
assertTrue(msg, gradOK); assertTrue(msg, gradOK);
TestUtils.testModelSerialization(graph); TestUtils.testModelSerialization(graph);
@ -712,13 +699,12 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
String msg = "testMultipleOutputsLayer() - minibatchSize = " + mb; String msg = "testMultipleOutputsLayer() - minibatchSize = " + mb;
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println(msg); System.out.println(msg);
for (int j = 0; j < graph.getNumLayers(); j++) // for (int j = 0; j < graph.getNumLayers(); j++)
System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[] {input}, .labels(new INDArray[]{out}));
new INDArray[] {out});
assertTrue(msg, gradOK); assertTrue(msg, gradOK);
TestUtils.testModelSerialization(graph); TestUtils.testModelSerialization(graph);
@ -759,12 +745,12 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
String msg = "testMultipleOutputsMergeVertex() - minibatchSize = " + mb; String msg = "testMultipleOutputsMergeVertex() - minibatchSize = " + mb;
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println(msg); System.out.println(msg);
for (int j = 0; j < graph.getNumLayers(); j++) // for (int j = 0; j < graph.getNumLayers(); j++)
System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(input)
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, new INDArray[] {out}); .labels(new INDArray[]{out}));
assertTrue(msg, gradOK); assertTrue(msg, gradOK);
TestUtils.testModelSerialization(graph); TestUtils.testModelSerialization(graph);
@ -810,13 +796,12 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
String msg = "testMultipleOutputsMergeVertex() - minibatchSize = " + mb; String msg = "testMultipleOutputsMergeVertex() - minibatchSize = " + mb;
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println(msg); System.out.println(msg);
for (int j = 0; j < graph.getNumLayers(); j++) // for (int j = 0; j < graph.getNumLayers(); j++)
System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[] {input}, .labels(new INDArray[]{out}));
new INDArray[] {out});
assertTrue(msg, gradOK); assertTrue(msg, gradOK);
TestUtils.testModelSerialization(graph); TestUtils.testModelSerialization(graph);
@ -873,19 +858,18 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
Map<String, INDArray> out = graph.feedForward(new INDArray[] {pos, anc, neg}, true); Map<String, INDArray> out = graph.feedForward(new INDArray[] {pos, anc, neg}, true);
for (String s : out.keySet()) { // for (String s : out.keySet()) {
System.out.println(s + "\t" + Arrays.toString(out.get(s).shape())); // System.out.println(s + "\t" + Arrays.toString(out.get(s).shape()));
} // }
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println("testBasicIrisTripletStackingL2Loss()"); System.out.println("testBasicIrisTripletStackingL2Loss()");
for (int j = 0; j < graph.getNumLayers(); j++) // for (int j = 0; j < graph.getNumLayers(); j++)
System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{pos, anc, neg})
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[] {pos, anc, neg}, .labels(new INDArray[]{labels}));
new INDArray[] {labels});
String msg = "testBasicIrisTripletStackingL2Loss()"; String msg = "testBasicIrisTripletStackingL2Loss()";
assertTrue(msg, gradOK); assertTrue(msg, gradOK);
@ -941,13 +925,12 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
String msg = "testBasicCenterLoss() - lambda = " + lambda + ", trainFirst = " + train; String msg = "testBasicCenterLoss() - lambda = " + lambda + ", trainFirst = " + train;
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println(msg); System.out.println(msg);
for (int j = 0; j < graph.getNumLayers(); j++) // for (int j = 0; j < graph.getNumLayers(); j++)
System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{example})
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[] {example}, .labels(new INDArray[]{labels}));
new INDArray[] {labels});
assertTrue(msg, gradOK); assertTrue(msg, gradOK);
TestUtils.testModelSerialization(graph); TestUtils.testModelSerialization(graph);
@ -1007,8 +990,8 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
String msg = "testBasicCenterLoss() - trainFirst = " + train; String msg = "testBasicCenterLoss() - trainFirst = " + train;
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println(msg); System.out.println(msg);
for (int j = 0; j < net.getnLayers(); j++) // for (int j = 0; j < net.getnLayers(); j++)
System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
@ -1056,13 +1039,12 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println(testName); System.out.println(testName);
for (int j = 0; j < graph.getNumLayers(); j++) // for (int j = 0; j < graph.getNumLayers(); j++)
System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{in1, in2})
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[] {in1, in2}, .labels(new INDArray[]{labels}));
new INDArray[] {labels});
assertTrue(testName, gradOK); assertTrue(testName, gradOK);
TestUtils.testModelSerialization(graph); TestUtils.testModelSerialization(graph);
@ -1115,13 +1097,12 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println(testName); System.out.println(testName);
for (int j = 0; j < graph.getNumLayers(); j++) // for (int j = 0; j < graph.getNumLayers(); j++)
System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{in1, in2})
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[] {in1, in2}, .labels(new INDArray[]{labels1, labels2}));
new INDArray[] {labels1, labels2});
assertTrue(testName, gradOK); assertTrue(testName, gradOK);
TestUtils.testModelSerialization(graph); TestUtils.testModelSerialization(graph);
@ -1174,13 +1155,12 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println(testName); System.out.println(testName);
for (int j = 0; j < graph.getNumLayers(); j++) // for (int j = 0; j < graph.getNumLayers(); j++)
System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{in1, in2})
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[] {in1, in2}, .labels(new INDArray[]{labels1, labels2}));
new INDArray[] {labels1, labels2});
assertTrue(testName, gradOK); assertTrue(testName, gradOK);
TestUtils.testModelSerialization(graph); TestUtils.testModelSerialization(graph);
@ -1238,15 +1218,14 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println(testName); System.out.println(testName);
for (int j = 0; j < graph.getNumLayers(); j++) // for (int j = 0; j < graph.getNumLayers(); j++)
System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
} }
graph.setLayerMaskArrays(new INDArray[] {inMask1, inMask2}, null); graph.setLayerMaskArrays(new INDArray[] {inMask1, inMask2}, null);
boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{in1, in2})
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[] {in1, in2}, .labels(new INDArray[]{labels1, labels2}).inputMask(new INDArray[]{inMask1, inMask2}));
new INDArray[] {labels1, labels2}, new INDArray[] {inMask1, inMask2}, null);
assertTrue(testName, gradOK); assertTrue(testName, gradOK);
TestUtils.testModelSerialization(graph); TestUtils.testModelSerialization(graph);
@ -1298,13 +1277,12 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println(testName); System.out.println(testName);
for (int j = 0; j < graph.getNumLayers(); j++) // for (int j = 0; j < graph.getNumLayers(); j++)
System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{in1, in2})
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[] {in1, in2}, .labels(new INDArray[]{labels1, labels2}));
new INDArray[] {labels1, labels2});
assertTrue(testName, gradOK); assertTrue(testName, gradOK);
TestUtils.testModelSerialization(graph); TestUtils.testModelSerialization(graph);
} }
@ -1341,13 +1319,12 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println(testName); System.out.println(testName);
for (int j = 0; j < graph.getNumLayers(); j++) // for (int j = 0; j < graph.getNumLayers(); j++)
System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{in1})
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[] {in1}, .labels(new INDArray[]{labels1}));
new INDArray[] {labels1});
assertTrue(testName, gradOK); assertTrue(testName, gradOK);
TestUtils.testModelSerialization(graph); TestUtils.testModelSerialization(graph);
@ -1391,13 +1368,12 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println(testName); System.out.println(testName);
for (int j = 0; j < graph.getNumLayers(); j++) // for (int j = 0; j < graph.getNumLayers(); j++)
System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{in1})
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[] {in1}, .labels(new INDArray[]{labels1}));
new INDArray[] {labels1});
assertTrue(testName, gradOK); assertTrue(testName, gradOK);
TestUtils.testModelSerialization(graph); TestUtils.testModelSerialization(graph);
@ -1430,12 +1406,12 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println("testGraphEmbeddingLayerSimple"); System.out.println("testGraphEmbeddingLayerSimple");
for (int j = 0; j < cg.getNumLayers(); j++) // for (int j = 0; j < cg.getNumLayers(); j++)
System.out.println("Layer " + j + " # params: " + cg.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + cg.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(cg, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(cg).inputs(new INDArray[]{input})
PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[] {input}, new INDArray[] {labels}); .labels(new INDArray[]{labels}));
String msg = "testGraphEmbeddingLayerSimple"; String msg = "testGraphEmbeddingLayerSimple";
assertTrue(msg, gradOK); assertTrue(msg, gradOK);

View File

@ -51,10 +51,6 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.*;
public class GradientCheckTestsMasking extends BaseDL4JTest { public class GradientCheckTestsMasking extends BaseDL4JTest {
private static final boolean PRINT_RESULTS = true; private static final boolean PRINT_RESULTS = true;
private static final boolean RETURN_ON_FIRST_FAILURE = false;
private static final double DEFAULT_EPS = 1e-6;
private static final double DEFAULT_MAX_REL_ERROR = 1e-3;
private static final double DEFAULT_MIN_ABS_ERROR = 1e-7;
static { static {
Nd4j.setDataType(DataType.DOUBLE); Nd4j.setDataType(DataType.DOUBLE);
@ -130,8 +126,8 @@ public class GradientCheckTestsMasking extends BaseDL4JTest {
MultiLayerNetwork mln = new MultiLayerNetwork(conf); MultiLayerNetwork mln = new MultiLayerNetwork(conf);
mln.init(); mln.init();
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input)
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, null, maskArr); .labels(labels).labelMask(maskArr));
String msg = "gradientCheckMaskingOutputSimple() - timeSeriesLength=" + timeSeriesLength String msg = "gradientCheckMaskingOutputSimple() - timeSeriesLength=" + timeSeriesLength
+ ", miniBatchSize=" + 1; + ", miniBatchSize=" + 1;
@ -186,12 +182,12 @@ public class GradientCheckTestsMasking extends BaseDL4JTest {
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println("testBidirectionalLSTMMasking() - testNum = " + testNum++); System.out.println("testBidirectionalLSTMMasking() - testNum = " + testNum++);
for (int j = 0; j < mln.getnLayers(); j++) // for (int j = 0; j < mln.getnLayers(); j++)
System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input)
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, mask, mask, true, 16); .labels(labels).inputMask(mask).labelMask(mask).subset(true).maxPerParam(16));
assertTrue(gradOK); assertTrue(gradOK);
TestUtils.testModelSerialization(mln); TestUtils.testModelSerialization(mln);
@ -271,8 +267,8 @@ public class GradientCheckTestsMasking extends BaseDL4JTest {
System.out.println(msg); System.out.println(msg);
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(features)
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, features, labels, null, labelMask); .labels(labels).labelMask(labelMask));
assertTrue(msg, gradOK); assertTrue(msg, gradOK);
TestUtils.testModelSerialization(net); TestUtils.testModelSerialization(net);
@ -366,8 +362,8 @@ public class GradientCheckTestsMasking extends BaseDL4JTest {
System.out.println(msg); System.out.println(msg);
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(features)
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, features, labels, null, labelMask); .labels(labels).labelMask(labelMask));
assertTrue(msg, gradOK); assertTrue(msg, gradOK);
@ -387,9 +383,8 @@ public class GradientCheckTestsMasking extends BaseDL4JTest {
ComputationGraph graph = new ComputationGraph(cg); ComputationGraph graph = new ComputationGraph(cg);
graph.init(); graph.init();
gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{features})
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, .labels(new INDArray[]{labels}).labelMask(new INDArray[]{labelMask}));
new INDArray[] {features}, new INDArray[] {labels}, null, new INDArray[]{labelMask}, null);
assertTrue(msg + " (compgraph)", gradOK); assertTrue(msg + " (compgraph)", gradOK);
TestUtils.testModelSerialization(graph); TestUtils.testModelSerialization(graph);
@ -425,8 +420,8 @@ public class GradientCheckTestsMasking extends BaseDL4JTest {
assertTrue(lm.sumNumber().intValue() > 0); assertTrue(lm.sumNumber().intValue() > 0);
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(f)
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, f, l, null, lm); .labels(l).labelMask(lm));
assertTrue(gradOK); assertTrue(gradOK);
//Also ensure score doesn't depend on masked feature or label values //Also ensure score doesn't depend on masked feature or label values
@ -478,9 +473,8 @@ public class GradientCheckTestsMasking extends BaseDL4JTest {
assertTrue(lm.sumNumber().intValue() > 0); assertTrue(lm.sumNumber().intValue() > 0);
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[]{f})
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[]{f}, new INDArray[]{l}, .labels(new INDArray[]{l}).labelMask(new INDArray[]{lm}));
null, new INDArray[]{lm});
assertTrue(gradOK); assertTrue(gradOK);
//Also ensure score doesn't depend on masked feature or label values //Also ensure score doesn't depend on masked feature or label values

View File

@ -82,10 +82,10 @@ public class LRNGradientCheckTests extends BaseDL4JTest {
MultiLayerNetwork mln = new MultiLayerNetwork(builder.build()); MultiLayerNetwork mln = new MultiLayerNetwork(builder.build());
mln.init(); mln.init();
if (PRINT_RESULTS) { // if (PRINT_RESULTS) {
for (int j = 0; j < mln.getnLayers(); j++) // for (int j = 0; j < mln.getnLayers(); j++)
System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
} // }
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);

View File

@ -124,8 +124,8 @@ public class LSTMGradientCheckTests extends BaseDL4JTest {
String testName = "testLSTMBasic(" + (graves ? "GravesLSTM" : "LSTM") + ")"; String testName = "testLSTMBasic(" + (graves ? "GravesLSTM" : "LSTM") + ")";
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println(testName); System.out.println(testName);
for (int j = 0; j < mln.getnLayers(); j++) // for (int j = 0; j < mln.getnLayers(); j++)
System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
@ -213,12 +213,12 @@ public class LSTMGradientCheckTests extends BaseDL4JTest {
+ outputActivation + ", l2=" + l2 + ", l1=" + l1; + outputActivation + ", l2=" + l2 + ", l1=" + l1;
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println(testName); System.out.println(testName);
for (int j = 0; j < mln.getnLayers(); j++) // for (int j = 0; j < mln.getnLayers(); j++)
System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input)
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, null, null, true, 128); .labels(labels).subset(true).maxPerParam(128));
assertTrue(testName, gradOK); assertTrue(testName, gradOK);
TestUtils.testModelSerialization(mln); TestUtils.testModelSerialization(mln);
@ -341,8 +341,8 @@ public class LSTMGradientCheckTests extends BaseDL4JTest {
System.out.println("testGradientGravesBidirectionalLSTMFull() - activationFn=" + afn System.out.println("testGradientGravesBidirectionalLSTMFull() - activationFn=" + afn
+ ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", l2=" + l2 + ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", l2=" + l2
+ ", l1=" + l1); + ", l1=" + l1);
for (int j = 0; j < mln.getnLayers(); j++) // for (int j = 0; j < mln.getnLayers(); j++)
System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
@ -394,8 +394,8 @@ public class LSTMGradientCheckTests extends BaseDL4JTest {
MultiLayerNetwork mln = new MultiLayerNetwork(conf); MultiLayerNetwork mln = new MultiLayerNetwork(conf);
mln.init(); mln.init();
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input)
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, null, null, true, 128); .labels(labels).subset(true).maxPerParam(128));
String msg = "testGradientGravesLSTMEdgeCases() - timeSeriesLength=" + timeSeriesLength[i] String msg = "testGradientGravesLSTMEdgeCases() - timeSeriesLength=" + timeSeriesLength[i]
+ ", miniBatchSize=" + miniBatchSize[i]; + ", miniBatchSize=" + miniBatchSize[i];
@ -452,8 +452,8 @@ public class LSTMGradientCheckTests extends BaseDL4JTest {
System.out.println("layer " + i + "\t" + mln.getLayer(i).numParams()); System.out.println("layer " + i + "\t" + mln.getLayer(i).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input)
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, null, null, true, 32); .labels(labels).subset(true).maxPerParam(32));
assertTrue(gradOK); assertTrue(gradOK);
TestUtils.testModelSerialization(mln); TestUtils.testModelSerialization(mln);
} }

View File

@ -206,21 +206,19 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
} else { } else {
failed.add(testName); failed.add(testName);
} }
System.out.println("\n\n");
TestUtils.testModelSerialization(net); TestUtils.testModelSerialization(net);
} }
} }
if(failed.size() > 0) {
System.out.println("---- Passed ----"); System.out.println("---- Passed ----");
for (String s : passed) { for (String s : passed) {
System.out.println(s); System.out.println(s);
} }
System.out.println("---- Failed ----");
System.out.println("---- Failed ----"); for (String s : failed) {
for (String s : failed) { System.out.println(s);
System.out.println(s); }
} }
assertEquals("Tests failed", 0, failed.size()); assertEquals("Tests failed", 0, failed.size());
@ -376,7 +374,6 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
failed.add(testName); failed.add(testName);
} }
System.out.println("\n\n");
TestUtils.testModelSerialization(net); TestUtils.testModelSerialization(net);
} }
} }
@ -684,8 +681,6 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
} else { } else {
failed.add(testName); failed.add(testName);
} }
System.out.println("\n\n");
} }
} }
} }

View File

@ -136,13 +136,13 @@ public class OutputLayerGradientChecks extends BaseDL4JTest {
String testName = "testRnnLossLayer(lf=" + lf + ", maskType=" + mt + ", outputActivation = " + oa + ")"; String testName = "testRnnLossLayer(lf=" + lf + ", maskType=" + mt + ", outputActivation = " + oa + ")";
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println(testName); System.out.println(testName);
for (int j = 0; j < mln.getnLayers(); j++) // for (int j = 0; j < mln.getnLayers(); j++)
System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
} }
System.out.println("Starting test: " + testName); System.out.println("Starting test: " + testName);
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input)
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, null, labelMask); .labels(labels).labelMask(labelMask));
assertTrue(testName, gradOK); assertTrue(testName, gradOK);
TestUtils.testModelSerialization(mln); TestUtils.testModelSerialization(mln);
@ -243,13 +243,13 @@ public class OutputLayerGradientChecks extends BaseDL4JTest {
String testName = "testCnnLossLayer(lf=" + lf + ", maskType=" + mt + ", outputActivation = " + oa + ")"; String testName = "testCnnLossLayer(lf=" + lf + ", maskType=" + mt + ", outputActivation = " + oa + ")";
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println(testName); System.out.println(testName);
for (int j = 0; j < mln.getnLayers(); j++) // for (int j = 0; j < mln.getnLayers(); j++)
System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
} }
System.out.println("Starting test: " + testName); System.out.println("Starting test: " + testName);
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input)
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, null, labelMask); .labels(labels).labelMask(labelMask));
assertTrue(testName, gradOK); assertTrue(testName, gradOK);
TestUtils.testModelSerialization(mln); TestUtils.testModelSerialization(mln);
@ -392,13 +392,13 @@ public class OutputLayerGradientChecks extends BaseDL4JTest {
String testName = "testCnn3dLossLayer(dataFormat=" + dataFormat + ",lf=" + lf + ", maskType=" + mt + ", outputActivation = " + oa + ")"; String testName = "testCnn3dLossLayer(dataFormat=" + dataFormat + ",lf=" + lf + ", maskType=" + mt + ", outputActivation = " + oa + ")";
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println(testName); System.out.println(testName);
for (int j = 0; j < mln.getnLayers(); j++) // for (int j = 0; j < mln.getnLayers(); j++)
System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
} }
System.out.println("Starting test: " + testName); System.out.println("Starting test: " + testName);
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input)
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, null, labelMask); .labels(labels).labelMask(labelMask));
assertTrue(testName, gradOK); assertTrue(testName, gradOK);
TestUtils.testModelSerialization(mln); TestUtils.testModelSerialization(mln);

View File

@ -127,8 +127,8 @@ public class RnnGradientChecks extends BaseDL4JTest {
net.init(); net.init();
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in)
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, in, labels, inMask, null); .labels(labels).inputMask(inMask));
assertTrue(gradOK); assertTrue(gradOK);
@ -207,8 +207,8 @@ public class RnnGradientChecks extends BaseDL4JTest {
net.init(); net.init();
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in)
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, in, labels, inMask, null); .labels(labels).inputMask(inMask));
assertTrue(gradOK); assertTrue(gradOK);
TestUtils.testModelSerialization(net); TestUtils.testModelSerialization(net);
} }
@ -282,8 +282,8 @@ public class RnnGradientChecks extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in)
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, in, labels, inMask, null, true, 16); .labels(labels).inputMask(inMask).subset(true).maxPerParam(16));
assertTrue(name, gradOK); assertTrue(name, gradOK);
TestUtils.testModelSerialization(net); TestUtils.testModelSerialization(net);
} }
@ -346,8 +346,8 @@ public class RnnGradientChecks extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in)
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, in, labels, inMask, null, true, 16); .labels(labels).inputMask(inMask).subset(true).maxPerParam(16));
assertTrue(name, gradOK); assertTrue(name, gradOK);
TestUtils.testModelSerialization(net); TestUtils.testModelSerialization(net);
} }

View File

@ -182,9 +182,9 @@ public class UtilLayerGradientChecks extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input)
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, .minAbsoluteError(1e-7)
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, label, inMask, null); .labels(label).inputMask(inMask));
assertTrue(gradOK); assertTrue(gradOK);
TestUtils.testModelSerialization(net); TestUtils.testModelSerialization(net);
@ -223,9 +223,8 @@ public class UtilLayerGradientChecks extends BaseDL4JTest {
Set<String> excludeParams = new HashSet<>(); Set<String> excludeParams = new HashSet<>();
excludeParams.addAll(Arrays.asList("1_W", "1_b", "2_W", "2_b")); excludeParams.addAll(Arrays.asList("1_W", "1_b", "2_W", "2_b"));
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in)
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, in, labels, null, null, .labels(labels).excludeParams(excludeParams));
false, -1, excludeParams);
assertTrue(gradOK); assertTrue(gradOK);
TestUtils.testModelSerialization(net); TestUtils.testModelSerialization(net);
@ -234,9 +233,8 @@ public class UtilLayerGradientChecks extends BaseDL4JTest {
//Test ComputationGraph equivalent: //Test ComputationGraph equivalent:
ComputationGraph g = net.toComputationGraph(); ComputationGraph g = net.toComputationGraph();
boolean gradOKCG = GradientCheckUtil.checkGradients(g, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOKCG = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(g).inputs(new INDArray[]{in})
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[]{in}, new INDArray[]{labels}, .labels(new INDArray[]{labels}).excludeParams(excludeParams));
null, null, excludeParams);
assertTrue(gradOKCG); assertTrue(gradOKCG);
TestUtils.testModelSerialization(g); TestUtils.testModelSerialization(g);

View File

@ -46,7 +46,7 @@ import static org.junit.Assert.assertTrue;
*/ */
public class VaeGradientCheckTests extends BaseDL4JTest { public class VaeGradientCheckTests extends BaseDL4JTest {
private static final boolean PRINT_RESULTS = true; private static final boolean PRINT_RESULTS = false;
private static final boolean RETURN_ON_FIRST_FAILURE = false; private static final boolean RETURN_ON_FIRST_FAILURE = false;
private static final double DEFAULT_EPS = 1e-6; private static final double DEFAULT_EPS = 1e-6;
private static final double DEFAULT_MAX_REL_ERROR = 1e-3; private static final double DEFAULT_MAX_REL_ERROR = 1e-3;
@ -122,8 +122,8 @@ public class VaeGradientCheckTests extends BaseDL4JTest {
+ Arrays.toString(decoderSizes) + ", l2=" + l2 + ", l1=" + l1; + Arrays.toString(decoderSizes) + ", l2=" + l2 + ", l1=" + l1;
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println(msg); System.out.println(msg);
for (int j = 0; j < mln.getnLayers(); j++) // for (int j = 0; j < mln.getnLayers(); j++)
System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
@ -193,8 +193,8 @@ public class VaeGradientCheckTests extends BaseDL4JTest {
+ l1; + l1;
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println(msg); System.out.println(msg);
for (int l = 0; l < mln.getnLayers(); l++) // for (int l = 0; l < mln.getnLayers(); l++)
System.out.println("Layer " + l + " # params: " + mln.getLayer(l).numParams()); // System.out.println("Layer " + l + " # params: " + mln.getLayer(l).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradientsPretrainLayer(layer, DEFAULT_EPS, boolean gradOK = GradientCheckUtil.checkGradientsPretrainLayer(layer, DEFAULT_EPS,
@ -281,8 +281,8 @@ public class VaeGradientCheckTests extends BaseDL4JTest {
String msg = "testVaePretrainReconstructionDistributions() - " + reconstructionDistributions[i]; String msg = "testVaePretrainReconstructionDistributions() - " + reconstructionDistributions[i];
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println(msg); System.out.println(msg);
for (int j = 0; j < mln.getnLayers(); j++) // for (int j = 0; j < mln.getnLayers(); j++)
System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradientsPretrainLayer(layer, DEFAULT_EPS, boolean gradOK = GradientCheckUtil.checkGradientsPretrainLayer(layer, DEFAULT_EPS,
@ -323,8 +323,8 @@ public class VaeGradientCheckTests extends BaseDL4JTest {
String msg = "testVaePretrainMultipleSamples() - numSamples = " + numSamples; String msg = "testVaePretrainMultipleSamples() - numSamples = " + numSamples;
if (PRINT_RESULTS) { if (PRINT_RESULTS) {
System.out.println(msg); System.out.println(msg);
for (int j = 0; j < mln.getnLayers(); j++) // for (int j = 0; j < mln.getnLayers(); j++)
System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradientsPretrainLayer(layer, DEFAULT_EPS, boolean gradOK = GradientCheckUtil.checkGradientsPretrainLayer(layer, DEFAULT_EPS,

View File

@ -120,8 +120,8 @@ public class YoloGradientCheckTests extends BaseDL4JTest {
String msg = "testYoloOutputLayer() - minibatch = " + mb + ", w=" + w + ", h=" + h + ", l1=" + l1[i] + ", l2=" + l2[i]; String msg = "testYoloOutputLayer() - minibatch = " + mb + ", w=" + w + ", h=" + h + ", l1=" + l1[i] + ", l2=" + l2[i];
System.out.println(msg); System.out.println(msg);
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input)
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, null, null, true, 100); .labels(labels).subset(true).maxPerParam(100));
assertTrue(msg, gradOK); assertTrue(msg, gradOK);
TestUtils.testModelSerialization(net); TestUtils.testModelSerialization(net);
@ -228,8 +228,8 @@ public class YoloGradientCheckTests extends BaseDL4JTest {
INDArray f = ds.getFeatures(); INDArray f = ds.getFeatures();
INDArray l = ds.getLabels(); INDArray l = ds.getLabels();
boolean ok = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean ok = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(f)
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, f, l, null, null, true, 64); .labels(l).inputMask(null).subset(true).maxPerParam(64));
assertTrue(ok); assertTrue(ok);
TestUtils.testModelSerialization(net); TestUtils.testModelSerialization(net);

View File

@ -130,7 +130,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
.setOutputs("out").build(); .setOutputs("out").build();
String json = conf.toJson(); String json = conf.toJson();
System.out.println(json); // System.out.println(json);
ComputationGraphConfiguration conf2 = ComputationGraphConfiguration.fromJson(json); ComputationGraphConfiguration conf2 = ComputationGraphConfiguration.fromJson(json);
@ -258,7 +258,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest {
.addVertex("test2", new StaticInnerGraphVertex(4, 5), "in").setOutputs("test", "test2").build(); .addVertex("test2", new StaticInnerGraphVertex(4, 5), "in").setOutputs("test", "test2").build();
String json = conf.toJson(); String json = conf.toJson();
System.out.println(json); // System.out.println(json);
ComputationGraphConfiguration conf2 = ComputationGraphConfiguration.fromJson(json); ComputationGraphConfiguration conf2 = ComputationGraphConfiguration.fromJson(json);

View File

@ -54,7 +54,7 @@ public class CustomPreprocessorTest extends BaseDL4JTest {
String json = conf.toJson(); String json = conf.toJson();
String yaml = conf.toYaml(); String yaml = conf.toYaml();
System.out.println(json); // System.out.println(json);
MultiLayerConfiguration confFromJson = MultiLayerConfiguration.fromJson(json); MultiLayerConfiguration confFromJson = MultiLayerConfiguration.fromJson(json);
assertEquals(conf, confFromJson); assertEquals(conf, confFromJson);

View File

@ -99,6 +99,11 @@ public class DTypeTests extends BaseDL4JTest {
Convolution1D.class //Alias for Convolution1DLayer Convolution1D.class //Alias for Convolution1DLayer
)); ));
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@AfterClass @AfterClass
public static void after() { public static void after() {
ImmutableSet<ClassPath.ClassInfo> info; ImmutableSet<ClassPath.ClassInfo> info;
@ -545,6 +550,7 @@ public class DTypeTests extends BaseDL4JTest {
.layer(new Convolution3D.Builder().kernelSize(2, 2, 2).stride(1, 1, 1).nOut(3).activation(Activation.TANH).build()) .layer(new Convolution3D.Builder().kernelSize(2, 2, 2).stride(1, 1, 1).nOut(3).activation(Activation.TANH).build())
.layer(new Convolution3D.Builder().kernelSize(2, 2, 2).stride(1, 1, 1).nOut(3).activation(Activation.TANH).build()) .layer(new Convolution3D.Builder().kernelSize(2, 2, 2).stride(1, 1, 1).nOut(3).activation(Activation.TANH).build())
.layer(new Subsampling3DLayer.Builder().poolingType(PoolingType.AVG).kernelSize(2, 2, 2).stride(2, 2, 2).build()) .layer(new Subsampling3DLayer.Builder().poolingType(PoolingType.AVG).kernelSize(2, 2, 2).stride(2, 2, 2).build())
.layer(new Deconvolution3D.Builder().kernelSize(2,2,2).stride(1,1,1).nIn(3).nOut(3).activation(Activation.TANH).build())
.layer(new Cropping3D.Builder(1, 1, 1, 1, 1, 1).build()) .layer(new Cropping3D.Builder(1, 1, 1, 1, 1, 1).build())
.layer(new ZeroPadding3DLayer.Builder(1, 1, 1, 1, 1, 1).build()) .layer(new ZeroPadding3DLayer.Builder(1, 1, 1, 1, 1, 1).build())
.layer(new ActivationLayer(Activation.LEAKYRELU)) .layer(new ActivationLayer(Activation.LEAKYRELU))

View File

@ -531,28 +531,38 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
ComputationGraph graph = new ComputationGraph(conf1); ComputationGraph graph = new ComputationGraph(conf1);
graph.init(); graph.init();
System.out.println(graph.summary()); // System.out.println(graph.summary());
System.out.println(graph.summary(InputType.feedForward(5))); // System.out.println(graph.summary(InputType.feedForward(5)));
graph.summary();
graph.summary(InputType.feedForward(5));
graph = new ComputationGraph(conf2); graph = new ComputationGraph(conf2);
graph.init(); graph.init();
System.out.println(graph.summary()); // System.out.println(graph.summary());
System.out.println(graph.summary(InputType.recurrent(5))); // System.out.println(graph.summary(InputType.recurrent(5)));
graph.summary();
graph.summary(InputType.recurrent(5));
graph = new ComputationGraph(conf3); graph = new ComputationGraph(conf3);
graph.init(); graph.init();
System.out.println(graph.summary()); // System.out.println(graph.summary());
System.out.println(graph.summary(InputType.convolutional(28, 28, 1))); // System.out.println(graph.summary(InputType.convolutional(28, 28, 1)));
graph.summary();
graph.summary(InputType.convolutional(28, 28, 1));
graph = new ComputationGraph(conf4); graph = new ComputationGraph(conf4);
graph.init(); graph.init();
System.out.println(graph.summary()); // System.out.println(graph.summary());
System.out.println(graph.summary(InputType.convolutional(28, 28, 1), InputType.recurrent(5))); // System.out.println(graph.summary(InputType.convolutional(28, 28, 1), InputType.recurrent(5)));
graph.summary();
graph.summary(InputType.convolutional(28, 28, 1), InputType.recurrent(5));
graph = new ComputationGraph(conf5); graph = new ComputationGraph(conf5);
graph.init(); graph.init();
System.out.println(graph.summary()); // System.out.println(graph.summary());
System.out.println(graph.summary(InputType.convolutional(28, 28, 1))); // System.out.println(graph.summary(InputType.convolutional(28, 28, 1)));
graph.summary();
graph.summary(InputType.convolutional(28, 28, 1));
} }
@Test @Test
@ -753,7 +763,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
int nOut = 3; int nOut = 3;
for(WorkspaceMode ws : WorkspaceMode.values()) { for(WorkspaceMode ws : WorkspaceMode.values()) {
System.out.println("***** WORKSPACE: " + ws); // System.out.println("***** WORKSPACE: " + ws);
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
.updater(new Adam(0.01)) .updater(new Adam(0.01))
@ -981,7 +991,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
OptimizationAlgorithm.LBFGS}; OptimizationAlgorithm.LBFGS};
for (OptimizationAlgorithm oa : oas) { for (OptimizationAlgorithm oa : oas) {
System.out.println(oa); // System.out.println(oa);
ComputationGraphConfiguration conf = ComputationGraphConfiguration conf =
new NeuralNetConfiguration.Builder().optimizationAlgo(oa).graphBuilder() new NeuralNetConfiguration.Builder().optimizationAlgo(oa).graphBuilder()
.addInputs("input") .addInputs("input")
@ -1065,12 +1075,15 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
ComputationGraph modelToTune = new ComputationGraph(conf); ComputationGraph modelToTune = new ComputationGraph(conf);
modelToTune.init(); modelToTune.init();
System.out.println(modelToTune.summary()); // System.out.println(modelToTune.summary());
modelToTune.summary();
ComputationGraph modelNow = ComputationGraph modelNow =
new TransferLearning.GraphBuilder(modelToTune).setFeatureExtractor("denseCentre2").build(); new TransferLearning.GraphBuilder(modelToTune).setFeatureExtractor("denseCentre2").build();
System.out.println(modelNow.summary()); // System.out.println(modelNow.summary());
System.out.println(modelNow.summary(InputType.feedForward(10),InputType.feedForward(2))); // System.out.println(modelNow.summary(InputType.feedForward(10),InputType.feedForward(2)));
modelNow.summary();
modelNow.summary(InputType.feedForward(10),InputType.feedForward(2));
} }
@Test @Test
@ -1315,9 +1328,12 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
ComputationGraph modelExpectedArch = new ComputationGraph(confForArchitecture); ComputationGraph modelExpectedArch = new ComputationGraph(confForArchitecture);
modelExpectedArch.init(); modelExpectedArch.init();
ComputationGraph modelMow = new TransferLearning.GraphBuilder(modelExpectedArch).setFeatureExtractor("layer2").build(); ComputationGraph modelMow = new TransferLearning.GraphBuilder(modelExpectedArch).setFeatureExtractor("layer2").build();
System.out.println(modelExpectedArch.summary()); // System.out.println(modelExpectedArch.summary());
System.out.println(modelMow.summary()); // System.out.println(modelMow.summary());
System.out.println(modelExpectedArch.summary(InputType.recurrent(V_HEIGHT* V_WIDTH* 3))); // System.out.println(modelExpectedArch.summary(InputType.recurrent(V_HEIGHT* V_WIDTH* 3)));
modelExpectedArch.summary();
modelMow.summary();
modelExpectedArch.summary(InputType.recurrent(V_HEIGHT* V_WIDTH* 3));
} }
@Test @Test
@ -2117,8 +2133,8 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
INDArray features = Nd4j.rand(new int[] {dataSize, inputSize}); INDArray features = Nd4j.rand(new int[] {dataSize, inputSize});
INDArray labels = Nd4j.rand(new int[] {dataSize, outputSize}); INDArray labels = Nd4j.rand(new int[] {dataSize, outputSize});
boolean gradOK = GradientCheckUtil.checkGradients(net, 1e-6, 1e-3, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[]{features})
1e-8, false, true, new INDArray[]{features}, new INDArray[]{labels}, null, null); .labels(new INDArray[]{labels}));
assertTrue(gradOK); assertTrue(gradOK);
} }

View File

@ -53,7 +53,7 @@ public class TestCustomActivation extends BaseDL4JTest {
String json = conf.toJson(); String json = conf.toJson();
String yaml = conf.toYaml(); String yaml = conf.toYaml();
System.out.println(json); // System.out.println(json);
MultiLayerConfiguration confFromJson = MultiLayerConfiguration.fromJson(json); MultiLayerConfiguration confFromJson = MultiLayerConfiguration.fromJson(json);
assertEquals(conf, confFromJson); assertEquals(conf, confFromJson);

View File

@ -64,7 +64,7 @@ public class TestCustomLayers extends BaseDL4JTest {
String json = conf.toJson(); String json = conf.toJson();
String yaml = conf.toYaml(); String yaml = conf.toYaml();
System.out.println(json); // System.out.println(json);
MultiLayerConfiguration confFromJson = MultiLayerConfiguration.fromJson(json); MultiLayerConfiguration confFromJson = MultiLayerConfiguration.fromJson(json);
assertEquals(conf, confFromJson); assertEquals(conf, confFromJson);
@ -88,7 +88,7 @@ public class TestCustomLayers extends BaseDL4JTest {
String json = conf.toJson(); String json = conf.toJson();
String yaml = conf.toYaml(); String yaml = conf.toYaml();
System.out.println(json); // System.out.println(json);
ComputationGraphConfiguration confFromJson = ComputationGraphConfiguration.fromJson(json); ComputationGraphConfiguration confFromJson = ComputationGraphConfiguration.fromJson(json);
assertEquals(conf, confFromJson); assertEquals(conf, confFromJson);
@ -135,7 +135,7 @@ public class TestCustomLayers extends BaseDL4JTest {
String json = conf.toJson(); String json = conf.toJson();
String yaml = conf.toYaml(); String yaml = conf.toYaml();
System.out.println(json); // System.out.println(json);
MultiLayerConfiguration confFromJson = MultiLayerConfiguration.fromJson(json); MultiLayerConfiguration confFromJson = MultiLayerConfiguration.fromJson(json);
assertEquals(conf, confFromJson); assertEquals(conf, confFromJson);
@ -188,7 +188,7 @@ public class TestCustomLayers extends BaseDL4JTest {
String json = conf.toJson(); String json = conf.toJson();
String yaml = conf.toYaml(); String yaml = conf.toYaml();
System.out.println(json); // System.out.println(json);
ComputationGraphConfiguration confFromJson = ComputationGraphConfiguration.fromJson(json); ComputationGraphConfiguration confFromJson = ComputationGraphConfiguration.fromJson(json);
assertEquals(conf, confFromJson); assertEquals(conf, confFromJson);

View File

@ -35,6 +35,7 @@ import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.embeddings.EmbeddingInitializer; import org.deeplearning4j.nn.weights.embeddings.EmbeddingInitializer;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.impl.ActivationIdentity;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -156,7 +157,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
.layer(new RnnOutputLayer.Builder().nIn(embeddingDim).nOut(nOut).activation(Activation.SOFTMAX).build()) .layer(new RnnOutputLayer.Builder().nIn(embeddingDim).nOut(nOut).activation(Activation.SOFTMAX).build())
.build(); .build();
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().activation(Activation.TANH).list() MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().activation(Activation.TANH).list()
.layer(0, new DenseLayer.Builder().nIn(nClassesIn).nOut(5).build()) .layer(0, new DenseLayer.Builder().nIn(nClassesIn).nOut(5).activation(Activation.IDENTITY).build())
.layer(1, new OutputLayer.Builder().nIn(5).nOut(4).activation(Activation.SOFTMAX).build()) .layer(1, new OutputLayer.Builder().nIn(5).nOut(4).activation(Activation.SOFTMAX).build())
.inputPreProcessor(0, new RnnToFeedForwardPreProcessor()) .inputPreProcessor(0, new RnnToFeedForwardPreProcessor())
.build(); .build();
@ -204,7 +205,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
.layer(1, new OutputLayer.Builder().nIn(5).nOut(4).activation(Activation.SOFTMAX).build()) .layer(1, new OutputLayer.Builder().nIn(5).nOut(4).activation(Activation.SOFTMAX).build())
.build(); .build();
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().activation(Activation.TANH).list() MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().activation(Activation.TANH).list()
.layer(0, new DenseLayer.Builder().nIn(nClassesIn).nOut(5).build()) .layer(0, new DenseLayer.Builder().nIn(nClassesIn).nOut(5).activation(Activation.IDENTITY).build())
.layer(1, new OutputLayer.Builder().nIn(5).nOut(4).activation(Activation.SOFTMAX).build()) .layer(1, new OutputLayer.Builder().nIn(5).nOut(4).activation(Activation.SOFTMAX).build())
.build(); .build();
@ -249,8 +250,8 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
.build(); .build();
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().activation(Activation.TANH) MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().activation(Activation.TANH)
.weightInit(WeightInit.XAVIER).list() .weightInit(WeightInit.XAVIER).list()
.layer(0, new DenseLayer.Builder().nIn(nClassesIn).nOut(5).build()).layer(1, .layer(new DenseLayer.Builder().nIn(nClassesIn).nOut(5).activation(Activation.IDENTITY).build())
new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(5).nOut(4) .layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(5).nOut(4)
.activation(Activation.SOFTMAX).build()) .activation(Activation.SOFTMAX).build())
.build(); .build();
@ -309,7 +310,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
.layer(new RnnOutputLayer.Builder().nIn(embeddingDim).nOut(nOut).activation(Activation.SOFTMAX).build()) .layer(new RnnOutputLayer.Builder().nIn(embeddingDim).nOut(nOut).activation(Activation.SOFTMAX).build())
.build(); .build();
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().activation(Activation.TANH).list() MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().activation(Activation.TANH).list()
.layer(new DenseLayer.Builder().nIn(nClassesIn).nOut(embeddingDim).build()) .layer(new DenseLayer.Builder().nIn(nClassesIn).nOut(embeddingDim).activation(Activation.IDENTITY).build())
.layer(new RnnOutputLayer.Builder().nIn(embeddingDim).nOut(nOut).activation(Activation.SOFTMAX).build()) .layer(new RnnOutputLayer.Builder().nIn(embeddingDim).nOut(nOut).activation(Activation.SOFTMAX).build())
.setInputType(InputType.recurrent(nClassesIn)) .setInputType(InputType.recurrent(nClassesIn))
.build(); .build();
@ -344,7 +345,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
net.computeGradientAndScore(); net.computeGradientAndScore();
net2.computeGradientAndScore(); net2.computeGradientAndScore();
System.out.println(net.score() + "\t" + net2.score()); // System.out.println(net.score() + "\t" + net2.score());
assertEquals(net2.score(), net.score(), 1e-6); assertEquals(net2.score(), net.score(), 1e-6);
Map<String, INDArray> gradient = net.gradient().gradientForVariable(); Map<String, INDArray> gradient = net.gradient().gradientForVariable();
@ -375,7 +376,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
.dataType(DataType.DOUBLE) .dataType(DataType.DOUBLE)
.list() .list()
.layer(0, new DenseLayer.Builder().nIn(nClassesIn).nOut(5).build()) .layer(0, new DenseLayer.Builder().nIn(nClassesIn).nOut(5).activation(Activation.IDENTITY).build())
.layer(1, new GravesLSTM.Builder().nIn(5).nOut(7).activation(Activation.SOFTSIGN).build()) .layer(1, new GravesLSTM.Builder().nIn(5).nOut(7).activation(Activation.SOFTSIGN).build())
.layer(2, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(7).nOut(4) .layer(2, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(7).nOut(4)
.activation(Activation.SOFTMAX).build()) .activation(Activation.SOFTMAX).build())
@ -416,7 +417,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
net.computeGradientAndScore(); net.computeGradientAndScore();
net2.computeGradientAndScore(); net2.computeGradientAndScore();
System.out.println(net.score() + "\t" + net2.score()); // System.out.println(net.score() + "\t" + net2.score());
assertEquals(net2.score(), net.score(), 1e-5); assertEquals(net2.score(), net.score(), 1e-5);
Map<String, INDArray> gradient = net.gradient().gradientForVariable(); Map<String, INDArray> gradient = net.gradient().gradientForVariable();
@ -513,7 +514,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
net.computeGradientAndScore(); net.computeGradientAndScore();
net2.computeGradientAndScore(); net2.computeGradientAndScore();
System.out.println(net.score() + "\t" + net2.score()); // System.out.println(net.score() + "\t" + net2.score());
assertEquals(net2.score(), net.score(), 1e-5); assertEquals(net2.score(), net.score(), 1e-5);
Map<String, INDArray> gradients = net.gradient().gradientForVariable(); Map<String, INDArray> gradients = net.gradient().gradientForVariable();
@ -707,4 +708,21 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
return true; return true;
} }
} }
@Test
public void testEmbeddingDefaultActivation(){
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.list()
.layer(new EmbeddingLayer.Builder().nIn(10).nOut(10).build())
.layer(new EmbeddingSequenceLayer.Builder().nIn(10).nOut(10).build())
.build();
EmbeddingLayer l = (EmbeddingLayer) conf.getConf(0).getLayer();
assertEquals(new ActivationIdentity(), l.getActivationFn());
EmbeddingSequenceLayer l2 = (EmbeddingSequenceLayer) conf.getConf(1).getLayer();
assertEquals(new ActivationIdentity(), l2.getActivationFn());
}
} }

View File

@ -90,6 +90,11 @@ public class BatchNormalizationTest extends BaseDL4JTest {
public void doBefore() { public void doBefore() {
} }
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test @Test
public void testDnnForwardPass() { public void testDnnForwardPass() {
int nOut = 10; int nOut = 10;
@ -102,7 +107,7 @@ public class BatchNormalizationTest extends BaseDL4JTest {
INDArray mean = output.mean(0); INDArray mean = output.mean(0);
INDArray stdev = output.std(false, 0); INDArray stdev = output.std(false, 0);
System.out.println(Arrays.toString(mean.data().asFloat())); // System.out.println(Arrays.toString(mean.data().asFloat()));
assertArrayEquals(new float[nOut], mean.data().asFloat(), 1e-6f); assertArrayEquals(new float[nOut], mean.data().asFloat(), 1e-6f);
assertEquals(Nd4j.ones(nOut), stdev); assertEquals(Nd4j.ones(nOut), stdev);
@ -161,8 +166,8 @@ public class BatchNormalizationTest extends BaseDL4JTest {
INDArray out = l.activate(input, true, LayerWorkspaceMgr.noWorkspaces()); INDArray out = l.activate(input, true, LayerWorkspaceMgr.noWorkspaces());
System.out.println(Arrays.toString(outExpected.data().asDouble())); // System.out.println(Arrays.toString(outExpected.data().asDouble()));
System.out.println(Arrays.toString(out.data().asDouble())); // System.out.println(Arrays.toString(out.data().asDouble()));
assertEquals(outExpected, out); assertEquals(outExpected, out);
@ -190,9 +195,9 @@ public class BatchNormalizationTest extends BaseDL4JTest {
assertEquals(dldgammaExp, dldgamma); assertEquals(dldgammaExp, dldgamma);
assertEquals(dldbetaExp, dldbeta); assertEquals(dldbetaExp, dldbeta);
System.out.println("EPSILONS"); // System.out.println("EPSILONS");
System.out.println(Arrays.toString(dldinExp.data().asDouble())); // System.out.println(Arrays.toString(dldinExp.data().asDouble()));
System.out.println(Arrays.toString(p.getSecond().dup().data().asDouble())); // System.out.println(Arrays.toString(p.getSecond().dup().data().asDouble()));
assertEquals(dldinExp, p.getSecond()); assertEquals(dldinExp, p.getSecond());
} }
@ -303,8 +308,8 @@ public class BatchNormalizationTest extends BaseDL4JTest {
INDArray out = l.activate(input, true, LayerWorkspaceMgr.noWorkspaces()); INDArray out = l.activate(input, true, LayerWorkspaceMgr.noWorkspaces());
System.out.println(Arrays.toString(outExpected.data().asDouble())); // System.out.println(Arrays.toString(outExpected.data().asDouble()));
System.out.println(Arrays.toString(out.data().asDouble())); // System.out.println(Arrays.toString(out.data().asDouble()));
assertEquals(outExpected, out); assertEquals(outExpected, out);

View File

@ -140,7 +140,7 @@ public class TestYolo2OutputLayer extends BaseDL4JTest {
y2impl.setLabels(labels); y2impl.setLabels(labels);
double score = y2impl.computeScore(0.0, true, LayerWorkspaceMgr.noWorkspaces()); double score = y2impl.computeScore(0.0, true, LayerWorkspaceMgr.noWorkspaces());
System.out.println("SCORE: " + score); // System.out.println("SCORE: " + score);
assertTrue(score > 0.0); assertTrue(score > 0.0);

View File

@ -220,20 +220,20 @@ public class GravesLSTMTest extends BaseDL4JTest {
INDArray out1 = net.output(in1); INDArray out1 = net.output(in1);
INDArray out2 = net.output(in2); INDArray out2 = net.output(in2);
System.out.println(Arrays.toString(net.output(in1).data().asFloat())); // System.out.println(Arrays.toString(net.output(in1).data().asFloat()));
System.out.println(Arrays.toString(net.output(in2).data().asFloat())); // System.out.println(Arrays.toString(net.output(in2).data().asFloat()));
List<INDArray> activations1 = net.feedForward(in1); List<INDArray> activations1 = net.feedForward(in1);
List<INDArray> activations2 = net.feedForward(in2); List<INDArray> activations2 = net.feedForward(in2);
for (int i = 0; i < 3; i++) { // for (int i = 0; i < 3; i++) {
System.out.println("-----\n" + i); // System.out.println("-----\n" + i);
System.out.println(Arrays.toString(activations1.get(i).dup().data().asDouble())); // System.out.println(Arrays.toString(activations1.get(i).dup().data().asDouble()));
System.out.println(Arrays.toString(activations2.get(i).dup().data().asDouble())); // System.out.println(Arrays.toString(activations2.get(i).dup().data().asDouble()));
//
System.out.println(activations1.get(i)); // System.out.println(activations1.get(i));
System.out.println(activations2.get(i)); // System.out.println(activations2.get(i));
} // }

View File

@ -306,8 +306,8 @@ public class TestSameDiffConv extends BaseDL4JTest {
INDArray l = TestUtils.randomOneHot(minibatch, nOut); INDArray l = TestUtils.randomOneHot(minibatch, nOut);
log.info("Starting: " + msg); log.info("Starting: " + msg);
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(f)
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, f, l, null, null, true, 50); //Most of weights are in output layer .labels(l).subset(true).maxPerParam(50));
assertTrue(msg, gradOK); assertTrue(msg, gradOK);

View File

@ -135,7 +135,7 @@ public class TestSameDiffDenseVertex extends BaseDL4JTest {
assertEquals(gStd.gradient(), gSD.gradient()); assertEquals(gStd.gradient(), gSD.gradient());
System.out.println("========================================================================"); // System.out.println("========================================================================");
//Sanity check: different minibatch size //Sanity check: different minibatch size
in = Nd4j.rand(2 * minibatch, nIn); in = Nd4j.rand(2 * minibatch, nIn);

View File

@ -317,7 +317,7 @@ public class TestReconstructionDistributions extends BaseDL4JTest {
INDArray gradient = rd.gradient(x, distributionParams); INDArray gradient = rd.gradient(x, distributionParams);
String testName = "minibatch = " + minibatch + ", size = " + inputSize + ", Distribution = " + rd; String testName = "minibatch = " + minibatch + ", size = " + inputSize + ", Distribution = " + rd;
System.out.println("\n\n***** Starting test: " + testName + "*****"); System.out.println("***** Starting test: " + testName + "*****");
int totalFailureCount = 0; int totalFailureCount = 0;
for (int i = 0; i < distributionParams.size(1); i++) { for (int i = 0; i < distributionParams.size(1); i++) {
@ -349,7 +349,7 @@ public class TestReconstructionDistributions extends BaseDL4JTest {
totalFailureCount++; totalFailureCount++;
} }
} else { } else {
log.info("Input (" + j + "," + i + ") passed: grad= " + backpropGrad + ", numericalGrad= " log.trace("Input (" + j + "," + i + ") passed: grad= " + backpropGrad + ", numericalGrad= "
+ numericalGrad + ", relError= " + relError); + numericalGrad + ", relError= " + relError);
} }
} }

View File

@ -472,7 +472,7 @@ public class WorkspaceTests extends BaseDL4JTest {
final ComputationGraph computationGraph = new ComputationGraph(config); final ComputationGraph computationGraph = new ComputationGraph(config);
computationGraph.init(); computationGraph.init();
computationGraph.setListeners(new ScoreIterationListener(1)); computationGraph.setListeners(new ScoreIterationListener(3));
WSTestDataSetIterator iterator = new WSTestDataSetIterator(); WSTestDataSetIterator iterator = new WSTestDataSetIterator();
computationGraph.fit(iterator); computationGraph.fit(iterator);

View File

@ -66,7 +66,7 @@ public class BackPropMLPTest extends BaseDL4JTest {
public void testMLP() { public void testMLP() {
//Simple mini-batch test with multiple hidden layers //Simple mini-batch test with multiple hidden layers
MultiLayerConfiguration conf = getIrisMLPSimpleConfig(new int[] {5, 4, 3}, Activation.SIGMOID); MultiLayerConfiguration conf = getIrisMLPSimpleConfig(new int[] {5, 4, 3}, Activation.SIGMOID);
System.out.println(conf); // System.out.println(conf);
MultiLayerNetwork network = new MultiLayerNetwork(conf); MultiLayerNetwork network = new MultiLayerNetwork(conf);
network.init(); network.init();
DataSetIterator iter = new IrisDataSetIterator(10, 100); DataSetIterator iter = new IrisDataSetIterator(10, 100);
@ -80,7 +80,7 @@ public class BackPropMLPTest extends BaseDL4JTest {
public void testMLP2() { public void testMLP2() {
//Simple mini-batch test with multiple hidden layers //Simple mini-batch test with multiple hidden layers
MultiLayerConfiguration conf = getIrisMLPSimpleConfig(new int[] {5, 15, 3}, Activation.TANH); MultiLayerConfiguration conf = getIrisMLPSimpleConfig(new int[] {5, 15, 3}, Activation.TANH);
System.out.println(conf); // System.out.println(conf);
MultiLayerNetwork network = new MultiLayerNetwork(conf); MultiLayerNetwork network = new MultiLayerNetwork(conf);
network.init(); network.init();
@ -104,7 +104,7 @@ public class BackPropMLPTest extends BaseDL4JTest {
Layer[] layers = network.getLayers(); Layer[] layers = network.getLayers();
final boolean printCalculations = true; final boolean printCalculations = false;
while (iris.hasNext()) { while (iris.hasNext()) {
DataSet data = iris.next(); DataSet data = iris.next();
@ -212,7 +212,7 @@ public class BackPropMLPTest extends BaseDL4JTest {
assertEquals(l1BiasFloatAfter,expectedL1BiasAfter,eps); assertEquals(l1BiasFloatAfter,expectedL1BiasAfter,eps);
assertArrayEquals(l2BiasFloatAfter,expectedL2BiasAfter,eps); assertArrayEquals(l2BiasFloatAfter,expectedL2BiasAfter,eps);
*/ */
System.out.println("\n\n--------------"); // System.out.println("\n\n--------------");
} }
} }

View File

@ -922,9 +922,9 @@ public class MultiLayerTest extends BaseDL4JTest {
MultiLayerNetwork modelExpectedArch = new MultiLayerNetwork(confForArchitecture); MultiLayerNetwork modelExpectedArch = new MultiLayerNetwork(confForArchitecture);
modelExpectedArch.init(); modelExpectedArch.init();
MultiLayerNetwork modelMow = new TransferLearning.Builder(modelExpectedArch).setFeatureExtractor(2).build(); MultiLayerNetwork modelMow = new TransferLearning.Builder(modelExpectedArch).setFeatureExtractor(2).build();
System.out.println(modelExpectedArch.summary()); // System.out.println(modelExpectedArch.summary());
System.out.println(modelMow.summary()); // System.out.println(modelMow.summary());
System.out.println(modelMow.summary(InputType.recurrent(V_HEIGHT*V_WIDTH*3))); // System.out.println(modelMow.summary(InputType.recurrent(V_HEIGHT*V_WIDTH*3)));
} }
@Test(expected = DL4JException.class) @Test(expected = DL4JException.class)
@ -1149,7 +1149,7 @@ public class MultiLayerTest extends BaseDL4JTest {
int nOut = 3; int nOut = 3;
for(WorkspaceMode ws : WorkspaceMode.values()) { for(WorkspaceMode ws : WorkspaceMode.values()) {
System.out.println("***** WORKSPACE: " + ws); // System.out.println("***** WORKSPACE: " + ws);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.updater(new Adam(0.01)) .updater(new Adam(0.01))

View File

@ -570,8 +570,8 @@ public class MultiLayerTestRNN extends BaseDL4JTest {
for (int j = 0; j < expOut.size(); j++) { for (int j = 0; j < expOut.size(); j++) {
INDArray exp = expOut.get(j); INDArray exp = expOut.get(j);
INDArray act = outSlice.get(j); INDArray act = outSlice.get(j);
System.out.println(j); // System.out.println(j);
System.out.println(exp.sub(act)); // System.out.println(exp.sub(act));
assertEquals(exp, act); assertEquals(exp, act);
} }

View File

@ -219,10 +219,10 @@ public class TestVariableLengthTS extends BaseDL4JTest {
INDArray g1s = g1map.get(s); INDArray g1s = g1map.get(s);
INDArray g2s = g2map.get(s); INDArray g2s = g2map.get(s);
System.out.println("-------"); // System.out.println("-------");
System.out.println("Variable: " + s); // System.out.println("Variable: " + s);
System.out.println(Arrays.toString(g1s.dup().data().asFloat())); // System.out.println(Arrays.toString(g1s.dup().data().asFloat()));
System.out.println(Arrays.toString(g2s.dup().data().asFloat())); // System.out.println(Arrays.toString(g2s.dup().data().asFloat()));
assertNotEquals(s, g1s, g2s); assertNotEquals(s, g1s, g2s);
} }
@ -507,7 +507,7 @@ public class TestVariableLengthTS extends BaseDL4JTest {
for (boolean bidirectional : isBidirectional) { for (boolean bidirectional : isBidirectional) {
for (PoolingType pt : poolingTypes) { for (PoolingType pt : poolingTypes) {
System.out.println("Starting test: bidirectional = " + bidirectional + ", poolingType = " + pt); // System.out.println("Starting test: bidirectional = " + bidirectional + ", poolingType = " + pt);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER) MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER)
.activation(Activation.TANH).list().layer(0, bidirectional .activation(Activation.TANH).list().layer(0, bidirectional

View File

@ -51,7 +51,6 @@ public class TestFrozenLayers extends BaseDL4JTest {
for(double l1 : new double[]{0.0, 0.3}){ for(double l1 : new double[]{0.0, 0.3}){
for( double l2 : new double[]{0.0, 0.4}){ for( double l2 : new double[]{0.0, 0.4}){
System.out.println("--------------------");
String msg = "l1=" + l1 + ", l2=" + l2; String msg = "l1=" + l1 + ", l2=" + l2;
FineTuneConfiguration ftc = new FineTuneConfiguration.Builder() FineTuneConfiguration ftc = new FineTuneConfiguration.Builder()

View File

@ -273,8 +273,9 @@ public class TransferLearningComplex extends BaseDL4JTest {
MultiDataSet rand = new MultiDataSet(new INDArray[] {Nd4j.rand(2, 2), Nd4j.rand(2, 2)}, MultiDataSet rand = new MultiDataSet(new INDArray[] {Nd4j.rand(2, 2), Nd4j.rand(2, 2)},
new INDArray[] {Nd4j.rand(2, 2), Nd4j.rand(2, 3)}); new INDArray[] {Nd4j.rand(2, 2), Nd4j.rand(2, 3)});
modelNow.fit(rand); modelNow.fit(rand);
log.info(modelNow.summary()); // log.info(modelNow.summary());
log.info(modelNow.summary(InputType.feedForward(2),InputType.feedForward(2))); // log.info(modelNow.summary(InputType.feedForward(2),InputType.feedForward(2)));
modelNow.summary();
modelNow.summary(InputType.feedForward(2),InputType.feedForward(2));
} }
} }

View File

@ -195,9 +195,10 @@ public class TransferLearningHelperTest extends BaseDL4JTest {
assertEquals(modelIdentical.getLayer("denseLeft0").params(), modelToTune.getLayer("denseLeft0").params()); assertEquals(modelIdentical.getLayer("denseLeft0").params(), modelToTune.getLayer("denseLeft0").params());
assertEquals(modelIdentical.getLayer("outLeft").params(), modelToTune.getLayer("outLeft").params()); assertEquals(modelIdentical.getLayer("outLeft").params(), modelToTune.getLayer("outLeft").params());
log.info(modelIdentical.summary()); // log.info(modelIdentical.summary());
log.info(helper.unfrozenGraph().summary()); // log.info(helper.unfrozenGraph().summary());
modelIdentical.summary();
helper.unfrozenGraph().summary();
} }
@Test @Test

View File

@ -84,8 +84,8 @@ public class TestDataSetConsumer {
count.incrementAndGet(); count.incrementAndGet();
if (count.get() % 100 == 0) // if (count.get() % 100 == 0)
logger.info("Passed {} datasets...", count.get()); // logger.info("Passed {} datasets...", count.get());
return count.get(); return count.get();
} }

View File

@ -186,7 +186,7 @@ public class BackTrackLineSearchTest extends BaseDL4JTest {
MultiLayerNetwork network = new MultiLayerNetwork(getIrisMultiLayerConfig(Activation.SIGMOID, optimizer)); MultiLayerNetwork network = new MultiLayerNetwork(getIrisMultiLayerConfig(Activation.SIGMOID, optimizer));
network.init(); network.init();
TrainingListener listener = new ScoreIterationListener(1); TrainingListener listener = new ScoreIterationListener(10);
network.setListeners(Collections.singletonList(listener)); network.setListeners(Collections.singletonList(listener));
double oldScore = network.score(data); double oldScore = network.score(data);
for( int i=0; i<100; i++ ) { for( int i=0; i<100; i++ ) {
@ -204,7 +204,7 @@ public class BackTrackLineSearchTest extends BaseDL4JTest {
data.normalizeZeroMeanZeroUnitVariance(); data.normalizeZeroMeanZeroUnitVariance();
MultiLayerNetwork network = new MultiLayerNetwork(getIrisMultiLayerConfig(Activation.RELU, optimizer)); MultiLayerNetwork network = new MultiLayerNetwork(getIrisMultiLayerConfig(Activation.RELU, optimizer));
network.init(); network.init();
TrainingListener listener = new ScoreIterationListener(1); TrainingListener listener = new ScoreIterationListener(10);
network.setListeners(Collections.singletonList(listener)); network.setListeners(Collections.singletonList(listener));
double firstScore = network.score(data); double firstScore = network.score(data);
@ -223,7 +223,7 @@ public class BackTrackLineSearchTest extends BaseDL4JTest {
data.normalizeZeroMeanZeroUnitVariance(); data.normalizeZeroMeanZeroUnitVariance();
MultiLayerNetwork network = new MultiLayerNetwork(getIrisMultiLayerConfig(Activation.RELU, optimizer)); MultiLayerNetwork network = new MultiLayerNetwork(getIrisMultiLayerConfig(Activation.RELU, optimizer));
network.init(); network.init();
TrainingListener listener = new ScoreIterationListener(1); TrainingListener listener = new ScoreIterationListener(10);
network.setListeners(Collections.singletonList(listener)); network.setListeners(Collections.singletonList(listener));
double oldScore = network.score(data); double oldScore = network.score(data);

View File

@ -66,7 +66,7 @@ import static org.junit.Assert.assertTrue;
public class TestOptimizers extends BaseDL4JTest { public class TestOptimizers extends BaseDL4JTest {
//For debugging. //For debugging.
private static final boolean PRINT_OPT_RESULTS = true; private static final boolean PRINT_OPT_RESULTS = false;
@Test @Test
public void testOptimizersBasicMLPBackprop() { public void testOptimizersBasicMLPBackprop() {

View File

@ -1,79 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.optimizer.listener;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.listeners.ParamAndGradientIterationListener;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.io.File;
import static org.junit.Assert.assertEquals;
public class TestParamAndGradientIterationListener extends BaseDL4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test
public void test() throws Exception {
IrisDataSetIterator iter = new IrisDataSetIterator(30, 150);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Sgd(1e-5))
.list().layer(0, new DenseLayer.Builder().nIn(4).nOut(20).build())
.layer(1, new DenseLayer.Builder().nIn(20).nOut(30).build())
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(30).nOut(3).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
File f = testDir.newFile("paramAndGradTest.txt");
TrainingListener listener = ParamAndGradientIterationListener.builder().outputToFile(true)
.file(f)
.outputToConsole(true).outputToLogger(false).iterations(2).printHeader(true).printMean(false)
.printMinMax(false).printMeanAbsValue(true).delimiter("\t").build();
net.setListeners(listener);
for (int i = 0; i < 2; i++) {
net.fit(iter);
}
}
}

View File

@ -91,7 +91,7 @@ public class BarnesHutTsneTest extends BaseDL4JTest {
.useAdaGrad(false).build(); .useAdaGrad(false).build();
b.fit(data); b.fit(data);
log.info("Result: {}", b.getData()); // log.info("Result: {}", b.getData());
val exp = Nd4j.createFromArray(new double[]{-3.5318212819287327, 35.40331834897696, 3.890809489531651, -1.291195609955519, -42.854099388207466, 7.8761368019456635, 28.798057251442877, 7.1456564000935225, 2.9518396278984786, -42.860181054199636, -34.989343304202, -108.99770355680282, 31.78123839126566, -29.322118879730205, 163.87558311206212, 2.9538984612478396, 31.419519824305546, 13.105400907817279, 25.46987139120746, -43.27317406736858, 32.455151773056144, 25.28067703547214, 0.005442008567682552, 21.005029233370358, -61.71390311950051, 5.218417653362599, 47.15762099517554, 8.834739256343404, 17.845790108867153, -54.31654219224107, -18.71285871476804, -16.446982180909007, -71.22568781913213, -12.339975548387091, 70.49096598213703, 25.022454385237456, -14.572652938207126, -5.320080866729078, 1.5874449933639676, -40.60960510287835, -31.98564381157643, -95.40875746933808, 19.196346639002364, -38.80930682421929, 135.00454225923906, 5.277879540549592, 30.79963767087089, -0.007276462027131683, 31.278796123365815, -38.47381680049993, 10.415728497075905, 36.567265019013085, -7.406587944733211, -18.376174615781114, -45.26976962854271}).reshape(-1, 5); val exp = Nd4j.createFromArray(new double[]{-3.5318212819287327, 35.40331834897696, 3.890809489531651, -1.291195609955519, -42.854099388207466, 7.8761368019456635, 28.798057251442877, 7.1456564000935225, 2.9518396278984786, -42.860181054199636, -34.989343304202, -108.99770355680282, 31.78123839126566, -29.322118879730205, 163.87558311206212, 2.9538984612478396, 31.419519824305546, 13.105400907817279, 25.46987139120746, -43.27317406736858, 32.455151773056144, 25.28067703547214, 0.005442008567682552, 21.005029233370358, -61.71390311950051, 5.218417653362599, 47.15762099517554, 8.834739256343404, 17.845790108867153, -54.31654219224107, -18.71285871476804, -16.446982180909007, -71.22568781913213, -12.339975548387091, 70.49096598213703, 25.022454385237456, -14.572652938207126, -5.320080866729078, 1.5874449933639676, -40.60960510287835, -31.98564381157643, -95.40875746933808, 19.196346639002364, -38.80930682421929, 135.00454225923906, 5.277879540549592, 30.79963767087089, -0.007276462027131683, 31.278796123365815, -38.47381680049993, 10.415728497075905, 36.567265019013085, -7.406587944733211, -18.376174615781114, -45.26976962854271}).reshape(-1, 5);
@ -178,7 +178,7 @@ public class BarnesHutTsneTest extends BaseDL4JTest {
INDArray data = iter.next().getFeatures(); INDArray data = iter.next().getFeatures();
INDArray perplexityOutput = b.computeGaussianPerplexity(data, 30.0); INDArray perplexityOutput = b.computeGaussianPerplexity(data, 30.0);
System.out.println(perplexityOutput); // System.out.println(perplexityOutput);
} }
@Test @Test
@ -217,17 +217,17 @@ public class BarnesHutTsneTest extends BaseDL4JTest {
StopWatch watch = new StopWatch(); StopWatch watch = new StopWatch();
watch.start(); watch.start();
b.fit(data); b.fit(data);
System.out.println(b.getData()); // System.out.println(b.getData());
watch.stop(); watch.stop();
File outDir = testDir.newFolder(); File outDir = testDir.newFolder();
ClassPathResource labels = new ClassPathResource("mnist2500_labels.txt"); ClassPathResource labels = new ClassPathResource("mnist2500_labels.txt");
List<String> labelsList = IOUtils.readLines(labels.getInputStream()); List<String> labelsList = IOUtils.readLines(labels.getInputStream());
b.saveAsFile(/*labelsList,*/ new File(outDir, "raw.txt").getAbsolutePath()); b.saveAsFile(/*labelsList,*/ new File(outDir, "raw.txt").getAbsolutePath());
System.out.println(b.getData()); // System.out.println(b.getData());
System.out.println("Fit done in " + watch); System.out.println("Fit done in " + watch);
assertEquals(2500, b.getData().size(0)); assertEquals(2500, b.getData().size(0));
System.out.println(b.getData()); // System.out.println(b.getData());
INDArray a1 = b.getData().getRow(0); INDArray a1 = b.getData().getRow(0);
INDArray a2 = b.getData().getRow(1); INDArray a2 = b.getData().getRow(1);
@ -338,7 +338,7 @@ public class BarnesHutTsneTest extends BaseDL4JTest {
double[] dC = {-0.0618386320333619, -0.06266654959379839, 0.029998268806149204, 0.10780566335888186, -0.19449543068355346, -0.14763764361792697, 0.17493572758118422, 0.1926109839221966, -0.15176648259935419, 0.10974665709698186, 0.13102419155322598, 0.004941641352409449, 0.19159764518354974, -0.26332838053474944, -0.023631441261541583, 0.09838669432305949, 0.09709129638394683, -0.01605053000727605, 0.06566171635025217, -0.17325078066035252, -0.1090854255505605, 0.023350644966904276, 0.075192354899586, -0.08278373866517603, 0.18431338134579323, 0.2766031655578053, -0.17557907233268688, 0.10616148241800637, -0.09999024423215641, -0.017181932145255287, 0.06711331400576945, -0.01388231800826619, -0.10248189290485302, 0.20786521034824304, 0.11254913977572988, -0.289564646781519, 0.13491805919337516, -0.07504249344962562, 0.004154656287570634, -0.10516715438388784, -0.27984655075804576, 0.09811828071286613, 0.03684521473995052, -0.054645216532387256, -0.18147132772800725, 0.027588750493223044, 0.214734364419479, -0.026729138234415008, -0.28410504978879136, 0.007015481601883835, 0.04427981739424874, -0.059253265830134655, -0.05325479031206952, -0.11319889109674944, 0.1530133971867549}; double[] dC = {-0.0618386320333619, -0.06266654959379839, 0.029998268806149204, 0.10780566335888186, -0.19449543068355346, -0.14763764361792697, 0.17493572758118422, 0.1926109839221966, -0.15176648259935419, 0.10974665709698186, 0.13102419155322598, 0.004941641352409449, 0.19159764518354974, -0.26332838053474944, -0.023631441261541583, 0.09838669432305949, 0.09709129638394683, -0.01605053000727605, 0.06566171635025217, -0.17325078066035252, -0.1090854255505605, 0.023350644966904276, 0.075192354899586, -0.08278373866517603, 0.18431338134579323, 0.2766031655578053, -0.17557907233268688, 0.10616148241800637, -0.09999024423215641, -0.017181932145255287, 0.06711331400576945, -0.01388231800826619, -0.10248189290485302, 0.20786521034824304, 0.11254913977572988, -0.289564646781519, 0.13491805919337516, -0.07504249344962562, 0.004154656287570634, -0.10516715438388784, -0.27984655075804576, 0.09811828071286613, 0.03684521473995052, -0.054645216532387256, -0.18147132772800725, 0.027588750493223044, 0.214734364419479, -0.026729138234415008, -0.28410504978879136, 0.007015481601883835, 0.04427981739424874, -0.059253265830134655, -0.05325479031206952, -0.11319889109674944, 0.1530133971867549};
INDArray actual = gradient.getGradientFor("yIncs"); INDArray actual = gradient.getGradientFor("yIncs");
System.out.println(actual); // System.out.println(actual);
assertArrayEquals(dC, actual.reshape(1,55).toDoubleVector(), 1e-05); assertArrayEquals(dC, actual.reshape(1,55).toDoubleVector(), 1e-05);
} }
@ -482,8 +482,8 @@ public class BarnesHutTsneTest extends BaseDL4JTest {
List<DataPoint> results = new ArrayList<>(); List<DataPoint> results = new ArrayList<>();
List<Double> distances = new ArrayList<>(); List<Double> distances = new ArrayList<>();
tree.search(target, 11, results, distances); tree.search(target, 11, results, distances);
System.out.println("Results:" + results); // System.out.println("Results:" + results);
System.out.println("Distances:" + distances); // System.out.println("Distances:" + distances);
} }
} }

View File

@ -250,7 +250,7 @@ public class CompareTrainingImplementations extends BaseDL4JTest {
sd.evaluate(iter, "softmax", rEvalSd); sd.evaluate(iter, "softmax", rEvalSd);
assertEquals(rEvalDl4j, rEvalSd); assertEquals(rEvalDl4j, rEvalSd);
System.out.println("---------------------------------"); // System.out.println("---------------------------------");
} }
} }
} }

View File

@ -47,6 +47,11 @@ import static org.junit.Assert.*;
public class CrashReportingUtilTest extends BaseDL4JTest { public class CrashReportingUtilTest extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 120000;
}
@Rule @Rule
public TemporaryFolder testDir = new TemporaryFolder(); public TemporaryFolder testDir = new TemporaryFolder();

View File

@ -51,7 +51,7 @@ public class ModelValidatorTests extends BaseDL4JTest {
assertEquals("MultiLayerNetwork", vr0.getFormatType()); assertEquals("MultiLayerNetwork", vr0.getFormatType());
assertEquals(MultiLayerNetwork.class, vr0.getFormatClass()); assertEquals(MultiLayerNetwork.class, vr0.getFormatClass());
assertNull(vr0.getException()); assertNull(vr0.getException());
System.out.println(vr0.toString()); // System.out.println(vr0.toString());
//Test empty file //Test empty file
File f1 = new File(f, "empty.bin"); File f1 = new File(f, "empty.bin");
@ -63,7 +63,7 @@ public class ModelValidatorTests extends BaseDL4JTest {
assertEquals("MultiLayerNetwork", vr1.getFormatType()); assertEquals("MultiLayerNetwork", vr1.getFormatType());
assertEquals(MultiLayerNetwork.class, vr1.getFormatClass()); assertEquals(MultiLayerNetwork.class, vr1.getFormatClass());
assertNull(vr1.getException()); assertNull(vr1.getException());
System.out.println(vr1.toString()); // System.out.println(vr1.toString());
//Test invalid zip file //Test invalid zip file
File f2 = new File(f, "notReallyZip.zip"); File f2 = new File(f, "notReallyZip.zip");
@ -75,7 +75,7 @@ public class ModelValidatorTests extends BaseDL4JTest {
assertEquals("MultiLayerNetwork", vr2.getFormatType()); assertEquals("MultiLayerNetwork", vr2.getFormatType());
assertEquals(MultiLayerNetwork.class, vr2.getFormatClass()); assertEquals(MultiLayerNetwork.class, vr2.getFormatClass());
assertNotNull(vr2.getException()); assertNotNull(vr2.getException());
System.out.println(vr2.toString()); // System.out.println(vr2.toString());
//Test valid zip, but missing configuration //Test valid zip, but missing configuration
File f3 = new File(f, "modelNoConfig.zip"); File f3 = new File(f, "modelNoConfig.zip");
@ -92,7 +92,7 @@ public class ModelValidatorTests extends BaseDL4JTest {
assertEquals("MultiLayerNetwork", vr3.getFormatType()); assertEquals("MultiLayerNetwork", vr3.getFormatType());
assertEquals(MultiLayerNetwork.class, vr3.getFormatClass()); assertEquals(MultiLayerNetwork.class, vr3.getFormatClass());
assertNull(vr3.getException()); assertNull(vr3.getException());
System.out.println(vr3.toString()); // System.out.println(vr3.toString());
//Test valid sip, but missing params //Test valid sip, but missing params
@ -110,7 +110,7 @@ public class ModelValidatorTests extends BaseDL4JTest {
assertEquals("MultiLayerNetwork", vr4.getFormatType()); assertEquals("MultiLayerNetwork", vr4.getFormatType());
assertEquals(MultiLayerNetwork.class, vr4.getFormatClass()); assertEquals(MultiLayerNetwork.class, vr4.getFormatClass());
assertNull(vr4.getException()); assertNull(vr4.getException());
System.out.println(vr4.toString()); // System.out.println(vr4.toString());
//Test valid model //Test valid model
@ -122,7 +122,7 @@ public class ModelValidatorTests extends BaseDL4JTest {
assertEquals("MultiLayerNetwork", vr5.getFormatType()); assertEquals("MultiLayerNetwork", vr5.getFormatType());
assertEquals(MultiLayerNetwork.class, vr5.getFormatClass()); assertEquals(MultiLayerNetwork.class, vr5.getFormatClass());
assertNull(vr5.getException()); assertNull(vr5.getException());
System.out.println(vr5.toString()); // System.out.println(vr5.toString());
//Test valid model with corrupted JSON //Test valid model with corrupted JSON
@ -141,7 +141,7 @@ public class ModelValidatorTests extends BaseDL4JTest {
bytes = IOUtils.toByteArray(zis); bytes = IOUtils.toByteArray(zis);
} }
zo.write(bytes); zo.write(bytes);
System.out.println("WROTE: " + ze.getName()); // System.out.println("WROTE: " + ze.getName());
} }
} }
} }
@ -153,7 +153,7 @@ public class ModelValidatorTests extends BaseDL4JTest {
assertEquals("MultiLayerNetwork", vr6.getFormatType()); assertEquals("MultiLayerNetwork", vr6.getFormatType());
assertEquals(MultiLayerNetwork.class, vr6.getFormatClass()); assertEquals(MultiLayerNetwork.class, vr6.getFormatClass());
assertNotNull(vr6.getException()); assertNotNull(vr6.getException());
System.out.println(vr6.toString()); // System.out.println(vr6.toString());
} }
@ -169,7 +169,7 @@ public class ModelValidatorTests extends BaseDL4JTest {
assertEquals("ComputationGraph", vr0.getFormatType()); assertEquals("ComputationGraph", vr0.getFormatType());
assertEquals(ComputationGraph.class, vr0.getFormatClass()); assertEquals(ComputationGraph.class, vr0.getFormatClass());
assertNull(vr0.getException()); assertNull(vr0.getException());
System.out.println(vr0.toString()); // System.out.println(vr0.toString());
//Test empty file //Test empty file
File f1 = new File(f, "empty.bin"); File f1 = new File(f, "empty.bin");
@ -181,7 +181,7 @@ public class ModelValidatorTests extends BaseDL4JTest {
assertEquals("ComputationGraph", vr1.getFormatType()); assertEquals("ComputationGraph", vr1.getFormatType());
assertEquals(ComputationGraph.class, vr1.getFormatClass()); assertEquals(ComputationGraph.class, vr1.getFormatClass());
assertNull(vr1.getException()); assertNull(vr1.getException());
System.out.println(vr1.toString()); // System.out.println(vr1.toString());
//Test invalid zip file //Test invalid zip file
File f2 = new File(f, "notReallyZip.zip"); File f2 = new File(f, "notReallyZip.zip");
@ -193,7 +193,7 @@ public class ModelValidatorTests extends BaseDL4JTest {
assertEquals("ComputationGraph", vr2.getFormatType()); assertEquals("ComputationGraph", vr2.getFormatType());
assertEquals(ComputationGraph.class, vr2.getFormatClass()); assertEquals(ComputationGraph.class, vr2.getFormatClass());
assertNotNull(vr2.getException()); assertNotNull(vr2.getException());
System.out.println(vr2.toString()); // System.out.println(vr2.toString());
//Test valid zip, but missing configuration //Test valid zip, but missing configuration
File f3 = new File(f, "modelNoConfig.zip"); File f3 = new File(f, "modelNoConfig.zip");
@ -210,7 +210,7 @@ public class ModelValidatorTests extends BaseDL4JTest {
assertEquals("ComputationGraph", vr3.getFormatType()); assertEquals("ComputationGraph", vr3.getFormatType());
assertEquals(ComputationGraph.class, vr3.getFormatClass()); assertEquals(ComputationGraph.class, vr3.getFormatClass());
assertNull(vr3.getException()); assertNull(vr3.getException());
System.out.println(vr3.toString()); // System.out.println(vr3.toString());
//Test valid sip, but missing params //Test valid sip, but missing params
@ -228,7 +228,7 @@ public class ModelValidatorTests extends BaseDL4JTest {
assertEquals("ComputationGraph", vr4.getFormatType()); assertEquals("ComputationGraph", vr4.getFormatType());
assertEquals(ComputationGraph.class, vr4.getFormatClass()); assertEquals(ComputationGraph.class, vr4.getFormatClass());
assertNull(vr4.getException()); assertNull(vr4.getException());
System.out.println(vr4.toString()); // System.out.println(vr4.toString());
//Test valid model //Test valid model
@ -240,7 +240,7 @@ public class ModelValidatorTests extends BaseDL4JTest {
assertEquals("ComputationGraph", vr5.getFormatType()); assertEquals("ComputationGraph", vr5.getFormatType());
assertEquals(ComputationGraph.class, vr5.getFormatClass()); assertEquals(ComputationGraph.class, vr5.getFormatClass());
assertNull(vr5.getException()); assertNull(vr5.getException());
System.out.println(vr5.toString()); // System.out.println(vr5.toString());
//Test valid model with corrupted JSON //Test valid model with corrupted JSON
@ -259,7 +259,7 @@ public class ModelValidatorTests extends BaseDL4JTest {
bytes = IOUtils.toByteArray(zis); bytes = IOUtils.toByteArray(zis);
} }
zo.write(bytes); zo.write(bytes);
System.out.println("WROTE: " + ze.getName()); // System.out.println("WROTE: " + ze.getName());
} }
} }
} }
@ -271,7 +271,7 @@ public class ModelValidatorTests extends BaseDL4JTest {
assertEquals("ComputationGraph", vr6.getFormatType()); assertEquals("ComputationGraph", vr6.getFormatType());
assertEquals(ComputationGraph.class, vr6.getFormatClass()); assertEquals(ComputationGraph.class, vr6.getFormatClass());
assertNotNull(vr6.getException()); assertNotNull(vr6.getException());
System.out.println(vr6.toString()); // System.out.println(vr6.toString());
} }

View File

@ -83,6 +83,12 @@
<artifactId>junit</artifactId> <artifactId>junit</artifactId>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-common-tests</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
</dependencies> </dependencies>
<build> <build>

View File

@ -1,141 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j;
import lombok.extern.slf4j.Slf4j;
import org.bytedeco.javacpp.Pointer;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.rules.TestName;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.profiler.ProfilerConfig;
import java.lang.management.ManagementFactory;
import java.util.List;
import java.util.Map;
import java.util.Properties;
@Slf4j
public class BaseDL4JTest {
@Rule
public TestName name = new TestName();
protected long startTime;
protected int threadCountBefore;
/**
* Override this to set the profiling mode for the tests defined in the child class
*/
public OpExecutioner.ProfilingMode getProfilingMode(){
return OpExecutioner.ProfilingMode.SCOPE_PANIC;
}
/**
* Override this to set the datatype of the tests defined in the child class
*/
public DataType getDataType(){
return DataType.DOUBLE;
}
public DataType getDefaultFPDataType(){
return getDataType();
}
@Before
public void beforeTest(){
log.info("{}.{}", getClass().getSimpleName(), name.getMethodName());
Nd4j.getExecutioner().setProfilingMode(getProfilingMode());
Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build());
Nd4j.setDefaultDataTypes(getDataType(), getDefaultFPDataType());
startTime = System.currentTimeMillis();
threadCountBefore = ManagementFactory.getThreadMXBean().getThreadCount();
}
@After
public void afterTest(){
//Attempt to keep workspaces isolated between tests
Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread();
MemoryWorkspace currWS = Nd4j.getMemoryManager().getCurrentWorkspace();
Nd4j.getMemoryManager().setCurrentWorkspace(null);
if(currWS != null){
//Not really safe to continue testing under this situation... other tests will likely fail with obscure
// errors that are hard to track back to this
log.error("Open workspace leaked from test! Exiting - {}, isOpen = {} - {}", currWS.getId(), currWS.isScopeActive(), currWS);
System.exit(1);
}
StringBuilder sb = new StringBuilder();
long maxPhys = Pointer.maxPhysicalBytes();
long maxBytes = Pointer.maxBytes();
long currPhys = Pointer.physicalBytes();
long currBytes = Pointer.totalBytes();
long jvmTotal = Runtime.getRuntime().totalMemory();
long jvmMax = Runtime.getRuntime().maxMemory();
int threadsAfter = ManagementFactory.getThreadMXBean().getThreadCount();
long duration = System.currentTimeMillis() - startTime;
sb.append(getClass().getSimpleName()).append(".").append(name.getMethodName())
.append(": ").append(duration).append(" ms")
.append(", threadCount: (").append(threadCountBefore).append("->").append(threadsAfter).append(")")
.append(", jvmTotal=").append(jvmTotal)
.append(", jvmMax=").append(jvmMax)
.append(", totalBytes=").append(currBytes).append(", maxBytes=").append(maxBytes)
.append(", currPhys=").append(currPhys).append(", maxPhys=").append(maxPhys);
List<MemoryWorkspace> ws = Nd4j.getWorkspaceManager().getAllWorkspacesForCurrentThread();
if(ws != null && ws.size() > 0){
long currSize = 0;
for(MemoryWorkspace w : ws){
currSize += w.getCurrentSize();
}
if(currSize > 0){
sb.append(", threadWSSize=").append(currSize)
.append(" (").append(ws.size()).append(" WSs)");
}
}
Properties p = Nd4j.getExecutioner().getEnvironmentInformation();
Object o = p.get("cuda.devicesInformation");
if(o instanceof List){
List<Map<String,Object>> l = (List<Map<String, Object>>) o;
if(l.size() > 0) {
sb.append(" [").append(l.size())
.append(" GPUs: ");
for (int i = 0; i < l.size(); i++) {
Map<String,Object> m = l.get(i);
if(i > 0)
sb.append(",");
sb.append("(").append(m.get("cuda.freeMemory")).append(" free, ")
.append(m.get("cuda.totalMemory")).append(" total)");
}
sb.append("]");
}
}
log.info(sb.toString());
}
}

View File

@ -735,9 +735,10 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
+ convFirst; + convFirst;
System.out.println(msg); System.out.println(msg);
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, new GradientCheckUtil.MLNConfig().net(net)
labels, null, null, true, 128); .input(input).labels(labels)
.subset(true).maxPerParam(128));
assertTrue(msg, gradOK); assertTrue(msg, gradOK);
@ -879,8 +880,10 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
+ k + ", s=" + s + ", d=" + d + ", cm=" + cm; + k + ", s=" + s + ", d=" + d + ", cm=" + cm;
System.out.println(msg); System.out.println(msg);
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, null, null, true, 100); new GradientCheckUtil.MLNConfig().net(net)
.input(input).labels(labels)
.subset(true).maxPerParam(100));
assertTrue(msg, gradOK); assertTrue(msg, gradOK);
@ -948,8 +951,10 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
+ k + ", nIn=" + nIn + ", depthMul=" + depthMultiplier + ", s=" + s + ", cm=" + cm; + k + ", nIn=" + nIn + ", depthMul=" + depthMultiplier + ", s=" + s + ", cm=" + cm;
System.out.println(msg); System.out.println(msg);
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, null, null, true, 256); new GradientCheckUtil.MLNConfig().net(net)
.input(input).labels(labels)
.subset(true).maxPerParam(256));
assertTrue(msg, gradOK); assertTrue(msg, gradOK);
@ -1021,8 +1026,10 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
+ k + ", s=" + s + ", d=" + d + ", cm=" + cm; + k + ", s=" + s + ", d=" + d + ", cm=" + cm;
System.out.println(msg); System.out.println(msg);
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, null, null, true, 50); //Most params are in output layer new GradientCheckUtil.MLNConfig().net(net)
.input(input).labels(labels)
.subset(true).maxPerParam(50));
assertTrue(msg, gradOK); assertTrue(msg, gradOK);
@ -1176,8 +1183,10 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
} }
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, boolean gradOK = GradientCheckUtil.checkGradients(
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, null, null, true, 160); new GradientCheckUtil.MLNConfig().net(net)
.input(input).labels(labels)
.subset(true).maxPerParam(160));
assertTrue(msg, gradOK); assertTrue(msg, gradOK);

View File

@ -51,6 +51,13 @@
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-common-tests</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
</dependencies> </dependencies>
<profiles> <profiles>

View File

@ -1,140 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.graph;
import lombok.extern.slf4j.Slf4j;
import org.bytedeco.javacpp.Pointer;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.rules.TestName;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.profiler.ProfilerConfig;
import java.lang.management.ManagementFactory;
import java.util.List;
import java.util.Map;
import java.util.Properties;
@Slf4j
public class BaseDL4JTest {
@Rule
public TestName name = new TestName();
protected long startTime;
protected int threadCountBefore;
/**
* Override this to set the profiling mode for the tests defined in the child class
*/
public OpExecutioner.ProfilingMode getProfilingMode(){
return OpExecutioner.ProfilingMode.SCOPE_PANIC;
}
/**
* Override this to set the datatype of the tests defined in the child class
*/
public DataType getDataType(){
return DataType.DOUBLE;
}
public DataType getDefaultFPDataType(){
return getDataType();
}
@Before
public void beforeTest(){
log.info("{}.{}", getClass().getSimpleName(), name.getMethodName());
Nd4j.getExecutioner().setProfilingMode(getProfilingMode());
Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build());
Nd4j.setDefaultDataTypes(getDataType(), getDefaultFPDataType());
startTime = System.currentTimeMillis();
threadCountBefore = ManagementFactory.getThreadMXBean().getThreadCount();
}
@After
public void afterTest(){
//Attempt to keep workspaces isolated between tests
Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread();
MemoryWorkspace currWS = Nd4j.getMemoryManager().getCurrentWorkspace();
Nd4j.getMemoryManager().setCurrentWorkspace(null);
if(currWS != null){
//Not really safe to continue testing under this situation... other tests will likely fail with obscure
// errors that are hard to track back to this
log.error("Open workspace leaked from test! Exiting - {}, isOpen = {} - {}", currWS.getId(), currWS.isScopeActive(), currWS);
System.exit(1);
}
StringBuilder sb = new StringBuilder();
long maxPhys = Pointer.maxPhysicalBytes();
long maxBytes = Pointer.maxBytes();
long currPhys = Pointer.physicalBytes();
long currBytes = Pointer.totalBytes();
long jvmTotal = Runtime.getRuntime().totalMemory();
long jvmMax = Runtime.getRuntime().maxMemory();
int threadsAfter = ManagementFactory.getThreadMXBean().getThreadCount();
long duration = System.currentTimeMillis() - startTime;
sb.append(getClass().getSimpleName()).append(".").append(name.getMethodName())
.append(": ").append(duration).append(" ms")
.append(", threadCount: (").append(threadCountBefore).append("->").append(threadsAfter).append(")")
.append(", jvmTotal=").append(jvmTotal)
.append(", jvmMax=").append(jvmMax)
.append(", totalBytes=").append(currBytes).append(", maxBytes=").append(maxBytes)
.append(", currPhys=").append(currPhys).append(", maxPhys=").append(maxPhys);
List<MemoryWorkspace> ws = Nd4j.getWorkspaceManager().getAllWorkspacesForCurrentThread();
if(ws != null && ws.size() > 0){
long currSize = 0;
for(MemoryWorkspace w : ws){
currSize += w.getCurrentSize();
}
if(currSize > 0){
sb.append(", threadWSSize=").append(currSize)
.append(" (").append(ws.size()).append(" WSs)");
}
}
Properties p = Nd4j.getExecutioner().getEnvironmentInformation();
Object o = p.get("cuda.devicesInformation");
if(o instanceof List){
List<Map<String,Object>> l = (List<Map<String, Object>>) o;
if(l.size() > 0) {
sb.append(" [").append(l.size())
.append(" GPUs: ");
for (int i = 0; i < l.size(); i++) {
Map<String,Object> m = l.get(i);
if(i > 0)
sb.append(",");
sb.append("(").append(m.get("cuda.freeMemory")).append(" free, ")
.append(m.get("cuda.totalMemory")).append(" total)");
}
sb.append("]");
}
}
log.info(sb.toString());
}
}

View File

@ -17,7 +17,7 @@
package org.deeplearning4j.graph.data; package org.deeplearning4j.graph.data;
import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.graph.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.graph.api.Edge; import org.deeplearning4j.graph.api.Edge;
import org.deeplearning4j.graph.api.IGraph; import org.deeplearning4j.graph.api.IGraph;
import org.deeplearning4j.graph.data.impl.DelimitedEdgeLineProcessor; import org.deeplearning4j.graph.data.impl.DelimitedEdgeLineProcessor;

View File

@ -17,7 +17,7 @@
package org.deeplearning4j.graph.data; package org.deeplearning4j.graph.data;
import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.graph.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.graph.api.Edge; import org.deeplearning4j.graph.api.Edge;
import org.deeplearning4j.graph.api.IGraph; import org.deeplearning4j.graph.api.IGraph;
import org.deeplearning4j.graph.data.impl.WeightedEdgeLineProcessor; import org.deeplearning4j.graph.data.impl.WeightedEdgeLineProcessor;

View File

@ -17,7 +17,7 @@
package org.deeplearning4j.graph.graph; package org.deeplearning4j.graph.graph;
import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.graph.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.graph.api.*; import org.deeplearning4j.graph.api.*;
import org.deeplearning4j.graph.data.GraphLoader; import org.deeplearning4j.graph.data.GraphLoader;
import org.deeplearning4j.graph.iterator.RandomWalkIterator; import org.deeplearning4j.graph.iterator.RandomWalkIterator;

View File

@ -16,7 +16,7 @@
package org.deeplearning4j.graph.models.deepwalk; package org.deeplearning4j.graph.models.deepwalk;
import org.deeplearning4j.graph.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.graph.data.GraphLoader; import org.deeplearning4j.graph.data.GraphLoader;
import org.deeplearning4j.graph.graph.Graph; import org.deeplearning4j.graph.graph.Graph;
import org.deeplearning4j.graph.iterator.GraphWalkIterator; import org.deeplearning4j.graph.iterator.GraphWalkIterator;

View File

@ -17,7 +17,7 @@
package org.deeplearning4j.graph.models.deepwalk; package org.deeplearning4j.graph.models.deepwalk;
import org.apache.commons.io.FilenameUtils; import org.apache.commons.io.FilenameUtils;
import org.deeplearning4j.graph.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.graph.api.Edge; import org.deeplearning4j.graph.api.Edge;
import org.deeplearning4j.graph.api.IGraph; import org.deeplearning4j.graph.api.IGraph;
import org.deeplearning4j.graph.data.GraphLoader; import org.deeplearning4j.graph.data.GraphLoader;

View File

@ -16,7 +16,7 @@
package org.deeplearning4j.graph.models.deepwalk; package org.deeplearning4j.graph.models.deepwalk;
import org.deeplearning4j.graph.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.junit.Test; import org.junit.Test;
import java.util.Arrays; import java.util.Arrays;

View File

@ -55,6 +55,13 @@
<artifactId>nd4j-api</artifactId> <artifactId>nd4j-api</artifactId>
<version>${nd4j.version}</version> <version>${nd4j.version}</version>
</dependency> </dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-common-tests</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
</dependencies> </dependencies>
<profiles> <profiles>

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.plot; package org.deeplearning4j.plot;
import lombok.val; import lombok.val;
import org.deeplearning4j.BaseDL4JTest;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -25,7 +26,7 @@ import java.util.ArrayList;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
public class Test6058 { public class Test6058 extends BaseDL4JTest {
@Test @Test
public void test() throws Exception { public void test() throws Exception {

View File

@ -86,6 +86,12 @@
<artifactId>junit</artifactId> <!-- Version set by deeplearning4j-parent dependency management --> <artifactId>junit</artifactId> <!-- Version set by deeplearning4j-parent dependency management -->
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-common-tests</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
<dependency> <dependency>
<groupId>ch.qos.logback</groupId> <groupId>ch.qos.logback</groupId>

View File

@ -108,6 +108,9 @@ public class KerasLayerConfiguration {
private final String LAYER_CLASS_NAME_LEAKY_RELU = "LeakyReLU"; private final String LAYER_CLASS_NAME_LEAKY_RELU = "LeakyReLU";
private final String LAYER_CLASS_NAME_PRELU = "PReLU"; private final String LAYER_CLASS_NAME_PRELU = "PReLU";
private final String LAYER_CLASS_NAME_THRESHOLDED_RELU = "ThresholdedReLU"; private final String LAYER_CLASS_NAME_THRESHOLDED_RELU = "ThresholdedReLU";
private final String LAYER_CLASS_NAME_RELU = "ReLU";
private final String LAYER_CLASS_NAME_ELU = "ELU";
private final String LAYER_CLASS_NAME_SOFTMAX = "Softmax";
private final String LAYER_CLASS_NAME_UPSAMPLING_1D = "UpSampling1D"; private final String LAYER_CLASS_NAME_UPSAMPLING_1D = "UpSampling1D";
private final String LAYER_CLASS_NAME_UPSAMPLING_2D = "UpSampling2D"; private final String LAYER_CLASS_NAME_UPSAMPLING_2D = "UpSampling2D";
private final String LAYER_CLASS_NAME_UPSAMPLING_3D = "UpSampling3D"; private final String LAYER_CLASS_NAME_UPSAMPLING_3D = "UpSampling3D";

View File

@ -0,0 +1,95 @@
/* ******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.nn.modelimport.keras.layers.advanced.activations;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationELU;
import org.nd4j.linalg.activations.impl.ActivationLReLU;
import java.util.Map;
/**
* Imports ELU layer from Keras
*
* @author Alex Black
*/
public class KerasELU extends KerasLayer {
/**
* Constructor from parsed Keras layer configuration dictionary.
*
* @param layerConfig dictionary containing Keras layer configuration
* @throws InvalidKerasConfigurationException Invalid Keras config
* @throws UnsupportedKerasConfigurationException Unsupported Invalid Keras config
*/
public KerasELU(Map<String, Object> layerConfig)
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
this(layerConfig, true);
}
/**
* Constructor from parsed Keras layer configuration dictionary.
*
* @param layerConfig dictionary containing Keras layer configuration
* @param enforceTrainingConfig whether to enforce training-related configuration options
* @throws InvalidKerasConfigurationException Invalid Keras config
* @throws UnsupportedKerasConfigurationException Invalid Keras config
*/
public KerasELU(Map<String, Object> layerConfig, boolean enforceTrainingConfig)
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
super(layerConfig, enforceTrainingConfig);
Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
double alpha = 1.0; // Set default alpha to default in nd4j
String layerFieldLeakyReluAlpha = "alpha";
if (innerConfig.containsKey(layerFieldLeakyReluAlpha)) {
alpha = (double) innerConfig.get(layerFieldLeakyReluAlpha);
}
IActivation leakyReLU = new ActivationELU(alpha);
this.layer = new ActivationLayer.Builder().name(this.layerName).activation(leakyReLU).build();
}
/**
* Get layer output type.
*
* @param inputType Array of InputTypes
* @return output type as InputType
* @throws InvalidKerasConfigurationException Invalid Keras config
*/
public InputType getOutputType(InputType... inputType) throws InvalidKerasConfigurationException {
if (inputType.length > 1)
throw new InvalidKerasConfigurationException(
"Keras Activation layer accepts only one input (received " + inputType.length + ")");
return this.getActivationLayer().getOutputType(-1, inputType[0]);
}
/**
* Get DL4J ActivationLayer.
*
* @return ActivationLayer
*/
public ActivationLayer getActivationLayer() {
return (ActivationLayer) this.layer;
}
}

View File

@ -0,0 +1,99 @@
/* ******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.nn.modelimport.keras.layers.advanced.activations;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationLReLU;
import org.nd4j.linalg.activations.impl.ActivationReLU;
import java.util.Map;
/**
* Imports ReLU layer from Keras
*
* @author Alex Black
*/
public class KerasReLU extends KerasLayer {
/**
* Constructor from parsed Keras layer configuration dictionary.
*
* @param layerConfig dictionary containing Keras layer configuration
* @throws InvalidKerasConfigurationException Invalid Keras config
* @throws UnsupportedKerasConfigurationException Unsupported Invalid Keras config
*/
public KerasReLU(Map<String, Object> layerConfig)
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
this(layerConfig, true);
}
/**
* Constructor from parsed Keras layer configuration dictionary.
*
* @param layerConfig dictionary containing Keras layer configuration
* @param enforceTrainingConfig whether to enforce training-related configuration options
* @throws InvalidKerasConfigurationException Invalid Keras config
* @throws UnsupportedKerasConfigurationException Invalid Keras config
*/
public KerasReLU(Map<String, Object> layerConfig, boolean enforceTrainingConfig)
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
super(layerConfig, enforceTrainingConfig);
Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
Double maxValue = (Double) innerConfig.get("max_value");
double negativeSlope = 0.0;
double threshold = 0.0;
if (innerConfig.containsKey("negative_slope")) {
negativeSlope = (double) innerConfig.get("negative_slope");
}
if (innerConfig.containsKey("threshold")) {
threshold = (double) innerConfig.get("threshold");
}
this.layer = new ActivationLayer.Builder().name(this.layerName)
.activation(new ActivationReLU(maxValue, threshold, negativeSlope)).build();
}
/**
* Get layer output type.
*
* @param inputType Array of InputTypes
* @return output type as InputType
* @throws InvalidKerasConfigurationException Invalid Keras config
*/
public InputType getOutputType(InputType... inputType) throws InvalidKerasConfigurationException {
if (inputType.length > 1)
throw new InvalidKerasConfigurationException(
"Keras Activation layer accepts only one input (received " + inputType.length + ")");
return this.getActivationLayer().getOutputType(-1, inputType[0]);
}
/**
* Get DL4J ActivationLayer.
*
* @return ActivationLayer
*/
public ActivationLayer getActivationLayer() {
return (ActivationLayer) this.layer;
}
}

View File

@ -0,0 +1,85 @@
/* ******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.nn.modelimport.keras.layers.advanced.activations;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.nd4j.linalg.activations.impl.ActivationSoftmax;
import java.util.Map;
/**
* Imports Softmax layer from Keras
*
* @author Alex Black
*/
public class KerasSoftmax extends KerasLayer {
/**
* Constructor from parsed Keras layer configuration dictionary.
*
* @param layerConfig dictionary containing Keras layer configuration
* @throws InvalidKerasConfigurationException Invalid Keras config
* @throws UnsupportedKerasConfigurationException Unsupported Invalid Keras config
*/
public KerasSoftmax(Map<String, Object> layerConfig)
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
this(layerConfig, true);
}
/**
* Constructor from parsed Keras layer configuration dictionary.
*
* @param layerConfig dictionary containing Keras layer configuration
* @param enforceTrainingConfig whether to enforce training-related configuration options
* @throws InvalidKerasConfigurationException Invalid Keras config
* @throws UnsupportedKerasConfigurationException Invalid Keras config
*/
public KerasSoftmax(Map<String, Object> layerConfig, boolean enforceTrainingConfig)
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
super(layerConfig, enforceTrainingConfig);
this.layer = new ActivationLayer.Builder().name(this.layerName).activation(new ActivationSoftmax()).build();
}
/**
* Get layer output type.
*
* @param inputType Array of InputTypes
* @return output type as InputType
* @throws InvalidKerasConfigurationException Invalid Keras config
*/
public InputType getOutputType(InputType... inputType) throws InvalidKerasConfigurationException {
if (inputType.length > 1)
throw new InvalidKerasConfigurationException(
"Keras Activation layer accepts only one input (received " + inputType.length + ")");
return this.getActivationLayer().getOutputType(-1, inputType[0]);
}
/**
* Get DL4J ActivationLayer.
*
* @return ActivationLayer
*/
public ActivationLayer getActivationLayer() {
return (ActivationLayer) this.layer;
}
}

View File

@ -25,9 +25,7 @@ import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.layers.KerasInput; import org.deeplearning4j.nn.modelimport.keras.layers.KerasInput;
import org.deeplearning4j.nn.modelimport.keras.layers.advanced.activations.KerasLeakyReLU; import org.deeplearning4j.nn.modelimport.keras.layers.advanced.activations.*;
import org.deeplearning4j.nn.modelimport.keras.layers.advanced.activations.KerasPReLU;
import org.deeplearning4j.nn.modelimport.keras.layers.advanced.activations.KerasThresholdedReLU;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.*; import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.*;
import org.deeplearning4j.nn.modelimport.keras.layers.core.*; import org.deeplearning4j.nn.modelimport.keras.layers.core.*;
import org.deeplearning4j.nn.modelimport.keras.layers.embeddings.KerasEmbedding; import org.deeplearning4j.nn.modelimport.keras.layers.embeddings.KerasEmbedding;
@ -313,6 +311,12 @@ public class KerasLayerUtils {
if (lambdaLayer != null){ if (lambdaLayer != null){
layer = new KerasLambda(layerConfig, enforceTrainingConfig, lambdaLayer); layer = new KerasLambda(layerConfig, enforceTrainingConfig, lambdaLayer);
} }
} else if(layerClassName.equals(conf.getLAYER_CLASS_NAME_RELU())){
layer = new KerasReLU(layerConfig, enforceTrainingConfig);
} else if(layerClassName.equals(conf.getLAYER_CLASS_NAME_ELU())){
layer = new KerasELU(layerConfig, enforceTrainingConfig);
} else if(layerClassName.equals(conf.getLAYER_CLASS_NAME_SOFTMAX())){
layer = new KerasSoftmax(layerConfig, enforceTrainingConfig);
} }
if (layer == null){ if (layer == null){
Class<? extends KerasLayer> customConfig = customLayers.get(layerClassName); Class<? extends KerasLayer> customConfig = customLayers.get(layerClassName);

View File

@ -1,140 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.nn.modelimport.keras;
import lombok.extern.slf4j.Slf4j;
import org.bytedeco.javacpp.Pointer;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.rules.TestName;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.profiler.ProfilerConfig;
import java.lang.management.ManagementFactory;
import java.util.List;
import java.util.Map;
import java.util.Properties;
@Slf4j
public class BaseDL4JTest {
@Rule
public TestName name = new TestName();
protected long startTime;
protected int threadCountBefore;
/**
* Override this to set the profiling mode for the tests defined in the child class
*/
public OpExecutioner.ProfilingMode getProfilingMode(){
return OpExecutioner.ProfilingMode.SCOPE_PANIC;
}
/**
* Override this to set the datatype of the tests defined in the child class
*/
public DataType getDataType(){
return DataType.DOUBLE;
}
public DataType getDefaultFPDataType(){
return getDataType();
}
@Before
public void beforeTest(){
log.info("{}.{}", getClass().getSimpleName(), name.getMethodName());
Nd4j.getExecutioner().setProfilingMode(getProfilingMode());
Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build());
Nd4j.setDefaultDataTypes(getDataType(), getDefaultFPDataType());
startTime = System.currentTimeMillis();
threadCountBefore = ManagementFactory.getThreadMXBean().getThreadCount();
}
@After
public void afterTest(){
//Attempt to keep workspaces isolated between tests
Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread();
MemoryWorkspace currWS = Nd4j.getMemoryManager().getCurrentWorkspace();
Nd4j.getMemoryManager().setCurrentWorkspace(null);
if(currWS != null){
//Not really safe to continue testing under this situation... other tests will likely fail with obscure
// errors that are hard to track back to this
log.error("Open workspace leaked from test! Exiting - {}, isOpen = {} - {}", currWS.getId(), currWS.isScopeActive(), currWS);
System.exit(1);
}
StringBuilder sb = new StringBuilder();
long maxPhys = Pointer.maxPhysicalBytes();
long maxBytes = Pointer.maxBytes();
long currPhys = Pointer.physicalBytes();
long currBytes = Pointer.totalBytes();
long jvmTotal = Runtime.getRuntime().totalMemory();
long jvmMax = Runtime.getRuntime().maxMemory();
int threadsAfter = ManagementFactory.getThreadMXBean().getThreadCount();
long duration = System.currentTimeMillis() - startTime;
sb.append(getClass().getSimpleName()).append(".").append(name.getMethodName())
.append(": ").append(duration).append(" ms")
.append(", threadCount: (").append(threadCountBefore).append("->").append(threadsAfter).append(")")
.append(", jvmTotal=").append(jvmTotal)
.append(", jvmMax=").append(jvmMax)
.append(", totalBytes=").append(currBytes).append(", maxBytes=").append(maxBytes)
.append(", currPhys=").append(currPhys).append(", maxPhys=").append(maxPhys);
List<MemoryWorkspace> ws = Nd4j.getWorkspaceManager().getAllWorkspacesForCurrentThread();
if(ws != null && ws.size() > 0){
long currSize = 0;
for(MemoryWorkspace w : ws){
currSize += w.getCurrentSize();
}
if(currSize > 0){
sb.append(", threadWSSize=").append(currSize)
.append(" (").append(ws.size()).append(" WSs)");
}
}
Properties p = Nd4j.getExecutioner().getEnvironmentInformation();
Object o = p.get("cuda.devicesInformation");
if(o instanceof List){
List<Map<String,Object>> l = (List<Map<String, Object>>) o;
if(l.size() > 0) {
sb.append(" [").append(l.size())
.append(" GPUs: ");
for (int i = 0; i < l.size(); i++) {
Map<String,Object> m = l.get(i);
if(i > 0)
sb.append(",");
sb.append("(").append(m.get("cuda.freeMemory")).append(" free, ")
.append(m.get("cuda.totalMemory")).append(" total)");
}
sb.append("]");
}
}
log.info(sb.toString());
}
}

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.nn.modelimport.keras; package org.deeplearning4j.nn.modelimport.keras;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.modelimport.keras.utils.DL4JKerasModelValidator; import org.deeplearning4j.nn.modelimport.keras.utils.DL4JKerasModelValidator;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Rule; import org.junit.Rule;

View File

@ -23,7 +23,7 @@ import org.datavec.api.split.NumberedFileInputSplit;
import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator; import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator;
import org.deeplearning4j.nn.layers.recurrent.LSTM; import org.deeplearning4j.nn.layers.recurrent.LSTM;
import org.deeplearning4j.nn.layers.recurrent.LastTimeStepLayer; import org.deeplearning4j.nn.layers.recurrent.LastTimeStepLayer;
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.modelimport.keras.KerasModel; import org.deeplearning4j.nn.modelimport.keras.KerasModel;
import org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel; import org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;

View File

@ -18,7 +18,7 @@ package org.deeplearning4j.nn.modelimport.keras.configurations;
import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor; import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.PermutePreprocessor; import org.deeplearning4j.nn.modelimport.keras.preprocessors.PermutePreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor; import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;

View File

@ -20,7 +20,7 @@ import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.modelimport.keras.KerasModel; import org.deeplearning4j.nn.modelimport.keras.KerasModel;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Test; import org.junit.Test;

View File

@ -21,7 +21,7 @@ import lombok.val;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.KerasModel; import org.deeplearning4j.nn.modelimport.keras.KerasModel;
import org.deeplearning4j.nn.modelimport.keras.KerasModelImport; import org.deeplearning4j.nn.modelimport.keras.KerasModelImport;

View File

@ -18,7 +18,7 @@ package org.deeplearning4j.nn.modelimport.keras.configurations;
import org.deeplearning4j.nn.conf.distribution.*; import org.deeplearning4j.nn.conf.distribution.*;
import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;

View File

@ -17,7 +17,7 @@
package org.deeplearning4j.nn.modelimport.keras.configurations; package org.deeplearning4j.nn.modelimport.keras.configurations;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.modelimport.keras.KerasModelImport; import org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;

View File

@ -20,7 +20,7 @@ import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
import org.deeplearning4j.common.resources.DL4JResources; import org.deeplearning4j.common.resources.DL4JResources;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.KerasModelImport; import org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
import org.deeplearning4j.nn.modelimport.keras.layers.custom.KerasLRN; import org.deeplearning4j.nn.modelimport.keras.layers.custom.KerasLRN;

View File

@ -19,7 +19,7 @@ package org.deeplearning4j.nn.modelimport.keras.e2e;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLambdaLayer; import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLambdaLayer;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.KerasModel; import org.deeplearning4j.nn.modelimport.keras.KerasModel;
import org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel; import org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel;

View File

@ -30,7 +30,7 @@ import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.LossLayer; import org.deeplearning4j.nn.conf.layers.LossLayer;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.modelimport.keras.Hdf5Archive; import org.deeplearning4j.nn.modelimport.keras.Hdf5Archive;
import org.deeplearning4j.nn.modelimport.keras.KerasModel; import org.deeplearning4j.nn.modelimport.keras.KerasModel;
import org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel; import org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel;
@ -724,6 +724,29 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
} }
} }
@Test
public void testActivationLayers() throws Exception {
String[] names = new String[]{
"ELU_0_model.h5",
"LeakyReLU_0_model.h5",
"ReLU_0_model.h5",
"ReLU_1_model.h5",
"ReLU_2_model.h5",
"ReLU_3_model.h5",
"Softmax_0_model.h5",
"ThresholdReLU_0_model.h5",
};
for(String name : names ){
System.out.println("Starting test: " + name);
String modelPath = "modelimport/keras/examples/activations/" + name;
String inputsOutputPath = "modelimport/keras/examples/activations/" + (name.substring(0,name.length()-"model.h5".length()) + "inputs_and_outputs.h5");
importEndModelTest(modelPath, inputsOutputPath, true, true,
true, true, false, null, null);
}
}
private ComputationGraph importFunctionalModelH5Test(String modelPath) throws Exception { private ComputationGraph importFunctionalModelH5Test(String modelPath) throws Exception {
return importFunctionalModelH5Test(modelPath, null, false); return importFunctionalModelH5Test(modelPath, null, false);
} }
@ -991,8 +1014,8 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
} }
Nd4j.setDataType(DataType.DOUBLE); Nd4j.setDataType(DataType.DOUBLE);
boolean passed = GradientCheckUtil.checkGradients(netToTest, eps, max_rel_error, min_abs_error, true, false, boolean passed = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(netToTest).input(input)
input, labels, null, null, true, 9); .labels(labels).subset(true).maxPerParam(9));
assertTrue("Gradient check failed", passed); assertTrue("Gradient check failed", passed);
} }

View File

@ -19,7 +19,7 @@ package org.deeplearning4j.nn.modelimport.keras.e2e;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.KerasModelImport; import org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasSpaceToDepth; import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasSpaceToDepth;

View File

@ -18,7 +18,7 @@ package org.deeplearning4j.nn.modelimport.keras.e2e;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.KerasModel; import org.deeplearning4j.nn.modelimport.keras.KerasModel;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasSpaceToDepth; import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasSpaceToDepth;

View File

@ -17,7 +17,7 @@
package org.deeplearning4j.nn.modelimport.keras.layers.advanced.activation; package org.deeplearning4j.nn.modelimport.keras.layers.advanced.activation;
import org.deeplearning4j.nn.conf.layers.ActivationLayer; import org.deeplearning4j.nn.conf.layers.ActivationLayer;
import org.deeplearning4j.nn.modelimport.keras.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;

Some files were not shown because too many files have changed in this diff Show More