diff --git a/arbiter/arbiter-core/pom.xml b/arbiter/arbiter-core/pom.xml
index 04a1fc0f0..6ce3c9c1f 100644
--- a/arbiter/arbiter-core/pom.xml
+++ b/arbiter/arbiter-core/pom.xml
@@ -84,6 +84,13 @@
jackson
${nd4j.version}
+
+
+ org.deeplearning4j
+ deeplearning4j-common-tests
+ ${project.version}
+ test
+
diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestGeneticSearch.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestGeneticSearch.java
index e7d07f81a..abfedac5e 100644
--- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestGeneticSearch.java
+++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestGeneticSearch.java
@@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize;
+import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition;
@@ -32,7 +33,7 @@ import org.deeplearning4j.arbiter.optimize.runner.listener.impl.LoggingStatusLis
import org.junit.Assert;
import org.junit.Test;
-public class TestGeneticSearch {
+public class TestGeneticSearch extends BaseDL4JTest {
public class TestSelectionOperator extends SelectionOperator {
@Override
diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestGridSearch.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestGridSearch.java
index 7f9906abf..24f80bf23 100644
--- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestGridSearch.java
+++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestGridSearch.java
@@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize;
+import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider;
import org.deeplearning4j.arbiter.optimize.generator.GridSearchCandidateGenerator;
@@ -26,7 +27,7 @@ import java.util.Map;
import static org.junit.Assert.*;
-public class TestGridSearch {
+public class TestGridSearch extends BaseDL4JTest {
@Test
public void testIndexing() {
diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestJson.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestJson.java
index 6f1b336bb..cf5b1e1c7 100644
--- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestJson.java
+++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestJson.java
@@ -19,6 +19,7 @@ package org.deeplearning4j.arbiter.optimize;
import org.apache.commons.math3.distribution.LogNormalDistribution;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.distribution.UniformIntegerDistribution;
+import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider;
@@ -49,7 +50,7 @@ import static org.junit.Assert.assertEquals;
/**
* Created by Alex on 02/02/2017.
*/
-public class TestJson {
+public class TestJson extends BaseDL4JTest {
protected static ObjectMapper getObjectMapper(JsonFactory factory) {
ObjectMapper om = new ObjectMapper(factory);
diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestRandomSearch.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestRandomSearch.java
index 305480420..34916ebdc 100644
--- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestRandomSearch.java
+++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestRandomSearch.java
@@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize;
+import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider;
import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition;
@@ -34,7 +35,7 @@ import java.util.Map;
* Test random search on the Branin Function:
* http://www.sfu.ca/~ssurjano/branin.html
*/
-public class TestRandomSearch {
+public class TestRandomSearch extends BaseDL4JTest {
@Test
public void test() throws Exception {
diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/distribution/TestLogUniform.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/distribution/TestLogUniform.java
index 85b048df7..6ca842637 100644
--- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/distribution/TestLogUniform.java
+++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/distribution/TestLogUniform.java
@@ -17,12 +17,13 @@
package org.deeplearning4j.arbiter.optimize.distribution;
import org.apache.commons.math3.distribution.RealDistribution;
+import org.deeplearning4j.BaseDL4JTest;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
-public class TestLogUniform {
+public class TestLogUniform extends BaseDL4JTest {
@Test
public void testSimple(){
diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/ArithmeticCrossoverTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/ArithmeticCrossoverTests.java
index 4342f32ef..252b4304f 100644
--- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/ArithmeticCrossoverTests.java
+++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/ArithmeticCrossoverTests.java
@@ -17,6 +17,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
import org.apache.commons.math3.random.RandomGenerator;
+import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.ArithmeticCrossover;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection;
@@ -24,7 +25,7 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
import org.junit.Assert;
import org.junit.Test;
-public class ArithmeticCrossoverTests {
+public class ArithmeticCrossoverTests extends BaseDL4JTest {
@Test
public void ArithmeticCrossover_Crossover_OutsideCrossoverRate_ShouldReturnParent0() {
diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/CrossoverOperatorTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/CrossoverOperatorTests.java
index 96e3fda7b..50ae7f729 100644
--- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/CrossoverOperatorTests.java
+++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/CrossoverOperatorTests.java
@@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
+import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
import org.deeplearning4j.arbiter.optimize.genetic.TestCrossoverOperator;
@@ -23,7 +24,7 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer;
import org.junit.Assert;
import org.junit.Test;
-public class CrossoverOperatorTests {
+public class CrossoverOperatorTests extends BaseDL4JTest {
@Test
public void CrossoverOperator_initializeInstance_ShouldInitPopulationModel() throws IllegalAccessException {
diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/CrossoverPointsGeneratorTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/CrossoverPointsGeneratorTests.java
index 0b1ec5271..5c1bebb51 100644
--- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/CrossoverPointsGeneratorTests.java
+++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/CrossoverPointsGeneratorTests.java
@@ -17,6 +17,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
import org.apache.commons.math3.random.RandomGenerator;
+import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.utils.CrossoverPointsGenerator;
import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
import org.junit.Assert;
@@ -24,7 +25,7 @@ import org.junit.Test;
import java.util.Deque;
-public class CrossoverPointsGeneratorTests {
+public class CrossoverPointsGeneratorTests extends BaseDL4JTest {
@Test
public void CrossoverPointsGenerator_FixedNumberCrossovers() {
diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/KPointCrossoverTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/KPointCrossoverTests.java
index fbafb37a5..2399d256a 100644
--- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/KPointCrossoverTests.java
+++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/KPointCrossoverTests.java
@@ -17,6 +17,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
import org.apache.commons.math3.random.RandomGenerator;
+import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.KPointCrossover;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection;
@@ -25,7 +26,7 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
import org.junit.Assert;
import org.junit.Test;
-public class KPointCrossoverTests {
+public class KPointCrossoverTests extends BaseDL4JTest {
@Test
public void KPointCrossover_BelowCrossoverRate_ShouldReturnParent0() {
diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/ParentSelectionTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/ParentSelectionTests.java
index ed2bec0ba..6976d8dd4 100644
--- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/ParentSelectionTests.java
+++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/ParentSelectionTests.java
@@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
+import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection;
import org.junit.Assert;
@@ -24,7 +25,7 @@ import org.junit.Test;
import java.util.ArrayList;
import java.util.List;
-public class ParentSelectionTests {
+public class ParentSelectionTests extends BaseDL4JTest {
@Test
public void ParentSelection_InitializeInstance_ShouldInitPopulation() {
diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/RandomTwoParentSelectionTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/RandomTwoParentSelectionTests.java
index 2e8a166a7..09b244ab3 100644
--- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/RandomTwoParentSelectionTests.java
+++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/RandomTwoParentSelectionTests.java
@@ -17,6 +17,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
import org.apache.commons.math3.random.RandomGenerator;
+import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.RandomTwoParentSelection;
import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
@@ -26,7 +27,7 @@ import org.junit.Test;
import java.util.ArrayList;
import java.util.List;
-public class RandomTwoParentSelectionTests {
+public class RandomTwoParentSelectionTests extends BaseDL4JTest {
@Test
public void RandomTwoParentSelection_ShouldReturnTwoDifferentParents() {
RandomGenerator rng = new TestRandomGenerator(new int[] {1, 1, 1, 0}, null);
diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/SinglePointCrossoverTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/SinglePointCrossoverTests.java
index 36c11e00d..32dfb136c 100644
--- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/SinglePointCrossoverTests.java
+++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/SinglePointCrossoverTests.java
@@ -17,6 +17,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
import org.apache.commons.math3.random.RandomGenerator;
+import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.SinglePointCrossover;
import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection;
@@ -24,7 +25,7 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
import org.junit.Assert;
import org.junit.Test;
-public class SinglePointCrossoverTests {
+public class SinglePointCrossoverTests extends BaseDL4JTest {
@Test
public void SinglePointCrossover_BelowCrossoverRate_ShouldReturnParent0() {
RandomGenerator rng = new TestRandomGenerator(null, new double[] {1.0});
diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/TwoParentsCrossoverOperatorTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/TwoParentsCrossoverOperatorTests.java
index 170984a68..9efe89620 100644
--- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/TwoParentsCrossoverOperatorTests.java
+++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/TwoParentsCrossoverOperatorTests.java
@@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
+import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.TwoParentsCrossoverOperator;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection;
@@ -27,7 +28,7 @@ import org.junit.Assert;
import org.junit.Test;
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
-public class TwoParentsCrossoverOperatorTests {
+public class TwoParentsCrossoverOperatorTests extends BaseDL4JTest {
class TestTwoParentsCrossoverOperator extends TwoParentsCrossoverOperator {
diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/UniformCrossoverTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/UniformCrossoverTests.java
index fc4524fdc..76a395c28 100644
--- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/UniformCrossoverTests.java
+++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/UniformCrossoverTests.java
@@ -17,6 +17,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
import org.apache.commons.math3.random.RandomGenerator;
+import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.UniformCrossover;
import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection;
@@ -24,7 +25,7 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
import org.junit.Assert;
import org.junit.Test;
-public class UniformCrossoverTests {
+public class UniformCrossoverTests extends BaseDL4JTest {
@Test
public void UniformCrossover_BelowCrossoverRate_ShouldReturnParent0() {
diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/culling/LeastFitCullOperatorTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/culling/LeastFitCullOperatorTests.java
index fbd5465c4..ccdb434e8 100644
--- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/culling/LeastFitCullOperatorTests.java
+++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/culling/LeastFitCullOperatorTests.java
@@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.culling;
+import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.LeastFitCullOperator;
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
@@ -27,7 +28,7 @@ import org.junit.Test;
import java.util.ArrayList;
import java.util.List;
-public class LeastFitCullOperatorTests {
+public class LeastFitCullOperatorTests extends BaseDL4JTest {
@Test
public void LeastFitCullingOperation_ShouldCullLastElements() {
diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/culling/RatioCullOperatorTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/culling/RatioCullOperatorTests.java
index 4e41268be..093ffd486 100644
--- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/culling/RatioCullOperatorTests.java
+++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/culling/RatioCullOperatorTests.java
@@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.culling;
+import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.RatioCullOperator;
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
@@ -27,7 +28,7 @@ import sun.reflect.generics.reflectiveObjects.NotImplementedException;
import java.util.List;
-public class RatioCullOperatorTests {
+public class RatioCullOperatorTests extends BaseDL4JTest {
class TestRatioCullOperator extends RatioCullOperator {
diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/mutation/RandomMutationOperatorTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/mutation/RandomMutationOperatorTests.java
index 80773c39e..38e2ba87b 100644
--- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/mutation/RandomMutationOperatorTests.java
+++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/mutation/RandomMutationOperatorTests.java
@@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.mutation;
+import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.mutation.RandomMutationOperator;
import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
import org.junit.Assert;
@@ -24,7 +25,7 @@ import org.junit.Test;
import java.lang.reflect.Field;
import java.util.Arrays;
-public class RandomMutationOperatorTests {
+public class RandomMutationOperatorTests extends BaseDL4JTest {
@Test
public void RandomMutationOperator_DefaultBuild_ShouldNotBeNull() {
RandomMutationOperator sut = new RandomMutationOperator.Builder().build();
diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/population/PopulationModelTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/population/PopulationModelTests.java
index 760b198b0..e185b8164 100644
--- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/population/PopulationModelTests.java
+++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/population/PopulationModelTests.java
@@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.population;
+import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.CullOperator;
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
@@ -27,7 +28,7 @@ import org.junit.Test;
import java.util.List;
-public class PopulationModelTests {
+public class PopulationModelTests extends BaseDL4JTest {
private class TestCullOperator implements CullOperator {
diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/selection/GeneticSelectionOperatorTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/selection/GeneticSelectionOperatorTests.java
index e5df87549..1d2b74de9 100644
--- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/selection/GeneticSelectionOperatorTests.java
+++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/selection/GeneticSelectionOperatorTests.java
@@ -17,6 +17,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.selection;
import org.apache.commons.math3.random.RandomGenerator;
+import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverOperator;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
@@ -36,7 +37,7 @@ import sun.reflect.generics.reflectiveObjects.NotImplementedException;
import static org.junit.Assert.assertArrayEquals;
-public class GeneticSelectionOperatorTests {
+public class GeneticSelectionOperatorTests extends BaseDL4JTest {
private class TestCullOperator implements CullOperator {
diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/selection/SelectionOperatorTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/selection/SelectionOperatorTests.java
index 5cc61d744..3f64279ee 100644
--- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/selection/SelectionOperatorTests.java
+++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/selection/SelectionOperatorTests.java
@@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.selection;
+import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
@@ -25,7 +26,7 @@ import org.junit.Assert;
import org.junit.Test;
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
-public class SelectionOperatorTests {
+public class SelectionOperatorTests extends BaseDL4JTest {
private class TestSelectionOperator extends SelectionOperator {
public PopulationModel getPopulationModel() {
diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/parameter/TestParameterSpaces.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/parameter/TestParameterSpaces.java
index 4a203f770..98396a941 100644
--- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/parameter/TestParameterSpaces.java
+++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/parameter/TestParameterSpaces.java
@@ -17,6 +17,7 @@
package org.deeplearning4j.arbiter.optimize.parameter;
import org.apache.commons.math3.distribution.NormalDistribution;
+import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.optimize.parameter.continuous.ContinuousParameterSpace;
import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterSpace;
@@ -25,7 +26,7 @@ import org.junit.Test;
import static org.junit.Assert.assertEquals;
-public class TestParameterSpaces {
+public class TestParameterSpaces extends BaseDL4JTest {
@Test
diff --git a/arbiter/arbiter-deeplearning4j/pom.xml b/arbiter/arbiter-deeplearning4j/pom.xml
index b163e2ae4..ec7e22d3c 100644
--- a/arbiter/arbiter-deeplearning4j/pom.xml
+++ b/arbiter/arbiter-deeplearning4j/pom.xml
@@ -63,6 +63,13 @@
gson
${gson.version}
+
+
+ org.deeplearning4j
+ deeplearning4j-common-tests
+ ${project.version}
+ test
+
diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestComputationGraphSpace.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestComputationGraphSpace.java
index af5d04c0a..7c4ec38f4 100644
--- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestComputationGraphSpace.java
+++ b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestComputationGraphSpace.java
@@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.computationgraph;
+import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.ComputationGraphSpace;
import org.deeplearning4j.arbiter.TestUtils;
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
@@ -44,7 +45,7 @@ import java.util.Random;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
-public class TestComputationGraphSpace {
+public class TestComputationGraphSpace extends BaseDL4JTest {
@Test
public void testBasic() {
diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecution.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecution.java
index c64a06040..1747b45f9 100644
--- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecution.java
+++ b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecution.java
@@ -17,6 +17,7 @@
package org.deeplearning4j.arbiter.computationgraph;
import lombok.extern.slf4j.Slf4j;
+import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.ComputationGraphSpace;
import org.deeplearning4j.arbiter.conf.updater.AdamSpace;
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
@@ -85,7 +86,7 @@ import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
@Slf4j
-public class TestGraphLocalExecution {
+public class TestGraphLocalExecution extends BaseDL4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@@ -126,7 +127,7 @@ public class TestGraphLocalExecution {
if(dataApproach == 0){
ds = TestDL4JLocalExecution.MnistDataSource.class;
dsP = new Properties();
- dsP.setProperty("minibatch", "8");
+ dsP.setProperty("minibatch", "2");
candidateGenerator = new RandomSearchGenerator(mls);
} else if(dataApproach == 1) {
//DataProvider approach
@@ -150,8 +151,8 @@ public class TestGraphLocalExecution {
.dataSource(ds, dsP)
.modelSaver(new FileModelSaver(modelSave))
.scoreFunction(new TestSetLossScoreFunction())
- .terminationConditions(new MaxTimeCondition(2, TimeUnit.MINUTES),
- new MaxCandidatesCondition(5))
+ .terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS),
+ new MaxCandidatesCondition(3))
.build();
IOptimizationRunner runner = new LocalOptimizationRunner(configuration,new ComputationGraphTaskCreator(new ClassificationEvaluator()));
@@ -159,7 +160,7 @@ public class TestGraphLocalExecution {
runner.execute();
List results = runner.getResults();
- assertEquals(5, results.size());
+ assertTrue(results.size() > 0);
System.out.println("----- COMPLETE - " + results.size() + " results -----");
}
@@ -203,8 +204,8 @@ public class TestGraphLocalExecution {
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
.candidateGenerator(candidateGenerator).dataProvider(dataProvider)
.modelSaver(new FileModelSaver(modelSavePath)).scoreFunction(ScoreFunctions.testSetLoss(true))
- .terminationConditions(new MaxTimeCondition(20, TimeUnit.SECONDS),
- new MaxCandidatesCondition(10))
+ .terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS),
+ new MaxCandidatesCondition(3))
.build();
IOptimizationRunner runner = new LocalOptimizationRunner(configuration,
@@ -223,7 +224,7 @@ public class TestGraphLocalExecution {
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.1)))
.l2(new ContinuousParameterSpace(0.0001, 0.01)).addInputs("in")
- .setInputTypes(InputType.feedForward(4))
+ .setInputTypes(InputType.feedForward(784))
.addLayer("layer0",
new DenseLayerSpace.Builder().nIn(784).nOut(new IntegerParameterSpace(2, 10))
.activation(new DiscreteParameterSpace<>(Activation.RELU, Activation.TANH))
@@ -250,8 +251,8 @@ public class TestGraphLocalExecution {
.candidateGenerator(candidateGenerator)
.dataProvider(new TestMdsDataProvider(1, 32))
.modelSaver(new FileModelSaver(modelSavePath)).scoreFunction(ScoreFunctions.testSetLoss(true))
- .terminationConditions(new MaxTimeCondition(20, TimeUnit.SECONDS),
- new MaxCandidatesCondition(10))
+ .terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS),
+ new MaxCandidatesCondition(3))
.scoreFunction(ScoreFunctions.testSetAccuracy())
.build();
@@ -279,7 +280,7 @@ public class TestGraphLocalExecution {
@Override
public Object trainData(Map dataParameters) {
try {
- DataSetIterator underlying = new MnistDataSetIterator(batchSize, Math.min(60000, 10 * batchSize), false, true, true, 12345);
+ DataSetIterator underlying = new MnistDataSetIterator(batchSize, Math.min(60000, 3 * batchSize), false, true, true, 12345);
return new MultiDataSetIteratorAdapter(new MultipleEpochsIterator(numEpochs, underlying));
} catch (IOException e) {
throw new RuntimeException(e);
@@ -289,7 +290,7 @@ public class TestGraphLocalExecution {
@Override
public Object testData(Map dataParameters) {
try {
- DataSetIterator underlying = new MnistDataSetIterator(batchSize, Math.min(10000, 5 * batchSize), false, false, false, 12345);
+ DataSetIterator underlying = new MnistDataSetIterator(batchSize, Math.min(10000, 2 * batchSize), false, false, false, 12345);
return new MultiDataSetIteratorAdapter(underlying);
} catch (IOException e) {
throw new RuntimeException(e);
@@ -305,7 +306,7 @@ public class TestGraphLocalExecution {
@Test
public void testLocalExecutionEarlyStopping() throws Exception {
EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder()
- .epochTerminationConditions(new MaxEpochsTerminationCondition(4))
+ .epochTerminationConditions(new MaxEpochsTerminationCondition(2))
.scoreCalculator(new ScoreProvider())
.modelSaver(new InMemoryModelSaver()).build();
Map commands = new HashMap<>();
@@ -348,8 +349,8 @@ public class TestGraphLocalExecution {
.dataProvider(dataProvider)
.scoreFunction(ScoreFunctions.testSetF1())
.modelSaver(new FileModelSaver(modelSavePath))
- .terminationConditions(new MaxTimeCondition(45, TimeUnit.SECONDS),
- new MaxCandidatesCondition(10))
+ .terminationConditions(new MaxTimeCondition(15, TimeUnit.SECONDS),
+ new MaxCandidatesCondition(3))
.build();
@@ -364,7 +365,7 @@ public class TestGraphLocalExecution {
@Override
public ScoreCalculator get() {
try {
- return new DataSetLossCalculatorCG(new MnistDataSetIterator(128, 1280), true);
+ return new DataSetLossCalculatorCG(new MnistDataSetIterator(4, 8), true);
} catch (Exception e){
throw new RuntimeException(e);
}
diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecutionGenetic.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecutionGenetic.java
index e67e854b8..2b9c5696d 100644
--- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecutionGenetic.java
+++ b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecutionGenetic.java
@@ -17,6 +17,7 @@
package org.deeplearning4j.arbiter.computationgraph;
import lombok.extern.slf4j.Slf4j;
+import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.ComputationGraphSpace;
import org.deeplearning4j.arbiter.conf.updater.AdamSpace;
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
@@ -79,11 +80,16 @@ import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
@Slf4j
-public class TestGraphLocalExecutionGenetic {
+public class TestGraphLocalExecutionGenetic extends BaseDL4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
+ @Override
+ public long getTimeoutMilliseconds() {
+ return 45000L;
+ }
+
@Test
public void testLocalExecutionDataSources() throws Exception {
for (int dataApproach = 0; dataApproach < 3; dataApproach++) {
@@ -115,7 +121,7 @@ public class TestGraphLocalExecutionGenetic {
if (dataApproach == 0) {
ds = TestDL4JLocalExecution.MnistDataSource.class;
dsP = new Properties();
- dsP.setProperty("minibatch", "8");
+ dsP.setProperty("minibatch", "2");
candidateGenerator = new GeneticSearchCandidateGenerator.Builder(mls, scoreFunction)
.populationModel(new PopulationModel.Builder().populationSize(5).build())
@@ -148,7 +154,7 @@ public class TestGraphLocalExecutionGenetic {
.dataSource(ds, dsP)
.modelSaver(new FileModelSaver(modelSave))
.scoreFunction(new TestSetLossScoreFunction())
- .terminationConditions(new MaxTimeCondition(2, TimeUnit.MINUTES),
+ .terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS),
new MaxCandidatesCondition(10))
.build();
@@ -157,7 +163,7 @@ public class TestGraphLocalExecutionGenetic {
runner.execute();
List results = runner.getResults();
- assertEquals(10, results.size());
+ assertTrue(results.size() > 0);
System.out.println("----- COMPLETE - " + results.size() + " results -----");
}
diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/json/TestJson.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/json/TestJson.java
index d6e77ccf4..12ccc71a5 100644
--- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/json/TestJson.java
+++ b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/json/TestJson.java
@@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.json;
+import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.ComputationGraphSpace;
import org.deeplearning4j.arbiter.MultiLayerSpace;
import org.deeplearning4j.arbiter.conf.updater.AdaMaxSpace;
@@ -71,7 +72,7 @@ import static org.junit.Assert.assertNotNull;
/**
* Created by Alex on 14/02/2017.
*/
-public class TestJson {
+public class TestJson extends BaseDL4JTest {
@Test
public void testMultiLayerSpaceJson() {
diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/MNISTOptimizationTest.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/MNISTOptimizationTest.java
index 501501e65..ea754990a 100644
--- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/MNISTOptimizationTest.java
+++ b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/MNISTOptimizationTest.java
@@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.multilayernetwork;
+import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.MultiLayerSpace;
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
import org.deeplearning4j.arbiter.layers.ConvolutionLayerSpace;
@@ -59,7 +60,7 @@ import java.util.concurrent.TimeUnit;
// import org.deeplearning4j.arbiter.optimize.ui.listener.UIOptimizationRunnerStatusListener;
/** Not strictly a unit test. Rather: part example, part debugging on MNIST */
-public class MNISTOptimizationTest {
+public class MNISTOptimizationTest extends BaseDL4JTest {
public static void main(String[] args) throws Exception {
EarlyStoppingConfiguration esConf =
diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestDL4JLocalExecution.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestDL4JLocalExecution.java
index 0f4d384ba..554ef346e 100644
--- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestDL4JLocalExecution.java
+++ b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestDL4JLocalExecution.java
@@ -17,6 +17,7 @@
package org.deeplearning4j.arbiter.multilayernetwork;
import lombok.extern.slf4j.Slf4j;
+import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.MultiLayerSpace;
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
import org.deeplearning4j.arbiter.evaluator.multilayer.ClassificationEvaluator;
@@ -72,9 +73,10 @@ import java.util.Properties;
import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
@Slf4j
-public class TestDL4JLocalExecution {
+public class TestDL4JLocalExecution extends BaseDL4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@@ -112,7 +114,7 @@ public class TestDL4JLocalExecution {
if(dataApproach == 0){
ds = MnistDataSource.class;
dsP = new Properties();
- dsP.setProperty("minibatch", "8");
+ dsP.setProperty("minibatch", "2");
candidateGenerator = new RandomSearchGenerator(mls);
} else if(dataApproach == 1) {
//DataProvider approach
@@ -136,7 +138,7 @@ public class TestDL4JLocalExecution {
.dataSource(ds, dsP)
.modelSaver(new FileModelSaver(modelSave))
.scoreFunction(new TestSetLossScoreFunction())
- .terminationConditions(new MaxTimeCondition(2, TimeUnit.MINUTES),
+ .terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS),
new MaxCandidatesCondition(5))
.build();
@@ -146,7 +148,7 @@ public class TestDL4JLocalExecution {
runner.execute();
List results = runner.getResults();
- assertEquals(5, results.size());
+ assertTrue(results.size() > 0);
System.out.println("----- COMPLETE - " + results.size() + " results -----");
}
diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestErrors.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestErrors.java
index bb8806f23..a62997a33 100644
--- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestErrors.java
+++ b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestErrors.java
@@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.multilayernetwork;
+import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.ComputationGraphSpace;
import org.deeplearning4j.arbiter.MultiLayerSpace;
import org.deeplearning4j.arbiter.layers.DenseLayerSpace;
@@ -39,7 +40,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.io.File;
-public class TestErrors {
+public class TestErrors extends BaseDL4JTest {
@Rule
public TemporaryFolder temp = new TemporaryFolder();
@@ -60,7 +61,7 @@ public class TestErrors {
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls);
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
- .candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 10))
+ .candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 3))
.modelSaver(new FileModelSaver(f)).scoreFunction(new TestSetLossScoreFunction(true))
.terminationConditions(
new MaxCandidatesCondition(5))
@@ -87,7 +88,7 @@ public class TestErrors {
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls);
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
- .candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 10))
+ .candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 3))
.modelSaver(new FileModelSaver(f)).scoreFunction(new TestSetLossScoreFunction(true))
.terminationConditions(
new MaxCandidatesCondition(5))
@@ -116,7 +117,7 @@ public class TestErrors {
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls);
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
- .candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 10))
+ .candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 3))
.modelSaver(new FileModelSaver(f)).scoreFunction(new TestSetLossScoreFunction(true))
.terminationConditions(new MaxCandidatesCondition(5))
.build();
@@ -143,7 +144,7 @@ public class TestErrors {
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls);
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
- .candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 10))
+ .candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 3))
.modelSaver(new FileModelSaver(f)).scoreFunction(new TestSetLossScoreFunction(true))
.terminationConditions(
new MaxCandidatesCondition(5))
diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestLayerSpace.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestLayerSpace.java
index e3338efcb..6a5458e65 100644
--- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestLayerSpace.java
+++ b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestLayerSpace.java
@@ -17,6 +17,7 @@
package org.deeplearning4j.arbiter.multilayernetwork;
import org.apache.commons.lang3.ArrayUtils;
+import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.TestUtils;
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
import org.deeplearning4j.arbiter.layers.*;
@@ -44,7 +45,7 @@ import java.util.Random;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
-public class TestLayerSpace {
+public class TestLayerSpace extends BaseDL4JTest {
@Test
public void testBasic1() {
diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestMultiLayerSpace.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestMultiLayerSpace.java
index 48055ed4b..99dc79f42 100644
--- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestMultiLayerSpace.java
+++ b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestMultiLayerSpace.java
@@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.multilayernetwork;
+import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.DL4JConfiguration;
import org.deeplearning4j.arbiter.MultiLayerSpace;
import org.deeplearning4j.arbiter.TestUtils;
@@ -86,7 +87,7 @@ import java.util.*;
import static org.junit.Assert.*;
-public class TestMultiLayerSpace {
+public class TestMultiLayerSpace extends BaseDL4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestScoreFunctions.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestScoreFunctions.java
index 01adefd18..f2dc5d180 100644
--- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestScoreFunctions.java
+++ b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestScoreFunctions.java
@@ -18,6 +18,7 @@ package org.deeplearning4j.arbiter.multilayernetwork;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
+import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.MultiLayerSpace;
import org.deeplearning4j.arbiter.conf.updater.AdamSpace;
import org.deeplearning4j.arbiter.layers.OutputLayerSpace;
@@ -60,7 +61,13 @@ import java.util.Map;
import static org.junit.Assert.assertEquals;
@Slf4j
-public class TestScoreFunctions {
+public class TestScoreFunctions extends BaseDL4JTest {
+
+
+ @Override
+ public long getTimeoutMilliseconds() {
+ return 60000L;
+ }
@Test
public void testROCScoreFunctions() throws Exception {
@@ -107,7 +114,7 @@ public class TestScoreFunctions {
List list = runner.getResults();
for (ResultReference rr : list) {
- DataSetIterator testIter = new MnistDataSetIterator(32, 2000, false, false, true, 12345);
+ DataSetIterator testIter = new MnistDataSetIterator(4, 16, false, false, false, 12345);
testIter.setPreProcessor(new PreProc(rocType));
OptimizationResult or = rr.getResult();
@@ -141,10 +148,10 @@ public class TestScoreFunctions {
}
- DataSetIterator iter = new MnistDataSetIterator(32, 8000, false, true, true, 12345);
+ DataSetIterator iter = new MnistDataSetIterator(4, 16, false, false, false, 12345);
iter.setPreProcessor(new PreProc(rocType));
- assertEquals(msg, expScore, or.getScore(), 1e-5);
+ assertEquals(msg, expScore, or.getScore(), 1e-4);
}
}
}
@@ -158,7 +165,7 @@ public class TestScoreFunctions {
@Override
public Object trainData(Map dataParameters) {
try {
- DataSetIterator iter = new MnistDataSetIterator(32, 8000, false, true, true, 12345);
+ DataSetIterator iter = new MnistDataSetIterator(4, 16, false, false, false, 12345);
iter.setPreProcessor(new PreProc(rocType));
return iter;
} catch (IOException e){
@@ -169,7 +176,7 @@ public class TestScoreFunctions {
@Override
public Object testData(Map dataParameters) {
try {
- DataSetIterator iter = new MnistDataSetIterator(32, 2000, false, false, true, 12345);
+ DataSetIterator iter = new MnistDataSetIterator(4, 16, false, false, false, 12345);
iter.setPreProcessor(new PreProc(rocType));
return iter;
} catch (IOException e){
diff --git a/arbiter/arbiter-server/pom.xml b/arbiter/arbiter-server/pom.xml
index bdea61138..c4306b967 100644
--- a/arbiter/arbiter-server/pom.xml
+++ b/arbiter/arbiter-server/pom.xml
@@ -49,6 +49,13 @@
${junit.version}
test
+
+
+ org.deeplearning4j
+ deeplearning4j-common-tests
+ ${project.version}
+ test
+
diff --git a/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/ArbiterCLIRunnerTest.java b/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/ArbiterCLIRunnerTest.java
index dd2b53409..21e4e402a 100644
--- a/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/ArbiterCLIRunnerTest.java
+++ b/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/ArbiterCLIRunnerTest.java
@@ -18,6 +18,7 @@ package org.deeplearning4j.arbiter.server;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
+import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.MultiLayerSpace;
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
import org.deeplearning4j.arbiter.layers.DenseLayerSpace;
@@ -52,7 +53,7 @@ import static org.junit.Assert.assertEquals;
* Created by agibsonccc on 3/12/17.
*/
@Slf4j
-public class ArbiterCLIRunnerTest {
+public class ArbiterCLIRunnerTest extends BaseDL4JTest {
@Test
diff --git a/deeplearning4j/deeplearning4j-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java b/deeplearning4j/deeplearning4j-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java
index 466a68e6a..e9b609c45 100644
--- a/deeplearning4j/deeplearning4j-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java
+++ b/deeplearning4j/deeplearning4j-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java
@@ -24,6 +24,7 @@ import org.junit.Before;
import org.junit.Rule;
import org.junit.rules.TestName;
import org.junit.rules.Timeout;
+import org.nd4j.base.Preconditions;
import org.nd4j.config.ND4JSystemProperties;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
@@ -36,6 +37,8 @@ import java.util.List;
import java.util.Map;
import java.util.Properties;
+import static org.junit.Assume.assumeTrue;
+
@Slf4j
public abstract class BaseDL4JTest {
@@ -47,6 +50,17 @@ public abstract class BaseDL4JTest {
protected long startTime;
protected int threadCountBefore;
+ private final int DEFAULT_THREADS = Runtime.getRuntime().availableProcessors();
+
+ /**
+ * Override this to specify the number of threads for C++ execution, via
+ * {@link org.nd4j.linalg.factory.Environment#setMaxMasterThreads(int)}
+ * @return Number of threads to use for C++ op execution
+ */
+ public int numThreads(){
+ return DEFAULT_THREADS;
+ }
+
/**
* Override this method to set the default timeout for methods in the test class
*/
@@ -72,6 +86,28 @@ public abstract class BaseDL4JTest {
return getDataType();
}
+ protected Boolean integrationTest;
+
+ /**
+ * @return True if integration tests maven profile is enabled, false otherwise.
+ */
+ public boolean isIntegrationTests(){
+ if(integrationTest == null){
+ String prop = System.getenv("DL4J_INTEGRATION_TESTS");
+ integrationTest = Boolean.parseBoolean(prop);
+ }
+ return integrationTest;
+ }
+
+ /**
+ * Call this as the first line of a test in order to skip that test, only when the integration tests maven profile is not enabled.
+ * This can be used to dynamically skip integration tests when the integration test profile is not enabled.
+ * Note that the integration test profile is not enabled by default - "integration-tests" profile
+ */
+ public void skipUnlessIntegrationTests(){
+ assumeTrue("Skipping integration test - integration profile is not enabled", isIntegrationTests());
+ }
+
@Before
public void beforeTest(){
log.info("{}.{}", getClass().getSimpleName(), name.getMethodName());
@@ -81,6 +117,14 @@ public abstract class BaseDL4JTest {
Nd4j.getExecutioner().setProfilingMode(getProfilingMode());
Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build());
Nd4j.setDefaultDataTypes(getDataType(), getDefaultFPDataType());
+ Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build());
+ Nd4j.getExecutioner().enableDebugMode(false);
+ Nd4j.getExecutioner().enableVerboseMode(false);
+ int numThreads = numThreads();
+ Preconditions.checkState(numThreads > 0, "Number of threads must be > 0");
+ if(numThreads != Nd4j.getEnvironment().maxMasterThreads()) {
+ Nd4j.getEnvironment().setMaxMasterThreads(numThreads);
+ }
startTime = System.currentTimeMillis();
threadCountBefore = ManagementFactory.getThreadMXBean().getThreadCount();
}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/LayerHelperValidationUtil.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/LayerHelperValidationUtil.java
index e3923c4ff..1b769a32d 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/LayerHelperValidationUtil.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/LayerHelperValidationUtil.java
@@ -158,7 +158,7 @@ public class LayerHelperValidationUtil {
double d2 = arr2.dup('c').getDouble(idx);
System.out.println("Different values at index " + idx + ": " + d1 + ", " + d2 + " - RE = " + maxRE);
}
- assertTrue(s + layerName + "activations - max RE: " + maxRE, maxRE < t.getMaxRelError());
+ assertTrue(s + layerName + " activations - max RE: " + maxRE, maxRE < t.getMaxRelError());
log.info("Forward pass, max relative error: " + layerName + " - " + maxRE);
}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/AsyncDataSetIteratorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/AsyncDataSetIteratorTest.java
index 8d8db1b8d..3e5015356 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/AsyncDataSetIteratorTest.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/AsyncDataSetIteratorTest.java
@@ -78,8 +78,16 @@ public class AsyncDataSetIteratorTest extends BaseDL4JTest {
@Test
public void hasNextWithResetAndLoad() throws Exception {
+ int[] prefetchSizes;
+ if(isIntegrationTests()){
+ prefetchSizes = new int[]{2, 3, 4, 5, 6, 7, 8};
+ } else {
+ prefetchSizes = new int[]{2, 3, 8};
+ }
+
+
for (int iter = 0; iter < ITERATIONS; iter++) {
- for (int prefetchSize = 2; prefetchSize <= 8; prefetchSize++) {
+ for(int prefetchSize : prefetchSizes){
AsyncDataSetIterator iterator = new AsyncDataSetIterator(backIterator, prefetchSize);
TestDataSetConsumer consumer = new TestDataSetConsumer(EXECUTION_SMALL);
int cnt = 0;
@@ -161,8 +169,14 @@ public class AsyncDataSetIteratorTest extends BaseDL4JTest {
@Test
public void testVariableTimeSeries1() throws Exception {
+ int numBatches = isIntegrationTests() ? 1000 : 100;
+ int batchSize = isIntegrationTests() ? 32 : 8;
+ int timeStepsMin = 10;
+ int timeStepsMax = isIntegrationTests() ? 500 : 100;
+ int valuesPerTimestep = isIntegrationTests() ? 128 : 16;
+
AsyncDataSetIterator adsi = new AsyncDataSetIterator(
- new VariableTimeseriesGenerator(1192, 1000, 32, 128, 10, 500, 10), 2, true);
+ new VariableTimeseriesGenerator(1192, numBatches, batchSize, valuesPerTimestep, timeStepsMin, timeStepsMax, 10), 2, true);
for (int e = 0; e < 10; e++) {
int cnt = 0;
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/AsyncMultiDataSetIteratorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/AsyncMultiDataSetIteratorTest.java
index 68381bec7..d6a0d1fe1 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/AsyncMultiDataSetIteratorTest.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/AsyncMultiDataSetIteratorTest.java
@@ -18,21 +18,10 @@ package org.deeplearning4j.datasets.iterator;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
-import org.datavec.api.records.reader.RecordReader;
-import org.datavec.api.records.reader.SequenceRecordReader;
-import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
-import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
-import org.datavec.api.split.NumberedFileInputSplit;
import org.deeplearning4j.BaseDL4JTest;
-import org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator;
import org.deeplearning4j.datasets.iterator.tools.VariableMultiTimeseriesGenerator;
import org.junit.Test;
import org.nd4j.linalg.dataset.api.MultiDataSet;
-import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
-import org.nd4j.linalg.dataset.api.preprocessor.MultiDataNormalization;
-import org.nd4j.linalg.dataset.api.preprocessor.MultiNormalizerStandardize;
-
-import java.util.Arrays;
import static org.junit.Assert.assertEquals;
@@ -49,7 +38,13 @@ public class AsyncMultiDataSetIteratorTest extends BaseDL4JTest {
*/
@Test
public void testVariableTimeSeries1() throws Exception {
- val iterator = new VariableMultiTimeseriesGenerator(1192, 1000, 32, 128, 10, 500, 10);
+ int numBatches = isIntegrationTests() ? 1000 : 100;
+ int batchSize = isIntegrationTests() ? 32 : 8;
+ int timeStepsMin = 10;
+ int timeStepsMax = isIntegrationTests() ? 500 : 100;
+ int valuesPerTimestep = isIntegrationTests() ? 128 : 16;
+
+ val iterator = new VariableMultiTimeseriesGenerator(1192, numBatches, batchSize, valuesPerTimestep, timeStepsMin, timeStepsMax, 10);
iterator.reset();
iterator.hasNext();
val amdsi = new AsyncMultiDataSetIterator(iterator, 2, true);
@@ -81,7 +76,13 @@ public class AsyncMultiDataSetIteratorTest extends BaseDL4JTest {
@Test
public void testVariableTimeSeries2() throws Exception {
- val iterator = new VariableMultiTimeseriesGenerator(1192, 1000, 32, 128, 10, 500, 10);
+ int numBatches = isIntegrationTests() ? 1000 : 100;
+ int batchSize = isIntegrationTests() ? 32 : 8;
+ int timeStepsMin = 10;
+ int timeStepsMax = isIntegrationTests() ? 500 : 100;
+ int valuesPerTimestep = isIntegrationTests() ? 128 : 16;
+
+ val iterator = new VariableMultiTimeseriesGenerator(1192, numBatches, batchSize, valuesPerTimestep, timeStepsMin, timeStepsMax, 10);
for (int e = 0; e < 10; e++) {
iterator.reset();
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/TestEmnistDataSetIterator.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/TestEmnistDataSetIterator.java
index b0d5e0d25..c5ac04901 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/TestEmnistDataSetIterator.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/TestEmnistDataSetIterator.java
@@ -46,17 +46,17 @@ public class TestEmnistDataSetIterator extends BaseDL4JTest {
@Test
public void testEmnistDataSetIterator() throws Exception {
- // EmnistFetcher fetcher = new EmnistFetcher(EmnistDataSetIterator.Set.COMPLETE);
- // File baseEmnistDir = fetcher.getFILE_DIR();
- // if(baseEmnistDir.exists()){
- // FileUtils.deleteDirectory(baseEmnistDir);
- // }
- // assertFalse(baseEmnistDir.exists());
-
int batchSize = 128;
- for (EmnistDataSetIterator.Set s : EmnistDataSetIterator.Set.values()) {
+ EmnistDataSetIterator.Set[] sets;
+ if(isIntegrationTests()){
+ sets = EmnistDataSetIterator.Set.values();
+ } else {
+ sets = new EmnistDataSetIterator.Set[]{EmnistDataSetIterator.Set.MNIST, EmnistDataSetIterator.Set.LETTERS};
+ }
+
+ for (EmnistDataSetIterator.Set s : sets) {
boolean isBalanced = EmnistDataSetIterator.isBalanced(s);
int numLabels = EmnistDataSetIterator.numLabels(s);
INDArray labelCounts = null;
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/EvalTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/EvalTest.java
index 43370548f..812ea2b08 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/EvalTest.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/EvalTest.java
@@ -476,7 +476,7 @@ public class EvalTest extends BaseDL4JTest {
net.setListeners(new EvaluativeListener(iterTest, 3));
- for( int i=0; i<10; i++ ){
+ for( int i=0; i<3; i++ ){
net.fit(iter);
}
}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN3DGradientCheckTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN3DGradientCheckTest.java
index cbad1adbb..9a4994d73 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN3DGradientCheckTest.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN3DGradientCheckTest.java
@@ -339,9 +339,6 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest {
if (PRINT_RESULTS) {
log.info(msg);
- for (int j = 0; j < net.getnLayers(); j++) {
- log.info("Layer " + j + " # params: " + net.getLayer(j).numParams());
- }
}
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS,
@@ -623,13 +620,10 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest {
if (PRINT_RESULTS) {
log.info(msg);
-// for (int j = 0; j < net.getnLayers(); j++) {
-// log.info("Layer " + j + " # params: " + net.getLayer(j).numParams());
-// }
}
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input)
- .labels(labels).subset(true).maxPerParam(128));
+ .labels(labels).subset(true).maxPerParam(64));
assertTrue(msg, gradOK);
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java
index 4ddb1ad40..329adbc9b 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java
@@ -557,60 +557,52 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
int[] minibatchSizes = {2};
int width = 5;
int height = 5;
- int[] inputDepths = {1, 2, 4};
-
- Activation[] activations = {Activation.SIGMOID, Activation.TANH};
Nd4j.getRandom().setSeed(12345);
- for (int inputDepth : inputDepths) {
- for (Activation afn : activations) {
- for (int minibatchSize : minibatchSizes) {
- INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth);
- INDArray labels = Nd4j.zeros(minibatchSize, nOut);
- for (int i = 0; i < minibatchSize; i++) {
- labels.putScalar(new int[]{i, i % nOut}, 1.0);
- }
+ int[] inputDepths = new int[]{1, 2, 4};
+ Activation[] activations = {Activation.SIGMOID, Activation.TANH, Activation.SOFTPLUS};
+ int[] minibatch = {2, 1, 3};
- MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(new NoOp())
- .dataType(DataType.DOUBLE)
- .activation(afn)
- .list()
- .layer(0, new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1)
- .padding(0, 0).nIn(inputDepth).nOut(2).build())//output: (5-2+0)/1+1 = 4
- .layer(1, new LocallyConnected2D.Builder().nIn(2).nOut(7).kernelSize(2, 2)
- .setInputSize(4, 4).convolutionMode(ConvolutionMode.Strict).hasBias(false)
- .stride(1, 1).padding(0, 0).build()) //(4-2+0)/1+1 = 3
- .layer(2, new ConvolutionLayer.Builder().nIn(7).nOut(2).kernelSize(2, 2)
- .stride(1, 1).padding(0, 0).build()) //(3-2+0)/1+1 = 2
- .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
- .activation(Activation.SOFTMAX).nIn(2 * 2 * 2).nOut(nOut)
- .build())
- .setInputType(InputType.convolutionalFlat(height, width, inputDepth)).build();
+ for( int i=0; i 0);
+// for (val word : words) {
+// System.out.println(word);
+// }
}
@Test
@@ -755,7 +774,16 @@ public class Word2VecTests extends BaseDL4JTest {
@Test
public void weightsNotUpdated_WhenLocked() throws Exception {
- SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
+ boolean isIntegration = isIntegrationTests();
+ SentenceIterator iter;
+ SentenceIterator iter2;
+ if(isIntegration){
+ iter = new BasicLineIterator(inputFile);
+ iter2 = new BasicLineIterator(inputFile2.getAbsolutePath());
+ } else {
+ iter = new CollectionSentenceIterator(firstNLines(inputFile, 300));
+ iter2 = new CollectionSentenceIterator(firstNLines(inputFile2, 300));
+ }
Word2Vec vec1 = new Word2Vec.Builder().minWordFrequency(1).iterations(3).batchSize(64).layerSize(100)
.stopWords(new ArrayList()).seed(42).learningRate(0.025).minLearningRate(0.001)
@@ -767,13 +795,12 @@ public class Word2VecTests extends BaseDL4JTest {
vec1.fit();
- iter = new BasicLineIterator(inputFile2.getAbsolutePath());
Word2Vec vec2 = new Word2Vec.Builder().minWordFrequency(1).iterations(3).batchSize(32).layerSize(100)
.stopWords(new ArrayList()).seed(32).learningRate(0.021).minLearningRate(0.001)
.sampling(0).elementsLearningAlgorithm(new SkipGram())
.epochs(1).windowSize(5).allowParallelTokenization(true)
.workers(1)
- .iterate(iter)
+ .iterate(iter2)
.intersectModel(vec1, true)
.modelUtils(new BasicModelUtils()).build();
@@ -861,6 +888,22 @@ public class Word2VecTests extends BaseDL4JTest {
}
System.out.print("\n");
}
- //
+
+ public static List firstNLines(File f, int n){
+ List lines = new ArrayList<>();
+ try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
+ LineIterator lineIter = IOUtils.lineIterator(is, StandardCharsets.UTF_8);
+ try{
+ for( int i=0; i cacheList = new ArrayList<>(); //cacheList is a dynamic array of strings used to hold all words
+ //STEP 1: Initialization
+ int iterations = 50;
+ //create an n-dimensional array of doubles
+ Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
+ List cacheList = new ArrayList<>(); //cacheList is a dynamic array of strings used to hold all words
- //STEP 2: Turn text input into a list of words
- INDArray weights;
- if(syntheticData){
- weights = Nd4j.rand(1000, 200);
- } else {
- log.info("Load & Vectorize data....");
- File wordFile = new ClassPathResource("deeplearning4j-tsne/words.txt").getFile(); //Open the file
- //Get the data of all unique word vectors
- Pair vectors = WordVectorSerializer.loadTxt(wordFile);
- VocabCache cache = vectors.getSecond();
- weights = vectors.getFirst().getSyn0(); //seperate weights of unique words into their own list
+ //STEP 2: Turn text input into a list of words
+ INDArray weights;
+ if(syntheticData){
+ weights = Nd4j.rand(250, 200);
+ } else {
+ log.info("Load & Vectorize data....");
+ File wordFile = new ClassPathResource("deeplearning4j-tsne/words.txt").getFile(); //Open the file
+ //Get the data of all unique word vectors
+ Pair vectors = WordVectorSerializer.loadTxt(wordFile);
+ VocabCache cache = vectors.getSecond();
+ weights = vectors.getFirst().getSyn0(); //seperate weights of unique words into their own list
- for (int i = 0; i < cache.numWords(); i++) //seperate strings of words into their own list
- cacheList.add(cache.wordAtIndex(i));
- }
-
- //STEP 3: build a dual-tree tsne to use later
- log.info("Build model....");
- BarnesHutTsne tsne = new BarnesHutTsne.Builder()
- .setMaxIter(iterations)
- .theta(0.5)
- .normalize(false)
- .learningRate(500)
- .useAdaGrad(false)
- .workspaceMode(wsm)
- .build();
-
-
- //STEP 4: establish the tsne values and save them to a file
- log.info("Store TSNE Coordinates for Plotting....");
- File outDir = testDir.newFolder();
- tsne.fit(weights);
- tsne.saveAsFile(cacheList, new File(outDir, "out.txt").getAbsolutePath());
+ for (int i = 0; i < cache.numWords(); i++) //seperate strings of words into their own list
+ cacheList.add(cache.wordAtIndex(i));
}
+
+ //STEP 3: build a dual-tree tsne to use later
+ log.info("Build model....");
+ BarnesHutTsne tsne = new BarnesHutTsne.Builder()
+ .setMaxIter(iterations)
+ .theta(0.5)
+ .normalize(false)
+ .learningRate(500)
+ .useAdaGrad(false)
+ .workspaceMode(wsm)
+ .build();
+
+
+ //STEP 4: establish the tsne values and save them to a file
+ log.info("Store TSNE Coordinates for Plotting....");
+ File outDir = testDir.newFolder();
+ tsne.fit(weights);
+ tsne.saveAsFile(cacheList, new File(outDir, "out.txt").getAbsolutePath());
}
}
- //Elapsed time : 01:01:57.988
@Test
public void testPerformance() throws Exception {
StopWatch watch = new StopWatch();
watch.start();
- for (boolean syntheticData : new boolean[]{false, true}) {
- for (WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.NONE, WorkspaceMode.ENABLED}) {
- log.info("Starting test: WSM={}, syntheticData={}", wsm, syntheticData);
+ for( int test=0; test <=1; test++){
+ boolean syntheticData = test == 1;
+ WorkspaceMode wsm = test == 0 ? WorkspaceMode.NONE : WorkspaceMode.ENABLED;
+ log.info("Starting test: WSM={}, syntheticData={}", wsm, syntheticData);
- //STEP 1: Initialization
- int iterations = 100;
- //create an n-dimensional array of doubles
- Nd4j.setDataType(DataType.DOUBLE);
- List cacheList = new ArrayList<>(); //cacheList is a dynamic array of strings used to hold all words
+ //STEP 1: Initialization
+ int iterations = 50;
+ //create an n-dimensional array of doubles
+ Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
+ List cacheList = new ArrayList<>(); //cacheList is a dynamic array of strings used to hold all words
- //STEP 2: Turn text input into a list of words
- INDArray weights;
- if(syntheticData){
- weights = Nd4j.rand(5000, 20);
- } else {
- log.info("Load & Vectorize data....");
- File wordFile = new ClassPathResource("deeplearning4j-tsne/words.txt").getFile(); //Open the file
- //Get the data of all unique word vectors
- Pair vectors = WordVectorSerializer.loadTxt(wordFile);
- VocabCache cache = vectors.getSecond();
- weights = vectors.getFirst().getSyn0(); //seperate weights of unique words into their own list
+ //STEP 2: Turn text input into a list of words
+ INDArray weights;
+ if(syntheticData){
+ weights = Nd4j.rand(DataType.FLOAT, 250, 20);
+ } else {
+ log.info("Load & Vectorize data....");
+ File wordFile = new ClassPathResource("deeplearning4j-tsne/words.txt").getFile(); //Open the file
+ //Get the data of all unique word vectors
+ Pair vectors = WordVectorSerializer.loadTxt(wordFile);
+ VocabCache cache = vectors.getSecond();
+ weights = vectors.getFirst().getSyn0(); //seperate weights of unique words into their own list
- for (int i = 0; i < cache.numWords(); i++) //seperate strings of words into their own list
- cacheList.add(cache.wordAtIndex(i));
- }
-
- //STEP 3: build a dual-tree tsne to use later
- log.info("Build model....");
- BarnesHutTsne tsne = new BarnesHutTsne.Builder()
- .setMaxIter(iterations)
- .theta(0.5)
- .normalize(false)
- .learningRate(500)
- .useAdaGrad(false)
- .workspaceMode(wsm)
- .build();
-
-
- //STEP 4: establish the tsne values and save them to a file
- log.info("Store TSNE Coordinates for Plotting....");
- File outDir = testDir.newFolder();
- tsne.fit(weights);
- tsne.saveAsFile(cacheList, new File(outDir, "out.txt").getAbsolutePath());
+ for (int i = 0; i < cache.numWords(); i++) //seperate strings of words into their own list
+ cacheList.add(cache.wordAtIndex(i));
}
+
+ //STEP 3: build a dual-tree tsne to use later
+ log.info("Build model....");
+ BarnesHutTsne tsne = new BarnesHutTsne.Builder()
+ .setMaxIter(iterations)
+ .theta(0.5)
+ .normalize(false)
+ .learningRate(500)
+ .useAdaGrad(false)
+ .workspaceMode(wsm)
+ .build();
+
+
+ //STEP 4: establish the tsne values and save them to a file
+ log.info("Store TSNE Coordinates for Plotting....");
+ File outDir = testDir.newFolder();
+ tsne.fit(weights);
+ tsne.saveAsFile(cacheList, new File(outDir, "out.txt").getAbsolutePath());
}
watch.stop();
System.out.println("Elapsed time : " + watch);
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectorsTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectorsTest.java
index c3e8ac89b..95cd4e9a6 100644
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectorsTest.java
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectorsTest.java
@@ -20,6 +20,8 @@ package org.deeplearning4j.models.paragraphvectors;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
+import org.apache.commons.io.IOUtils;
+import org.apache.commons.io.LineIterator;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.models.embeddings.learning.impl.elements.CBOW;
import org.deeplearning4j.models.embeddings.reader.impl.FlatModelUtils;
@@ -27,6 +29,7 @@ import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer;
import org.deeplearning4j.models.sequencevectors.transformers.impl.iterables.BasicTransformerIterator;
import org.deeplearning4j.models.sequencevectors.transformers.impl.iterables.ParallelTransformerIterator;
+import org.deeplearning4j.text.sentenceiterator.*;
import org.junit.Rule;
import org.junit.rules.TemporaryFolder;
import org.nd4j.linalg.io.ClassPathResource;
@@ -46,10 +49,6 @@ import org.deeplearning4j.text.documentiterator.FileLabelAwareIterator;
import org.deeplearning4j.text.documentiterator.LabelAwareIterator;
import org.deeplearning4j.text.documentiterator.LabelledDocument;
import org.deeplearning4j.text.documentiterator.LabelsSource;
-import org.deeplearning4j.text.sentenceiterator.AggregatingSentenceIterator;
-import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
-import org.deeplearning4j.text.sentenceiterator.FileSentenceIterator;
-import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.sentenceiterator.interoperability.SentenceIteratorConverter;
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
@@ -66,8 +65,8 @@ import org.nd4j.resources.Resources;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import java.io.File;
-import java.io.IOException;
+import java.io.*;
+import java.nio.charset.StandardCharsets;
import java.util.*;
import java.util.concurrent.atomic.AtomicLong;
@@ -372,7 +371,7 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
LabelsSource source = new LabelsSource("DOC_");
- ParagraphVectors vec = new ParagraphVectors.Builder().minWordFrequency(1).iterations(2).seed(119).epochs(3)
+ ParagraphVectors vec = new ParagraphVectors.Builder().minWordFrequency(1).iterations(2).seed(119).epochs(1)
.layerSize(100).learningRate(0.025).labelsSource(source).windowSize(5).iterate(iter)
.trainWordVectors(true).vocabCache(cache).tokenizerFactory(t).negativeSample(0)
.useHierarchicSoftmax(true).sampling(0).workers(1).usePreciseWeightInit(true)
@@ -425,6 +424,8 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
@Test(timeout = 300000)
public void testParagraphVectorsDBOW() throws Exception {
+ skipUnlessIntegrationTests();
+
File file = Resources.asFile("/big/raw_sentences.txt");
SentenceIterator iter = new BasicLineIterator(file);
@@ -657,7 +658,7 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
}
}
- @Test(timeout = 300000)
+ @Test
public void testIterator() throws IOException {
val folder_labeled = testDir.newFolder();
val folder_unlabeled = testDir.newFolder();
@@ -672,7 +673,7 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
SentenceIterator iter = new BasicLineIterator(resource_sentences);
int i = 0;
- for (; i < 10000; ++i) {
+ for (; i < 10; ++i) {
int j = 0;
int labels = 0;
int words = 0;
@@ -721,7 +722,7 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
- Word2Vec wordVectors = new Word2Vec.Builder().seed(119).minWordFrequency(1).batchSize(250).iterations(1).epochs(3)
+ Word2Vec wordVectors = new Word2Vec.Builder().seed(119).minWordFrequency(1).batchSize(250).iterations(1).epochs(1)
.learningRate(0.025).layerSize(150).minLearningRate(0.001)
.elementsLearningAlgorithm(new SkipGram()).useHierarchicSoftmax(true).windowSize(5)
.allowParallelTokenization(true)
@@ -1009,7 +1010,7 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
- Word2Vec wordVectors = new Word2Vec.Builder().minWordFrequency(1).batchSize(250).iterations(1).epochs(3)
+ Word2Vec wordVectors = new Word2Vec.Builder().minWordFrequency(1).batchSize(250).iterations(1).epochs(1)
.learningRate(0.025).layerSize(150).minLearningRate(0.001)
.elementsLearningAlgorithm(new SkipGram()).useHierarchicSoftmax(true).windowSize(5)
.iterate(iter).tokenizerFactory(t).build();
@@ -1151,8 +1152,27 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
@Test(timeout = 300000)
public void testDoubleFit() throws Exception {
+ boolean isIntegration = isIntegrationTests();
File resource = Resources.asFile("/big/raw_sentences.txt");
- SentenceIterator iter = new BasicLineIterator(resource);
+ SentenceIterator iter;
+ if(isIntegration){
+ iter = new BasicLineIterator(resource);
+ } else {
+ List lines = new ArrayList<>();
+ try(InputStream is = new BufferedInputStream(new FileInputStream(resource))){
+ LineIterator lineIter = IOUtils.lineIterator(is, StandardCharsets.UTF_8);
+ try{
+ for( int i=0; i<500 && lineIter.hasNext(); i++ ){
+ lines.add(lineIter.next());
+ }
+ } finally {
+ lineIter.close();
+ }
+ }
+
+ iter = new CollectionSentenceIterator(lines);
+ }
+
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/iterator/Word2VecDataSetIteratorTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/iterator/Word2VecDataSetIteratorTest.java
index 44b098dc1..9e8e89d39 100644
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/iterator/Word2VecDataSetIteratorTest.java
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/iterator/Word2VecDataSetIteratorTest.java
@@ -49,7 +49,7 @@ public class Word2VecDataSetIteratorTest extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
- return 240000L;
+ return 60000L;
}
/**
@@ -57,6 +57,7 @@ public class Word2VecDataSetIteratorTest extends BaseDL4JTest {
*/
@Test
public void testIterator1() throws Exception {
+
File inputFile = Resources.asFile("big/raw_sentences.txt");
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
@@ -77,10 +78,14 @@ public class Word2VecDataSetIteratorTest extends BaseDL4JTest {
Word2VecDataSetIterator iterator = new Word2VecDataSetIterator(vec, getLASI(iter, labels), labels, 1);
INDArray array = iterator.next().getFeatures();
+ int count = 0;
while (iterator.hasNext()) {
DataSet ds = iterator.next();
assertArrayEquals(array.shape(), ds.getFeatures().shape());
+
+ if(!isIntegrationTests() && count++ > 20)
+ break; //raw_sentences.txt is 2.81 MB, takes quite some time to process. We'll only first 20 minibatches when doing unit tests
}
}
diff --git a/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/optimize/solvers/accumulation/EncodedGradientsAccumulatorTest.java b/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/optimize/solvers/accumulation/EncodedGradientsAccumulatorTest.java
index 380807c8d..d75653604 100644
--- a/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/optimize/solvers/accumulation/EncodedGradientsAccumulatorTest.java
+++ b/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/optimize/solvers/accumulation/EncodedGradientsAccumulatorTest.java
@@ -45,9 +45,15 @@ public class EncodedGradientsAccumulatorTest extends BaseDL4JTest {
*/
@Test
public void testStore1() throws Exception {
- int numParams = 100000;
-
- int workers[] = new int[] {2, 4, 8};
+ int numParams;
+ int[] workers;
+ if(isIntegrationTests()){
+ numParams = 100000;
+ workers = new int[] {2, 4, 8};
+ } else {
+ numParams = 10000;
+ workers = new int[] {2, 3};
+ }
for (int numWorkers : workers) {
EncodingHandler handler = new EncodingHandler(new FixedThresholdAlgorithm(1e-3),null, null, false);
@@ -77,7 +83,13 @@ public class EncodedGradientsAccumulatorTest extends BaseDL4JTest {
*/
@Test
public void testEncodingLimits1() throws Exception {
- int numParams = 100000;
+ int numParams;
+ if(isIntegrationTests()){
+ numParams = 100000;
+ } else {
+ numParams = 10000;
+ }
+
EncodingHandler handler = new EncodingHandler(new FixedThresholdAlgorithm(1e-3), null, null, false);
for (int e = 10; e < numParams / 5; e++) {
diff --git a/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/optimize/solvers/accumulation/IndexedTailTest.java b/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/optimize/solvers/accumulation/IndexedTailTest.java
index 68de05ce5..a7884bb56 100644
--- a/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/optimize/solvers/accumulation/IndexedTailTest.java
+++ b/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/optimize/solvers/accumulation/IndexedTailTest.java
@@ -242,7 +242,7 @@ public class IndexedTailTest extends BaseDL4JTest {
final long[] sums = new long[numReaders];
val readers = new ArrayList();
for (int e = 0; e < numReaders; e++) {
- val f = e;
+ final int f = e;
val t = new Thread(new Runnable() {
@Override
public void run() {
@@ -297,7 +297,7 @@ public class IndexedTailTest extends BaseDL4JTest {
final long[] sums = new long[numReaders];
val readers = new ArrayList();
for (int e = 0; e < numReaders; e++) {
- val f = e;
+ final int f = e;
val t = new Thread(new Runnable() {
@Override
public void run() {
@@ -371,7 +371,7 @@ public class IndexedTailTest extends BaseDL4JTest {
final long[] sums = new long[numReaders];
val readers = new ArrayList();
for (int e = 0; e < numReaders; e++) {
- val f = e;
+ final int f = e;
val t = new Thread(new Runnable() {
@Override
public void run() {
diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/JsonModelServerTest.java b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/JsonModelServerTest.java
index fd79cc780..8b060a77c 100644
--- a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/JsonModelServerTest.java
+++ b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/JsonModelServerTest.java
@@ -35,6 +35,7 @@ import org.deeplearning4j.remote.helpers.House;
import org.deeplearning4j.remote.helpers.HouseToPredictedPriceAdapter;
import org.deeplearning4j.remote.helpers.PredictedPrice;
import org.junit.After;
+import org.junit.Before;
import org.junit.Test;
import org.nd4j.adapters.InferenceAdapter;
import org.nd4j.autodiff.samediff.SDVariable;
@@ -58,6 +59,7 @@ import java.util.Collections;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
import static org.deeplearning4j.parallelism.inference.InferenceMode.INPLACE;
import static org.deeplearning4j.parallelism.inference.InferenceMode.SEQUENTIAL;
@@ -66,7 +68,6 @@ import static org.junit.Assert.*;
@Slf4j
public class JsonModelServerTest extends BaseDL4JTest {
private static final MultiLayerNetwork model;
- private final int PORT = 18080;
static {
val conf = new NeuralNetConfiguration.Builder()
@@ -84,10 +85,18 @@ public class JsonModelServerTest extends BaseDL4JTest {
@After
public void pause() throws Exception {
- // TODO: the same port was used in previous test and not accessible immediately. Might be better solution.
+ // Need to wait for server shutdown; without sleep, tests will fail if starting immediately after shutdown
TimeUnit.SECONDS.sleep(2);
}
+ private AtomicInteger portCount = new AtomicInteger(18080);
+ private int PORT;
+
+ @Before
+ public void setPort(){
+ PORT = portCount.getAndIncrement();
+ }
+
@Test
public void testStartStopParallel() throws Exception {
@@ -343,7 +352,7 @@ public class JsonModelServerTest extends BaseDL4JTest {
val server = new JsonModelServer.Builder(model)
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
.inputDeserializer(null)
- .port(18080)
+ .port(PORT)
.build();
}
@@ -382,7 +391,7 @@ public class JsonModelServerTest extends BaseDL4JTest {
return null;
}
})
- .endpointAddress("http://localhost:18080/v1/serving")
+ .endpointAddress("http://localhost:" + PORT + "/v1/serving")
.build();
int district = 2;
diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelInferenceTest.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelInferenceTest.java
index 42e8437e7..219573f0a 100644
--- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelInferenceTest.java
+++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelInferenceTest.java
@@ -485,7 +485,8 @@ public class ParallelInferenceTest extends BaseDL4JTest {
List exp = new ArrayList<>();
Random r = new Random();
- for (int i = 0; i < 500; i++) {
+ int runs = isIntegrationTests() ? 500 : 30;
+ for (int i = 0; i < runs; i++) {
int[] shape = defaultSize;
if (r.nextDouble() < 0.4) {
shape = new int[]{r.nextInt(5) + 1, 10, r.nextInt(10) + 1};
@@ -597,7 +598,8 @@ public class ParallelInferenceTest extends BaseDL4JTest {
List arrs = new ArrayList<>();
List exp = new ArrayList<>();
Random r = new Random();
- for( int i=0; i<500; i++ ){
+ int runs = isIntegrationTests() ? 500 : 20;
+ for( int i=0; i in = new ArrayList<>();
List inMasks = new ArrayList<>();
List exp = new ArrayList<>();
- for (int i = 0; i < 100; i++) {
+ int nRuns = isIntegrationTests() ? 100 : 10;
+ for (int i = 0; i < nRuns; i++) {
int currTSLength = (randomTSLength ? 1 + r.nextInt(tsLength) : tsLength);
int currNumEx = 1 + r.nextInt(3);
INDArray inArr = Nd4j.rand(new int[]{currNumEx, nIn, currTSLength});
@@ -847,6 +852,7 @@ public class ParallelInferenceTest extends BaseDL4JTest {
List in = new ArrayList<>();
List exp = new ArrayList<>();
+ int runs = isIntegrationTests() ? 100 : 20;
for (int i = 0; i < 100; i++) {
int currNumEx = 1 + r.nextInt(3);
INDArray inArr = Nd4j.rand(new int[]{currNumEx, nIn});
diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelWrapperTest.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelWrapperTest.java
index 32c70f3ff..cea06265a 100644
--- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelWrapperTest.java
+++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelWrapperTest.java
@@ -62,8 +62,8 @@ public class ParallelWrapperTest extends BaseDL4JTest {
int seed = 123;
log.info("Load data....");
- DataSetIterator mnistTrain = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(batchSize, true, 12345), 100);
- DataSetIterator mnistTest = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(batchSize, false, 12345), 10);
+ DataSetIterator mnistTrain = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(batchSize, true, 12345), 15);
+ DataSetIterator mnistTest = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(batchSize, false, 12345), 4);
assertTrue(mnistTrain.hasNext());
val t0 = mnistTrain.next();
diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml
index 42b6e42cf..3eafbb9e2 100644
--- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml
+++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml
@@ -47,6 +47,12 @@
org.nd4j
nd4j-parameter-server-node_2.11
${nd4j.version}
+
+
+ net.jpountz.lz4
+ lz4
+
+
junit
diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/resources/log4j.properties b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/resources/log4j.properties
new file mode 100755
index 000000000..5d1edb39f
--- /dev/null
+++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/resources/log4j.properties
@@ -0,0 +1,31 @@
+################################################################################
+# Copyright (c) 2015-2019 Skymind, Inc.
+#
+# This program and the accompanying materials are made available under the
+# terms of the Apache License, Version 2.0 which is available at
+# https://www.apache.org/licenses/LICENSE-2.0.
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+################################################################################
+
+log4j.rootLogger=ERROR, Console
+log4j.appender.Console=org.apache.log4j.ConsoleAppender
+log4j.appender.Console.layout=org.apache.log4j.PatternLayout
+log4j.appender.Console.layout.ConversionPattern=%d{ABSOLUTE} %-5p ~ %m%n
+
+log4j.appender.org.springframework=DEBUG
+log4j.appender.org.deeplearning4j=DEBUG
+log4j.appender.org.nd4j=DEBUG
+
+log4j.logger.org.springframework=INFO
+log4j.logger.org.deeplearning4j=DEBUG
+log4j.logger.org.nd4j=DEBUG
+log4j.logger.org.apache.spark=WARN
+
+
diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/resources/logback.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/resources/logback.xml
new file mode 100644
index 000000000..9dec22fae
--- /dev/null
+++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/resources/logback.xml
@@ -0,0 +1,53 @@
+
+
+
+
+
+
+
+ logs/application.log
+
+ %date - [%level] - from %logger in %thread
+ %n%message%n%xException%n
+
+
+
+
+
+ %logger{15} - %message%n%xException{5}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/resources/log4j.properties b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/resources/log4j.properties
new file mode 100755
index 000000000..5d1edb39f
--- /dev/null
+++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/resources/log4j.properties
@@ -0,0 +1,31 @@
+################################################################################
+# Copyright (c) 2015-2019 Skymind, Inc.
+#
+# This program and the accompanying materials are made available under the
+# terms of the Apache License, Version 2.0 which is available at
+# https://www.apache.org/licenses/LICENSE-2.0.
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+################################################################################
+
+log4j.rootLogger=ERROR, Console
+log4j.appender.Console=org.apache.log4j.ConsoleAppender
+log4j.appender.Console.layout=org.apache.log4j.PatternLayout
+log4j.appender.Console.layout.ConversionPattern=%d{ABSOLUTE} %-5p ~ %m%n
+
+log4j.appender.org.springframework=DEBUG
+log4j.appender.org.deeplearning4j=DEBUG
+log4j.appender.org.nd4j=DEBUG
+
+log4j.logger.org.springframework=INFO
+log4j.logger.org.deeplearning4j=DEBUG
+log4j.logger.org.nd4j=DEBUG
+log4j.logger.org.apache.spark=WARN
+
+
diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/resources/logback.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/resources/logback.xml
new file mode 100644
index 000000000..9dec22fae
--- /dev/null
+++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/resources/logback.xml
@@ -0,0 +1,53 @@
+
+
+
+
+
+
+
+ logs/application.log
+
+ %date - [%level] - from %logger in %thread
+ %n%message%n%xException%n
+
+
+
+
+
+ %logger{15} - %message%n%xException{5}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/iterator/TestIteratorUtils.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/iterator/TestIteratorUtils.java
index 6aa730c2f..b7a4d6d49 100644
--- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/iterator/TestIteratorUtils.java
+++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/iterator/TestIteratorUtils.java
@@ -25,6 +25,7 @@ import org.datavec.spark.transform.misc.StringToWritablesFunction;
import org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator;
import org.deeplearning4j.spark.BaseSparkTest;
import org.junit.Test;
+import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.io.ClassPathResource;
@@ -35,6 +36,15 @@ import static org.junit.Assert.assertEquals;
public class TestIteratorUtils extends BaseSparkTest {
+ @Override
+ public DataType getDataType() {
+ return DataType.FLOAT;
+ }
+
+ @Override
+ public DataType getDefaultFPDataType() {
+ return DataType.FLOAT;
+ }
@Test
public void testIrisRRMDSI() throws Exception {
diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java
index abfd39060..6908be512 100644
--- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java
+++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java
@@ -453,8 +453,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
tempDirF.deleteOnExit();
int dataSetObjSize = 1;
- int batchSizePerExecutor = 16;
- int numSplits = 5;
+ int batchSizePerExecutor = 4;
+ int numSplits = 3;
int averagingFrequency = 3;
int totalExamples = numExecutors() * batchSizePerExecutor * numSplits * averagingFrequency;
DataSetIterator iter = new MnistDataSetIterator(dataSetObjSize, totalExamples, false);
@@ -506,7 +506,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
INDArray paramsAfter = sparkNet.getNetwork().params().dup();
assertNotEquals(paramsBefore, paramsAfter);
- Thread.sleep(2000);
+ Thread.sleep(200);
SparkTrainingStats stats = sparkNet.getSparkTrainingStats();
//Expect
@@ -517,7 +517,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
assertEquals(numSplits * numExecutors() * averagingFrequency, list.size());
for (EventStats es : list) {
ExampleCountEventStats e = (ExampleCountEventStats) es;
- assertTrue(batchSizePerExecutor * averagingFrequency - 10 >= e.getTotalExampleCount());
+ assertTrue(batchSizePerExecutor * averagingFrequency >= e.getTotalExampleCount());
}
@@ -535,9 +535,9 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
tempDirF.deleteOnExit();
tempDirF2.deleteOnExit();
- int dataSetObjSize = 5;
- int batchSizePerExecutor = 25;
- DataSetIterator iter = new MnistDataSetIterator(dataSetObjSize, 1000, false);
+ int dataSetObjSize = 4;
+ int batchSizePerExecutor = 8;
+ DataSetIterator iter = new MnistDataSetIterator(dataSetObjSize, 128, false);
int i = 0;
while (iter.hasNext()) {
File nextFile = new File(tempDirF, i + ".bin");
@@ -981,7 +981,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
.setOutputs("out")
.build();
- DataSetIterator iter = new IrisDataSetIterator(1, 150);
+ DataSetIterator iter = new IrisDataSetIterator(1, 50);
List l = new ArrayList<>();
while(iter.hasNext()){
@@ -992,9 +992,10 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
int rddDataSetNumExamples = 1;
- int averagingFrequency = 3;
+ int averagingFrequency = 2;
+ int batch = 2;
ParameterAveragingTrainingMaster tm = new ParameterAveragingTrainingMaster.Builder(rddDataSetNumExamples)
- .averagingFrequency(averagingFrequency).batchSizePerWorker(rddDataSetNumExamples)
+ .averagingFrequency(averagingFrequency).batchSizePerWorker(batch)
.saveUpdater(true).workerPrefetchNumBatches(0).build();
Nd4j.getRandom().setSeed(12345);
@@ -1003,7 +1004,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
SparkComputationGraph sn2 = new SparkComputationGraph(sc, conf2.clone(), tm);
- for(int i=0; i<4; i++ ){
+ for(int i=0; i<3; i++ ){
assertEquals(i, sn1.getNetwork().getLayerWiseConfigurations().getEpochCount());
assertEquals(i, sn2.getNetwork().getConfiguration().getEpochCount());
sn1.fit(rdd);
diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/util/TestRepartitioning.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/util/TestRepartitioning.java
index b01de7e5e..e0759a549 100644
--- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/util/TestRepartitioning.java
+++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/util/TestRepartitioning.java
@@ -42,6 +42,11 @@ import static org.junit.Assert.assertTrue;
*/
public class TestRepartitioning extends BaseSparkTest {
+ @Override
+ public long getTimeoutMilliseconds() {
+ return isIntegrationTests() ? 240000 : 60000;
+ }
+
@Test
public void testRepartitioning() {
List list = new ArrayList<>();
@@ -66,7 +71,12 @@ public class TestRepartitioning extends BaseSparkTest {
@Test
public void testRepartitioning2() throws Exception {
- int[] ns = {320, 321, 25600, 25601, 25615};
+ int[] ns;
+ if(isIntegrationTests()){
+ ns = new int[]{320, 321, 25600, 25601, 25615};
+ } else {
+ ns = new int[]{320, 2561};
+ }
for (int n : ns) {
diff --git a/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/MiscTests.java b/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/MiscTests.java
index af0205f00..c92f5acdc 100644
--- a/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/MiscTests.java
+++ b/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/MiscTests.java
@@ -32,6 +32,11 @@ import java.io.File;
public class MiscTests extends BaseDL4JTest {
+ @Override
+ public long getTimeoutMilliseconds() {
+ return 120000L;
+ }
+
@Test
public void testTransferVGG() throws Exception {
//https://github.com/deeplearning4j/deeplearning4j/issues/5167
diff --git a/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/TestDownload.java b/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/TestDownload.java
index bb41443bb..b45afe47a 100644
--- a/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/TestDownload.java
+++ b/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/TestDownload.java
@@ -48,6 +48,11 @@ import static org.junit.Assert.assertEquals;
@Slf4j
public class TestDownload extends BaseDL4JTest {
+ @Override
+ public long getTimeoutMilliseconds() {
+ return isIntegrationTests() ? 480000L : 60000L;
+ }
+
@ClassRule
public static TemporaryFolder testDir = new TemporaryFolder();
private static File f;
@@ -67,12 +72,20 @@ public class TestDownload extends BaseDL4JTest {
public void testDownloadAllModels() throws Exception {
// iterate through each available model
- ZooModel[] models = new ZooModel[]{
- LeNet.builder().build(),
- SimpleCNN.builder().build(),
- UNet.builder().build(),
- NASNet.builder().build()
- };
+ ZooModel[] models;
+
+ if(isIntegrationTests()){
+ models = new ZooModel[]{
+ LeNet.builder().build(),
+ SimpleCNN.builder().build(),
+ UNet.builder().build(),
+ NASNet.builder().build()};
+ } else {
+ models = new ZooModel[]{
+ LeNet.builder().build(),
+ SimpleCNN.builder().build()};
+ }
+
for (int i = 0; i < models.length; i++) {
diff --git a/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/TestImageNet.java b/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/TestImageNet.java
index b1963b9b6..9106bede3 100644
--- a/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/TestImageNet.java
+++ b/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/TestImageNet.java
@@ -57,6 +57,11 @@ import static org.junit.Assert.assertTrue;
@Slf4j
public class TestImageNet extends BaseDL4JTest {
+ @Override
+ public long getTimeoutMilliseconds() {
+ return 90000L;
+ }
+
@Override
public DataType getDataType(){
return DataType.FLOAT;
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/DistributionUniform.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/DistributionUniform.java
index 682d7c230..4b104d08b 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/DistributionUniform.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/DistributionUniform.java
@@ -63,10 +63,15 @@ public class DistributionUniform extends DynamicCustomOp {
addArgs();
}
- public DistributionUniform(INDArray shape, INDArray out, double min, double max){
+ public DistributionUniform(INDArray shape, INDArray out, double min, double max) {
+ this(shape, out, min, max, null);
+ }
+
+ public DistributionUniform(INDArray shape, INDArray out, double min, double max, DataType dataType){
super(null, new INDArray[]{shape}, new INDArray[]{out}, Arrays.asList(min, max), (List)null);
this.min = min;
this.max = max;
+ this.dataType = dataType;
}
diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/pom.xml b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/pom.xml
index 2d2c8d6e3..d98c7a6d1 100644
--- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/pom.xml
+++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/pom.xml
@@ -310,6 +310,13 @@
+
+
+ org.nd4j
+ nd4j-common-tests
+ ${project.version}
+ test
+
diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/DeviceLocalNDArrayTests.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/DeviceLocalNDArrayTests.java
index 9584d5692..a616e7aa1 100644
--- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/DeviceLocalNDArrayTests.java
+++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/DeviceLocalNDArrayTests.java
@@ -19,6 +19,7 @@ package org.nd4j.jita.allocator;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.junit.Test;
+import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.DeviceLocalNDArray;
@@ -29,7 +30,7 @@ import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
@Slf4j
-public class DeviceLocalNDArrayTests {
+public class DeviceLocalNDArrayTests extends BaseND4JTest {
@Test
public void testDeviceLocalArray_1() throws Exception{
diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/impl/MemoryTrackerTest.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/impl/MemoryTrackerTest.java
index d952c775a..9854b797a 100644
--- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/impl/MemoryTrackerTest.java
+++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/impl/MemoryTrackerTest.java
@@ -19,13 +19,14 @@ package org.nd4j.jita.allocator.impl;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.junit.Test;
+import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.factory.Nd4j;
import static org.junit.Assert.*;
@Slf4j
-public class MemoryTrackerTest {
+public class MemoryTrackerTest extends BaseND4JTest {
@Test
public void testAllocatedDelta() {
diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBufferTest.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBufferTest.java
index 09c1ebb04..6f62dc53a 100644
--- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBufferTest.java
+++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBufferTest.java
@@ -4,6 +4,7 @@ import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.junit.Before;
import org.junit.Test;
+import org.nd4j.BaseND4JTest;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.workspace.CudaWorkspace;
import org.nd4j.linalg.api.buffer.DataType;
@@ -20,7 +21,7 @@ import java.util.concurrent.atomic.AtomicInteger;
import static org.junit.Assert.*;
@Slf4j
-public class BaseCudaDataBufferTest {
+public class BaseCudaDataBufferTest extends BaseND4JTest {
@Before
public void setUp() {
diff --git a/nd4j/nd4j-backends/nd4j-tests-tensorflow/pom.xml b/nd4j/nd4j-backends/nd4j-tests-tensorflow/pom.xml
index bc468c874..5f5c5fa90 100644
--- a/nd4j/nd4j-backends/nd4j-tests-tensorflow/pom.xml
+++ b/nd4j/nd4j-backends/nd4j-tests-tensorflow/pom.xml
@@ -87,6 +87,13 @@
${logback.version}
test
+
+
+ org.nd4j
+ nd4j-common-tests
+ ${project.version}
+ test
+
diff --git a/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/cpujava/org/nd4j/tensorflow/conversion/GraphRunnerTest.java b/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/cpujava/org/nd4j/tensorflow/conversion/GraphRunnerTest.java
index c959edbfe..81599cd5b 100644
--- a/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/cpujava/org/nd4j/tensorflow/conversion/GraphRunnerTest.java
+++ b/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/cpujava/org/nd4j/tensorflow/conversion/GraphRunnerTest.java
@@ -20,6 +20,7 @@ package org.nd4j.tensorflow.conversion;
import junit.framework.TestCase;
import org.apache.commons.io.FileUtils;
import org.bytedeco.tensorflow.TF_Tensor;
+import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.resources.Resources;
import org.nd4j.shade.protobuf.Descriptors;
@@ -46,7 +47,17 @@ import java.util.Map;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
-public class GraphRunnerTest {
+public class GraphRunnerTest extends BaseND4JTest {
+
+ @Override
+ public DataType getDataType() {
+ return DataType.FLOAT;
+ }
+
+ @Override
+ public DataType getDefaultFPDataType() {
+ return DataType.FLOAT;
+ }
public static ConfigProto getConfig(){
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
diff --git a/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/cpujava/org/nd4j/tensorflow/conversion/TensorflowConversionTest.java b/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/cpujava/org/nd4j/tensorflow/conversion/TensorflowConversionTest.java
index eb263f119..fbf4249bd 100644
--- a/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/cpujava/org/nd4j/tensorflow/conversion/TensorflowConversionTest.java
+++ b/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/cpujava/org/nd4j/tensorflow/conversion/TensorflowConversionTest.java
@@ -18,6 +18,7 @@ package org.nd4j.tensorflow.conversion;
import org.apache.commons.io.IOUtils;
import org.junit.Test;
+import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.io.ClassPathResource;
@@ -29,7 +30,7 @@ import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.fail;
-public class TensorflowConversionTest {
+public class TensorflowConversionTest extends BaseND4JTest {
@Test
public void testView() {
diff --git a/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/gpujava/org/nd4j/tensorflow/conversion/GpuGraphRunnerTest.java b/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/gpujava/org/nd4j/tensorflow/conversion/GpuGraphRunnerTest.java
index 614330813..b13ca465f 100644
--- a/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/gpujava/org/nd4j/tensorflow/conversion/GpuGraphRunnerTest.java
+++ b/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/gpujava/org/nd4j/tensorflow/conversion/GpuGraphRunnerTest.java
@@ -17,6 +17,7 @@
package org.nd4j.tensorflow.conversion;
+import org.nd4j.BaseND4JTest;
import org.nd4j.shade.protobuf.util.JsonFormat;
import org.apache.commons.io.IOUtils;
import org.junit.Test;
@@ -37,7 +38,7 @@ import java.util.Map;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
-public class GpuGraphRunnerTest {
+public class GpuGraphRunnerTest extends BaseND4JTest {
@Test
public void testGraphRunner() throws Exception {
diff --git a/nd4j/nd4j-backends/nd4j-tests/pom.xml b/nd4j/nd4j-backends/nd4j-tests/pom.xml
index e6861d257..9d098189f 100644
--- a/nd4j/nd4j-backends/nd4j-tests/pom.xml
+++ b/nd4j/nd4j-backends/nd4j-tests/pom.xml
@@ -127,6 +127,13 @@
+
+
+ org.nd4j
+ nd4j-common-tests
+ ${project.version}
+ test
+
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java
index f7d8c6d9b..db443c548 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java
@@ -16,9 +16,6 @@
package org.nd4j.autodiff.opvalidation;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertNull;
-
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
@@ -46,6 +43,8 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.ops.transforms.Transforms;
+import static org.junit.Assert.*;
+
@Slf4j
public class LayerOpValidation extends BaseOpValidation {
public LayerOpValidation(Nd4jBackend backend) {
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LossOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LossOpValidation.java
index 1c55e47fb..7f8da282e 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LossOpValidation.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LossOpValidation.java
@@ -45,7 +45,7 @@ public class LossOpValidation extends BaseOpValidation {
}
@Override
- public long testTimeoutMilliseconds() {
+ public long getTimeoutMilliseconds() {
return 90000L;
}
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java
index 51e8fd714..6c6ba5b83 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java
@@ -54,8 +54,7 @@ import java.util.Arrays;
import java.util.Collections;
import java.util.List;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertNull;
+import static org.junit.Assert.*;
@Slf4j
@RunWith(Parameterized.class)
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java
index 15d6bd273..071551719 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java
@@ -2434,8 +2434,6 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test
public void testPermute4(){
- Nd4j.getExecutioner().enableDebugMode(true);
- Nd4j.getExecutioner().enableVerboseMode(true);
INDArray in = Nd4j.linspace(DataType.FLOAT, 1, 6, 1).reshape(3,2);
INDArray permute = Nd4j.createFromArray(1,0);
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/CheckpointListenerTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/CheckpointListenerTest.java
index 792fce892..33ac52f16 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/CheckpointListenerTest.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/CheckpointListenerTest.java
@@ -63,8 +63,12 @@ public class CheckpointListenerTest extends BaseNd4jTest {
return sd;
}
- public static DataSetIterator getIter(){
- return new IrisDataSetIterator(15, 150);
+ public static DataSetIterator getIter() {
+ return getIter(15, 150);
+ }
+
+ public static DataSetIterator getIter(int batch, int totalExamples){
+ return new IrisDataSetIterator(batch, totalExamples);
}
@@ -148,15 +152,15 @@ public class CheckpointListenerTest extends BaseNd4jTest {
CheckpointListener l = new CheckpointListener.Builder(dir)
.keepLast(2)
- .saveEvery(3, TimeUnit.SECONDS)
+ .saveEvery(1, TimeUnit.SECONDS)
.build();
sd.setListeners(l);
- DataSetIterator iter = getIter();
+ DataSetIterator iter = getIter(15, 150);
for(int i=0; i<5; i++ ){ //10 iterations total
sd.fit(iter, 1);
- Thread.sleep(4000);
+ Thread.sleep(1000);
}
//Expect models saved at iterations: 10, 20, 30, 40
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalCustomThreshold.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalCustomThreshold.java
index 7ec752a14..fcdb04e31 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalCustomThreshold.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalCustomThreshold.java
@@ -29,6 +29,7 @@ import org.nd4j.linalg.ops.transforms.Transforms;
import java.util.Random;
+import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
/**
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvaluationCalibrationTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvaluationCalibrationTest.java
index 012ba3434..64281c7d6 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvaluationCalibrationTest.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvaluationCalibrationTest.java
@@ -34,8 +34,7 @@ import java.util.ArrayList;
import java.util.List;
import java.util.Random;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.*;
/**
* Created by Alex on 05/07/2017.
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/ROCTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/ROCTest.java
index c4ffe13d2..e90a4829c 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/ROCTest.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/ROCTest.java
@@ -33,8 +33,7 @@ import org.nd4j.linalg.indexing.NDArrayIndex;
import java.util.*;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.*;
/**
* Created by Alex on 04/11/2016.
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/ByteOrderTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/ByteOrderTests.java
index c2ca2148e..5ac5c9410 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/ByteOrderTests.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/ByteOrderTests.java
@@ -33,6 +33,7 @@ import org.nd4j.nativeblas.NativeOpsHolder;
import java.util.Arrays;
+import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
@Slf4j
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java
index 2ea9a8142..63571a30e 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java
@@ -121,7 +121,12 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
"fused_batch_norm/.*",
// AB 2020/01/04 - https://github.com/eclipse/deeplearning4j/issues/8592
- "emptyArrayTests/reshape/rank2_shape2-0_2-0--1"
+ "emptyArrayTests/reshape/rank2_shape2-0_2-0--1",
+
+ //AB 2020/01/07 - Known issues
+ "bitcast/from_float64_to_int64",
+ "bitcast/from_rank2_float64_to_int64",
+ "bitcast/from_float64_to_uint64"
};
/* As per TFGraphTestList.printArraysDebugging - this field defines a set of regexes for test cases that should have
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TensorFlowImportTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TensorFlowImportTest.java
index d3a76bf28..2a0880906 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TensorFlowImportTest.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TensorFlowImportTest.java
@@ -252,7 +252,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
System.out.println(Arrays.toString(shape));
// this is NHWC weights. will be changed soon.
- assertArrayEquals(new int[]{5,5,1,32}, shape);
+ assertArrayEquals(new long[]{5,5,1,32}, shape);
System.out.println(convNode);
}
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/BaseNd4jTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/BaseNd4jTest.java
index 87a0a9aef..986103ef6 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/BaseNd4jTest.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/BaseNd4jTest.java
@@ -17,6 +17,7 @@
package org.nd4j.linalg;
+import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.bytedeco.javacpp.Pointer;
import org.junit.After;
@@ -26,6 +27,7 @@ import org.junit.rules.TestName;
import org.junit.rules.Timeout;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
+import org.nd4j.BaseND4JTest;
import org.nd4j.config.ND4JSystemProperties;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
@@ -40,30 +42,16 @@ import org.slf4j.LoggerFactory;
import java.lang.management.ManagementFactory;
import java.util.*;
+import static org.junit.Assume.assumeTrue;
+
/**
* Base Nd4j test
* @author Adam Gibson
*/
@RunWith(Parameterized.class)
-public abstract class BaseNd4jTest {
- private static Logger log = LoggerFactory.getLogger(BaseNd4jTest.class);
-
- @Rule
- public TestName testName = new TestName();
-
- @Rule
- public Timeout timeout = Timeout.seconds(testTimeoutMilliseconds());
-
- /**
- * Override this method to set the default timeout for methods in the class
- */
- public long testTimeoutMilliseconds(){
- return 30000L;
- }
-
- protected long startTime;
- protected int threadCountBefore;
+@Slf4j
+public abstract class BaseNd4jTest extends BaseND4JTest {
protected Nd4jBackend backend;
protected String name;
@@ -80,16 +68,10 @@ public abstract class BaseNd4jTest {
public BaseNd4jTest(String name, Nd4jBackend backend) {
this.backend = backend;
this.name = name;
-
- //Suppress ND4J initialization - don't need this logged for every test...
- System.setProperty(ND4JSystemProperties.LOG_INITIALIZATION, "false");
- System.setProperty(ND4JSystemProperties.ND4J_IGNORE_AVX, "true");
- System.gc();
}
public BaseNd4jTest(Nd4jBackend backend) {
this(backend.getClass().getName() + UUID.randomUUID().toString(), backend);
-
}
private static List backends;
@@ -104,79 +86,6 @@ public abstract class BaseNd4jTest {
if (backend.canRun() && backendsToRun.contains(backend.getClass().getName()) || backendsToRun.isEmpty())
backends.add(backend);
}
-
- }
- public static void assertArrayEquals(String string, Object[] expecteds, Object[] actuals) {
- org.junit.Assert.assertArrayEquals(string, expecteds, actuals);
- }
-
- public static void assertArrayEquals(Object[] expecteds, Object[] actuals) {
- org.junit.Assert.assertArrayEquals(expecteds, actuals);
- }
-
- public static void assertArrayEquals(String string, long[] shapeA, long[] shapeB) {
- org.junit.Assert.assertArrayEquals(string, shapeA, shapeB);
- }
-
- public static void assertArrayEquals(String string, byte[] shapeA, byte[] shapeB) {
- org.junit.Assert.assertArrayEquals(string, shapeA, shapeB);
- }
-
- public static void assertArrayEquals(byte[] shapeA, byte[] shapeB) {
- org.junit.Assert.assertArrayEquals(shapeA, shapeB);
- }
-
- public static void assertArrayEquals(long[] shapeA, long[] shapeB) {
- org.junit.Assert.assertArrayEquals(shapeA, shapeB);
- }
-
- public static void assertArrayEquals(String string, int[] shapeA, long[] shapeB) {
- org.junit.Assert.assertArrayEquals(string, ArrayUtil.toLongArray(shapeA), shapeB);
- }
-
- public static void assertArrayEquals(int[] shapeA, long[] shapeB) {
- org.junit.Assert.assertArrayEquals(ArrayUtil.toLongArray(shapeA), shapeB);
- }
-
- public static void assertArrayEquals(String string, long[] shapeA, int[] shapeB) {
- org.junit.Assert.assertArrayEquals(string, shapeA, ArrayUtil.toLongArray(shapeB));
- }
-
- public static void assertArrayEquals(long[] shapeA, int[] shapeB) {
- org.junit.Assert.assertArrayEquals(shapeA, ArrayUtil.toLongArray(shapeB));
- }
-
- public static void assertArrayEquals(String string, int[] shapeA, int[] shapeB) {
- org.junit.Assert.assertArrayEquals(string, shapeA, shapeB);
- }
-
- public static void assertArrayEquals(int[] shapeA, int[] shapeB) {
- org.junit.Assert.assertArrayEquals(shapeA, shapeB);
- }
-
- public static void assertArrayEquals(String string, boolean[] shapeA, boolean[] shapeB) {
- org.junit.Assert.assertArrayEquals(string, shapeA, shapeB);
- }
-
- public static void assertArrayEquals(boolean[] shapeA, boolean[] shapeB) {
- org.junit.Assert.assertArrayEquals(shapeA, shapeB);
- }
-
-
- public static void assertArrayEquals(float[] shapeA, float[] shapeB, float delta) {
- org.junit.Assert.assertArrayEquals(shapeA, shapeB, delta);
- }
-
- public static void assertArrayEquals(double[] shapeA, double[] shapeB, double delta) {
- org.junit.Assert.assertArrayEquals(shapeA, shapeB, delta);
- }
-
- public static void assertArrayEquals(String string, float[] shapeA, float[] shapeB, float delta) {
- org.junit.Assert.assertArrayEquals(string, shapeA, shapeB, delta);
- }
-
- public static void assertArrayEquals(String string, double[] shapeA, double[] shapeB, double delta) {
- org.junit.Assert.assertArrayEquals(string, shapeA, shapeB, delta);
}
@Parameterized.Parameters(name = "{index}: backend({0})={1}")
@@ -187,6 +96,13 @@ public abstract class BaseNd4jTest {
return ret;
}
+ @Override
+ @Before
+ public void beforeTest(){
+ super.beforeTest();
+ Nd4j.factory().setOrder(ordering());
+ }
+
/**
* Get the default backend (jblas)
* The default backend can be overridden by also passing:
@@ -207,106 +123,6 @@ public abstract class BaseNd4jTest {
}
-
- @Before
- public void before() throws Exception {
- //
- log.info("Running {}.{} on {}", getClass().getName(), testName.getMethodName(), backend.getClass().getSimpleName());
- Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build());
- Nd4j nd4j = new Nd4j();
- nd4j.initWithBackend(backend);
- Nd4j.factory().setOrder(ordering());
- NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false);
- Nd4j.getExecutioner().enableDebugMode(false);
- Nd4j.getExecutioner().enableVerboseMode(false);
- Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE);
- startTime = System.currentTimeMillis();
- threadCountBefore = ManagementFactory.getThreadMXBean().getThreadCount();
- }
-
- @After
- public void after() throws Exception {
- long totalTime = System.currentTimeMillis() - startTime;
- Nd4j.getMemoryManager().purgeCaches();
-
- logTestCompletion(totalTime);
- if (System.getProperties().getProperty("backends") != null
- && !System.getProperty("backends").contains(backend.getClass().getName()))
- return;
- Nd4j nd4j = new Nd4j();
- nd4j.initWithBackend(backend);
- Nd4j.factory().setOrder(ordering());
- NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false);
- Nd4j.getExecutioner().enableDebugMode(false);
- Nd4j.getExecutioner().enableVerboseMode(false);
-
- //Attempt to keep workspaces isolated between tests
- Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread();
- val currWS = Nd4j.getMemoryManager().getCurrentWorkspace();
- Nd4j.getMemoryManager().setCurrentWorkspace(null);
- if(currWS != null){
- //Not really safe to continue testing under this situation... other tests will likely fail with obscure
- // errors that are hard to track back to this
- log.error("Open workspace leaked from test! Exiting - {}, isOpen = {} - {}", currWS.getId(), currWS.isScopeActive(), currWS);
- System.exit(1);
- }
- }
-
- public void logTestCompletion( long totalTime){
- StringBuilder sb = new StringBuilder();
- long maxPhys = Pointer.maxPhysicalBytes();
- long maxBytes = Pointer.maxBytes();
- long currPhys = Pointer.physicalBytes();
- long currBytes = Pointer.totalBytes();
-
- long jvmTotal = Runtime.getRuntime().totalMemory();
- long jvmMax = Runtime.getRuntime().maxMemory();
-
- int threadsAfter = ManagementFactory.getThreadMXBean().getThreadCount();
- sb.append(getClass().getSimpleName()).append(".").append(testName.getMethodName())
- .append(": ").append(totalTime).append(" ms")
- .append(", threadCount: (").append(threadCountBefore).append("->").append(threadsAfter).append(")")
- .append(", jvmTotal=").append(jvmTotal)
- .append(", jvmMax=").append(jvmMax)
- .append(", totalBytes=").append(currBytes).append(", maxBytes=").append(maxBytes)
- .append(", currPhys=").append(currPhys).append(", maxPhys=").append(maxPhys);
-
- List ws = Nd4j.getWorkspaceManager().getAllWorkspacesForCurrentThread();
- if(ws != null && ws.size() > 0){
- long currSize = 0;
- for(MemoryWorkspace w : ws){
- currSize += w.getCurrentSize();
- }
- if(currSize > 0){
- sb.append(", threadWSSize=").append(currSize)
- .append(" (").append(ws.size()).append(" WSs)");
- }
- }
-
-
- Properties p = Nd4j.getExecutioner().getEnvironmentInformation();
- Object o = p.get("cuda.devicesInformation");
- if(o instanceof List){
- List