All tests compile
parent
b1229432d6
commit
82bdcc21d2
|
@ -25,7 +25,7 @@ import org.datavec.api.records.reader.impl.csv.CSVLineSequenceRecordReader;
|
||||||
import org.datavec.api.split.FileSplit;
|
import org.datavec.api.split.FileSplit;
|
||||||
import org.datavec.api.writable.Text;
|
import org.datavec.api.writable.Text;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.api.io.TempDir;
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
|
@ -25,7 +25,7 @@ import org.datavec.api.records.reader.impl.csv.CSVMultiSequenceRecordReader;
|
||||||
import org.datavec.api.split.FileSplit;
|
import org.datavec.api.split.FileSplit;
|
||||||
import org.datavec.api.writable.Text;
|
import org.datavec.api.writable.Text;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.api.io.TempDir;
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
|
@ -26,7 +26,7 @@ import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
|
||||||
import org.datavec.api.split.InputSplit;
|
import org.datavec.api.split.InputSplit;
|
||||||
import org.datavec.api.split.NumberedFileInputSplit;
|
import org.datavec.api.split.NumberedFileInputSplit;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.api.io.TempDir;
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
|
@ -27,7 +27,7 @@ import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
|
||||||
import org.datavec.api.records.reader.impl.filebatch.FileBatchRecordReader;
|
import org.datavec.api.records.reader.impl.filebatch.FileBatchRecordReader;
|
||||||
import org.datavec.api.records.reader.impl.filebatch.FileBatchSequenceRecordReader;
|
import org.datavec.api.records.reader.impl.filebatch.FileBatchSequenceRecordReader;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.api.io.TempDir;
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
|
@ -27,7 +27,7 @@ import org.datavec.api.split.CollectionInputSplit;
|
||||||
import org.datavec.api.split.FileSplit;
|
import org.datavec.api.split.FileSplit;
|
||||||
import org.datavec.api.writable.Text;
|
import org.datavec.api.writable.Text;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.api.io.TempDir;
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
|
@ -30,7 +30,7 @@ import org.datavec.api.split.NumberedFileInputSplit;
|
||||||
import org.datavec.api.writable.IntWritable;
|
import org.datavec.api.writable.IntWritable;
|
||||||
import org.datavec.api.writable.Text;
|
import org.datavec.api.writable.Text;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.api.io.TempDir;
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
|
@ -29,7 +29,7 @@ import org.datavec.api.split.FileSplit;
|
||||||
import org.datavec.api.split.InputSplit;
|
import org.datavec.api.split.InputSplit;
|
||||||
import org.datavec.api.split.InputStreamInputSplit;
|
import org.datavec.api.split.InputStreamInputSplit;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.api.io.TempDir;
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
|
@ -32,7 +32,7 @@ import org.datavec.api.split.InputSplit;
|
||||||
import org.datavec.api.split.NumberedFileInputSplit;
|
import org.datavec.api.split.NumberedFileInputSplit;
|
||||||
import org.datavec.api.writable.Text;
|
import org.datavec.api.writable.Text;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.api.io.TempDir;
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
|
@ -26,14 +26,14 @@ import org.datavec.api.records.reader.SequenceRecordReader;
|
||||||
import org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader;
|
import org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader;
|
||||||
import org.datavec.api.writable.IntWritable;
|
import org.datavec.api.writable.IntWritable;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
public class TestCollectionRecordReaders extends BaseND4JTest {
|
public class TestCollectionRecordReaders extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
|
@ -23,11 +23,11 @@ package org.datavec.api.records.reader.impl;
|
||||||
import org.datavec.api.records.reader.RecordReader;
|
import org.datavec.api.records.reader.RecordReader;
|
||||||
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
|
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
|
||||||
import org.datavec.api.split.FileSplit;
|
import org.datavec.api.split.FileSplit;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
import org.nd4j.common.io.ClassPathResource;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestConcatenatingRecordReader extends BaseND4JTest {
|
public class TestConcatenatingRecordReader extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
|
@ -37,7 +37,7 @@ import org.datavec.api.transform.TransformProcess;
|
||||||
import org.datavec.api.transform.schema.Schema;
|
import org.datavec.api.transform.schema.Schema;
|
||||||
import org.datavec.api.writable.Text;
|
import org.datavec.api.writable.Text;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
import org.nd4j.common.io.ClassPathResource;
|
||||||
import org.nd4j.shade.jackson.core.JsonFactory;
|
import org.nd4j.shade.jackson.core.JsonFactory;
|
||||||
|
@ -47,7 +47,7 @@ import java.io.*;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestSerialization extends BaseND4JTest {
|
public class TestSerialization extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
|
@ -30,7 +30,7 @@ import org.datavec.api.writable.IntWritable;
|
||||||
import org.datavec.api.writable.LongWritable;
|
import org.datavec.api.writable.LongWritable;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.joda.time.DateTimeZone;
|
import org.joda.time.DateTimeZone;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
import org.nd4j.common.io.ClassPathResource;
|
||||||
|
|
||||||
|
@ -38,8 +38,8 @@ import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
public class TransformProcessRecordReaderTests extends BaseND4JTest {
|
public class TransformProcessRecordReaderTests extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
|
@ -26,7 +26,7 @@ import org.datavec.api.io.filters.BalancedPathFilter;
|
||||||
import org.datavec.api.io.filters.RandomPathFilter;
|
import org.datavec.api.io.filters.RandomPathFilter;
|
||||||
import org.datavec.api.io.labels.ParentPathLabelGenerator;
|
import org.datavec.api.io.labels.ParentPathLabelGenerator;
|
||||||
import org.datavec.api.io.labels.PatternPathLabelGenerator;
|
import org.datavec.api.io.labels.PatternPathLabelGenerator;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
import java.io.*;
|
import java.io.*;
|
||||||
import java.net.URI;
|
import java.net.URI;
|
||||||
|
@ -35,7 +35,7 @@ import java.util.ArrayList;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
|
||||||
import static junit.framework.TestCase.assertTrue;
|
import static junit.framework.TestCase.assertTrue;
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
|
|
|
@ -20,13 +20,12 @@
|
||||||
|
|
||||||
package org.datavec.api.split;
|
package org.datavec.api.split;
|
||||||
|
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
||||||
import java.net.URI;
|
import java.net.URI;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
import static org.junit.Assert.assertTrue;
|
|
||||||
|
|
||||||
public class NumberedFileInputSplitTests extends BaseND4JTest {
|
public class NumberedFileInputSplitTests extends BaseND4JTest {
|
||||||
@Test
|
@Test
|
||||||
|
@ -69,60 +68,81 @@ public class NumberedFileInputSplitTests extends BaseND4JTest {
|
||||||
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
|
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected = IllegalArgumentException.class)
|
@Test()
|
||||||
public void testNumberedFileInputSplitWithLeadingSpaces() {
|
public void testNumberedFileInputSplitWithLeadingSpaces() {
|
||||||
String baseString = "/path/to/files/prefix-%5d.suffix";
|
assertThrows(IllegalArgumentException.class,() -> {
|
||||||
int minIdx = 0;
|
String baseString = "/path/to/files/prefix-%5d.suffix";
|
||||||
int maxIdx = 10;
|
int minIdx = 0;
|
||||||
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
|
int maxIdx = 10;
|
||||||
|
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
|
||||||
|
});
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected = IllegalArgumentException.class)
|
@Test()
|
||||||
public void testNumberedFileInputSplitWithNoLeadingZeroInPadding() {
|
public void testNumberedFileInputSplitWithNoLeadingZeroInPadding() {
|
||||||
String baseString = "/path/to/files/prefix%5d.suffix";
|
assertThrows(IllegalArgumentException.class, () -> {
|
||||||
int minIdx = 0;
|
String baseString = "/path/to/files/prefix%5d.suffix";
|
||||||
int maxIdx = 10;
|
int minIdx = 0;
|
||||||
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
|
int maxIdx = 10;
|
||||||
|
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
|
||||||
|
});
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected = IllegalArgumentException.class)
|
@Test()
|
||||||
public void testNumberedFileInputSplitWithLeadingPlusInPadding() {
|
public void testNumberedFileInputSplitWithLeadingPlusInPadding() {
|
||||||
String baseString = "/path/to/files/prefix%+5d.suffix";
|
assertThrows(IllegalArgumentException.class,() -> {
|
||||||
int minIdx = 0;
|
String baseString = "/path/to/files/prefix%+5d.suffix";
|
||||||
int maxIdx = 10;
|
int minIdx = 0;
|
||||||
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
|
int maxIdx = 10;
|
||||||
|
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
|
||||||
|
});
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected = IllegalArgumentException.class)
|
@Test()
|
||||||
public void testNumberedFileInputSplitWithLeadingMinusInPadding() {
|
public void testNumberedFileInputSplitWithLeadingMinusInPadding() {
|
||||||
String baseString = "/path/to/files/prefix%-5d.suffix";
|
assertThrows(IllegalArgumentException.class,() -> {
|
||||||
int minIdx = 0;
|
String baseString = "/path/to/files/prefix%-5d.suffix";
|
||||||
int maxIdx = 10;
|
int minIdx = 0;
|
||||||
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
|
int maxIdx = 10;
|
||||||
|
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
|
||||||
|
});
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected = IllegalArgumentException.class)
|
@Test()
|
||||||
public void testNumberedFileInputSplitWithTwoDigitsInPadding() {
|
public void testNumberedFileInputSplitWithTwoDigitsInPadding() {
|
||||||
String baseString = "/path/to/files/prefix%011d.suffix";
|
assertThrows(IllegalArgumentException.class,() -> {
|
||||||
int minIdx = 0;
|
String baseString = "/path/to/files/prefix%011d.suffix";
|
||||||
int maxIdx = 10;
|
int minIdx = 0;
|
||||||
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
|
int maxIdx = 10;
|
||||||
|
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
|
||||||
|
});
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected = IllegalArgumentException.class)
|
@Test()
|
||||||
public void testNumberedFileInputSplitWithInnerZerosInPadding() {
|
public void testNumberedFileInputSplitWithInnerZerosInPadding() {
|
||||||
String baseString = "/path/to/files/prefix%101d.suffix";
|
assertThrows(IllegalArgumentException.class,() -> {
|
||||||
int minIdx = 0;
|
String baseString = "/path/to/files/prefix%101d.suffix";
|
||||||
int maxIdx = 10;
|
int minIdx = 0;
|
||||||
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
|
int maxIdx = 10;
|
||||||
|
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
|
||||||
|
});
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected = IllegalArgumentException.class)
|
@Test()
|
||||||
public void testNumberedFileInputSplitWithRepeatInnerZerosInPadding() {
|
public void testNumberedFileInputSplitWithRepeatInnerZerosInPadding() {
|
||||||
String baseString = "/path/to/files/prefix%0505d.suffix";
|
assertThrows(IllegalArgumentException.class,() -> {
|
||||||
int minIdx = 0;
|
String baseString = "/path/to/files/prefix%0505d.suffix";
|
||||||
int maxIdx = 10;
|
int minIdx = 0;
|
||||||
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
|
int maxIdx = 10;
|
||||||
|
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
|
||||||
|
});
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -135,7 +155,7 @@ public class NumberedFileInputSplitTests extends BaseND4JTest {
|
||||||
String path = locs[j++].getPath();
|
String path = locs[j++].getPath();
|
||||||
String exp = String.format(baseString, i);
|
String exp = String.format(baseString, i);
|
||||||
String msg = exp + " vs " + path;
|
String msg = exp + " vs " + path;
|
||||||
assertTrue(msg, path.endsWith(exp)); //Note: on Windows, Java can prepend drive to path - "/C:/"
|
assertTrue(path.endsWith(exp),msg); //Note: on Windows, Java can prepend drive to path - "/C:/"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,9 +25,10 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
|
||||||
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
|
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
|
||||||
import org.datavec.api.writable.Text;
|
import org.datavec.api.writable.Text;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.rules.TemporaryFolder;
|
|
||||||
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
import org.nd4j.common.function.Function;
|
import org.nd4j.common.function.Function;
|
||||||
|
|
||||||
|
@ -37,22 +38,22 @@ import java.io.IOException;
|
||||||
import java.io.InputStream;
|
import java.io.InputStream;
|
||||||
import java.net.URI;
|
import java.net.URI;
|
||||||
import java.nio.charset.StandardCharsets;
|
import java.nio.charset.StandardCharsets;
|
||||||
|
import java.nio.file.Path;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.junit.Assert.assertNotEquals;
|
import static org.junit.jupiter.api.Assertions.assertNotEquals;
|
||||||
|
|
||||||
public class TestStreamInputSplit extends BaseND4JTest {
|
public class TestStreamInputSplit extends BaseND4JTest {
|
||||||
|
|
||||||
@Rule
|
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCsvSimple() throws Exception {
|
public void testCsvSimple(@TempDir Path testDir) throws Exception {
|
||||||
File dir = testDir.newFolder();
|
File dir = testDir.toFile();
|
||||||
File f1 = new File(dir, "file1.txt");
|
File f1 = new File(dir, "file1.txt");
|
||||||
File f2 = new File(dir, "file2.txt");
|
File f2 = new File(dir, "file2.txt");
|
||||||
|
|
||||||
|
@ -93,9 +94,9 @@ public class TestStreamInputSplit extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCsvSequenceSimple() throws Exception {
|
public void testCsvSequenceSimple(@TempDir Path testDir) throws Exception {
|
||||||
|
|
||||||
File dir = testDir.newFolder();
|
File dir = testDir.toFile();
|
||||||
File f1 = new File(dir, "file1.txt");
|
File f1 = new File(dir, "file1.txt");
|
||||||
File f2 = new File(dir, "file2.txt");
|
File f2 = new File(dir, "file2.txt");
|
||||||
|
|
||||||
|
@ -137,8 +138,8 @@ public class TestStreamInputSplit extends BaseND4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testShuffle() throws Exception {
|
public void testShuffle(@TempDir Path testDir) throws Exception {
|
||||||
File dir = testDir.newFolder();
|
File dir = testDir.toFile();
|
||||||
File f1 = new File(dir, "file1.txt");
|
File f1 = new File(dir, "file1.txt");
|
||||||
File f2 = new File(dir, "file2.txt");
|
File f2 = new File(dir, "file2.txt");
|
||||||
File f3 = new File(dir, "file3.txt");
|
File f3 = new File(dir, "file3.txt");
|
||||||
|
|
|
@ -27,14 +27,14 @@ import org.datavec.api.split.FileSplit;
|
||||||
import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
|
import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
|
||||||
import org.datavec.api.split.partition.PartitionMetaData;
|
import org.datavec.api.split.partition.PartitionMetaData;
|
||||||
import org.datavec.api.split.partition.Partitioner;
|
import org.datavec.api.split.partition.Partitioner;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.io.OutputStream;
|
import java.io.OutputStream;
|
||||||
|
|
||||||
import static junit.framework.TestCase.assertTrue;
|
import static junit.framework.TestCase.assertTrue;
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.junit.Assert.assertNotNull;
|
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||||
|
|
||||||
public class PartitionerTests extends BaseND4JTest {
|
public class PartitionerTests extends BaseND4JTest {
|
||||||
@Test
|
@Test
|
||||||
|
|
|
@ -29,12 +29,12 @@ import org.datavec.api.writable.DoubleWritable;
|
||||||
import org.datavec.api.writable.IntWritable;
|
import org.datavec.api.writable.IntWritable;
|
||||||
import org.datavec.api.writable.Text;
|
import org.datavec.api.writable.Text;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestTransformProcess extends BaseND4JTest {
|
public class TestTransformProcess extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
|
@ -27,13 +27,13 @@ import org.datavec.api.transform.condition.string.StringRegexColumnCondition;
|
||||||
import org.datavec.api.transform.schema.Schema;
|
import org.datavec.api.transform.schema.Schema;
|
||||||
import org.datavec.api.transform.transform.TestTransforms;
|
import org.datavec.api.transform.transform.TestTransforms;
|
||||||
import org.datavec.api.writable.*;
|
import org.datavec.api.writable.*;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
import static org.junit.Assert.assertFalse;
|
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
public class TestConditions extends BaseND4JTest {
|
public class TestConditions extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
|
@ -27,7 +27,7 @@ import org.datavec.api.transform.schema.Schema;
|
||||||
import org.datavec.api.writable.DoubleWritable;
|
import org.datavec.api.writable.DoubleWritable;
|
||||||
import org.datavec.api.writable.IntWritable;
|
import org.datavec.api.writable.IntWritable;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
@ -36,8 +36,8 @@ import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static java.util.Arrays.asList;
|
import static java.util.Arrays.asList;
|
||||||
import static org.junit.Assert.assertFalse;
|
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
public class TestFilters extends BaseND4JTest {
|
public class TestFilters extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
|
@ -26,19 +26,22 @@ import org.datavec.api.writable.IntWritable;
|
||||||
import org.datavec.api.writable.NullWritable;
|
import org.datavec.api.writable.NullWritable;
|
||||||
import org.datavec.api.writable.Text;
|
import org.datavec.api.writable.Text;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
||||||
|
import java.nio.file.Path;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||||
|
|
||||||
public class TestJoin extends BaseND4JTest {
|
public class TestJoin extends BaseND4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testJoin() {
|
public void testJoin(@TempDir Path testDir) {
|
||||||
|
|
||||||
Schema firstSchema =
|
Schema firstSchema =
|
||||||
new Schema.Builder().addColumnString("keyColumn").addColumnsInteger("first0", "first1").build();
|
new Schema.Builder().addColumnString("keyColumn").addColumnsInteger("first0", "first1").build();
|
||||||
|
@ -46,20 +49,20 @@ public class TestJoin extends BaseND4JTest {
|
||||||
Schema secondSchema = new Schema.Builder().addColumnString("keyColumn").addColumnsInteger("second0").build();
|
Schema secondSchema = new Schema.Builder().addColumnString("keyColumn").addColumnsInteger("second0").build();
|
||||||
|
|
||||||
List<List<Writable>> first = new ArrayList<>();
|
List<List<Writable>> first = new ArrayList<>();
|
||||||
first.add(Arrays.asList((Writable) new Text("key0"), new IntWritable(0), new IntWritable(1)));
|
first.add(Arrays.asList(new Text("key0"), new IntWritable(0), new IntWritable(1)));
|
||||||
first.add(Arrays.asList((Writable) new Text("key1"), new IntWritable(10), new IntWritable(11)));
|
first.add(Arrays.asList(new Text("key1"), new IntWritable(10), new IntWritable(11)));
|
||||||
|
|
||||||
List<List<Writable>> second = new ArrayList<>();
|
List<List<Writable>> second = new ArrayList<>();
|
||||||
second.add(Arrays.asList((Writable) new Text("key0"), new IntWritable(100)));
|
second.add(Arrays.asList(new Text("key0"), new IntWritable(100)));
|
||||||
second.add(Arrays.asList((Writable) new Text("key1"), new IntWritable(110)));
|
second.add(Arrays.asList(new Text("key1"), new IntWritable(110)));
|
||||||
|
|
||||||
Join join = new Join.Builder(Join.JoinType.Inner).setJoinColumns("keyColumn")
|
Join join = new Join.Builder(Join.JoinType.Inner).setJoinColumns("keyColumn")
|
||||||
.setSchemas(firstSchema, secondSchema).build();
|
.setSchemas(firstSchema, secondSchema).build();
|
||||||
|
|
||||||
List<List<Writable>> expected = new ArrayList<>();
|
List<List<Writable>> expected = new ArrayList<>();
|
||||||
expected.add(Arrays.asList((Writable) new Text("key0"), new IntWritable(0), new IntWritable(1),
|
expected.add(Arrays.asList(new Text("key0"), new IntWritable(0), new IntWritable(1),
|
||||||
new IntWritable(100)));
|
new IntWritable(100)));
|
||||||
expected.add(Arrays.asList((Writable) new Text("key1"), new IntWritable(10), new IntWritable(11),
|
expected.add(Arrays.asList(new Text("key1"), new IntWritable(10), new IntWritable(11),
|
||||||
new IntWritable(110)));
|
new IntWritable(110)));
|
||||||
|
|
||||||
|
|
||||||
|
@ -94,27 +97,31 @@ public class TestJoin extends BaseND4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test(expected = IllegalArgumentException.class)
|
@Test()
|
||||||
public void testJoinValidation() {
|
public void testJoinValidation() {
|
||||||
|
assertThrows(IllegalArgumentException.class,() -> {
|
||||||
|
Schema firstSchema = new Schema.Builder().addColumnString("keyColumn1").addColumnsInteger("first0", "first1")
|
||||||
|
.build();
|
||||||
|
|
||||||
Schema firstSchema = new Schema.Builder().addColumnString("keyColumn1").addColumnsInteger("first0", "first1")
|
Schema secondSchema = new Schema.Builder().addColumnString("keyColumn2").addColumnsInteger("second0").build();
|
||||||
.build();
|
|
||||||
|
|
||||||
Schema secondSchema = new Schema.Builder().addColumnString("keyColumn2").addColumnsInteger("second0").build();
|
new Join.Builder(Join.JoinType.Inner).setJoinColumns("keyColumn1", "thisDoesntExist")
|
||||||
|
.setSchemas(firstSchema, secondSchema).build();
|
||||||
|
});
|
||||||
|
|
||||||
new Join.Builder(Join.JoinType.Inner).setJoinColumns("keyColumn1", "thisDoesntExist")
|
|
||||||
.setSchemas(firstSchema, secondSchema).build();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected = IllegalArgumentException.class)
|
@Test()
|
||||||
public void testJoinValidation2() {
|
public void testJoinValidation2() {
|
||||||
|
assertThrows(IllegalArgumentException.class,() -> {
|
||||||
|
Schema firstSchema = new Schema.Builder().addColumnString("keyColumn1").addColumnsInteger("first0", "first1")
|
||||||
|
.build();
|
||||||
|
|
||||||
Schema firstSchema = new Schema.Builder().addColumnString("keyColumn1").addColumnsInteger("first0", "first1")
|
Schema secondSchema = new Schema.Builder().addColumnString("keyColumn2").addColumnsInteger("second0").build();
|
||||||
.build();
|
|
||||||
|
|
||||||
Schema secondSchema = new Schema.Builder().addColumnString("keyColumn2").addColumnsInteger("second0").build();
|
new Join.Builder(Join.JoinType.Inner).setJoinColumns("keyColumn1").setSchemas(firstSchema, secondSchema)
|
||||||
|
.build();
|
||||||
|
});
|
||||||
|
|
||||||
new Join.Builder(Join.JoinType.Inner).setJoinColumns("keyColumn1").setSchemas(firstSchema, secondSchema)
|
|
||||||
.build();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,17 +19,18 @@
|
||||||
*/
|
*/
|
||||||
package org.datavec.api.transform.ops;
|
package org.datavec.api.transform.ops;
|
||||||
|
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.rules.ExpectedException;
|
import org.junit.rules.ExpectedException;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
|
||||||
import org.junit.jupiter.api.DisplayName;
|
import org.junit.jupiter.api.DisplayName;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
@DisplayName("Aggregator Impls Test")
|
@DisplayName("Aggregator Impls Test")
|
||||||
class AggregatorImplsTest extends BaseND4JTest {
|
class AggregatorImplsTest extends BaseND4JTest {
|
||||||
|
|
||||||
|
@ -265,23 +266,25 @@ class AggregatorImplsTest extends BaseND4JTest {
|
||||||
assertEquals(9, cu.get().toInt());
|
assertEquals(9, cu.get().toInt());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Rule
|
|
||||||
public final ExpectedException exception = ExpectedException.none();
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@DisplayName("Incompatible Aggregator Test")
|
@DisplayName("Incompatible Aggregator Test")
|
||||||
void incompatibleAggregatorTest() {
|
void incompatibleAggregatorTest() {
|
||||||
AggregatorImpls.AggregableSum<Integer> sm = new AggregatorImpls.AggregableSum<>();
|
assertThrows(UnsupportedOperationException.class,() -> {
|
||||||
for (int i = 0; i < intList.size(); i++) {
|
AggregatorImpls.AggregableSum<Integer> sm = new AggregatorImpls.AggregableSum<>();
|
||||||
sm.accept(intList.get(i));
|
for (int i = 0; i < intList.size(); i++) {
|
||||||
}
|
sm.accept(intList.get(i));
|
||||||
assertEquals(45, sm.get().toInt());
|
}
|
||||||
AggregatorImpls.AggregableMean<Integer> reverse = new AggregatorImpls.AggregableMean<>();
|
assertEquals(45, sm.get().toInt());
|
||||||
for (int i = 0; i < intList.size(); i++) {
|
AggregatorImpls.AggregableMean<Integer> reverse = new AggregatorImpls.AggregableMean<>();
|
||||||
reverse.accept(intList.get(intList.size() - i - 1));
|
for (int i = 0; i < intList.size(); i++) {
|
||||||
}
|
reverse.accept(intList.get(intList.size() - i - 1));
|
||||||
exception.expect(UnsupportedOperationException.class);
|
}
|
||||||
sm.combine(reverse);
|
|
||||||
assertEquals(45, sm.get().toInt());
|
sm.combine(reverse);
|
||||||
|
assertEquals(45, sm.get().toInt());
|
||||||
|
});
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,13 +32,13 @@ import org.datavec.api.transform.ops.AggregableMultiOp;
|
||||||
import org.datavec.api.transform.ops.IAggregableReduceOp;
|
import org.datavec.api.transform.ops.IAggregableReduceOp;
|
||||||
import org.datavec.api.transform.schema.Schema;
|
import org.datavec.api.transform.schema.Schema;
|
||||||
import org.datavec.api.writable.*;
|
import org.datavec.api.writable.*;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.junit.Assert.fail;
|
import static org.junit.jupiter.api.Assertions.fail;
|
||||||
|
|
||||||
public class TestMultiOpReduce extends BaseND4JTest {
|
public class TestMultiOpReduce extends BaseND4JTest {
|
||||||
|
|
||||||
|
@ -46,10 +46,10 @@ public class TestMultiOpReduce extends BaseND4JTest {
|
||||||
public void testMultiOpReducerDouble() {
|
public void testMultiOpReducerDouble() {
|
||||||
|
|
||||||
List<List<Writable>> inputs = new ArrayList<>();
|
List<List<Writable>> inputs = new ArrayList<>();
|
||||||
inputs.add(Arrays.asList((Writable) new Text("someKey"), new DoubleWritable(0)));
|
inputs.add(Arrays.asList(new Text("someKey"), new DoubleWritable(0)));
|
||||||
inputs.add(Arrays.asList((Writable) new Text("someKey"), new DoubleWritable(1)));
|
inputs.add(Arrays.asList(new Text("someKey"), new DoubleWritable(1)));
|
||||||
inputs.add(Arrays.asList((Writable) new Text("someKey"), new DoubleWritable(2)));
|
inputs.add(Arrays.asList(new Text("someKey"), new DoubleWritable(2)));
|
||||||
inputs.add(Arrays.asList((Writable) new Text("someKey"), new DoubleWritable(2)));
|
inputs.add(Arrays.asList(new Text("someKey"), new DoubleWritable(2)));
|
||||||
|
|
||||||
Map<ReduceOp, Double> exp = new LinkedHashMap<>();
|
Map<ReduceOp, Double> exp = new LinkedHashMap<>();
|
||||||
exp.put(ReduceOp.Min, 0.0);
|
exp.put(ReduceOp.Min, 0.0);
|
||||||
|
@ -82,7 +82,7 @@ public class TestMultiOpReduce extends BaseND4JTest {
|
||||||
assertEquals(out.get(0), new Text("someKey"));
|
assertEquals(out.get(0), new Text("someKey"));
|
||||||
|
|
||||||
String msg = op.toString();
|
String msg = op.toString();
|
||||||
assertEquals(msg, exp.get(op), out.get(1).toDouble(), 1e-5);
|
assertEquals(exp.get(op), out.get(1).toDouble(), 1e-5,msg);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -126,7 +126,7 @@ public class TestMultiOpReduce extends BaseND4JTest {
|
||||||
assertEquals(out.get(0), new Text("someKey"));
|
assertEquals(out.get(0), new Text("someKey"));
|
||||||
|
|
||||||
String msg = op.toString();
|
String msg = op.toString();
|
||||||
assertEquals(msg, exp.get(op), out.get(1).toDouble(), 1e-5);
|
assertEquals(exp.get(op), out.get(1).toDouble(), 1e-5,msg);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -210,7 +210,7 @@ public class TestMultiOpReduce extends BaseND4JTest {
|
||||||
assertEquals(out.get(0), new Text("someKey"));
|
assertEquals(out.get(0), new Text("someKey"));
|
||||||
|
|
||||||
String msg = op.toString();
|
String msg = op.toString();
|
||||||
assertEquals(msg, exp.get(op), out.get(1).toDouble(), 1e-5);
|
assertEquals(exp.get(op), out.get(1).toDouble(), 1e-5,msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (ReduceOp op : Arrays.asList(ReduceOp.Min, ReduceOp.Max, ReduceOp.Range, ReduceOp.Sum, ReduceOp.Mean,
|
for (ReduceOp op : Arrays.asList(ReduceOp.Min, ReduceOp.Max, ReduceOp.Range, ReduceOp.Sum, ReduceOp.Mean,
|
||||||
|
|
|
@ -24,13 +24,13 @@ import org.datavec.api.transform.ops.IAggregableReduceOp;
|
||||||
import org.datavec.api.transform.reduce.impl.GeographicMidpointReduction;
|
import org.datavec.api.transform.reduce.impl.GeographicMidpointReduction;
|
||||||
import org.datavec.api.writable.Text;
|
import org.datavec.api.writable.Text;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestReductions extends BaseND4JTest {
|
public class TestReductions extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
|
@ -22,10 +22,10 @@ package org.datavec.api.transform.schema;
|
||||||
|
|
||||||
import org.datavec.api.transform.metadata.ColumnMetaData;
|
import org.datavec.api.transform.metadata.ColumnMetaData;
|
||||||
import org.joda.time.DateTimeZone;
|
import org.joda.time.DateTimeZone;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestJsonYaml extends BaseND4JTest {
|
public class TestJsonYaml extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
|
@ -21,10 +21,10 @@
|
||||||
package org.datavec.api.transform.schema;
|
package org.datavec.api.transform.schema;
|
||||||
|
|
||||||
import org.datavec.api.transform.ColumnType;
|
import org.datavec.api.transform.ColumnType;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestSchemaMethods extends BaseND4JTest {
|
public class TestSchemaMethods extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
|
@ -33,7 +33,7 @@ import org.datavec.api.writable.LongWritable;
|
||||||
import org.datavec.api.writable.NullWritable;
|
import org.datavec.api.writable.NullWritable;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.joda.time.DateTimeZone;
|
import org.joda.time.DateTimeZone;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
@ -41,7 +41,7 @@ import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestReduceSequenceByWindowFunction extends BaseND4JTest {
|
public class TestReduceSequenceByWindowFunction extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
|
@ -27,7 +27,7 @@ import org.datavec.api.writable.LongWritable;
|
||||||
import org.datavec.api.writable.Text;
|
import org.datavec.api.writable.Text;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.joda.time.DateTimeZone;
|
import org.joda.time.DateTimeZone;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
@ -35,7 +35,7 @@ import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestSequenceSplit extends BaseND4JTest {
|
public class TestSequenceSplit extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
|
@ -29,7 +29,7 @@ import org.datavec.api.writable.IntWritable;
|
||||||
import org.datavec.api.writable.LongWritable;
|
import org.datavec.api.writable.LongWritable;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.joda.time.DateTimeZone;
|
import org.joda.time.DateTimeZone;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
@ -37,7 +37,7 @@ import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestWindowFunctions extends BaseND4JTest {
|
public class TestWindowFunctions extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
|
@ -26,10 +26,10 @@ import org.datavec.api.transform.schema.Schema;
|
||||||
import org.datavec.api.transform.serde.testClasses.CustomCondition;
|
import org.datavec.api.transform.serde.testClasses.CustomCondition;
|
||||||
import org.datavec.api.transform.serde.testClasses.CustomFilter;
|
import org.datavec.api.transform.serde.testClasses.CustomFilter;
|
||||||
import org.datavec.api.transform.serde.testClasses.CustomTransform;
|
import org.datavec.api.transform.serde.testClasses.CustomTransform;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestCustomTransformJsonYaml extends BaseND4JTest {
|
public class TestCustomTransformJsonYaml extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
|
@ -64,13 +64,13 @@ import org.datavec.api.transform.transform.time.TimeMathOpTransform;
|
||||||
import org.datavec.api.writable.comparator.DoubleWritableComparator;
|
import org.datavec.api.writable.comparator.DoubleWritableComparator;
|
||||||
import org.joda.time.DateTimeFieldType;
|
import org.joda.time.DateTimeFieldType;
|
||||||
import org.joda.time.DateTimeZone;
|
import org.joda.time.DateTimeZone;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestYamlJsonSerde extends BaseND4JTest {
|
public class TestYamlJsonSerde extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
|
@ -24,12 +24,12 @@ import org.datavec.api.transform.StringReduceOp;
|
||||||
import org.datavec.api.transform.schema.Schema;
|
import org.datavec.api.transform.schema.Schema;
|
||||||
import org.datavec.api.writable.Text;
|
import org.datavec.api.writable.Text;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestReduce extends BaseND4JTest {
|
public class TestReduce extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
|
@ -50,7 +50,7 @@ import org.datavec.api.writable.Text;
|
||||||
import org.datavec.api.writable.comparator.LongWritableComparator;
|
import org.datavec.api.writable.comparator.LongWritableComparator;
|
||||||
import org.joda.time.DateTimeFieldType;
|
import org.joda.time.DateTimeFieldType;
|
||||||
import org.joda.time.DateTimeZone;
|
import org.joda.time.DateTimeZone;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
import org.nd4j.common.io.ClassPathResource;
|
||||||
|
|
||||||
|
@ -61,7 +61,7 @@ import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class RegressionTestJson extends BaseND4JTest {
|
public class RegressionTestJson extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
|
@ -50,13 +50,13 @@ import org.datavec.api.writable.Text;
|
||||||
import org.datavec.api.writable.comparator.LongWritableComparator;
|
import org.datavec.api.writable.comparator.LongWritableComparator;
|
||||||
import org.joda.time.DateTimeFieldType;
|
import org.joda.time.DateTimeFieldType;
|
||||||
import org.joda.time.DateTimeZone;
|
import org.joda.time.DateTimeZone;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestJsonYaml extends BaseND4JTest {
|
public class TestJsonYaml extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
|
@ -59,7 +59,7 @@ import org.datavec.api.writable.*;
|
||||||
import org.joda.time.DateTimeFieldType;
|
import org.joda.time.DateTimeFieldType;
|
||||||
import org.joda.time.DateTimeZone;
|
import org.joda.time.DateTimeZone;
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -72,7 +72,7 @@ import java.util.*;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
import static junit.framework.TestCase.assertEquals;
|
import static junit.framework.TestCase.assertEquals;
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
public class TestTransforms extends BaseND4JTest {
|
public class TestTransforms extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
|
@ -29,7 +29,7 @@ import org.datavec.api.writable.DoubleWritable;
|
||||||
import org.datavec.api.writable.NDArrayWritable;
|
import org.datavec.api.writable.NDArrayWritable;
|
||||||
import org.datavec.api.writable.Text;
|
import org.datavec.api.writable.Text;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -39,7 +39,7 @@ import org.nd4j.linalg.ops.transforms.Transforms;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestNDArrayWritableTransforms extends BaseND4JTest {
|
public class TestNDArrayWritableTransforms extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
|
@ -30,13 +30,13 @@ import org.datavec.api.transform.ndarray.NDArrayScalarOpTransform;
|
||||||
import org.datavec.api.transform.schema.Schema;
|
import org.datavec.api.transform.schema.Schema;
|
||||||
import org.datavec.api.transform.serde.JsonSerializer;
|
import org.datavec.api.transform.serde.JsonSerializer;
|
||||||
import org.datavec.api.transform.serde.YamlSerializer;
|
import org.datavec.api.transform.serde.YamlSerializer;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestYamlJsonSerde extends BaseND4JTest {
|
public class TestYamlJsonSerde extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
|
@ -35,26 +35,26 @@ import org.datavec.api.writable.DoubleWritable;
|
||||||
import org.datavec.api.writable.Text;
|
import org.datavec.api.writable.Text;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.joda.time.DateTimeZone;
|
import org.joda.time.DateTimeZone;
|
||||||
import org.junit.Ignore;
|
import org.junit.jupiter.api.Disabled;
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.rules.TemporaryFolder;
|
|
||||||
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
|
import java.nio.file.Path;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestUI extends BaseND4JTest {
|
public class TestUI extends BaseND4JTest {
|
||||||
|
|
||||||
@Rule
|
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testUI() throws Exception {
|
public void testUI(@TempDir Path testDir) throws Exception {
|
||||||
Schema schema = new Schema.Builder().addColumnString("StringColumn").addColumnInteger("IntColumn")
|
Schema schema = new Schema.Builder().addColumnString("StringColumn").addColumnInteger("IntColumn")
|
||||||
.addColumnInteger("IntColumn2").addColumnInteger("IntColumn3")
|
.addColumnInteger("IntColumn2").addColumnInteger("IntColumn3")
|
||||||
.addColumnTime("TimeColumn", DateTimeZone.UTC).build();
|
.addColumnTime("TimeColumn", DateTimeZone.UTC).build();
|
||||||
|
@ -92,7 +92,7 @@ public class TestUI extends BaseND4JTest {
|
||||||
|
|
||||||
DataAnalysis da = new DataAnalysis(schema, list);
|
DataAnalysis da = new DataAnalysis(schema, list);
|
||||||
|
|
||||||
File fDir = testDir.newFolder();
|
File fDir = testDir.toFile();
|
||||||
String tempDir = fDir.getAbsolutePath();
|
String tempDir = fDir.getAbsolutePath();
|
||||||
String outPath = FilenameUtils.concat(tempDir, "datavec_transform_UITest.html");
|
String outPath = FilenameUtils.concat(tempDir, "datavec_transform_UITest.html");
|
||||||
System.out.println(outPath);
|
System.out.println(outPath);
|
||||||
|
@ -143,7 +143,7 @@ public class TestUI extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Ignore
|
@Disabled
|
||||||
public void testSequencePlot() throws Exception {
|
public void testSequencePlot() throws Exception {
|
||||||
|
|
||||||
Schema schema = new SequenceSchema.Builder().addColumnDouble("sinx")
|
Schema schema = new SequenceSchema.Builder().addColumnDouble("sinx")
|
||||||
|
|
|
@ -21,14 +21,14 @@
|
||||||
package org.datavec.api.writable;
|
package org.datavec.api.writable;
|
||||||
|
|
||||||
import org.datavec.api.transform.metadata.NDArrayMetaData;
|
import org.datavec.api.transform.metadata.NDArrayMetaData;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import java.io.*;
|
import java.io.*;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
public class TestNDArrayWritableAndSerialization extends BaseND4JTest {
|
public class TestNDArrayWritableAndSerialization extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
|
@ -41,7 +41,7 @@ import org.datavec.api.transform.schema.Schema;
|
||||||
import org.datavec.api.writable.*;
|
import org.datavec.api.writable.*;
|
||||||
import org.datavec.arrow.recordreader.ArrowRecordReader;
|
import org.datavec.arrow.recordreader.ArrowRecordReader;
|
||||||
import org.datavec.arrow.recordreader.ArrowWritableRecordBatch;
|
import org.datavec.arrow.recordreader.ArrowWritableRecordBatch;
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.api.io.TempDir;
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
|
@ -29,16 +29,16 @@ import org.datavec.api.writable.IntWritable;
|
||||||
import org.datavec.api.writable.Text;
|
import org.datavec.api.writable.Text;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.datavec.arrow.ArrowConverter;
|
import org.datavec.arrow.ArrowConverter;
|
||||||
import org.junit.Ignore;
|
import org.junit.jupiter.api.Disabled;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.junit.Assert.assertFalse;
|
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||||
|
|
||||||
public class ArrowWritableRecordTimeSeriesBatchTests extends BaseND4JTest {
|
public class ArrowWritableRecordTimeSeriesBatchTests extends BaseND4JTest {
|
||||||
|
|
||||||
|
@ -69,7 +69,7 @@ public class ArrowWritableRecordTimeSeriesBatchTests extends BaseND4JTest {
|
||||||
assertEquals(3,fieldVectors.size());
|
assertEquals(3,fieldVectors.size());
|
||||||
for(FieldVector fieldVector : fieldVectors) {
|
for(FieldVector fieldVector : fieldVectors) {
|
||||||
for(int i = 0; i < fieldVector.getValueCount(); i++) {
|
for(int i = 0; i < fieldVector.getValueCount(); i++) {
|
||||||
assertFalse("Index " + i + " was null for field vector " + fieldVector, fieldVector.isNull(i));
|
assertFalse( fieldVector.isNull(i),"Index " + i + " was null for field vector " + fieldVector);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -79,7 +79,7 @@ public class ArrowWritableRecordTimeSeriesBatchTests extends BaseND4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
//not worried about this till after next release
|
//not worried about this till after next release
|
||||||
@Ignore
|
@Disabled
|
||||||
public void testVariableLengthTS() {
|
public void testVariableLengthTS() {
|
||||||
Schema.Builder schema = new Schema.Builder()
|
Schema.Builder schema = new Schema.Builder()
|
||||||
.addColumnString("str")
|
.addColumnString("str")
|
||||||
|
|
|
@ -23,7 +23,7 @@ import org.apache.commons.io.FileUtils;
|
||||||
import org.datavec.api.io.labels.ParentPathLabelGenerator;
|
import org.datavec.api.io.labels.ParentPathLabelGenerator;
|
||||||
import org.datavec.api.split.FileSplit;
|
import org.datavec.api.split.FileSplit;
|
||||||
import org.datavec.image.recordreader.ImageRecordReader;
|
import org.datavec.image.recordreader.ImageRecordReader;
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.api.io.TempDir;
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
import org.nd4j.common.io.ClassPathResource;
|
||||||
|
|
|
@ -22,8 +22,8 @@ package org.datavec.image.loader;
|
||||||
|
|
||||||
import org.apache.commons.io.FilenameUtils;
|
import org.apache.commons.io.FilenameUtils;
|
||||||
import org.datavec.api.records.reader.RecordReader;
|
import org.datavec.api.records.reader.RecordReader;
|
||||||
import org.junit.Ignore;
|
import org.junit.jupiter.api.Disabled;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
|
@ -32,9 +32,9 @@ import java.io.InputStream;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.junit.Assert.assertNotNull;
|
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
|
@ -182,7 +182,7 @@ public class LoaderTests {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Ignore // Use when confirming data is getting stored
|
@Disabled // Use when confirming data is getting stored
|
||||||
@Test
|
@Test
|
||||||
public void testProcessCifar() {
|
public void testProcessCifar() {
|
||||||
int row = 32;
|
int row = 32;
|
||||||
|
@ -208,15 +208,15 @@ public class LoaderTests {
|
||||||
int minibatch = 100;
|
int minibatch = 100;
|
||||||
int nMinibatches = 50000 / minibatch;
|
int nMinibatches = 50000 / minibatch;
|
||||||
|
|
||||||
for( int i=0; i<nMinibatches; i++ ){
|
for( int i=0; i < nMinibatches; i++) {
|
||||||
DataSet ds = loader.next(minibatch);
|
DataSet ds = loader.next(minibatch);
|
||||||
String s = String.valueOf(i);
|
String s = String.valueOf(i);
|
||||||
assertNotNull(s, ds.getFeatures());
|
assertNotNull(ds.getFeatures(),s);
|
||||||
assertNotNull(s, ds.getLabels());
|
assertNotNull(ds.getLabels(),s);
|
||||||
|
|
||||||
assertEquals(s, minibatch, ds.getFeatures().size(0));
|
assertEquals(minibatch, ds.getFeatures().size(0),s);
|
||||||
assertEquals(s, minibatch, ds.getLabels().size(0));
|
assertEquals(minibatch, ds.getLabels().size(0),s);
|
||||||
assertEquals(s, 10, ds.getLabels().size(1));
|
assertEquals(10, ds.getLabels().size(1),s);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,7 +21,7 @@
|
||||||
package org.datavec.image.loader;
|
package org.datavec.image.loader;
|
||||||
|
|
||||||
import org.datavec.image.data.Image;
|
import org.datavec.image.data.Image;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.resources.Resources;
|
import org.nd4j.common.resources.Resources;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
|
@ -32,7 +32,7 @@ import java.io.FileInputStream;
|
||||||
import java.io.InputStream;
|
import java.io.InputStream;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
|
|
||||||
public class TestImageLoader {
|
public class TestImageLoader {
|
||||||
|
|
|
@ -30,9 +30,10 @@ import org.bytedeco.javacv.Java2DFrameConverter;
|
||||||
import org.bytedeco.javacv.OpenCVFrameConverter;
|
import org.bytedeco.javacv.OpenCVFrameConverter;
|
||||||
import org.datavec.image.data.Image;
|
import org.datavec.image.data.Image;
|
||||||
import org.datavec.image.data.ImageWritable;
|
import org.datavec.image.data.ImageWritable;
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.rules.TemporaryFolder;
|
|
||||||
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import org.nd4j.common.resources.Resources;
|
import org.nd4j.common.resources.Resources;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -42,16 +43,17 @@ import org.nd4j.common.io.ClassPathResource;
|
||||||
import java.awt.image.BufferedImage;
|
import java.awt.image.BufferedImage;
|
||||||
import java.io.*;
|
import java.io.*;
|
||||||
import java.lang.reflect.Field;
|
import java.lang.reflect.Field;
|
||||||
|
import java.nio.file.Path;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
|
||||||
import org.bytedeco.leptonica.*;
|
import org.bytedeco.leptonica.*;
|
||||||
import org.bytedeco.opencv.opencv_core.*;
|
import org.bytedeco.opencv.opencv_core.*;
|
||||||
import static org.bytedeco.leptonica.global.lept.*;
|
import static org.bytedeco.leptonica.global.lept.*;
|
||||||
import static org.bytedeco.opencv.global.opencv_core.*;
|
import static org.bytedeco.opencv.global.opencv_core.*;
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.junit.Assert.assertNotEquals;
|
import static org.junit.jupiter.api.Assertions.assertNotEquals;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
import static org.junit.Assert.fail;
|
import static org.junit.jupiter.api.Assertions.fail;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
|
@ -62,8 +64,6 @@ public class TestNativeImageLoader {
|
||||||
static final long seed = 10;
|
static final long seed = 10;
|
||||||
static final Random rng = new Random(seed);
|
static final Random rng = new Random(seed);
|
||||||
|
|
||||||
@Rule
|
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testConvertPix() throws Exception {
|
public void testConvertPix() throws Exception {
|
||||||
|
@ -566,8 +566,8 @@ public class TestNativeImageLoader {
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testNativeImageLoaderEmptyStreams() throws Exception {
|
public void testNativeImageLoaderEmptyStreams(@TempDir Path testDir) throws Exception {
|
||||||
File dir = testDir.newFolder();
|
File dir = testDir.toFile();
|
||||||
File f = new File(dir, "myFile.jpg");
|
File f = new File(dir, "myFile.jpg");
|
||||||
f.createNewFile();
|
f.createNewFile();
|
||||||
|
|
||||||
|
@ -578,7 +578,7 @@ public class TestNativeImageLoader {
|
||||||
fail("Expected exception");
|
fail("Expected exception");
|
||||||
} catch (IOException e){
|
} catch (IOException e){
|
||||||
String msg = e.getMessage();
|
String msg = e.getMessage();
|
||||||
assertTrue(msg, msg.contains("decode image"));
|
assertTrue(msg.contains("decode image"),msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
try(InputStream is = new FileInputStream(f)){
|
try(InputStream is = new FileInputStream(f)){
|
||||||
|
@ -586,7 +586,7 @@ public class TestNativeImageLoader {
|
||||||
fail("Expected exception");
|
fail("Expected exception");
|
||||||
} catch (IOException e){
|
} catch (IOException e){
|
||||||
String msg = e.getMessage();
|
String msg = e.getMessage();
|
||||||
assertTrue(msg, msg.contains("decode image"));
|
assertTrue(msg.contains("decode image"),msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
try(InputStream is = new FileInputStream(f)){
|
try(InputStream is = new FileInputStream(f)){
|
||||||
|
@ -594,7 +594,7 @@ public class TestNativeImageLoader {
|
||||||
fail("Expected exception");
|
fail("Expected exception");
|
||||||
} catch (IOException e){
|
} catch (IOException e){
|
||||||
String msg = e.getMessage();
|
String msg = e.getMessage();
|
||||||
assertTrue(msg, msg.contains("decode image"));
|
assertTrue(msg.contains("decode image"),msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
try(InputStream is = new FileInputStream(f)){
|
try(InputStream is = new FileInputStream(f)){
|
||||||
|
@ -603,7 +603,7 @@ public class TestNativeImageLoader {
|
||||||
fail("Expected exception");
|
fail("Expected exception");
|
||||||
} catch (IOException e){
|
} catch (IOException e){
|
||||||
String msg = e.getMessage();
|
String msg = e.getMessage();
|
||||||
assertTrue(msg, msg.contains("decode image"));
|
assertTrue( msg.contains("decode image"),msg);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -27,7 +27,7 @@ import org.datavec.api.writable.IntWritable;
|
||||||
import org.datavec.api.writable.NDArrayWritable;
|
import org.datavec.api.writable.NDArrayWritable;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.datavec.image.loader.NativeImageLoader;
|
import org.datavec.image.loader.NativeImageLoader;
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.api.io.TempDir;
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import org.nd4j.common.loader.FileBatch;
|
import org.nd4j.common.loader.FileBatch;
|
||||||
|
|
|
@ -36,9 +36,10 @@ import org.datavec.api.writable.DoubleWritable;
|
||||||
import org.datavec.api.writable.NDArrayWritable;
|
import org.datavec.api.writable.NDArrayWritable;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.datavec.api.writable.batch.NDArrayRecordBatch;
|
import org.datavec.api.writable.batch.NDArrayRecordBatch;
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.rules.TemporaryFolder;
|
|
||||||
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -46,28 +47,30 @@ import org.nd4j.common.io.ClassPathResource;
|
||||||
|
|
||||||
import java.io.*;
|
import java.io.*;
|
||||||
import java.net.URI;
|
import java.net.URI;
|
||||||
|
import java.nio.file.Path;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
public class TestImageRecordReader {
|
public class TestImageRecordReader {
|
||||||
|
|
||||||
@Rule
|
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
|
||||||
|
|
||||||
@Test(expected = IllegalArgumentException.class)
|
@Test()
|
||||||
public void testEmptySplit() throws IOException {
|
public void testEmptySplit() throws IOException {
|
||||||
InputSplit data = new CollectionInputSplit(new ArrayList<URI>());
|
assertThrows(IllegalArgumentException.class,() -> {
|
||||||
new ImageRecordReader().initialize(data, null);
|
InputSplit data = new CollectionInputSplit(new ArrayList<>());
|
||||||
|
new ImageRecordReader().initialize(data, null);
|
||||||
|
});
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testMetaData() throws IOException {
|
public void testMetaData(@TempDir Path testDir) throws IOException {
|
||||||
|
|
||||||
File parentDir = testDir.newFolder();
|
File parentDir = testDir.toFile();
|
||||||
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(parentDir);
|
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(parentDir);
|
||||||
// System.out.println(f.getAbsolutePath());
|
// System.out.println(f.getAbsolutePath());
|
||||||
// System.out.println(f.getParentFile().getParentFile().getAbsolutePath());
|
// System.out.println(f.getParentFile().getParentFile().getAbsolutePath());
|
||||||
|
@ -104,11 +107,11 @@ public class TestImageRecordReader {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testImageRecordReaderLabelsOrder() throws Exception {
|
public void testImageRecordReaderLabelsOrder(@TempDir Path testDir) throws Exception {
|
||||||
//Labels order should be consistent, regardless of file iteration order
|
//Labels order should be consistent, regardless of file iteration order
|
||||||
|
|
||||||
//Idea: labels order should be consistent regardless of input file order
|
//Idea: labels order should be consistent regardless of input file order
|
||||||
File f = testDir.newFolder();
|
File f = testDir.toFile();
|
||||||
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f);
|
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f);
|
||||||
File f0 = new File(f, "/class0/0.jpg");
|
File f0 = new File(f, "/class0/0.jpg");
|
||||||
File f1 = new File(f, "/class1/A.jpg");
|
File f1 = new File(f, "/class1/A.jpg");
|
||||||
|
@ -135,11 +138,11 @@ public class TestImageRecordReader {
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testImageRecordReaderRandomization() throws Exception {
|
public void testImageRecordReaderRandomization(@TempDir Path testDir) throws Exception {
|
||||||
//Order of FileSplit+ImageRecordReader should be different after reset
|
//Order of FileSplit+ImageRecordReader should be different after reset
|
||||||
|
|
||||||
//Idea: labels order should be consistent regardless of input file order
|
//Idea: labels order should be consistent regardless of input file order
|
||||||
File f0 = testDir.newFolder();
|
File f0 = testDir.toFile();
|
||||||
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f0);
|
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f0);
|
||||||
|
|
||||||
FileSplit fs = new FileSplit(f0, new Random(12345));
|
FileSplit fs = new FileSplit(f0, new Random(12345));
|
||||||
|
@ -189,13 +192,13 @@ public class TestImageRecordReader {
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testImageRecordReaderRegression() throws Exception {
|
public void testImageRecordReaderRegression(@TempDir Path testDir) throws Exception {
|
||||||
|
|
||||||
PathLabelGenerator regressionLabelGen = new TestRegressionLabelGen();
|
PathLabelGenerator regressionLabelGen = new TestRegressionLabelGen();
|
||||||
|
|
||||||
ImageRecordReader rr = new ImageRecordReader(28, 28, 3, regressionLabelGen);
|
ImageRecordReader rr = new ImageRecordReader(28, 28, 3, regressionLabelGen);
|
||||||
|
|
||||||
File rootDir = testDir.newFolder();
|
File rootDir = testDir.toFile();
|
||||||
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(rootDir);
|
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(rootDir);
|
||||||
FileSplit fs = new FileSplit(rootDir);
|
FileSplit fs = new FileSplit(rootDir);
|
||||||
rr.initialize(fs);
|
rr.initialize(fs);
|
||||||
|
@ -244,10 +247,10 @@ public class TestImageRecordReader {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testListenerInvocationBatch() throws IOException {
|
public void testListenerInvocationBatch(@TempDir Path testDir) throws IOException {
|
||||||
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
|
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
|
||||||
ImageRecordReader rr = new ImageRecordReader(32, 32, 3, labelMaker);
|
ImageRecordReader rr = new ImageRecordReader(32, 32, 3, labelMaker);
|
||||||
File f = testDir.newFolder();
|
File f = testDir.toFile();
|
||||||
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f);
|
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f);
|
||||||
|
|
||||||
File parent = f;
|
File parent = f;
|
||||||
|
@ -260,10 +263,10 @@ public class TestImageRecordReader {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testListenerInvocationSingle() throws IOException {
|
public void testListenerInvocationSingle(@TempDir Path testDir) throws IOException {
|
||||||
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
|
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
|
||||||
ImageRecordReader rr = new ImageRecordReader(32, 32, 3, labelMaker);
|
ImageRecordReader rr = new ImageRecordReader(32, 32, 3, labelMaker);
|
||||||
File parent = testDir.newFolder();
|
File parent = testDir.toFile();
|
||||||
new ClassPathResource("datavec-data-image/testimages/class0/").copyDirectory(parent);
|
new ClassPathResource("datavec-data-image/testimages/class0/").copyDirectory(parent);
|
||||||
int numFiles = parent.list().length;
|
int numFiles = parent.list().length;
|
||||||
rr.initialize(new FileSplit(parent));
|
rr.initialize(new FileSplit(parent));
|
||||||
|
@ -315,7 +318,7 @@ public class TestImageRecordReader {
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testImageRecordReaderPathMultiLabelGenerator() throws Exception {
|
public void testImageRecordReaderPathMultiLabelGenerator(@TempDir Path testDir) throws Exception {
|
||||||
Nd4j.setDataType(DataType.FLOAT);
|
Nd4j.setDataType(DataType.FLOAT);
|
||||||
//Assumption: 2 multi-class (one hot) classification labels: 2 and 3 classes respectively
|
//Assumption: 2 multi-class (one hot) classification labels: 2 and 3 classes respectively
|
||||||
// PLUS single value (Writable) regression label
|
// PLUS single value (Writable) regression label
|
||||||
|
@ -324,7 +327,7 @@ public class TestImageRecordReader {
|
||||||
|
|
||||||
ImageRecordReader rr = new ImageRecordReader(28, 28, 3, multiLabelGen);
|
ImageRecordReader rr = new ImageRecordReader(28, 28, 3, multiLabelGen);
|
||||||
|
|
||||||
File rootDir = testDir.newFolder();
|
File rootDir = testDir.toFile();
|
||||||
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(rootDir);
|
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(rootDir);
|
||||||
FileSplit fs = new FileSplit(rootDir);
|
FileSplit fs = new FileSplit(rootDir);
|
||||||
rr.initialize(fs);
|
rr.initialize(fs);
|
||||||
|
@ -471,9 +474,9 @@ public class TestImageRecordReader {
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testNCHW_NCHW() throws Exception {
|
public void testNCHW_NCHW(@TempDir Path testDir) throws Exception {
|
||||||
//Idea: labels order should be consistent regardless of input file order
|
//Idea: labels order should be consistent regardless of input file order
|
||||||
File f0 = testDir.newFolder();
|
File f0 = testDir.toFile();
|
||||||
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f0);
|
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f0);
|
||||||
|
|
||||||
FileSplit fs0 = new FileSplit(f0, new Random(12345));
|
FileSplit fs0 = new FileSplit(f0, new Random(12345));
|
||||||
|
|
|
@ -35,9 +35,10 @@ import org.datavec.image.transform.FlipImageTransform;
|
||||||
import org.datavec.image.transform.ImageTransform;
|
import org.datavec.image.transform.ImageTransform;
|
||||||
import org.datavec.image.transform.PipelineImageTransform;
|
import org.datavec.image.transform.PipelineImageTransform;
|
||||||
import org.datavec.image.transform.ResizeImageTransform;
|
import org.datavec.image.transform.ResizeImageTransform;
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.rules.TemporaryFolder;
|
|
||||||
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.indexing.BooleanIndexing;
|
import org.nd4j.linalg.indexing.BooleanIndexing;
|
||||||
|
@ -46,24 +47,24 @@ import org.nd4j.common.io.ClassPathResource;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.net.URI;
|
import java.net.URI;
|
||||||
|
import java.nio.file.Path;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
public class TestObjectDetectionRecordReader {
|
public class TestObjectDetectionRecordReader {
|
||||||
|
|
||||||
@Rule
|
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void test() throws Exception {
|
public void test(@TempDir Path testDir) throws Exception {
|
||||||
for(boolean nchw : new boolean[]{true, false}) {
|
for(boolean nchw : new boolean[]{true, false}) {
|
||||||
ImageObjectLabelProvider lp = new TestImageObjectDetectionLabelProvider();
|
ImageObjectLabelProvider lp = new TestImageObjectDetectionLabelProvider();
|
||||||
|
|
||||||
File f = testDir.newFolder();
|
File f = testDir.toFile();
|
||||||
new ClassPathResource("datavec-data-image/objdetect/").copyDirectory(f);
|
new ClassPathResource("datavec-data-image/objdetect/").copyDirectory(f);
|
||||||
|
|
||||||
String path = new File(f, "000012.jpg").getParent();
|
String path = new File(f, "000012.jpg").getParent();
|
||||||
|
|
|
@ -21,27 +21,27 @@
|
||||||
package org.datavec.image.recordreader.objdetect;
|
package org.datavec.image.recordreader.objdetect;
|
||||||
|
|
||||||
import org.datavec.image.recordreader.objdetect.impl.VocLabelProvider;
|
import org.datavec.image.recordreader.objdetect.impl.VocLabelProvider;
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.rules.TemporaryFolder;
|
|
||||||
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
import org.nd4j.common.io.ClassPathResource;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
|
import java.nio.file.Path;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestVocLabelProvider {
|
public class TestVocLabelProvider {
|
||||||
|
|
||||||
@Rule
|
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testVocLabelProvider() throws Exception {
|
public void testVocLabelProvider(@TempDir Path testDir) throws Exception {
|
||||||
|
|
||||||
File f = testDir.newFolder();
|
File f = testDir.toFile();
|
||||||
new ClassPathResource("datavec-data-image/voc/2007/").copyDirectory(f);
|
new ClassPathResource("datavec-data-image/voc/2007/").copyDirectory(f);
|
||||||
|
|
||||||
String path = f.getAbsolutePath(); //new ClassPathResource("voc/2007/JPEGImages/000005.jpg").getFile().getParentFile().getParent();
|
String path = f.getAbsolutePath(); //new ClassPathResource("voc/2007/JPEGImages/000005.jpg").getFile().getParentFile().getParent();
|
||||||
|
|
|
@ -28,8 +28,8 @@ import org.nd4j.common.io.ClassPathResource;
|
||||||
import org.nd4j.common.primitives.Pair;
|
import org.nd4j.common.primitives.Pair;
|
||||||
import org.datavec.image.data.ImageWritable;
|
import org.datavec.image.data.ImageWritable;
|
||||||
import org.datavec.image.loader.NativeImageLoader;
|
import org.datavec.image.loader.NativeImageLoader;
|
||||||
import org.junit.Ignore;
|
import org.junit.jupiter.api.Disabled;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
import java.awt.*;
|
import java.awt.*;
|
||||||
import java.util.LinkedList;
|
import java.util.LinkedList;
|
||||||
|
@ -40,7 +40,7 @@ import org.bytedeco.opencv.opencv_core.*;
|
||||||
|
|
||||||
import static org.bytedeco.opencv.global.opencv_core.*;
|
import static org.bytedeco.opencv.global.opencv_core.*;
|
||||||
import static org.bytedeco.opencv.global.opencv_imgproc.*;
|
import static org.bytedeco.opencv.global.opencv_imgproc.*;
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
|
@ -255,7 +255,7 @@ public class TestImageTransform {
|
||||||
assertEquals(22, transformed[1], 0);
|
assertEquals(22, transformed[1], 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Ignore
|
@Disabled
|
||||||
@Test
|
@Test
|
||||||
public void testFilterImageTransform() throws Exception {
|
public void testFilterImageTransform() throws Exception {
|
||||||
ImageWritable writable = makeRandomImage(0, 0, 4);
|
ImageWritable writable = makeRandomImage(0, 0, 4);
|
||||||
|
|
|
@ -25,7 +25,7 @@ import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
|
||||||
import org.datavec.api.transform.schema.Schema;
|
import org.datavec.api.transform.schema.Schema;
|
||||||
import org.datavec.api.writable.IntWritable;
|
import org.datavec.api.writable.IntWritable;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.api.io.TempDir;
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import org.nd4j.common.primitives.Triple;
|
import org.nd4j.common.primitives.Triple;
|
||||||
|
|
|
@ -49,7 +49,7 @@ import org.datavec.api.writable.Text;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.jupiter.api.AfterEach;
|
import org.junit.jupiter.api.AfterEach;
|
||||||
import org.junit.jupiter.api.BeforeEach;
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.api.io.TempDir;
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import org.junit.jupiter.api.DisplayName;
|
import org.junit.jupiter.api.DisplayName;
|
||||||
|
|
|
@ -36,14 +36,14 @@ import org.datavec.api.writable.LongWritable;
|
||||||
import org.datavec.api.writable.Text;
|
import org.datavec.api.writable.Text;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.joda.time.DateTimeZone;
|
import org.joda.time.DateTimeZone;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
import org.nd4j.common.io.ClassPathResource;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class LocalTransformProcessRecordReaderTests {
|
public class LocalTransformProcessRecordReaderTests {
|
||||||
|
|
||||||
|
|
|
@ -29,9 +29,9 @@ import org.datavec.api.transform.schema.Schema;
|
||||||
import org.datavec.api.util.ndarray.RecordConverter;
|
import org.datavec.api.util.ndarray.RecordConverter;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.datavec.local.transforms.AnalyzeLocal;
|
import org.datavec.local.transforms.AnalyzeLocal;
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.rules.TemporaryFolder;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
import org.nd4j.common.io.ClassPathResource;
|
||||||
|
@ -39,12 +39,11 @@ import org.nd4j.common.io.ClassPathResource;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestAnalyzeLocal {
|
public class TestAnalyzeLocal {
|
||||||
|
|
||||||
@Rule
|
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testAnalysisBasic() throws Exception {
|
public void testAnalysisBasic() throws Exception {
|
||||||
|
@ -72,7 +71,7 @@ public class TestAnalyzeLocal {
|
||||||
INDArray mean = arr.mean(0);
|
INDArray mean = arr.mean(0);
|
||||||
INDArray std = arr.std(0);
|
INDArray std = arr.std(0);
|
||||||
|
|
||||||
for( int i=0; i<5; i++ ){
|
for( int i = 0; i < 5; i++) {
|
||||||
double m = ((NumericalColumnAnalysis)da.getColumnAnalysis().get(i)).getMean();
|
double m = ((NumericalColumnAnalysis)da.getColumnAnalysis().get(i)).getMean();
|
||||||
double stddev = ((NumericalColumnAnalysis)da.getColumnAnalysis().get(i)).getSampleStdev();
|
double stddev = ((NumericalColumnAnalysis)da.getColumnAnalysis().get(i)).getSampleStdev();
|
||||||
assertEquals(mean.getDouble(i), m, 1e-3);
|
assertEquals(mean.getDouble(i), m, 1e-3);
|
||||||
|
|
|
@ -27,7 +27,7 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
|
||||||
import org.datavec.api.split.FileSplit;
|
import org.datavec.api.split.FileSplit;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
|
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
import org.nd4j.common.io.ClassPathResource;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
|
@ -36,8 +36,8 @@ import java.util.List;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
public class TestLineRecordReaderFunction {
|
public class TestLineRecordReaderFunction {
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,7 @@ import org.datavec.api.writable.NDArrayWritable;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
|
|
||||||
import org.datavec.local.transforms.misc.NDArrayToWritablesFunction;
|
import org.datavec.local.transforms.misc.NDArrayToWritablesFunction;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
@ -33,7 +33,7 @@ import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestNDArrayToWritablesFunction {
|
public class TestNDArrayToWritablesFunction {
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,7 @@ import org.datavec.api.writable.NDArrayWritable;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
|
|
||||||
import org.datavec.local.transforms.misc.WritablesToNDArrayFunction;
|
import org.datavec.local.transforms.misc.WritablesToNDArrayFunction;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -33,7 +33,7 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestWritablesToNDArrayFunction {
|
public class TestWritablesToNDArrayFunction {
|
||||||
|
|
||||||
|
|
|
@ -30,12 +30,12 @@ import org.datavec.api.writable.Writable;
|
||||||
|
|
||||||
import org.datavec.local.transforms.misc.SequenceWritablesToStringFunction;
|
import org.datavec.local.transforms.misc.SequenceWritablesToStringFunction;
|
||||||
import org.datavec.local.transforms.misc.WritablesToStringFunction;
|
import org.datavec.local.transforms.misc.WritablesToStringFunction;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestWritablesToStringFunctions {
|
public class TestWritablesToStringFunctions {
|
||||||
|
|
||||||
|
|
|
@ -32,7 +32,8 @@ import org.datavec.api.writable.Text;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.AfterClass;
|
import org.junit.AfterClass;
|
||||||
import org.junit.BeforeClass;
|
import org.junit.BeforeClass;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.BeforeAll;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
import org.nd4j.common.io.ClassPathResource;
|
||||||
|
|
||||||
import java.io.*;
|
import java.io.*;
|
||||||
|
@ -40,14 +41,14 @@ import java.util.Arrays;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author saudet
|
* @author saudet
|
||||||
*/
|
*/
|
||||||
public class TestGeoTransforms {
|
public class TestGeoTransforms {
|
||||||
|
|
||||||
@BeforeClass
|
@BeforeAll
|
||||||
public static void beforeClass() throws Exception {
|
public static void beforeClass() throws Exception {
|
||||||
//Use test resources version to avoid tests suddenly failing due to IP/Location DB content changing
|
//Use test resources version to avoid tests suddenly failing due to IP/Location DB content changing
|
||||||
File f = new ClassPathResource("datavec-geo/GeoIP2-City-Test.mmdb").getFile();
|
File f = new ClassPathResource("datavec-geo/GeoIP2-City-Test.mmdb").getFile();
|
||||||
|
@ -63,7 +64,7 @@ public class TestGeoTransforms {
|
||||||
@Test
|
@Test
|
||||||
public void testCoordinatesDistanceTransform() throws Exception {
|
public void testCoordinatesDistanceTransform() throws Exception {
|
||||||
Schema schema = new Schema.Builder().addColumnString("point").addColumnString("mean").addColumnString("stddev")
|
Schema schema = new Schema.Builder().addColumnString("point").addColumnString("mean").addColumnString("stddev")
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
Transform transform = new CoordinatesDistanceTransform("dist", "point", "mean", "stddev", "\\|");
|
Transform transform = new CoordinatesDistanceTransform("dist", "point", "mean", "stddev", "\\|");
|
||||||
transform.setInputSchema(schema);
|
transform.setInputSchema(schema);
|
||||||
|
@ -72,14 +73,14 @@ public class TestGeoTransforms {
|
||||||
assertEquals(4, out.numColumns());
|
assertEquals(4, out.numColumns());
|
||||||
assertEquals(Arrays.asList("point", "mean", "stddev", "dist"), out.getColumnNames());
|
assertEquals(Arrays.asList("point", "mean", "stddev", "dist"), out.getColumnNames());
|
||||||
assertEquals(Arrays.asList(ColumnType.String, ColumnType.String, ColumnType.String, ColumnType.Double),
|
assertEquals(Arrays.asList(ColumnType.String, ColumnType.String, ColumnType.String, ColumnType.Double),
|
||||||
out.getColumnTypes());
|
out.getColumnTypes());
|
||||||
|
|
||||||
assertEquals(Arrays.asList((Writable) new Text("-30"), new Text("20"), new Text("10"), new DoubleWritable(5.0)),
|
assertEquals(Arrays.asList((Writable) new Text("-30"), new Text("20"), new Text("10"), new DoubleWritable(5.0)),
|
||||||
transform.map(Arrays.asList((Writable) new Text("-30"), new Text("20"), new Text("10"))));
|
transform.map(Arrays.asList((Writable) new Text("-30"), new Text("20"), new Text("10"))));
|
||||||
assertEquals(Arrays.asList((Writable) new Text("50|40"), new Text("10|-20"), new Text("10|5"),
|
assertEquals(Arrays.asList((Writable) new Text("50|40"), new Text("10|-20"), new Text("10|5"),
|
||||||
new DoubleWritable(Math.sqrt(160))),
|
new DoubleWritable(Math.sqrt(160))),
|
||||||
transform.map(Arrays.asList((Writable) new Text("50|40"), new Text("10|-20"),
|
transform.map(Arrays.asList((Writable) new Text("50|40"), new Text("10|-20"),
|
||||||
new Text("10|5"))));
|
new Text("10|5"))));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
|
@ -30,7 +30,8 @@ import org.datavec.local.transforms.LocalTransformExecutor;
|
||||||
import org.datavec.api.writable.*;
|
import org.datavec.api.writable.*;
|
||||||
import org.datavec.python.PythonCondition;
|
import org.datavec.python.PythonCondition;
|
||||||
import org.datavec.python.PythonTransform;
|
import org.datavec.python.PythonTransform;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.api.Timeout;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -43,7 +44,7 @@ import java.util.List;
|
||||||
|
|
||||||
import static junit.framework.TestCase.assertTrue;
|
import static junit.framework.TestCase.assertTrue;
|
||||||
import static org.datavec.api.transform.schema.Schema.Builder;
|
import static org.datavec.api.transform.schema.Schema.Builder;
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
@NotThreadSafe
|
@NotThreadSafe
|
||||||
public class TestPythonTransformProcess {
|
public class TestPythonTransformProcess {
|
||||||
|
@ -77,8 +78,9 @@ public class TestPythonTransformProcess {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(timeout = 60000L)
|
@Test()
|
||||||
public void testMixedTypes() throws Exception{
|
@Timeout(60000L)
|
||||||
|
public void testMixedTypes() throws Exception {
|
||||||
Builder schemaBuilder = new Builder();
|
Builder schemaBuilder = new Builder();
|
||||||
schemaBuilder
|
schemaBuilder
|
||||||
.addColumnInteger("col1")
|
.addColumnInteger("col1")
|
||||||
|
@ -99,7 +101,7 @@ public class TestPythonTransformProcess {
|
||||||
.inputSchema(initialSchema)
|
.inputSchema(initialSchema)
|
||||||
.build() ).build();
|
.build() ).build();
|
||||||
|
|
||||||
List<Writable> inputs = Arrays.asList((Writable)new IntWritable(10),
|
List<Writable> inputs = Arrays.asList(new IntWritable(10),
|
||||||
new FloatWritable(3.5f),
|
new FloatWritable(3.5f),
|
||||||
new Text("5"),
|
new Text("5"),
|
||||||
new DoubleWritable(2.0)
|
new DoubleWritable(2.0)
|
||||||
|
@ -109,8 +111,9 @@ public class TestPythonTransformProcess {
|
||||||
assertEquals(((LongWritable)outputs.get(4)).get(), 36);
|
assertEquals(((LongWritable)outputs.get(4)).get(), 36);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(timeout = 60000L)
|
@Test()
|
||||||
public void testNDArray() throws Exception{
|
@Timeout(60000L)
|
||||||
|
public void testNDArray() throws Exception {
|
||||||
long[] shape = new long[]{3, 2};
|
long[] shape = new long[]{3, 2};
|
||||||
INDArray arr1 = Nd4j.rand(shape);
|
INDArray arr1 = Nd4j.rand(shape);
|
||||||
INDArray arr2 = Nd4j.rand(shape);
|
INDArray arr2 = Nd4j.rand(shape);
|
||||||
|
@ -145,8 +148,9 @@ public class TestPythonTransformProcess {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(timeout = 60000L)
|
@Test()
|
||||||
public void testNDArray2() throws Exception{
|
@Timeout(60000L)
|
||||||
|
public void testNDArray2() throws Exception {
|
||||||
long[] shape = new long[]{3, 2};
|
long[] shape = new long[]{3, 2};
|
||||||
INDArray arr1 = Nd4j.rand(shape);
|
INDArray arr1 = Nd4j.rand(shape);
|
||||||
INDArray arr2 = Nd4j.rand(shape);
|
INDArray arr2 = Nd4j.rand(shape);
|
||||||
|
@ -181,7 +185,8 @@ public class TestPythonTransformProcess {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(timeout = 60000L)
|
@Test()
|
||||||
|
@Timeout(60000L)
|
||||||
public void testNDArrayMixed() throws Exception{
|
public void testNDArrayMixed() throws Exception{
|
||||||
long[] shape = new long[]{3, 2};
|
long[] shape = new long[]{3, 2};
|
||||||
INDArray arr1 = Nd4j.rand(DataType.DOUBLE, shape);
|
INDArray arr1 = Nd4j.rand(DataType.DOUBLE, shape);
|
||||||
|
@ -217,7 +222,8 @@ public class TestPythonTransformProcess {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(timeout = 60000L)
|
@Test()
|
||||||
|
@Timeout(60000L)
|
||||||
public void testPythonFilter() {
|
public void testPythonFilter() {
|
||||||
Schema schema = new Builder().addColumnInteger("column").build();
|
Schema schema = new Builder().addColumnInteger("column").build();
|
||||||
|
|
||||||
|
@ -237,8 +243,9 @@ public class TestPythonTransformProcess {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(timeout = 60000L)
|
@Test()
|
||||||
public void testPythonFilterAndTransform() throws Exception{
|
@Timeout(60000L)
|
||||||
|
public void testPythonFilterAndTransform() throws Exception {
|
||||||
Builder schemaBuilder = new Builder();
|
Builder schemaBuilder = new Builder();
|
||||||
schemaBuilder
|
schemaBuilder
|
||||||
.addColumnInteger("col1")
|
.addColumnInteger("col1")
|
||||||
|
|
|
@ -28,11 +28,11 @@ import org.datavec.api.writable.*;
|
||||||
|
|
||||||
|
|
||||||
import org.datavec.local.transforms.LocalTransformExecutor;
|
import org.datavec.local.transforms.LocalTransformExecutor;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestJoin {
|
public class TestJoin {
|
||||||
|
|
||||||
|
|
|
@ -31,13 +31,13 @@ import org.datavec.api.writable.comparator.DoubleWritableComparator;
|
||||||
|
|
||||||
|
|
||||||
import org.datavec.local.transforms.LocalTransformExecutor;
|
import org.datavec.local.transforms.LocalTransformExecutor;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestCalculateSortedRank {
|
public class TestCalculateSortedRank {
|
||||||
|
|
||||||
|
|
|
@ -31,14 +31,14 @@ import org.datavec.api.writable.Writable;
|
||||||
|
|
||||||
import org.datavec.arrow.recordreader.ArrowWritableRecordTimeSeriesBatch;
|
import org.datavec.arrow.recordreader.ArrowWritableRecordTimeSeriesBatch;
|
||||||
import org.datavec.local.transforms.LocalTransformExecutor;
|
import org.datavec.local.transforms.LocalTransformExecutor;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
public class TestConvertToSequence {
|
public class TestConvertToSequence {
|
||||||
|
|
||||||
|
|
|
@ -41,6 +41,12 @@
|
||||||
</properties>
|
</properties>
|
||||||
|
|
||||||
<dependencies>
|
<dependencies>
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.tdunning</groupId>
|
||||||
|
<artifactId>t-digest</artifactId>
|
||||||
|
<version>3.2</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.scala-lang</groupId>
|
<groupId>org.scala-lang</groupId>
|
||||||
<artifactId>scala-library</artifactId>
|
<artifactId>scala-library</artifactId>
|
||||||
|
|
|
@ -25,15 +25,15 @@ import org.apache.spark.serializer.SerializerInstance;
|
||||||
import org.datavec.api.records.reader.RecordReader;
|
import org.datavec.api.records.reader.RecordReader;
|
||||||
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
|
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
|
||||||
import org.datavec.api.split.FileSplit;
|
import org.datavec.api.split.FileSplit;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
import org.nd4j.common.io.ClassPathResource;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.junit.Assert.assertFalse;
|
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
public class TestKryoSerialization extends BaseSparkTest {
|
public class TestKryoSerialization extends BaseSparkTest {
|
||||||
|
|
||||||
|
|
|
@ -27,7 +27,7 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
|
||||||
import org.datavec.api.split.FileSplit;
|
import org.datavec.api.split.FileSplit;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.datavec.spark.BaseSparkTest;
|
import org.datavec.spark.BaseSparkTest;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
import org.nd4j.common.io.ClassPathResource;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
|
@ -35,8 +35,8 @@ import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
public class TestLineRecordReaderFunction extends BaseSparkTest {
|
public class TestLineRecordReaderFunction extends BaseSparkTest {
|
||||||
|
|
||||||
|
|
|
@ -24,7 +24,7 @@ import org.datavec.api.writable.DoubleWritable;
|
||||||
import org.datavec.api.writable.NDArrayWritable;
|
import org.datavec.api.writable.NDArrayWritable;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.datavec.spark.transform.misc.NDArrayToWritablesFunction;
|
import org.datavec.spark.transform.misc.NDArrayToWritablesFunction;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
@ -32,7 +32,7 @@ import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestNDArrayToWritablesFunction {
|
public class TestNDArrayToWritablesFunction {
|
||||||
|
|
||||||
|
|
|
@ -38,9 +38,10 @@ import org.datavec.spark.functions.pairdata.PairSequenceRecordReaderBytesFunctio
|
||||||
import org.datavec.spark.functions.pairdata.PathToKeyConverter;
|
import org.datavec.spark.functions.pairdata.PathToKeyConverter;
|
||||||
import org.datavec.spark.functions.pairdata.PathToKeyConverterFilename;
|
import org.datavec.spark.functions.pairdata.PathToKeyConverterFilename;
|
||||||
import org.datavec.spark.util.DataVecSparkUtil;
|
import org.datavec.spark.util.DataVecSparkUtil;
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.rules.TemporaryFolder;
|
|
||||||
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
import org.nd4j.common.io.ClassPathResource;
|
||||||
import scala.Tuple2;
|
import scala.Tuple2;
|
||||||
|
|
||||||
|
@ -50,16 +51,13 @@ import java.nio.file.Path;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.junit.Assert.fail;
|
import static org.junit.jupiter.api.Assertions.fail;
|
||||||
|
|
||||||
public class TestPairSequenceRecordReaderBytesFunction extends BaseSparkTest {
|
public class TestPairSequenceRecordReaderBytesFunction extends BaseSparkTest {
|
||||||
|
|
||||||
@Rule
|
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void test() throws Exception {
|
public void test(@TempDir Path testDir) throws Exception {
|
||||||
//Goal: combine separate files together into a hadoop sequence file, for later parsing by a SequenceRecordReader
|
//Goal: combine separate files together into a hadoop sequence file, for later parsing by a SequenceRecordReader
|
||||||
//For example: use to combine input and labels data from separate files for training a RNN
|
//For example: use to combine input and labels data from separate files for training a RNN
|
||||||
if(Platform.isWindows()) {
|
if(Platform.isWindows()) {
|
||||||
|
@ -67,7 +65,7 @@ public class TestPairSequenceRecordReaderBytesFunction extends BaseSparkTest {
|
||||||
}
|
}
|
||||||
JavaSparkContext sc = getContext();
|
JavaSparkContext sc = getContext();
|
||||||
|
|
||||||
File f = testDir.newFolder();
|
File f = testDir.toFile();
|
||||||
new ClassPathResource("datavec-spark/video/").copyDirectory(f);
|
new ClassPathResource("datavec-spark/video/").copyDirectory(f);
|
||||||
String path = f.getAbsolutePath() + "/*";
|
String path = f.getAbsolutePath() + "/*";
|
||||||
|
|
||||||
|
|
|
@ -36,9 +36,10 @@ import org.datavec.image.recordreader.ImageRecordReader;
|
||||||
import org.datavec.spark.BaseSparkTest;
|
import org.datavec.spark.BaseSparkTest;
|
||||||
import org.datavec.spark.functions.data.FilesAsBytesFunction;
|
import org.datavec.spark.functions.data.FilesAsBytesFunction;
|
||||||
import org.datavec.spark.functions.data.RecordReaderBytesFunction;
|
import org.datavec.spark.functions.data.RecordReaderBytesFunction;
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.rules.TemporaryFolder;
|
|
||||||
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
import org.nd4j.common.io.ClassPathResource;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
|
@ -48,23 +49,22 @@ import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.junit.Assert.fail;
|
import static org.junit.jupiter.api.Assertions.fail;
|
||||||
|
|
||||||
public class TestRecordReaderBytesFunction extends BaseSparkTest {
|
public class TestRecordReaderBytesFunction extends BaseSparkTest {
|
||||||
|
|
||||||
@Rule
|
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRecordReaderBytesFunction() throws Exception {
|
public void testRecordReaderBytesFunction(@TempDir Path testDir) throws Exception {
|
||||||
if(Platform.isWindows()) {
|
if(Platform.isWindows()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
JavaSparkContext sc = getContext();
|
JavaSparkContext sc = getContext();
|
||||||
|
|
||||||
//Local file path
|
//Local file path
|
||||||
File f = testDir.newFolder();
|
File f = testDir.toFile();
|
||||||
new ClassPathResource("datavec-spark/imagetest/").copyDirectory(f);
|
new ClassPathResource("datavec-spark/imagetest/").copyDirectory(f);
|
||||||
List<String> labelsList = Arrays.asList("0", "1"); //Need this for Spark: can't infer without init call
|
List<String> labelsList = Arrays.asList("0", "1"); //Need this for Spark: can't infer without init call
|
||||||
String path = f.getAbsolutePath() + "/*";
|
String path = f.getAbsolutePath() + "/*";
|
||||||
|
|
|
@ -31,30 +31,29 @@ import org.datavec.api.writable.ArrayWritable;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.datavec.image.recordreader.ImageRecordReader;
|
import org.datavec.image.recordreader.ImageRecordReader;
|
||||||
import org.datavec.spark.BaseSparkTest;
|
import org.datavec.spark.BaseSparkTest;
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.rules.TemporaryFolder;
|
|
||||||
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
import org.nd4j.common.io.ClassPathResource;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
|
import java.nio.file.Path;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.junit.Assert.fail;
|
import static org.junit.jupiter.api.Assertions.fail;
|
||||||
|
|
||||||
public class TestRecordReaderFunction extends BaseSparkTest {
|
public class TestRecordReaderFunction extends BaseSparkTest {
|
||||||
|
|
||||||
@Rule
|
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRecordReaderFunction() throws Exception {
|
public void testRecordReaderFunction(@TempDir Path testDir) throws Exception {
|
||||||
if(Platform.isWindows()) {
|
if(Platform.isWindows()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
File f = testDir.newFolder();
|
File f = testDir.toFile();
|
||||||
new ClassPathResource("datavec-spark/imagetest/").copyDirectory(f);
|
new ClassPathResource("datavec-spark/imagetest/").copyDirectory(f);
|
||||||
List<String> labelsList = Arrays.asList("0", "1"); //Need this for Spark: can't infer without init call
|
List<String> labelsList = Arrays.asList("0", "1"); //Need this for Spark: can't infer without init call
|
||||||
|
|
||||||
|
|
|
@ -36,9 +36,10 @@ import org.datavec.codec.reader.CodecRecordReader;
|
||||||
import org.datavec.spark.BaseSparkTest;
|
import org.datavec.spark.BaseSparkTest;
|
||||||
import org.datavec.spark.functions.data.FilesAsBytesFunction;
|
import org.datavec.spark.functions.data.FilesAsBytesFunction;
|
||||||
import org.datavec.spark.functions.data.SequenceRecordReaderBytesFunction;
|
import org.datavec.spark.functions.data.SequenceRecordReaderBytesFunction;
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.rules.TemporaryFolder;
|
|
||||||
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
import org.nd4j.common.io.ClassPathResource;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
|
@ -47,21 +48,20 @@ import java.nio.file.Path;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.junit.Assert.fail;
|
import static org.junit.jupiter.api.Assertions.fail;
|
||||||
|
|
||||||
public class TestSequenceRecordReaderBytesFunction extends BaseSparkTest {
|
public class TestSequenceRecordReaderBytesFunction extends BaseSparkTest {
|
||||||
|
|
||||||
@Rule
|
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRecordReaderBytesFunction() throws Exception {
|
public void testRecordReaderBytesFunction(@TempDir Path testDir) throws Exception {
|
||||||
if(Platform.isWindows()) {
|
if(Platform.isWindows()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
//Local file path
|
//Local file path
|
||||||
File f = testDir.newFolder();
|
File f = testDir.toFile();
|
||||||
new ClassPathResource("datavec-spark/video/").copyDirectory(f);
|
new ClassPathResource("datavec-spark/video/").copyDirectory(f);
|
||||||
String path = f.getAbsolutePath() + "/*";
|
String path = f.getAbsolutePath() + "/*";
|
||||||
|
|
||||||
|
|
|
@ -33,28 +33,29 @@ import org.datavec.api.writable.ArrayWritable;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.datavec.codec.reader.CodecRecordReader;
|
import org.datavec.codec.reader.CodecRecordReader;
|
||||||
import org.datavec.spark.BaseSparkTest;
|
import org.datavec.spark.BaseSparkTest;
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.rules.TemporaryFolder;
|
|
||||||
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
import org.nd4j.common.io.ClassPathResource;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
|
import java.nio.file.Path;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.junit.Assert.fail;
|
import static org.junit.jupiter.api.Assertions.fail;
|
||||||
|
|
||||||
public class TestSequenceRecordReaderFunction extends BaseSparkTest {
|
public class TestSequenceRecordReaderFunction extends BaseSparkTest {
|
||||||
|
|
||||||
@Rule
|
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testSequenceRecordReaderFunctionCSV() throws Exception {
|
public void testSequenceRecordReaderFunctionCSV(@TempDir Path testDir) throws Exception {
|
||||||
JavaSparkContext sc = getContext();
|
JavaSparkContext sc = getContext();
|
||||||
|
|
||||||
File f = testDir.newFolder();
|
File f = testDir.toFile();
|
||||||
new ClassPathResource("datavec-spark/csvsequence/").copyDirectory(f);
|
new ClassPathResource("datavec-spark/csvsequence/").copyDirectory(f);
|
||||||
|
|
||||||
String path = f.getAbsolutePath() + "/*";
|
String path = f.getAbsolutePath() + "/*";
|
||||||
|
@ -120,10 +121,10 @@ public class TestSequenceRecordReaderFunction extends BaseSparkTest {
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testSequenceRecordReaderFunctionVideo() throws Exception {
|
public void testSequenceRecordReaderFunctionVideo(@TempDir Path testDir) throws Exception {
|
||||||
JavaSparkContext sc = getContext();
|
JavaSparkContext sc = getContext();
|
||||||
|
|
||||||
File f = testDir.newFolder();
|
File f = testDir.toFile();
|
||||||
new ClassPathResource("datavec-spark/video/").copyDirectory(f);
|
new ClassPathResource("datavec-spark/video/").copyDirectory(f);
|
||||||
|
|
||||||
String path = f.getAbsolutePath() + "/*";
|
String path = f.getAbsolutePath() + "/*";
|
||||||
|
|
|
@ -22,7 +22,7 @@ package org.datavec.spark.functions;
|
||||||
|
|
||||||
import org.datavec.api.writable.*;
|
import org.datavec.api.writable.*;
|
||||||
import org.datavec.spark.transform.misc.WritablesToNDArrayFunction;
|
import org.datavec.spark.transform.misc.WritablesToNDArrayFunction;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -30,7 +30,7 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestWritablesToNDArrayFunction {
|
public class TestWritablesToNDArrayFunction {
|
||||||
|
|
||||||
|
|
|
@ -29,14 +29,14 @@ import org.datavec.api.writable.Writable;
|
||||||
import org.datavec.spark.BaseSparkTest;
|
import org.datavec.spark.BaseSparkTest;
|
||||||
import org.datavec.spark.transform.misc.SequenceWritablesToStringFunction;
|
import org.datavec.spark.transform.misc.SequenceWritablesToStringFunction;
|
||||||
import org.datavec.spark.transform.misc.WritablesToStringFunction;
|
import org.datavec.spark.transform.misc.WritablesToStringFunction;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import scala.Tuple2;
|
import scala.Tuple2;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestWritablesToStringFunctions extends BaseSparkTest {
|
public class TestWritablesToStringFunctions extends BaseSparkTest {
|
||||||
|
|
||||||
|
|
|
@ -26,7 +26,7 @@ import org.apache.spark.api.java.JavaPairRDD;
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
import org.datavec.api.writable.*;
|
import org.datavec.api.writable.*;
|
||||||
import org.datavec.spark.BaseSparkTest;
|
import org.datavec.spark.BaseSparkTest;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
|
@ -35,8 +35,8 @@ import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
public class TestSparkStorageUtils extends BaseSparkTest {
|
public class TestSparkStorageUtils extends BaseSparkTest {
|
||||||
|
|
||||||
|
|
|
@ -30,13 +30,13 @@ import org.datavec.api.util.ndarray.RecordConverter;
|
||||||
import org.datavec.api.writable.DoubleWritable;
|
import org.datavec.api.writable.DoubleWritable;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.datavec.spark.BaseSparkTest;
|
import org.datavec.spark.BaseSparkTest;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class DataFramesTests extends BaseSparkTest {
|
public class DataFramesTests extends BaseSparkTest {
|
||||||
|
|
||||||
|
|
|
@ -28,7 +28,7 @@ import org.datavec.api.util.ndarray.RecordConverter;
|
||||||
import org.datavec.api.writable.DoubleWritable;
|
import org.datavec.api.writable.DoubleWritable;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.datavec.spark.BaseSparkTest;
|
import org.datavec.spark.BaseSparkTest;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
|
@ -41,7 +41,7 @@ import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static junit.framework.TestCase.assertTrue;
|
import static junit.framework.TestCase.assertTrue;
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class NormalizationTests extends BaseSparkTest {
|
public class NormalizationTests extends BaseSparkTest {
|
||||||
|
|
||||||
|
|
|
@ -38,7 +38,7 @@ import org.datavec.local.transforms.AnalyzeLocal;
|
||||||
import org.datavec.spark.BaseSparkTest;
|
import org.datavec.spark.BaseSparkTest;
|
||||||
import org.datavec.spark.transform.AnalyzeSpark;
|
import org.datavec.spark.transform.AnalyzeSpark;
|
||||||
import org.joda.time.DateTimeZone;
|
import org.joda.time.DateTimeZone;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
import org.nd4j.common.io.ClassPathResource;
|
||||||
|
@ -47,7 +47,7 @@ import java.io.File;
|
||||||
import java.nio.file.Files;
|
import java.nio.file.Files;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
public class TestAnalysis extends BaseSparkTest {
|
public class TestAnalysis extends BaseSparkTest {
|
||||||
|
|
||||||
|
|
|
@ -27,11 +27,11 @@ import org.datavec.api.transform.schema.Schema;
|
||||||
import org.datavec.api.writable.*;
|
import org.datavec.api.writable.*;
|
||||||
import org.datavec.spark.BaseSparkTest;
|
import org.datavec.spark.BaseSparkTest;
|
||||||
import org.datavec.spark.transform.SparkTransformExecutor;
|
import org.datavec.spark.transform.SparkTransformExecutor;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestJoin extends BaseSparkTest {
|
public class TestJoin extends BaseSparkTest {
|
||||||
|
|
||||||
|
|
|
@ -30,13 +30,13 @@ import org.datavec.api.writable.Writable;
|
||||||
import org.datavec.api.writable.comparator.DoubleWritableComparator;
|
import org.datavec.api.writable.comparator.DoubleWritableComparator;
|
||||||
import org.datavec.spark.BaseSparkTest;
|
import org.datavec.spark.BaseSparkTest;
|
||||||
import org.datavec.spark.transform.SparkTransformExecutor;
|
import org.datavec.spark.transform.SparkTransformExecutor;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestCalculateSortedRank extends BaseSparkTest {
|
public class TestCalculateSortedRank extends BaseSparkTest {
|
||||||
|
|
||||||
|
|
|
@ -29,14 +29,14 @@ import org.datavec.api.writable.Text;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.datavec.spark.BaseSparkTest;
|
import org.datavec.spark.BaseSparkTest;
|
||||||
import org.datavec.spark.transform.SparkTransformExecutor;
|
import org.datavec.spark.transform.SparkTransformExecutor;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
public class TestConvertToSequence extends BaseSparkTest {
|
public class TestConvertToSequence extends BaseSparkTest {
|
||||||
|
|
||||||
|
|
|
@ -28,7 +28,7 @@ import org.datavec.api.writable.Text;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.datavec.spark.BaseSparkTest;
|
import org.datavec.spark.BaseSparkTest;
|
||||||
import org.datavec.spark.transform.utils.SparkUtils;
|
import org.datavec.spark.transform.utils.SparkUtils;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.io.FileInputStream;
|
import java.io.FileInputStream;
|
||||||
|
@ -36,7 +36,7 @@ import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class TestSparkUtil extends BaseSparkTest {
|
public class TestSparkUtil extends BaseSparkTest {
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,6 @@
|
||||||
package org.deeplearning4j;
|
package org.deeplearning4j;
|
||||||
|
|
||||||
import ch.qos.logback.classic.LoggerContext;
|
import ch.qos.logback.classic.LoggerContext;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.bytedeco.javacpp.Pointer;
|
import org.bytedeco.javacpp.Pointer;
|
||||||
import org.junit.jupiter.api.*;
|
import org.junit.jupiter.api.*;
|
||||||
|
|
||||||
|
@ -32,6 +31,7 @@ import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.profiler.ProfilerConfig;
|
import org.nd4j.linalg.profiler.ProfilerConfig;
|
||||||
import org.slf4j.ILoggerFactory;
|
import org.slf4j.ILoggerFactory;
|
||||||
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
import java.lang.management.ManagementFactory;
|
import java.lang.management.ManagementFactory;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -39,13 +39,11 @@ import java.util.Map;
|
||||||
import java.util.Properties;
|
import java.util.Properties;
|
||||||
import static org.junit.jupiter.api.Assumptions.assumeTrue;
|
import static org.junit.jupiter.api.Assumptions.assumeTrue;
|
||||||
|
|
||||||
import org.junit.jupiter.api.extension.ExtendWith;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
@DisplayName("Base DL 4 J Test")
|
@DisplayName("Base DL 4 J Test")
|
||||||
public abstract class BaseDL4JTest {
|
public abstract class BaseDL4JTest {
|
||||||
|
|
||||||
|
private static Logger log = LoggerFactory.getLogger(BaseDL4JTest.class.getName());
|
||||||
|
|
||||||
protected long startTime;
|
protected long startTime;
|
||||||
|
|
||||||
|
|
|
@ -43,7 +43,7 @@ import java.lang.reflect.Field;
|
||||||
import java.lang.reflect.Method;
|
import java.lang.reflect.Method;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class LayerHelperValidationUtil {
|
public class LayerHelperValidationUtil {
|
||||||
|
@ -145,7 +145,7 @@ public class LayerHelperValidationUtil {
|
||||||
System.out.println(p1);
|
System.out.println(p1);
|
||||||
System.out.println(p2);
|
System.out.println(p2);
|
||||||
}
|
}
|
||||||
assertTrue(s + " - param changed during forward pass: " + p, maxRE < t.getMaxRelError());
|
assertTrue(maxRE < t.getMaxRelError(),s + " - param changed during forward pass: " + p);
|
||||||
}
|
}
|
||||||
|
|
||||||
for( int i=0; i<ff1.size(); i++ ){
|
for( int i=0; i<ff1.size(); i++ ){
|
||||||
|
@ -163,7 +163,7 @@ public class LayerHelperValidationUtil {
|
||||||
double d2 = arr2.dup('c').getDouble(idx);
|
double d2 = arr2.dup('c').getDouble(idx);
|
||||||
System.out.println("Different values at index " + idx + ": " + d1 + ", " + d2 + " - RE = " + maxRE);
|
System.out.println("Different values at index " + idx + ": " + d1 + ", " + d2 + " - RE = " + maxRE);
|
||||||
}
|
}
|
||||||
assertTrue(s + layerName + " activations - max RE: " + maxRE, maxRE < t.getMaxRelError());
|
assertTrue(maxRE < t.getMaxRelError(), s + layerName + " activations - max RE: " + maxRE);
|
||||||
log.info("Forward pass, max relative error: " + layerName + " - " + maxRE);
|
log.info("Forward pass, max relative error: " + layerName + " - " + maxRE);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -180,7 +180,7 @@ public class LayerHelperValidationUtil {
|
||||||
log.info(s + "Output, max relative error: " + maxRE);
|
log.info(s + "Output, max relative error: " + maxRE);
|
||||||
|
|
||||||
assertEquals(net1NoHelper.params(), net2With.params()); //Check that forward pass does not modify params
|
assertEquals(net1NoHelper.params(), net2With.params()); //Check that forward pass does not modify params
|
||||||
assertTrue(s + "Max RE: " + maxRE, maxRE < t.getMaxRelError());
|
assertTrue(maxRE < t.getMaxRelError(), s + "Max RE: " + maxRE);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -201,7 +201,7 @@ public class LayerHelperValidationUtil {
|
||||||
|
|
||||||
double re = relError(s1, s2);
|
double re = relError(s1, s2);
|
||||||
String s = "Relative error: " + re;
|
String s = "Relative error: " + re;
|
||||||
assertTrue(s, re < t.getMaxRelError());
|
assertTrue(re < t.getMaxRelError(), s);
|
||||||
}
|
}
|
||||||
|
|
||||||
if(t.isTestBackward()) {
|
if(t.isTestBackward()) {
|
||||||
|
@ -243,8 +243,8 @@ public class LayerHelperValidationUtil {
|
||||||
} else {
|
} else {
|
||||||
System.out.println("OK: " + p);
|
System.out.println("OK: " + p);
|
||||||
}
|
}
|
||||||
assertTrue(t.getTestName() + " - Gradients are not equal: " + p + " - highest relative error = " + maxRE + " > max relative error = " + t.getMaxRelError(),
|
assertTrue(maxRE < t.getMaxRelError(),
|
||||||
maxRE < t.getMaxRelError());
|
t.getTestName() + " - Gradients are not equal: " + p + " - highest relative error = " + maxRE + " > max relative error = " + t.getMaxRelError());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -283,7 +283,7 @@ public class LayerHelperValidationUtil {
|
||||||
double d2 = listNew.get(j);
|
double d2 = listNew.get(j);
|
||||||
double re = relError(d1, d2);
|
double re = relError(d1, d2);
|
||||||
String msg = "Scores at iteration " + j + " - relError = " + re + ", score1 = " + d1 + ", score2 = " + d2;
|
String msg = "Scores at iteration " + j + " - relError = " + re + ", score1 = " + d1 + ", score2 = " + d2;
|
||||||
assertTrue(msg, re < t.getMaxRelError());
|
assertTrue(re < t.getMaxRelError(), msg);
|
||||||
System.out.println("j=" + j + ", d1 = " + d1 + ", d2 = " + d2);
|
System.out.println("j=" + j + ", d1 = " + d1 + ", d2 = " + d2);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -315,7 +315,7 @@ public class LayerHelperValidationUtil {
|
||||||
try {
|
try {
|
||||||
if (keepAndAssertPresent) {
|
if (keepAndAssertPresent) {
|
||||||
Object o = f.get(l);
|
Object o = f.get(l);
|
||||||
assertNotNull("Expect helper to be present for layer: " + l.getClass(), o);
|
assertNotNull(o,"Expect helper to be present for layer: " + l.getClass());
|
||||||
} else {
|
} else {
|
||||||
f.set(l, null);
|
f.set(l, null);
|
||||||
Integer i = map.get(l.getClass());
|
Integer i = map.get(l.getClass());
|
||||||
|
|
|
@ -26,8 +26,8 @@ import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.junit.Ignore;
|
import org.junit.jupiter.api.Disabled;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.resources.Resources;
|
import org.nd4j.common.resources.Resources;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
|
@ -38,7 +38,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
import java.nio.file.Files;
|
import java.nio.file.Files;
|
||||||
import java.util.concurrent.CountDownLatch;
|
import java.util.concurrent.CountDownLatch;
|
||||||
|
|
||||||
@Ignore
|
@Disabled
|
||||||
public class RandomTests extends BaseDL4JTest {
|
public class RandomTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
|
@ -50,8 +50,8 @@ import java.lang.reflect.Field;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.junit.Assert.assertNotNull;
|
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||||
|
|
||||||
public class TestUtils {
|
public class TestUtils {
|
||||||
|
|
||||||
|
|
|
@ -27,7 +27,7 @@ import org.junit.jupiter.api.AfterAll;
|
||||||
import org.junit.jupiter.api.BeforeAll;
|
import org.junit.jupiter.api.BeforeAll;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.api.io.TempDir;
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import org.junit.rules.Timeout;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
|
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
|
|
|
@ -23,7 +23,7 @@ package org.deeplearning4j.datasets;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.datasets.fetchers.Cifar10Fetcher;
|
import org.deeplearning4j.datasets.fetchers.Cifar10Fetcher;
|
||||||
import org.deeplearning4j.datasets.fetchers.TinyImageNetFetcher;
|
import org.deeplearning4j.datasets.fetchers.TinyImageNetFetcher;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
public class TestDataSets extends BaseDL4JTest {
|
public class TestDataSets extends BaseDL4JTest {
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,7 @@
|
||||||
*/
|
*/
|
||||||
package org.deeplearning4j.datasets.datavec;
|
package org.deeplearning4j.datasets.datavec;
|
||||||
|
|
||||||
import org.junit.rules.Timeout;
|
|
||||||
import org.nd4j.shade.guava.io.Files;
|
import org.nd4j.shade.guava.io.Files;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
|
@ -47,7 +47,7 @@ import org.deeplearning4j.datasets.datavec.exception.ZeroLengthSequenceException
|
||||||
import org.deeplearning4j.datasets.datavec.tools.SpecialImageRecordReader;
|
import org.deeplearning4j.datasets.datavec.tools.SpecialImageRecordReader;
|
||||||
import org.nd4j.linalg.dataset.AsyncDataSetIterator;
|
import org.nd4j.linalg.dataset.AsyncDataSetIterator;
|
||||||
import org.junit.jupiter.api.Disabled;
|
import org.junit.jupiter.api.Disabled;
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.api.io.TempDir;
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
@ -74,9 +74,6 @@ import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||||
@DisplayName("Record Reader Data Setiterator Test")
|
@DisplayName("Record Reader Data Setiterator Test")
|
||||||
class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
|
class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Rule
|
|
||||||
public Timeout timeout = Timeout.seconds(300);
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public DataType getDataType() {
|
public DataType getDataType() {
|
||||||
return DataType.FLOAT;
|
return DataType.FLOAT;
|
||||||
|
|
|
@ -19,7 +19,7 @@
|
||||||
*/
|
*/
|
||||||
package org.deeplearning4j.datasets.datavec;
|
package org.deeplearning4j.datasets.datavec;
|
||||||
|
|
||||||
import org.junit.rules.Timeout;
|
|
||||||
import org.nd4j.shade.guava.io.Files;
|
import org.nd4j.shade.guava.io.Files;
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
import org.apache.commons.io.FilenameUtils;
|
import org.apache.commons.io.FilenameUtils;
|
||||||
|
@ -44,7 +44,7 @@ import org.datavec.api.writable.Writable;
|
||||||
import org.datavec.image.recordreader.ImageRecordReader;
|
import org.datavec.image.recordreader.ImageRecordReader;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.TestUtils;
|
import org.deeplearning4j.TestUtils;
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.api.io.TempDir;
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -73,8 +73,7 @@ class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest {
|
||||||
@TempDir
|
@TempDir
|
||||||
public Path temporaryFolder;
|
public Path temporaryFolder;
|
||||||
|
|
||||||
@Rule
|
|
||||||
public Timeout timeout = Timeout.seconds(300);
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@DisplayName("Tests Basic")
|
@DisplayName("Tests Basic")
|
||||||
|
|
|
@ -20,9 +20,9 @@
|
||||||
package org.deeplearning4j.datasets.fetchers;
|
package org.deeplearning4j.datasets.fetchers;
|
||||||
|
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.rules.Timeout;
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
import static org.junit.jupiter.api.Assumptions.assumeTrue;
|
import static org.junit.jupiter.api.Assumptions.assumeTrue;
|
||||||
|
|
|
@ -21,14 +21,14 @@
|
||||||
package org.deeplearning4j.datasets.iterator;
|
package org.deeplearning4j.datasets.iterator;
|
||||||
|
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
|
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
|
||||||
import org.nd4j.linalg.dataset.api.preprocessor.MultiNormalizerMinMaxScaler;
|
import org.nd4j.linalg.dataset.api.preprocessor.MultiNormalizerMinMaxScaler;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class CombinedPreProcessorTests extends BaseDL4JTest {
|
public class CombinedPreProcessorTests extends BaseDL4JTest {
|
||||||
|
|
||||||
|
|
|
@ -23,7 +23,7 @@ package org.deeplearning4j.datasets.iterator;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.datasets.iterator.tools.DataSetGenerator;
|
import org.deeplearning4j.datasets.iterator.tools.DataSetGenerator;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -32,7 +32,7 @@ import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
public class DataSetSplitterTests extends BaseDL4JTest {
|
public class DataSetSplitterTests extends BaseDL4JTest {
|
||||||
@Test
|
@Test
|
||||||
|
@ -54,7 +54,7 @@ public class DataSetSplitterTests extends BaseDL4JTest {
|
||||||
while (train.hasNext()) {
|
while (train.hasNext()) {
|
||||||
val data = train.next().getFeatures();
|
val data = train.next().getFeatures();
|
||||||
|
|
||||||
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5);
|
assertEquals( (float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
|
||||||
gcntTrain++;
|
gcntTrain++;
|
||||||
global++;
|
global++;
|
||||||
}
|
}
|
||||||
|
@ -64,7 +64,7 @@ public class DataSetSplitterTests extends BaseDL4JTest {
|
||||||
|
|
||||||
while (test.hasNext()) {
|
while (test.hasNext()) {
|
||||||
val data = test.next().getFeatures();
|
val data = test.next().getFeatures();
|
||||||
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5);
|
assertEquals( (float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
|
||||||
gcntTest++;
|
gcntTest++;
|
||||||
global++;
|
global++;
|
||||||
}
|
}
|
||||||
|
@ -94,7 +94,7 @@ public class DataSetSplitterTests extends BaseDL4JTest {
|
||||||
while (train.hasNext()) {
|
while (train.hasNext()) {
|
||||||
val data = train.next().getFeatures();
|
val data = train.next().getFeatures();
|
||||||
|
|
||||||
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5);
|
assertEquals((float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
|
||||||
gcntTrain++;
|
gcntTrain++;
|
||||||
global++;
|
global++;
|
||||||
}
|
}
|
||||||
|
@ -104,7 +104,7 @@ public class DataSetSplitterTests extends BaseDL4JTest {
|
||||||
if (e % 2 == 0)
|
if (e % 2 == 0)
|
||||||
while (test.hasNext()) {
|
while (test.hasNext()) {
|
||||||
val data = test.next().getFeatures();
|
val data = test.next().getFeatures();
|
||||||
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5);
|
assertEquals( (float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
|
||||||
gcntTest++;
|
gcntTest++;
|
||||||
global++;
|
global++;
|
||||||
}
|
}
|
||||||
|
@ -113,46 +113,50 @@ public class DataSetSplitterTests extends BaseDL4JTest {
|
||||||
assertEquals(700 * numEpochs + (300 * numEpochs / 2), global);
|
assertEquals(700 * numEpochs + (300 * numEpochs / 2), global);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected = ND4JIllegalStateException.class)
|
@Test()
|
||||||
public void testSplitter_3() throws Exception {
|
public void testSplitter_3() throws Exception {
|
||||||
val back = new DataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
|
assertThrows(ND4JIllegalStateException.class, () -> {
|
||||||
|
val back = new DataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
|
||||||
|
|
||||||
val splitter = new DataSetIteratorSplitter(back, 1000, 0.7);
|
val splitter = new DataSetIteratorSplitter(back, 1000, 0.7);
|
||||||
|
|
||||||
val train = splitter.getTrainIterator();
|
val train = splitter.getTrainIterator();
|
||||||
val test = splitter.getTestIterator();
|
val test = splitter.getTestIterator();
|
||||||
val numEpochs = 10;
|
val numEpochs = 10;
|
||||||
|
|
||||||
int gcntTrain = 0;
|
int gcntTrain = 0;
|
||||||
int gcntTest = 0;
|
int gcntTest = 0;
|
||||||
int global = 0;
|
int global = 0;
|
||||||
// emulating epochs here
|
// emulating epochs here
|
||||||
for (int e = 0; e < numEpochs; e++) {
|
for (int e = 0; e < numEpochs; e++) {
|
||||||
int cnt = 0;
|
int cnt = 0;
|
||||||
while (train.hasNext()) {
|
while (train.hasNext()) {
|
||||||
val data = train.next().getFeatures();
|
val data = train.next().getFeatures();
|
||||||
|
|
||||||
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5);
|
assertEquals((float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
|
||||||
gcntTrain++;
|
gcntTrain++;
|
||||||
global++;
|
global++;
|
||||||
}
|
}
|
||||||
|
|
||||||
train.reset();
|
train.reset();
|
||||||
|
|
||||||
|
|
||||||
while (test.hasNext()) {
|
while (test.hasNext()) {
|
||||||
val data = test.next().getFeatures();
|
val data = test.next().getFeatures();
|
||||||
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5);
|
assertEquals((float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
|
||||||
gcntTest++;
|
gcntTest++;
|
||||||
global++;
|
global++;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// shifting underlying iterator by one
|
||||||
|
train.hasNext();
|
||||||
|
back.shift();
|
||||||
|
}
|
||||||
|
|
||||||
|
assertEquals(1000 * numEpochs, global);
|
||||||
|
});
|
||||||
|
|
||||||
// shifting underlying iterator by one
|
|
||||||
train.hasNext();
|
|
||||||
back.shift();
|
|
||||||
}
|
|
||||||
|
|
||||||
assertEquals(1000 * numEpochs, global);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -172,8 +176,7 @@ public class DataSetSplitterTests extends BaseDL4JTest {
|
||||||
partIterator.reset();
|
partIterator.reset();
|
||||||
while (partIterator.hasNext()) {
|
while (partIterator.hasNext()) {
|
||||||
val data = partIterator.next().getFeatures();
|
val data = partIterator.next().getFeatures();
|
||||||
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e,
|
assertEquals((float) perEpoch, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
|
||||||
(float) perEpoch, data.getFloat(0), 1e-5);
|
|
||||||
//gcntTrain++;
|
//gcntTrain++;
|
||||||
global++;
|
global++;
|
||||||
cnt++;
|
cnt++;
|
||||||
|
@ -206,8 +209,7 @@ public class DataSetSplitterTests extends BaseDL4JTest {
|
||||||
int cnt = 0;
|
int cnt = 0;
|
||||||
val data = partIterator.next().getFeatures();
|
val data = partIterator.next().getFeatures();
|
||||||
|
|
||||||
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e,
|
assertEquals((float) perEpoch, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
|
||||||
(float) perEpoch, data.getFloat(0), 1e-5);
|
|
||||||
//gcntTrain++;
|
//gcntTrain++;
|
||||||
global++;
|
global++;
|
||||||
cnt++;
|
cnt++;
|
||||||
|
@ -247,10 +249,10 @@ public class DataSetSplitterTests extends BaseDL4JTest {
|
||||||
val ds = trainIter.next();
|
val ds = trainIter.next();
|
||||||
assertNotNull(ds);
|
assertNotNull(ds);
|
||||||
|
|
||||||
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures().getDouble(0), 1e-5f);
|
assertEquals(globalIter, ds.getFeatures().getDouble(0), 1e-5f,"Failed at iteration [" + globalIter + "]");
|
||||||
globalIter++;
|
globalIter++;
|
||||||
}
|
}
|
||||||
assertTrue("Failed at epoch [" + e + "]", trained);
|
assertTrue(trained,"Failed at epoch [" + e + "]");
|
||||||
assertEquals(800, globalIter);
|
assertEquals(800, globalIter);
|
||||||
|
|
||||||
|
|
||||||
|
@ -262,10 +264,10 @@ public class DataSetSplitterTests extends BaseDL4JTest {
|
||||||
val ds = testIter.next();
|
val ds = testIter.next();
|
||||||
assertNotNull(ds);
|
assertNotNull(ds);
|
||||||
|
|
||||||
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures().getDouble(0), 1e-5f);
|
assertEquals(globalIter, ds.getFeatures().getDouble(0), 1e-5f,"Failed at iteration [" + globalIter + "]");
|
||||||
globalIter++;
|
globalIter++;
|
||||||
}
|
}
|
||||||
assertTrue("Failed at epoch [" + e + "]", tested);
|
assertTrue(tested,"Failed at epoch [" + e + "]");
|
||||||
assertEquals(900, globalIter);
|
assertEquals(900, globalIter);
|
||||||
|
|
||||||
// validation set is used every 5 epochs
|
// validation set is used every 5 epochs
|
||||||
|
@ -277,10 +279,10 @@ public class DataSetSplitterTests extends BaseDL4JTest {
|
||||||
val ds = validationIter.next();
|
val ds = validationIter.next();
|
||||||
assertNotNull(ds);
|
assertNotNull(ds);
|
||||||
|
|
||||||
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures().getDouble(0), 1e-5f);
|
assertEquals(globalIter, ds.getFeatures().getDouble(0), 1e-5f,"Failed at iteration [" + globalIter + "]");
|
||||||
globalIter++;
|
globalIter++;
|
||||||
}
|
}
|
||||||
assertTrue("Failed at epoch [" + e + "]", validated);
|
assertTrue(validated,"Failed at epoch [" + e + "]");
|
||||||
}
|
}
|
||||||
|
|
||||||
// all 3 iterators have exactly 1000 elements combined
|
// all 3 iterators have exactly 1000 elements combined
|
||||||
|
@ -312,7 +314,7 @@ public class DataSetSplitterTests extends BaseDL4JTest {
|
||||||
int farCnt = (1000 / 2) * (partNumber) + cnt;
|
int farCnt = (1000 / 2) * (partNumber) + cnt;
|
||||||
val data = iteratorList.get(partNumber).next().getFeatures();
|
val data = iteratorList.get(partNumber).next().getFeatures();
|
||||||
|
|
||||||
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) farCnt, data.getFloat(0), 1e-5);
|
assertEquals((float) farCnt, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
|
||||||
cnt++;
|
cnt++;
|
||||||
global++;
|
global++;
|
||||||
}
|
}
|
||||||
|
@ -322,7 +324,7 @@ public class DataSetSplitterTests extends BaseDL4JTest {
|
||||||
while (iteratorList.get(0).hasNext()) {
|
while (iteratorList.get(0).hasNext()) {
|
||||||
val data = iteratorList.get(0).next().getFeatures();
|
val data = iteratorList.get(0).next().getFeatures();
|
||||||
|
|
||||||
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5);
|
assertEquals((float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
|
||||||
global++;
|
global++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -341,7 +343,7 @@ public class DataSetSplitterTests extends BaseDL4JTest {
|
||||||
while (iteratorList.get(partNumber).hasNext()) {
|
while (iteratorList.get(partNumber).hasNext()) {
|
||||||
val data = iteratorList.get(partNumber).next().getFeatures();
|
val data = iteratorList.get(partNumber).next().getFeatures();
|
||||||
|
|
||||||
assertEquals("Train failed on iteration " + cnt, (float) (500*partNumber + cnt), data.getFloat(0), 1e-5);
|
assertEquals( (float) (500*partNumber + cnt), data.getFloat(0), 1e-5,"Train failed on iteration " + cnt);
|
||||||
cnt++;
|
cnt++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -365,7 +367,7 @@ public class DataSetSplitterTests extends BaseDL4JTest {
|
||||||
while (iteratorList.get(partNumber).hasNext()) {
|
while (iteratorList.get(partNumber).hasNext()) {
|
||||||
val data = iteratorList.get(partNumber).next().getFeatures();
|
val data = iteratorList.get(partNumber).next().getFeatures();
|
||||||
|
|
||||||
assertEquals("Train failed on iteration " + cnt, (float) (500*partNumber + cnt), data.getFloat(0), 1e-5);
|
assertEquals( (float) (500*partNumber + cnt), data.getFloat(0), 1e-5,"Train failed on iteration " + cnt);
|
||||||
cnt++;
|
cnt++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -390,7 +392,7 @@ public class DataSetSplitterTests extends BaseDL4JTest {
|
||||||
val ds = validationIter.next();
|
val ds = validationIter.next();
|
||||||
assertNotNull(ds);
|
assertNotNull(ds);
|
||||||
|
|
||||||
assertEquals("Validation failed on iteration " + valCnt, (float) valCnt + 90, ds.getFeatures().getFloat(0), 1e-5);
|
assertEquals((float) valCnt + 90, ds.getFeatures().getFloat(0), 1e-5,"Validation failed on iteration " + valCnt);
|
||||||
valCnt++;
|
valCnt++;
|
||||||
}
|
}
|
||||||
assertEquals(5, valCnt);
|
assertEquals(5, valCnt);
|
||||||
|
|
|
@ -25,15 +25,15 @@ import lombok.val;
|
||||||
import lombok.var;
|
import lombok.var;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.datasets.iterator.tools.SimpleVariableGenerator;
|
import org.deeplearning4j.datasets.iterator.tools.SimpleVariableGenerator;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.linalg.dataset.api.DataSet;
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.junit.Assert.assertNotNull;
|
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class DummyBlockDataSetIteratorTests extends BaseDL4JTest {
|
public class DummyBlockDataSetIteratorTests extends BaseDL4JTest {
|
||||||
|
|
|
@ -21,7 +21,7 @@ package org.deeplearning4j.datasets.iterator;
|
||||||
|
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.rules.ExpectedException;
|
import org.junit.rules.ExpectedException;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
|
@ -43,8 +43,7 @@ class EarlyTerminationDataSetIteratorTest extends BaseDL4JTest {
|
||||||
|
|
||||||
int numExamples = 105;
|
int numExamples = 105;
|
||||||
|
|
||||||
@Rule
|
|
||||||
public final ExpectedException exception = ExpectedException.none();
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@DisplayName("Test Next And Reset")
|
@DisplayName("Test Next And Reset")
|
||||||
|
@ -86,14 +85,16 @@ class EarlyTerminationDataSetIteratorTest extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@DisplayName("Test Callsto Next Not Allowed")
|
@DisplayName("Test calls to Next Not Allowed")
|
||||||
void testCallstoNextNotAllowed() throws IOException {
|
void testCallstoNextNotAllowed() throws IOException {
|
||||||
int terminateAfter = 1;
|
assertThrows(RuntimeException.class,() -> {
|
||||||
DataSetIterator iter = new MnistDataSetIterator(minibatchSize, numExamples);
|
int terminateAfter = 1;
|
||||||
EarlyTerminationDataSetIterator earlyEndIter = new EarlyTerminationDataSetIterator(iter, terminateAfter);
|
DataSetIterator iter = new MnistDataSetIterator(minibatchSize, numExamples);
|
||||||
earlyEndIter.next(10);
|
EarlyTerminationDataSetIterator earlyEndIter = new EarlyTerminationDataSetIterator(iter, terminateAfter);
|
||||||
iter.reset();
|
earlyEndIter.next(10);
|
||||||
exception.expect(RuntimeException.class);
|
iter.reset();
|
||||||
earlyEndIter.next(10);
|
earlyEndIter.next(10);
|
||||||
|
});
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,7 +21,7 @@ package org.deeplearning4j.datasets.iterator;
|
||||||
|
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.rules.ExpectedException;
|
import org.junit.rules.ExpectedException;
|
||||||
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
|
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
|
||||||
|
@ -30,11 +30,12 @@ import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
|
||||||
import org.junit.jupiter.api.DisplayName;
|
import org.junit.jupiter.api.DisplayName;
|
||||||
import org.junit.jupiter.api.extension.ExtendWith;
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
@DisplayName("Early Termination Multi Data Set Iterator Test")
|
@DisplayName("Early Termination Multi Data Set Iterator Test")
|
||||||
class EarlyTerminationMultiDataSetIteratorTest extends BaseDL4JTest {
|
class EarlyTerminationMultiDataSetIteratorTest extends BaseDL4JTest {
|
||||||
|
|
||||||
|
@ -42,8 +43,7 @@ class EarlyTerminationMultiDataSetIteratorTest extends BaseDL4JTest {
|
||||||
|
|
||||||
int numExamples = 105;
|
int numExamples = 105;
|
||||||
|
|
||||||
@Rule
|
|
||||||
public final ExpectedException exception = ExpectedException.none();
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@DisplayName("Test Next And Reset")
|
@DisplayName("Test Next And Reset")
|
||||||
|
@ -91,14 +91,16 @@ class EarlyTerminationMultiDataSetIteratorTest extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@DisplayName("Test Callsto Next Not Allowed")
|
@DisplayName("Test calls to Next Not Allowed")
|
||||||
void testCallstoNextNotAllowed() throws IOException {
|
void testCallstoNextNotAllowed() throws IOException {
|
||||||
int terminateAfter = 1;
|
assertThrows(RuntimeException.class,() -> {
|
||||||
MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(new MnistDataSetIterator(minibatchSize, numExamples));
|
int terminateAfter = 1;
|
||||||
EarlyTerminationMultiDataSetIterator earlyEndIter = new EarlyTerminationMultiDataSetIterator(iter, terminateAfter);
|
MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(new MnistDataSetIterator(minibatchSize, numExamples));
|
||||||
earlyEndIter.next(10);
|
EarlyTerminationMultiDataSetIterator earlyEndIter = new EarlyTerminationMultiDataSetIterator(iter, terminateAfter);
|
||||||
iter.reset();
|
earlyEndIter.next(10);
|
||||||
exception.expect(RuntimeException.class);
|
iter.reset();
|
||||||
earlyEndIter.next(10);
|
earlyEndIter.next(10);
|
||||||
|
});
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,14 +24,16 @@ import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.datasets.iterator.tools.DataSetGenerator;
|
import org.deeplearning4j.datasets.iterator.tools.DataSetGenerator;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.api.Timeout;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class JointMultiDataSetIteratorTests extends BaseDL4JTest {
|
public class JointMultiDataSetIteratorTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test (timeout = 20000L)
|
@Test ()
|
||||||
|
@Timeout(20000L)
|
||||||
public void testJMDSI_1() {
|
public void testJMDSI_1() {
|
||||||
val iter0 = new DataSetGenerator(32, new int[]{3, 3}, new int[]{2, 2});
|
val iter0 = new DataSetGenerator(32, new int[]{3, 3}, new int[]{2, 2});
|
||||||
val iter1 = new DataSetGenerator(32, new int[]{3, 3, 3}, new int[]{2, 2, 2});
|
val iter1 = new DataSetGenerator(32, new int[]{3, 3, 3}, new int[]{2, 2, 2});
|
||||||
|
@ -75,7 +77,8 @@ public class JointMultiDataSetIteratorTests extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test (timeout = 20000L)
|
@Test ()
|
||||||
|
@Timeout(20000L)
|
||||||
public void testJMDSI_2() {
|
public void testJMDSI_2() {
|
||||||
val iter0 = new DataSetGenerator(32, new int[]{3, 3}, new int[]{2, 2});
|
val iter0 = new DataSetGenerator(32, new int[]{3, 3}, new int[]{2, 2});
|
||||||
val iter1 = new DataSetGenerator(32, new int[]{3, 3, 3}, new int[]{2, 2, 2});
|
val iter1 = new DataSetGenerator(32, new int[]{3, 3, 3}, new int[]{2, 2, 2});
|
||||||
|
|
|
@ -23,7 +23,7 @@ package org.deeplearning4j.datasets.iterator;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.datasets.iterator.loader.DataSetLoaderIterator;
|
import org.deeplearning4j.datasets.iterator.loader.DataSetLoaderIterator;
|
||||||
import org.deeplearning4j.datasets.iterator.loader.MultiDataSetLoaderIterator;
|
import org.deeplearning4j.datasets.iterator.loader.MultiDataSetLoaderIterator;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.loader.Loader;
|
import org.nd4j.common.loader.Loader;
|
||||||
import org.nd4j.common.loader.LocalFileSourceFactory;
|
import org.nd4j.common.loader.LocalFileSourceFactory;
|
||||||
import org.nd4j.common.loader.Source;
|
import org.nd4j.common.loader.Source;
|
||||||
|
@ -39,8 +39,8 @@ import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
public class LoaderIteratorTests extends BaseDL4JTest {
|
public class LoaderIteratorTests extends BaseDL4JTest {
|
||||||
|
|
||||||
|
|
|
@ -24,7 +24,7 @@ import lombok.val;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.datasets.iterator.tools.DataSetGenerator;
|
import org.deeplearning4j.datasets.iterator.tools.DataSetGenerator;
|
||||||
import org.deeplearning4j.datasets.iterator.tools.MultiDataSetGenerator;
|
import org.deeplearning4j.datasets.iterator.tools.MultiDataSetGenerator;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||||
|
@ -32,7 +32,7 @@ import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
public class MultiDataSetSplitterTests extends BaseDL4JTest {
|
public class MultiDataSetSplitterTests extends BaseDL4JTest {
|
||||||
|
|
||||||
|
@ -55,7 +55,7 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
|
||||||
while (train.hasNext()) {
|
while (train.hasNext()) {
|
||||||
val data = train.next().getFeatures(0);
|
val data = train.next().getFeatures(0);
|
||||||
|
|
||||||
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5);
|
assertEquals( (float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
|
||||||
gcntTrain++;
|
gcntTrain++;
|
||||||
global++;
|
global++;
|
||||||
}
|
}
|
||||||
|
@ -65,7 +65,7 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
|
||||||
|
|
||||||
while (test.hasNext()) {
|
while (test.hasNext()) {
|
||||||
val data = test.next().getFeatures(0);
|
val data = test.next().getFeatures(0);
|
||||||
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5);
|
assertEquals( (float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
|
||||||
gcntTest++;
|
gcntTest++;
|
||||||
global++;
|
global++;
|
||||||
}
|
}
|
||||||
|
@ -96,7 +96,7 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
|
||||||
while (train.hasNext()) {
|
while (train.hasNext()) {
|
||||||
val data = train.next().getFeatures(0);
|
val data = train.next().getFeatures(0);
|
||||||
|
|
||||||
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5);
|
assertEquals( (float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
|
||||||
gcntTrain++;
|
gcntTrain++;
|
||||||
global++;
|
global++;
|
||||||
}
|
}
|
||||||
|
@ -106,7 +106,7 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
|
||||||
if (e % 2 == 0)
|
if (e % 2 == 0)
|
||||||
while (test.hasNext()) {
|
while (test.hasNext()) {
|
||||||
val data = test.next().getFeatures(0);
|
val data = test.next().getFeatures(0);
|
||||||
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5);
|
assertEquals((float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
|
||||||
gcntTest++;
|
gcntTest++;
|
||||||
global++;
|
global++;
|
||||||
}
|
}
|
||||||
|
@ -115,46 +115,49 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
|
||||||
assertEquals(700 * numEpochs + (300 * numEpochs / 2), global);
|
assertEquals(700 * numEpochs + (300 * numEpochs / 2), global);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected = ND4JIllegalStateException.class)
|
@Test()
|
||||||
public void testSplitter_3() throws Exception {
|
public void testSplitter_3() throws Exception {
|
||||||
val back = new MultiDataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
|
assertThrows(ND4JIllegalStateException.class,() -> {
|
||||||
|
val back = new MultiDataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
|
||||||
|
|
||||||
val splitter = new MultiDataSetIteratorSplitter(back, 1000, 0.7);
|
val splitter = new MultiDataSetIteratorSplitter(back, 1000, 0.7);
|
||||||
|
|
||||||
val train = splitter.getTrainIterator();
|
val train = splitter.getTrainIterator();
|
||||||
val test = splitter.getTestIterator();
|
val test = splitter.getTestIterator();
|
||||||
val numEpochs = 10;
|
val numEpochs = 10;
|
||||||
|
|
||||||
int gcntTrain = 0;
|
int gcntTrain = 0;
|
||||||
int gcntTest = 0;
|
int gcntTest = 0;
|
||||||
int global = 0;
|
int global = 0;
|
||||||
// emulating epochs here
|
// emulating epochs here
|
||||||
for (int e = 0; e < numEpochs; e++){
|
for (int e = 0; e < numEpochs; e++){
|
||||||
int cnt = 0;
|
int cnt = 0;
|
||||||
while (train.hasNext()) {
|
while (train.hasNext()) {
|
||||||
val data = train.next().getFeatures(0);
|
val data = train.next().getFeatures(0);
|
||||||
|
|
||||||
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5);
|
assertEquals((float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
|
||||||
gcntTrain++;
|
gcntTrain++;
|
||||||
global++;
|
global++;
|
||||||
}
|
}
|
||||||
|
|
||||||
train.reset();
|
train.reset();
|
||||||
|
|
||||||
|
|
||||||
while (test.hasNext()) {
|
while (test.hasNext()) {
|
||||||
val data = test.next().getFeatures(0);
|
val data = test.next().getFeatures(0);
|
||||||
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5);
|
assertEquals( (float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
|
||||||
gcntTest++;
|
gcntTest++;
|
||||||
global++;
|
global++;
|
||||||
}
|
}
|
||||||
|
|
||||||
// shifting underlying iterator by one
|
// shifting underlying iterator by one
|
||||||
train.hasNext();
|
train.hasNext();
|
||||||
back.shift();
|
back.shift();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
assertEquals(1000 * numEpochs, global);
|
||||||
|
});
|
||||||
|
|
||||||
assertEquals(1000 * numEpochs, global);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -185,11 +188,11 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
|
||||||
assertNotNull(ds);
|
assertNotNull(ds);
|
||||||
|
|
||||||
for (int i = 0; i < ds.getFeatures().length; ++i) {
|
for (int i = 0; i < ds.getFeatures().length; ++i) {
|
||||||
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f);
|
assertEquals( (double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f,"Failed at iteration [" + globalIter + "]");
|
||||||
}
|
}
|
||||||
globalIter++;
|
globalIter++;
|
||||||
}
|
}
|
||||||
assertTrue("Failed at epoch [" + e + "]", trained);
|
assertTrue(trained,"Failed at epoch [" + e + "]");
|
||||||
assertEquals(800, globalIter);
|
assertEquals(800, globalIter);
|
||||||
|
|
||||||
|
|
||||||
|
@ -202,11 +205,11 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
|
||||||
assertNotNull(ds);
|
assertNotNull(ds);
|
||||||
|
|
||||||
for (int i = 0; i < ds.getFeatures().length; ++i) {
|
for (int i = 0; i < ds.getFeatures().length; ++i) {
|
||||||
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f);
|
assertEquals((double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f,"Failed at iteration [" + globalIter + "]");
|
||||||
}
|
}
|
||||||
globalIter++;
|
globalIter++;
|
||||||
}
|
}
|
||||||
assertTrue("Failed at epoch [" + e + "]", tested);
|
assertTrue(tested,"Failed at epoch [" + e + "]");
|
||||||
assertEquals(900, globalIter);
|
assertEquals(900, globalIter);
|
||||||
|
|
||||||
// validation set is used every 5 epochs
|
// validation set is used every 5 epochs
|
||||||
|
@ -219,11 +222,11 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
|
||||||
assertNotNull(ds);
|
assertNotNull(ds);
|
||||||
|
|
||||||
for (int i = 0; i < ds.getFeatures().length; ++i) {
|
for (int i = 0; i < ds.getFeatures().length; ++i) {
|
||||||
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f);
|
assertEquals( (double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f,"Failed at iteration [" + globalIter + "]");
|
||||||
}
|
}
|
||||||
globalIter++;
|
globalIter++;
|
||||||
}
|
}
|
||||||
assertTrue("Failed at epoch [" + e + "]", validated);
|
assertTrue(validated,"Failed at epoch [" + e + "]");
|
||||||
}
|
}
|
||||||
|
|
||||||
// all 3 iterators have exactly 1000 elements combined
|
// all 3 iterators have exactly 1000 elements combined
|
||||||
|
@ -256,8 +259,7 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
|
||||||
val data = partIterator.next().getFeatures();
|
val data = partIterator.next().getFeatures();
|
||||||
|
|
||||||
for (int i = 0; i < data.length; ++i) {
|
for (int i = 0; i < data.length; ++i) {
|
||||||
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e,
|
assertEquals((float) perEpoch, data[i].getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
|
||||||
(float) perEpoch, data[i].getFloat(0), 1e-5);
|
|
||||||
}
|
}
|
||||||
//gcntTrain++;
|
//gcntTrain++;
|
||||||
global++;
|
global++;
|
||||||
|
@ -299,12 +301,12 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
|
||||||
assertNotNull(ds);
|
assertNotNull(ds);
|
||||||
|
|
||||||
for (int i = 0; i < ds.getFeatures().length; ++i) {
|
for (int i = 0; i < ds.getFeatures().length; ++i) {
|
||||||
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter,
|
assertEquals((double) globalIter,
|
||||||
ds.getFeatures()[i].getDouble(0), 1e-5f);
|
ds.getFeatures()[i].getDouble(0), 1e-5f,"Failed at iteration [" + globalIter + "]");
|
||||||
}
|
}
|
||||||
globalIter++;
|
globalIter++;
|
||||||
}
|
}
|
||||||
assertTrue("Failed at epoch [" + e + "]", trained);
|
assertTrue(trained,"Failed at epoch [" + e + "]");
|
||||||
assertEquals(800, globalIter);
|
assertEquals(800, globalIter);
|
||||||
|
|
||||||
|
|
||||||
|
@ -316,11 +318,11 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
|
||||||
val ds = testIter.next();
|
val ds = testIter.next();
|
||||||
assertNotNull(ds);
|
assertNotNull(ds);
|
||||||
for (int i = 0; i < ds.getFeatures().length; ++i) {
|
for (int i = 0; i < ds.getFeatures().length; ++i) {
|
||||||
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f);
|
assertEquals((double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f,"Failed at iteration [" + globalIter + "]");
|
||||||
}
|
}
|
||||||
globalIter++;
|
globalIter++;
|
||||||
}
|
}
|
||||||
assertTrue("Failed at epoch [" + e + "]", tested);
|
assertTrue(tested,"Failed at epoch [" + e + "]");
|
||||||
assertEquals(900, globalIter);
|
assertEquals(900, globalIter);
|
||||||
|
|
||||||
// validation set is used every 5 epochs
|
// validation set is used every 5 epochs
|
||||||
|
@ -333,12 +335,12 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
|
||||||
assertNotNull(ds);
|
assertNotNull(ds);
|
||||||
|
|
||||||
for (int i = 0; i < ds.getFeatures().length; ++i) {
|
for (int i = 0; i < ds.getFeatures().length; ++i) {
|
||||||
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter,
|
assertEquals((double) globalIter,
|
||||||
ds.getFeatures()[i].getDouble(0), 1e-5f);
|
ds.getFeatures()[i].getDouble(0), 1e-5f,"Failed at iteration [" + globalIter + "]");
|
||||||
}
|
}
|
||||||
globalIter++;
|
globalIter++;
|
||||||
}
|
}
|
||||||
assertTrue("Failed at epoch [" + e + "]", validated);
|
assertTrue(validated,"Failed at epoch [" + e + "]");
|
||||||
}
|
}
|
||||||
|
|
||||||
// all 3 iterators have exactly 1000 elements combined
|
// all 3 iterators have exactly 1000 elements combined
|
||||||
|
@ -370,7 +372,7 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
|
||||||
int farCnt = (1000 / 2) * (partNumber) + cnt;
|
int farCnt = (1000 / 2) * (partNumber) + cnt;
|
||||||
val data = iteratorList.get(partNumber).next().getFeatures();
|
val data = iteratorList.get(partNumber).next().getFeatures();
|
||||||
for (int i = 0; i < data.length; ++i) {
|
for (int i = 0; i < data.length; ++i) {
|
||||||
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) farCnt, data[i].getFloat(0), 1e-5);
|
assertEquals( (float) farCnt, data[i].getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
|
||||||
}
|
}
|
||||||
cnt++;
|
cnt++;
|
||||||
global++;
|
global++;
|
||||||
|
@ -381,8 +383,8 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
|
||||||
while (iteratorList.get(0).hasNext()) {
|
while (iteratorList.get(0).hasNext()) {
|
||||||
val data = iteratorList.get(0).next().getFeatures();
|
val data = iteratorList.get(0).next().getFeatures();
|
||||||
for (int i = 0; i < data.length; ++i) {
|
for (int i = 0; i < data.length; ++i) {
|
||||||
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++,
|
assertEquals((float) cnt++,
|
||||||
data[i].getFloat(0), 1e-5);
|
data[i].getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
|
||||||
}
|
}
|
||||||
global++;
|
global++;
|
||||||
}
|
}
|
||||||
|
@ -402,7 +404,7 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
|
||||||
while (iteratorList.get(partNumber).hasNext()) {
|
while (iteratorList.get(partNumber).hasNext()) {
|
||||||
val data = iteratorList.get(partNumber).next().getFeatures();
|
val data = iteratorList.get(partNumber).next().getFeatures();
|
||||||
for (int i = 0; i < data.length; ++i) {
|
for (int i = 0; i < data.length; ++i) {
|
||||||
assertEquals("Train failed on iteration " + cnt, (float) (500 * partNumber + cnt), data[i].getFloat(0), 1e-5);
|
assertEquals( (float) (500 * partNumber + cnt), data[i].getFloat(0), 1e-5,"Train failed on iteration " + cnt);
|
||||||
}
|
}
|
||||||
cnt++;
|
cnt++;
|
||||||
}
|
}
|
||||||
|
@ -427,8 +429,8 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
|
||||||
while (iteratorList.get(partNumber).hasNext()) {
|
while (iteratorList.get(partNumber).hasNext()) {
|
||||||
val data = iteratorList.get(partNumber).next().getFeatures();
|
val data = iteratorList.get(partNumber).next().getFeatures();
|
||||||
for (int i = 0; i < data.length; ++i) {
|
for (int i = 0; i < data.length; ++i) {
|
||||||
assertEquals("Train failed on iteration " + cnt, (float) (500 * partNumber + cnt),
|
assertEquals( (float) (500 * partNumber + cnt),
|
||||||
data[i].getFloat(0), 1e-5);
|
data[i].getFloat(0), 1e-5,"Train failed on iteration " + cnt);
|
||||||
}
|
}
|
||||||
cnt++;
|
cnt++;
|
||||||
}
|
}
|
||||||
|
@ -454,8 +456,8 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
|
||||||
val ds = validationIter.next();
|
val ds = validationIter.next();
|
||||||
assertNotNull(ds);
|
assertNotNull(ds);
|
||||||
for (int i = 0; i < ds.getFeatures().length; ++i) {
|
for (int i = 0; i < ds.getFeatures().length; ++i) {
|
||||||
assertEquals("Validation failed on iteration " + valCnt, (float) valCnt + 90,
|
assertEquals((float) valCnt + 90,
|
||||||
ds.getFeatures()[i].getFloat(0), 1e-5);
|
ds.getFeatures()[i].getFloat(0), 1e-5,"Validation failed on iteration " + valCnt);
|
||||||
}
|
}
|
||||||
valCnt++;
|
valCnt++;
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,9 +25,9 @@ import org.datavec.api.split.FileSplit;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
|
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
|
||||||
import org.deeplearning4j.nn.util.TestDataSetConsumer;
|
import org.deeplearning4j.nn.util.TestDataSetConsumer;
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.rules.Timeout;
|
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -42,9 +42,6 @@ import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
@DisplayName("Multiple Epochs Iterator Test")
|
@DisplayName("Multiple Epochs Iterator Test")
|
||||||
class MultipleEpochsIteratorTest extends BaseDL4JTest {
|
class MultipleEpochsIteratorTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Rule
|
|
||||||
public Timeout timeout = Timeout.seconds(300);
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@DisplayName("Test Next And Reset")
|
@DisplayName("Test Next And Reset")
|
||||||
void testNextAndReset() throws Exception {
|
void testNextAndReset() throws Exception {
|
||||||
|
|
|
@ -22,8 +22,8 @@ package org.deeplearning4j.datasets.iterator;
|
||||||
|
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
|
||||||
import org.junit.Ignore;
|
import org.junit.jupiter.api.Disabled;
|
||||||
import org.junit.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
|
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
|
||||||
|
@ -33,9 +33,9 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static junit.framework.TestCase.assertTrue;
|
import static junit.framework.TestCase.assertTrue;
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
@Ignore
|
@Disabled
|
||||||
public class TestAsyncIterator extends BaseDL4JTest {
|
public class TestAsyncIterator extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue