Test fixes + cleanup (#245)

* Test spam reduction

Signed-off-by: Alex Black <blacka101@gmail.com>

* Arbiter bad import fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* Small spark test tweak

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Arbiter test log spam reduction

Signed-off-by: Alex Black <blacka101@gmail.com>

* More test spam reduction

Signed-off-by: Alex Black <blacka101@gmail.com>
master
Alex Black 2020-02-18 10:29:06 +11:00 committed by GitHub
parent 2698fbf541
commit c8882cbfa5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
41 changed files with 168 additions and 182 deletions

View File

@ -1,47 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.util;
import org.slf4j.Logger;
import java.awt.*;
import java.net.URI;
/**
* Various utilities for webpages and dealing with browsers
*/
public class WebUtils {
public static void tryOpenBrowser(String path, Logger log) {
try {
WebUtils.openBrowser(new URI(path));
} catch (Exception e) {
log.error("Could not open browser", e);
System.out.println("Browser could not be launched automatically.\nUI path: " + path);
}
}
public static void openBrowser(URI uri) throws Exception {
if (Desktop.isDesktopSupported()) {
Desktop.getDesktop().browse(uri);
} else {
throw new UnsupportedOperationException(
"Cannot open browser on this platform: Desktop.isDesktopSupported() == false");
}
}
}

View File

@ -127,7 +127,7 @@ public class BraninFunction {
BraninConfig candidate = (BraninConfig) c.getValue(); BraninConfig candidate = (BraninConfig) c.getValue();
double score = scoreFunction.score(candidate, null, (Map) null); double score = scoreFunction.score(candidate, null, (Map) null);
System.out.println(candidate.getX1() + "\t" + candidate.getX2() + "\t" + score); // System.out.println(candidate.getX1() + "\t" + candidate.getX2() + "\t" + score);
Thread.sleep(20); Thread.sleep(20);

View File

@ -54,7 +54,7 @@ public class TestRandomSearch extends BaseDL4JTest {
runner.execute(); runner.execute();
System.out.println("----- Complete -----"); // System.out.println("----- Complete -----");
} }

View File

@ -16,8 +16,8 @@
package org.deeplearning4j.arbiter.optimize.genetic; package org.deeplearning4j.arbiter.optimize.genetic;
import org.apache.commons.lang3.NotImplementedException;
import org.apache.commons.math3.random.RandomGenerator; import org.apache.commons.math3.random.RandomGenerator;
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
public class TestRandomGenerator implements RandomGenerator { public class TestRandomGenerator implements RandomGenerator {
private final int[] intRandomNumbers; private final int[] intRandomNumbers;
@ -63,17 +63,17 @@ public class TestRandomGenerator implements RandomGenerator {
@Override @Override
public long nextLong() { public long nextLong() {
throw new NotImplementedException(); throw new NotImplementedException("Not implemented");
} }
@Override @Override
public boolean nextBoolean() { public boolean nextBoolean() {
throw new NotImplementedException(); throw new NotImplementedException("Not implemented");
} }
@Override @Override
public float nextFloat() { public float nextFloat() {
throw new NotImplementedException(); throw new NotImplementedException("Not implemented");
} }
@Override @Override
@ -83,6 +83,6 @@ public class TestRandomGenerator implements RandomGenerator {
@Override @Override
public double nextGaussian() { public double nextGaussian() {
throw new NotImplementedException(); throw new NotImplementedException("Not implemented");
} }
} }

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.crossover; package org.deeplearning4j.arbiter.optimize.genetic.crossover;
import org.apache.commons.lang3.NotImplementedException;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult; import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.TwoParentsCrossoverOperator; import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.TwoParentsCrossoverOperator;
@ -26,7 +27,6 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection;
import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer; import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
public class TwoParentsCrossoverOperatorTests extends BaseDL4JTest { public class TwoParentsCrossoverOperatorTests extends BaseDL4JTest {
@ -42,7 +42,7 @@ public class TwoParentsCrossoverOperatorTests extends BaseDL4JTest {
@Override @Override
public CrossoverResult crossover() { public CrossoverResult crossover() {
throw new NotImplementedException(); throw new NotImplementedException("Not implemented");
} }
} }

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.culling; package org.deeplearning4j.arbiter.optimize.genetic.culling;
import org.apache.commons.lang3.NotImplementedException;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome; import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.RatioCullOperator; import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.RatioCullOperator;
@ -24,7 +25,6 @@ import org.deeplearning4j.arbiter.optimize.generator.genetic.population.Populati
import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer; import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
import java.util.List; import java.util.List;
@ -46,7 +46,7 @@ public class RatioCullOperatorTests extends BaseDL4JTest {
@Override @Override
public void cullPopulation() { public void cullPopulation() {
throw new NotImplementedException(); throw new NotImplementedException("Not implemented");
} }
public double getCullRatio() { public double getCullRatio() {

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.selection; package org.deeplearning4j.arbiter.optimize.genetic.selection;
import org.apache.commons.lang3.NotImplementedException;
import org.apache.commons.math3.random.RandomGenerator; import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory; import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
@ -33,7 +34,6 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer;
import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator; import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertArrayEquals;
@ -55,7 +55,7 @@ public class GeneticSelectionOperatorTests extends BaseDL4JTest {
@Override @Override
public void cullPopulation() { public void cullPopulation() {
throw new NotImplementedException(); throw new NotImplementedException("Not implemented");
} }
@Override @Override

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize.genetic.selection; package org.deeplearning4j.arbiter.optimize.genetic.selection;
import org.apache.commons.lang3.NotImplementedException;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory; import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer; import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
@ -24,7 +25,6 @@ import org.deeplearning4j.arbiter.optimize.generator.genetic.selection.Selection
import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer; import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
public class SelectionOperatorTests extends BaseDL4JTest { public class SelectionOperatorTests extends BaseDL4JTest {
private class TestSelectionOperator extends SelectionOperator { private class TestSelectionOperator extends SelectionOperator {
@ -39,7 +39,7 @@ public class SelectionOperatorTests extends BaseDL4JTest {
@Override @Override
public double[] buildNextGenes() { public double[] buildNextGenes() {
throw new NotImplementedException(); throw new NotImplementedException("Not implemented");
} }
} }

View File

@ -158,7 +158,7 @@ public class TestComputationGraphSpace extends BaseDL4JTest {
} }
} }
System.out.println("ReLU vs. Tanh: " + reluCount + "\t" + tanhCount); // System.out.println("ReLU vs. Tanh: " + reluCount + "\t" + tanhCount);
assertTrue(reluCount > 0); assertTrue(reluCount > 0);
assertTrue(tanhCount > 0); assertTrue(tanhCount > 0);

View File

@ -162,7 +162,7 @@ public class TestGraphLocalExecution extends BaseDL4JTest {
List<ResultReference> results = runner.getResults(); List<ResultReference> results = runner.getResults();
assertTrue(results.size() > 0); assertTrue(results.size() > 0);
System.out.println("----- COMPLETE - " + results.size() + " results -----"); // System.out.println("----- COMPLETE - " + results.size() + " results -----");
} }
} }

View File

@ -165,7 +165,7 @@ public class TestGraphLocalExecutionGenetic extends BaseDL4JTest {
List<ResultReference> results = runner.getResults(); List<ResultReference> results = runner.getResults();
assertTrue(results.size() > 0); assertTrue(results.size() > 0);
System.out.println("----- COMPLETE - " + results.size() + " results -----"); // System.out.println("----- COMPLETE - " + results.size() + " results -----");
} }
} }

View File

@ -101,7 +101,7 @@ public class TestLayerSpace extends BaseDL4JTest {
double l2 = TestUtils.getL2(l); double l2 = TestUtils.getL2(l);
IActivation activation = l.getActivationFn(); IActivation activation = l.getActivationFn();
System.out.println(lr + "\t" + l2 + "\t" + activation); // System.out.println(lr + "\t" + l2 + "\t" + activation);
assertTrue(lr >= 0.3 && lr <= 0.4); assertTrue(lr >= 0.3 && lr <= 0.4);
assertTrue(l2 >= 0.01 && l2 <= 0.1); assertTrue(l2 >= 0.01 && l2 <= 0.1);
@ -190,7 +190,7 @@ public class TestLayerSpace extends BaseDL4JTest {
ActivationLayer al = als.getValue(d); ActivationLayer al = als.getValue(d);
IActivation activation = al.getActivationFn(); IActivation activation = al.getActivationFn();
System.out.println(activation); // System.out.println(activation);
assertTrue(containsActivationFunction(actFns, activation)); assertTrue(containsActivationFunction(actFns, activation));
} }
@ -228,7 +228,7 @@ public class TestLayerSpace extends BaseDL4JTest {
IActivation activation = el.getActivationFn(); IActivation activation = el.getActivationFn();
long nOut = el.getNOut(); long nOut = el.getNOut();
System.out.println(activation + "\t" + nOut); // System.out.println(activation + "\t" + nOut);
assertTrue(containsActivationFunction(actFns, activation)); assertTrue(containsActivationFunction(actFns, activation));
assertTrue(nOut >= 10 && nOut <= 20); assertTrue(nOut >= 10 && nOut <= 20);
@ -295,7 +295,7 @@ public class TestLayerSpace extends BaseDL4JTest {
long nOut = el.getNOut(); long nOut = el.getNOut();
double forgetGate = el.getForgetGateBiasInit(); double forgetGate = el.getForgetGateBiasInit();
System.out.println(activation + "\t" + nOut + "\t" + forgetGate); // System.out.println(activation + "\t" + nOut + "\t" + forgetGate);
assertTrue(containsActivationFunction(actFns, activation)); assertTrue(containsActivationFunction(actFns, activation));
assertTrue(nOut >= 10 && nOut <= 20); assertTrue(nOut >= 10 && nOut <= 20);

View File

@ -293,8 +293,8 @@ public class TestMultiLayerSpace extends BaseDL4JTest {
assertTrue(nLayerCounts[i] >= 5); //Expect approx equal (50/3 each), but some variation randomly assertTrue(nLayerCounts[i] >= 5); //Expect approx equal (50/3 each), but some variation randomly
} }
System.out.println("Number of layers: " + Arrays.toString(nLayerCounts)); // System.out.println("Number of layers: " + Arrays.toString(nLayerCounts));
System.out.println("ReLU vs. Tanh: " + reluCount + "\t" + tanhCount); // System.out.println("ReLU vs. Tanh: " + reluCount + "\t" + tanhCount);
} }

View File

@ -98,7 +98,8 @@ public class ArbiterCLIRunnerTest extends BaseDL4JTest {
assertEquals(configuration,OptimizationConfiguration.fromJson(configuration.toJson())); assertEquals(configuration,OptimizationConfiguration.fromJson(configuration.toJson()));
FileUtils.writeStringToFile(new File(configPath),configuration.toJson()); FileUtils.writeStringToFile(new File(configPath),configuration.toJson());
System.out.println(configuration.toJson()); // System.out.println(configuration.toJson());
configuration.toJson();
log.info("Starting test"); log.info("Starting test");
cliRunner.runMain( cliRunner.runMain(

View File

@ -41,7 +41,7 @@ public class TestGraphLoading extends BaseDL4JTest {
IGraph<String, String> graph = GraphLoader IGraph<String, String> graph = GraphLoader
.loadUndirectedGraphEdgeListFile(cpr.getTempFileFromArchive().getAbsolutePath(), 7, ","); .loadUndirectedGraphEdgeListFile(cpr.getTempFileFromArchive().getAbsolutePath(), 7, ",");
System.out.println(graph); // System.out.println(graph);
assertEquals(graph.numVertices(), 7); assertEquals(graph.numVertices(), 7);
int[][] edges = {{1, 2}, {0, 2, 4}, {0, 1, 3, 4}, {2, 4, 5}, {1, 2, 3, 5, 6}, {3, 4, 6}, {4, 5}}; int[][] edges = {{1, 2}, {0, 2, 4}, {0, 1, 3, 4}, {2, 4, 5}, {1, 2, 3, 5, 6}, {3, 4, 6}, {4, 5}};
@ -66,7 +66,7 @@ public class TestGraphLoading extends BaseDL4JTest {
edgeLineProcessor, vertexFactory, 10, false); edgeLineProcessor, vertexFactory, 10, false);
System.out.println(graph); // System.out.println(graph);
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
List<Edge<String>> edges = graph.getEdgesOut(i); List<Edge<String>> edges = graph.getEdgesOut(i);
@ -111,7 +111,7 @@ public class TestGraphLoading extends BaseDL4JTest {
Graph<String, String> graph = GraphLoader.loadGraph(verticesCPR.getTempFileFromArchive().getAbsolutePath(), Graph<String, String> graph = GraphLoader.loadGraph(verticesCPR.getTempFileFromArchive().getAbsolutePath(),
edgesCPR.getTempFileFromArchive().getAbsolutePath(), vertexLoader, edgeLineProcessor, false); edgesCPR.getTempFileFromArchive().getAbsolutePath(), vertexLoader, edgeLineProcessor, false);
System.out.println(graph); // System.out.println(graph);
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
List<Edge<String>> edges = graph.getEdgesOut(i); List<Edge<String>> edges = graph.getEdgesOut(i);

View File

@ -71,7 +71,7 @@ public class TestGraphLoadingWeighted extends BaseDL4JTest {
} }
} }
System.out.println(graph); // System.out.println(graph);
} }

View File

@ -220,7 +220,7 @@ public class TestGraph extends BaseDL4JTest {
sum += transitionProb[i][j]; sum += transitionProb[i][j];
for (int j = 0; j < transitionProb[i].length; j++) for (int j = 0; j < transitionProb[i].length; j++)
transitionProb[i][j] /= sum; transitionProb[i][j] /= sum;
System.out.println(Arrays.toString(transitionProb[i])); // System.out.println(Arrays.toString(transitionProb[i]));
} }
//Check that transition probs are essentially correct (within bounds of random variation) //Check that transition probs are essentially correct (within bounds of random variation)

View File

@ -145,8 +145,8 @@ public class DeepWalkGradientCheck extends BaseDL4JTest {
if (relError > MAX_REL_ERROR && absErr > MIN_ABS_ERROR) if (relError > MAX_REL_ERROR && absErr > MIN_ABS_ERROR)
fail(msg); fail(msg);
else // else
System.out.println(msg); // System.out.println(msg);
} }
} }
@ -333,10 +333,10 @@ public class DeepWalkGradientCheck extends BaseDL4JTest {
if (relError > MAX_REL_ERROR && absErr > MIN_ABS_ERROR) if (relError > MAX_REL_ERROR && absErr > MIN_ABS_ERROR)
fail(msg); fail(msg);
else // else
System.out.println(msg); // System.out.println(msg);
} }
System.out.println(); // System.out.println();
} }
} }

