parent
6ce620709a
commit
ba269a26ab
|
@ -17,7 +17,6 @@
|
||||||
package org.deeplearning4j.spark;
|
package org.deeplearning4j.spark;
|
||||||
|
|
||||||
import org.apache.spark.serializer.SerializerInstance;
|
import org.apache.spark.serializer.SerializerInstance;
|
||||||
import org.deeplearning4j.eval.*;
|
|
||||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
@ -28,6 +27,9 @@ import org.deeplearning4j.nn.conf.graph.rnn.LastTimeStepVertex;
|
||||||
import org.deeplearning4j.nn.conf.layers.*;
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
|
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.nd4j.evaluation.IEvaluation;
|
||||||
|
import org.nd4j.evaluation.classification.*;
|
||||||
|
import org.nd4j.evaluation.regression.RegressionEvaluation;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.learning.config.Adam;
|
import org.nd4j.linalg.learning.config.Adam;
|
||||||
|
|
|
@ -19,7 +19,6 @@ package org.deeplearning4j.spark.impl.multilayer;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||||
import org.deeplearning4j.eval.Evaluation;
|
|
||||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
@ -30,6 +29,7 @@ import org.deeplearning4j.spark.BaseSparkTest;
|
||||||
import org.deeplearning4j.spark.api.TrainingMaster;
|
import org.deeplearning4j.spark.api.TrainingMaster;
|
||||||
import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster;
|
import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.nd4j.evaluation.classification.Evaluation;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
|
|
|
@ -29,15 +29,13 @@ import org.apache.spark.mllib.regression.LabeledPoint;
|
||||||
import org.apache.spark.mllib.util.MLUtils;
|
import org.apache.spark.mllib.util.MLUtils;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||||
import org.deeplearning4j.eval.Evaluation;
|
|
||||||
import org.deeplearning4j.eval.ROC;
|
|
||||||
import org.deeplearning4j.eval.ROCMultiClass;
|
|
||||||
import org.deeplearning4j.nn.api.Layer;
|
import org.deeplearning4j.nn.api.Layer;
|
||||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.layers.BaseLayer;
|
import org.deeplearning4j.nn.conf.layers.BaseLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.BatchNormalization;
|
||||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.variational.GaussianReconstructionDistribution;
|
import org.deeplearning4j.nn.conf.layers.variational.GaussianReconstructionDistribution;
|
||||||
|
@ -56,6 +54,9 @@ import org.junit.Ignore;
|
||||||
import org.junit.Rule;
|
import org.junit.Rule;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.junit.rules.TemporaryFolder;
|
import org.junit.rules.TemporaryFolder;
|
||||||
|
import org.nd4j.evaluation.classification.Evaluation;
|
||||||
|
import org.nd4j.evaluation.classification.ROC;
|
||||||
|
import org.nd4j.evaluation.classification.ROCMultiClass;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
|
@ -63,6 +64,7 @@ import org.nd4j.linalg.dataset.MultiDataSet;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.io.ClassPathResource;
|
import org.nd4j.linalg.io.ClassPathResource;
|
||||||
|
import org.nd4j.linalg.learning.config.Adam;
|
||||||
import org.nd4j.linalg.learning.config.IUpdater;
|
import org.nd4j.linalg.learning.config.IUpdater;
|
||||||
import org.nd4j.linalg.learning.config.Nesterovs;
|
import org.nd4j.linalg.learning.config.Nesterovs;
|
||||||
import org.nd4j.linalg.learning.config.RmsProp;
|
import org.nd4j.linalg.learning.config.RmsProp;
|
||||||
|
@ -70,7 +72,6 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
import scala.Tuple2;
|
import scala.Tuple2;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.nio.file.Files;
|
|
||||||
import java.nio.file.Path;
|
import java.nio.file.Path;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
|
@ -121,11 +122,6 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 5, 1, 0));
|
new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 5, 1, 0));
|
||||||
|
|
||||||
MultiLayerNetwork network2 = master.fitLabeledPoint(data);
|
MultiLayerNetwork network2 = master.fitLabeledPoint(data);
|
||||||
Evaluation evaluation = new Evaluation();
|
|
||||||
evaluation.eval(d.getLabels(), network2.output(d.getFeatures()));
|
|
||||||
System.out.println(evaluation.stats());
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -137,20 +133,15 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
.getAbsolutePath())
|
.getAbsolutePath())
|
||||||
.toJavaRDD().map(new TestFn());
|
.toJavaRDD().map(new TestFn());
|
||||||
|
|
||||||
DataSet d = new IrisDataSetIterator(150, 150).next();
|
|
||||||
MultiLayerConfiguration conf =
|
MultiLayerConfiguration conf =
|
||||||
new NeuralNetConfiguration.Builder().seed(123)
|
new NeuralNetConfiguration.Builder().seed(123)
|
||||||
.optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT)
|
.updater(new Adam(1e-6))
|
||||||
.miniBatch(true).maxNumLineSearchIterations(10)
|
|
||||||
.list().layer(0,
|
|
||||||
new DenseLayer.Builder().nIn(4).nOut(100)
|
|
||||||
.weightInit(WeightInit.XAVIER)
|
.weightInit(WeightInit.XAVIER)
|
||||||
.activation(Activation.RELU)
|
.list()
|
||||||
.build())
|
.layer(new BatchNormalization.Builder().nIn(4).nOut(4).build())
|
||||||
.layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
|
.layer(new DenseLayer.Builder().nIn(4).nOut(32).activation(Activation.RELU).build())
|
||||||
LossFunctions.LossFunction.MCXENT).nIn(100).nOut(3)
|
.layer(new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(32).nOut(3)
|
||||||
.activation(Activation.SOFTMAX)
|
.activation(Activation.SOFTMAX).build())
|
||||||
.weightInit(WeightInit.XAVIER).build())
|
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
|
||||||
|
@ -161,10 +152,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
SparkDl4jMultiLayer master = new SparkDl4jMultiLayer(sc, getBasicConf(),
|
SparkDl4jMultiLayer master = new SparkDl4jMultiLayer(sc, getBasicConf(),
|
||||||
new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 5, 1, 0));
|
new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 5, 1, 0));
|
||||||
|
|
||||||
MultiLayerNetwork network2 = master.fitLabeledPoint(data);
|
master.fitLabeledPoint(data);
|
||||||
Evaluation evaluation = new Evaluation();
|
|
||||||
evaluation.eval(d.getLabels(), network2.output(d.getFeatures()));
|
|
||||||
System.out.println(evaluation.stats());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(timeout = 120000L)
|
@Test(timeout = 120000L)
|
||||||
|
@ -465,8 +453,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
tempDirF.deleteOnExit();
|
tempDirF.deleteOnExit();
|
||||||
|
|
||||||
int dataSetObjSize = 1;
|
int dataSetObjSize = 1;
|
||||||
int batchSizePerExecutor = 25;
|
int batchSizePerExecutor = 16;
|
||||||
int numSplits = 10;
|
int numSplits = 5;
|
||||||
int averagingFrequency = 3;
|
int averagingFrequency = 3;
|
||||||
int totalExamples = numExecutors() * batchSizePerExecutor * numSplits * averagingFrequency;
|
int totalExamples = numExecutors() * batchSizePerExecutor * numSplits * averagingFrequency;
|
||||||
DataSetIterator iter = new MnistDataSetIterator(dataSetObjSize, totalExamples, false);
|
DataSetIterator iter = new MnistDataSetIterator(dataSetObjSize, totalExamples, false);
|
||||||
|
|
Loading…
Reference in New Issue