cavis/brutex-extended-tests/src/test/java/net/brutex/spark/BrianTest2.java

349 lines
15 KiB
Java

/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
package net.brutex.spark;
import com.fasterxml.jackson.core.Version;
import lombok.extern.slf4j.Slf4j;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.SparkSession;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.transform.TransformProcess;
import org.datavec.api.transform.filter.FilterInvalidValues;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.Writable;
import org.datavec.spark.transform.SparkTransformExecutor;
import org.datavec.spark.transform.misc.StringToWritablesFunction;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.spark.api.TrainingMaster;
import org.deeplearning4j.spark.datavec.DataVecDataSetFunction;
import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer;
import org.deeplearning4j.spark.parameterserver.training.SharedTrainingMaster;
import org.deeplearning4j.ui.api.UIServer;
import org.junit.jupiter.api.*;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.parameterserver.distributed.conf.VoidConfiguration;
import java.io.File;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
/**
* Tests for new Spark Word2Vec implementation
*
* @author raver119@gmail.com
*/
@Slf4j
@Tag("integration")
public class BrianTest2 /*extends BaseDL4JTest*/ {
static {
String OS = System.getProperty("os.name").toLowerCase();
if (OS.contains("win")) {
System.setProperty("hadoop.home.dir", Paths.get("c:\\java\\winutils").toAbsolutePath().toString());
} else {
System.setProperty("hadoop.home.dir", "/");
}
}
public long getTimeoutMilliseconds() {
return 400000L;
}
private JavaSparkContext sc;
/*
@BeforeAll
public void loadData() {
/*
sc.addFile("https://www.openml.org/data/get_csv/1595261/phpMawTba");
org.apache.hadoop.fs.FileSystem hdfs = FileSystem.get( sc.hadoopConfiguration());
try {
String file = SparkFiles.get("phpMawTba");
Path target = new Path("/user/brian/" + "mydata.csv");
//Apache Commons
FileUtils.copyFile(new File(file), hdfs.create(target));
} catch (IOException e) {
e.printStackTrace();
}
}
*/
@BeforeEach
public void setUp() throws Exception {
log.info("Running @BeforeEach scope");
System.setProperty("hadoop.home.dir", Paths.get("c:\\java\\winutils").toAbsolutePath().toString());
Version version = com.fasterxml.jackson.databind.cfg.PackageVersion.VERSION;
System.out.println("Jackson version found: " + version);
System.out.println(System.getProperty("java.vm.name")+"\n"+System.getProperty("java.runtime.version"));
SparkConf sparkConf = new SparkConf()
.setMaster("spark://10.5.5.200:7077")
.setAppName("Brian3")
.set("spark.driver.bindAddress", "10.5.5.145")
.set("spark.network.timeout", "240000")
.set("spark.driver.host", "10.5.5.145")
.set("spark.deploy.mode", "client")
.set("spark.executor.memory", "4g")
.set("spark.cores.max", "2")
.set("spark.worker.cleanup.enabled", "false")
.set("spark.driver.extraJavaOptions", "-Dlog4j.configurationFile=log4j2.xml")
.set("spark.executor.extraJavaOptions", "-Dlog4j.configurationFile=log4j2.xml")
.set("spark.hadoop.fs.defaultFS", "hdfs://10.5.5.200:9000");
SparkSession spark = SparkSession.builder()
.master("spark://10.5.5.200:7077")
.appName("BrianTest2")
.config(sparkConf)
.getOrCreate();
this.sc = JavaSparkContext.fromSparkContext(spark.sparkContext());
/*
Whatever is in classpath (driver), is added to the Spark Executors
*/
final String clpath = System.getProperty("java.class.path");
log.info("java.class.path=\r\n{}\r\n", clpath);
final String separator = System.getProperty("path.separator");
final String[] a = clpath.split(separator);
for(String s : a) {
File f = new File(s);
if(f.exists() && f.isFile() && s.endsWith(".jar")) {
log.info("Adding jar to SparkContext '{}'.", f.getName());
this.sc.addJar(s);
}
}
}
@AfterEach
public void tearDown() throws Exception {
if(sc!=null) this.sc.stop();
UIServer.stopInstance();
}
@Test
public void testStringsTokenization1() throws Exception {
final JavaRDD<String> rdd = sc.textFile("hdfs://10.5.5.200:9000/user/zeppelin/cities_full.csv.gz");
//shrink for Test
//List<String> list = Arrays.asList(new String[]{"asdsad", "asdasdasd", "asdasdasd", "3easdasd"});
//JavaRDD<String> rdd = sc.parallelize(list);
// rdd = rdd.sample(true, 1.0, 1);
log.info("Datenmenge: " + rdd.count());
log.info("Sample: " + rdd.top(3));
Assertions.assertEquals(146889, rdd.count());
}
@Test
public void testSchemaCreation() throws Exception {
log.info(System.getProperty("java.vm.name")+"\n"+System.getProperty("java.runtime.version"));
final JavaRDD<String> rdd = sc.textFile("hdfs://10.5.5.200:9000/user/zeppelin/cities_full.csv.gz");
rdd.cache();
JavaRDD<String> cities = rdd.map( (Function<String, String>) line -> {
return line.split(",")[1];
}).cache();
JavaRDD<String> stateCodeList = rdd.map( (Function<String, String>) line -> {
return line.split(",")[2];
}).cache();
JavaRDD<String> countryCodeList = rdd.map( (Function<String, String>) line -> {
return line.split(",")[3];
}).cache();
CSVRecordReader recordReader = new CSVRecordReader(0, ',');
JavaRDD<List<Writable>> convertedRDD = rdd.map((Function<String, List<Writable>>) s -> {
return new StringToWritablesFunction( recordReader).call(s);
});
//Source Schema
Schema inputSchema = new Schema.Builder()
.addColumnLong("city_id")
.addColumnsString("city_name", "state_code", "country_code")
.addColumnsString("country_full")
.addColumnsDouble("lat", "lon")
.build();
//Running Transformation
/*
TransformProcess tp = new TransformProcess.Builder(inputSchema)
.removeColumns("country_full", "lat", "lon")
.addConstantIntegerColumn("dummy_spalte", 1)
.stringToCategorical("state_code", stateCodeList.distinct().collect())
.stringToCategorical("country_code", countryCodeList.distinct().collect())
.stringToCategorical("city_name", cities.distinct().collect())
.filter(new FilterInvalidValues())
.categoricalToOneHot("city_name")
.categoricalToOneHot("state_code")
.categoricalToOneHot("country_code")
.build();
*/
TransformProcess tp = new TransformProcess.Builder(inputSchema)
.removeAllColumnsExceptFor("country_code", "lat", "lon")
.stringToCategorical("country_code", Arrays.asList("GR", "FR", "DE", "CH"))
.filter(new FilterInvalidValues())
.categoricalToOneHot("country_code")
.build();
//log.info("Final Schema: " +tp.getFinalSchema().toString());
//Execute Transformation Process
//convertedRDD.repartition(1);
//convertedRDD.cache();
JavaRDD<List<Writable>> processedData = SparkTransformExecutor.execute(convertedRDD, tp);
//processedData.repartition(1);
//processedData.cache();
//log.info("Datenmenge nach processing: " + processedData.count());
//Vectorisieren
int labelIndex = 0; //in welcher Spalte ist das Label
int numLabels = 4; //Anzahl der Klassen 0-236 = 237 Werte
DataVecDataSetFunction datavecFunction = new DataVecDataSetFunction(labelIndex, numLabels, false);
JavaRDD<DataSet> rddDataSet = processedData.map(datavecFunction);
log.info("rddDataset: " + rddDataSet.toDebugString());
Random rand = new Random();
rddDataSet.sortBy( (Function<DataSet, Double>) s -> {return rand.nextDouble(); }, true, 8);
//og.info("Sample: " + rddDataSet.sample(false, 0.005, 0).collect());
/* Skip, this will save each record one by one to hdfs
*/
//Now save this hard work
/*
int miniBatchSize = 1; //Minibatch size of the saved DataSet objects
final String exportPath = "hdfs://10.5.5.200:9000/user/brian/data";
JavaRDD<String> paths = rddDataSet.mapPartitionsWithIndex(
new BatchAndExportDataSetsFunction(miniBatchSize, exportPath),
true)
;
paths.collect();
*/
// Configure distributed training required for gradient sharing implementation
VoidConfiguration conf = VoidConfiguration.builder()
.unicastPort(40123) //Port that workers will use to communicate. Use any free port
//.networkMask("10.0.0.0/16") //Network mask for communication. Examples 10.0.0.0/24, or 192.168.0.0/16 etc
.controllerAddress("10.5.5.145")
.build();
//Create the TrainingMaster instance
TrainingMaster trainingMaster = new SharedTrainingMaster.Builder(conf, 1000)
.batchSizePerWorker(20000) //Batch size for training
.updatesThreshold(1e-3) //Update threshold for quantization/compression. See technical explanation page
.workersPerNode(1) // equal to number of GPUs. For CPUs: use 1; use > 1 for large core count
.exportDirectory("/user/brian/")
.build();
//Create Trainingmaster
/*
TrainingMaster trainingMaster = new ParameterAveragingTrainingMaster.Builder(4)
.rddTrainingApproach(RDDTrainingApproach.Direct) //when "export", tries to save everything first
.collectTrainingStats(false).build();
*/
/*
TrainingMaster tm = new SharedTrainingMaster.Builder(voidConfiguration, minibatch)
.thresholdAlgorithm(new AdaptiveThresholdAlgorithm(this.gradientThreshold))
.residualPostProcessor(new ResidualClippingPostProcessor(5, 5))
.build();
*/
//Define Network
MultiLayerConfiguration multiLayerConfiguration = new NeuralNetConfiguration.Builder()
.seed(123)
.updater(new Nesterovs(0.1, 0.9))
.list()
.layer(0, new DenseLayer.Builder().nIn(5).nOut(20).weightInit(WeightInit.XAVIER).activation(Activation.RELU).l2(0.001).build())
.layer(1, new DenseLayer.Builder().nIn(20).nOut(20).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build())
//.layer(2, new DenseLayer.Builder().nIn(9).nOut(9).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build())
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(20).nOut(4).weightInit(WeightInit.XAVIER).activation(Activation.SIGMOID).build())
.build();
//Define SparkNet
SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, multiLayerConfiguration, trainingMaster);
JavaRDD<DataSet>[] split = rddDataSet.randomSplit(new double[] {0.9, 0.1}, 123);
//JavaRDD<DataSet> trainingData = split[0];
JavaRDD<DataSet> trainingData = rddDataSet;
JavaRDD<DataSet> testData = split[1];
//Run Training on subset
for(int i =0; i<4; i++) {
sparkNet.fit(trainingData);
}
//Evaluieren
MultiLayerNetwork finalNet = sparkNet.getNetwork();
//Speichern
Configuration hconf = sc.hadoopConfiguration();
hconf.set("hadoop.tmp.dir", "/user/brian/tmp");
FileSystem fs = FileSystem.get(hconf);
Path p = new Path("hdfs://10.5.5.200:9000/user/brian/model");
//fs.mkdirs(p);
//ModelSerializer.writeModel(finalNet, fs.create(p), true );
Evaluation eval = new Evaluation(4); // outputNum = 10: number of output classes
Iterator<DataSet> iter = testData.toLocalIterator();
log.info("testData has " + testData.count() + " DataSets");
while(iter.hasNext()){
DataSet next = iter.next();
//log.info("getFeatures " + next.getFeatures() );
INDArray output = finalNet.output(next.getFeatures()); //get the networks prediction
//log.info("output "+ output.toStringFull());
eval.eval(next.getLabels(), output); //check the prediction against the true class
//log.info("Predict " + finalNet.predict(next));
}
log.info("Evaluation stats: " + eval.stats());
}
}