View File

@ -67,7 +67,7 @@ public class TestDeepWalk extends BaseDL4JTest {
for (int i = 0; i < 7; i++) { for (int i = 0; i < 7; i++) {
INDArray vector = deepWalk.getVertexVector(i); INDArray vector = deepWalk.getVertexVector(i);
assertArrayEquals(new long[] {vectorSize}, vector.shape()); assertArrayEquals(new long[] {vectorSize}, vector.shape());
System.out.println(Arrays.toString(vector.dup().data().asFloat())); // System.out.println(Arrays.toString(vector.dup().data().asFloat()));
} }
GraphWalkIterator<String> iter = new RandomWalkIterator<>(graph, 8); GraphWalkIterator<String> iter = new RandomWalkIterator<>(graph, 8);
@ -77,11 +77,11 @@ public class TestDeepWalk extends BaseDL4JTest {
for (int t = 0; t < 5; t++) { for (int t = 0; t < 5; t++) {
iter.reset(); iter.reset();
deepWalk.fit(iter); deepWalk.fit(iter);
System.out.println("--------------------"); // System.out.println("--------------------");
for (int i = 0; i < 7; i++) { for (int i = 0; i < 7; i++) {
INDArray vector = deepWalk.getVertexVector(i); INDArray vector = deepWalk.getVertexVector(i);
assertArrayEquals(new long[] {vectorSize}, vector.shape()); assertArrayEquals(new long[] {vectorSize}, vector.shape());
System.out.println(Arrays.toString(vector.dup().data().asFloat())); // System.out.println(Arrays.toString(vector.dup().data().asFloat()));
} }
} }
} }
@ -160,7 +160,7 @@ public class TestDeepWalk extends BaseDL4JTest {
continue; continue;
double sim = deepWalk.similarity(i, nearestTo); double sim = deepWalk.similarity(i, nearestTo);
System.out.println(i + "\t" + nearestTo + "\t" + sim); // System.out.println(i + "\t" + nearestTo + "\t" + sim);
assertTrue(sim <= minSimNearest); assertTrue(sim <= minSimNearest);
} }
} }
@ -211,7 +211,7 @@ public class TestDeepWalk extends BaseDL4JTest {
Graph<String, String> graph = GraphLoader Graph<String, String> graph = GraphLoader
.loadUndirectedGraphEdgeListFile(cpr.getTempFileFromArchive().getAbsolutePath(), 13, ","); .loadUndirectedGraphEdgeListFile(cpr.getTempFileFromArchive().getAbsolutePath(), 13, ",");
System.out.println(graph); // System.out.println(graph);
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -229,11 +229,13 @@ public class TestDeepWalk extends BaseDL4JTest {
//Calculate similarity(0,i) //Calculate similarity(0,i)
for (int i = 0; i < nVertices; i++) { for (int i = 0; i < nVertices; i++) {
System.out.println(deepWalk.similarity(0, i)); // System.out.println(deepWalk.similarity(0, i));
deepWalk.similarity(0, i);
} }
for (int i = 0; i < nVertices; i++) for (int i = 0; i < nVertices; i++)
System.out.println(deepWalk.getVertexVector(i)); // System.out.println(deepWalk.getVertexVector(i));
deepWalk.getVertexVector(i);
} }
@Test(timeout = 60000L) @Test(timeout = 60000L)

View File

@ -38,9 +38,11 @@ public class TestGraphHuffman extends BaseDL4JTest {
gh.buildTree(vertexDegrees); gh.buildTree(vertexDegrees);
for (int i = 0; i < 7; i++) for (int i = 0; i < 7; i++) {
System.out.println(i + "\t" + gh.getCodeLength(i) + "\t" + gh.getCodeString(i) + "\t\t" + gh.getCode(i) String s = i + "\t" + gh.getCodeLength(i) + "\t" + gh.getCodeString(i) + "\t\t" + gh.getCode(i)
+ "\t\t" + Arrays.toString(gh.getPathInnerNodes(i))); + "\t\t" + Arrays.toString(gh.getPathInnerNodes(i));
// System.out.println(s);
}
int[] expectedLengths = {3, 2, 2, 5, 4, 2, 5}; int[] expectedLengths = {3, 2, 2, 5, 4, 2, 5};
for (int i = 0; i < vertexDegrees.length; i++) { for (int i = 0; i < vertexDegrees.length; i++) {

View File

@ -79,8 +79,9 @@ public class ParameterServerParallelWrapperTest extends BaseDL4JTest {
model.init(); model.init();
ParallelWrapper parameterServerParallelWrapper = ParallelWrapper parameterServerParallelWrapper =
new ParallelWrapper.Builder(model).trainerFactory(new ParameterServerTrainerContext()) new ParallelWrapper.Builder(model)
.workers(Runtime.getRuntime().availableProcessors()) .workers(Math.min(4, Runtime.getRuntime().availableProcessors()))
.trainerFactory(new ParameterServerTrainerContext())
.reportScoreAfterAveraging(true).prefetchBuffer(3).build(); .reportScoreAfterAveraging(true).prefetchBuffer(3).build();
parameterServerParallelWrapper.fit(mnistTrain); parameterServerParallelWrapper.fit(mnistTrain);

View File

@ -104,7 +104,7 @@ public class SparkWord2VecTest extends BaseDL4JTest {
public void call(ExportContainer<VocabWord> v) throws Exception { public void call(ExportContainer<VocabWord> v) throws Exception {
assertNotNull(v.getElement()); assertNotNull(v.getElement());
assertNotNull(v.getArray()); assertNotNull(v.getArray());
System.out.println(v.getElement() + " - " + v.getArray()); // System.out.println(v.getElement() + " - " + v.getArray());
} }
} }
} }

View File

@ -66,7 +66,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest {
.lossFunction(LossFunctions.LossFunction.MCXENT).build()) .lossFunction(LossFunctions.LossFunction.MCXENT).build())
.build(); .build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.setListeners(new ScoreIterationListener(1)); net.setListeners(new ScoreIterationListener(5));
JavaRDD<DataSet> irisData = getIris(); JavaRDD<DataSet> irisData = getIris();
@ -119,7 +119,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest {
.lossFunction(LossFunctions.LossFunction.MSE).build()) .lossFunction(LossFunctions.LossFunction.MSE).build())
.build(); .build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.setListeners(new ScoreIterationListener(1)); net.setListeners(new ScoreIterationListener(5));
JavaRDD<DataSet> irisData = getIris(); JavaRDD<DataSet> irisData = getIris();
EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>(); EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
@ -155,7 +155,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest {
.lossFunction(LossFunctions.LossFunction.MCXENT).build()) .lossFunction(LossFunctions.LossFunction.MCXENT).build())
.build(); .build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.setListeners(new ScoreIterationListener(1)); net.setListeners(new ScoreIterationListener(5));
JavaRDD<DataSet> irisData = getIris(); JavaRDD<DataSet> irisData = getIris();
@ -198,7 +198,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest {
.lossFunction(LossFunctions.LossFunction.MCXENT).build()) .lossFunction(LossFunctions.LossFunction.MCXENT).build())
.build(); .build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.setListeners(new ScoreIterationListener(1)); net.setListeners(new ScoreIterationListener(5));
JavaRDD<DataSet> irisData = getIris(); JavaRDD<DataSet> irisData = getIris();
@ -231,7 +231,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest {
.lossFunction(LossFunctions.LossFunction.MCXENT).build()) .lossFunction(LossFunctions.LossFunction.MCXENT).build())
.build(); .build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.setListeners(new ScoreIterationListener(1)); net.setListeners(new ScoreIterationListener(5));
JavaRDD<DataSet> irisData = getIris(); JavaRDD<DataSet> irisData = getIris();

