diff --git a/deeplearning4j/deeplearning4j-common-tests/pom.xml b/deeplearning4j/deeplearning4j-common-tests/pom.xml
index 825c55ca5..23e030df3 100644
--- a/deeplearning4j/deeplearning4j-common-tests/pom.xml
+++ b/deeplearning4j/deeplearning4j-common-tests/pom.xml
@@ -37,6 +37,10 @@
nd4j-api
${project.version}
+
+ ch.qos.logback
+ logback-classic
+
diff --git a/deeplearning4j/deeplearning4j-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java b/deeplearning4j/deeplearning4j-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java
index e9b609c45..d805adb1c 100644
--- a/deeplearning4j/deeplearning4j-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java
+++ b/deeplearning4j/deeplearning4j-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java
@@ -17,6 +17,7 @@
package org.deeplearning4j;
+import ch.qos.logback.classic.LoggerContext;
import lombok.extern.slf4j.Slf4j;
import org.bytedeco.javacpp.Pointer;
import org.junit.After;
@@ -31,6 +32,8 @@ import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.profiler.ProfilerConfig;
+import org.slf4j.ILoggerFactory;
+import org.slf4j.LoggerFactory;
import java.lang.management.ManagementFactory;
import java.util.List;
@@ -139,6 +142,15 @@ public abstract class BaseDL4JTest {
//Not really safe to continue testing under this situation... other tests will likely fail with obscure
// errors that are hard to track back to this
log.error("Open workspace leaked from test! Exiting - {}, isOpen = {} - {}", currWS.getId(), currWS.isScopeActive(), currWS);
+ System.out.println("Open workspace leaked from test! Exiting - " + currWS.getId() + ", isOpen = " + currWS.isScopeActive() + " - " + currWS);
+ System.out.flush();
+ //Try to flush logs also:
+ try{ Thread.sleep(1000); } catch (InterruptedException e){ }
+ ILoggerFactory lf = LoggerFactory.getILoggerFactory();
+ if( lf instanceof LoggerContext){
+ ((LoggerContext)lf).stop();
+ }
+ try{ Thread.sleep(1000); } catch (InterruptedException e){ }
System.exit(1);
}
diff --git a/deeplearning4j/deeplearning4j-core/pom.xml b/deeplearning4j/deeplearning4j-core/pom.xml
index 90c88d4c3..496bb6b1b 100644
--- a/deeplearning4j/deeplearning4j-core/pom.xml
+++ b/deeplearning4j/deeplearning4j-core/pom.xml
@@ -164,6 +164,20 @@
oshi-core
${oshi.version}
+
+
+
+ org.reflections
+ reflections
+ ${reflections.version}
+ test
+
+
+ com.google.code.findbugs
+ *
+
+
+
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/AssertTestsExtendBaseClass.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/AssertTestsExtendBaseClass.java
new file mode 100644
index 000000000..20d2967bb
--- /dev/null
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/AssertTestsExtendBaseClass.java
@@ -0,0 +1,72 @@
+/* ******************************************************************************
+ * Copyright (c) 2020 Konduit K.K.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+package org.deeplearning4j;
+
+import lombok.extern.slf4j.Slf4j;
+import org.junit.Test;
+import org.reflections.Reflections;
+import org.reflections.scanners.MethodAnnotationsScanner;
+import org.reflections.util.ClasspathHelper;
+import org.reflections.util.ConfigurationBuilder;
+
+import java.lang.reflect.Method;
+import java.util.*;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test)
+ * extends BaseDl4JTest - either directly or indirectly.
+ * Other than a small set of exceptions, all tests must extend this
+ *
+ * @author Alex Black
+ */
+@Slf4j
+public class AssertTestsExtendBaseClass extends BaseDL4JTest {
+
+ //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
+ private static final Set> exclusions = new HashSet<>();
+
+ @Test
+ public void checkTestClasses(){
+
+ Reflections reflections = new Reflections(new ConfigurationBuilder()
+ .setUrls(ClasspathHelper.forPackage("org.deeplearning4j"))
+ .setScanners(new MethodAnnotationsScanner()));
+ Set methods = reflections.getMethodsAnnotatedWith(Test.class);
+ Set> s = new HashSet<>();
+ for(Method m : methods){
+ s.add(m.getDeclaringClass());
+ }
+
+ List> l = new ArrayList<>(s);
+ Collections.sort(l, new Comparator>() {
+ @Override
+ public int compare(Class> aClass, Class> t1) {
+ return aClass.getName().compareTo(t1.getName());
+ }
+ });
+
+ int count = 0;
+ for(Class> c : l){
+ if(!BaseDL4JTest.class.isAssignableFrom(c) && !exclusions.contains(c)){
+ log.error("Test {} does not extend BaseDL4JTest (directly or indirectly). All tests must extend this class for proper memory tracking and timeouts", c);
+ count++;
+ }
+ }
+ assertEquals("Number of tests not extending BaseDL4JTest", 0, count);
+ }
+}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/RandomTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/RandomTests.java
index 8f727fdf9..b52b7cb49 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/RandomTests.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/RandomTests.java
@@ -17,7 +17,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.concurrent.CountDownLatch;
@Ignore
-public class RandomTests {
+public class RandomTests extends BaseDL4JTest {
@Test
public void testReproduce() throws Exception {
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/TestDataSets.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/TestDataSets.java
index 730038943..bc892905c 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/TestDataSets.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/TestDataSets.java
@@ -16,11 +16,12 @@
package org.deeplearning4j.datasets;
+import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.fetchers.Cifar10Fetcher;
import org.deeplearning4j.datasets.fetchers.TinyImageNetFetcher;
import org.junit.Test;
-public class TestDataSets {
+public class TestDataSets extends BaseDL4JTest {
@Test
public void testTinyImageNetExists() throws Exception {
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetiteratorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetiteratorTest.java
index 6b3047aa5..c20b5855f 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetiteratorTest.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetiteratorTest.java
@@ -1006,9 +1006,9 @@ public class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
for (RecordMetaData m : meta) {
Record r = csv.loadFromMetaData(m);
INDArray row = ds.getFeatures().getRow(i);
- if(i <= 3) {
- System.out.println(m.getLocation() + "\t" + r.getRecord() + "\t" + row);
- }
+// if(i <= 3) {
+// System.out.println(m.getLocation() + "\t" + r.getRecord() + "\t" + row);
+// }
for (int j = 0; j < 4; j++) {
double exp = r.getRecord().get(j).toDouble();
@@ -1017,7 +1017,7 @@ public class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
}
i++;
}
- System.out.println();
+// System.out.println();
DataSet fromMeta = rrdsi.loadFromMetaData(meta);
assertEquals(ds, fromMeta);
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/DummyBlockDataSetIteratorTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/DummyBlockDataSetIteratorTests.java
index 41cd343a1..2b9700bbe 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/DummyBlockDataSetIteratorTests.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/DummyBlockDataSetIteratorTests.java
@@ -19,6 +19,7 @@ package org.deeplearning4j.datasets.iterator;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import lombok.var;
+import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.tools.SimpleVariableGenerator;
import org.junit.Test;
import org.nd4j.linalg.dataset.api.DataSet;
@@ -31,7 +32,7 @@ import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
@Slf4j
-public class DummyBlockDataSetIteratorTests {
+public class DummyBlockDataSetIteratorTests extends BaseDL4JTest {
@Test
public void testBlock_1() throws Exception {
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/JointMultiDataSetIteratorTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/JointMultiDataSetIteratorTests.java
index a2feb91c7..2108c9ec3 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/JointMultiDataSetIteratorTests.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/JointMultiDataSetIteratorTests.java
@@ -18,13 +18,14 @@ package org.deeplearning4j.datasets.iterator;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
+import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.tools.DataSetGenerator;
import org.junit.Test;
import static org.junit.Assert.*;
@Slf4j
-public class JointMultiDataSetIteratorTests {
+public class JointMultiDataSetIteratorTests extends BaseDL4JTest {
@Test (timeout = 20000L)
public void testJMDSI_1() {
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/LoaderIteratorTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/LoaderIteratorTests.java
index 76368e729..aa49c9b50 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/LoaderIteratorTests.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/LoaderIteratorTests.java
@@ -16,6 +16,7 @@
package org.deeplearning4j.datasets.iterator;
+import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.loader.DataSetLoaderIterator;
import org.deeplearning4j.datasets.iterator.loader.MultiDataSetLoaderIterator;
import org.junit.Test;
@@ -37,7 +38,7 @@ import java.util.Random;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
-public class LoaderIteratorTests {
+public class LoaderIteratorTests extends BaseDL4JTest {
@Test
public void testDSLoaderIter(){
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/graphnodes/TestGraphNodes.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/graphnodes/TestGraphNodes.java
index 54d645259..a9816fd7c 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/graphnodes/TestGraphNodes.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/graphnodes/TestGraphNodes.java
@@ -17,6 +17,7 @@
package org.deeplearning4j.nn.graph.graphnodes;
import lombok.val;
+import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
@@ -54,7 +55,7 @@ import java.util.Map;
import static org.junit.Assert.*;
-public class TestGraphNodes {
+public class TestGraphNodes extends BaseDL4JTest {
@Test
public void testMergeNode() {
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/RepeatVectorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/RepeatVectorTest.java
index 5e8537529..f7624ec7a 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/RepeatVectorTest.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/RepeatVectorTest.java
@@ -16,6 +16,7 @@
package org.deeplearning4j.nn.layers;
+import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.misc.RepeatVector;
@@ -32,7 +33,7 @@ import java.util.Arrays;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
-public class RepeatVectorTest {
+public class RepeatVectorTest extends BaseDL4JTest {
private int REPEAT = 4;
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Convolution3DTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Convolution3DTest.java
index d4be87a2a..e9467e83a 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Convolution3DTest.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Convolution3DTest.java
@@ -16,6 +16,7 @@
package org.deeplearning4j.nn.layers.convolution;
+import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.ConvolutionMode;
@@ -37,7 +38,7 @@ import static org.junit.Assert.assertTrue;
/**
* @author Max Pumperla
*/
-public class Convolution3DTest {
+public class Convolution3DTest extends BaseDL4JTest {
private int nExamples = 1;
private int nChannelsOut = 1;
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayerTest.java
index d53522c5d..96ab25267 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayerTest.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayerTest.java
@@ -731,7 +731,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
// https://github.com/eclipse/deeplearning4j/issues/8663
//The embedding layer weight initialization should be independent of the vocabulary size (nIn setting)
- for(WeightInit wi : new WeightInit[]{WeightInit.XAVIER, WeightInit.RELU, WeightInit.XAVIER_UNIFORM, WeightInit.LECUN_NORMAL, WeightInit.VAR_SCALING_NORMAL_FAN_OUT}) {
+ for(WeightInit wi : new WeightInit[]{WeightInit.XAVIER, WeightInit.RELU, WeightInit.XAVIER_UNIFORM, WeightInit.LECUN_NORMAL}) {
for (boolean seq : new boolean[]{false, true}) {
@@ -771,7 +771,9 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
INDArray p1 = net.params();
INDArray p2 = net2.params();
INDArray p3 = net3.params();
- assertEquals(p1, p2);
+ boolean eq = p1.equalsWithEps(p2, 1e-4);
+ String str = (seq ? "EmbeddingSequenceLayer" : "EmbeddingLayer") + " - " + wi;
+ assertTrue(str + " p1/p2 params not equal", eq);
double m1 = p1.meanNumber().doubleValue();
double s1 = p1.stdNumber().doubleValue();
@@ -779,7 +781,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
double m3 = p3.meanNumber().doubleValue();
double s3 = p3.stdNumber().doubleValue();
- String str = (seq ? "EmbeddingSequenceLayer" : "EmbeddingLayer") + " - " + wi;
+
assertEquals(str, m1, m3, 0.1);
assertEquals(str, s1, s3, 0.1);
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/ocnn/OCNNOutputLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/ocnn/OCNNOutputLayerTest.java
index 3e3c47064..bf158a863 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/ocnn/OCNNOutputLayerTest.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/ocnn/OCNNOutputLayerTest.java
@@ -16,6 +16,7 @@
package org.deeplearning4j.nn.layers.ocnn;
+import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.deeplearning4j.gradientcheck.GradientCheckUtil;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
@@ -51,7 +52,7 @@ import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
-public class OCNNOutputLayerTest {
+public class OCNNOutputLayerTest extends BaseDL4JTest {
private static final boolean PRINT_RESULTS = true;
private static final boolean RETURN_ON_FIRST_FAILURE = false;
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRecurrentWeightInit.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRecurrentWeightInit.java
index 593687283..4c1e5374e 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRecurrentWeightInit.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRecurrentWeightInit.java
@@ -16,6 +16,7 @@
package org.deeplearning4j.nn.layers.recurrent;
+import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
@@ -27,7 +28,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import static org.junit.Assert.assertTrue;
-public class TestRecurrentWeightInit {
+public class TestRecurrentWeightInit extends BaseDL4JTest {
@Test
public void testRWInit() {
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/misc/LargeNetTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/misc/LargeNetTest.java
index 94b39591b..e728e0beb 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/misc/LargeNetTest.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/misc/LargeNetTest.java
@@ -16,6 +16,7 @@
package org.deeplearning4j.nn.misc;
+import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
@@ -35,7 +36,7 @@ import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
@Ignore //Ignored due to very large memory requirements
-public class LargeNetTest {
+public class LargeNetTest extends BaseDL4JTest {
@Ignore
@Test
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/updater/custom/TestCustomUpdater.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/updater/custom/TestCustomUpdater.java
index 128b1705d..c66fb41df 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/updater/custom/TestCustomUpdater.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/updater/custom/TestCustomUpdater.java
@@ -16,6 +16,7 @@
package org.deeplearning4j.nn.updater.custom;
+import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
@@ -35,7 +36,7 @@ import static org.junit.Assert.assertTrue;
/**
* Created by Alex on 09/05/2017.
*/
-public class TestCustomUpdater {
+public class TestCustomUpdater extends BaseDL4JTest {
@Test
public void testCustomUpdater() {
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/weights/LegacyWeightInitTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/weights/LegacyWeightInitTest.java
index a8a680584..1b1273280 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/weights/LegacyWeightInitTest.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/weights/LegacyWeightInitTest.java
@@ -16,6 +16,7 @@
package org.deeplearning4j.nn.weights;
+import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.conf.distribution.*;
import org.deeplearning4j.nn.conf.serde.JsonMappers;
import org.junit.After;
@@ -40,7 +41,7 @@ import static org.junit.Assert.*;
*
* @author Christian Skarby
*/
-public class LegacyWeightInitTest {
+public class LegacyWeightInitTest extends BaseDL4JTest {
private RandomFactory prevFactory;
private final static int SEED = 666;
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/weights/WeightInitIdentityTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/weights/WeightInitIdentityTest.java
index 06bfc9797..3ee2d64cb 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/weights/WeightInitIdentityTest.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/weights/WeightInitIdentityTest.java
@@ -16,6 +16,7 @@
package org.deeplearning4j.nn.weights;
+import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
@@ -34,7 +35,7 @@ import static org.junit.Assert.assertEquals;
*
* @author Christian Skarby
*/
-public class WeightInitIdentityTest {
+public class WeightInitIdentityTest extends BaseDL4JTest {
/**
* Test identity mapping for 1d convolution
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimizer/listener/ScoreStatTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimizer/listener/ScoreStatTest.java
index fdc86c6cc..a7f59bb8c 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimizer/listener/ScoreStatTest.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimizer/listener/ScoreStatTest.java
@@ -1,5 +1,6 @@
package org.deeplearning4j.optimizer.listener;
+import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.optimize.listeners.CollectScoresIterationListener;
import org.junit.Ignore;
import org.junit.Test;
@@ -7,7 +8,7 @@ import org.junit.Test;
import java.util.List;
import static org.junit.Assert.*;
-public class ScoreStatTest {
+public class ScoreStatTest extends BaseDL4JTest {
@Test
public void testScoreStatSmall() {
CollectScoresIterationListener.ScoreStat statTest = new CollectScoresIterationListener.ScoreStat();
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/MiscRegressionTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/MiscRegressionTests.java
index 84bfcd3e8..952084713 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/MiscRegressionTests.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/MiscRegressionTests.java
@@ -17,6 +17,7 @@
package org.deeplearning4j.regressiontest;
import org.apache.commons.io.FileUtils;
+import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
@@ -36,7 +37,7 @@ import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
-public class MiscRegressionTests {
+public class MiscRegressionTests extends BaseDL4JTest {
@Test
public void testFrozen() throws Exception {
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java
index 8e638f373..4a07d6c7a 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java
@@ -3304,8 +3304,9 @@ public class Nd4j {
*/
public static INDArray randomExponential(double lambda, INDArray target) {
Preconditions.checkArgument(lambda > 0, "Lambda argument must be >= 0 - got %s", lambda);
- INDArray shapeArr = Nd4j.create(ArrayUtil.toDouble(target.shape()));
- Nd4j.getExecutioner().execAndReturn(new RandomExponential(shapeArr, target, lambda));
+ INDArray shapeArr = Nd4j.createFromArray(target.shape());
+ RandomExponential r = new RandomExponential(shapeArr, target, lambda);
+ Nd4j.exec(r);
return target;
}
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/AssertTestsExtendBaseClass.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/AssertTestsExtendBaseClass.java
new file mode 100644
index 000000000..5d8a70725
--- /dev/null
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/AssertTestsExtendBaseClass.java
@@ -0,0 +1,83 @@
+/* ******************************************************************************
+ * Copyright (c) 2020 Konduit K.K.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+package org.nd4j;
+
+import lombok.extern.slf4j.Slf4j;
+import org.junit.Test;
+import org.nd4j.imports.TFGraphs.TFGraphTestAllLibnd4j;
+import org.nd4j.imports.TFGraphs.TFGraphTestAllSameDiff;
+import org.nd4j.imports.TFGraphs.TFGraphTestList;
+import org.nd4j.imports.TFGraphs.TFGraphTestZooModels;
+import org.nd4j.imports.listeners.ImportModelDebugger;
+import org.reflections.Reflections;
+import org.reflections.scanners.MethodAnnotationsScanner;
+import org.reflections.util.ClasspathHelper;
+import org.reflections.util.ConfigurationBuilder;
+
+import java.lang.reflect.Method;
+import java.util.*;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test)
+ * extends BaseDl4jTest - either directly or indirectly.
+ * Other than a small set of exceptions, all tests must extend this
+ *
+ * @author Alex Black
+ */
+@Slf4j
+public class AssertTestsExtendBaseClass extends BaseND4JTest {
+
+ //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
+ private static final Set> exclusions = new HashSet<>(Arrays.asList(
+ TFGraphTestAllSameDiff.class,
+ TFGraphTestAllLibnd4j.class,
+ TFGraphTestList.class,
+ TFGraphTestZooModels.class,
+ ImportModelDebugger.class //Run manually only, otherwise ignored
+ ));
+
+ @Test
+ public void checkTestClasses(){
+
+ Reflections reflections = new Reflections(new ConfigurationBuilder()
+ .setUrls(ClasspathHelper.forPackage("org.nd4j"))
+ .setScanners(new MethodAnnotationsScanner()));
+ Set methods = reflections.getMethodsAnnotatedWith(Test.class);
+ Set> s = new HashSet<>();
+ for(Method m : methods){
+ s.add(m.getDeclaringClass());
+ }
+
+ List> l = new ArrayList<>(s);
+ l.sort(new Comparator>() {
+ @Override
+ public int compare(Class> aClass, Class> t1) {
+ return aClass.getName().compareTo(t1.getName());
+ }
+ });
+
+ int count = 0;
+ for(Class> c : l){
+ if(!BaseND4JTest.class.isAssignableFrom(c) && !exclusions.contains(c)){
+ log.error("Test {} does not extend BaseND4JTest (directly or indirectly). All tests must extend this class for proper memory tracking and timeouts", c);
+ count++;
+ }
+ }
+ assertEquals("Number of tests not extending BaseND4JTest", 0, count);
+ }
+}
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/listeners/ImportModelDebugger.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/listeners/ImportModelDebugger.java
index 06c85b289..88a80c52f 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/listeners/ImportModelDebugger.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/listeners/ImportModelDebugger.java
@@ -90,6 +90,7 @@ import java.util.Map;
*
* @author Alex Black
*/
+@Ignore
public class ImportModelDebugger {
@Test
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/ExpandableOpsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/ExpandableOpsTests.java
index aa8037f89..b59ce3346 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/ExpandableOpsTests.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/ExpandableOpsTests.java
@@ -19,9 +19,11 @@ package org.nd4j.linalg.custom;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.junit.Test;
+import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.ops.compat.CompatStringSplit;
import org.nd4j.linalg.api.ops.util.PrintVariable;
import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
@@ -30,7 +32,16 @@ import static org.junit.Assert.assertNotNull;
* This is special test suit: we test operations that on C++ side modify arrays that come from Java
*/
@Slf4j
-public class ExpandableOpsTests {
+public class ExpandableOpsTests extends BaseNd4jTest {
+
+ public ExpandableOpsTests(Nd4jBackend backend) {
+ super(backend);
+ }
+
+ @Override
+ public char ordering() {
+ return 'c';
+ }
@Test
public void testCompatStringSplit_1() throws Exception {
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerStandardizeTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerStandardizeTest.java
index bdfcf596f..9b7b2e241 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerStandardizeTest.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerStandardizeTest.java
@@ -40,6 +40,11 @@ public class NormalizerStandardizeTest extends BaseNd4jTest {
super(backend);
}
+ @Override
+ public long getTimeoutMilliseconds() {
+ return 60_000L;
+ }
+
@Test
public void testBruteForce() {
/* This test creates a dataset where feature values are multiples of consecutive natural numbers
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/CompositeDataSetPreProcessorTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/CompositeDataSetPreProcessorTest.java
index a2af67dc9..4b80b2f1e 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/CompositeDataSetPreProcessorTest.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/CompositeDataSetPreProcessorTest.java
@@ -1,14 +1,26 @@
package org.nd4j.linalg.dataset.api.preprocessor;
import org.junit.Test;
+import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
-public class CompositeDataSetPreProcessorTest {
+public class CompositeDataSetPreProcessorTest extends BaseNd4jTest {
+
+ public CompositeDataSetPreProcessorTest(Nd4jBackend backend) {
+ super(backend);
+ }
+
+ @Override
+ public char ordering() {
+ return 'c';
+ }
+
@Test(expected = NullPointerException.class)
public void when_preConditionsIsNull_expect_NullPointerException() {
// Assemble
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/CropAndResizeDataSetPreProcessorTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/CropAndResizeDataSetPreProcessorTest.java
index 904484d5f..96def7d37 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/CropAndResizeDataSetPreProcessorTest.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/CropAndResizeDataSetPreProcessorTest.java
@@ -1,15 +1,26 @@
package org.nd4j.linalg.dataset.api.preprocessor;
import org.junit.Test;
+import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.Assert.*;
-public class CropAndResizeDataSetPreProcessorTest {
+public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTest {
+
+ public CropAndResizeDataSetPreProcessorTest(Nd4jBackend backend) {
+ super(backend);
+ }
+
+ @Override
+ public char ordering() {
+ return 'c';
+ }
@Test(expected = IllegalArgumentException.class)
public void when_originalHeightIsZero_expect_IllegalArgumentException() {
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/PermuteDataSetPreProcessorTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/PermuteDataSetPreProcessorTest.java
index acbac85df..88654ec88 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/PermuteDataSetPreProcessorTest.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/PermuteDataSetPreProcessorTest.java
@@ -1,14 +1,25 @@
package org.nd4j.linalg.dataset.api.preprocessor;
+import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.dataset.api.preprocessor.PermuteDataSetPreProcessor;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.Assert.*;
-public class PermuteDataSetPreProcessorTest {
+public class PermuteDataSetPreProcessorTest extends BaseNd4jTest {
+
+ public PermuteDataSetPreProcessorTest(Nd4jBackend backend) {
+ super(backend);
+ }
+
+ @Override
+ public char ordering() {
+ return 'c';
+ }
@Test(expected = NullPointerException.class)
public void when_dataSetIsNull_expect_NullPointerException() {
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/RGBtoGrayscaleDataSetPreProcessorTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/RGBtoGrayscaleDataSetPreProcessorTest.java
index b0408d8b7..8e76d8e95 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/RGBtoGrayscaleDataSetPreProcessorTest.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/RGBtoGrayscaleDataSetPreProcessorTest.java
@@ -1,14 +1,25 @@
package org.nd4j.linalg.dataset.api.preprocessor;
import org.junit.Test;
+import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
-public class RGBtoGrayscaleDataSetPreProcessorTest {
+public class RGBtoGrayscaleDataSetPreProcessorTest extends BaseNd4jTest {
+
+ public RGBtoGrayscaleDataSetPreProcessorTest(Nd4jBackend backend) {
+ super(backend);
+ }
+
+ @Override
+ public char ordering() {
+ return 'c';
+ }
@Test(expected = NullPointerException.class)
public void when_dataSetIsNull_expect_NullPointerException() {
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/multithreading/MultithreadedTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/multithreading/MultithreadedTests.java
index 9c50f8315..d3b49f434 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/multithreading/MultithreadedTests.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/multithreading/MultithreadedTests.java
@@ -18,9 +18,11 @@ package org.nd4j.linalg.multithreading;
import lombok.val;
import org.junit.Test;
+import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.factory.Nd4jBackend;
import java.util.ArrayList;
import java.util.HashSet;
@@ -30,7 +32,16 @@ import static org.junit.Assert.assertEquals;
/**
* @author raver119@gmail.com
*/
-public class MultithreadedTests {
+public class MultithreadedTests extends BaseNd4jTest {
+
+ public MultithreadedTests(Nd4jBackend backend) {
+ super(backend);
+ }
+
+ @Override
+ public char ordering() {
+ return 'c';
+ }
@Test
public void basicMigrationTest_1() throws Exception {
diff --git a/nd4j/nd4j-common-tests/pom.xml b/nd4j/nd4j-common-tests/pom.xml
index a242c6644..f3a2e56bb 100644
--- a/nd4j/nd4j-common-tests/pom.xml
+++ b/nd4j/nd4j-common-tests/pom.xml
@@ -22,6 +22,11 @@
nd4j-api
${project.version}
+
+ ch.qos.logback
+ logback-classic
+ ${logback.version}
+
diff --git a/nd4j/nd4j-common-tests/src/main/java/org/nd4j/BaseND4JTest.java b/nd4j/nd4j-common-tests/src/main/java/org/nd4j/BaseND4JTest.java
index ae2f56273..40b331ad5 100644
--- a/nd4j/nd4j-common-tests/src/main/java/org/nd4j/BaseND4JTest.java
+++ b/nd4j/nd4j-common-tests/src/main/java/org/nd4j/BaseND4JTest.java
@@ -17,11 +17,10 @@
package org.nd4j;
+import ch.qos.logback.classic.LoggerContext;
import lombok.extern.slf4j.Slf4j;
import org.bytedeco.javacpp.Pointer;
-import org.junit.After;
-import org.junit.Before;
-import org.junit.Rule;
+import org.junit.*;
import org.junit.rules.TestName;
import org.junit.rules.Timeout;
import org.nd4j.base.Preconditions;
@@ -31,6 +30,8 @@ import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.profiler.ProfilerConfig;
+import org.slf4j.ILoggerFactory;
+import org.slf4j.LoggerFactory;
import java.lang.management.ManagementFactory;
import java.util.List;
@@ -142,6 +143,15 @@ public abstract class BaseND4JTest {
//Not really safe to continue testing under this situation... other tests will likely fail with obscure
// errors that are hard to track back to this
log.error("Open workspace leaked from test! Exiting - {}, isOpen = {} - {}", currWS.getId(), currWS.isScopeActive(), currWS);
+ System.out.println("Open workspace leaked from test! Exiting - " + currWS.getId() + ", isOpen = " + currWS.isScopeActive() + " - " + currWS);
+ System.out.flush();
+ //Try to flush logs also:
+ try{ Thread.sleep(1000); } catch (InterruptedException e){ }
+ ILoggerFactory lf = LoggerFactory.getILoggerFactory();
+ if( lf instanceof LoggerContext){
+ ((LoggerContext)lf).stop();
+ }
+ try{ Thread.sleep(1000); } catch (InterruptedException e){ }
System.exit(1);
}