Test fixes + cleanup (#245)

* Test spam reduction

Signed-off-by: Alex Black <blacka101@gmail.com>

* Arbiter bad import fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* Small spark test tweak

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Arbiter test log spam reduction

Signed-off-by: Alex Black <blacka101@gmail.com>

* More test spam reduction

Signed-off-by: Alex Black <blacka101@gmail.com>
master
Alex Black 2020-02-18 10:29:06 +11:00 committed by GitHub
parent 2698fbf541
commit c8882cbfa5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
41 changed files with 168 additions and 182 deletions

View File

@ -1,47 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.util;
import org.slf4j.Logger;
import java.awt.*;
import java.net.URI;
/**
* Various utilities for webpages and dealing with browsers
*/
public class WebUtils {
public static void tryOpenBrowser(String path, Logger log) {
try {
WebUtils.openBrowser(new URI(path));
} catch (Exception e) {
log.error("Could not open browser", e);
System.out.println("Browser could not be launched automatically.\nUI path: " + path);
}
}
public static void openBrowser(URI uri) throws Exception {
if (Desktop.isDesktopSupported()) {
Desktop.getDesktop().browse(uri);
} else {
throw new UnsupportedOperationException(
"Cannot open browser on this platform: Desktop.isDesktopSupported() == false");
}
}
}

View File

@ -127,7 +127,7 @@ public class BraninFunction {
BraninConfig candidate = (BraninConfig) c.getValue();
double score = scoreFunction.score(candidate, null, (Map) null);
System.out.println(candidate.getX1() + "\t" + candidate.getX2() + "\t" + score);
// System.out.println(candidate.getX1() + "\t" + candidate.getX2() + "\t" + score);
Thread.sleep(20);

View File

@ -54,7 +54,7 @@ public class TestRandomSearch extends BaseDL4JTest {
runner.execute();
System.out.println("----- Complete -----");
// System.out.println("----- Complete -----");
}

View File

@ -16,8 +16,8 @@
package org.deeplearning4j.arbiter.optimize.genetic;
import org.apache.commons.lang3.NotImplementedException;
import org.apache.commons.math3.random.RandomGenerator;
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
public class TestRandomGenerator implements RandomGenerator {
private final int[] intRandomNumbers;
@ -63,17 +63,17 @@ public class TestRandomGenerator implements RandomGenerator {
@Override
public long nextLong() {
throw new NotImplementedException();
throw new NotImplementedException("Not implemented");
}
@Override
public boolean nextBoolean() {
throw new NotImplementedException();
throw new NotImplementedException("Not implemented");
}
@Override
public float nextFloat() {
throw new NotImplementedException();
throw new NotImplementedException("Not implemented");
}
@Override
@ -83,6 +83,6 @@ public class TestRandomGenerator implements RandomGenerator {
@Override
public double nextGaussian() {
throw new NotImplementedException();
throw new NotImplementedException("Not implemented");
}
}

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
import org.apache.commons.lang3.NotImplementedException;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.TwoParentsCrossoverOperator;
@ -26,7 +27,6 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection;
import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer;
import org.junit.Assert;
import org.junit.Test;
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
public class TwoParentsCrossoverOperatorTests extends BaseDL4JTest {
@ -42,7 +42,7 @@ public class TwoParentsCrossoverOperatorTests extends BaseDL4JTest {
@Override
public CrossoverResult crossover() {
throw new NotImplementedException();
throw new NotImplementedException("Not implemented");
}
}

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.culling;
import org.apache.commons.lang3.NotImplementedException;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.RatioCullOperator;
@ -24,7 +25,6 @@ import org.deeplearning4j.arbiter.optimize.generator.genetic.population.Populati
import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer;
import org.junit.Assert;
import org.junit.Test;
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
import java.util.List;
@ -46,7 +46,7 @@ public class RatioCullOperatorTests extends BaseDL4JTest {
@Override
public void cullPopulation() {
throw new NotImplementedException();
throw new NotImplementedException("Not implemented");
}
public double getCullRatio() {

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.selection;
import org.apache.commons.lang3.NotImplementedException;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
@ -33,7 +34,6 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer;
import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
import org.junit.Assert;
import org.junit.Test;
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
import static org.junit.Assert.assertArrayEquals;
@ -55,7 +55,7 @@ public class GeneticSelectionOperatorTests extends BaseDL4JTest {
@Override
public void cullPopulation() {
throw new NotImplementedException();
throw new NotImplementedException("Not implemented");
}
@Override

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.selection;
import org.apache.commons.lang3.NotImplementedException;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
@ -24,7 +25,6 @@ import org.deeplearning4j.arbiter.optimize.generator.genetic.selection.Selection
import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer;
import org.junit.Assert;
import org.junit.Test;
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
public class SelectionOperatorTests extends BaseDL4JTest {
private class TestSelectionOperator extends SelectionOperator {
@ -39,7 +39,7 @@ public class SelectionOperatorTests extends BaseDL4JTest {
@Override
public double[] buildNextGenes() {
throw new NotImplementedException();
throw new NotImplementedException("Not implemented");
}
}

View File

@ -158,7 +158,7 @@ public class TestComputationGraphSpace extends BaseDL4JTest {
}
}
System.out.println("ReLU vs. Tanh: " + reluCount + "\t" + tanhCount);
// System.out.println("ReLU vs. Tanh: " + reluCount + "\t" + tanhCount);
assertTrue(reluCount > 0);
assertTrue(tanhCount > 0);

View File

@ -162,7 +162,7 @@ public class TestGraphLocalExecution extends BaseDL4JTest {
List<ResultReference> results = runner.getResults();
assertTrue(results.size() > 0);
System.out.println("----- COMPLETE - " + results.size() + " results -----");
// System.out.println("----- COMPLETE - " + results.size() + " results -----");
}
}

View File

@ -165,7 +165,7 @@ public class TestGraphLocalExecutionGenetic extends BaseDL4JTest {
List<ResultReference> results = runner.getResults();
assertTrue(results.size() > 0);
System.out.println("----- COMPLETE - " + results.size() + " results -----");
// System.out.println("----- COMPLETE - " + results.size() + " results -----");
}
}

View File

@ -101,7 +101,7 @@ public class TestLayerSpace extends BaseDL4JTest {
double l2 = TestUtils.getL2(l);
IActivation activation = l.getActivationFn();
System.out.println(lr + "\t" + l2 + "\t" + activation);
// System.out.println(lr + "\t" + l2 + "\t" + activation);
assertTrue(lr >= 0.3 && lr <= 0.4);
assertTrue(l2 >= 0.01 && l2 <= 0.1);
@ -190,7 +190,7 @@ public class TestLayerSpace extends BaseDL4JTest {
ActivationLayer al = als.getValue(d);
IActivation activation = al.getActivationFn();
System.out.println(activation);
// System.out.println(activation);
assertTrue(containsActivationFunction(actFns, activation));
}
@ -228,7 +228,7 @@ public class TestLayerSpace extends BaseDL4JTest {
IActivation activation = el.getActivationFn();
long nOut = el.getNOut();
System.out.println(activation + "\t" + nOut);
// System.out.println(activation + "\t" + nOut);
assertTrue(containsActivationFunction(actFns, activation));
assertTrue(nOut >= 10 && nOut <= 20);
@ -295,7 +295,7 @@ public class TestLayerSpace extends BaseDL4JTest {
long nOut = el.getNOut();
double forgetGate = el.getForgetGateBiasInit();
System.out.println(activation + "\t" + nOut + "\t" + forgetGate);
// System.out.println(activation + "\t" + nOut + "\t" + forgetGate);
assertTrue(containsActivationFunction(actFns, activation));
assertTrue(nOut >= 10 && nOut <= 20);

View File

@ -293,8 +293,8 @@ public class TestMultiLayerSpace extends BaseDL4JTest {
assertTrue(nLayerCounts[i] >= 5); //Expect approx equal (50/3 each), but some variation randomly
}
System.out.println("Number of layers: " + Arrays.toString(nLayerCounts));
System.out.println("ReLU vs. Tanh: " + reluCount + "\t" + tanhCount);
// System.out.println("Number of layers: " + Arrays.toString(nLayerCounts));
// System.out.println("ReLU vs. Tanh: " + reluCount + "\t" + tanhCount);
}

View File

@ -98,7 +98,8 @@ public class ArbiterCLIRunnerTest extends BaseDL4JTest {
assertEquals(configuration,OptimizationConfiguration.fromJson(configuration.toJson()));
FileUtils.writeStringToFile(new File(configPath),configuration.toJson());
System.out.println(configuration.toJson());
// System.out.println(configuration.toJson());
configuration.toJson();
log.info("Starting test");
cliRunner.runMain(

View File

@ -41,7 +41,7 @@ public class TestGraphLoading extends BaseDL4JTest {
IGraph<String, String> graph = GraphLoader
.loadUndirectedGraphEdgeListFile(cpr.getTempFileFromArchive().getAbsolutePath(), 7, ",");
System.out.println(graph);
// System.out.println(graph);
assertEquals(graph.numVertices(), 7);
int[][] edges = {{1, 2}, {0, 2, 4}, {0, 1, 3, 4}, {2, 4, 5}, {1, 2, 3, 5, 6}, {3, 4, 6}, {4, 5}};
@ -66,7 +66,7 @@ public class TestGraphLoading extends BaseDL4JTest {
edgeLineProcessor, vertexFactory, 10, false);
System.out.println(graph);
// System.out.println(graph);
for (int i = 0; i < 10; i++) {
List<Edge<String>> edges = graph.getEdgesOut(i);
@ -111,7 +111,7 @@ public class TestGraphLoading extends BaseDL4JTest {
Graph<String, String> graph = GraphLoader.loadGraph(verticesCPR.getTempFileFromArchive().getAbsolutePath(),
edgesCPR.getTempFileFromArchive().getAbsolutePath(), vertexLoader, edgeLineProcessor, false);
System.out.println(graph);
// System.out.println(graph);
for (int i = 0; i < 10; i++) {
List<Edge<String>> edges = graph.getEdgesOut(i);

View File

@ -71,7 +71,7 @@ public class TestGraphLoadingWeighted extends BaseDL4JTest {
}
}
System.out.println(graph);
// System.out.println(graph);
}

View File

@ -220,7 +220,7 @@ public class TestGraph extends BaseDL4JTest {
sum += transitionProb[i][j];
for (int j = 0; j < transitionProb[i].length; j++)
transitionProb[i][j] /= sum;
System.out.println(Arrays.toString(transitionProb[i]));
// System.out.println(Arrays.toString(transitionProb[i]));
}
//Check that transition probs are essentially correct (within bounds of random variation)

View File

@ -145,8 +145,8 @@ public class DeepWalkGradientCheck extends BaseDL4JTest {
if (relError > MAX_REL_ERROR && absErr > MIN_ABS_ERROR)
fail(msg);
else
System.out.println(msg);
// else
// System.out.println(msg);
}
}
@ -333,10 +333,10 @@ public class DeepWalkGradientCheck extends BaseDL4JTest {
if (relError > MAX_REL_ERROR && absErr > MIN_ABS_ERROR)
fail(msg);
else
System.out.println(msg);
// else
// System.out.println(msg);
}
System.out.println();
// System.out.println();
}
}

View File

@ -67,7 +67,7 @@ public class TestDeepWalk extends BaseDL4JTest {
for (int i = 0; i < 7; i++) {
INDArray vector = deepWalk.getVertexVector(i);
assertArrayEquals(new long[] {vectorSize}, vector.shape());
System.out.println(Arrays.toString(vector.dup().data().asFloat()));
// System.out.println(Arrays.toString(vector.dup().data().asFloat()));
}
GraphWalkIterator<String> iter = new RandomWalkIterator<>(graph, 8);
@ -77,11 +77,11 @@ public class TestDeepWalk extends BaseDL4JTest {
for (int t = 0; t < 5; t++) {
iter.reset();
deepWalk.fit(iter);
System.out.println("--------------------");
// System.out.println("--------------------");
for (int i = 0; i < 7; i++) {
INDArray vector = deepWalk.getVertexVector(i);
assertArrayEquals(new long[] {vectorSize}, vector.shape());
System.out.println(Arrays.toString(vector.dup().data().asFloat()));
// System.out.println(Arrays.toString(vector.dup().data().asFloat()));
}
}
}
@ -160,7 +160,7 @@ public class TestDeepWalk extends BaseDL4JTest {
continue;
double sim = deepWalk.similarity(i, nearestTo);
System.out.println(i + "\t" + nearestTo + "\t" + sim);
// System.out.println(i + "\t" + nearestTo + "\t" + sim);
assertTrue(sim <= minSimNearest);
}
}
@ -211,7 +211,7 @@ public class TestDeepWalk extends BaseDL4JTest {
Graph<String, String> graph = GraphLoader
.loadUndirectedGraphEdgeListFile(cpr.getTempFileFromArchive().getAbsolutePath(), 13, ",");
System.out.println(graph);
// System.out.println(graph);
Nd4j.getRandom().setSeed(12345);
@ -229,11 +229,13 @@ public class TestDeepWalk extends BaseDL4JTest {
//Calculate similarity(0,i)
for (int i = 0; i < nVertices; i++) {
System.out.println(deepWalk.similarity(0, i));
// System.out.println(deepWalk.similarity(0, i));
deepWalk.similarity(0, i);
}
for (int i = 0; i < nVertices; i++)
System.out.println(deepWalk.getVertexVector(i));
// System.out.println(deepWalk.getVertexVector(i));
deepWalk.getVertexVector(i);
}
@Test(timeout = 60000L)

View File

@ -38,9 +38,11 @@ public class TestGraphHuffman extends BaseDL4JTest {
gh.buildTree(vertexDegrees);
for (int i = 0; i < 7; i++)
System.out.println(i + "\t" + gh.getCodeLength(i) + "\t" + gh.getCodeString(i) + "\t\t" + gh.getCode(i)
+ "\t\t" + Arrays.toString(gh.getPathInnerNodes(i)));
for (int i = 0; i < 7; i++) {
String s = i + "\t" + gh.getCodeLength(i) + "\t" + gh.getCodeString(i) + "\t\t" + gh.getCode(i)
+ "\t\t" + Arrays.toString(gh.getPathInnerNodes(i));
// System.out.println(s);
}
int[] expectedLengths = {3, 2, 2, 5, 4, 2, 5};
for (int i = 0; i < vertexDegrees.length; i++) {

View File

@ -79,8 +79,9 @@ public class ParameterServerParallelWrapperTest extends BaseDL4JTest {
model.init();
ParallelWrapper parameterServerParallelWrapper =
new ParallelWrapper.Builder(model).trainerFactory(new ParameterServerTrainerContext())
.workers(Runtime.getRuntime().availableProcessors())
new ParallelWrapper.Builder(model)
.workers(Math.min(4, Runtime.getRuntime().availableProcessors()))
.trainerFactory(new ParameterServerTrainerContext())
.reportScoreAfterAveraging(true).prefetchBuffer(3).build();
parameterServerParallelWrapper.fit(mnistTrain);

View File

@ -104,7 +104,7 @@ public class SparkWord2VecTest extends BaseDL4JTest {
public void call(ExportContainer<VocabWord> v) throws Exception {
assertNotNull(v.getElement());
assertNotNull(v.getArray());
System.out.println(v.getElement() + " - " + v.getArray());
// System.out.println(v.getElement() + " - " + v.getArray());
}
}
}

View File

@ -66,7 +66,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest {
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.setListeners(new ScoreIterationListener(1));
net.setListeners(new ScoreIterationListener(5));
JavaRDD<DataSet> irisData = getIris();
@ -119,7 +119,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest {
.lossFunction(LossFunctions.LossFunction.MSE).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.setListeners(new ScoreIterationListener(1));
net.setListeners(new ScoreIterationListener(5));
JavaRDD<DataSet> irisData = getIris();
EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
@ -155,7 +155,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest {
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.setListeners(new ScoreIterationListener(1));
net.setListeners(new ScoreIterationListener(5));
JavaRDD<DataSet> irisData = getIris();
@ -198,7 +198,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest {
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.setListeners(new ScoreIterationListener(1));
net.setListeners(new ScoreIterationListener(5));
JavaRDD<DataSet> irisData = getIris();
@ -231,7 +231,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest {
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.setListeners(new ScoreIterationListener(1));
net.setListeners(new ScoreIterationListener(5));
JavaRDD<DataSet> irisData = getIris();

View File

@ -69,7 +69,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest {
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
.setOutputs("0").build();
ComputationGraph net = new ComputationGraph(conf);
net.setListeners(new ScoreIterationListener(1));
net.setListeners(new ScoreIterationListener(5));
JavaRDD<DataSet> irisData = getIris();
@ -120,7 +120,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest {
.lossFunction(LossFunctions.LossFunction.MSE).build(), "in")
.setOutputs("0").build();
ComputationGraph net = new ComputationGraph(conf);
net.setListeners(new ScoreIterationListener(1));
net.setListeners(new ScoreIterationListener(5));
JavaRDD<DataSet> irisData = getIris();
EarlyStoppingModelSaver<ComputationGraph> saver = new InMemoryModelSaver<>();
@ -158,7 +158,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest {
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
.setOutputs("0").build();
ComputationGraph net = new ComputationGraph(conf);
net.setListeners(new ScoreIterationListener(1));
net.setListeners(new ScoreIterationListener(5));
JavaRDD<DataSet> irisData = getIris();
@ -203,7 +203,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest {
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
.setOutputs("0").build();
ComputationGraph net = new ComputationGraph(conf);
net.setListeners(new ScoreIterationListener(1));
net.setListeners(new ScoreIterationListener(5));
JavaRDD<DataSet> irisData = getIris();
@ -238,7 +238,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest {
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
.setOutputs("0").build();
ComputationGraph net = new ComputationGraph(conf);
net.setListeners(new ScoreIterationListener(1));
net.setListeners(new ScoreIterationListener(5));
JavaRDD<DataSet> irisData = getIris();

View File

@ -59,7 +59,7 @@ public class TestShuffleExamples extends BaseSparkTest {
int totalExampleCount = 0;
for (DataSet ds : shuffledList) {
totalExampleCount += ds.getFeatures().length();
System.out.println(Arrays.toString(ds.getFeatures().data().asFloat()));
// System.out.println(Arrays.toString(ds.getFeatures().data().asFloat()));
assertEquals(ds.getFeatures(), ds.getLabels());
}

View File

@ -86,7 +86,7 @@ public class TestExport extends BaseSparkTest {
for (File file : files) {
if (!file.getPath().endsWith(".bin"))
continue;
System.out.println(file);
// System.out.println(file);
DataSet ds = new DataSet();
ds.load(file);
assertEquals(minibatchSize, ds.numExamples());
@ -144,7 +144,7 @@ public class TestExport extends BaseSparkTest {
for (File file : files) {
if (!file.getPath().endsWith(".bin"))
continue;
System.out.println(file);
// System.out.println(file);
MultiDataSet ds = new org.nd4j.linalg.dataset.MultiDataSet();
ds.load(file);
assertEquals(minibatchSize, ds.getFeatures(0).size(0));

View File

@ -92,9 +92,9 @@ public class HashingBalancedPartitionerTest extends BaseSparkTest {
int[][] colorCountsByPartition = new int[3][2];
for (final Tuple2<Tuple2<Long, Integer>, String> val : testList) {
System.out.println(val);
// System.out.println(val);
Integer partition = hbp.getPartition(val._1());
System.out.println(partition);
// System.out.println(partition);
if (val._2().equals("red"))
colorCountsByPartition[partition][0] += 1;
@ -102,9 +102,9 @@ public class HashingBalancedPartitionerTest extends BaseSparkTest {
colorCountsByPartition[partition][1] += 1;
}
for (int i = 0; i < 3; i++) {
System.out.println(Arrays.toString(colorCountsByPartition[i]));
}
// for (int i = 0; i < 3; i++) {
// System.out.println(Arrays.toString(colorCountsByPartition[i]));
// }
for (int i = 0; i < 3; i++) {
// avg red per partition : 2.33
assertTrue(colorCountsByPartition[i][0] >= 1 && colorCountsByPartition[i][0] < 4);
@ -178,12 +178,12 @@ public class HashingBalancedPartitionerTest extends BaseSparkTest {
colorCountsByPartition[partition][1] += 1;
}
for (int i = 0; i < numPartitions; i++) {
System.out.println(Arrays.toString(colorCountsByPartition[i]));
}
System.out.println("Ideal red # per partition: " + avgRed);
System.out.println("Ideal blue # per partition: " + avgBlue);
// for (int i = 0; i < numPartitions; i++) {
// System.out.println(Arrays.toString(colorCountsByPartition[i]));
// }
//
// System.out.println("Ideal red # per partition: " + avgRed);
// System.out.println("Ideal blue # per partition: " + avgBlue);
for (int i = 0; i < numPartitions; i++) {
// avg red per partition : 2.33

View File

@ -115,7 +115,7 @@ public class TestSparkComputationGraph extends BaseSparkTest {
TrainingMaster tm = new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 10, 1, 0);
SparkComputationGraph scg = new SparkComputationGraph(sc, cg, tm);
scg.setListeners(Collections.singleton((TrainingListener) new ScoreIterationListener(1)));
scg.setListeners(Collections.singleton((TrainingListener) new ScoreIterationListener(5)));
JavaRDD<MultiDataSet> rdd = sc.parallelize(list);
scg.fitMultiDataSet(rdd);

View File

@ -31,8 +31,11 @@ import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster;
import org.junit.Test;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions;
@ -45,8 +48,24 @@ import static org.junit.Assert.assertTrue;
@Slf4j
public class TestSparkDl4jMultiLayer extends BaseSparkTest {
@Test(timeout = 120000L)
@Override
public long getTimeoutMilliseconds() {
return 120000L;
}
@Override
public DataType getDataType() {
return DataType.FLOAT;
}
@Override
public DataType getDefaultFPDataType() {
return DataType.FLOAT;
}
@Test
public void testEvaluationSimple() throws Exception {
Nd4j.getRandom().setSeed(12345);
for( int evalWorkers : new int[]{1, 4, 8}) {
//Simple test to validate DL4J issue 4099 is fixed...
@ -75,18 +94,18 @@ public class TestSparkDl4jMultiLayer extends BaseSparkTest {
//----------------------------------
//Create network configuration and conduct network training
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.FLOAT)
.seed(12345)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.activation(Activation.LEAKYRELU)
.weightInit(WeightInit.XAVIER)
.updater(new Nesterovs(0.02, 0.9))
.l2(1e-4)
.updater(new Adam(1e-3))
.l2(1e-5)
.list()
.layer(0, new DenseLayer.Builder().nIn(28 * 28).nOut(500).build())
.layer(1, new DenseLayer.Builder().nIn(500).nOut(100).build())
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.activation(Activation.SOFTMAX).nIn(100).nOut(10).build())
.build();
//Configuration for Spark training: see https://deeplearning4j.org/docs/latest/deeplearning4j-scaleout-howto for explanation of these configuration options

View File

@ -333,15 +333,16 @@ public class TestCompareParameterAveragingSparkVsSingleMachine {
sparkNet.fit(rdd);
}
System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
// System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
sparkNet.getSparkTrainingStats().statsAsString();
INDArray finalSparkParams = sparkNet.getNetwork().params().dup();
System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
System.out.println("Initial (Spark) params: "
+ Arrays.toString(initialSparkParams.data().asFloat()));
System.out.println("Final (Local) params: " + Arrays.toString(finalParams.data().asFloat()));
System.out.println("Final (Spark) params: " + Arrays.toString(finalSparkParams.data().asFloat()));
// System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
// System.out.println("Initial (Spark) params: "
// + Arrays.toString(initialSparkParams.data().asFloat()));
// System.out.println("Final (Local) params: " + Arrays.toString(finalParams.data().asFloat()));
// System.out.println("Final (Spark) params: " + Arrays.toString(finalSparkParams.data().asFloat()));
assertEquals(initialParams, initialSparkParams);
assertNotEquals(initialParams, finalParams);
assertEquals(finalParams, finalSparkParams);
@ -405,15 +406,16 @@ public class TestCompareParameterAveragingSparkVsSingleMachine {
sparkNet.fit(rdd);
}
System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
// System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
sparkNet.getSparkTrainingStats().statsAsString();
INDArray finalSparkParams = sparkNet.getNetwork().params().dup();
System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
System.out.println("Initial (Spark) params: "
+ Arrays.toString(initialSparkParams.data().asFloat()));
System.out.println("Final (Local) params: " + Arrays.toString(finalParams.data().asFloat()));
System.out.println("Final (Spark) params: " + Arrays.toString(finalSparkParams.data().asFloat()));
// System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
// System.out.println("Initial (Spark) params: "
// + Arrays.toString(initialSparkParams.data().asFloat()));
// System.out.println("Final (Local) params: " + Arrays.toString(finalParams.data().asFloat()));
// System.out.println("Final (Spark) params: " + Arrays.toString(finalSparkParams.data().asFloat()));
assertArrayEquals(initialParams.data().asFloat(), initialSparkParams.data().asFloat(), 1e-8f);
assertArrayEquals(finalParams.data().asFloat(), finalSparkParams.data().asFloat(), 1e-6f);
@ -478,18 +480,19 @@ public class TestCompareParameterAveragingSparkVsSingleMachine {
sparkNet.fit(rdd);
}
System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
// System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
sparkNet.getSparkTrainingStats().statsAsString();
INDArray finalSparkParams = sparkNet.getNetwork().params().dup();
// executioner.addToWatchdog(finalSparkParams, "finalSparkParams");
float[] fp = finalParams.data().asFloat();
float[] fps = finalSparkParams.data().asFloat();
System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
System.out.println("Initial (Spark) params: "
+ Arrays.toString(initialSparkParams.data().asFloat()));
System.out.println("Final (Local) params: " + Arrays.toString(fp));
System.out.println("Final (Spark) params: " + Arrays.toString(fps));
// System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
// System.out.println("Initial (Spark) params: "
// + Arrays.toString(initialSparkParams.data().asFloat()));
// System.out.println("Final (Local) params: " + Arrays.toString(fp));
// System.out.println("Final (Spark) params: " + Arrays.toString(fps));
assertEquals(initialParams, initialSparkParams);
assertNotEquals(initialParams, finalParams);
@ -551,14 +554,15 @@ public class TestCompareParameterAveragingSparkVsSingleMachine {
sparkNet.fit(rdd);
}
System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
// System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
sparkNet.getSparkTrainingStats().statsAsString();
INDArray finalSparkParams = sparkNet.getNetwork().params().dup();
System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
System.out.println("Initial (Spark) params: " + Arrays.toString(initialSparkParams.data().asFloat()));
System.out.println("Final (Local) params: " + Arrays.toString(finalParams.data().asFloat()));
System.out.println("Final (Spark) params: " + Arrays.toString(finalSparkParams.data().asFloat()));
// System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
// System.out.println("Initial (Spark) params: " + Arrays.toString(initialSparkParams.data().asFloat()));
// System.out.println("Final (Local) params: " + Arrays.toString(finalParams.data().asFloat()));
// System.out.println("Final (Spark) params: " + Arrays.toString(finalSparkParams.data().asFloat()));
assertArrayEquals(initialParams.data().asFloat(), initialSparkParams.data().asFloat(), 1e-8f);
assertArrayEquals(finalParams.data().asFloat(), finalSparkParams.data().asFloat(), 1e-6f);

View File

@ -37,7 +37,7 @@ public class TestJsonYaml {
String json = tm.toJson();
String yaml = tm.toYaml();
System.out.println(json);
// System.out.println(json);
TrainingMaster fromJson = ParameterAveragingTrainingMaster.fromJson(json);
TrainingMaster fromYaml = ParameterAveragingTrainingMaster.fromYaml(yaml);

View File

@ -389,7 +389,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
List<EventStats> workerFitStats = stats.getValue("ParameterAveragingWorkerFitTimesMs");
for (EventStats e : workerFitStats) {
ExampleCountEventStats eces = (ExampleCountEventStats) e;
System.out.println(eces.getTotalExampleCount());
// System.out.println(eces.getTotalExampleCount());
}
for (EventStats e : workerFitStats) {
@ -457,7 +457,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
assertNotEquals(paramsBefore, paramsAfter);
SparkTrainingStats stats = sparkNet.getSparkTrainingStats();
System.out.println(stats.statsAsString());
// System.out.println(stats.statsAsString());
stats.statsAsString();
sparkNet.getTrainingMaster().deleteTempFiles(sc);
}
@ -483,7 +484,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
i++;
}
System.out.println("Saved to: " + tempDirF.getAbsolutePath());
// System.out.println("Saved to: " + tempDirF.getAbsolutePath());
@ -527,7 +528,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
SparkTrainingStats stats = sparkNet.getSparkTrainingStats();
//Expect
System.out.println(stats.statsAsString());
// System.out.println(stats.statsAsString());
stats.statsAsString();
assertEquals(numSplits, stats.getValue("ParameterAveragingMasterRepartitionTimesMs").size());
List<EventStats> list = stats.getValue("ParameterAveragingWorkerFitTimesMs");
@ -566,8 +568,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
i++;
}
System.out.println("Saved to: " + tempDirF.getAbsolutePath());
System.out.println("Saved to: " + tempDirF2.getAbsolutePath());
// System.out.println("Saved to: " + tempDirF.getAbsolutePath());
// System.out.println("Saved to: " + tempDirF2.getAbsolutePath());
@ -610,7 +612,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
assertNotEquals(paramsBefore, paramsAfter);
SparkTrainingStats stats = sparkNet.getSparkTrainingStats();
System.out.println(stats.statsAsString());
// System.out.println(stats.statsAsString());
stats.statsAsString();
//Same thing, buf for MultiDataSet objects:
config = new Configuration();
@ -631,7 +634,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
assertNotEquals(paramsBefore, paramsAfter);
stats = sparkNet.getSparkTrainingStats();
System.out.println(stats.statsAsString());
// System.out.println(stats.statsAsString());
stats.statsAsString();
}
@ -730,13 +734,13 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
.build();
for (int avgFreq : new int[] {1, 5, 10}) {
System.out.println("--- Avg freq " + avgFreq + " ---");
// System.out.println("--- Avg freq " + avgFreq + " ---");
SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, conf.clone(),
new ParameterAveragingTrainingMaster.Builder(numExecutors(), dataSetObjSize)
.batchSizePerWorker(batchSizePerExecutor).averagingFrequency(avgFreq)
.repartionData(Repartition.Always).build());
sparkNet.setListeners(new ScoreIterationListener(1));
sparkNet.setListeners(new ScoreIterationListener(5));
@ -778,13 +782,13 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
.setOutputs("1").build();
for (int avgFreq : new int[] {1, 5, 10}) {
System.out.println("--- Avg freq " + avgFreq + " ---");
// System.out.println("--- Avg freq " + avgFreq + " ---");
SparkComputationGraph sparkNet = new SparkComputationGraph(sc, conf.clone(),
new ParameterAveragingTrainingMaster.Builder(numExecutors(), dataSetObjSize)
.batchSizePerWorker(batchSizePerExecutor).averagingFrequency(avgFreq)
.repartionData(Repartition.Always).build());
sparkNet.setListeners(new ScoreIterationListener(1));
sparkNet.setListeners(new ScoreIterationListener(5));
JavaRDD<DataSet> rdd = sc.parallelize(list);

View File

@ -107,7 +107,7 @@ public class TestTrainingStatsCollection extends BaseSparkTest {
expectedStatNames.addAll(c);
}
System.out.println(expectedStatNames);
// System.out.println(expectedStatNames);
SparkTrainingStats stats = sparkNet.getSparkTrainingStats();
@ -119,7 +119,7 @@ public class TestTrainingStatsCollection extends BaseSparkTest {
}
String statsAsString = stats.statsAsString();
System.out.println(statsAsString);
// System.out.println(statsAsString);
assertEquals(actualKeySet.size(), statsAsString.split("\n").length); //One line per stat

View File

@ -35,7 +35,7 @@ public class TestTimeSource {
long systemTime = System.currentTimeMillis();
long ntpTime = timeSource.currentTimeMillis();
long offset = ntpTime - systemTime;
System.out.println("System: " + systemTime + "\tNTPTimeSource: " + ntpTime + "\tOffset: " + offset);
// System.out.println("System: " + systemTime + "\tNTPTimeSource: " + ntpTime + "\tOffset: " + offset);
Thread.sleep(500);
}
}
@ -49,7 +49,7 @@ public class TestTimeSource {
long systemTime = System.currentTimeMillis();
long ntpTime = timeSource.currentTimeMillis();
long offset = ntpTime - systemTime;
System.out.println("System: " + systemTime + "\tSystemClockTimeSource: " + ntpTime + "\tOffset: " + offset);
// System.out.println("System: " + systemTime + "\tSystemClockTimeSource: " + ntpTime + "\tOffset: " + offset);
assertEquals(systemTime, ntpTime, 2); //Should be exact, but we might randomly tick over between one ms and the next
Thread.sleep(500);
}

View File

@ -87,7 +87,7 @@ public class TestListeners extends BaseSparkTest {
net.fit(rdd);
List<String> sessions = ss.listSessionIDs();
System.out.println("Sessions: " + sessions);
// System.out.println("Sessions: " + sessions);
assertEquals(1, sessions.size());
String sid = sessions.get(0);
@ -95,15 +95,15 @@ public class TestListeners extends BaseSparkTest {
List<String> typeIDs = ss.listTypeIDsForSession(sid);
List<String> workers = ss.listWorkerIDsForSession(sid);
System.out.println(sid + "\t" + typeIDs + "\t" + workers);
// System.out.println(sid + "\t" + typeIDs + "\t" + workers);
List<Persistable> lastUpdates = ss.getLatestUpdateAllWorkers(sid, StatsListener.TYPE_ID);
System.out.println(lastUpdates);
// System.out.println(lastUpdates);
System.out.println("Static info:");
// System.out.println("Static info:");
for (String wid : workers) {
Persistable staticInfo = ss.getStaticInfo(sid, StatsListener.TYPE_ID, wid);
System.out.println(sid + "\t" + wid);
// System.out.println(sid + "\t" + wid);
}
assertEquals(1, typeIDs.size());

View File

@ -63,7 +63,7 @@ public class TestRepartitioning extends BaseSparkTest {
assertEquals(10, rdd2.partitions().size());
for (int i = 0; i < 10; i++) {
List<String> partition = rdd2.collectPartitions(new int[] {i})[0];
System.out.println("Partition " + i + " size: " + partition.size());
// System.out.println("Partition " + i + " size: " + partition.size());
assertEquals(100, partition.size()); //Should be exactly 100, for the util method (but NOT spark .repartition)
}
}
@ -170,7 +170,7 @@ public class TestRepartitioning extends BaseSparkTest {
List<Tuple2<Integer, Integer>> partitionCounts = initial.values().mapPartitionsWithIndex(new CountPartitionsFunction<Integer>(), true).collect();
System.out.println(partitionCounts);
// System.out.println(partitionCounts);
List<Tuple2<Integer,Integer>> initialExpected = Arrays.asList(
new Tuple2<>(0,29),
@ -185,7 +185,7 @@ public class TestRepartitioning extends BaseSparkTest {
JavaRDD<Integer> afterRepartition = SparkUtils.repartitionBalanceIfRequired(initial.values(), Repartition.Always, 2, 112);
List<Tuple2<Integer, Integer>> partitionCountsAfter = afterRepartition.mapPartitionsWithIndex(new CountPartitionsFunction<Integer>(), true).collect();
System.out.println(partitionCountsAfter);
// System.out.println(partitionCountsAfter);
for(Tuple2<Integer,Integer> t2 : partitionCountsAfter){
assertEquals(2, (int)t2._2());
@ -219,8 +219,8 @@ public class TestRepartitioning extends BaseSparkTest {
}
}
System.out.println("min: " + min + "\t@\t" + minIdx);
System.out.println("max: " + max + "\t@\t" + maxIdx);
// System.out.println("min: " + min + "\t@\t" + minIdx);
// System.out.println("max: " + max + "\t@\t" + maxIdx);
assertEquals(1, min);
assertEquals(2, max);
@ -244,7 +244,7 @@ public class TestRepartitioning extends BaseSparkTest {
for (int i = 0; i < 10; i++) {
List<String> partition = rdd2.collectPartitions(new int[] {i})[0];
System.out.println("Partition " + i + " size: " + partition.size());
// System.out.println("Partition " + i + " size: " + partition.size());
assertTrue(partition.size() >= 90 && partition.size() <= 110);
}
}

View File

@ -123,7 +123,7 @@ trait NDArrayExtractionTestBase extends FlatSpec { self: OrderingForTest =>
val expectedSlice = expectedArray.slice(0)
val actualSlice = expectedArray(0, ->)
Console.println(expectedSlice)
// Console.println(expectedSlice)
assert(actualSlice == expectedSlice)
}

View File

@ -28,7 +28,7 @@ class TrainingTest extends FlatSpec with Matchers {
val unused3 = unused1.div(unused2)
val loss1 = add.std("l1", true)
val loss2 = mmul.mean("l2")
Console.println(sd.summary)
// Console.println(sd.summary)
if (i == 0) {
sd.setLossVariables("l1", "l2")
sd.createGradFunction()

View File

@ -43,8 +43,8 @@ public class HistoryProcessorTest {
hp.add(a);
INDArray[] h = hp.getHistory();
assertEquals(4, h.length);
System.out.println(Arrays.toString(a.shape()));
System.out.println(Arrays.toString(h[0].shape()));
// System.out.println(Arrays.toString(a.shape()));
// System.out.println(Arrays.toString(h[0].shape()));
assertEquals( 1, h[0].shape()[0]);
assertEquals(a.shape()[0], h[0].shape()[1]);
assertEquals(a.shape()[1], h[0].shape()[2]);

View File

@ -100,8 +100,8 @@ public class ActorCriticTest {
double error2 = gradient2 - gradient.getDouble(1);
double relError1 = error1 / gradient.getDouble(0);
double relError2 = error2 / gradient.getDouble(1);
System.out.println(gradient.getDouble(0) + " " + gradient1 + " " + relError1);
System.out.println(gradient.getDouble(1) + " " + gradient2 + " " + relError2);
// System.out.println(gradient.getDouble(0) + " " + gradient1 + " " + relError1);
// System.out.println(gradient.getDouble(1) + " " + gradient2 + " " + relError2);
assertTrue(gradient.getDouble(0) < maxRelError || Math.abs(relError1) < maxRelError);
assertTrue(gradient.getDouble(1) < maxRelError || Math.abs(relError2) < maxRelError);
}

View File

@ -158,7 +158,7 @@ public class PolicyTest {
for (int i = 0; i < 100; i++) {
count[policy.nextAction(input)]++;
}
System.out.println(count[0] + " " + count[1] + " " + count[2] + " " + count[3]);
// System.out.println(count[0] + " " + count[1] + " " + count[2] + " " + count[3]);
assertTrue(count[0] < 20);
assertTrue(count[1] < 30);
assertTrue(count[2] < 40);