View File

@ -69,7 +69,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest {
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in") .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
.setOutputs("0").build(); .setOutputs("0").build();
ComputationGraph net = new ComputationGraph(conf); ComputationGraph net = new ComputationGraph(conf);
net.setListeners(new ScoreIterationListener(1)); net.setListeners(new ScoreIterationListener(5));
JavaRDD<DataSet> irisData = getIris(); JavaRDD<DataSet> irisData = getIris();
@ -120,7 +120,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest {
.lossFunction(LossFunctions.LossFunction.MSE).build(), "in") .lossFunction(LossFunctions.LossFunction.MSE).build(), "in")
.setOutputs("0").build(); .setOutputs("0").build();
ComputationGraph net = new ComputationGraph(conf); ComputationGraph net = new ComputationGraph(conf);
net.setListeners(new ScoreIterationListener(1)); net.setListeners(new ScoreIterationListener(5));
JavaRDD<DataSet> irisData = getIris(); JavaRDD<DataSet> irisData = getIris();
EarlyStoppingModelSaver<ComputationGraph> saver = new InMemoryModelSaver<>(); EarlyStoppingModelSaver<ComputationGraph> saver = new InMemoryModelSaver<>();
@ -158,7 +158,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest {
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in") .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
.setOutputs("0").build(); .setOutputs("0").build();
ComputationGraph net = new ComputationGraph(conf); ComputationGraph net = new ComputationGraph(conf);
net.setListeners(new ScoreIterationListener(1)); net.setListeners(new ScoreIterationListener(5));
JavaRDD<DataSet> irisData = getIris(); JavaRDD<DataSet> irisData = getIris();
@ -203,7 +203,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest {
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in") .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
.setOutputs("0").build(); .setOutputs("0").build();
ComputationGraph net = new ComputationGraph(conf); ComputationGraph net = new ComputationGraph(conf);
net.setListeners(new ScoreIterationListener(1)); net.setListeners(new ScoreIterationListener(5));
JavaRDD<DataSet> irisData = getIris(); JavaRDD<DataSet> irisData = getIris();
@ -238,7 +238,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest {
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in") .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
.setOutputs("0").build(); .setOutputs("0").build();
ComputationGraph net = new ComputationGraph(conf); ComputationGraph net = new ComputationGraph(conf);
net.setListeners(new ScoreIterationListener(1)); net.setListeners(new ScoreIterationListener(5));
JavaRDD<DataSet> irisData = getIris(); JavaRDD<DataSet> irisData = getIris();

View File

@ -59,7 +59,7 @@ public class TestShuffleExamples extends BaseSparkTest {
int totalExampleCount = 0; int totalExampleCount = 0;
for (DataSet ds : shuffledList) { for (DataSet ds : shuffledList) {
totalExampleCount += ds.getFeatures().length(); totalExampleCount += ds.getFeatures().length();
System.out.println(Arrays.toString(ds.getFeatures().data().asFloat())); // System.out.println(Arrays.toString(ds.getFeatures().data().asFloat()));
assertEquals(ds.getFeatures(), ds.getLabels()); assertEquals(ds.getFeatures(), ds.getLabels());
} }

View File

@ -86,7 +86,7 @@ public class TestExport extends BaseSparkTest {
for (File file : files) { for (File file : files) {
if (!file.getPath().endsWith(".bin")) if (!file.getPath().endsWith(".bin"))
continue; continue;
System.out.println(file); // System.out.println(file);
DataSet ds = new DataSet(); DataSet ds = new DataSet();
ds.load(file); ds.load(file);
assertEquals(minibatchSize, ds.numExamples()); assertEquals(minibatchSize, ds.numExamples());
@ -144,7 +144,7 @@ public class TestExport extends BaseSparkTest {
for (File file : files) { for (File file : files) {
if (!file.getPath().endsWith(".bin")) if (!file.getPath().endsWith(".bin"))
continue; continue;
System.out.println(file); // System.out.println(file);
MultiDataSet ds = new org.nd4j.linalg.dataset.MultiDataSet(); MultiDataSet ds = new org.nd4j.linalg.dataset.MultiDataSet();
ds.load(file); ds.load(file);
assertEquals(minibatchSize, ds.getFeatures(0).size(0)); assertEquals(minibatchSize, ds.getFeatures(0).size(0));

View File

@ -92,9 +92,9 @@ public class HashingBalancedPartitionerTest extends BaseSparkTest {
int[][] colorCountsByPartition = new int[3][2]; int[][] colorCountsByPartition = new int[3][2];
for (final Tuple2<Tuple2<Long, Integer>, String> val : testList) { for (final Tuple2<Tuple2<Long, Integer>, String> val : testList) {
System.out.println(val); // System.out.println(val);
Integer partition = hbp.getPartition(val._1()); Integer partition = hbp.getPartition(val._1());
System.out.println(partition); // System.out.println(partition);
if (val._2().equals("red")) if (val._2().equals("red"))
colorCountsByPartition[partition][0] += 1; colorCountsByPartition[partition][0] += 1;
@ -102,9 +102,9 @@ public class HashingBalancedPartitionerTest extends BaseSparkTest {
colorCountsByPartition[partition][1] += 1; colorCountsByPartition[partition][1] += 1;
} }
for (int i = 0; i < 3; i++) { // for (int i = 0; i < 3; i++) {
System.out.println(Arrays.toString(colorCountsByPartition[i])); // System.out.println(Arrays.toString(colorCountsByPartition[i]));
} // }
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
// avg red per partition : 2.33 // avg red per partition : 2.33
assertTrue(colorCountsByPartition[i][0] >= 1 && colorCountsByPartition[i][0] < 4); assertTrue(colorCountsByPartition[i][0] >= 1 && colorCountsByPartition[i][0] < 4);
@ -178,12 +178,12 @@ public class HashingBalancedPartitionerTest extends BaseSparkTest {
colorCountsByPartition[partition][1] += 1; colorCountsByPartition[partition][1] += 1;
} }
for (int i = 0; i < numPartitions; i++) { // for (int i = 0; i < numPartitions; i++) {
System.out.println(Arrays.toString(colorCountsByPartition[i])); // System.out.println(Arrays.toString(colorCountsByPartition[i]));
} // }
//
System.out.println("Ideal red # per partition: " + avgRed); // System.out.println("Ideal red # per partition: " + avgRed);
System.out.println("Ideal blue # per partition: " + avgBlue); // System.out.println("Ideal blue # per partition: " + avgBlue);
for (int i = 0; i < numPartitions; i++) { for (int i = 0; i < numPartitions; i++) {
// avg red per partition : 2.33 // avg red per partition : 2.33

View File

@ -115,7 +115,7 @@ public class TestSparkComputationGraph extends BaseSparkTest {
TrainingMaster tm = new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 10, 1, 0); TrainingMaster tm = new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 10, 1, 0);
SparkComputationGraph scg = new SparkComputationGraph(sc, cg, tm); SparkComputationGraph scg = new SparkComputationGraph(sc, cg, tm);
scg.setListeners(Collections.singleton((TrainingListener) new ScoreIterationListener(1))); scg.setListeners(Collections.singleton((TrainingListener) new ScoreIterationListener(5)));
JavaRDD<MultiDataSet> rdd = sc.parallelize(list); JavaRDD<MultiDataSet> rdd = sc.parallelize(list);
scg.fitMultiDataSet(rdd); scg.fitMultiDataSet(rdd);

View File

@ -31,8 +31,11 @@ import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster;
import org.junit.Test; import org.junit.Test;
import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.Nesterovs; import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
@ -45,8 +48,24 @@ import static org.junit.Assert.assertTrue;
@Slf4j @Slf4j
public class TestSparkDl4jMultiLayer extends BaseSparkTest { public class TestSparkDl4jMultiLayer extends BaseSparkTest {
@Test(timeout = 120000L) @Override
public long getTimeoutMilliseconds() {
return 120000L;
}
@Override
public DataType getDataType() {
return DataType.FLOAT;
}
@Override
public DataType getDefaultFPDataType() {
return DataType.FLOAT;
}
@Test
public void testEvaluationSimple() throws Exception { public void testEvaluationSimple() throws Exception {
Nd4j.getRandom().setSeed(12345);
for( int evalWorkers : new int[]{1, 4, 8}) { for( int evalWorkers : new int[]{1, 4, 8}) {
//Simple test to validate DL4J issue 4099 is fixed... //Simple test to validate DL4J issue 4099 is fixed...
@ -75,18 +94,18 @@ public class TestSparkDl4jMultiLayer extends BaseSparkTest {
//---------------------------------- //----------------------------------
//Create network configuration and conduct network training //Create network configuration and conduct network training
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.FLOAT)
.seed(12345) .seed(12345)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.activation(Activation.LEAKYRELU) .activation(Activation.LEAKYRELU)
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
.updater(new Nesterovs(0.02, 0.9)) .updater(new Adam(1e-3))
.l2(1e-4) .l2(1e-5)
.list() .list()
.layer(0, new DenseLayer.Builder().nIn(28 * 28).nOut(500).build()) .layer(0, new DenseLayer.Builder().nIn(28 * 28).nOut(500).build())
.layer(1, new DenseLayer.Builder().nIn(500).nOut(100).build()) .layer(1, new DenseLayer.Builder().nIn(500).nOut(100).build())
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.activation(Activation.SOFTMAX).nIn(100).nOut(10).build()) .activation(Activation.SOFTMAX).nIn(100).nOut(10).build())
.build(); .build();
//Configuration for Spark training: see https://deeplearning4j.org/docs/latest/deeplearning4j-scaleout-howto for explanation of these configuration options //Configuration for Spark training: see https://deeplearning4j.org/docs/latest/deeplearning4j-scaleout-howto for explanation of these configuration options

View File

@ -333,15 +333,16 @@ public class TestCompareParameterAveragingSparkVsSingleMachine {
sparkNet.fit(rdd); sparkNet.fit(rdd);
} }
System.out.println(sparkNet.getSparkTrainingStats().statsAsString()); // System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
sparkNet.getSparkTrainingStats().statsAsString();
INDArray finalSparkParams = sparkNet.getNetwork().params().dup(); INDArray finalSparkParams = sparkNet.getNetwork().params().dup();
System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat())); // System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
System.out.println("Initial (Spark) params: " // System.out.println("Initial (Spark) params: "
+ Arrays.toString(initialSparkParams.data().asFloat())); // + Arrays.toString(initialSparkParams.data().asFloat()));
System.out.println("Final (Local) params: " + Arrays.toString(finalParams.data().asFloat())); // System.out.println("Final (Local) params: " + Arrays.toString(finalParams.data().asFloat()));
System.out.println("Final (Spark) params: " + Arrays.toString(finalSparkParams.data().asFloat())); // System.out.println("Final (Spark) params: " + Arrays.toString(finalSparkParams.data().asFloat()));
assertEquals(initialParams, initialSparkParams); assertEquals(initialParams, initialSparkParams);
assertNotEquals(initialParams, finalParams); assertNotEquals(initialParams, finalParams);
assertEquals(finalParams, finalSparkParams); assertEquals(finalParams, finalSparkParams);
@ -405,15 +406,16 @@ public class TestCompareParameterAveragingSparkVsSingleMachine {
sparkNet.fit(rdd); sparkNet.fit(rdd);
} }
System.out.println(sparkNet.getSparkTrainingStats().statsAsString()); // System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
sparkNet.getSparkTrainingStats().statsAsString();
INDArray finalSparkParams = sparkNet.getNetwork().params().dup(); INDArray finalSparkParams = sparkNet.getNetwork().params().dup();
System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat())); // System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
System.out.println("Initial (Spark) params: " // System.out.println("Initial (Spark) params: "
+ Arrays.toString(initialSparkParams.data().asFloat())); // + Arrays.toString(initialSparkParams.data().asFloat()));
System.out.println("Final (Local) params: " + Arrays.toString(finalParams.data().asFloat())); // System.out.println("Final (Local) params: " + Arrays.toString(finalParams.data().asFloat()));
System.out.println("Final (Spark) params: " + Arrays.toString(finalSparkParams.data().asFloat())); // System.out.println("Final (Spark) params: " + Arrays.toString(finalSparkParams.data().asFloat()));
assertArrayEquals(initialParams.data().asFloat(), initialSparkParams.data().asFloat(), 1e-8f); assertArrayEquals(initialParams.data().asFloat(), initialSparkParams.data().asFloat(), 1e-8f);
assertArrayEquals(finalParams.data().asFloat(), finalSparkParams.data().asFloat(), 1e-6f); assertArrayEquals(finalParams.data().asFloat(), finalSparkParams.data().asFloat(), 1e-6f);
@ -478,18 +480,19 @@ public class TestCompareParameterAveragingSparkVsSingleMachine {
sparkNet.fit(rdd); sparkNet.fit(rdd);
} }
System.out.println(sparkNet.getSparkTrainingStats().statsAsString()); // System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
sparkNet.getSparkTrainingStats().statsAsString();
INDArray finalSparkParams = sparkNet.getNetwork().params().dup(); INDArray finalSparkParams = sparkNet.getNetwork().params().dup();
// executioner.addToWatchdog(finalSparkParams, "finalSparkParams"); // executioner.addToWatchdog(finalSparkParams, "finalSparkParams");
float[] fp = finalParams.data().asFloat(); float[] fp = finalParams.data().asFloat();
float[] fps = finalSparkParams.data().asFloat(); float[] fps = finalSparkParams.data().asFloat();
System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat())); // System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
System.out.println("Initial (Spark) params: " // System.out.println("Initial (Spark) params: "
+ Arrays.toString(initialSparkParams.data().asFloat())); // + Arrays.toString(initialSparkParams.data().asFloat()));
System.out.println("Final (Local) params: " + Arrays.toString(fp)); // System.out.println("Final (Local) params: " + Arrays.toString(fp));
System.out.println("Final (Spark) params: " + Arrays.toString(fps)); // System.out.println("Final (Spark) params: " + Arrays.toString(fps));
assertEquals(initialParams, initialSparkParams); assertEquals(initialParams, initialSparkParams);
assertNotEquals(initialParams, finalParams); assertNotEquals(initialParams, finalParams);
@ -551,14 +554,15 @@ public class TestCompareParameterAveragingSparkVsSingleMachine {
sparkNet.fit(rdd); sparkNet.fit(rdd);
} }
System.out.println(sparkNet.getSparkTrainingStats().statsAsString()); // System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
sparkNet.getSparkTrainingStats().statsAsString();
INDArray finalSparkParams = sparkNet.getNetwork().params().dup(); INDArray finalSparkParams = sparkNet.getNetwork().params().dup();
System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat())); // System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
System.out.println("Initial (Spark) params: " + Arrays.toString(initialSparkParams.data().asFloat())); // System.out.println("Initial (Spark) params: " + Arrays.toString(initialSparkParams.data().asFloat()));
System.out.println("Final (Local) params: " + Arrays.toString(finalParams.data().asFloat())); // System.out.println("Final (Local) params: " + Arrays.toString(finalParams.data().asFloat()));
System.out.println("Final (Spark) params: " + Arrays.toString(finalSparkParams.data().asFloat())); // System.out.println("Final (Spark) params: " + Arrays.toString(finalSparkParams.data().asFloat()));
assertArrayEquals(initialParams.data().asFloat(), initialSparkParams.data().asFloat(), 1e-8f); assertArrayEquals(initialParams.data().asFloat(), initialSparkParams.data().asFloat(), 1e-8f);
assertArrayEquals(finalParams.data().asFloat(), finalSparkParams.data().asFloat(), 1e-6f); assertArrayEquals(finalParams.data().asFloat(), finalSparkParams.data().asFloat(), 1e-6f);

View File

@ -37,7 +37,7 @@ public class TestJsonYaml {
String json = tm.toJson(); String json = tm.toJson();
String yaml = tm.toYaml(); String yaml = tm.toYaml();
System.out.println(json); // System.out.println(json);
TrainingMaster fromJson = ParameterAveragingTrainingMaster.fromJson(json); TrainingMaster fromJson = ParameterAveragingTrainingMaster.fromJson(json);
TrainingMaster fromYaml = ParameterAveragingTrainingMaster.fromYaml(yaml); TrainingMaster fromYaml = ParameterAveragingTrainingMaster.fromYaml(yaml);

View File

@ -389,7 +389,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
List<EventStats> workerFitStats = stats.getValue("ParameterAveragingWorkerFitTimesMs"); List<EventStats> workerFitStats = stats.getValue("ParameterAveragingWorkerFitTimesMs");
for (EventStats e : workerFitStats) { for (EventStats e : workerFitStats) {
ExampleCountEventStats eces = (ExampleCountEventStats) e; ExampleCountEventStats eces = (ExampleCountEventStats) e;
System.out.println(eces.getTotalExampleCount()); // System.out.println(eces.getTotalExampleCount());
} }
for (EventStats e : workerFitStats) { for (EventStats e : workerFitStats) {
@ -457,7 +457,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
assertNotEquals(paramsBefore, paramsAfter); assertNotEquals(paramsBefore, paramsAfter);
SparkTrainingStats stats = sparkNet.getSparkTrainingStats(); SparkTrainingStats stats = sparkNet.getSparkTrainingStats();
System.out.println(stats.statsAsString()); // System.out.println(stats.statsAsString());
stats.statsAsString();
sparkNet.getTrainingMaster().deleteTempFiles(sc); sparkNet.getTrainingMaster().deleteTempFiles(sc);
} }
@ -483,7 +484,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
i++; i++;
} }
System.out.println("Saved to: " + tempDirF.getAbsolutePath()); // System.out.println("Saved to: " + tempDirF.getAbsolutePath());
@ -527,7 +528,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
SparkTrainingStats stats = sparkNet.getSparkTrainingStats(); SparkTrainingStats stats = sparkNet.getSparkTrainingStats();
//Expect //Expect
System.out.println(stats.statsAsString()); // System.out.println(stats.statsAsString());
stats.statsAsString();
assertEquals(numSplits, stats.getValue("ParameterAveragingMasterRepartitionTimesMs").size()); assertEquals(numSplits, stats.getValue("ParameterAveragingMasterRepartitionTimesMs").size());
List<EventStats> list = stats.getValue("ParameterAveragingWorkerFitTimesMs"); List<EventStats> list = stats.getValue("ParameterAveragingWorkerFitTimesMs");
@ -566,8 +568,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
i++; i++;
} }
System.out.println("Saved to: " + tempDirF.getAbsolutePath()); // System.out.println("Saved to: " + tempDirF.getAbsolutePath());
System.out.println("Saved to: " + tempDirF2.getAbsolutePath()); // System.out.println("Saved to: " + tempDirF2.getAbsolutePath());
@ -610,7 +612,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
assertNotEquals(paramsBefore, paramsAfter); assertNotEquals(paramsBefore, paramsAfter);
SparkTrainingStats stats = sparkNet.getSparkTrainingStats(); SparkTrainingStats stats = sparkNet.getSparkTrainingStats();
System.out.println(stats.statsAsString()); // System.out.println(stats.statsAsString());
stats.statsAsString();
//Same thing, buf for MultiDataSet objects: //Same thing, buf for MultiDataSet objects:
config = new Configuration(); config = new Configuration();
@ -631,7 +634,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
assertNotEquals(paramsBefore, paramsAfter); assertNotEquals(paramsBefore, paramsAfter);
stats = sparkNet.getSparkTrainingStats(); stats = sparkNet.getSparkTrainingStats();
System.out.println(stats.statsAsString()); // System.out.println(stats.statsAsString());
stats.statsAsString();
} }
@ -730,13 +734,13 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
.build(); .build();
for (int avgFreq : new int[] {1, 5, 10}) { for (int avgFreq : new int[] {1, 5, 10}) {
System.out.println("--- Avg freq " + avgFreq + " ---"); // System.out.println("--- Avg freq " + avgFreq + " ---");
SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, conf.clone(), SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, conf.clone(),
new ParameterAveragingTrainingMaster.Builder(numExecutors(), dataSetObjSize) new ParameterAveragingTrainingMaster.Builder(numExecutors(), dataSetObjSize)
.batchSizePerWorker(batchSizePerExecutor).averagingFrequency(avgFreq) .batchSizePerWorker(batchSizePerExecutor).averagingFrequency(avgFreq)
.repartionData(Repartition.Always).build()); .repartionData(Repartition.Always).build());
sparkNet.setListeners(new ScoreIterationListener(1)); sparkNet.setListeners(new ScoreIterationListener(5));
@ -778,13 +782,13 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
.setOutputs("1").build(); .setOutputs("1").build();
for (int avgFreq : new int[] {1, 5, 10}) { for (int avgFreq : new int[] {1, 5, 10}) {
System.out.println("--- Avg freq " + avgFreq + " ---"); // System.out.println("--- Avg freq " + avgFreq + " ---");
SparkComputationGraph sparkNet = new SparkComputationGraph(sc, conf.clone(), SparkComputationGraph sparkNet = new SparkComputationGraph(sc, conf.clone(),
new ParameterAveragingTrainingMaster.Builder(numExecutors(), dataSetObjSize) new ParameterAveragingTrainingMaster.Builder(numExecutors(), dataSetObjSize)
.batchSizePerWorker(batchSizePerExecutor).averagingFrequency(avgFreq) .batchSizePerWorker(batchSizePerExecutor).averagingFrequency(avgFreq)
.repartionData(Repartition.Always).build()); .repartionData(Repartition.Always).build());
sparkNet.setListeners(new ScoreIterationListener(1)); sparkNet.setListeners(new ScoreIterationListener(5));
JavaRDD<DataSet> rdd = sc.parallelize(list); JavaRDD<DataSet> rdd = sc.parallelize(list);

View File

@ -107,7 +107,7 @@ public class TestTrainingStatsCollection extends BaseSparkTest {
expectedStatNames.addAll(c); expectedStatNames.addAll(c);
} }
System.out.println(expectedStatNames); // System.out.println(expectedStatNames);
SparkTrainingStats stats = sparkNet.getSparkTrainingStats(); SparkTrainingStats stats = sparkNet.getSparkTrainingStats();
@ -119,7 +119,7 @@ public class TestTrainingStatsCollection extends BaseSparkTest {
} }
String statsAsString = stats.statsAsString(); String statsAsString = stats.statsAsString();
System.out.println(statsAsString); // System.out.println(statsAsString);
assertEquals(actualKeySet.size(), statsAsString.split("\n").length); //One line per stat assertEquals(actualKeySet.size(), statsAsString.split("\n").length); //One line per stat

View File

@ -35,7 +35,7 @@ public class TestTimeSource {
long systemTime = System.currentTimeMillis(); long systemTime = System.currentTimeMillis();
long ntpTime = timeSource.currentTimeMillis(); long ntpTime = timeSource.currentTimeMillis();
long offset = ntpTime - systemTime; long offset = ntpTime - systemTime;
System.out.println("System: " + systemTime + "\tNTPTimeSource: " + ntpTime + "\tOffset: " + offset); // System.out.println("System: " + systemTime + "\tNTPTimeSource: " + ntpTime + "\tOffset: " + offset);
Thread.sleep(500); Thread.sleep(500);
} }
} }
@ -49,7 +49,7 @@ public class TestTimeSource {
long systemTime = System.currentTimeMillis(); long systemTime = System.currentTimeMillis();
long ntpTime = timeSource.currentTimeMillis(); long ntpTime = timeSource.currentTimeMillis();
long offset = ntpTime - systemTime; long offset = ntpTime - systemTime;
System.out.println("System: " + systemTime + "\tSystemClockTimeSource: " + ntpTime + "\tOffset: " + offset); // System.out.println("System: " + systemTime + "\tSystemClockTimeSource: " + ntpTime + "\tOffset: " + offset);
assertEquals(systemTime, ntpTime, 2); //Should be exact, but we might randomly tick over between one ms and the next assertEquals(systemTime, ntpTime, 2); //Should be exact, but we might randomly tick over between one ms and the next
Thread.sleep(500); Thread.sleep(500);
} }

