349 lines
15 KiB
Java
349 lines
15 KiB
Java
/*
|
|
*
|
|
* ******************************************************************************
|
|
* *
|
|
* * This program and the accompanying materials are made available under the
|
|
* * terms of the Apache License, Version 2.0 which is available at
|
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
* *
|
|
* * See the NOTICE file distributed with this work for additional
|
|
* * information regarding copyright ownership.
|
|
* * Unless required by applicable law or agreed to in writing, software
|
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
* * License for the specific language governing permissions and limitations
|
|
* * under the License.
|
|
* *
|
|
* * SPDX-License-Identifier: Apache-2.0
|
|
* *****************************************************************************
|
|
*
|
|
*/
|
|
package net.brutex.spark;
|
|
|
|
import com.fasterxml.jackson.core.Version;
|
|
import lombok.extern.slf4j.Slf4j;
|
|
import org.apache.hadoop.conf.Configuration;
|
|
import org.apache.hadoop.fs.FileSystem;
|
|
import org.apache.hadoop.fs.Path;
|
|
import org.apache.spark.SparkConf;
|
|
import org.apache.spark.api.java.JavaRDD;
|
|
import org.apache.spark.api.java.JavaSparkContext;
|
|
import org.apache.spark.api.java.function.Function;
|
|
import org.apache.spark.sql.SparkSession;
|
|
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
|
|
import org.datavec.api.transform.TransformProcess;
|
|
import org.datavec.api.transform.filter.FilterInvalidValues;
|
|
import org.datavec.api.transform.schema.Schema;
|
|
import org.datavec.api.Writable;
|
|
import org.datavec.spark.transform.SparkTransformExecutor;
|
|
import org.datavec.spark.transform.misc.StringToWritablesFunction;
|
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
|
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
|
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
|
import org.deeplearning4j.nn.weights.WeightInit;
|
|
import org.deeplearning4j.spark.api.TrainingMaster;
|
|
import org.deeplearning4j.spark.datavec.DataVecDataSetFunction;
|
|
import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer;
|
|
import org.deeplearning4j.spark.parameterserver.training.SharedTrainingMaster;
|
|
import org.deeplearning4j.ui.api.UIServer;
|
|
import org.junit.jupiter.api.*;
|
|
import org.nd4j.evaluation.classification.Evaluation;
|
|
import org.nd4j.linalg.activations.Activation;
|
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
import org.nd4j.linalg.dataset.DataSet;
|
|
import org.nd4j.linalg.learning.config.Nesterovs;
|
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
|
import org.nd4j.parameterserver.distributed.conf.VoidConfiguration;
|
|
|
|
import java.io.File;
|
|
import java.nio.file.Paths;
|
|
import java.util.Arrays;
|
|
import java.util.Iterator;
|
|
import java.util.List;
|
|
import java.util.Random;
|
|
|
|
/**
|
|
* Tests for new Spark Word2Vec implementation
|
|
*
|
|
* @author raver119@gmail.com
|
|
*/
|
|
@Slf4j
|
|
@Tag("integration")
|
|
public class BrianTest2 /*extends BaseDL4JTest*/ {
|
|
static {
|
|
String OS = System.getProperty("os.name").toLowerCase();
|
|
|
|
if (OS.contains("win")) {
|
|
System.setProperty("hadoop.home.dir", Paths.get("c:\\java\\winutils").toAbsolutePath().toString());
|
|
} else {
|
|
System.setProperty("hadoop.home.dir", "/");
|
|
}
|
|
}
|
|
|
|
public long getTimeoutMilliseconds() {
|
|
return 400000L;
|
|
}
|
|
|
|
private JavaSparkContext sc;
|
|
|
|
|
|
/*
|
|
@BeforeAll
|
|
public void loadData() {
|
|
|
|
|
|
/*
|
|
sc.addFile("https://www.openml.org/data/get_csv/1595261/phpMawTba");
|
|
org.apache.hadoop.fs.FileSystem hdfs = FileSystem.get( sc.hadoopConfiguration());
|
|
try {
|
|
String file = SparkFiles.get("phpMawTba");
|
|
Path target = new Path("/user/brian/" + "mydata.csv");
|
|
//Apache Commons
|
|
FileUtils.copyFile(new File(file), hdfs.create(target));
|
|
} catch (IOException e) {
|
|
e.printStackTrace();
|
|
}
|
|
|
|
|
|
}
|
|
*/
|
|
|
|
@BeforeEach
|
|
public void setUp() throws Exception {
|
|
log.info("Running @BeforeEach scope");
|
|
System.setProperty("hadoop.home.dir", Paths.get("c:\\java\\winutils").toAbsolutePath().toString());
|
|
Version version = com.fasterxml.jackson.databind.cfg.PackageVersion.VERSION;
|
|
System.out.println("Jackson version found: " + version);
|
|
System.out.println(System.getProperty("java.vm.name")+"\n"+System.getProperty("java.runtime.version"));
|
|
|
|
SparkConf sparkConf = new SparkConf()
|
|
.setMaster("spark://10.5.5.200:7077")
|
|
.setAppName("Brian3")
|
|
.set("spark.driver.bindAddress", "10.5.5.145")
|
|
.set("spark.network.timeout", "240000")
|
|
.set("spark.driver.host", "10.5.5.145")
|
|
.set("spark.deploy.mode", "client")
|
|
.set("spark.executor.memory", "4g")
|
|
.set("spark.cores.max", "2")
|
|
.set("spark.worker.cleanup.enabled", "false")
|
|
.set("spark.driver.extraJavaOptions", "-Dlog4j.configurationFile=log4j2.xml")
|
|
.set("spark.executor.extraJavaOptions", "-Dlog4j.configurationFile=log4j2.xml")
|
|
.set("spark.hadoop.fs.defaultFS", "hdfs://10.5.5.200:9000");
|
|
|
|
SparkSession spark = SparkSession.builder()
|
|
.master("spark://10.5.5.200:7077")
|
|
.appName("BrianTest2")
|
|
.config(sparkConf)
|
|
.getOrCreate();
|
|
|
|
this.sc = JavaSparkContext.fromSparkContext(spark.sparkContext());
|
|
|
|
/*
|
|
Whatever is in classpath (driver), is added to the Spark Executors
|
|
*/
|
|
final String clpath = System.getProperty("java.class.path");
|
|
log.info("java.class.path=\r\n{}\r\n", clpath);
|
|
final String separator = System.getProperty("path.separator");
|
|
final String[] a = clpath.split(separator);
|
|
for(String s : a) {
|
|
File f = new File(s);
|
|
if(f.exists() && f.isFile() && s.endsWith(".jar")) {
|
|
log.info("Adding jar to SparkContext '{}'.", f.getName());
|
|
this.sc.addJar(s);
|
|
}
|
|
}
|
|
}
|
|
|
|
@AfterEach
|
|
public void tearDown() throws Exception {
|
|
if(sc!=null) this.sc.stop();
|
|
UIServer.stopInstance();
|
|
}
|
|
|
|
@Test
|
|
public void testStringsTokenization1() throws Exception {
|
|
|
|
final JavaRDD<String> rdd = sc.textFile("hdfs://10.5.5.200:9000/user/zeppelin/cities_full.csv.gz");
|
|
//shrink for Test
|
|
//List<String> list = Arrays.asList(new String[]{"asdsad", "asdasdasd", "asdasdasd", "3easdasd"});
|
|
//JavaRDD<String> rdd = sc.parallelize(list);
|
|
|
|
// rdd = rdd.sample(true, 1.0, 1);
|
|
log.info("Datenmenge: " + rdd.count());
|
|
log.info("Sample: " + rdd.top(3));
|
|
|
|
Assertions.assertEquals(146889, rdd.count());
|
|
}
|
|
|
|
@Test
|
|
public void testSchemaCreation() throws Exception {
|
|
log.info(System.getProperty("java.vm.name")+"\n"+System.getProperty("java.runtime.version"));
|
|
final JavaRDD<String> rdd = sc.textFile("hdfs://10.5.5.200:9000/user/zeppelin/cities_full.csv.gz");
|
|
rdd.cache();
|
|
|
|
JavaRDD<String> cities = rdd.map( (Function<String, String>) line -> {
|
|
return line.split(",")[1];
|
|
}).cache();
|
|
|
|
JavaRDD<String> stateCodeList = rdd.map( (Function<String, String>) line -> {
|
|
return line.split(",")[2];
|
|
}).cache();
|
|
|
|
JavaRDD<String> countryCodeList = rdd.map( (Function<String, String>) line -> {
|
|
return line.split(",")[3];
|
|
}).cache();
|
|
|
|
|
|
CSVRecordReader recordReader = new CSVRecordReader(0, ',');
|
|
JavaRDD<List<Writable>> convertedRDD = rdd.map((Function<String, List<Writable>>) s -> {
|
|
return new StringToWritablesFunction( recordReader).call(s);
|
|
});
|
|
|
|
//Source Schema
|
|
Schema inputSchema = new Schema.Builder()
|
|
.addColumnLong("city_id")
|
|
.addColumnsString("city_name", "state_code", "country_code")
|
|
.addColumnsString("country_full")
|
|
.addColumnsDouble("lat", "lon")
|
|
.build();
|
|
|
|
//Running Transformation
|
|
/*
|
|
TransformProcess tp = new TransformProcess.Builder(inputSchema)
|
|
.removeColumns("country_full", "lat", "lon")
|
|
.addConstantIntegerColumn("dummy_spalte", 1)
|
|
.stringToCategorical("state_code", stateCodeList.distinct().collect())
|
|
.stringToCategorical("country_code", countryCodeList.distinct().collect())
|
|
.stringToCategorical("city_name", cities.distinct().collect())
|
|
.filter(new FilterInvalidValues())
|
|
.categoricalToOneHot("city_name")
|
|
.categoricalToOneHot("state_code")
|
|
.categoricalToOneHot("country_code")
|
|
.build();
|
|
*/
|
|
TransformProcess tp = new TransformProcess.Builder(inputSchema)
|
|
.removeAllColumnsExceptFor("country_code", "lat", "lon")
|
|
.stringToCategorical("country_code", Arrays.asList("GR", "FR", "DE", "CH"))
|
|
.filter(new FilterInvalidValues())
|
|
.categoricalToOneHot("country_code")
|
|
.build();
|
|
|
|
//log.info("Final Schema: " +tp.getFinalSchema().toString());
|
|
//Execute Transformation Process
|
|
//convertedRDD.repartition(1);
|
|
//convertedRDD.cache();
|
|
JavaRDD<List<Writable>> processedData = SparkTransformExecutor.execute(convertedRDD, tp);
|
|
//processedData.repartition(1);
|
|
//processedData.cache();
|
|
//log.info("Datenmenge nach processing: " + processedData.count());
|
|
|
|
|
|
//Vectorisieren
|
|
int labelIndex = 0; //in welcher Spalte ist das Label
|
|
int numLabels = 4; //Anzahl der Klassen 0-236 = 237 Werte
|
|
|
|
DataVecDataSetFunction datavecFunction = new DataVecDataSetFunction(labelIndex, numLabels, false);
|
|
JavaRDD<DataSet> rddDataSet = processedData.map(datavecFunction);
|
|
log.info("rddDataset: " + rddDataSet.toDebugString());
|
|
Random rand = new Random();
|
|
rddDataSet.sortBy( (Function<DataSet, Double>) s -> {return rand.nextDouble(); }, true, 8);
|
|
|
|
//og.info("Sample: " + rddDataSet.sample(false, 0.005, 0).collect());
|
|
|
|
/* Skip, this will save each record one by one to hdfs
|
|
*/
|
|
//Now save this hard work
|
|
/*
|
|
int miniBatchSize = 1; //Minibatch size of the saved DataSet objects
|
|
final String exportPath = "hdfs://10.5.5.200:9000/user/brian/data";
|
|
JavaRDD<String> paths = rddDataSet.mapPartitionsWithIndex(
|
|
new BatchAndExportDataSetsFunction(miniBatchSize, exportPath),
|
|
true)
|
|
;
|
|
paths.collect();
|
|
*/
|
|
|
|
|
|
// Configure distributed training required for gradient sharing implementation
|
|
VoidConfiguration conf = VoidConfiguration.builder()
|
|
.unicastPort(40123) //Port that workers will use to communicate. Use any free port
|
|
//.networkMask("10.0.0.0/16") //Network mask for communication. Examples 10.0.0.0/24, or 192.168.0.0/16 etc
|
|
.controllerAddress("10.5.5.145")
|
|
.build();
|
|
|
|
//Create the TrainingMaster instance
|
|
TrainingMaster trainingMaster = new SharedTrainingMaster.Builder(conf, 1000)
|
|
.batchSizePerWorker(20000) //Batch size for training
|
|
.updatesThreshold(1e-3) //Update threshold for quantization/compression. See technical explanation page
|
|
.workersPerNode(1) // equal to number of GPUs. For CPUs: use 1; use > 1 for large core count
|
|
.exportDirectory("/user/brian/")
|
|
.build();
|
|
|
|
//Create Trainingmaster
|
|
/*
|
|
TrainingMaster trainingMaster = new ParameterAveragingTrainingMaster.Builder(4)
|
|
.rddTrainingApproach(RDDTrainingApproach.Direct) //when "export", tries to save everything first
|
|
.collectTrainingStats(false).build();
|
|
*/
|
|
/*
|
|
TrainingMaster tm = new SharedTrainingMaster.Builder(voidConfiguration, minibatch)
|
|
.thresholdAlgorithm(new AdaptiveThresholdAlgorithm(this.gradientThreshold))
|
|
.residualPostProcessor(new ResidualClippingPostProcessor(5, 5))
|
|
.build();
|
|
*/
|
|
//Define Network
|
|
|
|
MultiLayerConfiguration multiLayerConfiguration = new NeuralNetConfiguration.Builder()
|
|
.seed(123)
|
|
.updater(new Nesterovs(0.1, 0.9))
|
|
.list()
|
|
.layer(0, new DenseLayer.Builder().nIn(5).nOut(20).weightInit(WeightInit.XAVIER).activation(Activation.RELU).l2(0.001).build())
|
|
.layer(1, new DenseLayer.Builder().nIn(20).nOut(20).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build())
|
|
//.layer(2, new DenseLayer.Builder().nIn(9).nOut(9).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build())
|
|
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(20).nOut(4).weightInit(WeightInit.XAVIER).activation(Activation.SIGMOID).build())
|
|
.build();
|
|
|
|
//Define SparkNet
|
|
SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, multiLayerConfiguration, trainingMaster);
|
|
|
|
|
|
JavaRDD<DataSet>[] split = rddDataSet.randomSplit(new double[] {0.9, 0.1}, 123);
|
|
//JavaRDD<DataSet> trainingData = split[0];
|
|
JavaRDD<DataSet> trainingData = rddDataSet;
|
|
JavaRDD<DataSet> testData = split[1];
|
|
|
|
//Run Training on subset
|
|
for(int i =0; i<4; i++) {
|
|
sparkNet.fit(trainingData);
|
|
}
|
|
|
|
//Evaluieren
|
|
MultiLayerNetwork finalNet = sparkNet.getNetwork();
|
|
|
|
//Speichern
|
|
Configuration hconf = sc.hadoopConfiguration();
|
|
hconf.set("hadoop.tmp.dir", "/user/brian/tmp");
|
|
FileSystem fs = FileSystem.get(hconf);
|
|
Path p = new Path("hdfs://10.5.5.200:9000/user/brian/model");
|
|
//fs.mkdirs(p);
|
|
//ModelSerializer.writeModel(finalNet, fs.create(p), true );
|
|
|
|
Evaluation eval = new Evaluation(4); // outputNum = 10: number of output classes
|
|
Iterator<DataSet> iter = testData.toLocalIterator();
|
|
log.info("testData has " + testData.count() + " DataSets");
|
|
while(iter.hasNext()){
|
|
DataSet next = iter.next();
|
|
//log.info("getFeatures " + next.getFeatures() );
|
|
INDArray output = finalNet.output(next.getFeatures()); //get the networks prediction
|
|
//log.info("output "+ output.toStringFull());
|
|
eval.eval(next.getLabels(), output); //check the prediction against the true class
|
|
//log.info("Predict " + finalNet.predict(next));
|
|
}
|
|
log.info("Evaluation stats: " + eval.stats());
|
|
|
|
}
|
|
|
|
}
|