Test fixes + cleanup (#245)
* Test spam reduction Signed-off-by: Alex Black <blacka101@gmail.com> * Arbiter bad import fixes Signed-off-by: Alex Black <blacka101@gmail.com> * Small spark test tweak Signed-off-by: AlexDBlack <blacka101@gmail.com> * Arbiter test log spam reduction Signed-off-by: Alex Black <blacka101@gmail.com> * More test spam reduction Signed-off-by: Alex Black <blacka101@gmail.com>master
parent
2698fbf541
commit
c8882cbfa5
|
@ -1,47 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.util;
|
||||
|
||||
import org.slf4j.Logger;
|
||||
|
||||
import java.awt.*;
|
||||
import java.net.URI;
|
||||
|
||||
/**
|
||||
* Various utilities for webpages and dealing with browsers
|
||||
*/
|
||||
public class WebUtils {
|
||||
|
||||
public static void tryOpenBrowser(String path, Logger log) {
|
||||
try {
|
||||
WebUtils.openBrowser(new URI(path));
|
||||
} catch (Exception e) {
|
||||
log.error("Could not open browser", e);
|
||||
System.out.println("Browser could not be launched automatically.\nUI path: " + path);
|
||||
}
|
||||
}
|
||||
|
||||
public static void openBrowser(URI uri) throws Exception {
|
||||
if (Desktop.isDesktopSupported()) {
|
||||
Desktop.getDesktop().browse(uri);
|
||||
} else {
|
||||
throw new UnsupportedOperationException(
|
||||
"Cannot open browser on this platform: Desktop.isDesktopSupported() == false");
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -127,7 +127,7 @@ public class BraninFunction {
|
|||
BraninConfig candidate = (BraninConfig) c.getValue();
|
||||
|
||||
double score = scoreFunction.score(candidate, null, (Map) null);
|
||||
System.out.println(candidate.getX1() + "\t" + candidate.getX2() + "\t" + score);
|
||||
// System.out.println(candidate.getX1() + "\t" + candidate.getX2() + "\t" + score);
|
||||
|
||||
Thread.sleep(20);
|
||||
|
||||
|
|
|
@ -54,7 +54,7 @@ public class TestRandomSearch extends BaseDL4JTest {
|
|||
runner.execute();
|
||||
|
||||
|
||||
System.out.println("----- Complete -----");
|
||||
// System.out.println("----- Complete -----");
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -16,8 +16,8 @@
|
|||
|
||||
package org.deeplearning4j.arbiter.optimize.genetic;
|
||||
|
||||
import org.apache.commons.lang3.NotImplementedException;
|
||||
import org.apache.commons.math3.random.RandomGenerator;
|
||||
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
|
||||
|
||||
public class TestRandomGenerator implements RandomGenerator {
|
||||
private final int[] intRandomNumbers;
|
||||
|
@ -63,17 +63,17 @@ public class TestRandomGenerator implements RandomGenerator {
|
|||
|
||||
@Override
|
||||
public long nextLong() {
|
||||
throw new NotImplementedException();
|
||||
throw new NotImplementedException("Not implemented");
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean nextBoolean() {
|
||||
throw new NotImplementedException();
|
||||
throw new NotImplementedException("Not implemented");
|
||||
}
|
||||
|
||||
@Override
|
||||
public float nextFloat() {
|
||||
throw new NotImplementedException();
|
||||
throw new NotImplementedException("Not implemented");
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -83,6 +83,6 @@ public class TestRandomGenerator implements RandomGenerator {
|
|||
|
||||
@Override
|
||||
public double nextGaussian() {
|
||||
throw new NotImplementedException();
|
||||
throw new NotImplementedException("Not implemented");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
|
||||
|
||||
import org.apache.commons.lang3.NotImplementedException;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.TwoParentsCrossoverOperator;
|
||||
|
@ -26,7 +27,6 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection;
|
|||
import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
|
||||
|
||||
public class TwoParentsCrossoverOperatorTests extends BaseDL4JTest {
|
||||
|
||||
|
@ -42,7 +42,7 @@ public class TwoParentsCrossoverOperatorTests extends BaseDL4JTest {
|
|||
|
||||
@Override
|
||||
public CrossoverResult crossover() {
|
||||
throw new NotImplementedException();
|
||||
throw new NotImplementedException("Not implemented");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.arbiter.optimize.genetic.culling;
|
||||
|
||||
import org.apache.commons.lang3.NotImplementedException;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.RatioCullOperator;
|
||||
|
@ -24,7 +25,6 @@ import org.deeplearning4j.arbiter.optimize.generator.genetic.population.Populati
|
|||
import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
|
@ -46,7 +46,7 @@ public class RatioCullOperatorTests extends BaseDL4JTest {
|
|||
|
||||
@Override
|
||||
public void cullPopulation() {
|
||||
throw new NotImplementedException();
|
||||
throw new NotImplementedException("Not implemented");
|
||||
}
|
||||
|
||||
public double getCullRatio() {
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.arbiter.optimize.genetic.selection;
|
||||
|
||||
import org.apache.commons.lang3.NotImplementedException;
|
||||
import org.apache.commons.math3.random.RandomGenerator;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
|
||||
|
@ -33,7 +34,6 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer;
|
|||
import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
|
||||
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
|
||||
|
@ -55,7 +55,7 @@ public class GeneticSelectionOperatorTests extends BaseDL4JTest {
|
|||
|
||||
@Override
|
||||
public void cullPopulation() {
|
||||
throw new NotImplementedException();
|
||||
throw new NotImplementedException("Not implemented");
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.arbiter.optimize.genetic.selection;
|
||||
|
||||
import org.apache.commons.lang3.NotImplementedException;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
|
||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
|
||||
|
@ -24,7 +25,6 @@ import org.deeplearning4j.arbiter.optimize.generator.genetic.selection.Selection
|
|||
import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
|
||||
|
||||
public class SelectionOperatorTests extends BaseDL4JTest {
|
||||
private class TestSelectionOperator extends SelectionOperator {
|
||||
|
@ -39,7 +39,7 @@ public class SelectionOperatorTests extends BaseDL4JTest {
|
|||
|
||||
@Override
|
||||
public double[] buildNextGenes() {
|
||||
throw new NotImplementedException();
|
||||
throw new NotImplementedException("Not implemented");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -158,7 +158,7 @@ public class TestComputationGraphSpace extends BaseDL4JTest {
|
|||
}
|
||||
}
|
||||
|
||||
System.out.println("ReLU vs. Tanh: " + reluCount + "\t" + tanhCount);
|
||||
// System.out.println("ReLU vs. Tanh: " + reluCount + "\t" + tanhCount);
|
||||
assertTrue(reluCount > 0);
|
||||
assertTrue(tanhCount > 0);
|
||||
|
||||
|
|
|
@ -162,7 +162,7 @@ public class TestGraphLocalExecution extends BaseDL4JTest {
|
|||
List<ResultReference> results = runner.getResults();
|
||||
assertTrue(results.size() > 0);
|
||||
|
||||
System.out.println("----- COMPLETE - " + results.size() + " results -----");
|
||||
// System.out.println("----- COMPLETE - " + results.size() + " results -----");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -165,7 +165,7 @@ public class TestGraphLocalExecutionGenetic extends BaseDL4JTest {
|
|||
List<ResultReference> results = runner.getResults();
|
||||
assertTrue(results.size() > 0);
|
||||
|
||||
System.out.println("----- COMPLETE - " + results.size() + " results -----");
|
||||
// System.out.println("----- COMPLETE - " + results.size() + " results -----");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -101,7 +101,7 @@ public class TestLayerSpace extends BaseDL4JTest {
|
|||
double l2 = TestUtils.getL2(l);
|
||||
IActivation activation = l.getActivationFn();
|
||||
|
||||
System.out.println(lr + "\t" + l2 + "\t" + activation);
|
||||
// System.out.println(lr + "\t" + l2 + "\t" + activation);
|
||||
|
||||
assertTrue(lr >= 0.3 && lr <= 0.4);
|
||||
assertTrue(l2 >= 0.01 && l2 <= 0.1);
|
||||
|
@ -190,7 +190,7 @@ public class TestLayerSpace extends BaseDL4JTest {
|
|||
ActivationLayer al = als.getValue(d);
|
||||
IActivation activation = al.getActivationFn();
|
||||
|
||||
System.out.println(activation);
|
||||
// System.out.println(activation);
|
||||
|
||||
assertTrue(containsActivationFunction(actFns, activation));
|
||||
}
|
||||
|
@ -228,7 +228,7 @@ public class TestLayerSpace extends BaseDL4JTest {
|
|||
IActivation activation = el.getActivationFn();
|
||||
long nOut = el.getNOut();
|
||||
|
||||
System.out.println(activation + "\t" + nOut);
|
||||
// System.out.println(activation + "\t" + nOut);
|
||||
|
||||
assertTrue(containsActivationFunction(actFns, activation));
|
||||
assertTrue(nOut >= 10 && nOut <= 20);
|
||||
|
@ -295,7 +295,7 @@ public class TestLayerSpace extends BaseDL4JTest {
|
|||
long nOut = el.getNOut();
|
||||
double forgetGate = el.getForgetGateBiasInit();
|
||||
|
||||
System.out.println(activation + "\t" + nOut + "\t" + forgetGate);
|
||||
// System.out.println(activation + "\t" + nOut + "\t" + forgetGate);
|
||||
|
||||
assertTrue(containsActivationFunction(actFns, activation));
|
||||
assertTrue(nOut >= 10 && nOut <= 20);
|
||||
|
|
|
@ -293,8 +293,8 @@ public class TestMultiLayerSpace extends BaseDL4JTest {
|
|||
assertTrue(nLayerCounts[i] >= 5); //Expect approx equal (50/3 each), but some variation randomly
|
||||
}
|
||||
|
||||
System.out.println("Number of layers: " + Arrays.toString(nLayerCounts));
|
||||
System.out.println("ReLU vs. Tanh: " + reluCount + "\t" + tanhCount);
|
||||
// System.out.println("Number of layers: " + Arrays.toString(nLayerCounts));
|
||||
// System.out.println("ReLU vs. Tanh: " + reluCount + "\t" + tanhCount);
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -98,7 +98,8 @@ public class ArbiterCLIRunnerTest extends BaseDL4JTest {
|
|||
assertEquals(configuration,OptimizationConfiguration.fromJson(configuration.toJson()));
|
||||
|
||||
FileUtils.writeStringToFile(new File(configPath),configuration.toJson());
|
||||
System.out.println(configuration.toJson());
|
||||
// System.out.println(configuration.toJson());
|
||||
configuration.toJson();
|
||||
|
||||
log.info("Starting test");
|
||||
cliRunner.runMain(
|
||||
|
|
|
@ -41,7 +41,7 @@ public class TestGraphLoading extends BaseDL4JTest {
|
|||
|
||||
IGraph<String, String> graph = GraphLoader
|
||||
.loadUndirectedGraphEdgeListFile(cpr.getTempFileFromArchive().getAbsolutePath(), 7, ",");
|
||||
System.out.println(graph);
|
||||
// System.out.println(graph);
|
||||
|
||||
assertEquals(graph.numVertices(), 7);
|
||||
int[][] edges = {{1, 2}, {0, 2, 4}, {0, 1, 3, 4}, {2, 4, 5}, {1, 2, 3, 5, 6}, {3, 4, 6}, {4, 5}};
|
||||
|
@ -66,7 +66,7 @@ public class TestGraphLoading extends BaseDL4JTest {
|
|||
edgeLineProcessor, vertexFactory, 10, false);
|
||||
|
||||
|
||||
System.out.println(graph);
|
||||
// System.out.println(graph);
|
||||
|
||||
for (int i = 0; i < 10; i++) {
|
||||
List<Edge<String>> edges = graph.getEdgesOut(i);
|
||||
|
@ -111,7 +111,7 @@ public class TestGraphLoading extends BaseDL4JTest {
|
|||
Graph<String, String> graph = GraphLoader.loadGraph(verticesCPR.getTempFileFromArchive().getAbsolutePath(),
|
||||
edgesCPR.getTempFileFromArchive().getAbsolutePath(), vertexLoader, edgeLineProcessor, false);
|
||||
|
||||
System.out.println(graph);
|
||||
// System.out.println(graph);
|
||||
|
||||
for (int i = 0; i < 10; i++) {
|
||||
List<Edge<String>> edges = graph.getEdgesOut(i);
|
||||
|
|
|
@ -71,7 +71,7 @@ public class TestGraphLoadingWeighted extends BaseDL4JTest {
|
|||
}
|
||||
}
|
||||
|
||||
System.out.println(graph);
|
||||
// System.out.println(graph);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -220,7 +220,7 @@ public class TestGraph extends BaseDL4JTest {
|
|||
sum += transitionProb[i][j];
|
||||
for (int j = 0; j < transitionProb[i].length; j++)
|
||||
transitionProb[i][j] /= sum;
|
||||
System.out.println(Arrays.toString(transitionProb[i]));
|
||||
// System.out.println(Arrays.toString(transitionProb[i]));
|
||||
}
|
||||
|
||||
//Check that transition probs are essentially correct (within bounds of random variation)
|
||||
|
|
|
@ -145,8 +145,8 @@ public class DeepWalkGradientCheck extends BaseDL4JTest {
|
|||
|
||||
if (relError > MAX_REL_ERROR && absErr > MIN_ABS_ERROR)
|
||||
fail(msg);
|
||||
else
|
||||
System.out.println(msg);
|
||||
// else
|
||||
// System.out.println(msg);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -333,10 +333,10 @@ public class DeepWalkGradientCheck extends BaseDL4JTest {
|
|||
|
||||
if (relError > MAX_REL_ERROR && absErr > MIN_ABS_ERROR)
|
||||
fail(msg);
|
||||
else
|
||||
System.out.println(msg);
|
||||
// else
|
||||
// System.out.println(msg);
|
||||
}
|
||||
System.out.println();
|
||||
// System.out.println();
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -67,7 +67,7 @@ public class TestDeepWalk extends BaseDL4JTest {
|
|||
for (int i = 0; i < 7; i++) {
|
||||
INDArray vector = deepWalk.getVertexVector(i);
|
||||
assertArrayEquals(new long[] {vectorSize}, vector.shape());
|
||||
System.out.println(Arrays.toString(vector.dup().data().asFloat()));
|
||||
// System.out.println(Arrays.toString(vector.dup().data().asFloat()));
|
||||
}
|
||||
|
||||
GraphWalkIterator<String> iter = new RandomWalkIterator<>(graph, 8);
|
||||
|
@ -77,11 +77,11 @@ public class TestDeepWalk extends BaseDL4JTest {
|
|||
for (int t = 0; t < 5; t++) {
|
||||
iter.reset();
|
||||
deepWalk.fit(iter);
|
||||
System.out.println("--------------------");
|
||||
// System.out.println("--------------------");
|
||||
for (int i = 0; i < 7; i++) {
|
||||
INDArray vector = deepWalk.getVertexVector(i);
|
||||
assertArrayEquals(new long[] {vectorSize}, vector.shape());
|
||||
System.out.println(Arrays.toString(vector.dup().data().asFloat()));
|
||||
// System.out.println(Arrays.toString(vector.dup().data().asFloat()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -160,7 +160,7 @@ public class TestDeepWalk extends BaseDL4JTest {
|
|||
continue;
|
||||
|
||||
double sim = deepWalk.similarity(i, nearestTo);
|
||||
System.out.println(i + "\t" + nearestTo + "\t" + sim);
|
||||
// System.out.println(i + "\t" + nearestTo + "\t" + sim);
|
||||
assertTrue(sim <= minSimNearest);
|
||||
}
|
||||
}
|
||||
|
@ -211,7 +211,7 @@ public class TestDeepWalk extends BaseDL4JTest {
|
|||
Graph<String, String> graph = GraphLoader
|
||||
.loadUndirectedGraphEdgeListFile(cpr.getTempFileFromArchive().getAbsolutePath(), 13, ",");
|
||||
|
||||
System.out.println(graph);
|
||||
// System.out.println(graph);
|
||||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
|
@ -229,11 +229,13 @@ public class TestDeepWalk extends BaseDL4JTest {
|
|||
|
||||
//Calculate similarity(0,i)
|
||||
for (int i = 0; i < nVertices; i++) {
|
||||
System.out.println(deepWalk.similarity(0, i));
|
||||
// System.out.println(deepWalk.similarity(0, i));
|
||||
deepWalk.similarity(0, i);
|
||||
}
|
||||
|
||||
for (int i = 0; i < nVertices; i++)
|
||||
System.out.println(deepWalk.getVertexVector(i));
|
||||
// System.out.println(deepWalk.getVertexVector(i));
|
||||
deepWalk.getVertexVector(i);
|
||||
}
|
||||
|
||||
@Test(timeout = 60000L)
|
||||
|
|
|
@ -38,9 +38,11 @@ public class TestGraphHuffman extends BaseDL4JTest {
|
|||
|
||||
gh.buildTree(vertexDegrees);
|
||||
|
||||
for (int i = 0; i < 7; i++)
|
||||
System.out.println(i + "\t" + gh.getCodeLength(i) + "\t" + gh.getCodeString(i) + "\t\t" + gh.getCode(i)
|
||||
+ "\t\t" + Arrays.toString(gh.getPathInnerNodes(i)));
|
||||
for (int i = 0; i < 7; i++) {
|
||||
String s = i + "\t" + gh.getCodeLength(i) + "\t" + gh.getCodeString(i) + "\t\t" + gh.getCode(i)
|
||||
+ "\t\t" + Arrays.toString(gh.getPathInnerNodes(i));
|
||||
// System.out.println(s);
|
||||
}
|
||||
|
||||
int[] expectedLengths = {3, 2, 2, 5, 4, 2, 5};
|
||||
for (int i = 0; i < vertexDegrees.length; i++) {
|
||||
|
|
|
@ -79,8 +79,9 @@ public class ParameterServerParallelWrapperTest extends BaseDL4JTest {
|
|||
model.init();
|
||||
|
||||
ParallelWrapper parameterServerParallelWrapper =
|
||||
new ParallelWrapper.Builder(model).trainerFactory(new ParameterServerTrainerContext())
|
||||
.workers(Runtime.getRuntime().availableProcessors())
|
||||
new ParallelWrapper.Builder(model)
|
||||
.workers(Math.min(4, Runtime.getRuntime().availableProcessors()))
|
||||
.trainerFactory(new ParameterServerTrainerContext())
|
||||
.reportScoreAfterAveraging(true).prefetchBuffer(3).build();
|
||||
parameterServerParallelWrapper.fit(mnistTrain);
|
||||
|
||||
|
|
|
@ -104,7 +104,7 @@ public class SparkWord2VecTest extends BaseDL4JTest {
|
|||
public void call(ExportContainer<VocabWord> v) throws Exception {
|
||||
assertNotNull(v.getElement());
|
||||
assertNotNull(v.getArray());
|
||||
System.out.println(v.getElement() + " - " + v.getArray());
|
||||
// System.out.println(v.getElement() + " - " + v.getArray());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -66,7 +66,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest {
|
|||
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||
.build();
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
net.setListeners(new ScoreIterationListener(1));
|
||||
net.setListeners(new ScoreIterationListener(5));
|
||||
|
||||
|
||||
JavaRDD<DataSet> irisData = getIris();
|
||||
|
@ -119,7 +119,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest {
|
|||
.lossFunction(LossFunctions.LossFunction.MSE).build())
|
||||
.build();
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
net.setListeners(new ScoreIterationListener(1));
|
||||
net.setListeners(new ScoreIterationListener(5));
|
||||
|
||||
JavaRDD<DataSet> irisData = getIris();
|
||||
EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
|
||||
|
@ -155,7 +155,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest {
|
|||
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||
.build();
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
net.setListeners(new ScoreIterationListener(1));
|
||||
net.setListeners(new ScoreIterationListener(5));
|
||||
|
||||
JavaRDD<DataSet> irisData = getIris();
|
||||
|
||||
|
@ -198,7 +198,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest {
|
|||
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||
.build();
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
net.setListeners(new ScoreIterationListener(1));
|
||||
net.setListeners(new ScoreIterationListener(5));
|
||||
|
||||
JavaRDD<DataSet> irisData = getIris();
|
||||
|
||||
|
@ -231,7 +231,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest {
|
|||
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||
.build();
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
net.setListeners(new ScoreIterationListener(1));
|
||||
net.setListeners(new ScoreIterationListener(5));
|
||||
|
||||
|
||||
JavaRDD<DataSet> irisData = getIris();
|
||||
|
|
|
@ -69,7 +69,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest {
|
|||
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
|
||||
.setOutputs("0").build();
|
||||
ComputationGraph net = new ComputationGraph(conf);
|
||||
net.setListeners(new ScoreIterationListener(1));
|
||||
net.setListeners(new ScoreIterationListener(5));
|
||||
|
||||
|
||||
JavaRDD<DataSet> irisData = getIris();
|
||||
|
@ -120,7 +120,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest {
|
|||
.lossFunction(LossFunctions.LossFunction.MSE).build(), "in")
|
||||
.setOutputs("0").build();
|
||||
ComputationGraph net = new ComputationGraph(conf);
|
||||
net.setListeners(new ScoreIterationListener(1));
|
||||
net.setListeners(new ScoreIterationListener(5));
|
||||
|
||||
JavaRDD<DataSet> irisData = getIris();
|
||||
EarlyStoppingModelSaver<ComputationGraph> saver = new InMemoryModelSaver<>();
|
||||
|
@ -158,7 +158,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest {
|
|||
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
|
||||
.setOutputs("0").build();
|
||||
ComputationGraph net = new ComputationGraph(conf);
|
||||
net.setListeners(new ScoreIterationListener(1));
|
||||
net.setListeners(new ScoreIterationListener(5));
|
||||
|
||||
JavaRDD<DataSet> irisData = getIris();
|
||||
|
||||
|
@ -203,7 +203,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest {
|
|||
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
|
||||
.setOutputs("0").build();
|
||||
ComputationGraph net = new ComputationGraph(conf);
|
||||
net.setListeners(new ScoreIterationListener(1));
|
||||
net.setListeners(new ScoreIterationListener(5));
|
||||
|
||||
JavaRDD<DataSet> irisData = getIris();
|
||||
|
||||
|
@ -238,7 +238,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest {
|
|||
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
|
||||
.setOutputs("0").build();
|
||||
ComputationGraph net = new ComputationGraph(conf);
|
||||
net.setListeners(new ScoreIterationListener(1));
|
||||
net.setListeners(new ScoreIterationListener(5));
|
||||
|
||||
|
||||
JavaRDD<DataSet> irisData = getIris();
|
||||
|
|
|
@ -59,7 +59,7 @@ public class TestShuffleExamples extends BaseSparkTest {
|
|||
int totalExampleCount = 0;
|
||||
for (DataSet ds : shuffledList) {
|
||||
totalExampleCount += ds.getFeatures().length();
|
||||
System.out.println(Arrays.toString(ds.getFeatures().data().asFloat()));
|
||||
// System.out.println(Arrays.toString(ds.getFeatures().data().asFloat()));
|
||||
|
||||
assertEquals(ds.getFeatures(), ds.getLabels());
|
||||
}
|
||||
|
|
|
@ -86,7 +86,7 @@ public class TestExport extends BaseSparkTest {
|
|||
for (File file : files) {
|
||||
if (!file.getPath().endsWith(".bin"))
|
||||
continue;
|
||||
System.out.println(file);
|
||||
// System.out.println(file);
|
||||
DataSet ds = new DataSet();
|
||||
ds.load(file);
|
||||
assertEquals(minibatchSize, ds.numExamples());
|
||||
|
@ -144,7 +144,7 @@ public class TestExport extends BaseSparkTest {
|
|||
for (File file : files) {
|
||||
if (!file.getPath().endsWith(".bin"))
|
||||
continue;
|
||||
System.out.println(file);
|
||||
// System.out.println(file);
|
||||
MultiDataSet ds = new org.nd4j.linalg.dataset.MultiDataSet();
|
||||
ds.load(file);
|
||||
assertEquals(minibatchSize, ds.getFeatures(0).size(0));
|
||||
|
|
|
@ -92,9 +92,9 @@ public class HashingBalancedPartitionerTest extends BaseSparkTest {
|
|||
|
||||
int[][] colorCountsByPartition = new int[3][2];
|
||||
for (final Tuple2<Tuple2<Long, Integer>, String> val : testList) {
|
||||
System.out.println(val);
|
||||
// System.out.println(val);
|
||||
Integer partition = hbp.getPartition(val._1());
|
||||
System.out.println(partition);
|
||||
// System.out.println(partition);
|
||||
|
||||
if (val._2().equals("red"))
|
||||
colorCountsByPartition[partition][0] += 1;
|
||||
|
@ -102,9 +102,9 @@ public class HashingBalancedPartitionerTest extends BaseSparkTest {
|
|||
colorCountsByPartition[partition][1] += 1;
|
||||
}
|
||||
|
||||
for (int i = 0; i < 3; i++) {
|
||||
System.out.println(Arrays.toString(colorCountsByPartition[i]));
|
||||
}
|
||||
// for (int i = 0; i < 3; i++) {
|
||||
// System.out.println(Arrays.toString(colorCountsByPartition[i]));
|
||||
// }
|
||||
for (int i = 0; i < 3; i++) {
|
||||
// avg red per partition : 2.33
|
||||
assertTrue(colorCountsByPartition[i][0] >= 1 && colorCountsByPartition[i][0] < 4);
|
||||
|
@ -178,12 +178,12 @@ public class HashingBalancedPartitionerTest extends BaseSparkTest {
|
|||
colorCountsByPartition[partition][1] += 1;
|
||||
}
|
||||
|
||||
for (int i = 0; i < numPartitions; i++) {
|
||||
System.out.println(Arrays.toString(colorCountsByPartition[i]));
|
||||
}
|
||||
|
||||
System.out.println("Ideal red # per partition: " + avgRed);
|
||||
System.out.println("Ideal blue # per partition: " + avgBlue);
|
||||
// for (int i = 0; i < numPartitions; i++) {
|
||||
// System.out.println(Arrays.toString(colorCountsByPartition[i]));
|
||||
// }
|
||||
//
|
||||
// System.out.println("Ideal red # per partition: " + avgRed);
|
||||
// System.out.println("Ideal blue # per partition: " + avgBlue);
|
||||
|
||||
for (int i = 0; i < numPartitions; i++) {
|
||||
// avg red per partition : 2.33
|
||||
|
|
|
@ -115,7 +115,7 @@ public class TestSparkComputationGraph extends BaseSparkTest {
|
|||
TrainingMaster tm = new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 10, 1, 0);
|
||||
|
||||
SparkComputationGraph scg = new SparkComputationGraph(sc, cg, tm);
|
||||
scg.setListeners(Collections.singleton((TrainingListener) new ScoreIterationListener(1)));
|
||||
scg.setListeners(Collections.singleton((TrainingListener) new ScoreIterationListener(5)));
|
||||
|
||||
JavaRDD<MultiDataSet> rdd = sc.parallelize(list);
|
||||
scg.fitMultiDataSet(rdd);
|
||||
|
|
|
@ -31,8 +31,11 @@ import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster;
|
|||
import org.junit.Test;
|
||||
import org.nd4j.evaluation.classification.Evaluation;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.dataset.DataSet;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.learning.config.Adam;
|
||||
import org.nd4j.linalg.learning.config.Nesterovs;
|
||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||
|
||||
|
@ -45,8 +48,24 @@ import static org.junit.Assert.assertTrue;
|
|||
@Slf4j
|
||||
public class TestSparkDl4jMultiLayer extends BaseSparkTest {
|
||||
|
||||
@Test(timeout = 120000L)
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 120000L;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType getDataType() {
|
||||
return DataType.FLOAT;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType getDefaultFPDataType() {
|
||||
return DataType.FLOAT;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEvaluationSimple() throws Exception {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
for( int evalWorkers : new int[]{1, 4, 8}) {
|
||||
//Simple test to validate DL4J issue 4099 is fixed...
|
||||
|
@ -75,18 +94,18 @@ public class TestSparkDl4jMultiLayer extends BaseSparkTest {
|
|||
//----------------------------------
|
||||
//Create network configuration and conduct network training
|
||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||
.dataType(DataType.FLOAT)
|
||||
.seed(12345)
|
||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
||||
.activation(Activation.LEAKYRELU)
|
||||
.weightInit(WeightInit.XAVIER)
|
||||
.updater(new Nesterovs(0.02, 0.9))
|
||||
.l2(1e-4)
|
||||
.updater(new Adam(1e-3))
|
||||
.l2(1e-5)
|
||||
.list()
|
||||
.layer(0, new DenseLayer.Builder().nIn(28 * 28).nOut(500).build())
|
||||
.layer(1, new DenseLayer.Builder().nIn(500).nOut(100).build())
|
||||
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
|
||||
.activation(Activation.SOFTMAX).nIn(100).nOut(10).build())
|
||||
|
||||
.build();
|
||||
|
||||
//Configuration for Spark training: see https://deeplearning4j.org/docs/latest/deeplearning4j-scaleout-howto for explanation of these configuration options
|
||||
|
|
|
@ -333,15 +333,16 @@ public class TestCompareParameterAveragingSparkVsSingleMachine {
|
|||
sparkNet.fit(rdd);
|
||||
}
|
||||
|
||||
System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
|
||||
// System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
|
||||
sparkNet.getSparkTrainingStats().statsAsString();
|
||||
|
||||
INDArray finalSparkParams = sparkNet.getNetwork().params().dup();
|
||||
|
||||
System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
|
||||
System.out.println("Initial (Spark) params: "
|
||||
+ Arrays.toString(initialSparkParams.data().asFloat()));
|
||||
System.out.println("Final (Local) params: " + Arrays.toString(finalParams.data().asFloat()));
|
||||
System.out.println("Final (Spark) params: " + Arrays.toString(finalSparkParams.data().asFloat()));
|
||||
// System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
|
||||
// System.out.println("Initial (Spark) params: "
|
||||
// + Arrays.toString(initialSparkParams.data().asFloat()));
|
||||
// System.out.println("Final (Local) params: " + Arrays.toString(finalParams.data().asFloat()));
|
||||
// System.out.println("Final (Spark) params: " + Arrays.toString(finalSparkParams.data().asFloat()));
|
||||
assertEquals(initialParams, initialSparkParams);
|
||||
assertNotEquals(initialParams, finalParams);
|
||||
assertEquals(finalParams, finalSparkParams);
|
||||
|
@ -405,15 +406,16 @@ public class TestCompareParameterAveragingSparkVsSingleMachine {
|
|||
sparkNet.fit(rdd);
|
||||
}
|
||||
|
||||
System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
|
||||
// System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
|
||||
sparkNet.getSparkTrainingStats().statsAsString();
|
||||
|
||||
INDArray finalSparkParams = sparkNet.getNetwork().params().dup();
|
||||
|
||||
System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
|
||||
System.out.println("Initial (Spark) params: "
|
||||
+ Arrays.toString(initialSparkParams.data().asFloat()));
|
||||
System.out.println("Final (Local) params: " + Arrays.toString(finalParams.data().asFloat()));
|
||||
System.out.println("Final (Spark) params: " + Arrays.toString(finalSparkParams.data().asFloat()));
|
||||
// System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
|
||||
// System.out.println("Initial (Spark) params: "
|
||||
// + Arrays.toString(initialSparkParams.data().asFloat()));
|
||||
// System.out.println("Final (Local) params: " + Arrays.toString(finalParams.data().asFloat()));
|
||||
// System.out.println("Final (Spark) params: " + Arrays.toString(finalSparkParams.data().asFloat()));
|
||||
assertArrayEquals(initialParams.data().asFloat(), initialSparkParams.data().asFloat(), 1e-8f);
|
||||
assertArrayEquals(finalParams.data().asFloat(), finalSparkParams.data().asFloat(), 1e-6f);
|
||||
|
||||
|
@ -478,18 +480,19 @@ public class TestCompareParameterAveragingSparkVsSingleMachine {
|
|||
sparkNet.fit(rdd);
|
||||
}
|
||||
|
||||
System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
|
||||
// System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
|
||||
sparkNet.getSparkTrainingStats().statsAsString();
|
||||
|
||||
INDArray finalSparkParams = sparkNet.getNetwork().params().dup();
|
||||
// executioner.addToWatchdog(finalSparkParams, "finalSparkParams");
|
||||
|
||||
float[] fp = finalParams.data().asFloat();
|
||||
float[] fps = finalSparkParams.data().asFloat();
|
||||
System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
|
||||
System.out.println("Initial (Spark) params: "
|
||||
+ Arrays.toString(initialSparkParams.data().asFloat()));
|
||||
System.out.println("Final (Local) params: " + Arrays.toString(fp));
|
||||
System.out.println("Final (Spark) params: " + Arrays.toString(fps));
|
||||
// System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
|
||||
// System.out.println("Initial (Spark) params: "
|
||||
// + Arrays.toString(initialSparkParams.data().asFloat()));
|
||||
// System.out.println("Final (Local) params: " + Arrays.toString(fp));
|
||||
// System.out.println("Final (Spark) params: " + Arrays.toString(fps));
|
||||
|
||||
assertEquals(initialParams, initialSparkParams);
|
||||
assertNotEquals(initialParams, finalParams);
|
||||
|
@ -551,14 +554,15 @@ public class TestCompareParameterAveragingSparkVsSingleMachine {
|
|||
sparkNet.fit(rdd);
|
||||
}
|
||||
|
||||
System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
|
||||
// System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
|
||||
sparkNet.getSparkTrainingStats().statsAsString();
|
||||
|
||||
INDArray finalSparkParams = sparkNet.getNetwork().params().dup();
|
||||
|
||||
System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
|
||||
System.out.println("Initial (Spark) params: " + Arrays.toString(initialSparkParams.data().asFloat()));
|
||||
System.out.println("Final (Local) params: " + Arrays.toString(finalParams.data().asFloat()));
|
||||
System.out.println("Final (Spark) params: " + Arrays.toString(finalSparkParams.data().asFloat()));
|
||||
// System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
|
||||
// System.out.println("Initial (Spark) params: " + Arrays.toString(initialSparkParams.data().asFloat()));
|
||||
// System.out.println("Final (Local) params: " + Arrays.toString(finalParams.data().asFloat()));
|
||||
// System.out.println("Final (Spark) params: " + Arrays.toString(finalSparkParams.data().asFloat()));
|
||||
assertArrayEquals(initialParams.data().asFloat(), initialSparkParams.data().asFloat(), 1e-8f);
|
||||
assertArrayEquals(finalParams.data().asFloat(), finalSparkParams.data().asFloat(), 1e-6f);
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ public class TestJsonYaml {
|
|||
String json = tm.toJson();
|
||||
String yaml = tm.toYaml();
|
||||
|
||||
System.out.println(json);
|
||||
// System.out.println(json);
|
||||
|
||||
TrainingMaster fromJson = ParameterAveragingTrainingMaster.fromJson(json);
|
||||
TrainingMaster fromYaml = ParameterAveragingTrainingMaster.fromYaml(yaml);
|
||||
|
|
|
@ -389,7 +389,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
|||
List<EventStats> workerFitStats = stats.getValue("ParameterAveragingWorkerFitTimesMs");
|
||||
for (EventStats e : workerFitStats) {
|
||||
ExampleCountEventStats eces = (ExampleCountEventStats) e;
|
||||
System.out.println(eces.getTotalExampleCount());
|
||||
// System.out.println(eces.getTotalExampleCount());
|
||||
}
|
||||
|
||||
for (EventStats e : workerFitStats) {
|
||||
|
@ -457,7 +457,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
|||
assertNotEquals(paramsBefore, paramsAfter);
|
||||
|
||||
SparkTrainingStats stats = sparkNet.getSparkTrainingStats();
|
||||
System.out.println(stats.statsAsString());
|
||||
// System.out.println(stats.statsAsString());
|
||||
stats.statsAsString();
|
||||
|
||||
sparkNet.getTrainingMaster().deleteTempFiles(sc);
|
||||
}
|
||||
|
@ -483,7 +484,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
|||
i++;
|
||||
}
|
||||
|
||||
System.out.println("Saved to: " + tempDirF.getAbsolutePath());
|
||||
// System.out.println("Saved to: " + tempDirF.getAbsolutePath());
|
||||
|
||||
|
||||
|
||||
|
@ -527,7 +528,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
|||
SparkTrainingStats stats = sparkNet.getSparkTrainingStats();
|
||||
|
||||
//Expect
|
||||
System.out.println(stats.statsAsString());
|
||||
// System.out.println(stats.statsAsString());
|
||||
stats.statsAsString();
|
||||
assertEquals(numSplits, stats.getValue("ParameterAveragingMasterRepartitionTimesMs").size());
|
||||
|
||||
List<EventStats> list = stats.getValue("ParameterAveragingWorkerFitTimesMs");
|
||||
|
@ -566,8 +568,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
|||
i++;
|
||||
}
|
||||
|
||||
System.out.println("Saved to: " + tempDirF.getAbsolutePath());
|
||||
System.out.println("Saved to: " + tempDirF2.getAbsolutePath());
|
||||
// System.out.println("Saved to: " + tempDirF.getAbsolutePath());
|
||||
// System.out.println("Saved to: " + tempDirF2.getAbsolutePath());
|
||||
|
||||
|
||||
|
||||
|
@ -610,7 +612,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
|||
assertNotEquals(paramsBefore, paramsAfter);
|
||||
|
||||
SparkTrainingStats stats = sparkNet.getSparkTrainingStats();
|
||||
System.out.println(stats.statsAsString());
|
||||
// System.out.println(stats.statsAsString());
|
||||
stats.statsAsString();
|
||||
|
||||
//Same thing, buf for MultiDataSet objects:
|
||||
config = new Configuration();
|
||||
|
@ -631,7 +634,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
|||
assertNotEquals(paramsBefore, paramsAfter);
|
||||
|
||||
stats = sparkNet.getSparkTrainingStats();
|
||||
System.out.println(stats.statsAsString());
|
||||
// System.out.println(stats.statsAsString());
|
||||
stats.statsAsString();
|
||||
}
|
||||
|
||||
|
||||
|
@ -730,13 +734,13 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
|||
.build();
|
||||
|
||||
for (int avgFreq : new int[] {1, 5, 10}) {
|
||||
System.out.println("--- Avg freq " + avgFreq + " ---");
|
||||
// System.out.println("--- Avg freq " + avgFreq + " ---");
|
||||
SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, conf.clone(),
|
||||
new ParameterAveragingTrainingMaster.Builder(numExecutors(), dataSetObjSize)
|
||||
.batchSizePerWorker(batchSizePerExecutor).averagingFrequency(avgFreq)
|
||||
.repartionData(Repartition.Always).build());
|
||||
|
||||
sparkNet.setListeners(new ScoreIterationListener(1));
|
||||
sparkNet.setListeners(new ScoreIterationListener(5));
|
||||
|
||||
|
||||
|
||||
|
@ -778,13 +782,13 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
|||
.setOutputs("1").build();
|
||||
|
||||
for (int avgFreq : new int[] {1, 5, 10}) {
|
||||
System.out.println("--- Avg freq " + avgFreq + " ---");
|
||||
// System.out.println("--- Avg freq " + avgFreq + " ---");
|
||||
SparkComputationGraph sparkNet = new SparkComputationGraph(sc, conf.clone(),
|
||||
new ParameterAveragingTrainingMaster.Builder(numExecutors(), dataSetObjSize)
|
||||
.batchSizePerWorker(batchSizePerExecutor).averagingFrequency(avgFreq)
|
||||
.repartionData(Repartition.Always).build());
|
||||
|
||||
sparkNet.setListeners(new ScoreIterationListener(1));
|
||||
sparkNet.setListeners(new ScoreIterationListener(5));
|
||||
|
||||
JavaRDD<DataSet> rdd = sc.parallelize(list);
|
||||
|
||||
|
|
|
@ -107,7 +107,7 @@ public class TestTrainingStatsCollection extends BaseSparkTest {
|
|||
expectedStatNames.addAll(c);
|
||||
}
|
||||
|
||||
System.out.println(expectedStatNames);
|
||||
// System.out.println(expectedStatNames);
|
||||
|
||||
|
||||
SparkTrainingStats stats = sparkNet.getSparkTrainingStats();
|
||||
|
@ -119,7 +119,7 @@ public class TestTrainingStatsCollection extends BaseSparkTest {
|
|||
}
|
||||
|
||||
String statsAsString = stats.statsAsString();
|
||||
System.out.println(statsAsString);
|
||||
// System.out.println(statsAsString);
|
||||
assertEquals(actualKeySet.size(), statsAsString.split("\n").length); //One line per stat
|
||||
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@ public class TestTimeSource {
|
|||
long systemTime = System.currentTimeMillis();
|
||||
long ntpTime = timeSource.currentTimeMillis();
|
||||
long offset = ntpTime - systemTime;
|
||||
System.out.println("System: " + systemTime + "\tNTPTimeSource: " + ntpTime + "\tOffset: " + offset);
|
||||
// System.out.println("System: " + systemTime + "\tNTPTimeSource: " + ntpTime + "\tOffset: " + offset);
|
||||
Thread.sleep(500);
|
||||
}
|
||||
}
|
||||
|
@ -49,7 +49,7 @@ public class TestTimeSource {
|
|||
long systemTime = System.currentTimeMillis();
|
||||
long ntpTime = timeSource.currentTimeMillis();
|
||||
long offset = ntpTime - systemTime;
|
||||
System.out.println("System: " + systemTime + "\tSystemClockTimeSource: " + ntpTime + "\tOffset: " + offset);
|
||||
// System.out.println("System: " + systemTime + "\tSystemClockTimeSource: " + ntpTime + "\tOffset: " + offset);
|
||||
assertEquals(systemTime, ntpTime, 2); //Should be exact, but we might randomly tick over between one ms and the next
|
||||
Thread.sleep(500);
|
||||
}
|
||||
|
|
|
@ -87,7 +87,7 @@ public class TestListeners extends BaseSparkTest {
|
|||
net.fit(rdd);
|
||||
|
||||
List<String> sessions = ss.listSessionIDs();
|
||||
System.out.println("Sessions: " + sessions);
|
||||
// System.out.println("Sessions: " + sessions);
|
||||
assertEquals(1, sessions.size());
|
||||
|
||||
String sid = sessions.get(0);
|
||||
|
@ -95,15 +95,15 @@ public class TestListeners extends BaseSparkTest {
|
|||
List<String> typeIDs = ss.listTypeIDsForSession(sid);
|
||||
List<String> workers = ss.listWorkerIDsForSession(sid);
|
||||
|
||||
System.out.println(sid + "\t" + typeIDs + "\t" + workers);
|
||||
// System.out.println(sid + "\t" + typeIDs + "\t" + workers);
|
||||
|
||||
List<Persistable> lastUpdates = ss.getLatestUpdateAllWorkers(sid, StatsListener.TYPE_ID);
|
||||
System.out.println(lastUpdates);
|
||||
// System.out.println(lastUpdates);
|
||||
|
||||
System.out.println("Static info:");
|
||||
// System.out.println("Static info:");
|
||||
for (String wid : workers) {
|
||||
Persistable staticInfo = ss.getStaticInfo(sid, StatsListener.TYPE_ID, wid);
|
||||
System.out.println(sid + "\t" + wid);
|
||||
// System.out.println(sid + "\t" + wid);
|
||||
}
|
||||
|
||||
assertEquals(1, typeIDs.size());
|
||||
|
|
|
@ -63,7 +63,7 @@ public class TestRepartitioning extends BaseSparkTest {
|
|||
assertEquals(10, rdd2.partitions().size());
|
||||
for (int i = 0; i < 10; i++) {
|
||||
List<String> partition = rdd2.collectPartitions(new int[] {i})[0];
|
||||
System.out.println("Partition " + i + " size: " + partition.size());
|
||||
// System.out.println("Partition " + i + " size: " + partition.size());
|
||||
assertEquals(100, partition.size()); //Should be exactly 100, for the util method (but NOT spark .repartition)
|
||||
}
|
||||
}
|
||||
|
@ -170,7 +170,7 @@ public class TestRepartitioning extends BaseSparkTest {
|
|||
|
||||
List<Tuple2<Integer, Integer>> partitionCounts = initial.values().mapPartitionsWithIndex(new CountPartitionsFunction<Integer>(), true).collect();
|
||||
|
||||
System.out.println(partitionCounts);
|
||||
// System.out.println(partitionCounts);
|
||||
|
||||
List<Tuple2<Integer,Integer>> initialExpected = Arrays.asList(
|
||||
new Tuple2<>(0,29),
|
||||
|
@ -185,7 +185,7 @@ public class TestRepartitioning extends BaseSparkTest {
|
|||
|
||||
JavaRDD<Integer> afterRepartition = SparkUtils.repartitionBalanceIfRequired(initial.values(), Repartition.Always, 2, 112);
|
||||
List<Tuple2<Integer, Integer>> partitionCountsAfter = afterRepartition.mapPartitionsWithIndex(new CountPartitionsFunction<Integer>(), true).collect();
|
||||
System.out.println(partitionCountsAfter);
|
||||
// System.out.println(partitionCountsAfter);
|
||||
|
||||
for(Tuple2<Integer,Integer> t2 : partitionCountsAfter){
|
||||
assertEquals(2, (int)t2._2());
|
||||
|
@ -219,8 +219,8 @@ public class TestRepartitioning extends BaseSparkTest {
|
|||
}
|
||||
}
|
||||
|
||||
System.out.println("min: " + min + "\t@\t" + minIdx);
|
||||
System.out.println("max: " + max + "\t@\t" + maxIdx);
|
||||
// System.out.println("min: " + min + "\t@\t" + minIdx);
|
||||
// System.out.println("max: " + max + "\t@\t" + maxIdx);
|
||||
|
||||
assertEquals(1, min);
|
||||
assertEquals(2, max);
|
||||
|
@ -244,7 +244,7 @@ public class TestRepartitioning extends BaseSparkTest {
|
|||
|
||||
for (int i = 0; i < 10; i++) {
|
||||
List<String> partition = rdd2.collectPartitions(new int[] {i})[0];
|
||||
System.out.println("Partition " + i + " size: " + partition.size());
|
||||
// System.out.println("Partition " + i + " size: " + partition.size());
|
||||
assertTrue(partition.size() >= 90 && partition.size() <= 110);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -123,7 +123,7 @@ trait NDArrayExtractionTestBase extends FlatSpec { self: OrderingForTest =>
|
|||
val expectedSlice = expectedArray.slice(0)
|
||||
val actualSlice = expectedArray(0, ->)
|
||||
|
||||
Console.println(expectedSlice)
|
||||
// Console.println(expectedSlice)
|
||||
|
||||
assert(actualSlice == expectedSlice)
|
||||
}
|
||||
|
|
|
@ -28,7 +28,7 @@ class TrainingTest extends FlatSpec with Matchers {
|
|||
val unused3 = unused1.div(unused2)
|
||||
val loss1 = add.std("l1", true)
|
||||
val loss2 = mmul.mean("l2")
|
||||
Console.println(sd.summary)
|
||||
// Console.println(sd.summary)
|
||||
if (i == 0) {
|
||||
sd.setLossVariables("l1", "l2")
|
||||
sd.createGradFunction()
|
||||
|
|
|
@ -43,8 +43,8 @@ public class HistoryProcessorTest {
|
|||
hp.add(a);
|
||||
INDArray[] h = hp.getHistory();
|
||||
assertEquals(4, h.length);
|
||||
System.out.println(Arrays.toString(a.shape()));
|
||||
System.out.println(Arrays.toString(h[0].shape()));
|
||||
// System.out.println(Arrays.toString(a.shape()));
|
||||
// System.out.println(Arrays.toString(h[0].shape()));
|
||||
assertEquals( 1, h[0].shape()[0]);
|
||||
assertEquals(a.shape()[0], h[0].shape()[1]);
|
||||
assertEquals(a.shape()[1], h[0].shape()[2]);
|
||||
|
|
|
@ -100,8 +100,8 @@ public class ActorCriticTest {
|
|||
double error2 = gradient2 - gradient.getDouble(1);
|
||||
double relError1 = error1 / gradient.getDouble(0);
|
||||
double relError2 = error2 / gradient.getDouble(1);
|
||||
System.out.println(gradient.getDouble(0) + " " + gradient1 + " " + relError1);
|
||||
System.out.println(gradient.getDouble(1) + " " + gradient2 + " " + relError2);
|
||||
// System.out.println(gradient.getDouble(0) + " " + gradient1 + " " + relError1);
|
||||
// System.out.println(gradient.getDouble(1) + " " + gradient2 + " " + relError2);
|
||||
assertTrue(gradient.getDouble(0) < maxRelError || Math.abs(relError1) < maxRelError);
|
||||
assertTrue(gradient.getDouble(1) < maxRelError || Math.abs(relError2) < maxRelError);
|
||||
}
|
||||
|
|
|
@ -158,7 +158,7 @@ public class PolicyTest {
|
|||
for (int i = 0; i < 100; i++) {
|
||||
count[policy.nextAction(input)]++;
|
||||
}
|
||||
System.out.println(count[0] + " " + count[1] + " " + count[2] + " " + count[3]);
|
||||
// System.out.println(count[0] + " " + count[1] + " " + count[2] + " " + count[3]);
|
||||
assertTrue(count[0] < 20);
|
||||
assertTrue(count[1] < 30);
|
||||
assertTrue(count[2] < 40);
|
||||
|
|
Loading…
Reference in New Issue