Another round of small fixes (#241)

* Small base spark test fix; ROC toString for empty ROC

Signed-off-by: Alex Black <blacka101@gmail.com>

* More fixes

Signed-off-by: Alex Black <blacka101@gmail.com>
master
Alex Black 2019-09-05 17:01:47 +10:00 committed by GitHub
parent 87d873929f
commit 52d2795193
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 68 additions and 4 deletions

View File

@ -63,7 +63,7 @@
<dependency>
<groupId>com.fasterxml.jackson.module</groupId>
<artifactId>jackson-module-scala_2.11</artifactId>
<version>${jackson.version}</version>
<version>2.6.7</version>
</dependency>
</dependencies>

View File

@ -23,6 +23,9 @@ import org.junit.After;
import org.junit.Before;
import java.io.Serializable;
import java.lang.reflect.Field;
import java.util.Collections;
import java.util.Map;
/**
* Created by agibsonccc on 1/23/15.
@ -37,7 +40,9 @@ public abstract class BaseSparkTest implements Serializable {
@After
public void after() {
sc.close();
if(sc != null) {
sc.close();
}
sc = null;
}
@ -48,6 +53,30 @@ public abstract class BaseSparkTest implements Serializable {
public JavaSparkContext getContext() {
if (sc != null)
return sc;
//Ensure SPARK_USER environment variable is set for Spark tests
String u = System.getenv("SPARK_USER");
Map<String, String> env = System.getenv();
if(u == null || u.isEmpty()) {
try {
Class[] classes = Collections.class.getDeclaredClasses();
for (Class cl : classes) {
if ("java.util.Collections$UnmodifiableMap".equals(cl.getName())) {
Field field = cl.getDeclaredField("m");
field.setAccessible(true);
Object obj = field.get(env);
Map<String, String> map = (Map<String, String>) obj;
String user = System.getProperty("user.name");
if (user == null || user.isEmpty())
user = "user";
map.put("SPARK_USER", user);
}
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}
// set to test mode
SparkConf sparkConf = new SparkConf().setMaster("local[4]").set("spark.driver.host", "localhost")
.setAppName("sparktest")

View File

@ -19,6 +19,10 @@ package org.deeplearning4j.spark;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import java.lang.reflect.Field;
import java.util.Collections;
import java.util.Map;
/**
* Created by Alex on 04/07/2017.
*/
@ -30,6 +34,31 @@ public class BaseSparkKryoTest extends BaseSparkTest {
return sc;
}
//Ensure SPARK_USER environment variable is set for Spark Kryo tests
String u = System.getenv("SPARK_USER");
if(u == null || u.isEmpty()){
try {
Class[] classes = Collections.class.getDeclaredClasses();
Map<String, String> env = System.getenv();
for (Class cl : classes) {
if ("java.util.Collections$UnmodifiableMap".equals(cl.getName())) {
Field field = cl.getDeclaredField("m");
field.setAccessible(true);
Object obj = field.get(env);
Map<String, String> map = (Map<String, String>) obj;
String user = System.getProperty("user.name");
if(user == null || user.isEmpty())
user = "user";
map.put("SPARK_USER", user);
}
}
} catch (Exception e){
throw new RuntimeException(e);
}
}
SparkConf sparkConf = new SparkConf().setMaster("local[" + numExecutors() + "]").setAppName("sparktest");
sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer");

View File

@ -74,7 +74,9 @@ public abstract class BaseSparkTest implements Serializable {
@After
public void after() {
sc.close();
if(sc != null) {
sc.close();
}
sc = null;
}

View File

@ -75,7 +75,6 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
@EqualsAndHashCode(callSuper = true,
exclude = {"auc", "auprc", "probAndLabel", "exactAllocBlockSize", "rocCurve", "prCurve", "axis"})
@Data
@ToString(exclude = {"probAndLabel", "exactAllocBlockSize", "rocCurve", "prCurve"})
@JsonIgnoreProperties({"probAndLabel", "exactAllocBlockSize"})
@JsonSerialize(using = ROCSerializer.class)
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY)
@ -824,6 +823,11 @@ public class ROC extends BaseEvaluation<ROC> {
return sb.toString();
}
@Override
public String toString(){
return stats();
}
public double scoreForMetric(Metric metric){
switch (metric){
case AUROC: