Various fixes (#290)
* Add check to ensure ALL tests extend BaseND4JTest for proper timeouts + logging Signed-off-by: Alex Black <blacka101@gmail.com> * Add 'must extend BaseDL4JTest' check for deeplearning4j-core Signed-off-by: Alex Black <blacka101@gmail.com> * Flush logging on workspace exit during tests Signed-off-by: Alex Black <blacka101@gmail.com>master
parent
2911da061b
commit
19d5a8d49d
|
@ -37,6 +37,10 @@
|
||||||
<artifactId>nd4j-api</artifactId>
|
<artifactId>nd4j-api</artifactId>
|
||||||
<version>${project.version}</version>
|
<version>${project.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>ch.qos.logback</groupId>
|
||||||
|
<artifactId>logback-classic</artifactId>
|
||||||
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j;
|
package org.deeplearning4j;
|
||||||
|
|
||||||
|
import ch.qos.logback.classic.LoggerContext;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.bytedeco.javacpp.Pointer;
|
import org.bytedeco.javacpp.Pointer;
|
||||||
import org.junit.After;
|
import org.junit.After;
|
||||||
|
@ -31,6 +32,8 @@ import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||||
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.profiler.ProfilerConfig;
|
import org.nd4j.linalg.profiler.ProfilerConfig;
|
||||||
|
import org.slf4j.ILoggerFactory;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
import java.lang.management.ManagementFactory;
|
import java.lang.management.ManagementFactory;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -139,6 +142,15 @@ public abstract class BaseDL4JTest {
|
||||||
//Not really safe to continue testing under this situation... other tests will likely fail with obscure
|
//Not really safe to continue testing under this situation... other tests will likely fail with obscure
|
||||||
// errors that are hard to track back to this
|
// errors that are hard to track back to this
|
||||||
log.error("Open workspace leaked from test! Exiting - {}, isOpen = {} - {}", currWS.getId(), currWS.isScopeActive(), currWS);
|
log.error("Open workspace leaked from test! Exiting - {}, isOpen = {} - {}", currWS.getId(), currWS.isScopeActive(), currWS);
|
||||||
|
System.out.println("Open workspace leaked from test! Exiting - " + currWS.getId() + ", isOpen = " + currWS.isScopeActive() + " - " + currWS);
|
||||||
|
System.out.flush();
|
||||||
|
//Try to flush logs also:
|
||||||
|
try{ Thread.sleep(1000); } catch (InterruptedException e){ }
|
||||||
|
ILoggerFactory lf = LoggerFactory.getILoggerFactory();
|
||||||
|
if( lf instanceof LoggerContext){
|
||||||
|
((LoggerContext)lf).stop();
|
||||||
|
}
|
||||||
|
try{ Thread.sleep(1000); } catch (InterruptedException e){ }
|
||||||
System.exit(1);
|
System.exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -164,6 +164,20 @@
|
||||||
<artifactId>oshi-core</artifactId>
|
<artifactId>oshi-core</artifactId>
|
||||||
<version>${oshi.version}</version>
|
<version>${oshi.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
|
<!-- Test scope reflections to ensure all classes extend base test class -->
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.reflections</groupId>
|
||||||
|
<artifactId>reflections</artifactId>
|
||||||
|
<version>${reflections.version}</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
<exclusions>
|
||||||
|
<exclusion>
|
||||||
|
<groupId>com.google.code.findbugs</groupId>
|
||||||
|
<artifactId>*</artifactId>
|
||||||
|
</exclusion>
|
||||||
|
</exclusions>
|
||||||
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
|
|
|
@ -0,0 +1,72 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.deeplearning4j;
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.reflections.Reflections;
|
||||||
|
import org.reflections.scanners.MethodAnnotationsScanner;
|
||||||
|
import org.reflections.util.ClasspathHelper;
|
||||||
|
import org.reflections.util.ConfigurationBuilder;
|
||||||
|
|
||||||
|
import java.lang.reflect.Method;
|
||||||
|
import java.util.*;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This class checks that all test classes (i.e., anything with one or more methods annotated with @Test)
|
||||||
|
* extends BaseDl4JTest - either directly or indirectly.
|
||||||
|
* Other than a small set of exceptions, all tests must extend this
|
||||||
|
*
|
||||||
|
* @author Alex Black
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public class AssertTestsExtendBaseClass extends BaseDL4JTest {
|
||||||
|
|
||||||
|
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
|
||||||
|
private static final Set<Class<?>> exclusions = new HashSet<>();
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void checkTestClasses(){
|
||||||
|
|
||||||
|
Reflections reflections = new Reflections(new ConfigurationBuilder()
|
||||||
|
.setUrls(ClasspathHelper.forPackage("org.deeplearning4j"))
|
||||||
|
.setScanners(new MethodAnnotationsScanner()));
|
||||||
|
Set<Method> methods = reflections.getMethodsAnnotatedWith(Test.class);
|
||||||
|
Set<Class<?>> s = new HashSet<>();
|
||||||
|
for(Method m : methods){
|
||||||
|
s.add(m.getDeclaringClass());
|
||||||
|
}
|
||||||
|
|
||||||
|
List<Class<?>> l = new ArrayList<>(s);
|
||||||
|
Collections.sort(l, new Comparator<Class<?>>() {
|
||||||
|
@Override
|
||||||
|
public int compare(Class<?> aClass, Class<?> t1) {
|
||||||
|
return aClass.getName().compareTo(t1.getName());
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
int count = 0;
|
||||||
|
for(Class<?> c : l){
|
||||||
|
if(!BaseDL4JTest.class.isAssignableFrom(c) && !exclusions.contains(c)){
|
||||||
|
log.error("Test {} does not extend BaseDL4JTest (directly or indirectly). All tests must extend this class for proper memory tracking and timeouts", c);
|
||||||
|
count++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assertEquals("Number of tests not extending BaseDL4JTest", 0, count);
|
||||||
|
}
|
||||||
|
}
|
|
@ -17,7 +17,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
import java.util.concurrent.CountDownLatch;
|
import java.util.concurrent.CountDownLatch;
|
||||||
|
|
||||||
@Ignore
|
@Ignore
|
||||||
public class RandomTests {
|
public class RandomTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testReproduce() throws Exception {
|
public void testReproduce() throws Exception {
|
||||||
|
|
|
@ -16,11 +16,12 @@
|
||||||
|
|
||||||
package org.deeplearning4j.datasets;
|
package org.deeplearning4j.datasets;
|
||||||
|
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.datasets.fetchers.Cifar10Fetcher;
|
import org.deeplearning4j.datasets.fetchers.Cifar10Fetcher;
|
||||||
import org.deeplearning4j.datasets.fetchers.TinyImageNetFetcher;
|
import org.deeplearning4j.datasets.fetchers.TinyImageNetFetcher;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
public class TestDataSets {
|
public class TestDataSets extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testTinyImageNetExists() throws Exception {
|
public void testTinyImageNetExists() throws Exception {
|
||||||
|
|
|
@ -1006,9 +1006,9 @@ public class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
|
||||||
for (RecordMetaData m : meta) {
|
for (RecordMetaData m : meta) {
|
||||||
Record r = csv.loadFromMetaData(m);
|
Record r = csv.loadFromMetaData(m);
|
||||||
INDArray row = ds.getFeatures().getRow(i);
|
INDArray row = ds.getFeatures().getRow(i);
|
||||||
if(i <= 3) {
|
// if(i <= 3) {
|
||||||
System.out.println(m.getLocation() + "\t" + r.getRecord() + "\t" + row);
|
// System.out.println(m.getLocation() + "\t" + r.getRecord() + "\t" + row);
|
||||||
}
|
// }
|
||||||
|
|
||||||
for (int j = 0; j < 4; j++) {
|
for (int j = 0; j < 4; j++) {
|
||||||
double exp = r.getRecord().get(j).toDouble();
|
double exp = r.getRecord().get(j).toDouble();
|
||||||
|
@ -1017,7 +1017,7 @@ public class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
i++;
|
i++;
|
||||||
}
|
}
|
||||||
System.out.println();
|
// System.out.println();
|
||||||
|
|
||||||
DataSet fromMeta = rrdsi.loadFromMetaData(meta);
|
DataSet fromMeta = rrdsi.loadFromMetaData(meta);
|
||||||
assertEquals(ds, fromMeta);
|
assertEquals(ds, fromMeta);
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.deeplearning4j.datasets.iterator;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import lombok.var;
|
import lombok.var;
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.datasets.iterator.tools.SimpleVariableGenerator;
|
import org.deeplearning4j.datasets.iterator.tools.SimpleVariableGenerator;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.dataset.api.DataSet;
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
|
@ -31,7 +32,7 @@ import static org.junit.Assert.assertNotNull;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class DummyBlockDataSetIteratorTests {
|
public class DummyBlockDataSetIteratorTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBlock_1() throws Exception {
|
public void testBlock_1() throws Exception {
|
||||||
|
|
|
@ -18,13 +18,14 @@ package org.deeplearning4j.datasets.iterator;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.datasets.iterator.tools.DataSetGenerator;
|
import org.deeplearning4j.datasets.iterator.tools.DataSetGenerator;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.Assert.*;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class JointMultiDataSetIteratorTests {
|
public class JointMultiDataSetIteratorTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test (timeout = 20000L)
|
@Test (timeout = 20000L)
|
||||||
public void testJMDSI_1() {
|
public void testJMDSI_1() {
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.datasets.iterator;
|
package org.deeplearning4j.datasets.iterator;
|
||||||
|
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.datasets.iterator.loader.DataSetLoaderIterator;
|
import org.deeplearning4j.datasets.iterator.loader.DataSetLoaderIterator;
|
||||||
import org.deeplearning4j.datasets.iterator.loader.MultiDataSetLoaderIterator;
|
import org.deeplearning4j.datasets.iterator.loader.MultiDataSetLoaderIterator;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
@ -37,7 +38,7 @@ import java.util.Random;
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
public class LoaderIteratorTests {
|
public class LoaderIteratorTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testDSLoaderIter(){
|
public void testDSLoaderIter(){
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.deeplearning4j.nn.graph.graphnodes;
|
package org.deeplearning4j.nn.graph.graphnodes;
|
||||||
|
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.nn.api.MaskState;
|
import org.deeplearning4j.nn.api.MaskState;
|
||||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
@ -54,7 +55,7 @@ import java.util.Map;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.Assert.*;
|
||||||
|
|
||||||
public class TestGraphNodes {
|
public class TestGraphNodes extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testMergeNode() {
|
public void testMergeNode() {
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.layers;
|
package org.deeplearning4j.nn.layers;
|
||||||
|
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.nn.api.Layer;
|
import org.deeplearning4j.nn.api.Layer;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.layers.misc.RepeatVector;
|
import org.deeplearning4j.nn.conf.layers.misc.RepeatVector;
|
||||||
|
@ -32,7 +33,7 @@ import java.util.Arrays;
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
public class RepeatVectorTest {
|
public class RepeatVectorTest extends BaseDL4JTest {
|
||||||
|
|
||||||
private int REPEAT = 4;
|
private int REPEAT = 4;
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.layers.convolution;
|
package org.deeplearning4j.nn.layers.convolution;
|
||||||
|
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||||
import org.deeplearning4j.nn.api.Layer;
|
import org.deeplearning4j.nn.api.Layer;
|
||||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||||
|
@ -37,7 +38,7 @@ import static org.junit.Assert.assertTrue;
|
||||||
/**
|
/**
|
||||||
* @author Max Pumperla
|
* @author Max Pumperla
|
||||||
*/
|
*/
|
||||||
public class Convolution3DTest {
|
public class Convolution3DTest extends BaseDL4JTest {
|
||||||
|
|
||||||
private int nExamples = 1;
|
private int nExamples = 1;
|
||||||
private int nChannelsOut = 1;
|
private int nChannelsOut = 1;
|
||||||
|
|
|
@ -731,7 +731,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
|
||||||
// https://github.com/eclipse/deeplearning4j/issues/8663
|
// https://github.com/eclipse/deeplearning4j/issues/8663
|
||||||
//The embedding layer weight initialization should be independent of the vocabulary size (nIn setting)
|
//The embedding layer weight initialization should be independent of the vocabulary size (nIn setting)
|
||||||
|
|
||||||
for(WeightInit wi : new WeightInit[]{WeightInit.XAVIER, WeightInit.RELU, WeightInit.XAVIER_UNIFORM, WeightInit.LECUN_NORMAL, WeightInit.VAR_SCALING_NORMAL_FAN_OUT}) {
|
for(WeightInit wi : new WeightInit[]{WeightInit.XAVIER, WeightInit.RELU, WeightInit.XAVIER_UNIFORM, WeightInit.LECUN_NORMAL}) {
|
||||||
|
|
||||||
for (boolean seq : new boolean[]{false, true}) {
|
for (boolean seq : new boolean[]{false, true}) {
|
||||||
|
|
||||||
|
@ -771,7 +771,9 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
|
||||||
INDArray p1 = net.params();
|
INDArray p1 = net.params();
|
||||||
INDArray p2 = net2.params();
|
INDArray p2 = net2.params();
|
||||||
INDArray p3 = net3.params();
|
INDArray p3 = net3.params();
|
||||||
assertEquals(p1, p2);
|
boolean eq = p1.equalsWithEps(p2, 1e-4);
|
||||||
|
String str = (seq ? "EmbeddingSequenceLayer" : "EmbeddingLayer") + " - " + wi;
|
||||||
|
assertTrue(str + " p1/p2 params not equal", eq);
|
||||||
|
|
||||||
double m1 = p1.meanNumber().doubleValue();
|
double m1 = p1.meanNumber().doubleValue();
|
||||||
double s1 = p1.stdNumber().doubleValue();
|
double s1 = p1.stdNumber().doubleValue();
|
||||||
|
@ -779,7 +781,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
|
||||||
double m3 = p3.meanNumber().doubleValue();
|
double m3 = p3.meanNumber().doubleValue();
|
||||||
double s3 = p3.stdNumber().doubleValue();
|
double s3 = p3.stdNumber().doubleValue();
|
||||||
|
|
||||||
String str = (seq ? "EmbeddingSequenceLayer" : "EmbeddingLayer") + " - " + wi;
|
|
||||||
|
|
||||||
assertEquals(str, m1, m3, 0.1);
|
assertEquals(str, m1, m3, 0.1);
|
||||||
assertEquals(str, s1, s3, 0.1);
|
assertEquals(str, s1, s3, 0.1);
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.layers.ocnn;
|
package org.deeplearning4j.nn.layers.ocnn;
|
||||||
|
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
|
||||||
import org.deeplearning4j.gradientcheck.GradientCheckUtil;
|
import org.deeplearning4j.gradientcheck.GradientCheckUtil;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
|
@ -51,7 +52,7 @@ import static org.junit.Assert.assertEquals;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
|
|
||||||
public class OCNNOutputLayerTest {
|
public class OCNNOutputLayerTest extends BaseDL4JTest {
|
||||||
|
|
||||||
private static final boolean PRINT_RESULTS = true;
|
private static final boolean PRINT_RESULTS = true;
|
||||||
private static final boolean RETURN_ON_FIRST_FAILURE = false;
|
private static final boolean RETURN_ON_FIRST_FAILURE = false;
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.layers.recurrent;
|
package org.deeplearning4j.nn.layers.recurrent;
|
||||||
|
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
|
import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
|
||||||
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
|
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
|
||||||
|
@ -27,7 +28,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
public class TestRecurrentWeightInit {
|
public class TestRecurrentWeightInit extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRWInit() {
|
public void testRWInit() {
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.misc;
|
package org.deeplearning4j.nn.misc;
|
||||||
|
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
@ -35,7 +36,7 @@ import static org.junit.Assert.assertArrayEquals;
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
@Ignore //Ignored due to very large memory requirements
|
@Ignore //Ignored due to very large memory requirements
|
||||||
public class LargeNetTest {
|
public class LargeNetTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Ignore
|
@Ignore
|
||||||
@Test
|
@Test
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.updater.custom;
|
package org.deeplearning4j.nn.updater.custom;
|
||||||
|
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.layers.BaseLayer;
|
import org.deeplearning4j.nn.conf.layers.BaseLayer;
|
||||||
|
@ -35,7 +36,7 @@ import static org.junit.Assert.assertTrue;
|
||||||
/**
|
/**
|
||||||
* Created by Alex on 09/05/2017.
|
* Created by Alex on 09/05/2017.
|
||||||
*/
|
*/
|
||||||
public class TestCustomUpdater {
|
public class TestCustomUpdater extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCustomUpdater() {
|
public void testCustomUpdater() {
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.weights;
|
package org.deeplearning4j.nn.weights;
|
||||||
|
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.nn.conf.distribution.*;
|
import org.deeplearning4j.nn.conf.distribution.*;
|
||||||
import org.deeplearning4j.nn.conf.serde.JsonMappers;
|
import org.deeplearning4j.nn.conf.serde.JsonMappers;
|
||||||
import org.junit.After;
|
import org.junit.After;
|
||||||
|
@ -40,7 +41,7 @@ import static org.junit.Assert.*;
|
||||||
*
|
*
|
||||||
* @author Christian Skarby
|
* @author Christian Skarby
|
||||||
*/
|
*/
|
||||||
public class LegacyWeightInitTest {
|
public class LegacyWeightInitTest extends BaseDL4JTest {
|
||||||
|
|
||||||
private RandomFactory prevFactory;
|
private RandomFactory prevFactory;
|
||||||
private final static int SEED = 666;
|
private final static int SEED = 666;
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.weights;
|
package org.deeplearning4j.nn.weights;
|
||||||
|
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
|
@ -34,7 +35,7 @@ import static org.junit.Assert.assertEquals;
|
||||||
*
|
*
|
||||||
* @author Christian Skarby
|
* @author Christian Skarby
|
||||||
*/
|
*/
|
||||||
public class WeightInitIdentityTest {
|
public class WeightInitIdentityTest extends BaseDL4JTest {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Test identity mapping for 1d convolution
|
* Test identity mapping for 1d convolution
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
package org.deeplearning4j.optimizer.listener;
|
package org.deeplearning4j.optimizer.listener;
|
||||||
|
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.optimize.listeners.CollectScoresIterationListener;
|
import org.deeplearning4j.optimize.listeners.CollectScoresIterationListener;
|
||||||
import org.junit.Ignore;
|
import org.junit.Ignore;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
@ -7,7 +8,7 @@ import org.junit.Test;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.Assert.*;
|
||||||
|
|
||||||
public class ScoreStatTest {
|
public class ScoreStatTest extends BaseDL4JTest {
|
||||||
@Test
|
@Test
|
||||||
public void testScoreStatSmall() {
|
public void testScoreStatSmall() {
|
||||||
CollectScoresIterationListener.ScoreStat statTest = new CollectScoresIterationListener.ScoreStat();
|
CollectScoresIterationListener.ScoreStat statTest = new CollectScoresIterationListener.ScoreStat();
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.deeplearning4j.regressiontest;
|
package org.deeplearning4j.regressiontest;
|
||||||
|
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
@ -36,7 +37,7 @@ import static org.junit.Assert.assertEquals;
|
||||||
import static org.junit.Assert.assertNotNull;
|
import static org.junit.Assert.assertNotNull;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
public class MiscRegressionTests {
|
public class MiscRegressionTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testFrozen() throws Exception {
|
public void testFrozen() throws Exception {
|
||||||
|
|
|
@ -3304,8 +3304,9 @@ public class Nd4j {
|
||||||
*/
|
*/
|
||||||
public static INDArray randomExponential(double lambda, INDArray target) {
|
public static INDArray randomExponential(double lambda, INDArray target) {
|
||||||
Preconditions.checkArgument(lambda > 0, "Lambda argument must be >= 0 - got %s", lambda);
|
Preconditions.checkArgument(lambda > 0, "Lambda argument must be >= 0 - got %s", lambda);
|
||||||
INDArray shapeArr = Nd4j.create(ArrayUtil.toDouble(target.shape()));
|
INDArray shapeArr = Nd4j.createFromArray(target.shape());
|
||||||
Nd4j.getExecutioner().execAndReturn(new RandomExponential(shapeArr, target, lambda));
|
RandomExponential r = new RandomExponential(shapeArr, target, lambda);
|
||||||
|
Nd4j.exec(r);
|
||||||
return target;
|
return target;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,83 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.nd4j;
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.nd4j.imports.TFGraphs.TFGraphTestAllLibnd4j;
|
||||||
|
import org.nd4j.imports.TFGraphs.TFGraphTestAllSameDiff;
|
||||||
|
import org.nd4j.imports.TFGraphs.TFGraphTestList;
|
||||||
|
import org.nd4j.imports.TFGraphs.TFGraphTestZooModels;
|
||||||
|
import org.nd4j.imports.listeners.ImportModelDebugger;
|
||||||
|
import org.reflections.Reflections;
|
||||||
|
import org.reflections.scanners.MethodAnnotationsScanner;
|
||||||
|
import org.reflections.util.ClasspathHelper;
|
||||||
|
import org.reflections.util.ConfigurationBuilder;
|
||||||
|
|
||||||
|
import java.lang.reflect.Method;
|
||||||
|
import java.util.*;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This class checks that all test classes (i.e., anything with one or more methods annotated with @Test)
|
||||||
|
* extends BaseDl4jTest - either directly or indirectly.
|
||||||
|
* Other than a small set of exceptions, all tests must extend this
|
||||||
|
*
|
||||||
|
* @author Alex Black
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public class AssertTestsExtendBaseClass extends BaseND4JTest {
|
||||||
|
|
||||||
|
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
|
||||||
|
private static final Set<Class<?>> exclusions = new HashSet<>(Arrays.asList(
|
||||||
|
TFGraphTestAllSameDiff.class,
|
||||||
|
TFGraphTestAllLibnd4j.class,
|
||||||
|
TFGraphTestList.class,
|
||||||
|
TFGraphTestZooModels.class,
|
||||||
|
ImportModelDebugger.class //Run manually only, otherwise ignored
|
||||||
|
));
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void checkTestClasses(){
|
||||||
|
|
||||||
|
Reflections reflections = new Reflections(new ConfigurationBuilder()
|
||||||
|
.setUrls(ClasspathHelper.forPackage("org.nd4j"))
|
||||||
|
.setScanners(new MethodAnnotationsScanner()));
|
||||||
|
Set<Method> methods = reflections.getMethodsAnnotatedWith(Test.class);
|
||||||
|
Set<Class<?>> s = new HashSet<>();
|
||||||
|
for(Method m : methods){
|
||||||
|
s.add(m.getDeclaringClass());
|
||||||
|
}
|
||||||
|
|
||||||
|
List<Class<?>> l = new ArrayList<>(s);
|
||||||
|
l.sort(new Comparator<Class<?>>() {
|
||||||
|
@Override
|
||||||
|
public int compare(Class<?> aClass, Class<?> t1) {
|
||||||
|
return aClass.getName().compareTo(t1.getName());
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
int count = 0;
|
||||||
|
for(Class<?> c : l){
|
||||||
|
if(!BaseND4JTest.class.isAssignableFrom(c) && !exclusions.contains(c)){
|
||||||
|
log.error("Test {} does not extend BaseND4JTest (directly or indirectly). All tests must extend this class for proper memory tracking and timeouts", c);
|
||||||
|
count++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assertEquals("Number of tests not extending BaseND4JTest", 0, count);
|
||||||
|
}
|
||||||
|
}
|
|
@ -90,6 +90,7 @@ import java.util.Map;
|
||||||
*
|
*
|
||||||
* @author Alex Black
|
* @author Alex Black
|
||||||
*/
|
*/
|
||||||
|
@Ignore
|
||||||
public class ImportModelDebugger {
|
public class ImportModelDebugger {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
|
@ -19,9 +19,11 @@ package org.nd4j.linalg.custom;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.BaseNd4jTest;
|
||||||
import org.nd4j.linalg.api.ops.compat.CompatStringSplit;
|
import org.nd4j.linalg.api.ops.compat.CompatStringSplit;
|
||||||
import org.nd4j.linalg.api.ops.util.PrintVariable;
|
import org.nd4j.linalg.api.ops.util.PrintVariable;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
import static org.junit.Assert.assertNotNull;
|
import static org.junit.Assert.assertNotNull;
|
||||||
|
@ -30,7 +32,16 @@ import static org.junit.Assert.assertNotNull;
|
||||||
* This is special test suit: we test operations that on C++ side modify arrays that come from Java
|
* This is special test suit: we test operations that on C++ side modify arrays that come from Java
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class ExpandableOpsTests {
|
public class ExpandableOpsTests extends BaseNd4jTest {
|
||||||
|
|
||||||
|
public ExpandableOpsTests(Nd4jBackend backend) {
|
||||||
|
super(backend);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public char ordering() {
|
||||||
|
return 'c';
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCompatStringSplit_1() throws Exception {
|
public void testCompatStringSplit_1() throws Exception {
|
||||||
|
|
|
@ -40,6 +40,11 @@ public class NormalizerStandardizeTest extends BaseNd4jTest {
|
||||||
super(backend);
|
super(backend);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getTimeoutMilliseconds() {
|
||||||
|
return 60_000L;
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBruteForce() {
|
public void testBruteForce() {
|
||||||
/* This test creates a dataset where feature values are multiples of consecutive natural numbers
|
/* This test creates a dataset where feature values are multiples of consecutive natural numbers
|
||||||
|
|
|
@ -1,14 +1,26 @@
|
||||||
package org.nd4j.linalg.dataset.api.preprocessor;
|
package org.nd4j.linalg.dataset.api.preprocessor;
|
||||||
|
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.BaseNd4jTest;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
|
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
|
|
||||||
import static org.junit.Assert.assertFalse;
|
import static org.junit.Assert.assertFalse;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
public class CompositeDataSetPreProcessorTest {
|
public class CompositeDataSetPreProcessorTest extends BaseNd4jTest {
|
||||||
|
|
||||||
|
public CompositeDataSetPreProcessorTest(Nd4jBackend backend) {
|
||||||
|
super(backend);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public char ordering() {
|
||||||
|
return 'c';
|
||||||
|
}
|
||||||
|
|
||||||
@Test(expected = NullPointerException.class)
|
@Test(expected = NullPointerException.class)
|
||||||
public void when_preConditionsIsNull_expect_NullPointerException() {
|
public void when_preConditionsIsNull_expect_NullPointerException() {
|
||||||
// Assemble
|
// Assemble
|
||||||
|
|
|
@ -1,15 +1,26 @@
|
||||||
package org.nd4j.linalg.dataset.api.preprocessor;
|
package org.nd4j.linalg.dataset.api.preprocessor;
|
||||||
|
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.BaseNd4jTest;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.Assert.*;
|
||||||
|
|
||||||
public class CropAndResizeDataSetPreProcessorTest {
|
public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTest {
|
||||||
|
|
||||||
|
public CropAndResizeDataSetPreProcessorTest(Nd4jBackend backend) {
|
||||||
|
super(backend);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public char ordering() {
|
||||||
|
return 'c';
|
||||||
|
}
|
||||||
|
|
||||||
@Test(expected = IllegalArgumentException.class)
|
@Test(expected = IllegalArgumentException.class)
|
||||||
public void when_originalHeightIsZero_expect_IllegalArgumentException() {
|
public void when_originalHeightIsZero_expect_IllegalArgumentException() {
|
||||||
|
|
|
@ -1,14 +1,25 @@
|
||||||
package org.nd4j.linalg.dataset.api.preprocessor;
|
package org.nd4j.linalg.dataset.api.preprocessor;
|
||||||
|
|
||||||
|
import org.nd4j.linalg.BaseNd4jTest;
|
||||||
import org.nd4j.linalg.dataset.api.preprocessor.PermuteDataSetPreProcessor;
|
import org.nd4j.linalg.dataset.api.preprocessor.PermuteDataSetPreProcessor;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.Assert.*;
|
||||||
|
|
||||||
public class PermuteDataSetPreProcessorTest {
|
public class PermuteDataSetPreProcessorTest extends BaseNd4jTest {
|
||||||
|
|
||||||
|
public PermuteDataSetPreProcessorTest(Nd4jBackend backend) {
|
||||||
|
super(backend);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public char ordering() {
|
||||||
|
return 'c';
|
||||||
|
}
|
||||||
|
|
||||||
@Test(expected = NullPointerException.class)
|
@Test(expected = NullPointerException.class)
|
||||||
public void when_dataSetIsNull_expect_NullPointerException() {
|
public void when_dataSetIsNull_expect_NullPointerException() {
|
||||||
|
|
|
@ -1,14 +1,25 @@
|
||||||
package org.nd4j.linalg.dataset.api.preprocessor;
|
package org.nd4j.linalg.dataset.api.preprocessor;
|
||||||
|
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.BaseNd4jTest;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
public class RGBtoGrayscaleDataSetPreProcessorTest {
|
public class RGBtoGrayscaleDataSetPreProcessorTest extends BaseNd4jTest {
|
||||||
|
|
||||||
|
public RGBtoGrayscaleDataSetPreProcessorTest(Nd4jBackend backend) {
|
||||||
|
super(backend);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public char ordering() {
|
||||||
|
return 'c';
|
||||||
|
}
|
||||||
|
|
||||||
@Test(expected = NullPointerException.class)
|
@Test(expected = NullPointerException.class)
|
||||||
public void when_dataSetIsNull_expect_NullPointerException() {
|
public void when_dataSetIsNull_expect_NullPointerException() {
|
||||||
|
|
|
@ -18,9 +18,11 @@ package org.nd4j.linalg.multithreading;
|
||||||
|
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.BaseNd4jTest;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
|
@ -30,7 +32,16 @@ import static org.junit.Assert.assertEquals;
|
||||||
/**
|
/**
|
||||||
* @author raver119@gmail.com
|
* @author raver119@gmail.com
|
||||||
*/
|
*/
|
||||||
public class MultithreadedTests {
|
public class MultithreadedTests extends BaseNd4jTest {
|
||||||
|
|
||||||
|
public MultithreadedTests(Nd4jBackend backend) {
|
||||||
|
super(backend);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public char ordering() {
|
||||||
|
return 'c';
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void basicMigrationTest_1() throws Exception {
|
public void basicMigrationTest_1() throws Exception {
|
||||||
|
|
|
@ -22,6 +22,11 @@
|
||||||
<artifactId>nd4j-api</artifactId>
|
<artifactId>nd4j-api</artifactId>
|
||||||
<version>${project.version}</version>
|
<version>${project.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>ch.qos.logback</groupId>
|
||||||
|
<artifactId>logback-classic</artifactId>
|
||||||
|
<version>${logback.version}</version>
|
||||||
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -17,11 +17,10 @@
|
||||||
|
|
||||||
package org.nd4j;
|
package org.nd4j;
|
||||||
|
|
||||||
|
import ch.qos.logback.classic.LoggerContext;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.bytedeco.javacpp.Pointer;
|
import org.bytedeco.javacpp.Pointer;
|
||||||
import org.junit.After;
|
import org.junit.*;
|
||||||
import org.junit.Before;
|
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.rules.TestName;
|
import org.junit.rules.TestName;
|
||||||
import org.junit.rules.Timeout;
|
import org.junit.rules.Timeout;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
@ -31,6 +30,8 @@ import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||||
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.profiler.ProfilerConfig;
|
import org.nd4j.linalg.profiler.ProfilerConfig;
|
||||||
|
import org.slf4j.ILoggerFactory;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
import java.lang.management.ManagementFactory;
|
import java.lang.management.ManagementFactory;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -142,6 +143,15 @@ public abstract class BaseND4JTest {
|
||||||
//Not really safe to continue testing under this situation... other tests will likely fail with obscure
|
//Not really safe to continue testing under this situation... other tests will likely fail with obscure
|
||||||
// errors that are hard to track back to this
|
// errors that are hard to track back to this
|
||||||
log.error("Open workspace leaked from test! Exiting - {}, isOpen = {} - {}", currWS.getId(), currWS.isScopeActive(), currWS);
|
log.error("Open workspace leaked from test! Exiting - {}, isOpen = {} - {}", currWS.getId(), currWS.isScopeActive(), currWS);
|
||||||
|
System.out.println("Open workspace leaked from test! Exiting - " + currWS.getId() + ", isOpen = " + currWS.isScopeActive() + " - " + currWS);
|
||||||
|
System.out.flush();
|
||||||
|
//Try to flush logs also:
|
||||||
|
try{ Thread.sleep(1000); } catch (InterruptedException e){ }
|
||||||
|
ILoggerFactory lf = LoggerFactory.getILoggerFactory();
|
||||||
|
if( lf instanceof LoggerContext){
|
||||||
|
((LoggerContext)lf).stop();
|
||||||
|
}
|
||||||
|
try{ Thread.sleep(1000); } catch (InterruptedException e){ }
|
||||||
System.exit(1);
|
System.exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue