parent
6ce620709a
commit
ba269a26ab
|
@ -17,7 +17,6 @@
|
|||
package org.deeplearning4j.spark;
|
||||
|
||||
import org.apache.spark.serializer.SerializerInstance;
|
||||
import org.deeplearning4j.eval.*;
|
||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
|
@ -28,6 +27,9 @@ import org.deeplearning4j.nn.conf.graph.rnn.LastTimeStepVertex;
|
|||
import org.deeplearning4j.nn.conf.layers.*;
|
||||
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.evaluation.IEvaluation;
|
||||
import org.nd4j.evaluation.classification.*;
|
||||
import org.nd4j.evaluation.regression.RegressionEvaluation;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.learning.config.Adam;
|
||||
|
|
|
@ -19,7 +19,6 @@ package org.deeplearning4j.spark.impl.multilayer;
|
|||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.spark.api.java.JavaRDD;
|
||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||
import org.deeplearning4j.eval.Evaluation;
|
||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
|
@ -30,6 +29,7 @@ import org.deeplearning4j.spark.BaseSparkTest;
|
|||
import org.deeplearning4j.spark.api.TrainingMaster;
|
||||
import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.evaluation.classification.Evaluation;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.dataset.DataSet;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
|
|
|
@ -29,15 +29,13 @@ import org.apache.spark.mllib.regression.LabeledPoint;
|
|||
import org.apache.spark.mllib.util.MLUtils;
|
||||
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
|
||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||
import org.deeplearning4j.eval.Evaluation;
|
||||
import org.deeplearning4j.eval.ROC;
|
||||
import org.deeplearning4j.eval.ROCMultiClass;
|
||||
import org.deeplearning4j.nn.api.Layer;
|
||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.layers.BaseLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.BatchNormalization;
|
||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.variational.GaussianReconstructionDistribution;
|
||||
|
@ -56,6 +54,9 @@ import org.junit.Ignore;
|
|||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
import org.nd4j.evaluation.classification.Evaluation;
|
||||
import org.nd4j.evaluation.classification.ROC;
|
||||
import org.nd4j.evaluation.classification.ROCMultiClass;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.DataSet;
|
||||
|
@ -63,6 +64,7 @@ import org.nd4j.linalg.dataset.MultiDataSet;
|
|||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.io.ClassPathResource;
|
||||
import org.nd4j.linalg.learning.config.Adam;
|
||||
import org.nd4j.linalg.learning.config.IUpdater;
|
||||
import org.nd4j.linalg.learning.config.Nesterovs;
|
||||
import org.nd4j.linalg.learning.config.RmsProp;
|
||||
|
@ -70,7 +72,6 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
|
|||
import scala.Tuple2;
|
||||
|
||||
import java.io.File;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.util.*;
|
||||
|
||||
|
@ -121,11 +122,6 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
|||
new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 5, 1, 0));
|
||||
|
||||
MultiLayerNetwork network2 = master.fitLabeledPoint(data);
|
||||
Evaluation evaluation = new Evaluation();
|
||||
evaluation.eval(d.getLabels(), network2.output(d.getFeatures()));
|
||||
System.out.println(evaluation.stats());
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -137,20 +133,15 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
|||
.getAbsolutePath())
|
||||
.toJavaRDD().map(new TestFn());
|
||||
|
||||
DataSet d = new IrisDataSetIterator(150, 150).next();
|
||||
MultiLayerConfiguration conf =
|
||||
new NeuralNetConfiguration.Builder().seed(123)
|
||||
.optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT)
|
||||
.miniBatch(true).maxNumLineSearchIterations(10)
|
||||
.list().layer(0,
|
||||
new DenseLayer.Builder().nIn(4).nOut(100)
|
||||
.weightInit(WeightInit.XAVIER)
|
||||
.activation(Activation.RELU)
|
||||
.build())
|
||||
.layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
|
||||
LossFunctions.LossFunction.MCXENT).nIn(100).nOut(3)
|
||||
.activation(Activation.SOFTMAX)
|
||||
.weightInit(WeightInit.XAVIER).build())
|
||||
.updater(new Adam(1e-6))
|
||||
.weightInit(WeightInit.XAVIER)
|
||||
.list()
|
||||
.layer(new BatchNormalization.Builder().nIn(4).nOut(4).build())
|
||||
.layer(new DenseLayer.Builder().nIn(4).nOut(32).activation(Activation.RELU).build())
|
||||
.layer(new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(32).nOut(3)
|
||||
.activation(Activation.SOFTMAX).build())
|
||||
.build();
|
||||
|
||||
|
||||
|
@ -161,10 +152,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
|||
SparkDl4jMultiLayer master = new SparkDl4jMultiLayer(sc, getBasicConf(),
|
||||
new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 5, 1, 0));
|
||||
|
||||
MultiLayerNetwork network2 = master.fitLabeledPoint(data);
|
||||
Evaluation evaluation = new Evaluation();
|
||||
evaluation.eval(d.getLabels(), network2.output(d.getFeatures()));
|
||||
System.out.println(evaluation.stats());
|
||||
master.fitLabeledPoint(data);
|
||||
}
|
||||
|
||||
@Test(timeout = 120000L)
|
||||
|
@ -465,8 +453,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
|||
tempDirF.deleteOnExit();
|
||||
|
||||
int dataSetObjSize = 1;
|
||||
int batchSizePerExecutor = 25;
|
||||
int numSplits = 10;
|
||||
int batchSizePerExecutor = 16;
|
||||
int numSplits = 5;
|
||||
int averagingFrequency = 3;
|
||||
int totalExamples = numExecutors() * batchSizePerExecutor * numSplits * averagingFrequency;
|
||||
DataSetIterator iter = new MnistDataSetIterator(dataSetObjSize, totalExamples, false);
|
||||
|
|
Loading…
Reference in New Issue