Another round of small fixes (#241)
* Small base spark test fix; ROC toString for empty ROC Signed-off-by: Alex Black <blacka101@gmail.com> * More fixes Signed-off-by: Alex Black <blacka101@gmail.com>master
parent
87d873929f
commit
52d2795193
|
@ -63,7 +63,7 @@
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.fasterxml.jackson.module</groupId>
|
<groupId>com.fasterxml.jackson.module</groupId>
|
||||||
<artifactId>jackson-module-scala_2.11</artifactId>
|
<artifactId>jackson-module-scala_2.11</artifactId>
|
||||||
<version>${jackson.version}</version>
|
<version>2.6.7</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,9 @@ import org.junit.After;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
|
import java.lang.reflect.Field;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Created by agibsonccc on 1/23/15.
|
* Created by agibsonccc on 1/23/15.
|
||||||
|
@ -37,7 +40,9 @@ public abstract class BaseSparkTest implements Serializable {
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void after() {
|
public void after() {
|
||||||
sc.close();
|
if(sc != null) {
|
||||||
|
sc.close();
|
||||||
|
}
|
||||||
sc = null;
|
sc = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -48,6 +53,30 @@ public abstract class BaseSparkTest implements Serializable {
|
||||||
public JavaSparkContext getContext() {
|
public JavaSparkContext getContext() {
|
||||||
if (sc != null)
|
if (sc != null)
|
||||||
return sc;
|
return sc;
|
||||||
|
|
||||||
|
//Ensure SPARK_USER environment variable is set for Spark tests
|
||||||
|
String u = System.getenv("SPARK_USER");
|
||||||
|
Map<String, String> env = System.getenv();
|
||||||
|
if(u == null || u.isEmpty()) {
|
||||||
|
try {
|
||||||
|
Class[] classes = Collections.class.getDeclaredClasses();
|
||||||
|
for (Class cl : classes) {
|
||||||
|
if ("java.util.Collections$UnmodifiableMap".equals(cl.getName())) {
|
||||||
|
Field field = cl.getDeclaredField("m");
|
||||||
|
field.setAccessible(true);
|
||||||
|
Object obj = field.get(env);
|
||||||
|
Map<String, String> map = (Map<String, String>) obj;
|
||||||
|
String user = System.getProperty("user.name");
|
||||||
|
if (user == null || user.isEmpty())
|
||||||
|
user = "user";
|
||||||
|
map.put("SPARK_USER", user);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// set to test mode
|
// set to test mode
|
||||||
SparkConf sparkConf = new SparkConf().setMaster("local[4]").set("spark.driver.host", "localhost")
|
SparkConf sparkConf = new SparkConf().setMaster("local[4]").set("spark.driver.host", "localhost")
|
||||||
.setAppName("sparktest")
|
.setAppName("sparktest")
|
||||||
|
|
|
@ -19,6 +19,10 @@ package org.deeplearning4j.spark;
|
||||||
import org.apache.spark.SparkConf;
|
import org.apache.spark.SparkConf;
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
|
|
||||||
|
import java.lang.reflect.Field;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Created by Alex on 04/07/2017.
|
* Created by Alex on 04/07/2017.
|
||||||
*/
|
*/
|
||||||
|
@ -30,6 +34,31 @@ public class BaseSparkKryoTest extends BaseSparkTest {
|
||||||
return sc;
|
return sc;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//Ensure SPARK_USER environment variable is set for Spark Kryo tests
|
||||||
|
String u = System.getenv("SPARK_USER");
|
||||||
|
if(u == null || u.isEmpty()){
|
||||||
|
try {
|
||||||
|
Class[] classes = Collections.class.getDeclaredClasses();
|
||||||
|
Map<String, String> env = System.getenv();
|
||||||
|
for (Class cl : classes) {
|
||||||
|
if ("java.util.Collections$UnmodifiableMap".equals(cl.getName())) {
|
||||||
|
Field field = cl.getDeclaredField("m");
|
||||||
|
field.setAccessible(true);
|
||||||
|
Object obj = field.get(env);
|
||||||
|
Map<String, String> map = (Map<String, String>) obj;
|
||||||
|
String user = System.getProperty("user.name");
|
||||||
|
if(user == null || user.isEmpty())
|
||||||
|
user = "user";
|
||||||
|
map.put("SPARK_USER", user);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (Exception e){
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
SparkConf sparkConf = new SparkConf().setMaster("local[" + numExecutors() + "]").setAppName("sparktest");
|
SparkConf sparkConf = new SparkConf().setMaster("local[" + numExecutors() + "]").setAppName("sparktest");
|
||||||
|
|
||||||
sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer");
|
sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer");
|
||||||
|
|
|
@ -74,7 +74,9 @@ public abstract class BaseSparkTest implements Serializable {
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void after() {
|
public void after() {
|
||||||
sc.close();
|
if(sc != null) {
|
||||||
|
sc.close();
|
||||||
|
}
|
||||||
sc = null;
|
sc = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -75,7 +75,6 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
|
||||||
@EqualsAndHashCode(callSuper = true,
|
@EqualsAndHashCode(callSuper = true,
|
||||||
exclude = {"auc", "auprc", "probAndLabel", "exactAllocBlockSize", "rocCurve", "prCurve", "axis"})
|
exclude = {"auc", "auprc", "probAndLabel", "exactAllocBlockSize", "rocCurve", "prCurve", "axis"})
|
||||||
@Data
|
@Data
|
||||||
@ToString(exclude = {"probAndLabel", "exactAllocBlockSize", "rocCurve", "prCurve"})
|
|
||||||
@JsonIgnoreProperties({"probAndLabel", "exactAllocBlockSize"})
|
@JsonIgnoreProperties({"probAndLabel", "exactAllocBlockSize"})
|
||||||
@JsonSerialize(using = ROCSerializer.class)
|
@JsonSerialize(using = ROCSerializer.class)
|
||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY)
|
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY)
|
||||||
|
@ -824,6 +823,11 @@ public class ROC extends BaseEvaluation<ROC> {
|
||||||
return sb.toString();
|
return sb.toString();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString(){
|
||||||
|
return stats();
|
||||||
|
}
|
||||||
|
|
||||||
public double scoreForMetric(Metric metric){
|
public double scoreForMetric(Metric metric){
|
||||||
switch (metric){
|
switch (metric){
|
||||||
case AUROC:
|
case AUROC:
|
||||||
|
|
Loading…
Reference in New Issue