Small fixes (#223)

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-09-03 10:48:59 +10:00 committed by GitHub
parent 6ce620709a
commit ba269a26ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 19 additions and 29 deletions

View File

@ -17,7 +17,6 @@
package org.deeplearning4j.spark;
import org.apache.spark.serializer.SerializerInstance;
import org.deeplearning4j.eval.*;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
@ -28,6 +27,9 @@ import org.deeplearning4j.nn.conf.graph.rnn.LastTimeStepVertex;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
import org.junit.Test;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.classification.*;
import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;

View File

@ -19,7 +19,6 @@ package org.deeplearning4j.spark.impl.multilayer;
import lombok.extern.slf4j.Slf4j;
import org.apache.spark.api.java.JavaRDD;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
@ -30,6 +29,7 @@ import org.deeplearning4j.spark.BaseSparkTest;
import org.deeplearning4j.spark.api.TrainingMaster;
import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster;
import org.junit.Test;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;

View File

@ -29,15 +29,13 @@ import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.util.MLUtils;
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.eval.ROC;
import org.deeplearning4j.eval.ROCMultiClass;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.deeplearning4j.nn.conf.layers.BatchNormalization;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.variational.GaussianReconstructionDistribution;
@ -56,6 +54,9 @@ import org.junit.Ignore;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.ROC;
import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
@ -63,6 +64,7 @@ import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.learning.config.RmsProp;
@ -70,7 +72,6 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
import scala.Tuple2;
import java.io.File;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.*;
@ -121,11 +122,6 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 5, 1, 0));
MultiLayerNetwork network2 = master.fitLabeledPoint(data);
Evaluation evaluation = new Evaluation();
evaluation.eval(d.getLabels(), network2.output(d.getFeatures()));
System.out.println(evaluation.stats());
}
@ -137,20 +133,15 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
.getAbsolutePath())
.toJavaRDD().map(new TestFn());
DataSet d = new IrisDataSetIterator(150, 150).next();
MultiLayerConfiguration conf =
new NeuralNetConfiguration.Builder().seed(123)
.optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT)
.miniBatch(true).maxNumLineSearchIterations(10)
.list().layer(0,
new DenseLayer.Builder().nIn(4).nOut(100)
.weightInit(WeightInit.XAVIER)
.activation(Activation.RELU)
.build())
.layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
LossFunctions.LossFunction.MCXENT).nIn(100).nOut(3)
.activation(Activation.SOFTMAX)
.weightInit(WeightInit.XAVIER).build())
.updater(new Adam(1e-6))
.weightInit(WeightInit.XAVIER)
.list()
.layer(new BatchNormalization.Builder().nIn(4).nOut(4).build())
.layer(new DenseLayer.Builder().nIn(4).nOut(32).activation(Activation.RELU).build())
.layer(new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(32).nOut(3)
.activation(Activation.SOFTMAX).build())
.build();
@ -161,10 +152,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
SparkDl4jMultiLayer master = new SparkDl4jMultiLayer(sc, getBasicConf(),
new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 5, 1, 0));
MultiLayerNetwork network2 = master.fitLabeledPoint(data);
Evaluation evaluation = new Evaluation();
evaluation.eval(d.getLabels(), network2.output(d.getFeatures()));
System.out.println(evaluation.stats());
master.fitLabeledPoint(data);
}
@Test(timeout = 120000L)
@ -465,8 +453,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
tempDirF.deleteOnExit();
int dataSetObjSize = 1;
int batchSizePerExecutor = 25;
int numSplits = 10;
int batchSizePerExecutor = 16;
int numSplits = 5;
int averagingFrequency = 3;
int totalExamples = numExecutors() * batchSizePerExecutor * numSplits * averagingFrequency;
DataSetIterator iter = new MnistDataSetIterator(dataSetObjSize, totalExamples, false);