Unit/integration test split + test speedup (#166)

* Add maven profile + base tests methods for integration tests

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Switch from system property to environment variable; seems more reliable in intellij

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Add nd4j-common-tests module, and common base test; cleanup

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Ensure all ND4J tests extend BaseND4JTest

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Test spam reduction, import fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Add test logging to nd4j-aeron

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix unintended change

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Reduce sprint test log spam

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* More test spam cleanup

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Significantly speed up TSNE tests

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* W2V iterator test unit/integration split

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* More NLP test speedups

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Avoid debug/verbose mode leaking between tests

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* test tweak

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Arbiter extends base DL4J test

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Arbiter test speedup

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* nlp-uima test speedup

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* More test speedups

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix ND4J base test

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Few small ND4J test speed improvements

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* DL4J tests speedup

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* More tweaks

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Even more test speedups

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* More tweaks

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Various test fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* More test fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* Add ability to specify number of threads for C++ ops in BaseDL4JTest and BaseND4JTest

Signed-off-by: Alex Black <blacka101@gmail.com>

* nd4j-aeron test profile fix for CUDA

Signed-off-by: Alex Black <blacka101@gmail.com>
master
Alex Black 2020-01-22 22:27:01 +11:00 committed by GitHub
parent 2717b25931
commit a25bb6a11c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
191 changed files with 1468 additions and 794 deletions

View File

@ -84,6 +84,13 @@
<artifactId>jackson</artifactId>
<version>${nd4j.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-common-tests</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
<profiles>

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition;
@ -32,7 +33,7 @@ import org.deeplearning4j.arbiter.optimize.runner.listener.impl.LoggingStatusLis
import org.junit.Assert;
import org.junit.Test;
public class TestGeneticSearch {
public class TestGeneticSearch extends BaseDL4JTest {
public class TestSelectionOperator extends SelectionOperator {
@Override

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider;
import org.deeplearning4j.arbiter.optimize.generator.GridSearchCandidateGenerator;
@ -26,7 +27,7 @@ import java.util.Map;
import static org.junit.Assert.*;
public class TestGridSearch {
public class TestGridSearch extends BaseDL4JTest {
@Test
public void testIndexing() {

View File

@ -19,6 +19,7 @@ package org.deeplearning4j.arbiter.optimize;
import org.apache.commons.math3.distribution.LogNormalDistribution;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.distribution.UniformIntegerDistribution;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider;
@ -49,7 +50,7 @@ import static org.junit.Assert.assertEquals;
/**
* Created by Alex on 02/02/2017.
*/
public class TestJson {
public class TestJson extends BaseDL4JTest {
protected static ObjectMapper getObjectMapper(JsonFactory factory) {
ObjectMapper om = new ObjectMapper(factory);

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider;
import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition;
@ -34,7 +35,7 @@ import java.util.Map;
* Test random search on the Branin Function:
* http://www.sfu.ca/~ssurjano/branin.html
*/
public class TestRandomSearch {
public class TestRandomSearch extends BaseDL4JTest {
@Test
public void test() throws Exception {

View File

@ -17,12 +17,13 @@
package org.deeplearning4j.arbiter.optimize.distribution;
import org.apache.commons.math3.distribution.RealDistribution;
import org.deeplearning4j.BaseDL4JTest;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
public class TestLogUniform {
public class TestLogUniform extends BaseDL4JTest {
@Test
public void testSimple(){

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.ArithmeticCrossover;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection;
@ -24,7 +25,7 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
import org.junit.Assert;
import org.junit.Test;
public class ArithmeticCrossoverTests {
public class ArithmeticCrossoverTests extends BaseDL4JTest {
@Test
public void ArithmeticCrossover_Crossover_OutsideCrossoverRate_ShouldReturnParent0() {

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
import org.deeplearning4j.arbiter.optimize.genetic.TestCrossoverOperator;
@ -23,7 +24,7 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer;
import org.junit.Assert;
import org.junit.Test;
public class CrossoverOperatorTests {
public class CrossoverOperatorTests extends BaseDL4JTest {
@Test
public void CrossoverOperator_initializeInstance_ShouldInitPopulationModel() throws IllegalAccessException {

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.utils.CrossoverPointsGenerator;
import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
import org.junit.Assert;
@ -24,7 +25,7 @@ import org.junit.Test;
import java.util.Deque;
public class CrossoverPointsGeneratorTests {
public class CrossoverPointsGeneratorTests extends BaseDL4JTest {
@Test
public void CrossoverPointsGenerator_FixedNumberCrossovers() {

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.KPointCrossover;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection;
@ -25,7 +26,7 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
import org.junit.Assert;
import org.junit.Test;
public class KPointCrossoverTests {
public class KPointCrossoverTests extends BaseDL4JTest {
@Test
public void KPointCrossover_BelowCrossoverRate_ShouldReturnParent0() {

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection;
import org.junit.Assert;
@ -24,7 +25,7 @@ import org.junit.Test;
import java.util.ArrayList;
import java.util.List;
public class ParentSelectionTests {
public class ParentSelectionTests extends BaseDL4JTest {
@Test
public void ParentSelection_InitializeInstance_ShouldInitPopulation() {

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.RandomTwoParentSelection;
import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
@ -26,7 +27,7 @@ import org.junit.Test;
import java.util.ArrayList;
import java.util.List;
public class RandomTwoParentSelectionTests {
public class RandomTwoParentSelectionTests extends BaseDL4JTest {
@Test
public void RandomTwoParentSelection_ShouldReturnTwoDifferentParents() {
RandomGenerator rng = new TestRandomGenerator(new int[] {1, 1, 1, 0}, null);

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.SinglePointCrossover;
import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection;
@ -24,7 +25,7 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
import org.junit.Assert;
import org.junit.Test;
public class SinglePointCrossoverTests {
public class SinglePointCrossoverTests extends BaseDL4JTest {
@Test
public void SinglePointCrossover_BelowCrossoverRate_ShouldReturnParent0() {
RandomGenerator rng = new TestRandomGenerator(null, new double[] {1.0});

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.TwoParentsCrossoverOperator;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection;
@ -27,7 +28,7 @@ import org.junit.Assert;
import org.junit.Test;
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
public class TwoParentsCrossoverOperatorTests {
public class TwoParentsCrossoverOperatorTests extends BaseDL4JTest {
class TestTwoParentsCrossoverOperator extends TwoParentsCrossoverOperator {

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.UniformCrossover;
import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection;
@ -24,7 +25,7 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
import org.junit.Assert;
import org.junit.Test;
public class UniformCrossoverTests {
public class UniformCrossoverTests extends BaseDL4JTest {
@Test
public void UniformCrossover_BelowCrossoverRate_ShouldReturnParent0() {

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.culling;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.LeastFitCullOperator;
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
@ -27,7 +28,7 @@ import org.junit.Test;
import java.util.ArrayList;
import java.util.List;
public class LeastFitCullOperatorTests {
public class LeastFitCullOperatorTests extends BaseDL4JTest {
@Test
public void LeastFitCullingOperation_ShouldCullLastElements() {

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.culling;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.RatioCullOperator;
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
@ -27,7 +28,7 @@ import sun.reflect.generics.reflectiveObjects.NotImplementedException;
import java.util.List;
public class RatioCullOperatorTests {
public class RatioCullOperatorTests extends BaseDL4JTest {
class TestRatioCullOperator extends RatioCullOperator {

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.mutation;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.mutation.RandomMutationOperator;
import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
import org.junit.Assert;
@ -24,7 +25,7 @@ import org.junit.Test;
import java.lang.reflect.Field;
import java.util.Arrays;
public class RandomMutationOperatorTests {
public class RandomMutationOperatorTests extends BaseDL4JTest {
@Test
public void RandomMutationOperator_DefaultBuild_ShouldNotBeNull() {
RandomMutationOperator sut = new RandomMutationOperator.Builder().build();

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.population;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.CullOperator;
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
@ -27,7 +28,7 @@ import org.junit.Test;
import java.util.List;
public class PopulationModelTests {
public class PopulationModelTests extends BaseDL4JTest {
private class TestCullOperator implements CullOperator {

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.selection;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverOperator;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
@ -36,7 +37,7 @@ import sun.reflect.generics.reflectiveObjects.NotImplementedException;
import static org.junit.Assert.assertArrayEquals;
public class GeneticSelectionOperatorTests {
public class GeneticSelectionOperatorTests extends BaseDL4JTest {
private class TestCullOperator implements CullOperator {

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.selection;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
@ -25,7 +26,7 @@ import org.junit.Assert;
import org.junit.Test;
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
public class SelectionOperatorTests {
public class SelectionOperatorTests extends BaseDL4JTest {
private class TestSelectionOperator extends SelectionOperator {
public PopulationModel getPopulationModel() {

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.arbiter.optimize.parameter;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.optimize.parameter.continuous.ContinuousParameterSpace;
import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterSpace;
@ -25,7 +26,7 @@ import org.junit.Test;
import static org.junit.Assert.assertEquals;
public class TestParameterSpaces {
public class TestParameterSpaces extends BaseDL4JTest {
@Test

View File

@ -63,6 +63,13 @@
<artifactId>gson</artifactId>
<version>${gson.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-common-tests</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
<profiles>

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.computationgraph;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.ComputationGraphSpace;
import org.deeplearning4j.arbiter.TestUtils;
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
@ -44,7 +45,7 @@ import java.util.Random;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
public class TestComputationGraphSpace {
public class TestComputationGraphSpace extends BaseDL4JTest {
@Test
public void testBasic() {

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.arbiter.computationgraph;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.ComputationGraphSpace;
import org.deeplearning4j.arbiter.conf.updater.AdamSpace;
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
@ -85,7 +86,7 @@ import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
@Slf4j
public class TestGraphLocalExecution {
public class TestGraphLocalExecution extends BaseDL4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@ -126,7 +127,7 @@ public class TestGraphLocalExecution {
if(dataApproach == 0){
ds = TestDL4JLocalExecution.MnistDataSource.class;
dsP = new Properties();
dsP.setProperty("minibatch", "8");
dsP.setProperty("minibatch", "2");
candidateGenerator = new RandomSearchGenerator(mls);
} else if(dataApproach == 1) {
//DataProvider approach
@ -150,8 +151,8 @@ public class TestGraphLocalExecution {
.dataSource(ds, dsP)
.modelSaver(new FileModelSaver(modelSave))
.scoreFunction(new TestSetLossScoreFunction())
.terminationConditions(new MaxTimeCondition(2, TimeUnit.MINUTES),
new MaxCandidatesCondition(5))
.terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS),
new MaxCandidatesCondition(3))
.build();
IOptimizationRunner runner = new LocalOptimizationRunner(configuration,new ComputationGraphTaskCreator(new ClassificationEvaluator()));
@ -159,7 +160,7 @@ public class TestGraphLocalExecution {
runner.execute();
List<ResultReference> results = runner.getResults();
assertEquals(5, results.size());
assertTrue(results.size() > 0);
System.out.println("----- COMPLETE - " + results.size() + " results -----");
}
@ -203,8 +204,8 @@ public class TestGraphLocalExecution {
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
.candidateGenerator(candidateGenerator).dataProvider(dataProvider)
.modelSaver(new FileModelSaver(modelSavePath)).scoreFunction(ScoreFunctions.testSetLoss(true))
.terminationConditions(new MaxTimeCondition(20, TimeUnit.SECONDS),
new MaxCandidatesCondition(10))
.terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS),
new MaxCandidatesCondition(3))
.build();
IOptimizationRunner runner = new LocalOptimizationRunner(configuration,
@ -223,7 +224,7 @@ public class TestGraphLocalExecution {
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.1)))
.l2(new ContinuousParameterSpace(0.0001, 0.01)).addInputs("in")
.setInputTypes(InputType.feedForward(4))
.setInputTypes(InputType.feedForward(784))
.addLayer("layer0",
new DenseLayerSpace.Builder().nIn(784).nOut(new IntegerParameterSpace(2, 10))
.activation(new DiscreteParameterSpace<>(Activation.RELU, Activation.TANH))
@ -250,8 +251,8 @@ public class TestGraphLocalExecution {
.candidateGenerator(candidateGenerator)
.dataProvider(new TestMdsDataProvider(1, 32))
.modelSaver(new FileModelSaver(modelSavePath)).scoreFunction(ScoreFunctions.testSetLoss(true))
.terminationConditions(new MaxTimeCondition(20, TimeUnit.SECONDS),
new MaxCandidatesCondition(10))
.terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS),
new MaxCandidatesCondition(3))
.scoreFunction(ScoreFunctions.testSetAccuracy())
.build();
@ -279,7 +280,7 @@ public class TestGraphLocalExecution {
@Override
public Object trainData(Map<String, Object> dataParameters) {
try {
DataSetIterator underlying = new MnistDataSetIterator(batchSize, Math.min(60000, 10 * batchSize), false, true, true, 12345);
DataSetIterator underlying = new MnistDataSetIterator(batchSize, Math.min(60000, 3 * batchSize), false, true, true, 12345);
return new MultiDataSetIteratorAdapter(new MultipleEpochsIterator(numEpochs, underlying));
} catch (IOException e) {
throw new RuntimeException(e);
@ -289,7 +290,7 @@ public class TestGraphLocalExecution {
@Override
public Object testData(Map<String, Object> dataParameters) {
try {
DataSetIterator underlying = new MnistDataSetIterator(batchSize, Math.min(10000, 5 * batchSize), false, false, false, 12345);
DataSetIterator underlying = new MnistDataSetIterator(batchSize, Math.min(10000, 2 * batchSize), false, false, false, 12345);
return new MultiDataSetIteratorAdapter(underlying);
} catch (IOException e) {
throw new RuntimeException(e);
@ -305,7 +306,7 @@ public class TestGraphLocalExecution {
@Test
public void testLocalExecutionEarlyStopping() throws Exception {
EarlyStoppingConfiguration<ComputationGraph> esConf = new EarlyStoppingConfiguration.Builder<ComputationGraph>()
.epochTerminationConditions(new MaxEpochsTerminationCondition(4))
.epochTerminationConditions(new MaxEpochsTerminationCondition(2))
.scoreCalculator(new ScoreProvider())
.modelSaver(new InMemoryModelSaver()).build();
Map<String, Object> commands = new HashMap<>();
@ -348,8 +349,8 @@ public class TestGraphLocalExecution {
.dataProvider(dataProvider)
.scoreFunction(ScoreFunctions.testSetF1())
.modelSaver(new FileModelSaver(modelSavePath))
.terminationConditions(new MaxTimeCondition(45, TimeUnit.SECONDS),
new MaxCandidatesCondition(10))
.terminationConditions(new MaxTimeCondition(15, TimeUnit.SECONDS),
new MaxCandidatesCondition(3))
.build();
@ -364,7 +365,7 @@ public class TestGraphLocalExecution {
@Override
public ScoreCalculator get() {
try {
return new DataSetLossCalculatorCG(new MnistDataSetIterator(128, 1280), true);
return new DataSetLossCalculatorCG(new MnistDataSetIterator(4, 8), true);
} catch (Exception e){
throw new RuntimeException(e);
}

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.arbiter.computationgraph;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.ComputationGraphSpace;
import org.deeplearning4j.arbiter.conf.updater.AdamSpace;
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
@ -79,11 +80,16 @@ import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
@Slf4j
public class TestGraphLocalExecutionGenetic {
public class TestGraphLocalExecutionGenetic extends BaseDL4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Override
public long getTimeoutMilliseconds() {
return 45000L;
}
@Test
public void testLocalExecutionDataSources() throws Exception {
for (int dataApproach = 0; dataApproach < 3; dataApproach++) {
@ -115,7 +121,7 @@ public class TestGraphLocalExecutionGenetic {
if (dataApproach == 0) {
ds = TestDL4JLocalExecution.MnistDataSource.class;
dsP = new Properties();
dsP.setProperty("minibatch", "8");
dsP.setProperty("minibatch", "2");
candidateGenerator = new GeneticSearchCandidateGenerator.Builder(mls, scoreFunction)
.populationModel(new PopulationModel.Builder().populationSize(5).build())
@ -148,7 +154,7 @@ public class TestGraphLocalExecutionGenetic {
.dataSource(ds, dsP)
.modelSaver(new FileModelSaver(modelSave))
.scoreFunction(new TestSetLossScoreFunction())
.terminationConditions(new MaxTimeCondition(2, TimeUnit.MINUTES),
.terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS),
new MaxCandidatesCondition(10))
.build();
@ -157,7 +163,7 @@ public class TestGraphLocalExecutionGenetic {
runner.execute();
List<ResultReference> results = runner.getResults();
assertEquals(10, results.size());
assertTrue(results.size() > 0);
System.out.println("----- COMPLETE - " + results.size() + " results -----");
}

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.json;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.ComputationGraphSpace;
import org.deeplearning4j.arbiter.MultiLayerSpace;
import org.deeplearning4j.arbiter.conf.updater.AdaMaxSpace;
@ -71,7 +72,7 @@ import static org.junit.Assert.assertNotNull;
/**
* Created by Alex on 14/02/2017.
*/
public class TestJson {
public class TestJson extends BaseDL4JTest {
@Test
public void testMultiLayerSpaceJson() {

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.multilayernetwork;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.MultiLayerSpace;
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
import org.deeplearning4j.arbiter.layers.ConvolutionLayerSpace;
@ -59,7 +60,7 @@ import java.util.concurrent.TimeUnit;
// import org.deeplearning4j.arbiter.optimize.ui.listener.UIOptimizationRunnerStatusListener;
/** Not strictly a unit test. Rather: part example, part debugging on MNIST */
public class MNISTOptimizationTest {
public class MNISTOptimizationTest extends BaseDL4JTest {
public static void main(String[] args) throws Exception {
EarlyStoppingConfiguration<MultiLayerNetwork> esConf =

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.arbiter.multilayernetwork;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.MultiLayerSpace;
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
import org.deeplearning4j.arbiter.evaluator.multilayer.ClassificationEvaluator;
@ -72,9 +73,10 @@ import java.util.Properties;
import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
@Slf4j
public class TestDL4JLocalExecution {
public class TestDL4JLocalExecution extends BaseDL4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@ -112,7 +114,7 @@ public class TestDL4JLocalExecution {
if(dataApproach == 0){
ds = MnistDataSource.class;
dsP = new Properties();
dsP.setProperty("minibatch", "8");
dsP.setProperty("minibatch", "2");
candidateGenerator = new RandomSearchGenerator(mls);
} else if(dataApproach == 1) {
//DataProvider approach
@ -136,7 +138,7 @@ public class TestDL4JLocalExecution {
.dataSource(ds, dsP)
.modelSaver(new FileModelSaver(modelSave))
.scoreFunction(new TestSetLossScoreFunction())
.terminationConditions(new MaxTimeCondition(2, TimeUnit.MINUTES),
.terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS),
new MaxCandidatesCondition(5))
.build();
@ -146,7 +148,7 @@ public class TestDL4JLocalExecution {
runner.execute();
List<ResultReference> results = runner.getResults();
assertEquals(5, results.size());
assertTrue(results.size() > 0);
System.out.println("----- COMPLETE - " + results.size() + " results -----");
}

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.multilayernetwork;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.ComputationGraphSpace;
import org.deeplearning4j.arbiter.MultiLayerSpace;
import org.deeplearning4j.arbiter.layers.DenseLayerSpace;
@ -39,7 +40,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.io.File;
public class TestErrors {
public class TestErrors extends BaseDL4JTest {
@Rule
public TemporaryFolder temp = new TemporaryFolder();
@ -60,7 +61,7 @@ public class TestErrors {
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls);
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
.candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 10))
.candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 3))
.modelSaver(new FileModelSaver(f)).scoreFunction(new TestSetLossScoreFunction(true))
.terminationConditions(
new MaxCandidatesCondition(5))
@ -87,7 +88,7 @@ public class TestErrors {
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls);
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
.candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 10))
.candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 3))
.modelSaver(new FileModelSaver(f)).scoreFunction(new TestSetLossScoreFunction(true))
.terminationConditions(
new MaxCandidatesCondition(5))
@ -116,7 +117,7 @@ public class TestErrors {
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls);
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
.candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 10))
.candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 3))
.modelSaver(new FileModelSaver(f)).scoreFunction(new TestSetLossScoreFunction(true))
.terminationConditions(new MaxCandidatesCondition(5))
.build();
@ -143,7 +144,7 @@ public class TestErrors {
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls);
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
.candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 10))
.candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 3))
.modelSaver(new FileModelSaver(f)).scoreFunction(new TestSetLossScoreFunction(true))
.terminationConditions(
new MaxCandidatesCondition(5))

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.arbiter.multilayernetwork;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.TestUtils;
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
import org.deeplearning4j.arbiter.layers.*;
@ -44,7 +45,7 @@ import java.util.Random;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
public class TestLayerSpace {
public class TestLayerSpace extends BaseDL4JTest {
@Test
public void testBasic1() {

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.multilayernetwork;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.DL4JConfiguration;
import org.deeplearning4j.arbiter.MultiLayerSpace;
import org.deeplearning4j.arbiter.TestUtils;
@ -86,7 +87,7 @@ import java.util.*;
import static org.junit.Assert.*;
public class TestMultiLayerSpace {
public class TestMultiLayerSpace extends BaseDL4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();

View File

@ -18,6 +18,7 @@ package org.deeplearning4j.arbiter.multilayernetwork;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.MultiLayerSpace;
import org.deeplearning4j.arbiter.conf.updater.AdamSpace;
import org.deeplearning4j.arbiter.layers.OutputLayerSpace;
@ -60,7 +61,13 @@ import java.util.Map;
import static org.junit.Assert.assertEquals;
@Slf4j
public class TestScoreFunctions {
public class TestScoreFunctions extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 60000L;
}
@Test
public void testROCScoreFunctions() throws Exception {
@ -107,7 +114,7 @@ public class TestScoreFunctions {
List<ResultReference> list = runner.getResults();
for (ResultReference rr : list) {
DataSetIterator testIter = new MnistDataSetIterator(32, 2000, false, false, true, 12345);
DataSetIterator testIter = new MnistDataSetIterator(4, 16, false, false, false, 12345);
testIter.setPreProcessor(new PreProc(rocType));
OptimizationResult or = rr.getResult();
@ -141,10 +148,10 @@ public class TestScoreFunctions {
}
DataSetIterator iter = new MnistDataSetIterator(32, 8000, false, true, true, 12345);
DataSetIterator iter = new MnistDataSetIterator(4, 16, false, false, false, 12345);
iter.setPreProcessor(new PreProc(rocType));
assertEquals(msg, expScore, or.getScore(), 1e-5);
assertEquals(msg, expScore, or.getScore(), 1e-4);
}
}
}
@ -158,7 +165,7 @@ public class TestScoreFunctions {
@Override
public Object trainData(Map<String, Object> dataParameters) {
try {
DataSetIterator iter = new MnistDataSetIterator(32, 8000, false, true, true, 12345);
DataSetIterator iter = new MnistDataSetIterator(4, 16, false, false, false, 12345);
iter.setPreProcessor(new PreProc(rocType));
return iter;
} catch (IOException e){
@ -169,7 +176,7 @@ public class TestScoreFunctions {
@Override
public Object testData(Map<String, Object> dataParameters) {
try {
DataSetIterator iter = new MnistDataSetIterator(32, 2000, false, false, true, 12345);
DataSetIterator iter = new MnistDataSetIterator(4, 16, false, false, false, 12345);
iter.setPreProcessor(new PreProc(rocType));
return iter;
} catch (IOException e){

View File

@ -49,6 +49,13 @@
<version>${junit.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-common-tests</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
<profiles>

View File

@ -18,6 +18,7 @@ package org.deeplearning4j.arbiter.server;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.MultiLayerSpace;
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
import org.deeplearning4j.arbiter.layers.DenseLayerSpace;
@ -52,7 +53,7 @@ import static org.junit.Assert.assertEquals;
* Created by agibsonccc on 3/12/17.
*/
@Slf4j
public class ArbiterCLIRunnerTest {
public class ArbiterCLIRunnerTest extends BaseDL4JTest {
@Test

View File

@ -24,6 +24,7 @@ import org.junit.Before;
import org.junit.Rule;
import org.junit.rules.TestName;
import org.junit.rules.Timeout;
import org.nd4j.base.Preconditions;
import org.nd4j.config.ND4JSystemProperties;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
@ -36,6 +37,8 @@ import java.util.List;
import java.util.Map;
import java.util.Properties;
import static org.junit.Assume.assumeTrue;
@Slf4j
public abstract class BaseDL4JTest {
@ -47,6 +50,17 @@ public abstract class BaseDL4JTest {
protected long startTime;
protected int threadCountBefore;
private final int DEFAULT_THREADS = Runtime.getRuntime().availableProcessors();
/**
* Override this to specify the number of threads for C++ execution, via
* {@link org.nd4j.linalg.factory.Environment#setMaxMasterThreads(int)}
* @return Number of threads to use for C++ op execution
*/
public int numThreads(){
return DEFAULT_THREADS;
}
/**
* Override this method to set the default timeout for methods in the test class
*/
@ -72,6 +86,28 @@ public abstract class BaseDL4JTest {
return getDataType();
}
protected Boolean integrationTest;
/**
* @return True if integration tests maven profile is enabled, false otherwise.
*/
public boolean isIntegrationTests(){
if(integrationTest == null){
String prop = System.getenv("DL4J_INTEGRATION_TESTS");
integrationTest = Boolean.parseBoolean(prop);
}
return integrationTest;
}
/**
* Call this as the first line of a test in order to skip that test, only when the integration tests maven profile is not enabled.
* This can be used to dynamically skip integration tests when the integration test profile is not enabled.
* Note that the integration test profile is not enabled by default - "integration-tests" profile
*/
public void skipUnlessIntegrationTests(){
assumeTrue("Skipping integration test - integration profile is not enabled", isIntegrationTests());
}
@Before
public void beforeTest(){
log.info("{}.{}", getClass().getSimpleName(), name.getMethodName());
@ -81,6 +117,14 @@ public abstract class BaseDL4JTest {
Nd4j.getExecutioner().setProfilingMode(getProfilingMode());
Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build());
Nd4j.setDefaultDataTypes(getDataType(), getDefaultFPDataType());
Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build());
Nd4j.getExecutioner().enableDebugMode(false);
Nd4j.getExecutioner().enableVerboseMode(false);
int numThreads = numThreads();
Preconditions.checkState(numThreads > 0, "Number of threads must be > 0");
if(numThreads != Nd4j.getEnvironment().maxMasterThreads()) {
Nd4j.getEnvironment().setMaxMasterThreads(numThreads);
}
startTime = System.currentTimeMillis();
threadCountBefore = ManagementFactory.getThreadMXBean().getThreadCount();
}

View File

@ -78,8 +78,16 @@ public class AsyncDataSetIteratorTest extends BaseDL4JTest {
@Test
public void hasNextWithResetAndLoad() throws Exception {
int[] prefetchSizes;
if(isIntegrationTests()){
prefetchSizes = new int[]{2, 3, 4, 5, 6, 7, 8};
} else {
prefetchSizes = new int[]{2, 3, 8};
}
for (int iter = 0; iter < ITERATIONS; iter++) {
for (int prefetchSize = 2; prefetchSize <= 8; prefetchSize++) {
for(int prefetchSize : prefetchSizes){
AsyncDataSetIterator iterator = new AsyncDataSetIterator(backIterator, prefetchSize);
TestDataSetConsumer consumer = new TestDataSetConsumer(EXECUTION_SMALL);
int cnt = 0;
@ -161,8 +169,14 @@ public class AsyncDataSetIteratorTest extends BaseDL4JTest {
@Test
public void testVariableTimeSeries1() throws Exception {
int numBatches = isIntegrationTests() ? 1000 : 100;
int batchSize = isIntegrationTests() ? 32 : 8;
int timeStepsMin = 10;
int timeStepsMax = isIntegrationTests() ? 500 : 100;
int valuesPerTimestep = isIntegrationTests() ? 128 : 16;
AsyncDataSetIterator adsi = new AsyncDataSetIterator(
new VariableTimeseriesGenerator(1192, 1000, 32, 128, 10, 500, 10), 2, true);
new VariableTimeseriesGenerator(1192, numBatches, batchSize, valuesPerTimestep, timeStepsMin, timeStepsMax, 10), 2, true);
for (int e = 0; e < 10; e++) {
int cnt = 0;

View File

@ -18,21 +18,10 @@ package org.deeplearning4j.datasets.iterator;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
import org.datavec.api.split.NumberedFileInputSplit;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator;
import org.deeplearning4j.datasets.iterator.tools.VariableMultiTimeseriesGenerator;
import org.junit.Test;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.MultiDataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.MultiNormalizerStandardize;
import java.util.Arrays;
import static org.junit.Assert.assertEquals;
@ -49,7 +38,13 @@ public class AsyncMultiDataSetIteratorTest extends BaseDL4JTest {
*/
@Test
public void testVariableTimeSeries1() throws Exception {
val iterator = new VariableMultiTimeseriesGenerator(1192, 1000, 32, 128, 10, 500, 10);
int numBatches = isIntegrationTests() ? 1000 : 100;
int batchSize = isIntegrationTests() ? 32 : 8;
int timeStepsMin = 10;
int timeStepsMax = isIntegrationTests() ? 500 : 100;
int valuesPerTimestep = isIntegrationTests() ? 128 : 16;
val iterator = new VariableMultiTimeseriesGenerator(1192, numBatches, batchSize, valuesPerTimestep, timeStepsMin, timeStepsMax, 10);
iterator.reset();
iterator.hasNext();
val amdsi = new AsyncMultiDataSetIterator(iterator, 2, true);
@ -81,7 +76,13 @@ public class AsyncMultiDataSetIteratorTest extends BaseDL4JTest {
@Test
public void testVariableTimeSeries2() throws Exception {
val iterator = new VariableMultiTimeseriesGenerator(1192, 1000, 32, 128, 10, 500, 10);
int numBatches = isIntegrationTests() ? 1000 : 100;
int batchSize = isIntegrationTests() ? 32 : 8;
int timeStepsMin = 10;
int timeStepsMax = isIntegrationTests() ? 500 : 100;
int valuesPerTimestep = isIntegrationTests() ? 128 : 16;
val iterator = new VariableMultiTimeseriesGenerator(1192, numBatches, batchSize, valuesPerTimestep, timeStepsMin, timeStepsMax, 10);
for (int e = 0; e < 10; e++) {
iterator.reset();

View File

@ -46,17 +46,17 @@ public class TestEmnistDataSetIterator extends BaseDL4JTest {
@Test
public void testEmnistDataSetIterator() throws Exception {
// EmnistFetcher fetcher = new EmnistFetcher(EmnistDataSetIterator.Set.COMPLETE);
// File baseEmnistDir = fetcher.getFILE_DIR();
// if(baseEmnistDir.exists()){
// FileUtils.deleteDirectory(baseEmnistDir);
// }
// assertFalse(baseEmnistDir.exists());
int batchSize = 128;
for (EmnistDataSetIterator.Set s : EmnistDataSetIterator.Set.values()) {
EmnistDataSetIterator.Set[] sets;
if(isIntegrationTests()){
sets = EmnistDataSetIterator.Set.values();
} else {
sets = new EmnistDataSetIterator.Set[]{EmnistDataSetIterator.Set.MNIST, EmnistDataSetIterator.Set.LETTERS};
}
for (EmnistDataSetIterator.Set s : sets) {
boolean isBalanced = EmnistDataSetIterator.isBalanced(s);
int numLabels = EmnistDataSetIterator.numLabels(s);
INDArray labelCounts = null;

View File

@ -476,7 +476,7 @@ public class EvalTest extends BaseDL4JTest {
net.setListeners(new EvaluativeListener(iterTest, 3));
for( int i=0; i<10; i++ ){
for( int i=0; i<3; i++ ){
net.fit(iter);
}
}

View File

@ -339,9 +339,6 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest {
if (PRINT_RESULTS) {
log.info(msg);
for (int j = 0; j < net.getnLayers(); j++) {
log.info("Layer " + j + " # params: " + net.getLayer(j).numParams());
}
}
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS,
@ -623,13 +620,10 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest {
if (PRINT_RESULTS) {
log.info(msg);
// for (int j = 0; j < net.getnLayers(); j++) {
// log.info("Layer " + j + " # params: " + net.getLayer(j).numParams());
// }
}
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input)
.labels(labels).subset(true).maxPerParam(128));
.labels(labels).subset(true).maxPerParam(64));
assertTrue(msg, gradOK);

View File

@ -557,20 +557,19 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
int[] minibatchSizes = {2};
int width = 5;
int height = 5;
int[] inputDepths = {1, 2, 4};
Activation[] activations = {Activation.SIGMOID, Activation.TANH};
Nd4j.getRandom().setSeed(12345);
for (int inputDepth : inputDepths) {
for (Activation afn : activations) {
for (int minibatchSize : minibatchSizes) {
int[] inputDepths = new int[]{1, 2, 4};
Activation[] activations = {Activation.SIGMOID, Activation.TANH, Activation.SOFTPLUS};
int[] minibatch = {2, 1, 3};
for( int i=0; i<inputDepths.length; i++ ){
int inputDepth = inputDepths[i];
Activation afn = activations[i];
int minibatchSize = minibatch[i];
INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth);
INDArray labels = Nd4j.zeros(minibatchSize, nOut);
for (int i = 0; i < minibatchSize; i++) {
labels.putScalar(new int[]{i, i % nOut}, 1.0);
}
INDArray labels = TestUtils.randomOneHot(minibatchSize, nOut);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(new NoOp())
.dataType(DataType.DOUBLE)
@ -594,10 +593,6 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
for (int i = 0; i < 4; i++) {
System.out.println("nParams, layer " + i + ": " + net.getLayer(i).numParams());
}
String msg = "Minibatch=" + minibatchSize + ", activationFn="
+ afn;
System.out.println(msg);
@ -609,9 +604,6 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
TestUtils.testModelSerialization(net);
}
}
}
}
@Test
@ -1106,26 +1098,23 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
public void testCropping2DLayer() {
Nd4j.getRandom().setSeed(12345);
int nOut = 2;
int[] minibatchSizes = {1, 3};
int width = 12;
int height = 11;
int[] inputDepths = {1, 3};
int[] kernel = {2, 2};
int[] stride = {1, 1};
int[] padding = {0, 0};
int[][] cropTestCases = new int[][]{{0, 0, 0, 0}, {1, 1, 0, 0}, {2, 2, 2, 2}, {1, 2, 3, 4}};
int[] inputDepths = {1, 2, 3, 2};
int[] minibatchSizes = {2, 1, 3, 2};
for (int inputDepth : inputDepths) {
for (int minibatchSize : minibatchSizes) {
for (int i = 0; i < cropTestCases.length; i++) {
int inputDepth = inputDepths[i];
int minibatchSize = minibatchSizes[i];
int[] crop = cropTestCases[i];
INDArray input = Nd4j.rand(new int[]{minibatchSize, inputDepth, height, width});
INDArray labels = Nd4j.zeros(minibatchSize, nOut);
for (int i = 0; i < minibatchSize; i++) {
labels.putScalar(new int[]{i, i % nOut}, 1.0);
}
for (int[] crop : cropTestCases) {
INDArray labels = TestUtils.randomOneHot(minibatchSize, nOut);
MultiLayerConfiguration conf =
new NeuralNetConfiguration.Builder()
@ -1159,8 +1148,6 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
if (PRINT_RESULTS) {
System.out.println(msg);
// for (int j = 0; j < net.getnLayers(); j++)
// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
}
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input)
@ -1171,8 +1158,6 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
TestUtils.testModelSerialization(net);
}
}
}
}
@Test
public void testDepthwiseConv2D() {

View File

@ -142,9 +142,9 @@ public class GradientCheckTestsMasking extends BaseDL4JTest {
Nd4j.getRandom().setSeed(12345L);
int timeSeriesLength = 5;
int nIn = 5;
int nIn = 3;
int layerSize = 3;
int nOut = 3;
int nOut = 2;
int miniBatchSize = 2;
@ -170,24 +170,16 @@ public class GradientCheckTestsMasking extends BaseDL4JTest {
MultiLayerNetwork mln = new MultiLayerNetwork(conf);
mln.init();
Random r = new Random(12345L);
INDArray input = Nd4j.rand(new int[]{miniBatchSize, nIn, timeSeriesLength}, 'f').subi(0.5);
INDArray labels = Nd4j.zeros(miniBatchSize, nOut, timeSeriesLength);
for (int i = 0; i < miniBatchSize; i++) {
for (int j = 0; j < nIn; j++) {
labels.putScalar(i, r.nextInt(nOut), j, 1.0);
}
}
INDArray labels = TestUtils.randomOneHotTimeSeries(miniBatchSize, nOut, timeSeriesLength);
if (PRINT_RESULTS) {
System.out.println("testBidirectionalLSTMMasking() - testNum = " + testNum++);
// for (int j = 0; j < mln.getnLayers(); j++)
// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
}
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input)
.labels(labels).inputMask(mask).labelMask(mask).subset(true).maxPerParam(16));
.labels(labels).inputMask(mask).labelMask(mask).subset(true).maxPerParam(12));
assertTrue(gradOK);
TestUtils.testModelSerialization(mln);

View File

@ -123,8 +123,6 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
if (PRINT_RESULTS) {
System.out.println(name + " - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation="
+ outputActivation + ", doLearningFirst=" + doLearningFirst);
for (int j = 0; j < mln.getnLayers(); j++)
System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
}
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
@ -214,8 +212,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
System.out.println(testName + "- activationFn=" + afn + ", lossFn=" + lf
+ ", outputActivation=" + outputActivation + ", doLearningFirst="
+ doLearningFirst);
for (int j = 0; j < mln.getnLayers(); j++)
System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
// for (int j = 0; j < mln.getnLayers(); j++)
// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
}
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
@ -277,8 +275,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
if (PRINT_RESULTS) {
System.out.println(msg);
for (int j = 0; j < net.getnLayers(); j++)
System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
// for (int j = 0; j < net.getnLayers(); j++)
// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
}
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
@ -340,8 +338,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
if (PRINT_RESULTS) {
System.out.println(msg);
for (int j = 0; j < net.getnLayers(); j++)
System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
// for (int j = 0; j < net.getnLayers(); j++)
// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
}
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
@ -397,8 +395,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
if (PRINT_RESULTS) {
System.out.println(msg);
for (int j = 0; j < net.getnLayers(); j++)
System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
// for (int j = 0; j < net.getnLayers(); j++)
// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
}
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
@ -468,8 +466,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
if (PRINT_RESULTS) {
System.out.println(msg);
for (int j = 0; j < net.getnLayers(); j++)
System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
// for (int j = 0; j < net.getnLayers(); j++)
// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
}
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
@ -602,9 +600,9 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
for (int i = 0; i < 4; i++) {
System.out.println("nParams, layer " + i + ": " + net.getLayer(i).numParams());
}
// for (int i = 0; i < 4; i++) {
// System.out.println("nParams, layer " + i + ": " + net.getLayer(i).numParams());
// }
String msg = "PoolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn="
+ afn;
@ -663,9 +661,9 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
for (int j = 0; j < net.getLayers().length; j++) {
System.out.println("nParams, layer " + j + ": " + net.getLayer(j).numParams());
}
// for (int j = 0; j < net.getLayers().length; j++) {
// System.out.println("nParams, layer " + j + ": " + net.getLayer(j).numParams());
// }
String msg = "Minibatch=" + minibatchSize + ", inDepth=" + inputDepth + ", height=" + height
+ ", kernelSize=" + k;
@ -726,9 +724,9 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
for (int i = 0; i < net.getLayers().length; i++) {
System.out.println("nParams, layer " + i + ": " + net.getLayer(i).numParams());
}
// for (int i = 0; i < net.getLayers().length; i++) {
// System.out.println("nParams, layer " + i + ": " + net.getLayer(i).numParams());
// }
String msg = "Minibatch=" + minibatchSize + ", inDepth=" + inputDepth + ", height=" + height
+ ", kernelSize=" + k + ", stride = " + stride + ", convLayer first = "
@ -806,8 +804,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
if (PRINT_RESULTS) {
System.out.println(msg);
for (int j = 0; j < net.getnLayers(); j++)
System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
// for (int j = 0; j < net.getnLayers(); j++)
// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
}
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
@ -872,9 +870,9 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
for (int j = 0; j < net.getLayers().length; j++) {
System.out.println("nParams, layer " + j + ": " + net.getLayer(j).numParams());
}
// for (int j = 0; j < net.getLayers().length; j++) {
// System.out.println("nParams, layer " + j + ": " + net.getLayer(j).numParams());
// }
String msg = " - mb=" + minibatchSize + ", k="
+ k + ", s=" + s + ", d=" + d + ", cm=" + cm;
@ -943,9 +941,9 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
for (int i = 0; i < net.getLayers().length; i++) {
System.out.println("nParams, layer " + i + ": " + net.getLayer(i).numParams());
}
// for (int i = 0; i < net.getLayers().length; i++) {
// System.out.println("nParams, layer " + i + ": " + net.getLayer(i).numParams());
// }
String msg = " - mb=" + minibatchSize + ", k="
+ k + ", nIn=" + nIn + ", depthMul=" + depthMultiplier + ", s=" + s + ", cm=" + cm;
@ -1018,9 +1016,9 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
for (int i = 0; i < net.getLayers().length; i++) {
System.out.println("nParams, layer " + i + ": " + net.getLayer(i).numParams());
}
// for (int i = 0; i < net.getLayers().length; i++) {
// System.out.println("nParams, layer " + i + ": " + net.getLayer(i).numParams());
// }
String msg = " - mb=" + minibatchSize + ", k="
+ k + ", s=" + s + ", d=" + d + ", cm=" + cm;
@ -1104,9 +1102,9 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
for (int i = 0; i < net.getLayers().length; i++) {
System.out.println("nParams, layer " + i + ": " + net.getLayer(i).numParams());
}
// for (int i = 0; i < net.getLayers().length; i++) {
// System.out.println("nParams, layer " + i + ": " + net.getLayer(i).numParams());
// }
String msg = (subsampling ? "subsampling" : "conv") + " - mb=" + minibatchSize + ", k="
+ k + ", s=" + s + ", d=" + d + ", cm=" + cm;
@ -1179,8 +1177,6 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
if (PRINT_RESULTS) {
System.out.println(msg);
for (int j = 0; j < net.getnLayers(); j++)
System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
}
boolean gradOK = GradientCheckUtil.checkGradients(

View File

@ -177,7 +177,7 @@ public class KDTreeTest extends BaseDL4JTest {
@Test
public void testKNN() {
int dimensions = 512;
int vectorsNo = 50000;
int vectorsNo = isIntegrationTests() ? 50000 : 1000;
// make a KD-tree of dimension {#dimensions}
Stopwatch stopwatch = Stopwatch.createStarted();
KDTree kdTree = new KDTree(dimensions);

View File

@ -92,13 +92,13 @@ public class SPTreeTest extends BaseDL4JTest {
@Test
//@Ignore
public void testLargeTree() {
int num = 100000;
int num = isIntegrationTests() ? 100000 : 1000;
StopWatch watch = new StopWatch();
watch.start();
INDArray arr = Nd4j.linspace(1, num, num, Nd4j.dataType()).reshape(num, 1);
SpTree tree = new SpTree(arr);
watch.stop();
System.out.println("Tree created in " + watch);
System.out.println("Tree of size " + num + " created in " + watch);
}
}

View File

@ -45,19 +45,19 @@ public class RandomizedInputTest extends RandomizedTest {
private Tokenizer tokenizer = new Tokenizer();
@Test
@Repeat(iterations = 50)
@Repeat(iterations = 10)
public void testRandomizedUnicodeInput() {
assertCanTokenizeString(randomUnicodeOfLength(LENGTH), tokenizer);
}
@Test
@Repeat(iterations = 50)
@Repeat(iterations = 10)
public void testRandomizedRealisticUnicodeInput() {
assertCanTokenizeString(randomRealisticUnicodeOfLength(LENGTH), tokenizer);
}
@Test
@Repeat(iterations = 50)
@Repeat(iterations = 10)
public void testRandomizedAsciiInput() {
assertCanTokenizeString(randomAsciiOfLength(LENGTH), tokenizer);
}

View File

@ -406,11 +406,11 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
double simD = arraysSimilarity(day1, day2);
double simN = arraysSimilarity(night1, night2);
logger.info("Vec1 day: " + day1);
logger.info("Vec2 day: " + day2);
// logger.info("Vec1 day: " + day1);
// logger.info("Vec2 day: " + day2);
logger.info("Vec1 night: " + night1);
logger.info("Vec2 night: " + night2);
// logger.info("Vec1 night: " + night1);
// logger.info("Vec2 night: " + night2);
logger.info("Day/day cross-model similarity: " + simD);
logger.info("Night/night cross-model similarity: " + simN);

View File

@ -16,6 +16,9 @@
package org.deeplearning4j.models.word2vec;
import org.apache.commons.io.IOUtils;
import org.apache.commons.io.LineIterator;
import org.deeplearning4j.text.sentenceiterator.CollectionSentenceIterator;
import org.junit.Rule;
import org.junit.rules.Timeout;
import org.nd4j.shade.guava.primitives.Doubles;
@ -51,8 +54,8 @@ import org.nd4j.resources.Resources;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.IOException;
import java.io.*;
import java.nio.charset.StandardCharsets;
import java.util.*;
import static org.junit.Assert.*;
@ -185,7 +188,12 @@ public class Word2VecTests extends BaseDL4JTest {
@Test
public void testWord2VecMultiEpoch() throws Exception {
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
SentenceIterator iter;
if(isIntegrationTests()){
iter = new BasicLineIterator(inputFile.getAbsolutePath());
} else {
iter = new CollectionSentenceIterator(firstNLines(inputFile, 50000));
}
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
@ -389,7 +397,12 @@ public class Word2VecTests extends BaseDL4JTest {
@Test
public void testW2VnegativeOnRestore() throws Exception {
// Strip white space before and after for each line
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
SentenceIterator iter;
if(isIntegrationTests()){
iter = new BasicLineIterator(inputFile.getAbsolutePath());
} else {
iter = new CollectionSentenceIterator(firstNLines(inputFile, 300));
}
// Split on white spaces in the line to get words
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
@ -491,7 +504,12 @@ public class Word2VecTests extends BaseDL4JTest {
@Test
public void orderIsCorrect_WhenParallelized() throws Exception {
// Strip white space before and after for each line
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
SentenceIterator iter;
if(isIntegrationTests()){
iter = new BasicLineIterator(inputFile.getAbsolutePath());
} else {
iter = new CollectionSentenceIterator(firstNLines(inputFile, 300));
}
// Split on white spaces in the line to get words
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
@ -510,9 +528,10 @@ public class Word2VecTests extends BaseDL4JTest {
System.out.println(vec.getVocab().numWords());
val words = vec.getVocab().words();
for (val word : words) {
System.out.println(word);
}
assertTrue(words.size() > 0);
// for (val word : words) {
// System.out.println(word);
// }
}
@Test
@ -755,7 +774,16 @@ public class Word2VecTests extends BaseDL4JTest {
@Test
public void weightsNotUpdated_WhenLocked() throws Exception {
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
boolean isIntegration = isIntegrationTests();
SentenceIterator iter;
SentenceIterator iter2;
if(isIntegration){
iter = new BasicLineIterator(inputFile);
iter2 = new BasicLineIterator(inputFile2.getAbsolutePath());
} else {
iter = new CollectionSentenceIterator(firstNLines(inputFile, 300));
iter2 = new CollectionSentenceIterator(firstNLines(inputFile2, 300));
}
Word2Vec vec1 = new Word2Vec.Builder().minWordFrequency(1).iterations(3).batchSize(64).layerSize(100)
.stopWords(new ArrayList<String>()).seed(42).learningRate(0.025).minLearningRate(0.001)
@ -767,13 +795,12 @@ public class Word2VecTests extends BaseDL4JTest {
vec1.fit();
iter = new BasicLineIterator(inputFile2.getAbsolutePath());
Word2Vec vec2 = new Word2Vec.Builder().minWordFrequency(1).iterations(3).batchSize(32).layerSize(100)
.stopWords(new ArrayList<String>()).seed(32).learningRate(0.021).minLearningRate(0.001)
.sampling(0).elementsLearningAlgorithm(new SkipGram<VocabWord>())
.epochs(1).windowSize(5).allowParallelTokenization(true)
.workers(1)
.iterate(iter)
.iterate(iter2)
.intersectModel(vec1, true)
.modelUtils(new BasicModelUtils<VocabWord>()).build();
@ -861,6 +888,22 @@ public class Word2VecTests extends BaseDL4JTest {
}
System.out.print("\n");
}
//
public static List<String> firstNLines(File f, int n){
List<String> lines = new ArrayList<>();
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
LineIterator lineIter = IOUtils.lineIterator(is, StandardCharsets.UTF_8);
try{
for( int i=0; i<n && lineIter.hasNext(); i++ ){
lines.add(lineIter.next());
}
} finally {
lineIter.close();
}
return lines;
} catch (IOException e){
throw new RuntimeException(e);
}
}
}

View File

@ -48,7 +48,7 @@ public class TsneTest extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 120000L;
return 60000L;
}
@Rule
@ -58,8 +58,9 @@ public class TsneTest extends BaseDL4JTest {
public void testSimple() throws Exception {
//Simple sanity check
for (boolean syntheticData : new boolean[]{false, true}) {
for (WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.NONE, WorkspaceMode.ENABLED}) {
for( int test=0; test <=1; test++){
boolean syntheticData = test == 1;
WorkspaceMode wsm = test == 0 ? WorkspaceMode.NONE : WorkspaceMode.ENABLED;
log.info("Starting test: WSM={}, syntheticData={}", wsm, syntheticData);
//STEP 1: Initialization
@ -71,7 +72,7 @@ public class TsneTest extends BaseDL4JTest {
//STEP 2: Turn text input into a list of words
INDArray weights;
if(syntheticData){
weights = Nd4j.rand(1000, 200);
weights = Nd4j.rand(250, 200);
} else {
log.info("Load & Vectorize data....");
File wordFile = new ClassPathResource("deeplearning4j-tsne/words.txt").getFile(); //Open the file
@ -103,28 +104,27 @@ public class TsneTest extends BaseDL4JTest {
tsne.saveAsFile(cacheList, new File(outDir, "out.txt").getAbsolutePath());
}
}
}
//Elapsed time : 01:01:57.988
@Test
public void testPerformance() throws Exception {
StopWatch watch = new StopWatch();
watch.start();
for (boolean syntheticData : new boolean[]{false, true}) {
for (WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.NONE, WorkspaceMode.ENABLED}) {
for( int test=0; test <=1; test++){
boolean syntheticData = test == 1;
WorkspaceMode wsm = test == 0 ? WorkspaceMode.NONE : WorkspaceMode.ENABLED;
log.info("Starting test: WSM={}, syntheticData={}", wsm, syntheticData);
//STEP 1: Initialization
int iterations = 100;
int iterations = 50;
//create an n-dimensional array of doubles
Nd4j.setDataType(DataType.DOUBLE);
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
List<String> cacheList = new ArrayList<>(); //cacheList is a dynamic array of strings used to hold all words
//STEP 2: Turn text input into a list of words
INDArray weights;
if(syntheticData){
weights = Nd4j.rand(5000, 20);
weights = Nd4j.rand(DataType.FLOAT, 250, 20);
} else {
log.info("Load & Vectorize data....");
File wordFile = new ClassPathResource("deeplearning4j-tsne/words.txt").getFile(); //Open the file
@ -155,7 +155,6 @@ public class TsneTest extends BaseDL4JTest {
tsne.fit(weights);
tsne.saveAsFile(cacheList, new File(outDir, "out.txt").getAbsolutePath());
}
}
watch.stop();
System.out.println("Elapsed time : " + watch);
}

View File

@ -20,6 +20,8 @@ package org.deeplearning4j.models.paragraphvectors;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.apache.commons.io.IOUtils;
import org.apache.commons.io.LineIterator;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.models.embeddings.learning.impl.elements.CBOW;
import org.deeplearning4j.models.embeddings.reader.impl.FlatModelUtils;
@ -27,6 +29,7 @@ import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer;
import org.deeplearning4j.models.sequencevectors.transformers.impl.iterables.BasicTransformerIterator;
import org.deeplearning4j.models.sequencevectors.transformers.impl.iterables.ParallelTransformerIterator;
import org.deeplearning4j.text.sentenceiterator.*;
import org.junit.Rule;
import org.junit.rules.TemporaryFolder;
import org.nd4j.linalg.io.ClassPathResource;
@ -46,10 +49,6 @@ import org.deeplearning4j.text.documentiterator.FileLabelAwareIterator;
import org.deeplearning4j.text.documentiterator.LabelAwareIterator;
import org.deeplearning4j.text.documentiterator.LabelledDocument;
import org.deeplearning4j.text.documentiterator.LabelsSource;
import org.deeplearning4j.text.sentenceiterator.AggregatingSentenceIterator;
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
import org.deeplearning4j.text.sentenceiterator.FileSentenceIterator;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.sentenceiterator.interoperability.SentenceIteratorConverter;
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
@ -66,8 +65,8 @@ import org.nd4j.resources.Resources;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.IOException;
import java.io.*;
import java.nio.charset.StandardCharsets;
import java.util.*;
import java.util.concurrent.atomic.AtomicLong;
@ -372,7 +371,7 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
LabelsSource source = new LabelsSource("DOC_");
ParagraphVectors vec = new ParagraphVectors.Builder().minWordFrequency(1).iterations(2).seed(119).epochs(3)
ParagraphVectors vec = new ParagraphVectors.Builder().minWordFrequency(1).iterations(2).seed(119).epochs(1)
.layerSize(100).learningRate(0.025).labelsSource(source).windowSize(5).iterate(iter)
.trainWordVectors(true).vocabCache(cache).tokenizerFactory(t).negativeSample(0)
.useHierarchicSoftmax(true).sampling(0).workers(1).usePreciseWeightInit(true)
@ -425,6 +424,8 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
@Test(timeout = 300000)
public void testParagraphVectorsDBOW() throws Exception {
skipUnlessIntegrationTests();
File file = Resources.asFile("/big/raw_sentences.txt");
SentenceIterator iter = new BasicLineIterator(file);
@ -657,7 +658,7 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
}
}
@Test(timeout = 300000)
@Test
public void testIterator() throws IOException {
val folder_labeled = testDir.newFolder();
val folder_unlabeled = testDir.newFolder();
@ -672,7 +673,7 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
SentenceIterator iter = new BasicLineIterator(resource_sentences);
int i = 0;
for (; i < 10000; ++i) {
for (; i < 10; ++i) {
int j = 0;
int labels = 0;
int words = 0;
@ -721,7 +722,7 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
Word2Vec wordVectors = new Word2Vec.Builder().seed(119).minWordFrequency(1).batchSize(250).iterations(1).epochs(3)
Word2Vec wordVectors = new Word2Vec.Builder().seed(119).minWordFrequency(1).batchSize(250).iterations(1).epochs(1)
.learningRate(0.025).layerSize(150).minLearningRate(0.001)
.elementsLearningAlgorithm(new SkipGram<VocabWord>()).useHierarchicSoftmax(true).windowSize(5)
.allowParallelTokenization(true)
@ -1009,7 +1010,7 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
Word2Vec wordVectors = new Word2Vec.Builder().minWordFrequency(1).batchSize(250).iterations(1).epochs(3)
Word2Vec wordVectors = new Word2Vec.Builder().minWordFrequency(1).batchSize(250).iterations(1).epochs(1)
.learningRate(0.025).layerSize(150).minLearningRate(0.001)
.elementsLearningAlgorithm(new SkipGram<VocabWord>()).useHierarchicSoftmax(true).windowSize(5)
.iterate(iter).tokenizerFactory(t).build();
@ -1151,8 +1152,27 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
@Test(timeout = 300000)
public void testDoubleFit() throws Exception {
boolean isIntegration = isIntegrationTests();
File resource = Resources.asFile("/big/raw_sentences.txt");
SentenceIterator iter = new BasicLineIterator(resource);
SentenceIterator iter;
if(isIntegration){
iter = new BasicLineIterator(resource);
} else {
List<String> lines = new ArrayList<>();
try(InputStream is = new BufferedInputStream(new FileInputStream(resource))){
LineIterator lineIter = IOUtils.lineIterator(is, StandardCharsets.UTF_8);
try{
for( int i=0; i<500 && lineIter.hasNext(); i++ ){
lines.add(lineIter.next());
}
} finally {
lineIter.close();
}
}
iter = new CollectionSentenceIterator(lines);
}
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());

View File

@ -49,7 +49,7 @@ public class Word2VecDataSetIteratorTest extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 240000L;
return 60000L;
}
/**
@ -57,6 +57,7 @@ public class Word2VecDataSetIteratorTest extends BaseDL4JTest {
*/
@Test
public void testIterator1() throws Exception {
File inputFile = Resources.asFile("big/raw_sentences.txt");
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
@ -77,10 +78,14 @@ public class Word2VecDataSetIteratorTest extends BaseDL4JTest {
Word2VecDataSetIterator iterator = new Word2VecDataSetIterator(vec, getLASI(iter, labels), labels, 1);
INDArray array = iterator.next().getFeatures();
int count = 0;
while (iterator.hasNext()) {
DataSet ds = iterator.next();
assertArrayEquals(array.shape(), ds.getFeatures().shape());
if(!isIntegrationTests() && count++ > 20)
break; //raw_sentences.txt is 2.81 MB, takes quite some time to process. We'll only first 20 minibatches when doing unit tests
}
}

View File

@ -45,9 +45,15 @@ public class EncodedGradientsAccumulatorTest extends BaseDL4JTest {
*/
@Test
public void testStore1() throws Exception {
int numParams = 100000;
int workers[] = new int[] {2, 4, 8};
int numParams;
int[] workers;
if(isIntegrationTests()){
numParams = 100000;
workers = new int[] {2, 4, 8};
} else {
numParams = 10000;
workers = new int[] {2, 3};
}
for (int numWorkers : workers) {
EncodingHandler handler = new EncodingHandler(new FixedThresholdAlgorithm(1e-3),null, null, false);
@ -77,7 +83,13 @@ public class EncodedGradientsAccumulatorTest extends BaseDL4JTest {
*/
@Test
public void testEncodingLimits1() throws Exception {
int numParams = 100000;
int numParams;
if(isIntegrationTests()){
numParams = 100000;
} else {
numParams = 10000;
}
EncodingHandler handler = new EncodingHandler(new FixedThresholdAlgorithm(1e-3), null, null, false);
for (int e = 10; e < numParams / 5; e++) {

View File

@ -242,7 +242,7 @@ public class IndexedTailTest extends BaseDL4JTest {
final long[] sums = new long[numReaders];
val readers = new ArrayList<Thread>();
for (int e = 0; e < numReaders; e++) {
val f = e;
final int f = e;
val t = new Thread(new Runnable() {
@Override
public void run() {
@ -297,7 +297,7 @@ public class IndexedTailTest extends BaseDL4JTest {
final long[] sums = new long[numReaders];
val readers = new ArrayList<Thread>();
for (int e = 0; e < numReaders; e++) {
val f = e;
final int f = e;
val t = new Thread(new Runnable() {
@Override
public void run() {
@ -371,7 +371,7 @@ public class IndexedTailTest extends BaseDL4JTest {
final long[] sums = new long[numReaders];
val readers = new ArrayList<Thread>();
for (int e = 0; e < numReaders; e++) {
val f = e;
final int f = e;
val t = new Thread(new Runnable() {
@Override
public void run() {

View File

@ -35,6 +35,7 @@ import org.deeplearning4j.remote.helpers.House;
import org.deeplearning4j.remote.helpers.HouseToPredictedPriceAdapter;
import org.deeplearning4j.remote.helpers.PredictedPrice;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.nd4j.adapters.InferenceAdapter;
import org.nd4j.autodiff.samediff.SDVariable;
@ -58,6 +59,7 @@ import java.util.Collections;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import static org.deeplearning4j.parallelism.inference.InferenceMode.INPLACE;
import static org.deeplearning4j.parallelism.inference.InferenceMode.SEQUENTIAL;
@ -66,7 +68,6 @@ import static org.junit.Assert.*;
@Slf4j
public class JsonModelServerTest extends BaseDL4JTest {
private static final MultiLayerNetwork model;
private final int PORT = 18080;
static {
val conf = new NeuralNetConfiguration.Builder()
@ -84,10 +85,18 @@ public class JsonModelServerTest extends BaseDL4JTest {
@After
public void pause() throws Exception {
// TODO: the same port was used in previous test and not accessible immediately. Might be better solution.
// Need to wait for server shutdown; without sleep, tests will fail if starting immediately after shutdown
TimeUnit.SECONDS.sleep(2);
}
private AtomicInteger portCount = new AtomicInteger(18080);
private int PORT;
@Before
public void setPort(){
PORT = portCount.getAndIncrement();
}
@Test
public void testStartStopParallel() throws Exception {
@ -343,7 +352,7 @@ public class JsonModelServerTest extends BaseDL4JTest {
val server = new JsonModelServer.Builder<House, PredictedPrice>(model)
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
.inputDeserializer(null)
.port(18080)
.port(PORT)
.build();
}
@ -382,7 +391,7 @@ public class JsonModelServerTest extends BaseDL4JTest {
return null;
}
})
.endpointAddress("http://localhost:18080/v1/serving")
.endpointAddress("http://localhost:" + PORT + "/v1/serving")
.build();
int district = 2;

View File

@ -485,7 +485,8 @@ public class ParallelInferenceTest extends BaseDL4JTest {
List<INDArray> exp = new ArrayList<>();
Random r = new Random();
for (int i = 0; i < 500; i++) {
int runs = isIntegrationTests() ? 500 : 30;
for (int i = 0; i < runs; i++) {
int[] shape = defaultSize;
if (r.nextDouble() < 0.4) {
shape = new int[]{r.nextInt(5) + 1, 10, r.nextInt(10) + 1};
@ -597,7 +598,8 @@ public class ParallelInferenceTest extends BaseDL4JTest {
List<INDArray> arrs = new ArrayList<>();
List<INDArray> exp = new ArrayList<>();
Random r = new Random();
for( int i=0; i<500; i++ ){
int runs = isIntegrationTests() ? 500 : 20;
for( int i=0; i<runs; i++ ){
int[] shape = defaultShape;
if(r.nextDouble() < 0.4){
shape = new int[]{r.nextInt(5)+1, nIn, 10, r.nextInt(10)+1};
@ -679,8 +681,7 @@ public class ParallelInferenceTest extends BaseDL4JTest {
}
}
@Test(timeout = 120000)
public void testInputMasking() throws Exception {
private void testInputMasking() throws Exception {
Nd4j.getRandom().setSeed(12345);
int nIn = 10;
@ -698,12 +699,15 @@ public class ParallelInferenceTest extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
// InferenceMode[] inferenceModes = new InferenceMode[]{InferenceMode.SEQUENTIAL, InferenceMode.BATCHED, InferenceMode.INPLACE, InferenceMode.SEQUENTIAL};
// int[] workers = new int[]{2, 2, 2, 1};
// boolean[] randomTS = new boolean[]{true, false, true, false};
Random r = new Random();
for( InferenceMode m : InferenceMode.values()) {
log.info("Testing inference mode: [{}]", m);
for( int w : new int[]{1,2}) {
for (boolean randomTSLength : new boolean[]{false, true}) {
final ParallelInference inf =
new ParallelInference.Builder(net)
.inferenceMode(m)
@ -714,7 +718,8 @@ public class ParallelInferenceTest extends BaseDL4JTest {
List<INDArray> in = new ArrayList<>();
List<INDArray> inMasks = new ArrayList<>();
List<INDArray> exp = new ArrayList<>();
for (int i = 0; i < 100; i++) {
int nRuns = isIntegrationTests() ? 100 : 10;
for (int i = 0; i < nRuns; i++) {
int currTSLength = (randomTSLength ? 1 + r.nextInt(tsLength) : tsLength);
int currNumEx = 1 + r.nextInt(3);
INDArray inArr = Nd4j.rand(new int[]{currNumEx, nIn, currTSLength});
@ -847,6 +852,7 @@ public class ParallelInferenceTest extends BaseDL4JTest {
List<INDArray[]> in = new ArrayList<>();
List<INDArray[]> exp = new ArrayList<>();
int runs = isIntegrationTests() ? 100 : 20;
for (int i = 0; i < 100; i++) {
int currNumEx = 1 + r.nextInt(3);
INDArray inArr = Nd4j.rand(new int[]{currNumEx, nIn});

View File

@ -62,8 +62,8 @@ public class ParallelWrapperTest extends BaseDL4JTest {
int seed = 123;
log.info("Load data....");
DataSetIterator mnistTrain = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(batchSize, true, 12345), 100);
DataSetIterator mnistTest = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(batchSize, false, 12345), 10);
DataSetIterator mnistTrain = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(batchSize, true, 12345), 15);
DataSetIterator mnistTest = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(batchSize, false, 12345), 4);
assertTrue(mnistTrain.hasNext());
val t0 = mnistTrain.next();

View File

@ -47,6 +47,12 @@
<groupId>org.nd4j</groupId>
<artifactId>nd4j-parameter-server-node_2.11</artifactId>
<version>${nd4j.version}</version>
<exclusions>
<exclusion>
<groupId>net.jpountz.lz4</groupId>
<artifactId>lz4</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>junit</groupId>

View File

@ -0,0 +1,31 @@
################################################################################
# Copyright (c) 2015-2019 Skymind, Inc.
#
# This program and the accompanying materials are made available under the
# terms of the Apache License, Version 2.0 which is available at
# https://www.apache.org/licenses/LICENSE-2.0.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
#
# SPDX-License-Identifier: Apache-2.0
################################################################################
log4j.rootLogger=ERROR, Console
log4j.appender.Console=org.apache.log4j.ConsoleAppender
log4j.appender.Console.layout=org.apache.log4j.PatternLayout
log4j.appender.Console.layout.ConversionPattern=%d{ABSOLUTE} %-5p ~ %m%n
log4j.appender.org.springframework=DEBUG
log4j.appender.org.deeplearning4j=DEBUG
log4j.appender.org.nd4j=DEBUG
log4j.logger.org.springframework=INFO
log4j.logger.org.deeplearning4j=DEBUG
log4j.logger.org.nd4j=DEBUG
log4j.logger.org.apache.spark=WARN

View File

@ -0,0 +1,53 @@
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~ Copyright (c) 2015-2018 Skymind, Inc.
~
~ This program and the accompanying materials are made available under the
~ terms of the Apache License, Version 2.0 which is available at
~ https://www.apache.org/licenses/LICENSE-2.0.
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
~ License for the specific language governing permissions and limitations
~ under the License.
~
~ SPDX-License-Identifier: Apache-2.0
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
<configuration>
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
<file>logs/application.log</file>
<encoder>
<pattern>%date - [%level] - from %logger in %thread
%n%message%n%xException%n</pattern>
</encoder>
</appender>
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
<encoder>
<pattern> %logger{15} - %message%n%xException{5}
</pattern>
</encoder>
</appender>
<logger name="org.apache.catalina.core" level="DEBUG" />
<logger name="org.springframework" level="DEBUG" />
<logger name="org.deeplearning4j" level="DEBUG" />
<logger name="org.datavec" level="INFO" />
<logger name="org.nd4j" level="INFO" />
<logger name="opennlp.uima.util" level="OFF" />
<logger name="org.apache.uima" level="OFF" />
<logger name="org.cleartk" level="OFF" />
<logger name="org.apache.spark" level="WARN" />
<root level="ERROR">
<appender-ref ref="STDOUT" />
<appender-ref ref="FILE" />
</root>
</configuration>

View File

@ -0,0 +1,31 @@
################################################################################
# Copyright (c) 2015-2019 Skymind, Inc.
#
# This program and the accompanying materials are made available under the
# terms of the Apache License, Version 2.0 which is available at
# https://www.apache.org/licenses/LICENSE-2.0.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
#
# SPDX-License-Identifier: Apache-2.0
################################################################################
log4j.rootLogger=ERROR, Console
log4j.appender.Console=org.apache.log4j.ConsoleAppender
log4j.appender.Console.layout=org.apache.log4j.PatternLayout
log4j.appender.Console.layout.ConversionPattern=%d{ABSOLUTE} %-5p ~ %m%n
log4j.appender.org.springframework=DEBUG
log4j.appender.org.deeplearning4j=DEBUG
log4j.appender.org.nd4j=DEBUG
log4j.logger.org.springframework=INFO
log4j.logger.org.deeplearning4j=DEBUG
log4j.logger.org.nd4j=DEBUG
log4j.logger.org.apache.spark=WARN

View File

@ -0,0 +1,53 @@
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~ Copyright (c) 2015-2018 Skymind, Inc.
~
~ This program and the accompanying materials are made available under the
~ terms of the Apache License, Version 2.0 which is available at
~ https://www.apache.org/licenses/LICENSE-2.0.
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
~ License for the specific language governing permissions and limitations
~ under the License.
~
~ SPDX-License-Identifier: Apache-2.0
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
<configuration>
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
<file>logs/application.log</file>
<encoder>
<pattern>%date - [%level] - from %logger in %thread
%n%message%n%xException%n</pattern>
</encoder>
</appender>
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
<encoder>
<pattern> %logger{15} - %message%n%xException{5}
</pattern>
</encoder>
</appender>
<logger name="org.apache.catalina.core" level="DEBUG" />
<logger name="org.springframework" level="DEBUG" />
<logger name="org.deeplearning4j" level="DEBUG" />
<logger name="org.datavec" level="INFO" />
<logger name="org.nd4j" level="INFO" />
<logger name="opennlp.uima.util" level="OFF" />
<logger name="org.apache.uima" level="OFF" />
<logger name="org.cleartk" level="OFF" />
<logger name="org.apache.spark" level="WARN" />
<root level="ERROR">
<appender-ref ref="STDOUT" />
<appender-ref ref="FILE" />
</root>
</configuration>

View File

@ -25,6 +25,7 @@ import org.datavec.spark.transform.misc.StringToWritablesFunction;
import org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator;
import org.deeplearning4j.spark.BaseSparkTest;
import org.junit.Test;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.io.ClassPathResource;
@ -35,6 +36,15 @@ import static org.junit.Assert.assertEquals;
public class TestIteratorUtils extends BaseSparkTest {
@Override
public DataType getDataType() {
return DataType.FLOAT;
}
@Override
public DataType getDefaultFPDataType() {
return DataType.FLOAT;
}
@Test
public void testIrisRRMDSI() throws Exception {

View File

@ -453,8 +453,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
tempDirF.deleteOnExit();
int dataSetObjSize = 1;
int batchSizePerExecutor = 16;
int numSplits = 5;
int batchSizePerExecutor = 4;
int numSplits = 3;
int averagingFrequency = 3;
int totalExamples = numExecutors() * batchSizePerExecutor * numSplits * averagingFrequency;
DataSetIterator iter = new MnistDataSetIterator(dataSetObjSize, totalExamples, false);
@ -506,7 +506,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
INDArray paramsAfter = sparkNet.getNetwork().params().dup();
assertNotEquals(paramsBefore, paramsAfter);
Thread.sleep(2000);
Thread.sleep(200);
SparkTrainingStats stats = sparkNet.getSparkTrainingStats();
//Expect
@ -517,7 +517,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
assertEquals(numSplits * numExecutors() * averagingFrequency, list.size());
for (EventStats es : list) {
ExampleCountEventStats e = (ExampleCountEventStats) es;
assertTrue(batchSizePerExecutor * averagingFrequency - 10 >= e.getTotalExampleCount());
assertTrue(batchSizePerExecutor * averagingFrequency >= e.getTotalExampleCount());
}
@ -535,9 +535,9 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
tempDirF.deleteOnExit();
tempDirF2.deleteOnExit();
int dataSetObjSize = 5;
int batchSizePerExecutor = 25;
DataSetIterator iter = new MnistDataSetIterator(dataSetObjSize, 1000, false);
int dataSetObjSize = 4;
int batchSizePerExecutor = 8;
DataSetIterator iter = new MnistDataSetIterator(dataSetObjSize, 128, false);
int i = 0;
while (iter.hasNext()) {
File nextFile = new File(tempDirF, i + ".bin");
@ -981,7 +981,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
.setOutputs("out")
.build();
DataSetIterator iter = new IrisDataSetIterator(1, 150);
DataSetIterator iter = new IrisDataSetIterator(1, 50);
List<DataSet> l = new ArrayList<>();
while(iter.hasNext()){
@ -992,9 +992,10 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
int rddDataSetNumExamples = 1;
int averagingFrequency = 3;
int averagingFrequency = 2;
int batch = 2;
ParameterAveragingTrainingMaster tm = new ParameterAveragingTrainingMaster.Builder(rddDataSetNumExamples)
.averagingFrequency(averagingFrequency).batchSizePerWorker(rddDataSetNumExamples)
.averagingFrequency(averagingFrequency).batchSizePerWorker(batch)
.saveUpdater(true).workerPrefetchNumBatches(0).build();
Nd4j.getRandom().setSeed(12345);
@ -1003,7 +1004,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
SparkComputationGraph sn2 = new SparkComputationGraph(sc, conf2.clone(), tm);
for(int i=0; i<4; i++ ){
for(int i=0; i<3; i++ ){
assertEquals(i, sn1.getNetwork().getLayerWiseConfigurations().getEpochCount());
assertEquals(i, sn2.getNetwork().getConfiguration().getEpochCount());
sn1.fit(rdd);

View File

@ -42,6 +42,11 @@ import static org.junit.Assert.assertTrue;
*/
public class TestRepartitioning extends BaseSparkTest {
@Override
public long getTimeoutMilliseconds() {
return isIntegrationTests() ? 240000 : 60000;
}
@Test
public void testRepartitioning() {
List<String> list = new ArrayList<>();
@ -66,7 +71,12 @@ public class TestRepartitioning extends BaseSparkTest {
@Test
public void testRepartitioning2() throws Exception {
int[] ns = {320, 321, 25600, 25601, 25615};
int[] ns;
if(isIntegrationTests()){
ns = new int[]{320, 321, 25600, 25601, 25615};
} else {
ns = new int[]{320, 2561};
}
for (int n : ns) {

View File

@ -32,6 +32,11 @@ import java.io.File;
public class MiscTests extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 120000L;
}
@Test
public void testTransferVGG() throws Exception {
//https://github.com/deeplearning4j/deeplearning4j/issues/5167

View File

@ -48,6 +48,11 @@ import static org.junit.Assert.assertEquals;
@Slf4j
public class TestDownload extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return isIntegrationTests() ? 480000L : 60000L;
}
@ClassRule
public static TemporaryFolder testDir = new TemporaryFolder();
private static File f;
@ -67,12 +72,20 @@ public class TestDownload extends BaseDL4JTest {
public void testDownloadAllModels() throws Exception {
// iterate through each available model
ZooModel[] models = new ZooModel[]{
ZooModel[] models;
if(isIntegrationTests()){
models = new ZooModel[]{
LeNet.builder().build(),
SimpleCNN.builder().build(),
UNet.builder().build(),
NASNet.builder().build()
};
NASNet.builder().build()};
} else {
models = new ZooModel[]{
LeNet.builder().build(),
SimpleCNN.builder().build()};
}
for (int i = 0; i < models.length; i++) {

View File

@ -57,6 +57,11 @@ import static org.junit.Assert.assertTrue;
@Slf4j
public class TestImageNet extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Override
public DataType getDataType(){
return DataType.FLOAT;

View File

@ -64,9 +64,14 @@ public class DistributionUniform extends DynamicCustomOp {
}
public DistributionUniform(INDArray shape, INDArray out, double min, double max) {
this(shape, out, min, max, null);
}
public DistributionUniform(INDArray shape, INDArray out, double min, double max, DataType dataType){
super(null, new INDArray[]{shape}, new INDArray[]{out}, Arrays.asList(min, max), (List<Integer>)null);
this.min = min;
this.max = max;
this.dataType = dataType;
}

View File

@ -310,6 +310,13 @@
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-common-tests</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
<profiles>

View File

@ -19,6 +19,7 @@ package org.nd4j.jita.allocator;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.DeviceLocalNDArray;
@ -29,7 +30,7 @@ import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
@Slf4j
public class DeviceLocalNDArrayTests {
public class DeviceLocalNDArrayTests extends BaseND4JTest {
@Test
public void testDeviceLocalArray_1() throws Exception{

View File

@ -19,13 +19,14 @@ package org.nd4j.jita.allocator.impl;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.factory.Nd4j;
import static org.junit.Assert.*;
@Slf4j
public class MemoryTrackerTest {
public class MemoryTrackerTest extends BaseND4JTest {
@Test
public void testAllocatedDelta() {

View File

@ -4,6 +4,7 @@ import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.junit.Before;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.workspace.CudaWorkspace;
import org.nd4j.linalg.api.buffer.DataType;
@ -20,7 +21,7 @@ import java.util.concurrent.atomic.AtomicInteger;
import static org.junit.Assert.*;
@Slf4j
public class BaseCudaDataBufferTest {
public class BaseCudaDataBufferTest extends BaseND4JTest {
@Before
public void setUp() {

View File

@ -87,6 +87,13 @@
<version>${logback.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-common-tests</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
<profiles>

View File

@ -20,6 +20,7 @@ package org.nd4j.tensorflow.conversion;
import junit.framework.TestCase;
import org.apache.commons.io.FileUtils;
import org.bytedeco.tensorflow.TF_Tensor;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.resources.Resources;
import org.nd4j.shade.protobuf.Descriptors;
@ -46,7 +47,17 @@ import java.util.Map;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
public class GraphRunnerTest {
public class GraphRunnerTest extends BaseND4JTest {
@Override
public DataType getDataType() {
return DataType.FLOAT;
}
@Override
public DataType getDefaultFPDataType() {
return DataType.FLOAT;
}
public static ConfigProto getConfig(){
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");

View File

@ -18,6 +18,7 @@ package org.nd4j.tensorflow.conversion;
import org.apache.commons.io.IOUtils;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.io.ClassPathResource;
@ -29,7 +30,7 @@ import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.fail;
public class TensorflowConversionTest {
public class TensorflowConversionTest extends BaseND4JTest {
@Test
public void testView() {

View File

@ -17,6 +17,7 @@
package org.nd4j.tensorflow.conversion;
import org.nd4j.BaseND4JTest;
import org.nd4j.shade.protobuf.util.JsonFormat;
import org.apache.commons.io.IOUtils;
import org.junit.Test;
@ -37,7 +38,7 @@ import java.util.Map;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
public class GpuGraphRunnerTest {
public class GpuGraphRunnerTest extends BaseND4JTest {
@Test
public void testGraphRunner() throws Exception {

View File

@ -127,6 +127,13 @@
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-common-tests</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
<reporting>

View File

@ -16,9 +16,6 @@
package org.nd4j.autodiff.opvalidation;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
@ -46,6 +43,8 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.ops.transforms.Transforms;
import static org.junit.Assert.*;
@Slf4j
public class LayerOpValidation extends BaseOpValidation {
public LayerOpValidation(Nd4jBackend backend) {

View File

@ -45,7 +45,7 @@ public class LossOpValidation extends BaseOpValidation {
}
@Override
public long testTimeoutMilliseconds() {
public long getTimeoutMilliseconds() {
return 90000L;
}

View File

@ -54,8 +54,7 @@ import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.*;
@Slf4j
@RunWith(Parameterized.class)

View File

@ -2434,8 +2434,6 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test
public void testPermute4(){
Nd4j.getExecutioner().enableDebugMode(true);
Nd4j.getExecutioner().enableVerboseMode(true);
INDArray in = Nd4j.linspace(DataType.FLOAT, 1, 6, 1).reshape(3,2);
INDArray permute = Nd4j.createFromArray(1,0);

View File

@ -64,7 +64,11 @@ public class CheckpointListenerTest extends BaseNd4jTest {
}
public static DataSetIterator getIter() {
return new IrisDataSetIterator(15, 150);
return getIter(15, 150);
}
public static DataSetIterator getIter(int batch, int totalExamples){
return new IrisDataSetIterator(batch, totalExamples);
}
@ -148,15 +152,15 @@ public class CheckpointListenerTest extends BaseNd4jTest {
CheckpointListener l = new CheckpointListener.Builder(dir)
.keepLast(2)
.saveEvery(3, TimeUnit.SECONDS)
.saveEvery(1, TimeUnit.SECONDS)
.build();
sd.setListeners(l);
DataSetIterator iter = getIter();
DataSetIterator iter = getIter(15, 150);
for(int i=0; i<5; i++ ){ //10 iterations total
sd.fit(iter, 1);
Thread.sleep(4000);
Thread.sleep(1000);
}
//Expect models saved at iterations: 10, 20, 30, 40

View File

@ -29,6 +29,7 @@ import org.nd4j.linalg.ops.transforms.Transforms;
import java.util.Random;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
/**

View File

@ -34,8 +34,7 @@ import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.*;
/**
* Created by Alex on 05/07/2017.

View File

@ -33,8 +33,7 @@ import org.nd4j.linalg.indexing.NDArrayIndex;
import java.util.*;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.*;
/**
* Created by Alex on 04/11/2016.

View File

@ -33,6 +33,7 @@ import org.nd4j.nativeblas.NativeOpsHolder;
import java.util.Arrays;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
@Slf4j

View File

@ -121,7 +121,12 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
"fused_batch_norm/.*",
// AB 2020/01/04 - https://github.com/eclipse/deeplearning4j/issues/8592
"emptyArrayTests/reshape/rank2_shape2-0_2-0--1"
"emptyArrayTests/reshape/rank2_shape2-0_2-0--1",
//AB 2020/01/07 - Known issues
"bitcast/from_float64_to_int64",
"bitcast/from_rank2_float64_to_int64",
"bitcast/from_float64_to_uint64"
};
/* As per TFGraphTestList.printArraysDebugging - this field defines a set of regexes for test cases that should have

View File

@ -252,7 +252,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
System.out.println(Arrays.toString(shape));
// this is NHWC weights. will be changed soon.
assertArrayEquals(new int[]{5,5,1,32}, shape);
assertArrayEquals(new long[]{5,5,1,32}, shape);
System.out.println(convNode);
}

View File

@ -17,6 +17,7 @@
package org.nd4j.linalg;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.bytedeco.javacpp.Pointer;
import org.junit.After;
@ -26,6 +27,7 @@ import org.junit.rules.TestName;
import org.junit.rules.Timeout;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.BaseND4JTest;
import org.nd4j.config.ND4JSystemProperties;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
@ -40,30 +42,16 @@ import org.slf4j.LoggerFactory;
import java.lang.management.ManagementFactory;
import java.util.*;
import static org.junit.Assume.assumeTrue;
/**
* Base Nd4j test
* @author Adam Gibson
*/
@RunWith(Parameterized.class)
public abstract class BaseNd4jTest {
private static Logger log = LoggerFactory.getLogger(BaseNd4jTest.class);
@Rule
public TestName testName = new TestName();
@Rule
public Timeout timeout = Timeout.seconds(testTimeoutMilliseconds());
/**
* Override this method to set the default timeout for methods in the class
*/
public long testTimeoutMilliseconds(){
return 30000L;
}
protected long startTime;
protected int threadCountBefore;
@Slf4j
public abstract class BaseNd4jTest extends BaseND4JTest {
protected Nd4jBackend backend;
protected String name;
@ -80,16 +68,10 @@ public abstract class BaseNd4jTest {
public BaseNd4jTest(String name, Nd4jBackend backend) {
this.backend = backend;
this.name = name;
//Suppress ND4J initialization - don't need this logged for every test...
System.setProperty(ND4JSystemProperties.LOG_INITIALIZATION, "false");
System.setProperty(ND4JSystemProperties.ND4J_IGNORE_AVX, "true");
System.gc();
}
public BaseNd4jTest(Nd4jBackend backend) {
this(backend.getClass().getName() + UUID.randomUUID().toString(), backend);
}
private static List<Nd4jBackend> backends;
@ -104,79 +86,6 @@ public abstract class BaseNd4jTest {
if (backend.canRun() && backendsToRun.contains(backend.getClass().getName()) || backendsToRun.isEmpty())
backends.add(backend);
}
}
public static void assertArrayEquals(String string, Object[] expecteds, Object[] actuals) {
org.junit.Assert.assertArrayEquals(string, expecteds, actuals);
}
public static void assertArrayEquals(Object[] expecteds, Object[] actuals) {
org.junit.Assert.assertArrayEquals(expecteds, actuals);
}
public static void assertArrayEquals(String string, long[] shapeA, long[] shapeB) {
org.junit.Assert.assertArrayEquals(string, shapeA, shapeB);
}
public static void assertArrayEquals(String string, byte[] shapeA, byte[] shapeB) {
org.junit.Assert.assertArrayEquals(string, shapeA, shapeB);
}
public static void assertArrayEquals(byte[] shapeA, byte[] shapeB) {
org.junit.Assert.assertArrayEquals(shapeA, shapeB);
}
public static void assertArrayEquals(long[] shapeA, long[] shapeB) {
org.junit.Assert.assertArrayEquals(shapeA, shapeB);
}
public static void assertArrayEquals(String string, int[] shapeA, long[] shapeB) {
org.junit.Assert.assertArrayEquals(string, ArrayUtil.toLongArray(shapeA), shapeB);
}
public static void assertArrayEquals(int[] shapeA, long[] shapeB) {
org.junit.Assert.assertArrayEquals(ArrayUtil.toLongArray(shapeA), shapeB);
}
public static void assertArrayEquals(String string, long[] shapeA, int[] shapeB) {
org.junit.Assert.assertArrayEquals(string, shapeA, ArrayUtil.toLongArray(shapeB));
}
public static void assertArrayEquals(long[] shapeA, int[] shapeB) {
org.junit.Assert.assertArrayEquals(shapeA, ArrayUtil.toLongArray(shapeB));
}
public static void assertArrayEquals(String string, int[] shapeA, int[] shapeB) {
org.junit.Assert.assertArrayEquals(string, shapeA, shapeB);
}
public static void assertArrayEquals(int[] shapeA, int[] shapeB) {
org.junit.Assert.assertArrayEquals(shapeA, shapeB);
}
public static void assertArrayEquals(String string, boolean[] shapeA, boolean[] shapeB) {
org.junit.Assert.assertArrayEquals(string, shapeA, shapeB);
}
public static void assertArrayEquals(boolean[] shapeA, boolean[] shapeB) {
org.junit.Assert.assertArrayEquals(shapeA, shapeB);
}
public static void assertArrayEquals(float[] shapeA, float[] shapeB, float delta) {
org.junit.Assert.assertArrayEquals(shapeA, shapeB, delta);
}
public static void assertArrayEquals(double[] shapeA, double[] shapeB, double delta) {
org.junit.Assert.assertArrayEquals(shapeA, shapeB, delta);
}
public static void assertArrayEquals(String string, float[] shapeA, float[] shapeB, float delta) {
org.junit.Assert.assertArrayEquals(string, shapeA, shapeB, delta);
}
public static void assertArrayEquals(String string, double[] shapeA, double[] shapeB, double delta) {
org.junit.Assert.assertArrayEquals(string, shapeA, shapeB, delta);
}
@Parameterized.Parameters(name = "{index}: backend({0})={1}")
@ -187,6 +96,13 @@ public abstract class BaseNd4jTest {
return ret;
}
@Override
@Before
public void beforeTest(){
super.beforeTest();
Nd4j.factory().setOrder(ordering());
}
/**
* Get the default backend (jblas)
* The default backend can be overridden by also passing:
@ -207,106 +123,6 @@ public abstract class BaseNd4jTest {
}
@Before
public void before() throws Exception {
//
log.info("Running {}.{} on {}", getClass().getName(), testName.getMethodName(), backend.getClass().getSimpleName());
Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build());
Nd4j nd4j = new Nd4j();
nd4j.initWithBackend(backend);
Nd4j.factory().setOrder(ordering());
NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false);
Nd4j.getExecutioner().enableDebugMode(false);
Nd4j.getExecutioner().enableVerboseMode(false);
Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE);
startTime = System.currentTimeMillis();
threadCountBefore = ManagementFactory.getThreadMXBean().getThreadCount();
}
@After
public void after() throws Exception {
long totalTime = System.currentTimeMillis() - startTime;
Nd4j.getMemoryManager().purgeCaches();
logTestCompletion(totalTime);
if (System.getProperties().getProperty("backends") != null
&& !System.getProperty("backends").contains(backend.getClass().getName()))
return;
Nd4j nd4j = new Nd4j();
nd4j.initWithBackend(backend);
Nd4j.factory().setOrder(ordering());
NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false);
Nd4j.getExecutioner().enableDebugMode(false);
Nd4j.getExecutioner().enableVerboseMode(false);
//Attempt to keep workspaces isolated between tests
Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread();
val currWS = Nd4j.getMemoryManager().getCurrentWorkspace();
Nd4j.getMemoryManager().setCurrentWorkspace(null);
if(currWS != null){
//Not really safe to continue testing under this situation... other tests will likely fail with obscure
// errors that are hard to track back to this
log.error("Open workspace leaked from test! Exiting - {}, isOpen = {} - {}", currWS.getId(), currWS.isScopeActive(), currWS);
System.exit(1);
}
}
public void logTestCompletion( long totalTime){
StringBuilder sb = new StringBuilder();
long maxPhys = Pointer.maxPhysicalBytes();
long maxBytes = Pointer.maxBytes();
long currPhys = Pointer.physicalBytes();
long currBytes = Pointer.totalBytes();
long jvmTotal = Runtime.getRuntime().totalMemory();
long jvmMax = Runtime.getRuntime().maxMemory();
int threadsAfter = ManagementFactory.getThreadMXBean().getThreadCount();
sb.append(getClass().getSimpleName()).append(".").append(testName.getMethodName())
.append(": ").append(totalTime).append(" ms")
.append(", threadCount: (").append(threadCountBefore).append("->").append(threadsAfter).append(")")
.append(", jvmTotal=").append(jvmTotal)
.append(", jvmMax=").append(jvmMax)
.append(", totalBytes=").append(currBytes).append(", maxBytes=").append(maxBytes)
.append(", currPhys=").append(currPhys).append(", maxPhys=").append(maxPhys);
List<MemoryWorkspace> ws = Nd4j.getWorkspaceManager().getAllWorkspacesForCurrentThread();
if(ws != null && ws.size() > 0){
long currSize = 0;
for(MemoryWorkspace w : ws){
currSize += w.getCurrentSize();
}
if(currSize > 0){
sb.append(", threadWSSize=").append(currSize)
.append(" (").append(ws.size()).append(" WSs)");
}
}
Properties p = Nd4j.getExecutioner().getEnvironmentInformation();
Object o = p.get("cuda.devicesInformation");
if(o instanceof List){
List<Map<String,Object>> l = (List<Map<String, Object>>) o;
if(l.size() > 0) {
sb.append(" [").append(l.size())
.append(" GPUs: ");
for (int i = 0; i < l.size(); i++) {
Map<String,Object> m = l.get(i);
if(i > 0)
sb.append(",");
sb.append("(").append(m.get("cuda.freeMemory")).append(" free, ")
.append(m.get("cuda.totalMemory")).append(" total)");
}
sb.append("]");
}
}
log.info(sb.toString());
}
/**
* The ordering for this test
* This test will only be invoked for
@ -315,15 +131,10 @@ public abstract class BaseNd4jTest {
* @return the ordering for this test
*/
public char ordering() {
return 'a';
return 'c';
}
public String getFailureMessage() {
return "Failed with backend " + backend.getClass().getName() + " and ordering " + ordering();
}
}

View File

@ -65,16 +65,6 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
super(backend);
}
@Before
public void before() throws Exception {
super.before();
}
@After
public void after() throws Exception {
super.after();
}
@Test

View File

@ -133,13 +133,13 @@ public class Nd4jTestsC extends BaseNd4jTest {
}
@Override
public long testTimeoutMilliseconds() {
public long getTimeoutMilliseconds() {
return 90000;
}
@Before
public void before() throws Exception {
super.before();
super.beforeTest();
Nd4j.setDataType(DataType.DOUBLE);
Nd4j.getRandom().setSeed(123);
Nd4j.getExecutioner().enableDebugMode(false);
@ -148,7 +148,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
@After
public void after() throws Exception {
super.after();
super.afterTest();
Nd4j.setDataType(initialType);
}
@ -5331,7 +5331,8 @@ public class Nd4jTestsC extends BaseNd4jTest {
@Test
public void testNativeSort3() {
INDArray array = Nd4j.linspace(1, 1048576, 1048576, DataType.DOUBLE).reshape(1, -1);
int length = isIntegrationTests() ? 1048576 : 16484;
INDArray array = Nd4j.linspace(1, length, length, DataType.DOUBLE).reshape(1, -1);
INDArray exp = array.dup();
Nd4j.shuffle(array, 0);
@ -7196,19 +7197,19 @@ public class Nd4jTestsC extends BaseNd4jTest {
for( int i=-3; i<3; i++ ){
INDArray out = Nd4j.stack(i, in, in2);
int[] expShape;
long[] expShape;
switch (i){
case -3:
case 0:
expShape = new int[]{2,3,4};
expShape = new long[]{2,3,4};
break;
case -2:
case 1:
expShape = new int[]{3,2,4};
expShape = new long[]{3,2,4};
break;
case -1:
case 2:
expShape = new int[]{3,4,2};
expShape = new long[]{3,4,2};
break;
default:
throw new RuntimeException(String.valueOf(i));
@ -7602,6 +7603,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
String wsName = "testRollingMeanWs";
try {
System.gc();
int iterations1 = isIntegrationTests() ? 5 : 2;
for (int e = 0; e < 5; e++) {
try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(wsconf, wsName)) {
val array = Nd4j.create(DataType.FLOAT, 32, 128, 256, 256);
@ -7609,7 +7611,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
}
}
int iterations = 20;
int iterations = isIntegrationTests() ? 20 : 3;
val timeStart = System.nanoTime();
for (int e = 0; e < iterations; e++) {
try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(wsconf, wsName)) {

View File

@ -57,13 +57,13 @@ public class Nd4jTestsComparisonC extends BaseNd4jTest {
@Before
public void before() throws Exception {
super.before();
super.beforeTest();
DataTypeUtil.setDTypeForContext(DataType.DOUBLE);
}
@After
public void after() throws Exception {
super.after();
super.afterTest();
DataTypeUtil.setDTypeForContext(initialType);
}

View File

@ -37,8 +37,7 @@ import org.slf4j.LoggerFactory;
import java.util.List;
import java.util.Random;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.*;
/**
* Tests comparing Nd4j ops to other libraries
@ -59,7 +58,7 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest {
@Before
public void before() throws Exception {
super.before();
super.beforeTest();
DataTypeUtil.setDTypeForContext(DataType.DOUBLE);
Nd4j.getRandom().setSeed(SEED);
@ -67,7 +66,7 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest {
@After
public void after() throws Exception {
super.after();
super.afterTest();
DataTypeUtil.setDTypeForContext(initialType);
}
@ -197,7 +196,7 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest {
INDArray gemv = m.mmul(v);
RealMatrix gemv2 = rm.multiply(rv);
assertArrayEquals(new int[] {rows, 1}, gemv.shape());
assertArrayEquals(new long[] {rows, 1}, gemv.shape());
assertArrayEquals(new int[] {rows, 1},
new int[] {gemv2.getRowDimension(), gemv2.getColumnDimension()});

View File

@ -25,6 +25,8 @@ import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil;
import static org.junit.Assert.assertArrayEquals;
/**
* Created by Alex on 30/04/2016.
*/

View File

@ -38,8 +38,7 @@ import org.nd4j.linalg.util.SerializationUtils;
import java.io.*;
import java.util.Arrays;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.*;
/**
* Double data buffer tests

View File

@ -37,8 +37,7 @@ import org.nd4j.linalg.util.SerializationUtils;
import java.io.*;
import java.nio.ByteBuffer;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.*;
/**
* Float data buffer tests

View File

@ -31,8 +31,7 @@ import org.nd4j.linalg.factory.Nd4jBackend;
import java.io.*;
import java.util.Arrays;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.*;
/**
* Tests for INT INDArrays and DataBuffers serialization

View File

@ -33,8 +33,7 @@ import org.nd4j.linalg.util.ArrayUtil;
import java.util.Arrays;
import java.util.Random;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.*;
import static org.nd4j.linalg.indexing.NDArrayIndex.*;
/**

View File

@ -27,8 +27,7 @@ import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.indexing.PointIndex;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.*;
/**
* @author Adam Gibson

Some files were not shown because too many files have changed in this diff Show More