Test fixes + cleanup (#245)
* Test spam reduction Signed-off-by: Alex Black <blacka101@gmail.com> * Arbiter bad import fixes Signed-off-by: Alex Black <blacka101@gmail.com> * Small spark test tweak Signed-off-by: AlexDBlack <blacka101@gmail.com> * Arbiter test log spam reduction Signed-off-by: Alex Black <blacka101@gmail.com> * More test spam reduction Signed-off-by: Alex Black <blacka101@gmail.com>master
parent
2698fbf541
commit
c8882cbfa5
|
@ -1,47 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.util;
|
|
||||||
|
|
||||||
import org.slf4j.Logger;
|
|
||||||
|
|
||||||
import java.awt.*;
|
|
||||||
import java.net.URI;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Various utilities for webpages and dealing with browsers
|
|
||||||
*/
|
|
||||||
public class WebUtils {
|
|
||||||
|
|
||||||
public static void tryOpenBrowser(String path, Logger log) {
|
|
||||||
try {
|
|
||||||
WebUtils.openBrowser(new URI(path));
|
|
||||||
} catch (Exception e) {
|
|
||||||
log.error("Could not open browser", e);
|
|
||||||
System.out.println("Browser could not be launched automatically.\nUI path: " + path);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void openBrowser(URI uri) throws Exception {
|
|
||||||
if (Desktop.isDesktopSupported()) {
|
|
||||||
Desktop.getDesktop().browse(uri);
|
|
||||||
} else {
|
|
||||||
throw new UnsupportedOperationException(
|
|
||||||
"Cannot open browser on this platform: Desktop.isDesktopSupported() == false");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -127,7 +127,7 @@ public class BraninFunction {
|
||||||
BraninConfig candidate = (BraninConfig) c.getValue();
|
BraninConfig candidate = (BraninConfig) c.getValue();
|
||||||
|
|
||||||
double score = scoreFunction.score(candidate, null, (Map) null);
|
double score = scoreFunction.score(candidate, null, (Map) null);
|
||||||
System.out.println(candidate.getX1() + "\t" + candidate.getX2() + "\t" + score);
|
// System.out.println(candidate.getX1() + "\t" + candidate.getX2() + "\t" + score);
|
||||||
|
|
||||||
Thread.sleep(20);
|
Thread.sleep(20);
|
||||||
|
|
||||||
|
|
|
@ -54,7 +54,7 @@ public class TestRandomSearch extends BaseDL4JTest {
|
||||||
runner.execute();
|
runner.execute();
|
||||||
|
|
||||||
|
|
||||||
System.out.println("----- Complete -----");
|
// System.out.println("----- Complete -----");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -16,8 +16,8 @@
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.genetic;
|
package org.deeplearning4j.arbiter.optimize.genetic;
|
||||||
|
|
||||||
|
import org.apache.commons.lang3.NotImplementedException;
|
||||||
import org.apache.commons.math3.random.RandomGenerator;
|
import org.apache.commons.math3.random.RandomGenerator;
|
||||||
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
|
|
||||||
|
|
||||||
public class TestRandomGenerator implements RandomGenerator {
|
public class TestRandomGenerator implements RandomGenerator {
|
||||||
private final int[] intRandomNumbers;
|
private final int[] intRandomNumbers;
|
||||||
|
@ -63,17 +63,17 @@ public class TestRandomGenerator implements RandomGenerator {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long nextLong() {
|
public long nextLong() {
|
||||||
throw new NotImplementedException();
|
throw new NotImplementedException("Not implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean nextBoolean() {
|
public boolean nextBoolean() {
|
||||||
throw new NotImplementedException();
|
throw new NotImplementedException("Not implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public float nextFloat() {
|
public float nextFloat() {
|
||||||
throw new NotImplementedException();
|
throw new NotImplementedException("Not implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -83,6 +83,6 @@ public class TestRandomGenerator implements RandomGenerator {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public double nextGaussian() {
|
public double nextGaussian() {
|
||||||
throw new NotImplementedException();
|
throw new NotImplementedException("Not implemented");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
|
package org.deeplearning4j.arbiter.optimize.genetic.crossover;
|
||||||
|
|
||||||
|
import org.apache.commons.lang3.NotImplementedException;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.TwoParentsCrossoverOperator;
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.TwoParentsCrossoverOperator;
|
||||||
|
@ -26,7 +27,6 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection;
|
||||||
import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer;
|
import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer;
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
|
|
||||||
|
|
||||||
public class TwoParentsCrossoverOperatorTests extends BaseDL4JTest {
|
public class TwoParentsCrossoverOperatorTests extends BaseDL4JTest {
|
||||||
|
|
||||||
|
@ -42,7 +42,7 @@ public class TwoParentsCrossoverOperatorTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public CrossoverResult crossover() {
|
public CrossoverResult crossover() {
|
||||||
throw new NotImplementedException();
|
throw new NotImplementedException("Not implemented");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.genetic.culling;
|
package org.deeplearning4j.arbiter.optimize.genetic.culling;
|
||||||
|
|
||||||
|
import org.apache.commons.lang3.NotImplementedException;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.RatioCullOperator;
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.RatioCullOperator;
|
||||||
|
@ -24,7 +25,6 @@ import org.deeplearning4j.arbiter.optimize.generator.genetic.population.Populati
|
||||||
import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer;
|
import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer;
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
|
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
@ -46,7 +46,7 @@ public class RatioCullOperatorTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void cullPopulation() {
|
public void cullPopulation() {
|
||||||
throw new NotImplementedException();
|
throw new NotImplementedException("Not implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
public double getCullRatio() {
|
public double getCullRatio() {
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.genetic.selection;
|
package org.deeplearning4j.arbiter.optimize.genetic.selection;
|
||||||
|
|
||||||
|
import org.apache.commons.lang3.NotImplementedException;
|
||||||
import org.apache.commons.math3.random.RandomGenerator;
|
import org.apache.commons.math3.random.RandomGenerator;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
|
||||||
|
@ -33,7 +34,6 @@ import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer;
|
||||||
import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
|
import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator;
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
|
|
||||||
|
|
||||||
import static org.junit.Assert.assertArrayEquals;
|
import static org.junit.Assert.assertArrayEquals;
|
||||||
|
|
||||||
|
@ -55,7 +55,7 @@ public class GeneticSelectionOperatorTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void cullPopulation() {
|
public void cullPopulation() {
|
||||||
throw new NotImplementedException();
|
throw new NotImplementedException("Not implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.genetic.selection;
|
package org.deeplearning4j.arbiter.optimize.genetic.selection;
|
||||||
|
|
||||||
|
import org.apache.commons.lang3.NotImplementedException;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
|
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
|
||||||
|
@ -24,7 +25,6 @@ import org.deeplearning4j.arbiter.optimize.generator.genetic.selection.Selection
|
||||||
import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer;
|
import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer;
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
|
|
||||||
|
|
||||||
public class SelectionOperatorTests extends BaseDL4JTest {
|
public class SelectionOperatorTests extends BaseDL4JTest {
|
||||||
private class TestSelectionOperator extends SelectionOperator {
|
private class TestSelectionOperator extends SelectionOperator {
|
||||||
|
@ -39,7 +39,7 @@ public class SelectionOperatorTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public double[] buildNextGenes() {
|
public double[] buildNextGenes() {
|
||||||
throw new NotImplementedException();
|
throw new NotImplementedException("Not implemented");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -158,7 +158,7 @@ public class TestComputationGraphSpace extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
System.out.println("ReLU vs. Tanh: " + reluCount + "\t" + tanhCount);
|
// System.out.println("ReLU vs. Tanh: " + reluCount + "\t" + tanhCount);
|
||||||
assertTrue(reluCount > 0);
|
assertTrue(reluCount > 0);
|
||||||
assertTrue(tanhCount > 0);
|
assertTrue(tanhCount > 0);
|
||||||
|
|
||||||
|
|
|
@ -162,7 +162,7 @@ public class TestGraphLocalExecution extends BaseDL4JTest {
|
||||||
List<ResultReference> results = runner.getResults();
|
List<ResultReference> results = runner.getResults();
|
||||||
assertTrue(results.size() > 0);
|
assertTrue(results.size() > 0);
|
||||||
|
|
||||||
System.out.println("----- COMPLETE - " + results.size() + " results -----");
|
// System.out.println("----- COMPLETE - " + results.size() + " results -----");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -165,7 +165,7 @@ public class TestGraphLocalExecutionGenetic extends BaseDL4JTest {
|
||||||
List<ResultReference> results = runner.getResults();
|
List<ResultReference> results = runner.getResults();
|
||||||
assertTrue(results.size() > 0);
|
assertTrue(results.size() > 0);
|
||||||
|
|
||||||
System.out.println("----- COMPLETE - " + results.size() + " results -----");
|
// System.out.println("----- COMPLETE - " + results.size() + " results -----");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -101,7 +101,7 @@ public class TestLayerSpace extends BaseDL4JTest {
|
||||||
double l2 = TestUtils.getL2(l);
|
double l2 = TestUtils.getL2(l);
|
||||||
IActivation activation = l.getActivationFn();
|
IActivation activation = l.getActivationFn();
|
||||||
|
|
||||||
System.out.println(lr + "\t" + l2 + "\t" + activation);
|
// System.out.println(lr + "\t" + l2 + "\t" + activation);
|
||||||
|
|
||||||
assertTrue(lr >= 0.3 && lr <= 0.4);
|
assertTrue(lr >= 0.3 && lr <= 0.4);
|
||||||
assertTrue(l2 >= 0.01 && l2 <= 0.1);
|
assertTrue(l2 >= 0.01 && l2 <= 0.1);
|
||||||
|
@ -190,7 +190,7 @@ public class TestLayerSpace extends BaseDL4JTest {
|
||||||
ActivationLayer al = als.getValue(d);
|
ActivationLayer al = als.getValue(d);
|
||||||
IActivation activation = al.getActivationFn();
|
IActivation activation = al.getActivationFn();
|
||||||
|
|
||||||
System.out.println(activation);
|
// System.out.println(activation);
|
||||||
|
|
||||||
assertTrue(containsActivationFunction(actFns, activation));
|
assertTrue(containsActivationFunction(actFns, activation));
|
||||||
}
|
}
|
||||||
|
@ -228,7 +228,7 @@ public class TestLayerSpace extends BaseDL4JTest {
|
||||||
IActivation activation = el.getActivationFn();
|
IActivation activation = el.getActivationFn();
|
||||||
long nOut = el.getNOut();
|
long nOut = el.getNOut();
|
||||||
|
|
||||||
System.out.println(activation + "\t" + nOut);
|
// System.out.println(activation + "\t" + nOut);
|
||||||
|
|
||||||
assertTrue(containsActivationFunction(actFns, activation));
|
assertTrue(containsActivationFunction(actFns, activation));
|
||||||
assertTrue(nOut >= 10 && nOut <= 20);
|
assertTrue(nOut >= 10 && nOut <= 20);
|
||||||
|
@ -295,7 +295,7 @@ public class TestLayerSpace extends BaseDL4JTest {
|
||||||
long nOut = el.getNOut();
|
long nOut = el.getNOut();
|
||||||
double forgetGate = el.getForgetGateBiasInit();
|
double forgetGate = el.getForgetGateBiasInit();
|
||||||
|
|
||||||
System.out.println(activation + "\t" + nOut + "\t" + forgetGate);
|
// System.out.println(activation + "\t" + nOut + "\t" + forgetGate);
|
||||||
|
|
||||||
assertTrue(containsActivationFunction(actFns, activation));
|
assertTrue(containsActivationFunction(actFns, activation));
|
||||||
assertTrue(nOut >= 10 && nOut <= 20);
|
assertTrue(nOut >= 10 && nOut <= 20);
|
||||||
|
|
|
@ -293,8 +293,8 @@ public class TestMultiLayerSpace extends BaseDL4JTest {
|
||||||
assertTrue(nLayerCounts[i] >= 5); //Expect approx equal (50/3 each), but some variation randomly
|
assertTrue(nLayerCounts[i] >= 5); //Expect approx equal (50/3 each), but some variation randomly
|
||||||
}
|
}
|
||||||
|
|
||||||
System.out.println("Number of layers: " + Arrays.toString(nLayerCounts));
|
// System.out.println("Number of layers: " + Arrays.toString(nLayerCounts));
|
||||||
System.out.println("ReLU vs. Tanh: " + reluCount + "\t" + tanhCount);
|
// System.out.println("ReLU vs. Tanh: " + reluCount + "\t" + tanhCount);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -98,7 +98,8 @@ public class ArbiterCLIRunnerTest extends BaseDL4JTest {
|
||||||
assertEquals(configuration,OptimizationConfiguration.fromJson(configuration.toJson()));
|
assertEquals(configuration,OptimizationConfiguration.fromJson(configuration.toJson()));
|
||||||
|
|
||||||
FileUtils.writeStringToFile(new File(configPath),configuration.toJson());
|
FileUtils.writeStringToFile(new File(configPath),configuration.toJson());
|
||||||
System.out.println(configuration.toJson());
|
// System.out.println(configuration.toJson());
|
||||||
|
configuration.toJson();
|
||||||
|
|
||||||
log.info("Starting test");
|
log.info("Starting test");
|
||||||
cliRunner.runMain(
|
cliRunner.runMain(
|
||||||
|
|
|
@ -41,7 +41,7 @@ public class TestGraphLoading extends BaseDL4JTest {
|
||||||
|
|
||||||
IGraph<String, String> graph = GraphLoader
|
IGraph<String, String> graph = GraphLoader
|
||||||
.loadUndirectedGraphEdgeListFile(cpr.getTempFileFromArchive().getAbsolutePath(), 7, ",");
|
.loadUndirectedGraphEdgeListFile(cpr.getTempFileFromArchive().getAbsolutePath(), 7, ",");
|
||||||
System.out.println(graph);
|
// System.out.println(graph);
|
||||||
|
|
||||||
assertEquals(graph.numVertices(), 7);
|
assertEquals(graph.numVertices(), 7);
|
||||||
int[][] edges = {{1, 2}, {0, 2, 4}, {0, 1, 3, 4}, {2, 4, 5}, {1, 2, 3, 5, 6}, {3, 4, 6}, {4, 5}};
|
int[][] edges = {{1, 2}, {0, 2, 4}, {0, 1, 3, 4}, {2, 4, 5}, {1, 2, 3, 5, 6}, {3, 4, 6}, {4, 5}};
|
||||||
|
@ -66,7 +66,7 @@ public class TestGraphLoading extends BaseDL4JTest {
|
||||||
edgeLineProcessor, vertexFactory, 10, false);
|
edgeLineProcessor, vertexFactory, 10, false);
|
||||||
|
|
||||||
|
|
||||||
System.out.println(graph);
|
// System.out.println(graph);
|
||||||
|
|
||||||
for (int i = 0; i < 10; i++) {
|
for (int i = 0; i < 10; i++) {
|
||||||
List<Edge<String>> edges = graph.getEdgesOut(i);
|
List<Edge<String>> edges = graph.getEdgesOut(i);
|
||||||
|
@ -111,7 +111,7 @@ public class TestGraphLoading extends BaseDL4JTest {
|
||||||
Graph<String, String> graph = GraphLoader.loadGraph(verticesCPR.getTempFileFromArchive().getAbsolutePath(),
|
Graph<String, String> graph = GraphLoader.loadGraph(verticesCPR.getTempFileFromArchive().getAbsolutePath(),
|
||||||
edgesCPR.getTempFileFromArchive().getAbsolutePath(), vertexLoader, edgeLineProcessor, false);
|
edgesCPR.getTempFileFromArchive().getAbsolutePath(), vertexLoader, edgeLineProcessor, false);
|
||||||
|
|
||||||
System.out.println(graph);
|
// System.out.println(graph);
|
||||||
|
|
||||||
for (int i = 0; i < 10; i++) {
|
for (int i = 0; i < 10; i++) {
|
||||||
List<Edge<String>> edges = graph.getEdgesOut(i);
|
List<Edge<String>> edges = graph.getEdgesOut(i);
|
||||||
|
|
|
@ -71,7 +71,7 @@ public class TestGraphLoadingWeighted extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
System.out.println(graph);
|
// System.out.println(graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -220,7 +220,7 @@ public class TestGraph extends BaseDL4JTest {
|
||||||
sum += transitionProb[i][j];
|
sum += transitionProb[i][j];
|
||||||
for (int j = 0; j < transitionProb[i].length; j++)
|
for (int j = 0; j < transitionProb[i].length; j++)
|
||||||
transitionProb[i][j] /= sum;
|
transitionProb[i][j] /= sum;
|
||||||
System.out.println(Arrays.toString(transitionProb[i]));
|
// System.out.println(Arrays.toString(transitionProb[i]));
|
||||||
}
|
}
|
||||||
|
|
||||||
//Check that transition probs are essentially correct (within bounds of random variation)
|
//Check that transition probs are essentially correct (within bounds of random variation)
|
||||||
|
|
|
@ -145,8 +145,8 @@ public class DeepWalkGradientCheck extends BaseDL4JTest {
|
||||||
|
|
||||||
if (relError > MAX_REL_ERROR && absErr > MIN_ABS_ERROR)
|
if (relError > MAX_REL_ERROR && absErr > MIN_ABS_ERROR)
|
||||||
fail(msg);
|
fail(msg);
|
||||||
else
|
// else
|
||||||
System.out.println(msg);
|
// System.out.println(msg);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -333,10 +333,10 @@ public class DeepWalkGradientCheck extends BaseDL4JTest {
|
||||||
|
|
||||||
if (relError > MAX_REL_ERROR && absErr > MIN_ABS_ERROR)
|
if (relError > MAX_REL_ERROR && absErr > MIN_ABS_ERROR)
|
||||||
fail(msg);
|
fail(msg);
|
||||||
else
|
// else
|
||||||
System.out.println(msg);
|
// System.out.println(msg);
|
||||||
}
|
}
|
||||||
System.out.println();
|
// System.out.println();
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -67,7 +67,7 @@ public class TestDeepWalk extends BaseDL4JTest {
|
||||||
for (int i = 0; i < 7; i++) {
|
for (int i = 0; i < 7; i++) {
|
||||||
INDArray vector = deepWalk.getVertexVector(i);
|
INDArray vector = deepWalk.getVertexVector(i);
|
||||||
assertArrayEquals(new long[] {vectorSize}, vector.shape());
|
assertArrayEquals(new long[] {vectorSize}, vector.shape());
|
||||||
System.out.println(Arrays.toString(vector.dup().data().asFloat()));
|
// System.out.println(Arrays.toString(vector.dup().data().asFloat()));
|
||||||
}
|
}
|
||||||
|
|
||||||
GraphWalkIterator<String> iter = new RandomWalkIterator<>(graph, 8);
|
GraphWalkIterator<String> iter = new RandomWalkIterator<>(graph, 8);
|
||||||
|
@ -77,11 +77,11 @@ public class TestDeepWalk extends BaseDL4JTest {
|
||||||
for (int t = 0; t < 5; t++) {
|
for (int t = 0; t < 5; t++) {
|
||||||
iter.reset();
|
iter.reset();
|
||||||
deepWalk.fit(iter);
|
deepWalk.fit(iter);
|
||||||
System.out.println("--------------------");
|
// System.out.println("--------------------");
|
||||||
for (int i = 0; i < 7; i++) {
|
for (int i = 0; i < 7; i++) {
|
||||||
INDArray vector = deepWalk.getVertexVector(i);
|
INDArray vector = deepWalk.getVertexVector(i);
|
||||||
assertArrayEquals(new long[] {vectorSize}, vector.shape());
|
assertArrayEquals(new long[] {vectorSize}, vector.shape());
|
||||||
System.out.println(Arrays.toString(vector.dup().data().asFloat()));
|
// System.out.println(Arrays.toString(vector.dup().data().asFloat()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -160,7 +160,7 @@ public class TestDeepWalk extends BaseDL4JTest {
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
double sim = deepWalk.similarity(i, nearestTo);
|
double sim = deepWalk.similarity(i, nearestTo);
|
||||||
System.out.println(i + "\t" + nearestTo + "\t" + sim);
|
// System.out.println(i + "\t" + nearestTo + "\t" + sim);
|
||||||
assertTrue(sim <= minSimNearest);
|
assertTrue(sim <= minSimNearest);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -211,7 +211,7 @@ public class TestDeepWalk extends BaseDL4JTest {
|
||||||
Graph<String, String> graph = GraphLoader
|
Graph<String, String> graph = GraphLoader
|
||||||
.loadUndirectedGraphEdgeListFile(cpr.getTempFileFromArchive().getAbsolutePath(), 13, ",");
|
.loadUndirectedGraphEdgeListFile(cpr.getTempFileFromArchive().getAbsolutePath(), 13, ",");
|
||||||
|
|
||||||
System.out.println(graph);
|
// System.out.println(graph);
|
||||||
|
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
@ -229,11 +229,13 @@ public class TestDeepWalk extends BaseDL4JTest {
|
||||||
|
|
||||||
//Calculate similarity(0,i)
|
//Calculate similarity(0,i)
|
||||||
for (int i = 0; i < nVertices; i++) {
|
for (int i = 0; i < nVertices; i++) {
|
||||||
System.out.println(deepWalk.similarity(0, i));
|
// System.out.println(deepWalk.similarity(0, i));
|
||||||
|
deepWalk.similarity(0, i);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < nVertices; i++)
|
for (int i = 0; i < nVertices; i++)
|
||||||
System.out.println(deepWalk.getVertexVector(i));
|
// System.out.println(deepWalk.getVertexVector(i));
|
||||||
|
deepWalk.getVertexVector(i);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(timeout = 60000L)
|
@Test(timeout = 60000L)
|
||||||
|
|
|
@ -38,9 +38,11 @@ public class TestGraphHuffman extends BaseDL4JTest {
|
||||||
|
|
||||||
gh.buildTree(vertexDegrees);
|
gh.buildTree(vertexDegrees);
|
||||||
|
|
||||||
for (int i = 0; i < 7; i++)
|
for (int i = 0; i < 7; i++) {
|
||||||
System.out.println(i + "\t" + gh.getCodeLength(i) + "\t" + gh.getCodeString(i) + "\t\t" + gh.getCode(i)
|
String s = i + "\t" + gh.getCodeLength(i) + "\t" + gh.getCodeString(i) + "\t\t" + gh.getCode(i)
|
||||||
+ "\t\t" + Arrays.toString(gh.getPathInnerNodes(i)));
|
+ "\t\t" + Arrays.toString(gh.getPathInnerNodes(i));
|
||||||
|
// System.out.println(s);
|
||||||
|
}
|
||||||
|
|
||||||
int[] expectedLengths = {3, 2, 2, 5, 4, 2, 5};
|
int[] expectedLengths = {3, 2, 2, 5, 4, 2, 5};
|
||||||
for (int i = 0; i < vertexDegrees.length; i++) {
|
for (int i = 0; i < vertexDegrees.length; i++) {
|
||||||
|
|
|
@ -79,8 +79,9 @@ public class ParameterServerParallelWrapperTest extends BaseDL4JTest {
|
||||||
model.init();
|
model.init();
|
||||||
|
|
||||||
ParallelWrapper parameterServerParallelWrapper =
|
ParallelWrapper parameterServerParallelWrapper =
|
||||||
new ParallelWrapper.Builder(model).trainerFactory(new ParameterServerTrainerContext())
|
new ParallelWrapper.Builder(model)
|
||||||
.workers(Runtime.getRuntime().availableProcessors())
|
.workers(Math.min(4, Runtime.getRuntime().availableProcessors()))
|
||||||
|
.trainerFactory(new ParameterServerTrainerContext())
|
||||||
.reportScoreAfterAveraging(true).prefetchBuffer(3).build();
|
.reportScoreAfterAveraging(true).prefetchBuffer(3).build();
|
||||||
parameterServerParallelWrapper.fit(mnistTrain);
|
parameterServerParallelWrapper.fit(mnistTrain);
|
||||||
|
|
||||||
|
|
|
@ -104,7 +104,7 @@ public class SparkWord2VecTest extends BaseDL4JTest {
|
||||||
public void call(ExportContainer<VocabWord> v) throws Exception {
|
public void call(ExportContainer<VocabWord> v) throws Exception {
|
||||||
assertNotNull(v.getElement());
|
assertNotNull(v.getElement());
|
||||||
assertNotNull(v.getArray());
|
assertNotNull(v.getArray());
|
||||||
System.out.println(v.getElement() + " - " + v.getArray());
|
// System.out.println(v.getElement() + " - " + v.getArray());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -66,7 +66,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest {
|
||||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||||
.build();
|
.build();
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.setListeners(new ScoreIterationListener(1));
|
net.setListeners(new ScoreIterationListener(5));
|
||||||
|
|
||||||
|
|
||||||
JavaRDD<DataSet> irisData = getIris();
|
JavaRDD<DataSet> irisData = getIris();
|
||||||
|
@ -119,7 +119,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest {
|
||||||
.lossFunction(LossFunctions.LossFunction.MSE).build())
|
.lossFunction(LossFunctions.LossFunction.MSE).build())
|
||||||
.build();
|
.build();
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.setListeners(new ScoreIterationListener(1));
|
net.setListeners(new ScoreIterationListener(5));
|
||||||
|
|
||||||
JavaRDD<DataSet> irisData = getIris();
|
JavaRDD<DataSet> irisData = getIris();
|
||||||
EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
|
EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
|
||||||
|
@ -155,7 +155,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest {
|
||||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||||
.build();
|
.build();
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.setListeners(new ScoreIterationListener(1));
|
net.setListeners(new ScoreIterationListener(5));
|
||||||
|
|
||||||
JavaRDD<DataSet> irisData = getIris();
|
JavaRDD<DataSet> irisData = getIris();
|
||||||
|
|
||||||
|
@ -198,7 +198,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest {
|
||||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||||
.build();
|
.build();
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.setListeners(new ScoreIterationListener(1));
|
net.setListeners(new ScoreIterationListener(5));
|
||||||
|
|
||||||
JavaRDD<DataSet> irisData = getIris();
|
JavaRDD<DataSet> irisData = getIris();
|
||||||
|
|
||||||
|
@ -231,7 +231,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest {
|
||||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||||
.build();
|
.build();
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.setListeners(new ScoreIterationListener(1));
|
net.setListeners(new ScoreIterationListener(5));
|
||||||
|
|
||||||
|
|
||||||
JavaRDD<DataSet> irisData = getIris();
|
JavaRDD<DataSet> irisData = getIris();
|
||||||
|
|
|
@ -69,7 +69,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest {
|
||||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
|
||||||
.setOutputs("0").build();
|
.setOutputs("0").build();
|
||||||
ComputationGraph net = new ComputationGraph(conf);
|
ComputationGraph net = new ComputationGraph(conf);
|
||||||
net.setListeners(new ScoreIterationListener(1));
|
net.setListeners(new ScoreIterationListener(5));
|
||||||
|
|
||||||
|
|
||||||
JavaRDD<DataSet> irisData = getIris();
|
JavaRDD<DataSet> irisData = getIris();
|
||||||
|
@ -120,7 +120,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest {
|
||||||
.lossFunction(LossFunctions.LossFunction.MSE).build(), "in")
|
.lossFunction(LossFunctions.LossFunction.MSE).build(), "in")
|
||||||
.setOutputs("0").build();
|
.setOutputs("0").build();
|
||||||
ComputationGraph net = new ComputationGraph(conf);
|
ComputationGraph net = new ComputationGraph(conf);
|
||||||
net.setListeners(new ScoreIterationListener(1));
|
net.setListeners(new ScoreIterationListener(5));
|
||||||
|
|
||||||
JavaRDD<DataSet> irisData = getIris();
|
JavaRDD<DataSet> irisData = getIris();
|
||||||
EarlyStoppingModelSaver<ComputationGraph> saver = new InMemoryModelSaver<>();
|
EarlyStoppingModelSaver<ComputationGraph> saver = new InMemoryModelSaver<>();
|
||||||
|
@ -158,7 +158,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest {
|
||||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
|
||||||
.setOutputs("0").build();
|
.setOutputs("0").build();
|
||||||
ComputationGraph net = new ComputationGraph(conf);
|
ComputationGraph net = new ComputationGraph(conf);
|
||||||
net.setListeners(new ScoreIterationListener(1));
|
net.setListeners(new ScoreIterationListener(5));
|
||||||
|
|
||||||
JavaRDD<DataSet> irisData = getIris();
|
JavaRDD<DataSet> irisData = getIris();
|
||||||
|
|
||||||
|
@ -203,7 +203,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest {
|
||||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
|
||||||
.setOutputs("0").build();
|
.setOutputs("0").build();
|
||||||
ComputationGraph net = new ComputationGraph(conf);
|
ComputationGraph net = new ComputationGraph(conf);
|
||||||
net.setListeners(new ScoreIterationListener(1));
|
net.setListeners(new ScoreIterationListener(5));
|
||||||
|
|
||||||
JavaRDD<DataSet> irisData = getIris();
|
JavaRDD<DataSet> irisData = getIris();
|
||||||
|
|
||||||
|
@ -238,7 +238,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest {
|
||||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
|
||||||
.setOutputs("0").build();
|
.setOutputs("0").build();
|
||||||
ComputationGraph net = new ComputationGraph(conf);
|
ComputationGraph net = new ComputationGraph(conf);
|
||||||
net.setListeners(new ScoreIterationListener(1));
|
net.setListeners(new ScoreIterationListener(5));
|
||||||
|
|
||||||
|
|
||||||
JavaRDD<DataSet> irisData = getIris();
|
JavaRDD<DataSet> irisData = getIris();
|
||||||
|
|
|
@ -59,7 +59,7 @@ public class TestShuffleExamples extends BaseSparkTest {
|
||||||
int totalExampleCount = 0;
|
int totalExampleCount = 0;
|
||||||
for (DataSet ds : shuffledList) {
|
for (DataSet ds : shuffledList) {
|
||||||
totalExampleCount += ds.getFeatures().length();
|
totalExampleCount += ds.getFeatures().length();
|
||||||
System.out.println(Arrays.toString(ds.getFeatures().data().asFloat()));
|
// System.out.println(Arrays.toString(ds.getFeatures().data().asFloat()));
|
||||||
|
|
||||||
assertEquals(ds.getFeatures(), ds.getLabels());
|
assertEquals(ds.getFeatures(), ds.getLabels());
|
||||||
}
|
}
|
||||||
|
|
|
@ -86,7 +86,7 @@ public class TestExport extends BaseSparkTest {
|
||||||
for (File file : files) {
|
for (File file : files) {
|
||||||
if (!file.getPath().endsWith(".bin"))
|
if (!file.getPath().endsWith(".bin"))
|
||||||
continue;
|
continue;
|
||||||
System.out.println(file);
|
// System.out.println(file);
|
||||||
DataSet ds = new DataSet();
|
DataSet ds = new DataSet();
|
||||||
ds.load(file);
|
ds.load(file);
|
||||||
assertEquals(minibatchSize, ds.numExamples());
|
assertEquals(minibatchSize, ds.numExamples());
|
||||||
|
@ -144,7 +144,7 @@ public class TestExport extends BaseSparkTest {
|
||||||
for (File file : files) {
|
for (File file : files) {
|
||||||
if (!file.getPath().endsWith(".bin"))
|
if (!file.getPath().endsWith(".bin"))
|
||||||
continue;
|
continue;
|
||||||
System.out.println(file);
|
// System.out.println(file);
|
||||||
MultiDataSet ds = new org.nd4j.linalg.dataset.MultiDataSet();
|
MultiDataSet ds = new org.nd4j.linalg.dataset.MultiDataSet();
|
||||||
ds.load(file);
|
ds.load(file);
|
||||||
assertEquals(minibatchSize, ds.getFeatures(0).size(0));
|
assertEquals(minibatchSize, ds.getFeatures(0).size(0));
|
||||||
|
|
|
@ -92,9 +92,9 @@ public class HashingBalancedPartitionerTest extends BaseSparkTest {
|
||||||
|
|
||||||
int[][] colorCountsByPartition = new int[3][2];
|
int[][] colorCountsByPartition = new int[3][2];
|
||||||
for (final Tuple2<Tuple2<Long, Integer>, String> val : testList) {
|
for (final Tuple2<Tuple2<Long, Integer>, String> val : testList) {
|
||||||
System.out.println(val);
|
// System.out.println(val);
|
||||||
Integer partition = hbp.getPartition(val._1());
|
Integer partition = hbp.getPartition(val._1());
|
||||||
System.out.println(partition);
|
// System.out.println(partition);
|
||||||
|
|
||||||
if (val._2().equals("red"))
|
if (val._2().equals("red"))
|
||||||
colorCountsByPartition[partition][0] += 1;
|
colorCountsByPartition[partition][0] += 1;
|
||||||
|
@ -102,9 +102,9 @@ public class HashingBalancedPartitionerTest extends BaseSparkTest {
|
||||||
colorCountsByPartition[partition][1] += 1;
|
colorCountsByPartition[partition][1] += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < 3; i++) {
|
// for (int i = 0; i < 3; i++) {
|
||||||
System.out.println(Arrays.toString(colorCountsByPartition[i]));
|
// System.out.println(Arrays.toString(colorCountsByPartition[i]));
|
||||||
}
|
// }
|
||||||
for (int i = 0; i < 3; i++) {
|
for (int i = 0; i < 3; i++) {
|
||||||
// avg red per partition : 2.33
|
// avg red per partition : 2.33
|
||||||
assertTrue(colorCountsByPartition[i][0] >= 1 && colorCountsByPartition[i][0] < 4);
|
assertTrue(colorCountsByPartition[i][0] >= 1 && colorCountsByPartition[i][0] < 4);
|
||||||
|
@ -178,12 +178,12 @@ public class HashingBalancedPartitionerTest extends BaseSparkTest {
|
||||||
colorCountsByPartition[partition][1] += 1;
|
colorCountsByPartition[partition][1] += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < numPartitions; i++) {
|
// for (int i = 0; i < numPartitions; i++) {
|
||||||
System.out.println(Arrays.toString(colorCountsByPartition[i]));
|
// System.out.println(Arrays.toString(colorCountsByPartition[i]));
|
||||||
}
|
// }
|
||||||
|
//
|
||||||
System.out.println("Ideal red # per partition: " + avgRed);
|
// System.out.println("Ideal red # per partition: " + avgRed);
|
||||||
System.out.println("Ideal blue # per partition: " + avgBlue);
|
// System.out.println("Ideal blue # per partition: " + avgBlue);
|
||||||
|
|
||||||
for (int i = 0; i < numPartitions; i++) {
|
for (int i = 0; i < numPartitions; i++) {
|
||||||
// avg red per partition : 2.33
|
// avg red per partition : 2.33
|
||||||
|
|
|
@ -115,7 +115,7 @@ public class TestSparkComputationGraph extends BaseSparkTest {
|
||||||
TrainingMaster tm = new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 10, 1, 0);
|
TrainingMaster tm = new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 10, 1, 0);
|
||||||
|
|
||||||
SparkComputationGraph scg = new SparkComputationGraph(sc, cg, tm);
|
SparkComputationGraph scg = new SparkComputationGraph(sc, cg, tm);
|
||||||
scg.setListeners(Collections.singleton((TrainingListener) new ScoreIterationListener(1)));
|
scg.setListeners(Collections.singleton((TrainingListener) new ScoreIterationListener(5)));
|
||||||
|
|
||||||
JavaRDD<MultiDataSet> rdd = sc.parallelize(list);
|
JavaRDD<MultiDataSet> rdd = sc.parallelize(list);
|
||||||
scg.fitMultiDataSet(rdd);
|
scg.fitMultiDataSet(rdd);
|
||||||
|
|
|
@ -31,8 +31,11 @@ import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.evaluation.classification.Evaluation;
|
import org.nd4j.evaluation.classification.Evaluation;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.learning.config.Adam;
|
||||||
import org.nd4j.linalg.learning.config.Nesterovs;
|
import org.nd4j.linalg.learning.config.Nesterovs;
|
||||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
|
@ -45,8 +48,24 @@ import static org.junit.Assert.assertTrue;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class TestSparkDl4jMultiLayer extends BaseSparkTest {
|
public class TestSparkDl4jMultiLayer extends BaseSparkTest {
|
||||||
|
|
||||||
@Test(timeout = 120000L)
|
@Override
|
||||||
|
public long getTimeoutMilliseconds() {
|
||||||
|
return 120000L;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DataType getDataType() {
|
||||||
|
return DataType.FLOAT;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DataType getDefaultFPDataType() {
|
||||||
|
return DataType.FLOAT;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
public void testEvaluationSimple() throws Exception {
|
public void testEvaluationSimple() throws Exception {
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
for( int evalWorkers : new int[]{1, 4, 8}) {
|
for( int evalWorkers : new int[]{1, 4, 8}) {
|
||||||
//Simple test to validate DL4J issue 4099 is fixed...
|
//Simple test to validate DL4J issue 4099 is fixed...
|
||||||
|
@ -75,18 +94,18 @@ public class TestSparkDl4jMultiLayer extends BaseSparkTest {
|
||||||
//----------------------------------
|
//----------------------------------
|
||||||
//Create network configuration and conduct network training
|
//Create network configuration and conduct network training
|
||||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
|
.dataType(DataType.FLOAT)
|
||||||
.seed(12345)
|
.seed(12345)
|
||||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
||||||
.activation(Activation.LEAKYRELU)
|
.activation(Activation.LEAKYRELU)
|
||||||
.weightInit(WeightInit.XAVIER)
|
.weightInit(WeightInit.XAVIER)
|
||||||
.updater(new Nesterovs(0.02, 0.9))
|
.updater(new Adam(1e-3))
|
||||||
.l2(1e-4)
|
.l2(1e-5)
|
||||||
.list()
|
.list()
|
||||||
.layer(0, new DenseLayer.Builder().nIn(28 * 28).nOut(500).build())
|
.layer(0, new DenseLayer.Builder().nIn(28 * 28).nOut(500).build())
|
||||||
.layer(1, new DenseLayer.Builder().nIn(500).nOut(100).build())
|
.layer(1, new DenseLayer.Builder().nIn(500).nOut(100).build())
|
||||||
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
|
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
|
||||||
.activation(Activation.SOFTMAX).nIn(100).nOut(10).build())
|
.activation(Activation.SOFTMAX).nIn(100).nOut(10).build())
|
||||||
|
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
//Configuration for Spark training: see https://deeplearning4j.org/docs/latest/deeplearning4j-scaleout-howto for explanation of these configuration options
|
//Configuration for Spark training: see https://deeplearning4j.org/docs/latest/deeplearning4j-scaleout-howto for explanation of these configuration options
|
||||||
|
|
|
@ -333,15 +333,16 @@ public class TestCompareParameterAveragingSparkVsSingleMachine {
|
||||||
sparkNet.fit(rdd);
|
sparkNet.fit(rdd);
|
||||||
}
|
}
|
||||||
|
|
||||||
System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
|
// System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
|
||||||
|
sparkNet.getSparkTrainingStats().statsAsString();
|
||||||
|
|
||||||
INDArray finalSparkParams = sparkNet.getNetwork().params().dup();
|
INDArray finalSparkParams = sparkNet.getNetwork().params().dup();
|
||||||
|
|
||||||
System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
|
// System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
|
||||||
System.out.println("Initial (Spark) params: "
|
// System.out.println("Initial (Spark) params: "
|
||||||
+ Arrays.toString(initialSparkParams.data().asFloat()));
|
// + Arrays.toString(initialSparkParams.data().asFloat()));
|
||||||
System.out.println("Final (Local) params: " + Arrays.toString(finalParams.data().asFloat()));
|
// System.out.println("Final (Local) params: " + Arrays.toString(finalParams.data().asFloat()));
|
||||||
System.out.println("Final (Spark) params: " + Arrays.toString(finalSparkParams.data().asFloat()));
|
// System.out.println("Final (Spark) params: " + Arrays.toString(finalSparkParams.data().asFloat()));
|
||||||
assertEquals(initialParams, initialSparkParams);
|
assertEquals(initialParams, initialSparkParams);
|
||||||
assertNotEquals(initialParams, finalParams);
|
assertNotEquals(initialParams, finalParams);
|
||||||
assertEquals(finalParams, finalSparkParams);
|
assertEquals(finalParams, finalSparkParams);
|
||||||
|
@ -405,15 +406,16 @@ public class TestCompareParameterAveragingSparkVsSingleMachine {
|
||||||
sparkNet.fit(rdd);
|
sparkNet.fit(rdd);
|
||||||
}
|
}
|
||||||
|
|
||||||
System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
|
// System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
|
||||||
|
sparkNet.getSparkTrainingStats().statsAsString();
|
||||||
|
|
||||||
INDArray finalSparkParams = sparkNet.getNetwork().params().dup();
|
INDArray finalSparkParams = sparkNet.getNetwork().params().dup();
|
||||||
|
|
||||||
System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
|
// System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
|
||||||
System.out.println("Initial (Spark) params: "
|
// System.out.println("Initial (Spark) params: "
|
||||||
+ Arrays.toString(initialSparkParams.data().asFloat()));
|
// + Arrays.toString(initialSparkParams.data().asFloat()));
|
||||||
System.out.println("Final (Local) params: " + Arrays.toString(finalParams.data().asFloat()));
|
// System.out.println("Final (Local) params: " + Arrays.toString(finalParams.data().asFloat()));
|
||||||
System.out.println("Final (Spark) params: " + Arrays.toString(finalSparkParams.data().asFloat()));
|
// System.out.println("Final (Spark) params: " + Arrays.toString(finalSparkParams.data().asFloat()));
|
||||||
assertArrayEquals(initialParams.data().asFloat(), initialSparkParams.data().asFloat(), 1e-8f);
|
assertArrayEquals(initialParams.data().asFloat(), initialSparkParams.data().asFloat(), 1e-8f);
|
||||||
assertArrayEquals(finalParams.data().asFloat(), finalSparkParams.data().asFloat(), 1e-6f);
|
assertArrayEquals(finalParams.data().asFloat(), finalSparkParams.data().asFloat(), 1e-6f);
|
||||||
|
|
||||||
|
@ -478,18 +480,19 @@ public class TestCompareParameterAveragingSparkVsSingleMachine {
|
||||||
sparkNet.fit(rdd);
|
sparkNet.fit(rdd);
|
||||||
}
|
}
|
||||||
|
|
||||||
System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
|
// System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
|
||||||
|
sparkNet.getSparkTrainingStats().statsAsString();
|
||||||
|
|
||||||
INDArray finalSparkParams = sparkNet.getNetwork().params().dup();
|
INDArray finalSparkParams = sparkNet.getNetwork().params().dup();
|
||||||
// executioner.addToWatchdog(finalSparkParams, "finalSparkParams");
|
// executioner.addToWatchdog(finalSparkParams, "finalSparkParams");
|
||||||
|
|
||||||
float[] fp = finalParams.data().asFloat();
|
float[] fp = finalParams.data().asFloat();
|
||||||
float[] fps = finalSparkParams.data().asFloat();
|
float[] fps = finalSparkParams.data().asFloat();
|
||||||
System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
|
// System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
|
||||||
System.out.println("Initial (Spark) params: "
|
// System.out.println("Initial (Spark) params: "
|
||||||
+ Arrays.toString(initialSparkParams.data().asFloat()));
|
// + Arrays.toString(initialSparkParams.data().asFloat()));
|
||||||
System.out.println("Final (Local) params: " + Arrays.toString(fp));
|
// System.out.println("Final (Local) params: " + Arrays.toString(fp));
|
||||||
System.out.println("Final (Spark) params: " + Arrays.toString(fps));
|
// System.out.println("Final (Spark) params: " + Arrays.toString(fps));
|
||||||
|
|
||||||
assertEquals(initialParams, initialSparkParams);
|
assertEquals(initialParams, initialSparkParams);
|
||||||
assertNotEquals(initialParams, finalParams);
|
assertNotEquals(initialParams, finalParams);
|
||||||
|
@ -551,14 +554,15 @@ public class TestCompareParameterAveragingSparkVsSingleMachine {
|
||||||
sparkNet.fit(rdd);
|
sparkNet.fit(rdd);
|
||||||
}
|
}
|
||||||
|
|
||||||
System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
|
// System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
|
||||||
|
sparkNet.getSparkTrainingStats().statsAsString();
|
||||||
|
|
||||||
INDArray finalSparkParams = sparkNet.getNetwork().params().dup();
|
INDArray finalSparkParams = sparkNet.getNetwork().params().dup();
|
||||||
|
|
||||||
System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
|
// System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat()));
|
||||||
System.out.println("Initial (Spark) params: " + Arrays.toString(initialSparkParams.data().asFloat()));
|
// System.out.println("Initial (Spark) params: " + Arrays.toString(initialSparkParams.data().asFloat()));
|
||||||
System.out.println("Final (Local) params: " + Arrays.toString(finalParams.data().asFloat()));
|
// System.out.println("Final (Local) params: " + Arrays.toString(finalParams.data().asFloat()));
|
||||||
System.out.println("Final (Spark) params: " + Arrays.toString(finalSparkParams.data().asFloat()));
|
// System.out.println("Final (Spark) params: " + Arrays.toString(finalSparkParams.data().asFloat()));
|
||||||
assertArrayEquals(initialParams.data().asFloat(), initialSparkParams.data().asFloat(), 1e-8f);
|
assertArrayEquals(initialParams.data().asFloat(), initialSparkParams.data().asFloat(), 1e-8f);
|
||||||
assertArrayEquals(finalParams.data().asFloat(), finalSparkParams.data().asFloat(), 1e-6f);
|
assertArrayEquals(finalParams.data().asFloat(), finalSparkParams.data().asFloat(), 1e-6f);
|
||||||
|
|
||||||
|
|
|
@ -37,7 +37,7 @@ public class TestJsonYaml {
|
||||||
String json = tm.toJson();
|
String json = tm.toJson();
|
||||||
String yaml = tm.toYaml();
|
String yaml = tm.toYaml();
|
||||||
|
|
||||||
System.out.println(json);
|
// System.out.println(json);
|
||||||
|
|
||||||
TrainingMaster fromJson = ParameterAveragingTrainingMaster.fromJson(json);
|
TrainingMaster fromJson = ParameterAveragingTrainingMaster.fromJson(json);
|
||||||
TrainingMaster fromYaml = ParameterAveragingTrainingMaster.fromYaml(yaml);
|
TrainingMaster fromYaml = ParameterAveragingTrainingMaster.fromYaml(yaml);
|
||||||
|
|
|
@ -389,7 +389,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
List<EventStats> workerFitStats = stats.getValue("ParameterAveragingWorkerFitTimesMs");
|
List<EventStats> workerFitStats = stats.getValue("ParameterAveragingWorkerFitTimesMs");
|
||||||
for (EventStats e : workerFitStats) {
|
for (EventStats e : workerFitStats) {
|
||||||
ExampleCountEventStats eces = (ExampleCountEventStats) e;
|
ExampleCountEventStats eces = (ExampleCountEventStats) e;
|
||||||
System.out.println(eces.getTotalExampleCount());
|
// System.out.println(eces.getTotalExampleCount());
|
||||||
}
|
}
|
||||||
|
|
||||||
for (EventStats e : workerFitStats) {
|
for (EventStats e : workerFitStats) {
|
||||||
|
@ -457,7 +457,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
assertNotEquals(paramsBefore, paramsAfter);
|
assertNotEquals(paramsBefore, paramsAfter);
|
||||||
|
|
||||||
SparkTrainingStats stats = sparkNet.getSparkTrainingStats();
|
SparkTrainingStats stats = sparkNet.getSparkTrainingStats();
|
||||||
System.out.println(stats.statsAsString());
|
// System.out.println(stats.statsAsString());
|
||||||
|
stats.statsAsString();
|
||||||
|
|
||||||
sparkNet.getTrainingMaster().deleteTempFiles(sc);
|
sparkNet.getTrainingMaster().deleteTempFiles(sc);
|
||||||
}
|
}
|
||||||
|
@ -483,7 +484,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
i++;
|
i++;
|
||||||
}
|
}
|
||||||
|
|
||||||
System.out.println("Saved to: " + tempDirF.getAbsolutePath());
|
// System.out.println("Saved to: " + tempDirF.getAbsolutePath());
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -527,7 +528,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
SparkTrainingStats stats = sparkNet.getSparkTrainingStats();
|
SparkTrainingStats stats = sparkNet.getSparkTrainingStats();
|
||||||
|
|
||||||
//Expect
|
//Expect
|
||||||
System.out.println(stats.statsAsString());
|
// System.out.println(stats.statsAsString());
|
||||||
|
stats.statsAsString();
|
||||||
assertEquals(numSplits, stats.getValue("ParameterAveragingMasterRepartitionTimesMs").size());
|
assertEquals(numSplits, stats.getValue("ParameterAveragingMasterRepartitionTimesMs").size());
|
||||||
|
|
||||||
List<EventStats> list = stats.getValue("ParameterAveragingWorkerFitTimesMs");
|
List<EventStats> list = stats.getValue("ParameterAveragingWorkerFitTimesMs");
|
||||||
|
@ -566,8 +568,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
i++;
|
i++;
|
||||||
}
|
}
|
||||||
|
|
||||||
System.out.println("Saved to: " + tempDirF.getAbsolutePath());
|
// System.out.println("Saved to: " + tempDirF.getAbsolutePath());
|
||||||
System.out.println("Saved to: " + tempDirF2.getAbsolutePath());
|
// System.out.println("Saved to: " + tempDirF2.getAbsolutePath());
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -610,7 +612,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
assertNotEquals(paramsBefore, paramsAfter);
|
assertNotEquals(paramsBefore, paramsAfter);
|
||||||
|
|
||||||
SparkTrainingStats stats = sparkNet.getSparkTrainingStats();
|
SparkTrainingStats stats = sparkNet.getSparkTrainingStats();
|
||||||
System.out.println(stats.statsAsString());
|
// System.out.println(stats.statsAsString());
|
||||||
|
stats.statsAsString();
|
||||||
|
|
||||||
//Same thing, buf for MultiDataSet objects:
|
//Same thing, buf for MultiDataSet objects:
|
||||||
config = new Configuration();
|
config = new Configuration();
|
||||||
|
@ -631,7 +634,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
assertNotEquals(paramsBefore, paramsAfter);
|
assertNotEquals(paramsBefore, paramsAfter);
|
||||||
|
|
||||||
stats = sparkNet.getSparkTrainingStats();
|
stats = sparkNet.getSparkTrainingStats();
|
||||||
System.out.println(stats.statsAsString());
|
// System.out.println(stats.statsAsString());
|
||||||
|
stats.statsAsString();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -730,13 +734,13 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
for (int avgFreq : new int[] {1, 5, 10}) {
|
for (int avgFreq : new int[] {1, 5, 10}) {
|
||||||
System.out.println("--- Avg freq " + avgFreq + " ---");
|
// System.out.println("--- Avg freq " + avgFreq + " ---");
|
||||||
SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, conf.clone(),
|
SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, conf.clone(),
|
||||||
new ParameterAveragingTrainingMaster.Builder(numExecutors(), dataSetObjSize)
|
new ParameterAveragingTrainingMaster.Builder(numExecutors(), dataSetObjSize)
|
||||||
.batchSizePerWorker(batchSizePerExecutor).averagingFrequency(avgFreq)
|
.batchSizePerWorker(batchSizePerExecutor).averagingFrequency(avgFreq)
|
||||||
.repartionData(Repartition.Always).build());
|
.repartionData(Repartition.Always).build());
|
||||||
|
|
||||||
sparkNet.setListeners(new ScoreIterationListener(1));
|
sparkNet.setListeners(new ScoreIterationListener(5));
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -778,13 +782,13 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
.setOutputs("1").build();
|
.setOutputs("1").build();
|
||||||
|
|
||||||
for (int avgFreq : new int[] {1, 5, 10}) {
|
for (int avgFreq : new int[] {1, 5, 10}) {
|
||||||
System.out.println("--- Avg freq " + avgFreq + " ---");
|
// System.out.println("--- Avg freq " + avgFreq + " ---");
|
||||||
SparkComputationGraph sparkNet = new SparkComputationGraph(sc, conf.clone(),
|
SparkComputationGraph sparkNet = new SparkComputationGraph(sc, conf.clone(),
|
||||||
new ParameterAveragingTrainingMaster.Builder(numExecutors(), dataSetObjSize)
|
new ParameterAveragingTrainingMaster.Builder(numExecutors(), dataSetObjSize)
|
||||||
.batchSizePerWorker(batchSizePerExecutor).averagingFrequency(avgFreq)
|
.batchSizePerWorker(batchSizePerExecutor).averagingFrequency(avgFreq)
|
||||||
.repartionData(Repartition.Always).build());
|
.repartionData(Repartition.Always).build());
|
||||||
|
|
||||||
sparkNet.setListeners(new ScoreIterationListener(1));
|
sparkNet.setListeners(new ScoreIterationListener(5));
|
||||||
|
|
||||||
JavaRDD<DataSet> rdd = sc.parallelize(list);
|
JavaRDD<DataSet> rdd = sc.parallelize(list);
|
||||||
|
|
||||||
|
|
|
@ -107,7 +107,7 @@ public class TestTrainingStatsCollection extends BaseSparkTest {
|
||||||
expectedStatNames.addAll(c);
|
expectedStatNames.addAll(c);
|
||||||
}
|
}
|
||||||
|
|
||||||
System.out.println(expectedStatNames);
|
// System.out.println(expectedStatNames);
|
||||||
|
|
||||||
|
|
||||||
SparkTrainingStats stats = sparkNet.getSparkTrainingStats();
|
SparkTrainingStats stats = sparkNet.getSparkTrainingStats();
|
||||||
|
@ -119,7 +119,7 @@ public class TestTrainingStatsCollection extends BaseSparkTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
String statsAsString = stats.statsAsString();
|
String statsAsString = stats.statsAsString();
|
||||||
System.out.println(statsAsString);
|
// System.out.println(statsAsString);
|
||||||
assertEquals(actualKeySet.size(), statsAsString.split("\n").length); //One line per stat
|
assertEquals(actualKeySet.size(), statsAsString.split("\n").length); //One line per stat
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -35,7 +35,7 @@ public class TestTimeSource {
|
||||||
long systemTime = System.currentTimeMillis();
|
long systemTime = System.currentTimeMillis();
|
||||||
long ntpTime = timeSource.currentTimeMillis();
|
long ntpTime = timeSource.currentTimeMillis();
|
||||||
long offset = ntpTime - systemTime;
|
long offset = ntpTime - systemTime;
|
||||||
System.out.println("System: " + systemTime + "\tNTPTimeSource: " + ntpTime + "\tOffset: " + offset);
|
// System.out.println("System: " + systemTime + "\tNTPTimeSource: " + ntpTime + "\tOffset: " + offset);
|
||||||
Thread.sleep(500);
|
Thread.sleep(500);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -49,7 +49,7 @@ public class TestTimeSource {
|
||||||
long systemTime = System.currentTimeMillis();
|
long systemTime = System.currentTimeMillis();
|
||||||
long ntpTime = timeSource.currentTimeMillis();
|
long ntpTime = timeSource.currentTimeMillis();
|
||||||
long offset = ntpTime - systemTime;
|
long offset = ntpTime - systemTime;
|
||||||
System.out.println("System: " + systemTime + "\tSystemClockTimeSource: " + ntpTime + "\tOffset: " + offset);
|
// System.out.println("System: " + systemTime + "\tSystemClockTimeSource: " + ntpTime + "\tOffset: " + offset);
|
||||||
assertEquals(systemTime, ntpTime, 2); //Should be exact, but we might randomly tick over between one ms and the next
|
assertEquals(systemTime, ntpTime, 2); //Should be exact, but we might randomly tick over between one ms and the next
|
||||||
Thread.sleep(500);
|
Thread.sleep(500);
|
||||||
}
|
}
|
||||||
|
|
|
@ -87,7 +87,7 @@ public class TestListeners extends BaseSparkTest {
|
||||||
net.fit(rdd);
|
net.fit(rdd);
|
||||||
|
|
||||||
List<String> sessions = ss.listSessionIDs();
|
List<String> sessions = ss.listSessionIDs();
|
||||||
System.out.println("Sessions: " + sessions);
|
// System.out.println("Sessions: " + sessions);
|
||||||
assertEquals(1, sessions.size());
|
assertEquals(1, sessions.size());
|
||||||
|
|
||||||
String sid = sessions.get(0);
|
String sid = sessions.get(0);
|
||||||
|
@ -95,15 +95,15 @@ public class TestListeners extends BaseSparkTest {
|
||||||
List<String> typeIDs = ss.listTypeIDsForSession(sid);
|
List<String> typeIDs = ss.listTypeIDsForSession(sid);
|
||||||
List<String> workers = ss.listWorkerIDsForSession(sid);
|
List<String> workers = ss.listWorkerIDsForSession(sid);
|
||||||
|
|
||||||
System.out.println(sid + "\t" + typeIDs + "\t" + workers);
|
// System.out.println(sid + "\t" + typeIDs + "\t" + workers);
|
||||||
|
|
||||||
List<Persistable> lastUpdates = ss.getLatestUpdateAllWorkers(sid, StatsListener.TYPE_ID);
|
List<Persistable> lastUpdates = ss.getLatestUpdateAllWorkers(sid, StatsListener.TYPE_ID);
|
||||||
System.out.println(lastUpdates);
|
// System.out.println(lastUpdates);
|
||||||
|
|
||||||
System.out.println("Static info:");
|
// System.out.println("Static info:");
|
||||||
for (String wid : workers) {
|
for (String wid : workers) {
|
||||||
Persistable staticInfo = ss.getStaticInfo(sid, StatsListener.TYPE_ID, wid);
|
Persistable staticInfo = ss.getStaticInfo(sid, StatsListener.TYPE_ID, wid);
|
||||||
System.out.println(sid + "\t" + wid);
|
// System.out.println(sid + "\t" + wid);
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(1, typeIDs.size());
|
assertEquals(1, typeIDs.size());
|
||||||
|
|
|
@ -63,7 +63,7 @@ public class TestRepartitioning extends BaseSparkTest {
|
||||||
assertEquals(10, rdd2.partitions().size());
|
assertEquals(10, rdd2.partitions().size());
|
||||||
for (int i = 0; i < 10; i++) {
|
for (int i = 0; i < 10; i++) {
|
||||||
List<String> partition = rdd2.collectPartitions(new int[] {i})[0];
|
List<String> partition = rdd2.collectPartitions(new int[] {i})[0];
|
||||||
System.out.println("Partition " + i + " size: " + partition.size());
|
// System.out.println("Partition " + i + " size: " + partition.size());
|
||||||
assertEquals(100, partition.size()); //Should be exactly 100, for the util method (but NOT spark .repartition)
|
assertEquals(100, partition.size()); //Should be exactly 100, for the util method (but NOT spark .repartition)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -170,7 +170,7 @@ public class TestRepartitioning extends BaseSparkTest {
|
||||||
|
|
||||||
List<Tuple2<Integer, Integer>> partitionCounts = initial.values().mapPartitionsWithIndex(new CountPartitionsFunction<Integer>(), true).collect();
|
List<Tuple2<Integer, Integer>> partitionCounts = initial.values().mapPartitionsWithIndex(new CountPartitionsFunction<Integer>(), true).collect();
|
||||||
|
|
||||||
System.out.println(partitionCounts);
|
// System.out.println(partitionCounts);
|
||||||
|
|
||||||
List<Tuple2<Integer,Integer>> initialExpected = Arrays.asList(
|
List<Tuple2<Integer,Integer>> initialExpected = Arrays.asList(
|
||||||
new Tuple2<>(0,29),
|
new Tuple2<>(0,29),
|
||||||
|
@ -185,7 +185,7 @@ public class TestRepartitioning extends BaseSparkTest {
|
||||||
|
|
||||||
JavaRDD<Integer> afterRepartition = SparkUtils.repartitionBalanceIfRequired(initial.values(), Repartition.Always, 2, 112);
|
JavaRDD<Integer> afterRepartition = SparkUtils.repartitionBalanceIfRequired(initial.values(), Repartition.Always, 2, 112);
|
||||||
List<Tuple2<Integer, Integer>> partitionCountsAfter = afterRepartition.mapPartitionsWithIndex(new CountPartitionsFunction<Integer>(), true).collect();
|
List<Tuple2<Integer, Integer>> partitionCountsAfter = afterRepartition.mapPartitionsWithIndex(new CountPartitionsFunction<Integer>(), true).collect();
|
||||||
System.out.println(partitionCountsAfter);
|
// System.out.println(partitionCountsAfter);
|
||||||
|
|
||||||
for(Tuple2<Integer,Integer> t2 : partitionCountsAfter){
|
for(Tuple2<Integer,Integer> t2 : partitionCountsAfter){
|
||||||
assertEquals(2, (int)t2._2());
|
assertEquals(2, (int)t2._2());
|
||||||
|
@ -219,8 +219,8 @@ public class TestRepartitioning extends BaseSparkTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
System.out.println("min: " + min + "\t@\t" + minIdx);
|
// System.out.println("min: " + min + "\t@\t" + minIdx);
|
||||||
System.out.println("max: " + max + "\t@\t" + maxIdx);
|
// System.out.println("max: " + max + "\t@\t" + maxIdx);
|
||||||
|
|
||||||
assertEquals(1, min);
|
assertEquals(1, min);
|
||||||
assertEquals(2, max);
|
assertEquals(2, max);
|
||||||
|
@ -244,7 +244,7 @@ public class TestRepartitioning extends BaseSparkTest {
|
||||||
|
|
||||||
for (int i = 0; i < 10; i++) {
|
for (int i = 0; i < 10; i++) {
|
||||||
List<String> partition = rdd2.collectPartitions(new int[] {i})[0];
|
List<String> partition = rdd2.collectPartitions(new int[] {i})[0];
|
||||||
System.out.println("Partition " + i + " size: " + partition.size());
|
// System.out.println("Partition " + i + " size: " + partition.size());
|
||||||
assertTrue(partition.size() >= 90 && partition.size() <= 110);
|
assertTrue(partition.size() >= 90 && partition.size() <= 110);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -123,7 +123,7 @@ trait NDArrayExtractionTestBase extends FlatSpec { self: OrderingForTest =>
|
||||||
val expectedSlice = expectedArray.slice(0)
|
val expectedSlice = expectedArray.slice(0)
|
||||||
val actualSlice = expectedArray(0, ->)
|
val actualSlice = expectedArray(0, ->)
|
||||||
|
|
||||||
Console.println(expectedSlice)
|
// Console.println(expectedSlice)
|
||||||
|
|
||||||
assert(actualSlice == expectedSlice)
|
assert(actualSlice == expectedSlice)
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,7 +28,7 @@ class TrainingTest extends FlatSpec with Matchers {
|
||||||
val unused3 = unused1.div(unused2)
|
val unused3 = unused1.div(unused2)
|
||||||
val loss1 = add.std("l1", true)
|
val loss1 = add.std("l1", true)
|
||||||
val loss2 = mmul.mean("l2")
|
val loss2 = mmul.mean("l2")
|
||||||
Console.println(sd.summary)
|
// Console.println(sd.summary)
|
||||||
if (i == 0) {
|
if (i == 0) {
|
||||||
sd.setLossVariables("l1", "l2")
|
sd.setLossVariables("l1", "l2")
|
||||||
sd.createGradFunction()
|
sd.createGradFunction()
|
||||||
|
|
|
@ -43,8 +43,8 @@ public class HistoryProcessorTest {
|
||||||
hp.add(a);
|
hp.add(a);
|
||||||
INDArray[] h = hp.getHistory();
|
INDArray[] h = hp.getHistory();
|
||||||
assertEquals(4, h.length);
|
assertEquals(4, h.length);
|
||||||
System.out.println(Arrays.toString(a.shape()));
|
// System.out.println(Arrays.toString(a.shape()));
|
||||||
System.out.println(Arrays.toString(h[0].shape()));
|
// System.out.println(Arrays.toString(h[0].shape()));
|
||||||
assertEquals( 1, h[0].shape()[0]);
|
assertEquals( 1, h[0].shape()[0]);
|
||||||
assertEquals(a.shape()[0], h[0].shape()[1]);
|
assertEquals(a.shape()[0], h[0].shape()[1]);
|
||||||
assertEquals(a.shape()[1], h[0].shape()[2]);
|
assertEquals(a.shape()[1], h[0].shape()[2]);
|
||||||
|
|
|
@ -100,8 +100,8 @@ public class ActorCriticTest {
|
||||||
double error2 = gradient2 - gradient.getDouble(1);
|
double error2 = gradient2 - gradient.getDouble(1);
|
||||||
double relError1 = error1 / gradient.getDouble(0);
|
double relError1 = error1 / gradient.getDouble(0);
|
||||||
double relError2 = error2 / gradient.getDouble(1);
|
double relError2 = error2 / gradient.getDouble(1);
|
||||||
System.out.println(gradient.getDouble(0) + " " + gradient1 + " " + relError1);
|
// System.out.println(gradient.getDouble(0) + " " + gradient1 + " " + relError1);
|
||||||
System.out.println(gradient.getDouble(1) + " " + gradient2 + " " + relError2);
|
// System.out.println(gradient.getDouble(1) + " " + gradient2 + " " + relError2);
|
||||||
assertTrue(gradient.getDouble(0) < maxRelError || Math.abs(relError1) < maxRelError);
|
assertTrue(gradient.getDouble(0) < maxRelError || Math.abs(relError1) < maxRelError);
|
||||||
assertTrue(gradient.getDouble(1) < maxRelError || Math.abs(relError2) < maxRelError);
|
assertTrue(gradient.getDouble(1) < maxRelError || Math.abs(relError2) < maxRelError);
|
||||||
}
|
}
|
||||||
|
|
|
@ -158,7 +158,7 @@ public class PolicyTest {
|
||||||
for (int i = 0; i < 100; i++) {
|
for (int i = 0; i < 100; i++) {
|
||||||
count[policy.nextAction(input)]++;
|
count[policy.nextAction(input)]++;
|
||||||
}
|
}
|
||||||
System.out.println(count[0] + " " + count[1] + " " + count[2] + " " + count[3]);
|
// System.out.println(count[0] + " " + count[1] + " " + count[2] + " " + count[3]);
|
||||||
assertTrue(count[0] < 20);
|
assertTrue(count[0] < 20);
|
||||||
assertTrue(count[1] < 30);
|
assertTrue(count[1] < 30);
|
||||||
assertTrue(count[2] < 40);
|
assertTrue(count[2] < 40);
|
||||||
|
|
Loading…
Reference in New Issue