commit
97c7dd2c94
|
@ -84,6 +84,13 @@
|
|||
<artifactId>jackson</artifactId>
|
||||
<version>${nd4j.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>deeplearning4j-common-tests</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<profiles>
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.arbiter.optimize;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
|
||||
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
|
||||
import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition;
|
||||
|
@ -32,7 +33,7 @@ import org.deeplearning4j.arbiter.optimize.runner.listener.impl.LoggingStatusLis
|
|||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
public class TestGeneticSearch {
|
||||
public class TestGeneticSearch extends BaseDL4JTest {
|
||||
public class TestSelectionOperator extends SelectionOperator {
|
||||
|
||||
@Override
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.arbiter.optimize;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
|
||||
import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.GridSearchCandidateGenerator;
|
||||
|
@ -26,7 +27,7 @@ import java.util.Map;
|
|||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
public class TestGridSearch {
|
||||
public class TestGridSearch extends BaseDL4JTest {
|
||||
|
||||
@Test
|
||||
public void testIndexing() {
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.deeplearning4j.arbiter.optimize;
|
|||
import org.apache.commons.math3.distribution.LogNormalDistribution;
|
||||
import org.apache.commons.math3.distribution.NormalDistribution;
|
||||
import org.apache.commons.math3.distribution.UniformIntegerDistribution;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
|
||||
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
||||
import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider;
|
||||
|
@ -49,7 +50,7 @@ import static org.junit.Assert.assertEquals;
|
|||
/**
|
||||
* Created by Alex on 02/02/2017.
|
||||
*/
|
||||
public class TestJson {
|
||||
public class TestJson extends BaseDL4JTest {
|
||||
|
||||
protected static ObjectMapper getObjectMapper(JsonFactory factory) {
|
||||
ObjectMapper om = new ObjectMapper(factory);
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.arbiter.optimize;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
|
||||
import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider;
|
||||
import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition;
|
||||
|
@ -34,7 +35,7 @@ import java.util.Map;
|
|||
* Test random search on the Branin Function:
|
||||
* http://www.sfu.ca/~ssurjano/branin.html
|
||||
*/
|
||||
public class TestRandomSearch {
|
||||
public class TestRandomSearch extends BaseDL4JTest {
|
||||
|
||||
@Test
|
||||
public void test() throws Exception {
|
||||
|
|
|
@ -17,12 +17,13 @@
|
|||
package org.deeplearning4j.arbiter.optimize.distribution;
|
||||
|
||||
import org.apache.commons.math3.distribution.RealDistribution;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.junit.Test;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
public class TestLogUniform {
|
||||
public class TestLogUniform extends BaseDL4JTest {
|
||||
|
||||
@Test
|
||||
public void testSimple(){
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
|
||||
|
||||
import org.apache.commons.math3.random.RandomGenerator;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.ArithmeticCrossover;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
|
||||
import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection;
|
||||
|
@ -24,7 +25,7 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
|
|||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
public class ArithmeticCrossoverTests {
|
||||
public class ArithmeticCrossoverTests extends BaseDL4JTest {
|
||||
|
||||
@Test
|
||||
public void ArithmeticCrossover_Crossover_OutsideCrossoverRate_ShouldReturnParent0() {
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
|
||||
import org.deeplearning4j.arbiter.optimize.genetic.TestCrossoverOperator;
|
||||
|
@ -23,7 +24,7 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer;
|
|||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
public class CrossoverOperatorTests {
|
||||
public class CrossoverOperatorTests extends BaseDL4JTest {
|
||||
|
||||
@Test
|
||||
public void CrossoverOperator_initializeInstance_ShouldInitPopulationModel() throws IllegalAccessException {
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
|
||||
|
||||
import org.apache.commons.math3.random.RandomGenerator;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.utils.CrossoverPointsGenerator;
|
||||
import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
|
||||
import org.junit.Assert;
|
||||
|
@ -24,7 +25,7 @@ import org.junit.Test;
|
|||
|
||||
import java.util.Deque;
|
||||
|
||||
public class CrossoverPointsGeneratorTests {
|
||||
public class CrossoverPointsGeneratorTests extends BaseDL4JTest {
|
||||
|
||||
@Test
|
||||
public void CrossoverPointsGenerator_FixedNumberCrossovers() {
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
|
||||
|
||||
import org.apache.commons.math3.random.RandomGenerator;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.KPointCrossover;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection;
|
||||
|
@ -25,7 +26,7 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
|
|||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
public class KPointCrossoverTests {
|
||||
public class KPointCrossoverTests extends BaseDL4JTest {
|
||||
|
||||
@Test
|
||||
public void KPointCrossover_BelowCrossoverRate_ShouldReturnParent0() {
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
||||
import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection;
|
||||
import org.junit.Assert;
|
||||
|
@ -24,7 +25,7 @@ import org.junit.Test;
|
|||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class ParentSelectionTests {
|
||||
public class ParentSelectionTests extends BaseDL4JTest {
|
||||
|
||||
@Test
|
||||
public void ParentSelection_InitializeInstance_ShouldInitPopulation() {
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
|
||||
|
||||
import org.apache.commons.math3.random.RandomGenerator;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.RandomTwoParentSelection;
|
||||
import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
|
||||
|
@ -26,7 +27,7 @@ import org.junit.Test;
|
|||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class RandomTwoParentSelectionTests {
|
||||
public class RandomTwoParentSelectionTests extends BaseDL4JTest {
|
||||
@Test
|
||||
public void RandomTwoParentSelection_ShouldReturnTwoDifferentParents() {
|
||||
RandomGenerator rng = new TestRandomGenerator(new int[] {1, 1, 1, 0}, null);
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
|
||||
|
||||
import org.apache.commons.math3.random.RandomGenerator;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.SinglePointCrossover;
|
||||
import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection;
|
||||
|
@ -24,7 +25,7 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
|
|||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
public class SinglePointCrossoverTests {
|
||||
public class SinglePointCrossoverTests extends BaseDL4JTest {
|
||||
@Test
|
||||
public void SinglePointCrossover_BelowCrossoverRate_ShouldReturnParent0() {
|
||||
RandomGenerator rng = new TestRandomGenerator(null, new double[] {1.0});
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.TwoParentsCrossoverOperator;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection;
|
||||
|
@ -27,7 +28,7 @@ import org.junit.Assert;
|
|||
import org.junit.Test;
|
||||
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
|
||||
|
||||
public class TwoParentsCrossoverOperatorTests {
|
||||
public class TwoParentsCrossoverOperatorTests extends BaseDL4JTest {
|
||||
|
||||
class TestTwoParentsCrossoverOperator extends TwoParentsCrossoverOperator {
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
|
||||
|
||||
import org.apache.commons.math3.random.RandomGenerator;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.UniformCrossover;
|
||||
import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection;
|
||||
|
@ -24,7 +25,7 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
|
|||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
public class UniformCrossoverTests {
|
||||
public class UniformCrossoverTests extends BaseDL4JTest {
|
||||
|
||||
@Test
|
||||
public void UniformCrossover_BelowCrossoverRate_ShouldReturnParent0() {
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.arbiter.optimize.genetic.culling;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.LeastFitCullOperator;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
|
||||
|
@ -27,7 +28,7 @@ import org.junit.Test;
|
|||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class LeastFitCullOperatorTests {
|
||||
public class LeastFitCullOperatorTests extends BaseDL4JTest {
|
||||
|
||||
@Test
|
||||
public void LeastFitCullingOperation_ShouldCullLastElements() {
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.arbiter.optimize.genetic.culling;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.RatioCullOperator;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
|
||||
|
@ -27,7 +28,7 @@ import sun.reflect.generics.reflectiveObjects.NotImplementedException;
|
|||
|
||||
import java.util.List;
|
||||
|
||||
public class RatioCullOperatorTests {
|
||||
public class RatioCullOperatorTests extends BaseDL4JTest {
|
||||
|
||||
class TestRatioCullOperator extends RatioCullOperator {
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.arbiter.optimize.genetic.mutation;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.mutation.RandomMutationOperator;
|
||||
import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
|
||||
import org.junit.Assert;
|
||||
|
@ -24,7 +25,7 @@ import org.junit.Test;
|
|||
import java.lang.reflect.Field;
|
||||
import java.util.Arrays;
|
||||
|
||||
public class RandomMutationOperatorTests {
|
||||
public class RandomMutationOperatorTests extends BaseDL4JTest {
|
||||
@Test
|
||||
public void RandomMutationOperator_DefaultBuild_ShouldNotBeNull() {
|
||||
RandomMutationOperator sut = new RandomMutationOperator.Builder().build();
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.arbiter.optimize.genetic.population;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.CullOperator;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
|
||||
|
@ -27,7 +28,7 @@ import org.junit.Test;
|
|||
|
||||
import java.util.List;
|
||||
|
||||
public class PopulationModelTests {
|
||||
public class PopulationModelTests extends BaseDL4JTest {
|
||||
|
||||
private class TestCullOperator implements CullOperator {
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.deeplearning4j.arbiter.optimize.genetic.selection;
|
||||
|
||||
import org.apache.commons.math3.random.RandomGenerator;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverOperator;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
|
||||
|
@ -36,7 +37,7 @@ import sun.reflect.generics.reflectiveObjects.NotImplementedException;
|
|||
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
|
||||
public class GeneticSelectionOperatorTests {
|
||||
public class GeneticSelectionOperatorTests extends BaseDL4JTest {
|
||||
|
||||
private class TestCullOperator implements CullOperator {
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.arbiter.optimize.genetic.selection;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
|
||||
|
@ -25,7 +26,7 @@ import org.junit.Assert;
|
|||
import org.junit.Test;
|
||||
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
|
||||
|
||||
public class SelectionOperatorTests {
|
||||
public class SelectionOperatorTests extends BaseDL4JTest {
|
||||
private class TestSelectionOperator extends SelectionOperator {
|
||||
|
||||
public PopulationModel getPopulationModel() {
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.deeplearning4j.arbiter.optimize.parameter;
|
||||
|
||||
import org.apache.commons.math3.distribution.NormalDistribution;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
||||
import org.deeplearning4j.arbiter.optimize.parameter.continuous.ContinuousParameterSpace;
|
||||
import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterSpace;
|
||||
|
@ -25,7 +26,7 @@ import org.junit.Test;
|
|||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
public class TestParameterSpaces {
|
||||
public class TestParameterSpaces extends BaseDL4JTest {
|
||||
|
||||
|
||||
@Test
|
||||
|
|
|
@ -63,6 +63,13 @@
|
|||
<artifactId>gson</artifactId>
|
||||
<version>${gson.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>deeplearning4j-common-tests</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<profiles>
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.arbiter.computationgraph;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.ComputationGraphSpace;
|
||||
import org.deeplearning4j.arbiter.TestUtils;
|
||||
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
|
||||
|
@ -44,7 +45,7 @@ import java.util.Random;
|
|||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
public class TestComputationGraphSpace {
|
||||
public class TestComputationGraphSpace extends BaseDL4JTest {
|
||||
|
||||
@Test
|
||||
public void testBasic() {
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.deeplearning4j.arbiter.computationgraph;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.ComputationGraphSpace;
|
||||
import org.deeplearning4j.arbiter.conf.updater.AdamSpace;
|
||||
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
|
||||
|
@ -85,7 +86,7 @@ import static org.junit.Assert.assertEquals;
|
|||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
@Slf4j
|
||||
public class TestGraphLocalExecution {
|
||||
public class TestGraphLocalExecution extends BaseDL4JTest {
|
||||
|
||||
@Rule
|
||||
public TemporaryFolder testDir = new TemporaryFolder();
|
||||
|
@ -126,7 +127,7 @@ public class TestGraphLocalExecution {
|
|||
if(dataApproach == 0){
|
||||
ds = TestDL4JLocalExecution.MnistDataSource.class;
|
||||
dsP = new Properties();
|
||||
dsP.setProperty("minibatch", "8");
|
||||
dsP.setProperty("minibatch", "2");
|
||||
candidateGenerator = new RandomSearchGenerator(mls);
|
||||
} else if(dataApproach == 1) {
|
||||
//DataProvider approach
|
||||
|
@ -150,8 +151,8 @@ public class TestGraphLocalExecution {
|
|||
.dataSource(ds, dsP)
|
||||
.modelSaver(new FileModelSaver(modelSave))
|
||||
.scoreFunction(new TestSetLossScoreFunction())
|
||||
.terminationConditions(new MaxTimeCondition(2, TimeUnit.MINUTES),
|
||||
new MaxCandidatesCondition(5))
|
||||
.terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS),
|
||||
new MaxCandidatesCondition(3))
|
||||
.build();
|
||||
|
||||
IOptimizationRunner runner = new LocalOptimizationRunner(configuration,new ComputationGraphTaskCreator(new ClassificationEvaluator()));
|
||||
|
@ -159,7 +160,7 @@ public class TestGraphLocalExecution {
|
|||
runner.execute();
|
||||
|
||||
List<ResultReference> results = runner.getResults();
|
||||
assertEquals(5, results.size());
|
||||
assertTrue(results.size() > 0);
|
||||
|
||||
System.out.println("----- COMPLETE - " + results.size() + " results -----");
|
||||
}
|
||||
|
@ -203,8 +204,8 @@ public class TestGraphLocalExecution {
|
|||
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
|
||||
.candidateGenerator(candidateGenerator).dataProvider(dataProvider)
|
||||
.modelSaver(new FileModelSaver(modelSavePath)).scoreFunction(ScoreFunctions.testSetLoss(true))
|
||||
.terminationConditions(new MaxTimeCondition(20, TimeUnit.SECONDS),
|
||||
new MaxCandidatesCondition(10))
|
||||
.terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS),
|
||||
new MaxCandidatesCondition(3))
|
||||
.build();
|
||||
|
||||
IOptimizationRunner runner = new LocalOptimizationRunner(configuration,
|
||||
|
@ -223,7 +224,7 @@ public class TestGraphLocalExecution {
|
|||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
||||
.updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.1)))
|
||||
.l2(new ContinuousParameterSpace(0.0001, 0.01)).addInputs("in")
|
||||
.setInputTypes(InputType.feedForward(4))
|
||||
.setInputTypes(InputType.feedForward(784))
|
||||
.addLayer("layer0",
|
||||
new DenseLayerSpace.Builder().nIn(784).nOut(new IntegerParameterSpace(2, 10))
|
||||
.activation(new DiscreteParameterSpace<>(Activation.RELU, Activation.TANH))
|
||||
|
@ -250,8 +251,8 @@ public class TestGraphLocalExecution {
|
|||
.candidateGenerator(candidateGenerator)
|
||||
.dataProvider(new TestMdsDataProvider(1, 32))
|
||||
.modelSaver(new FileModelSaver(modelSavePath)).scoreFunction(ScoreFunctions.testSetLoss(true))
|
||||
.terminationConditions(new MaxTimeCondition(20, TimeUnit.SECONDS),
|
||||
new MaxCandidatesCondition(10))
|
||||
.terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS),
|
||||
new MaxCandidatesCondition(3))
|
||||
.scoreFunction(ScoreFunctions.testSetAccuracy())
|
||||
.build();
|
||||
|
||||
|
@ -279,7 +280,7 @@ public class TestGraphLocalExecution {
|
|||
@Override
|
||||
public Object trainData(Map<String, Object> dataParameters) {
|
||||
try {
|
||||
DataSetIterator underlying = new MnistDataSetIterator(batchSize, Math.min(60000, 10 * batchSize), false, true, true, 12345);
|
||||
DataSetIterator underlying = new MnistDataSetIterator(batchSize, Math.min(60000, 3 * batchSize), false, true, true, 12345);
|
||||
return new MultiDataSetIteratorAdapter(new MultipleEpochsIterator(numEpochs, underlying));
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
|
@ -289,7 +290,7 @@ public class TestGraphLocalExecution {
|
|||
@Override
|
||||
public Object testData(Map<String, Object> dataParameters) {
|
||||
try {
|
||||
DataSetIterator underlying = new MnistDataSetIterator(batchSize, Math.min(10000, 5 * batchSize), false, false, false, 12345);
|
||||
DataSetIterator underlying = new MnistDataSetIterator(batchSize, Math.min(10000, 2 * batchSize), false, false, false, 12345);
|
||||
return new MultiDataSetIteratorAdapter(underlying);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
|
@ -305,7 +306,7 @@ public class TestGraphLocalExecution {
|
|||
@Test
|
||||
public void testLocalExecutionEarlyStopping() throws Exception {
|
||||
EarlyStoppingConfiguration<ComputationGraph> esConf = new EarlyStoppingConfiguration.Builder<ComputationGraph>()
|
||||
.epochTerminationConditions(new MaxEpochsTerminationCondition(4))
|
||||
.epochTerminationConditions(new MaxEpochsTerminationCondition(2))
|
||||
.scoreCalculator(new ScoreProvider())
|
||||
.modelSaver(new InMemoryModelSaver()).build();
|
||||
Map<String, Object> commands = new HashMap<>();
|
||||
|
@ -348,8 +349,8 @@ public class TestGraphLocalExecution {
|
|||
.dataProvider(dataProvider)
|
||||
.scoreFunction(ScoreFunctions.testSetF1())
|
||||
.modelSaver(new FileModelSaver(modelSavePath))
|
||||
.terminationConditions(new MaxTimeCondition(45, TimeUnit.SECONDS),
|
||||
new MaxCandidatesCondition(10))
|
||||
.terminationConditions(new MaxTimeCondition(15, TimeUnit.SECONDS),
|
||||
new MaxCandidatesCondition(3))
|
||||
.build();
|
||||
|
||||
|
||||
|
@ -364,7 +365,7 @@ public class TestGraphLocalExecution {
|
|||
@Override
|
||||
public ScoreCalculator get() {
|
||||
try {
|
||||
return new DataSetLossCalculatorCG(new MnistDataSetIterator(128, 1280), true);
|
||||
return new DataSetLossCalculatorCG(new MnistDataSetIterator(4, 8), true);
|
||||
} catch (Exception e){
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.deeplearning4j.arbiter.computationgraph;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.ComputationGraphSpace;
|
||||
import org.deeplearning4j.arbiter.conf.updater.AdamSpace;
|
||||
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
|
||||
|
@ -79,11 +80,16 @@ import static org.junit.Assert.assertEquals;
|
|||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
@Slf4j
|
||||
public class TestGraphLocalExecutionGenetic {
|
||||
public class TestGraphLocalExecutionGenetic extends BaseDL4JTest {
|
||||
|
||||
@Rule
|
||||
public TemporaryFolder testDir = new TemporaryFolder();
|
||||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 45000L;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLocalExecutionDataSources() throws Exception {
|
||||
for (int dataApproach = 0; dataApproach < 3; dataApproach++) {
|
||||
|
@ -115,7 +121,7 @@ public class TestGraphLocalExecutionGenetic {
|
|||
if (dataApproach == 0) {
|
||||
ds = TestDL4JLocalExecution.MnistDataSource.class;
|
||||
dsP = new Properties();
|
||||
dsP.setProperty("minibatch", "8");
|
||||
dsP.setProperty("minibatch", "2");
|
||||
|
||||
candidateGenerator = new GeneticSearchCandidateGenerator.Builder(mls, scoreFunction)
|
||||
.populationModel(new PopulationModel.Builder().populationSize(5).build())
|
||||
|
@ -148,7 +154,7 @@ public class TestGraphLocalExecutionGenetic {
|
|||
.dataSource(ds, dsP)
|
||||
.modelSaver(new FileModelSaver(modelSave))
|
||||
.scoreFunction(new TestSetLossScoreFunction())
|
||||
.terminationConditions(new MaxTimeCondition(2, TimeUnit.MINUTES),
|
||||
.terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS),
|
||||
new MaxCandidatesCondition(10))
|
||||
.build();
|
||||
|
||||
|
@ -157,7 +163,7 @@ public class TestGraphLocalExecutionGenetic {
|
|||
runner.execute();
|
||||
|
||||
List<ResultReference> results = runner.getResults();
|
||||
assertEquals(10, results.size());
|
||||
assertTrue(results.size() > 0);
|
||||
|
||||
System.out.println("----- COMPLETE - " + results.size() + " results -----");
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.arbiter.json;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.ComputationGraphSpace;
|
||||
import org.deeplearning4j.arbiter.MultiLayerSpace;
|
||||
import org.deeplearning4j.arbiter.conf.updater.AdaMaxSpace;
|
||||
|
@ -71,7 +72,7 @@ import static org.junit.Assert.assertNotNull;
|
|||
/**
|
||||
* Created by Alex on 14/02/2017.
|
||||
*/
|
||||
public class TestJson {
|
||||
public class TestJson extends BaseDL4JTest {
|
||||
|
||||
@Test
|
||||
public void testMultiLayerSpaceJson() {
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.arbiter.multilayernetwork;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.MultiLayerSpace;
|
||||
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
|
||||
import org.deeplearning4j.arbiter.layers.ConvolutionLayerSpace;
|
||||
|
@ -59,7 +60,7 @@ import java.util.concurrent.TimeUnit;
|
|||
// import org.deeplearning4j.arbiter.optimize.ui.listener.UIOptimizationRunnerStatusListener;
|
||||
|
||||
/** Not strictly a unit test. Rather: part example, part debugging on MNIST */
|
||||
public class MNISTOptimizationTest {
|
||||
public class MNISTOptimizationTest extends BaseDL4JTest {
|
||||
|
||||
public static void main(String[] args) throws Exception {
|
||||
EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.deeplearning4j.arbiter.multilayernetwork;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.MultiLayerSpace;
|
||||
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
|
||||
import org.deeplearning4j.arbiter.evaluator.multilayer.ClassificationEvaluator;
|
||||
|
@ -72,9 +73,10 @@ import java.util.Properties;
|
|||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
@Slf4j
|
||||
public class TestDL4JLocalExecution {
|
||||
public class TestDL4JLocalExecution extends BaseDL4JTest {
|
||||
|
||||
@Rule
|
||||
public TemporaryFolder testDir = new TemporaryFolder();
|
||||
|
@ -112,7 +114,7 @@ public class TestDL4JLocalExecution {
|
|||
if(dataApproach == 0){
|
||||
ds = MnistDataSource.class;
|
||||
dsP = new Properties();
|
||||
dsP.setProperty("minibatch", "8");
|
||||
dsP.setProperty("minibatch", "2");
|
||||
candidateGenerator = new RandomSearchGenerator(mls);
|
||||
} else if(dataApproach == 1) {
|
||||
//DataProvider approach
|
||||
|
@ -136,7 +138,7 @@ public class TestDL4JLocalExecution {
|
|||
.dataSource(ds, dsP)
|
||||
.modelSaver(new FileModelSaver(modelSave))
|
||||
.scoreFunction(new TestSetLossScoreFunction())
|
||||
.terminationConditions(new MaxTimeCondition(2, TimeUnit.MINUTES),
|
||||
.terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS),
|
||||
new MaxCandidatesCondition(5))
|
||||
.build();
|
||||
|
||||
|
@ -146,7 +148,7 @@ public class TestDL4JLocalExecution {
|
|||
runner.execute();
|
||||
|
||||
List<ResultReference> results = runner.getResults();
|
||||
assertEquals(5, results.size());
|
||||
assertTrue(results.size() > 0);
|
||||
|
||||
System.out.println("----- COMPLETE - " + results.size() + " results -----");
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.arbiter.multilayernetwork;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.ComputationGraphSpace;
|
||||
import org.deeplearning4j.arbiter.MultiLayerSpace;
|
||||
import org.deeplearning4j.arbiter.layers.DenseLayerSpace;
|
||||
|
@ -39,7 +40,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
|
|||
|
||||
import java.io.File;
|
||||
|
||||
public class TestErrors {
|
||||
public class TestErrors extends BaseDL4JTest {
|
||||
|
||||
@Rule
|
||||
public TemporaryFolder temp = new TemporaryFolder();
|
||||
|
@ -60,7 +61,7 @@ public class TestErrors {
|
|||
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls);
|
||||
|
||||
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
|
||||
.candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 10))
|
||||
.candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 3))
|
||||
.modelSaver(new FileModelSaver(f)).scoreFunction(new TestSetLossScoreFunction(true))
|
||||
.terminationConditions(
|
||||
new MaxCandidatesCondition(5))
|
||||
|
@ -87,7 +88,7 @@ public class TestErrors {
|
|||
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls);
|
||||
|
||||
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
|
||||
.candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 10))
|
||||
.candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 3))
|
||||
.modelSaver(new FileModelSaver(f)).scoreFunction(new TestSetLossScoreFunction(true))
|
||||
.terminationConditions(
|
||||
new MaxCandidatesCondition(5))
|
||||
|
@ -116,7 +117,7 @@ public class TestErrors {
|
|||
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls);
|
||||
|
||||
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
|
||||
.candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 10))
|
||||
.candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 3))
|
||||
.modelSaver(new FileModelSaver(f)).scoreFunction(new TestSetLossScoreFunction(true))
|
||||
.terminationConditions(new MaxCandidatesCondition(5))
|
||||
.build();
|
||||
|
@ -143,7 +144,7 @@ public class TestErrors {
|
|||
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls);
|
||||
|
||||
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
|
||||
.candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 10))
|
||||
.candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 3))
|
||||
.modelSaver(new FileModelSaver(f)).scoreFunction(new TestSetLossScoreFunction(true))
|
||||
.terminationConditions(
|
||||
new MaxCandidatesCondition(5))
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.deeplearning4j.arbiter.multilayernetwork;
|
||||
|
||||
import org.apache.commons.lang3.ArrayUtils;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.TestUtils;
|
||||
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
|
||||
import org.deeplearning4j.arbiter.layers.*;
|
||||
|
@ -44,7 +45,7 @@ import java.util.Random;
|
|||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
public class TestLayerSpace {
|
||||
public class TestLayerSpace extends BaseDL4JTest {
|
||||
|
||||
@Test
|
||||
public void testBasic1() {
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.arbiter.multilayernetwork;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.DL4JConfiguration;
|
||||
import org.deeplearning4j.arbiter.MultiLayerSpace;
|
||||
import org.deeplearning4j.arbiter.TestUtils;
|
||||
|
@ -86,7 +87,7 @@ import java.util.*;
|
|||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
public class TestMultiLayerSpace {
|
||||
public class TestMultiLayerSpace extends BaseDL4JTest {
|
||||
|
||||
@Rule
|
||||
public TemporaryFolder testDir = new TemporaryFolder();
|
||||
|
|
|
@ -18,6 +18,7 @@ package org.deeplearning4j.arbiter.multilayernetwork;
|
|||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.MultiLayerSpace;
|
||||
import org.deeplearning4j.arbiter.conf.updater.AdamSpace;
|
||||
import org.deeplearning4j.arbiter.layers.OutputLayerSpace;
|
||||
|
@ -60,7 +61,13 @@ import java.util.Map;
|
|||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
@Slf4j
|
||||
public class TestScoreFunctions {
|
||||
public class TestScoreFunctions extends BaseDL4JTest {
|
||||
|
||||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 60000L;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testROCScoreFunctions() throws Exception {
|
||||
|
@ -107,7 +114,7 @@ public class TestScoreFunctions {
|
|||
List<ResultReference> list = runner.getResults();
|
||||
|
||||
for (ResultReference rr : list) {
|
||||
DataSetIterator testIter = new MnistDataSetIterator(32, 2000, false, false, true, 12345);
|
||||
DataSetIterator testIter = new MnistDataSetIterator(4, 16, false, false, false, 12345);
|
||||
testIter.setPreProcessor(new PreProc(rocType));
|
||||
|
||||
OptimizationResult or = rr.getResult();
|
||||
|
@ -141,10 +148,10 @@ public class TestScoreFunctions {
|
|||
}
|
||||
|
||||
|
||||
DataSetIterator iter = new MnistDataSetIterator(32, 8000, false, true, true, 12345);
|
||||
DataSetIterator iter = new MnistDataSetIterator(4, 16, false, false, false, 12345);
|
||||
iter.setPreProcessor(new PreProc(rocType));
|
||||
|
||||
assertEquals(msg, expScore, or.getScore(), 1e-5);
|
||||
assertEquals(msg, expScore, or.getScore(), 1e-4);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -158,7 +165,7 @@ public class TestScoreFunctions {
|
|||
@Override
|
||||
public Object trainData(Map<String, Object> dataParameters) {
|
||||
try {
|
||||
DataSetIterator iter = new MnistDataSetIterator(32, 8000, false, true, true, 12345);
|
||||
DataSetIterator iter = new MnistDataSetIterator(4, 16, false, false, false, 12345);
|
||||
iter.setPreProcessor(new PreProc(rocType));
|
||||
return iter;
|
||||
} catch (IOException e){
|
||||
|
@ -169,7 +176,7 @@ public class TestScoreFunctions {
|
|||
@Override
|
||||
public Object testData(Map<String, Object> dataParameters) {
|
||||
try {
|
||||
DataSetIterator iter = new MnistDataSetIterator(32, 2000, false, false, true, 12345);
|
||||
DataSetIterator iter = new MnistDataSetIterator(4, 16, false, false, false, 12345);
|
||||
iter.setPreProcessor(new PreProc(rocType));
|
||||
return iter;
|
||||
} catch (IOException e){
|
||||
|
|
|
@ -49,6 +49,13 @@
|
|||
<version>${junit.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>deeplearning4j-common-tests</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<profiles>
|
||||
|
|
|
@ -18,6 +18,7 @@ package org.deeplearning4j.arbiter.server;
|
|||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.MultiLayerSpace;
|
||||
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
|
||||
import org.deeplearning4j.arbiter.layers.DenseLayerSpace;
|
||||
|
@ -52,7 +53,7 @@ import static org.junit.Assert.assertEquals;
|
|||
* Created by agibsonccc on 3/12/17.
|
||||
*/
|
||||
@Slf4j
|
||||
public class ArbiterCLIRunnerTest {
|
||||
public class ArbiterCLIRunnerTest extends BaseDL4JTest {
|
||||
|
||||
|
||||
@Test
|
||||
|
|
|
@ -311,7 +311,11 @@ public class CSVRecordReaderTest {
|
|||
rr.reset();
|
||||
fail("Expected exception");
|
||||
} catch (Exception e){
|
||||
e.printStackTrace();
|
||||
String msg = e.getMessage();
|
||||
String msg2 = e.getCause().getMessage();
|
||||
assertTrue(msg, msg.contains("Error during LineRecordReader reset"));
|
||||
assertTrue(msg2, msg2.contains("Reset not supported from streams"));
|
||||
// e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -55,8 +55,7 @@ public class LineReaderTest {
|
|||
|
||||
@Test
|
||||
public void testLineReader() throws Exception {
|
||||
String tempDir = System.getProperty("java.io.tmpdir");
|
||||
File tmpdir = new File(tempDir, "tmpdir-testLineReader");
|
||||
File tmpdir = testDir.newFolder();
|
||||
if (tmpdir.exists())
|
||||
tmpdir.delete();
|
||||
tmpdir.mkdir();
|
||||
|
@ -84,12 +83,6 @@ public class LineReaderTest {
|
|||
}
|
||||
|
||||
assertEquals(9, count);
|
||||
|
||||
try {
|
||||
FileUtils.deleteDirectory(tmpdir);
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -145,13 +138,6 @@ public class LineReaderTest {
|
|||
assertEquals(2, subset.size());
|
||||
assertEquals(out3.get(4), subset.get(0));
|
||||
assertEquals(out3.get(7), subset.get(1));
|
||||
|
||||
|
||||
try {
|
||||
FileUtils.deleteDirectory(tmpdir);
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -177,11 +163,5 @@ public class LineReaderTest {
|
|||
}
|
||||
|
||||
assertEquals(9, count);
|
||||
|
||||
try {
|
||||
FileUtils.deleteDirectory(tmpdir);
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,116 +0,0 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
~ Copyright (c) 2015-2018 Skymind, Inc.
|
||||
~
|
||||
~ This program and the accompanying materials are made available under the
|
||||
~ terms of the Apache License, Version 2.0 which is available at
|
||||
~ https://www.apache.org/licenses/LICENSE-2.0.
|
||||
~
|
||||
~ Unless required by applicable law or agreed to in writing, software
|
||||
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
~ License for the specific language governing permissions and limitations
|
||||
~ under the License.
|
||||
~
|
||||
~ SPDX-License-Identifier: Apache-2.0
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
|
||||
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
|
||||
<parent>
|
||||
<artifactId>datavec-parent</artifactId>
|
||||
<groupId>org.datavec</groupId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
</parent>
|
||||
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<artifactId>datavec-camel</artifactId>
|
||||
|
||||
<name>DataVec Camel Component</name>
|
||||
<url>http://deeplearning4j.org</url>
|
||||
|
||||
<dependencies>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.apache.camel</groupId>
|
||||
<artifactId>camel-csv</artifactId>
|
||||
<version>${camel.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.datavec</groupId>
|
||||
<artifactId>datavec-api</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.camel</groupId>
|
||||
<artifactId>camel-core</artifactId>
|
||||
<version>${camel.version}</version>
|
||||
</dependency>
|
||||
|
||||
<!-- support camel documentation -->
|
||||
<!-- logging -->
|
||||
<!-- testing -->
|
||||
<dependency>
|
||||
<groupId>org.apache.camel</groupId>
|
||||
<artifactId>camel-test</artifactId>
|
||||
<version>${camel.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<build>
|
||||
<defaultGoal>install</defaultGoal>
|
||||
|
||||
<plugins>
|
||||
|
||||
<plugin>
|
||||
<artifactId>maven-compiler-plugin</artifactId>
|
||||
<configuration>
|
||||
<source>1.7</source>
|
||||
<target>1.7</target>
|
||||
</configuration>
|
||||
</plugin>
|
||||
|
||||
<plugin>
|
||||
<artifactId>maven-resources-plugin</artifactId>
|
||||
<version>3.0.1</version>
|
||||
<configuration>
|
||||
<encoding>UTF-8</encoding>
|
||||
</configuration>
|
||||
</plugin>
|
||||
|
||||
<!-- generate components meta-data and validate component includes documentation etc -->
|
||||
<plugin>
|
||||
<groupId>org.apache.camel</groupId>
|
||||
<artifactId>camel-package-maven-plugin</artifactId>
|
||||
<version>${camel.version}</version>
|
||||
<executions>
|
||||
<execution>
|
||||
<id>prepare</id>
|
||||
<goals>
|
||||
<goal>prepare-components</goal>
|
||||
</goals>
|
||||
<phase>generate-resources</phase>
|
||||
</execution>
|
||||
<!-- <execution>
|
||||
<id>validate</id>
|
||||
<goals>
|
||||
<goal>validate-components</goal>
|
||||
</goals>
|
||||
<phase>prepare-package</phase>
|
||||
</execution>-->
|
||||
</executions>
|
||||
</plugin>
|
||||
</plugins>
|
||||
</build>
|
||||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-10.2</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
|
@ -1,45 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.datavec.camel.component;
|
||||
|
||||
import org.apache.camel.CamelContext;
|
||||
import org.apache.camel.Endpoint;
|
||||
import org.apache.camel.impl.UriEndpointComponent;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Represents the component that manages {@link DataVecEndpoint}.
|
||||
*/
|
||||
public class DataVecComponent extends UriEndpointComponent {
|
||||
|
||||
public DataVecComponent() {
|
||||
super(DataVecEndpoint.class);
|
||||
}
|
||||
|
||||
public DataVecComponent(CamelContext context) {
|
||||
super(context, DataVecEndpoint.class);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Endpoint createEndpoint(String uri, String remaining, Map<String, Object> parameters) throws Exception {
|
||||
DataVecEndpoint endpoint = new DataVecEndpoint(uri, this);
|
||||
setProperties(endpoint, parameters);
|
||||
endpoint.setInputFormat(remaining);
|
||||
return endpoint;
|
||||
}
|
||||
}
|
|
@ -1,93 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.datavec.camel.component;
|
||||
|
||||
|
||||
import org.apache.camel.Exchange;
|
||||
import org.apache.camel.Processor;
|
||||
import org.apache.camel.impl.ScheduledPollConsumer;
|
||||
import org.datavec.api.conf.Configuration;
|
||||
import org.datavec.api.formats.input.InputFormat;
|
||||
import org.datavec.api.records.reader.RecordReader;
|
||||
import org.datavec.api.split.InputSplit;
|
||||
|
||||
/**
|
||||
* The DataVec consumer.
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
public class DataVecConsumer extends ScheduledPollConsumer {
|
||||
private final DataVecEndpoint endpoint;
|
||||
private Class<? extends InputFormat> inputFormatClazz;
|
||||
private Class<? extends DataVecMarshaller> marshallerClazz;
|
||||
private InputFormat inputFormat;
|
||||
private Configuration configuration;
|
||||
private DataVecMarshaller marshaller;
|
||||
|
||||
|
||||
public DataVecConsumer(DataVecEndpoint endpoint, Processor processor) {
|
||||
super(endpoint, processor);
|
||||
this.endpoint = endpoint;
|
||||
|
||||
try {
|
||||
inputFormatClazz = (Class<? extends InputFormat>) Class.forName(endpoint.getInputFormat());
|
||||
inputFormat = inputFormatClazz.newInstance();
|
||||
marshallerClazz = (Class<? extends DataVecMarshaller>) Class.forName(endpoint.getInputMarshaller());
|
||||
marshaller = marshallerClazz.newInstance();
|
||||
configuration = new Configuration();
|
||||
for (String prop : endpoint.getConsumerProperties().keySet())
|
||||
configuration.set(prop, endpoint.getConsumerProperties().get(prop).toString());
|
||||
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
//stub, still need to fill out more of the end point yet..endpoint will likely be initialized with a split
|
||||
protected InputSplit inputFromExchange(Exchange exchange) {
|
||||
return marshaller.getSplit(exchange);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected int poll() throws Exception {
|
||||
Exchange exchange = endpoint.createExchange();
|
||||
InputSplit split = inputFromExchange(exchange);
|
||||
RecordReader reader = inputFormat.createReader(split, configuration);
|
||||
int numMessagesPolled = 0;
|
||||
while (reader.hasNext()) {
|
||||
// create a message body
|
||||
while (reader.hasNext()) {
|
||||
exchange.getIn().setBody(reader.next());
|
||||
|
||||
try {
|
||||
// send message to next processor in the route
|
||||
getProcessor().process(exchange);
|
||||
numMessagesPolled++; // number of messages polled
|
||||
} finally {
|
||||
// log exception if an exception occurred and was not handled
|
||||
if (exchange.getException() != null) {
|
||||
getExceptionHandler().handleException("Error processing exchange", exchange,
|
||||
exchange.getException());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
return numMessagesPolled;
|
||||
}
|
||||
}
|
|
@ -1,68 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.datavec.camel.component;
|
||||
|
||||
import lombok.Data;
|
||||
import org.apache.camel.Consumer;
|
||||
import org.apache.camel.Processor;
|
||||
import org.apache.camel.Producer;
|
||||
import org.apache.camel.impl.DefaultEndpoint;
|
||||
import org.apache.camel.spi.Metadata;
|
||||
import org.apache.camel.spi.UriEndpoint;
|
||||
import org.apache.camel.spi.UriParam;
|
||||
import org.apache.camel.spi.UriPath;
|
||||
|
||||
/**
|
||||
* Represents a DataVec endpoint.
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
@UriEndpoint(scheme = "datavec", title = "datavec", syntax = "datavec:inputFormat/?outputFormat=?&inputMarshaller=?",
|
||||
consumerClass = DataVecConsumer.class, label = "datavec")
|
||||
@Data
|
||||
public class DataVecEndpoint extends DefaultEndpoint {
|
||||
@UriPath
|
||||
@Metadata(required = "true")
|
||||
private String inputFormat;
|
||||
@UriParam(defaultValue = "")
|
||||
private String outputFormat;
|
||||
@UriParam
|
||||
@Metadata(required = "true")
|
||||
private String inputMarshaller;
|
||||
@UriParam(defaultValue = "org.datavec.api.io.converters.SelfWritableConverter")
|
||||
private String writableConverter;
|
||||
|
||||
public DataVecEndpoint(String uri, DataVecComponent component) {
|
||||
super(uri, component);
|
||||
}
|
||||
|
||||
public DataVecEndpoint(String endpointUri) {
|
||||
super(endpointUri);
|
||||
}
|
||||
|
||||
public Producer createProducer() throws Exception {
|
||||
return new DataVecProducer(this);
|
||||
}
|
||||
|
||||
public Consumer createConsumer(Processor processor) throws Exception {
|
||||
return new DataVecConsumer(this, processor);
|
||||
}
|
||||
|
||||
public boolean isSingleton() {
|
||||
return true;
|
||||
}
|
||||
|
||||
}
|
|
@ -1,109 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.datavec.camel.component;
|
||||
|
||||
import org.apache.camel.Exchange;
|
||||
import org.apache.camel.impl.DefaultProducer;
|
||||
import org.datavec.api.conf.Configuration;
|
||||
import org.datavec.api.formats.input.InputFormat;
|
||||
import org.datavec.api.io.WritableConverter;
|
||||
import org.datavec.api.io.converters.SelfWritableConverter;
|
||||
import org.datavec.api.records.reader.RecordReader;
|
||||
import org.datavec.api.split.InputSplit;
|
||||
import org.datavec.api.writable.Writable;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
|
||||
|
||||
/**
|
||||
* The DataVec producer.
|
||||
* Converts input records in to their final form
|
||||
* based on the input split generated from
|
||||
* the given exchange.
|
||||
*
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
public class DataVecProducer extends DefaultProducer {
|
||||
private Class<? extends InputFormat> inputFormatClazz;
|
||||
private Class<? extends DataVecMarshaller> marshallerClazz;
|
||||
private InputFormat inputFormat;
|
||||
private Configuration configuration;
|
||||
private WritableConverter writableConverter;
|
||||
private DataVecMarshaller marshaller;
|
||||
|
||||
|
||||
public DataVecProducer(DataVecEndpoint endpoint) {
|
||||
super(endpoint);
|
||||
if (endpoint.getInputFormat() != null) {
|
||||
try {
|
||||
inputFormatClazz = (Class<? extends InputFormat>) Class.forName(endpoint.getInputFormat());
|
||||
inputFormat = inputFormatClazz.newInstance();
|
||||
marshallerClazz = (Class<? extends DataVecMarshaller>) Class.forName(endpoint.getInputMarshaller());
|
||||
Class<? extends WritableConverter> converterClazz =
|
||||
(Class<? extends WritableConverter>) Class.forName(endpoint.getWritableConverter());
|
||||
writableConverter = converterClazz.newInstance();
|
||||
marshaller = marshallerClazz.newInstance();
|
||||
configuration = new Configuration();
|
||||
for (String prop : endpoint.getConsumerProperties().keySet())
|
||||
configuration.set(prop, endpoint.getConsumerProperties().get(prop).toString());
|
||||
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
//stub, still need to fill out more of the end point yet..endpoint will likely be initialized with a split
|
||||
protected InputSplit inputFromExchange(Exchange exchange) {
|
||||
return marshaller.getSplit(exchange);
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void process(Exchange exchange) throws Exception {
|
||||
InputSplit split = inputFromExchange(exchange);
|
||||
RecordReader reader = inputFormat.createReader(split, configuration);
|
||||
Collection<Collection<Writable>> newRecord = new ArrayList<>();
|
||||
if (!(writableConverter instanceof SelfWritableConverter)) {
|
||||
newRecord = new ArrayList<>();
|
||||
while (reader.hasNext()) {
|
||||
Collection<Writable> newRecordAdd = new ArrayList<>();
|
||||
// create a message body
|
||||
Collection<Writable> next = reader.next();
|
||||
for (Writable writable : next) {
|
||||
newRecordAdd.add(writableConverter.convert(writable));
|
||||
}
|
||||
|
||||
|
||||
newRecord.add(newRecordAdd);
|
||||
}
|
||||
} else {
|
||||
while (reader.hasNext()) {
|
||||
// create a message body
|
||||
Collection<Writable> next = reader.next();
|
||||
newRecord.add(next);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
exchange.getIn().setBody(newRecord);
|
||||
exchange.getOut().setBody(newRecord);
|
||||
}
|
||||
}
|
|
@ -1,42 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.datavec.camel.component.csv.marshaller;
|
||||
|
||||
import org.apache.camel.Exchange;
|
||||
import org.datavec.api.split.InputSplit;
|
||||
import org.datavec.api.split.ListStringSplit;
|
||||
import org.datavec.camel.component.DataVecMarshaller;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Marshals List<List<String>>
|
||||
*
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
public class ListStringInputMarshaller implements DataVecMarshaller {
|
||||
/**
|
||||
* @param exchange
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public InputSplit getSplit(Exchange exchange) {
|
||||
List<List<String>> data = (List<List<String>>) exchange.getIn().getBody();
|
||||
InputSplit listSplit = new ListStringSplit(data);
|
||||
return listSplit;
|
||||
}
|
||||
}
|
|
@ -1 +0,0 @@
|
|||
class=org.datavec.camel.component.DataVecComponent
|
|
@ -1,82 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.datavec.camel.component;
|
||||
|
||||
import org.apache.camel.builder.RouteBuilder;
|
||||
import org.apache.camel.component.mock.MockEndpoint;
|
||||
import org.apache.camel.test.junit4.CamelTestSupport;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.datavec.api.records.reader.RecordReader;
|
||||
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
|
||||
import org.datavec.api.split.FileSplit;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.BeforeClass;
|
||||
import org.junit.ClassRule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
import org.nd4j.linalg.io.ClassPathResource;
|
||||
|
||||
import java.io.File;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
|
||||
public class DataVecComponentTest extends CamelTestSupport {
|
||||
|
||||
@ClassRule
|
||||
public static TemporaryFolder testDir = new TemporaryFolder();
|
||||
private static File dir;
|
||||
private static File irisFile;
|
||||
|
||||
|
||||
@BeforeClass
|
||||
public static void before() throws Exception {
|
||||
dir = testDir.newFolder();
|
||||
File iris = new ClassPathResource("iris.dat").getFile();
|
||||
irisFile = new File(dir, "iris.dat");
|
||||
FileUtils.copyFile(iris, irisFile );
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Test
|
||||
public void testDataVec() throws Exception {
|
||||
MockEndpoint mock = getMockEndpoint("mock:result");
|
||||
//1
|
||||
mock.expectedMessageCount(1);
|
||||
|
||||
RecordReader reader = new CSVRecordReader();
|
||||
reader.initialize(new FileSplit(new ClassPathResource("iris.dat").getFile()));
|
||||
Collection<Collection<Writable>> recordAssertion = new ArrayList<>();
|
||||
while (reader.hasNext())
|
||||
recordAssertion.add(reader.next());
|
||||
mock.expectedBodiesReceived(recordAssertion);
|
||||
assertMockEndpointsSatisfied();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected RouteBuilder createRouteBuilder() throws Exception {
|
||||
|
||||
|
||||
return new RouteBuilder() {
|
||||
public void configure() {
|
||||
from("file:" + dir.getAbsolutePath() + "?fileName=iris.dat&noop=true").unmarshal().csv()
|
||||
.to("datavec://org.datavec.api.formats.input.impl.ListStringInputFormat?inputMarshaller=org.datavec.camel.component.ListStringInputMarshaller&writableConverter=org.datavec.api.io.converters.SelfWritableConverter")
|
||||
.to("mock:result");
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
|
@ -37,11 +37,6 @@
|
|||
<version>${logback.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-buffer</artifactId>
|
||||
<version>${nd4j.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.github.jai-imageio</groupId>
|
||||
<artifactId>jai-imageio-core</artifactId>
|
||||
|
|
|
@ -570,6 +570,7 @@ public class TestNativeImageLoader {
|
|||
|
||||
try(InputStream is = new FileInputStream(f)){
|
||||
nil.asMatrix(is);
|
||||
fail("Expected exception");
|
||||
} catch (IOException e){
|
||||
String msg = e.getMessage();
|
||||
assertTrue(msg, msg.contains("decode image"));
|
||||
|
@ -577,6 +578,7 @@ public class TestNativeImageLoader {
|
|||
|
||||
try(InputStream is = new FileInputStream(f)){
|
||||
nil.asImageMatrix(is);
|
||||
fail("Expected exception");
|
||||
} catch (IOException e){
|
||||
String msg = e.getMessage();
|
||||
assertTrue(msg, msg.contains("decode image"));
|
||||
|
@ -584,6 +586,7 @@ public class TestNativeImageLoader {
|
|||
|
||||
try(InputStream is = new FileInputStream(f)){
|
||||
nil.asRowVector(is);
|
||||
fail("Expected exception");
|
||||
} catch (IOException e){
|
||||
String msg = e.getMessage();
|
||||
assertTrue(msg, msg.contains("decode image"));
|
||||
|
@ -592,6 +595,7 @@ public class TestNativeImageLoader {
|
|||
try(InputStream is = new FileInputStream(f)){
|
||||
INDArray arr = Nd4j.create(DataType.FLOAT, 1, 3, 32, 32);
|
||||
nil.asMatrixView(is, arr);
|
||||
fail("Expected exception");
|
||||
} catch (IOException e){
|
||||
String msg = e.getMessage();
|
||||
assertTrue(msg, msg.contains("decode image"));
|
||||
|
|
|
@ -66,9 +66,9 @@ public class JsonYamlTest {
|
|||
String asJson = itp.toJson();
|
||||
String asYaml = itp.toYaml();
|
||||
|
||||
System.out.println(asJson);
|
||||
System.out.println("\n\n\n");
|
||||
System.out.println(asYaml);
|
||||
// System.out.println(asJson);
|
||||
// System.out.println("\n\n\n");
|
||||
// System.out.println(asYaml);
|
||||
|
||||
ImageWritable img = TestImageTransform.makeRandomImage(0, 0, 3);
|
||||
ImageWritable imgJson = new ImageWritable(img.getFrame().clone());
|
||||
|
|
|
@ -1,65 +0,0 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
|
||||
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
~ Copyright (c) 2015-2018 Skymind, Inc.
|
||||
~
|
||||
~ This program and the accompanying materials are made available under the
|
||||
~ terms of the Apache License, Version 2.0 which is available at
|
||||
~ https://www.apache.org/licenses/LICENSE-2.0.
|
||||
~
|
||||
~ Unless required by applicable law or agreed to in writing, software
|
||||
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
~ License for the specific language governing permissions and limitations
|
||||
~ under the License.
|
||||
~
|
||||
~ SPDX-License-Identifier: Apache-2.0
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
|
||||
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<parent>
|
||||
<artifactId>datavec-parent</artifactId>
|
||||
<groupId>org.datavec</groupId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
</parent>
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<artifactId>datavec-perf</artifactId>
|
||||
|
||||
<name>datavec-perf</name>
|
||||
|
||||
<properties>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
<maven.compiler.source>1.7</maven.compiler.source>
|
||||
<maven.compiler.target>1.7</maven.compiler.target>
|
||||
</properties>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.slf4j</groupId>
|
||||
<artifactId>slf4j-api</artifactId>
|
||||
<version>${slf4j.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.datavec</groupId>
|
||||
<artifactId>datavec-data-image</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.datavec</groupId>
|
||||
<artifactId>datavec-api</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-10.2</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
|
@ -1,112 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.datavec.perf.timing;
|
||||
|
||||
import lombok.val;
|
||||
import org.datavec.api.records.reader.RecordReader;
|
||||
import org.datavec.api.split.InputStreamInputSplit;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.performance.PerformanceTracker;
|
||||
import org.nd4j.linalg.memory.MemcpyDirection;
|
||||
|
||||
import java.io.BufferedInputStream;
|
||||
import java.io.File;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.InputStream;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* Timing components of a data vec pipeline
|
||||
* consisting of:
|
||||
* {@link RecordReader}, {@link InputStreamInputSplit}
|
||||
* (note that this uses input stream input split,
|
||||
* the record reader must support {@link InputStreamInputSplit} for this to work)
|
||||
*
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
public class IOTiming {
|
||||
|
||||
|
||||
/**
|
||||
* Returns statistics for components of a datavec pipeline
|
||||
* averaged over the specified number of times
|
||||
* @param nTimes the number of times to run the pipeline for averaging
|
||||
* @param recordReader the record reader
|
||||
* @param file the file to read
|
||||
* @param function the function
|
||||
* @return the averaged {@link TimingStatistics} for input/output on a record
|
||||
* reader and ndarray creation (based on the given function
|
||||
* @throws Exception
|
||||
*/
|
||||
public static TimingStatistics averageFileRead(long nTimes, RecordReader recordReader, File file, INDArrayCreationFunction function) throws Exception {
|
||||
TimingStatistics timingStatistics = null;
|
||||
for(int i = 0; i < nTimes; i++) {
|
||||
TimingStatistics timingStatistics1 = timeNDArrayCreation(recordReader,new BufferedInputStream(new FileInputStream(file)),function);
|
||||
if(timingStatistics == null)
|
||||
timingStatistics = timingStatistics1;
|
||||
else {
|
||||
timingStatistics = timingStatistics.add(timingStatistics1);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return timingStatistics.average(nTimes);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param reader
|
||||
* @param inputStream
|
||||
* @param function
|
||||
* @return
|
||||
* @throws Exception
|
||||
*/
|
||||
public static TimingStatistics timeNDArrayCreation(RecordReader reader,
|
||||
InputStream inputStream,
|
||||
INDArrayCreationFunction function) throws Exception {
|
||||
|
||||
|
||||
reader.initialize(new InputStreamInputSplit(inputStream));
|
||||
long longNanos = System.nanoTime();
|
||||
List<Writable> next = reader.next();
|
||||
long endNanos = System.nanoTime();
|
||||
long etlDiff = endNanos - longNanos;
|
||||
long startArrCreation = System.nanoTime();
|
||||
INDArray arr = function.createFromRecord(next);
|
||||
long endArrCreation = System.nanoTime();
|
||||
long endCreationDiff = endArrCreation - startArrCreation;
|
||||
Map<Integer, Map<MemcpyDirection, Long>> currentBandwidth = PerformanceTracker.getInstance().getCurrentBandwidth();
|
||||
val bw = currentBandwidth.get(0).get(MemcpyDirection.HOST_TO_DEVICE);
|
||||
val deviceToHost = currentBandwidth.get(0).get(MemcpyDirection.HOST_TO_DEVICE);
|
||||
|
||||
return TimingStatistics.builder()
|
||||
.diskReadingTimeNanos(etlDiff)
|
||||
.bandwidthNanosHostToDevice(bw)
|
||||
.bandwidthDeviceToHost(deviceToHost)
|
||||
.ndarrayCreationTimeNanos(endCreationDiff)
|
||||
.build();
|
||||
}
|
||||
|
||||
public interface INDArrayCreationFunction {
|
||||
INDArray createFromRecord(List<Writable> record);
|
||||
}
|
||||
|
||||
}
|
|
@ -1,74 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.datavec.perf.timing;
|
||||
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
|
||||
|
||||
/**
|
||||
* The timing statistics for a data pipeline including:
|
||||
* ndarray creation time in nanoseconds
|
||||
* disk reading time in nanoseconds
|
||||
* bandwidth used in host to device in nano seconds
|
||||
* bandwidth device to host in nanoseconds
|
||||
*
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
@Builder
|
||||
@Data
|
||||
public class TimingStatistics {
|
||||
|
||||
private long ndarrayCreationTimeNanos;
|
||||
private long diskReadingTimeNanos;
|
||||
private long bandwidthNanosHostToDevice;
|
||||
private long bandwidthDeviceToHost;
|
||||
|
||||
|
||||
/**
|
||||
* Accumulate the given statistics
|
||||
* @param timingStatistics the statistics to add
|
||||
* @return the added statistics
|
||||
*/
|
||||
public TimingStatistics add(TimingStatistics timingStatistics) {
|
||||
return TimingStatistics.builder()
|
||||
.ndarrayCreationTimeNanos(ndarrayCreationTimeNanos + timingStatistics.ndarrayCreationTimeNanos)
|
||||
.bandwidthNanosHostToDevice(bandwidthNanosHostToDevice + timingStatistics.bandwidthNanosHostToDevice)
|
||||
.diskReadingTimeNanos(diskReadingTimeNanos + timingStatistics.diskReadingTimeNanos)
|
||||
.bandwidthDeviceToHost(bandwidthDeviceToHost + timingStatistics.bandwidthDeviceToHost)
|
||||
.build();
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Average the results relative to the number of n.
|
||||
* This method is meant to be used alongside
|
||||
* {@link #add(TimingStatistics)}
|
||||
* accumulated a number of times
|
||||
* @param n n the number of elements
|
||||
* @return the averaged results
|
||||
*/
|
||||
public TimingStatistics average(long n) {
|
||||
return TimingStatistics.builder()
|
||||
.ndarrayCreationTimeNanos(ndarrayCreationTimeNanos / n)
|
||||
.bandwidthDeviceToHost(bandwidthDeviceToHost / n)
|
||||
.diskReadingTimeNanos(diskReadingTimeNanos / n)
|
||||
.bandwidthNanosHostToDevice(bandwidthNanosHostToDevice / n)
|
||||
.build();
|
||||
}
|
||||
|
||||
}
|
|
@ -1,60 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.datavec.datavec.timing;
|
||||
|
||||
import org.datavec.api.records.reader.RecordReader;
|
||||
import org.datavec.api.writable.NDArrayWritable;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.datavec.image.loader.NativeImageLoader;
|
||||
import org.datavec.image.recordreader.ImageRecordReader;
|
||||
import org.datavec.perf.timing.IOTiming;
|
||||
import org.datavec.perf.timing.TimingStatistics;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.io.ClassPathResource;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class IOTimingTest {
|
||||
|
||||
@Test
|
||||
public void testTiming() throws Exception {
|
||||
final RecordReader image = new ImageRecordReader(28,28);
|
||||
final NativeImageLoader nativeImageLoader = new NativeImageLoader(28,28);
|
||||
|
||||
TimingStatistics timingStatistics = IOTiming.timeNDArrayCreation(image, new ClassPathResource("datavec-perf/largestblobtest.jpg").getInputStream(), new IOTiming.INDArrayCreationFunction() {
|
||||
@Override
|
||||
public INDArray createFromRecord(List<Writable> record) {
|
||||
NDArrayWritable imageWritable = (NDArrayWritable) record.get(0);
|
||||
return imageWritable.get();
|
||||
}
|
||||
});
|
||||
|
||||
System.out.println(timingStatistics);
|
||||
|
||||
TimingStatistics timingStatistics1 = IOTiming.averageFileRead(1000,image,new ClassPathResource("datavec-perf/largestblobtest.jpg").getFile(), new IOTiming.INDArrayCreationFunction() {
|
||||
@Override
|
||||
public INDArray createFromRecord(List<Writable> record) {
|
||||
NDArrayWritable imageWritable = (NDArrayWritable) record.get(0);
|
||||
return imageWritable.get();
|
||||
}
|
||||
});
|
||||
|
||||
System.out.println(timingStatistics1);
|
||||
}
|
||||
|
||||
}
|
|
@ -60,7 +60,7 @@ public class CSVSparkTransformTest {
|
|||
Base64NDArrayBody body = csvSparkTransform.toArray(new SingleCSVRecord(values));
|
||||
INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray());
|
||||
assertTrue(fromBase64.isVector());
|
||||
System.out.println("Base 64ed array " + fromBase64);
|
||||
// System.out.println("Base 64ed array " + fromBase64);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -125,7 +125,7 @@ public class CSVSparkTransformTest {
|
|||
|
||||
SequenceBatchCSVRecord transformed = csvSparkTransform.transformSequence(sequenceBatchCSVRecord);
|
||||
assertNotNull(transformed.getRecords());
|
||||
System.out.println(transformed);
|
||||
// System.out.println(transformed);
|
||||
|
||||
|
||||
}
|
||||
|
@ -153,7 +153,8 @@ public class CSVSparkTransformTest {
|
|||
new SingleCSVRecord(data2)));
|
||||
|
||||
final CSVSparkTransform transform = new CSVSparkTransform(transformProcess);
|
||||
System.out.println(transform.transformSequenceIncremental(batchCsvRecord));
|
||||
// System.out.println(transform.transformSequenceIncremental(batchCsvRecord));
|
||||
transform.transformSequenceIncremental(batchCsvRecord);
|
||||
assertEquals(3,Nd4jBase64.fromBase64(transform.transformSequenceArrayIncremental(batchCsvRecord).getNdarray()).rank());
|
||||
|
||||
}
|
||||
|
|
|
@ -54,7 +54,7 @@ public class ImageSparkTransformTest {
|
|||
Base64NDArrayBody body = imgSparkTransform.toArray(imgRecord);
|
||||
|
||||
INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray());
|
||||
System.out.println("Base 64ed array " + fromBase64);
|
||||
// System.out.println("Base 64ed array " + fromBase64);
|
||||
assertEquals(1, fromBase64.size(0));
|
||||
}
|
||||
|
||||
|
@ -78,7 +78,7 @@ public class ImageSparkTransformTest {
|
|||
Base64NDArrayBody body = imgSparkTransform.toArray(batch);
|
||||
|
||||
INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray());
|
||||
System.out.println("Base 64ed array " + fromBase64);
|
||||
// System.out.println("Base 64ed array " + fromBase64);
|
||||
assertEquals(3, fromBase64.size(0));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -120,7 +120,7 @@ public class ImageSparkTransformServerTest {
|
|||
INDArray batchResult = getNDArray(jsonNodeBatch);
|
||||
assertEquals(3, batchResult.size(0));
|
||||
|
||||
System.out.println(array);
|
||||
// System.out.println(array);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -136,7 +136,7 @@ public class ImageSparkTransformServerTest {
|
|||
INDArray batchResult = getNDArray(jsonNode);
|
||||
assertEquals(3, batchResult.size(0));
|
||||
|
||||
System.out.println(batchResult);
|
||||
// System.out.println(batchResult);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -153,7 +153,7 @@ public class ImageSparkTransformServerTest {
|
|||
INDArray result = getNDArray(jsonNode);
|
||||
assertEquals(1, result.size(0));
|
||||
|
||||
System.out.println(result);
|
||||
// System.out.println(result);
|
||||
}
|
||||
|
||||
public INDArray getNDArray(JsonNode node) throws IOException {
|
||||
|
|
|
@ -72,7 +72,9 @@ public class TestAnalysis extends BaseSparkTest {
|
|||
DataAnalysis da = AnalyzeSpark.analyze(schema, rdd);
|
||||
String daString = da.toString();
|
||||
|
||||
System.out.println(da);
|
||||
// System.out.println(da);
|
||||
da.toJson();
|
||||
da.toString();
|
||||
|
||||
List<ColumnAnalysis> ca = da.getColumnAnalysis();
|
||||
assertEquals(5, ca.size());
|
||||
|
@ -151,7 +153,7 @@ public class TestAnalysis extends BaseSparkTest {
|
|||
assertEquals(1, countD[countD.length - 1]);
|
||||
|
||||
File f = Files.createTempFile("datavec_spark_analysis_UITest", ".html").toFile();
|
||||
System.out.println(f.getAbsolutePath());
|
||||
// System.out.println(f.getAbsolutePath());
|
||||
f.deleteOnExit();
|
||||
HtmlAnalysis.createHtmlAnalysisFile(da, f);
|
||||
}
|
||||
|
@ -210,7 +212,7 @@ public class TestAnalysis extends BaseSparkTest {
|
|||
for( int i=1; i<10; i++ ){
|
||||
counter.merge(counters.get(i));
|
||||
sparkCounter.merge(sparkCounters.get(i));
|
||||
System.out.println();
|
||||
// System.out.println();
|
||||
}
|
||||
assertEquals(sc1.sampleStdev(), counter.getStddev(false), 1e-6);
|
||||
assertEquals(sparkCounter.sampleStdev(), counter.getStddev(false), 1e-6);
|
||||
|
@ -356,7 +358,9 @@ public class TestAnalysis extends BaseSparkTest {
|
|||
|
||||
JavaRDD<List<Writable>> rdd = sc.parallelize(data);
|
||||
DataAnalysis da = AnalyzeSpark.analyze(s, rdd);
|
||||
System.out.println(da);
|
||||
// System.out.println(da);
|
||||
da.toString();
|
||||
da.toJson();
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -59,16 +59,12 @@
|
|||
<modules>
|
||||
<module>datavec-api</module>
|
||||
<module>datavec-data</module>
|
||||
<module>datavec-geo</module>
|
||||
<module>datavec-hadoop</module>
|
||||
<module>datavec-spark</module>
|
||||
<module>datavec-camel</module>
|
||||
<module>datavec-local</module>
|
||||
<module>datavec-spark-inference-parent</module>
|
||||
<module>datavec-jdbc</module>
|
||||
<module>datavec-excel</module>
|
||||
<module>datavec-arrow</module>
|
||||
<module>datavec-perf</module>
|
||||
<module>datavec-python</module>
|
||||
</modules>
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
~ Copyright (c) 2015-2018 Skymind, Inc.
|
||||
~ Copyright (c) 2020 Konduit K.K.
|
||||
~
|
||||
~ This program and the accompanying materials are made available under the
|
||||
~ terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -13,8 +14,8 @@
|
|||
~
|
||||
~ SPDX-License-Identifier: Apache-2.0
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
|
||||
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<parent>
|
||||
<artifactId>deeplearning4j-parent</artifactId>
|
||||
|
@ -23,36 +24,45 @@
|
|||
</parent>
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<artifactId>deeplearning4j-util</artifactId>
|
||||
<packaging>jar</packaging>
|
||||
|
||||
<name>deeplearning4j-util</name>
|
||||
<url>http://maven.apache.org</url>
|
||||
|
||||
<properties>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
</properties>
|
||||
<artifactId>deeplearning4j-common-tests</artifactId>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-api</artifactId>
|
||||
<version>${nd4j.version}</version>
|
||||
<groupId>junit</groupId>
|
||||
<artifactId>junit</artifactId>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-common</artifactId>
|
||||
<version>${nd4j.version}</version>
|
||||
<artifactId>nd4j-api</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-native</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</profile>
|
||||
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-10.2</id>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-cuda-10.2</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
||||
|
||||
</project>
|
|
@ -1,5 +1,6 @@
|
|||
/*******************************************************************************
|
||||
/* ******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -22,7 +23,9 @@ import org.junit.After;
|
|||
import org.junit.Before;
|
||||
import org.junit.Rule;
|
||||
import org.junit.rules.TestName;
|
||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||
import org.junit.rules.Timeout;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.config.ND4JSystemProperties;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
||||
|
@ -30,22 +33,41 @@ import org.nd4j.linalg.factory.Nd4j;
|
|||
import org.nd4j.linalg.profiler.ProfilerConfig;
|
||||
|
||||
import java.lang.management.ManagementFactory;
|
||||
import java.lang.management.ThreadMXBean;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Properties;
|
||||
|
||||
import static org.junit.Assert.assertNull;
|
||||
import static org.junit.Assume.assumeTrue;
|
||||
|
||||
@Slf4j
|
||||
public class BaseDL4JTest {
|
||||
public abstract class BaseDL4JTest {
|
||||
|
||||
@Rule
|
||||
public TestName name = new TestName();
|
||||
@Rule
|
||||
public Timeout timeout = Timeout.millis(getTimeoutMilliseconds());
|
||||
|
||||
protected long startTime;
|
||||
protected int threadCountBefore;
|
||||
|
||||
private final int DEFAULT_THREADS = Runtime.getRuntime().availableProcessors();
|
||||
|
||||
/**
|
||||
* Override this to specify the number of threads for C++ execution, via
|
||||
* {@link org.nd4j.linalg.factory.Environment#setMaxMasterThreads(int)}
|
||||
* @return Number of threads to use for C++ op execution
|
||||
*/
|
||||
public int numThreads(){
|
||||
return DEFAULT_THREADS;
|
||||
}
|
||||
|
||||
/**
|
||||
* Override this method to set the default timeout for methods in the test class
|
||||
*/
|
||||
public long getTimeoutMilliseconds(){
|
||||
return 30000;
|
||||
}
|
||||
|
||||
/**
|
||||
* Override this to set the profiling mode for the tests defined in the child class
|
||||
*/
|
||||
|
@ -64,12 +86,45 @@ public class BaseDL4JTest {
|
|||
return getDataType();
|
||||
}
|
||||
|
||||
protected Boolean integrationTest;
|
||||
|
||||
/**
|
||||
* @return True if integration tests maven profile is enabled, false otherwise.
|
||||
*/
|
||||
public boolean isIntegrationTests(){
|
||||
if(integrationTest == null){
|
||||
String prop = System.getenv("DL4J_INTEGRATION_TESTS");
|
||||
integrationTest = Boolean.parseBoolean(prop);
|
||||
}
|
||||
return integrationTest;
|
||||
}
|
||||
|
||||
/**
|
||||
* Call this as the first line of a test in order to skip that test, only when the integration tests maven profile is not enabled.
|
||||
* This can be used to dynamically skip integration tests when the integration test profile is not enabled.
|
||||
* Note that the integration test profile is not enabled by default - "integration-tests" profile
|
||||
*/
|
||||
public void skipUnlessIntegrationTests(){
|
||||
assumeTrue("Skipping integration test - integration profile is not enabled", isIntegrationTests());
|
||||
}
|
||||
|
||||
@Before
|
||||
public void beforeTest(){
|
||||
log.info("{}.{}", getClass().getSimpleName(), name.getMethodName());
|
||||
//Suppress ND4J initialization - don't need this logged for every test...
|
||||
System.setProperty(ND4JSystemProperties.LOG_INITIALIZATION, "false");
|
||||
System.setProperty(ND4JSystemProperties.ND4J_IGNORE_AVX, "true");
|
||||
Nd4j.getExecutioner().setProfilingMode(getProfilingMode());
|
||||
Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build());
|
||||
Nd4j.setDefaultDataTypes(getDataType(), getDefaultFPDataType());
|
||||
Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build());
|
||||
Nd4j.getExecutioner().enableDebugMode(false);
|
||||
Nd4j.getExecutioner().enableVerboseMode(false);
|
||||
int numThreads = numThreads();
|
||||
Preconditions.checkState(numThreads > 0, "Number of threads must be > 0");
|
||||
if(numThreads != Nd4j.getEnvironment().maxMasterThreads()) {
|
||||
Nd4j.getEnvironment().setMaxMasterThreads(numThreads);
|
||||
}
|
||||
startTime = System.currentTimeMillis();
|
||||
threadCountBefore = ManagementFactory.getThreadMXBean().getThreadCount();
|
||||
}
|
|
@ -95,6 +95,12 @@
|
|||
<artifactId>junit</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>deeplearning4j-common-tests</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
|
@ -147,6 +153,17 @@
|
|||
<version>${jaxb.version}</version>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
<!-- Dependencies for dl4j-perf subproject -->
|
||||
<dependency>
|
||||
<groupId>com.github.oshi</groupId>
|
||||
<artifactId>oshi-json</artifactId>
|
||||
<version>${oshi.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.github.oshi</groupId>
|
||||
<artifactId>oshi-core</artifactId>
|
||||
<version>${oshi.version}</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<profiles>
|
||||
|
|
|
@ -158,7 +158,7 @@ public class LayerHelperValidationUtil {
|
|||
double d2 = arr2.dup('c').getDouble(idx);
|
||||
System.out.println("Different values at index " + idx + ": " + d1 + ", " + d2 + " - RE = " + maxRE);
|
||||
}
|
||||
assertTrue(s + layerName + "activations - max RE: " + maxRE, maxRE < t.getMaxRelError());
|
||||
assertTrue(s + layerName + " activations - max RE: " + maxRE, maxRE < t.getMaxRelError());
|
||||
log.info("Forward pass, max relative error: " + layerName + " - " + maxRE);
|
||||
}
|
||||
|
||||
|
|
|
@ -75,7 +75,6 @@ public class TestUtils {
|
|||
}
|
||||
|
||||
public static ComputationGraph testModelSerialization(ComputationGraph net){
|
||||
|
||||
ComputationGraph restored;
|
||||
try {
|
||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
||||
|
|
|
@ -20,11 +20,9 @@ import org.deeplearning4j.BaseDL4JTest;
|
|||
import org.deeplearning4j.base.MnistFetcher;
|
||||
import org.deeplearning4j.common.resources.DL4JResources;
|
||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||
import org.junit.AfterClass;
|
||||
import org.junit.BeforeClass;
|
||||
import org.junit.ClassRule;
|
||||
import org.junit.Test;
|
||||
import org.junit.*;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
import org.junit.rules.Timeout;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
|
||||
import org.nd4j.linalg.dataset.DataSet;
|
||||
|
@ -47,6 +45,8 @@ public class MnistFetcherTest extends BaseDL4JTest {
|
|||
|
||||
@ClassRule
|
||||
public static TemporaryFolder testDir = new TemporaryFolder();
|
||||
@Rule
|
||||
public Timeout timeout = Timeout.seconds(300);
|
||||
|
||||
@BeforeClass
|
||||
public static void setup() throws Exception {
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue