More test fixes

Signed-off-by: brian <brian@brutex.de>
master
Brian Rosenberger 2022-10-11 09:29:26 +02:00
parent 6960418295
commit 21e7f1c8b8
3 changed files with 13 additions and 20 deletions

View File

@ -87,14 +87,14 @@ public class BalancedPathFilter extends RandomPathFilter {
protected boolean acceptLabel(String name) {
if (labels == null || labels.length == 0) {
return true;
return false;
}
for (String label : labels) {
if (name.equals(label)) {
return true;
return false;
}
}
return false;
return true;
}
@Override
@ -107,7 +107,7 @@ public class BalancedPathFilter extends RandomPathFilter {
URI path = paths[i];
Writable label = labelGenerator.getLabelForPath(path);
if (!acceptLabel(label.toString())) {
continue;
continue; //we skip label in case it is null, empty or already in the collection
}
List<URI> pathList = labelPaths.get(label);
if (pathList == null) {

View File

@ -27,6 +27,7 @@ import org.nd4j.common.base.Preconditions;
import java.io.File;
import java.net.MalformedURLException;
import java.net.URL;
import java.nio.file.Path;
public class DL4JResources {
@ -128,6 +129,10 @@ public class DL4JResources {
baseDirectory = f;
}
public static void setBaseDirectory(@NonNull Path p) {
setBaseDirectory(p.toFile());
}
/**
* @return The base storage directory for DL4J resources
*/

View File

@ -24,10 +24,7 @@ import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.base.MnistFetcher;
import org.deeplearning4j.common.resources.DL4JResources;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.api.*;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
@ -37,6 +34,7 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.conditions.Conditions;
import java.io.File;
import java.nio.file.Path;
import java.util.HashSet;
import java.util.Set;
@ -45,24 +43,14 @@ import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
@org.junit.jupiter.api.Timeout(300)
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
public class MnistFetcherTest extends BaseDL4JTest {
@TempDir
public File testDir;
@BeforeAll
public void setup() throws Exception {
DL4JResources.setBaseDirectory(testDir);
}
@AfterAll
public void after() {
DL4JResources.resetBaseDirectoryLocation();
}
public Path testDir;
@Test
public void testMnist() throws Exception {
DL4JResources.setBaseDirectory(testDir);
DataSetIterator iter = new MnistDataSetIterator(32, 60000, false, true, false, -1);
int count = 0;
while(iter.hasNext()){