More test fixes

Signed-off-by: brian <brian@brutex.de>
master
Brian Rosenberger 2022-10-11 09:29:26 +02:00
parent 6960418295
commit 21e7f1c8b8
3 changed files with 13 additions and 20 deletions

View File

@ -87,15 +87,15 @@ public class BalancedPathFilter extends RandomPathFilter {
protected boolean acceptLabel(String name) { protected boolean acceptLabel(String name) {
if (labels == null || labels.length == 0) { if (labels == null || labels.length == 0) {
return true; return false;
} }
for (String label : labels) { for (String label : labels) {
if (name.equals(label)) { if (name.equals(label)) {
return true;
}
}
return false; return false;
} }
}
return true;
}
@Override @Override
public URI[] filter(URI[] paths) { public URI[] filter(URI[] paths) {
@ -107,7 +107,7 @@ public class BalancedPathFilter extends RandomPathFilter {
URI path = paths[i]; URI path = paths[i];
Writable label = labelGenerator.getLabelForPath(path); Writable label = labelGenerator.getLabelForPath(path);
if (!acceptLabel(label.toString())) { if (!acceptLabel(label.toString())) {
continue; continue; //we skip label in case it is null, empty or already in the collection
} }
List<URI> pathList = labelPaths.get(label); List<URI> pathList = labelPaths.get(label);
if (pathList == null) { if (pathList == null) {

View File

@ -27,6 +27,7 @@ import org.nd4j.common.base.Preconditions;
import java.io.File; import java.io.File;
import java.net.MalformedURLException; import java.net.MalformedURLException;
import java.net.URL; import java.net.URL;
import java.nio.file.Path;
public class DL4JResources { public class DL4JResources {
@ -128,6 +129,10 @@ public class DL4JResources {
baseDirectory = f; baseDirectory = f;
} }
public static void setBaseDirectory(@NonNull Path p) {
setBaseDirectory(p.toFile());
}
/** /**
* @return The base storage directory for DL4J resources * @return The base storage directory for DL4J resources
*/ */

View File

@ -24,10 +24,7 @@ import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.base.MnistFetcher; import org.deeplearning4j.datasets.base.MnistFetcher;
import org.deeplearning4j.common.resources.DL4JResources; import org.deeplearning4j.common.resources.DL4JResources;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.*;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition; import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
@ -37,6 +34,7 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.conditions.Conditions; import org.nd4j.linalg.indexing.conditions.Conditions;
import java.io.File; import java.io.File;
import java.nio.file.Path;
import java.util.HashSet; import java.util.HashSet;
import java.util.Set; import java.util.Set;
@ -45,24 +43,14 @@ import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
@org.junit.jupiter.api.Timeout(300) @org.junit.jupiter.api.Timeout(300)
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
public class MnistFetcherTest extends BaseDL4JTest { public class MnistFetcherTest extends BaseDL4JTest {
@TempDir @TempDir
public File testDir; public Path testDir;
@BeforeAll
public void setup() throws Exception {
DL4JResources.setBaseDirectory(testDir);
}
@AfterAll
public void after() {
DL4JResources.resetBaseDirectoryLocation();
}
@Test @Test
public void testMnist() throws Exception { public void testMnist() throws Exception {
DL4JResources.setBaseDirectory(testDir);
DataSetIterator iter = new MnistDataSetIterator(32, 60000, false, true, false, -1); DataSetIterator iter = new MnistDataSetIterator(32, 60000, false, true, false, -1);
int count = 0; int count = 0;
while(iter.hasNext()){ while(iter.hasNext()){