2022-09-20 15:40:53 +02:00

129 lines
4.7 KiB
Java

/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.spark;
import lombok.extern.slf4j.Slf4j;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.SparkSession;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import java.io.File;
import java.io.Serializable;
@Slf4j
public abstract class BaseSparkTest implements Serializable {
protected static JavaSparkContext sc;
@BeforeEach
public void before() {
sc = getContext();
}
@AfterEach
public synchronized void after() {
sc.close();
//Wait until it's stopped, to avoid race conditions during tests
for (int i = 0; i < 100; i++) {
if (!sc.sc().stopped().get()) {
try {
Thread.sleep(100L);
} catch (InterruptedException e) {
log.error("",e);
}
} else {
break;
}
}
if (!sc.sc().stopped().get()) {
throw new RuntimeException("Spark context is not stopped after 10s");
}
sc = null;
}
public synchronized JavaSparkContext getContext() {
if (sc != null)
return sc;
/*
SparkConf sparkConf = new SparkConf().setMaster("local[*]").set("spark.driver.host", "localhost")
.set("spark.driverEnv.SPARK_LOCAL_IP", "127.0.0.1")
.set("spark.executorEnv.SPARK_LOCAL_IP", "127.0.0.1").setAppName("sparktest");
*/
SparkConf sparkConf = new SparkConf()
.setMaster("spark://10.5.5.200:7077")
.setAppName("Brian4")
.set("spark.driver.bindAddress", "10.5.5.145")
.set("spark.network.timeout", "240000")
.set("spark.driver.host", "10.5.5.145")
.set("spark.deploy.mode", "client")
.set("spark.executor.memory", "4g")
//.set("spark.cores.max", "8")
.set("spark.worker.cleanup.enabled", "true")
.set("spark.driver.extraJavaOptions", "-Dlog4j.configurationFile=log4j2.xml")
.set("spark.executor.extraJavaOptions", "-Dlog4j.configurationFile=log4j2.xml")
//.set("spark.driver.extraClassPath", "brian-spark-dl4j-tests-1.0.1-SNAPSHOT-bin.jar;brian-spark-dl4j-tests-1.0.1-SNAPSHOT-tests.jar")
//.set("spark.executor.extraClassPath", "brian-spark-dl4j-tests-1.0.1-SNAPSHOT-bin.jar;brian-spark-dl4j-tests-1.0.1-SNAPSHOT-tests.jar")
.set("spark.hadoop.fs.defaultFS", "hdfs://10.5.5.200:9000");
//.set("spark.jars.packages", "com.fasterxml.jackson.datatype:jackson-datatype-joda:2.10.4");
//.set("spark.driver.cores", "2")
//.set("spark.driver.memory", "8g")
//.set("spark.driver.host", "10.5.5.145")
//.setExecutorEnv("spark.executor.cores", "2")
//.setExecutorEnv("spark.executor.memory", "2g")
//.set("spark.submit.deployMode", "client")
if (useKryo()) {
sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer");
}
SparkSession spark = SparkSession.builder()
.config(sparkConf)
.getOrCreate();
sc = JavaSparkContext.fromSparkContext(spark.sparkContext());
/*
Whatever is in classpath (driver), is added to the Spark Executors
*/
final String clpath = System.getProperty("java.class.path");
final String separator = System.getProperty("path.separator");
final String[] a = clpath.split(separator);
for(String s : a) {
File f = new File(s);
if(f.exists() && f.isFile() && s.endsWith(".jar")) {
sc.addJar(s);
}
}
return sc;
}
public boolean useKryo(){
return false;
}
}