From 21e7f1c8b829695bd1073a1258f3f0c61fd73f2e Mon Sep 17 00:00:00 2001 From: brian Date: Tue, 11 Oct 2022 09:29:26 +0200 Subject: [PATCH] More test fixes Signed-off-by: brian --- .../api/io/filters/BalancedPathFilter.java | 8 ++++---- .../common/resources/DL4JResources.java | 5 +++++ .../datasets/MnistFetcherTest.java | 20 ++++--------------- 3 files changed, 13 insertions(+), 20 deletions(-) diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/filters/BalancedPathFilter.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/filters/BalancedPathFilter.java index 3a58cc3a7..348b4e0fd 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/filters/BalancedPathFilter.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/filters/BalancedPathFilter.java @@ -87,14 +87,14 @@ public class BalancedPathFilter extends RandomPathFilter { protected boolean acceptLabel(String name) { if (labels == null || labels.length == 0) { - return true; + return false; } for (String label : labels) { if (name.equals(label)) { - return true; + return false; } } - return false; + return true; } @Override @@ -107,7 +107,7 @@ public class BalancedPathFilter extends RandomPathFilter { URI path = paths[i]; Writable label = labelGenerator.getLabelForPath(path); if (!acceptLabel(label.toString())) { - continue; + continue; //we skip label in case it is null, empty or already in the collection } List pathList = labelPaths.get(label); if (pathList == null) { diff --git a/cavis-dnn/cavis-dnn-common/src/main/java/org/deeplearning4j/common/resources/DL4JResources.java b/cavis-dnn/cavis-dnn-common/src/main/java/org/deeplearning4j/common/resources/DL4JResources.java index e8f6ecda0..c73955595 100644 --- a/cavis-dnn/cavis-dnn-common/src/main/java/org/deeplearning4j/common/resources/DL4JResources.java +++ b/cavis-dnn/cavis-dnn-common/src/main/java/org/deeplearning4j/common/resources/DL4JResources.java @@ -27,6 +27,7 @@ import org.nd4j.common.base.Preconditions; import java.io.File; import java.net.MalformedURLException; import java.net.URL; +import java.nio.file.Path; public class DL4JResources { @@ -128,6 +129,10 @@ public class DL4JResources { baseDirectory = f; } + public static void setBaseDirectory(@NonNull Path p) { + setBaseDirectory(p.toFile()); + } + /** * @return The base storage directory for DL4J resources */ diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/MnistFetcherTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/MnistFetcherTest.java index c63e4ac7d..1bc48d33e 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/MnistFetcherTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/MnistFetcherTest.java @@ -24,10 +24,7 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.base.MnistFetcher; import org.deeplearning4j.common.resources.DL4JResources; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.*; import org.junit.jupiter.api.io.TempDir; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition; @@ -37,6 +34,7 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.conditions.Conditions; import java.io.File; +import java.nio.file.Path; import java.util.HashSet; import java.util.Set; @@ -45,24 +43,14 @@ import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; @org.junit.jupiter.api.Timeout(300) -@TestInstance(TestInstance.Lifecycle.PER_CLASS) public class MnistFetcherTest extends BaseDL4JTest { @TempDir - public File testDir; - - @BeforeAll - public void setup() throws Exception { - DL4JResources.setBaseDirectory(testDir); - } - - @AfterAll - public void after() { - DL4JResources.resetBaseDirectoryLocation(); - } + public Path testDir; @Test public void testMnist() throws Exception { + DL4JResources.setBaseDirectory(testDir); DataSetIterator iter = new MnistDataSetIterator(32, 60000, false, true, false, -1); int count = 0; while(iter.hasNext()){