diff --git a/arbiter/arbiter-core/pom.xml b/arbiter/arbiter-core/pom.xml index 04a1fc0f0..6ce3c9c1f 100644 --- a/arbiter/arbiter-core/pom.xml +++ b/arbiter/arbiter-core/pom.xml @@ -84,6 +84,13 @@ jackson ${nd4j.version} + + + org.deeplearning4j + deeplearning4j-common-tests + ${project.version} + test + diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestGeneticSearch.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestGeneticSearch.java index e7d07f81a..abfedac5e 100644 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestGeneticSearch.java +++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestGeneticSearch.java @@ -16,6 +16,7 @@ package org.deeplearning4j.arbiter.optimize; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator; import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction; import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition; @@ -32,7 +33,7 @@ import org.deeplearning4j.arbiter.optimize.runner.listener.impl.LoggingStatusLis import org.junit.Assert; import org.junit.Test; -public class TestGeneticSearch { +public class TestGeneticSearch extends BaseDL4JTest { public class TestSelectionOperator extends SelectionOperator { @Override diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestGridSearch.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestGridSearch.java index 7f9906abf..24f80bf23 100644 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestGridSearch.java +++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestGridSearch.java @@ -16,6 +16,7 @@ package org.deeplearning4j.arbiter.optimize; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator; import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider; import org.deeplearning4j.arbiter.optimize.generator.GridSearchCandidateGenerator; @@ -26,7 +27,7 @@ import java.util.Map; import static org.junit.Assert.*; -public class TestGridSearch { +public class TestGridSearch extends BaseDL4JTest { @Test public void testIndexing() { diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestJson.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestJson.java index 6f1b336bb..cf5b1e1c7 100644 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestJson.java +++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestJson.java @@ -19,6 +19,7 @@ package org.deeplearning4j.arbiter.optimize; import org.apache.commons.math3.distribution.LogNormalDistribution; import org.apache.commons.math3.distribution.NormalDistribution; import org.apache.commons.math3.distribution.UniformIntegerDistribution; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator; import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider; @@ -49,7 +50,7 @@ import static org.junit.Assert.assertEquals; /** * Created by Alex on 02/02/2017. */ -public class TestJson { +public class TestJson extends BaseDL4JTest { protected static ObjectMapper getObjectMapper(JsonFactory factory) { ObjectMapper om = new ObjectMapper(factory); diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestRandomSearch.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestRandomSearch.java index 305480420..34916ebdc 100644 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestRandomSearch.java +++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestRandomSearch.java @@ -16,6 +16,7 @@ package org.deeplearning4j.arbiter.optimize; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator; import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider; import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition; @@ -34,7 +35,7 @@ import java.util.Map; * Test random search on the Branin Function: * http://www.sfu.ca/~ssurjano/branin.html */ -public class TestRandomSearch { +public class TestRandomSearch extends BaseDL4JTest { @Test public void test() throws Exception { diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/distribution/TestLogUniform.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/distribution/TestLogUniform.java index 85b048df7..6ca842637 100644 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/distribution/TestLogUniform.java +++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/distribution/TestLogUniform.java @@ -17,12 +17,13 @@ package org.deeplearning4j.arbiter.optimize.distribution; import org.apache.commons.math3.distribution.RealDistribution; +import org.deeplearning4j.BaseDL4JTest; import org.junit.Test; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; -public class TestLogUniform { +public class TestLogUniform extends BaseDL4JTest { @Test public void testSimple(){ diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/ArithmeticCrossoverTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/ArithmeticCrossoverTests.java index 4342f32ef..252b4304f 100644 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/ArithmeticCrossoverTests.java +++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/ArithmeticCrossoverTests.java @@ -17,6 +17,7 @@ package org.deeplearning4j.arbiter.optimize.genetic.crossover; import org.apache.commons.math3.random.RandomGenerator; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.ArithmeticCrossover; import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult; import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection; @@ -24,7 +25,7 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator; import org.junit.Assert; import org.junit.Test; -public class ArithmeticCrossoverTests { +public class ArithmeticCrossoverTests extends BaseDL4JTest { @Test public void ArithmeticCrossover_Crossover_OutsideCrossoverRate_ShouldReturnParent0() { diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/CrossoverOperatorTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/CrossoverOperatorTests.java index 96e3fda7b..50ae7f729 100644 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/CrossoverOperatorTests.java +++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/CrossoverOperatorTests.java @@ -16,6 +16,7 @@ package org.deeplearning4j.arbiter.optimize.genetic.crossover; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer; import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel; import org.deeplearning4j.arbiter.optimize.genetic.TestCrossoverOperator; @@ -23,7 +24,7 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer; import org.junit.Assert; import org.junit.Test; -public class CrossoverOperatorTests { +public class CrossoverOperatorTests extends BaseDL4JTest { @Test public void CrossoverOperator_initializeInstance_ShouldInitPopulationModel() throws IllegalAccessException { diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/CrossoverPointsGeneratorTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/CrossoverPointsGeneratorTests.java index 0b1ec5271..5c1bebb51 100644 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/CrossoverPointsGeneratorTests.java +++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/CrossoverPointsGeneratorTests.java @@ -17,6 +17,7 @@ package org.deeplearning4j.arbiter.optimize.genetic.crossover; import org.apache.commons.math3.random.RandomGenerator; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.utils.CrossoverPointsGenerator; import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator; import org.junit.Assert; @@ -24,7 +25,7 @@ import org.junit.Test; import java.util.Deque; -public class CrossoverPointsGeneratorTests { +public class CrossoverPointsGeneratorTests extends BaseDL4JTest { @Test public void CrossoverPointsGenerator_FixedNumberCrossovers() { diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/KPointCrossoverTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/KPointCrossoverTests.java index fbafb37a5..2399d256a 100644 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/KPointCrossoverTests.java +++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/KPointCrossoverTests.java @@ -17,6 +17,7 @@ package org.deeplearning4j.arbiter.optimize.genetic.crossover; import org.apache.commons.math3.random.RandomGenerator; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult; import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.KPointCrossover; import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection; @@ -25,7 +26,7 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator; import org.junit.Assert; import org.junit.Test; -public class KPointCrossoverTests { +public class KPointCrossoverTests extends BaseDL4JTest { @Test public void KPointCrossover_BelowCrossoverRate_ShouldReturnParent0() { diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/ParentSelectionTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/ParentSelectionTests.java index ed2bec0ba..6976d8dd4 100644 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/ParentSelectionTests.java +++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/ParentSelectionTests.java @@ -16,6 +16,7 @@ package org.deeplearning4j.arbiter.optimize.genetic.crossover; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome; import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection; import org.junit.Assert; @@ -24,7 +25,7 @@ import org.junit.Test; import java.util.ArrayList; import java.util.List; -public class ParentSelectionTests { +public class ParentSelectionTests extends BaseDL4JTest { @Test public void ParentSelection_InitializeInstance_ShouldInitPopulation() { diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/RandomTwoParentSelectionTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/RandomTwoParentSelectionTests.java index 2e8a166a7..09b244ab3 100644 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/RandomTwoParentSelectionTests.java +++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/RandomTwoParentSelectionTests.java @@ -17,6 +17,7 @@ package org.deeplearning4j.arbiter.optimize.genetic.crossover; import org.apache.commons.math3.random.RandomGenerator; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome; import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.RandomTwoParentSelection; import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator; @@ -26,7 +27,7 @@ import org.junit.Test; import java.util.ArrayList; import java.util.List; -public class RandomTwoParentSelectionTests { +public class RandomTwoParentSelectionTests extends BaseDL4JTest { @Test public void RandomTwoParentSelection_ShouldReturnTwoDifferentParents() { RandomGenerator rng = new TestRandomGenerator(new int[] {1, 1, 1, 0}, null); diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/SinglePointCrossoverTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/SinglePointCrossoverTests.java index 36c11e00d..32dfb136c 100644 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/SinglePointCrossoverTests.java +++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/SinglePointCrossoverTests.java @@ -17,6 +17,7 @@ package org.deeplearning4j.arbiter.optimize.genetic.crossover; import org.apache.commons.math3.random.RandomGenerator; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult; import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.SinglePointCrossover; import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection; @@ -24,7 +25,7 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator; import org.junit.Assert; import org.junit.Test; -public class SinglePointCrossoverTests { +public class SinglePointCrossoverTests extends BaseDL4JTest { @Test public void SinglePointCrossover_BelowCrossoverRate_ShouldReturnParent0() { RandomGenerator rng = new TestRandomGenerator(null, new double[] {1.0}); diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/TwoParentsCrossoverOperatorTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/TwoParentsCrossoverOperatorTests.java index 170984a68..9efe89620 100644 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/TwoParentsCrossoverOperatorTests.java +++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/TwoParentsCrossoverOperatorTests.java @@ -16,6 +16,7 @@ package org.deeplearning4j.arbiter.optimize.genetic.crossover; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult; import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.TwoParentsCrossoverOperator; import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection; @@ -27,7 +28,7 @@ import org.junit.Assert; import org.junit.Test; import sun.reflect.generics.reflectiveObjects.NotImplementedException; -public class TwoParentsCrossoverOperatorTests { +public class TwoParentsCrossoverOperatorTests extends BaseDL4JTest { class TestTwoParentsCrossoverOperator extends TwoParentsCrossoverOperator { diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/UniformCrossoverTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/UniformCrossoverTests.java index fc4524fdc..76a395c28 100644 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/UniformCrossoverTests.java +++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/UniformCrossoverTests.java @@ -17,6 +17,7 @@ package org.deeplearning4j.arbiter.optimize.genetic.crossover; import org.apache.commons.math3.random.RandomGenerator; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult; import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.UniformCrossover; import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection; @@ -24,7 +25,7 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator; import org.junit.Assert; import org.junit.Test; -public class UniformCrossoverTests { +public class UniformCrossoverTests extends BaseDL4JTest { @Test public void UniformCrossover_BelowCrossoverRate_ShouldReturnParent0() { diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/culling/LeastFitCullOperatorTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/culling/LeastFitCullOperatorTests.java index fbd5465c4..ccdb434e8 100644 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/culling/LeastFitCullOperatorTests.java +++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/culling/LeastFitCullOperatorTests.java @@ -16,6 +16,7 @@ package org.deeplearning4j.arbiter.optimize.genetic.culling; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome; import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.LeastFitCullOperator; import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer; @@ -27,7 +28,7 @@ import org.junit.Test; import java.util.ArrayList; import java.util.List; -public class LeastFitCullOperatorTests { +public class LeastFitCullOperatorTests extends BaseDL4JTest { @Test public void LeastFitCullingOperation_ShouldCullLastElements() { diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/culling/RatioCullOperatorTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/culling/RatioCullOperatorTests.java index 4e41268be..093ffd486 100644 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/culling/RatioCullOperatorTests.java +++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/culling/RatioCullOperatorTests.java @@ -16,6 +16,7 @@ package org.deeplearning4j.arbiter.optimize.genetic.culling; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome; import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.RatioCullOperator; import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer; @@ -27,7 +28,7 @@ import sun.reflect.generics.reflectiveObjects.NotImplementedException; import java.util.List; -public class RatioCullOperatorTests { +public class RatioCullOperatorTests extends BaseDL4JTest { class TestRatioCullOperator extends RatioCullOperator { diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/mutation/RandomMutationOperatorTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/mutation/RandomMutationOperatorTests.java index 80773c39e..38e2ba87b 100644 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/mutation/RandomMutationOperatorTests.java +++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/mutation/RandomMutationOperatorTests.java @@ -16,6 +16,7 @@ package org.deeplearning4j.arbiter.optimize.genetic.mutation; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.optimize.generator.genetic.mutation.RandomMutationOperator; import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator; import org.junit.Assert; @@ -24,7 +25,7 @@ import org.junit.Test; import java.lang.reflect.Field; import java.util.Arrays; -public class RandomMutationOperatorTests { +public class RandomMutationOperatorTests extends BaseDL4JTest { @Test public void RandomMutationOperator_DefaultBuild_ShouldNotBeNull() { RandomMutationOperator sut = new RandomMutationOperator.Builder().build(); diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/population/PopulationModelTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/population/PopulationModelTests.java index 760b198b0..e185b8164 100644 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/population/PopulationModelTests.java +++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/population/PopulationModelTests.java @@ -16,6 +16,7 @@ package org.deeplearning4j.arbiter.optimize.genetic.population; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome; import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.CullOperator; import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer; @@ -27,7 +28,7 @@ import org.junit.Test; import java.util.List; -public class PopulationModelTests { +public class PopulationModelTests extends BaseDL4JTest { private class TestCullOperator implements CullOperator { diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/selection/GeneticSelectionOperatorTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/selection/GeneticSelectionOperatorTests.java index e5df87549..1d2b74de9 100644 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/selection/GeneticSelectionOperatorTests.java +++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/selection/GeneticSelectionOperatorTests.java @@ -17,6 +17,7 @@ package org.deeplearning4j.arbiter.optimize.genetic.selection; import org.apache.commons.math3.random.RandomGenerator; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory; import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverOperator; import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult; @@ -36,7 +37,7 @@ import sun.reflect.generics.reflectiveObjects.NotImplementedException; import static org.junit.Assert.assertArrayEquals; -public class GeneticSelectionOperatorTests { +public class GeneticSelectionOperatorTests extends BaseDL4JTest { private class TestCullOperator implements CullOperator { diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/selection/SelectionOperatorTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/selection/SelectionOperatorTests.java index 5cc61d744..3f64279ee 100644 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/selection/SelectionOperatorTests.java +++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/selection/SelectionOperatorTests.java @@ -16,6 +16,7 @@ package org.deeplearning4j.arbiter.optimize.genetic.selection; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory; import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer; import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel; @@ -25,7 +26,7 @@ import org.junit.Assert; import org.junit.Test; import sun.reflect.generics.reflectiveObjects.NotImplementedException; -public class SelectionOperatorTests { +public class SelectionOperatorTests extends BaseDL4JTest { private class TestSelectionOperator extends SelectionOperator { public PopulationModel getPopulationModel() { diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/parameter/TestParameterSpaces.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/parameter/TestParameterSpaces.java index 4a203f770..98396a941 100644 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/parameter/TestParameterSpaces.java +++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/parameter/TestParameterSpaces.java @@ -17,6 +17,7 @@ package org.deeplearning4j.arbiter.optimize.parameter; import org.apache.commons.math3.distribution.NormalDistribution; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; import org.deeplearning4j.arbiter.optimize.parameter.continuous.ContinuousParameterSpace; import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterSpace; @@ -25,7 +26,7 @@ import org.junit.Test; import static org.junit.Assert.assertEquals; -public class TestParameterSpaces { +public class TestParameterSpaces extends BaseDL4JTest { @Test diff --git a/arbiter/arbiter-deeplearning4j/pom.xml b/arbiter/arbiter-deeplearning4j/pom.xml index b163e2ae4..ec7e22d3c 100644 --- a/arbiter/arbiter-deeplearning4j/pom.xml +++ b/arbiter/arbiter-deeplearning4j/pom.xml @@ -63,6 +63,13 @@ gson ${gson.version} + + + org.deeplearning4j + deeplearning4j-common-tests + ${project.version} + test + diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestComputationGraphSpace.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestComputationGraphSpace.java index af5d04c0a..7c4ec38f4 100644 --- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestComputationGraphSpace.java +++ b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestComputationGraphSpace.java @@ -16,6 +16,7 @@ package org.deeplearning4j.arbiter.computationgraph; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.ComputationGraphSpace; import org.deeplearning4j.arbiter.TestUtils; import org.deeplearning4j.arbiter.conf.updater.SgdSpace; @@ -44,7 +45,7 @@ import java.util.Random; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; -public class TestComputationGraphSpace { +public class TestComputationGraphSpace extends BaseDL4JTest { @Test public void testBasic() { diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecution.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecution.java index c64a06040..1747b45f9 100644 --- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecution.java +++ b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecution.java @@ -17,6 +17,7 @@ package org.deeplearning4j.arbiter.computationgraph; import lombok.extern.slf4j.Slf4j; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.ComputationGraphSpace; import org.deeplearning4j.arbiter.conf.updater.AdamSpace; import org.deeplearning4j.arbiter.conf.updater.SgdSpace; @@ -85,7 +86,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @Slf4j -public class TestGraphLocalExecution { +public class TestGraphLocalExecution extends BaseDL4JTest { @Rule public TemporaryFolder testDir = new TemporaryFolder(); @@ -126,7 +127,7 @@ public class TestGraphLocalExecution { if(dataApproach == 0){ ds = TestDL4JLocalExecution.MnistDataSource.class; dsP = new Properties(); - dsP.setProperty("minibatch", "8"); + dsP.setProperty("minibatch", "2"); candidateGenerator = new RandomSearchGenerator(mls); } else if(dataApproach == 1) { //DataProvider approach @@ -150,8 +151,8 @@ public class TestGraphLocalExecution { .dataSource(ds, dsP) .modelSaver(new FileModelSaver(modelSave)) .scoreFunction(new TestSetLossScoreFunction()) - .terminationConditions(new MaxTimeCondition(2, TimeUnit.MINUTES), - new MaxCandidatesCondition(5)) + .terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS), + new MaxCandidatesCondition(3)) .build(); IOptimizationRunner runner = new LocalOptimizationRunner(configuration,new ComputationGraphTaskCreator(new ClassificationEvaluator())); @@ -159,7 +160,7 @@ public class TestGraphLocalExecution { runner.execute(); List results = runner.getResults(); - assertEquals(5, results.size()); + assertTrue(results.size() > 0); System.out.println("----- COMPLETE - " + results.size() + " results -----"); } @@ -203,8 +204,8 @@ public class TestGraphLocalExecution { OptimizationConfiguration configuration = new OptimizationConfiguration.Builder() .candidateGenerator(candidateGenerator).dataProvider(dataProvider) .modelSaver(new FileModelSaver(modelSavePath)).scoreFunction(ScoreFunctions.testSetLoss(true)) - .terminationConditions(new MaxTimeCondition(20, TimeUnit.SECONDS), - new MaxCandidatesCondition(10)) + .terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS), + new MaxCandidatesCondition(3)) .build(); IOptimizationRunner runner = new LocalOptimizationRunner(configuration, @@ -223,7 +224,7 @@ public class TestGraphLocalExecution { .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.1))) .l2(new ContinuousParameterSpace(0.0001, 0.01)).addInputs("in") - .setInputTypes(InputType.feedForward(4)) + .setInputTypes(InputType.feedForward(784)) .addLayer("layer0", new DenseLayerSpace.Builder().nIn(784).nOut(new IntegerParameterSpace(2, 10)) .activation(new DiscreteParameterSpace<>(Activation.RELU, Activation.TANH)) @@ -250,8 +251,8 @@ public class TestGraphLocalExecution { .candidateGenerator(candidateGenerator) .dataProvider(new TestMdsDataProvider(1, 32)) .modelSaver(new FileModelSaver(modelSavePath)).scoreFunction(ScoreFunctions.testSetLoss(true)) - .terminationConditions(new MaxTimeCondition(20, TimeUnit.SECONDS), - new MaxCandidatesCondition(10)) + .terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS), + new MaxCandidatesCondition(3)) .scoreFunction(ScoreFunctions.testSetAccuracy()) .build(); @@ -279,7 +280,7 @@ public class TestGraphLocalExecution { @Override public Object trainData(Map dataParameters) { try { - DataSetIterator underlying = new MnistDataSetIterator(batchSize, Math.min(60000, 10 * batchSize), false, true, true, 12345); + DataSetIterator underlying = new MnistDataSetIterator(batchSize, Math.min(60000, 3 * batchSize), false, true, true, 12345); return new MultiDataSetIteratorAdapter(new MultipleEpochsIterator(numEpochs, underlying)); } catch (IOException e) { throw new RuntimeException(e); @@ -289,7 +290,7 @@ public class TestGraphLocalExecution { @Override public Object testData(Map dataParameters) { try { - DataSetIterator underlying = new MnistDataSetIterator(batchSize, Math.min(10000, 5 * batchSize), false, false, false, 12345); + DataSetIterator underlying = new MnistDataSetIterator(batchSize, Math.min(10000, 2 * batchSize), false, false, false, 12345); return new MultiDataSetIteratorAdapter(underlying); } catch (IOException e) { throw new RuntimeException(e); @@ -305,7 +306,7 @@ public class TestGraphLocalExecution { @Test public void testLocalExecutionEarlyStopping() throws Exception { EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder() - .epochTerminationConditions(new MaxEpochsTerminationCondition(4)) + .epochTerminationConditions(new MaxEpochsTerminationCondition(2)) .scoreCalculator(new ScoreProvider()) .modelSaver(new InMemoryModelSaver()).build(); Map commands = new HashMap<>(); @@ -348,8 +349,8 @@ public class TestGraphLocalExecution { .dataProvider(dataProvider) .scoreFunction(ScoreFunctions.testSetF1()) .modelSaver(new FileModelSaver(modelSavePath)) - .terminationConditions(new MaxTimeCondition(45, TimeUnit.SECONDS), - new MaxCandidatesCondition(10)) + .terminationConditions(new MaxTimeCondition(15, TimeUnit.SECONDS), + new MaxCandidatesCondition(3)) .build(); @@ -364,7 +365,7 @@ public class TestGraphLocalExecution { @Override public ScoreCalculator get() { try { - return new DataSetLossCalculatorCG(new MnistDataSetIterator(128, 1280), true); + return new DataSetLossCalculatorCG(new MnistDataSetIterator(4, 8), true); } catch (Exception e){ throw new RuntimeException(e); } diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecutionGenetic.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecutionGenetic.java index e67e854b8..2b9c5696d 100644 --- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecutionGenetic.java +++ b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecutionGenetic.java @@ -17,6 +17,7 @@ package org.deeplearning4j.arbiter.computationgraph; import lombok.extern.slf4j.Slf4j; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.ComputationGraphSpace; import org.deeplearning4j.arbiter.conf.updater.AdamSpace; import org.deeplearning4j.arbiter.conf.updater.SgdSpace; @@ -79,11 +80,16 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @Slf4j -public class TestGraphLocalExecutionGenetic { +public class TestGraphLocalExecutionGenetic extends BaseDL4JTest { @Rule public TemporaryFolder testDir = new TemporaryFolder(); + @Override + public long getTimeoutMilliseconds() { + return 45000L; + } + @Test public void testLocalExecutionDataSources() throws Exception { for (int dataApproach = 0; dataApproach < 3; dataApproach++) { @@ -115,7 +121,7 @@ public class TestGraphLocalExecutionGenetic { if (dataApproach == 0) { ds = TestDL4JLocalExecution.MnistDataSource.class; dsP = new Properties(); - dsP.setProperty("minibatch", "8"); + dsP.setProperty("minibatch", "2"); candidateGenerator = new GeneticSearchCandidateGenerator.Builder(mls, scoreFunction) .populationModel(new PopulationModel.Builder().populationSize(5).build()) @@ -148,7 +154,7 @@ public class TestGraphLocalExecutionGenetic { .dataSource(ds, dsP) .modelSaver(new FileModelSaver(modelSave)) .scoreFunction(new TestSetLossScoreFunction()) - .terminationConditions(new MaxTimeCondition(2, TimeUnit.MINUTES), + .terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS), new MaxCandidatesCondition(10)) .build(); @@ -157,7 +163,7 @@ public class TestGraphLocalExecutionGenetic { runner.execute(); List results = runner.getResults(); - assertEquals(10, results.size()); + assertTrue(results.size() > 0); System.out.println("----- COMPLETE - " + results.size() + " results -----"); } diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/json/TestJson.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/json/TestJson.java index d6e77ccf4..12ccc71a5 100644 --- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/json/TestJson.java +++ b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/json/TestJson.java @@ -16,6 +16,7 @@ package org.deeplearning4j.arbiter.json; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.ComputationGraphSpace; import org.deeplearning4j.arbiter.MultiLayerSpace; import org.deeplearning4j.arbiter.conf.updater.AdaMaxSpace; @@ -71,7 +72,7 @@ import static org.junit.Assert.assertNotNull; /** * Created by Alex on 14/02/2017. */ -public class TestJson { +public class TestJson extends BaseDL4JTest { @Test public void testMultiLayerSpaceJson() { diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/MNISTOptimizationTest.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/MNISTOptimizationTest.java index 501501e65..ea754990a 100644 --- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/MNISTOptimizationTest.java +++ b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/MNISTOptimizationTest.java @@ -16,6 +16,7 @@ package org.deeplearning4j.arbiter.multilayernetwork; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.MultiLayerSpace; import org.deeplearning4j.arbiter.conf.updater.SgdSpace; import org.deeplearning4j.arbiter.layers.ConvolutionLayerSpace; @@ -59,7 +60,7 @@ import java.util.concurrent.TimeUnit; // import org.deeplearning4j.arbiter.optimize.ui.listener.UIOptimizationRunnerStatusListener; /** Not strictly a unit test. Rather: part example, part debugging on MNIST */ -public class MNISTOptimizationTest { +public class MNISTOptimizationTest extends BaseDL4JTest { public static void main(String[] args) throws Exception { EarlyStoppingConfiguration esConf = diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestDL4JLocalExecution.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestDL4JLocalExecution.java index 0f4d384ba..554ef346e 100644 --- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestDL4JLocalExecution.java +++ b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestDL4JLocalExecution.java @@ -17,6 +17,7 @@ package org.deeplearning4j.arbiter.multilayernetwork; import lombok.extern.slf4j.Slf4j; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.MultiLayerSpace; import org.deeplearning4j.arbiter.conf.updater.SgdSpace; import org.deeplearning4j.arbiter.evaluator.multilayer.ClassificationEvaluator; @@ -72,9 +73,10 @@ import java.util.Properties; import java.util.concurrent.TimeUnit; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; @Slf4j -public class TestDL4JLocalExecution { +public class TestDL4JLocalExecution extends BaseDL4JTest { @Rule public TemporaryFolder testDir = new TemporaryFolder(); @@ -112,7 +114,7 @@ public class TestDL4JLocalExecution { if(dataApproach == 0){ ds = MnistDataSource.class; dsP = new Properties(); - dsP.setProperty("minibatch", "8"); + dsP.setProperty("minibatch", "2"); candidateGenerator = new RandomSearchGenerator(mls); } else if(dataApproach == 1) { //DataProvider approach @@ -136,7 +138,7 @@ public class TestDL4JLocalExecution { .dataSource(ds, dsP) .modelSaver(new FileModelSaver(modelSave)) .scoreFunction(new TestSetLossScoreFunction()) - .terminationConditions(new MaxTimeCondition(2, TimeUnit.MINUTES), + .terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS), new MaxCandidatesCondition(5)) .build(); @@ -146,7 +148,7 @@ public class TestDL4JLocalExecution { runner.execute(); List results = runner.getResults(); - assertEquals(5, results.size()); + assertTrue(results.size() > 0); System.out.println("----- COMPLETE - " + results.size() + " results -----"); } diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestErrors.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestErrors.java index bb8806f23..a62997a33 100644 --- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestErrors.java +++ b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestErrors.java @@ -16,6 +16,7 @@ package org.deeplearning4j.arbiter.multilayernetwork; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.ComputationGraphSpace; import org.deeplearning4j.arbiter.MultiLayerSpace; import org.deeplearning4j.arbiter.layers.DenseLayerSpace; @@ -39,7 +40,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions; import java.io.File; -public class TestErrors { +public class TestErrors extends BaseDL4JTest { @Rule public TemporaryFolder temp = new TemporaryFolder(); @@ -60,7 +61,7 @@ public class TestErrors { CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls); OptimizationConfiguration configuration = new OptimizationConfiguration.Builder() - .candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 10)) + .candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 3)) .modelSaver(new FileModelSaver(f)).scoreFunction(new TestSetLossScoreFunction(true)) .terminationConditions( new MaxCandidatesCondition(5)) @@ -87,7 +88,7 @@ public class TestErrors { CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls); OptimizationConfiguration configuration = new OptimizationConfiguration.Builder() - .candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 10)) + .candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 3)) .modelSaver(new FileModelSaver(f)).scoreFunction(new TestSetLossScoreFunction(true)) .terminationConditions( new MaxCandidatesCondition(5)) @@ -116,7 +117,7 @@ public class TestErrors { CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls); OptimizationConfiguration configuration = new OptimizationConfiguration.Builder() - .candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 10)) + .candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 3)) .modelSaver(new FileModelSaver(f)).scoreFunction(new TestSetLossScoreFunction(true)) .terminationConditions(new MaxCandidatesCondition(5)) .build(); @@ -143,7 +144,7 @@ public class TestErrors { CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls); OptimizationConfiguration configuration = new OptimizationConfiguration.Builder() - .candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 10)) + .candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 3)) .modelSaver(new FileModelSaver(f)).scoreFunction(new TestSetLossScoreFunction(true)) .terminationConditions( new MaxCandidatesCondition(5)) diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestLayerSpace.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestLayerSpace.java index e3338efcb..6a5458e65 100644 --- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestLayerSpace.java +++ b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestLayerSpace.java @@ -17,6 +17,7 @@ package org.deeplearning4j.arbiter.multilayernetwork; import org.apache.commons.lang3.ArrayUtils; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.TestUtils; import org.deeplearning4j.arbiter.conf.updater.SgdSpace; import org.deeplearning4j.arbiter.layers.*; @@ -44,7 +45,7 @@ import java.util.Random; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; -public class TestLayerSpace { +public class TestLayerSpace extends BaseDL4JTest { @Test public void testBasic1() { diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestMultiLayerSpace.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestMultiLayerSpace.java index 48055ed4b..99dc79f42 100644 --- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestMultiLayerSpace.java +++ b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestMultiLayerSpace.java @@ -16,6 +16,7 @@ package org.deeplearning4j.arbiter.multilayernetwork; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.DL4JConfiguration; import org.deeplearning4j.arbiter.MultiLayerSpace; import org.deeplearning4j.arbiter.TestUtils; @@ -86,7 +87,7 @@ import java.util.*; import static org.junit.Assert.*; -public class TestMultiLayerSpace { +public class TestMultiLayerSpace extends BaseDL4JTest { @Rule public TemporaryFolder testDir = new TemporaryFolder(); diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestScoreFunctions.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestScoreFunctions.java index 01adefd18..f2dc5d180 100644 --- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestScoreFunctions.java +++ b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestScoreFunctions.java @@ -18,6 +18,7 @@ package org.deeplearning4j.arbiter.multilayernetwork; import lombok.AllArgsConstructor; import lombok.extern.slf4j.Slf4j; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.MultiLayerSpace; import org.deeplearning4j.arbiter.conf.updater.AdamSpace; import org.deeplearning4j.arbiter.layers.OutputLayerSpace; @@ -60,7 +61,13 @@ import java.util.Map; import static org.junit.Assert.assertEquals; @Slf4j -public class TestScoreFunctions { +public class TestScoreFunctions extends BaseDL4JTest { + + + @Override + public long getTimeoutMilliseconds() { + return 60000L; + } @Test public void testROCScoreFunctions() throws Exception { @@ -107,7 +114,7 @@ public class TestScoreFunctions { List list = runner.getResults(); for (ResultReference rr : list) { - DataSetIterator testIter = new MnistDataSetIterator(32, 2000, false, false, true, 12345); + DataSetIterator testIter = new MnistDataSetIterator(4, 16, false, false, false, 12345); testIter.setPreProcessor(new PreProc(rocType)); OptimizationResult or = rr.getResult(); @@ -141,10 +148,10 @@ public class TestScoreFunctions { } - DataSetIterator iter = new MnistDataSetIterator(32, 8000, false, true, true, 12345); + DataSetIterator iter = new MnistDataSetIterator(4, 16, false, false, false, 12345); iter.setPreProcessor(new PreProc(rocType)); - assertEquals(msg, expScore, or.getScore(), 1e-5); + assertEquals(msg, expScore, or.getScore(), 1e-4); } } } @@ -158,7 +165,7 @@ public class TestScoreFunctions { @Override public Object trainData(Map dataParameters) { try { - DataSetIterator iter = new MnistDataSetIterator(32, 8000, false, true, true, 12345); + DataSetIterator iter = new MnistDataSetIterator(4, 16, false, false, false, 12345); iter.setPreProcessor(new PreProc(rocType)); return iter; } catch (IOException e){ @@ -169,7 +176,7 @@ public class TestScoreFunctions { @Override public Object testData(Map dataParameters) { try { - DataSetIterator iter = new MnistDataSetIterator(32, 2000, false, false, true, 12345); + DataSetIterator iter = new MnistDataSetIterator(4, 16, false, false, false, 12345); iter.setPreProcessor(new PreProc(rocType)); return iter; } catch (IOException e){ diff --git a/arbiter/arbiter-server/pom.xml b/arbiter/arbiter-server/pom.xml index bdea61138..c4306b967 100644 --- a/arbiter/arbiter-server/pom.xml +++ b/arbiter/arbiter-server/pom.xml @@ -49,6 +49,13 @@ ${junit.version} test + + + org.deeplearning4j + deeplearning4j-common-tests + ${project.version} + test + diff --git a/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/ArbiterCLIRunnerTest.java b/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/ArbiterCLIRunnerTest.java index dd2b53409..21e4e402a 100644 --- a/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/ArbiterCLIRunnerTest.java +++ b/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/ArbiterCLIRunnerTest.java @@ -18,6 +18,7 @@ package org.deeplearning4j.arbiter.server; import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.FileUtils; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.arbiter.MultiLayerSpace; import org.deeplearning4j.arbiter.conf.updater.SgdSpace; import org.deeplearning4j.arbiter.layers.DenseLayerSpace; @@ -52,7 +53,7 @@ import static org.junit.Assert.assertEquals; * Created by agibsonccc on 3/12/17. */ @Slf4j -public class ArbiterCLIRunnerTest { +public class ArbiterCLIRunnerTest extends BaseDL4JTest { @Test diff --git a/deeplearning4j/deeplearning4j-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java b/deeplearning4j/deeplearning4j-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java index 466a68e6a..e9b609c45 100644 --- a/deeplearning4j/deeplearning4j-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java +++ b/deeplearning4j/deeplearning4j-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java @@ -24,6 +24,7 @@ import org.junit.Before; import org.junit.Rule; import org.junit.rules.TestName; import org.junit.rules.Timeout; +import org.nd4j.base.Preconditions; import org.nd4j.config.ND4JSystemProperties; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; @@ -36,6 +37,8 @@ import java.util.List; import java.util.Map; import java.util.Properties; +import static org.junit.Assume.assumeTrue; + @Slf4j public abstract class BaseDL4JTest { @@ -47,6 +50,17 @@ public abstract class BaseDL4JTest { protected long startTime; protected int threadCountBefore; + private final int DEFAULT_THREADS = Runtime.getRuntime().availableProcessors(); + + /** + * Override this to specify the number of threads for C++ execution, via + * {@link org.nd4j.linalg.factory.Environment#setMaxMasterThreads(int)} + * @return Number of threads to use for C++ op execution + */ + public int numThreads(){ + return DEFAULT_THREADS; + } + /** * Override this method to set the default timeout for methods in the test class */ @@ -72,6 +86,28 @@ public abstract class BaseDL4JTest { return getDataType(); } + protected Boolean integrationTest; + + /** + * @return True if integration tests maven profile is enabled, false otherwise. + */ + public boolean isIntegrationTests(){ + if(integrationTest == null){ + String prop = System.getenv("DL4J_INTEGRATION_TESTS"); + integrationTest = Boolean.parseBoolean(prop); + } + return integrationTest; + } + + /** + * Call this as the first line of a test in order to skip that test, only when the integration tests maven profile is not enabled. + * This can be used to dynamically skip integration tests when the integration test profile is not enabled. + * Note that the integration test profile is not enabled by default - "integration-tests" profile + */ + public void skipUnlessIntegrationTests(){ + assumeTrue("Skipping integration test - integration profile is not enabled", isIntegrationTests()); + } + @Before public void beforeTest(){ log.info("{}.{}", getClass().getSimpleName(), name.getMethodName()); @@ -81,6 +117,14 @@ public abstract class BaseDL4JTest { Nd4j.getExecutioner().setProfilingMode(getProfilingMode()); Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build()); Nd4j.setDefaultDataTypes(getDataType(), getDefaultFPDataType()); + Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build()); + Nd4j.getExecutioner().enableDebugMode(false); + Nd4j.getExecutioner().enableVerboseMode(false); + int numThreads = numThreads(); + Preconditions.checkState(numThreads > 0, "Number of threads must be > 0"); + if(numThreads != Nd4j.getEnvironment().maxMasterThreads()) { + Nd4j.getEnvironment().setMaxMasterThreads(numThreads); + } startTime = System.currentTimeMillis(); threadCountBefore = ManagementFactory.getThreadMXBean().getThreadCount(); } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/LayerHelperValidationUtil.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/LayerHelperValidationUtil.java index e3923c4ff..1b769a32d 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/LayerHelperValidationUtil.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/LayerHelperValidationUtil.java @@ -158,7 +158,7 @@ public class LayerHelperValidationUtil { double d2 = arr2.dup('c').getDouble(idx); System.out.println("Different values at index " + idx + ": " + d1 + ", " + d2 + " - RE = " + maxRE); } - assertTrue(s + layerName + "activations - max RE: " + maxRE, maxRE < t.getMaxRelError()); + assertTrue(s + layerName + " activations - max RE: " + maxRE, maxRE < t.getMaxRelError()); log.info("Forward pass, max relative error: " + layerName + " - " + maxRE); } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/AsyncDataSetIteratorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/AsyncDataSetIteratorTest.java index 8d8db1b8d..3e5015356 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/AsyncDataSetIteratorTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/AsyncDataSetIteratorTest.java @@ -78,8 +78,16 @@ public class AsyncDataSetIteratorTest extends BaseDL4JTest { @Test public void hasNextWithResetAndLoad() throws Exception { + int[] prefetchSizes; + if(isIntegrationTests()){ + prefetchSizes = new int[]{2, 3, 4, 5, 6, 7, 8}; + } else { + prefetchSizes = new int[]{2, 3, 8}; + } + + for (int iter = 0; iter < ITERATIONS; iter++) { - for (int prefetchSize = 2; prefetchSize <= 8; prefetchSize++) { + for(int prefetchSize : prefetchSizes){ AsyncDataSetIterator iterator = new AsyncDataSetIterator(backIterator, prefetchSize); TestDataSetConsumer consumer = new TestDataSetConsumer(EXECUTION_SMALL); int cnt = 0; @@ -161,8 +169,14 @@ public class AsyncDataSetIteratorTest extends BaseDL4JTest { @Test public void testVariableTimeSeries1() throws Exception { + int numBatches = isIntegrationTests() ? 1000 : 100; + int batchSize = isIntegrationTests() ? 32 : 8; + int timeStepsMin = 10; + int timeStepsMax = isIntegrationTests() ? 500 : 100; + int valuesPerTimestep = isIntegrationTests() ? 128 : 16; + AsyncDataSetIterator adsi = new AsyncDataSetIterator( - new VariableTimeseriesGenerator(1192, 1000, 32, 128, 10, 500, 10), 2, true); + new VariableTimeseriesGenerator(1192, numBatches, batchSize, valuesPerTimestep, timeStepsMin, timeStepsMax, 10), 2, true); for (int e = 0; e < 10; e++) { int cnt = 0; diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/AsyncMultiDataSetIteratorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/AsyncMultiDataSetIteratorTest.java index 68381bec7..d6a0d1fe1 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/AsyncMultiDataSetIteratorTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/AsyncMultiDataSetIteratorTest.java @@ -18,21 +18,10 @@ package org.deeplearning4j.datasets.iterator; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.datavec.api.records.reader.RecordReader; -import org.datavec.api.records.reader.SequenceRecordReader; -import org.datavec.api.records.reader.impl.csv.CSVRecordReader; -import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader; -import org.datavec.api.split.NumberedFileInputSplit; import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator; import org.deeplearning4j.datasets.iterator.tools.VariableMultiTimeseriesGenerator; import org.junit.Test; import org.nd4j.linalg.dataset.api.MultiDataSet; -import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; -import org.nd4j.linalg.dataset.api.preprocessor.MultiDataNormalization; -import org.nd4j.linalg.dataset.api.preprocessor.MultiNormalizerStandardize; - -import java.util.Arrays; import static org.junit.Assert.assertEquals; @@ -49,7 +38,13 @@ public class AsyncMultiDataSetIteratorTest extends BaseDL4JTest { */ @Test public void testVariableTimeSeries1() throws Exception { - val iterator = new VariableMultiTimeseriesGenerator(1192, 1000, 32, 128, 10, 500, 10); + int numBatches = isIntegrationTests() ? 1000 : 100; + int batchSize = isIntegrationTests() ? 32 : 8; + int timeStepsMin = 10; + int timeStepsMax = isIntegrationTests() ? 500 : 100; + int valuesPerTimestep = isIntegrationTests() ? 128 : 16; + + val iterator = new VariableMultiTimeseriesGenerator(1192, numBatches, batchSize, valuesPerTimestep, timeStepsMin, timeStepsMax, 10); iterator.reset(); iterator.hasNext(); val amdsi = new AsyncMultiDataSetIterator(iterator, 2, true); @@ -81,7 +76,13 @@ public class AsyncMultiDataSetIteratorTest extends BaseDL4JTest { @Test public void testVariableTimeSeries2() throws Exception { - val iterator = new VariableMultiTimeseriesGenerator(1192, 1000, 32, 128, 10, 500, 10); + int numBatches = isIntegrationTests() ? 1000 : 100; + int batchSize = isIntegrationTests() ? 32 : 8; + int timeStepsMin = 10; + int timeStepsMax = isIntegrationTests() ? 500 : 100; + int valuesPerTimestep = isIntegrationTests() ? 128 : 16; + + val iterator = new VariableMultiTimeseriesGenerator(1192, numBatches, batchSize, valuesPerTimestep, timeStepsMin, timeStepsMax, 10); for (int e = 0; e < 10; e++) { iterator.reset(); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/TestEmnistDataSetIterator.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/TestEmnistDataSetIterator.java index b0d5e0d25..c5ac04901 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/TestEmnistDataSetIterator.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/TestEmnistDataSetIterator.java @@ -46,17 +46,17 @@ public class TestEmnistDataSetIterator extends BaseDL4JTest { @Test public void testEmnistDataSetIterator() throws Exception { - // EmnistFetcher fetcher = new EmnistFetcher(EmnistDataSetIterator.Set.COMPLETE); - // File baseEmnistDir = fetcher.getFILE_DIR(); - // if(baseEmnistDir.exists()){ - // FileUtils.deleteDirectory(baseEmnistDir); - // } - // assertFalse(baseEmnistDir.exists()); - int batchSize = 128; - for (EmnistDataSetIterator.Set s : EmnistDataSetIterator.Set.values()) { + EmnistDataSetIterator.Set[] sets; + if(isIntegrationTests()){ + sets = EmnistDataSetIterator.Set.values(); + } else { + sets = new EmnistDataSetIterator.Set[]{EmnistDataSetIterator.Set.MNIST, EmnistDataSetIterator.Set.LETTERS}; + } + + for (EmnistDataSetIterator.Set s : sets) { boolean isBalanced = EmnistDataSetIterator.isBalanced(s); int numLabels = EmnistDataSetIterator.numLabels(s); INDArray labelCounts = null; diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/EvalTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/EvalTest.java index 43370548f..812ea2b08 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/EvalTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/EvalTest.java @@ -476,7 +476,7 @@ public class EvalTest extends BaseDL4JTest { net.setListeners(new EvaluativeListener(iterTest, 3)); - for( int i=0; i<10; i++ ){ + for( int i=0; i<3; i++ ){ net.fit(iter); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN3DGradientCheckTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN3DGradientCheckTest.java index cbad1adbb..9a4994d73 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN3DGradientCheckTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN3DGradientCheckTest.java @@ -339,9 +339,6 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest { if (PRINT_RESULTS) { log.info(msg); - for (int j = 0; j < net.getnLayers(); j++) { - log.info("Layer " + j + " # params: " + net.getLayer(j).numParams()); - } } boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, @@ -623,13 +620,10 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest { if (PRINT_RESULTS) { log.info(msg); -// for (int j = 0; j < net.getnLayers(); j++) { -// log.info("Layer " + j + " # params: " + net.getLayer(j).numParams()); -// } } boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input) - .labels(labels).subset(true).maxPerParam(128)); + .labels(labels).subset(true).maxPerParam(64)); assertTrue(msg, gradOK); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java index 4ddb1ad40..329adbc9b 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java @@ -557,60 +557,52 @@ public class CNNGradientCheckTest extends BaseDL4JTest { int[] minibatchSizes = {2}; int width = 5; int height = 5; - int[] inputDepths = {1, 2, 4}; - - Activation[] activations = {Activation.SIGMOID, Activation.TANH}; Nd4j.getRandom().setSeed(12345); - for (int inputDepth : inputDepths) { - for (Activation afn : activations) { - for (int minibatchSize : minibatchSizes) { - INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth); - INDArray labels = Nd4j.zeros(minibatchSize, nOut); - for (int i = 0; i < minibatchSize; i++) { - labels.putScalar(new int[]{i, i % nOut}, 1.0); - } + int[] inputDepths = new int[]{1, 2, 4}; + Activation[] activations = {Activation.SIGMOID, Activation.TANH, Activation.SOFTPLUS}; + int[] minibatch = {2, 1, 3}; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(new NoOp()) - .dataType(DataType.DOUBLE) - .activation(afn) - .list() - .layer(0, new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1) - .padding(0, 0).nIn(inputDepth).nOut(2).build())//output: (5-2+0)/1+1 = 4 - .layer(1, new LocallyConnected2D.Builder().nIn(2).nOut(7).kernelSize(2, 2) - .setInputSize(4, 4).convolutionMode(ConvolutionMode.Strict).hasBias(false) - .stride(1, 1).padding(0, 0).build()) //(4-2+0)/1+1 = 3 - .layer(2, new ConvolutionLayer.Builder().nIn(7).nOut(2).kernelSize(2, 2) - .stride(1, 1).padding(0, 0).build()) //(3-2+0)/1+1 = 2 - .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nIn(2 * 2 * 2).nOut(nOut) - .build()) - .setInputType(InputType.convolutionalFlat(height, width, inputDepth)).build(); + for( int i=0; i 0); +// for (val word : words) { +// System.out.println(word); +// } } @Test @@ -755,7 +774,16 @@ public class Word2VecTests extends BaseDL4JTest { @Test public void weightsNotUpdated_WhenLocked() throws Exception { - SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath()); + boolean isIntegration = isIntegrationTests(); + SentenceIterator iter; + SentenceIterator iter2; + if(isIntegration){ + iter = new BasicLineIterator(inputFile); + iter2 = new BasicLineIterator(inputFile2.getAbsolutePath()); + } else { + iter = new CollectionSentenceIterator(firstNLines(inputFile, 300)); + iter2 = new CollectionSentenceIterator(firstNLines(inputFile2, 300)); + } Word2Vec vec1 = new Word2Vec.Builder().minWordFrequency(1).iterations(3).batchSize(64).layerSize(100) .stopWords(new ArrayList()).seed(42).learningRate(0.025).minLearningRate(0.001) @@ -767,13 +795,12 @@ public class Word2VecTests extends BaseDL4JTest { vec1.fit(); - iter = new BasicLineIterator(inputFile2.getAbsolutePath()); Word2Vec vec2 = new Word2Vec.Builder().minWordFrequency(1).iterations(3).batchSize(32).layerSize(100) .stopWords(new ArrayList()).seed(32).learningRate(0.021).minLearningRate(0.001) .sampling(0).elementsLearningAlgorithm(new SkipGram()) .epochs(1).windowSize(5).allowParallelTokenization(true) .workers(1) - .iterate(iter) + .iterate(iter2) .intersectModel(vec1, true) .modelUtils(new BasicModelUtils()).build(); @@ -861,6 +888,22 @@ public class Word2VecTests extends BaseDL4JTest { } System.out.print("\n"); } - // + + public static List firstNLines(File f, int n){ + List lines = new ArrayList<>(); + try(InputStream is = new BufferedInputStream(new FileInputStream(f))){ + LineIterator lineIter = IOUtils.lineIterator(is, StandardCharsets.UTF_8); + try{ + for( int i=0; i cacheList = new ArrayList<>(); //cacheList is a dynamic array of strings used to hold all words + //STEP 1: Initialization + int iterations = 50; + //create an n-dimensional array of doubles + Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); + List cacheList = new ArrayList<>(); //cacheList is a dynamic array of strings used to hold all words - //STEP 2: Turn text input into a list of words - INDArray weights; - if(syntheticData){ - weights = Nd4j.rand(1000, 200); - } else { - log.info("Load & Vectorize data...."); - File wordFile = new ClassPathResource("deeplearning4j-tsne/words.txt").getFile(); //Open the file - //Get the data of all unique word vectors - Pair vectors = WordVectorSerializer.loadTxt(wordFile); - VocabCache cache = vectors.getSecond(); - weights = vectors.getFirst().getSyn0(); //seperate weights of unique words into their own list + //STEP 2: Turn text input into a list of words + INDArray weights; + if(syntheticData){ + weights = Nd4j.rand(250, 200); + } else { + log.info("Load & Vectorize data...."); + File wordFile = new ClassPathResource("deeplearning4j-tsne/words.txt").getFile(); //Open the file + //Get the data of all unique word vectors + Pair vectors = WordVectorSerializer.loadTxt(wordFile); + VocabCache cache = vectors.getSecond(); + weights = vectors.getFirst().getSyn0(); //seperate weights of unique words into their own list - for (int i = 0; i < cache.numWords(); i++) //seperate strings of words into their own list - cacheList.add(cache.wordAtIndex(i)); - } - - //STEP 3: build a dual-tree tsne to use later - log.info("Build model...."); - BarnesHutTsne tsne = new BarnesHutTsne.Builder() - .setMaxIter(iterations) - .theta(0.5) - .normalize(false) - .learningRate(500) - .useAdaGrad(false) - .workspaceMode(wsm) - .build(); - - - //STEP 4: establish the tsne values and save them to a file - log.info("Store TSNE Coordinates for Plotting...."); - File outDir = testDir.newFolder(); - tsne.fit(weights); - tsne.saveAsFile(cacheList, new File(outDir, "out.txt").getAbsolutePath()); + for (int i = 0; i < cache.numWords(); i++) //seperate strings of words into their own list + cacheList.add(cache.wordAtIndex(i)); } + + //STEP 3: build a dual-tree tsne to use later + log.info("Build model...."); + BarnesHutTsne tsne = new BarnesHutTsne.Builder() + .setMaxIter(iterations) + .theta(0.5) + .normalize(false) + .learningRate(500) + .useAdaGrad(false) + .workspaceMode(wsm) + .build(); + + + //STEP 4: establish the tsne values and save them to a file + log.info("Store TSNE Coordinates for Plotting...."); + File outDir = testDir.newFolder(); + tsne.fit(weights); + tsne.saveAsFile(cacheList, new File(outDir, "out.txt").getAbsolutePath()); } } - //Elapsed time : 01:01:57.988 @Test public void testPerformance() throws Exception { StopWatch watch = new StopWatch(); watch.start(); - for (boolean syntheticData : new boolean[]{false, true}) { - for (WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.NONE, WorkspaceMode.ENABLED}) { - log.info("Starting test: WSM={}, syntheticData={}", wsm, syntheticData); + for( int test=0; test <=1; test++){ + boolean syntheticData = test == 1; + WorkspaceMode wsm = test == 0 ? WorkspaceMode.NONE : WorkspaceMode.ENABLED; + log.info("Starting test: WSM={}, syntheticData={}", wsm, syntheticData); - //STEP 1: Initialization - int iterations = 100; - //create an n-dimensional array of doubles - Nd4j.setDataType(DataType.DOUBLE); - List cacheList = new ArrayList<>(); //cacheList is a dynamic array of strings used to hold all words + //STEP 1: Initialization + int iterations = 50; + //create an n-dimensional array of doubles + Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); + List cacheList = new ArrayList<>(); //cacheList is a dynamic array of strings used to hold all words - //STEP 2: Turn text input into a list of words - INDArray weights; - if(syntheticData){ - weights = Nd4j.rand(5000, 20); - } else { - log.info("Load & Vectorize data...."); - File wordFile = new ClassPathResource("deeplearning4j-tsne/words.txt").getFile(); //Open the file - //Get the data of all unique word vectors - Pair vectors = WordVectorSerializer.loadTxt(wordFile); - VocabCache cache = vectors.getSecond(); - weights = vectors.getFirst().getSyn0(); //seperate weights of unique words into their own list + //STEP 2: Turn text input into a list of words + INDArray weights; + if(syntheticData){ + weights = Nd4j.rand(DataType.FLOAT, 250, 20); + } else { + log.info("Load & Vectorize data...."); + File wordFile = new ClassPathResource("deeplearning4j-tsne/words.txt").getFile(); //Open the file + //Get the data of all unique word vectors + Pair vectors = WordVectorSerializer.loadTxt(wordFile); + VocabCache cache = vectors.getSecond(); + weights = vectors.getFirst().getSyn0(); //seperate weights of unique words into their own list - for (int i = 0; i < cache.numWords(); i++) //seperate strings of words into their own list - cacheList.add(cache.wordAtIndex(i)); - } - - //STEP 3: build a dual-tree tsne to use later - log.info("Build model...."); - BarnesHutTsne tsne = new BarnesHutTsne.Builder() - .setMaxIter(iterations) - .theta(0.5) - .normalize(false) - .learningRate(500) - .useAdaGrad(false) - .workspaceMode(wsm) - .build(); - - - //STEP 4: establish the tsne values and save them to a file - log.info("Store TSNE Coordinates for Plotting...."); - File outDir = testDir.newFolder(); - tsne.fit(weights); - tsne.saveAsFile(cacheList, new File(outDir, "out.txt").getAbsolutePath()); + for (int i = 0; i < cache.numWords(); i++) //seperate strings of words into their own list + cacheList.add(cache.wordAtIndex(i)); } + + //STEP 3: build a dual-tree tsne to use later + log.info("Build model...."); + BarnesHutTsne tsne = new BarnesHutTsne.Builder() + .setMaxIter(iterations) + .theta(0.5) + .normalize(false) + .learningRate(500) + .useAdaGrad(false) + .workspaceMode(wsm) + .build(); + + + //STEP 4: establish the tsne values and save them to a file + log.info("Store TSNE Coordinates for Plotting...."); + File outDir = testDir.newFolder(); + tsne.fit(weights); + tsne.saveAsFile(cacheList, new File(outDir, "out.txt").getAbsolutePath()); } watch.stop(); System.out.println("Elapsed time : " + watch); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectorsTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectorsTest.java index c3e8ac89b..95cd4e9a6 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectorsTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectorsTest.java @@ -20,6 +20,8 @@ package org.deeplearning4j.models.paragraphvectors; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import lombok.val; +import org.apache.commons.io.IOUtils; +import org.apache.commons.io.LineIterator; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.models.embeddings.learning.impl.elements.CBOW; import org.deeplearning4j.models.embeddings.reader.impl.FlatModelUtils; @@ -27,6 +29,7 @@ import org.deeplearning4j.models.sequencevectors.sequence.Sequence; import org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer; import org.deeplearning4j.models.sequencevectors.transformers.impl.iterables.BasicTransformerIterator; import org.deeplearning4j.models.sequencevectors.transformers.impl.iterables.ParallelTransformerIterator; +import org.deeplearning4j.text.sentenceiterator.*; import org.junit.Rule; import org.junit.rules.TemporaryFolder; import org.nd4j.linalg.io.ClassPathResource; @@ -46,10 +49,6 @@ import org.deeplearning4j.text.documentiterator.FileLabelAwareIterator; import org.deeplearning4j.text.documentiterator.LabelAwareIterator; import org.deeplearning4j.text.documentiterator.LabelledDocument; import org.deeplearning4j.text.documentiterator.LabelsSource; -import org.deeplearning4j.text.sentenceiterator.AggregatingSentenceIterator; -import org.deeplearning4j.text.sentenceiterator.BasicLineIterator; -import org.deeplearning4j.text.sentenceiterator.FileSentenceIterator; -import org.deeplearning4j.text.sentenceiterator.SentenceIterator; import org.deeplearning4j.text.sentenceiterator.interoperability.SentenceIteratorConverter; import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor; import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory; @@ -66,8 +65,8 @@ import org.nd4j.resources.Resources; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.io.File; -import java.io.IOException; +import java.io.*; +import java.nio.charset.StandardCharsets; import java.util.*; import java.util.concurrent.atomic.AtomicLong; @@ -372,7 +371,7 @@ public class ParagraphVectorsTest extends BaseDL4JTest { LabelsSource source = new LabelsSource("DOC_"); - ParagraphVectors vec = new ParagraphVectors.Builder().minWordFrequency(1).iterations(2).seed(119).epochs(3) + ParagraphVectors vec = new ParagraphVectors.Builder().minWordFrequency(1).iterations(2).seed(119).epochs(1) .layerSize(100).learningRate(0.025).labelsSource(source).windowSize(5).iterate(iter) .trainWordVectors(true).vocabCache(cache).tokenizerFactory(t).negativeSample(0) .useHierarchicSoftmax(true).sampling(0).workers(1).usePreciseWeightInit(true) @@ -425,6 +424,8 @@ public class ParagraphVectorsTest extends BaseDL4JTest { @Test(timeout = 300000) public void testParagraphVectorsDBOW() throws Exception { + skipUnlessIntegrationTests(); + File file = Resources.asFile("/big/raw_sentences.txt"); SentenceIterator iter = new BasicLineIterator(file); @@ -657,7 +658,7 @@ public class ParagraphVectorsTest extends BaseDL4JTest { } } - @Test(timeout = 300000) + @Test public void testIterator() throws IOException { val folder_labeled = testDir.newFolder(); val folder_unlabeled = testDir.newFolder(); @@ -672,7 +673,7 @@ public class ParagraphVectorsTest extends BaseDL4JTest { SentenceIterator iter = new BasicLineIterator(resource_sentences); int i = 0; - for (; i < 10000; ++i) { + for (; i < 10; ++i) { int j = 0; int labels = 0; int words = 0; @@ -721,7 +722,7 @@ public class ParagraphVectorsTest extends BaseDL4JTest { TokenizerFactory t = new DefaultTokenizerFactory(); t.setTokenPreProcessor(new CommonPreprocessor()); - Word2Vec wordVectors = new Word2Vec.Builder().seed(119).minWordFrequency(1).batchSize(250).iterations(1).epochs(3) + Word2Vec wordVectors = new Word2Vec.Builder().seed(119).minWordFrequency(1).batchSize(250).iterations(1).epochs(1) .learningRate(0.025).layerSize(150).minLearningRate(0.001) .elementsLearningAlgorithm(new SkipGram()).useHierarchicSoftmax(true).windowSize(5) .allowParallelTokenization(true) @@ -1009,7 +1010,7 @@ public class ParagraphVectorsTest extends BaseDL4JTest { TokenizerFactory t = new DefaultTokenizerFactory(); t.setTokenPreProcessor(new CommonPreprocessor()); - Word2Vec wordVectors = new Word2Vec.Builder().minWordFrequency(1).batchSize(250).iterations(1).epochs(3) + Word2Vec wordVectors = new Word2Vec.Builder().minWordFrequency(1).batchSize(250).iterations(1).epochs(1) .learningRate(0.025).layerSize(150).minLearningRate(0.001) .elementsLearningAlgorithm(new SkipGram()).useHierarchicSoftmax(true).windowSize(5) .iterate(iter).tokenizerFactory(t).build(); @@ -1151,8 +1152,27 @@ public class ParagraphVectorsTest extends BaseDL4JTest { @Test(timeout = 300000) public void testDoubleFit() throws Exception { + boolean isIntegration = isIntegrationTests(); File resource = Resources.asFile("/big/raw_sentences.txt"); - SentenceIterator iter = new BasicLineIterator(resource); + SentenceIterator iter; + if(isIntegration){ + iter = new BasicLineIterator(resource); + } else { + List lines = new ArrayList<>(); + try(InputStream is = new BufferedInputStream(new FileInputStream(resource))){ + LineIterator lineIter = IOUtils.lineIterator(is, StandardCharsets.UTF_8); + try{ + for( int i=0; i<500 && lineIter.hasNext(); i++ ){ + lines.add(lineIter.next()); + } + } finally { + lineIter.close(); + } + } + + iter = new CollectionSentenceIterator(lines); + } + TokenizerFactory t = new DefaultTokenizerFactory(); t.setTokenPreProcessor(new CommonPreprocessor()); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/iterator/Word2VecDataSetIteratorTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/iterator/Word2VecDataSetIteratorTest.java index 44b098dc1..9e8e89d39 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/iterator/Word2VecDataSetIteratorTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/iterator/Word2VecDataSetIteratorTest.java @@ -49,7 +49,7 @@ public class Word2VecDataSetIteratorTest extends BaseDL4JTest { @Override public long getTimeoutMilliseconds() { - return 240000L; + return 60000L; } /** @@ -57,6 +57,7 @@ public class Word2VecDataSetIteratorTest extends BaseDL4JTest { */ @Test public void testIterator1() throws Exception { + File inputFile = Resources.asFile("big/raw_sentences.txt"); SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath()); @@ -77,10 +78,14 @@ public class Word2VecDataSetIteratorTest extends BaseDL4JTest { Word2VecDataSetIterator iterator = new Word2VecDataSetIterator(vec, getLASI(iter, labels), labels, 1); INDArray array = iterator.next().getFeatures(); + int count = 0; while (iterator.hasNext()) { DataSet ds = iterator.next(); assertArrayEquals(array.shape(), ds.getFeatures().shape()); + + if(!isIntegrationTests() && count++ > 20) + break; //raw_sentences.txt is 2.81 MB, takes quite some time to process. We'll only first 20 minibatches when doing unit tests } } diff --git a/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/optimize/solvers/accumulation/EncodedGradientsAccumulatorTest.java b/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/optimize/solvers/accumulation/EncodedGradientsAccumulatorTest.java index 380807c8d..d75653604 100644 --- a/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/optimize/solvers/accumulation/EncodedGradientsAccumulatorTest.java +++ b/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/optimize/solvers/accumulation/EncodedGradientsAccumulatorTest.java @@ -45,9 +45,15 @@ public class EncodedGradientsAccumulatorTest extends BaseDL4JTest { */ @Test public void testStore1() throws Exception { - int numParams = 100000; - - int workers[] = new int[] {2, 4, 8}; + int numParams; + int[] workers; + if(isIntegrationTests()){ + numParams = 100000; + workers = new int[] {2, 4, 8}; + } else { + numParams = 10000; + workers = new int[] {2, 3}; + } for (int numWorkers : workers) { EncodingHandler handler = new EncodingHandler(new FixedThresholdAlgorithm(1e-3),null, null, false); @@ -77,7 +83,13 @@ public class EncodedGradientsAccumulatorTest extends BaseDL4JTest { */ @Test public void testEncodingLimits1() throws Exception { - int numParams = 100000; + int numParams; + if(isIntegrationTests()){ + numParams = 100000; + } else { + numParams = 10000; + } + EncodingHandler handler = new EncodingHandler(new FixedThresholdAlgorithm(1e-3), null, null, false); for (int e = 10; e < numParams / 5; e++) { diff --git a/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/optimize/solvers/accumulation/IndexedTailTest.java b/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/optimize/solvers/accumulation/IndexedTailTest.java index 68de05ce5..a7884bb56 100644 --- a/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/optimize/solvers/accumulation/IndexedTailTest.java +++ b/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/optimize/solvers/accumulation/IndexedTailTest.java @@ -242,7 +242,7 @@ public class IndexedTailTest extends BaseDL4JTest { final long[] sums = new long[numReaders]; val readers = new ArrayList(); for (int e = 0; e < numReaders; e++) { - val f = e; + final int f = e; val t = new Thread(new Runnable() { @Override public void run() { @@ -297,7 +297,7 @@ public class IndexedTailTest extends BaseDL4JTest { final long[] sums = new long[numReaders]; val readers = new ArrayList(); for (int e = 0; e < numReaders; e++) { - val f = e; + final int f = e; val t = new Thread(new Runnable() { @Override public void run() { @@ -371,7 +371,7 @@ public class IndexedTailTest extends BaseDL4JTest { final long[] sums = new long[numReaders]; val readers = new ArrayList(); for (int e = 0; e < numReaders; e++) { - val f = e; + final int f = e; val t = new Thread(new Runnable() { @Override public void run() { diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/JsonModelServerTest.java b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/JsonModelServerTest.java index fd79cc780..8b060a77c 100644 --- a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/JsonModelServerTest.java +++ b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/JsonModelServerTest.java @@ -35,6 +35,7 @@ import org.deeplearning4j.remote.helpers.House; import org.deeplearning4j.remote.helpers.HouseToPredictedPriceAdapter; import org.deeplearning4j.remote.helpers.PredictedPrice; import org.junit.After; +import org.junit.Before; import org.junit.Test; import org.nd4j.adapters.InferenceAdapter; import org.nd4j.autodiff.samediff.SDVariable; @@ -58,6 +59,7 @@ import java.util.Collections; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import static org.deeplearning4j.parallelism.inference.InferenceMode.INPLACE; import static org.deeplearning4j.parallelism.inference.InferenceMode.SEQUENTIAL; @@ -66,7 +68,6 @@ import static org.junit.Assert.*; @Slf4j public class JsonModelServerTest extends BaseDL4JTest { private static final MultiLayerNetwork model; - private final int PORT = 18080; static { val conf = new NeuralNetConfiguration.Builder() @@ -84,10 +85,18 @@ public class JsonModelServerTest extends BaseDL4JTest { @After public void pause() throws Exception { - // TODO: the same port was used in previous test and not accessible immediately. Might be better solution. + // Need to wait for server shutdown; without sleep, tests will fail if starting immediately after shutdown TimeUnit.SECONDS.sleep(2); } + private AtomicInteger portCount = new AtomicInteger(18080); + private int PORT; + + @Before + public void setPort(){ + PORT = portCount.getAndIncrement(); + } + @Test public void testStartStopParallel() throws Exception { @@ -343,7 +352,7 @@ public class JsonModelServerTest extends BaseDL4JTest { val server = new JsonModelServer.Builder(model) .outputSerializer(new PredictedPrice.PredictedPriceSerializer()) .inputDeserializer(null) - .port(18080) + .port(PORT) .build(); } @@ -382,7 +391,7 @@ public class JsonModelServerTest extends BaseDL4JTest { return null; } }) - .endpointAddress("http://localhost:18080/v1/serving") + .endpointAddress("http://localhost:" + PORT + "/v1/serving") .build(); int district = 2; diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelInferenceTest.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelInferenceTest.java index 42e8437e7..219573f0a 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelInferenceTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelInferenceTest.java @@ -485,7 +485,8 @@ public class ParallelInferenceTest extends BaseDL4JTest { List exp = new ArrayList<>(); Random r = new Random(); - for (int i = 0; i < 500; i++) { + int runs = isIntegrationTests() ? 500 : 30; + for (int i = 0; i < runs; i++) { int[] shape = defaultSize; if (r.nextDouble() < 0.4) { shape = new int[]{r.nextInt(5) + 1, 10, r.nextInt(10) + 1}; @@ -597,7 +598,8 @@ public class ParallelInferenceTest extends BaseDL4JTest { List arrs = new ArrayList<>(); List exp = new ArrayList<>(); Random r = new Random(); - for( int i=0; i<500; i++ ){ + int runs = isIntegrationTests() ? 500 : 20; + for( int i=0; i in = new ArrayList<>(); List inMasks = new ArrayList<>(); List exp = new ArrayList<>(); - for (int i = 0; i < 100; i++) { + int nRuns = isIntegrationTests() ? 100 : 10; + for (int i = 0; i < nRuns; i++) { int currTSLength = (randomTSLength ? 1 + r.nextInt(tsLength) : tsLength); int currNumEx = 1 + r.nextInt(3); INDArray inArr = Nd4j.rand(new int[]{currNumEx, nIn, currTSLength}); @@ -847,6 +852,7 @@ public class ParallelInferenceTest extends BaseDL4JTest { List in = new ArrayList<>(); List exp = new ArrayList<>(); + int runs = isIntegrationTests() ? 100 : 20; for (int i = 0; i < 100; i++) { int currNumEx = 1 + r.nextInt(3); INDArray inArr = Nd4j.rand(new int[]{currNumEx, nIn}); diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelWrapperTest.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelWrapperTest.java index 32c70f3ff..cea06265a 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelWrapperTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelWrapperTest.java @@ -62,8 +62,8 @@ public class ParallelWrapperTest extends BaseDL4JTest { int seed = 123; log.info("Load data...."); - DataSetIterator mnistTrain = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(batchSize, true, 12345), 100); - DataSetIterator mnistTest = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(batchSize, false, 12345), 10); + DataSetIterator mnistTrain = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(batchSize, true, 12345), 15); + DataSetIterator mnistTest = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(batchSize, false, 12345), 4); assertTrue(mnistTrain.hasNext()); val t0 = mnistTrain.next(); diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml index 42b6e42cf..3eafbb9e2 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml @@ -47,6 +47,12 @@ org.nd4j nd4j-parameter-server-node_2.11 ${nd4j.version} + + + net.jpountz.lz4 + lz4 + + junit diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/resources/log4j.properties b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/resources/log4j.properties new file mode 100755 index 000000000..5d1edb39f --- /dev/null +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/resources/log4j.properties @@ -0,0 +1,31 @@ +################################################################################ +# Copyright (c) 2015-2019 Skymind, Inc. +# +# This program and the accompanying materials are made available under the +# terms of the Apache License, Version 2.0 which is available at +# https://www.apache.org/licenses/LICENSE-2.0. +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +# +# SPDX-License-Identifier: Apache-2.0 +################################################################################ + +log4j.rootLogger=ERROR, Console +log4j.appender.Console=org.apache.log4j.ConsoleAppender +log4j.appender.Console.layout=org.apache.log4j.PatternLayout +log4j.appender.Console.layout.ConversionPattern=%d{ABSOLUTE} %-5p ~ %m%n + +log4j.appender.org.springframework=DEBUG +log4j.appender.org.deeplearning4j=DEBUG +log4j.appender.org.nd4j=DEBUG + +log4j.logger.org.springframework=INFO +log4j.logger.org.deeplearning4j=DEBUG +log4j.logger.org.nd4j=DEBUG +log4j.logger.org.apache.spark=WARN + + diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/resources/logback.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/resources/logback.xml new file mode 100644 index 000000000..9dec22fae --- /dev/null +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/resources/logback.xml @@ -0,0 +1,53 @@ + + + + + + + + logs/application.log + + %date - [%level] - from %logger in %thread + %n%message%n%xException%n + + + + + + %logger{15} - %message%n%xException{5} + + + + + + + + + + + + + + + + + + + + + + diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/resources/log4j.properties b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/resources/log4j.properties new file mode 100755 index 000000000..5d1edb39f --- /dev/null +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/resources/log4j.properties @@ -0,0 +1,31 @@ +################################################################################ +# Copyright (c) 2015-2019 Skymind, Inc. +# +# This program and the accompanying materials are made available under the +# terms of the Apache License, Version 2.0 which is available at +# https://www.apache.org/licenses/LICENSE-2.0. +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +# +# SPDX-License-Identifier: Apache-2.0 +################################################################################ + +log4j.rootLogger=ERROR, Console +log4j.appender.Console=org.apache.log4j.ConsoleAppender +log4j.appender.Console.layout=org.apache.log4j.PatternLayout +log4j.appender.Console.layout.ConversionPattern=%d{ABSOLUTE} %-5p ~ %m%n + +log4j.appender.org.springframework=DEBUG +log4j.appender.org.deeplearning4j=DEBUG +log4j.appender.org.nd4j=DEBUG + +log4j.logger.org.springframework=INFO +log4j.logger.org.deeplearning4j=DEBUG +log4j.logger.org.nd4j=DEBUG +log4j.logger.org.apache.spark=WARN + + diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/resources/logback.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/resources/logback.xml new file mode 100644 index 000000000..9dec22fae --- /dev/null +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/resources/logback.xml @@ -0,0 +1,53 @@ + + + + + + + + logs/application.log + + %date - [%level] - from %logger in %thread + %n%message%n%xException%n + + + + + + %logger{15} - %message%n%xException{5} + + + + + + + + + + + + + + + + + + + + + + diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/iterator/TestIteratorUtils.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/iterator/TestIteratorUtils.java index 6aa730c2f..b7a4d6d49 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/iterator/TestIteratorUtils.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/iterator/TestIteratorUtils.java @@ -25,6 +25,7 @@ import org.datavec.spark.transform.misc.StringToWritablesFunction; import org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator; import org.deeplearning4j.spark.BaseSparkTest; import org.junit.Test; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.io.ClassPathResource; @@ -35,6 +36,15 @@ import static org.junit.Assert.assertEquals; public class TestIteratorUtils extends BaseSparkTest { + @Override + public DataType getDataType() { + return DataType.FLOAT; + } + + @Override + public DataType getDefaultFPDataType() { + return DataType.FLOAT; + } @Test public void testIrisRRMDSI() throws Exception { diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java index abfd39060..6908be512 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java @@ -453,8 +453,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { tempDirF.deleteOnExit(); int dataSetObjSize = 1; - int batchSizePerExecutor = 16; - int numSplits = 5; + int batchSizePerExecutor = 4; + int numSplits = 3; int averagingFrequency = 3; int totalExamples = numExecutors() * batchSizePerExecutor * numSplits * averagingFrequency; DataSetIterator iter = new MnistDataSetIterator(dataSetObjSize, totalExamples, false); @@ -506,7 +506,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { INDArray paramsAfter = sparkNet.getNetwork().params().dup(); assertNotEquals(paramsBefore, paramsAfter); - Thread.sleep(2000); + Thread.sleep(200); SparkTrainingStats stats = sparkNet.getSparkTrainingStats(); //Expect @@ -517,7 +517,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { assertEquals(numSplits * numExecutors() * averagingFrequency, list.size()); for (EventStats es : list) { ExampleCountEventStats e = (ExampleCountEventStats) es; - assertTrue(batchSizePerExecutor * averagingFrequency - 10 >= e.getTotalExampleCount()); + assertTrue(batchSizePerExecutor * averagingFrequency >= e.getTotalExampleCount()); } @@ -535,9 +535,9 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { tempDirF.deleteOnExit(); tempDirF2.deleteOnExit(); - int dataSetObjSize = 5; - int batchSizePerExecutor = 25; - DataSetIterator iter = new MnistDataSetIterator(dataSetObjSize, 1000, false); + int dataSetObjSize = 4; + int batchSizePerExecutor = 8; + DataSetIterator iter = new MnistDataSetIterator(dataSetObjSize, 128, false); int i = 0; while (iter.hasNext()) { File nextFile = new File(tempDirF, i + ".bin"); @@ -981,7 +981,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { .setOutputs("out") .build(); - DataSetIterator iter = new IrisDataSetIterator(1, 150); + DataSetIterator iter = new IrisDataSetIterator(1, 50); List l = new ArrayList<>(); while(iter.hasNext()){ @@ -992,9 +992,10 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { int rddDataSetNumExamples = 1; - int averagingFrequency = 3; + int averagingFrequency = 2; + int batch = 2; ParameterAveragingTrainingMaster tm = new ParameterAveragingTrainingMaster.Builder(rddDataSetNumExamples) - .averagingFrequency(averagingFrequency).batchSizePerWorker(rddDataSetNumExamples) + .averagingFrequency(averagingFrequency).batchSizePerWorker(batch) .saveUpdater(true).workerPrefetchNumBatches(0).build(); Nd4j.getRandom().setSeed(12345); @@ -1003,7 +1004,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { SparkComputationGraph sn2 = new SparkComputationGraph(sc, conf2.clone(), tm); - for(int i=0; i<4; i++ ){ + for(int i=0; i<3; i++ ){ assertEquals(i, sn1.getNetwork().getLayerWiseConfigurations().getEpochCount()); assertEquals(i, sn2.getNetwork().getConfiguration().getEpochCount()); sn1.fit(rdd); diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/util/TestRepartitioning.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/util/TestRepartitioning.java index b01de7e5e..e0759a549 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/util/TestRepartitioning.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/util/TestRepartitioning.java @@ -42,6 +42,11 @@ import static org.junit.Assert.assertTrue; */ public class TestRepartitioning extends BaseSparkTest { + @Override + public long getTimeoutMilliseconds() { + return isIntegrationTests() ? 240000 : 60000; + } + @Test public void testRepartitioning() { List list = new ArrayList<>(); @@ -66,7 +71,12 @@ public class TestRepartitioning extends BaseSparkTest { @Test public void testRepartitioning2() throws Exception { - int[] ns = {320, 321, 25600, 25601, 25615}; + int[] ns; + if(isIntegrationTests()){ + ns = new int[]{320, 321, 25600, 25601, 25615}; + } else { + ns = new int[]{320, 2561}; + } for (int n : ns) { diff --git a/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/MiscTests.java b/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/MiscTests.java index af0205f00..c92f5acdc 100644 --- a/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/MiscTests.java +++ b/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/MiscTests.java @@ -32,6 +32,11 @@ import java.io.File; public class MiscTests extends BaseDL4JTest { + @Override + public long getTimeoutMilliseconds() { + return 120000L; + } + @Test public void testTransferVGG() throws Exception { //https://github.com/deeplearning4j/deeplearning4j/issues/5167 diff --git a/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/TestDownload.java b/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/TestDownload.java index bb41443bb..b45afe47a 100644 --- a/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/TestDownload.java +++ b/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/TestDownload.java @@ -48,6 +48,11 @@ import static org.junit.Assert.assertEquals; @Slf4j public class TestDownload extends BaseDL4JTest { + @Override + public long getTimeoutMilliseconds() { + return isIntegrationTests() ? 480000L : 60000L; + } + @ClassRule public static TemporaryFolder testDir = new TemporaryFolder(); private static File f; @@ -67,12 +72,20 @@ public class TestDownload extends BaseDL4JTest { public void testDownloadAllModels() throws Exception { // iterate through each available model - ZooModel[] models = new ZooModel[]{ - LeNet.builder().build(), - SimpleCNN.builder().build(), - UNet.builder().build(), - NASNet.builder().build() - }; + ZooModel[] models; + + if(isIntegrationTests()){ + models = new ZooModel[]{ + LeNet.builder().build(), + SimpleCNN.builder().build(), + UNet.builder().build(), + NASNet.builder().build()}; + } else { + models = new ZooModel[]{ + LeNet.builder().build(), + SimpleCNN.builder().build()}; + } + for (int i = 0; i < models.length; i++) { diff --git a/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/TestImageNet.java b/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/TestImageNet.java index b1963b9b6..9106bede3 100644 --- a/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/TestImageNet.java +++ b/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/TestImageNet.java @@ -57,6 +57,11 @@ import static org.junit.Assert.assertTrue; @Slf4j public class TestImageNet extends BaseDL4JTest { + @Override + public long getTimeoutMilliseconds() { + return 90000L; + } + @Override public DataType getDataType(){ return DataType.FLOAT; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/DistributionUniform.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/DistributionUniform.java index 682d7c230..4b104d08b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/DistributionUniform.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/DistributionUniform.java @@ -63,10 +63,15 @@ public class DistributionUniform extends DynamicCustomOp { addArgs(); } - public DistributionUniform(INDArray shape, INDArray out, double min, double max){ + public DistributionUniform(INDArray shape, INDArray out, double min, double max) { + this(shape, out, min, max, null); + } + + public DistributionUniform(INDArray shape, INDArray out, double min, double max, DataType dataType){ super(null, new INDArray[]{shape}, new INDArray[]{out}, Arrays.asList(min, max), (List)null); this.min = min; this.max = max; + this.dataType = dataType; } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/pom.xml b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/pom.xml index 2d2c8d6e3..d98c7a6d1 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/pom.xml +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/pom.xml @@ -310,6 +310,13 @@ + + + org.nd4j + nd4j-common-tests + ${project.version} + test + diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/DeviceLocalNDArrayTests.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/DeviceLocalNDArrayTests.java index 9584d5692..a616e7aa1 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/DeviceLocalNDArrayTests.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/DeviceLocalNDArrayTests.java @@ -19,6 +19,7 @@ package org.nd4j.jita.allocator; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.util.DeviceLocalNDArray; @@ -29,7 +30,7 @@ import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; @Slf4j -public class DeviceLocalNDArrayTests { +public class DeviceLocalNDArrayTests extends BaseND4JTest { @Test public void testDeviceLocalArray_1() throws Exception{ diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/impl/MemoryTrackerTest.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/impl/MemoryTrackerTest.java index d952c775a..9854b797a 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/impl/MemoryTrackerTest.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/impl/MemoryTrackerTest.java @@ -19,13 +19,14 @@ package org.nd4j.jita.allocator.impl; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.factory.Nd4j; import static org.junit.Assert.*; @Slf4j -public class MemoryTrackerTest { +public class MemoryTrackerTest extends BaseND4JTest { @Test public void testAllocatedDelta() { diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBufferTest.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBufferTest.java index 09c1ebb04..6f62dc53a 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBufferTest.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBufferTest.java @@ -4,6 +4,7 @@ import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.Before; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.jita.allocator.impl.AtomicAllocator; import org.nd4j.jita.workspace.CudaWorkspace; import org.nd4j.linalg.api.buffer.DataType; @@ -20,7 +21,7 @@ import java.util.concurrent.atomic.AtomicInteger; import static org.junit.Assert.*; @Slf4j -public class BaseCudaDataBufferTest { +public class BaseCudaDataBufferTest extends BaseND4JTest { @Before public void setUp() { diff --git a/nd4j/nd4j-backends/nd4j-tests-tensorflow/pom.xml b/nd4j/nd4j-backends/nd4j-tests-tensorflow/pom.xml index bc468c874..5f5c5fa90 100644 --- a/nd4j/nd4j-backends/nd4j-tests-tensorflow/pom.xml +++ b/nd4j/nd4j-backends/nd4j-tests-tensorflow/pom.xml @@ -87,6 +87,13 @@ ${logback.version} test + + + org.nd4j + nd4j-common-tests + ${project.version} + test + diff --git a/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/cpujava/org/nd4j/tensorflow/conversion/GraphRunnerTest.java b/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/cpujava/org/nd4j/tensorflow/conversion/GraphRunnerTest.java index c959edbfe..81599cd5b 100644 --- a/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/cpujava/org/nd4j/tensorflow/conversion/GraphRunnerTest.java +++ b/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/cpujava/org/nd4j/tensorflow/conversion/GraphRunnerTest.java @@ -20,6 +20,7 @@ package org.nd4j.tensorflow.conversion; import junit.framework.TestCase; import org.apache.commons.io.FileUtils; import org.bytedeco.tensorflow.TF_Tensor; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.resources.Resources; import org.nd4j.shade.protobuf.Descriptors; @@ -46,7 +47,17 @@ import java.util.Map; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; -public class GraphRunnerTest { +public class GraphRunnerTest extends BaseND4JTest { + + @Override + public DataType getDataType() { + return DataType.FLOAT; + } + + @Override + public DataType getDefaultFPDataType() { + return DataType.FLOAT; + } public static ConfigProto getConfig(){ String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"); diff --git a/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/cpujava/org/nd4j/tensorflow/conversion/TensorflowConversionTest.java b/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/cpujava/org/nd4j/tensorflow/conversion/TensorflowConversionTest.java index eb263f119..fbf4249bd 100644 --- a/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/cpujava/org/nd4j/tensorflow/conversion/TensorflowConversionTest.java +++ b/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/cpujava/org/nd4j/tensorflow/conversion/TensorflowConversionTest.java @@ -18,6 +18,7 @@ package org.nd4j.tensorflow.conversion; import org.apache.commons.io.IOUtils; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.io.ClassPathResource; @@ -29,7 +30,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.fail; -public class TensorflowConversionTest { +public class TensorflowConversionTest extends BaseND4JTest { @Test public void testView() { diff --git a/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/gpujava/org/nd4j/tensorflow/conversion/GpuGraphRunnerTest.java b/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/gpujava/org/nd4j/tensorflow/conversion/GpuGraphRunnerTest.java index 614330813..b13ca465f 100644 --- a/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/gpujava/org/nd4j/tensorflow/conversion/GpuGraphRunnerTest.java +++ b/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/gpujava/org/nd4j/tensorflow/conversion/GpuGraphRunnerTest.java @@ -17,6 +17,7 @@ package org.nd4j.tensorflow.conversion; +import org.nd4j.BaseND4JTest; import org.nd4j.shade.protobuf.util.JsonFormat; import org.apache.commons.io.IOUtils; import org.junit.Test; @@ -37,7 +38,7 @@ import java.util.Map; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; -public class GpuGraphRunnerTest { +public class GpuGraphRunnerTest extends BaseND4JTest { @Test public void testGraphRunner() throws Exception { diff --git a/nd4j/nd4j-backends/nd4j-tests/pom.xml b/nd4j/nd4j-backends/nd4j-tests/pom.xml index e6861d257..9d098189f 100644 --- a/nd4j/nd4j-backends/nd4j-tests/pom.xml +++ b/nd4j/nd4j-backends/nd4j-tests/pom.xml @@ -127,6 +127,13 @@ + + + org.nd4j + nd4j-common-tests + ${project.version} + test + diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java index f7d8c6d9b..db443c548 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java @@ -16,9 +16,6 @@ package org.nd4j.autodiff.opvalidation; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNull; - import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -46,6 +43,8 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.ops.transforms.Transforms; +import static org.junit.Assert.*; + @Slf4j public class LayerOpValidation extends BaseOpValidation { public LayerOpValidation(Nd4jBackend backend) { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LossOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LossOpValidation.java index 1c55e47fb..7f8da282e 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LossOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LossOpValidation.java @@ -45,7 +45,7 @@ public class LossOpValidation extends BaseOpValidation { } @Override - public long testTimeoutMilliseconds() { + public long getTimeoutMilliseconds() { return 90000L; } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java index 51e8fd714..6c6ba5b83 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java @@ -54,8 +54,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNull; +import static org.junit.Assert.*; @Slf4j @RunWith(Parameterized.class) diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java index 15d6bd273..071551719 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java @@ -2434,8 +2434,6 @@ public class ShapeOpValidation extends BaseOpValidation { @Test public void testPermute4(){ - Nd4j.getExecutioner().enableDebugMode(true); - Nd4j.getExecutioner().enableVerboseMode(true); INDArray in = Nd4j.linspace(DataType.FLOAT, 1, 6, 1).reshape(3,2); INDArray permute = Nd4j.createFromArray(1,0); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/CheckpointListenerTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/CheckpointListenerTest.java index 792fce892..33ac52f16 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/CheckpointListenerTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/CheckpointListenerTest.java @@ -63,8 +63,12 @@ public class CheckpointListenerTest extends BaseNd4jTest { return sd; } - public static DataSetIterator getIter(){ - return new IrisDataSetIterator(15, 150); + public static DataSetIterator getIter() { + return getIter(15, 150); + } + + public static DataSetIterator getIter(int batch, int totalExamples){ + return new IrisDataSetIterator(batch, totalExamples); } @@ -148,15 +152,15 @@ public class CheckpointListenerTest extends BaseNd4jTest { CheckpointListener l = new CheckpointListener.Builder(dir) .keepLast(2) - .saveEvery(3, TimeUnit.SECONDS) + .saveEvery(1, TimeUnit.SECONDS) .build(); sd.setListeners(l); - DataSetIterator iter = getIter(); + DataSetIterator iter = getIter(15, 150); for(int i=0; i<5; i++ ){ //10 iterations total sd.fit(iter, 1); - Thread.sleep(4000); + Thread.sleep(1000); } //Expect models saved at iterations: 10, 20, 30, 40 diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalCustomThreshold.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalCustomThreshold.java index 7ec752a14..fcdb04e31 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalCustomThreshold.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalCustomThreshold.java @@ -29,6 +29,7 @@ import org.nd4j.linalg.ops.transforms.Transforms; import java.util.Random; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; /** diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvaluationCalibrationTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvaluationCalibrationTest.java index 012ba3434..64281c7d6 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvaluationCalibrationTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvaluationCalibrationTest.java @@ -34,8 +34,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Random; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; /** * Created by Alex on 05/07/2017. diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/ROCTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/ROCTest.java index c4ffe13d2..e90a4829c 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/ROCTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/ROCTest.java @@ -33,8 +33,7 @@ import org.nd4j.linalg.indexing.NDArrayIndex; import java.util.*; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; /** * Created by Alex on 04/11/2016. diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/ByteOrderTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/ByteOrderTests.java index c2ca2148e..5ac5c9410 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/ByteOrderTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/ByteOrderTests.java @@ -33,6 +33,7 @@ import org.nd4j.nativeblas.NativeOpsHolder; import java.util.Arrays; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; @Slf4j diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java index 2ea9a8142..63571a30e 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java @@ -121,7 +121,12 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a "fused_batch_norm/.*", // AB 2020/01/04 - https://github.com/eclipse/deeplearning4j/issues/8592 - "emptyArrayTests/reshape/rank2_shape2-0_2-0--1" + "emptyArrayTests/reshape/rank2_shape2-0_2-0--1", + + //AB 2020/01/07 - Known issues + "bitcast/from_float64_to_int64", + "bitcast/from_rank2_float64_to_int64", + "bitcast/from_float64_to_uint64" }; /* As per TFGraphTestList.printArraysDebugging - this field defines a set of regexes for test cases that should have diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TensorFlowImportTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TensorFlowImportTest.java index d3a76bf28..2a0880906 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TensorFlowImportTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TensorFlowImportTest.java @@ -252,7 +252,7 @@ public class TensorFlowImportTest extends BaseNd4jTest { System.out.println(Arrays.toString(shape)); // this is NHWC weights. will be changed soon. - assertArrayEquals(new int[]{5,5,1,32}, shape); + assertArrayEquals(new long[]{5,5,1,32}, shape); System.out.println(convNode); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/BaseNd4jTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/BaseNd4jTest.java index 87a0a9aef..986103ef6 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/BaseNd4jTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/BaseNd4jTest.java @@ -17,6 +17,7 @@ package org.nd4j.linalg; +import lombok.extern.slf4j.Slf4j; import lombok.val; import org.bytedeco.javacpp.Pointer; import org.junit.After; @@ -26,6 +27,7 @@ import org.junit.rules.TestName; import org.junit.rules.Timeout; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; +import org.nd4j.BaseND4JTest; import org.nd4j.config.ND4JSystemProperties; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; @@ -40,30 +42,16 @@ import org.slf4j.LoggerFactory; import java.lang.management.ManagementFactory; import java.util.*; +import static org.junit.Assume.assumeTrue; + /** * Base Nd4j test * @author Adam Gibson */ @RunWith(Parameterized.class) -public abstract class BaseNd4jTest { - private static Logger log = LoggerFactory.getLogger(BaseNd4jTest.class); - - @Rule - public TestName testName = new TestName(); - - @Rule - public Timeout timeout = Timeout.seconds(testTimeoutMilliseconds()); - - /** - * Override this method to set the default timeout for methods in the class - */ - public long testTimeoutMilliseconds(){ - return 30000L; - } - - protected long startTime; - protected int threadCountBefore; +@Slf4j +public abstract class BaseNd4jTest extends BaseND4JTest { protected Nd4jBackend backend; protected String name; @@ -80,16 +68,10 @@ public abstract class BaseNd4jTest { public BaseNd4jTest(String name, Nd4jBackend backend) { this.backend = backend; this.name = name; - - //Suppress ND4J initialization - don't need this logged for every test... - System.setProperty(ND4JSystemProperties.LOG_INITIALIZATION, "false"); - System.setProperty(ND4JSystemProperties.ND4J_IGNORE_AVX, "true"); - System.gc(); } public BaseNd4jTest(Nd4jBackend backend) { this(backend.getClass().getName() + UUID.randomUUID().toString(), backend); - } private static List backends; @@ -104,79 +86,6 @@ public abstract class BaseNd4jTest { if (backend.canRun() && backendsToRun.contains(backend.getClass().getName()) || backendsToRun.isEmpty()) backends.add(backend); } - - } - public static void assertArrayEquals(String string, Object[] expecteds, Object[] actuals) { - org.junit.Assert.assertArrayEquals(string, expecteds, actuals); - } - - public static void assertArrayEquals(Object[] expecteds, Object[] actuals) { - org.junit.Assert.assertArrayEquals(expecteds, actuals); - } - - public static void assertArrayEquals(String string, long[] shapeA, long[] shapeB) { - org.junit.Assert.assertArrayEquals(string, shapeA, shapeB); - } - - public static void assertArrayEquals(String string, byte[] shapeA, byte[] shapeB) { - org.junit.Assert.assertArrayEquals(string, shapeA, shapeB); - } - - public static void assertArrayEquals(byte[] shapeA, byte[] shapeB) { - org.junit.Assert.assertArrayEquals(shapeA, shapeB); - } - - public static void assertArrayEquals(long[] shapeA, long[] shapeB) { - org.junit.Assert.assertArrayEquals(shapeA, shapeB); - } - - public static void assertArrayEquals(String string, int[] shapeA, long[] shapeB) { - org.junit.Assert.assertArrayEquals(string, ArrayUtil.toLongArray(shapeA), shapeB); - } - - public static void assertArrayEquals(int[] shapeA, long[] shapeB) { - org.junit.Assert.assertArrayEquals(ArrayUtil.toLongArray(shapeA), shapeB); - } - - public static void assertArrayEquals(String string, long[] shapeA, int[] shapeB) { - org.junit.Assert.assertArrayEquals(string, shapeA, ArrayUtil.toLongArray(shapeB)); - } - - public static void assertArrayEquals(long[] shapeA, int[] shapeB) { - org.junit.Assert.assertArrayEquals(shapeA, ArrayUtil.toLongArray(shapeB)); - } - - public static void assertArrayEquals(String string, int[] shapeA, int[] shapeB) { - org.junit.Assert.assertArrayEquals(string, shapeA, shapeB); - } - - public static void assertArrayEquals(int[] shapeA, int[] shapeB) { - org.junit.Assert.assertArrayEquals(shapeA, shapeB); - } - - public static void assertArrayEquals(String string, boolean[] shapeA, boolean[] shapeB) { - org.junit.Assert.assertArrayEquals(string, shapeA, shapeB); - } - - public static void assertArrayEquals(boolean[] shapeA, boolean[] shapeB) { - org.junit.Assert.assertArrayEquals(shapeA, shapeB); - } - - - public static void assertArrayEquals(float[] shapeA, float[] shapeB, float delta) { - org.junit.Assert.assertArrayEquals(shapeA, shapeB, delta); - } - - public static void assertArrayEquals(double[] shapeA, double[] shapeB, double delta) { - org.junit.Assert.assertArrayEquals(shapeA, shapeB, delta); - } - - public static void assertArrayEquals(String string, float[] shapeA, float[] shapeB, float delta) { - org.junit.Assert.assertArrayEquals(string, shapeA, shapeB, delta); - } - - public static void assertArrayEquals(String string, double[] shapeA, double[] shapeB, double delta) { - org.junit.Assert.assertArrayEquals(string, shapeA, shapeB, delta); } @Parameterized.Parameters(name = "{index}: backend({0})={1}") @@ -187,6 +96,13 @@ public abstract class BaseNd4jTest { return ret; } + @Override + @Before + public void beforeTest(){ + super.beforeTest(); + Nd4j.factory().setOrder(ordering()); + } + /** * Get the default backend (jblas) * The default backend can be overridden by also passing: @@ -207,106 +123,6 @@ public abstract class BaseNd4jTest { } - - @Before - public void before() throws Exception { - // - log.info("Running {}.{} on {}", getClass().getName(), testName.getMethodName(), backend.getClass().getSimpleName()); - Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build()); - Nd4j nd4j = new Nd4j(); - nd4j.initWithBackend(backend); - Nd4j.factory().setOrder(ordering()); - NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false); - Nd4j.getExecutioner().enableDebugMode(false); - Nd4j.getExecutioner().enableVerboseMode(false); - Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE); - startTime = System.currentTimeMillis(); - threadCountBefore = ManagementFactory.getThreadMXBean().getThreadCount(); - } - - @After - public void after() throws Exception { - long totalTime = System.currentTimeMillis() - startTime; - Nd4j.getMemoryManager().purgeCaches(); - - logTestCompletion(totalTime); - if (System.getProperties().getProperty("backends") != null - && !System.getProperty("backends").contains(backend.getClass().getName())) - return; - Nd4j nd4j = new Nd4j(); - nd4j.initWithBackend(backend); - Nd4j.factory().setOrder(ordering()); - NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false); - Nd4j.getExecutioner().enableDebugMode(false); - Nd4j.getExecutioner().enableVerboseMode(false); - - //Attempt to keep workspaces isolated between tests - Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); - val currWS = Nd4j.getMemoryManager().getCurrentWorkspace(); - Nd4j.getMemoryManager().setCurrentWorkspace(null); - if(currWS != null){ - //Not really safe to continue testing under this situation... other tests will likely fail with obscure - // errors that are hard to track back to this - log.error("Open workspace leaked from test! Exiting - {}, isOpen = {} - {}", currWS.getId(), currWS.isScopeActive(), currWS); - System.exit(1); - } - } - - public void logTestCompletion( long totalTime){ - StringBuilder sb = new StringBuilder(); - long maxPhys = Pointer.maxPhysicalBytes(); - long maxBytes = Pointer.maxBytes(); - long currPhys = Pointer.physicalBytes(); - long currBytes = Pointer.totalBytes(); - - long jvmTotal = Runtime.getRuntime().totalMemory(); - long jvmMax = Runtime.getRuntime().maxMemory(); - - int threadsAfter = ManagementFactory.getThreadMXBean().getThreadCount(); - sb.append(getClass().getSimpleName()).append(".").append(testName.getMethodName()) - .append(": ").append(totalTime).append(" ms") - .append(", threadCount: (").append(threadCountBefore).append("->").append(threadsAfter).append(")") - .append(", jvmTotal=").append(jvmTotal) - .append(", jvmMax=").append(jvmMax) - .append(", totalBytes=").append(currBytes).append(", maxBytes=").append(maxBytes) - .append(", currPhys=").append(currPhys).append(", maxPhys=").append(maxPhys); - - List ws = Nd4j.getWorkspaceManager().getAllWorkspacesForCurrentThread(); - if(ws != null && ws.size() > 0){ - long currSize = 0; - for(MemoryWorkspace w : ws){ - currSize += w.getCurrentSize(); - } - if(currSize > 0){ - sb.append(", threadWSSize=").append(currSize) - .append(" (").append(ws.size()).append(" WSs)"); - } - } - - - Properties p = Nd4j.getExecutioner().getEnvironmentInformation(); - Object o = p.get("cuda.devicesInformation"); - if(o instanceof List){ - List> l = (List>) o; - if(l.size() > 0) { - - sb.append(" [").append(l.size()) - .append(" GPUs: "); - - for (int i = 0; i < l.size(); i++) { - Map m = l.get(i); - if(i > 0) - sb.append(","); - sb.append("(").append(m.get("cuda.freeMemory")).append(" free, ") - .append(m.get("cuda.totalMemory")).append(" total)"); - } - sb.append("]"); - } - } - log.info(sb.toString()); - } - - /** * The ordering for this test * This test will only be invoked for @@ -315,15 +131,10 @@ public abstract class BaseNd4jTest { * @return the ordering for this test */ public char ordering() { - return 'a'; + return 'c'; } - - public String getFailureMessage() { return "Failed with backend " + backend.getClass().getName() + " and ordering " + ordering(); } - - - } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/NDArrayTestsFortran.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/NDArrayTestsFortran.java index 8657d061e..659dd93bb 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/NDArrayTestsFortran.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/NDArrayTestsFortran.java @@ -65,16 +65,6 @@ public class NDArrayTestsFortran extends BaseNd4jTest { super(backend); } - @Before - public void before() throws Exception { - super.before(); - } - - @After - public void after() throws Exception { - super.after(); - } - @Test diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java index 534720b29..c86246a52 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java @@ -133,13 +133,13 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Override - public long testTimeoutMilliseconds() { + public long getTimeoutMilliseconds() { return 90000; } @Before public void before() throws Exception { - super.before(); + super.beforeTest(); Nd4j.setDataType(DataType.DOUBLE); Nd4j.getRandom().setSeed(123); Nd4j.getExecutioner().enableDebugMode(false); @@ -148,7 +148,7 @@ public class Nd4jTestsC extends BaseNd4jTest { @After public void after() throws Exception { - super.after(); + super.afterTest(); Nd4j.setDataType(initialType); } @@ -5331,7 +5331,8 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test public void testNativeSort3() { - INDArray array = Nd4j.linspace(1, 1048576, 1048576, DataType.DOUBLE).reshape(1, -1); + int length = isIntegrationTests() ? 1048576 : 16484; + INDArray array = Nd4j.linspace(1, length, length, DataType.DOUBLE).reshape(1, -1); INDArray exp = array.dup(); Nd4j.shuffle(array, 0); @@ -7196,19 +7197,19 @@ public class Nd4jTestsC extends BaseNd4jTest { for( int i=-3; i<3; i++ ){ INDArray out = Nd4j.stack(i, in, in2); - int[] expShape; + long[] expShape; switch (i){ case -3: case 0: - expShape = new int[]{2,3,4}; + expShape = new long[]{2,3,4}; break; case -2: case 1: - expShape = new int[]{3,2,4}; + expShape = new long[]{3,2,4}; break; case -1: case 2: - expShape = new int[]{3,4,2}; + expShape = new long[]{3,4,2}; break; default: throw new RuntimeException(String.valueOf(i)); @@ -7602,6 +7603,7 @@ public class Nd4jTestsC extends BaseNd4jTest { String wsName = "testRollingMeanWs"; try { System.gc(); + int iterations1 = isIntegrationTests() ? 5 : 2; for (int e = 0; e < 5; e++) { try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(wsconf, wsName)) { val array = Nd4j.create(DataType.FLOAT, 32, 128, 256, 256); @@ -7609,7 +7611,7 @@ public class Nd4jTestsC extends BaseNd4jTest { } } - int iterations = 20; + int iterations = isIntegrationTests() ? 20 : 3; val timeStart = System.nanoTime(); for (int e = 0; e < iterations; e++) { try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(wsconf, wsName)) { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsComparisonC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsComparisonC.java index feb7131fd..946da0fb3 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsComparisonC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsComparisonC.java @@ -57,13 +57,13 @@ public class Nd4jTestsComparisonC extends BaseNd4jTest { @Before public void before() throws Exception { - super.before(); + super.beforeTest(); DataTypeUtil.setDTypeForContext(DataType.DOUBLE); } @After public void after() throws Exception { - super.after(); + super.afterTest(); DataTypeUtil.setDTypeForContext(initialType); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsComparisonFortran.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsComparisonFortran.java index 39f3de324..ef32c1f99 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsComparisonFortran.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsComparisonFortran.java @@ -37,8 +37,7 @@ import org.slf4j.LoggerFactory; import java.util.List; import java.util.Random; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; /** * Tests comparing Nd4j ops to other libraries @@ -59,7 +58,7 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest { @Before public void before() throws Exception { - super.before(); + super.beforeTest(); DataTypeUtil.setDTypeForContext(DataType.DOUBLE); Nd4j.getRandom().setSeed(SEED); @@ -67,7 +66,7 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest { @After public void after() throws Exception { - super.after(); + super.afterTest(); DataTypeUtil.setDTypeForContext(initialType); } @@ -197,7 +196,7 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest { INDArray gemv = m.mmul(v); RealMatrix gemv2 = rm.multiply(rv); - assertArrayEquals(new int[] {rows, 1}, gemv.shape()); + assertArrayEquals(new long[] {rows, 1}, gemv.shape()); assertArrayEquals(new int[] {rows, 1}, new int[] {gemv2.getRowDimension(), gemv2.getColumnDimension()}); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestNDArrayCreationUtil.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestNDArrayCreationUtil.java index 2dc595f54..1aba11f5e 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestNDArrayCreationUtil.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestNDArrayCreationUtil.java @@ -25,6 +25,8 @@ import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.util.ArrayUtil; +import static org.junit.Assert.assertArrayEquals; + /** * Created by Alex on 30/04/2016. */ diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DoubleDataBufferTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DoubleDataBufferTest.java index 41c0e3e2d..e39760d9b 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DoubleDataBufferTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DoubleDataBufferTest.java @@ -38,8 +38,7 @@ import org.nd4j.linalg.util.SerializationUtils; import java.io.*; import java.util.Arrays; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; /** * Double data buffer tests diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/FloatDataBufferTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/FloatDataBufferTest.java index 2360f66c0..c9bc81dc1 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/FloatDataBufferTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/FloatDataBufferTest.java @@ -37,8 +37,7 @@ import org.nd4j.linalg.util.SerializationUtils; import java.io.*; import java.nio.ByteBuffer; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; /** * Float data buffer tests diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/IntDataBufferTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/IntDataBufferTests.java index 4121cf1ea..ac7a63f22 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/IntDataBufferTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/IntDataBufferTests.java @@ -31,8 +31,7 @@ import org.nd4j.linalg.factory.Nd4jBackend; import java.io.*; import java.util.Arrays; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; /** * Tests for INT INDArrays and DataBuffers serialization diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTestsC.java index 2b5af1033..bad99afb4 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTestsC.java @@ -33,8 +33,7 @@ import org.nd4j.linalg.util.ArrayUtil; import java.util.Arrays; import java.util.Random; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; import static org.nd4j.linalg.indexing.NDArrayIndex.*; /** diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/resolve/NDArrayIndexResolveTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/resolve/NDArrayIndexResolveTests.java index 00f369756..80ebb203c 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/resolve/NDArrayIndexResolveTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/resolve/NDArrayIndexResolveTests.java @@ -27,8 +27,7 @@ import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.PointIndex; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; /** * @author Adam Gibson diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/shape/IndexShapeTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/shape/IndexShapeTests.java index a25a31425..3f18ee5cb 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/shape/IndexShapeTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/shape/IndexShapeTests.java @@ -25,6 +25,8 @@ import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.Indices; import org.nd4j.linalg.indexing.NDArrayIndex; +import static org.junit.Assert.assertArrayEquals; + /** * @author Adam Gibson */ diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/shape/IndexShapeTests2d.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/shape/IndexShapeTests2d.java index b74d0b1b4..d106dca0a 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/shape/IndexShapeTests2d.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/shape/IndexShapeTests2d.java @@ -24,6 +24,8 @@ import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.indexing.Indices; import org.nd4j.linalg.indexing.NDArrayIndex; +import static org.junit.Assert.assertArrayEquals; + /** * @author Adam Gibson */ diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/iterator/NDIndexIteratorTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/iterator/NDIndexIteratorTest.java index 65fe355fe..10a640c47 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/iterator/NDIndexIteratorTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/iterator/NDIndexIteratorTest.java @@ -24,6 +24,8 @@ import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.factory.Nd4jBackend; +import static org.junit.Assert.assertArrayEquals; + /** * @author Adam Gibson */ diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/rng/RngTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/rng/RngTests.java index 729c18c77..93e0b0739 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/rng/RngTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/rng/RngTests.java @@ -24,8 +24,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; /** * @author Adam Gibson diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/tad/TestTensorAlongDimension.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/tad/TestTensorAlongDimension.java index cede0543a..985715cc8 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/tad/TestTensorAlongDimension.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/tad/TestTensorAlongDimension.java @@ -32,6 +32,7 @@ import org.nd4j.linalg.primitives.Pair; import java.util.Arrays; import java.util.List; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; /** @@ -76,8 +77,8 @@ public class TestTensorAlongDimension extends BaseNd4jTest { INDArray row = Nd4j.create(DataType.DOUBLE, 1, 5); INDArray col = Nd4j.create(DataType.DOUBLE, 5, 1); - assertArrayEquals(new int[] {1, 5}, row.tensorAlongDimension(0, 1).shape()); - assertArrayEquals(new int[] {1, 5}, col.tensorAlongDimension(0, 0).shape()); + assertArrayEquals(new long[] {1, 5}, row.tensorAlongDimension(0, 1).shape()); + assertArrayEquals(new long[] {1, 5}, col.tensorAlongDimension(0, 0).shape()); } @Test @@ -100,7 +101,7 @@ public class TestTensorAlongDimension extends BaseNd4jTest { INDArray tad = arr.tensorAlongDimension(i, 0); INDArray javaTad = arr.tensorAlongDimension(i, 0); assertEquals(javaTad, tad); - assertArrayEquals(new int[] {rows}, tad.shape()); + assertArrayEquals(new long[] {rows}, tad.shape()); //assertEquals(testValues.javaTensorAlongDimension(i, 0), tad); } @@ -108,7 +109,7 @@ public class TestTensorAlongDimension extends BaseNd4jTest { assertEquals(rows, arr.tensorsAlongDimension(1)); for (int i = 0; i < rows; i++) { INDArray tad = arr.tensorAlongDimension(i, 1); - assertArrayEquals(new int[] {cols}, tad.shape()); + assertArrayEquals(new long[] {cols}, tad.shape()); //assertEquals(testValues.javaTensorAlongDimension(i, 1), tad); } } @@ -127,7 +128,7 @@ public class TestTensorAlongDimension extends BaseNd4jTest { assertEquals("Failed on " + p.getValue(), cols * dim2, arr.tensorsAlongDimension(0)); for (int i = 0; i < cols * dim2; i++) { INDArray tad = arr.tensorAlongDimension(i, 0); - assertArrayEquals(new int[] {rows}, tad.shape()); + assertArrayEquals(new long[] {rows}, tad.shape()); //assertEquals(testValues.javaTensorAlongDimension(i, 0), tad); } @@ -135,7 +136,7 @@ public class TestTensorAlongDimension extends BaseNd4jTest { assertEquals(rows * dim2, arr.tensorsAlongDimension(1)); for (int i = 0; i < rows * dim2; i++) { INDArray tad = arr.tensorAlongDimension(i, 1); - assertArrayEquals(new int[] {cols}, tad.shape()); + assertArrayEquals(new long[] {cols}, tad.shape()); //assertEquals(testValues.javaTensorAlongDimension(i, 1), tad); } @@ -143,7 +144,7 @@ public class TestTensorAlongDimension extends BaseNd4jTest { assertEquals(rows * cols, arr.tensorsAlongDimension(2)); for (int i = 0; i < rows * cols; i++) { INDArray tad = arr.tensorAlongDimension(i, 2); - assertArrayEquals(new int[] {dim2}, tad.shape()); + assertArrayEquals(new long[] {dim2}, tad.shape()); //assertEquals(testValues.javaTensorAlongDimension(i, 2), tad); } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/convolution/ConvolutionTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/convolution/ConvolutionTests.java index e04a41714..569bb26a5 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/convolution/ConvolutionTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/convolution/ConvolutionTests.java @@ -39,6 +39,7 @@ import org.nd4j.linalg.util.ArrayUtil; import java.util.Arrays; import java.util.List; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.nd4j.linalg.indexing.NDArrayIndex.all; import static org.nd4j.linalg.indexing.NDArrayIndex.point; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dimensionalityreduction/TestRandomProjection.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dimensionalityreduction/TestRandomProjection.java index ed93fdc38..a8c211944 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dimensionalityreduction/TestRandomProjection.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dimensionalityreduction/TestRandomProjection.java @@ -34,8 +34,7 @@ import org.nd4j.linalg.ops.transforms.Transforms; import java.util.ArrayList; import java.util.Arrays; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; import static org.nd4j.linalg.dimensionalityreduction.RandomProjection.johnsonLindenStraussMinDim; import static org.nd4j.linalg.dimensionalityreduction.RandomProjection.targetShape; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/Nd4jTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/Nd4jTest.java index c1c64b896..8ddd47eb7 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/Nd4jTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/Nd4jTest.java @@ -41,6 +41,7 @@ import java.util.Arrays; import java.util.List; import java.util.UUID; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; /** @@ -229,7 +230,7 @@ public class Nd4jTest extends BaseNd4jTest { ByteBuffer floatBuffer = pointer1.asByteBuffer(); byte[] dataTwo = new byte[floatBuffer.capacity()]; floatBuffer.get(dataTwo); - Assert.assertArrayEquals(originalData,dataTwo); + assertArrayEquals(originalData,dataTwo); floatBuffer.position(0); DataBuffer dataBuffer = Nd4j.createBuffer(new FloatPointer(floatBuffer.asFloatBuffer()),linspace.length(), DataType.FLOAT); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/DeviceLocalNDArrayTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/DeviceLocalNDArrayTests.java index cf8154ab4..d8c17a70c 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/DeviceLocalNDArrayTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/DeviceLocalNDArrayTests.java @@ -31,6 +31,7 @@ import org.nd4j.linalg.util.DeviceLocalNDArray; import java.util.Arrays; import java.util.concurrent.atomic.AtomicInteger; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; @Slf4j diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/nativ/OpsMappingTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/nativ/OpsMappingTests.java index efb77abef..454739496 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/nativ/OpsMappingTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/nativ/OpsMappingTests.java @@ -60,7 +60,7 @@ public class OpsMappingTests extends BaseNd4jTest { } @Override - public long testTimeoutMilliseconds() { + public long getTimeoutMilliseconds() { return 90000L; } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/LargeSerDeTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/LargeSerDeTests.java index b7516a9f2..7d23a021a 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/LargeSerDeTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/LargeSerDeTests.java @@ -29,6 +29,7 @@ import org.nd4j.linalg.factory.Nd4jBackend; import java.io.*; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; @RunWith(Parameterized.class) diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/NumpyFormatTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/NumpyFormatTests.java index 8fa4ed44f..bf7286b91 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/NumpyFormatTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/NumpyFormatTests.java @@ -38,8 +38,7 @@ import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.Map; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; @Slf4j public class NumpyFormatTests extends BaseNd4jTest { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/LongShapeTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/LongShapeTests.java index b96ebaca9..1eaac2e60 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/LongShapeTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/LongShapeTests.java @@ -25,6 +25,7 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; /** diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/StaticShapeTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/StaticShapeTests.java index 4a019c251..28e61a9eb 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/StaticShapeTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/StaticShapeTests.java @@ -34,6 +34,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; /** diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/padding/PaddingTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/padding/PaddingTests.java index 6c98c5afb..a54cbbfbd 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/padding/PaddingTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/padding/PaddingTests.java @@ -25,6 +25,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; /** diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/padding/PaddingTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/padding/PaddingTestsC.java index 448d8c8ec..7efef4551 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/padding/PaddingTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/padding/PaddingTestsC.java @@ -27,6 +27,7 @@ import org.nd4j.linalg.convolution.Convolution; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; /** diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTests.java index eb8af3852..6c3f911f4 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTests.java @@ -33,8 +33,7 @@ import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.SpecifiedIndex; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; /** * @author Adam Gibson diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTestsC.java index b836cfe75..33a64c291 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTestsC.java @@ -32,8 +32,7 @@ import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.SpecifiedIndex; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; /** * @author Adam Gibson @@ -99,7 +98,7 @@ public class IndexingTestsC extends BaseNd4jTest { public void testPointIndexes() { INDArray arr = Nd4j.create(DataType.DOUBLE, 4, 3, 2); INDArray get = arr.get(NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.all()); - assertArrayEquals(new int[] {4, 2}, get.shape()); + assertArrayEquals(new long[] {4, 2}, get.shape()); INDArray linspaced = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 3, 2); INDArray assertion = Nd4j.create(new double[][] {{3, 4}, {9, 10}, {15, 16}, {21, 22}}); @@ -108,7 +107,7 @@ public class IndexingTestsC extends BaseNd4jTest { INDArray sliceI = linspacedGet.slice(i); assertEquals(assertion.slice(i), sliceI); } - assertArrayEquals(new int[] {6, 1}, linspacedGet.stride()); + assertArrayEquals(new long[] {6, 1}, linspacedGet.stride()); assertEquals(assertion, linspacedGet); } @@ -131,7 +130,7 @@ public class IndexingTestsC extends BaseNd4jTest { INDArray get = padded.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(i, sy, iLim), NDArrayIndex.interval(j, sx, jLim)); - assertArrayEquals(new int[] {81, 81, 18, 2}, get.stride()); + assertArrayEquals(new long[] {81, 81, 18, 2}, get.stride()); INDArray assertion = Nd4j.create(new double[] {1, 1, 1, 1, 3, 3, 3, 3, 1, 1, 1, 1, 3, 3, 3, 3}, new int[] {1, 1, 4, 4}); assertEquals(assertion, get); @@ -143,7 +142,7 @@ public class IndexingTestsC extends BaseNd4jTest { INDArray assertion2 = Nd4j.create(new double[] {2, 2, 2, 2, 4, 4, 4, 4, 2, 2, 2, 2, 4, 4, 4, 4}, new int[] {1, 1, 4, 4}); - assertArrayEquals(new int[] {81, 81, 18, 2}, get3.stride()); + assertArrayEquals(new long[] {81, 81, 18, 2}, get3.stride()); assertEquals(assertion2, get3); @@ -154,7 +153,7 @@ public class IndexingTestsC extends BaseNd4jTest { j = 1; INDArray get2 = padded.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(i, sy, iLim), NDArrayIndex.interval(j, sx, jLim)); - assertArrayEquals(new int[] {81, 81, 18, 2}, get2.stride()); + assertArrayEquals(new long[] {81, 81, 18, 2}, get2.stride()); assertEquals(assertion, get2); @@ -171,22 +170,22 @@ public class IndexingTestsC extends BaseNd4jTest { } INDArray first10a = row.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 10)); - assertArrayEquals(first10a.shape(), new int[] {10}); + assertArrayEquals(first10a.shape(), new long[] {10}); for (int i = 0; i < 10; i++) assertTrue(first10a.getDouble(i) == i); INDArray first10b = row.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 10)); - assertArrayEquals(first10b.shape(), new int[] {10}); + assertArrayEquals(first10b.shape(), new long[] {10}); for (int i = 0; i < 10; i++) assertTrue(first10b.getDouble(i) == i); INDArray last10a = row.get(NDArrayIndex.point(0), NDArrayIndex.interval(20, 30)); - assertArrayEquals(last10a.shape(), new int[] {10}); + assertArrayEquals(last10a.shape(), new long[] {10}); for (int i = 0; i < 10; i++) assertEquals(i+20, last10a.getDouble(i), 1e-6); INDArray last10b = row.get(NDArrayIndex.point(0), NDArrayIndex.interval(20, 30)); - assertArrayEquals(last10b.shape(), new int[] {10}); + assertArrayEquals(last10b.shape(), new long[] {10}); for (int i = 0; i < 10; i++) assertTrue(last10b.getDouble(i) == 20 + i); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/reshape/ReshapeTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/reshape/ReshapeTests.java index 4c40df4e3..bd698b9fd 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/reshape/ReshapeTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/reshape/ReshapeTests.java @@ -26,6 +26,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assume.assumeNotNull; @@ -66,7 +67,7 @@ public class ReshapeTests extends BaseNd4jTest { double delta = 1e-1; INDArray arr = Nd4j.create(1, 3); INDArray reshaped = arr.reshape('f', 3, 1); - assertArrayEquals(new int[] {3, 1}, reshaped.shape()); + assertArrayEquals(new long[] {3, 1}, reshaped.shape()); assertEquals(0.0, reshaped.getDouble(1), delta); assertEquals(0.0, reshaped.getDouble(2), delta); log.info("Reshaped: {}", reshaped.shapeInfoDataBuffer().asInt()); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/slicing/SlicingTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/slicing/SlicingTestsC.java index 4fd06dced..b54203af2 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/slicing/SlicingTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/slicing/SlicingTestsC.java @@ -27,6 +27,7 @@ import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.SpecifiedIndex; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; /** diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/SortCooTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/SortCooTests.java index fb85535f6..39f5c10a9 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/SortCooTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/SortCooTests.java @@ -36,6 +36,8 @@ import org.nd4j.nativeblas.NativeOpsHolder; import java.util.Arrays; import java.util.stream.LongStream; +import static org.junit.Assert.assertArrayEquals; + /** * @author raver119@gmail.com */ @@ -92,7 +94,7 @@ public class SortCooTests extends BaseNd4jTest { // log.info("New indices: {}", Arrays.toString(idx.asInt())); - assertArrayEquals(expIndices, idx.asInt()); + assertArrayEquals(expIndices, idx.asLong()); assertArrayEquals(expValues, val.asDouble(), 1e-5); } @@ -121,7 +123,7 @@ public class SortCooTests extends BaseNd4jTest { NativeOpsHolder.getInstance().getDeviceNativeOps().sortCooIndices(null, (LongPointer) idx.addressPointer(), val.addressPointer(), 3, 3); - assertArrayEquals(expIndices, idx.asInt()); + assertArrayEquals(expIndices, idx.asLong()); assertArrayEquals(expValues, val.asDouble(), 1e-5); } @@ -277,7 +279,7 @@ public class SortCooTests extends BaseNd4jTest { // just check the indices. sortSparseCooIndicesSort1 and sortSparseCooIndicesSort2 checks that // indices and values are both swapped. This test just makes sure index sort works for larger arrays. - assertArrayEquals(expIndices, idx.asInt()); + assertArrayEquals(expIndices, idx.asLong()); } @Override public char ordering() { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/NDArrayUtilTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/NDArrayUtilTest.java index baef37c13..daa80295f 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/NDArrayUtilTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/NDArrayUtilTest.java @@ -23,6 +23,7 @@ import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4jBackend; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; /** diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/ShapeTestC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/ShapeTestC.java index a2d1cf979..e5669595f 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/ShapeTestC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/ShapeTestC.java @@ -30,6 +30,7 @@ import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; /** diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/BasicWorkspaceTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/BasicWorkspaceTests.java index d5686bba5..e9b3e100f 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/BasicWorkspaceTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/BasicWorkspaceTests.java @@ -303,8 +303,8 @@ public class BasicWorkspaceTests extends BaseNd4jTest { INDArray delta = Nd4j.create(new int[] {50, 64, 8, 8}, new int[] {64, 3200, 8, 1}, 'c'); delta = delta.permute(1, 0, 2, 3); - assertArrayEquals(new int[] {64, 50, 8, 8}, delta.shape()); - assertArrayEquals(new int[] {3200, 64, 8, 1}, delta.stride()); + assertArrayEquals(new long[] {64, 50, 8, 8}, delta.shape()); + assertArrayEquals(new long[] {3200, 64, 8, 1}, delta.stride()); INDArray delta2d = Shape.newShapeNoCopy(delta, new int[] {outDepth, miniBatch * outH * outW}, false); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/resources/logback.xml b/nd4j/nd4j-backends/nd4j-tests/src/test/resources/logback.xml index 63b2ca84e..9d7518dfb 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/resources/logback.xml +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/resources/logback.xml @@ -34,7 +34,7 @@ - + diff --git a/nd4j/nd4j-common-tests/pom.xml b/nd4j/nd4j-common-tests/pom.xml new file mode 100644 index 000000000..a242c6644 --- /dev/null +++ b/nd4j/nd4j-common-tests/pom.xml @@ -0,0 +1,42 @@ + + + + nd4j + org.nd4j + 1.0.0-SNAPSHOT + + 4.0.0 + + nd4j-common-tests + + + + junit + junit + provided + + + org.nd4j + nd4j-api + ${project.version} + + + + + + + nd4j-tests-cpu + + + nd4j-tests-cuda + + + testresources + + + nd4j-testresources + + + \ No newline at end of file diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/BaseNd4jTest.java b/nd4j/nd4j-common-tests/src/main/java/org/nd4j/BaseND4JTest.java similarity index 57% rename from nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/BaseNd4jTest.java rename to nd4j/nd4j-common-tests/src/main/java/org/nd4j/BaseND4JTest.java index 36958198d..ae2f56273 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/BaseNd4jTest.java +++ b/nd4j/nd4j-common-tests/src/main/java/org/nd4j/BaseND4JTest.java @@ -1,5 +1,6 @@ -/******************************************************************************* +/* ****************************************************************************** * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019-2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -14,19 +15,20 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.nd4j.parameterserver; - +package org.nd4j; import lombok.extern.slf4j.Slf4j; -import lombok.val; import org.bytedeco.javacpp.Pointer; import org.junit.After; import org.junit.Before; import org.junit.Rule; import org.junit.rules.TestName; +import org.junit.rules.Timeout; +import org.nd4j.base.Preconditions; import org.nd4j.config.ND4JSystemProperties; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; +import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.profiler.ProfilerConfig; @@ -35,50 +37,106 @@ import java.util.List; import java.util.Map; import java.util.Properties; +import static org.junit.Assume.assumeTrue; -/** - * Base Nd4j test - * @author Adam Gibson - */ @Slf4j -public abstract class BaseNd4jTest { +public abstract class BaseND4JTest { @Rule - public TestName testName = new TestName(); + public TestName name = new TestName(); + @Rule + public Timeout timeout = Timeout.millis(getTimeoutMilliseconds()); protected long startTime; protected int threadCountBefore; - public BaseNd4jTest(){ - //Suppress ND4J initialization - don't need this logged for every test... - System.setProperty(ND4JSystemProperties.LOG_INITIALIZATION, "false"); - System.gc(); + /** + * Override this method to set the default timeout for methods in the test class + */ + public long getTimeoutMilliseconds(){ + return 30000; } + /** + * Override this to set the profiling mode for the tests defined in the child class + */ + public OpExecutioner.ProfilingMode getProfilingMode(){ + return OpExecutioner.ProfilingMode.SCOPE_PANIC; + } + + /** + * Override this to set the datatype of the tests defined in the child class + */ + public DataType getDataType(){ + return DataType.DOUBLE; + } + + /** + * Override this to set the datatype of the tests defined in the child class + */ + public DataType getDefaultFPDataType(){ + return getDataType(); + } + + private final int DEFAULT_THREADS = Runtime.getRuntime().availableProcessors(); + + /** + * Override this to specify the number of threads for C++ execution, via + * {@link org.nd4j.linalg.factory.Environment#setMaxMasterThreads(int)} + * @return Number of threads to use for C++ op execution + */ + public int numThreads(){ + return DEFAULT_THREADS; + } + + protected Boolean integrationTest; + + /** + * @return True if integration tests maven profile is enabled, false otherwise. + */ + public boolean isIntegrationTests(){ + if(integrationTest == null){ + String prop = System.getenv("DL4J_INTEGRATION_TESTS"); + integrationTest = Boolean.parseBoolean(prop); + } + return integrationTest; + } + + /** + * Call this as the first line of a test in order to skip that test, only when the integration tests maven profile is not enabled. + * This can be used to dynamically skip integration tests when the integration test profile is not enabled. + * Note that the integration test profile is not enabled by default - "integration-tests" profile + */ + public void skipUnlessIntegrationTests(){ + assumeTrue("Skipping integration test - integration profile is not enabled", isIntegrationTests()); + } @Before - public void before() throws Exception { - log.info("Running " + getClass().getName() + "." + testName.getMethodName()); + public void beforeTest(){ + log.info("{}.{}", getClass().getSimpleName(), name.getMethodName()); + //Suppress ND4J initialization - don't need this logged for every test... + System.setProperty(ND4JSystemProperties.LOG_INITIALIZATION, "false"); + System.setProperty(ND4JSystemProperties.ND4J_IGNORE_AVX, "true"); + Nd4j.getExecutioner().setProfilingMode(getProfilingMode()); + Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build()); + Nd4j.setDefaultDataTypes(getDataType(), getDefaultFPDataType()); Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build()); Nd4j.getExecutioner().enableDebugMode(false); Nd4j.getExecutioner().enableVerboseMode(false); - Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE); + int numThreads = numThreads(); + Preconditions.checkState(numThreads > 0, "Number of threads must be > 0"); + if(numThreads != Nd4j.getEnvironment().maxMasterThreads()) { + Nd4j.getEnvironment().setMaxMasterThreads(numThreads); + } startTime = System.currentTimeMillis(); threadCountBefore = ManagementFactory.getThreadMXBean().getThreadCount(); } @After - public void after() throws Exception { - long totalTime = System.currentTimeMillis() - startTime; - Nd4j.getMemoryManager().purgeCaches(); - - logTestCompletion(totalTime); - Nd4j.getExecutioner().enableDebugMode(false); - Nd4j.getExecutioner().enableVerboseMode(false); - + public void afterTest(){ //Attempt to keep workspaces isolated between tests Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); - val currWS = Nd4j.getMemoryManager().getCurrentWorkspace(); + MemoryWorkspace currWS = Nd4j.getMemoryManager().getCurrentWorkspace(); Nd4j.getMemoryManager().setCurrentWorkspace(null); if(currWS != null){ //Not really safe to continue testing under this situation... other tests will likely fail with obscure @@ -86,9 +144,7 @@ public abstract class BaseNd4jTest { log.error("Open workspace leaked from test! Exiting - {}, isOpen = {} - {}", currWS.getId(), currWS.isScopeActive(), currWS); System.exit(1); } - } - public void logTestCompletion( long totalTime){ StringBuilder sb = new StringBuilder(); long maxPhys = Pointer.maxPhysicalBytes(); long maxBytes = Pointer.maxBytes(); @@ -99,8 +155,10 @@ public abstract class BaseNd4jTest { long jvmMax = Runtime.getRuntime().maxMemory(); int threadsAfter = ManagementFactory.getThreadMXBean().getThreadCount(); - sb.append(getClass().getSimpleName()).append(".").append(testName.getMethodName()) - .append(": ").append(totalTime).append(" ms") + + long duration = System.currentTimeMillis() - startTime; + sb.append(getClass().getSimpleName()).append(".").append(name.getMethodName()) + .append(": ").append(duration).append(" ms") .append(", threadCount: (").append(threadCountBefore).append("->").append(threadsAfter).append(")") .append(", jvmTotal=").append(jvmTotal) .append(", jvmMax=").append(jvmMax) diff --git a/nd4j/nd4j-jdbc/nd4j-jdbc-hsql/pom.xml b/nd4j/nd4j-jdbc/nd4j-jdbc-hsql/pom.xml index 7b96eb072..b0941300a 100644 --- a/nd4j/nd4j-jdbc/nd4j-jdbc-hsql/pom.xml +++ b/nd4j/nd4j-jdbc/nd4j-jdbc-hsql/pom.xml @@ -55,6 +55,13 @@ hsqldb ${hsqldb.version} + + + org.nd4j + nd4j-common-tests + ${project.version} + test + diff --git a/nd4j/nd4j-jdbc/nd4j-jdbc-hsql/src/test/java/org/nd4j/jdbc/hsql/HSqlLoaderTest.java b/nd4j/nd4j-jdbc/nd4j-jdbc-hsql/src/test/java/org/nd4j/jdbc/hsql/HSqlLoaderTest.java index 68846ae74..4279b7a78 100644 --- a/nd4j/nd4j-jdbc/nd4j-jdbc-hsql/src/test/java/org/nd4j/jdbc/hsql/HSqlLoaderTest.java +++ b/nd4j/nd4j-jdbc/nd4j-jdbc-hsql/src/test/java/org/nd4j/jdbc/hsql/HSqlLoaderTest.java @@ -20,6 +20,7 @@ import org.hsqldb.jdbc.JDBCDataSource; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -32,7 +33,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertThat; -public class HSqlLoaderTest { +public class HSqlLoaderTest extends BaseND4JTest { private static HsqlLoader hsqlLoader; private static DataSource dataSource; @@ -125,7 +126,7 @@ public class HSqlLoaderTest { INDArray load = hsqlLoader.load(hsqlLoader.loadForID("1")); assertNotNull(load); - assertEquals(Nd4j.linspace(1,4,4, Nd4j.dataType()),load); + assertEquals(Nd4j.linspace(1,4,4, load.dataType()),load); } diff --git a/nd4j/nd4j-jdbc/nd4j-jdbc-mysql/pom.xml b/nd4j/nd4j-jdbc/nd4j-jdbc-mysql/pom.xml index 44f146f38..0ece4c8b0 100644 --- a/nd4j/nd4j-jdbc/nd4j-jdbc-mysql/pom.xml +++ b/nd4j/nd4j-jdbc/nd4j-jdbc-mysql/pom.xml @@ -38,6 +38,13 @@ junit test + + + org.nd4j + nd4j-common-tests + ${project.version} + test + diff --git a/nd4j/nd4j-jdbc/nd4j-jdbc-mysql/src/test/java/org/nd4j/jdbc/mysql/MysqlLoaderTest.java b/nd4j/nd4j-jdbc/nd4j-jdbc-mysql/src/test/java/org/nd4j/jdbc/mysql/MysqlLoaderTest.java index db554ea43..59aea4ac0 100644 --- a/nd4j/nd4j-jdbc/nd4j-jdbc-mysql/src/test/java/org/nd4j/jdbc/mysql/MysqlLoaderTest.java +++ b/nd4j/nd4j-jdbc/nd4j-jdbc-mysql/src/test/java/org/nd4j/jdbc/mysql/MysqlLoaderTest.java @@ -19,6 +19,7 @@ package org.nd4j.jdbc.mysql; import com.mchange.v2.c3p0.ComboPooledDataSource; import org.junit.Ignore; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -26,7 +27,7 @@ import java.sql.Blob; import static org.junit.Assert.assertEquals; -public class MysqlLoaderTest { +public class MysqlLoaderTest extends BaseND4JTest { //simple litmus test, unfortunately relies on an external database diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/pom.xml b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/pom.xml index 21b3f6b65..ccabf1b76 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/pom.xml +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/pom.xml @@ -68,6 +68,13 @@ ${logback.version} test + + + org.nd4j + nd4j-common-tests + ${project.version} + test + diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/background/RemoteParameterServerClientTests.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/background/RemoteParameterServerClientTests.java index 530aa33ff..6398cbab1 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/background/RemoteParameterServerClientTests.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/background/RemoteParameterServerClientTests.java @@ -25,10 +25,10 @@ import org.agrona.concurrent.BusySpinIdleStrategy; import org.junit.After; import org.junit.Before; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.aeron.ipc.AeronUtil; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.parameterserver.BaseNd4jTest; import org.nd4j.parameterserver.client.ParameterServerClient; import java.util.concurrent.atomic.AtomicInteger; @@ -39,7 +39,7 @@ import static org.junit.Assert.assertEquals; * Created by agibsonccc on 10/5/16. */ @Slf4j -public class RemoteParameterServerClientTests extends BaseNd4jTest { +public class RemoteParameterServerClientTests extends BaseND4JTest { private int parameterLength = 1000; private Aeron.Context ctx; private MediaDriver mediaDriver; diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/client/ParameterServerClientPartialTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/client/ParameterServerClientPartialTest.java index d69455a57..3846d6c2d 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/client/ParameterServerClientPartialTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/client/ParameterServerClientPartialTest.java @@ -24,11 +24,11 @@ import org.agrona.concurrent.BusySpinIdleStrategy; import org.junit.BeforeClass; import org.junit.Ignore; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.aeron.ipc.AeronUtil; import org.nd4j.aeron.ipc.NDArrayMessage; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.parameterserver.BaseNd4jTest; import org.nd4j.parameterserver.ParameterServerListener; import org.nd4j.parameterserver.ParameterServerSubscriber; @@ -40,7 +40,7 @@ import static org.junit.Assert.assertTrue; * Created by agibsonccc on 10/3/16. */ @Slf4j -public class ParameterServerClientPartialTest extends BaseNd4jTest { +public class ParameterServerClientPartialTest extends BaseND4JTest { private static MediaDriver mediaDriver; private static Aeron.Context ctx; private static ParameterServerSubscriber masterNode, slaveNode; diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/client/ParameterServerClientTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/client/ParameterServerClientTest.java index 1ec08a83c..6aeaadb90 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/client/ParameterServerClientTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/client/ParameterServerClientTest.java @@ -21,10 +21,10 @@ import io.aeron.driver.MediaDriver; import org.junit.BeforeClass; import org.junit.Ignore; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.aeron.ipc.AeronUtil; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.parameterserver.BaseNd4jTest; import org.nd4j.parameterserver.ParameterServerListener; import org.nd4j.parameterserver.ParameterServerSubscriber; import org.slf4j.Logger; @@ -37,7 +37,7 @@ import static org.junit.Assert.assertTrue; /** * Created by agibsonccc on 10/3/16. */ -public class ParameterServerClientTest extends BaseNd4jTest { +public class ParameterServerClientTest extends BaseND4JTest { private static MediaDriver mediaDriver; private static Logger log = LoggerFactory.getLogger(ParameterServerClientTest.class); private static Aeron aeron; diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/pom.xml b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/pom.xml index f0d44beb6..bfab849da 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/pom.xml +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/pom.xml @@ -79,6 +79,13 @@ rxjava 2.2.0 + + + org.nd4j + nd4j-common-tests + ${project.version} + test + diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/VoidParameterServerStressTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/VoidParameterServerStressTest.java index 8b27c87ea..b1d159c1d 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/VoidParameterServerStressTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/VoidParameterServerStressTest.java @@ -22,6 +22,7 @@ import org.junit.After; import org.junit.Before; import org.junit.Ignore; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.rng.Random; import org.nd4j.linalg.factory.Nd4j; @@ -60,7 +61,7 @@ import static org.junit.Assert.*; @Slf4j @Ignore @Deprecated -public class VoidParameterServerStressTest { +public class VoidParameterServerStressTest extends BaseND4JTest { private static final int NUM_WORDS = 100000; @Before diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/VoidParameterServerTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/VoidParameterServerTest.java index f85e23d91..57c79e3ba 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/VoidParameterServerTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/VoidParameterServerTest.java @@ -21,6 +21,7 @@ import org.junit.After; import org.junit.Before; import org.junit.Ignore; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.parameterserver.distributed.conf.VoidConfiguration; @@ -54,7 +55,7 @@ import static org.junit.Assert.*; @Slf4j @Ignore @Deprecated -public class VoidParameterServerTest { +public class VoidParameterServerTest extends BaseND4JTest { private static List localIPs; private static List badIPs; private static final Transport transport = new MulticastTransport(); diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/conf/VoidConfigurationTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/conf/VoidConfigurationTest.java index 6253d776b..0cb4b254a 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/conf/VoidConfigurationTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/conf/VoidConfigurationTest.java @@ -20,6 +20,7 @@ import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.Timeout; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.exception.ND4JIllegalStateException; import static org.junit.Assert.*; @@ -27,7 +28,7 @@ import static org.junit.Assert.*; /** * @author raver119@gmail.com */ -public class VoidConfigurationTest { +public class VoidConfigurationTest extends BaseND4JTest { @Rule public Timeout globalTimeout = Timeout.seconds(30); diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/logic/ClipboardTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/logic/ClipboardTest.java index 4041bb1e8..26b9da8eb 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/logic/ClipboardTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/logic/ClipboardTest.java @@ -20,6 +20,7 @@ import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.RandomUtils; import org.junit.*; import org.junit.rules.Timeout; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.util.ArrayUtil; @@ -40,7 +41,7 @@ import static org.junit.Assert.*; @Slf4j @Ignore @Deprecated -public class ClipboardTest { +public class ClipboardTest extends BaseND4JTest { @Before public void setUp() throws Exception { diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/logic/FrameCompletionHandlerTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/logic/FrameCompletionHandlerTest.java index ad3f23c1e..5c9694995 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/logic/FrameCompletionHandlerTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/logic/FrameCompletionHandlerTest.java @@ -21,6 +21,7 @@ import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; import org.junit.rules.Timeout; +import org.nd4j.BaseND4JTest; import org.nd4j.parameterserver.distributed.logic.completion.FrameCompletionHandler; import static org.junit.Assert.*; @@ -30,7 +31,7 @@ import static org.junit.Assert.*; */ @Ignore @Deprecated -public class FrameCompletionHandlerTest { +public class FrameCompletionHandlerTest extends BaseND4JTest { @Before public void setUp() throws Exception { diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/logic/routing/InterleavedRouterTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/logic/routing/InterleavedRouterTest.java index d2211f9fa..607ee96e9 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/logic/routing/InterleavedRouterTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/logic/routing/InterleavedRouterTest.java @@ -21,6 +21,7 @@ import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; import org.junit.rules.Timeout; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.io.StringUtils; import org.nd4j.linalg.util.HashUtil; import org.nd4j.parameterserver.distributed.conf.VoidConfiguration; @@ -39,7 +40,7 @@ import static org.junit.Assert.*; */ @Ignore @Deprecated -public class InterleavedRouterTest { +public class InterleavedRouterTest extends BaseND4JTest { VoidConfiguration configuration; Transport transport; long originator; diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/messages/FrameTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/messages/FrameTest.java index 9c4ecec51..0e46bb179 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/messages/FrameTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/messages/FrameTest.java @@ -20,6 +20,7 @@ import org.agrona.concurrent.UnsafeBuffer; import org.junit.Before; import org.junit.Ignore; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.parameterserver.distributed.conf.VoidConfiguration; import org.nd4j.parameterserver.distributed.enums.NodeRole; import org.nd4j.parameterserver.distributed.logic.completion.Clipboard; @@ -37,7 +38,7 @@ import static org.junit.Assert.*; */ @Ignore @Deprecated -public class FrameTest { +public class FrameTest extends BaseND4JTest { @Before public void setUp() throws Exception { diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/messages/VoidMessageTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/messages/VoidMessageTest.java index eab7aa468..d23ecd047 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/messages/VoidMessageTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/messages/VoidMessageTest.java @@ -20,6 +20,7 @@ import org.junit.After; import org.junit.Before; import org.junit.Ignore; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.parameterserver.distributed.messages.requests.SkipGramRequestMessage; import static org.junit.Assert.*; @@ -29,7 +30,7 @@ import static org.junit.Assert.*; */ @Ignore @Deprecated -public class VoidMessageTest { +public class VoidMessageTest extends BaseND4JTest { @Before public void setUp() throws Exception { diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/messages/aggregations/VoidAggregationTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/messages/aggregations/VoidAggregationTest.java index f5914cde1..698ab6dc1 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/messages/aggregations/VoidAggregationTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/messages/aggregations/VoidAggregationTest.java @@ -19,6 +19,7 @@ package org.nd4j.parameterserver.distributed.messages.aggregations; import lombok.extern.slf4j.Slf4j; import org.junit.*; import org.junit.rules.Timeout; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -33,7 +34,7 @@ import static org.junit.Assert.*; @Slf4j @Ignore @Deprecated -public class VoidAggregationTest { +public class VoidAggregationTest extends BaseND4JTest { private static final short NODES = 100; private static final int ELEMENTS_PER_NODE = 3; diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/transport/RoutedTransportTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/transport/RoutedTransportTest.java index 7d9092591..661c74196 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/transport/RoutedTransportTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/transport/RoutedTransportTest.java @@ -20,6 +20,7 @@ import org.junit.After; import org.junit.Before; import org.junit.Ignore; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.parameterserver.distributed.conf.VoidConfiguration; import org.nd4j.parameterserver.distributed.enums.NodeRole; import org.nd4j.parameterserver.distributed.logic.ClientRouter; @@ -39,7 +40,7 @@ import static org.junit.Assert.*; */ @Ignore @Deprecated -public class RoutedTransportTest { +public class RoutedTransportTest extends BaseND4JTest { @Before public void setUp() throws Exception { diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/util/NetworkOrganizerTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/util/NetworkOrganizerTest.java index 1be838edf..c7ace4db5 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/util/NetworkOrganizerTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/util/NetworkOrganizerTest.java @@ -20,6 +20,7 @@ import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.RandomUtils; import org.junit.*; import org.junit.rules.Timeout; +import org.nd4j.BaseND4JTest; import java.util.*; @@ -29,7 +30,7 @@ import static org.junit.Assert.*; * @author raver119@gmail.com */ @Slf4j -public class NetworkOrganizerTest { +public class NetworkOrganizerTest extends BaseND4JTest { @Before public void setUp() throws Exception { diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/DelayedModelParameterServerTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/DelayedModelParameterServerTest.java index a75d00909..20831d47c 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/DelayedModelParameterServerTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/DelayedModelParameterServerTest.java @@ -23,6 +23,7 @@ import org.apache.commons.lang3.RandomUtils; import org.junit.After; import org.junit.Before; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.primitives.AtomicBoolean; @@ -48,7 +49,7 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; @Slf4j -public class DelayedModelParameterServerTest { +public class DelayedModelParameterServerTest extends BaseND4JTest { private static final String rootId = "ROOT_NODE"; @Before diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/ModelParameterServerTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/ModelParameterServerTest.java index 54acd8adf..c6be1b9c9 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/ModelParameterServerTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/ModelParameterServerTest.java @@ -20,6 +20,7 @@ import io.reactivex.functions.Consumer; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.primitives.AtomicBoolean; @@ -46,7 +47,7 @@ import java.util.concurrent.atomic.AtomicInteger; import static org.junit.Assert.*; @Slf4j -public class ModelParameterServerTest { +public class ModelParameterServerTest extends BaseND4JTest { private static final String rootId = "ROOT_NODE"; @Test(timeout = 20000L) diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/chunks/impl/FileChunksTrackerTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/chunks/impl/FileChunksTrackerTest.java index 86e4b8fa4..0d48d5358 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/chunks/impl/FileChunksTrackerTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/chunks/impl/FileChunksTrackerTest.java @@ -19,6 +19,7 @@ package org.nd4j.parameterserver.distributed.v2.chunks.impl; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.parameterserver.distributed.v2.chunks.VoidChunk; import org.nd4j.parameterserver.distributed.v2.chunks.impl.FileChunksTracker; @@ -30,7 +31,7 @@ import java.util.ArrayList; import static org.junit.Assert.*; @Slf4j -public class FileChunksTrackerTest { +public class FileChunksTrackerTest extends BaseND4JTest { @Test public void testTracker_1() throws Exception { val array = Nd4j.linspace(1, 100000, 100000).reshape(-1, 1000); diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/chunks/impl/InmemoryChunksTrackerTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/chunks/impl/InmemoryChunksTrackerTest.java index aa78d721f..3a2430ed1 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/chunks/impl/InmemoryChunksTrackerTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/chunks/impl/InmemoryChunksTrackerTest.java @@ -18,6 +18,7 @@ package org.nd4j.parameterserver.distributed.v2.chunks.impl; import lombok.val; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.parameterserver.distributed.v2.chunks.VoidChunk; import org.nd4j.parameterserver.distributed.v2.messages.impl.GradientsUpdateMessage; @@ -27,7 +28,7 @@ import java.util.ArrayList; import static org.junit.Assert.*; -public class InmemoryChunksTrackerTest { +public class InmemoryChunksTrackerTest extends BaseND4JTest { @Test public void testTracker_1() throws Exception { val array = Nd4j.linspace(1, 100000, 10000).reshape(-1, 1000); diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/messages/VoidMessageTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/messages/VoidMessageTest.java index 8380f157c..2b08b0c62 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/messages/VoidMessageTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/messages/VoidMessageTest.java @@ -19,13 +19,14 @@ package org.nd4j.parameterserver.distributed.v2.messages; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.util.SerializationUtils; import org.nd4j.parameterserver.distributed.v2.messages.pairs.handshake.HandshakeRequest; import static org.junit.Assert.*; @Slf4j -public class VoidMessageTest { +public class VoidMessageTest extends BaseND4JTest { @Test public void testHandshakeSerialization_1() throws Exception { val req = new HandshakeRequest(); diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/messages/history/HashHistoryHolderTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/messages/history/HashHistoryHolderTest.java index 901d4b197..c6999e469 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/messages/history/HashHistoryHolderTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/messages/history/HashHistoryHolderTest.java @@ -19,11 +19,12 @@ package org.nd4j.parameterserver.distributed.v2.messages.history; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.Test; +import org.nd4j.BaseND4JTest; import static org.junit.Assert.*; @Slf4j -public class HashHistoryHolderTest { +public class HashHistoryHolderTest extends BaseND4JTest { @Test public void testBasicStuff_1() { diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/transport/impl/AeronUdpTransportTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/transport/impl/AeronUdpTransportTest.java index 990acbcd7..bd780e0ef 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/transport/impl/AeronUdpTransportTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/transport/impl/AeronUdpTransportTest.java @@ -20,13 +20,14 @@ import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.Ignore; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.parameterserver.distributed.conf.VoidConfiguration; import org.nd4j.parameterserver.distributed.v2.messages.pairs.ping.PingMessage; import static org.junit.Assert.*; @Slf4j -public class AeronUdpTransportTest { +public class AeronUdpTransportTest extends BaseND4JTest { private static final String IP = "127.0.0.1"; private static final int ROOT_PORT = 40781; diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/transport/impl/DummyTransportTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/transport/impl/DummyTransportTest.java index b2f530df4..ba34f974f 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/transport/impl/DummyTransportTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/transport/impl/DummyTransportTest.java @@ -19,6 +19,7 @@ package org.nd4j.parameterserver.distributed.v2.transport.impl; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.parameterserver.distributed.v2.enums.PropagationMode; import org.nd4j.parameterserver.distributed.v2.messages.VoidMessage; @@ -33,7 +34,7 @@ import java.util.concurrent.atomic.AtomicInteger; import static org.junit.Assert.*; @Slf4j -public class DummyTransportTest { +public class DummyTransportTest extends BaseND4JTest { @Test public void testBasicConnection_1() throws Exception { diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/util/MeshOrganizerTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/util/MeshOrganizerTest.java index 63bdb9fc6..72487bf2a 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/util/MeshOrganizerTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/util/MeshOrganizerTest.java @@ -19,6 +19,7 @@ package org.nd4j.parameterserver.distributed.v2.util; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.util.SerializationUtils; import org.nd4j.parameterserver.distributed.v2.enums.MeshBuildMode; @@ -32,7 +33,7 @@ import static org.junit.Assert.*; * @author raver119@gmail.com */ @Slf4j -public class MeshOrganizerTest { +public class MeshOrganizerTest extends BaseND4JTest { @Test(timeout = 1000L) public void testDescendantsCount_1() { @@ -165,8 +166,8 @@ public class MeshOrganizerTest { nodes.add(node); } - for (val n:nodes) - log.info("Number of downstreams: [{}]", n.numberOfDownstreams()); +// for (val n:nodes) +// log.info("Number of downstreams: [{}]", n.numberOfDownstreams()); log.info("Going for first clone"); val clone1 = mesh.clone(); diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/util/MessageSplitterTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/util/MessageSplitterTest.java index 20038d1a6..71761291e 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/util/MessageSplitterTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/util/MessageSplitterTest.java @@ -19,6 +19,7 @@ package org.nd4j.parameterserver.distributed.v2.util; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.primitives.Atomic; import org.nd4j.linalg.primitives.Optional; @@ -29,7 +30,7 @@ import java.util.ArrayList; import static org.junit.Assert.*; @Slf4j -public class MessageSplitterTest { +public class MessageSplitterTest extends BaseND4JTest { @Test public void testMessageSplit_1() throws Exception { diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/node/ParameterServerNodeTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/node/ParameterServerNodeTest.java index 2f0623791..52e84efd3 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/node/ParameterServerNodeTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/node/ParameterServerNodeTest.java @@ -22,6 +22,7 @@ import lombok.extern.slf4j.Slf4j; import org.junit.BeforeClass; import org.junit.Ignore; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.aeron.ipc.AeronUtil; import org.nd4j.aeron.ipc.NDArrayMessage; import org.nd4j.linalg.factory.Nd4j; @@ -38,7 +39,7 @@ import static org.junit.Assert.*; @Slf4j @Ignore @Deprecated -public class ParameterServerNodeTest { +public class ParameterServerNodeTest extends BaseND4JTest { private static MediaDriver mediaDriver; private static Aeron aeron; private static ParameterServerNode parameterServerNode; diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-rocksdb-storage/pom.xml b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-rocksdb-storage/pom.xml index 6996672ca..619af0d7b 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-rocksdb-storage/pom.xml +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-rocksdb-storage/pom.xml @@ -46,6 +46,13 @@ junit test + + + org.nd4j + nd4j-common-tests + ${project.version} + test + diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-rocksdb-storage/src/test/java/org/nd4j/parameterserver/updater/storage/UpdaterStorageTests.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-rocksdb-storage/src/test/java/org/nd4j/parameterserver/updater/storage/UpdaterStorageTests.java index 3a7ac9ea3..170e1583f 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-rocksdb-storage/src/test/java/org/nd4j/parameterserver/updater/storage/UpdaterStorageTests.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-rocksdb-storage/src/test/java/org/nd4j/parameterserver/updater/storage/UpdaterStorageTests.java @@ -17,6 +17,7 @@ package org.nd4j.parameterserver.updater.storage; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.aeron.ipc.NDArrayMessage; import org.nd4j.linalg.factory.Nd4j; @@ -25,7 +26,7 @@ import static junit.framework.TestCase.assertEquals; /** * Created by agibsonccc on 12/2/16. */ -public class UpdaterStorageTests { +public class UpdaterStorageTests extends BaseND4JTest { @Test(timeout = 30000L) public void testInMemory() { diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/pom.xml b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/pom.xml index 62bb98c1c..4ed7c8b7b 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/pom.xml +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/pom.xml @@ -95,6 +95,13 @@ + + + org.nd4j + nd4j-common-tests + ${project.version} + test + diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/src/test/java/org/nd4j/parameterserver/status/play/StatusServerTests.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/src/test/java/org/nd4j/parameterserver/status/play/StatusServerTests.java index b821b2571..23ee67d78 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/src/test/java/org/nd4j/parameterserver/status/play/StatusServerTests.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/src/test/java/org/nd4j/parameterserver/status/play/StatusServerTests.java @@ -17,12 +17,13 @@ package org.nd4j.parameterserver.status.play; import org.junit.Test; +import org.nd4j.BaseND4JTest; import play.server.Server; /** * Created by agibsonccc on 12/1/16. */ -public class StatusServerTests { +public class StatusServerTests extends BaseND4JTest { @Test(timeout = 20000L) public void runStatusServer() { diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/src/test/java/org/nd4j/parameterserver/status/play/StorageTests.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/src/test/java/org/nd4j/parameterserver/status/play/StorageTests.java index cd9aab453..1ad1e478e 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/src/test/java/org/nd4j/parameterserver/status/play/StorageTests.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/src/test/java/org/nd4j/parameterserver/status/play/StorageTests.java @@ -17,6 +17,7 @@ package org.nd4j.parameterserver.status.play; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.parameterserver.model.SubscriberState; import static junit.framework.TestCase.assertEquals; @@ -25,7 +26,7 @@ import static org.junit.Assert.assertTrue; /** * Created by agibsonccc on 12/1/16. */ -public class StorageTests { +public class StorageTests extends BaseND4JTest { @Test(timeout = 20000L) public void testMapStorage() throws Exception { diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/pom.xml b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/pom.xml index af7316a37..4c60d83e2 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/pom.xml +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/pom.xml @@ -58,6 +58,13 @@ unirest-java ${unirest.version} + + + org.nd4j + nd4j-common-tests + ${project.version} + test + diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/test/java/org/nd4j/parameterserver/updater/ParameterServerUpdaterTests.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/test/java/org/nd4j/parameterserver/updater/ParameterServerUpdaterTests.java index f2d47e95b..58627405b 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/test/java/org/nd4j/parameterserver/updater/ParameterServerUpdaterTests.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/test/java/org/nd4j/parameterserver/updater/ParameterServerUpdaterTests.java @@ -17,6 +17,7 @@ package org.nd4j.parameterserver.updater; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.aeron.ipc.NDArrayMessage; import org.nd4j.aeron.ndarrayholder.InMemoryNDArrayHolder; import org.nd4j.linalg.factory.Nd4j; @@ -29,7 +30,7 @@ import static org.junit.Assume.assumeNotNull; /** * Created by agibsonccc on 12/2/16. */ -public class ParameterServerUpdaterTests { +public class ParameterServerUpdaterTests extends BaseND4JTest { @Test(timeout = 30000L) public void synchronousTest() { diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/test/java/org/nd4j/parameterserver/updater/storage/UpdaterStorageTests.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/test/java/org/nd4j/parameterserver/updater/storage/UpdaterStorageTests.java index 5186f1a4f..d666f7cb4 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/test/java/org/nd4j/parameterserver/updater/storage/UpdaterStorageTests.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/test/java/org/nd4j/parameterserver/updater/storage/UpdaterStorageTests.java @@ -17,6 +17,7 @@ package org.nd4j.parameterserver.updater.storage; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.aeron.ipc.NDArrayMessage; import org.nd4j.linalg.factory.Nd4j; @@ -25,7 +26,7 @@ import static junit.framework.TestCase.assertEquals; /** * Created by agibsonccc on 12/2/16. */ -public class UpdaterStorageTests { +public class UpdaterStorageTests extends BaseND4JTest { @Test(expected = UnsupportedOperationException.class) diff --git a/nd4j/nd4j-remote/nd4j-grpc-client/pom.xml b/nd4j/nd4j-remote/nd4j-grpc-client/pom.xml index 93d171535..88af3c166 100644 --- a/nd4j/nd4j-remote/nd4j-grpc-client/pom.xml +++ b/nd4j/nd4j-remote/nd4j-grpc-client/pom.xml @@ -83,6 +83,13 @@ ${logback.version} test + + + org.nd4j + nd4j-common-tests + ${project.version} + test + diff --git a/nd4j/nd4j-remote/nd4j-grpc-client/src/test/java/org/nd4j/graph/GraphInferenceGrpcClientTest.java b/nd4j/nd4j-remote/nd4j-grpc-client/src/test/java/org/nd4j/graph/GraphInferenceGrpcClientTest.java index 1c84da010..bbd6e8cf7 100644 --- a/nd4j/nd4j-remote/nd4j-grpc-client/src/test/java/org/nd4j/graph/GraphInferenceGrpcClientTest.java +++ b/nd4j/nd4j-remote/nd4j-grpc-client/src/test/java/org/nd4j/graph/GraphInferenceGrpcClientTest.java @@ -21,6 +21,7 @@ import lombok.val; import org.apache.commons.lang3.RandomUtils; import org.junit.Ignore; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.autodiff.execution.conf.ExecutorConfiguration; import org.nd4j.autodiff.execution.conf.OutputMode; import org.nd4j.autodiff.execution.input.Operands; @@ -32,7 +33,7 @@ import static org.junit.Assert.*; @Slf4j @Ignore -public class GraphInferenceGrpcClientTest { +public class GraphInferenceGrpcClientTest extends BaseND4JTest { @Test public void testSimpleGraph_1() throws Exception { val exp = Nd4j.create(new double[] {-0.95938617, -1.20301781, 1.22260064, 0.50172403, 0.59972949, 0.78568028, 0.31609724, 1.51674747, 0.68013491, -0.05227458, 0.25903158,1.13243439}, new long[]{3, 1, 4}); diff --git a/nd4j/nd4j-remote/nd4j-json-server/pom.xml b/nd4j/nd4j-remote/nd4j-json-server/pom.xml index 6c307f71c..d7729f179 100644 --- a/nd4j/nd4j-remote/nd4j-json-server/pom.xml +++ b/nd4j/nd4j-remote/nd4j-json-server/pom.xml @@ -130,6 +130,14 @@ ${gson.version} test + + + + org.nd4j + nd4j-common-tests + ${project.version} + test + diff --git a/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/SameDiffJsonModelServerTest.java b/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/SameDiffJsonModelServerTest.java index 290d1c4a1..1abbff6b3 100644 --- a/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/SameDiffJsonModelServerTest.java +++ b/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/SameDiffJsonModelServerTest.java @@ -19,6 +19,7 @@ package org.nd4j.remote; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.remote.clients.JsonRemoteInference; @@ -35,7 +36,7 @@ import java.util.concurrent.Future; import static org.junit.Assert.*; @Slf4j -public class SameDiffJsonModelServerTest { +public class SameDiffJsonModelServerTest extends BaseND4JTest { @Test public void basicServingTest_1() throws Exception { diff --git a/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/SameDiffServletTest.java b/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/SameDiffServletTest.java index a26809efc..9a7b94272 100644 --- a/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/SameDiffServletTest.java +++ b/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/SameDiffServletTest.java @@ -7,6 +7,7 @@ import org.apache.http.impl.client.HttpClientBuilder; import org.junit.After; import org.junit.Before; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.MultiDataSet; @@ -18,7 +19,7 @@ import java.io.IOException; import static org.junit.Assert.assertEquals; -public class SameDiffServletTest { +public class SameDiffServletTest extends BaseND4JTest { private SameDiffJsonModelServer server; diff --git a/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/serde/BasicSerdeTests.java b/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/serde/BasicSerdeTests.java index 3aca94b8a..8332918b6 100644 --- a/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/serde/BasicSerdeTests.java +++ b/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/serde/BasicSerdeTests.java @@ -2,12 +2,13 @@ package org.nd4j.remote.serde; import lombok.val; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.remote.clients.serde.impl.*; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; -public class BasicSerdeTests { +public class BasicSerdeTests extends BaseND4JTest { private final static DoubleArraySerde doubleArraySerde = new DoubleArraySerde(); private final static FloatArraySerde floatArraySerde = new FloatArraySerde(); private final static StringSerde stringSerde = new StringSerde(); diff --git a/nd4j/nd4j-serde/nd4j-aeron/pom.xml b/nd4j/nd4j-serde/nd4j-aeron/pom.xml index fcf74c518..87e9347dd 100644 --- a/nd4j/nd4j-serde/nd4j-aeron/pom.xml +++ b/nd4j/nd4j-serde/nd4j-aeron/pom.xml @@ -49,15 +49,105 @@ testresources - + + + + nd4j-tests-cpu + + false + org.nd4j nd4j-native ${project.version} - test + + + + org.apache.maven.plugins + maven-surefire-plugin + + + ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cpu/blas/ + + src/test/java + + *.java + **/*.java + **/Test*.java + **/*Test.java + **/*TestCase.java + + junit:junit + + org.nd4j.linalg.cpu.nativecpu.CpuBackend + org.nd4j.linalg.cpu.nativecpu.CpuBackend + + + -Ddtype=float -Xmx8g + + + + + + + + nd4j-tests-cuda + + false + + + + org.nd4j + nd4j-cuda-10.2 + ${project.version} + + + + + + org.apache.maven.plugins + maven-surefire-plugin + + + org.apache.maven.surefire + surefire-junit47 + 2.19.1 + + + + + ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cuda/blas/ + + src/test/java + + *.java + **/*.java + **/Test*.java + **/*Test.java + **/*TestCase.java + + junit:junit + + org.nd4j.linalg.jcublas.JCublasBackend + org.nd4j.linalg.jcublas.JCublasBackend + + + -Ddtype=float -Xmx6g + + + + @@ -78,5 +168,27 @@ junit test + + + ch.qos.logback + logback-classic + ${logback.version} + test + + + + ch.qos.logback + logback-core + ${logback.version} + test + + + + + org.nd4j + nd4j-common-tests + ${project.version} + test + diff --git a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/AeronNDArraySerdeTest.java b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/AeronNDArraySerdeTest.java index e05a5b6f9..9b7874bae 100644 --- a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/AeronNDArraySerdeTest.java +++ b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/AeronNDArraySerdeTest.java @@ -19,6 +19,7 @@ package org.nd4j.aeron.ipc; import org.agrona.concurrent.UnsafeBuffer; import org.apache.commons.lang3.time.StopWatch; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -32,7 +33,7 @@ import static org.junit.Assert.assertTrue; /** * Created by agibsonccc on 9/23/16. */ -public class AeronNDArraySerdeTest { +public class AeronNDArraySerdeTest extends BaseND4JTest { @Test public void testToAndFrom() { diff --git a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/LargeNdArrayIpcTest.java b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/LargeNdArrayIpcTest.java index 55e4e13ff..bd3ac9b9c 100644 --- a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/LargeNdArrayIpcTest.java +++ b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/LargeNdArrayIpcTest.java @@ -23,6 +23,7 @@ import org.agrona.CloseHelper; import org.junit.After; import org.junit.Before; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -34,29 +35,40 @@ import static org.junit.Assert.assertFalse; * Created by agibsonccc on 9/22/16. */ @Slf4j -public class LargeNdArrayIpcTest { +public class LargeNdArrayIpcTest extends BaseND4JTest { private MediaDriver mediaDriver; private Aeron.Context ctx; private String channel = "aeron:udp?endpoint=localhost:" + (40123 + new java.util.Random().nextInt(130)); private int streamId = 10; private int length = (int) 1e7; + @Override + public long getTimeoutMilliseconds() { + return 180000L; + } + @Before public void before() { - //MediaDriver.loadPropertiesFile("aeron.properties"); - MediaDriver.Context ctx = AeronUtil.getMediaDriverContext(length); - mediaDriver = MediaDriver.launchEmbedded(ctx); - System.out.println("Using media driver directory " + mediaDriver.aeronDirectoryName()); - System.out.println("Launched media driver"); + if(isIntegrationTests()) { + //MediaDriver.loadPropertiesFile("aeron.properties"); + MediaDriver.Context ctx = AeronUtil.getMediaDriverContext(length); + mediaDriver = MediaDriver.launchEmbedded(ctx); + System.out.println("Using media driver directory " + mediaDriver.aeronDirectoryName()); + System.out.println("Launched media driver"); + } } @After public void after() { - CloseHelper.quietClose(mediaDriver); + if(isIntegrationTests()) { + CloseHelper.quietClose(mediaDriver); + } } @Test public void testMultiThreadedIpcBig() throws Exception { + skipUnlessIntegrationTests(); //Long-running test - don't run as part of unit tests by default + int length = (int) 1e7; INDArray arr = Nd4j.ones(length); AeronNDArrayPublisher publisher; diff --git a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/NDArrayMessageTest.java b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/NDArrayMessageTest.java index 0cddea16b..be0572c93 100644 --- a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/NDArrayMessageTest.java +++ b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/NDArrayMessageTest.java @@ -18,6 +18,7 @@ package org.nd4j.aeron.ipc; import org.agrona.DirectBuffer; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -26,7 +27,7 @@ import static org.junit.Assert.assertEquals; /** * Created by agibsonccc on 11/6/16. */ -public class NDArrayMessageTest { +public class NDArrayMessageTest extends BaseND4JTest { @Test public void testNDArrayMessageToAndFrom() { diff --git a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/NdArrayIpcTest.java b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/NdArrayIpcTest.java index 0412b1bed..d1a95ffde 100644 --- a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/NdArrayIpcTest.java +++ b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/NdArrayIpcTest.java @@ -22,6 +22,7 @@ import org.agrona.CloseHelper; import org.junit.After; import org.junit.Before; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.slf4j.Logger; @@ -36,7 +37,7 @@ import static org.junit.Assert.assertFalse; /** * Created by agibsonccc on 9/22/16. */ -public class NdArrayIpcTest { +public class NdArrayIpcTest extends BaseND4JTest { private MediaDriver mediaDriver; private static Logger log = LoggerFactory.getLogger(NdArrayIpcTest.class); private Aeron.Context ctx; @@ -44,21 +45,32 @@ public class NdArrayIpcTest { private int streamId = 10; private int length = (int) 1e7; + @Override + public long getTimeoutMilliseconds() { + return 120000L; + } + @Before public void before() { - MediaDriver.Context ctx = AeronUtil.getMediaDriverContext(length); - mediaDriver = MediaDriver.launchEmbedded(ctx); - System.out.println("Using media driver directory " + mediaDriver.aeronDirectoryName()); - System.out.println("Launched media driver"); + if(isIntegrationTests()) { + MediaDriver.Context ctx = AeronUtil.getMediaDriverContext(length); + mediaDriver = MediaDriver.launchEmbedded(ctx); + System.out.println("Using media driver directory " + mediaDriver.aeronDirectoryName()); + System.out.println("Launched media driver"); + } } @After public void after() { - CloseHelper.quietClose(mediaDriver); + if(isIntegrationTests()) { + CloseHelper.quietClose(mediaDriver); + } } @Test public void testMultiThreadedIpc() throws Exception { + skipUnlessIntegrationTests(); //Long-running test - don't run as part of unit tests by default + ExecutorService executorService = Executors.newFixedThreadPool(4); INDArray arr = Nd4j.scalar(1.0); @@ -139,6 +151,8 @@ public class NdArrayIpcTest { @Test public void testIpc() throws Exception { + skipUnlessIntegrationTests(); //Long-running test - don't run as part of unit tests by default + INDArray arr = Nd4j.scalar(1.0); diff --git a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/chunk/ChunkAccumulatorTests.java b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/chunk/ChunkAccumulatorTests.java index 8890d3836..d856b8f95 100644 --- a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/chunk/ChunkAccumulatorTests.java +++ b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/chunk/ChunkAccumulatorTests.java @@ -17,6 +17,7 @@ package org.nd4j.aeron.ipc.chunk; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.aeron.ipc.NDArrayMessage; import org.nd4j.linalg.factory.Nd4j; @@ -25,7 +26,7 @@ import static org.junit.Assert.assertEquals; /** * Created by agibsonccc on 11/20/16. */ -public class ChunkAccumulatorTests { +public class ChunkAccumulatorTests extends BaseND4JTest { @Test public void testAccumulator() { diff --git a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/chunk/NDArrayMessageChunkTests.java b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/chunk/NDArrayMessageChunkTests.java index 3aeb37ee0..314ed0243 100644 --- a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/chunk/NDArrayMessageChunkTests.java +++ b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/chunk/NDArrayMessageChunkTests.java @@ -18,6 +18,7 @@ package org.nd4j.aeron.ipc.chunk; import org.agrona.DirectBuffer; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.aeron.ipc.NDArrayMessage; import org.nd4j.aeron.util.BufferUtil; import org.nd4j.linalg.factory.Nd4j; @@ -30,7 +31,7 @@ import static org.junit.Assert.assertEquals; /** * Created by agibsonccc on 11/20/16. */ -public class NDArrayMessageChunkTests { +public class NDArrayMessageChunkTests extends BaseND4JTest { @Test public void testChunkSerialization() { diff --git a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/response/AeronNDArrayResponseTest.java b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/response/AeronNDArrayResponseTest.java index f843f3839..b1df7d193 100644 --- a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/response/AeronNDArrayResponseTest.java +++ b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/response/AeronNDArrayResponseTest.java @@ -24,6 +24,7 @@ import org.agrona.CloseHelper; import org.agrona.concurrent.BusySpinIdleStrategy; import org.junit.Before; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.aeron.ipc.*; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -37,24 +38,33 @@ import static org.junit.Assert.assertEquals; * Created by agibsonccc on 10/3/16. */ @Slf4j -public class AeronNDArrayResponseTest { +public class AeronNDArrayResponseTest extends BaseND4JTest { private MediaDriver mediaDriver; + @Override + public long getTimeoutMilliseconds() { + return 180000L; + } + @Before public void before() { - final MediaDriver.Context ctx = - new MediaDriver.Context().threadingMode(ThreadingMode.SHARED).dirsDeleteOnStart(true) - .termBufferSparseFile(false).conductorIdleStrategy(new BusySpinIdleStrategy()) - .receiverIdleStrategy(new BusySpinIdleStrategy()) - .senderIdleStrategy(new BusySpinIdleStrategy()); - mediaDriver = MediaDriver.launchEmbedded(ctx); - System.out.println("Using media driver directory " + mediaDriver.aeronDirectoryName()); - System.out.println("Launched media driver"); + if(isIntegrationTests()) { + final MediaDriver.Context ctx = + new MediaDriver.Context().threadingMode(ThreadingMode.SHARED).dirsDeleteOnStart(true) + .termBufferSparseFile(false).conductorIdleStrategy(new BusySpinIdleStrategy()) + .receiverIdleStrategy(new BusySpinIdleStrategy()) + .senderIdleStrategy(new BusySpinIdleStrategy()); + mediaDriver = MediaDriver.launchEmbedded(ctx); + System.out.println("Using media driver directory " + mediaDriver.aeronDirectoryName()); + System.out.println("Launched media driver"); + } } @Test public void testResponse() throws Exception { + skipUnlessIntegrationTests(); //Long-running test - don't run as part of unit tests by default + int streamId = 10; int responderStreamId = 11; String host = "127.0.0.1"; diff --git a/nd4j/nd4j-serde/nd4j-arrow/pom.xml b/nd4j/nd4j-serde/nd4j-arrow/pom.xml index f16583745..502666a76 100644 --- a/nd4j/nd4j-serde/nd4j-arrow/pom.xml +++ b/nd4j/nd4j-serde/nd4j-arrow/pom.xml @@ -54,6 +54,14 @@ arrow-format ${arrow.version} + + + + org.nd4j + nd4j-common-tests + ${project.version} + test + diff --git a/nd4j/nd4j-serde/nd4j-arrow/src/test/java/org/nd4j/arrow/ArrowSerdeTest.java b/nd4j/nd4j-serde/nd4j-arrow/src/test/java/org/nd4j/arrow/ArrowSerdeTest.java index 6c3328ea8..eee1521bd 100644 --- a/nd4j/nd4j-serde/nd4j-arrow/src/test/java/org/nd4j/arrow/ArrowSerdeTest.java +++ b/nd4j/nd4j-serde/nd4j-arrow/src/test/java/org/nd4j/arrow/ArrowSerdeTest.java @@ -18,12 +18,13 @@ package org.nd4j.arrow; import org.apache.arrow.flatbuf.Tensor; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import static org.junit.Assert.assertEquals; -public class ArrowSerdeTest { +public class ArrowSerdeTest extends BaseND4JTest { @Test public void testBackAndForth() { diff --git a/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/pom.xml b/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/pom.xml index 8ccd6b574..d7f3d0887 100644 --- a/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/pom.xml +++ b/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/pom.xml @@ -76,6 +76,14 @@ + + + + org.nd4j + nd4j-common-tests + ${project.version} + test + diff --git a/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/test/java/org/nd4j/kafka/Nd4jKafkaRouteTest.java b/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/test/java/org/nd4j/kafka/Nd4jKafkaRouteTest.java index 14c6d5875..3b40661bb 100644 --- a/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/test/java/org/nd4j/kafka/Nd4jKafkaRouteTest.java +++ b/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/test/java/org/nd4j/kafka/Nd4jKafkaRouteTest.java @@ -21,6 +21,7 @@ import org.apache.camel.impl.DefaultCamelContext; import org.junit.After; import org.junit.Before; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.camel.kafka.KafkaConnectionInformation; import org.nd4j.camel.kafka.Nd4jKafkaConsumer; import org.nd4j.camel.kafka.Nd4jKafkaProducer; @@ -32,7 +33,7 @@ import static org.junit.Assert.assertEquals; /** * Created by agibsonccc on 7/19/16. */ -public class Nd4jKafkaRouteTest { +public class Nd4jKafkaRouteTest extends BaseND4JTest { private EmbeddedKafkaCluster kafka; private EmbeddedZookeeper zk; private CamelContext camelContext; diff --git a/nd4j/nd4j-serde/nd4j-gson/pom.xml b/nd4j/nd4j-serde/nd4j-gson/pom.xml index 82770d51a..75d251a29 100644 --- a/nd4j/nd4j-serde/nd4j-gson/pom.xml +++ b/nd4j/nd4j-serde/nd4j-gson/pom.xml @@ -45,6 +45,14 @@ junit test + + + + org.nd4j + nd4j-common-tests + ${project.version} + test + diff --git a/nd4j/nd4j-serde/nd4j-gson/src/test/java/org/nd4j/serde/gson/GsonDeserializationUtilsTest.java b/nd4j/nd4j-serde/nd4j-gson/src/test/java/org/nd4j/serde/gson/GsonDeserializationUtilsTest.java index 0695a6cb8..a64641fe3 100644 --- a/nd4j/nd4j-serde/nd4j-gson/src/test/java/org/nd4j/serde/gson/GsonDeserializationUtilsTest.java +++ b/nd4j/nd4j-serde/nd4j-gson/src/test/java/org/nd4j/serde/gson/GsonDeserializationUtilsTest.java @@ -17,13 +17,14 @@ package org.nd4j.serde.gson; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import static org.junit.Assert.assertEquals; -public class GsonDeserializationUtilsTest { +public class GsonDeserializationUtilsTest extends BaseND4JTest { @Test public void deserializeRawJson_PassInInRank3Array_ExpectCorrectDeserialization() { String serializedRawArray = "[[[1.00, 11.00, 3.00],\n" + "[13.00, 5.00, 15.00],\n" + "[7.00, 17.00, 9.00]]]"; diff --git a/nd4j/nd4j-serde/nd4j-kryo/pom.xml b/nd4j/nd4j-serde/nd4j-kryo/pom.xml index 850413b1d..5c91c8024 100644 --- a/nd4j/nd4j-serde/nd4j-kryo/pom.xml +++ b/nd4j/nd4j-serde/nd4j-kryo/pom.xml @@ -113,6 +113,14 @@ junit test + + + + org.nd4j + nd4j-common-tests + ${project.version} + test + diff --git a/nd4j/nd4j-serde/nd4j-kryo/src/test/java/org/nd4j/TestNd4jKryoSerialization.java b/nd4j/nd4j-serde/nd4j-kryo/src/test/java/org/nd4j/TestNd4jKryoSerialization.java index 1d51060ae..44db9d95f 100644 --- a/nd4j/nd4j-serde/nd4j-kryo/src/test/java/org/nd4j/TestNd4jKryoSerialization.java +++ b/nd4j/nd4j-serde/nd4j-kryo/src/test/java/org/nd4j/TestNd4jKryoSerialization.java @@ -42,7 +42,7 @@ import static org.junit.Assert.assertTrue; /** * Created by Alex on 04/07/2016. */ -public class TestNd4jKryoSerialization { +public class TestNd4jKryoSerialization extends BaseND4JTest { private JavaSparkContext sc; diff --git a/nd4j/pom.xml b/nd4j/pom.xml index f043d7299..ad87566df 100644 --- a/nd4j/pom.xml +++ b/nd4j/pom.xml @@ -63,6 +63,7 @@ nd4j-uberjar nd4j-tensorflow nd4j-remote + nd4j-common-tests diff --git a/pom.xml b/pom.xml index ada833f12..062ecf5bd 100644 --- a/pom.xml +++ b/pom.xml @@ -854,5 +854,29 @@ arm + + + + integration-tests + + false + + + + + maven-surefire-plugin + ${maven-surefire-plugin.version} + true + + + true + + + + + +