Merge pull request #8641 from KonduitAI/master

Update Eclipse master from working repo
master
Alex Black 2020-01-27 16:01:20 +11:00 committed by GitHub
commit 97c7dd2c94
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1284 changed files with 35023 additions and 26492 deletions

View File

@ -84,6 +84,13 @@
<artifactId>jackson</artifactId>
<version>${nd4j.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-common-tests</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
<profiles>

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition;
@ -32,7 +33,7 @@ import org.deeplearning4j.arbiter.optimize.runner.listener.impl.LoggingStatusLis
import org.junit.Assert;
import org.junit.Test;
public class TestGeneticSearch {
public class TestGeneticSearch extends BaseDL4JTest {
public class TestSelectionOperator extends SelectionOperator {
@Override

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider;
import org.deeplearning4j.arbiter.optimize.generator.GridSearchCandidateGenerator;
@ -26,7 +27,7 @@ import java.util.Map;
import static org.junit.Assert.*;
public class TestGridSearch {
public class TestGridSearch extends BaseDL4JTest {
@Test
public void testIndexing() {

View File

@ -19,6 +19,7 @@ package org.deeplearning4j.arbiter.optimize;
import org.apache.commons.math3.distribution.LogNormalDistribution;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.distribution.UniformIntegerDistribution;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider;
@ -49,7 +50,7 @@ import static org.junit.Assert.assertEquals;
/**
* Created by Alex on 02/02/2017.
*/
public class TestJson {
public class TestJson extends BaseDL4JTest {
protected static ObjectMapper getObjectMapper(JsonFactory factory) {
ObjectMapper om = new ObjectMapper(factory);

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider;
import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition;
@ -34,7 +35,7 @@ import java.util.Map;
* Test random search on the Branin Function:
* http://www.sfu.ca/~ssurjano/branin.html
*/
public class TestRandomSearch {
public class TestRandomSearch extends BaseDL4JTest {
@Test
public void test() throws Exception {

View File

@ -17,12 +17,13 @@
package org.deeplearning4j.arbiter.optimize.distribution;
import org.apache.commons.math3.distribution.RealDistribution;
import org.deeplearning4j.BaseDL4JTest;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
public class TestLogUniform {
public class TestLogUniform extends BaseDL4JTest {
@Test
public void testSimple(){

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.ArithmeticCrossover;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection;
@ -24,7 +25,7 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
import org.junit.Assert;
import org.junit.Test;
public class ArithmeticCrossoverTests {
public class ArithmeticCrossoverTests extends BaseDL4JTest {
@Test
public void ArithmeticCrossover_Crossover_OutsideCrossoverRate_ShouldReturnParent0() {

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
import org.deeplearning4j.arbiter.optimize.genetic.TestCrossoverOperator;
@ -23,7 +24,7 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer;
import org.junit.Assert;
import org.junit.Test;
public class CrossoverOperatorTests {
public class CrossoverOperatorTests extends BaseDL4JTest {
@Test
public void CrossoverOperator_initializeInstance_ShouldInitPopulationModel() throws IllegalAccessException {

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.utils.CrossoverPointsGenerator;
import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
import org.junit.Assert;
@ -24,7 +25,7 @@ import org.junit.Test;
import java.util.Deque;
public class CrossoverPointsGeneratorTests {
public class CrossoverPointsGeneratorTests extends BaseDL4JTest {
@Test
public void CrossoverPointsGenerator_FixedNumberCrossovers() {

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.KPointCrossover;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection;
@ -25,7 +26,7 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
import org.junit.Assert;
import org.junit.Test;
public class KPointCrossoverTests {
public class KPointCrossoverTests extends BaseDL4JTest {
@Test
public void KPointCrossover_BelowCrossoverRate_ShouldReturnParent0() {

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection;
import org.junit.Assert;
@ -24,7 +25,7 @@ import org.junit.Test;
import java.util.ArrayList;
import java.util.List;
public class ParentSelectionTests {
public class ParentSelectionTests extends BaseDL4JTest {
@Test
public void ParentSelection_InitializeInstance_ShouldInitPopulation() {

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.RandomTwoParentSelection;
import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
@ -26,7 +27,7 @@ import org.junit.Test;
import java.util.ArrayList;
import java.util.List;
public class RandomTwoParentSelectionTests {
public class RandomTwoParentSelectionTests extends BaseDL4JTest {
@Test
public void RandomTwoParentSelection_ShouldReturnTwoDifferentParents() {
RandomGenerator rng = new TestRandomGenerator(new int[] {1, 1, 1, 0}, null);

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.SinglePointCrossover;
import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection;
@ -24,7 +25,7 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
import org.junit.Assert;
import org.junit.Test;
public class SinglePointCrossoverTests {
public class SinglePointCrossoverTests extends BaseDL4JTest {
@Test
public void SinglePointCrossover_BelowCrossoverRate_ShouldReturnParent0() {
RandomGenerator rng = new TestRandomGenerator(null, new double[] {1.0});

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.TwoParentsCrossoverOperator;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection;
@ -27,7 +28,7 @@ import org.junit.Assert;
import org.junit.Test;
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
public class TwoParentsCrossoverOperatorTests {
public class TwoParentsCrossoverOperatorTests extends BaseDL4JTest {
class TestTwoParentsCrossoverOperator extends TwoParentsCrossoverOperator {

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.UniformCrossover;
import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection;
@ -24,7 +25,7 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
import org.junit.Assert;
import org.junit.Test;
public class UniformCrossoverTests {
public class UniformCrossoverTests extends BaseDL4JTest {
@Test
public void UniformCrossover_BelowCrossoverRate_ShouldReturnParent0() {

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.culling;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.LeastFitCullOperator;
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
@ -27,7 +28,7 @@ import org.junit.Test;
import java.util.ArrayList;
import java.util.List;
public class LeastFitCullOperatorTests {
public class LeastFitCullOperatorTests extends BaseDL4JTest {
@Test
public void LeastFitCullingOperation_ShouldCullLastElements() {

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.culling;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.RatioCullOperator;
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
@ -27,7 +28,7 @@ import sun.reflect.generics.reflectiveObjects.NotImplementedException;
import java.util.List;
public class RatioCullOperatorTests {
public class RatioCullOperatorTests extends BaseDL4JTest {
class TestRatioCullOperator extends RatioCullOperator {

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.mutation;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.mutation.RandomMutationOperator;
import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
import org.junit.Assert;
@ -24,7 +25,7 @@ import org.junit.Test;
import java.lang.reflect.Field;
import java.util.Arrays;
public class RandomMutationOperatorTests {
public class RandomMutationOperatorTests extends BaseDL4JTest {
@Test
public void RandomMutationOperator_DefaultBuild_ShouldNotBeNull() {
RandomMutationOperator sut = new RandomMutationOperator.Builder().build();

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.population;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.CullOperator;
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
@ -27,7 +28,7 @@ import org.junit.Test;
import java.util.List;
public class PopulationModelTests {
public class PopulationModelTests extends BaseDL4JTest {
private class TestCullOperator implements CullOperator {

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.selection;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverOperator;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
@ -36,7 +37,7 @@ import sun.reflect.generics.reflectiveObjects.NotImplementedException;
import static org.junit.Assert.assertArrayEquals;
public class GeneticSelectionOperatorTests {
public class GeneticSelectionOperatorTests extends BaseDL4JTest {
private class TestCullOperator implements CullOperator {

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.selection;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
@ -25,7 +26,7 @@ import org.junit.Assert;
import org.junit.Test;
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
public class SelectionOperatorTests {
public class SelectionOperatorTests extends BaseDL4JTest {
private class TestSelectionOperator extends SelectionOperator {
public PopulationModel getPopulationModel() {

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.arbiter.optimize.parameter;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.optimize.parameter.continuous.ContinuousParameterSpace;
import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterSpace;
@ -25,7 +26,7 @@ import org.junit.Test;
import static org.junit.Assert.assertEquals;
public class TestParameterSpaces {
public class TestParameterSpaces extends BaseDL4JTest {
@Test

View File

@ -63,6 +63,13 @@
<artifactId>gson</artifactId>
<version>${gson.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-common-tests</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
<profiles>

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.computationgraph;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.ComputationGraphSpace;
import org.deeplearning4j.arbiter.TestUtils;
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
@ -44,7 +45,7 @@ import java.util.Random;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
public class TestComputationGraphSpace {
public class TestComputationGraphSpace extends BaseDL4JTest {
@Test
public void testBasic() {

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.arbiter.computationgraph;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.ComputationGraphSpace;
import org.deeplearning4j.arbiter.conf.updater.AdamSpace;
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
@ -85,7 +86,7 @@ import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
@Slf4j
public class TestGraphLocalExecution {
public class TestGraphLocalExecution extends BaseDL4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@ -126,7 +127,7 @@ public class TestGraphLocalExecution {
if(dataApproach == 0){
ds = TestDL4JLocalExecution.MnistDataSource.class;
dsP = new Properties();
dsP.setProperty("minibatch", "8");
dsP.setProperty("minibatch", "2");
candidateGenerator = new RandomSearchGenerator(mls);
} else if(dataApproach == 1) {
//DataProvider approach
@ -150,8 +151,8 @@ public class TestGraphLocalExecution {
.dataSource(ds, dsP)
.modelSaver(new FileModelSaver(modelSave))
.scoreFunction(new TestSetLossScoreFunction())
.terminationConditions(new MaxTimeCondition(2, TimeUnit.MINUTES),
new MaxCandidatesCondition(5))
.terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS),
new MaxCandidatesCondition(3))
.build();
IOptimizationRunner runner = new LocalOptimizationRunner(configuration,new ComputationGraphTaskCreator(new ClassificationEvaluator()));
@ -159,7 +160,7 @@ public class TestGraphLocalExecution {
runner.execute();
List<ResultReference> results = runner.getResults();
assertEquals(5, results.size());
assertTrue(results.size() > 0);
System.out.println("----- COMPLETE - " + results.size() + " results -----");
}
@ -203,8 +204,8 @@ public class TestGraphLocalExecution {
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
.candidateGenerator(candidateGenerator).dataProvider(dataProvider)
.modelSaver(new FileModelSaver(modelSavePath)).scoreFunction(ScoreFunctions.testSetLoss(true))
.terminationConditions(new MaxTimeCondition(20, TimeUnit.SECONDS),
new MaxCandidatesCondition(10))
.terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS),
new MaxCandidatesCondition(3))
.build();
IOptimizationRunner runner = new LocalOptimizationRunner(configuration,
@ -223,7 +224,7 @@ public class TestGraphLocalExecution {
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.1)))
.l2(new ContinuousParameterSpace(0.0001, 0.01)).addInputs("in")
.setInputTypes(InputType.feedForward(4))
.setInputTypes(InputType.feedForward(784))
.addLayer("layer0",
new DenseLayerSpace.Builder().nIn(784).nOut(new IntegerParameterSpace(2, 10))
.activation(new DiscreteParameterSpace<>(Activation.RELU, Activation.TANH))
@ -250,8 +251,8 @@ public class TestGraphLocalExecution {
.candidateGenerator(candidateGenerator)
.dataProvider(new TestMdsDataProvider(1, 32))
.modelSaver(new FileModelSaver(modelSavePath)).scoreFunction(ScoreFunctions.testSetLoss(true))
.terminationConditions(new MaxTimeCondition(20, TimeUnit.SECONDS),
new MaxCandidatesCondition(10))
.terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS),
new MaxCandidatesCondition(3))
.scoreFunction(ScoreFunctions.testSetAccuracy())
.build();
@ -279,7 +280,7 @@ public class TestGraphLocalExecution {
@Override
public Object trainData(Map<String, Object> dataParameters) {
try {
DataSetIterator underlying = new MnistDataSetIterator(batchSize, Math.min(60000, 10 * batchSize), false, true, true, 12345);
DataSetIterator underlying = new MnistDataSetIterator(batchSize, Math.min(60000, 3 * batchSize), false, true, true, 12345);
return new MultiDataSetIteratorAdapter(new MultipleEpochsIterator(numEpochs, underlying));
} catch (IOException e) {
throw new RuntimeException(e);
@ -289,7 +290,7 @@ public class TestGraphLocalExecution {
@Override
public Object testData(Map<String, Object> dataParameters) {
try {
DataSetIterator underlying = new MnistDataSetIterator(batchSize, Math.min(10000, 5 * batchSize), false, false, false, 12345);
DataSetIterator underlying = new MnistDataSetIterator(batchSize, Math.min(10000, 2 * batchSize), false, false, false, 12345);
return new MultiDataSetIteratorAdapter(underlying);
} catch (IOException e) {
throw new RuntimeException(e);
@ -305,7 +306,7 @@ public class TestGraphLocalExecution {
@Test
public void testLocalExecutionEarlyStopping() throws Exception {
EarlyStoppingConfiguration<ComputationGraph> esConf = new EarlyStoppingConfiguration.Builder<ComputationGraph>()
.epochTerminationConditions(new MaxEpochsTerminationCondition(4))
.epochTerminationConditions(new MaxEpochsTerminationCondition(2))
.scoreCalculator(new ScoreProvider())
.modelSaver(new InMemoryModelSaver()).build();
Map<String, Object> commands = new HashMap<>();
@ -348,8 +349,8 @@ public class TestGraphLocalExecution {
.dataProvider(dataProvider)
.scoreFunction(ScoreFunctions.testSetF1())
.modelSaver(new FileModelSaver(modelSavePath))
.terminationConditions(new MaxTimeCondition(45, TimeUnit.SECONDS),
new MaxCandidatesCondition(10))
.terminationConditions(new MaxTimeCondition(15, TimeUnit.SECONDS),
new MaxCandidatesCondition(3))
.build();
@ -364,7 +365,7 @@ public class TestGraphLocalExecution {
@Override
public ScoreCalculator get() {
try {
return new DataSetLossCalculatorCG(new MnistDataSetIterator(128, 1280), true);
return new DataSetLossCalculatorCG(new MnistDataSetIterator(4, 8), true);
} catch (Exception e){
throw new RuntimeException(e);
}

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.arbiter.computationgraph;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.ComputationGraphSpace;
import org.deeplearning4j.arbiter.conf.updater.AdamSpace;
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
@ -79,11 +80,16 @@ import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
@Slf4j
public class TestGraphLocalExecutionGenetic {
public class TestGraphLocalExecutionGenetic extends BaseDL4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Override
public long getTimeoutMilliseconds() {
return 45000L;
}
@Test
public void testLocalExecutionDataSources() throws Exception {
for (int dataApproach = 0; dataApproach < 3; dataApproach++) {
@ -115,7 +121,7 @@ public class TestGraphLocalExecutionGenetic {
if (dataApproach == 0) {
ds = TestDL4JLocalExecution.MnistDataSource.class;
dsP = new Properties();
dsP.setProperty("minibatch", "8");
dsP.setProperty("minibatch", "2");
candidateGenerator = new GeneticSearchCandidateGenerator.Builder(mls, scoreFunction)
.populationModel(new PopulationModel.Builder().populationSize(5).build())
@ -148,7 +154,7 @@ public class TestGraphLocalExecutionGenetic {
.dataSource(ds, dsP)
.modelSaver(new FileModelSaver(modelSave))
.scoreFunction(new TestSetLossScoreFunction())
.terminationConditions(new MaxTimeCondition(2, TimeUnit.MINUTES),
.terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS),
new MaxCandidatesCondition(10))
.build();
@ -157,7 +163,7 @@ public class TestGraphLocalExecutionGenetic {
runner.execute();
List<ResultReference> results = runner.getResults();
assertEquals(10, results.size());
assertTrue(results.size() > 0);
System.out.println("----- COMPLETE - " + results.size() + " results -----");
}

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.json;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.ComputationGraphSpace;
import org.deeplearning4j.arbiter.MultiLayerSpace;
import org.deeplearning4j.arbiter.conf.updater.AdaMaxSpace;
@ -71,7 +72,7 @@ import static org.junit.Assert.assertNotNull;
/**
* Created by Alex on 14/02/2017.
*/
public class TestJson {
public class TestJson extends BaseDL4JTest {
@Test
public void testMultiLayerSpaceJson() {

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.multilayernetwork;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.MultiLayerSpace;
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
import org.deeplearning4j.arbiter.layers.ConvolutionLayerSpace;
@ -59,7 +60,7 @@ import java.util.concurrent.TimeUnit;
// import org.deeplearning4j.arbiter.optimize.ui.listener.UIOptimizationRunnerStatusListener;
/** Not strictly a unit test. Rather: part example, part debugging on MNIST */
public class MNISTOptimizationTest {
public class MNISTOptimizationTest extends BaseDL4JTest {
public static void main(String[] args) throws Exception {
EarlyStoppingConfiguration<MultiLayerNetwork> esConf =

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.arbiter.multilayernetwork;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.MultiLayerSpace;
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
import org.deeplearning4j.arbiter.evaluator.multilayer.ClassificationEvaluator;
@ -72,9 +73,10 @@ import java.util.Properties;
import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
@Slf4j
public class TestDL4JLocalExecution {
public class TestDL4JLocalExecution extends BaseDL4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@ -112,7 +114,7 @@ public class TestDL4JLocalExecution {
if(dataApproach == 0){
ds = MnistDataSource.class;
dsP = new Properties();
dsP.setProperty("minibatch", "8");
dsP.setProperty("minibatch", "2");
candidateGenerator = new RandomSearchGenerator(mls);
} else if(dataApproach == 1) {
//DataProvider approach
@ -136,7 +138,7 @@ public class TestDL4JLocalExecution {
.dataSource(ds, dsP)
.modelSaver(new FileModelSaver(modelSave))
.scoreFunction(new TestSetLossScoreFunction())
.terminationConditions(new MaxTimeCondition(2, TimeUnit.MINUTES),
.terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS),
new MaxCandidatesCondition(5))
.build();
@ -146,7 +148,7 @@ public class TestDL4JLocalExecution {
runner.execute();
List<ResultReference> results = runner.getResults();
assertEquals(5, results.size());
assertTrue(results.size() > 0);
System.out.println("----- COMPLETE - " + results.size() + " results -----");
}

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.multilayernetwork;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.ComputationGraphSpace;
import org.deeplearning4j.arbiter.MultiLayerSpace;
import org.deeplearning4j.arbiter.layers.DenseLayerSpace;
@ -39,7 +40,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.io.File;
public class TestErrors {
public class TestErrors extends BaseDL4JTest {
@Rule
public TemporaryFolder temp = new TemporaryFolder();
@ -60,7 +61,7 @@ public class TestErrors {
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls);
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
.candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 10))
.candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 3))
.modelSaver(new FileModelSaver(f)).scoreFunction(new TestSetLossScoreFunction(true))
.terminationConditions(
new MaxCandidatesCondition(5))
@ -87,7 +88,7 @@ public class TestErrors {
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls);
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
.candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 10))
.candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 3))
.modelSaver(new FileModelSaver(f)).scoreFunction(new TestSetLossScoreFunction(true))
.terminationConditions(
new MaxCandidatesCondition(5))
@ -116,7 +117,7 @@ public class TestErrors {
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls);
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
.candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 10))
.candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 3))
.modelSaver(new FileModelSaver(f)).scoreFunction(new TestSetLossScoreFunction(true))
.terminationConditions(new MaxCandidatesCondition(5))
.build();
@ -143,7 +144,7 @@ public class TestErrors {
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls);
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
.candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 10))
.candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 3))
.modelSaver(new FileModelSaver(f)).scoreFunction(new TestSetLossScoreFunction(true))
.terminationConditions(
new MaxCandidatesCondition(5))

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.arbiter.multilayernetwork;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.TestUtils;
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
import org.deeplearning4j.arbiter.layers.*;
@ -44,7 +45,7 @@ import java.util.Random;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
public class TestLayerSpace {
public class TestLayerSpace extends BaseDL4JTest {
@Test
public void testBasic1() {

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.multilayernetwork;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.DL4JConfiguration;
import org.deeplearning4j.arbiter.MultiLayerSpace;
import org.deeplearning4j.arbiter.TestUtils;
@ -86,7 +87,7 @@ import java.util.*;
import static org.junit.Assert.*;
public class TestMultiLayerSpace {
public class TestMultiLayerSpace extends BaseDL4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();

View File

@ -18,6 +18,7 @@ package org.deeplearning4j.arbiter.multilayernetwork;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.MultiLayerSpace;
import org.deeplearning4j.arbiter.conf.updater.AdamSpace;
import org.deeplearning4j.arbiter.layers.OutputLayerSpace;
@ -60,7 +61,13 @@ import java.util.Map;
import static org.junit.Assert.assertEquals;
@Slf4j
public class TestScoreFunctions {
public class TestScoreFunctions extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 60000L;
}
@Test
public void testROCScoreFunctions() throws Exception {
@ -107,7 +114,7 @@ public class TestScoreFunctions {
List<ResultReference> list = runner.getResults();
for (ResultReference rr : list) {
DataSetIterator testIter = new MnistDataSetIterator(32, 2000, false, false, true, 12345);
DataSetIterator testIter = new MnistDataSetIterator(4, 16, false, false, false, 12345);
testIter.setPreProcessor(new PreProc(rocType));
OptimizationResult or = rr.getResult();
@ -141,10 +148,10 @@ public class TestScoreFunctions {
}
DataSetIterator iter = new MnistDataSetIterator(32, 8000, false, true, true, 12345);
DataSetIterator iter = new MnistDataSetIterator(4, 16, false, false, false, 12345);
iter.setPreProcessor(new PreProc(rocType));
assertEquals(msg, expScore, or.getScore(), 1e-5);
assertEquals(msg, expScore, or.getScore(), 1e-4);
}
}
}
@ -158,7 +165,7 @@ public class TestScoreFunctions {
@Override
public Object trainData(Map<String, Object> dataParameters) {
try {
DataSetIterator iter = new MnistDataSetIterator(32, 8000, false, true, true, 12345);
DataSetIterator iter = new MnistDataSetIterator(4, 16, false, false, false, 12345);
iter.setPreProcessor(new PreProc(rocType));
return iter;
} catch (IOException e){
@ -169,7 +176,7 @@ public class TestScoreFunctions {
@Override
public Object testData(Map<String, Object> dataParameters) {
try {
DataSetIterator iter = new MnistDataSetIterator(32, 2000, false, false, true, 12345);
DataSetIterator iter = new MnistDataSetIterator(4, 16, false, false, false, 12345);
iter.setPreProcessor(new PreProc(rocType));
return iter;
} catch (IOException e){

View File

@ -49,6 +49,13 @@
<version>${junit.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-common-tests</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
<profiles>

View File

@ -18,6 +18,7 @@ package org.deeplearning4j.arbiter.server;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.MultiLayerSpace;
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
import org.deeplearning4j.arbiter.layers.DenseLayerSpace;
@ -52,7 +53,7 @@ import static org.junit.Assert.assertEquals;
* Created by agibsonccc on 3/12/17.
*/
@Slf4j
public class ArbiterCLIRunnerTest {
public class ArbiterCLIRunnerTest extends BaseDL4JTest {
@Test

View File

@ -311,7 +311,11 @@ public class CSVRecordReaderTest {
rr.reset();
fail("Expected exception");
} catch (Exception e){
e.printStackTrace();
String msg = e.getMessage();
String msg2 = e.getCause().getMessage();
assertTrue(msg, msg.contains("Error during LineRecordReader reset"));
assertTrue(msg2, msg2.contains("Reset not supported from streams"));
// e.printStackTrace();
}
}

View File

@ -55,8 +55,7 @@ public class LineReaderTest {
@Test
public void testLineReader() throws Exception {
String tempDir = System.getProperty("java.io.tmpdir");
File tmpdir = new File(tempDir, "tmpdir-testLineReader");
File tmpdir = testDir.newFolder();
if (tmpdir.exists())
tmpdir.delete();
tmpdir.mkdir();
@ -84,12 +83,6 @@ public class LineReaderTest {
}
assertEquals(9, count);
try {
FileUtils.deleteDirectory(tmpdir);
} catch (Exception e) {
e.printStackTrace();
}
}
@Test
@ -145,13 +138,6 @@ public class LineReaderTest {
assertEquals(2, subset.size());
assertEquals(out3.get(4), subset.get(0));
assertEquals(out3.get(7), subset.get(1));
try {
FileUtils.deleteDirectory(tmpdir);
} catch (Exception e) {
e.printStackTrace();
}
}
@Test
@ -177,11 +163,5 @@ public class LineReaderTest {
}
assertEquals(9, count);
try {
FileUtils.deleteDirectory(tmpdir);
} catch (Exception e) {
e.printStackTrace();
}
}
}

View File

@ -1,116 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~ Copyright (c) 2015-2018 Skymind, Inc.
~
~ This program and the accompanying materials are made available under the
~ terms of the Apache License, Version 2.0 which is available at
~ https://www.apache.org/licenses/LICENSE-2.0.
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
~ License for the specific language governing permissions and limitations
~ under the License.
~
~ SPDX-License-Identifier: Apache-2.0
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
<parent>
<artifactId>datavec-parent</artifactId>
<groupId>org.datavec</groupId>
<version>1.0.0-SNAPSHOT</version>
</parent>
<modelVersion>4.0.0</modelVersion>
<artifactId>datavec-camel</artifactId>
<name>DataVec Camel Component</name>
<url>http://deeplearning4j.org</url>
<dependencies>
<dependency>
<groupId>org.apache.camel</groupId>
<artifactId>camel-csv</artifactId>
<version>${camel.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-api</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.camel</groupId>
<artifactId>camel-core</artifactId>
<version>${camel.version}</version>
</dependency>
<!-- support camel documentation -->
<!-- logging -->
<!-- testing -->
<dependency>
<groupId>org.apache.camel</groupId>
<artifactId>camel-test</artifactId>
<version>${camel.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<defaultGoal>install</defaultGoal>
<plugins>
<plugin>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<source>1.7</source>
<target>1.7</target>
</configuration>
</plugin>
<plugin>
<artifactId>maven-resources-plugin</artifactId>
<version>3.0.1</version>
<configuration>
<encoding>UTF-8</encoding>
</configuration>
</plugin>
<!-- generate components meta-data and validate component includes documentation etc -->
<plugin>
<groupId>org.apache.camel</groupId>
<artifactId>camel-package-maven-plugin</artifactId>
<version>${camel.version}</version>
<executions>
<execution>
<id>prepare</id>
<goals>
<goal>prepare-components</goal>
</goals>
<phase>generate-resources</phase>
</execution>
<!-- <execution>
<id>validate</id>
<goals>
<goal>validate-components</goal>
</goals>
<phase>prepare-package</phase>
</execution>-->
</executions>
</plugin>
</plugins>
</build>
<profiles>
<profile>
<id>test-nd4j-native</id>
</profile>
<profile>
<id>test-nd4j-cuda-10.2</id>
</profile>
</profiles>
</project>

View File

@ -1,45 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.camel.component;
import org.apache.camel.CamelContext;
import org.apache.camel.Endpoint;
import org.apache.camel.impl.UriEndpointComponent;
import java.util.Map;
/**
* Represents the component that manages {@link DataVecEndpoint}.
*/
public class DataVecComponent extends UriEndpointComponent {
public DataVecComponent() {
super(DataVecEndpoint.class);
}
public DataVecComponent(CamelContext context) {
super(context, DataVecEndpoint.class);
}
@Override
protected Endpoint createEndpoint(String uri, String remaining, Map<String, Object> parameters) throws Exception {
DataVecEndpoint endpoint = new DataVecEndpoint(uri, this);
setProperties(endpoint, parameters);
endpoint.setInputFormat(remaining);
return endpoint;
}
}

View File

@ -1,93 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.camel.component;
import org.apache.camel.Exchange;
import org.apache.camel.Processor;
import org.apache.camel.impl.ScheduledPollConsumer;
import org.datavec.api.conf.Configuration;
import org.datavec.api.formats.input.InputFormat;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.split.InputSplit;
/**
* The DataVec consumer.
* @author Adam Gibson
*/
public class DataVecConsumer extends ScheduledPollConsumer {
private final DataVecEndpoint endpoint;
private Class<? extends InputFormat> inputFormatClazz;
private Class<? extends DataVecMarshaller> marshallerClazz;
private InputFormat inputFormat;
private Configuration configuration;
private DataVecMarshaller marshaller;
public DataVecConsumer(DataVecEndpoint endpoint, Processor processor) {
super(endpoint, processor);
this.endpoint = endpoint;
try {
inputFormatClazz = (Class<? extends InputFormat>) Class.forName(endpoint.getInputFormat());
inputFormat = inputFormatClazz.newInstance();
marshallerClazz = (Class<? extends DataVecMarshaller>) Class.forName(endpoint.getInputMarshaller());
marshaller = marshallerClazz.newInstance();
configuration = new Configuration();
for (String prop : endpoint.getConsumerProperties().keySet())
configuration.set(prop, endpoint.getConsumerProperties().get(prop).toString());
} catch (Exception e) {
throw new RuntimeException(e);
}
}
//stub, still need to fill out more of the end point yet..endpoint will likely be initialized with a split
protected InputSplit inputFromExchange(Exchange exchange) {
return marshaller.getSplit(exchange);
}
@Override
protected int poll() throws Exception {
Exchange exchange = endpoint.createExchange();
InputSplit split = inputFromExchange(exchange);
RecordReader reader = inputFormat.createReader(split, configuration);
int numMessagesPolled = 0;
while (reader.hasNext()) {
// create a message body
while (reader.hasNext()) {
exchange.getIn().setBody(reader.next());
try {
// send message to next processor in the route
getProcessor().process(exchange);
numMessagesPolled++; // number of messages polled
} finally {
// log exception if an exception occurred and was not handled
if (exchange.getException() != null) {
getExceptionHandler().handleException("Error processing exchange", exchange,
exchange.getException());
}
}
}
}
return numMessagesPolled;
}
}

View File

@ -1,68 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.camel.component;
import lombok.Data;
import org.apache.camel.Consumer;
import org.apache.camel.Processor;
import org.apache.camel.Producer;
import org.apache.camel.impl.DefaultEndpoint;
import org.apache.camel.spi.Metadata;
import org.apache.camel.spi.UriEndpoint;
import org.apache.camel.spi.UriParam;
import org.apache.camel.spi.UriPath;
/**
* Represents a DataVec endpoint.
* @author Adam Gibson
*/
@UriEndpoint(scheme = "datavec", title = "datavec", syntax = "datavec:inputFormat/?outputFormat=?&inputMarshaller=?",
consumerClass = DataVecConsumer.class, label = "datavec")
@Data
public class DataVecEndpoint extends DefaultEndpoint {
@UriPath
@Metadata(required = "true")
private String inputFormat;
@UriParam(defaultValue = "")
private String outputFormat;
@UriParam
@Metadata(required = "true")
private String inputMarshaller;
@UriParam(defaultValue = "org.datavec.api.io.converters.SelfWritableConverter")
private String writableConverter;
public DataVecEndpoint(String uri, DataVecComponent component) {
super(uri, component);
}
public DataVecEndpoint(String endpointUri) {
super(endpointUri);
}
public Producer createProducer() throws Exception {
return new DataVecProducer(this);
}
public Consumer createConsumer(Processor processor) throws Exception {
return new DataVecConsumer(this, processor);
}
public boolean isSingleton() {
return true;
}
}

View File

@ -1,109 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.camel.component;
import org.apache.camel.Exchange;
import org.apache.camel.impl.DefaultProducer;
import org.datavec.api.conf.Configuration;
import org.datavec.api.formats.input.InputFormat;
import org.datavec.api.io.WritableConverter;
import org.datavec.api.io.converters.SelfWritableConverter;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.split.InputSplit;
import org.datavec.api.writable.Writable;
import java.util.ArrayList;
import java.util.Collection;
/**
* The DataVec producer.
* Converts input records in to their final form
* based on the input split generated from
* the given exchange.
*
* @author Adam Gibson
*/
public class DataVecProducer extends DefaultProducer {
private Class<? extends InputFormat> inputFormatClazz;
private Class<? extends DataVecMarshaller> marshallerClazz;
private InputFormat inputFormat;
private Configuration configuration;
private WritableConverter writableConverter;
private DataVecMarshaller marshaller;
public DataVecProducer(DataVecEndpoint endpoint) {
super(endpoint);
if (endpoint.getInputFormat() != null) {
try {
inputFormatClazz = (Class<? extends InputFormat>) Class.forName(endpoint.getInputFormat());
inputFormat = inputFormatClazz.newInstance();
marshallerClazz = (Class<? extends DataVecMarshaller>) Class.forName(endpoint.getInputMarshaller());
Class<? extends WritableConverter> converterClazz =
(Class<? extends WritableConverter>) Class.forName(endpoint.getWritableConverter());
writableConverter = converterClazz.newInstance();
marshaller = marshallerClazz.newInstance();
configuration = new Configuration();
for (String prop : endpoint.getConsumerProperties().keySet())
configuration.set(prop, endpoint.getConsumerProperties().get(prop).toString());
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}
//stub, still need to fill out more of the end point yet..endpoint will likely be initialized with a split
protected InputSplit inputFromExchange(Exchange exchange) {
return marshaller.getSplit(exchange);
}
@Override
public void process(Exchange exchange) throws Exception {
InputSplit split = inputFromExchange(exchange);
RecordReader reader = inputFormat.createReader(split, configuration);
Collection<Collection<Writable>> newRecord = new ArrayList<>();
if (!(writableConverter instanceof SelfWritableConverter)) {
newRecord = new ArrayList<>();
while (reader.hasNext()) {
Collection<Writable> newRecordAdd = new ArrayList<>();
// create a message body
Collection<Writable> next = reader.next();
for (Writable writable : next) {
newRecordAdd.add(writableConverter.convert(writable));
}
newRecord.add(newRecordAdd);
}
} else {
while (reader.hasNext()) {
// create a message body
Collection<Writable> next = reader.next();
newRecord.add(next);
}
}
exchange.getIn().setBody(newRecord);
exchange.getOut().setBody(newRecord);
}
}

View File

@ -1,42 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.camel.component.csv.marshaller;
import org.apache.camel.Exchange;
import org.datavec.api.split.InputSplit;
import org.datavec.api.split.ListStringSplit;
import org.datavec.camel.component.DataVecMarshaller;
import java.util.List;
/**
* Marshals List<List<String>>
*
* @author Adam Gibson
*/
public class ListStringInputMarshaller implements DataVecMarshaller {
/**
* @param exchange
* @return
*/
@Override
public InputSplit getSplit(Exchange exchange) {
List<List<String>> data = (List<List<String>>) exchange.getIn().getBody();
InputSplit listSplit = new ListStringSplit(data);
return listSplit;
}
}

View File

@ -1 +0,0 @@
class=org.datavec.camel.component.DataVecComponent

View File

@ -1,82 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.camel.component;
import org.apache.camel.builder.RouteBuilder;
import org.apache.camel.component.mock.MockEndpoint;
import org.apache.camel.test.junit4.CamelTestSupport;
import org.apache.commons.io.FileUtils;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Writable;
import org.junit.BeforeClass;
import org.junit.ClassRule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.linalg.io.ClassPathResource;
import java.io.File;
import java.util.ArrayList;
import java.util.Collection;
public class DataVecComponentTest extends CamelTestSupport {
@ClassRule
public static TemporaryFolder testDir = new TemporaryFolder();
private static File dir;
private static File irisFile;
@BeforeClass
public static void before() throws Exception {
dir = testDir.newFolder();
File iris = new ClassPathResource("iris.dat").getFile();
irisFile = new File(dir, "iris.dat");
FileUtils.copyFile(iris, irisFile );
}
@Test
public void testDataVec() throws Exception {
MockEndpoint mock = getMockEndpoint("mock:result");
//1
mock.expectedMessageCount(1);
RecordReader reader = new CSVRecordReader();
reader.initialize(new FileSplit(new ClassPathResource("iris.dat").getFile()));
Collection<Collection<Writable>> recordAssertion = new ArrayList<>();
while (reader.hasNext())
recordAssertion.add(reader.next());
mock.expectedBodiesReceived(recordAssertion);
assertMockEndpointsSatisfied();
}
@Override
protected RouteBuilder createRouteBuilder() throws Exception {
return new RouteBuilder() {
public void configure() {
from("file:" + dir.getAbsolutePath() + "?fileName=iris.dat&noop=true").unmarshal().csv()
.to("datavec://org.datavec.api.formats.input.impl.ListStringInputFormat?inputMarshaller=org.datavec.camel.component.ListStringInputMarshaller&writableConverter=org.datavec.api.io.converters.SelfWritableConverter")
.to("mock:result");
}
};
}
}

View File

@ -37,11 +37,6 @@
<version>${logback.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-buffer</artifactId>
<version>${nd4j.version}</version>
</dependency>
<dependency>
<groupId>com.github.jai-imageio</groupId>
<artifactId>jai-imageio-core</artifactId>

View File

@ -570,6 +570,7 @@ public class TestNativeImageLoader {
try(InputStream is = new FileInputStream(f)){
nil.asMatrix(is);
fail("Expected exception");
} catch (IOException e){
String msg = e.getMessage();
assertTrue(msg, msg.contains("decode image"));
@ -577,6 +578,7 @@ public class TestNativeImageLoader {
try(InputStream is = new FileInputStream(f)){
nil.asImageMatrix(is);
fail("Expected exception");
} catch (IOException e){
String msg = e.getMessage();
assertTrue(msg, msg.contains("decode image"));
@ -584,6 +586,7 @@ public class TestNativeImageLoader {
try(InputStream is = new FileInputStream(f)){
nil.asRowVector(is);
fail("Expected exception");
} catch (IOException e){
String msg = e.getMessage();
assertTrue(msg, msg.contains("decode image"));
@ -592,6 +595,7 @@ public class TestNativeImageLoader {
try(InputStream is = new FileInputStream(f)){
INDArray arr = Nd4j.create(DataType.FLOAT, 1, 3, 32, 32);
nil.asMatrixView(is, arr);
fail("Expected exception");
} catch (IOException e){
String msg = e.getMessage();
assertTrue(msg, msg.contains("decode image"));

View File

@ -66,9 +66,9 @@ public class JsonYamlTest {
String asJson = itp.toJson();
String asYaml = itp.toYaml();
System.out.println(asJson);
System.out.println("\n\n\n");
System.out.println(asYaml);
// System.out.println(asJson);
// System.out.println("\n\n\n");
// System.out.println(asYaml);
ImageWritable img = TestImageTransform.makeRandomImage(0, 0, 3);
ImageWritable imgJson = new ImageWritable(img.getFrame().clone());

View File

@ -1,65 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~ Copyright (c) 2015-2018 Skymind, Inc.
~
~ This program and the accompanying materials are made available under the
~ terms of the Apache License, Version 2.0 which is available at
~ https://www.apache.org/licenses/LICENSE-2.0.
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
~ License for the specific language governing permissions and limitations
~ under the License.
~
~ SPDX-License-Identifier: Apache-2.0
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<artifactId>datavec-parent</artifactId>
<groupId>org.datavec</groupId>
<version>1.0.0-SNAPSHOT</version>
</parent>
<modelVersion>4.0.0</modelVersion>
<artifactId>datavec-perf</artifactId>
<name>datavec-perf</name>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<maven.compiler.source>1.7</maven.compiler.source>
<maven.compiler.target>1.7</maven.compiler.target>
</properties>
<dependencies>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<version>${slf4j.version}</version>
</dependency>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-data-image</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-api</artifactId>
<version>${project.version}</version>
</dependency>
</dependencies>
<profiles>
<profile>
<id>test-nd4j-native</id>
</profile>
<profile>
<id>test-nd4j-cuda-10.2</id>
</profile>
</profiles>
</project>

View File

@ -1,112 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.perf.timing;
import lombok.val;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.split.InputStreamInputSplit;
import org.datavec.api.writable.Writable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.performance.PerformanceTracker;
import org.nd4j.linalg.memory.MemcpyDirection;
import java.io.BufferedInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStream;
import java.util.List;
import java.util.Map;
/**
* Timing components of a data vec pipeline
* consisting of:
* {@link RecordReader}, {@link InputStreamInputSplit}
* (note that this uses input stream input split,
* the record reader must support {@link InputStreamInputSplit} for this to work)
*
* @author Adam Gibson
*/
public class IOTiming {
/**
* Returns statistics for components of a datavec pipeline
* averaged over the specified number of times
* @param nTimes the number of times to run the pipeline for averaging
* @param recordReader the record reader
* @param file the file to read
* @param function the function
* @return the averaged {@link TimingStatistics} for input/output on a record
* reader and ndarray creation (based on the given function
* @throws Exception
*/
public static TimingStatistics averageFileRead(long nTimes, RecordReader recordReader, File file, INDArrayCreationFunction function) throws Exception {
TimingStatistics timingStatistics = null;
for(int i = 0; i < nTimes; i++) {
TimingStatistics timingStatistics1 = timeNDArrayCreation(recordReader,new BufferedInputStream(new FileInputStream(file)),function);
if(timingStatistics == null)
timingStatistics = timingStatistics1;
else {
timingStatistics = timingStatistics.add(timingStatistics1);
}
}
return timingStatistics.average(nTimes);
}
/**
*
* @param reader
* @param inputStream
* @param function
* @return
* @throws Exception
*/
public static TimingStatistics timeNDArrayCreation(RecordReader reader,
InputStream inputStream,
INDArrayCreationFunction function) throws Exception {
reader.initialize(new InputStreamInputSplit(inputStream));
long longNanos = System.nanoTime();
List<Writable> next = reader.next();
long endNanos = System.nanoTime();
long etlDiff = endNanos - longNanos;
long startArrCreation = System.nanoTime();
INDArray arr = function.createFromRecord(next);
long endArrCreation = System.nanoTime();
long endCreationDiff = endArrCreation - startArrCreation;
Map<Integer, Map<MemcpyDirection, Long>> currentBandwidth = PerformanceTracker.getInstance().getCurrentBandwidth();
val bw = currentBandwidth.get(0).get(MemcpyDirection.HOST_TO_DEVICE);
val deviceToHost = currentBandwidth.get(0).get(MemcpyDirection.HOST_TO_DEVICE);
return TimingStatistics.builder()
.diskReadingTimeNanos(etlDiff)
.bandwidthNanosHostToDevice(bw)
.bandwidthDeviceToHost(deviceToHost)
.ndarrayCreationTimeNanos(endCreationDiff)
.build();
}
public interface INDArrayCreationFunction {
INDArray createFromRecord(List<Writable> record);
}
}

View File

@ -1,74 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.perf.timing;
import lombok.Builder;
import lombok.Data;
/**
* The timing statistics for a data pipeline including:
* ndarray creation time in nanoseconds
* disk reading time in nanoseconds
* bandwidth used in host to device in nano seconds
* bandwidth device to host in nanoseconds
*
* @author Adam Gibson
*/
@Builder
@Data
public class TimingStatistics {
private long ndarrayCreationTimeNanos;
private long diskReadingTimeNanos;
private long bandwidthNanosHostToDevice;
private long bandwidthDeviceToHost;
/**
* Accumulate the given statistics
* @param timingStatistics the statistics to add
* @return the added statistics
*/
public TimingStatistics add(TimingStatistics timingStatistics) {
return TimingStatistics.builder()
.ndarrayCreationTimeNanos(ndarrayCreationTimeNanos + timingStatistics.ndarrayCreationTimeNanos)
.bandwidthNanosHostToDevice(bandwidthNanosHostToDevice + timingStatistics.bandwidthNanosHostToDevice)
.diskReadingTimeNanos(diskReadingTimeNanos + timingStatistics.diskReadingTimeNanos)
.bandwidthDeviceToHost(bandwidthDeviceToHost + timingStatistics.bandwidthDeviceToHost)
.build();
}
/**
* Average the results relative to the number of n.
* This method is meant to be used alongside
* {@link #add(TimingStatistics)}
* accumulated a number of times
* @param n n the number of elements
* @return the averaged results
*/
public TimingStatistics average(long n) {
return TimingStatistics.builder()
.ndarrayCreationTimeNanos(ndarrayCreationTimeNanos / n)
.bandwidthDeviceToHost(bandwidthDeviceToHost / n)
.diskReadingTimeNanos(diskReadingTimeNanos / n)
.bandwidthNanosHostToDevice(bandwidthNanosHostToDevice / n)
.build();
}
}

View File

@ -1,60 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.datavec.timing;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.datavec.perf.timing.IOTiming;
import org.datavec.perf.timing.TimingStatistics;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.io.ClassPathResource;
import java.util.List;
public class IOTimingTest {
@Test
public void testTiming() throws Exception {
final RecordReader image = new ImageRecordReader(28,28);
final NativeImageLoader nativeImageLoader = new NativeImageLoader(28,28);
TimingStatistics timingStatistics = IOTiming.timeNDArrayCreation(image, new ClassPathResource("datavec-perf/largestblobtest.jpg").getInputStream(), new IOTiming.INDArrayCreationFunction() {
@Override
public INDArray createFromRecord(List<Writable> record) {
NDArrayWritable imageWritable = (NDArrayWritable) record.get(0);
return imageWritable.get();
}
});
System.out.println(timingStatistics);
TimingStatistics timingStatistics1 = IOTiming.averageFileRead(1000,image,new ClassPathResource("datavec-perf/largestblobtest.jpg").getFile(), new IOTiming.INDArrayCreationFunction() {
@Override
public INDArray createFromRecord(List<Writable> record) {
NDArrayWritable imageWritable = (NDArrayWritable) record.get(0);
return imageWritable.get();
}
});
System.out.println(timingStatistics1);
}
}

View File

@ -60,7 +60,7 @@ public class CSVSparkTransformTest {
Base64NDArrayBody body = csvSparkTransform.toArray(new SingleCSVRecord(values));
INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray());
assertTrue(fromBase64.isVector());
System.out.println("Base 64ed array " + fromBase64);
// System.out.println("Base 64ed array " + fromBase64);
}
@Test
@ -125,7 +125,7 @@ public class CSVSparkTransformTest {
SequenceBatchCSVRecord transformed = csvSparkTransform.transformSequence(sequenceBatchCSVRecord);
assertNotNull(transformed.getRecords());
System.out.println(transformed);
// System.out.println(transformed);
}
@ -153,7 +153,8 @@ public class CSVSparkTransformTest {
new SingleCSVRecord(data2)));
final CSVSparkTransform transform = new CSVSparkTransform(transformProcess);
System.out.println(transform.transformSequenceIncremental(batchCsvRecord));
// System.out.println(transform.transformSequenceIncremental(batchCsvRecord));
transform.transformSequenceIncremental(batchCsvRecord);
assertEquals(3,Nd4jBase64.fromBase64(transform.transformSequenceArrayIncremental(batchCsvRecord).getNdarray()).rank());
}

View File

@ -54,7 +54,7 @@ public class ImageSparkTransformTest {
Base64NDArrayBody body = imgSparkTransform.toArray(imgRecord);
INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray());
System.out.println("Base 64ed array " + fromBase64);
// System.out.println("Base 64ed array " + fromBase64);
assertEquals(1, fromBase64.size(0));
}
@ -78,7 +78,7 @@ public class ImageSparkTransformTest {
Base64NDArrayBody body = imgSparkTransform.toArray(batch);
INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray());
System.out.println("Base 64ed array " + fromBase64);
// System.out.println("Base 64ed array " + fromBase64);
assertEquals(3, fromBase64.size(0));
}
}

View File

@ -120,7 +120,7 @@ public class ImageSparkTransformServerTest {
INDArray batchResult = getNDArray(jsonNodeBatch);
assertEquals(3, batchResult.size(0));
System.out.println(array);
// System.out.println(array);
}
@Test
@ -136,7 +136,7 @@ public class ImageSparkTransformServerTest {
INDArray batchResult = getNDArray(jsonNode);
assertEquals(3, batchResult.size(0));
System.out.println(batchResult);
// System.out.println(batchResult);
}
@Test
@ -153,7 +153,7 @@ public class ImageSparkTransformServerTest {
INDArray result = getNDArray(jsonNode);
assertEquals(1, result.size(0));
System.out.println(result);
// System.out.println(result);
}
public INDArray getNDArray(JsonNode node) throws IOException {

View File

@ -72,7 +72,9 @@ public class TestAnalysis extends BaseSparkTest {
DataAnalysis da = AnalyzeSpark.analyze(schema, rdd);
String daString = da.toString();
System.out.println(da);
// System.out.println(da);
da.toJson();
da.toString();
List<ColumnAnalysis> ca = da.getColumnAnalysis();
assertEquals(5, ca.size());
@ -151,7 +153,7 @@ public class TestAnalysis extends BaseSparkTest {
assertEquals(1, countD[countD.length - 1]);
File f = Files.createTempFile("datavec_spark_analysis_UITest", ".html").toFile();
System.out.println(f.getAbsolutePath());
// System.out.println(f.getAbsolutePath());
f.deleteOnExit();
HtmlAnalysis.createHtmlAnalysisFile(da, f);
}
@ -210,7 +212,7 @@ public class TestAnalysis extends BaseSparkTest {
for( int i=1; i<10; i++ ){
counter.merge(counters.get(i));
sparkCounter.merge(sparkCounters.get(i));
System.out.println();
// System.out.println();
}
assertEquals(sc1.sampleStdev(), counter.getStddev(false), 1e-6);
assertEquals(sparkCounter.sampleStdev(), counter.getStddev(false), 1e-6);
@ -356,7 +358,9 @@ public class TestAnalysis extends BaseSparkTest {
JavaRDD<List<Writable>> rdd = sc.parallelize(data);
DataAnalysis da = AnalyzeSpark.analyze(s, rdd);
System.out.println(da);
// System.out.println(da);
da.toString();
da.toJson();
}
}

View File

@ -59,16 +59,12 @@
<modules>
<module>datavec-api</module>
<module>datavec-data</module>
<module>datavec-geo</module>
<module>datavec-hadoop</module>
<module>datavec-spark</module>
<module>datavec-camel</module>
<module>datavec-local</module>
<module>datavec-spark-inference-parent</module>
<module>datavec-jdbc</module>
<module>datavec-excel</module>
<module>datavec-arrow</module>
<module>datavec-perf</module>
<module>datavec-python</module>
</modules>

View File

@ -1,5 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~ Copyright (c) 2015-2018 Skymind, Inc.
~ Copyright (c) 2020 Konduit K.K.
~
~ This program and the accompanying materials are made available under the
~ terms of the Apache License, Version 2.0 which is available at
@ -13,8 +14,8 @@
~
~ SPDX-License-Identifier: Apache-2.0
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<artifactId>deeplearning4j-parent</artifactId>
@ -23,36 +24,45 @@
</parent>
<modelVersion>4.0.0</modelVersion>
<artifactId>deeplearning4j-util</artifactId>
<packaging>jar</packaging>
<name>deeplearning4j-util</name>
<url>http://maven.apache.org</url>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
</properties>
<artifactId>deeplearning4j-common-tests</artifactId>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-api</artifactId>
<version>${nd4j.version}</version>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-common</artifactId>
<version>${nd4j.version}</version>
<artifactId>nd4j-api</artifactId>
<version>${project.version}</version>
</dependency>
</dependencies>
<profiles>
<profile>
<id>test-nd4j-native</id>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
</profile>
<profile>
<id>test-nd4j-cuda-10.2</id>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-cuda-10.2</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
</profile>
</profiles>
</project>

View File

@ -1,5 +1,6 @@
/*******************************************************************************
/* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -22,7 +23,9 @@ import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.rules.TestName;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.junit.rules.Timeout;
import org.nd4j.base.Preconditions;
import org.nd4j.config.ND4JSystemProperties;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
@ -30,22 +33,41 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.profiler.ProfilerConfig;
import java.lang.management.ManagementFactory;
import java.lang.management.ThreadMXBean;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import static org.junit.Assert.assertNull;
import static org.junit.Assume.assumeTrue;
@Slf4j
public class BaseDL4JTest {
public abstract class BaseDL4JTest {
@Rule
public TestName name = new TestName();
@Rule
public Timeout timeout = Timeout.millis(getTimeoutMilliseconds());
protected long startTime;
protected int threadCountBefore;
private final int DEFAULT_THREADS = Runtime.getRuntime().availableProcessors();
/**
* Override this to specify the number of threads for C++ execution, via
* {@link org.nd4j.linalg.factory.Environment#setMaxMasterThreads(int)}
* @return Number of threads to use for C++ op execution
*/
public int numThreads(){
return DEFAULT_THREADS;
}
/**
* Override this method to set the default timeout for methods in the test class
*/
public long getTimeoutMilliseconds(){
return 30000;
}
/**
* Override this to set the profiling mode for the tests defined in the child class
*/
@ -64,12 +86,45 @@ public class BaseDL4JTest {
return getDataType();
}
protected Boolean integrationTest;
/**
* @return True if integration tests maven profile is enabled, false otherwise.
*/
public boolean isIntegrationTests(){
if(integrationTest == null){
String prop = System.getenv("DL4J_INTEGRATION_TESTS");
integrationTest = Boolean.parseBoolean(prop);
}
return integrationTest;
}
/**
* Call this as the first line of a test in order to skip that test, only when the integration tests maven profile is not enabled.
* This can be used to dynamically skip integration tests when the integration test profile is not enabled.
* Note that the integration test profile is not enabled by default - "integration-tests" profile
*/
public void skipUnlessIntegrationTests(){
assumeTrue("Skipping integration test - integration profile is not enabled", isIntegrationTests());
}
@Before
public void beforeTest(){
log.info("{}.{}", getClass().getSimpleName(), name.getMethodName());
//Suppress ND4J initialization - don't need this logged for every test...
System.setProperty(ND4JSystemProperties.LOG_INITIALIZATION, "false");
System.setProperty(ND4JSystemProperties.ND4J_IGNORE_AVX, "true");
Nd4j.getExecutioner().setProfilingMode(getProfilingMode());
Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build());
Nd4j.setDefaultDataTypes(getDataType(), getDefaultFPDataType());
Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build());
Nd4j.getExecutioner().enableDebugMode(false);
Nd4j.getExecutioner().enableVerboseMode(false);
int numThreads = numThreads();
Preconditions.checkState(numThreads > 0, "Number of threads must be > 0");
if(numThreads != Nd4j.getEnvironment().maxMasterThreads()) {
Nd4j.getEnvironment().setMaxMasterThreads(numThreads);
}
startTime = System.currentTimeMillis();
threadCountBefore = ManagementFactory.getThreadMXBean().getThreadCount();
}

View File

@ -95,6 +95,12 @@
<artifactId>junit</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-common-tests</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
@ -147,6 +153,17 @@
<version>${jaxb.version}</version>
<scope>provided</scope>
</dependency>
<!-- Dependencies for dl4j-perf subproject -->
<dependency>
<groupId>com.github.oshi</groupId>
<artifactId>oshi-json</artifactId>
<version>${oshi.version}</version>
</dependency>
<dependency>
<groupId>com.github.oshi</groupId>
<artifactId>oshi-core</artifactId>
<version>${oshi.version}</version>
</dependency>
</dependencies>
<profiles>

View File

@ -158,7 +158,7 @@ public class LayerHelperValidationUtil {
double d2 = arr2.dup('c').getDouble(idx);
System.out.println("Different values at index " + idx + ": " + d1 + ", " + d2 + " - RE = " + maxRE);
}
assertTrue(s + layerName + "activations - max RE: " + maxRE, maxRE < t.getMaxRelError());
assertTrue(s + layerName + " activations - max RE: " + maxRE, maxRE < t.getMaxRelError());
log.info("Forward pass, max relative error: " + layerName + " - " + maxRE);
}

View File

@ -75,7 +75,6 @@ public class TestUtils {
}
public static ComputationGraph testModelSerialization(ComputationGraph net){
ComputationGraph restored;
try {
ByteArrayOutputStream baos = new ByteArrayOutputStream();

View File

@ -20,11 +20,9 @@ import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.base.MnistFetcher;
import org.deeplearning4j.common.resources.DL4JResources;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.ClassRule;
import org.junit.Test;
import org.junit.*;
import org.junit.rules.TemporaryFolder;
import org.junit.rules.Timeout;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
import org.nd4j.linalg.dataset.DataSet;
@ -47,6 +45,8 @@ public class MnistFetcherTest extends BaseDL4JTest {
@ClassRule
public static TemporaryFolder testDir = new TemporaryFolder();
@Rule
public Timeout timeout = Timeout.seconds(300);
@BeforeClass
public static void setup() throws Exception {

Some files were not shown because too many files have changed in this diff Show More