From 52d279519393b3b2cdc92075e1cd0faa900e7dd3 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Thu, 5 Sep 2019 17:01:47 +1000 Subject: [PATCH] Another round of small fixes (#241) * Small base spark test fix; ROC toString for empty ROC Signed-off-by: Alex Black * More fixes Signed-off-by: Alex Black --- .../spark/dl4j-spark-nlp/pom.xml | 2 +- .../spark/text/BaseSparkTest.java | 31 ++++++++++++++++++- .../spark/BaseSparkKryoTest.java | 29 +++++++++++++++++ .../deeplearning4j/spark/BaseSparkTest.java | 4 ++- .../nd4j/evaluation/classification/ROC.java | 6 +++- 5 files changed, 68 insertions(+), 4 deletions(-) diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/pom.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/pom.xml index 16c4ac298..a4746be70 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/pom.xml @@ -63,7 +63,7 @@ com.fasterxml.jackson.module jackson-module-scala_2.11 - ${jackson.version} + 2.6.7 diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/BaseSparkTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/BaseSparkTest.java index 738daa647..152ef4db5 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/BaseSparkTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/BaseSparkTest.java @@ -23,6 +23,9 @@ import org.junit.After; import org.junit.Before; import java.io.Serializable; +import java.lang.reflect.Field; +import java.util.Collections; +import java.util.Map; /** * Created by agibsonccc on 1/23/15. @@ -37,7 +40,9 @@ public abstract class BaseSparkTest implements Serializable { @After public void after() { - sc.close(); + if(sc != null) { + sc.close(); + } sc = null; } @@ -48,6 +53,30 @@ public abstract class BaseSparkTest implements Serializable { public JavaSparkContext getContext() { if (sc != null) return sc; + + //Ensure SPARK_USER environment variable is set for Spark tests + String u = System.getenv("SPARK_USER"); + Map env = System.getenv(); + if(u == null || u.isEmpty()) { + try { + Class[] classes = Collections.class.getDeclaredClasses(); + for (Class cl : classes) { + if ("java.util.Collections$UnmodifiableMap".equals(cl.getName())) { + Field field = cl.getDeclaredField("m"); + field.setAccessible(true); + Object obj = field.get(env); + Map map = (Map) obj; + String user = System.getProperty("user.name"); + if (user == null || user.isEmpty()) + user = "user"; + map.put("SPARK_USER", user); + } + } + } catch (Exception e) { + throw new RuntimeException(e); + } + } + // set to test mode SparkConf sparkConf = new SparkConf().setMaster("local[4]").set("spark.driver.host", "localhost") .setAppName("sparktest") diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkKryoTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkKryoTest.java index bb3a7180e..1c794ebf6 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkKryoTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkKryoTest.java @@ -19,6 +19,10 @@ package org.deeplearning4j.spark; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaSparkContext; +import java.lang.reflect.Field; +import java.util.Collections; +import java.util.Map; + /** * Created by Alex on 04/07/2017. */ @@ -30,6 +34,31 @@ public class BaseSparkKryoTest extends BaseSparkTest { return sc; } + //Ensure SPARK_USER environment variable is set for Spark Kryo tests + String u = System.getenv("SPARK_USER"); + if(u == null || u.isEmpty()){ + try { + Class[] classes = Collections.class.getDeclaredClasses(); + Map env = System.getenv(); + for (Class cl : classes) { + if ("java.util.Collections$UnmodifiableMap".equals(cl.getName())) { + Field field = cl.getDeclaredField("m"); + field.setAccessible(true); + Object obj = field.get(env); + Map map = (Map) obj; + String user = System.getProperty("user.name"); + if(user == null || user.isEmpty()) + user = "user"; + map.put("SPARK_USER", user); + } + } + } catch (Exception e){ + throw new RuntimeException(e); + } + } + + + SparkConf sparkConf = new SparkConf().setMaster("local[" + numExecutors() + "]").setAppName("sparktest"); sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer"); diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkTest.java index d2a6e08e1..781e3dad2 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkTest.java @@ -74,7 +74,9 @@ public abstract class BaseSparkTest implements Serializable { @After public void after() { - sc.close(); + if(sc != null) { + sc.close(); + } sc = null; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROC.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROC.java index 63a5a012a..c9f5cabdf 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROC.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROC.java @@ -75,7 +75,6 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.interval; @EqualsAndHashCode(callSuper = true, exclude = {"auc", "auprc", "probAndLabel", "exactAllocBlockSize", "rocCurve", "prCurve", "axis"}) @Data -@ToString(exclude = {"probAndLabel", "exactAllocBlockSize", "rocCurve", "prCurve"}) @JsonIgnoreProperties({"probAndLabel", "exactAllocBlockSize"}) @JsonSerialize(using = ROCSerializer.class) @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY) @@ -824,6 +823,11 @@ public class ROC extends BaseEvaluation { return sb.toString(); } + @Override + public String toString(){ + return stats(); + } + public double scoreForMetric(Metric metric){ switch (metric){ case AUROC: