Various fixes (#290)

* Add check to ensure ALL tests extend BaseND4JTest for proper timeouts + logging

Signed-off-by: Alex Black <blacka101@gmail.com>

* Add 'must extend BaseDL4JTest' check for deeplearning4j-core

Signed-off-by: Alex Black <blacka101@gmail.com>

* Flush logging on workspace exit during tests

Signed-off-by: Alex Black <blacka101@gmail.com>
master
Alex Black 2020-03-06 00:02:32 +11:00 committed by GitHub
parent 2911da061b
commit 19d5a8d49d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
34 changed files with 325 additions and 34 deletions

View File

@ -37,6 +37,10 @@
<artifactId>nd4j-api</artifactId> <artifactId>nd4j-api</artifactId>
<version>${project.version}</version> <version>${project.version}</version>
</dependency> </dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
</dependency>
</dependencies> </dependencies>
<profiles> <profiles>

View File

@ -17,6 +17,7 @@
package org.deeplearning4j; package org.deeplearning4j;
import ch.qos.logback.classic.LoggerContext;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.Pointer;
import org.junit.After; import org.junit.After;
@ -31,6 +32,8 @@ import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.profiler.ProfilerConfig; import org.nd4j.linalg.profiler.ProfilerConfig;
import org.slf4j.ILoggerFactory;
import org.slf4j.LoggerFactory;
import java.lang.management.ManagementFactory; import java.lang.management.ManagementFactory;
import java.util.List; import java.util.List;
@ -139,6 +142,15 @@ public abstract class BaseDL4JTest {
//Not really safe to continue testing under this situation... other tests will likely fail with obscure //Not really safe to continue testing under this situation... other tests will likely fail with obscure
// errors that are hard to track back to this // errors that are hard to track back to this
log.error("Open workspace leaked from test! Exiting - {}, isOpen = {} - {}", currWS.getId(), currWS.isScopeActive(), currWS); log.error("Open workspace leaked from test! Exiting - {}, isOpen = {} - {}", currWS.getId(), currWS.isScopeActive(), currWS);
System.out.println("Open workspace leaked from test! Exiting - " + currWS.getId() + ", isOpen = " + currWS.isScopeActive() + " - " + currWS);
System.out.flush();
//Try to flush logs also:
try{ Thread.sleep(1000); } catch (InterruptedException e){ }
ILoggerFactory lf = LoggerFactory.getILoggerFactory();
if( lf instanceof LoggerContext){
((LoggerContext)lf).stop();
}
try{ Thread.sleep(1000); } catch (InterruptedException e){ }
System.exit(1); System.exit(1);
} }

View File

@ -164,6 +164,20 @@
<artifactId>oshi-core</artifactId> <artifactId>oshi-core</artifactId>
<version>${oshi.version}</version> <version>${oshi.version}</version>
</dependency> </dependency>
<!-- Test scope reflections to ensure all classes extend base test class -->
<dependency>
<groupId>org.reflections</groupId>
<artifactId>reflections</artifactId>
<version>${reflections.version}</version>
<scope>test</scope>
<exclusions>
<exclusion>
<groupId>com.google.code.findbugs</groupId>
<artifactId>*</artifactId>
</exclusion>
</exclusions>
</dependency>
</dependencies> </dependencies>
<profiles> <profiles>

View File

@ -0,0 +1,72 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j;
import lombok.extern.slf4j.Slf4j;
import org.junit.Test;
import org.reflections.Reflections;
import org.reflections.scanners.MethodAnnotationsScanner;
import org.reflections.util.ClasspathHelper;
import org.reflections.util.ConfigurationBuilder;
import java.lang.reflect.Method;
import java.util.*;
import static org.junit.Assert.assertEquals;
/**
* This class checks that all test classes (i.e., anything with one or more methods annotated with @Test)
* extends BaseDl4JTest - either directly or indirectly.
* Other than a small set of exceptions, all tests must extend this
*
* @author Alex Black
*/
@Slf4j
public class AssertTestsExtendBaseClass extends BaseDL4JTest {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
private static final Set<Class<?>> exclusions = new HashSet<>();
@Test
public void checkTestClasses(){
Reflections reflections = new Reflections(new ConfigurationBuilder()
.setUrls(ClasspathHelper.forPackage("org.deeplearning4j"))
.setScanners(new MethodAnnotationsScanner()));
Set<Method> methods = reflections.getMethodsAnnotatedWith(Test.class);
Set<Class<?>> s = new HashSet<>();
for(Method m : methods){
s.add(m.getDeclaringClass());
}
List<Class<?>> l = new ArrayList<>(s);
Collections.sort(l, new Comparator<Class<?>>() {
@Override
public int compare(Class<?> aClass, Class<?> t1) {
return aClass.getName().compareTo(t1.getName());
}
});
int count = 0;
for(Class<?> c : l){
if(!BaseDL4JTest.class.isAssignableFrom(c) && !exclusions.contains(c)){
log.error("Test {} does not extend BaseDL4JTest (directly or indirectly). All tests must extend this class for proper memory tracking and timeouts", c);
count++;
}
}
assertEquals("Number of tests not extending BaseDL4JTest", 0, count);
}
}

View File

@ -17,7 +17,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
@Ignore @Ignore
public class RandomTests { public class RandomTests extends BaseDL4JTest {
@Test @Test
public void testReproduce() throws Exception { public void testReproduce() throws Exception {

View File

@ -16,11 +16,12 @@
package org.deeplearning4j.datasets; package org.deeplearning4j.datasets;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.fetchers.Cifar10Fetcher; import org.deeplearning4j.datasets.fetchers.Cifar10Fetcher;
import org.deeplearning4j.datasets.fetchers.TinyImageNetFetcher; import org.deeplearning4j.datasets.fetchers.TinyImageNetFetcher;
import org.junit.Test; import org.junit.Test;
public class TestDataSets { public class TestDataSets extends BaseDL4JTest {
@Test @Test
public void testTinyImageNetExists() throws Exception { public void testTinyImageNetExists() throws Exception {

View File

@ -1006,9 +1006,9 @@ public class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
for (RecordMetaData m : meta) { for (RecordMetaData m : meta) {
Record r = csv.loadFromMetaData(m); Record r = csv.loadFromMetaData(m);
INDArray row = ds.getFeatures().getRow(i); INDArray row = ds.getFeatures().getRow(i);
if(i <= 3) { // if(i <= 3) {
System.out.println(m.getLocation() + "\t" + r.getRecord() + "\t" + row); // System.out.println(m.getLocation() + "\t" + r.getRecord() + "\t" + row);
} // }
for (int j = 0; j < 4; j++) { for (int j = 0; j < 4; j++) {
double exp = r.getRecord().get(j).toDouble(); double exp = r.getRecord().get(j).toDouble();
@ -1017,7 +1017,7 @@ public class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
} }
i++; i++;
} }
System.out.println(); // System.out.println();
DataSet fromMeta = rrdsi.loadFromMetaData(meta); DataSet fromMeta = rrdsi.loadFromMetaData(meta);
assertEquals(ds, fromMeta); assertEquals(ds, fromMeta);

View File

@ -19,6 +19,7 @@ package org.deeplearning4j.datasets.iterator;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import lombok.var; import lombok.var;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.tools.SimpleVariableGenerator; import org.deeplearning4j.datasets.iterator.tools.SimpleVariableGenerator;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.dataset.api.DataSet;
@ -31,7 +32,7 @@ import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
@Slf4j @Slf4j
public class DummyBlockDataSetIteratorTests { public class DummyBlockDataSetIteratorTests extends BaseDL4JTest {
@Test @Test
public void testBlock_1() throws Exception { public void testBlock_1() throws Exception {

View File

@ -18,13 +18,14 @@ package org.deeplearning4j.datasets.iterator;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.tools.DataSetGenerator; import org.deeplearning4j.datasets.iterator.tools.DataSetGenerator;
import org.junit.Test; import org.junit.Test;
import static org.junit.Assert.*; import static org.junit.Assert.*;
@Slf4j @Slf4j
public class JointMultiDataSetIteratorTests { public class JointMultiDataSetIteratorTests extends BaseDL4JTest {
@Test (timeout = 20000L) @Test (timeout = 20000L)
public void testJMDSI_1() { public void testJMDSI_1() {

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.datasets.iterator; package org.deeplearning4j.datasets.iterator;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.loader.DataSetLoaderIterator; import org.deeplearning4j.datasets.iterator.loader.DataSetLoaderIterator;
import org.deeplearning4j.datasets.iterator.loader.MultiDataSetLoaderIterator; import org.deeplearning4j.datasets.iterator.loader.MultiDataSetLoaderIterator;
import org.junit.Test; import org.junit.Test;
@ -37,7 +38,7 @@ import java.util.Random;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
public class LoaderIteratorTests { public class LoaderIteratorTests extends BaseDL4JTest {
@Test @Test
public void testDSLoaderIter(){ public void testDSLoaderIter(){

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.nn.graph.graphnodes; package org.deeplearning4j.nn.graph.graphnodes;
import lombok.val; import lombok.val;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
@ -54,7 +55,7 @@ import java.util.Map;
import static org.junit.Assert.*; import static org.junit.Assert.*;
public class TestGraphNodes { public class TestGraphNodes extends BaseDL4JTest {
@Test @Test
public void testMergeNode() { public void testMergeNode() {

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.nn.layers; package org.deeplearning4j.nn.layers;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.misc.RepeatVector; import org.deeplearning4j.nn.conf.layers.misc.RepeatVector;
@ -32,7 +33,7 @@ import java.util.Arrays;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
public class RepeatVectorTest { public class RepeatVectorTest extends BaseDL4JTest {
private int REPEAT = 4; private int REPEAT = 4;

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.nn.layers.convolution; package org.deeplearning4j.nn.layers.convolution;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.ConvolutionMode;
@ -37,7 +38,7 @@ import static org.junit.Assert.assertTrue;
/** /**
* @author Max Pumperla * @author Max Pumperla
*/ */
public class Convolution3DTest { public class Convolution3DTest extends BaseDL4JTest {
private int nExamples = 1; private int nExamples = 1;
private int nChannelsOut = 1; private int nChannelsOut = 1;

View File

@ -731,7 +731,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
// https://github.com/eclipse/deeplearning4j/issues/8663 // https://github.com/eclipse/deeplearning4j/issues/8663
//The embedding layer weight initialization should be independent of the vocabulary size (nIn setting) //The embedding layer weight initialization should be independent of the vocabulary size (nIn setting)
for(WeightInit wi : new WeightInit[]{WeightInit.XAVIER, WeightInit.RELU, WeightInit.XAVIER_UNIFORM, WeightInit.LECUN_NORMAL, WeightInit.VAR_SCALING_NORMAL_FAN_OUT}) { for(WeightInit wi : new WeightInit[]{WeightInit.XAVIER, WeightInit.RELU, WeightInit.XAVIER_UNIFORM, WeightInit.LECUN_NORMAL}) {
for (boolean seq : new boolean[]{false, true}) { for (boolean seq : new boolean[]{false, true}) {
@ -771,7 +771,9 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
INDArray p1 = net.params(); INDArray p1 = net.params();
INDArray p2 = net2.params(); INDArray p2 = net2.params();
INDArray p3 = net3.params(); INDArray p3 = net3.params();
assertEquals(p1, p2); boolean eq = p1.equalsWithEps(p2, 1e-4);
String str = (seq ? "EmbeddingSequenceLayer" : "EmbeddingLayer") + " - " + wi;
assertTrue(str + " p1/p2 params not equal", eq);
double m1 = p1.meanNumber().doubleValue(); double m1 = p1.meanNumber().doubleValue();
double s1 = p1.stdNumber().doubleValue(); double s1 = p1.stdNumber().doubleValue();
@ -779,7 +781,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
double m3 = p3.meanNumber().doubleValue(); double m3 = p3.meanNumber().doubleValue();
double s3 = p3.stdNumber().doubleValue(); double s3 = p3.stdNumber().doubleValue();
String str = (seq ? "EmbeddingSequenceLayer" : "EmbeddingLayer") + " - " + wi;
assertEquals(str, m1, m3, 0.1); assertEquals(str, m1, m3, 0.1);
assertEquals(str, s1, s3, 0.1); assertEquals(str, s1, s3, 0.1);

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.nn.layers.ocnn; package org.deeplearning4j.nn.layers.ocnn;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.deeplearning4j.gradientcheck.GradientCheckUtil; import org.deeplearning4j.gradientcheck.GradientCheckUtil;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
@ -51,7 +52,7 @@ import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
public class OCNNOutputLayerTest { public class OCNNOutputLayerTest extends BaseDL4JTest {
private static final boolean PRINT_RESULTS = true; private static final boolean PRINT_RESULTS = true;
private static final boolean RETURN_ON_FIRST_FAILURE = false; private static final boolean RETURN_ON_FIRST_FAILURE = false;

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.nn.layers.recurrent; package org.deeplearning4j.nn.layers.recurrent;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.distribution.UniformDistribution; import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
import org.deeplearning4j.nn.conf.layers.GravesLSTM; import org.deeplearning4j.nn.conf.layers.GravesLSTM;
@ -27,7 +28,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
public class TestRecurrentWeightInit { public class TestRecurrentWeightInit extends BaseDL4JTest {
@Test @Test
public void testRWInit() { public void testRWInit() {

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.nn.misc; package org.deeplearning4j.nn.misc;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
@ -35,7 +36,7 @@ import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
@Ignore //Ignored due to very large memory requirements @Ignore //Ignored due to very large memory requirements
public class LargeNetTest { public class LargeNetTest extends BaseDL4JTest {
@Ignore @Ignore
@Test @Test

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.nn.updater.custom; package org.deeplearning4j.nn.updater.custom;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.BaseLayer; import org.deeplearning4j.nn.conf.layers.BaseLayer;
@ -35,7 +36,7 @@ import static org.junit.Assert.assertTrue;
/** /**
* Created by Alex on 09/05/2017. * Created by Alex on 09/05/2017.
*/ */
public class TestCustomUpdater { public class TestCustomUpdater extends BaseDL4JTest {
@Test @Test
public void testCustomUpdater() { public void testCustomUpdater() {

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.nn.weights; package org.deeplearning4j.nn.weights;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.conf.distribution.*; import org.deeplearning4j.nn.conf.distribution.*;
import org.deeplearning4j.nn.conf.serde.JsonMappers; import org.deeplearning4j.nn.conf.serde.JsonMappers;
import org.junit.After; import org.junit.After;
@ -40,7 +41,7 @@ import static org.junit.Assert.*;
* *
* @author Christian Skarby * @author Christian Skarby
*/ */
public class LegacyWeightInitTest { public class LegacyWeightInitTest extends BaseDL4JTest {
private RandomFactory prevFactory; private RandomFactory prevFactory;
private final static int SEED = 666; private final static int SEED = 666;

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.nn.weights; package org.deeplearning4j.nn.weights;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
@ -34,7 +35,7 @@ import static org.junit.Assert.assertEquals;
* *
* @author Christian Skarby * @author Christian Skarby
*/ */
public class WeightInitIdentityTest { public class WeightInitIdentityTest extends BaseDL4JTest {
/** /**
* Test identity mapping for 1d convolution * Test identity mapping for 1d convolution

View File

@ -1,5 +1,6 @@
package org.deeplearning4j.optimizer.listener; package org.deeplearning4j.optimizer.listener;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.optimize.listeners.CollectScoresIterationListener; import org.deeplearning4j.optimize.listeners.CollectScoresIterationListener;
import org.junit.Ignore; import org.junit.Ignore;
import org.junit.Test; import org.junit.Test;
@ -7,7 +8,7 @@ import org.junit.Test;
import java.util.List; import java.util.List;
import static org.junit.Assert.*; import static org.junit.Assert.*;
public class ScoreStatTest { public class ScoreStatTest extends BaseDL4JTest {
@Test @Test
public void testScoreStatSmall() { public void testScoreStatSmall() {
CollectScoresIterationListener.ScoreStat statTest = new CollectScoresIterationListener.ScoreStat(); CollectScoresIterationListener.ScoreStat statTest = new CollectScoresIterationListener.ScoreStat();

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.regressiontest; package org.deeplearning4j.regressiontest;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
@ -36,7 +37,7 @@ import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
public class MiscRegressionTests { public class MiscRegressionTests extends BaseDL4JTest {
@Test @Test
public void testFrozen() throws Exception { public void testFrozen() throws Exception {

View File

@ -3304,8 +3304,9 @@ public class Nd4j {
*/ */
public static INDArray randomExponential(double lambda, INDArray target) { public static INDArray randomExponential(double lambda, INDArray target) {
Preconditions.checkArgument(lambda > 0, "Lambda argument must be >= 0 - got %s", lambda); Preconditions.checkArgument(lambda > 0, "Lambda argument must be >= 0 - got %s", lambda);
INDArray shapeArr = Nd4j.create(ArrayUtil.toDouble(target.shape())); INDArray shapeArr = Nd4j.createFromArray(target.shape());
Nd4j.getExecutioner().execAndReturn(new RandomExponential(shapeArr, target, lambda)); RandomExponential r = new RandomExponential(shapeArr, target, lambda);
Nd4j.exec(r);
return target; return target;
} }

View File

@ -0,0 +1,83 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j;
import lombok.extern.slf4j.Slf4j;
import org.junit.Test;
import org.nd4j.imports.TFGraphs.TFGraphTestAllLibnd4j;
import org.nd4j.imports.TFGraphs.TFGraphTestAllSameDiff;
import org.nd4j.imports.TFGraphs.TFGraphTestList;
import org.nd4j.imports.TFGraphs.TFGraphTestZooModels;
import org.nd4j.imports.listeners.ImportModelDebugger;
import org.reflections.Reflections;
import org.reflections.scanners.MethodAnnotationsScanner;
import org.reflections.util.ClasspathHelper;
import org.reflections.util.ConfigurationBuilder;
import java.lang.reflect.Method;
import java.util.*;
import static org.junit.Assert.assertEquals;
/**
* This class checks that all test classes (i.e., anything with one or more methods annotated with @Test)
* extends BaseDl4jTest - either directly or indirectly.
* Other than a small set of exceptions, all tests must extend this
*
* @author Alex Black
*/
@Slf4j
public class AssertTestsExtendBaseClass extends BaseND4JTest {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
private static final Set<Class<?>> exclusions = new HashSet<>(Arrays.asList(
TFGraphTestAllSameDiff.class,
TFGraphTestAllLibnd4j.class,
TFGraphTestList.class,
TFGraphTestZooModels.class,
ImportModelDebugger.class //Run manually only, otherwise ignored
));
@Test
public void checkTestClasses(){
Reflections reflections = new Reflections(new ConfigurationBuilder()
.setUrls(ClasspathHelper.forPackage("org.nd4j"))
.setScanners(new MethodAnnotationsScanner()));
Set<Method> methods = reflections.getMethodsAnnotatedWith(Test.class);
Set<Class<?>> s = new HashSet<>();
for(Method m : methods){
s.add(m.getDeclaringClass());
}
List<Class<?>> l = new ArrayList<>(s);
l.sort(new Comparator<Class<?>>() {
@Override
public int compare(Class<?> aClass, Class<?> t1) {
return aClass.getName().compareTo(t1.getName());
}
});
int count = 0;
for(Class<?> c : l){
if(!BaseND4JTest.class.isAssignableFrom(c) && !exclusions.contains(c)){
log.error("Test {} does not extend BaseND4JTest (directly or indirectly). All tests must extend this class for proper memory tracking and timeouts", c);
count++;
}
}
assertEquals("Number of tests not extending BaseND4JTest", 0, count);
}
}

View File

@ -90,6 +90,7 @@ import java.util.Map;
* *
* @author Alex Black * @author Alex Black
*/ */
@Ignore
public class ImportModelDebugger { public class ImportModelDebugger {
@Test @Test

View File

@ -19,9 +19,11 @@ package org.nd4j.linalg.custom;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.ops.compat.CompatStringSplit; import org.nd4j.linalg.api.ops.compat.CompatStringSplit;
import org.nd4j.linalg.api.ops.util.PrintVariable; import org.nd4j.linalg.api.ops.util.PrintVariable;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotNull;
@ -30,7 +32,16 @@ import static org.junit.Assert.assertNotNull;
* This is special test suit: we test operations that on C++ side modify arrays that come from Java * This is special test suit: we test operations that on C++ side modify arrays that come from Java
*/ */
@Slf4j @Slf4j
public class ExpandableOpsTests { public class ExpandableOpsTests extends BaseNd4jTest {
public ExpandableOpsTests(Nd4jBackend backend) {
super(backend);
}
@Override
public char ordering() {
return 'c';
}
@Test @Test
public void testCompatStringSplit_1() throws Exception { public void testCompatStringSplit_1() throws Exception {

View File

@ -40,6 +40,11 @@ public class NormalizerStandardizeTest extends BaseNd4jTest {
super(backend); super(backend);
} }
@Override
public long getTimeoutMilliseconds() {
return 60_000L;
}
@Test @Test
public void testBruteForce() { public void testBruteForce() {
/* This test creates a dataset where feature values are multiples of consecutive natural numbers /* This test creates a dataset where feature values are multiples of consecutive natural numbers

View File

@ -1,14 +1,26 @@
package org.nd4j.linalg.dataset.api.preprocessor; package org.nd4j.linalg.dataset.api.preprocessor;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor; import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
public class CompositeDataSetPreProcessorTest { public class CompositeDataSetPreProcessorTest extends BaseNd4jTest {
public CompositeDataSetPreProcessorTest(Nd4jBackend backend) {
super(backend);
}
@Override
public char ordering() {
return 'c';
}
@Test(expected = NullPointerException.class) @Test(expected = NullPointerException.class)
public void when_preConditionsIsNull_expect_NullPointerException() { public void when_preConditionsIsNull_expect_NullPointerException() {
// Assemble // Assemble

View File

@ -1,15 +1,26 @@
package org.nd4j.linalg.dataset.api.preprocessor; package org.nd4j.linalg.dataset.api.preprocessor;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.Assert.*; import static org.junit.Assert.*;
public class CropAndResizeDataSetPreProcessorTest { public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTest {
public CropAndResizeDataSetPreProcessorTest(Nd4jBackend backend) {
super(backend);
}
@Override
public char ordering() {
return 'c';
}
@Test(expected = IllegalArgumentException.class) @Test(expected = IllegalArgumentException.class)
public void when_originalHeightIsZero_expect_IllegalArgumentException() { public void when_originalHeightIsZero_expect_IllegalArgumentException() {

View File

@ -1,14 +1,25 @@
package org.nd4j.linalg.dataset.api.preprocessor; package org.nd4j.linalg.dataset.api.preprocessor;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.dataset.api.preprocessor.PermuteDataSetPreProcessor; import org.nd4j.linalg.dataset.api.preprocessor.PermuteDataSetPreProcessor;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.Assert.*; import static org.junit.Assert.*;
public class PermuteDataSetPreProcessorTest { public class PermuteDataSetPreProcessorTest extends BaseNd4jTest {
public PermuteDataSetPreProcessorTest(Nd4jBackend backend) {
super(backend);
}
@Override
public char ordering() {
return 'c';
}
@Test(expected = NullPointerException.class) @Test(expected = NullPointerException.class)
public void when_dataSetIsNull_expect_NullPointerException() { public void when_dataSetIsNull_expect_NullPointerException() {

View File

@ -1,14 +1,25 @@
package org.nd4j.linalg.dataset.api.preprocessor; package org.nd4j.linalg.dataset.api.preprocessor;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
public class RGBtoGrayscaleDataSetPreProcessorTest { public class RGBtoGrayscaleDataSetPreProcessorTest extends BaseNd4jTest {
public RGBtoGrayscaleDataSetPreProcessorTest(Nd4jBackend backend) {
super(backend);
}
@Override
public char ordering() {
return 'c';
}
@Test(expected = NullPointerException.class) @Test(expected = NullPointerException.class)
public void when_dataSetIsNull_expect_NullPointerException() { public void when_dataSetIsNull_expect_NullPointerException() {

View File

@ -18,9 +18,11 @@ package org.nd4j.linalg.multithreading;
import lombok.val; import lombok.val;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashSet; import java.util.HashSet;
@ -30,7 +32,16 @@ import static org.junit.Assert.assertEquals;
/** /**
* @author raver119@gmail.com * @author raver119@gmail.com
*/ */
public class MultithreadedTests { public class MultithreadedTests extends BaseNd4jTest {
public MultithreadedTests(Nd4jBackend backend) {
super(backend);
}
@Override
public char ordering() {
return 'c';
}
@Test @Test
public void basicMigrationTest_1() throws Exception { public void basicMigrationTest_1() throws Exception {

View File

@ -22,6 +22,11 @@
<artifactId>nd4j-api</artifactId> <artifactId>nd4j-api</artifactId>
<version>${project.version}</version> <version>${project.version}</version>
</dependency> </dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
<version>${logback.version}</version>
</dependency>
</dependencies> </dependencies>

View File

@ -17,11 +17,10 @@
package org.nd4j; package org.nd4j;
import ch.qos.logback.classic.LoggerContext;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.Pointer;
import org.junit.After; import org.junit.*;
import org.junit.Before;
import org.junit.Rule;
import org.junit.rules.TestName; import org.junit.rules.TestName;
import org.junit.rules.Timeout; import org.junit.rules.Timeout;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
@ -31,6 +30,8 @@ import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.profiler.ProfilerConfig; import org.nd4j.linalg.profiler.ProfilerConfig;
import org.slf4j.ILoggerFactory;
import org.slf4j.LoggerFactory;
import java.lang.management.ManagementFactory; import java.lang.management.ManagementFactory;
import java.util.List; import java.util.List;
@ -142,6 +143,15 @@ public abstract class BaseND4JTest {
//Not really safe to continue testing under this situation... other tests will likely fail with obscure //Not really safe to continue testing under this situation... other tests will likely fail with obscure
// errors that are hard to track back to this // errors that are hard to track back to this
log.error("Open workspace leaked from test! Exiting - {}, isOpen = {} - {}", currWS.getId(), currWS.isScopeActive(), currWS); log.error("Open workspace leaked from test! Exiting - {}, isOpen = {} - {}", currWS.getId(), currWS.isScopeActive(), currWS);
System.out.println("Open workspace leaked from test! Exiting - " + currWS.getId() + ", isOpen = " + currWS.isScopeActive() + " - " + currWS);
System.out.flush();
//Try to flush logs also:
try{ Thread.sleep(1000); } catch (InterruptedException e){ }
ILoggerFactory lf = LoggerFactory.getILoggerFactory();
if( lf instanceof LoggerContext){
((LoggerContext)lf).stop();
}
try{ Thread.sleep(1000); } catch (InterruptedException e){ }
System.exit(1); System.exit(1);
} }