Another round of small fixes (#241)
* Small base spark test fix; ROC toString for empty ROC Signed-off-by: Alex Black <blacka101@gmail.com> * More fixes Signed-off-by: Alex Black <blacka101@gmail.com>master
parent
87d873929f
commit
52d2795193
|
@ -63,7 +63,7 @@
|
|||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.module</groupId>
|
||||
<artifactId>jackson-module-scala_2.11</artifactId>
|
||||
<version>${jackson.version}</version>
|
||||
<version>2.6.7</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
|
|
|
@ -23,6 +23,9 @@ import org.junit.After;
|
|||
import org.junit.Before;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.lang.reflect.Field;
|
||||
import java.util.Collections;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Created by agibsonccc on 1/23/15.
|
||||
|
@ -37,7 +40,9 @@ public abstract class BaseSparkTest implements Serializable {
|
|||
|
||||
@After
|
||||
public void after() {
|
||||
sc.close();
|
||||
if(sc != null) {
|
||||
sc.close();
|
||||
}
|
||||
sc = null;
|
||||
}
|
||||
|
||||
|
@ -48,6 +53,30 @@ public abstract class BaseSparkTest implements Serializable {
|
|||
public JavaSparkContext getContext() {
|
||||
if (sc != null)
|
||||
return sc;
|
||||
|
||||
//Ensure SPARK_USER environment variable is set for Spark tests
|
||||
String u = System.getenv("SPARK_USER");
|
||||
Map<String, String> env = System.getenv();
|
||||
if(u == null || u.isEmpty()) {
|
||||
try {
|
||||
Class[] classes = Collections.class.getDeclaredClasses();
|
||||
for (Class cl : classes) {
|
||||
if ("java.util.Collections$UnmodifiableMap".equals(cl.getName())) {
|
||||
Field field = cl.getDeclaredField("m");
|
||||
field.setAccessible(true);
|
||||
Object obj = field.get(env);
|
||||
Map<String, String> map = (Map<String, String>) obj;
|
||||
String user = System.getProperty("user.name");
|
||||
if (user == null || user.isEmpty())
|
||||
user = "user";
|
||||
map.put("SPARK_USER", user);
|
||||
}
|
||||
}
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
// set to test mode
|
||||
SparkConf sparkConf = new SparkConf().setMaster("local[4]").set("spark.driver.host", "localhost")
|
||||
.setAppName("sparktest")
|
||||
|
|
|
@ -19,6 +19,10 @@ package org.deeplearning4j.spark;
|
|||
import org.apache.spark.SparkConf;
|
||||
import org.apache.spark.api.java.JavaSparkContext;
|
||||
|
||||
import java.lang.reflect.Field;
|
||||
import java.util.Collections;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Created by Alex on 04/07/2017.
|
||||
*/
|
||||
|
@ -30,6 +34,31 @@ public class BaseSparkKryoTest extends BaseSparkTest {
|
|||
return sc;
|
||||
}
|
||||
|
||||
//Ensure SPARK_USER environment variable is set for Spark Kryo tests
|
||||
String u = System.getenv("SPARK_USER");
|
||||
if(u == null || u.isEmpty()){
|
||||
try {
|
||||
Class[] classes = Collections.class.getDeclaredClasses();
|
||||
Map<String, String> env = System.getenv();
|
||||
for (Class cl : classes) {
|
||||
if ("java.util.Collections$UnmodifiableMap".equals(cl.getName())) {
|
||||
Field field = cl.getDeclaredField("m");
|
||||
field.setAccessible(true);
|
||||
Object obj = field.get(env);
|
||||
Map<String, String> map = (Map<String, String>) obj;
|
||||
String user = System.getProperty("user.name");
|
||||
if(user == null || user.isEmpty())
|
||||
user = "user";
|
||||
map.put("SPARK_USER", user);
|
||||
}
|
||||
}
|
||||
} catch (Exception e){
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
SparkConf sparkConf = new SparkConf().setMaster("local[" + numExecutors() + "]").setAppName("sparktest");
|
||||
|
||||
sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer");
|
||||
|
|
|
@ -74,7 +74,9 @@ public abstract class BaseSparkTest implements Serializable {
|
|||
|
||||
@After
|
||||
public void after() {
|
||||
sc.close();
|
||||
if(sc != null) {
|
||||
sc.close();
|
||||
}
|
||||
sc = null;
|
||||
}
|
||||
|
||||
|
|
|
@ -75,7 +75,6 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
|
|||
@EqualsAndHashCode(callSuper = true,
|
||||
exclude = {"auc", "auprc", "probAndLabel", "exactAllocBlockSize", "rocCurve", "prCurve", "axis"})
|
||||
@Data
|
||||
@ToString(exclude = {"probAndLabel", "exactAllocBlockSize", "rocCurve", "prCurve"})
|
||||
@JsonIgnoreProperties({"probAndLabel", "exactAllocBlockSize"})
|
||||
@JsonSerialize(using = ROCSerializer.class)
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY)
|
||||
|
@ -824,6 +823,11 @@ public class ROC extends BaseEvaluation<ROC> {
|
|||
return sb.toString();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString(){
|
||||
return stats();
|
||||
}
|
||||
|
||||
public double scoreForMetric(Metric metric){
|
||||
switch (metric){
|
||||
case AUROC:
|
||||
|
|
Loading…
Reference in New Issue