Small fixes (#223)

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-09-03 10:48:59 +10:00 committed by GitHub
parent 6ce620709a
commit ba269a26ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 19 additions and 29 deletions

View File

@ -17,7 +17,6 @@
package org.deeplearning4j.spark; package org.deeplearning4j.spark;
import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.serializer.SerializerInstance;
import org.deeplearning4j.eval.*;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
@ -28,6 +27,9 @@ import org.deeplearning4j.nn.conf.graph.rnn.LastTimeStepVertex;
import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor; import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
import org.junit.Test; import org.junit.Test;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.classification.*;
import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.Adam;

View File

@ -19,7 +19,6 @@ package org.deeplearning4j.spark.impl.multilayer;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaRDD;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
@ -30,6 +29,7 @@ import org.deeplearning4j.spark.BaseSparkTest;
import org.deeplearning4j.spark.api.TrainingMaster; import org.deeplearning4j.spark.api.TrainingMaster;
import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster; import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster;
import org.junit.Test; import org.junit.Test;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;

View File

@ -29,15 +29,13 @@ import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.util.MLUtils; import org.apache.spark.mllib.util.MLUtils;
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.eval.ROC;
import org.deeplearning4j.eval.ROCMultiClass;
import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.BaseLayer; import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.deeplearning4j.nn.conf.layers.BatchNormalization;
import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.variational.GaussianReconstructionDistribution; import org.deeplearning4j.nn.conf.layers.variational.GaussianReconstructionDistribution;
@ -56,6 +54,9 @@ import org.junit.Ignore;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.TemporaryFolder; import org.junit.rules.TemporaryFolder;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.ROC;
import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
@ -63,6 +64,7 @@ import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.learning.config.Nesterovs; import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.learning.config.RmsProp; import org.nd4j.linalg.learning.config.RmsProp;
@ -70,7 +72,6 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
import scala.Tuple2; import scala.Tuple2;
import java.io.File; import java.io.File;
import java.nio.file.Files;
import java.nio.file.Path; import java.nio.file.Path;
import java.util.*; import java.util.*;
@ -121,11 +122,6 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 5, 1, 0)); new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 5, 1, 0));
MultiLayerNetwork network2 = master.fitLabeledPoint(data); MultiLayerNetwork network2 = master.fitLabeledPoint(data);
Evaluation evaluation = new Evaluation();
evaluation.eval(d.getLabels(), network2.output(d.getFeatures()));
System.out.println(evaluation.stats());
} }
@ -137,20 +133,15 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
.getAbsolutePath()) .getAbsolutePath())
.toJavaRDD().map(new TestFn()); .toJavaRDD().map(new TestFn());
DataSet d = new IrisDataSetIterator(150, 150).next();
MultiLayerConfiguration conf = MultiLayerConfiguration conf =
new NeuralNetConfiguration.Builder().seed(123) new NeuralNetConfiguration.Builder().seed(123)
.optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT) .updater(new Adam(1e-6))
.miniBatch(true).maxNumLineSearchIterations(10)
.list().layer(0,
new DenseLayer.Builder().nIn(4).nOut(100)
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
.activation(Activation.RELU) .list()
.build()) .layer(new BatchNormalization.Builder().nIn(4).nOut(4).build())
.layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( .layer(new DenseLayer.Builder().nIn(4).nOut(32).activation(Activation.RELU).build())
LossFunctions.LossFunction.MCXENT).nIn(100).nOut(3) .layer(new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(32).nOut(3)
.activation(Activation.SOFTMAX) .activation(Activation.SOFTMAX).build())
.weightInit(WeightInit.XAVIER).build())
.build(); .build();
@ -161,10 +152,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
SparkDl4jMultiLayer master = new SparkDl4jMultiLayer(sc, getBasicConf(), SparkDl4jMultiLayer master = new SparkDl4jMultiLayer(sc, getBasicConf(),
new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 5, 1, 0)); new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 5, 1, 0));
MultiLayerNetwork network2 = master.fitLabeledPoint(data); master.fitLabeledPoint(data);
Evaluation evaluation = new Evaluation();
evaluation.eval(d.getLabels(), network2.output(d.getFeatures()));
System.out.println(evaluation.stats());
} }
@Test(timeout = 120000L) @Test(timeout = 120000L)
@ -465,8 +453,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
tempDirF.deleteOnExit(); tempDirF.deleteOnExit();
int dataSetObjSize = 1; int dataSetObjSize = 1;
int batchSizePerExecutor = 25; int batchSizePerExecutor = 16;
int numSplits = 10; int numSplits = 5;
int averagingFrequency = 3; int averagingFrequency = 3;
int totalExamples = numExecutors() * batchSizePerExecutor * numSplits * averagingFrequency; int totalExamples = numExecutors() * batchSizePerExecutor * numSplits * averagingFrequency;
DataSetIterator iter = new MnistDataSetIterator(dataSetObjSize, totalExamples, false); DataSetIterator iter = new MnistDataSetIterator(dataSetObjSize, totalExamples, false);