View File

@ -87,7 +87,7 @@ public class TestListeners extends BaseSparkTest {
net.fit(rdd); net.fit(rdd);
List<String> sessions = ss.listSessionIDs(); List<String> sessions = ss.listSessionIDs();
System.out.println("Sessions: " + sessions); // System.out.println("Sessions: " + sessions);
assertEquals(1, sessions.size()); assertEquals(1, sessions.size());
String sid = sessions.get(0); String sid = sessions.get(0);
@ -95,15 +95,15 @@ public class TestListeners extends BaseSparkTest {
List<String> typeIDs = ss.listTypeIDsForSession(sid); List<String> typeIDs = ss.listTypeIDsForSession(sid);
List<String> workers = ss.listWorkerIDsForSession(sid); List<String> workers = ss.listWorkerIDsForSession(sid);
System.out.println(sid + "\t" + typeIDs + "\t" + workers); // System.out.println(sid + "\t" + typeIDs + "\t" + workers);
List<Persistable> lastUpdates = ss.getLatestUpdateAllWorkers(sid, StatsListener.TYPE_ID); List<Persistable> lastUpdates = ss.getLatestUpdateAllWorkers(sid, StatsListener.TYPE_ID);
System.out.println(lastUpdates); // System.out.println(lastUpdates);
System.out.println("Static info:"); // System.out.println("Static info:");
for (String wid : workers) { for (String wid : workers) {
Persistable staticInfo = ss.getStaticInfo(sid, StatsListener.TYPE_ID, wid); Persistable staticInfo = ss.getStaticInfo(sid, StatsListener.TYPE_ID, wid);
System.out.println(sid + "\t" + wid); // System.out.println(sid + "\t" + wid);
} }
assertEquals(1, typeIDs.size()); assertEquals(1, typeIDs.size());

View File

@ -63,7 +63,7 @@ public class TestRepartitioning extends BaseSparkTest {
assertEquals(10, rdd2.partitions().size()); assertEquals(10, rdd2.partitions().size());
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
List<String> partition = rdd2.collectPartitions(new int[] {i})[0]; List<String> partition = rdd2.collectPartitions(new int[] {i})[0];
System.out.println("Partition " + i + " size: " + partition.size()); // System.out.println("Partition " + i + " size: " + partition.size());
assertEquals(100, partition.size()); //Should be exactly 100, for the util method (but NOT spark .repartition) assertEquals(100, partition.size()); //Should be exactly 100, for the util method (but NOT spark .repartition)
} }
} }
@ -170,7 +170,7 @@ public class TestRepartitioning extends BaseSparkTest {
List<Tuple2<Integer, Integer>> partitionCounts = initial.values().mapPartitionsWithIndex(new CountPartitionsFunction<Integer>(), true).collect(); List<Tuple2<Integer, Integer>> partitionCounts = initial.values().mapPartitionsWithIndex(new CountPartitionsFunction<Integer>(), true).collect();
System.out.println(partitionCounts); // System.out.println(partitionCounts);
List<Tuple2<Integer,Integer>> initialExpected = Arrays.asList( List<Tuple2<Integer,Integer>> initialExpected = Arrays.asList(
new Tuple2<>(0,29), new Tuple2<>(0,29),
@ -185,7 +185,7 @@ public class TestRepartitioning extends BaseSparkTest {
JavaRDD<Integer> afterRepartition = SparkUtils.repartitionBalanceIfRequired(initial.values(), Repartition.Always, 2, 112); JavaRDD<Integer> afterRepartition = SparkUtils.repartitionBalanceIfRequired(initial.values(), Repartition.Always, 2, 112);
List<Tuple2<Integer, Integer>> partitionCountsAfter = afterRepartition.mapPartitionsWithIndex(new CountPartitionsFunction<Integer>(), true).collect(); List<Tuple2<Integer, Integer>> partitionCountsAfter = afterRepartition.mapPartitionsWithIndex(new CountPartitionsFunction<Integer>(), true).collect();
System.out.println(partitionCountsAfter); // System.out.println(partitionCountsAfter);
for(Tuple2<Integer,Integer> t2 : partitionCountsAfter){ for(Tuple2<Integer,Integer> t2 : partitionCountsAfter){
assertEquals(2, (int)t2._2()); assertEquals(2, (int)t2._2());
@ -219,8 +219,8 @@ public class TestRepartitioning extends BaseSparkTest {
} }
} }
System.out.println("min: " + min + "\t@\t" + minIdx); // System.out.println("min: " + min + "\t@\t" + minIdx);
System.out.println("max: " + max + "\t@\t" + maxIdx); // System.out.println("max: " + max + "\t@\t" + maxIdx);
assertEquals(1, min); assertEquals(1, min);
assertEquals(2, max); assertEquals(2, max);
@ -244,7 +244,7 @@ public class TestRepartitioning extends BaseSparkTest {
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
List<String> partition = rdd2.collectPartitions(new int[] {i})[0]; List<String> partition = rdd2.collectPartitions(new int[] {i})[0];
System.out.println("Partition " + i + " size: " + partition.size()); // System.out.println("Partition " + i + " size: " + partition.size());
assertTrue(partition.size() >= 90 && partition.size() <= 110); assertTrue(partition.size() >= 90 && partition.size() <= 110);
} }
} }

