Unit/integration test split + test speedup (#166)
* Add maven profile + base tests methods for integration tests Signed-off-by: AlexDBlack <blacka101@gmail.com> * Switch from system property to environment variable; seems more reliable in intellij Signed-off-by: AlexDBlack <blacka101@gmail.com> * Add nd4j-common-tests module, and common base test; cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * Ensure all ND4J tests extend BaseND4JTest Signed-off-by: AlexDBlack <blacka101@gmail.com> * Test spam reduction, import fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Add test logging to nd4j-aeron Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix unintended change Signed-off-by: AlexDBlack <blacka101@gmail.com> * Reduce sprint test log spam Signed-off-by: AlexDBlack <blacka101@gmail.com> * More test spam cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * Significantly speed up TSNE tests Signed-off-by: AlexDBlack <blacka101@gmail.com> * W2V iterator test unit/integration split Signed-off-by: AlexDBlack <blacka101@gmail.com> * More NLP test speedups Signed-off-by: AlexDBlack <blacka101@gmail.com> * Avoid debug/verbose mode leaking between tests Signed-off-by: AlexDBlack <blacka101@gmail.com> * test tweak Signed-off-by: AlexDBlack <blacka101@gmail.com> * Arbiter extends base DL4J test Signed-off-by: AlexDBlack <blacka101@gmail.com> * Arbiter test speedup Signed-off-by: AlexDBlack <blacka101@gmail.com> * nlp-uima test speedup Signed-off-by: AlexDBlack <blacka101@gmail.com> * More test speedups Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix ND4J base test Signed-off-by: AlexDBlack <blacka101@gmail.com> * Few small ND4J test speed improvements Signed-off-by: AlexDBlack <blacka101@gmail.com> * DL4J tests speedup Signed-off-by: AlexDBlack <blacka101@gmail.com> * More tweaks Signed-off-by: AlexDBlack <blacka101@gmail.com> * Even more test speedups Signed-off-by: AlexDBlack <blacka101@gmail.com> * More tweaks Signed-off-by: AlexDBlack <blacka101@gmail.com> * Various test fixes Signed-off-by: Alex Black <blacka101@gmail.com> * More test fixes Signed-off-by: Alex Black <blacka101@gmail.com> * Add ability to specify number of threads for C++ ops in BaseDL4JTest and BaseND4JTest Signed-off-by: Alex Black <blacka101@gmail.com> * nd4j-aeron test profile fix for CUDA Signed-off-by: Alex Black <blacka101@gmail.com>master
parent
2717b25931
commit
a25bb6a11c
|
@ -84,6 +84,13 @@
|
|||
<artifactId>jackson</artifactId>
|
||||
<version>${nd4j.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>deeplearning4j-common-tests</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<profiles>
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.arbiter.optimize;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
|
||||
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
|
||||
import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition;
|
||||
|
@ -32,7 +33,7 @@ import org.deeplearning4j.arbiter.optimize.runner.listener.impl.LoggingStatusLis
|
|||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
public class TestGeneticSearch {
|
||||
public class TestGeneticSearch extends BaseDL4JTest {
|
||||
public class TestSelectionOperator extends SelectionOperator {
|
||||
|
||||
@Override
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.arbiter.optimize;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
|
||||
import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.GridSearchCandidateGenerator;
|
||||
|
@ -26,7 +27,7 @@ import java.util.Map;
|
|||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
public class TestGridSearch {
|
||||
public class TestGridSearch extends BaseDL4JTest {
|
||||
|
||||
@Test
|
||||
public void testIndexing() {
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.deeplearning4j.arbiter.optimize;
|
|||
import org.apache.commons.math3.distribution.LogNormalDistribution;
|
||||
import org.apache.commons.math3.distribution.NormalDistribution;
|
||||
import org.apache.commons.math3.distribution.UniformIntegerDistribution;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
|
||||
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
||||
import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider;
|
||||
|
@ -49,7 +50,7 @@ import static org.junit.Assert.assertEquals;
|
|||
/**
|
||||
* Created by Alex on 02/02/2017.
|
||||
*/
|
||||
public class TestJson {
|
||||
public class TestJson extends BaseDL4JTest {
|
||||
|
||||
protected static ObjectMapper getObjectMapper(JsonFactory factory) {
|
||||
ObjectMapper om = new ObjectMapper(factory);
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.arbiter.optimize;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
|
||||
import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider;
|
||||
import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition;
|
||||
|
@ -34,7 +35,7 @@ import java.util.Map;
|
|||
* Test random search on the Branin Function:
|
||||
* http://www.sfu.ca/~ssurjano/branin.html
|
||||
*/
|
||||
public class TestRandomSearch {
|
||||
public class TestRandomSearch extends BaseDL4JTest {
|
||||
|
||||
@Test
|
||||
public void test() throws Exception {
|
||||
|
|
|
@ -17,12 +17,13 @@
|
|||
package org.deeplearning4j.arbiter.optimize.distribution;
|
||||
|
||||
import org.apache.commons.math3.distribution.RealDistribution;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.junit.Test;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
public class TestLogUniform {
|
||||
public class TestLogUniform extends BaseDL4JTest {
|
||||
|
||||
@Test
|
||||
public void testSimple(){
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
|
||||
|
||||
import org.apache.commons.math3.random.RandomGenerator;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.ArithmeticCrossover;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
|
||||
import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection;
|
||||
|
@ -24,7 +25,7 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
|
|||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
public class ArithmeticCrossoverTests {
|
||||
public class ArithmeticCrossoverTests extends BaseDL4JTest {
|
||||
|
||||
@Test
|
||||
public void ArithmeticCrossover_Crossover_OutsideCrossoverRate_ShouldReturnParent0() {
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
|
||||
import org.deeplearning4j.arbiter.optimize.genetic.TestCrossoverOperator;
|
||||
|
@ -23,7 +24,7 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer;
|
|||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
public class CrossoverOperatorTests {
|
||||
public class CrossoverOperatorTests extends BaseDL4JTest {
|
||||
|
||||
@Test
|
||||
public void CrossoverOperator_initializeInstance_ShouldInitPopulationModel() throws IllegalAccessException {
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
|
||||
|
||||
import org.apache.commons.math3.random.RandomGenerator;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.utils.CrossoverPointsGenerator;
|
||||
import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
|
||||
import org.junit.Assert;
|
||||
|
@ -24,7 +25,7 @@ import org.junit.Test;
|
|||
|
||||
import java.util.Deque;
|
||||
|
||||
public class CrossoverPointsGeneratorTests {
|
||||
public class CrossoverPointsGeneratorTests extends BaseDL4JTest {
|
||||
|
||||
@Test
|
||||
public void CrossoverPointsGenerator_FixedNumberCrossovers() {
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
|
||||
|
||||
import org.apache.commons.math3.random.RandomGenerator;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.KPointCrossover;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection;
|
||||
|
@ -25,7 +26,7 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
|
|||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
public class KPointCrossoverTests {
|
||||
public class KPointCrossoverTests extends BaseDL4JTest {
|
||||
|
||||
@Test
|
||||
public void KPointCrossover_BelowCrossoverRate_ShouldReturnParent0() {
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
||||
import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection;
|
||||
import org.junit.Assert;
|
||||
|
@ -24,7 +25,7 @@ import org.junit.Test;
|
|||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class ParentSelectionTests {
|
||||
public class ParentSelectionTests extends BaseDL4JTest {
|
||||
|
||||
@Test
|
||||
public void ParentSelection_InitializeInstance_ShouldInitPopulation() {
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
|
||||
|
||||
import org.apache.commons.math3.random.RandomGenerator;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.RandomTwoParentSelection;
|
||||
import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
|
||||
|
@ -26,7 +27,7 @@ import org.junit.Test;
|
|||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class RandomTwoParentSelectionTests {
|
||||
public class RandomTwoParentSelectionTests extends BaseDL4JTest {
|
||||
@Test
|
||||
public void RandomTwoParentSelection_ShouldReturnTwoDifferentParents() {
|
||||
RandomGenerator rng = new TestRandomGenerator(new int[] {1, 1, 1, 0}, null);
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
|
||||
|
||||
import org.apache.commons.math3.random.RandomGenerator;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.SinglePointCrossover;
|
||||
import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection;
|
||||
|
@ -24,7 +25,7 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
|
|||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
public class SinglePointCrossoverTests {
|
||||
public class SinglePointCrossoverTests extends BaseDL4JTest {
|
||||
@Test
|
||||
public void SinglePointCrossover_BelowCrossoverRate_ShouldReturnParent0() {
|
||||
RandomGenerator rng = new TestRandomGenerator(null, new double[] {1.0});
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.TwoParentsCrossoverOperator;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection;
|
||||
|
@ -27,7 +28,7 @@ import org.junit.Assert;
|
|||
import org.junit.Test;
|
||||
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
|
||||
|
||||
public class TwoParentsCrossoverOperatorTests {
|
||||
public class TwoParentsCrossoverOperatorTests extends BaseDL4JTest {
|
||||
|
||||
class TestTwoParentsCrossoverOperator extends TwoParentsCrossoverOperator {
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
|
||||
|
||||
import org.apache.commons.math3.random.RandomGenerator;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.UniformCrossover;
|
||||
import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection;
|
||||
|
@ -24,7 +25,7 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
|
|||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
public class UniformCrossoverTests {
|
||||
public class UniformCrossoverTests extends BaseDL4JTest {
|
||||
|
||||
@Test
|
||||
public void UniformCrossover_BelowCrossoverRate_ShouldReturnParent0() {
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.arbiter.optimize.genetic.culling;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.LeastFitCullOperator;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
|
||||
|
@ -27,7 +28,7 @@ import org.junit.Test;
|
|||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class LeastFitCullOperatorTests {
|
||||
public class LeastFitCullOperatorTests extends BaseDL4JTest {
|
||||
|
||||
@Test
|
||||
public void LeastFitCullingOperation_ShouldCullLastElements() {
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.arbiter.optimize.genetic.culling;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.RatioCullOperator;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
|
||||
|
@ -27,7 +28,7 @@ import sun.reflect.generics.reflectiveObjects.NotImplementedException;
|
|||
|
||||
import java.util.List;
|
||||
|
||||
public class RatioCullOperatorTests {
|
||||
public class RatioCullOperatorTests extends BaseDL4JTest {
|
||||
|
||||
class TestRatioCullOperator extends RatioCullOperator {
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.arbiter.optimize.genetic.mutation;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.mutation.RandomMutationOperator;
|
||||
import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
|
||||
import org.junit.Assert;
|
||||
|
@ -24,7 +25,7 @@ import org.junit.Test;
|
|||
import java.lang.reflect.Field;
|
||||
import java.util.Arrays;
|
||||
|
||||
public class RandomMutationOperatorTests {
|
||||
public class RandomMutationOperatorTests extends BaseDL4JTest {
|
||||
@Test
|
||||
public void RandomMutationOperator_DefaultBuild_ShouldNotBeNull() {
|
||||
RandomMutationOperator sut = new RandomMutationOperator.Builder().build();
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.arbiter.optimize.genetic.population;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.CullOperator;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
|
||||
|
@ -27,7 +28,7 @@ import org.junit.Test;
|
|||
|
||||
import java.util.List;
|
||||
|
||||
public class PopulationModelTests {
|
||||
public class PopulationModelTests extends BaseDL4JTest {
|
||||
|
||||
private class TestCullOperator implements CullOperator {
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.deeplearning4j.arbiter.optimize.genetic.selection;
|
||||
|
||||
import org.apache.commons.math3.random.RandomGenerator;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverOperator;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
|
||||
|
@ -36,7 +37,7 @@ import sun.reflect.generics.reflectiveObjects.NotImplementedException;
|
|||
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
|
||||
public class GeneticSelectionOperatorTests {
|
||||
public class GeneticSelectionOperatorTests extends BaseDL4JTest {
|
||||
|
||||
private class TestCullOperator implements CullOperator {
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.arbiter.optimize.genetic.selection;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
|
||||
|
@ -25,7 +26,7 @@ import org.junit.Assert;
|
|||
import org.junit.Test;
|
||||
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
|
||||
|
||||
public class SelectionOperatorTests {
|
||||
public class SelectionOperatorTests extends BaseDL4JTest {
|
||||
private class TestSelectionOperator extends SelectionOperator {
|
||||
|
||||
public PopulationModel getPopulationModel() {
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.deeplearning4j.arbiter.optimize.parameter;
|
||||
|
||||
import org.apache.commons.math3.distribution.NormalDistribution;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
||||
import org.deeplearning4j.arbiter.optimize.parameter.continuous.ContinuousParameterSpace;
|
||||
import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterSpace;
|
||||
|
@ -25,7 +26,7 @@ import org.junit.Test;
|
|||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
public class TestParameterSpaces {
|
||||
public class TestParameterSpaces extends BaseDL4JTest {
|
||||
|
||||
|
||||
@Test
|
||||
|
|
|
@ -63,6 +63,13 @@
|
|||
<artifactId>gson</artifactId>
|
||||
<version>${gson.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>deeplearning4j-common-tests</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<profiles>
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.arbiter.computationgraph;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.ComputationGraphSpace;
|
||||
import org.deeplearning4j.arbiter.TestUtils;
|
||||
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
|
||||
|
@ -44,7 +45,7 @@ import java.util.Random;
|
|||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
public class TestComputationGraphSpace {
|
||||
public class TestComputationGraphSpace extends BaseDL4JTest {
|
||||
|
||||
@Test
|
||||
public void testBasic() {
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.deeplearning4j.arbiter.computationgraph;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.ComputationGraphSpace;
|
||||
import org.deeplearning4j.arbiter.conf.updater.AdamSpace;
|
||||
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
|
||||
|
@ -85,7 +86,7 @@ import static org.junit.Assert.assertEquals;
|
|||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
@Slf4j
|
||||
public class TestGraphLocalExecution {
|
||||
public class TestGraphLocalExecution extends BaseDL4JTest {
|
||||
|
||||
@Rule
|
||||
public TemporaryFolder testDir = new TemporaryFolder();
|
||||
|
@ -126,7 +127,7 @@ public class TestGraphLocalExecution {
|
|||
if(dataApproach == 0){
|
||||
ds = TestDL4JLocalExecution.MnistDataSource.class;
|
||||
dsP = new Properties();
|
||||
dsP.setProperty("minibatch", "8");
|
||||
dsP.setProperty("minibatch", "2");
|
||||
candidateGenerator = new RandomSearchGenerator(mls);
|
||||
} else if(dataApproach == 1) {
|
||||
//DataProvider approach
|
||||
|
@ -150,8 +151,8 @@ public class TestGraphLocalExecution {
|
|||
.dataSource(ds, dsP)
|
||||
.modelSaver(new FileModelSaver(modelSave))
|
||||
.scoreFunction(new TestSetLossScoreFunction())
|
||||
.terminationConditions(new MaxTimeCondition(2, TimeUnit.MINUTES),
|
||||
new MaxCandidatesCondition(5))
|
||||
.terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS),
|
||||
new MaxCandidatesCondition(3))
|
||||
.build();
|
||||
|
||||
IOptimizationRunner runner = new LocalOptimizationRunner(configuration,new ComputationGraphTaskCreator(new ClassificationEvaluator()));
|
||||
|
@ -159,7 +160,7 @@ public class TestGraphLocalExecution {
|
|||
runner.execute();
|
||||
|
||||
List<ResultReference> results = runner.getResults();
|
||||
assertEquals(5, results.size());
|
||||
assertTrue(results.size() > 0);
|
||||
|
||||
System.out.println("----- COMPLETE - " + results.size() + " results -----");
|
||||
}
|
||||
|
@ -203,8 +204,8 @@ public class TestGraphLocalExecution {
|
|||
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
|
||||
.candidateGenerator(candidateGenerator).dataProvider(dataProvider)
|
||||
.modelSaver(new FileModelSaver(modelSavePath)).scoreFunction(ScoreFunctions.testSetLoss(true))
|
||||
.terminationConditions(new MaxTimeCondition(20, TimeUnit.SECONDS),
|
||||
new MaxCandidatesCondition(10))
|
||||
.terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS),
|
||||
new MaxCandidatesCondition(3))
|
||||
.build();
|
||||
|
||||
IOptimizationRunner runner = new LocalOptimizationRunner(configuration,
|
||||
|
@ -223,7 +224,7 @@ public class TestGraphLocalExecution {
|
|||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
||||
.updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.1)))
|
||||
.l2(new ContinuousParameterSpace(0.0001, 0.01)).addInputs("in")
|
||||
.setInputTypes(InputType.feedForward(4))
|
||||
.setInputTypes(InputType.feedForward(784))
|
||||
.addLayer("layer0",
|
||||
new DenseLayerSpace.Builder().nIn(784).nOut(new IntegerParameterSpace(2, 10))
|
||||
.activation(new DiscreteParameterSpace<>(Activation.RELU, Activation.TANH))
|
||||
|
@ -250,8 +251,8 @@ public class TestGraphLocalExecution {
|
|||
.candidateGenerator(candidateGenerator)
|
||||
.dataProvider(new TestMdsDataProvider(1, 32))
|
||||
.modelSaver(new FileModelSaver(modelSavePath)).scoreFunction(ScoreFunctions.testSetLoss(true))
|
||||
.terminationConditions(new MaxTimeCondition(20, TimeUnit.SECONDS),
|
||||
new MaxCandidatesCondition(10))
|
||||
.terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS),
|
||||
new MaxCandidatesCondition(3))
|
||||
.scoreFunction(ScoreFunctions.testSetAccuracy())
|
||||
.build();
|
||||
|
||||
|
@ -279,7 +280,7 @@ public class TestGraphLocalExecution {
|
|||
@Override
|
||||
public Object trainData(Map<String, Object> dataParameters) {
|
||||
try {
|
||||
DataSetIterator underlying = new MnistDataSetIterator(batchSize, Math.min(60000, 10 * batchSize), false, true, true, 12345);
|
||||
DataSetIterator underlying = new MnistDataSetIterator(batchSize, Math.min(60000, 3 * batchSize), false, true, true, 12345);
|
||||
return new MultiDataSetIteratorAdapter(new MultipleEpochsIterator(numEpochs, underlying));
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
|
@ -289,7 +290,7 @@ public class TestGraphLocalExecution {
|
|||
@Override
|
||||
public Object testData(Map<String, Object> dataParameters) {
|
||||
try {
|
||||
DataSetIterator underlying = new MnistDataSetIterator(batchSize, Math.min(10000, 5 * batchSize), false, false, false, 12345);
|
||||
DataSetIterator underlying = new MnistDataSetIterator(batchSize, Math.min(10000, 2 * batchSize), false, false, false, 12345);
|
||||
return new MultiDataSetIteratorAdapter(underlying);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
|
@ -305,7 +306,7 @@ public class TestGraphLocalExecution {
|
|||
@Test
|
||||
public void testLocalExecutionEarlyStopping() throws Exception {
|
||||
EarlyStoppingConfiguration<ComputationGraph> esConf = new EarlyStoppingConfiguration.Builder<ComputationGraph>()
|
||||
.epochTerminationConditions(new MaxEpochsTerminationCondition(4))
|
||||
.epochTerminationConditions(new MaxEpochsTerminationCondition(2))
|
||||
.scoreCalculator(new ScoreProvider())
|
||||
.modelSaver(new InMemoryModelSaver()).build();
|
||||
Map<String, Object> commands = new HashMap<>();
|
||||
|
@ -348,8 +349,8 @@ public class TestGraphLocalExecution {
|
|||
.dataProvider(dataProvider)
|
||||
.scoreFunction(ScoreFunctions.testSetF1())
|
||||
.modelSaver(new FileModelSaver(modelSavePath))
|
||||
.terminationConditions(new MaxTimeCondition(45, TimeUnit.SECONDS),
|
||||
new MaxCandidatesCondition(10))
|
||||
.terminationConditions(new MaxTimeCondition(15, TimeUnit.SECONDS),
|
||||
new MaxCandidatesCondition(3))
|
||||
.build();
|
||||
|
||||
|
||||
|
@ -364,7 +365,7 @@ public class TestGraphLocalExecution {
|
|||
@Override
|
||||
public ScoreCalculator get() {
|
||||
try {
|
||||
return new DataSetLossCalculatorCG(new MnistDataSetIterator(128, 1280), true);
|
||||
return new DataSetLossCalculatorCG(new MnistDataSetIterator(4, 8), true);
|
||||
} catch (Exception e){
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.deeplearning4j.arbiter.computationgraph;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.ComputationGraphSpace;
|
||||
import org.deeplearning4j.arbiter.conf.updater.AdamSpace;
|
||||
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
|
||||
|
@ -79,11 +80,16 @@ import static org.junit.Assert.assertEquals;
|
|||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
@Slf4j
|
||||
public class TestGraphLocalExecutionGenetic {
|
||||
public class TestGraphLocalExecutionGenetic extends BaseDL4JTest {
|
||||
|
||||
@Rule
|
||||
public TemporaryFolder testDir = new TemporaryFolder();
|
||||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 45000L;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLocalExecutionDataSources() throws Exception {
|
||||
for (int dataApproach = 0; dataApproach < 3; dataApproach++) {
|
||||
|
@ -115,7 +121,7 @@ public class TestGraphLocalExecutionGenetic {
|
|||
if (dataApproach == 0) {
|
||||
ds = TestDL4JLocalExecution.MnistDataSource.class;
|
||||
dsP = new Properties();
|
||||
dsP.setProperty("minibatch", "8");
|
||||
dsP.setProperty("minibatch", "2");
|
||||
|
||||
candidateGenerator = new GeneticSearchCandidateGenerator.Builder(mls, scoreFunction)
|
||||
.populationModel(new PopulationModel.Builder().populationSize(5).build())
|
||||
|
@ -148,7 +154,7 @@ public class TestGraphLocalExecutionGenetic {
|
|||
.dataSource(ds, dsP)
|
||||
.modelSaver(new FileModelSaver(modelSave))
|
||||
.scoreFunction(new TestSetLossScoreFunction())
|
||||
.terminationConditions(new MaxTimeCondition(2, TimeUnit.MINUTES),
|
||||
.terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS),
|
||||
new MaxCandidatesCondition(10))
|
||||
.build();
|
||||
|
||||
|
@ -157,7 +163,7 @@ public class TestGraphLocalExecutionGenetic {
|
|||
runner.execute();
|
||||
|
||||
List<ResultReference> results = runner.getResults();
|
||||
assertEquals(10, results.size());
|
||||
assertTrue(results.size() > 0);
|
||||
|
||||
System.out.println("----- COMPLETE - " + results.size() + " results -----");
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.arbiter.json;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.ComputationGraphSpace;
|
||||
import org.deeplearning4j.arbiter.MultiLayerSpace;
|
||||
import org.deeplearning4j.arbiter.conf.updater.AdaMaxSpace;
|
||||
|
@ -71,7 +72,7 @@ import static org.junit.Assert.assertNotNull;
|
|||
/**
|
||||
* Created by Alex on 14/02/2017.
|
||||
*/
|
||||
public class TestJson {
|
||||
public class TestJson extends BaseDL4JTest {
|
||||
|
||||
@Test
|
||||
public void testMultiLayerSpaceJson() {
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.arbiter.multilayernetwork;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.MultiLayerSpace;
|
||||
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
|
||||
import org.deeplearning4j.arbiter.layers.ConvolutionLayerSpace;
|
||||
|
@ -59,7 +60,7 @@ import java.util.concurrent.TimeUnit;
|
|||
// import org.deeplearning4j.arbiter.optimize.ui.listener.UIOptimizationRunnerStatusListener;
|
||||
|
||||
/** Not strictly a unit test. Rather: part example, part debugging on MNIST */
|
||||
public class MNISTOptimizationTest {
|
||||
public class MNISTOptimizationTest extends BaseDL4JTest {
|
||||
|
||||
public static void main(String[] args) throws Exception {
|
||||
EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.deeplearning4j.arbiter.multilayernetwork;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.MultiLayerSpace;
|
||||
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
|
||||
import org.deeplearning4j.arbiter.evaluator.multilayer.ClassificationEvaluator;
|
||||
|
@ -72,9 +73,10 @@ import java.util.Properties;
|
|||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
@Slf4j
|
||||
public class TestDL4JLocalExecution {
|
||||
public class TestDL4JLocalExecution extends BaseDL4JTest {
|
||||
|
||||
@Rule
|
||||
public TemporaryFolder testDir = new TemporaryFolder();
|
||||
|
@ -112,7 +114,7 @@ public class TestDL4JLocalExecution {
|
|||
if(dataApproach == 0){
|
||||
ds = MnistDataSource.class;
|
||||
dsP = new Properties();
|
||||
dsP.setProperty("minibatch", "8");
|
||||
dsP.setProperty("minibatch", "2");
|
||||
candidateGenerator = new RandomSearchGenerator(mls);
|
||||
} else if(dataApproach == 1) {
|
||||
//DataProvider approach
|
||||
|
@ -136,7 +138,7 @@ public class TestDL4JLocalExecution {
|
|||
.dataSource(ds, dsP)
|
||||
.modelSaver(new FileModelSaver(modelSave))
|
||||
.scoreFunction(new TestSetLossScoreFunction())
|
||||
.terminationConditions(new MaxTimeCondition(2, TimeUnit.MINUTES),
|
||||
.terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS),
|
||||
new MaxCandidatesCondition(5))
|
||||
.build();
|
||||
|
||||
|
@ -146,7 +148,7 @@ public class TestDL4JLocalExecution {
|
|||
runner.execute();
|
||||
|
||||
List<ResultReference> results = runner.getResults();
|
||||
assertEquals(5, results.size());
|
||||
assertTrue(results.size() > 0);
|
||||
|
||||
System.out.println("----- COMPLETE - " + results.size() + " results -----");
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.arbiter.multilayernetwork;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.ComputationGraphSpace;
|
||||
import org.deeplearning4j.arbiter.MultiLayerSpace;
|
||||
import org.deeplearning4j.arbiter.layers.DenseLayerSpace;
|
||||
|
@ -39,7 +40,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
|
|||
|
||||
import java.io.File;
|
||||
|
||||
public class TestErrors {
|
||||
public class TestErrors extends BaseDL4JTest {
|
||||
|
||||
@Rule
|
||||
public TemporaryFolder temp = new TemporaryFolder();
|
||||
|
@ -60,7 +61,7 @@ public class TestErrors {
|
|||
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls);
|
||||
|
||||
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
|
||||
.candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 10))
|
||||
.candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 3))
|
||||
.modelSaver(new FileModelSaver(f)).scoreFunction(new TestSetLossScoreFunction(true))
|
||||
.terminationConditions(
|
||||
new MaxCandidatesCondition(5))
|
||||
|
@ -87,7 +88,7 @@ public class TestErrors {
|
|||
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls);
|
||||
|
||||
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
|
||||
.candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 10))
|
||||
.candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 3))
|
||||
.modelSaver(new FileModelSaver(f)).scoreFunction(new TestSetLossScoreFunction(true))
|
||||
.terminationConditions(
|
||||
new MaxCandidatesCondition(5))
|
||||
|
@ -116,7 +117,7 @@ public class TestErrors {
|
|||
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls);
|
||||
|
||||
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
|
||||
.candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 10))
|
||||
.candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 3))
|
||||
.modelSaver(new FileModelSaver(f)).scoreFunction(new TestSetLossScoreFunction(true))
|
||||
.terminationConditions(new MaxCandidatesCondition(5))
|
||||
.build();
|
||||
|
@ -143,7 +144,7 @@ public class TestErrors {
|
|||
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls);
|
||||
|
||||
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
|
||||
.candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 10))
|
||||
.candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 3))
|
||||
.modelSaver(new FileModelSaver(f)).scoreFunction(new TestSetLossScoreFunction(true))
|
||||
.terminationConditions(
|
||||
new MaxCandidatesCondition(5))
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.deeplearning4j.arbiter.multilayernetwork;
|
||||
|
||||
import org.apache.commons.lang3.ArrayUtils;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.TestUtils;
|
||||
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
|
||||
import org.deeplearning4j.arbiter.layers.*;
|
||||
|
@ -44,7 +45,7 @@ import java.util.Random;
|
|||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
public class TestLayerSpace {
|
||||
public class TestLayerSpace extends BaseDL4JTest {
|
||||
|
||||
@Test
|
||||
public void testBasic1() {
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.arbiter.multilayernetwork;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.DL4JConfiguration;
|
||||
import org.deeplearning4j.arbiter.MultiLayerSpace;
|
||||
import org.deeplearning4j.arbiter.TestUtils;
|
||||
|
@ -86,7 +87,7 @@ import java.util.*;
|
|||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
public class TestMultiLayerSpace {
|
||||
public class TestMultiLayerSpace extends BaseDL4JTest {
|
||||
|
||||
@Rule
|
||||
public TemporaryFolder testDir = new TemporaryFolder();
|
||||
|
|
|
@ -18,6 +18,7 @@ package org.deeplearning4j.arbiter.multilayernetwork;
|
|||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.MultiLayerSpace;
|
||||
import org.deeplearning4j.arbiter.conf.updater.AdamSpace;
|
||||
import org.deeplearning4j.arbiter.layers.OutputLayerSpace;
|
||||
|
@ -60,7 +61,13 @@ import java.util.Map;
|
|||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
@Slf4j
|
||||
public class TestScoreFunctions {
|
||||
public class TestScoreFunctions extends BaseDL4JTest {
|
||||
|
||||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 60000L;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testROCScoreFunctions() throws Exception {
|
||||
|
@ -107,7 +114,7 @@ public class TestScoreFunctions {
|
|||
List<ResultReference> list = runner.getResults();
|
||||
|
||||
for (ResultReference rr : list) {
|
||||
DataSetIterator testIter = new MnistDataSetIterator(32, 2000, false, false, true, 12345);
|
||||
DataSetIterator testIter = new MnistDataSetIterator(4, 16, false, false, false, 12345);
|
||||
testIter.setPreProcessor(new PreProc(rocType));
|
||||
|
||||
OptimizationResult or = rr.getResult();
|
||||
|
@ -141,10 +148,10 @@ public class TestScoreFunctions {
|
|||
}
|
||||
|
||||
|
||||
DataSetIterator iter = new MnistDataSetIterator(32, 8000, false, true, true, 12345);
|
||||
DataSetIterator iter = new MnistDataSetIterator(4, 16, false, false, false, 12345);
|
||||
iter.setPreProcessor(new PreProc(rocType));
|
||||
|
||||
assertEquals(msg, expScore, or.getScore(), 1e-5);
|
||||
assertEquals(msg, expScore, or.getScore(), 1e-4);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -158,7 +165,7 @@ public class TestScoreFunctions {
|
|||
@Override
|
||||
public Object trainData(Map<String, Object> dataParameters) {
|
||||
try {
|
||||
DataSetIterator iter = new MnistDataSetIterator(32, 8000, false, true, true, 12345);
|
||||
DataSetIterator iter = new MnistDataSetIterator(4, 16, false, false, false, 12345);
|
||||
iter.setPreProcessor(new PreProc(rocType));
|
||||
return iter;
|
||||
} catch (IOException e){
|
||||
|
@ -169,7 +176,7 @@ public class TestScoreFunctions {
|
|||
@Override
|
||||
public Object testData(Map<String, Object> dataParameters) {
|
||||
try {
|
||||
DataSetIterator iter = new MnistDataSetIterator(32, 2000, false, false, true, 12345);
|
||||
DataSetIterator iter = new MnistDataSetIterator(4, 16, false, false, false, 12345);
|
||||
iter.setPreProcessor(new PreProc(rocType));
|
||||
return iter;
|
||||
} catch (IOException e){
|
||||
|
|
|
@ -49,6 +49,13 @@
|
|||
<version>${junit.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>deeplearning4j-common-tests</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<profiles>
|
||||
|
|
|
@ -18,6 +18,7 @@ package org.deeplearning4j.arbiter.server;
|
|||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.MultiLayerSpace;
|
||||
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
|
||||
import org.deeplearning4j.arbiter.layers.DenseLayerSpace;
|
||||
|
@ -52,7 +53,7 @@ import static org.junit.Assert.assertEquals;
|
|||
* Created by agibsonccc on 3/12/17.
|
||||
*/
|
||||
@Slf4j
|
||||
public class ArbiterCLIRunnerTest {
|
||||
public class ArbiterCLIRunnerTest extends BaseDL4JTest {
|
||||
|
||||
|
||||
@Test
|
||||
|
|
|
@ -24,6 +24,7 @@ import org.junit.Before;
|
|||
import org.junit.Rule;
|
||||
import org.junit.rules.TestName;
|
||||
import org.junit.rules.Timeout;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.config.ND4JSystemProperties;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||
|
@ -36,6 +37,8 @@ import java.util.List;
|
|||
import java.util.Map;
|
||||
import java.util.Properties;
|
||||
|
||||
import static org.junit.Assume.assumeTrue;
|
||||
|
||||
@Slf4j
|
||||
public abstract class BaseDL4JTest {
|
||||
|
||||
|
@ -47,6 +50,17 @@ public abstract class BaseDL4JTest {
|
|||
protected long startTime;
|
||||
protected int threadCountBefore;
|
||||
|
||||
private final int DEFAULT_THREADS = Runtime.getRuntime().availableProcessors();
|
||||
|
||||
/**
|
||||
* Override this to specify the number of threads for C++ execution, via
|
||||
* {@link org.nd4j.linalg.factory.Environment#setMaxMasterThreads(int)}
|
||||
* @return Number of threads to use for C++ op execution
|
||||
*/
|
||||
public int numThreads(){
|
||||
return DEFAULT_THREADS;
|
||||
}
|
||||
|
||||
/**
|
||||
* Override this method to set the default timeout for methods in the test class
|
||||
*/
|
||||
|
@ -72,6 +86,28 @@ public abstract class BaseDL4JTest {
|
|||
return getDataType();
|
||||
}
|
||||
|
||||
protected Boolean integrationTest;
|
||||
|
||||
/**
|
||||
* @return True if integration tests maven profile is enabled, false otherwise.
|
||||
*/
|
||||
public boolean isIntegrationTests(){
|
||||
if(integrationTest == null){
|
||||
String prop = System.getenv("DL4J_INTEGRATION_TESTS");
|
||||
integrationTest = Boolean.parseBoolean(prop);
|
||||
}
|
||||
return integrationTest;
|
||||
}
|
||||
|
||||
/**
|
||||
* Call this as the first line of a test in order to skip that test, only when the integration tests maven profile is not enabled.
|
||||
* This can be used to dynamically skip integration tests when the integration test profile is not enabled.
|
||||
* Note that the integration test profile is not enabled by default - "integration-tests" profile
|
||||
*/
|
||||
public void skipUnlessIntegrationTests(){
|
||||
assumeTrue("Skipping integration test - integration profile is not enabled", isIntegrationTests());
|
||||
}
|
||||
|
||||
@Before
|
||||
public void beforeTest(){
|
||||
log.info("{}.{}", getClass().getSimpleName(), name.getMethodName());
|
||||
|
@ -81,6 +117,14 @@ public abstract class BaseDL4JTest {
|
|||
Nd4j.getExecutioner().setProfilingMode(getProfilingMode());
|
||||
Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build());
|
||||
Nd4j.setDefaultDataTypes(getDataType(), getDefaultFPDataType());
|
||||
Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build());
|
||||
Nd4j.getExecutioner().enableDebugMode(false);
|
||||
Nd4j.getExecutioner().enableVerboseMode(false);
|
||||
int numThreads = numThreads();
|
||||
Preconditions.checkState(numThreads > 0, "Number of threads must be > 0");
|
||||
if(numThreads != Nd4j.getEnvironment().maxMasterThreads()) {
|
||||
Nd4j.getEnvironment().setMaxMasterThreads(numThreads);
|
||||
}
|
||||
startTime = System.currentTimeMillis();
|
||||
threadCountBefore = ManagementFactory.getThreadMXBean().getThreadCount();
|
||||
}
|
||||
|
|
|
@ -158,7 +158,7 @@ public class LayerHelperValidationUtil {
|
|||
double d2 = arr2.dup('c').getDouble(idx);
|
||||
System.out.println("Different values at index " + idx + ": " + d1 + ", " + d2 + " - RE = " + maxRE);
|
||||
}
|
||||
assertTrue(s + layerName + "activations - max RE: " + maxRE, maxRE < t.getMaxRelError());
|
||||
assertTrue(s + layerName + " activations - max RE: " + maxRE, maxRE < t.getMaxRelError());
|
||||
log.info("Forward pass, max relative error: " + layerName + " - " + maxRE);
|
||||
}
|
||||
|
||||
|
|
|
@ -78,8 +78,16 @@ public class AsyncDataSetIteratorTest extends BaseDL4JTest {
|
|||
|
||||
@Test
|
||||
public void hasNextWithResetAndLoad() throws Exception {
|
||||
int[] prefetchSizes;
|
||||
if(isIntegrationTests()){
|
||||
prefetchSizes = new int[]{2, 3, 4, 5, 6, 7, 8};
|
||||
} else {
|
||||
prefetchSizes = new int[]{2, 3, 8};
|
||||
}
|
||||
|
||||
|
||||
for (int iter = 0; iter < ITERATIONS; iter++) {
|
||||
for (int prefetchSize = 2; prefetchSize <= 8; prefetchSize++) {
|
||||
for(int prefetchSize : prefetchSizes){
|
||||
AsyncDataSetIterator iterator = new AsyncDataSetIterator(backIterator, prefetchSize);
|
||||
TestDataSetConsumer consumer = new TestDataSetConsumer(EXECUTION_SMALL);
|
||||
int cnt = 0;
|
||||
|
@ -161,8 +169,14 @@ public class AsyncDataSetIteratorTest extends BaseDL4JTest {
|
|||
|
||||
@Test
|
||||
public void testVariableTimeSeries1() throws Exception {
|
||||
int numBatches = isIntegrationTests() ? 1000 : 100;
|
||||
int batchSize = isIntegrationTests() ? 32 : 8;
|
||||
int timeStepsMin = 10;
|
||||
int timeStepsMax = isIntegrationTests() ? 500 : 100;
|
||||
int valuesPerTimestep = isIntegrationTests() ? 128 : 16;
|
||||
|
||||
AsyncDataSetIterator adsi = new AsyncDataSetIterator(
|
||||
new VariableTimeseriesGenerator(1192, 1000, 32, 128, 10, 500, 10), 2, true);
|
||||
new VariableTimeseriesGenerator(1192, numBatches, batchSize, valuesPerTimestep, timeStepsMin, timeStepsMax, 10), 2, true);
|
||||
|
||||
for (int e = 0; e < 10; e++) {
|
||||
int cnt = 0;
|
||||
|
|
|
@ -18,21 +18,10 @@ package org.deeplearning4j.datasets.iterator;
|
|||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.datavec.api.records.reader.RecordReader;
|
||||
import org.datavec.api.records.reader.SequenceRecordReader;
|
||||
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
|
||||
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
|
||||
import org.datavec.api.split.NumberedFileInputSplit;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator;
|
||||
import org.deeplearning4j.datasets.iterator.tools.VariableMultiTimeseriesGenerator;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||
import org.nd4j.linalg.dataset.api.preprocessor.MultiDataNormalization;
|
||||
import org.nd4j.linalg.dataset.api.preprocessor.MultiNormalizerStandardize;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
|
@ -49,7 +38,13 @@ public class AsyncMultiDataSetIteratorTest extends BaseDL4JTest {
|
|||
*/
|
||||
@Test
|
||||
public void testVariableTimeSeries1() throws Exception {
|
||||
val iterator = new VariableMultiTimeseriesGenerator(1192, 1000, 32, 128, 10, 500, 10);
|
||||
int numBatches = isIntegrationTests() ? 1000 : 100;
|
||||
int batchSize = isIntegrationTests() ? 32 : 8;
|
||||
int timeStepsMin = 10;
|
||||
int timeStepsMax = isIntegrationTests() ? 500 : 100;
|
||||
int valuesPerTimestep = isIntegrationTests() ? 128 : 16;
|
||||
|
||||
val iterator = new VariableMultiTimeseriesGenerator(1192, numBatches, batchSize, valuesPerTimestep, timeStepsMin, timeStepsMax, 10);
|
||||
iterator.reset();
|
||||
iterator.hasNext();
|
||||
val amdsi = new AsyncMultiDataSetIterator(iterator, 2, true);
|
||||
|
@ -81,7 +76,13 @@ public class AsyncMultiDataSetIteratorTest extends BaseDL4JTest {
|
|||
|
||||
@Test
|
||||
public void testVariableTimeSeries2() throws Exception {
|
||||
val iterator = new VariableMultiTimeseriesGenerator(1192, 1000, 32, 128, 10, 500, 10);
|
||||
int numBatches = isIntegrationTests() ? 1000 : 100;
|
||||
int batchSize = isIntegrationTests() ? 32 : 8;
|
||||
int timeStepsMin = 10;
|
||||
int timeStepsMax = isIntegrationTests() ? 500 : 100;
|
||||
int valuesPerTimestep = isIntegrationTests() ? 128 : 16;
|
||||
|
||||
val iterator = new VariableMultiTimeseriesGenerator(1192, numBatches, batchSize, valuesPerTimestep, timeStepsMin, timeStepsMax, 10);
|
||||
|
||||
for (int e = 0; e < 10; e++) {
|
||||
iterator.reset();
|
||||
|
|
|
@ -46,17 +46,17 @@ public class TestEmnistDataSetIterator extends BaseDL4JTest {
|
|||
@Test
|
||||
public void testEmnistDataSetIterator() throws Exception {
|
||||
|
||||
// EmnistFetcher fetcher = new EmnistFetcher(EmnistDataSetIterator.Set.COMPLETE);
|
||||
// File baseEmnistDir = fetcher.getFILE_DIR();
|
||||
// if(baseEmnistDir.exists()){
|
||||
// FileUtils.deleteDirectory(baseEmnistDir);
|
||||
// }
|
||||
// assertFalse(baseEmnistDir.exists());
|
||||
|
||||
|
||||
int batchSize = 128;
|
||||
|
||||
for (EmnistDataSetIterator.Set s : EmnistDataSetIterator.Set.values()) {
|
||||
EmnistDataSetIterator.Set[] sets;
|
||||
if(isIntegrationTests()){
|
||||
sets = EmnistDataSetIterator.Set.values();
|
||||
} else {
|
||||
sets = new EmnistDataSetIterator.Set[]{EmnistDataSetIterator.Set.MNIST, EmnistDataSetIterator.Set.LETTERS};
|
||||
}
|
||||
|
||||
for (EmnistDataSetIterator.Set s : sets) {
|
||||
boolean isBalanced = EmnistDataSetIterator.isBalanced(s);
|
||||
int numLabels = EmnistDataSetIterator.numLabels(s);
|
||||
INDArray labelCounts = null;
|
||||
|
|
|
@ -476,7 +476,7 @@ public class EvalTest extends BaseDL4JTest {
|
|||
|
||||
net.setListeners(new EvaluativeListener(iterTest, 3));
|
||||
|
||||
for( int i=0; i<10; i++ ){
|
||||
for( int i=0; i<3; i++ ){
|
||||
net.fit(iter);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -339,9 +339,6 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest {
|
|||
|
||||
if (PRINT_RESULTS) {
|
||||
log.info(msg);
|
||||
for (int j = 0; j < net.getnLayers(); j++) {
|
||||
log.info("Layer " + j + " # params: " + net.getLayer(j).numParams());
|
||||
}
|
||||
}
|
||||
|
||||
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS,
|
||||
|
@ -623,13 +620,10 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest {
|
|||
|
||||
if (PRINT_RESULTS) {
|
||||
log.info(msg);
|
||||
// for (int j = 0; j < net.getnLayers(); j++) {
|
||||
// log.info("Layer " + j + " # params: " + net.getLayer(j).numParams());
|
||||
// }
|
||||
}
|
||||
|
||||
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input)
|
||||
.labels(labels).subset(true).maxPerParam(128));
|
||||
.labels(labels).subset(true).maxPerParam(64));
|
||||
|
||||
assertTrue(msg, gradOK);
|
||||
|
||||
|
|
|
@ -557,60 +557,52 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
int[] minibatchSizes = {2};
|
||||
int width = 5;
|
||||
int height = 5;
|
||||
int[] inputDepths = {1, 2, 4};
|
||||
|
||||
Activation[] activations = {Activation.SIGMOID, Activation.TANH};
|
||||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
for (int inputDepth : inputDepths) {
|
||||
for (Activation afn : activations) {
|
||||
for (int minibatchSize : minibatchSizes) {
|
||||
INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth);
|
||||
INDArray labels = Nd4j.zeros(minibatchSize, nOut);
|
||||
for (int i = 0; i < minibatchSize; i++) {
|
||||
labels.putScalar(new int[]{i, i % nOut}, 1.0);
|
||||
}
|
||||
int[] inputDepths = new int[]{1, 2, 4};
|
||||
Activation[] activations = {Activation.SIGMOID, Activation.TANH, Activation.SOFTPLUS};
|
||||
int[] minibatch = {2, 1, 3};
|
||||
|
||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(new NoOp())
|
||||
.dataType(DataType.DOUBLE)
|
||||
.activation(afn)
|
||||
.list()
|
||||
.layer(0, new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1)
|
||||
.padding(0, 0).nIn(inputDepth).nOut(2).build())//output: (5-2+0)/1+1 = 4
|
||||
.layer(1, new LocallyConnected2D.Builder().nIn(2).nOut(7).kernelSize(2, 2)
|
||||
.setInputSize(4, 4).convolutionMode(ConvolutionMode.Strict).hasBias(false)
|
||||
.stride(1, 1).padding(0, 0).build()) //(4-2+0)/1+1 = 3
|
||||
.layer(2, new ConvolutionLayer.Builder().nIn(7).nOut(2).kernelSize(2, 2)
|
||||
.stride(1, 1).padding(0, 0).build()) //(3-2+0)/1+1 = 2
|
||||
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
|
||||
.activation(Activation.SOFTMAX).nIn(2 * 2 * 2).nOut(nOut)
|
||||
.build())
|
||||
.setInputType(InputType.convolutionalFlat(height, width, inputDepth)).build();
|
||||
for( int i=0; i<inputDepths.length; i++ ){
|
||||
int inputDepth = inputDepths[i];
|
||||
Activation afn = activations[i];
|
||||
int minibatchSize = minibatch[i];
|
||||
INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth);
|
||||
INDArray labels = TestUtils.randomOneHot(minibatchSize, nOut);
|
||||
|
||||
assertEquals(ConvolutionMode.Truncate,
|
||||
((ConvolutionLayer) conf.getConf(0).getLayer()).getConvolutionMode());
|
||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(new NoOp())
|
||||
.dataType(DataType.DOUBLE)
|
||||
.activation(afn)
|
||||
.list()
|
||||
.layer(0, new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1)
|
||||
.padding(0, 0).nIn(inputDepth).nOut(2).build())//output: (5-2+0)/1+1 = 4
|
||||
.layer(1, new LocallyConnected2D.Builder().nIn(2).nOut(7).kernelSize(2, 2)
|
||||
.setInputSize(4, 4).convolutionMode(ConvolutionMode.Strict).hasBias(false)
|
||||
.stride(1, 1).padding(0, 0).build()) //(4-2+0)/1+1 = 3
|
||||
.layer(2, new ConvolutionLayer.Builder().nIn(7).nOut(2).kernelSize(2, 2)
|
||||
.stride(1, 1).padding(0, 0).build()) //(3-2+0)/1+1 = 2
|
||||
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
|
||||
.activation(Activation.SOFTMAX).nIn(2 * 2 * 2).nOut(nOut)
|
||||
.build())
|
||||
.setInputType(InputType.convolutionalFlat(height, width, inputDepth)).build();
|
||||
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
net.init();
|
||||
assertEquals(ConvolutionMode.Truncate,
|
||||
((ConvolutionLayer) conf.getConf(0).getLayer()).getConvolutionMode());
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
System.out.println("nParams, layer " + i + ": " + net.getLayer(i).numParams());
|
||||
}
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
net.init();
|
||||
|
||||
String msg = "Minibatch=" + minibatchSize + ", activationFn="
|
||||
+ afn;
|
||||
System.out.println(msg);
|
||||
String msg = "Minibatch=" + minibatchSize + ", activationFn="
|
||||
+ afn;
|
||||
System.out.println(msg);
|
||||
|
||||
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
|
||||
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
|
||||
|
||||
assertTrue(msg, gradOK);
|
||||
assertTrue(msg, gradOK);
|
||||
|
||||
TestUtils.testModelSerialization(net);
|
||||
}
|
||||
|
||||
}
|
||||
TestUtils.testModelSerialization(net);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1106,71 +1098,64 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
public void testCropping2DLayer() {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
int nOut = 2;
|
||||
|
||||
int[] minibatchSizes = {1, 3};
|
||||
int width = 12;
|
||||
int height = 11;
|
||||
int[] inputDepths = {1, 3};
|
||||
|
||||
int[] kernel = {2, 2};
|
||||
int[] stride = {1, 1};
|
||||
int[] padding = {0, 0};
|
||||
|
||||
int[][] cropTestCases = new int[][]{{0, 0, 0, 0}, {1, 1, 0, 0}, {2, 2, 2, 2}, {1, 2, 3, 4}};
|
||||
int[] inputDepths = {1, 2, 3, 2};
|
||||
int[] minibatchSizes = {2, 1, 3, 2};
|
||||
|
||||
for (int inputDepth : inputDepths) {
|
||||
for (int minibatchSize : minibatchSizes) {
|
||||
INDArray input = Nd4j.rand(new int[]{minibatchSize, inputDepth, height, width});
|
||||
INDArray labels = Nd4j.zeros(minibatchSize, nOut);
|
||||
for (int i = 0; i < minibatchSize; i++) {
|
||||
labels.putScalar(new int[]{i, i % nOut}, 1.0);
|
||||
}
|
||||
for (int[] crop : cropTestCases) {
|
||||
for (int i = 0; i < cropTestCases.length; i++) {
|
||||
int inputDepth = inputDepths[i];
|
||||
int minibatchSize = minibatchSizes[i];
|
||||
int[] crop = cropTestCases[i];
|
||||
INDArray input = Nd4j.rand(new int[]{minibatchSize, inputDepth, height, width});
|
||||
INDArray labels = TestUtils.randomOneHot(minibatchSize, nOut);
|
||||
|
||||
MultiLayerConfiguration conf =
|
||||
new NeuralNetConfiguration.Builder()
|
||||
.dataType(DataType.DOUBLE)
|
||||
.updater(new NoOp())
|
||||
.convolutionMode(ConvolutionMode.Same)
|
||||
.weightInit(new NormalDistribution(0, 1)).list()
|
||||
.layer(new ConvolutionLayer.Builder(kernel, stride, padding)
|
||||
.nIn(inputDepth).nOut(2).build())//output: (6-2+0)/1+1 = 5
|
||||
.layer(new Cropping2D(crop))
|
||||
.layer(new ConvolutionLayer.Builder(kernel, stride, padding).nIn(2).nOut(2).build())
|
||||
.layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.AVG).kernelSize(3, 3).stride(3, 3).build())
|
||||
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
|
||||
.activation(Activation.SOFTMAX).nOut(nOut).build())
|
||||
.setInputType(InputType.convolutional(height, width, inputDepth))
|
||||
.build();
|
||||
MultiLayerConfiguration conf =
|
||||
new NeuralNetConfiguration.Builder()
|
||||
.dataType(DataType.DOUBLE)
|
||||
.updater(new NoOp())
|
||||
.convolutionMode(ConvolutionMode.Same)
|
||||
.weightInit(new NormalDistribution(0, 1)).list()
|
||||
.layer(new ConvolutionLayer.Builder(kernel, stride, padding)
|
||||
.nIn(inputDepth).nOut(2).build())//output: (6-2+0)/1+1 = 5
|
||||
.layer(new Cropping2D(crop))
|
||||
.layer(new ConvolutionLayer.Builder(kernel, stride, padding).nIn(2).nOut(2).build())
|
||||
.layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.AVG).kernelSize(3, 3).stride(3, 3).build())
|
||||
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
|
||||
.activation(Activation.SOFTMAX).nOut(nOut).build())
|
||||
.setInputType(InputType.convolutional(height, width, inputDepth))
|
||||
.build();
|
||||
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
net.init();
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
net.init();
|
||||
|
||||
//Check cropping activation shape
|
||||
org.deeplearning4j.nn.layers.convolution.Cropping2DLayer cl =
|
||||
(org.deeplearning4j.nn.layers.convolution.Cropping2DLayer) net.getLayer(1);
|
||||
val expShape = new long[]{minibatchSize, inputDepth, height - crop[0] - crop[1],
|
||||
width - crop[2] - crop[3]};
|
||||
INDArray out = cl.activate(input, false, LayerWorkspaceMgr.noWorkspaces());
|
||||
assertArrayEquals(expShape, out.shape());
|
||||
//Check cropping activation shape
|
||||
org.deeplearning4j.nn.layers.convolution.Cropping2DLayer cl =
|
||||
(org.deeplearning4j.nn.layers.convolution.Cropping2DLayer) net.getLayer(1);
|
||||
val expShape = new long[]{minibatchSize, inputDepth, height - crop[0] - crop[1],
|
||||
width - crop[2] - crop[3]};
|
||||
INDArray out = cl.activate(input, false, LayerWorkspaceMgr.noWorkspaces());
|
||||
assertArrayEquals(expShape, out.shape());
|
||||
|
||||
String msg = "minibatch=" + minibatchSize + ", channels=" + inputDepth + ", zeroPad = "
|
||||
+ Arrays.toString(crop);
|
||||
String msg = "minibatch=" + minibatchSize + ", channels=" + inputDepth + ", zeroPad = "
|
||||
+ Arrays.toString(crop);
|
||||
|
||||
if (PRINT_RESULTS) {
|
||||
System.out.println(msg);
|
||||
// for (int j = 0; j < net.getnLayers(); j++)
|
||||
// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
|
||||
}
|
||||
|
||||
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input)
|
||||
.labels(labels).subset(true).maxPerParam(160));
|
||||
|
||||
assertTrue(msg, gradOK);
|
||||
|
||||
TestUtils.testModelSerialization(net);
|
||||
}
|
||||
if (PRINT_RESULTS) {
|
||||
System.out.println(msg);
|
||||
}
|
||||
|
||||
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input)
|
||||
.labels(labels).subset(true).maxPerParam(160));
|
||||
|
||||
assertTrue(msg, gradOK);
|
||||
|
||||
TestUtils.testModelSerialization(net);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -142,9 +142,9 @@ public class GradientCheckTestsMasking extends BaseDL4JTest {
|
|||
Nd4j.getRandom().setSeed(12345L);
|
||||
|
||||
int timeSeriesLength = 5;
|
||||
int nIn = 5;
|
||||
int nIn = 3;
|
||||
int layerSize = 3;
|
||||
int nOut = 3;
|
||||
int nOut = 2;
|
||||
|
||||
int miniBatchSize = 2;
|
||||
|
||||
|
@ -170,24 +170,16 @@ public class GradientCheckTestsMasking extends BaseDL4JTest {
|
|||
MultiLayerNetwork mln = new MultiLayerNetwork(conf);
|
||||
mln.init();
|
||||
|
||||
Random r = new Random(12345L);
|
||||
INDArray input = Nd4j.rand(new int[]{miniBatchSize, nIn, timeSeriesLength}, 'f').subi(0.5);
|
||||
|
||||
INDArray labels = Nd4j.zeros(miniBatchSize, nOut, timeSeriesLength);
|
||||
for (int i = 0; i < miniBatchSize; i++) {
|
||||
for (int j = 0; j < nIn; j++) {
|
||||
labels.putScalar(i, r.nextInt(nOut), j, 1.0);
|
||||
}
|
||||
}
|
||||
INDArray labels = TestUtils.randomOneHotTimeSeries(miniBatchSize, nOut, timeSeriesLength);
|
||||
|
||||
if (PRINT_RESULTS) {
|
||||
System.out.println("testBidirectionalLSTMMasking() - testNum = " + testNum++);
|
||||
// for (int j = 0; j < mln.getnLayers(); j++)
|
||||
// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
|
||||
}
|
||||
|
||||
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input)
|
||||
.labels(labels).inputMask(mask).labelMask(mask).subset(true).maxPerParam(16));
|
||||
.labels(labels).inputMask(mask).labelMask(mask).subset(true).maxPerParam(12));
|
||||
|
||||
assertTrue(gradOK);
|
||||
TestUtils.testModelSerialization(mln);
|
||||
|
|
|
@ -123,8 +123,6 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
if (PRINT_RESULTS) {
|
||||
System.out.println(name + " - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation="
|
||||
+ outputActivation + ", doLearningFirst=" + doLearningFirst);
|
||||
for (int j = 0; j < mln.getnLayers(); j++)
|
||||
System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
|
||||
}
|
||||
|
||||
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||
|
@ -214,8 +212,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
System.out.println(testName + "- activationFn=" + afn + ", lossFn=" + lf
|
||||
+ ", outputActivation=" + outputActivation + ", doLearningFirst="
|
||||
+ doLearningFirst);
|
||||
for (int j = 0; j < mln.getnLayers(); j++)
|
||||
System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
|
||||
// for (int j = 0; j < mln.getnLayers(); j++)
|
||||
// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
|
||||
}
|
||||
|
||||
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||
|
@ -277,8 +275,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
|
||||
if (PRINT_RESULTS) {
|
||||
System.out.println(msg);
|
||||
for (int j = 0; j < net.getnLayers(); j++)
|
||||
System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
|
||||
// for (int j = 0; j < net.getnLayers(); j++)
|
||||
// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
|
||||
}
|
||||
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
|
||||
|
@ -340,8 +338,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
|
||||
if (PRINT_RESULTS) {
|
||||
System.out.println(msg);
|
||||
for (int j = 0; j < net.getnLayers(); j++)
|
||||
System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
|
||||
// for (int j = 0; j < net.getnLayers(); j++)
|
||||
// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
|
||||
}
|
||||
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
|
||||
|
@ -397,8 +395,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
|
||||
if (PRINT_RESULTS) {
|
||||
System.out.println(msg);
|
||||
for (int j = 0; j < net.getnLayers(); j++)
|
||||
System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
|
||||
// for (int j = 0; j < net.getnLayers(); j++)
|
||||
// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
|
||||
}
|
||||
|
||||
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||
|
@ -468,8 +466,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
|
||||
if (PRINT_RESULTS) {
|
||||
System.out.println(msg);
|
||||
for (int j = 0; j < net.getnLayers(); j++)
|
||||
System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
|
||||
// for (int j = 0; j < net.getnLayers(); j++)
|
||||
// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
|
||||
}
|
||||
|
||||
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||
|
@ -602,9 +600,9 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
net.init();
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
System.out.println("nParams, layer " + i + ": " + net.getLayer(i).numParams());
|
||||
}
|
||||
// for (int i = 0; i < 4; i++) {
|
||||
// System.out.println("nParams, layer " + i + ": " + net.getLayer(i).numParams());
|
||||
// }
|
||||
|
||||
String msg = "PoolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn="
|
||||
+ afn;
|
||||
|
@ -663,9 +661,9 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
net.init();
|
||||
|
||||
for (int j = 0; j < net.getLayers().length; j++) {
|
||||
System.out.println("nParams, layer " + j + ": " + net.getLayer(j).numParams());
|
||||
}
|
||||
// for (int j = 0; j < net.getLayers().length; j++) {
|
||||
// System.out.println("nParams, layer " + j + ": " + net.getLayer(j).numParams());
|
||||
// }
|
||||
|
||||
String msg = "Minibatch=" + minibatchSize + ", inDepth=" + inputDepth + ", height=" + height
|
||||
+ ", kernelSize=" + k;
|
||||
|
@ -726,9 +724,9 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
net.init();
|
||||
|
||||
for (int i = 0; i < net.getLayers().length; i++) {
|
||||
System.out.println("nParams, layer " + i + ": " + net.getLayer(i).numParams());
|
||||
}
|
||||
// for (int i = 0; i < net.getLayers().length; i++) {
|
||||
// System.out.println("nParams, layer " + i + ": " + net.getLayer(i).numParams());
|
||||
// }
|
||||
|
||||
String msg = "Minibatch=" + minibatchSize + ", inDepth=" + inputDepth + ", height=" + height
|
||||
+ ", kernelSize=" + k + ", stride = " + stride + ", convLayer first = "
|
||||
|
@ -806,8 +804,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
|
||||
if (PRINT_RESULTS) {
|
||||
System.out.println(msg);
|
||||
for (int j = 0; j < net.getnLayers(); j++)
|
||||
System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
|
||||
// for (int j = 0; j < net.getnLayers(); j++)
|
||||
// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
|
||||
}
|
||||
|
||||
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||
|
@ -872,9 +870,9 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
net.init();
|
||||
|
||||
for (int j = 0; j < net.getLayers().length; j++) {
|
||||
System.out.println("nParams, layer " + j + ": " + net.getLayer(j).numParams());
|
||||
}
|
||||
// for (int j = 0; j < net.getLayers().length; j++) {
|
||||
// System.out.println("nParams, layer " + j + ": " + net.getLayer(j).numParams());
|
||||
// }
|
||||
|
||||
String msg = " - mb=" + minibatchSize + ", k="
|
||||
+ k + ", s=" + s + ", d=" + d + ", cm=" + cm;
|
||||
|
@ -943,9 +941,9 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
net.init();
|
||||
|
||||
for (int i = 0; i < net.getLayers().length; i++) {
|
||||
System.out.println("nParams, layer " + i + ": " + net.getLayer(i).numParams());
|
||||
}
|
||||
// for (int i = 0; i < net.getLayers().length; i++) {
|
||||
// System.out.println("nParams, layer " + i + ": " + net.getLayer(i).numParams());
|
||||
// }
|
||||
|
||||
String msg = " - mb=" + minibatchSize + ", k="
|
||||
+ k + ", nIn=" + nIn + ", depthMul=" + depthMultiplier + ", s=" + s + ", cm=" + cm;
|
||||
|
@ -1018,9 +1016,9 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
net.init();
|
||||
|
||||
for (int i = 0; i < net.getLayers().length; i++) {
|
||||
System.out.println("nParams, layer " + i + ": " + net.getLayer(i).numParams());
|
||||
}
|
||||
// for (int i = 0; i < net.getLayers().length; i++) {
|
||||
// System.out.println("nParams, layer " + i + ": " + net.getLayer(i).numParams());
|
||||
// }
|
||||
|
||||
String msg = " - mb=" + minibatchSize + ", k="
|
||||
+ k + ", s=" + s + ", d=" + d + ", cm=" + cm;
|
||||
|
@ -1104,9 +1102,9 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
net.init();
|
||||
|
||||
for (int i = 0; i < net.getLayers().length; i++) {
|
||||
System.out.println("nParams, layer " + i + ": " + net.getLayer(i).numParams());
|
||||
}
|
||||
// for (int i = 0; i < net.getLayers().length; i++) {
|
||||
// System.out.println("nParams, layer " + i + ": " + net.getLayer(i).numParams());
|
||||
// }
|
||||
|
||||
String msg = (subsampling ? "subsampling" : "conv") + " - mb=" + minibatchSize + ", k="
|
||||
+ k + ", s=" + s + ", d=" + d + ", cm=" + cm;
|
||||
|
@ -1179,8 +1177,6 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
|
||||
if (PRINT_RESULTS) {
|
||||
System.out.println(msg);
|
||||
for (int j = 0; j < net.getnLayers(); j++)
|
||||
System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
|
||||
}
|
||||
|
||||
boolean gradOK = GradientCheckUtil.checkGradients(
|
||||
|
|
|
@ -177,7 +177,7 @@ public class KDTreeTest extends BaseDL4JTest {
|
|||
@Test
|
||||
public void testKNN() {
|
||||
int dimensions = 512;
|
||||
int vectorsNo = 50000;
|
||||
int vectorsNo = isIntegrationTests() ? 50000 : 1000;
|
||||
// make a KD-tree of dimension {#dimensions}
|
||||
Stopwatch stopwatch = Stopwatch.createStarted();
|
||||
KDTree kdTree = new KDTree(dimensions);
|
||||
|
|
|
@ -92,13 +92,13 @@ public class SPTreeTest extends BaseDL4JTest {
|
|||
@Test
|
||||
//@Ignore
|
||||
public void testLargeTree() {
|
||||
int num = 100000;
|
||||
int num = isIntegrationTests() ? 100000 : 1000;
|
||||
StopWatch watch = new StopWatch();
|
||||
watch.start();
|
||||
INDArray arr = Nd4j.linspace(1, num, num, Nd4j.dataType()).reshape(num, 1);
|
||||
SpTree tree = new SpTree(arr);
|
||||
watch.stop();
|
||||
System.out.println("Tree created in " + watch);
|
||||
System.out.println("Tree of size " + num + " created in " + watch);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -45,19 +45,19 @@ public class RandomizedInputTest extends RandomizedTest {
|
|||
private Tokenizer tokenizer = new Tokenizer();
|
||||
|
||||
@Test
|
||||
@Repeat(iterations = 50)
|
||||
@Repeat(iterations = 10)
|
||||
public void testRandomizedUnicodeInput() {
|
||||
assertCanTokenizeString(randomUnicodeOfLength(LENGTH), tokenizer);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Repeat(iterations = 50)
|
||||
@Repeat(iterations = 10)
|
||||
public void testRandomizedRealisticUnicodeInput() {
|
||||
assertCanTokenizeString(randomRealisticUnicodeOfLength(LENGTH), tokenizer);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Repeat(iterations = 50)
|
||||
@Repeat(iterations = 10)
|
||||
public void testRandomizedAsciiInput() {
|
||||
assertCanTokenizeString(randomAsciiOfLength(LENGTH), tokenizer);
|
||||
}
|
||||
|
|
|
@ -406,11 +406,11 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
|
|||
double simD = arraysSimilarity(day1, day2);
|
||||
double simN = arraysSimilarity(night1, night2);
|
||||
|
||||
logger.info("Vec1 day: " + day1);
|
||||
logger.info("Vec2 day: " + day2);
|
||||
// logger.info("Vec1 day: " + day1);
|
||||
// logger.info("Vec2 day: " + day2);
|
||||
|
||||
logger.info("Vec1 night: " + night1);
|
||||
logger.info("Vec2 night: " + night2);
|
||||
// logger.info("Vec1 night: " + night1);
|
||||
// logger.info("Vec2 night: " + night2);
|
||||
|
||||
logger.info("Day/day cross-model similarity: " + simD);
|
||||
logger.info("Night/night cross-model similarity: " + simN);
|
||||
|
|
|
@ -16,6 +16,9 @@
|
|||
|
||||
package org.deeplearning4j.models.word2vec;
|
||||
|
||||
import org.apache.commons.io.IOUtils;
|
||||
import org.apache.commons.io.LineIterator;
|
||||
import org.deeplearning4j.text.sentenceiterator.CollectionSentenceIterator;
|
||||
import org.junit.Rule;
|
||||
import org.junit.rules.Timeout;
|
||||
import org.nd4j.shade.guava.primitives.Doubles;
|
||||
|
@ -51,8 +54,8 @@ import org.nd4j.resources.Resources;
|
|||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.io.*;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.*;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
@ -185,7 +188,12 @@ public class Word2VecTests extends BaseDL4JTest {
|
|||
|
||||
@Test
|
||||
public void testWord2VecMultiEpoch() throws Exception {
|
||||
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
||||
SentenceIterator iter;
|
||||
if(isIntegrationTests()){
|
||||
iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
||||
} else {
|
||||
iter = new CollectionSentenceIterator(firstNLines(inputFile, 50000));
|
||||
}
|
||||
|
||||
TokenizerFactory t = new DefaultTokenizerFactory();
|
||||
t.setTokenPreProcessor(new CommonPreprocessor());
|
||||
|
@ -389,7 +397,12 @@ public class Word2VecTests extends BaseDL4JTest {
|
|||
@Test
|
||||
public void testW2VnegativeOnRestore() throws Exception {
|
||||
// Strip white space before and after for each line
|
||||
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
||||
SentenceIterator iter;
|
||||
if(isIntegrationTests()){
|
||||
iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
||||
} else {
|
||||
iter = new CollectionSentenceIterator(firstNLines(inputFile, 300));
|
||||
}
|
||||
// Split on white spaces in the line to get words
|
||||
TokenizerFactory t = new DefaultTokenizerFactory();
|
||||
t.setTokenPreProcessor(new CommonPreprocessor());
|
||||
|
@ -491,7 +504,12 @@ public class Word2VecTests extends BaseDL4JTest {
|
|||
@Test
|
||||
public void orderIsCorrect_WhenParallelized() throws Exception {
|
||||
// Strip white space before and after for each line
|
||||
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
||||
SentenceIterator iter;
|
||||
if(isIntegrationTests()){
|
||||
iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
||||
} else {
|
||||
iter = new CollectionSentenceIterator(firstNLines(inputFile, 300));
|
||||
}
|
||||
// Split on white spaces in the line to get words
|
||||
TokenizerFactory t = new DefaultTokenizerFactory();
|
||||
t.setTokenPreProcessor(new CommonPreprocessor());
|
||||
|
@ -510,9 +528,10 @@ public class Word2VecTests extends BaseDL4JTest {
|
|||
System.out.println(vec.getVocab().numWords());
|
||||
|
||||
val words = vec.getVocab().words();
|
||||
for (val word : words) {
|
||||
System.out.println(word);
|
||||
}
|
||||
assertTrue(words.size() > 0);
|
||||
// for (val word : words) {
|
||||
// System.out.println(word);
|
||||
// }
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -755,7 +774,16 @@ public class Word2VecTests extends BaseDL4JTest {
|
|||
@Test
|
||||
public void weightsNotUpdated_WhenLocked() throws Exception {
|
||||
|
||||
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
||||
boolean isIntegration = isIntegrationTests();
|
||||
SentenceIterator iter;
|
||||
SentenceIterator iter2;
|
||||
if(isIntegration){
|
||||
iter = new BasicLineIterator(inputFile);
|
||||
iter2 = new BasicLineIterator(inputFile2.getAbsolutePath());
|
||||
} else {
|
||||
iter = new CollectionSentenceIterator(firstNLines(inputFile, 300));
|
||||
iter2 = new CollectionSentenceIterator(firstNLines(inputFile2, 300));
|
||||
}
|
||||
|
||||
Word2Vec vec1 = new Word2Vec.Builder().minWordFrequency(1).iterations(3).batchSize(64).layerSize(100)
|
||||
.stopWords(new ArrayList<String>()).seed(42).learningRate(0.025).minLearningRate(0.001)
|
||||
|
@ -767,13 +795,12 @@ public class Word2VecTests extends BaseDL4JTest {
|
|||
|
||||
vec1.fit();
|
||||
|
||||
iter = new BasicLineIterator(inputFile2.getAbsolutePath());
|
||||
Word2Vec vec2 = new Word2Vec.Builder().minWordFrequency(1).iterations(3).batchSize(32).layerSize(100)
|
||||
.stopWords(new ArrayList<String>()).seed(32).learningRate(0.021).minLearningRate(0.001)
|
||||
.sampling(0).elementsLearningAlgorithm(new SkipGram<VocabWord>())
|
||||
.epochs(1).windowSize(5).allowParallelTokenization(true)
|
||||
.workers(1)
|
||||
.iterate(iter)
|
||||
.iterate(iter2)
|
||||
.intersectModel(vec1, true)
|
||||
.modelUtils(new BasicModelUtils<VocabWord>()).build();
|
||||
|
||||
|
@ -861,6 +888,22 @@ public class Word2VecTests extends BaseDL4JTest {
|
|||
}
|
||||
System.out.print("\n");
|
||||
}
|
||||
//
|
||||
|
||||
public static List<String> firstNLines(File f, int n){
|
||||
List<String> lines = new ArrayList<>();
|
||||
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||
LineIterator lineIter = IOUtils.lineIterator(is, StandardCharsets.UTF_8);
|
||||
try{
|
||||
for( int i=0; i<n && lineIter.hasNext(); i++ ){
|
||||
lines.add(lineIter.next());
|
||||
}
|
||||
} finally {
|
||||
lineIter.close();
|
||||
}
|
||||
return lines;
|
||||
} catch (IOException e){
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -48,7 +48,7 @@ public class TsneTest extends BaseDL4JTest {
|
|||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 120000L;
|
||||
return 60000L;
|
||||
}
|
||||
|
||||
@Rule
|
||||
|
@ -58,103 +58,102 @@ public class TsneTest extends BaseDL4JTest {
|
|||
public void testSimple() throws Exception {
|
||||
//Simple sanity check
|
||||
|
||||
for (boolean syntheticData : new boolean[]{false, true}) {
|
||||
for (WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.NONE, WorkspaceMode.ENABLED}) {
|
||||
log.info("Starting test: WSM={}, syntheticData={}", wsm, syntheticData);
|
||||
for( int test=0; test <=1; test++){
|
||||
boolean syntheticData = test == 1;
|
||||
WorkspaceMode wsm = test == 0 ? WorkspaceMode.NONE : WorkspaceMode.ENABLED;
|
||||
log.info("Starting test: WSM={}, syntheticData={}", wsm, syntheticData);
|
||||
|
||||
//STEP 1: Initialization
|
||||
int iterations = 50;
|
||||
//create an n-dimensional array of doubles
|
||||
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
|
||||
List<String> cacheList = new ArrayList<>(); //cacheList is a dynamic array of strings used to hold all words
|
||||
//STEP 1: Initialization
|
||||
int iterations = 50;
|
||||
//create an n-dimensional array of doubles
|
||||
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
|
||||
List<String> cacheList = new ArrayList<>(); //cacheList is a dynamic array of strings used to hold all words
|
||||
|
||||
//STEP 2: Turn text input into a list of words
|
||||
INDArray weights;
|
||||
if(syntheticData){
|
||||
weights = Nd4j.rand(1000, 200);
|
||||
} else {
|
||||
log.info("Load & Vectorize data....");
|
||||
File wordFile = new ClassPathResource("deeplearning4j-tsne/words.txt").getFile(); //Open the file
|
||||
//Get the data of all unique word vectors
|
||||
Pair<InMemoryLookupTable, VocabCache> vectors = WordVectorSerializer.loadTxt(wordFile);
|
||||
VocabCache cache = vectors.getSecond();
|
||||
weights = vectors.getFirst().getSyn0(); //seperate weights of unique words into their own list
|
||||
//STEP 2: Turn text input into a list of words
|
||||
INDArray weights;
|
||||
if(syntheticData){
|
||||
weights = Nd4j.rand(250, 200);
|
||||
} else {
|
||||
log.info("Load & Vectorize data....");
|
||||
File wordFile = new ClassPathResource("deeplearning4j-tsne/words.txt").getFile(); //Open the file
|
||||
//Get the data of all unique word vectors
|
||||
Pair<InMemoryLookupTable, VocabCache> vectors = WordVectorSerializer.loadTxt(wordFile);
|
||||
VocabCache cache = vectors.getSecond();
|
||||
weights = vectors.getFirst().getSyn0(); //seperate weights of unique words into their own list
|
||||
|
||||
for (int i = 0; i < cache.numWords(); i++) //seperate strings of words into their own list
|
||||
cacheList.add(cache.wordAtIndex(i));
|
||||
}
|
||||
|
||||
//STEP 3: build a dual-tree tsne to use later
|
||||
log.info("Build model....");
|
||||
BarnesHutTsne tsne = new BarnesHutTsne.Builder()
|
||||
.setMaxIter(iterations)
|
||||
.theta(0.5)
|
||||
.normalize(false)
|
||||
.learningRate(500)
|
||||
.useAdaGrad(false)
|
||||
.workspaceMode(wsm)
|
||||
.build();
|
||||
|
||||
|
||||
//STEP 4: establish the tsne values and save them to a file
|
||||
log.info("Store TSNE Coordinates for Plotting....");
|
||||
File outDir = testDir.newFolder();
|
||||
tsne.fit(weights);
|
||||
tsne.saveAsFile(cacheList, new File(outDir, "out.txt").getAbsolutePath());
|
||||
for (int i = 0; i < cache.numWords(); i++) //seperate strings of words into their own list
|
||||
cacheList.add(cache.wordAtIndex(i));
|
||||
}
|
||||
|
||||
//STEP 3: build a dual-tree tsne to use later
|
||||
log.info("Build model....");
|
||||
BarnesHutTsne tsne = new BarnesHutTsne.Builder()
|
||||
.setMaxIter(iterations)
|
||||
.theta(0.5)
|
||||
.normalize(false)
|
||||
.learningRate(500)
|
||||
.useAdaGrad(false)
|
||||
.workspaceMode(wsm)
|
||||
.build();
|
||||
|
||||
|
||||
//STEP 4: establish the tsne values and save them to a file
|
||||
log.info("Store TSNE Coordinates for Plotting....");
|
||||
File outDir = testDir.newFolder();
|
||||
tsne.fit(weights);
|
||||
tsne.saveAsFile(cacheList, new File(outDir, "out.txt").getAbsolutePath());
|
||||
}
|
||||
}
|
||||
|
||||
//Elapsed time : 01:01:57.988
|
||||
@Test
|
||||
public void testPerformance() throws Exception {
|
||||
|
||||
StopWatch watch = new StopWatch();
|
||||
watch.start();
|
||||
for (boolean syntheticData : new boolean[]{false, true}) {
|
||||
for (WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.NONE, WorkspaceMode.ENABLED}) {
|
||||
log.info("Starting test: WSM={}, syntheticData={}", wsm, syntheticData);
|
||||
for( int test=0; test <=1; test++){
|
||||
boolean syntheticData = test == 1;
|
||||
WorkspaceMode wsm = test == 0 ? WorkspaceMode.NONE : WorkspaceMode.ENABLED;
|
||||
log.info("Starting test: WSM={}, syntheticData={}", wsm, syntheticData);
|
||||
|
||||
//STEP 1: Initialization
|
||||
int iterations = 100;
|
||||
//create an n-dimensional array of doubles
|
||||
Nd4j.setDataType(DataType.DOUBLE);
|
||||
List<String> cacheList = new ArrayList<>(); //cacheList is a dynamic array of strings used to hold all words
|
||||
//STEP 1: Initialization
|
||||
int iterations = 50;
|
||||
//create an n-dimensional array of doubles
|
||||
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
|
||||
List<String> cacheList = new ArrayList<>(); //cacheList is a dynamic array of strings used to hold all words
|
||||
|
||||
//STEP 2: Turn text input into a list of words
|
||||
INDArray weights;
|
||||
if(syntheticData){
|
||||
weights = Nd4j.rand(5000, 20);
|
||||
} else {
|
||||
log.info("Load & Vectorize data....");
|
||||
File wordFile = new ClassPathResource("deeplearning4j-tsne/words.txt").getFile(); //Open the file
|
||||
//Get the data of all unique word vectors
|
||||
Pair<InMemoryLookupTable, VocabCache> vectors = WordVectorSerializer.loadTxt(wordFile);
|
||||
VocabCache cache = vectors.getSecond();
|
||||
weights = vectors.getFirst().getSyn0(); //seperate weights of unique words into their own list
|
||||
//STEP 2: Turn text input into a list of words
|
||||
INDArray weights;
|
||||
if(syntheticData){
|
||||
weights = Nd4j.rand(DataType.FLOAT, 250, 20);
|
||||
} else {
|
||||
log.info("Load & Vectorize data....");
|
||||
File wordFile = new ClassPathResource("deeplearning4j-tsne/words.txt").getFile(); //Open the file
|
||||
//Get the data of all unique word vectors
|
||||
Pair<InMemoryLookupTable, VocabCache> vectors = WordVectorSerializer.loadTxt(wordFile);
|
||||
VocabCache cache = vectors.getSecond();
|
||||
weights = vectors.getFirst().getSyn0(); //seperate weights of unique words into their own list
|
||||
|
||||
for (int i = 0; i < cache.numWords(); i++) //seperate strings of words into their own list
|
||||
cacheList.add(cache.wordAtIndex(i));
|
||||
}
|
||||
|
||||
//STEP 3: build a dual-tree tsne to use later
|
||||
log.info("Build model....");
|
||||
BarnesHutTsne tsne = new BarnesHutTsne.Builder()
|
||||
.setMaxIter(iterations)
|
||||
.theta(0.5)
|
||||
.normalize(false)
|
||||
.learningRate(500)
|
||||
.useAdaGrad(false)
|
||||
.workspaceMode(wsm)
|
||||
.build();
|
||||
|
||||
|
||||
//STEP 4: establish the tsne values and save them to a file
|
||||
log.info("Store TSNE Coordinates for Plotting....");
|
||||
File outDir = testDir.newFolder();
|
||||
tsne.fit(weights);
|
||||
tsne.saveAsFile(cacheList, new File(outDir, "out.txt").getAbsolutePath());
|
||||
for (int i = 0; i < cache.numWords(); i++) //seperate strings of words into their own list
|
||||
cacheList.add(cache.wordAtIndex(i));
|
||||
}
|
||||
|
||||
//STEP 3: build a dual-tree tsne to use later
|
||||
log.info("Build model....");
|
||||
BarnesHutTsne tsne = new BarnesHutTsne.Builder()
|
||||
.setMaxIter(iterations)
|
||||
.theta(0.5)
|
||||
.normalize(false)
|
||||
.learningRate(500)
|
||||
.useAdaGrad(false)
|
||||
.workspaceMode(wsm)
|
||||
.build();
|
||||
|
||||
|
||||
//STEP 4: establish the tsne values and save them to a file
|
||||
log.info("Store TSNE Coordinates for Plotting....");
|
||||
File outDir = testDir.newFolder();
|
||||
tsne.fit(weights);
|
||||
tsne.saveAsFile(cacheList, new File(outDir, "out.txt").getAbsolutePath());
|
||||
}
|
||||
watch.stop();
|
||||
System.out.println("Elapsed time : " + watch);
|
||||
|
|
|
@ -20,6 +20,8 @@ package org.deeplearning4j.models.paragraphvectors;
|
|||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.apache.commons.io.IOUtils;
|
||||
import org.apache.commons.io.LineIterator;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.models.embeddings.learning.impl.elements.CBOW;
|
||||
import org.deeplearning4j.models.embeddings.reader.impl.FlatModelUtils;
|
||||
|
@ -27,6 +29,7 @@ import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
|
|||
import org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer;
|
||||
import org.deeplearning4j.models.sequencevectors.transformers.impl.iterables.BasicTransformerIterator;
|
||||
import org.deeplearning4j.models.sequencevectors.transformers.impl.iterables.ParallelTransformerIterator;
|
||||
import org.deeplearning4j.text.sentenceiterator.*;
|
||||
import org.junit.Rule;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
import org.nd4j.linalg.io.ClassPathResource;
|
||||
|
@ -46,10 +49,6 @@ import org.deeplearning4j.text.documentiterator.FileLabelAwareIterator;
|
|||
import org.deeplearning4j.text.documentiterator.LabelAwareIterator;
|
||||
import org.deeplearning4j.text.documentiterator.LabelledDocument;
|
||||
import org.deeplearning4j.text.documentiterator.LabelsSource;
|
||||
import org.deeplearning4j.text.sentenceiterator.AggregatingSentenceIterator;
|
||||
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
|
||||
import org.deeplearning4j.text.sentenceiterator.FileSentenceIterator;
|
||||
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
|
||||
import org.deeplearning4j.text.sentenceiterator.interoperability.SentenceIteratorConverter;
|
||||
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
|
||||
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
|
||||
|
@ -66,8 +65,8 @@ import org.nd4j.resources.Resources;
|
|||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.io.*;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.atomic.AtomicLong;
|
||||
|
||||
|
@ -372,7 +371,7 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
|
|||
|
||||
LabelsSource source = new LabelsSource("DOC_");
|
||||
|
||||
ParagraphVectors vec = new ParagraphVectors.Builder().minWordFrequency(1).iterations(2).seed(119).epochs(3)
|
||||
ParagraphVectors vec = new ParagraphVectors.Builder().minWordFrequency(1).iterations(2).seed(119).epochs(1)
|
||||
.layerSize(100).learningRate(0.025).labelsSource(source).windowSize(5).iterate(iter)
|
||||
.trainWordVectors(true).vocabCache(cache).tokenizerFactory(t).negativeSample(0)
|
||||
.useHierarchicSoftmax(true).sampling(0).workers(1).usePreciseWeightInit(true)
|
||||
|
@ -425,6 +424,8 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
|
|||
|
||||
@Test(timeout = 300000)
|
||||
public void testParagraphVectorsDBOW() throws Exception {
|
||||
skipUnlessIntegrationTests();
|
||||
|
||||
File file = Resources.asFile("/big/raw_sentences.txt");
|
||||
SentenceIterator iter = new BasicLineIterator(file);
|
||||
|
||||
|
@ -657,7 +658,7 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test(timeout = 300000)
|
||||
@Test
|
||||
public void testIterator() throws IOException {
|
||||
val folder_labeled = testDir.newFolder();
|
||||
val folder_unlabeled = testDir.newFolder();
|
||||
|
@ -672,7 +673,7 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
|
|||
SentenceIterator iter = new BasicLineIterator(resource_sentences);
|
||||
|
||||
int i = 0;
|
||||
for (; i < 10000; ++i) {
|
||||
for (; i < 10; ++i) {
|
||||
int j = 0;
|
||||
int labels = 0;
|
||||
int words = 0;
|
||||
|
@ -721,7 +722,7 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
|
|||
TokenizerFactory t = new DefaultTokenizerFactory();
|
||||
t.setTokenPreProcessor(new CommonPreprocessor());
|
||||
|
||||
Word2Vec wordVectors = new Word2Vec.Builder().seed(119).minWordFrequency(1).batchSize(250).iterations(1).epochs(3)
|
||||
Word2Vec wordVectors = new Word2Vec.Builder().seed(119).minWordFrequency(1).batchSize(250).iterations(1).epochs(1)
|
||||
.learningRate(0.025).layerSize(150).minLearningRate(0.001)
|
||||
.elementsLearningAlgorithm(new SkipGram<VocabWord>()).useHierarchicSoftmax(true).windowSize(5)
|
||||
.allowParallelTokenization(true)
|
||||
|
@ -1009,7 +1010,7 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
|
|||
TokenizerFactory t = new DefaultTokenizerFactory();
|
||||
t.setTokenPreProcessor(new CommonPreprocessor());
|
||||
|
||||
Word2Vec wordVectors = new Word2Vec.Builder().minWordFrequency(1).batchSize(250).iterations(1).epochs(3)
|
||||
Word2Vec wordVectors = new Word2Vec.Builder().minWordFrequency(1).batchSize(250).iterations(1).epochs(1)
|
||||
.learningRate(0.025).layerSize(150).minLearningRate(0.001)
|
||||
.elementsLearningAlgorithm(new SkipGram<VocabWord>()).useHierarchicSoftmax(true).windowSize(5)
|
||||
.iterate(iter).tokenizerFactory(t).build();
|
||||
|
@ -1151,8 +1152,27 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
|
|||
|
||||
@Test(timeout = 300000)
|
||||
public void testDoubleFit() throws Exception {
|
||||
boolean isIntegration = isIntegrationTests();
|
||||
File resource = Resources.asFile("/big/raw_sentences.txt");
|
||||
SentenceIterator iter = new BasicLineIterator(resource);
|
||||
SentenceIterator iter;
|
||||
if(isIntegration){
|
||||
iter = new BasicLineIterator(resource);
|
||||
} else {
|
||||
List<String> lines = new ArrayList<>();
|
||||
try(InputStream is = new BufferedInputStream(new FileInputStream(resource))){
|
||||
LineIterator lineIter = IOUtils.lineIterator(is, StandardCharsets.UTF_8);
|
||||
try{
|
||||
for( int i=0; i<500 && lineIter.hasNext(); i++ ){
|
||||
lines.add(lineIter.next());
|
||||
}
|
||||
} finally {
|
||||
lineIter.close();
|
||||
}
|
||||
}
|
||||
|
||||
iter = new CollectionSentenceIterator(lines);
|
||||
}
|
||||
|
||||
|
||||
TokenizerFactory t = new DefaultTokenizerFactory();
|
||||
t.setTokenPreProcessor(new CommonPreprocessor());
|
||||
|
|
|
@ -49,7 +49,7 @@ public class Word2VecDataSetIteratorTest extends BaseDL4JTest {
|
|||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 240000L;
|
||||
return 60000L;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -57,6 +57,7 @@ public class Word2VecDataSetIteratorTest extends BaseDL4JTest {
|
|||
*/
|
||||
@Test
|
||||
public void testIterator1() throws Exception {
|
||||
|
||||
File inputFile = Resources.asFile("big/raw_sentences.txt");
|
||||
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
||||
|
||||
|
@ -77,10 +78,14 @@ public class Word2VecDataSetIteratorTest extends BaseDL4JTest {
|
|||
|
||||
Word2VecDataSetIterator iterator = new Word2VecDataSetIterator(vec, getLASI(iter, labels), labels, 1);
|
||||
INDArray array = iterator.next().getFeatures();
|
||||
int count = 0;
|
||||
while (iterator.hasNext()) {
|
||||
DataSet ds = iterator.next();
|
||||
|
||||
assertArrayEquals(array.shape(), ds.getFeatures().shape());
|
||||
|
||||
if(!isIntegrationTests() && count++ > 20)
|
||||
break; //raw_sentences.txt is 2.81 MB, takes quite some time to process. We'll only first 20 minibatches when doing unit tests
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -45,9 +45,15 @@ public class EncodedGradientsAccumulatorTest extends BaseDL4JTest {
|
|||
*/
|
||||
@Test
|
||||
public void testStore1() throws Exception {
|
||||
int numParams = 100000;
|
||||
|
||||
int workers[] = new int[] {2, 4, 8};
|
||||
int numParams;
|
||||
int[] workers;
|
||||
if(isIntegrationTests()){
|
||||
numParams = 100000;
|
||||
workers = new int[] {2, 4, 8};
|
||||
} else {
|
||||
numParams = 10000;
|
||||
workers = new int[] {2, 3};
|
||||
}
|
||||
|
||||
for (int numWorkers : workers) {
|
||||
EncodingHandler handler = new EncodingHandler(new FixedThresholdAlgorithm(1e-3),null, null, false);
|
||||
|
@ -77,7 +83,13 @@ public class EncodedGradientsAccumulatorTest extends BaseDL4JTest {
|
|||
*/
|
||||
@Test
|
||||
public void testEncodingLimits1() throws Exception {
|
||||
int numParams = 100000;
|
||||
int numParams;
|
||||
if(isIntegrationTests()){
|
||||
numParams = 100000;
|
||||
} else {
|
||||
numParams = 10000;
|
||||
}
|
||||
|
||||
|
||||
EncodingHandler handler = new EncodingHandler(new FixedThresholdAlgorithm(1e-3), null, null, false);
|
||||
for (int e = 10; e < numParams / 5; e++) {
|
||||
|
|
|
@ -242,7 +242,7 @@ public class IndexedTailTest extends BaseDL4JTest {
|
|||
final long[] sums = new long[numReaders];
|
||||
val readers = new ArrayList<Thread>();
|
||||
for (int e = 0; e < numReaders; e++) {
|
||||
val f = e;
|
||||
final int f = e;
|
||||
val t = new Thread(new Runnable() {
|
||||
@Override
|
||||
public void run() {
|
||||
|
@ -297,7 +297,7 @@ public class IndexedTailTest extends BaseDL4JTest {
|
|||
final long[] sums = new long[numReaders];
|
||||
val readers = new ArrayList<Thread>();
|
||||
for (int e = 0; e < numReaders; e++) {
|
||||
val f = e;
|
||||
final int f = e;
|
||||
val t = new Thread(new Runnable() {
|
||||
@Override
|
||||
public void run() {
|
||||
|
@ -371,7 +371,7 @@ public class IndexedTailTest extends BaseDL4JTest {
|
|||
final long[] sums = new long[numReaders];
|
||||
val readers = new ArrayList<Thread>();
|
||||
for (int e = 0; e < numReaders; e++) {
|
||||
val f = e;
|
||||
final int f = e;
|
||||
val t = new Thread(new Runnable() {
|
||||
@Override
|
||||
public void run() {
|
||||
|
|
|
@ -35,6 +35,7 @@ import org.deeplearning4j.remote.helpers.House;
|
|||
import org.deeplearning4j.remote.helpers.HouseToPredictedPriceAdapter;
|
||||
import org.deeplearning4j.remote.helpers.PredictedPrice;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.adapters.InferenceAdapter;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
|
@ -58,6 +59,7 @@ import java.util.Collections;
|
|||
import java.util.concurrent.ExecutionException;
|
||||
import java.util.concurrent.Future;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
import static org.deeplearning4j.parallelism.inference.InferenceMode.INPLACE;
|
||||
import static org.deeplearning4j.parallelism.inference.InferenceMode.SEQUENTIAL;
|
||||
|
@ -66,7 +68,6 @@ import static org.junit.Assert.*;
|
|||
@Slf4j
|
||||
public class JsonModelServerTest extends BaseDL4JTest {
|
||||
private static final MultiLayerNetwork model;
|
||||
private final int PORT = 18080;
|
||||
|
||||
static {
|
||||
val conf = new NeuralNetConfiguration.Builder()
|
||||
|
@ -84,10 +85,18 @@ public class JsonModelServerTest extends BaseDL4JTest {
|
|||
|
||||
@After
|
||||
public void pause() throws Exception {
|
||||
// TODO: the same port was used in previous test and not accessible immediately. Might be better solution.
|
||||
// Need to wait for server shutdown; without sleep, tests will fail if starting immediately after shutdown
|
||||
TimeUnit.SECONDS.sleep(2);
|
||||
}
|
||||
|
||||
private AtomicInteger portCount = new AtomicInteger(18080);
|
||||
private int PORT;
|
||||
|
||||
@Before
|
||||
public void setPort(){
|
||||
PORT = portCount.getAndIncrement();
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testStartStopParallel() throws Exception {
|
||||
|
@ -343,7 +352,7 @@ public class JsonModelServerTest extends BaseDL4JTest {
|
|||
val server = new JsonModelServer.Builder<House, PredictedPrice>(model)
|
||||
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
|
||||
.inputDeserializer(null)
|
||||
.port(18080)
|
||||
.port(PORT)
|
||||
.build();
|
||||
}
|
||||
|
||||
|
@ -382,7 +391,7 @@ public class JsonModelServerTest extends BaseDL4JTest {
|
|||
return null;
|
||||
}
|
||||
})
|
||||
.endpointAddress("http://localhost:18080/v1/serving")
|
||||
.endpointAddress("http://localhost:" + PORT + "/v1/serving")
|
||||
.build();
|
||||
|
||||
int district = 2;
|
||||
|
|
|
@ -485,7 +485,8 @@ public class ParallelInferenceTest extends BaseDL4JTest {
|
|||
List<INDArray> exp = new ArrayList<>();
|
||||
|
||||
Random r = new Random();
|
||||
for (int i = 0; i < 500; i++) {
|
||||
int runs = isIntegrationTests() ? 500 : 30;
|
||||
for (int i = 0; i < runs; i++) {
|
||||
int[] shape = defaultSize;
|
||||
if (r.nextDouble() < 0.4) {
|
||||
shape = new int[]{r.nextInt(5) + 1, 10, r.nextInt(10) + 1};
|
||||
|
@ -597,7 +598,8 @@ public class ParallelInferenceTest extends BaseDL4JTest {
|
|||
List<INDArray> arrs = new ArrayList<>();
|
||||
List<INDArray> exp = new ArrayList<>();
|
||||
Random r = new Random();
|
||||
for( int i=0; i<500; i++ ){
|
||||
int runs = isIntegrationTests() ? 500 : 20;
|
||||
for( int i=0; i<runs; i++ ){
|
||||
int[] shape = defaultShape;
|
||||
if(r.nextDouble() < 0.4){
|
||||
shape = new int[]{r.nextInt(5)+1, nIn, 10, r.nextInt(10)+1};
|
||||
|
@ -679,8 +681,7 @@ public class ParallelInferenceTest extends BaseDL4JTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test(timeout = 120000)
|
||||
public void testInputMasking() throws Exception {
|
||||
private void testInputMasking() throws Exception {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
int nIn = 10;
|
||||
|
@ -698,12 +699,15 @@ public class ParallelInferenceTest extends BaseDL4JTest {
|
|||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
net.init();
|
||||
|
||||
// InferenceMode[] inferenceModes = new InferenceMode[]{InferenceMode.SEQUENTIAL, InferenceMode.BATCHED, InferenceMode.INPLACE, InferenceMode.SEQUENTIAL};
|
||||
// int[] workers = new int[]{2, 2, 2, 1};
|
||||
// boolean[] randomTS = new boolean[]{true, false, true, false};
|
||||
|
||||
Random r = new Random();
|
||||
for( InferenceMode m : InferenceMode.values()) {
|
||||
log.info("Testing inference mode: [{}]", m);
|
||||
for( int w : new int[]{1,2}) {
|
||||
for (boolean randomTSLength : new boolean[]{false, true}) {
|
||||
|
||||
final ParallelInference inf =
|
||||
new ParallelInference.Builder(net)
|
||||
.inferenceMode(m)
|
||||
|
@ -714,7 +718,8 @@ public class ParallelInferenceTest extends BaseDL4JTest {
|
|||
List<INDArray> in = new ArrayList<>();
|
||||
List<INDArray> inMasks = new ArrayList<>();
|
||||
List<INDArray> exp = new ArrayList<>();
|
||||
for (int i = 0; i < 100; i++) {
|
||||
int nRuns = isIntegrationTests() ? 100 : 10;
|
||||
for (int i = 0; i < nRuns; i++) {
|
||||
int currTSLength = (randomTSLength ? 1 + r.nextInt(tsLength) : tsLength);
|
||||
int currNumEx = 1 + r.nextInt(3);
|
||||
INDArray inArr = Nd4j.rand(new int[]{currNumEx, nIn, currTSLength});
|
||||
|
@ -847,6 +852,7 @@ public class ParallelInferenceTest extends BaseDL4JTest {
|
|||
|
||||
List<INDArray[]> in = new ArrayList<>();
|
||||
List<INDArray[]> exp = new ArrayList<>();
|
||||
int runs = isIntegrationTests() ? 100 : 20;
|
||||
for (int i = 0; i < 100; i++) {
|
||||
int currNumEx = 1 + r.nextInt(3);
|
||||
INDArray inArr = Nd4j.rand(new int[]{currNumEx, nIn});
|
||||
|
|
|
@ -62,8 +62,8 @@ public class ParallelWrapperTest extends BaseDL4JTest {
|
|||
int seed = 123;
|
||||
|
||||
log.info("Load data....");
|
||||
DataSetIterator mnistTrain = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(batchSize, true, 12345), 100);
|
||||
DataSetIterator mnistTest = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(batchSize, false, 12345), 10);
|
||||
DataSetIterator mnistTrain = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(batchSize, true, 12345), 15);
|
||||
DataSetIterator mnistTest = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(batchSize, false, 12345), 4);
|
||||
|
||||
assertTrue(mnistTrain.hasNext());
|
||||
val t0 = mnistTrain.next();
|
||||
|
|
|
@ -47,6 +47,12 @@
|
|||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-parameter-server-node_2.11</artifactId>
|
||||
<version>${nd4j.version}</version>
|
||||
<exclusions>
|
||||
<exclusion>
|
||||
<groupId>net.jpountz.lz4</groupId>
|
||||
<artifactId>lz4</artifactId>
|
||||
</exclusion>
|
||||
</exclusions>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>junit</groupId>
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
################################################################################
|
||||
# Copyright (c) 2015-2019 Skymind, Inc.
|
||||
#
|
||||
# This program and the accompanying materials are made available under the
|
||||
# terms of the Apache License, Version 2.0 which is available at
|
||||
# https://www.apache.org/licenses/LICENSE-2.0.
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
################################################################################
|
||||
|
||||
log4j.rootLogger=ERROR, Console
|
||||
log4j.appender.Console=org.apache.log4j.ConsoleAppender
|
||||
log4j.appender.Console.layout=org.apache.log4j.PatternLayout
|
||||
log4j.appender.Console.layout.ConversionPattern=%d{ABSOLUTE} %-5p ~ %m%n
|
||||
|
||||
log4j.appender.org.springframework=DEBUG
|
||||
log4j.appender.org.deeplearning4j=DEBUG
|
||||
log4j.appender.org.nd4j=DEBUG
|
||||
|
||||
log4j.logger.org.springframework=INFO
|
||||
log4j.logger.org.deeplearning4j=DEBUG
|
||||
log4j.logger.org.nd4j=DEBUG
|
||||
log4j.logger.org.apache.spark=WARN
|
||||
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
~ Copyright (c) 2015-2018 Skymind, Inc.
|
||||
~
|
||||
~ This program and the accompanying materials are made available under the
|
||||
~ terms of the Apache License, Version 2.0 which is available at
|
||||
~ https://www.apache.org/licenses/LICENSE-2.0.
|
||||
~
|
||||
~ Unless required by applicable law or agreed to in writing, software
|
||||
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
~ License for the specific language governing permissions and limitations
|
||||
~ under the License.
|
||||
~
|
||||
~ SPDX-License-Identifier: Apache-2.0
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
|
||||
|
||||
<configuration>
|
||||
|
||||
|
||||
|
||||
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
|
||||
<file>logs/application.log</file>
|
||||
<encoder>
|
||||
<pattern>%date - [%level] - from %logger in %thread
|
||||
%n%message%n%xException%n</pattern>
|
||||
</encoder>
|
||||
</appender>
|
||||
|
||||
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
|
||||
<encoder>
|
||||
<pattern> %logger{15} - %message%n%xException{5}
|
||||
</pattern>
|
||||
</encoder>
|
||||
</appender>
|
||||
|
||||
<logger name="org.apache.catalina.core" level="DEBUG" />
|
||||
<logger name="org.springframework" level="DEBUG" />
|
||||
<logger name="org.deeplearning4j" level="DEBUG" />
|
||||
<logger name="org.datavec" level="INFO" />
|
||||
<logger name="org.nd4j" level="INFO" />
|
||||
<logger name="opennlp.uima.util" level="OFF" />
|
||||
<logger name="org.apache.uima" level="OFF" />
|
||||
<logger name="org.cleartk" level="OFF" />
|
||||
<logger name="org.apache.spark" level="WARN" />
|
||||
|
||||
|
||||
|
||||
<root level="ERROR">
|
||||
<appender-ref ref="STDOUT" />
|
||||
<appender-ref ref="FILE" />
|
||||
</root>
|
||||
|
||||
</configuration>
|
|
@ -0,0 +1,31 @@
|
|||
################################################################################
|
||||
# Copyright (c) 2015-2019 Skymind, Inc.
|
||||
#
|
||||
# This program and the accompanying materials are made available under the
|
||||
# terms of the Apache License, Version 2.0 which is available at
|
||||
# https://www.apache.org/licenses/LICENSE-2.0.
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
################################################################################
|
||||
|
||||
log4j.rootLogger=ERROR, Console
|
||||
log4j.appender.Console=org.apache.log4j.ConsoleAppender
|
||||
log4j.appender.Console.layout=org.apache.log4j.PatternLayout
|
||||
log4j.appender.Console.layout.ConversionPattern=%d{ABSOLUTE} %-5p ~ %m%n
|
||||
|
||||
log4j.appender.org.springframework=DEBUG
|
||||
log4j.appender.org.deeplearning4j=DEBUG
|
||||
log4j.appender.org.nd4j=DEBUG
|
||||
|
||||
log4j.logger.org.springframework=INFO
|
||||
log4j.logger.org.deeplearning4j=DEBUG
|
||||
log4j.logger.org.nd4j=DEBUG
|
||||
log4j.logger.org.apache.spark=WARN
|
||||
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
~ Copyright (c) 2015-2018 Skymind, Inc.
|
||||
~
|
||||
~ This program and the accompanying materials are made available under the
|
||||
~ terms of the Apache License, Version 2.0 which is available at
|
||||
~ https://www.apache.org/licenses/LICENSE-2.0.
|
||||
~
|
||||
~ Unless required by applicable law or agreed to in writing, software
|
||||
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
~ License for the specific language governing permissions and limitations
|
||||
~ under the License.
|
||||
~
|
||||
~ SPDX-License-Identifier: Apache-2.0
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
|
||||
|
||||
<configuration>
|
||||
|
||||
|
||||
|
||||
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
|
||||
<file>logs/application.log</file>
|
||||
<encoder>
|
||||
<pattern>%date - [%level] - from %logger in %thread
|
||||
%n%message%n%xException%n</pattern>
|
||||
</encoder>
|
||||
</appender>
|
||||
|
||||
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
|
||||
<encoder>
|
||||
<pattern> %logger{15} - %message%n%xException{5}
|
||||
</pattern>
|
||||
</encoder>
|
||||
</appender>
|
||||
|
||||
<logger name="org.apache.catalina.core" level="DEBUG" />
|
||||
<logger name="org.springframework" level="DEBUG" />
|
||||
<logger name="org.deeplearning4j" level="DEBUG" />
|
||||
<logger name="org.datavec" level="INFO" />
|
||||
<logger name="org.nd4j" level="INFO" />
|
||||
<logger name="opennlp.uima.util" level="OFF" />
|
||||
<logger name="org.apache.uima" level="OFF" />
|
||||
<logger name="org.cleartk" level="OFF" />
|
||||
<logger name="org.apache.spark" level="WARN" />
|
||||
|
||||
|
||||
|
||||
<root level="ERROR">
|
||||
<appender-ref ref="STDOUT" />
|
||||
<appender-ref ref="FILE" />
|
||||
</root>
|
||||
|
||||
</configuration>
|
|
@ -25,6 +25,7 @@ import org.datavec.spark.transform.misc.StringToWritablesFunction;
|
|||
import org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator;
|
||||
import org.deeplearning4j.spark.BaseSparkTest;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||
import org.nd4j.linalg.io.ClassPathResource;
|
||||
|
||||
|
@ -35,6 +36,15 @@ import static org.junit.Assert.assertEquals;
|
|||
|
||||
public class TestIteratorUtils extends BaseSparkTest {
|
||||
|
||||
@Override
|
||||
public DataType getDataType() {
|
||||
return DataType.FLOAT;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType getDefaultFPDataType() {
|
||||
return DataType.FLOAT;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testIrisRRMDSI() throws Exception {
|
||||
|
|
|
@ -453,8 +453,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
|||
tempDirF.deleteOnExit();
|
||||
|
||||
int dataSetObjSize = 1;
|
||||
int batchSizePerExecutor = 16;
|
||||
int numSplits = 5;
|
||||
int batchSizePerExecutor = 4;
|
||||
int numSplits = 3;
|
||||
int averagingFrequency = 3;
|
||||
int totalExamples = numExecutors() * batchSizePerExecutor * numSplits * averagingFrequency;
|
||||
DataSetIterator iter = new MnistDataSetIterator(dataSetObjSize, totalExamples, false);
|
||||
|
@ -506,7 +506,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
|||
INDArray paramsAfter = sparkNet.getNetwork().params().dup();
|
||||
assertNotEquals(paramsBefore, paramsAfter);
|
||||
|
||||
Thread.sleep(2000);
|
||||
Thread.sleep(200);
|
||||
SparkTrainingStats stats = sparkNet.getSparkTrainingStats();
|
||||
|
||||
//Expect
|
||||
|
@ -517,7 +517,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
|||
assertEquals(numSplits * numExecutors() * averagingFrequency, list.size());
|
||||
for (EventStats es : list) {
|
||||
ExampleCountEventStats e = (ExampleCountEventStats) es;
|
||||
assertTrue(batchSizePerExecutor * averagingFrequency - 10 >= e.getTotalExampleCount());
|
||||
assertTrue(batchSizePerExecutor * averagingFrequency >= e.getTotalExampleCount());
|
||||
}
|
||||
|
||||
|
||||
|
@ -535,9 +535,9 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
|||
tempDirF.deleteOnExit();
|
||||
tempDirF2.deleteOnExit();
|
||||
|
||||
int dataSetObjSize = 5;
|
||||
int batchSizePerExecutor = 25;
|
||||
DataSetIterator iter = new MnistDataSetIterator(dataSetObjSize, 1000, false);
|
||||
int dataSetObjSize = 4;
|
||||
int batchSizePerExecutor = 8;
|
||||
DataSetIterator iter = new MnistDataSetIterator(dataSetObjSize, 128, false);
|
||||
int i = 0;
|
||||
while (iter.hasNext()) {
|
||||
File nextFile = new File(tempDirF, i + ".bin");
|
||||
|
@ -981,7 +981,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
|||
.setOutputs("out")
|
||||
.build();
|
||||
|
||||
DataSetIterator iter = new IrisDataSetIterator(1, 150);
|
||||
DataSetIterator iter = new IrisDataSetIterator(1, 50);
|
||||
|
||||
List<DataSet> l = new ArrayList<>();
|
||||
while(iter.hasNext()){
|
||||
|
@ -992,9 +992,10 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
|||
|
||||
|
||||
int rddDataSetNumExamples = 1;
|
||||
int averagingFrequency = 3;
|
||||
int averagingFrequency = 2;
|
||||
int batch = 2;
|
||||
ParameterAveragingTrainingMaster tm = new ParameterAveragingTrainingMaster.Builder(rddDataSetNumExamples)
|
||||
.averagingFrequency(averagingFrequency).batchSizePerWorker(rddDataSetNumExamples)
|
||||
.averagingFrequency(averagingFrequency).batchSizePerWorker(batch)
|
||||
.saveUpdater(true).workerPrefetchNumBatches(0).build();
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
|
@ -1003,7 +1004,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
|||
SparkComputationGraph sn2 = new SparkComputationGraph(sc, conf2.clone(), tm);
|
||||
|
||||
|
||||
for(int i=0; i<4; i++ ){
|
||||
for(int i=0; i<3; i++ ){
|
||||
assertEquals(i, sn1.getNetwork().getLayerWiseConfigurations().getEpochCount());
|
||||
assertEquals(i, sn2.getNetwork().getConfiguration().getEpochCount());
|
||||
sn1.fit(rdd);
|
||||
|
|
|
@ -42,6 +42,11 @@ import static org.junit.Assert.assertTrue;
|
|||
*/
|
||||
public class TestRepartitioning extends BaseSparkTest {
|
||||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return isIntegrationTests() ? 240000 : 60000;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRepartitioning() {
|
||||
List<String> list = new ArrayList<>();
|
||||
|
@ -66,7 +71,12 @@ public class TestRepartitioning extends BaseSparkTest {
|
|||
@Test
|
||||
public void testRepartitioning2() throws Exception {
|
||||
|
||||
int[] ns = {320, 321, 25600, 25601, 25615};
|
||||
int[] ns;
|
||||
if(isIntegrationTests()){
|
||||
ns = new int[]{320, 321, 25600, 25601, 25615};
|
||||
} else {
|
||||
ns = new int[]{320, 2561};
|
||||
}
|
||||
|
||||
for (int n : ns) {
|
||||
|
||||
|
|
|
@ -32,6 +32,11 @@ import java.io.File;
|
|||
|
||||
public class MiscTests extends BaseDL4JTest {
|
||||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 120000L;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTransferVGG() throws Exception {
|
||||
//https://github.com/deeplearning4j/deeplearning4j/issues/5167
|
||||
|
|
|
@ -48,6 +48,11 @@ import static org.junit.Assert.assertEquals;
|
|||
@Slf4j
|
||||
public class TestDownload extends BaseDL4JTest {
|
||||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return isIntegrationTests() ? 480000L : 60000L;
|
||||
}
|
||||
|
||||
@ClassRule
|
||||
public static TemporaryFolder testDir = new TemporaryFolder();
|
||||
private static File f;
|
||||
|
@ -67,12 +72,20 @@ public class TestDownload extends BaseDL4JTest {
|
|||
public void testDownloadAllModels() throws Exception {
|
||||
|
||||
// iterate through each available model
|
||||
ZooModel[] models = new ZooModel[]{
|
||||
LeNet.builder().build(),
|
||||
SimpleCNN.builder().build(),
|
||||
UNet.builder().build(),
|
||||
NASNet.builder().build()
|
||||
};
|
||||
ZooModel[] models;
|
||||
|
||||
if(isIntegrationTests()){
|
||||
models = new ZooModel[]{
|
||||
LeNet.builder().build(),
|
||||
SimpleCNN.builder().build(),
|
||||
UNet.builder().build(),
|
||||
NASNet.builder().build()};
|
||||
} else {
|
||||
models = new ZooModel[]{
|
||||
LeNet.builder().build(),
|
||||
SimpleCNN.builder().build()};
|
||||
}
|
||||
|
||||
|
||||
|
||||
for (int i = 0; i < models.length; i++) {
|
||||
|
|
|
@ -57,6 +57,11 @@ import static org.junit.Assert.assertTrue;
|
|||
@Slf4j
|
||||
public class TestImageNet extends BaseDL4JTest {
|
||||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 90000L;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType getDataType(){
|
||||
return DataType.FLOAT;
|
||||
|
|
|
@ -63,10 +63,15 @@ public class DistributionUniform extends DynamicCustomOp {
|
|||
addArgs();
|
||||
}
|
||||
|
||||
public DistributionUniform(INDArray shape, INDArray out, double min, double max){
|
||||
public DistributionUniform(INDArray shape, INDArray out, double min, double max) {
|
||||
this(shape, out, min, max, null);
|
||||
}
|
||||
|
||||
public DistributionUniform(INDArray shape, INDArray out, double min, double max, DataType dataType){
|
||||
super(null, new INDArray[]{shape}, new INDArray[]{out}, Arrays.asList(min, max), (List<Integer>)null);
|
||||
this.min = min;
|
||||
this.max = max;
|
||||
this.dataType = dataType;
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -310,6 +310,13 @@
|
|||
</exclusion>
|
||||
</exclusions>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-common-tests</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<profiles>
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.nd4j.jita.allocator;
|
|||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.BaseND4JTest;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.util.DeviceLocalNDArray;
|
||||
|
@ -29,7 +30,7 @@ import static org.junit.Assert.assertArrayEquals;
|
|||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
@Slf4j
|
||||
public class DeviceLocalNDArrayTests {
|
||||
public class DeviceLocalNDArrayTests extends BaseND4JTest {
|
||||
|
||||
@Test
|
||||
public void testDeviceLocalArray_1() throws Exception{
|
||||
|
|
|
@ -19,13 +19,14 @@ package org.nd4j.jita.allocator.impl;
|
|||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.BaseND4JTest;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
@Slf4j
|
||||
public class MemoryTrackerTest {
|
||||
public class MemoryTrackerTest extends BaseND4JTest {
|
||||
|
||||
@Test
|
||||
public void testAllocatedDelta() {
|
||||
|
|
|
@ -4,6 +4,7 @@ import lombok.extern.slf4j.Slf4j;
|
|||
import lombok.val;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.BaseND4JTest;
|
||||
import org.nd4j.jita.allocator.impl.AtomicAllocator;
|
||||
import org.nd4j.jita.workspace.CudaWorkspace;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
|
@ -20,7 +21,7 @@ import java.util.concurrent.atomic.AtomicInteger;
|
|||
import static org.junit.Assert.*;
|
||||
|
||||
@Slf4j
|
||||
public class BaseCudaDataBufferTest {
|
||||
public class BaseCudaDataBufferTest extends BaseND4JTest {
|
||||
|
||||
@Before
|
||||
public void setUp() {
|
||||
|
|
|
@ -87,6 +87,13 @@
|
|||
<version>${logback.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-common-tests</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<profiles>
|
||||
|
|
|
@ -20,6 +20,7 @@ package org.nd4j.tensorflow.conversion;
|
|||
import junit.framework.TestCase;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.bytedeco.tensorflow.TF_Tensor;
|
||||
import org.nd4j.BaseND4JTest;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.resources.Resources;
|
||||
import org.nd4j.shade.protobuf.Descriptors;
|
||||
|
@ -46,7 +47,17 @@ import java.util.Map;
|
|||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
|
||||
public class GraphRunnerTest {
|
||||
public class GraphRunnerTest extends BaseND4JTest {
|
||||
|
||||
@Override
|
||||
public DataType getDataType() {
|
||||
return DataType.FLOAT;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType getDefaultFPDataType() {
|
||||
return DataType.FLOAT;
|
||||
}
|
||||
|
||||
public static ConfigProto getConfig(){
|
||||
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
||||
|
|
|
@ -18,6 +18,7 @@ package org.nd4j.tensorflow.conversion;
|
|||
|
||||
import org.apache.commons.io.IOUtils;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.BaseND4JTest;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.io.ClassPathResource;
|
||||
|
@ -29,7 +30,7 @@ import static org.junit.Assert.assertEquals;
|
|||
import static org.junit.Assert.assertNotNull;
|
||||
import static org.junit.Assert.fail;
|
||||
|
||||
public class TensorflowConversionTest {
|
||||
public class TensorflowConversionTest extends BaseND4JTest {
|
||||
|
||||
@Test
|
||||
public void testView() {
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
package org.nd4j.tensorflow.conversion;
|
||||
|
||||
import org.nd4j.BaseND4JTest;
|
||||
import org.nd4j.shade.protobuf.util.JsonFormat;
|
||||
import org.apache.commons.io.IOUtils;
|
||||
import org.junit.Test;
|
||||
|
@ -37,7 +38,7 @@ import java.util.Map;
|
|||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
|
||||
public class GpuGraphRunnerTest {
|
||||
public class GpuGraphRunnerTest extends BaseND4JTest {
|
||||
|
||||
@Test
|
||||
public void testGraphRunner() throws Exception {
|
||||
|
|
|
@ -127,6 +127,13 @@
|
|||
</exclusion>
|
||||
</exclusions>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-common-tests</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<reporting>
|
||||
|
|
|
@ -16,9 +16,6 @@
|
|||
|
||||
package org.nd4j.autodiff.opvalidation;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertNull;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
|
@ -46,6 +43,8 @@ import org.nd4j.linalg.factory.Nd4j;
|
|||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
@Slf4j
|
||||
public class LayerOpValidation extends BaseOpValidation {
|
||||
public LayerOpValidation(Nd4jBackend backend) {
|
||||
|
|
|
@ -45,7 +45,7 @@ public class LossOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Override
|
||||
public long testTimeoutMilliseconds() {
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 90000L;
|
||||
}
|
||||
|
||||
|
|
|
@ -54,8 +54,7 @@ import java.util.Arrays;
|
|||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertNull;
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
@Slf4j
|
||||
@RunWith(Parameterized.class)
|
||||
|
|
|
@ -2434,8 +2434,6 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
|
||||
@Test
|
||||
public void testPermute4(){
|
||||
Nd4j.getExecutioner().enableDebugMode(true);
|
||||
Nd4j.getExecutioner().enableVerboseMode(true);
|
||||
INDArray in = Nd4j.linspace(DataType.FLOAT, 1, 6, 1).reshape(3,2);
|
||||
INDArray permute = Nd4j.createFromArray(1,0);
|
||||
|
||||
|
|
|
@ -63,8 +63,12 @@ public class CheckpointListenerTest extends BaseNd4jTest {
|
|||
return sd;
|
||||
}
|
||||
|
||||
public static DataSetIterator getIter(){
|
||||
return new IrisDataSetIterator(15, 150);
|
||||
public static DataSetIterator getIter() {
|
||||
return getIter(15, 150);
|
||||
}
|
||||
|
||||
public static DataSetIterator getIter(int batch, int totalExamples){
|
||||
return new IrisDataSetIterator(batch, totalExamples);
|
||||
}
|
||||
|
||||
|
||||
|
@ -148,15 +152,15 @@ public class CheckpointListenerTest extends BaseNd4jTest {
|
|||
|
||||
CheckpointListener l = new CheckpointListener.Builder(dir)
|
||||
.keepLast(2)
|
||||
.saveEvery(3, TimeUnit.SECONDS)
|
||||
.saveEvery(1, TimeUnit.SECONDS)
|
||||
.build();
|
||||
sd.setListeners(l);
|
||||
|
||||
DataSetIterator iter = getIter();
|
||||
DataSetIterator iter = getIter(15, 150);
|
||||
|
||||
for(int i=0; i<5; i++ ){ //10 iterations total
|
||||
sd.fit(iter, 1);
|
||||
Thread.sleep(4000);
|
||||
Thread.sleep(1000);
|
||||
}
|
||||
|
||||
//Expect models saved at iterations: 10, 20, 30, 40
|
||||
|
|
|
@ -29,6 +29,7 @@ import org.nd4j.linalg.ops.transforms.Transforms;
|
|||
|
||||
import java.util.Random;
|
||||
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
/**
|
||||
|
|
|
@ -34,8 +34,7 @@ import java.util.ArrayList;
|
|||
import java.util.List;
|
||||
import java.util.Random;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
/**
|
||||
* Created by Alex on 05/07/2017.
|
||||
|
|
|
@ -33,8 +33,7 @@ import org.nd4j.linalg.indexing.NDArrayIndex;
|
|||
|
||||
import java.util.*;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
/**
|
||||
* Created by Alex on 04/11/2016.
|
||||
|
|
|
@ -33,6 +33,7 @@ import org.nd4j.nativeblas.NativeOpsHolder;
|
|||
|
||||
import java.util.Arrays;
|
||||
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
@Slf4j
|
||||
|
|
|
@ -121,7 +121,12 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
|
|||
"fused_batch_norm/.*",
|
||||
|
||||
// AB 2020/01/04 - https://github.com/eclipse/deeplearning4j/issues/8592
|
||||
"emptyArrayTests/reshape/rank2_shape2-0_2-0--1"
|
||||
"emptyArrayTests/reshape/rank2_shape2-0_2-0--1",
|
||||
|
||||
//AB 2020/01/07 - Known issues
|
||||
"bitcast/from_float64_to_int64",
|
||||
"bitcast/from_rank2_float64_to_int64",
|
||||
"bitcast/from_float64_to_uint64"
|
||||
};
|
||||
|
||||
/* As per TFGraphTestList.printArraysDebugging - this field defines a set of regexes for test cases that should have
|
||||
|
|
|
@ -252,7 +252,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
|
|||
System.out.println(Arrays.toString(shape));
|
||||
|
||||
// this is NHWC weights. will be changed soon.
|
||||
assertArrayEquals(new int[]{5,5,1,32}, shape);
|
||||
assertArrayEquals(new long[]{5,5,1,32}, shape);
|
||||
System.out.println(convNode);
|
||||
}
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.nd4j.linalg;
|
||||
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.bytedeco.javacpp.Pointer;
|
||||
import org.junit.After;
|
||||
|
@ -26,6 +27,7 @@ import org.junit.rules.TestName;
|
|||
import org.junit.rules.Timeout;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.nd4j.BaseND4JTest;
|
||||
import org.nd4j.config.ND4JSystemProperties;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||
|
@ -40,30 +42,16 @@ import org.slf4j.LoggerFactory;
|
|||
import java.lang.management.ManagementFactory;
|
||||
import java.util.*;
|
||||
|
||||
import static org.junit.Assume.assumeTrue;
|
||||
|
||||
|
||||
/**
|
||||
* Base Nd4j test
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
@RunWith(Parameterized.class)
|
||||
public abstract class BaseNd4jTest {
|
||||
private static Logger log = LoggerFactory.getLogger(BaseNd4jTest.class);
|
||||
|
||||
@Rule
|
||||
public TestName testName = new TestName();
|
||||
|
||||
@Rule
|
||||
public Timeout timeout = Timeout.seconds(testTimeoutMilliseconds());
|
||||
|
||||
/**
|
||||
* Override this method to set the default timeout for methods in the class
|
||||
*/
|
||||
public long testTimeoutMilliseconds(){
|
||||
return 30000L;
|
||||
}
|
||||
|
||||
protected long startTime;
|
||||
protected int threadCountBefore;
|
||||
@Slf4j
|
||||
public abstract class BaseNd4jTest extends BaseND4JTest {
|
||||
|
||||
protected Nd4jBackend backend;
|
||||
protected String name;
|
||||
|
@ -80,16 +68,10 @@ public abstract class BaseNd4jTest {
|
|||
public BaseNd4jTest(String name, Nd4jBackend backend) {
|
||||
this.backend = backend;
|
||||
this.name = name;
|
||||
|
||||
//Suppress ND4J initialization - don't need this logged for every test...
|
||||
System.setProperty(ND4JSystemProperties.LOG_INITIALIZATION, "false");
|
||||
System.setProperty(ND4JSystemProperties.ND4J_IGNORE_AVX, "true");
|
||||
System.gc();
|
||||
}
|
||||
|
||||
public BaseNd4jTest(Nd4jBackend backend) {
|
||||
this(backend.getClass().getName() + UUID.randomUUID().toString(), backend);
|
||||
|
||||
}
|
||||
|
||||
private static List<Nd4jBackend> backends;
|
||||
|
@ -104,79 +86,6 @@ public abstract class BaseNd4jTest {
|
|||
if (backend.canRun() && backendsToRun.contains(backend.getClass().getName()) || backendsToRun.isEmpty())
|
||||
backends.add(backend);
|
||||
}
|
||||
|
||||
}
|
||||
public static void assertArrayEquals(String string, Object[] expecteds, Object[] actuals) {
|
||||
org.junit.Assert.assertArrayEquals(string, expecteds, actuals);
|
||||
}
|
||||
|
||||
public static void assertArrayEquals(Object[] expecteds, Object[] actuals) {
|
||||
org.junit.Assert.assertArrayEquals(expecteds, actuals);
|
||||
}
|
||||
|
||||
public static void assertArrayEquals(String string, long[] shapeA, long[] shapeB) {
|
||||
org.junit.Assert.assertArrayEquals(string, shapeA, shapeB);
|
||||
}
|
||||
|
||||
public static void assertArrayEquals(String string, byte[] shapeA, byte[] shapeB) {
|
||||
org.junit.Assert.assertArrayEquals(string, shapeA, shapeB);
|
||||
}
|
||||
|
||||
public static void assertArrayEquals(byte[] shapeA, byte[] shapeB) {
|
||||
org.junit.Assert.assertArrayEquals(shapeA, shapeB);
|
||||
}
|
||||
|
||||
public static void assertArrayEquals(long[] shapeA, long[] shapeB) {
|
||||
org.junit.Assert.assertArrayEquals(shapeA, shapeB);
|
||||
}
|
||||
|
||||
public static void assertArrayEquals(String string, int[] shapeA, long[] shapeB) {
|
||||
org.junit.Assert.assertArrayEquals(string, ArrayUtil.toLongArray(shapeA), shapeB);
|
||||
}
|
||||
|
||||
public static void assertArrayEquals(int[] shapeA, long[] shapeB) {
|
||||
org.junit.Assert.assertArrayEquals(ArrayUtil.toLongArray(shapeA), shapeB);
|
||||
}
|
||||
|
||||
public static void assertArrayEquals(String string, long[] shapeA, int[] shapeB) {
|
||||
org.junit.Assert.assertArrayEquals(string, shapeA, ArrayUtil.toLongArray(shapeB));
|
||||
}
|
||||
|
||||
public static void assertArrayEquals(long[] shapeA, int[] shapeB) {
|
||||
org.junit.Assert.assertArrayEquals(shapeA, ArrayUtil.toLongArray(shapeB));
|
||||
}
|
||||
|
||||
public static void assertArrayEquals(String string, int[] shapeA, int[] shapeB) {
|
||||
org.junit.Assert.assertArrayEquals(string, shapeA, shapeB);
|
||||
}
|
||||
|
||||
public static void assertArrayEquals(int[] shapeA, int[] shapeB) {
|
||||
org.junit.Assert.assertArrayEquals(shapeA, shapeB);
|
||||
}
|
||||
|
||||
public static void assertArrayEquals(String string, boolean[] shapeA, boolean[] shapeB) {
|
||||
org.junit.Assert.assertArrayEquals(string, shapeA, shapeB);
|
||||
}
|
||||
|
||||
public static void assertArrayEquals(boolean[] shapeA, boolean[] shapeB) {
|
||||
org.junit.Assert.assertArrayEquals(shapeA, shapeB);
|
||||
}
|
||||
|
||||
|
||||
public static void assertArrayEquals(float[] shapeA, float[] shapeB, float delta) {
|
||||
org.junit.Assert.assertArrayEquals(shapeA, shapeB, delta);
|
||||
}
|
||||
|
||||
public static void assertArrayEquals(double[] shapeA, double[] shapeB, double delta) {
|
||||
org.junit.Assert.assertArrayEquals(shapeA, shapeB, delta);
|
||||
}
|
||||
|
||||
public static void assertArrayEquals(String string, float[] shapeA, float[] shapeB, float delta) {
|
||||
org.junit.Assert.assertArrayEquals(string, shapeA, shapeB, delta);
|
||||
}
|
||||
|
||||
public static void assertArrayEquals(String string, double[] shapeA, double[] shapeB, double delta) {
|
||||
org.junit.Assert.assertArrayEquals(string, shapeA, shapeB, delta);
|
||||
}
|
||||
|
||||
@Parameterized.Parameters(name = "{index}: backend({0})={1}")
|
||||
|
@ -187,6 +96,13 @@ public abstract class BaseNd4jTest {
|
|||
return ret;
|
||||
}
|
||||
|
||||
@Override
|
||||
@Before
|
||||
public void beforeTest(){
|
||||
super.beforeTest();
|
||||
Nd4j.factory().setOrder(ordering());
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the default backend (jblas)
|
||||
* The default backend can be overridden by also passing:
|
||||
|
@ -207,106 +123,6 @@ public abstract class BaseNd4jTest {
|
|||
}
|
||||
|
||||
|
||||
|
||||
@Before
|
||||
public void before() throws Exception {
|
||||
//
|
||||
log.info("Running {}.{} on {}", getClass().getName(), testName.getMethodName(), backend.getClass().getSimpleName());
|
||||
Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build());
|
||||
Nd4j nd4j = new Nd4j();
|
||||
nd4j.initWithBackend(backend);
|
||||
Nd4j.factory().setOrder(ordering());
|
||||
NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false);
|
||||
Nd4j.getExecutioner().enableDebugMode(false);
|
||||
Nd4j.getExecutioner().enableVerboseMode(false);
|
||||
Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE);
|
||||
startTime = System.currentTimeMillis();
|
||||
threadCountBefore = ManagementFactory.getThreadMXBean().getThreadCount();
|
||||
}
|
||||
|
||||
@After
|
||||
public void after() throws Exception {
|
||||
long totalTime = System.currentTimeMillis() - startTime;
|
||||
Nd4j.getMemoryManager().purgeCaches();
|
||||
|
||||
logTestCompletion(totalTime);
|
||||
if (System.getProperties().getProperty("backends") != null
|
||||
&& !System.getProperty("backends").contains(backend.getClass().getName()))
|
||||
return;
|
||||
Nd4j nd4j = new Nd4j();
|
||||
nd4j.initWithBackend(backend);
|
||||
Nd4j.factory().setOrder(ordering());
|
||||
NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false);
|
||||
Nd4j.getExecutioner().enableDebugMode(false);
|
||||
Nd4j.getExecutioner().enableVerboseMode(false);
|
||||
|
||||
//Attempt to keep workspaces isolated between tests
|
||||
Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread();
|
||||
val currWS = Nd4j.getMemoryManager().getCurrentWorkspace();
|
||||
Nd4j.getMemoryManager().setCurrentWorkspace(null);
|
||||
if(currWS != null){
|
||||
//Not really safe to continue testing under this situation... other tests will likely fail with obscure
|
||||
// errors that are hard to track back to this
|
||||
log.error("Open workspace leaked from test! Exiting - {}, isOpen = {} - {}", currWS.getId(), currWS.isScopeActive(), currWS);
|
||||
System.exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
public void logTestCompletion( long totalTime){
|
||||
StringBuilder sb = new StringBuilder();
|
||||
long maxPhys = Pointer.maxPhysicalBytes();
|
||||
long maxBytes = Pointer.maxBytes();
|
||||
long currPhys = Pointer.physicalBytes();
|
||||
long currBytes = Pointer.totalBytes();
|
||||
|
||||
long jvmTotal = Runtime.getRuntime().totalMemory();
|
||||
long jvmMax = Runtime.getRuntime().maxMemory();
|
||||
|
||||
int threadsAfter = ManagementFactory.getThreadMXBean().getThreadCount();
|
||||
sb.append(getClass().getSimpleName()).append(".").append(testName.getMethodName())
|
||||
.append(": ").append(totalTime).append(" ms")
|
||||
.append(", threadCount: (").append(threadCountBefore).append("->").append(threadsAfter).append(")")
|
||||
.append(", jvmTotal=").append(jvmTotal)
|
||||
.append(", jvmMax=").append(jvmMax)
|
||||
.append(", totalBytes=").append(currBytes).append(", maxBytes=").append(maxBytes)
|
||||
.append(", currPhys=").append(currPhys).append(", maxPhys=").append(maxPhys);
|
||||
|
||||
List<MemoryWorkspace> ws = Nd4j.getWorkspaceManager().getAllWorkspacesForCurrentThread();
|
||||
if(ws != null && ws.size() > 0){
|
||||
long currSize = 0;
|
||||
for(MemoryWorkspace w : ws){
|
||||
currSize += w.getCurrentSize();
|
||||
}
|
||||
if(currSize > 0){
|
||||
sb.append(", threadWSSize=").append(currSize)
|
||||
.append(" (").append(ws.size()).append(" WSs)");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Properties p = Nd4j.getExecutioner().getEnvironmentInformation();
|
||||
Object o = p.get("cuda.devicesInformation");
|
||||
if(o instanceof List){
|
||||
List<Map<String,Object>> l = (List<Map<String, Object>>) o;
|
||||
if(l.size() > 0) {
|
||||
|
||||
sb.append(" [").append(l.size())
|
||||
.append(" GPUs: ");
|
||||
|
||||
for (int i = 0; i < l.size(); i++) {
|
||||
Map<String,Object> m = l.get(i);
|
||||
if(i > 0)
|
||||
sb.append(",");
|
||||
sb.append("(").append(m.get("cuda.freeMemory")).append(" free, ")
|
||||
.append(m.get("cuda.totalMemory")).append(" total)");
|
||||
}
|
||||
sb.append("]");
|
||||
}
|
||||
}
|
||||
log.info(sb.toString());
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* The ordering for this test
|
||||
* This test will only be invoked for
|
||||
|
@ -315,15 +131,10 @@ public abstract class BaseNd4jTest {
|
|||
* @return the ordering for this test
|
||||
*/
|
||||
public char ordering() {
|
||||
return 'a';
|
||||
return 'c';
|
||||
}
|
||||
|
||||
|
||||
|
||||
public String getFailureMessage() {
|
||||
return "Failed with backend " + backend.getClass().getName() + " and ordering " + ordering();
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -65,16 +65,6 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
super(backend);
|
||||
}
|
||||
|
||||
@Before
|
||||
public void before() throws Exception {
|
||||
super.before();
|
||||
}
|
||||
|
||||
@After
|
||||
public void after() throws Exception {
|
||||
super.after();
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Test
|
||||
|
|
|
@ -133,13 +133,13 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Override
|
||||
public long testTimeoutMilliseconds() {
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 90000;
|
||||
}
|
||||
|
||||
@Before
|
||||
public void before() throws Exception {
|
||||
super.before();
|
||||
super.beforeTest();
|
||||
Nd4j.setDataType(DataType.DOUBLE);
|
||||
Nd4j.getRandom().setSeed(123);
|
||||
Nd4j.getExecutioner().enableDebugMode(false);
|
||||
|
@ -148,7 +148,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
|||
|
||||
@After
|
||||
public void after() throws Exception {
|
||||
super.after();
|
||||
super.afterTest();
|
||||
Nd4j.setDataType(initialType);
|
||||
}
|
||||
|
||||
|
@ -5331,7 +5331,8 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
|||
|
||||
@Test
|
||||
public void testNativeSort3() {
|
||||
INDArray array = Nd4j.linspace(1, 1048576, 1048576, DataType.DOUBLE).reshape(1, -1);
|
||||
int length = isIntegrationTests() ? 1048576 : 16484;
|
||||
INDArray array = Nd4j.linspace(1, length, length, DataType.DOUBLE).reshape(1, -1);
|
||||
INDArray exp = array.dup();
|
||||
Nd4j.shuffle(array, 0);
|
||||
|
||||
|
@ -7196,19 +7197,19 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
|||
|
||||
for( int i=-3; i<3; i++ ){
|
||||
INDArray out = Nd4j.stack(i, in, in2);
|
||||
int[] expShape;
|
||||
long[] expShape;
|
||||
switch (i){
|
||||
case -3:
|
||||
case 0:
|
||||
expShape = new int[]{2,3,4};
|
||||
expShape = new long[]{2,3,4};
|
||||
break;
|
||||
case -2:
|
||||
case 1:
|
||||
expShape = new int[]{3,2,4};
|
||||
expShape = new long[]{3,2,4};
|
||||
break;
|
||||
case -1:
|
||||
case 2:
|
||||
expShape = new int[]{3,4,2};
|
||||
expShape = new long[]{3,4,2};
|
||||
break;
|
||||
default:
|
||||
throw new RuntimeException(String.valueOf(i));
|
||||
|
@ -7602,6 +7603,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
|||
String wsName = "testRollingMeanWs";
|
||||
try {
|
||||
System.gc();
|
||||
int iterations1 = isIntegrationTests() ? 5 : 2;
|
||||
for (int e = 0; e < 5; e++) {
|
||||
try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(wsconf, wsName)) {
|
||||
val array = Nd4j.create(DataType.FLOAT, 32, 128, 256, 256);
|
||||
|
@ -7609,7 +7611,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
|||
}
|
||||
}
|
||||
|
||||
int iterations = 20;
|
||||
int iterations = isIntegrationTests() ? 20 : 3;
|
||||
val timeStart = System.nanoTime();
|
||||
for (int e = 0; e < iterations; e++) {
|
||||
try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(wsconf, wsName)) {
|
||||
|
|
|
@ -57,13 +57,13 @@ public class Nd4jTestsComparisonC extends BaseNd4jTest {
|
|||
|
||||
@Before
|
||||
public void before() throws Exception {
|
||||
super.before();
|
||||
super.beforeTest();
|
||||
DataTypeUtil.setDTypeForContext(DataType.DOUBLE);
|
||||
}
|
||||
|
||||
@After
|
||||
public void after() throws Exception {
|
||||
super.after();
|
||||
super.afterTest();
|
||||
DataTypeUtil.setDTypeForContext(initialType);
|
||||
}
|
||||
|
||||
|
|
|
@ -37,8 +37,7 @@ import org.slf4j.LoggerFactory;
|
|||
import java.util.List;
|
||||
import java.util.Random;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
/**
|
||||
* Tests comparing Nd4j ops to other libraries
|
||||
|
@ -59,7 +58,7 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest {
|
|||
|
||||
@Before
|
||||
public void before() throws Exception {
|
||||
super.before();
|
||||
super.beforeTest();
|
||||
DataTypeUtil.setDTypeForContext(DataType.DOUBLE);
|
||||
Nd4j.getRandom().setSeed(SEED);
|
||||
|
||||
|
@ -67,7 +66,7 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest {
|
|||
|
||||
@After
|
||||
public void after() throws Exception {
|
||||
super.after();
|
||||
super.afterTest();
|
||||
DataTypeUtil.setDTypeForContext(initialType);
|
||||
}
|
||||
|
||||
|
@ -197,7 +196,7 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest {
|
|||
INDArray gemv = m.mmul(v);
|
||||
RealMatrix gemv2 = rm.multiply(rv);
|
||||
|
||||
assertArrayEquals(new int[] {rows, 1}, gemv.shape());
|
||||
assertArrayEquals(new long[] {rows, 1}, gemv.shape());
|
||||
assertArrayEquals(new int[] {rows, 1},
|
||||
new int[] {gemv2.getRowDimension(), gemv2.getColumnDimension()});
|
||||
|
||||
|
|
|
@ -25,6 +25,8 @@ import org.nd4j.linalg.factory.Nd4jBackend;
|
|||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.nd4j.linalg.util.ArrayUtil;
|
||||
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
|
||||
/**
|
||||
* Created by Alex on 30/04/2016.
|
||||
*/
|
||||
|
|
|
@ -38,8 +38,7 @@ import org.nd4j.linalg.util.SerializationUtils;
|
|||
import java.io.*;
|
||||
import java.util.Arrays;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
/**
|
||||
* Double data buffer tests
|
||||
|
|
|
@ -37,8 +37,7 @@ import org.nd4j.linalg.util.SerializationUtils;
|
|||
import java.io.*;
|
||||
import java.nio.ByteBuffer;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
/**
|
||||
* Float data buffer tests
|
||||
|
|
|
@ -31,8 +31,7 @@ import org.nd4j.linalg.factory.Nd4jBackend;
|
|||
import java.io.*;
|
||||
import java.util.Arrays;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
/**
|
||||
* Tests for INT INDArrays and DataBuffers serialization
|
||||
|
|
|
@ -33,8 +33,7 @@ import org.nd4j.linalg.util.ArrayUtil;
|
|||
import java.util.Arrays;
|
||||
import java.util.Random;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import static org.junit.Assert.*;
|
||||
import static org.nd4j.linalg.indexing.NDArrayIndex.*;
|
||||
|
||||
/**
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue