parent
6960418295
commit
21e7f1c8b8
|
@ -87,15 +87,15 @@ public class BalancedPathFilter extends RandomPathFilter {
|
|||
|
||||
protected boolean acceptLabel(String name) {
|
||||
if (labels == null || labels.length == 0) {
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
for (String label : labels) {
|
||||
if (name.equals(label)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public URI[] filter(URI[] paths) {
|
||||
|
@ -107,7 +107,7 @@ public class BalancedPathFilter extends RandomPathFilter {
|
|||
URI path = paths[i];
|
||||
Writable label = labelGenerator.getLabelForPath(path);
|
||||
if (!acceptLabel(label.toString())) {
|
||||
continue;
|
||||
continue; //we skip label in case it is null, empty or already in the collection
|
||||
}
|
||||
List<URI> pathList = labelPaths.get(label);
|
||||
if (pathList == null) {
|
||||
|
|
|
@ -27,6 +27,7 @@ import org.nd4j.common.base.Preconditions;
|
|||
import java.io.File;
|
||||
import java.net.MalformedURLException;
|
||||
import java.net.URL;
|
||||
import java.nio.file.Path;
|
||||
|
||||
public class DL4JResources {
|
||||
|
||||
|
@ -128,6 +129,10 @@ public class DL4JResources {
|
|||
baseDirectory = f;
|
||||
}
|
||||
|
||||
public static void setBaseDirectory(@NonNull Path p) {
|
||||
setBaseDirectory(p.toFile());
|
||||
}
|
||||
|
||||
/**
|
||||
* @return The base storage directory for DL4J resources
|
||||
*/
|
||||
|
|
|
@ -24,10 +24,7 @@ import org.deeplearning4j.BaseDL4JTest;
|
|||
import org.deeplearning4j.datasets.base.MnistFetcher;
|
||||
import org.deeplearning4j.common.resources.DL4JResources;
|
||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||
import org.junit.jupiter.api.AfterAll;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.TestInstance;
|
||||
import org.junit.jupiter.api.*;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
|
||||
|
@ -37,6 +34,7 @@ import org.nd4j.linalg.factory.Nd4j;
|
|||
import org.nd4j.linalg.indexing.conditions.Conditions;
|
||||
|
||||
import java.io.File;
|
||||
import java.nio.file.Path;
|
||||
import java.util.HashSet;
|
||||
import java.util.Set;
|
||||
|
||||
|
@ -45,24 +43,14 @@ import static org.junit.jupiter.api.Assertions.assertFalse;
|
|||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
@org.junit.jupiter.api.Timeout(300)
|
||||
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
|
||||
public class MnistFetcherTest extends BaseDL4JTest {
|
||||
|
||||
@TempDir
|
||||
public File testDir;
|
||||
|
||||
@BeforeAll
|
||||
public void setup() throws Exception {
|
||||
DL4JResources.setBaseDirectory(testDir);
|
||||
}
|
||||
|
||||
@AfterAll
|
||||
public void after() {
|
||||
DL4JResources.resetBaseDirectoryLocation();
|
||||
}
|
||||
public Path testDir;
|
||||
|
||||
@Test
|
||||
public void testMnist() throws Exception {
|
||||
DL4JResources.setBaseDirectory(testDir);
|
||||
DataSetIterator iter = new MnistDataSetIterator(32, 60000, false, true, false, -1);
|
||||
int count = 0;
|
||||
while(iter.hasNext()){
|
||||
|
|
Loading…
Reference in New Issue