Unit/integration test split + test speedup (#166)
* Add maven profile + base tests methods for integration tests Signed-off-by: AlexDBlack <blacka101@gmail.com> * Switch from system property to environment variable; seems more reliable in intellij Signed-off-by: AlexDBlack <blacka101@gmail.com> * Add nd4j-common-tests module, and common base test; cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * Ensure all ND4J tests extend BaseND4JTest Signed-off-by: AlexDBlack <blacka101@gmail.com> * Test spam reduction, import fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Add test logging to nd4j-aeron Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix unintended change Signed-off-by: AlexDBlack <blacka101@gmail.com> * Reduce sprint test log spam Signed-off-by: AlexDBlack <blacka101@gmail.com> * More test spam cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * Significantly speed up TSNE tests Signed-off-by: AlexDBlack <blacka101@gmail.com> * W2V iterator test unit/integration split Signed-off-by: AlexDBlack <blacka101@gmail.com> * More NLP test speedups Signed-off-by: AlexDBlack <blacka101@gmail.com> * Avoid debug/verbose mode leaking between tests Signed-off-by: AlexDBlack <blacka101@gmail.com> * test tweak Signed-off-by: AlexDBlack <blacka101@gmail.com> * Arbiter extends base DL4J test Signed-off-by: AlexDBlack <blacka101@gmail.com> * Arbiter test speedup Signed-off-by: AlexDBlack <blacka101@gmail.com> * nlp-uima test speedup Signed-off-by: AlexDBlack <blacka101@gmail.com> * More test speedups Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix ND4J base test Signed-off-by: AlexDBlack <blacka101@gmail.com> * Few small ND4J test speed improvements Signed-off-by: AlexDBlack <blacka101@gmail.com> * DL4J tests speedup Signed-off-by: AlexDBlack <blacka101@gmail.com> * More tweaks Signed-off-by: AlexDBlack <blacka101@gmail.com> * Even more test speedups Signed-off-by: AlexDBlack <blacka101@gmail.com> * More tweaks Signed-off-by: AlexDBlack <blacka101@gmail.com> * Various test fixes Signed-off-by: Alex Black <blacka101@gmail.com> * More test fixes Signed-off-by: Alex Black <blacka101@gmail.com> * Add ability to specify number of threads for C++ ops in BaseDL4JTest and BaseND4JTest Signed-off-by: Alex Black <blacka101@gmail.com> * nd4j-aeron test profile fix for CUDA Signed-off-by: Alex Black <blacka101@gmail.com>master
parent
2717b25931
commit
a25bb6a11c
|
@ -84,6 +84,13 @@
|
||||||
<artifactId>jackson</artifactId>
|
<artifactId>jackson</artifactId>
|
||||||
<version>${nd4j.version}</version>
|
<version>${nd4j.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.deeplearning4j</groupId>
|
||||||
|
<artifactId>deeplearning4j-common-tests</artifactId>
|
||||||
|
<version>${project.version}</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize;
|
package org.deeplearning4j.arbiter.optimize;
|
||||||
|
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
|
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
|
||||||
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
|
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
|
||||||
import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition;
|
import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition;
|
||||||
|
@ -32,7 +33,7 @@ import org.deeplearning4j.arbiter.optimize.runner.listener.impl.LoggingStatusLis
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
public class TestGeneticSearch {
|
public class TestGeneticSearch extends BaseDL4JTest {
|
||||||
public class TestSelectionOperator extends SelectionOperator {
|
public class TestSelectionOperator extends SelectionOperator {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize;
|
package org.deeplearning4j.arbiter.optimize;
|
||||||
|
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
|
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
|
||||||
import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider;
|
import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider;
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.GridSearchCandidateGenerator;
|
import org.deeplearning4j.arbiter.optimize.generator.GridSearchCandidateGenerator;
|
||||||
|
@ -26,7 +27,7 @@ import java.util.Map;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.Assert.*;
|
||||||
|
|
||||||
public class TestGridSearch {
|
public class TestGridSearch extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testIndexing() {
|
public void testIndexing() {
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.deeplearning4j.arbiter.optimize;
|
||||||
import org.apache.commons.math3.distribution.LogNormalDistribution;
|
import org.apache.commons.math3.distribution.LogNormalDistribution;
|
||||||
import org.apache.commons.math3.distribution.NormalDistribution;
|
import org.apache.commons.math3.distribution.NormalDistribution;
|
||||||
import org.apache.commons.math3.distribution.UniformIntegerDistribution;
|
import org.apache.commons.math3.distribution.UniformIntegerDistribution;
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
|
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
|
||||||
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
||||||
import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider;
|
import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider;
|
||||||
|
@ -49,7 +50,7 @@ import static org.junit.Assert.assertEquals;
|
||||||
/**
|
/**
|
||||||
* Created by Alex on 02/02/2017.
|
* Created by Alex on 02/02/2017.
|
||||||
*/
|
*/
|
||||||
public class TestJson {
|
public class TestJson extends BaseDL4JTest {
|
||||||
|
|
||||||
protected static ObjectMapper getObjectMapper(JsonFactory factory) {
|
protected static ObjectMapper getObjectMapper(JsonFactory factory) {
|
||||||
ObjectMapper om = new ObjectMapper(factory);
|
ObjectMapper om = new ObjectMapper(factory);
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize;
|
package org.deeplearning4j.arbiter.optimize;
|
||||||
|
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
|
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
|
||||||
import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider;
|
import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider;
|
||||||
import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition;
|
import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition;
|
||||||
|
@ -34,7 +35,7 @@ import java.util.Map;
|
||||||
* Test random search on the Branin Function:
|
* Test random search on the Branin Function:
|
||||||
* http://www.sfu.ca/~ssurjano/branin.html
|
* http://www.sfu.ca/~ssurjano/branin.html
|
||||||
*/
|
*/
|
||||||
public class TestRandomSearch {
|
public class TestRandomSearch extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void test() throws Exception {
|
public void test() throws Exception {
|
||||||
|
|
|
@ -17,12 +17,13 @@
|
||||||
package org.deeplearning4j.arbiter.optimize.distribution;
|
package org.deeplearning4j.arbiter.optimize.distribution;
|
||||||
|
|
||||||
import org.apache.commons.math3.distribution.RealDistribution;
|
import org.apache.commons.math3.distribution.RealDistribution;
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
public class TestLogUniform {
|
public class TestLogUniform extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testSimple(){
|
public void testSimple(){
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
|
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
|
||||||
|
|
||||||
import org.apache.commons.math3.random.RandomGenerator;
|
import org.apache.commons.math3.random.RandomGenerator;
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.ArithmeticCrossover;
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.ArithmeticCrossover;
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
|
||||||
import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection;
|
import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection;
|
||||||
|
@ -24,7 +25,7 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
public class ArithmeticCrossoverTests {
|
public class ArithmeticCrossoverTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void ArithmeticCrossover_Crossover_OutsideCrossoverRate_ShouldReturnParent0() {
|
public void ArithmeticCrossover_Crossover_OutsideCrossoverRate_ShouldReturnParent0() {
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
|
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
|
||||||
|
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
|
||||||
import org.deeplearning4j.arbiter.optimize.genetic.TestCrossoverOperator;
|
import org.deeplearning4j.arbiter.optimize.genetic.TestCrossoverOperator;
|
||||||
|
@ -23,7 +24,7 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer;
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
public class CrossoverOperatorTests {
|
public class CrossoverOperatorTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void CrossoverOperator_initializeInstance_ShouldInitPopulationModel() throws IllegalAccessException {
|
public void CrossoverOperator_initializeInstance_ShouldInitPopulationModel() throws IllegalAccessException {
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
|
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
|
||||||
|
|
||||||
import org.apache.commons.math3.random.RandomGenerator;
|
import org.apache.commons.math3.random.RandomGenerator;
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.utils.CrossoverPointsGenerator;
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.utils.CrossoverPointsGenerator;
|
||||||
import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
|
import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
|
@ -24,7 +25,7 @@ import org.junit.Test;
|
||||||
|
|
||||||
import java.util.Deque;
|
import java.util.Deque;
|
||||||
|
|
||||||
public class CrossoverPointsGeneratorTests {
|
public class CrossoverPointsGeneratorTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void CrossoverPointsGenerator_FixedNumberCrossovers() {
|
public void CrossoverPointsGenerator_FixedNumberCrossovers() {
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
|
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
|
||||||
|
|
||||||
import org.apache.commons.math3.random.RandomGenerator;
|
import org.apache.commons.math3.random.RandomGenerator;
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.KPointCrossover;
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.KPointCrossover;
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection;
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection;
|
||||||
|
@ -25,7 +26,7 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
public class KPointCrossoverTests {
|
public class KPointCrossoverTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void KPointCrossover_BelowCrossoverRate_ShouldReturnParent0() {
|
public void KPointCrossover_BelowCrossoverRate_ShouldReturnParent0() {
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
|
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
|
||||||
|
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
||||||
import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection;
|
import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection;
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
|
@ -24,7 +25,7 @@ import org.junit.Test;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
public class ParentSelectionTests {
|
public class ParentSelectionTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void ParentSelection_InitializeInstance_ShouldInitPopulation() {
|
public void ParentSelection_InitializeInstance_ShouldInitPopulation() {
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
|
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
|
||||||
|
|
||||||
import org.apache.commons.math3.random.RandomGenerator;
|
import org.apache.commons.math3.random.RandomGenerator;
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.RandomTwoParentSelection;
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.RandomTwoParentSelection;
|
||||||
import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
|
import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
|
||||||
|
@ -26,7 +27,7 @@ import org.junit.Test;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
public class RandomTwoParentSelectionTests {
|
public class RandomTwoParentSelectionTests extends BaseDL4JTest {
|
||||||
@Test
|
@Test
|
||||||
public void RandomTwoParentSelection_ShouldReturnTwoDifferentParents() {
|
public void RandomTwoParentSelection_ShouldReturnTwoDifferentParents() {
|
||||||
RandomGenerator rng = new TestRandomGenerator(new int[] {1, 1, 1, 0}, null);
|
RandomGenerator rng = new TestRandomGenerator(new int[] {1, 1, 1, 0}, null);
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
|
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
|
||||||
|
|
||||||
import org.apache.commons.math3.random.RandomGenerator;
|
import org.apache.commons.math3.random.RandomGenerator;
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.SinglePointCrossover;
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.SinglePointCrossover;
|
||||||
import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection;
|
import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection;
|
||||||
|
@ -24,7 +25,7 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
public class SinglePointCrossoverTests {
|
public class SinglePointCrossoverTests extends BaseDL4JTest {
|
||||||
@Test
|
@Test
|
||||||
public void SinglePointCrossover_BelowCrossoverRate_ShouldReturnParent0() {
|
public void SinglePointCrossover_BelowCrossoverRate_ShouldReturnParent0() {
|
||||||
RandomGenerator rng = new TestRandomGenerator(null, new double[] {1.0});
|
RandomGenerator rng = new TestRandomGenerator(null, new double[] {1.0});
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
|
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
|
||||||
|
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.TwoParentsCrossoverOperator;
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.TwoParentsCrossoverOperator;
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection;
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection;
|
||||||
|
@ -27,7 +28,7 @@ import org.junit.Assert;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
|
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
|
||||||
|
|
||||||
public class TwoParentsCrossoverOperatorTests {
|
public class TwoParentsCrossoverOperatorTests extends BaseDL4JTest {
|
||||||
|
|
||||||
class TestTwoParentsCrossoverOperator extends TwoParentsCrossoverOperator {
|
class TestTwoParentsCrossoverOperator extends TwoParentsCrossoverOperator {
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
|
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
|
||||||
|
|
||||||
import org.apache.commons.math3.random.RandomGenerator;
|
import org.apache.commons.math3.random.RandomGenerator;
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.UniformCrossover;
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.UniformCrossover;
|
||||||
import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection;
|
import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection;
|
||||||
|
@ -24,7 +25,7 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
public class UniformCrossoverTests {
|
public class UniformCrossoverTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void UniformCrossover_BelowCrossoverRate_ShouldReturnParent0() {
|
public void UniformCrossover_BelowCrossoverRate_ShouldReturnParent0() {
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.genetic.culling;
|
package org.deeplearning4j.arbiter.optimize.genetic.culling;
|
||||||
|
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.LeastFitCullOperator;
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.LeastFitCullOperator;
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
|
||||||
|
@ -27,7 +28,7 @@ import org.junit.Test;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
public class LeastFitCullOperatorTests {
|
public class LeastFitCullOperatorTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void LeastFitCullingOperation_ShouldCullLastElements() {
|
public void LeastFitCullingOperation_ShouldCullLastElements() {
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.genetic.culling;
|
package org.deeplearning4j.arbiter.optimize.genetic.culling;
|
||||||
|
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.RatioCullOperator;
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.RatioCullOperator;
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
|
||||||
|
@ -27,7 +28,7 @@ import sun.reflect.generics.reflectiveObjects.NotImplementedException;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
public class RatioCullOperatorTests {
|
public class RatioCullOperatorTests extends BaseDL4JTest {
|
||||||
|
|
||||||
class TestRatioCullOperator extends RatioCullOperator {
|
class TestRatioCullOperator extends RatioCullOperator {
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.genetic.mutation;
|
package org.deeplearning4j.arbiter.optimize.genetic.mutation;
|
||||||
|
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.mutation.RandomMutationOperator;
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.mutation.RandomMutationOperator;
|
||||||
import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
|
import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
|
@ -24,7 +25,7 @@ import org.junit.Test;
|
||||||
import java.lang.reflect.Field;
|
import java.lang.reflect.Field;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
|
||||||
public class RandomMutationOperatorTests {
|
public class RandomMutationOperatorTests extends BaseDL4JTest {
|
||||||
@Test
|
@Test
|
||||||
public void RandomMutationOperator_DefaultBuild_ShouldNotBeNull() {
|
public void RandomMutationOperator_DefaultBuild_ShouldNotBeNull() {
|
||||||
RandomMutationOperator sut = new RandomMutationOperator.Builder().build();
|
RandomMutationOperator sut = new RandomMutationOperator.Builder().build();
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.genetic.population;
|
package org.deeplearning4j.arbiter.optimize.genetic.population;
|
||||||
|
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.CullOperator;
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.CullOperator;
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
|
||||||
|
@ -27,7 +28,7 @@ import org.junit.Test;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
public class PopulationModelTests {
|
public class PopulationModelTests extends BaseDL4JTest {
|
||||||
|
|
||||||
private class TestCullOperator implements CullOperator {
|
private class TestCullOperator implements CullOperator {
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.deeplearning4j.arbiter.optimize.genetic.selection;
|
package org.deeplearning4j.arbiter.optimize.genetic.selection;
|
||||||
|
|
||||||
import org.apache.commons.math3.random.RandomGenerator;
|
import org.apache.commons.math3.random.RandomGenerator;
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverOperator;
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverOperator;
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
|
||||||
|
@ -36,7 +37,7 @@ import sun.reflect.generics.reflectiveObjects.NotImplementedException;
|
||||||
|
|
||||||
import static org.junit.Assert.assertArrayEquals;
|
import static org.junit.Assert.assertArrayEquals;
|
||||||
|
|
||||||
public class GeneticSelectionOperatorTests {
|
public class GeneticSelectionOperatorTests extends BaseDL4JTest {
|
||||||
|
|
||||||
private class TestCullOperator implements CullOperator {
|
private class TestCullOperator implements CullOperator {
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.genetic.selection;
|
package org.deeplearning4j.arbiter.optimize.genetic.selection;
|
||||||
|
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
|
||||||
|
@ -25,7 +26,7 @@ import org.junit.Assert;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
|
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
|
||||||
|
|
||||||
public class SelectionOperatorTests {
|
public class SelectionOperatorTests extends BaseDL4JTest {
|
||||||
private class TestSelectionOperator extends SelectionOperator {
|
private class TestSelectionOperator extends SelectionOperator {
|
||||||
|
|
||||||
public PopulationModel getPopulationModel() {
|
public PopulationModel getPopulationModel() {
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.deeplearning4j.arbiter.optimize.parameter;
|
package org.deeplearning4j.arbiter.optimize.parameter;
|
||||||
|
|
||||||
import org.apache.commons.math3.distribution.NormalDistribution;
|
import org.apache.commons.math3.distribution.NormalDistribution;
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
||||||
import org.deeplearning4j.arbiter.optimize.parameter.continuous.ContinuousParameterSpace;
|
import org.deeplearning4j.arbiter.optimize.parameter.continuous.ContinuousParameterSpace;
|
||||||
import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterSpace;
|
import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterSpace;
|
||||||
|
@ -25,7 +26,7 @@ import org.junit.Test;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
public class TestParameterSpaces {
|
public class TestParameterSpaces extends BaseDL4JTest {
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
|
@ -63,6 +63,13 @@
|
||||||
<artifactId>gson</artifactId>
|
<artifactId>gson</artifactId>
|
||||||
<version>${gson.version}</version>
|
<version>${gson.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.deeplearning4j</groupId>
|
||||||
|
<artifactId>deeplearning4j-common-tests</artifactId>
|
||||||
|
<version>${project.version}</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.computationgraph;
|
package org.deeplearning4j.arbiter.computationgraph;
|
||||||
|
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.arbiter.ComputationGraphSpace;
|
import org.deeplearning4j.arbiter.ComputationGraphSpace;
|
||||||
import org.deeplearning4j.arbiter.TestUtils;
|
import org.deeplearning4j.arbiter.TestUtils;
|
||||||
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
|
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
|
||||||
|
@ -44,7 +45,7 @@ import java.util.Random;
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
public class TestComputationGraphSpace {
|
public class TestComputationGraphSpace extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBasic() {
|
public void testBasic() {
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.deeplearning4j.arbiter.computationgraph;
|
package org.deeplearning4j.arbiter.computationgraph;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.arbiter.ComputationGraphSpace;
|
import org.deeplearning4j.arbiter.ComputationGraphSpace;
|
||||||
import org.deeplearning4j.arbiter.conf.updater.AdamSpace;
|
import org.deeplearning4j.arbiter.conf.updater.AdamSpace;
|
||||||
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
|
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
|
||||||
|
@ -85,7 +86,7 @@ import static org.junit.Assert.assertEquals;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class TestGraphLocalExecution {
|
public class TestGraphLocalExecution extends BaseDL4JTest {
|
||||||
|
|
||||||
@Rule
|
@Rule
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
public TemporaryFolder testDir = new TemporaryFolder();
|
||||||
|
@ -126,7 +127,7 @@ public class TestGraphLocalExecution {
|
||||||
if(dataApproach == 0){
|
if(dataApproach == 0){
|
||||||
ds = TestDL4JLocalExecution.MnistDataSource.class;
|
ds = TestDL4JLocalExecution.MnistDataSource.class;
|
||||||
dsP = new Properties();
|
dsP = new Properties();
|
||||||
dsP.setProperty("minibatch", "8");
|
dsP.setProperty("minibatch", "2");
|
||||||
candidateGenerator = new RandomSearchGenerator(mls);
|
candidateGenerator = new RandomSearchGenerator(mls);
|
||||||
} else if(dataApproach == 1) {
|
} else if(dataApproach == 1) {
|
||||||
//DataProvider approach
|
//DataProvider approach
|
||||||
|
@ -150,8 +151,8 @@ public class TestGraphLocalExecution {
|
||||||
.dataSource(ds, dsP)
|
.dataSource(ds, dsP)
|
||||||
.modelSaver(new FileModelSaver(modelSave))
|
.modelSaver(new FileModelSaver(modelSave))
|
||||||
.scoreFunction(new TestSetLossScoreFunction())
|
.scoreFunction(new TestSetLossScoreFunction())
|
||||||
.terminationConditions(new MaxTimeCondition(2, TimeUnit.MINUTES),
|
.terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS),
|
||||||
new MaxCandidatesCondition(5))
|
new MaxCandidatesCondition(3))
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
IOptimizationRunner runner = new LocalOptimizationRunner(configuration,new ComputationGraphTaskCreator(new ClassificationEvaluator()));
|
IOptimizationRunner runner = new LocalOptimizationRunner(configuration,new ComputationGraphTaskCreator(new ClassificationEvaluator()));
|
||||||
|
@ -159,7 +160,7 @@ public class TestGraphLocalExecution {
|
||||||
runner.execute();
|
runner.execute();
|
||||||
|
|
||||||
List<ResultReference> results = runner.getResults();
|
List<ResultReference> results = runner.getResults();
|
||||||
assertEquals(5, results.size());
|
assertTrue(results.size() > 0);
|
||||||
|
|
||||||
System.out.println("----- COMPLETE - " + results.size() + " results -----");
|
System.out.println("----- COMPLETE - " + results.size() + " results -----");
|
||||||
}
|
}
|
||||||
|
@ -203,8 +204,8 @@ public class TestGraphLocalExecution {
|
||||||
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
|
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
|
||||||
.candidateGenerator(candidateGenerator).dataProvider(dataProvider)
|
.candidateGenerator(candidateGenerator).dataProvider(dataProvider)
|
||||||
.modelSaver(new FileModelSaver(modelSavePath)).scoreFunction(ScoreFunctions.testSetLoss(true))
|
.modelSaver(new FileModelSaver(modelSavePath)).scoreFunction(ScoreFunctions.testSetLoss(true))
|
||||||
.terminationConditions(new MaxTimeCondition(20, TimeUnit.SECONDS),
|
.terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS),
|
||||||
new MaxCandidatesCondition(10))
|
new MaxCandidatesCondition(3))
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
IOptimizationRunner runner = new LocalOptimizationRunner(configuration,
|
IOptimizationRunner runner = new LocalOptimizationRunner(configuration,
|
||||||
|
@ -223,7 +224,7 @@ public class TestGraphLocalExecution {
|
||||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
||||||
.updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.1)))
|
.updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.1)))
|
||||||
.l2(new ContinuousParameterSpace(0.0001, 0.01)).addInputs("in")
|
.l2(new ContinuousParameterSpace(0.0001, 0.01)).addInputs("in")
|
||||||
.setInputTypes(InputType.feedForward(4))
|
.setInputTypes(InputType.feedForward(784))
|
||||||
.addLayer("layer0",
|
.addLayer("layer0",
|
||||||
new DenseLayerSpace.Builder().nIn(784).nOut(new IntegerParameterSpace(2, 10))
|
new DenseLayerSpace.Builder().nIn(784).nOut(new IntegerParameterSpace(2, 10))
|
||||||
.activation(new DiscreteParameterSpace<>(Activation.RELU, Activation.TANH))
|
.activation(new DiscreteParameterSpace<>(Activation.RELU, Activation.TANH))
|
||||||
|
@ -250,8 +251,8 @@ public class TestGraphLocalExecution {
|
||||||
.candidateGenerator(candidateGenerator)
|
.candidateGenerator(candidateGenerator)
|
||||||
.dataProvider(new TestMdsDataProvider(1, 32))
|
.dataProvider(new TestMdsDataProvider(1, 32))
|
||||||
.modelSaver(new FileModelSaver(modelSavePath)).scoreFunction(ScoreFunctions.testSetLoss(true))
|
.modelSaver(new FileModelSaver(modelSavePath)).scoreFunction(ScoreFunctions.testSetLoss(true))
|
||||||
.terminationConditions(new MaxTimeCondition(20, TimeUnit.SECONDS),
|
.terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS),
|
||||||
new MaxCandidatesCondition(10))
|
new MaxCandidatesCondition(3))
|
||||||
.scoreFunction(ScoreFunctions.testSetAccuracy())
|
.scoreFunction(ScoreFunctions.testSetAccuracy())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
@ -279,7 +280,7 @@ public class TestGraphLocalExecution {
|
||||||
@Override
|
@Override
|
||||||
public Object trainData(Map<String, Object> dataParameters) {
|
public Object trainData(Map<String, Object> dataParameters) {
|
||||||
try {
|
try {
|
||||||
DataSetIterator underlying = new MnistDataSetIterator(batchSize, Math.min(60000, 10 * batchSize), false, true, true, 12345);
|
DataSetIterator underlying = new MnistDataSetIterator(batchSize, Math.min(60000, 3 * batchSize), false, true, true, 12345);
|
||||||
return new MultiDataSetIteratorAdapter(new MultipleEpochsIterator(numEpochs, underlying));
|
return new MultiDataSetIteratorAdapter(new MultipleEpochsIterator(numEpochs, underlying));
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
|
@ -289,7 +290,7 @@ public class TestGraphLocalExecution {
|
||||||
@Override
|
@Override
|
||||||
public Object testData(Map<String, Object> dataParameters) {
|
public Object testData(Map<String, Object> dataParameters) {
|
||||||
try {
|
try {
|
||||||
DataSetIterator underlying = new MnistDataSetIterator(batchSize, Math.min(10000, 5 * batchSize), false, false, false, 12345);
|
DataSetIterator underlying = new MnistDataSetIterator(batchSize, Math.min(10000, 2 * batchSize), false, false, false, 12345);
|
||||||
return new MultiDataSetIteratorAdapter(underlying);
|
return new MultiDataSetIteratorAdapter(underlying);
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
|
@ -305,7 +306,7 @@ public class TestGraphLocalExecution {
|
||||||
@Test
|
@Test
|
||||||
public void testLocalExecutionEarlyStopping() throws Exception {
|
public void testLocalExecutionEarlyStopping() throws Exception {
|
||||||
EarlyStoppingConfiguration<ComputationGraph> esConf = new EarlyStoppingConfiguration.Builder<ComputationGraph>()
|
EarlyStoppingConfiguration<ComputationGraph> esConf = new EarlyStoppingConfiguration.Builder<ComputationGraph>()
|
||||||
.epochTerminationConditions(new MaxEpochsTerminationCondition(4))
|
.epochTerminationConditions(new MaxEpochsTerminationCondition(2))
|
||||||
.scoreCalculator(new ScoreProvider())
|
.scoreCalculator(new ScoreProvider())
|
||||||
.modelSaver(new InMemoryModelSaver()).build();
|
.modelSaver(new InMemoryModelSaver()).build();
|
||||||
Map<String, Object> commands = new HashMap<>();
|
Map<String, Object> commands = new HashMap<>();
|
||||||
|
@ -348,8 +349,8 @@ public class TestGraphLocalExecution {
|
||||||
.dataProvider(dataProvider)
|
.dataProvider(dataProvider)
|
||||||
.scoreFunction(ScoreFunctions.testSetF1())
|
.scoreFunction(ScoreFunctions.testSetF1())
|
||||||
.modelSaver(new FileModelSaver(modelSavePath))
|
.modelSaver(new FileModelSaver(modelSavePath))
|
||||||
.terminationConditions(new MaxTimeCondition(45, TimeUnit.SECONDS),
|
.terminationConditions(new MaxTimeCondition(15, TimeUnit.SECONDS),
|
||||||
new MaxCandidatesCondition(10))
|
new MaxCandidatesCondition(3))
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
|
||||||
|
@ -364,7 +365,7 @@ public class TestGraphLocalExecution {
|
||||||
@Override
|
@Override
|
||||||
public ScoreCalculator get() {
|
public ScoreCalculator get() {
|
||||||
try {
|
try {
|
||||||
return new DataSetLossCalculatorCG(new MnistDataSetIterator(128, 1280), true);
|
return new DataSetLossCalculatorCG(new MnistDataSetIterator(4, 8), true);
|
||||||
} catch (Exception e){
|
} catch (Exception e){
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.deeplearning4j.arbiter.computationgraph;
|
package org.deeplearning4j.arbiter.computationgraph;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.arbiter.ComputationGraphSpace;
|
import org.deeplearning4j.arbiter.ComputationGraphSpace;
|
||||||
import org.deeplearning4j.arbiter.conf.updater.AdamSpace;
|
import org.deeplearning4j.arbiter.conf.updater.AdamSpace;
|
||||||
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
|
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
|
||||||
|
@ -79,11 +80,16 @@ import static org.junit.Assert.assertEquals;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class TestGraphLocalExecutionGenetic {
|
public class TestGraphLocalExecutionGenetic extends BaseDL4JTest {
|
||||||
|
|
||||||
@Rule
|
@Rule
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
public TemporaryFolder testDir = new TemporaryFolder();
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getTimeoutMilliseconds() {
|
||||||
|
return 45000L;
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testLocalExecutionDataSources() throws Exception {
|
public void testLocalExecutionDataSources() throws Exception {
|
||||||
for (int dataApproach = 0; dataApproach < 3; dataApproach++) {
|
for (int dataApproach = 0; dataApproach < 3; dataApproach++) {
|
||||||
|
@ -115,7 +121,7 @@ public class TestGraphLocalExecutionGenetic {
|
||||||
if (dataApproach == 0) {
|
if (dataApproach == 0) {
|
||||||
ds = TestDL4JLocalExecution.MnistDataSource.class;
|
ds = TestDL4JLocalExecution.MnistDataSource.class;
|
||||||
dsP = new Properties();
|
dsP = new Properties();
|
||||||
dsP.setProperty("minibatch", "8");
|
dsP.setProperty("minibatch", "2");
|
||||||
|
|
||||||
candidateGenerator = new GeneticSearchCandidateGenerator.Builder(mls, scoreFunction)
|
candidateGenerator = new GeneticSearchCandidateGenerator.Builder(mls, scoreFunction)
|
||||||
.populationModel(new PopulationModel.Builder().populationSize(5).build())
|
.populationModel(new PopulationModel.Builder().populationSize(5).build())
|
||||||
|
@ -148,7 +154,7 @@ public class TestGraphLocalExecutionGenetic {
|
||||||
.dataSource(ds, dsP)
|
.dataSource(ds, dsP)
|
||||||
.modelSaver(new FileModelSaver(modelSave))
|
.modelSaver(new FileModelSaver(modelSave))
|
||||||
.scoreFunction(new TestSetLossScoreFunction())
|
.scoreFunction(new TestSetLossScoreFunction())
|
||||||
.terminationConditions(new MaxTimeCondition(2, TimeUnit.MINUTES),
|
.terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS),
|
||||||
new MaxCandidatesCondition(10))
|
new MaxCandidatesCondition(10))
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
@ -157,7 +163,7 @@ public class TestGraphLocalExecutionGenetic {
|
||||||
runner.execute();
|
runner.execute();
|
||||||
|
|
||||||
List<ResultReference> results = runner.getResults();
|
List<ResultReference> results = runner.getResults();
|
||||||
assertEquals(10, results.size());
|
assertTrue(results.size() > 0);
|
||||||
|
|
||||||
System.out.println("----- COMPLETE - " + results.size() + " results -----");
|
System.out.println("----- COMPLETE - " + results.size() + " results -----");
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.json;
|
package org.deeplearning4j.arbiter.json;
|
||||||
|
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.arbiter.ComputationGraphSpace;
|
import org.deeplearning4j.arbiter.ComputationGraphSpace;
|
||||||
import org.deeplearning4j.arbiter.MultiLayerSpace;
|
import org.deeplearning4j.arbiter.MultiLayerSpace;
|
||||||
import org.deeplearning4j.arbiter.conf.updater.AdaMaxSpace;
|
import org.deeplearning4j.arbiter.conf.updater.AdaMaxSpace;
|
||||||
|
@ -71,7 +72,7 @@ import static org.junit.Assert.assertNotNull;
|
||||||
/**
|
/**
|
||||||
* Created by Alex on 14/02/2017.
|
* Created by Alex on 14/02/2017.
|
||||||
*/
|
*/
|
||||||
public class TestJson {
|
public class TestJson extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testMultiLayerSpaceJson() {
|
public void testMultiLayerSpaceJson() {
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.multilayernetwork;
|
package org.deeplearning4j.arbiter.multilayernetwork;
|
||||||
|
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.arbiter.MultiLayerSpace;
|
import org.deeplearning4j.arbiter.MultiLayerSpace;
|
||||||
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
|
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
|
||||||
import org.deeplearning4j.arbiter.layers.ConvolutionLayerSpace;
|
import org.deeplearning4j.arbiter.layers.ConvolutionLayerSpace;
|
||||||
|
@ -59,7 +60,7 @@ import java.util.concurrent.TimeUnit;
|
||||||
// import org.deeplearning4j.arbiter.optimize.ui.listener.UIOptimizationRunnerStatusListener;
|
// import org.deeplearning4j.arbiter.optimize.ui.listener.UIOptimizationRunnerStatusListener;
|
||||||
|
|
||||||
/** Not strictly a unit test. Rather: part example, part debugging on MNIST */
|
/** Not strictly a unit test. Rather: part example, part debugging on MNIST */
|
||||||
public class MNISTOptimizationTest {
|
public class MNISTOptimizationTest extends BaseDL4JTest {
|
||||||
|
|
||||||
public static void main(String[] args) throws Exception {
|
public static void main(String[] args) throws Exception {
|
||||||
EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
|
EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.deeplearning4j.arbiter.multilayernetwork;
|
package org.deeplearning4j.arbiter.multilayernetwork;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.arbiter.MultiLayerSpace;
|
import org.deeplearning4j.arbiter.MultiLayerSpace;
|
||||||
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
|
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
|
||||||
import org.deeplearning4j.arbiter.evaluator.multilayer.ClassificationEvaluator;
|
import org.deeplearning4j.arbiter.evaluator.multilayer.ClassificationEvaluator;
|
||||||
|
@ -72,9 +73,10 @@ import java.util.Properties;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class TestDL4JLocalExecution {
|
public class TestDL4JLocalExecution extends BaseDL4JTest {
|
||||||
|
|
||||||
@Rule
|
@Rule
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
public TemporaryFolder testDir = new TemporaryFolder();
|
||||||
|
@ -112,7 +114,7 @@ public class TestDL4JLocalExecution {
|
||||||
if(dataApproach == 0){
|
if(dataApproach == 0){
|
||||||
ds = MnistDataSource.class;
|
ds = MnistDataSource.class;
|
||||||
dsP = new Properties();
|
dsP = new Properties();
|
||||||
dsP.setProperty("minibatch", "8");
|
dsP.setProperty("minibatch", "2");
|
||||||
candidateGenerator = new RandomSearchGenerator(mls);
|
candidateGenerator = new RandomSearchGenerator(mls);
|
||||||
} else if(dataApproach == 1) {
|
} else if(dataApproach == 1) {
|
||||||
//DataProvider approach
|
//DataProvider approach
|
||||||
|
@ -136,7 +138,7 @@ public class TestDL4JLocalExecution {
|
||||||
.dataSource(ds, dsP)
|
.dataSource(ds, dsP)
|
||||||
.modelSaver(new FileModelSaver(modelSave))
|
.modelSaver(new FileModelSaver(modelSave))
|
||||||
.scoreFunction(new TestSetLossScoreFunction())
|
.scoreFunction(new TestSetLossScoreFunction())
|
||||||
.terminationConditions(new MaxTimeCondition(2, TimeUnit.MINUTES),
|
.terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS),
|
||||||
new MaxCandidatesCondition(5))
|
new MaxCandidatesCondition(5))
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
@ -146,7 +148,7 @@ public class TestDL4JLocalExecution {
|
||||||
runner.execute();
|
runner.execute();
|
||||||
|
|
||||||
List<ResultReference> results = runner.getResults();
|
List<ResultReference> results = runner.getResults();
|
||||||
assertEquals(5, results.size());
|
assertTrue(results.size() > 0);
|
||||||
|
|
||||||
System.out.println("----- COMPLETE - " + results.size() + " results -----");
|
System.out.println("----- COMPLETE - " + results.size() + " results -----");
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.multilayernetwork;
|
package org.deeplearning4j.arbiter.multilayernetwork;
|
||||||
|
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.arbiter.ComputationGraphSpace;
|
import org.deeplearning4j.arbiter.ComputationGraphSpace;
|
||||||
import org.deeplearning4j.arbiter.MultiLayerSpace;
|
import org.deeplearning4j.arbiter.MultiLayerSpace;
|
||||||
import org.deeplearning4j.arbiter.layers.DenseLayerSpace;
|
import org.deeplearning4j.arbiter.layers.DenseLayerSpace;
|
||||||
|
@ -39,7 +40,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
|
|
||||||
public class TestErrors {
|
public class TestErrors extends BaseDL4JTest {
|
||||||
|
|
||||||
@Rule
|
@Rule
|
||||||
public TemporaryFolder temp = new TemporaryFolder();
|
public TemporaryFolder temp = new TemporaryFolder();
|
||||||
|
@ -60,7 +61,7 @@ public class TestErrors {
|
||||||
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls);
|
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls);
|
||||||
|
|
||||||
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
|
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
|
||||||
.candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 10))
|
.candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 3))
|
||||||
.modelSaver(new FileModelSaver(f)).scoreFunction(new TestSetLossScoreFunction(true))
|
.modelSaver(new FileModelSaver(f)).scoreFunction(new TestSetLossScoreFunction(true))
|
||||||
.terminationConditions(
|
.terminationConditions(
|
||||||
new MaxCandidatesCondition(5))
|
new MaxCandidatesCondition(5))
|
||||||
|
@ -87,7 +88,7 @@ public class TestErrors {
|
||||||
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls);
|
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls);
|
||||||
|
|
||||||
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
|
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
|
||||||
.candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 10))
|
.candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 3))
|
||||||
.modelSaver(new FileModelSaver(f)).scoreFunction(new TestSetLossScoreFunction(true))
|
.modelSaver(new FileModelSaver(f)).scoreFunction(new TestSetLossScoreFunction(true))
|
||||||
.terminationConditions(
|
.terminationConditions(
|
||||||
new MaxCandidatesCondition(5))
|
new MaxCandidatesCondition(5))
|
||||||
|
@ -116,7 +117,7 @@ public class TestErrors {
|
||||||
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls);
|
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls);
|
||||||
|
|
||||||
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
|
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
|
||||||
.candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 10))
|
.candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 3))
|
||||||
.modelSaver(new FileModelSaver(f)).scoreFunction(new TestSetLossScoreFunction(true))
|
.modelSaver(new FileModelSaver(f)).scoreFunction(new TestSetLossScoreFunction(true))
|
||||||
.terminationConditions(new MaxCandidatesCondition(5))
|
.terminationConditions(new MaxCandidatesCondition(5))
|
||||||
.build();
|
.build();
|
||||||
|
@ -143,7 +144,7 @@ public class TestErrors {
|
||||||
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls);
|
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls);
|
||||||
|
|
||||||
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
|
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
|
||||||
.candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 10))
|
.candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 3))
|
||||||
.modelSaver(new FileModelSaver(f)).scoreFunction(new TestSetLossScoreFunction(true))
|
.modelSaver(new FileModelSaver(f)).scoreFunction(new TestSetLossScoreFunction(true))
|
||||||
.terminationConditions(
|
.terminationConditions(
|
||||||
new MaxCandidatesCondition(5))
|
new MaxCandidatesCondition(5))
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.deeplearning4j.arbiter.multilayernetwork;
|
package org.deeplearning4j.arbiter.multilayernetwork;
|
||||||
|
|
||||||
import org.apache.commons.lang3.ArrayUtils;
|
import org.apache.commons.lang3.ArrayUtils;
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.arbiter.TestUtils;
|
import org.deeplearning4j.arbiter.TestUtils;
|
||||||
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
|
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
|
||||||
import org.deeplearning4j.arbiter.layers.*;
|
import org.deeplearning4j.arbiter.layers.*;
|
||||||
|
@ -44,7 +45,7 @@ import java.util.Random;
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
public class TestLayerSpace {
|
public class TestLayerSpace extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBasic1() {
|
public void testBasic1() {
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.multilayernetwork;
|
package org.deeplearning4j.arbiter.multilayernetwork;
|
||||||
|
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.arbiter.DL4JConfiguration;
|
import org.deeplearning4j.arbiter.DL4JConfiguration;
|
||||||
import org.deeplearning4j.arbiter.MultiLayerSpace;
|
import org.deeplearning4j.arbiter.MultiLayerSpace;
|
||||||
import org.deeplearning4j.arbiter.TestUtils;
|
import org.deeplearning4j.arbiter.TestUtils;
|
||||||
|
@ -86,7 +87,7 @@ import java.util.*;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.Assert.*;
|
||||||
|
|
||||||
public class TestMultiLayerSpace {
|
public class TestMultiLayerSpace extends BaseDL4JTest {
|
||||||
|
|
||||||
@Rule
|
@Rule
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
public TemporaryFolder testDir = new TemporaryFolder();
|
||||||
|
|
|
@ -18,6 +18,7 @@ package org.deeplearning4j.arbiter.multilayernetwork;
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.arbiter.MultiLayerSpace;
|
import org.deeplearning4j.arbiter.MultiLayerSpace;
|
||||||
import org.deeplearning4j.arbiter.conf.updater.AdamSpace;
|
import org.deeplearning4j.arbiter.conf.updater.AdamSpace;
|
||||||
import org.deeplearning4j.arbiter.layers.OutputLayerSpace;
|
import org.deeplearning4j.arbiter.layers.OutputLayerSpace;
|
||||||
|
@ -60,7 +61,13 @@ import java.util.Map;
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class TestScoreFunctions {
|
public class TestScoreFunctions extends BaseDL4JTest {
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getTimeoutMilliseconds() {
|
||||||
|
return 60000L;
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testROCScoreFunctions() throws Exception {
|
public void testROCScoreFunctions() throws Exception {
|
||||||
|
@ -107,7 +114,7 @@ public class TestScoreFunctions {
|
||||||
List<ResultReference> list = runner.getResults();
|
List<ResultReference> list = runner.getResults();
|
||||||
|
|
||||||
for (ResultReference rr : list) {
|
for (ResultReference rr : list) {
|
||||||
DataSetIterator testIter = new MnistDataSetIterator(32, 2000, false, false, true, 12345);
|
DataSetIterator testIter = new MnistDataSetIterator(4, 16, false, false, false, 12345);
|
||||||
testIter.setPreProcessor(new PreProc(rocType));
|
testIter.setPreProcessor(new PreProc(rocType));
|
||||||
|
|
||||||
OptimizationResult or = rr.getResult();
|
OptimizationResult or = rr.getResult();
|
||||||
|
@ -141,10 +148,10 @@ public class TestScoreFunctions {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
DataSetIterator iter = new MnistDataSetIterator(32, 8000, false, true, true, 12345);
|
DataSetIterator iter = new MnistDataSetIterator(4, 16, false, false, false, 12345);
|
||||||
iter.setPreProcessor(new PreProc(rocType));
|
iter.setPreProcessor(new PreProc(rocType));
|
||||||
|
|
||||||
assertEquals(msg, expScore, or.getScore(), 1e-5);
|
assertEquals(msg, expScore, or.getScore(), 1e-4);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -158,7 +165,7 @@ public class TestScoreFunctions {
|
||||||
@Override
|
@Override
|
||||||
public Object trainData(Map<String, Object> dataParameters) {
|
public Object trainData(Map<String, Object> dataParameters) {
|
||||||
try {
|
try {
|
||||||
DataSetIterator iter = new MnistDataSetIterator(32, 8000, false, true, true, 12345);
|
DataSetIterator iter = new MnistDataSetIterator(4, 16, false, false, false, 12345);
|
||||||
iter.setPreProcessor(new PreProc(rocType));
|
iter.setPreProcessor(new PreProc(rocType));
|
||||||
return iter;
|
return iter;
|
||||||
} catch (IOException e){
|
} catch (IOException e){
|
||||||
|
@ -169,7 +176,7 @@ public class TestScoreFunctions {
|
||||||
@Override
|
@Override
|
||||||
public Object testData(Map<String, Object> dataParameters) {
|
public Object testData(Map<String, Object> dataParameters) {
|
||||||
try {
|
try {
|
||||||
DataSetIterator iter = new MnistDataSetIterator(32, 2000, false, false, true, 12345);
|
DataSetIterator iter = new MnistDataSetIterator(4, 16, false, false, false, 12345);
|
||||||
iter.setPreProcessor(new PreProc(rocType));
|
iter.setPreProcessor(new PreProc(rocType));
|
||||||
return iter;
|
return iter;
|
||||||
} catch (IOException e){
|
} catch (IOException e){
|
||||||
|
|
|
@ -49,6 +49,13 @@
|
||||||
<version>${junit.version}</version>
|
<version>${junit.version}</version>
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.deeplearning4j</groupId>
|
||||||
|
<artifactId>deeplearning4j-common-tests</artifactId>
|
||||||
|
<version>${project.version}</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
|
|
|
@ -18,6 +18,7 @@ package org.deeplearning4j.arbiter.server;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.arbiter.MultiLayerSpace;
|
import org.deeplearning4j.arbiter.MultiLayerSpace;
|
||||||
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
|
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
|
||||||
import org.deeplearning4j.arbiter.layers.DenseLayerSpace;
|
import org.deeplearning4j.arbiter.layers.DenseLayerSpace;
|
||||||
|
@ -52,7 +53,7 @@ import static org.junit.Assert.assertEquals;
|
||||||
* Created by agibsonccc on 3/12/17.
|
* Created by agibsonccc on 3/12/17.
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class ArbiterCLIRunnerTest {
|
public class ArbiterCLIRunnerTest extends BaseDL4JTest {
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
|
@ -24,6 +24,7 @@ import org.junit.Before;
|
||||||
import org.junit.Rule;
|
import org.junit.Rule;
|
||||||
import org.junit.rules.TestName;
|
import org.junit.rules.TestName;
|
||||||
import org.junit.rules.Timeout;
|
import org.junit.rules.Timeout;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.config.ND4JSystemProperties;
|
import org.nd4j.config.ND4JSystemProperties;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||||
|
@ -36,6 +37,8 @@ import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Properties;
|
import java.util.Properties;
|
||||||
|
|
||||||
|
import static org.junit.Assume.assumeTrue;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public abstract class BaseDL4JTest {
|
public abstract class BaseDL4JTest {
|
||||||
|
|
||||||
|
@ -47,6 +50,17 @@ public abstract class BaseDL4JTest {
|
||||||
protected long startTime;
|
protected long startTime;
|
||||||
protected int threadCountBefore;
|
protected int threadCountBefore;
|
||||||
|
|
||||||
|
private final int DEFAULT_THREADS = Runtime.getRuntime().availableProcessors();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Override this to specify the number of threads for C++ execution, via
|
||||||
|
* {@link org.nd4j.linalg.factory.Environment#setMaxMasterThreads(int)}
|
||||||
|
* @return Number of threads to use for C++ op execution
|
||||||
|
*/
|
||||||
|
public int numThreads(){
|
||||||
|
return DEFAULT_THREADS;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Override this method to set the default timeout for methods in the test class
|
* Override this method to set the default timeout for methods in the test class
|
||||||
*/
|
*/
|
||||||
|
@ -72,6 +86,28 @@ public abstract class BaseDL4JTest {
|
||||||
return getDataType();
|
return getDataType();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected Boolean integrationTest;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return True if integration tests maven profile is enabled, false otherwise.
|
||||||
|
*/
|
||||||
|
public boolean isIntegrationTests(){
|
||||||
|
if(integrationTest == null){
|
||||||
|
String prop = System.getenv("DL4J_INTEGRATION_TESTS");
|
||||||
|
integrationTest = Boolean.parseBoolean(prop);
|
||||||
|
}
|
||||||
|
return integrationTest;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Call this as the first line of a test in order to skip that test, only when the integration tests maven profile is not enabled.
|
||||||
|
* This can be used to dynamically skip integration tests when the integration test profile is not enabled.
|
||||||
|
* Note that the integration test profile is not enabled by default - "integration-tests" profile
|
||||||
|
*/
|
||||||
|
public void skipUnlessIntegrationTests(){
|
||||||
|
assumeTrue("Skipping integration test - integration profile is not enabled", isIntegrationTests());
|
||||||
|
}
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void beforeTest(){
|
public void beforeTest(){
|
||||||
log.info("{}.{}", getClass().getSimpleName(), name.getMethodName());
|
log.info("{}.{}", getClass().getSimpleName(), name.getMethodName());
|
||||||
|
@ -81,6 +117,14 @@ public abstract class BaseDL4JTest {
|
||||||
Nd4j.getExecutioner().setProfilingMode(getProfilingMode());
|
Nd4j.getExecutioner().setProfilingMode(getProfilingMode());
|
||||||
Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build());
|
Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build());
|
||||||
Nd4j.setDefaultDataTypes(getDataType(), getDefaultFPDataType());
|
Nd4j.setDefaultDataTypes(getDataType(), getDefaultFPDataType());
|
||||||
|
Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build());
|
||||||
|
Nd4j.getExecutioner().enableDebugMode(false);
|
||||||
|
Nd4j.getExecutioner().enableVerboseMode(false);
|
||||||
|
int numThreads = numThreads();
|
||||||
|
Preconditions.checkState(numThreads > 0, "Number of threads must be > 0");
|
||||||
|
if(numThreads != Nd4j.getEnvironment().maxMasterThreads()) {
|
||||||
|
Nd4j.getEnvironment().setMaxMasterThreads(numThreads);
|
||||||
|
}
|
||||||
startTime = System.currentTimeMillis();
|
startTime = System.currentTimeMillis();
|
||||||
threadCountBefore = ManagementFactory.getThreadMXBean().getThreadCount();
|
threadCountBefore = ManagementFactory.getThreadMXBean().getThreadCount();
|
||||||
}
|
}
|
||||||
|
|
|
@ -158,7 +158,7 @@ public class LayerHelperValidationUtil {
|
||||||
double d2 = arr2.dup('c').getDouble(idx);
|
double d2 = arr2.dup('c').getDouble(idx);
|
||||||
System.out.println("Different values at index " + idx + ": " + d1 + ", " + d2 + " - RE = " + maxRE);
|
System.out.println("Different values at index " + idx + ": " + d1 + ", " + d2 + " - RE = " + maxRE);
|
||||||
}
|
}
|
||||||
assertTrue(s + layerName + "activations - max RE: " + maxRE, maxRE < t.getMaxRelError());
|
assertTrue(s + layerName + " activations - max RE: " + maxRE, maxRE < t.getMaxRelError());
|
||||||
log.info("Forward pass, max relative error: " + layerName + " - " + maxRE);
|
log.info("Forward pass, max relative error: " + layerName + " - " + maxRE);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -78,8 +78,16 @@ public class AsyncDataSetIteratorTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void hasNextWithResetAndLoad() throws Exception {
|
public void hasNextWithResetAndLoad() throws Exception {
|
||||||
|
int[] prefetchSizes;
|
||||||
|
if(isIntegrationTests()){
|
||||||
|
prefetchSizes = new int[]{2, 3, 4, 5, 6, 7, 8};
|
||||||
|
} else {
|
||||||
|
prefetchSizes = new int[]{2, 3, 8};
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
for (int iter = 0; iter < ITERATIONS; iter++) {
|
for (int iter = 0; iter < ITERATIONS; iter++) {
|
||||||
for (int prefetchSize = 2; prefetchSize <= 8; prefetchSize++) {
|
for(int prefetchSize : prefetchSizes){
|
||||||
AsyncDataSetIterator iterator = new AsyncDataSetIterator(backIterator, prefetchSize);
|
AsyncDataSetIterator iterator = new AsyncDataSetIterator(backIterator, prefetchSize);
|
||||||
TestDataSetConsumer consumer = new TestDataSetConsumer(EXECUTION_SMALL);
|
TestDataSetConsumer consumer = new TestDataSetConsumer(EXECUTION_SMALL);
|
||||||
int cnt = 0;
|
int cnt = 0;
|
||||||
|
@ -161,8 +169,14 @@ public class AsyncDataSetIteratorTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testVariableTimeSeries1() throws Exception {
|
public void testVariableTimeSeries1() throws Exception {
|
||||||
|
int numBatches = isIntegrationTests() ? 1000 : 100;
|
||||||
|
int batchSize = isIntegrationTests() ? 32 : 8;
|
||||||
|
int timeStepsMin = 10;
|
||||||
|
int timeStepsMax = isIntegrationTests() ? 500 : 100;
|
||||||
|
int valuesPerTimestep = isIntegrationTests() ? 128 : 16;
|
||||||
|
|
||||||
AsyncDataSetIterator adsi = new AsyncDataSetIterator(
|
AsyncDataSetIterator adsi = new AsyncDataSetIterator(
|
||||||
new VariableTimeseriesGenerator(1192, 1000, 32, 128, 10, 500, 10), 2, true);
|
new VariableTimeseriesGenerator(1192, numBatches, batchSize, valuesPerTimestep, timeStepsMin, timeStepsMax, 10), 2, true);
|
||||||
|
|
||||||
for (int e = 0; e < 10; e++) {
|
for (int e = 0; e < 10; e++) {
|
||||||
int cnt = 0;
|
int cnt = 0;
|
||||||
|
|
|
@ -18,21 +18,10 @@ package org.deeplearning4j.datasets.iterator;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.datavec.api.records.reader.RecordReader;
|
|
||||||
import org.datavec.api.records.reader.SequenceRecordReader;
|
|
||||||
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
|
|
||||||
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
|
|
||||||
import org.datavec.api.split.NumberedFileInputSplit;
|
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator;
|
|
||||||
import org.deeplearning4j.datasets.iterator.tools.VariableMultiTimeseriesGenerator;
|
import org.deeplearning4j.datasets.iterator.tools.VariableMultiTimeseriesGenerator;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
|
||||||
import org.nd4j.linalg.dataset.api.preprocessor.MultiDataNormalization;
|
|
||||||
import org.nd4j.linalg.dataset.api.preprocessor.MultiNormalizerStandardize;
|
|
||||||
|
|
||||||
import java.util.Arrays;
|
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
|
@ -49,7 +38,13 @@ public class AsyncMultiDataSetIteratorTest extends BaseDL4JTest {
|
||||||
*/
|
*/
|
||||||
@Test
|
@Test
|
||||||
public void testVariableTimeSeries1() throws Exception {
|
public void testVariableTimeSeries1() throws Exception {
|
||||||
val iterator = new VariableMultiTimeseriesGenerator(1192, 1000, 32, 128, 10, 500, 10);
|
int numBatches = isIntegrationTests() ? 1000 : 100;
|
||||||
|
int batchSize = isIntegrationTests() ? 32 : 8;
|
||||||
|
int timeStepsMin = 10;
|
||||||
|
int timeStepsMax = isIntegrationTests() ? 500 : 100;
|
||||||
|
int valuesPerTimestep = isIntegrationTests() ? 128 : 16;
|
||||||
|
|
||||||
|
val iterator = new VariableMultiTimeseriesGenerator(1192, numBatches, batchSize, valuesPerTimestep, timeStepsMin, timeStepsMax, 10);
|
||||||
iterator.reset();
|
iterator.reset();
|
||||||
iterator.hasNext();
|
iterator.hasNext();
|
||||||
val amdsi = new AsyncMultiDataSetIterator(iterator, 2, true);
|
val amdsi = new AsyncMultiDataSetIterator(iterator, 2, true);
|
||||||
|
@ -81,7 +76,13 @@ public class AsyncMultiDataSetIteratorTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testVariableTimeSeries2() throws Exception {
|
public void testVariableTimeSeries2() throws Exception {
|
||||||
val iterator = new VariableMultiTimeseriesGenerator(1192, 1000, 32, 128, 10, 500, 10);
|
int numBatches = isIntegrationTests() ? 1000 : 100;
|
||||||
|
int batchSize = isIntegrationTests() ? 32 : 8;
|
||||||
|
int timeStepsMin = 10;
|
||||||
|
int timeStepsMax = isIntegrationTests() ? 500 : 100;
|
||||||
|
int valuesPerTimestep = isIntegrationTests() ? 128 : 16;
|
||||||
|
|
||||||
|
val iterator = new VariableMultiTimeseriesGenerator(1192, numBatches, batchSize, valuesPerTimestep, timeStepsMin, timeStepsMax, 10);
|
||||||
|
|
||||||
for (int e = 0; e < 10; e++) {
|
for (int e = 0; e < 10; e++) {
|
||||||
iterator.reset();
|
iterator.reset();
|
||||||
|
|
|
@ -46,17 +46,17 @@ public class TestEmnistDataSetIterator extends BaseDL4JTest {
|
||||||
@Test
|
@Test
|
||||||
public void testEmnistDataSetIterator() throws Exception {
|
public void testEmnistDataSetIterator() throws Exception {
|
||||||
|
|
||||||
// EmnistFetcher fetcher = new EmnistFetcher(EmnistDataSetIterator.Set.COMPLETE);
|
|
||||||
// File baseEmnistDir = fetcher.getFILE_DIR();
|
|
||||||
// if(baseEmnistDir.exists()){
|
|
||||||
// FileUtils.deleteDirectory(baseEmnistDir);
|
|
||||||
// }
|
|
||||||
// assertFalse(baseEmnistDir.exists());
|
|
||||||
|
|
||||||
|
|
||||||
int batchSize = 128;
|
int batchSize = 128;
|
||||||
|
|
||||||
for (EmnistDataSetIterator.Set s : EmnistDataSetIterator.Set.values()) {
|
EmnistDataSetIterator.Set[] sets;
|
||||||
|
if(isIntegrationTests()){
|
||||||
|
sets = EmnistDataSetIterator.Set.values();
|
||||||
|
} else {
|
||||||
|
sets = new EmnistDataSetIterator.Set[]{EmnistDataSetIterator.Set.MNIST, EmnistDataSetIterator.Set.LETTERS};
|
||||||
|
}
|
||||||
|
|
||||||
|
for (EmnistDataSetIterator.Set s : sets) {
|
||||||
boolean isBalanced = EmnistDataSetIterator.isBalanced(s);
|
boolean isBalanced = EmnistDataSetIterator.isBalanced(s);
|
||||||
int numLabels = EmnistDataSetIterator.numLabels(s);
|
int numLabels = EmnistDataSetIterator.numLabels(s);
|
||||||
INDArray labelCounts = null;
|
INDArray labelCounts = null;
|
||||||
|
|
|
@ -476,7 +476,7 @@ public class EvalTest extends BaseDL4JTest {
|
||||||
|
|
||||||
net.setListeners(new EvaluativeListener(iterTest, 3));
|
net.setListeners(new EvaluativeListener(iterTest, 3));
|
||||||
|
|
||||||
for( int i=0; i<10; i++ ){
|
for( int i=0; i<3; i++ ){
|
||||||
net.fit(iter);
|
net.fit(iter);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -339,9 +339,6 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest {
|
||||||
|
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
log.info(msg);
|
log.info(msg);
|
||||||
for (int j = 0; j < net.getnLayers(); j++) {
|
|
||||||
log.info("Layer " + j + " # params: " + net.getLayer(j).numParams());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS,
|
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS,
|
||||||
|
@ -623,13 +620,10 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest {
|
||||||
|
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
log.info(msg);
|
log.info(msg);
|
||||||
// for (int j = 0; j < net.getnLayers(); j++) {
|
|
||||||
// log.info("Layer " + j + " # params: " + net.getLayer(j).numParams());
|
|
||||||
// }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input)
|
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input)
|
||||||
.labels(labels).subset(true).maxPerParam(128));
|
.labels(labels).subset(true).maxPerParam(64));
|
||||||
|
|
||||||
assertTrue(msg, gradOK);
|
assertTrue(msg, gradOK);
|
||||||
|
|
||||||
|
|
|
@ -557,60 +557,52 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
int[] minibatchSizes = {2};
|
int[] minibatchSizes = {2};
|
||||||
int width = 5;
|
int width = 5;
|
||||||
int height = 5;
|
int height = 5;
|
||||||
int[] inputDepths = {1, 2, 4};
|
|
||||||
|
|
||||||
Activation[] activations = {Activation.SIGMOID, Activation.TANH};
|
|
||||||
|
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
for (int inputDepth : inputDepths) {
|
int[] inputDepths = new int[]{1, 2, 4};
|
||||||
for (Activation afn : activations) {
|
Activation[] activations = {Activation.SIGMOID, Activation.TANH, Activation.SOFTPLUS};
|
||||||
for (int minibatchSize : minibatchSizes) {
|
int[] minibatch = {2, 1, 3};
|
||||||
INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth);
|
|
||||||
INDArray labels = Nd4j.zeros(minibatchSize, nOut);
|
|
||||||
for (int i = 0; i < minibatchSize; i++) {
|
|
||||||
labels.putScalar(new int[]{i, i % nOut}, 1.0);
|
|
||||||
}
|
|
||||||
|
|
||||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(new NoOp())
|
for( int i=0; i<inputDepths.length; i++ ){
|
||||||
.dataType(DataType.DOUBLE)
|
int inputDepth = inputDepths[i];
|
||||||
.activation(afn)
|
Activation afn = activations[i];
|
||||||
.list()
|
int minibatchSize = minibatch[i];
|
||||||
.layer(0, new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1)
|
INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth);
|
||||||
.padding(0, 0).nIn(inputDepth).nOut(2).build())//output: (5-2+0)/1+1 = 4
|
INDArray labels = TestUtils.randomOneHot(minibatchSize, nOut);
|
||||||
.layer(1, new LocallyConnected2D.Builder().nIn(2).nOut(7).kernelSize(2, 2)
|
|
||||||
.setInputSize(4, 4).convolutionMode(ConvolutionMode.Strict).hasBias(false)
|
|
||||||
.stride(1, 1).padding(0, 0).build()) //(4-2+0)/1+1 = 3
|
|
||||||
.layer(2, new ConvolutionLayer.Builder().nIn(7).nOut(2).kernelSize(2, 2)
|
|
||||||
.stride(1, 1).padding(0, 0).build()) //(3-2+0)/1+1 = 2
|
|
||||||
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
|
|
||||||
.activation(Activation.SOFTMAX).nIn(2 * 2 * 2).nOut(nOut)
|
|
||||||
.build())
|
|
||||||
.setInputType(InputType.convolutionalFlat(height, width, inputDepth)).build();
|
|
||||||
|
|
||||||
assertEquals(ConvolutionMode.Truncate,
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(new NoOp())
|
||||||
((ConvolutionLayer) conf.getConf(0).getLayer()).getConvolutionMode());
|
.dataType(DataType.DOUBLE)
|
||||||
|
.activation(afn)
|
||||||
|
.list()
|
||||||
|
.layer(0, new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1)
|
||||||
|
.padding(0, 0).nIn(inputDepth).nOut(2).build())//output: (5-2+0)/1+1 = 4
|
||||||
|
.layer(1, new LocallyConnected2D.Builder().nIn(2).nOut(7).kernelSize(2, 2)
|
||||||
|
.setInputSize(4, 4).convolutionMode(ConvolutionMode.Strict).hasBias(false)
|
||||||
|
.stride(1, 1).padding(0, 0).build()) //(4-2+0)/1+1 = 3
|
||||||
|
.layer(2, new ConvolutionLayer.Builder().nIn(7).nOut(2).kernelSize(2, 2)
|
||||||
|
.stride(1, 1).padding(0, 0).build()) //(3-2+0)/1+1 = 2
|
||||||
|
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
|
||||||
|
.activation(Activation.SOFTMAX).nIn(2 * 2 * 2).nOut(nOut)
|
||||||
|
.build())
|
||||||
|
.setInputType(InputType.convolutionalFlat(height, width, inputDepth)).build();
|
||||||
|
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
assertEquals(ConvolutionMode.Truncate,
|
||||||
net.init();
|
((ConvolutionLayer) conf.getConf(0).getLayer()).getConvolutionMode());
|
||||||
|
|
||||||
for (int i = 0; i < 4; i++) {
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
System.out.println("nParams, layer " + i + ": " + net.getLayer(i).numParams());
|
net.init();
|
||||||
}
|
|
||||||
|
|
||||||
String msg = "Minibatch=" + minibatchSize + ", activationFn="
|
String msg = "Minibatch=" + minibatchSize + ", activationFn="
|
||||||
+ afn;
|
+ afn;
|
||||||
System.out.println(msg);
|
System.out.println(msg);
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||||
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
|
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
|
||||||
|
|
||||||
assertTrue(msg, gradOK);
|
assertTrue(msg, gradOK);
|
||||||
|
|
||||||
TestUtils.testModelSerialization(net);
|
TestUtils.testModelSerialization(net);
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1106,71 +1098,64 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
public void testCropping2DLayer() {
|
public void testCropping2DLayer() {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
int nOut = 2;
|
int nOut = 2;
|
||||||
|
|
||||||
int[] minibatchSizes = {1, 3};
|
|
||||||
int width = 12;
|
int width = 12;
|
||||||
int height = 11;
|
int height = 11;
|
||||||
int[] inputDepths = {1, 3};
|
|
||||||
|
|
||||||
int[] kernel = {2, 2};
|
int[] kernel = {2, 2};
|
||||||
int[] stride = {1, 1};
|
int[] stride = {1, 1};
|
||||||
int[] padding = {0, 0};
|
int[] padding = {0, 0};
|
||||||
|
|
||||||
int[][] cropTestCases = new int[][]{{0, 0, 0, 0}, {1, 1, 0, 0}, {2, 2, 2, 2}, {1, 2, 3, 4}};
|
int[][] cropTestCases = new int[][]{{0, 0, 0, 0}, {1, 1, 0, 0}, {2, 2, 2, 2}, {1, 2, 3, 4}};
|
||||||
|
int[] inputDepths = {1, 2, 3, 2};
|
||||||
|
int[] minibatchSizes = {2, 1, 3, 2};
|
||||||
|
|
||||||
for (int inputDepth : inputDepths) {
|
for (int i = 0; i < cropTestCases.length; i++) {
|
||||||
for (int minibatchSize : minibatchSizes) {
|
int inputDepth = inputDepths[i];
|
||||||
INDArray input = Nd4j.rand(new int[]{minibatchSize, inputDepth, height, width});
|
int minibatchSize = minibatchSizes[i];
|
||||||
INDArray labels = Nd4j.zeros(minibatchSize, nOut);
|
int[] crop = cropTestCases[i];
|
||||||
for (int i = 0; i < minibatchSize; i++) {
|
INDArray input = Nd4j.rand(new int[]{minibatchSize, inputDepth, height, width});
|
||||||
labels.putScalar(new int[]{i, i % nOut}, 1.0);
|
INDArray labels = TestUtils.randomOneHot(minibatchSize, nOut);
|
||||||
}
|
|
||||||
for (int[] crop : cropTestCases) {
|
|
||||||
|
|
||||||
MultiLayerConfiguration conf =
|
MultiLayerConfiguration conf =
|
||||||
new NeuralNetConfiguration.Builder()
|
new NeuralNetConfiguration.Builder()
|
||||||
.dataType(DataType.DOUBLE)
|
.dataType(DataType.DOUBLE)
|
||||||
.updater(new NoOp())
|
.updater(new NoOp())
|
||||||
.convolutionMode(ConvolutionMode.Same)
|
.convolutionMode(ConvolutionMode.Same)
|
||||||
.weightInit(new NormalDistribution(0, 1)).list()
|
.weightInit(new NormalDistribution(0, 1)).list()
|
||||||
.layer(new ConvolutionLayer.Builder(kernel, stride, padding)
|
.layer(new ConvolutionLayer.Builder(kernel, stride, padding)
|
||||||
.nIn(inputDepth).nOut(2).build())//output: (6-2+0)/1+1 = 5
|
.nIn(inputDepth).nOut(2).build())//output: (6-2+0)/1+1 = 5
|
||||||
.layer(new Cropping2D(crop))
|
.layer(new Cropping2D(crop))
|
||||||
.layer(new ConvolutionLayer.Builder(kernel, stride, padding).nIn(2).nOut(2).build())
|
.layer(new ConvolutionLayer.Builder(kernel, stride, padding).nIn(2).nOut(2).build())
|
||||||
.layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.AVG).kernelSize(3, 3).stride(3, 3).build())
|
.layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.AVG).kernelSize(3, 3).stride(3, 3).build())
|
||||||
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
|
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
|
||||||
.activation(Activation.SOFTMAX).nOut(nOut).build())
|
.activation(Activation.SOFTMAX).nOut(nOut).build())
|
||||||
.setInputType(InputType.convolutional(height, width, inputDepth))
|
.setInputType(InputType.convolutional(height, width, inputDepth))
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
//Check cropping activation shape
|
//Check cropping activation shape
|
||||||
org.deeplearning4j.nn.layers.convolution.Cropping2DLayer cl =
|
org.deeplearning4j.nn.layers.convolution.Cropping2DLayer cl =
|
||||||
(org.deeplearning4j.nn.layers.convolution.Cropping2DLayer) net.getLayer(1);
|
(org.deeplearning4j.nn.layers.convolution.Cropping2DLayer) net.getLayer(1);
|
||||||
val expShape = new long[]{minibatchSize, inputDepth, height - crop[0] - crop[1],
|
val expShape = new long[]{minibatchSize, inputDepth, height - crop[0] - crop[1],
|
||||||
width - crop[2] - crop[3]};
|
width - crop[2] - crop[3]};
|
||||||
INDArray out = cl.activate(input, false, LayerWorkspaceMgr.noWorkspaces());
|
INDArray out = cl.activate(input, false, LayerWorkspaceMgr.noWorkspaces());
|
||||||
assertArrayEquals(expShape, out.shape());
|
assertArrayEquals(expShape, out.shape());
|
||||||
|
|
||||||
String msg = "minibatch=" + minibatchSize + ", channels=" + inputDepth + ", zeroPad = "
|
String msg = "minibatch=" + minibatchSize + ", channels=" + inputDepth + ", zeroPad = "
|
||||||
+ Arrays.toString(crop);
|
+ Arrays.toString(crop);
|
||||||
|
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println(msg);
|
System.out.println(msg);
|
||||||
// for (int j = 0; j < net.getnLayers(); j++)
|
|
||||||
// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
|
|
||||||
}
|
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input)
|
|
||||||
.labels(labels).subset(true).maxPerParam(160));
|
|
||||||
|
|
||||||
assertTrue(msg, gradOK);
|
|
||||||
|
|
||||||
TestUtils.testModelSerialization(net);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input)
|
||||||
|
.labels(labels).subset(true).maxPerParam(160));
|
||||||
|
|
||||||
|
assertTrue(msg, gradOK);
|
||||||
|
|
||||||
|
TestUtils.testModelSerialization(net);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -142,9 +142,9 @@ public class GradientCheckTestsMasking extends BaseDL4JTest {
|
||||||
Nd4j.getRandom().setSeed(12345L);
|
Nd4j.getRandom().setSeed(12345L);
|
||||||
|
|
||||||
int timeSeriesLength = 5;
|
int timeSeriesLength = 5;
|
||||||
int nIn = 5;
|
int nIn = 3;
|
||||||
int layerSize = 3;
|
int layerSize = 3;
|
||||||
int nOut = 3;
|
int nOut = 2;
|
||||||
|
|
||||||
int miniBatchSize = 2;
|
int miniBatchSize = 2;
|
||||||
|
|
||||||
|
@ -170,24 +170,16 @@ public class GradientCheckTestsMasking extends BaseDL4JTest {
|
||||||
MultiLayerNetwork mln = new MultiLayerNetwork(conf);
|
MultiLayerNetwork mln = new MultiLayerNetwork(conf);
|
||||||
mln.init();
|
mln.init();
|
||||||
|
|
||||||
Random r = new Random(12345L);
|
|
||||||
INDArray input = Nd4j.rand(new int[]{miniBatchSize, nIn, timeSeriesLength}, 'f').subi(0.5);
|
INDArray input = Nd4j.rand(new int[]{miniBatchSize, nIn, timeSeriesLength}, 'f').subi(0.5);
|
||||||
|
|
||||||
INDArray labels = Nd4j.zeros(miniBatchSize, nOut, timeSeriesLength);
|
INDArray labels = TestUtils.randomOneHotTimeSeries(miniBatchSize, nOut, timeSeriesLength);
|
||||||
for (int i = 0; i < miniBatchSize; i++) {
|
|
||||||
for (int j = 0; j < nIn; j++) {
|
|
||||||
labels.putScalar(i, r.nextInt(nOut), j, 1.0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println("testBidirectionalLSTMMasking() - testNum = " + testNum++);
|
System.out.println("testBidirectionalLSTMMasking() - testNum = " + testNum++);
|
||||||
// for (int j = 0; j < mln.getnLayers(); j++)
|
|
||||||
// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input)
|
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input)
|
||||||
.labels(labels).inputMask(mask).labelMask(mask).subset(true).maxPerParam(16));
|
.labels(labels).inputMask(mask).labelMask(mask).subset(true).maxPerParam(12));
|
||||||
|
|
||||||
assertTrue(gradOK);
|
assertTrue(gradOK);
|
||||||
TestUtils.testModelSerialization(mln);
|
TestUtils.testModelSerialization(mln);
|
||||||
|
|
|
@ -123,8 +123,6 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println(name + " - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation="
|
System.out.println(name + " - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation="
|
||||||
+ outputActivation + ", doLearningFirst=" + doLearningFirst);
|
+ outputActivation + ", doLearningFirst=" + doLearningFirst);
|
||||||
for (int j = 0; j < mln.getnLayers(); j++)
|
|
||||||
System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||||
|
@ -214,8 +212,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
System.out.println(testName + "- activationFn=" + afn + ", lossFn=" + lf
|
System.out.println(testName + "- activationFn=" + afn + ", lossFn=" + lf
|
||||||
+ ", outputActivation=" + outputActivation + ", doLearningFirst="
|
+ ", outputActivation=" + outputActivation + ", doLearningFirst="
|
||||||
+ doLearningFirst);
|
+ doLearningFirst);
|
||||||
for (int j = 0; j < mln.getnLayers(); j++)
|
// for (int j = 0; j < mln.getnLayers(); j++)
|
||||||
System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
|
// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||||
|
@ -277,8 +275,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
|
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println(msg);
|
System.out.println(msg);
|
||||||
for (int j = 0; j < net.getnLayers(); j++)
|
// for (int j = 0; j < net.getnLayers(); j++)
|
||||||
System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
|
// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||||
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
|
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
|
||||||
|
@ -340,8 +338,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
|
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println(msg);
|
System.out.println(msg);
|
||||||
for (int j = 0; j < net.getnLayers(); j++)
|
// for (int j = 0; j < net.getnLayers(); j++)
|
||||||
System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
|
// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||||
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
|
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
|
||||||
|
@ -397,8 +395,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
|
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println(msg);
|
System.out.println(msg);
|
||||||
for (int j = 0; j < net.getnLayers(); j++)
|
// for (int j = 0; j < net.getnLayers(); j++)
|
||||||
System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
|
// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||||
|
@ -468,8 +466,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
|
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println(msg);
|
System.out.println(msg);
|
||||||
for (int j = 0; j < net.getnLayers(); j++)
|
// for (int j = 0; j < net.getnLayers(); j++)
|
||||||
System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
|
// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||||
|
@ -602,9 +600,9 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
for (int i = 0; i < 4; i++) {
|
// for (int i = 0; i < 4; i++) {
|
||||||
System.out.println("nParams, layer " + i + ": " + net.getLayer(i).numParams());
|
// System.out.println("nParams, layer " + i + ": " + net.getLayer(i).numParams());
|
||||||
}
|
// }
|
||||||
|
|
||||||
String msg = "PoolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn="
|
String msg = "PoolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn="
|
||||||
+ afn;
|
+ afn;
|
||||||
|
@ -663,9 +661,9 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
for (int j = 0; j < net.getLayers().length; j++) {
|
// for (int j = 0; j < net.getLayers().length; j++) {
|
||||||
System.out.println("nParams, layer " + j + ": " + net.getLayer(j).numParams());
|
// System.out.println("nParams, layer " + j + ": " + net.getLayer(j).numParams());
|
||||||
}
|
// }
|
||||||
|
|
||||||
String msg = "Minibatch=" + minibatchSize + ", inDepth=" + inputDepth + ", height=" + height
|
String msg = "Minibatch=" + minibatchSize + ", inDepth=" + inputDepth + ", height=" + height
|
||||||
+ ", kernelSize=" + k;
|
+ ", kernelSize=" + k;
|
||||||
|
@ -726,9 +724,9 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
for (int i = 0; i < net.getLayers().length; i++) {
|
// for (int i = 0; i < net.getLayers().length; i++) {
|
||||||
System.out.println("nParams, layer " + i + ": " + net.getLayer(i).numParams());
|
// System.out.println("nParams, layer " + i + ": " + net.getLayer(i).numParams());
|
||||||
}
|
// }
|
||||||
|
|
||||||
String msg = "Minibatch=" + minibatchSize + ", inDepth=" + inputDepth + ", height=" + height
|
String msg = "Minibatch=" + minibatchSize + ", inDepth=" + inputDepth + ", height=" + height
|
||||||
+ ", kernelSize=" + k + ", stride = " + stride + ", convLayer first = "
|
+ ", kernelSize=" + k + ", stride = " + stride + ", convLayer first = "
|
||||||
|
@ -806,8 +804,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
|
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println(msg);
|
System.out.println(msg);
|
||||||
for (int j = 0; j < net.getnLayers(); j++)
|
// for (int j = 0; j < net.getnLayers(); j++)
|
||||||
System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
|
// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||||
|
@ -872,9 +870,9 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
for (int j = 0; j < net.getLayers().length; j++) {
|
// for (int j = 0; j < net.getLayers().length; j++) {
|
||||||
System.out.println("nParams, layer " + j + ": " + net.getLayer(j).numParams());
|
// System.out.println("nParams, layer " + j + ": " + net.getLayer(j).numParams());
|
||||||
}
|
// }
|
||||||
|
|
||||||
String msg = " - mb=" + minibatchSize + ", k="
|
String msg = " - mb=" + minibatchSize + ", k="
|
||||||
+ k + ", s=" + s + ", d=" + d + ", cm=" + cm;
|
+ k + ", s=" + s + ", d=" + d + ", cm=" + cm;
|
||||||
|
@ -943,9 +941,9 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
for (int i = 0; i < net.getLayers().length; i++) {
|
// for (int i = 0; i < net.getLayers().length; i++) {
|
||||||
System.out.println("nParams, layer " + i + ": " + net.getLayer(i).numParams());
|
// System.out.println("nParams, layer " + i + ": " + net.getLayer(i).numParams());
|
||||||
}
|
// }
|
||||||
|
|
||||||
String msg = " - mb=" + minibatchSize + ", k="
|
String msg = " - mb=" + minibatchSize + ", k="
|
||||||
+ k + ", nIn=" + nIn + ", depthMul=" + depthMultiplier + ", s=" + s + ", cm=" + cm;
|
+ k + ", nIn=" + nIn + ", depthMul=" + depthMultiplier + ", s=" + s + ", cm=" + cm;
|
||||||
|
@ -1018,9 +1016,9 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
for (int i = 0; i < net.getLayers().length; i++) {
|
// for (int i = 0; i < net.getLayers().length; i++) {
|
||||||
System.out.println("nParams, layer " + i + ": " + net.getLayer(i).numParams());
|
// System.out.println("nParams, layer " + i + ": " + net.getLayer(i).numParams());
|
||||||
}
|
// }
|
||||||
|
|
||||||
String msg = " - mb=" + minibatchSize + ", k="
|
String msg = " - mb=" + minibatchSize + ", k="
|
||||||
+ k + ", s=" + s + ", d=" + d + ", cm=" + cm;
|
+ k + ", s=" + s + ", d=" + d + ", cm=" + cm;
|
||||||
|
@ -1104,9 +1102,9 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
for (int i = 0; i < net.getLayers().length; i++) {
|
// for (int i = 0; i < net.getLayers().length; i++) {
|
||||||
System.out.println("nParams, layer " + i + ": " + net.getLayer(i).numParams());
|
// System.out.println("nParams, layer " + i + ": " + net.getLayer(i).numParams());
|
||||||
}
|
// }
|
||||||
|
|
||||||
String msg = (subsampling ? "subsampling" : "conv") + " - mb=" + minibatchSize + ", k="
|
String msg = (subsampling ? "subsampling" : "conv") + " - mb=" + minibatchSize + ", k="
|
||||||
+ k + ", s=" + s + ", d=" + d + ", cm=" + cm;
|
+ k + ", s=" + s + ", d=" + d + ", cm=" + cm;
|
||||||
|
@ -1179,8 +1177,6 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
|
|
||||||
if (PRINT_RESULTS) {
|
if (PRINT_RESULTS) {
|
||||||
System.out.println(msg);
|
System.out.println(msg);
|
||||||
for (int j = 0; j < net.getnLayers(); j++)
|
|
||||||
System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(
|
boolean gradOK = GradientCheckUtil.checkGradients(
|
||||||
|
|
|
@ -177,7 +177,7 @@ public class KDTreeTest extends BaseDL4JTest {
|
||||||
@Test
|
@Test
|
||||||
public void testKNN() {
|
public void testKNN() {
|
||||||
int dimensions = 512;
|
int dimensions = 512;
|
||||||
int vectorsNo = 50000;
|
int vectorsNo = isIntegrationTests() ? 50000 : 1000;
|
||||||
// make a KD-tree of dimension {#dimensions}
|
// make a KD-tree of dimension {#dimensions}
|
||||||
Stopwatch stopwatch = Stopwatch.createStarted();
|
Stopwatch stopwatch = Stopwatch.createStarted();
|
||||||
KDTree kdTree = new KDTree(dimensions);
|
KDTree kdTree = new KDTree(dimensions);
|
||||||
|
|
|
@ -92,13 +92,13 @@ public class SPTreeTest extends BaseDL4JTest {
|
||||||
@Test
|
@Test
|
||||||
//@Ignore
|
//@Ignore
|
||||||
public void testLargeTree() {
|
public void testLargeTree() {
|
||||||
int num = 100000;
|
int num = isIntegrationTests() ? 100000 : 1000;
|
||||||
StopWatch watch = new StopWatch();
|
StopWatch watch = new StopWatch();
|
||||||
watch.start();
|
watch.start();
|
||||||
INDArray arr = Nd4j.linspace(1, num, num, Nd4j.dataType()).reshape(num, 1);
|
INDArray arr = Nd4j.linspace(1, num, num, Nd4j.dataType()).reshape(num, 1);
|
||||||
SpTree tree = new SpTree(arr);
|
SpTree tree = new SpTree(arr);
|
||||||
watch.stop();
|
watch.stop();
|
||||||
System.out.println("Tree created in " + watch);
|
System.out.println("Tree of size " + num + " created in " + watch);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,19 +45,19 @@ public class RandomizedInputTest extends RandomizedTest {
|
||||||
private Tokenizer tokenizer = new Tokenizer();
|
private Tokenizer tokenizer = new Tokenizer();
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Repeat(iterations = 50)
|
@Repeat(iterations = 10)
|
||||||
public void testRandomizedUnicodeInput() {
|
public void testRandomizedUnicodeInput() {
|
||||||
assertCanTokenizeString(randomUnicodeOfLength(LENGTH), tokenizer);
|
assertCanTokenizeString(randomUnicodeOfLength(LENGTH), tokenizer);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Repeat(iterations = 50)
|
@Repeat(iterations = 10)
|
||||||
public void testRandomizedRealisticUnicodeInput() {
|
public void testRandomizedRealisticUnicodeInput() {
|
||||||
assertCanTokenizeString(randomRealisticUnicodeOfLength(LENGTH), tokenizer);
|
assertCanTokenizeString(randomRealisticUnicodeOfLength(LENGTH), tokenizer);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Repeat(iterations = 50)
|
@Repeat(iterations = 10)
|
||||||
public void testRandomizedAsciiInput() {
|
public void testRandomizedAsciiInput() {
|
||||||
assertCanTokenizeString(randomAsciiOfLength(LENGTH), tokenizer);
|
assertCanTokenizeString(randomAsciiOfLength(LENGTH), tokenizer);
|
||||||
}
|
}
|
||||||
|
|
|
@ -406,11 +406,11 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
|
||||||
double simD = arraysSimilarity(day1, day2);
|
double simD = arraysSimilarity(day1, day2);
|
||||||
double simN = arraysSimilarity(night1, night2);
|
double simN = arraysSimilarity(night1, night2);
|
||||||
|
|
||||||
logger.info("Vec1 day: " + day1);
|
// logger.info("Vec1 day: " + day1);
|
||||||
logger.info("Vec2 day: " + day2);
|
// logger.info("Vec2 day: " + day2);
|
||||||
|
|
||||||
logger.info("Vec1 night: " + night1);
|
// logger.info("Vec1 night: " + night1);
|
||||||
logger.info("Vec2 night: " + night2);
|
// logger.info("Vec2 night: " + night2);
|
||||||
|
|
||||||
logger.info("Day/day cross-model similarity: " + simD);
|
logger.info("Day/day cross-model similarity: " + simD);
|
||||||
logger.info("Night/night cross-model similarity: " + simN);
|
logger.info("Night/night cross-model similarity: " + simN);
|
||||||
|
|
|
@ -16,6 +16,9 @@
|
||||||
|
|
||||||
package org.deeplearning4j.models.word2vec;
|
package org.deeplearning4j.models.word2vec;
|
||||||
|
|
||||||
|
import org.apache.commons.io.IOUtils;
|
||||||
|
import org.apache.commons.io.LineIterator;
|
||||||
|
import org.deeplearning4j.text.sentenceiterator.CollectionSentenceIterator;
|
||||||
import org.junit.Rule;
|
import org.junit.Rule;
|
||||||
import org.junit.rules.Timeout;
|
import org.junit.rules.Timeout;
|
||||||
import org.nd4j.shade.guava.primitives.Doubles;
|
import org.nd4j.shade.guava.primitives.Doubles;
|
||||||
|
@ -51,8 +54,8 @@ import org.nd4j.resources.Resources;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.*;
|
||||||
import java.io.IOException;
|
import java.nio.charset.StandardCharsets;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.Assert.*;
|
||||||
|
@ -185,7 +188,12 @@ public class Word2VecTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testWord2VecMultiEpoch() throws Exception {
|
public void testWord2VecMultiEpoch() throws Exception {
|
||||||
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
SentenceIterator iter;
|
||||||
|
if(isIntegrationTests()){
|
||||||
|
iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
||||||
|
} else {
|
||||||
|
iter = new CollectionSentenceIterator(firstNLines(inputFile, 50000));
|
||||||
|
}
|
||||||
|
|
||||||
TokenizerFactory t = new DefaultTokenizerFactory();
|
TokenizerFactory t = new DefaultTokenizerFactory();
|
||||||
t.setTokenPreProcessor(new CommonPreprocessor());
|
t.setTokenPreProcessor(new CommonPreprocessor());
|
||||||
|
@ -389,7 +397,12 @@ public class Word2VecTests extends BaseDL4JTest {
|
||||||
@Test
|
@Test
|
||||||
public void testW2VnegativeOnRestore() throws Exception {
|
public void testW2VnegativeOnRestore() throws Exception {
|
||||||
// Strip white space before and after for each line
|
// Strip white space before and after for each line
|
||||||
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
SentenceIterator iter;
|
||||||
|
if(isIntegrationTests()){
|
||||||
|
iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
||||||
|
} else {
|
||||||
|
iter = new CollectionSentenceIterator(firstNLines(inputFile, 300));
|
||||||
|
}
|
||||||
// Split on white spaces in the line to get words
|
// Split on white spaces in the line to get words
|
||||||
TokenizerFactory t = new DefaultTokenizerFactory();
|
TokenizerFactory t = new DefaultTokenizerFactory();
|
||||||
t.setTokenPreProcessor(new CommonPreprocessor());
|
t.setTokenPreProcessor(new CommonPreprocessor());
|
||||||
|
@ -491,7 +504,12 @@ public class Word2VecTests extends BaseDL4JTest {
|
||||||
@Test
|
@Test
|
||||||
public void orderIsCorrect_WhenParallelized() throws Exception {
|
public void orderIsCorrect_WhenParallelized() throws Exception {
|
||||||
// Strip white space before and after for each line
|
// Strip white space before and after for each line
|
||||||
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
SentenceIterator iter;
|
||||||
|
if(isIntegrationTests()){
|
||||||
|
iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
||||||
|
} else {
|
||||||
|
iter = new CollectionSentenceIterator(firstNLines(inputFile, 300));
|
||||||
|
}
|
||||||
// Split on white spaces in the line to get words
|
// Split on white spaces in the line to get words
|
||||||
TokenizerFactory t = new DefaultTokenizerFactory();
|
TokenizerFactory t = new DefaultTokenizerFactory();
|
||||||
t.setTokenPreProcessor(new CommonPreprocessor());
|
t.setTokenPreProcessor(new CommonPreprocessor());
|
||||||
|
@ -510,9 +528,10 @@ public class Word2VecTests extends BaseDL4JTest {
|
||||||
System.out.println(vec.getVocab().numWords());
|
System.out.println(vec.getVocab().numWords());
|
||||||
|
|
||||||
val words = vec.getVocab().words();
|
val words = vec.getVocab().words();
|
||||||
for (val word : words) {
|
assertTrue(words.size() > 0);
|
||||||
System.out.println(word);
|
// for (val word : words) {
|
||||||
}
|
// System.out.println(word);
|
||||||
|
// }
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -755,7 +774,16 @@ public class Word2VecTests extends BaseDL4JTest {
|
||||||
@Test
|
@Test
|
||||||
public void weightsNotUpdated_WhenLocked() throws Exception {
|
public void weightsNotUpdated_WhenLocked() throws Exception {
|
||||||
|
|
||||||
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
boolean isIntegration = isIntegrationTests();
|
||||||
|
SentenceIterator iter;
|
||||||
|
SentenceIterator iter2;
|
||||||
|
if(isIntegration){
|
||||||
|
iter = new BasicLineIterator(inputFile);
|
||||||
|
iter2 = new BasicLineIterator(inputFile2.getAbsolutePath());
|
||||||
|
} else {
|
||||||
|
iter = new CollectionSentenceIterator(firstNLines(inputFile, 300));
|
||||||
|
iter2 = new CollectionSentenceIterator(firstNLines(inputFile2, 300));
|
||||||
|
}
|
||||||
|
|
||||||
Word2Vec vec1 = new Word2Vec.Builder().minWordFrequency(1).iterations(3).batchSize(64).layerSize(100)
|
Word2Vec vec1 = new Word2Vec.Builder().minWordFrequency(1).iterations(3).batchSize(64).layerSize(100)
|
||||||
.stopWords(new ArrayList<String>()).seed(42).learningRate(0.025).minLearningRate(0.001)
|
.stopWords(new ArrayList<String>()).seed(42).learningRate(0.025).minLearningRate(0.001)
|
||||||
|
@ -767,13 +795,12 @@ public class Word2VecTests extends BaseDL4JTest {
|
||||||
|
|
||||||
vec1.fit();
|
vec1.fit();
|
||||||
|
|
||||||
iter = new BasicLineIterator(inputFile2.getAbsolutePath());
|
|
||||||
Word2Vec vec2 = new Word2Vec.Builder().minWordFrequency(1).iterations(3).batchSize(32).layerSize(100)
|
Word2Vec vec2 = new Word2Vec.Builder().minWordFrequency(1).iterations(3).batchSize(32).layerSize(100)
|
||||||
.stopWords(new ArrayList<String>()).seed(32).learningRate(0.021).minLearningRate(0.001)
|
.stopWords(new ArrayList<String>()).seed(32).learningRate(0.021).minLearningRate(0.001)
|
||||||
.sampling(0).elementsLearningAlgorithm(new SkipGram<VocabWord>())
|
.sampling(0).elementsLearningAlgorithm(new SkipGram<VocabWord>())
|
||||||
.epochs(1).windowSize(5).allowParallelTokenization(true)
|
.epochs(1).windowSize(5).allowParallelTokenization(true)
|
||||||
.workers(1)
|
.workers(1)
|
||||||
.iterate(iter)
|
.iterate(iter2)
|
||||||
.intersectModel(vec1, true)
|
.intersectModel(vec1, true)
|
||||||
.modelUtils(new BasicModelUtils<VocabWord>()).build();
|
.modelUtils(new BasicModelUtils<VocabWord>()).build();
|
||||||
|
|
||||||
|
@ -861,6 +888,22 @@ public class Word2VecTests extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
System.out.print("\n");
|
System.out.print("\n");
|
||||||
}
|
}
|
||||||
//
|
|
||||||
|
public static List<String> firstNLines(File f, int n){
|
||||||
|
List<String> lines = new ArrayList<>();
|
||||||
|
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||||
|
LineIterator lineIter = IOUtils.lineIterator(is, StandardCharsets.UTF_8);
|
||||||
|
try{
|
||||||
|
for( int i=0; i<n && lineIter.hasNext(); i++ ){
|
||||||
|
lines.add(lineIter.next());
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
lineIter.close();
|
||||||
|
}
|
||||||
|
return lines;
|
||||||
|
} catch (IOException e){
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -48,7 +48,7 @@ public class TsneTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long getTimeoutMilliseconds() {
|
public long getTimeoutMilliseconds() {
|
||||||
return 120000L;
|
return 60000L;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Rule
|
@Rule
|
||||||
|
@ -58,103 +58,102 @@ public class TsneTest extends BaseDL4JTest {
|
||||||
public void testSimple() throws Exception {
|
public void testSimple() throws Exception {
|
||||||
//Simple sanity check
|
//Simple sanity check
|
||||||
|
|
||||||
for (boolean syntheticData : new boolean[]{false, true}) {
|
for( int test=0; test <=1; test++){
|
||||||
for (WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.NONE, WorkspaceMode.ENABLED}) {
|
boolean syntheticData = test == 1;
|
||||||
log.info("Starting test: WSM={}, syntheticData={}", wsm, syntheticData);
|
WorkspaceMode wsm = test == 0 ? WorkspaceMode.NONE : WorkspaceMode.ENABLED;
|
||||||
|
log.info("Starting test: WSM={}, syntheticData={}", wsm, syntheticData);
|
||||||
|
|
||||||
//STEP 1: Initialization
|
//STEP 1: Initialization
|
||||||
int iterations = 50;
|
int iterations = 50;
|
||||||
//create an n-dimensional array of doubles
|
//create an n-dimensional array of doubles
|
||||||
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
|
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
|
||||||
List<String> cacheList = new ArrayList<>(); //cacheList is a dynamic array of strings used to hold all words
|
List<String> cacheList = new ArrayList<>(); //cacheList is a dynamic array of strings used to hold all words
|
||||||
|
|
||||||
//STEP 2: Turn text input into a list of words
|
//STEP 2: Turn text input into a list of words
|
||||||
INDArray weights;
|
INDArray weights;
|
||||||
if(syntheticData){
|
if(syntheticData){
|
||||||
weights = Nd4j.rand(1000, 200);
|
weights = Nd4j.rand(250, 200);
|
||||||
} else {
|
} else {
|
||||||
log.info("Load & Vectorize data....");
|
log.info("Load & Vectorize data....");
|
||||||
File wordFile = new ClassPathResource("deeplearning4j-tsne/words.txt").getFile(); //Open the file
|
File wordFile = new ClassPathResource("deeplearning4j-tsne/words.txt").getFile(); //Open the file
|
||||||
//Get the data of all unique word vectors
|
//Get the data of all unique word vectors
|
||||||
Pair<InMemoryLookupTable, VocabCache> vectors = WordVectorSerializer.loadTxt(wordFile);
|
Pair<InMemoryLookupTable, VocabCache> vectors = WordVectorSerializer.loadTxt(wordFile);
|
||||||
VocabCache cache = vectors.getSecond();
|
VocabCache cache = vectors.getSecond();
|
||||||
weights = vectors.getFirst().getSyn0(); //seperate weights of unique words into their own list
|
weights = vectors.getFirst().getSyn0(); //seperate weights of unique words into their own list
|
||||||
|
|
||||||
for (int i = 0; i < cache.numWords(); i++) //seperate strings of words into their own list
|
for (int i = 0; i < cache.numWords(); i++) //seperate strings of words into their own list
|
||||||
cacheList.add(cache.wordAtIndex(i));
|
cacheList.add(cache.wordAtIndex(i));
|
||||||
}
|
|
||||||
|
|
||||||
//STEP 3: build a dual-tree tsne to use later
|
|
||||||
log.info("Build model....");
|
|
||||||
BarnesHutTsne tsne = new BarnesHutTsne.Builder()
|
|
||||||
.setMaxIter(iterations)
|
|
||||||
.theta(0.5)
|
|
||||||
.normalize(false)
|
|
||||||
.learningRate(500)
|
|
||||||
.useAdaGrad(false)
|
|
||||||
.workspaceMode(wsm)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
|
|
||||||
//STEP 4: establish the tsne values and save them to a file
|
|
||||||
log.info("Store TSNE Coordinates for Plotting....");
|
|
||||||
File outDir = testDir.newFolder();
|
|
||||||
tsne.fit(weights);
|
|
||||||
tsne.saveAsFile(cacheList, new File(outDir, "out.txt").getAbsolutePath());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//STEP 3: build a dual-tree tsne to use later
|
||||||
|
log.info("Build model....");
|
||||||
|
BarnesHutTsne tsne = new BarnesHutTsne.Builder()
|
||||||
|
.setMaxIter(iterations)
|
||||||
|
.theta(0.5)
|
||||||
|
.normalize(false)
|
||||||
|
.learningRate(500)
|
||||||
|
.useAdaGrad(false)
|
||||||
|
.workspaceMode(wsm)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
|
||||||
|
//STEP 4: establish the tsne values and save them to a file
|
||||||
|
log.info("Store TSNE Coordinates for Plotting....");
|
||||||
|
File outDir = testDir.newFolder();
|
||||||
|
tsne.fit(weights);
|
||||||
|
tsne.saveAsFile(cacheList, new File(outDir, "out.txt").getAbsolutePath());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//Elapsed time : 01:01:57.988
|
|
||||||
@Test
|
@Test
|
||||||
public void testPerformance() throws Exception {
|
public void testPerformance() throws Exception {
|
||||||
|
|
||||||
StopWatch watch = new StopWatch();
|
StopWatch watch = new StopWatch();
|
||||||
watch.start();
|
watch.start();
|
||||||
for (boolean syntheticData : new boolean[]{false, true}) {
|
for( int test=0; test <=1; test++){
|
||||||
for (WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.NONE, WorkspaceMode.ENABLED}) {
|
boolean syntheticData = test == 1;
|
||||||
log.info("Starting test: WSM={}, syntheticData={}", wsm, syntheticData);
|
WorkspaceMode wsm = test == 0 ? WorkspaceMode.NONE : WorkspaceMode.ENABLED;
|
||||||
|
log.info("Starting test: WSM={}, syntheticData={}", wsm, syntheticData);
|
||||||
|
|
||||||
//STEP 1: Initialization
|
//STEP 1: Initialization
|
||||||
int iterations = 100;
|
int iterations = 50;
|
||||||
//create an n-dimensional array of doubles
|
//create an n-dimensional array of doubles
|
||||||
Nd4j.setDataType(DataType.DOUBLE);
|
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
|
||||||
List<String> cacheList = new ArrayList<>(); //cacheList is a dynamic array of strings used to hold all words
|
List<String> cacheList = new ArrayList<>(); //cacheList is a dynamic array of strings used to hold all words
|
||||||
|
|
||||||
//STEP 2: Turn text input into a list of words
|
//STEP 2: Turn text input into a list of words
|
||||||
INDArray weights;
|
INDArray weights;
|
||||||
if(syntheticData){
|
if(syntheticData){
|
||||||
weights = Nd4j.rand(5000, 20);
|
weights = Nd4j.rand(DataType.FLOAT, 250, 20);
|
||||||
} else {
|
} else {
|
||||||
log.info("Load & Vectorize data....");
|
log.info("Load & Vectorize data....");
|
||||||
File wordFile = new ClassPathResource("deeplearning4j-tsne/words.txt").getFile(); //Open the file
|
File wordFile = new ClassPathResource("deeplearning4j-tsne/words.txt").getFile(); //Open the file
|
||||||
//Get the data of all unique word vectors
|
//Get the data of all unique word vectors
|
||||||
Pair<InMemoryLookupTable, VocabCache> vectors = WordVectorSerializer.loadTxt(wordFile);
|
Pair<InMemoryLookupTable, VocabCache> vectors = WordVectorSerializer.loadTxt(wordFile);
|
||||||
VocabCache cache = vectors.getSecond();
|
VocabCache cache = vectors.getSecond();
|
||||||
weights = vectors.getFirst().getSyn0(); //seperate weights of unique words into their own list
|
weights = vectors.getFirst().getSyn0(); //seperate weights of unique words into their own list
|
||||||
|
|
||||||
for (int i = 0; i < cache.numWords(); i++) //seperate strings of words into their own list
|
for (int i = 0; i < cache.numWords(); i++) //seperate strings of words into their own list
|
||||||
cacheList.add(cache.wordAtIndex(i));
|
cacheList.add(cache.wordAtIndex(i));
|
||||||
}
|
|
||||||
|
|
||||||
//STEP 3: build a dual-tree tsne to use later
|
|
||||||
log.info("Build model....");
|
|
||||||
BarnesHutTsne tsne = new BarnesHutTsne.Builder()
|
|
||||||
.setMaxIter(iterations)
|
|
||||||
.theta(0.5)
|
|
||||||
.normalize(false)
|
|
||||||
.learningRate(500)
|
|
||||||
.useAdaGrad(false)
|
|
||||||
.workspaceMode(wsm)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
|
|
||||||
//STEP 4: establish the tsne values and save them to a file
|
|
||||||
log.info("Store TSNE Coordinates for Plotting....");
|
|
||||||
File outDir = testDir.newFolder();
|
|
||||||
tsne.fit(weights);
|
|
||||||
tsne.saveAsFile(cacheList, new File(outDir, "out.txt").getAbsolutePath());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//STEP 3: build a dual-tree tsne to use later
|
||||||
|
log.info("Build model....");
|
||||||
|
BarnesHutTsne tsne = new BarnesHutTsne.Builder()
|
||||||
|
.setMaxIter(iterations)
|
||||||
|
.theta(0.5)
|
||||||
|
.normalize(false)
|
||||||
|
.learningRate(500)
|
||||||
|
.useAdaGrad(false)
|
||||||
|
.workspaceMode(wsm)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
|
||||||
|
//STEP 4: establish the tsne values and save them to a file
|
||||||
|
log.info("Store TSNE Coordinates for Plotting....");
|
||||||
|
File outDir = testDir.newFolder();
|
||||||
|
tsne.fit(weights);
|
||||||
|
tsne.saveAsFile(cacheList, new File(outDir, "out.txt").getAbsolutePath());
|
||||||
}
|
}
|
||||||
watch.stop();
|
watch.stop();
|
||||||
System.out.println("Elapsed time : " + watch);
|
System.out.println("Elapsed time : " + watch);
|
||||||
|
|
|
@ -20,6 +20,8 @@ package org.deeplearning4j.models.paragraphvectors;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
|
import org.apache.commons.io.IOUtils;
|
||||||
|
import org.apache.commons.io.LineIterator;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.models.embeddings.learning.impl.elements.CBOW;
|
import org.deeplearning4j.models.embeddings.learning.impl.elements.CBOW;
|
||||||
import org.deeplearning4j.models.embeddings.reader.impl.FlatModelUtils;
|
import org.deeplearning4j.models.embeddings.reader.impl.FlatModelUtils;
|
||||||
|
@ -27,6 +29,7 @@ import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
|
||||||
import org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer;
|
import org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer;
|
||||||
import org.deeplearning4j.models.sequencevectors.transformers.impl.iterables.BasicTransformerIterator;
|
import org.deeplearning4j.models.sequencevectors.transformers.impl.iterables.BasicTransformerIterator;
|
||||||
import org.deeplearning4j.models.sequencevectors.transformers.impl.iterables.ParallelTransformerIterator;
|
import org.deeplearning4j.models.sequencevectors.transformers.impl.iterables.ParallelTransformerIterator;
|
||||||
|
import org.deeplearning4j.text.sentenceiterator.*;
|
||||||
import org.junit.Rule;
|
import org.junit.Rule;
|
||||||
import org.junit.rules.TemporaryFolder;
|
import org.junit.rules.TemporaryFolder;
|
||||||
import org.nd4j.linalg.io.ClassPathResource;
|
import org.nd4j.linalg.io.ClassPathResource;
|
||||||
|
@ -46,10 +49,6 @@ import org.deeplearning4j.text.documentiterator.FileLabelAwareIterator;
|
||||||
import org.deeplearning4j.text.documentiterator.LabelAwareIterator;
|
import org.deeplearning4j.text.documentiterator.LabelAwareIterator;
|
||||||
import org.deeplearning4j.text.documentiterator.LabelledDocument;
|
import org.deeplearning4j.text.documentiterator.LabelledDocument;
|
||||||
import org.deeplearning4j.text.documentiterator.LabelsSource;
|
import org.deeplearning4j.text.documentiterator.LabelsSource;
|
||||||
import org.deeplearning4j.text.sentenceiterator.AggregatingSentenceIterator;
|
|
||||||
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
|
|
||||||
import org.deeplearning4j.text.sentenceiterator.FileSentenceIterator;
|
|
||||||
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
|
|
||||||
import org.deeplearning4j.text.sentenceiterator.interoperability.SentenceIteratorConverter;
|
import org.deeplearning4j.text.sentenceiterator.interoperability.SentenceIteratorConverter;
|
||||||
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
|
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
|
||||||
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
|
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
|
||||||
|
@ -66,8 +65,8 @@ import org.nd4j.resources.Resources;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.*;
|
||||||
import java.io.IOException;
|
import java.nio.charset.StandardCharsets;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
import java.util.concurrent.atomic.AtomicLong;
|
import java.util.concurrent.atomic.AtomicLong;
|
||||||
|
|
||||||
|
@ -372,7 +371,7 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
|
||||||
|
|
||||||
LabelsSource source = new LabelsSource("DOC_");
|
LabelsSource source = new LabelsSource("DOC_");
|
||||||
|
|
||||||
ParagraphVectors vec = new ParagraphVectors.Builder().minWordFrequency(1).iterations(2).seed(119).epochs(3)
|
ParagraphVectors vec = new ParagraphVectors.Builder().minWordFrequency(1).iterations(2).seed(119).epochs(1)
|
||||||
.layerSize(100).learningRate(0.025).labelsSource(source).windowSize(5).iterate(iter)
|
.layerSize(100).learningRate(0.025).labelsSource(source).windowSize(5).iterate(iter)
|
||||||
.trainWordVectors(true).vocabCache(cache).tokenizerFactory(t).negativeSample(0)
|
.trainWordVectors(true).vocabCache(cache).tokenizerFactory(t).negativeSample(0)
|
||||||
.useHierarchicSoftmax(true).sampling(0).workers(1).usePreciseWeightInit(true)
|
.useHierarchicSoftmax(true).sampling(0).workers(1).usePreciseWeightInit(true)
|
||||||
|
@ -425,6 +424,8 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test(timeout = 300000)
|
@Test(timeout = 300000)
|
||||||
public void testParagraphVectorsDBOW() throws Exception {
|
public void testParagraphVectorsDBOW() throws Exception {
|
||||||
|
skipUnlessIntegrationTests();
|
||||||
|
|
||||||
File file = Resources.asFile("/big/raw_sentences.txt");
|
File file = Resources.asFile("/big/raw_sentences.txt");
|
||||||
SentenceIterator iter = new BasicLineIterator(file);
|
SentenceIterator iter = new BasicLineIterator(file);
|
||||||
|
|
||||||
|
@ -657,7 +658,7 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(timeout = 300000)
|
@Test
|
||||||
public void testIterator() throws IOException {
|
public void testIterator() throws IOException {
|
||||||
val folder_labeled = testDir.newFolder();
|
val folder_labeled = testDir.newFolder();
|
||||||
val folder_unlabeled = testDir.newFolder();
|
val folder_unlabeled = testDir.newFolder();
|
||||||
|
@ -672,7 +673,7 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
|
||||||
SentenceIterator iter = new BasicLineIterator(resource_sentences);
|
SentenceIterator iter = new BasicLineIterator(resource_sentences);
|
||||||
|
|
||||||
int i = 0;
|
int i = 0;
|
||||||
for (; i < 10000; ++i) {
|
for (; i < 10; ++i) {
|
||||||
int j = 0;
|
int j = 0;
|
||||||
int labels = 0;
|
int labels = 0;
|
||||||
int words = 0;
|
int words = 0;
|
||||||
|
@ -721,7 +722,7 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
|
||||||
TokenizerFactory t = new DefaultTokenizerFactory();
|
TokenizerFactory t = new DefaultTokenizerFactory();
|
||||||
t.setTokenPreProcessor(new CommonPreprocessor());
|
t.setTokenPreProcessor(new CommonPreprocessor());
|
||||||
|
|
||||||
Word2Vec wordVectors = new Word2Vec.Builder().seed(119).minWordFrequency(1).batchSize(250).iterations(1).epochs(3)
|
Word2Vec wordVectors = new Word2Vec.Builder().seed(119).minWordFrequency(1).batchSize(250).iterations(1).epochs(1)
|
||||||
.learningRate(0.025).layerSize(150).minLearningRate(0.001)
|
.learningRate(0.025).layerSize(150).minLearningRate(0.001)
|
||||||
.elementsLearningAlgorithm(new SkipGram<VocabWord>()).useHierarchicSoftmax(true).windowSize(5)
|
.elementsLearningAlgorithm(new SkipGram<VocabWord>()).useHierarchicSoftmax(true).windowSize(5)
|
||||||
.allowParallelTokenization(true)
|
.allowParallelTokenization(true)
|
||||||
|
@ -1009,7 +1010,7 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
|
||||||
TokenizerFactory t = new DefaultTokenizerFactory();
|
TokenizerFactory t = new DefaultTokenizerFactory();
|
||||||
t.setTokenPreProcessor(new CommonPreprocessor());
|
t.setTokenPreProcessor(new CommonPreprocessor());
|
||||||
|
|
||||||
Word2Vec wordVectors = new Word2Vec.Builder().minWordFrequency(1).batchSize(250).iterations(1).epochs(3)
|
Word2Vec wordVectors = new Word2Vec.Builder().minWordFrequency(1).batchSize(250).iterations(1).epochs(1)
|
||||||
.learningRate(0.025).layerSize(150).minLearningRate(0.001)
|
.learningRate(0.025).layerSize(150).minLearningRate(0.001)
|
||||||
.elementsLearningAlgorithm(new SkipGram<VocabWord>()).useHierarchicSoftmax(true).windowSize(5)
|
.elementsLearningAlgorithm(new SkipGram<VocabWord>()).useHierarchicSoftmax(true).windowSize(5)
|
||||||
.iterate(iter).tokenizerFactory(t).build();
|
.iterate(iter).tokenizerFactory(t).build();
|
||||||
|
@ -1151,8 +1152,27 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test(timeout = 300000)
|
@Test(timeout = 300000)
|
||||||
public void testDoubleFit() throws Exception {
|
public void testDoubleFit() throws Exception {
|
||||||
|
boolean isIntegration = isIntegrationTests();
|
||||||
File resource = Resources.asFile("/big/raw_sentences.txt");
|
File resource = Resources.asFile("/big/raw_sentences.txt");
|
||||||
SentenceIterator iter = new BasicLineIterator(resource);
|
SentenceIterator iter;
|
||||||
|
if(isIntegration){
|
||||||
|
iter = new BasicLineIterator(resource);
|
||||||
|
} else {
|
||||||
|
List<String> lines = new ArrayList<>();
|
||||||
|
try(InputStream is = new BufferedInputStream(new FileInputStream(resource))){
|
||||||
|
LineIterator lineIter = IOUtils.lineIterator(is, StandardCharsets.UTF_8);
|
||||||
|
try{
|
||||||
|
for( int i=0; i<500 && lineIter.hasNext(); i++ ){
|
||||||
|
lines.add(lineIter.next());
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
lineIter.close();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
iter = new CollectionSentenceIterator(lines);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
TokenizerFactory t = new DefaultTokenizerFactory();
|
TokenizerFactory t = new DefaultTokenizerFactory();
|
||||||
t.setTokenPreProcessor(new CommonPreprocessor());
|
t.setTokenPreProcessor(new CommonPreprocessor());
|
||||||
|
|
|
@ -49,7 +49,7 @@ public class Word2VecDataSetIteratorTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long getTimeoutMilliseconds() {
|
public long getTimeoutMilliseconds() {
|
||||||
return 240000L;
|
return 60000L;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -57,6 +57,7 @@ public class Word2VecDataSetIteratorTest extends BaseDL4JTest {
|
||||||
*/
|
*/
|
||||||
@Test
|
@Test
|
||||||
public void testIterator1() throws Exception {
|
public void testIterator1() throws Exception {
|
||||||
|
|
||||||
File inputFile = Resources.asFile("big/raw_sentences.txt");
|
File inputFile = Resources.asFile("big/raw_sentences.txt");
|
||||||
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
||||||
|
|
||||||
|
@ -77,10 +78,14 @@ public class Word2VecDataSetIteratorTest extends BaseDL4JTest {
|
||||||
|
|
||||||
Word2VecDataSetIterator iterator = new Word2VecDataSetIterator(vec, getLASI(iter, labels), labels, 1);
|
Word2VecDataSetIterator iterator = new Word2VecDataSetIterator(vec, getLASI(iter, labels), labels, 1);
|
||||||
INDArray array = iterator.next().getFeatures();
|
INDArray array = iterator.next().getFeatures();
|
||||||
|
int count = 0;
|
||||||
while (iterator.hasNext()) {
|
while (iterator.hasNext()) {
|
||||||
DataSet ds = iterator.next();
|
DataSet ds = iterator.next();
|
||||||
|
|
||||||
assertArrayEquals(array.shape(), ds.getFeatures().shape());
|
assertArrayEquals(array.shape(), ds.getFeatures().shape());
|
||||||
|
|
||||||
|
if(!isIntegrationTests() && count++ > 20)
|
||||||
|
break; //raw_sentences.txt is 2.81 MB, takes quite some time to process. We'll only first 20 minibatches when doing unit tests
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -45,9 +45,15 @@ public class EncodedGradientsAccumulatorTest extends BaseDL4JTest {
|
||||||
*/
|
*/
|
||||||
@Test
|
@Test
|
||||||
public void testStore1() throws Exception {
|
public void testStore1() throws Exception {
|
||||||
int numParams = 100000;
|
int numParams;
|
||||||
|
int[] workers;
|
||||||
int workers[] = new int[] {2, 4, 8};
|
if(isIntegrationTests()){
|
||||||
|
numParams = 100000;
|
||||||
|
workers = new int[] {2, 4, 8};
|
||||||
|
} else {
|
||||||
|
numParams = 10000;
|
||||||
|
workers = new int[] {2, 3};
|
||||||
|
}
|
||||||
|
|
||||||
for (int numWorkers : workers) {
|
for (int numWorkers : workers) {
|
||||||
EncodingHandler handler = new EncodingHandler(new FixedThresholdAlgorithm(1e-3),null, null, false);
|
EncodingHandler handler = new EncodingHandler(new FixedThresholdAlgorithm(1e-3),null, null, false);
|
||||||
|
@ -77,7 +83,13 @@ public class EncodedGradientsAccumulatorTest extends BaseDL4JTest {
|
||||||
*/
|
*/
|
||||||
@Test
|
@Test
|
||||||
public void testEncodingLimits1() throws Exception {
|
public void testEncodingLimits1() throws Exception {
|
||||||
int numParams = 100000;
|
int numParams;
|
||||||
|
if(isIntegrationTests()){
|
||||||
|
numParams = 100000;
|
||||||
|
} else {
|
||||||
|
numParams = 10000;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
EncodingHandler handler = new EncodingHandler(new FixedThresholdAlgorithm(1e-3), null, null, false);
|
EncodingHandler handler = new EncodingHandler(new FixedThresholdAlgorithm(1e-3), null, null, false);
|
||||||
for (int e = 10; e < numParams / 5; e++) {
|
for (int e = 10; e < numParams / 5; e++) {
|
||||||
|
|
|
@ -242,7 +242,7 @@ public class IndexedTailTest extends BaseDL4JTest {
|
||||||
final long[] sums = new long[numReaders];
|
final long[] sums = new long[numReaders];
|
||||||
val readers = new ArrayList<Thread>();
|
val readers = new ArrayList<Thread>();
|
||||||
for (int e = 0; e < numReaders; e++) {
|
for (int e = 0; e < numReaders; e++) {
|
||||||
val f = e;
|
final int f = e;
|
||||||
val t = new Thread(new Runnable() {
|
val t = new Thread(new Runnable() {
|
||||||
@Override
|
@Override
|
||||||
public void run() {
|
public void run() {
|
||||||
|
@ -297,7 +297,7 @@ public class IndexedTailTest extends BaseDL4JTest {
|
||||||
final long[] sums = new long[numReaders];
|
final long[] sums = new long[numReaders];
|
||||||
val readers = new ArrayList<Thread>();
|
val readers = new ArrayList<Thread>();
|
||||||
for (int e = 0; e < numReaders; e++) {
|
for (int e = 0; e < numReaders; e++) {
|
||||||
val f = e;
|
final int f = e;
|
||||||
val t = new Thread(new Runnable() {
|
val t = new Thread(new Runnable() {
|
||||||
@Override
|
@Override
|
||||||
public void run() {
|
public void run() {
|
||||||
|
@ -371,7 +371,7 @@ public class IndexedTailTest extends BaseDL4JTest {
|
||||||
final long[] sums = new long[numReaders];
|
final long[] sums = new long[numReaders];
|
||||||
val readers = new ArrayList<Thread>();
|
val readers = new ArrayList<Thread>();
|
||||||
for (int e = 0; e < numReaders; e++) {
|
for (int e = 0; e < numReaders; e++) {
|
||||||
val f = e;
|
final int f = e;
|
||||||
val t = new Thread(new Runnable() {
|
val t = new Thread(new Runnable() {
|
||||||
@Override
|
@Override
|
||||||
public void run() {
|
public void run() {
|
||||||
|
|
|
@ -35,6 +35,7 @@ import org.deeplearning4j.remote.helpers.House;
|
||||||
import org.deeplearning4j.remote.helpers.HouseToPredictedPriceAdapter;
|
import org.deeplearning4j.remote.helpers.HouseToPredictedPriceAdapter;
|
||||||
import org.deeplearning4j.remote.helpers.PredictedPrice;
|
import org.deeplearning4j.remote.helpers.PredictedPrice;
|
||||||
import org.junit.After;
|
import org.junit.After;
|
||||||
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.adapters.InferenceAdapter;
|
import org.nd4j.adapters.InferenceAdapter;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
@ -58,6 +59,7 @@ import java.util.Collections;
|
||||||
import java.util.concurrent.ExecutionException;
|
import java.util.concurrent.ExecutionException;
|
||||||
import java.util.concurrent.Future;
|
import java.util.concurrent.Future;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
|
||||||
import static org.deeplearning4j.parallelism.inference.InferenceMode.INPLACE;
|
import static org.deeplearning4j.parallelism.inference.InferenceMode.INPLACE;
|
||||||
import static org.deeplearning4j.parallelism.inference.InferenceMode.SEQUENTIAL;
|
import static org.deeplearning4j.parallelism.inference.InferenceMode.SEQUENTIAL;
|
||||||
|
@ -66,7 +68,6 @@ import static org.junit.Assert.*;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class JsonModelServerTest extends BaseDL4JTest {
|
public class JsonModelServerTest extends BaseDL4JTest {
|
||||||
private static final MultiLayerNetwork model;
|
private static final MultiLayerNetwork model;
|
||||||
private final int PORT = 18080;
|
|
||||||
|
|
||||||
static {
|
static {
|
||||||
val conf = new NeuralNetConfiguration.Builder()
|
val conf = new NeuralNetConfiguration.Builder()
|
||||||
|
@ -84,10 +85,18 @@ public class JsonModelServerTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void pause() throws Exception {
|
public void pause() throws Exception {
|
||||||
// TODO: the same port was used in previous test and not accessible immediately. Might be better solution.
|
// Need to wait for server shutdown; without sleep, tests will fail if starting immediately after shutdown
|
||||||
TimeUnit.SECONDS.sleep(2);
|
TimeUnit.SECONDS.sleep(2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private AtomicInteger portCount = new AtomicInteger(18080);
|
||||||
|
private int PORT;
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void setPort(){
|
||||||
|
PORT = portCount.getAndIncrement();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testStartStopParallel() throws Exception {
|
public void testStartStopParallel() throws Exception {
|
||||||
|
@ -343,7 +352,7 @@ public class JsonModelServerTest extends BaseDL4JTest {
|
||||||
val server = new JsonModelServer.Builder<House, PredictedPrice>(model)
|
val server = new JsonModelServer.Builder<House, PredictedPrice>(model)
|
||||||
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
|
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
|
||||||
.inputDeserializer(null)
|
.inputDeserializer(null)
|
||||||
.port(18080)
|
.port(PORT)
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -382,7 +391,7 @@ public class JsonModelServerTest extends BaseDL4JTest {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.endpointAddress("http://localhost:18080/v1/serving")
|
.endpointAddress("http://localhost:" + PORT + "/v1/serving")
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
int district = 2;
|
int district = 2;
|
||||||
|
|
|
@ -485,7 +485,8 @@ public class ParallelInferenceTest extends BaseDL4JTest {
|
||||||
List<INDArray> exp = new ArrayList<>();
|
List<INDArray> exp = new ArrayList<>();
|
||||||
|
|
||||||
Random r = new Random();
|
Random r = new Random();
|
||||||
for (int i = 0; i < 500; i++) {
|
int runs = isIntegrationTests() ? 500 : 30;
|
||||||
|
for (int i = 0; i < runs; i++) {
|
||||||
int[] shape = defaultSize;
|
int[] shape = defaultSize;
|
||||||
if (r.nextDouble() < 0.4) {
|
if (r.nextDouble() < 0.4) {
|
||||||
shape = new int[]{r.nextInt(5) + 1, 10, r.nextInt(10) + 1};
|
shape = new int[]{r.nextInt(5) + 1, 10, r.nextInt(10) + 1};
|
||||||
|
@ -597,7 +598,8 @@ public class ParallelInferenceTest extends BaseDL4JTest {
|
||||||
List<INDArray> arrs = new ArrayList<>();
|
List<INDArray> arrs = new ArrayList<>();
|
||||||
List<INDArray> exp = new ArrayList<>();
|
List<INDArray> exp = new ArrayList<>();
|
||||||
Random r = new Random();
|
Random r = new Random();
|
||||||
for( int i=0; i<500; i++ ){
|
int runs = isIntegrationTests() ? 500 : 20;
|
||||||
|
for( int i=0; i<runs; i++ ){
|
||||||
int[] shape = defaultShape;
|
int[] shape = defaultShape;
|
||||||
if(r.nextDouble() < 0.4){
|
if(r.nextDouble() < 0.4){
|
||||||
shape = new int[]{r.nextInt(5)+1, nIn, 10, r.nextInt(10)+1};
|
shape = new int[]{r.nextInt(5)+1, nIn, 10, r.nextInt(10)+1};
|
||||||
|
@ -679,8 +681,7 @@ public class ParallelInferenceTest extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(timeout = 120000)
|
private void testInputMasking() throws Exception {
|
||||||
public void testInputMasking() throws Exception {
|
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
int nIn = 10;
|
int nIn = 10;
|
||||||
|
@ -698,12 +699,15 @@ public class ParallelInferenceTest extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
|
// InferenceMode[] inferenceModes = new InferenceMode[]{InferenceMode.SEQUENTIAL, InferenceMode.BATCHED, InferenceMode.INPLACE, InferenceMode.SEQUENTIAL};
|
||||||
|
// int[] workers = new int[]{2, 2, 2, 1};
|
||||||
|
// boolean[] randomTS = new boolean[]{true, false, true, false};
|
||||||
|
|
||||||
Random r = new Random();
|
Random r = new Random();
|
||||||
for( InferenceMode m : InferenceMode.values()) {
|
for( InferenceMode m : InferenceMode.values()) {
|
||||||
log.info("Testing inference mode: [{}]", m);
|
log.info("Testing inference mode: [{}]", m);
|
||||||
for( int w : new int[]{1,2}) {
|
for( int w : new int[]{1,2}) {
|
||||||
for (boolean randomTSLength : new boolean[]{false, true}) {
|
for (boolean randomTSLength : new boolean[]{false, true}) {
|
||||||
|
|
||||||
final ParallelInference inf =
|
final ParallelInference inf =
|
||||||
new ParallelInference.Builder(net)
|
new ParallelInference.Builder(net)
|
||||||
.inferenceMode(m)
|
.inferenceMode(m)
|
||||||
|
@ -714,7 +718,8 @@ public class ParallelInferenceTest extends BaseDL4JTest {
|
||||||
List<INDArray> in = new ArrayList<>();
|
List<INDArray> in = new ArrayList<>();
|
||||||
List<INDArray> inMasks = new ArrayList<>();
|
List<INDArray> inMasks = new ArrayList<>();
|
||||||
List<INDArray> exp = new ArrayList<>();
|
List<INDArray> exp = new ArrayList<>();
|
||||||
for (int i = 0; i < 100; i++) {
|
int nRuns = isIntegrationTests() ? 100 : 10;
|
||||||
|
for (int i = 0; i < nRuns; i++) {
|
||||||
int currTSLength = (randomTSLength ? 1 + r.nextInt(tsLength) : tsLength);
|
int currTSLength = (randomTSLength ? 1 + r.nextInt(tsLength) : tsLength);
|
||||||
int currNumEx = 1 + r.nextInt(3);
|
int currNumEx = 1 + r.nextInt(3);
|
||||||
INDArray inArr = Nd4j.rand(new int[]{currNumEx, nIn, currTSLength});
|
INDArray inArr = Nd4j.rand(new int[]{currNumEx, nIn, currTSLength});
|
||||||
|
@ -847,6 +852,7 @@ public class ParallelInferenceTest extends BaseDL4JTest {
|
||||||
|
|
||||||
List<INDArray[]> in = new ArrayList<>();
|
List<INDArray[]> in = new ArrayList<>();
|
||||||
List<INDArray[]> exp = new ArrayList<>();
|
List<INDArray[]> exp = new ArrayList<>();
|
||||||
|
int runs = isIntegrationTests() ? 100 : 20;
|
||||||
for (int i = 0; i < 100; i++) {
|
for (int i = 0; i < 100; i++) {
|
||||||
int currNumEx = 1 + r.nextInt(3);
|
int currNumEx = 1 + r.nextInt(3);
|
||||||
INDArray inArr = Nd4j.rand(new int[]{currNumEx, nIn});
|
INDArray inArr = Nd4j.rand(new int[]{currNumEx, nIn});
|
||||||
|
|
|
@ -62,8 +62,8 @@ public class ParallelWrapperTest extends BaseDL4JTest {
|
||||||
int seed = 123;
|
int seed = 123;
|
||||||
|
|
||||||
log.info("Load data....");
|
log.info("Load data....");
|
||||||
DataSetIterator mnistTrain = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(batchSize, true, 12345), 100);
|
DataSetIterator mnistTrain = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(batchSize, true, 12345), 15);
|
||||||
DataSetIterator mnistTest = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(batchSize, false, 12345), 10);
|
DataSetIterator mnistTest = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(batchSize, false, 12345), 4);
|
||||||
|
|
||||||
assertTrue(mnistTrain.hasNext());
|
assertTrue(mnistTrain.hasNext());
|
||||||
val t0 = mnistTrain.next();
|
val t0 = mnistTrain.next();
|
||||||
|
|
|
@ -47,6 +47,12 @@
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
<artifactId>nd4j-parameter-server-node_2.11</artifactId>
|
<artifactId>nd4j-parameter-server-node_2.11</artifactId>
|
||||||
<version>${nd4j.version}</version>
|
<version>${nd4j.version}</version>
|
||||||
|
<exclusions>
|
||||||
|
<exclusion>
|
||||||
|
<groupId>net.jpountz.lz4</groupId>
|
||||||
|
<artifactId>lz4</artifactId>
|
||||||
|
</exclusion>
|
||||||
|
</exclusions>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>junit</groupId>
|
<groupId>junit</groupId>
|
||||||
|
|
|
@ -0,0 +1,31 @@
|
||||||
|
################################################################################
|
||||||
|
# Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
#
|
||||||
|
# This program and the accompanying materials are made available under the
|
||||||
|
# terms of the Apache License, Version 2.0 which is available at
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
# License for the specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
################################################################################
|
||||||
|
|
||||||
|
log4j.rootLogger=ERROR, Console
|
||||||
|
log4j.appender.Console=org.apache.log4j.ConsoleAppender
|
||||||
|
log4j.appender.Console.layout=org.apache.log4j.PatternLayout
|
||||||
|
log4j.appender.Console.layout.ConversionPattern=%d{ABSOLUTE} %-5p ~ %m%n
|
||||||
|
|
||||||
|
log4j.appender.org.springframework=DEBUG
|
||||||
|
log4j.appender.org.deeplearning4j=DEBUG
|
||||||
|
log4j.appender.org.nd4j=DEBUG
|
||||||
|
|
||||||
|
log4j.logger.org.springframework=INFO
|
||||||
|
log4j.logger.org.deeplearning4j=DEBUG
|
||||||
|
log4j.logger.org.nd4j=DEBUG
|
||||||
|
log4j.logger.org.apache.spark=WARN
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,53 @@
|
||||||
|
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
~ Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
~
|
||||||
|
~ This program and the accompanying materials are made available under the
|
||||||
|
~ terms of the Apache License, Version 2.0 which is available at
|
||||||
|
~ https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
~
|
||||||
|
~ Unless required by applicable law or agreed to in writing, software
|
||||||
|
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
~ License for the specific language governing permissions and limitations
|
||||||
|
~ under the License.
|
||||||
|
~
|
||||||
|
~ SPDX-License-Identifier: Apache-2.0
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
|
||||||
|
|
||||||
|
<configuration>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
|
||||||
|
<file>logs/application.log</file>
|
||||||
|
<encoder>
|
||||||
|
<pattern>%date - [%level] - from %logger in %thread
|
||||||
|
%n%message%n%xException%n</pattern>
|
||||||
|
</encoder>
|
||||||
|
</appender>
|
||||||
|
|
||||||
|
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
|
||||||
|
<encoder>
|
||||||
|
<pattern> %logger{15} - %message%n%xException{5}
|
||||||
|
</pattern>
|
||||||
|
</encoder>
|
||||||
|
</appender>
|
||||||
|
|
||||||
|
<logger name="org.apache.catalina.core" level="DEBUG" />
|
||||||
|
<logger name="org.springframework" level="DEBUG" />
|
||||||
|
<logger name="org.deeplearning4j" level="DEBUG" />
|
||||||
|
<logger name="org.datavec" level="INFO" />
|
||||||
|
<logger name="org.nd4j" level="INFO" />
|
||||||
|
<logger name="opennlp.uima.util" level="OFF" />
|
||||||
|
<logger name="org.apache.uima" level="OFF" />
|
||||||
|
<logger name="org.cleartk" level="OFF" />
|
||||||
|
<logger name="org.apache.spark" level="WARN" />
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
<root level="ERROR">
|
||||||
|
<appender-ref ref="STDOUT" />
|
||||||
|
<appender-ref ref="FILE" />
|
||||||
|
</root>
|
||||||
|
|
||||||
|
</configuration>
|
|
@ -0,0 +1,31 @@
|
||||||
|
################################################################################
|
||||||
|
# Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
#
|
||||||
|
# This program and the accompanying materials are made available under the
|
||||||
|
# terms of the Apache License, Version 2.0 which is available at
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
# License for the specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
################################################################################
|
||||||
|
|
||||||
|
log4j.rootLogger=ERROR, Console
|
||||||
|
log4j.appender.Console=org.apache.log4j.ConsoleAppender
|
||||||
|
log4j.appender.Console.layout=org.apache.log4j.PatternLayout
|
||||||
|
log4j.appender.Console.layout.ConversionPattern=%d{ABSOLUTE} %-5p ~ %m%n
|
||||||
|
|
||||||
|
log4j.appender.org.springframework=DEBUG
|
||||||
|
log4j.appender.org.deeplearning4j=DEBUG
|
||||||
|
log4j.appender.org.nd4j=DEBUG
|
||||||
|
|
||||||
|
log4j.logger.org.springframework=INFO
|
||||||
|
log4j.logger.org.deeplearning4j=DEBUG
|
||||||
|
log4j.logger.org.nd4j=DEBUG
|
||||||
|
log4j.logger.org.apache.spark=WARN
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,53 @@
|
||||||
|
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
~ Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
~
|
||||||
|
~ This program and the accompanying materials are made available under the
|
||||||
|
~ terms of the Apache License, Version 2.0 which is available at
|
||||||
|
~ https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
~
|
||||||
|
~ Unless required by applicable law or agreed to in writing, software
|
||||||
|
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
~ License for the specific language governing permissions and limitations
|
||||||
|
~ under the License.
|
||||||
|
~
|
||||||
|
~ SPDX-License-Identifier: Apache-2.0
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
|
||||||
|
|
||||||
|
<configuration>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
|
||||||
|
<file>logs/application.log</file>
|
||||||
|
<encoder>
|
||||||
|
<pattern>%date - [%level] - from %logger in %thread
|
||||||
|
%n%message%n%xException%n</pattern>
|
||||||
|
</encoder>
|
||||||
|
</appender>
|
||||||
|
|
||||||
|
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
|
||||||
|
<encoder>
|
||||||
|
<pattern> %logger{15} - %message%n%xException{5}
|
||||||
|
</pattern>
|
||||||
|
</encoder>
|
||||||
|
</appender>
|
||||||
|
|
||||||
|
<logger name="org.apache.catalina.core" level="DEBUG" />
|
||||||
|
<logger name="org.springframework" level="DEBUG" />
|
||||||
|
<logger name="org.deeplearning4j" level="DEBUG" />
|
||||||
|
<logger name="org.datavec" level="INFO" />
|
||||||
|
<logger name="org.nd4j" level="INFO" />
|
||||||
|
<logger name="opennlp.uima.util" level="OFF" />
|
||||||
|
<logger name="org.apache.uima" level="OFF" />
|
||||||
|
<logger name="org.cleartk" level="OFF" />
|
||||||
|
<logger name="org.apache.spark" level="WARN" />
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
<root level="ERROR">
|
||||||
|
<appender-ref ref="STDOUT" />
|
||||||
|
<appender-ref ref="FILE" />
|
||||||
|
</root>
|
||||||
|
|
||||||
|
</configuration>
|
|
@ -25,6 +25,7 @@ import org.datavec.spark.transform.misc.StringToWritablesFunction;
|
||||||
import org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator;
|
import org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator;
|
||||||
import org.deeplearning4j.spark.BaseSparkTest;
|
import org.deeplearning4j.spark.BaseSparkTest;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
import org.nd4j.linalg.io.ClassPathResource;
|
import org.nd4j.linalg.io.ClassPathResource;
|
||||||
|
|
||||||
|
@ -35,6 +36,15 @@ import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
public class TestIteratorUtils extends BaseSparkTest {
|
public class TestIteratorUtils extends BaseSparkTest {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DataType getDataType() {
|
||||||
|
return DataType.FLOAT;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DataType getDefaultFPDataType() {
|
||||||
|
return DataType.FLOAT;
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testIrisRRMDSI() throws Exception {
|
public void testIrisRRMDSI() throws Exception {
|
||||||
|
|
|
@ -453,8 +453,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
tempDirF.deleteOnExit();
|
tempDirF.deleteOnExit();
|
||||||
|
|
||||||
int dataSetObjSize = 1;
|
int dataSetObjSize = 1;
|
||||||
int batchSizePerExecutor = 16;
|
int batchSizePerExecutor = 4;
|
||||||
int numSplits = 5;
|
int numSplits = 3;
|
||||||
int averagingFrequency = 3;
|
int averagingFrequency = 3;
|
||||||
int totalExamples = numExecutors() * batchSizePerExecutor * numSplits * averagingFrequency;
|
int totalExamples = numExecutors() * batchSizePerExecutor * numSplits * averagingFrequency;
|
||||||
DataSetIterator iter = new MnistDataSetIterator(dataSetObjSize, totalExamples, false);
|
DataSetIterator iter = new MnistDataSetIterator(dataSetObjSize, totalExamples, false);
|
||||||
|
@ -506,7 +506,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
INDArray paramsAfter = sparkNet.getNetwork().params().dup();
|
INDArray paramsAfter = sparkNet.getNetwork().params().dup();
|
||||||
assertNotEquals(paramsBefore, paramsAfter);
|
assertNotEquals(paramsBefore, paramsAfter);
|
||||||
|
|
||||||
Thread.sleep(2000);
|
Thread.sleep(200);
|
||||||
SparkTrainingStats stats = sparkNet.getSparkTrainingStats();
|
SparkTrainingStats stats = sparkNet.getSparkTrainingStats();
|
||||||
|
|
||||||
//Expect
|
//Expect
|
||||||
|
@ -517,7 +517,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
assertEquals(numSplits * numExecutors() * averagingFrequency, list.size());
|
assertEquals(numSplits * numExecutors() * averagingFrequency, list.size());
|
||||||
for (EventStats es : list) {
|
for (EventStats es : list) {
|
||||||
ExampleCountEventStats e = (ExampleCountEventStats) es;
|
ExampleCountEventStats e = (ExampleCountEventStats) es;
|
||||||
assertTrue(batchSizePerExecutor * averagingFrequency - 10 >= e.getTotalExampleCount());
|
assertTrue(batchSizePerExecutor * averagingFrequency >= e.getTotalExampleCount());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -535,9 +535,9 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
tempDirF.deleteOnExit();
|
tempDirF.deleteOnExit();
|
||||||
tempDirF2.deleteOnExit();
|
tempDirF2.deleteOnExit();
|
||||||
|
|
||||||
int dataSetObjSize = 5;
|
int dataSetObjSize = 4;
|
||||||
int batchSizePerExecutor = 25;
|
int batchSizePerExecutor = 8;
|
||||||
DataSetIterator iter = new MnistDataSetIterator(dataSetObjSize, 1000, false);
|
DataSetIterator iter = new MnistDataSetIterator(dataSetObjSize, 128, false);
|
||||||
int i = 0;
|
int i = 0;
|
||||||
while (iter.hasNext()) {
|
while (iter.hasNext()) {
|
||||||
File nextFile = new File(tempDirF, i + ".bin");
|
File nextFile = new File(tempDirF, i + ".bin");
|
||||||
|
@ -981,7 +981,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
.setOutputs("out")
|
.setOutputs("out")
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
DataSetIterator iter = new IrisDataSetIterator(1, 150);
|
DataSetIterator iter = new IrisDataSetIterator(1, 50);
|
||||||
|
|
||||||
List<DataSet> l = new ArrayList<>();
|
List<DataSet> l = new ArrayList<>();
|
||||||
while(iter.hasNext()){
|
while(iter.hasNext()){
|
||||||
|
@ -992,9 +992,10 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
|
|
||||||
|
|
||||||
int rddDataSetNumExamples = 1;
|
int rddDataSetNumExamples = 1;
|
||||||
int averagingFrequency = 3;
|
int averagingFrequency = 2;
|
||||||
|
int batch = 2;
|
||||||
ParameterAveragingTrainingMaster tm = new ParameterAveragingTrainingMaster.Builder(rddDataSetNumExamples)
|
ParameterAveragingTrainingMaster tm = new ParameterAveragingTrainingMaster.Builder(rddDataSetNumExamples)
|
||||||
.averagingFrequency(averagingFrequency).batchSizePerWorker(rddDataSetNumExamples)
|
.averagingFrequency(averagingFrequency).batchSizePerWorker(batch)
|
||||||
.saveUpdater(true).workerPrefetchNumBatches(0).build();
|
.saveUpdater(true).workerPrefetchNumBatches(0).build();
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
@ -1003,7 +1004,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
SparkComputationGraph sn2 = new SparkComputationGraph(sc, conf2.clone(), tm);
|
SparkComputationGraph sn2 = new SparkComputationGraph(sc, conf2.clone(), tm);
|
||||||
|
|
||||||
|
|
||||||
for(int i=0; i<4; i++ ){
|
for(int i=0; i<3; i++ ){
|
||||||
assertEquals(i, sn1.getNetwork().getLayerWiseConfigurations().getEpochCount());
|
assertEquals(i, sn1.getNetwork().getLayerWiseConfigurations().getEpochCount());
|
||||||
assertEquals(i, sn2.getNetwork().getConfiguration().getEpochCount());
|
assertEquals(i, sn2.getNetwork().getConfiguration().getEpochCount());
|
||||||
sn1.fit(rdd);
|
sn1.fit(rdd);
|
||||||
|
|
|
@ -42,6 +42,11 @@ import static org.junit.Assert.assertTrue;
|
||||||
*/
|
*/
|
||||||
public class TestRepartitioning extends BaseSparkTest {
|
public class TestRepartitioning extends BaseSparkTest {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getTimeoutMilliseconds() {
|
||||||
|
return isIntegrationTests() ? 240000 : 60000;
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRepartitioning() {
|
public void testRepartitioning() {
|
||||||
List<String> list = new ArrayList<>();
|
List<String> list = new ArrayList<>();
|
||||||
|
@ -66,7 +71,12 @@ public class TestRepartitioning extends BaseSparkTest {
|
||||||
@Test
|
@Test
|
||||||
public void testRepartitioning2() throws Exception {
|
public void testRepartitioning2() throws Exception {
|
||||||
|
|
||||||
int[] ns = {320, 321, 25600, 25601, 25615};
|
int[] ns;
|
||||||
|
if(isIntegrationTests()){
|
||||||
|
ns = new int[]{320, 321, 25600, 25601, 25615};
|
||||||
|
} else {
|
||||||
|
ns = new int[]{320, 2561};
|
||||||
|
}
|
||||||
|
|
||||||
for (int n : ns) {
|
for (int n : ns) {
|
||||||
|
|
||||||
|
|
|
@ -32,6 +32,11 @@ import java.io.File;
|
||||||
|
|
||||||
public class MiscTests extends BaseDL4JTest {
|
public class MiscTests extends BaseDL4JTest {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getTimeoutMilliseconds() {
|
||||||
|
return 120000L;
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testTransferVGG() throws Exception {
|
public void testTransferVGG() throws Exception {
|
||||||
//https://github.com/deeplearning4j/deeplearning4j/issues/5167
|
//https://github.com/deeplearning4j/deeplearning4j/issues/5167
|
||||||
|
|
|
@ -48,6 +48,11 @@ import static org.junit.Assert.assertEquals;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class TestDownload extends BaseDL4JTest {
|
public class TestDownload extends BaseDL4JTest {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getTimeoutMilliseconds() {
|
||||||
|
return isIntegrationTests() ? 480000L : 60000L;
|
||||||
|
}
|
||||||
|
|
||||||
@ClassRule
|
@ClassRule
|
||||||
public static TemporaryFolder testDir = new TemporaryFolder();
|
public static TemporaryFolder testDir = new TemporaryFolder();
|
||||||
private static File f;
|
private static File f;
|
||||||
|
@ -67,12 +72,20 @@ public class TestDownload extends BaseDL4JTest {
|
||||||
public void testDownloadAllModels() throws Exception {
|
public void testDownloadAllModels() throws Exception {
|
||||||
|
|
||||||
// iterate through each available model
|
// iterate through each available model
|
||||||
ZooModel[] models = new ZooModel[]{
|
ZooModel[] models;
|
||||||
LeNet.builder().build(),
|
|
||||||
SimpleCNN.builder().build(),
|
if(isIntegrationTests()){
|
||||||
UNet.builder().build(),
|
models = new ZooModel[]{
|
||||||
NASNet.builder().build()
|
LeNet.builder().build(),
|
||||||
};
|
SimpleCNN.builder().build(),
|
||||||
|
UNet.builder().build(),
|
||||||
|
NASNet.builder().build()};
|
||||||
|
} else {
|
||||||
|
models = new ZooModel[]{
|
||||||
|
LeNet.builder().build(),
|
||||||
|
SimpleCNN.builder().build()};
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
for (int i = 0; i < models.length; i++) {
|
for (int i = 0; i < models.length; i++) {
|
||||||
|
|
|
@ -57,6 +57,11 @@ import static org.junit.Assert.assertTrue;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class TestImageNet extends BaseDL4JTest {
|
public class TestImageNet extends BaseDL4JTest {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getTimeoutMilliseconds() {
|
||||||
|
return 90000L;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public DataType getDataType(){
|
public DataType getDataType(){
|
||||||
return DataType.FLOAT;
|
return DataType.FLOAT;
|
||||||
|
|
|
@ -63,10 +63,15 @@ public class DistributionUniform extends DynamicCustomOp {
|
||||||
addArgs();
|
addArgs();
|
||||||
}
|
}
|
||||||
|
|
||||||
public DistributionUniform(INDArray shape, INDArray out, double min, double max){
|
public DistributionUniform(INDArray shape, INDArray out, double min, double max) {
|
||||||
|
this(shape, out, min, max, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
public DistributionUniform(INDArray shape, INDArray out, double min, double max, DataType dataType){
|
||||||
super(null, new INDArray[]{shape}, new INDArray[]{out}, Arrays.asList(min, max), (List<Integer>)null);
|
super(null, new INDArray[]{shape}, new INDArray[]{out}, Arrays.asList(min, max), (List<Integer>)null);
|
||||||
this.min = min;
|
this.min = min;
|
||||||
this.max = max;
|
this.max = max;
|
||||||
|
this.dataType = dataType;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -310,6 +310,13 @@
|
||||||
</exclusion>
|
</exclusion>
|
||||||
</exclusions>
|
</exclusions>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.nd4j</groupId>
|
||||||
|
<artifactId>nd4j-common-tests</artifactId>
|
||||||
|
<version>${project.version}</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.nd4j.jita.allocator;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.nd4j.BaseND4JTest;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.util.DeviceLocalNDArray;
|
import org.nd4j.linalg.util.DeviceLocalNDArray;
|
||||||
|
@ -29,7 +30,7 @@ import static org.junit.Assert.assertArrayEquals;
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class DeviceLocalNDArrayTests {
|
public class DeviceLocalNDArrayTests extends BaseND4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testDeviceLocalArray_1() throws Exception{
|
public void testDeviceLocalArray_1() throws Exception{
|
||||||
|
|
|
@ -19,13 +19,14 @@ package org.nd4j.jita.allocator.impl;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.nd4j.BaseND4JTest;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.Assert.*;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class MemoryTrackerTest {
|
public class MemoryTrackerTest extends BaseND4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testAllocatedDelta() {
|
public void testAllocatedDelta() {
|
||||||
|
|
|
@ -4,6 +4,7 @@ import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.nd4j.BaseND4JTest;
|
||||||
import org.nd4j.jita.allocator.impl.AtomicAllocator;
|
import org.nd4j.jita.allocator.impl.AtomicAllocator;
|
||||||
import org.nd4j.jita.workspace.CudaWorkspace;
|
import org.nd4j.jita.workspace.CudaWorkspace;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
@ -20,7 +21,7 @@ import java.util.concurrent.atomic.AtomicInteger;
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.Assert.*;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class BaseCudaDataBufferTest {
|
public class BaseCudaDataBufferTest extends BaseND4JTest {
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
|
|
|
@ -87,6 +87,13 @@
|
||||||
<version>${logback.version}</version>
|
<version>${logback.version}</version>
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.nd4j</groupId>
|
||||||
|
<artifactId>nd4j-common-tests</artifactId>
|
||||||
|
<version>${project.version}</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
|
|
|
@ -20,6 +20,7 @@ package org.nd4j.tensorflow.conversion;
|
||||||
import junit.framework.TestCase;
|
import junit.framework.TestCase;
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
import org.bytedeco.tensorflow.TF_Tensor;
|
import org.bytedeco.tensorflow.TF_Tensor;
|
||||||
|
import org.nd4j.BaseND4JTest;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.resources.Resources;
|
import org.nd4j.resources.Resources;
|
||||||
import org.nd4j.shade.protobuf.Descriptors;
|
import org.nd4j.shade.protobuf.Descriptors;
|
||||||
|
@ -46,7 +47,17 @@ import java.util.Map;
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
import static org.junit.Assert.assertNotNull;
|
import static org.junit.Assert.assertNotNull;
|
||||||
|
|
||||||
public class GraphRunnerTest {
|
public class GraphRunnerTest extends BaseND4JTest {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DataType getDataType() {
|
||||||
|
return DataType.FLOAT;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DataType getDefaultFPDataType() {
|
||||||
|
return DataType.FLOAT;
|
||||||
|
}
|
||||||
|
|
||||||
public static ConfigProto getConfig(){
|
public static ConfigProto getConfig(){
|
||||||
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
||||||
|
|
|
@ -18,6 +18,7 @@ package org.nd4j.tensorflow.conversion;
|
||||||
|
|
||||||
import org.apache.commons.io.IOUtils;
|
import org.apache.commons.io.IOUtils;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.nd4j.BaseND4JTest;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.io.ClassPathResource;
|
import org.nd4j.linalg.io.ClassPathResource;
|
||||||
|
@ -29,7 +30,7 @@ import static org.junit.Assert.assertEquals;
|
||||||
import static org.junit.Assert.assertNotNull;
|
import static org.junit.Assert.assertNotNull;
|
||||||
import static org.junit.Assert.fail;
|
import static org.junit.Assert.fail;
|
||||||
|
|
||||||
public class TensorflowConversionTest {
|
public class TensorflowConversionTest extends BaseND4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testView() {
|
public void testView() {
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
|
|
||||||
package org.nd4j.tensorflow.conversion;
|
package org.nd4j.tensorflow.conversion;
|
||||||
|
|
||||||
|
import org.nd4j.BaseND4JTest;
|
||||||
import org.nd4j.shade.protobuf.util.JsonFormat;
|
import org.nd4j.shade.protobuf.util.JsonFormat;
|
||||||
import org.apache.commons.io.IOUtils;
|
import org.apache.commons.io.IOUtils;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
@ -37,7 +38,7 @@ import java.util.Map;
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
import static org.junit.Assert.assertNotNull;
|
import static org.junit.Assert.assertNotNull;
|
||||||
|
|
||||||
public class GpuGraphRunnerTest {
|
public class GpuGraphRunnerTest extends BaseND4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testGraphRunner() throws Exception {
|
public void testGraphRunner() throws Exception {
|
||||||
|
|
|
@ -127,6 +127,13 @@
|
||||||
</exclusion>
|
</exclusion>
|
||||||
</exclusions>
|
</exclusions>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.nd4j</groupId>
|
||||||
|
<artifactId>nd4j-common-tests</artifactId>
|
||||||
|
<version>${project.version}</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
<reporting>
|
<reporting>
|
||||||
|
|
|
@ -16,9 +16,6 @@
|
||||||
|
|
||||||
package org.nd4j.autodiff.opvalidation;
|
package org.nd4j.autodiff.opvalidation;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
|
||||||
import static org.junit.Assert.assertNull;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
@ -46,6 +43,8 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||||
|
|
||||||
|
import static org.junit.Assert.*;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class LayerOpValidation extends BaseOpValidation {
|
public class LayerOpValidation extends BaseOpValidation {
|
||||||
public LayerOpValidation(Nd4jBackend backend) {
|
public LayerOpValidation(Nd4jBackend backend) {
|
||||||
|
|
|
@ -45,7 +45,7 @@ public class LossOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long testTimeoutMilliseconds() {
|
public long getTimeoutMilliseconds() {
|
||||||
return 90000L;
|
return 90000L;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -54,8 +54,7 @@ import java.util.Arrays;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.*;
|
||||||
import static org.junit.Assert.assertNull;
|
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@RunWith(Parameterized.class)
|
@RunWith(Parameterized.class)
|
||||||
|
|
|
@ -2434,8 +2434,6 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testPermute4(){
|
public void testPermute4(){
|
||||||
Nd4j.getExecutioner().enableDebugMode(true);
|
|
||||||
Nd4j.getExecutioner().enableVerboseMode(true);
|
|
||||||
INDArray in = Nd4j.linspace(DataType.FLOAT, 1, 6, 1).reshape(3,2);
|
INDArray in = Nd4j.linspace(DataType.FLOAT, 1, 6, 1).reshape(3,2);
|
||||||
INDArray permute = Nd4j.createFromArray(1,0);
|
INDArray permute = Nd4j.createFromArray(1,0);
|
||||||
|
|
||||||
|
|
|
@ -63,8 +63,12 @@ public class CheckpointListenerTest extends BaseNd4jTest {
|
||||||
return sd;
|
return sd;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static DataSetIterator getIter(){
|
public static DataSetIterator getIter() {
|
||||||
return new IrisDataSetIterator(15, 150);
|
return getIter(15, 150);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static DataSetIterator getIter(int batch, int totalExamples){
|
||||||
|
return new IrisDataSetIterator(batch, totalExamples);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -148,15 +152,15 @@ public class CheckpointListenerTest extends BaseNd4jTest {
|
||||||
|
|
||||||
CheckpointListener l = new CheckpointListener.Builder(dir)
|
CheckpointListener l = new CheckpointListener.Builder(dir)
|
||||||
.keepLast(2)
|
.keepLast(2)
|
||||||
.saveEvery(3, TimeUnit.SECONDS)
|
.saveEvery(1, TimeUnit.SECONDS)
|
||||||
.build();
|
.build();
|
||||||
sd.setListeners(l);
|
sd.setListeners(l);
|
||||||
|
|
||||||
DataSetIterator iter = getIter();
|
DataSetIterator iter = getIter(15, 150);
|
||||||
|
|
||||||
for(int i=0; i<5; i++ ){ //10 iterations total
|
for(int i=0; i<5; i++ ){ //10 iterations total
|
||||||
sd.fit(iter, 1);
|
sd.fit(iter, 1);
|
||||||
Thread.sleep(4000);
|
Thread.sleep(1000);
|
||||||
}
|
}
|
||||||
|
|
||||||
//Expect models saved at iterations: 10, 20, 30, 40
|
//Expect models saved at iterations: 10, 20, 30, 40
|
||||||
|
|
|
@ -29,6 +29,7 @@ import org.nd4j.linalg.ops.transforms.Transforms;
|
||||||
|
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertArrayEquals;
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -34,8 +34,7 @@ import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.*;
|
||||||
import static org.junit.Assert.assertTrue;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Created by Alex on 05/07/2017.
|
* Created by Alex on 05/07/2017.
|
||||||
|
|
|
@ -33,8 +33,7 @@ import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.*;
|
||||||
import static org.junit.Assert.assertTrue;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Created by Alex on 04/11/2016.
|
* Created by Alex on 04/11/2016.
|
||||||
|
|
|
@ -33,6 +33,7 @@ import org.nd4j.nativeblas.NativeOpsHolder;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertArrayEquals;
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
|
|
|
@ -121,7 +121,12 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
|
||||||
"fused_batch_norm/.*",
|
"fused_batch_norm/.*",
|
||||||
|
|
||||||
// AB 2020/01/04 - https://github.com/eclipse/deeplearning4j/issues/8592
|
// AB 2020/01/04 - https://github.com/eclipse/deeplearning4j/issues/8592
|
||||||
"emptyArrayTests/reshape/rank2_shape2-0_2-0--1"
|
"emptyArrayTests/reshape/rank2_shape2-0_2-0--1",
|
||||||
|
|
||||||
|
//AB 2020/01/07 - Known issues
|
||||||
|
"bitcast/from_float64_to_int64",
|
||||||
|
"bitcast/from_rank2_float64_to_int64",
|
||||||
|
"bitcast/from_float64_to_uint64"
|
||||||
};
|
};
|
||||||
|
|
||||||
/* As per TFGraphTestList.printArraysDebugging - this field defines a set of regexes for test cases that should have
|
/* As per TFGraphTestList.printArraysDebugging - this field defines a set of regexes for test cases that should have
|
||||||
|
|
|
@ -252,7 +252,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
|
||||||
System.out.println(Arrays.toString(shape));
|
System.out.println(Arrays.toString(shape));
|
||||||
|
|
||||||
// this is NHWC weights. will be changed soon.
|
// this is NHWC weights. will be changed soon.
|
||||||
assertArrayEquals(new int[]{5,5,1,32}, shape);
|
assertArrayEquals(new long[]{5,5,1,32}, shape);
|
||||||
System.out.println(convNode);
|
System.out.println(convNode);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.nd4j.linalg;
|
package org.nd4j.linalg;
|
||||||
|
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.bytedeco.javacpp.Pointer;
|
import org.bytedeco.javacpp.Pointer;
|
||||||
import org.junit.After;
|
import org.junit.After;
|
||||||
|
@ -26,6 +27,7 @@ import org.junit.rules.TestName;
|
||||||
import org.junit.rules.Timeout;
|
import org.junit.rules.Timeout;
|
||||||
import org.junit.runner.RunWith;
|
import org.junit.runner.RunWith;
|
||||||
import org.junit.runners.Parameterized;
|
import org.junit.runners.Parameterized;
|
||||||
|
import org.nd4j.BaseND4JTest;
|
||||||
import org.nd4j.config.ND4JSystemProperties;
|
import org.nd4j.config.ND4JSystemProperties;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||||
|
@ -40,30 +42,16 @@ import org.slf4j.LoggerFactory;
|
||||||
import java.lang.management.ManagementFactory;
|
import java.lang.management.ManagementFactory;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
|
import static org.junit.Assume.assumeTrue;
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Base Nd4j test
|
* Base Nd4j test
|
||||||
* @author Adam Gibson
|
* @author Adam Gibson
|
||||||
*/
|
*/
|
||||||
@RunWith(Parameterized.class)
|
@RunWith(Parameterized.class)
|
||||||
public abstract class BaseNd4jTest {
|
@Slf4j
|
||||||
private static Logger log = LoggerFactory.getLogger(BaseNd4jTest.class);
|
public abstract class BaseNd4jTest extends BaseND4JTest {
|
||||||
|
|
||||||
@Rule
|
|
||||||
public TestName testName = new TestName();
|
|
||||||
|
|
||||||
@Rule
|
|
||||||
public Timeout timeout = Timeout.seconds(testTimeoutMilliseconds());
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Override this method to set the default timeout for methods in the class
|
|
||||||
*/
|
|
||||||
public long testTimeoutMilliseconds(){
|
|
||||||
return 30000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
protected long startTime;
|
|
||||||
protected int threadCountBefore;
|
|
||||||
|
|
||||||
protected Nd4jBackend backend;
|
protected Nd4jBackend backend;
|
||||||
protected String name;
|
protected String name;
|
||||||
|
@ -80,16 +68,10 @@ public abstract class BaseNd4jTest {
|
||||||
public BaseNd4jTest(String name, Nd4jBackend backend) {
|
public BaseNd4jTest(String name, Nd4jBackend backend) {
|
||||||
this.backend = backend;
|
this.backend = backend;
|
||||||
this.name = name;
|
this.name = name;
|
||||||
|
|
||||||
//Suppress ND4J initialization - don't need this logged for every test...
|
|
||||||
System.setProperty(ND4JSystemProperties.LOG_INITIALIZATION, "false");
|
|
||||||
System.setProperty(ND4JSystemProperties.ND4J_IGNORE_AVX, "true");
|
|
||||||
System.gc();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public BaseNd4jTest(Nd4jBackend backend) {
|
public BaseNd4jTest(Nd4jBackend backend) {
|
||||||
this(backend.getClass().getName() + UUID.randomUUID().toString(), backend);
|
this(backend.getClass().getName() + UUID.randomUUID().toString(), backend);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private static List<Nd4jBackend> backends;
|
private static List<Nd4jBackend> backends;
|
||||||
|
@ -104,79 +86,6 @@ public abstract class BaseNd4jTest {
|
||||||
if (backend.canRun() && backendsToRun.contains(backend.getClass().getName()) || backendsToRun.isEmpty())
|
if (backend.canRun() && backendsToRun.contains(backend.getClass().getName()) || backendsToRun.isEmpty())
|
||||||
backends.add(backend);
|
backends.add(backend);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
|
||||||
public static void assertArrayEquals(String string, Object[] expecteds, Object[] actuals) {
|
|
||||||
org.junit.Assert.assertArrayEquals(string, expecteds, actuals);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void assertArrayEquals(Object[] expecteds, Object[] actuals) {
|
|
||||||
org.junit.Assert.assertArrayEquals(expecteds, actuals);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void assertArrayEquals(String string, long[] shapeA, long[] shapeB) {
|
|
||||||
org.junit.Assert.assertArrayEquals(string, shapeA, shapeB);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void assertArrayEquals(String string, byte[] shapeA, byte[] shapeB) {
|
|
||||||
org.junit.Assert.assertArrayEquals(string, shapeA, shapeB);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void assertArrayEquals(byte[] shapeA, byte[] shapeB) {
|
|
||||||
org.junit.Assert.assertArrayEquals(shapeA, shapeB);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void assertArrayEquals(long[] shapeA, long[] shapeB) {
|
|
||||||
org.junit.Assert.assertArrayEquals(shapeA, shapeB);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void assertArrayEquals(String string, int[] shapeA, long[] shapeB) {
|
|
||||||
org.junit.Assert.assertArrayEquals(string, ArrayUtil.toLongArray(shapeA), shapeB);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void assertArrayEquals(int[] shapeA, long[] shapeB) {
|
|
||||||
org.junit.Assert.assertArrayEquals(ArrayUtil.toLongArray(shapeA), shapeB);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void assertArrayEquals(String string, long[] shapeA, int[] shapeB) {
|
|
||||||
org.junit.Assert.assertArrayEquals(string, shapeA, ArrayUtil.toLongArray(shapeB));
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void assertArrayEquals(long[] shapeA, int[] shapeB) {
|
|
||||||
org.junit.Assert.assertArrayEquals(shapeA, ArrayUtil.toLongArray(shapeB));
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void assertArrayEquals(String string, int[] shapeA, int[] shapeB) {
|
|
||||||
org.junit.Assert.assertArrayEquals(string, shapeA, shapeB);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void assertArrayEquals(int[] shapeA, int[] shapeB) {
|
|
||||||
org.junit.Assert.assertArrayEquals(shapeA, shapeB);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void assertArrayEquals(String string, boolean[] shapeA, boolean[] shapeB) {
|
|
||||||
org.junit.Assert.assertArrayEquals(string, shapeA, shapeB);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void assertArrayEquals(boolean[] shapeA, boolean[] shapeB) {
|
|
||||||
org.junit.Assert.assertArrayEquals(shapeA, shapeB);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
public static void assertArrayEquals(float[] shapeA, float[] shapeB, float delta) {
|
|
||||||
org.junit.Assert.assertArrayEquals(shapeA, shapeB, delta);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void assertArrayEquals(double[] shapeA, double[] shapeB, double delta) {
|
|
||||||
org.junit.Assert.assertArrayEquals(shapeA, shapeB, delta);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void assertArrayEquals(String string, float[] shapeA, float[] shapeB, float delta) {
|
|
||||||
org.junit.Assert.assertArrayEquals(string, shapeA, shapeB, delta);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void assertArrayEquals(String string, double[] shapeA, double[] shapeB, double delta) {
|
|
||||||
org.junit.Assert.assertArrayEquals(string, shapeA, shapeB, delta);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Parameterized.Parameters(name = "{index}: backend({0})={1}")
|
@Parameterized.Parameters(name = "{index}: backend({0})={1}")
|
||||||
|
@ -187,6 +96,13 @@ public abstract class BaseNd4jTest {
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
@Before
|
||||||
|
public void beforeTest(){
|
||||||
|
super.beforeTest();
|
||||||
|
Nd4j.factory().setOrder(ordering());
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the default backend (jblas)
|
* Get the default backend (jblas)
|
||||||
* The default backend can be overridden by also passing:
|
* The default backend can be overridden by also passing:
|
||||||
|
@ -207,106 +123,6 @@ public abstract class BaseNd4jTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Before
|
|
||||||
public void before() throws Exception {
|
|
||||||
//
|
|
||||||
log.info("Running {}.{} on {}", getClass().getName(), testName.getMethodName(), backend.getClass().getSimpleName());
|
|
||||||
Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build());
|
|
||||||
Nd4j nd4j = new Nd4j();
|
|
||||||
nd4j.initWithBackend(backend);
|
|
||||||
Nd4j.factory().setOrder(ordering());
|
|
||||||
NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false);
|
|
||||||
Nd4j.getExecutioner().enableDebugMode(false);
|
|
||||||
Nd4j.getExecutioner().enableVerboseMode(false);
|
|
||||||
Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE);
|
|
||||||
startTime = System.currentTimeMillis();
|
|
||||||
threadCountBefore = ManagementFactory.getThreadMXBean().getThreadCount();
|
|
||||||
}
|
|
||||||
|
|
||||||
@After
|
|
||||||
public void after() throws Exception {
|
|
||||||
long totalTime = System.currentTimeMillis() - startTime;
|
|
||||||
Nd4j.getMemoryManager().purgeCaches();
|
|
||||||
|
|
||||||
logTestCompletion(totalTime);
|
|
||||||
if (System.getProperties().getProperty("backends") != null
|
|
||||||
&& !System.getProperty("backends").contains(backend.getClass().getName()))
|
|
||||||
return;
|
|
||||||
Nd4j nd4j = new Nd4j();
|
|
||||||
nd4j.initWithBackend(backend);
|
|
||||||
Nd4j.factory().setOrder(ordering());
|
|
||||||
NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false);
|
|
||||||
Nd4j.getExecutioner().enableDebugMode(false);
|
|
||||||
Nd4j.getExecutioner().enableVerboseMode(false);
|
|
||||||
|
|
||||||
//Attempt to keep workspaces isolated between tests
|
|
||||||
Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread();
|
|
||||||
val currWS = Nd4j.getMemoryManager().getCurrentWorkspace();
|
|
||||||
Nd4j.getMemoryManager().setCurrentWorkspace(null);
|
|
||||||
if(currWS != null){
|
|
||||||
//Not really safe to continue testing under this situation... other tests will likely fail with obscure
|
|
||||||
// errors that are hard to track back to this
|
|
||||||
log.error("Open workspace leaked from test! Exiting - {}, isOpen = {} - {}", currWS.getId(), currWS.isScopeActive(), currWS);
|
|
||||||
System.exit(1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public void logTestCompletion( long totalTime){
|
|
||||||
StringBuilder sb = new StringBuilder();
|
|
||||||
long maxPhys = Pointer.maxPhysicalBytes();
|
|
||||||
long maxBytes = Pointer.maxBytes();
|
|
||||||
long currPhys = Pointer.physicalBytes();
|
|
||||||
long currBytes = Pointer.totalBytes();
|
|
||||||
|
|
||||||
long jvmTotal = Runtime.getRuntime().totalMemory();
|
|
||||||
long jvmMax = Runtime.getRuntime().maxMemory();
|
|
||||||
|
|
||||||
int threadsAfter = ManagementFactory.getThreadMXBean().getThreadCount();
|
|
||||||
sb.append(getClass().getSimpleName()).append(".").append(testName.getMethodName())
|
|
||||||
.append(": ").append(totalTime).append(" ms")
|
|
||||||
.append(", threadCount: (").append(threadCountBefore).append("->").append(threadsAfter).append(")")
|
|
||||||
.append(", jvmTotal=").append(jvmTotal)
|
|
||||||
.append(", jvmMax=").append(jvmMax)
|
|
||||||
.append(", totalBytes=").append(currBytes).append(", maxBytes=").append(maxBytes)
|
|
||||||
.append(", currPhys=").append(currPhys).append(", maxPhys=").append(maxPhys);
|
|
||||||
|
|
||||||
List<MemoryWorkspace> ws = Nd4j.getWorkspaceManager().getAllWorkspacesForCurrentThread();
|
|
||||||
if(ws != null && ws.size() > 0){
|
|
||||||
long currSize = 0;
|
|
||||||
for(MemoryWorkspace w : ws){
|
|
||||||
currSize += w.getCurrentSize();
|
|
||||||
}
|
|
||||||
if(currSize > 0){
|
|
||||||
sb.append(", threadWSSize=").append(currSize)
|
|
||||||
.append(" (").append(ws.size()).append(" WSs)");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
Properties p = Nd4j.getExecutioner().getEnvironmentInformation();
|
|
||||||
Object o = p.get("cuda.devicesInformation");
|
|
||||||
if(o instanceof List){
|
|
||||||
List<Map<String,Object>> l = (List<Map<String, Object>>) o;
|
|
||||||
if(l.size() > 0) {
|
|
||||||
|
|
||||||
sb.append(" [").append(l.size())
|
|
||||||
.append(" GPUs: ");
|
|
||||||
|
|
||||||
for (int i = 0; i < l.size(); i++) {
|
|
||||||
Map<String,Object> m = l.get(i);
|
|
||||||
if(i > 0)
|
|
||||||
sb.append(",");
|
|
||||||
sb.append("(").append(m.get("cuda.freeMemory")).append(" free, ")
|
|
||||||
.append(m.get("cuda.totalMemory")).append(" total)");
|
|
||||||
}
|
|
||||||
sb.append("]");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
log.info(sb.toString());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The ordering for this test
|
* The ordering for this test
|
||||||
* This test will only be invoked for
|
* This test will only be invoked for
|
||||||
|
@ -315,15 +131,10 @@ public abstract class BaseNd4jTest {
|
||||||
* @return the ordering for this test
|
* @return the ordering for this test
|
||||||
*/
|
*/
|
||||||
public char ordering() {
|
public char ordering() {
|
||||||
return 'a';
|
return 'c';
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
public String getFailureMessage() {
|
public String getFailureMessage() {
|
||||||
return "Failed with backend " + backend.getClass().getName() + " and ordering " + ordering();
|
return "Failed with backend " + backend.getClass().getName() + " and ordering " + ordering();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -65,16 +65,6 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
||||||
super(backend);
|
super(backend);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Before
|
|
||||||
public void before() throws Exception {
|
|
||||||
super.before();
|
|
||||||
}
|
|
||||||
|
|
||||||
@After
|
|
||||||
public void after() throws Exception {
|
|
||||||
super.after();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
|
@ -133,13 +133,13 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long testTimeoutMilliseconds() {
|
public long getTimeoutMilliseconds() {
|
||||||
return 90000;
|
return 90000;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void before() throws Exception {
|
public void before() throws Exception {
|
||||||
super.before();
|
super.beforeTest();
|
||||||
Nd4j.setDataType(DataType.DOUBLE);
|
Nd4j.setDataType(DataType.DOUBLE);
|
||||||
Nd4j.getRandom().setSeed(123);
|
Nd4j.getRandom().setSeed(123);
|
||||||
Nd4j.getExecutioner().enableDebugMode(false);
|
Nd4j.getExecutioner().enableDebugMode(false);
|
||||||
|
@ -148,7 +148,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void after() throws Exception {
|
public void after() throws Exception {
|
||||||
super.after();
|
super.afterTest();
|
||||||
Nd4j.setDataType(initialType);
|
Nd4j.setDataType(initialType);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5331,7 +5331,8 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testNativeSort3() {
|
public void testNativeSort3() {
|
||||||
INDArray array = Nd4j.linspace(1, 1048576, 1048576, DataType.DOUBLE).reshape(1, -1);
|
int length = isIntegrationTests() ? 1048576 : 16484;
|
||||||
|
INDArray array = Nd4j.linspace(1, length, length, DataType.DOUBLE).reshape(1, -1);
|
||||||
INDArray exp = array.dup();
|
INDArray exp = array.dup();
|
||||||
Nd4j.shuffle(array, 0);
|
Nd4j.shuffle(array, 0);
|
||||||
|
|
||||||
|
@ -7196,19 +7197,19 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
|
|
||||||
for( int i=-3; i<3; i++ ){
|
for( int i=-3; i<3; i++ ){
|
||||||
INDArray out = Nd4j.stack(i, in, in2);
|
INDArray out = Nd4j.stack(i, in, in2);
|
||||||
int[] expShape;
|
long[] expShape;
|
||||||
switch (i){
|
switch (i){
|
||||||
case -3:
|
case -3:
|
||||||
case 0:
|
case 0:
|
||||||
expShape = new int[]{2,3,4};
|
expShape = new long[]{2,3,4};
|
||||||
break;
|
break;
|
||||||
case -2:
|
case -2:
|
||||||
case 1:
|
case 1:
|
||||||
expShape = new int[]{3,2,4};
|
expShape = new long[]{3,2,4};
|
||||||
break;
|
break;
|
||||||
case -1:
|
case -1:
|
||||||
case 2:
|
case 2:
|
||||||
expShape = new int[]{3,4,2};
|
expShape = new long[]{3,4,2};
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
throw new RuntimeException(String.valueOf(i));
|
throw new RuntimeException(String.valueOf(i));
|
||||||
|
@ -7602,6 +7603,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
String wsName = "testRollingMeanWs";
|
String wsName = "testRollingMeanWs";
|
||||||
try {
|
try {
|
||||||
System.gc();
|
System.gc();
|
||||||
|
int iterations1 = isIntegrationTests() ? 5 : 2;
|
||||||
for (int e = 0; e < 5; e++) {
|
for (int e = 0; e < 5; e++) {
|
||||||
try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(wsconf, wsName)) {
|
try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(wsconf, wsName)) {
|
||||||
val array = Nd4j.create(DataType.FLOAT, 32, 128, 256, 256);
|
val array = Nd4j.create(DataType.FLOAT, 32, 128, 256, 256);
|
||||||
|
@ -7609,7 +7611,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int iterations = 20;
|
int iterations = isIntegrationTests() ? 20 : 3;
|
||||||
val timeStart = System.nanoTime();
|
val timeStart = System.nanoTime();
|
||||||
for (int e = 0; e < iterations; e++) {
|
for (int e = 0; e < iterations; e++) {
|
||||||
try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(wsconf, wsName)) {
|
try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(wsconf, wsName)) {
|
||||||
|
|
|
@ -57,13 +57,13 @@ public class Nd4jTestsComparisonC extends BaseNd4jTest {
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void before() throws Exception {
|
public void before() throws Exception {
|
||||||
super.before();
|
super.beforeTest();
|
||||||
DataTypeUtil.setDTypeForContext(DataType.DOUBLE);
|
DataTypeUtil.setDTypeForContext(DataType.DOUBLE);
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void after() throws Exception {
|
public void after() throws Exception {
|
||||||
super.after();
|
super.afterTest();
|
||||||
DataTypeUtil.setDTypeForContext(initialType);
|
DataTypeUtil.setDTypeForContext(initialType);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -37,8 +37,7 @@ import org.slf4j.LoggerFactory;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.*;
|
||||||
import static org.junit.Assert.assertTrue;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Tests comparing Nd4j ops to other libraries
|
* Tests comparing Nd4j ops to other libraries
|
||||||
|
@ -59,7 +58,7 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest {
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void before() throws Exception {
|
public void before() throws Exception {
|
||||||
super.before();
|
super.beforeTest();
|
||||||
DataTypeUtil.setDTypeForContext(DataType.DOUBLE);
|
DataTypeUtil.setDTypeForContext(DataType.DOUBLE);
|
||||||
Nd4j.getRandom().setSeed(SEED);
|
Nd4j.getRandom().setSeed(SEED);
|
||||||
|
|
||||||
|
@ -67,7 +66,7 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest {
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void after() throws Exception {
|
public void after() throws Exception {
|
||||||
super.after();
|
super.afterTest();
|
||||||
DataTypeUtil.setDTypeForContext(initialType);
|
DataTypeUtil.setDTypeForContext(initialType);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -197,7 +196,7 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest {
|
||||||
INDArray gemv = m.mmul(v);
|
INDArray gemv = m.mmul(v);
|
||||||
RealMatrix gemv2 = rm.multiply(rv);
|
RealMatrix gemv2 = rm.multiply(rv);
|
||||||
|
|
||||||
assertArrayEquals(new int[] {rows, 1}, gemv.shape());
|
assertArrayEquals(new long[] {rows, 1}, gemv.shape());
|
||||||
assertArrayEquals(new int[] {rows, 1},
|
assertArrayEquals(new int[] {rows, 1},
|
||||||
new int[] {gemv2.getRowDimension(), gemv2.getColumnDimension()});
|
new int[] {gemv2.getRowDimension(), gemv2.getColumnDimension()});
|
||||||
|
|
||||||
|
|
|
@ -25,6 +25,8 @@ import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
import org.nd4j.linalg.util.ArrayUtil;
|
import org.nd4j.linalg.util.ArrayUtil;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertArrayEquals;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Created by Alex on 30/04/2016.
|
* Created by Alex on 30/04/2016.
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -38,8 +38,7 @@ import org.nd4j.linalg.util.SerializationUtils;
|
||||||
import java.io.*;
|
import java.io.*;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.*;
|
||||||
import static org.junit.Assert.assertTrue;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Double data buffer tests
|
* Double data buffer tests
|
||||||
|
|
|
@ -37,8 +37,7 @@ import org.nd4j.linalg.util.SerializationUtils;
|
||||||
import java.io.*;
|
import java.io.*;
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.*;
|
||||||
import static org.junit.Assert.assertTrue;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Float data buffer tests
|
* Float data buffer tests
|
||||||
|
|
|
@ -31,8 +31,7 @@ import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
import java.io.*;
|
import java.io.*;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.*;
|
||||||
import static org.junit.Assert.assertTrue;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Tests for INT INDArrays and DataBuffers serialization
|
* Tests for INT INDArrays and DataBuffers serialization
|
||||||
|
|
|
@ -33,8 +33,7 @@ import org.nd4j.linalg.util.ArrayUtil;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.*;
|
||||||
import static org.junit.Assert.assertTrue;
|
|
||||||
import static org.nd4j.linalg.indexing.NDArrayIndex.*;
|
import static org.nd4j.linalg.indexing.NDArrayIndex.*;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue