parent
6960418295
commit
21e7f1c8b8
|
@ -87,14 +87,14 @@ public class BalancedPathFilter extends RandomPathFilter {
|
||||||
|
|
||||||
protected boolean acceptLabel(String name) {
|
protected boolean acceptLabel(String name) {
|
||||||
if (labels == null || labels.length == 0) {
|
if (labels == null || labels.length == 0) {
|
||||||
return true;
|
return false;
|
||||||
}
|
}
|
||||||
for (String label : labels) {
|
for (String label : labels) {
|
||||||
if (name.equals(label)) {
|
if (name.equals(label)) {
|
||||||
return true;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -107,7 +107,7 @@ public class BalancedPathFilter extends RandomPathFilter {
|
||||||
URI path = paths[i];
|
URI path = paths[i];
|
||||||
Writable label = labelGenerator.getLabelForPath(path);
|
Writable label = labelGenerator.getLabelForPath(path);
|
||||||
if (!acceptLabel(label.toString())) {
|
if (!acceptLabel(label.toString())) {
|
||||||
continue;
|
continue; //we skip label in case it is null, empty or already in the collection
|
||||||
}
|
}
|
||||||
List<URI> pathList = labelPaths.get(label);
|
List<URI> pathList = labelPaths.get(label);
|
||||||
if (pathList == null) {
|
if (pathList == null) {
|
||||||
|
|
|
@ -27,6 +27,7 @@ import org.nd4j.common.base.Preconditions;
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.net.MalformedURLException;
|
import java.net.MalformedURLException;
|
||||||
import java.net.URL;
|
import java.net.URL;
|
||||||
|
import java.nio.file.Path;
|
||||||
|
|
||||||
public class DL4JResources {
|
public class DL4JResources {
|
||||||
|
|
||||||
|
@ -128,6 +129,10 @@ public class DL4JResources {
|
||||||
baseDirectory = f;
|
baseDirectory = f;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static void setBaseDirectory(@NonNull Path p) {
|
||||||
|
setBaseDirectory(p.toFile());
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @return The base storage directory for DL4J resources
|
* @return The base storage directory for DL4J resources
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -24,10 +24,7 @@ import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.datasets.base.MnistFetcher;
|
import org.deeplearning4j.datasets.base.MnistFetcher;
|
||||||
import org.deeplearning4j.common.resources.DL4JResources;
|
import org.deeplearning4j.common.resources.DL4JResources;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||||
import org.junit.jupiter.api.AfterAll;
|
import org.junit.jupiter.api.*;
|
||||||
import org.junit.jupiter.api.BeforeAll;
|
|
||||||
import org.junit.jupiter.api.Test;
|
|
||||||
import org.junit.jupiter.api.TestInstance;
|
|
||||||
import org.junit.jupiter.api.io.TempDir;
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
|
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
|
||||||
|
@ -37,6 +34,7 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.indexing.conditions.Conditions;
|
import org.nd4j.linalg.indexing.conditions.Conditions;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
|
import java.nio.file.Path;
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
|
||||||
|
@ -45,24 +43,14 @@ import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
@org.junit.jupiter.api.Timeout(300)
|
@org.junit.jupiter.api.Timeout(300)
|
||||||
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
|
|
||||||
public class MnistFetcherTest extends BaseDL4JTest {
|
public class MnistFetcherTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@TempDir
|
@TempDir
|
||||||
public File testDir;
|
public Path testDir;
|
||||||
|
|
||||||
@BeforeAll
|
|
||||||
public void setup() throws Exception {
|
|
||||||
DL4JResources.setBaseDirectory(testDir);
|
|
||||||
}
|
|
||||||
|
|
||||||
@AfterAll
|
|
||||||
public void after() {
|
|
||||||
DL4JResources.resetBaseDirectoryLocation();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testMnist() throws Exception {
|
public void testMnist() throws Exception {
|
||||||
|
DL4JResources.setBaseDirectory(testDir);
|
||||||
DataSetIterator iter = new MnistDataSetIterator(32, 60000, false, true, false, -1);
|
DataSetIterator iter = new MnistDataSetIterator(32, 60000, false, true, false, -1);
|
||||||
int count = 0;
|
int count = 0;
|
||||||
while(iter.hasNext()){
|
while(iter.hasNext()){
|
||||||
|
|
Loading…
Reference in New Issue