View File

@ -123,7 +123,7 @@ trait NDArrayExtractionTestBase extends FlatSpec { self: OrderingForTest =>
val expectedSlice = expectedArray.slice(0) val expectedSlice = expectedArray.slice(0)
val actualSlice = expectedArray(0, ->) val actualSlice = expectedArray(0, ->)
Console.println(expectedSlice) // Console.println(expectedSlice)
assert(actualSlice == expectedSlice) assert(actualSlice == expectedSlice)
} }

View File

@ -28,7 +28,7 @@ class TrainingTest extends FlatSpec with Matchers {
val unused3 = unused1.div(unused2) val unused3 = unused1.div(unused2)
val loss1 = add.std("l1", true) val loss1 = add.std("l1", true)
val loss2 = mmul.mean("l2") val loss2 = mmul.mean("l2")
Console.println(sd.summary) // Console.println(sd.summary)
if (i == 0) { if (i == 0) {
sd.setLossVariables("l1", "l2") sd.setLossVariables("l1", "l2")
sd.createGradFunction() sd.createGradFunction()

View File

@ -43,8 +43,8 @@ public class HistoryProcessorTest {
hp.add(a); hp.add(a);
INDArray[] h = hp.getHistory(); INDArray[] h = hp.getHistory();
assertEquals(4, h.length); assertEquals(4, h.length);
System.out.println(Arrays.toString(a.shape())); // System.out.println(Arrays.toString(a.shape()));
System.out.println(Arrays.toString(h[0].shape())); // System.out.println(Arrays.toString(h[0].shape()));
assertEquals( 1, h[0].shape()[0]); assertEquals( 1, h[0].shape()[0]);
assertEquals(a.shape()[0], h[0].shape()[1]); assertEquals(a.shape()[0], h[0].shape()[1]);
assertEquals(a.shape()[1], h[0].shape()[2]); assertEquals(a.shape()[1], h[0].shape()[2]);

View File

@ -100,8 +100,8 @@ public class ActorCriticTest {
double error2 = gradient2 - gradient.getDouble(1); double error2 = gradient2 - gradient.getDouble(1);
double relError1 = error1 / gradient.getDouble(0); double relError1 = error1 / gradient.getDouble(0);
double relError2 = error2 / gradient.getDouble(1); double relError2 = error2 / gradient.getDouble(1);
System.out.println(gradient.getDouble(0) + " " + gradient1 + " " + relError1); // System.out.println(gradient.getDouble(0) + " " + gradient1 + " " + relError1);
System.out.println(gradient.getDouble(1) + " " + gradient2 + " " + relError2); // System.out.println(gradient.getDouble(1) + " " + gradient2 + " " + relError2);
assertTrue(gradient.getDouble(0) < maxRelError || Math.abs(relError1) < maxRelError); assertTrue(gradient.getDouble(0) < maxRelError || Math.abs(relError1) < maxRelError);
assertTrue(gradient.getDouble(1) < maxRelError || Math.abs(relError2) < maxRelError); assertTrue(gradient.getDouble(1) < maxRelError || Math.abs(relError2) < maxRelError);
} }

View File

@ -158,7 +158,7 @@ public class PolicyTest {
for (int i = 0; i < 100; i++) { for (int i = 0; i < 100; i++) {
count[policy.nextAction(input)]++; count[policy.nextAction(input)]++;
} }
System.out.println(count[0] + " " + count[1] + " " + count[2] + " " + count[3]); // System.out.println(count[0] + " " + count[1] + " " + count[2] + " " + count[3]);
assertTrue(count[0] < 20); assertTrue(count[0] < 20);
assertTrue(count[1] < 30); assertTrue(count[1] < 30);
assertTrue(count[2] < 40); assertTrue(count[2] < 40);