All tests compile

master
agibsonccc 2021-03-16 11:57:24 +09:00
parent b1229432d6
commit 82bdcc21d2
729 changed files with 6080 additions and 5619 deletions

View File

@ -25,7 +25,7 @@ import org.datavec.api.records.reader.impl.csv.CSVLineSequenceRecordReader;
import org.datavec.api.split.FileSplit; import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;

View File

@ -25,7 +25,7 @@ import org.datavec.api.records.reader.impl.csv.CSVMultiSequenceRecordReader;
import org.datavec.api.split.FileSplit; import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;

View File

@ -26,7 +26,7 @@ import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
import org.datavec.api.split.InputSplit; import org.datavec.api.split.InputSplit;
import org.datavec.api.split.NumberedFileInputSplit; import org.datavec.api.split.NumberedFileInputSplit;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;

View File

@ -27,7 +27,7 @@ import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
import org.datavec.api.records.reader.impl.filebatch.FileBatchRecordReader; import org.datavec.api.records.reader.impl.filebatch.FileBatchRecordReader;
import org.datavec.api.records.reader.impl.filebatch.FileBatchSequenceRecordReader; import org.datavec.api.records.reader.impl.filebatch.FileBatchSequenceRecordReader;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;

View File

@ -27,7 +27,7 @@ import org.datavec.api.split.CollectionInputSplit;
import org.datavec.api.split.FileSplit; import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;

View File

@ -30,7 +30,7 @@ import org.datavec.api.split.NumberedFileInputSplit;
import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;

View File

@ -29,7 +29,7 @@ import org.datavec.api.split.FileSplit;
import org.datavec.api.split.InputSplit; import org.datavec.api.split.InputSplit;
import org.datavec.api.split.InputStreamInputSplit; import org.datavec.api.split.InputStreamInputSplit;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;

View File

@ -32,7 +32,7 @@ import org.datavec.api.split.InputSplit;
import org.datavec.api.split.NumberedFileInputSplit; import org.datavec.api.split.NumberedFileInputSplit;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;

View File

@ -26,14 +26,14 @@ import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader; import org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader;
import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.Assert.*; import static org.junit.jupiter.api.Assertions.*;
public class TestCollectionRecordReaders extends BaseND4JTest { public class TestCollectionRecordReaders extends BaseND4JTest {

View File

@ -23,11 +23,11 @@ package org.datavec.api.records.reader.impl;
import org.datavec.api.records.reader.RecordReader; import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit; import org.datavec.api.split.FileSplit;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestConcatenatingRecordReader extends BaseND4JTest { public class TestConcatenatingRecordReader extends BaseND4JTest {

View File

@ -37,7 +37,7 @@ import org.datavec.api.transform.TransformProcess;
import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import org.nd4j.shade.jackson.core.JsonFactory; import org.nd4j.shade.jackson.core.JsonFactory;
@ -47,7 +47,7 @@ import java.io.*;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestSerialization extends BaseND4JTest { public class TestSerialization extends BaseND4JTest {

View File

@ -30,7 +30,7 @@ import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.LongWritable; import org.datavec.api.writable.LongWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.joda.time.DateTimeZone; import org.joda.time.DateTimeZone;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
@ -38,8 +38,8 @@ import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
public class TransformProcessRecordReaderTests extends BaseND4JTest { public class TransformProcessRecordReaderTests extends BaseND4JTest {

View File

@ -26,7 +26,7 @@ import org.datavec.api.io.filters.BalancedPathFilter;
import org.datavec.api.io.filters.RandomPathFilter; import org.datavec.api.io.filters.RandomPathFilter;
import org.datavec.api.io.labels.ParentPathLabelGenerator; import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.io.labels.PatternPathLabelGenerator; import org.datavec.api.io.labels.PatternPathLabelGenerator;
import org.junit.Test; import org.junit.jupiter.api.Test;
import java.io.*; import java.io.*;
import java.net.URI; import java.net.URI;
@ -35,7 +35,7 @@ import java.util.ArrayList;
import java.util.Random; import java.util.Random;
import static junit.framework.TestCase.assertTrue; import static junit.framework.TestCase.assertTrue;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
/** /**
* *

View File

@ -20,13 +20,12 @@
package org.datavec.api.split; package org.datavec.api.split;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.net.URI; import java.net.URI;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.*;
import static org.junit.Assert.assertTrue;
public class NumberedFileInputSplitTests extends BaseND4JTest { public class NumberedFileInputSplitTests extends BaseND4JTest {
@Test @Test
@ -69,60 +68,81 @@ public class NumberedFileInputSplitTests extends BaseND4JTest {
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx); runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
} }
@Test(expected = IllegalArgumentException.class) @Test()
public void testNumberedFileInputSplitWithLeadingSpaces() { public void testNumberedFileInputSplitWithLeadingSpaces() {
assertThrows(IllegalArgumentException.class,() -> {
String baseString = "/path/to/files/prefix-%5d.suffix"; String baseString = "/path/to/files/prefix-%5d.suffix";
int minIdx = 0; int minIdx = 0;
int maxIdx = 10; int maxIdx = 10;
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx); runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
});
} }
@Test(expected = IllegalArgumentException.class) @Test()
public void testNumberedFileInputSplitWithNoLeadingZeroInPadding() { public void testNumberedFileInputSplitWithNoLeadingZeroInPadding() {
assertThrows(IllegalArgumentException.class, () -> {
String baseString = "/path/to/files/prefix%5d.suffix"; String baseString = "/path/to/files/prefix%5d.suffix";
int minIdx = 0; int minIdx = 0;
int maxIdx = 10; int maxIdx = 10;
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx); runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
});
} }
@Test(expected = IllegalArgumentException.class) @Test()
public void testNumberedFileInputSplitWithLeadingPlusInPadding() { public void testNumberedFileInputSplitWithLeadingPlusInPadding() {
assertThrows(IllegalArgumentException.class,() -> {
String baseString = "/path/to/files/prefix%+5d.suffix"; String baseString = "/path/to/files/prefix%+5d.suffix";
int minIdx = 0; int minIdx = 0;
int maxIdx = 10; int maxIdx = 10;
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx); runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
});
} }
@Test(expected = IllegalArgumentException.class) @Test()
public void testNumberedFileInputSplitWithLeadingMinusInPadding() { public void testNumberedFileInputSplitWithLeadingMinusInPadding() {
assertThrows(IllegalArgumentException.class,() -> {
String baseString = "/path/to/files/prefix%-5d.suffix"; String baseString = "/path/to/files/prefix%-5d.suffix";
int minIdx = 0; int minIdx = 0;
int maxIdx = 10; int maxIdx = 10;
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx); runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
});
} }
@Test(expected = IllegalArgumentException.class) @Test()
public void testNumberedFileInputSplitWithTwoDigitsInPadding() { public void testNumberedFileInputSplitWithTwoDigitsInPadding() {
assertThrows(IllegalArgumentException.class,() -> {
String baseString = "/path/to/files/prefix%011d.suffix"; String baseString = "/path/to/files/prefix%011d.suffix";
int minIdx = 0; int minIdx = 0;
int maxIdx = 10; int maxIdx = 10;
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx); runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
});
} }
@Test(expected = IllegalArgumentException.class) @Test()
public void testNumberedFileInputSplitWithInnerZerosInPadding() { public void testNumberedFileInputSplitWithInnerZerosInPadding() {
assertThrows(IllegalArgumentException.class,() -> {
String baseString = "/path/to/files/prefix%101d.suffix"; String baseString = "/path/to/files/prefix%101d.suffix";
int minIdx = 0; int minIdx = 0;
int maxIdx = 10; int maxIdx = 10;
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx); runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
});
} }
@Test(expected = IllegalArgumentException.class) @Test()
public void testNumberedFileInputSplitWithRepeatInnerZerosInPadding() { public void testNumberedFileInputSplitWithRepeatInnerZerosInPadding() {
assertThrows(IllegalArgumentException.class,() -> {
String baseString = "/path/to/files/prefix%0505d.suffix"; String baseString = "/path/to/files/prefix%0505d.suffix";
int minIdx = 0; int minIdx = 0;
int maxIdx = 10; int maxIdx = 10;
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx); runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
});
} }
@ -135,7 +155,7 @@ public class NumberedFileInputSplitTests extends BaseND4JTest {
String path = locs[j++].getPath(); String path = locs[j++].getPath();
String exp = String.format(baseString, i); String exp = String.format(baseString, i);
String msg = exp + " vs " + path; String msg = exp + " vs " + path;
assertTrue(msg, path.endsWith(exp)); //Note: on Windows, Java can prepend drive to path - "/C:/" assertTrue(path.endsWith(exp),msg); //Note: on Windows, Java can prepend drive to path - "/C:/"
} }
} }
} }

View File

@ -25,9 +25,10 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader; import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.function.Function; import org.nd4j.common.function.Function;
@ -37,22 +38,22 @@ import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.net.URI; import java.net.URI;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.Random; import java.util.Random;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertNotEquals;
public class TestStreamInputSplit extends BaseND4JTest { public class TestStreamInputSplit extends BaseND4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test @Test
public void testCsvSimple() throws Exception { public void testCsvSimple(@TempDir Path testDir) throws Exception {
File dir = testDir.newFolder(); File dir = testDir.toFile();
File f1 = new File(dir, "file1.txt"); File f1 = new File(dir, "file1.txt");
File f2 = new File(dir, "file2.txt"); File f2 = new File(dir, "file2.txt");
@ -93,9 +94,9 @@ public class TestStreamInputSplit extends BaseND4JTest {
@Test @Test
public void testCsvSequenceSimple() throws Exception { public void testCsvSequenceSimple(@TempDir Path testDir) throws Exception {
File dir = testDir.newFolder(); File dir = testDir.toFile();
File f1 = new File(dir, "file1.txt"); File f1 = new File(dir, "file1.txt");
File f2 = new File(dir, "file2.txt"); File f2 = new File(dir, "file2.txt");
@ -137,8 +138,8 @@ public class TestStreamInputSplit extends BaseND4JTest {
} }
@Test @Test
public void testShuffle() throws Exception { public void testShuffle(@TempDir Path testDir) throws Exception {
File dir = testDir.newFolder(); File dir = testDir.toFile();
File f1 = new File(dir, "file1.txt"); File f1 = new File(dir, "file1.txt");
File f2 = new File(dir, "file2.txt"); File f2 = new File(dir, "file2.txt");
File f3 = new File(dir, "file3.txt"); File f3 = new File(dir, "file3.txt");

View File

@ -27,14 +27,14 @@ import org.datavec.api.split.FileSplit;
import org.datavec.api.split.partition.NumberOfRecordsPartitioner; import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
import org.datavec.api.split.partition.PartitionMetaData; import org.datavec.api.split.partition.PartitionMetaData;
import org.datavec.api.split.partition.Partitioner; import org.datavec.api.split.partition.Partitioner;
import org.junit.Test; import org.junit.jupiter.api.Test;
import java.io.File; import java.io.File;
import java.io.OutputStream; import java.io.OutputStream;
import static junit.framework.TestCase.assertTrue; import static junit.framework.TestCase.assertTrue;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNotNull;
public class PartitionerTests extends BaseND4JTest { public class PartitionerTests extends BaseND4JTest {
@Test @Test

View File

@ -29,12 +29,12 @@ import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.util.*; import java.util.*;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestTransformProcess extends BaseND4JTest { public class TestTransformProcess extends BaseND4JTest {

View File

@ -27,13 +27,13 @@ import org.datavec.api.transform.condition.string.StringRegexColumnCondition;
import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.transform.TestTransforms; import org.datavec.api.transform.transform.TestTransforms;
import org.datavec.api.writable.*; import org.datavec.api.writable.*;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.util.*; import java.util.*;
import static org.junit.Assert.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.Assert.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
public class TestConditions extends BaseND4JTest { public class TestConditions extends BaseND4JTest {

View File

@ -27,7 +27,7 @@ import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.util.ArrayList; import java.util.ArrayList;
@ -36,8 +36,8 @@ import java.util.Collections;
import java.util.List; import java.util.List;
import static java.util.Arrays.asList; import static java.util.Arrays.asList;
import static org.junit.Assert.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.Assert.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
public class TestFilters extends BaseND4JTest { public class TestFilters extends BaseND4JTest {

View File

@ -26,19 +26,22 @@ import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.NullWritable; import org.datavec.api.writable.NullWritable;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.nio.file.Path;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
public class TestJoin extends BaseND4JTest { public class TestJoin extends BaseND4JTest {
@Test @Test
public void testJoin() { public void testJoin(@TempDir Path testDir) {
Schema firstSchema = Schema firstSchema =
new Schema.Builder().addColumnString("keyColumn").addColumnsInteger("first0", "first1").build(); new Schema.Builder().addColumnString("keyColumn").addColumnsInteger("first0", "first1").build();
@ -46,20 +49,20 @@ public class TestJoin extends BaseND4JTest {
Schema secondSchema = new Schema.Builder().addColumnString("keyColumn").addColumnsInteger("second0").build(); Schema secondSchema = new Schema.Builder().addColumnString("keyColumn").addColumnsInteger("second0").build();
List<List<Writable>> first = new ArrayList<>(); List<List<Writable>> first = new ArrayList<>();
first.add(Arrays.asList((Writable) new Text("key0"), new IntWritable(0), new IntWritable(1))); first.add(Arrays.asList(new Text("key0"), new IntWritable(0), new IntWritable(1)));
first.add(Arrays.asList((Writable) new Text("key1"), new IntWritable(10), new IntWritable(11))); first.add(Arrays.asList(new Text("key1"), new IntWritable(10), new IntWritable(11)));
List<List<Writable>> second = new ArrayList<>(); List<List<Writable>> second = new ArrayList<>();
second.add(Arrays.asList((Writable) new Text("key0"), new IntWritable(100))); second.add(Arrays.asList(new Text("key0"), new IntWritable(100)));
second.add(Arrays.asList((Writable) new Text("key1"), new IntWritable(110))); second.add(Arrays.asList(new Text("key1"), new IntWritable(110)));
Join join = new Join.Builder(Join.JoinType.Inner).setJoinColumns("keyColumn") Join join = new Join.Builder(Join.JoinType.Inner).setJoinColumns("keyColumn")
.setSchemas(firstSchema, secondSchema).build(); .setSchemas(firstSchema, secondSchema).build();
List<List<Writable>> expected = new ArrayList<>(); List<List<Writable>> expected = new ArrayList<>();
expected.add(Arrays.asList((Writable) new Text("key0"), new IntWritable(0), new IntWritable(1), expected.add(Arrays.asList(new Text("key0"), new IntWritable(0), new IntWritable(1),
new IntWritable(100))); new IntWritable(100)));
expected.add(Arrays.asList((Writable) new Text("key1"), new IntWritable(10), new IntWritable(11), expected.add(Arrays.asList(new Text("key1"), new IntWritable(10), new IntWritable(11),
new IntWritable(110))); new IntWritable(110)));
@ -94,9 +97,9 @@ public class TestJoin extends BaseND4JTest {
} }
@Test(expected = IllegalArgumentException.class) @Test()
public void testJoinValidation() { public void testJoinValidation() {
assertThrows(IllegalArgumentException.class,() -> {
Schema firstSchema = new Schema.Builder().addColumnString("keyColumn1").addColumnsInteger("first0", "first1") Schema firstSchema = new Schema.Builder().addColumnString("keyColumn1").addColumnsInteger("first0", "first1")
.build(); .build();
@ -104,11 +107,13 @@ public class TestJoin extends BaseND4JTest {
new Join.Builder(Join.JoinType.Inner).setJoinColumns("keyColumn1", "thisDoesntExist") new Join.Builder(Join.JoinType.Inner).setJoinColumns("keyColumn1", "thisDoesntExist")
.setSchemas(firstSchema, secondSchema).build(); .setSchemas(firstSchema, secondSchema).build();
});
} }
@Test(expected = IllegalArgumentException.class) @Test()
public void testJoinValidation2() { public void testJoinValidation2() {
assertThrows(IllegalArgumentException.class,() -> {
Schema firstSchema = new Schema.Builder().addColumnString("keyColumn1").addColumnsInteger("first0", "first1") Schema firstSchema = new Schema.Builder().addColumnString("keyColumn1").addColumnsInteger("first0", "first1")
.build(); .build();
@ -116,5 +121,7 @@ public class TestJoin extends BaseND4JTest {
new Join.Builder(Join.JoinType.Inner).setJoinColumns("keyColumn1").setSchemas(firstSchema, secondSchema) new Join.Builder(Join.JoinType.Inner).setJoinColumns("keyColumn1").setSchemas(firstSchema, secondSchema)
.build(); .build();
});
} }
} }

View File

@ -19,17 +19,18 @@
*/ */
package org.datavec.api.transform.ops; package org.datavec.api.transform.ops;
import org.junit.Rule;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.ExpectedException; import org.junit.rules.ExpectedException;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.DisplayName;
import static org.junit.jupiter.api.Assertions.*;
@DisplayName("Aggregator Impls Test") @DisplayName("Aggregator Impls Test")
class AggregatorImplsTest extends BaseND4JTest { class AggregatorImplsTest extends BaseND4JTest {
@ -265,12 +266,12 @@ class AggregatorImplsTest extends BaseND4JTest {
assertEquals(9, cu.get().toInt()); assertEquals(9, cu.get().toInt());
} }
@Rule
public final ExpectedException exception = ExpectedException.none();
@Test @Test
@DisplayName("Incompatible Aggregator Test") @DisplayName("Incompatible Aggregator Test")
void incompatibleAggregatorTest() { void incompatibleAggregatorTest() {
assertThrows(UnsupportedOperationException.class,() -> {
AggregatorImpls.AggregableSum<Integer> sm = new AggregatorImpls.AggregableSum<>(); AggregatorImpls.AggregableSum<Integer> sm = new AggregatorImpls.AggregableSum<>();
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
sm.accept(intList.get(i)); sm.accept(intList.get(i));
@ -280,8 +281,10 @@ class AggregatorImplsTest extends BaseND4JTest {
for (int i = 0; i < intList.size(); i++) { for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1)); reverse.accept(intList.get(intList.size() - i - 1));
} }
exception.expect(UnsupportedOperationException.class);
sm.combine(reverse); sm.combine(reverse);
assertEquals(45, sm.get().toInt()); assertEquals(45, sm.get().toInt());
});
} }
} }

View File

@ -32,13 +32,13 @@ import org.datavec.api.transform.ops.AggregableMultiOp;
import org.datavec.api.transform.ops.IAggregableReduceOp; import org.datavec.api.transform.ops.IAggregableReduceOp;
import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.*; import org.datavec.api.writable.*;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.util.*; import java.util.*;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.fail; import static org.junit.jupiter.api.Assertions.fail;
public class TestMultiOpReduce extends BaseND4JTest { public class TestMultiOpReduce extends BaseND4JTest {
@ -46,10 +46,10 @@ public class TestMultiOpReduce extends BaseND4JTest {
public void testMultiOpReducerDouble() { public void testMultiOpReducerDouble() {
List<List<Writable>> inputs = new ArrayList<>(); List<List<Writable>> inputs = new ArrayList<>();
inputs.add(Arrays.asList((Writable) new Text("someKey"), new DoubleWritable(0))); inputs.add(Arrays.asList(new Text("someKey"), new DoubleWritable(0)));
inputs.add(Arrays.asList((Writable) new Text("someKey"), new DoubleWritable(1))); inputs.add(Arrays.asList(new Text("someKey"), new DoubleWritable(1)));
inputs.add(Arrays.asList((Writable) new Text("someKey"), new DoubleWritable(2))); inputs.add(Arrays.asList(new Text("someKey"), new DoubleWritable(2)));
inputs.add(Arrays.asList((Writable) new Text("someKey"), new DoubleWritable(2))); inputs.add(Arrays.asList(new Text("someKey"), new DoubleWritable(2)));
Map<ReduceOp, Double> exp = new LinkedHashMap<>(); Map<ReduceOp, Double> exp = new LinkedHashMap<>();
exp.put(ReduceOp.Min, 0.0); exp.put(ReduceOp.Min, 0.0);
@ -82,7 +82,7 @@ public class TestMultiOpReduce extends BaseND4JTest {
assertEquals(out.get(0), new Text("someKey")); assertEquals(out.get(0), new Text("someKey"));
String msg = op.toString(); String msg = op.toString();
assertEquals(msg, exp.get(op), out.get(1).toDouble(), 1e-5); assertEquals(exp.get(op), out.get(1).toDouble(), 1e-5,msg);
} }
} }
@ -126,7 +126,7 @@ public class TestMultiOpReduce extends BaseND4JTest {
assertEquals(out.get(0), new Text("someKey")); assertEquals(out.get(0), new Text("someKey"));
String msg = op.toString(); String msg = op.toString();
assertEquals(msg, exp.get(op), out.get(1).toDouble(), 1e-5); assertEquals(exp.get(op), out.get(1).toDouble(), 1e-5,msg);
} }
} }
@ -210,7 +210,7 @@ public class TestMultiOpReduce extends BaseND4JTest {
assertEquals(out.get(0), new Text("someKey")); assertEquals(out.get(0), new Text("someKey"));
String msg = op.toString(); String msg = op.toString();
assertEquals(msg, exp.get(op), out.get(1).toDouble(), 1e-5); assertEquals(exp.get(op), out.get(1).toDouble(), 1e-5,msg);
} }
for (ReduceOp op : Arrays.asList(ReduceOp.Min, ReduceOp.Max, ReduceOp.Range, ReduceOp.Sum, ReduceOp.Mean, for (ReduceOp op : Arrays.asList(ReduceOp.Min, ReduceOp.Max, ReduceOp.Range, ReduceOp.Sum, ReduceOp.Mean,

View File

@ -24,13 +24,13 @@ import org.datavec.api.transform.ops.IAggregableReduceOp;
import org.datavec.api.transform.reduce.impl.GeographicMidpointReduction; import org.datavec.api.transform.reduce.impl.GeographicMidpointReduction;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestReductions extends BaseND4JTest { public class TestReductions extends BaseND4JTest {

View File

@ -22,10 +22,10 @@ package org.datavec.api.transform.schema;
import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.metadata.ColumnMetaData;
import org.joda.time.DateTimeZone; import org.joda.time.DateTimeZone;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestJsonYaml extends BaseND4JTest { public class TestJsonYaml extends BaseND4JTest {

View File

@ -21,10 +21,10 @@
package org.datavec.api.transform.schema; package org.datavec.api.transform.schema;
import org.datavec.api.transform.ColumnType; import org.datavec.api.transform.ColumnType;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestSchemaMethods extends BaseND4JTest { public class TestSchemaMethods extends BaseND4JTest {

View File

@ -33,7 +33,7 @@ import org.datavec.api.writable.LongWritable;
import org.datavec.api.writable.NullWritable; import org.datavec.api.writable.NullWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.joda.time.DateTimeZone; import org.joda.time.DateTimeZone;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.util.ArrayList; import java.util.ArrayList;
@ -41,7 +41,7 @@ import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestReduceSequenceByWindowFunction extends BaseND4JTest { public class TestReduceSequenceByWindowFunction extends BaseND4JTest {

View File

@ -27,7 +27,7 @@ import org.datavec.api.writable.LongWritable;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.joda.time.DateTimeZone; import org.joda.time.DateTimeZone;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.util.ArrayList; import java.util.ArrayList;
@ -35,7 +35,7 @@ import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestSequenceSplit extends BaseND4JTest { public class TestSequenceSplit extends BaseND4JTest {

View File

@ -29,7 +29,7 @@ import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.LongWritable; import org.datavec.api.writable.LongWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.joda.time.DateTimeZone; import org.joda.time.DateTimeZone;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.util.ArrayList; import java.util.ArrayList;
@ -37,7 +37,7 @@ import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestWindowFunctions extends BaseND4JTest { public class TestWindowFunctions extends BaseND4JTest {

View File

@ -26,10 +26,10 @@ import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.serde.testClasses.CustomCondition; import org.datavec.api.transform.serde.testClasses.CustomCondition;
import org.datavec.api.transform.serde.testClasses.CustomFilter; import org.datavec.api.transform.serde.testClasses.CustomFilter;
import org.datavec.api.transform.serde.testClasses.CustomTransform; import org.datavec.api.transform.serde.testClasses.CustomTransform;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestCustomTransformJsonYaml extends BaseND4JTest { public class TestCustomTransformJsonYaml extends BaseND4JTest {

View File

@ -64,13 +64,13 @@ import org.datavec.api.transform.transform.time.TimeMathOpTransform;
import org.datavec.api.writable.comparator.DoubleWritableComparator; import org.datavec.api.writable.comparator.DoubleWritableComparator;
import org.joda.time.DateTimeFieldType; import org.joda.time.DateTimeFieldType;
import org.joda.time.DateTimeZone; import org.joda.time.DateTimeZone;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.util.*; import java.util.*;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestYamlJsonSerde extends BaseND4JTest { public class TestYamlJsonSerde extends BaseND4JTest {

View File

@ -24,12 +24,12 @@ import org.datavec.api.transform.StringReduceOp;
import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.util.*; import java.util.*;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestReduce extends BaseND4JTest { public class TestReduce extends BaseND4JTest {

View File

@ -50,7 +50,7 @@ import org.datavec.api.writable.Text;
import org.datavec.api.writable.comparator.LongWritableComparator; import org.datavec.api.writable.comparator.LongWritableComparator;
import org.joda.time.DateTimeFieldType; import org.joda.time.DateTimeFieldType;
import org.joda.time.DateTimeZone; import org.joda.time.DateTimeZone;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
@ -61,7 +61,7 @@ import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class RegressionTestJson extends BaseND4JTest { public class RegressionTestJson extends BaseND4JTest {

View File

@ -50,13 +50,13 @@ import org.datavec.api.writable.Text;
import org.datavec.api.writable.comparator.LongWritableComparator; import org.datavec.api.writable.comparator.LongWritableComparator;
import org.joda.time.DateTimeFieldType; import org.joda.time.DateTimeFieldType;
import org.joda.time.DateTimeZone; import org.joda.time.DateTimeZone;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.util.*; import java.util.*;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestJsonYaml extends BaseND4JTest { public class TestJsonYaml extends BaseND4JTest {

View File

@ -59,7 +59,7 @@ import org.datavec.api.writable.*;
import org.joda.time.DateTimeFieldType; import org.joda.time.DateTimeFieldType;
import org.joda.time.DateTimeZone; import org.joda.time.DateTimeZone;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -72,7 +72,7 @@ import java.util.*;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static junit.framework.TestCase.assertEquals; import static junit.framework.TestCase.assertEquals;
import static org.junit.Assert.*; import static org.junit.jupiter.api.Assertions.*;
public class TestTransforms extends BaseND4JTest { public class TestTransforms extends BaseND4JTest {

View File

@ -29,7 +29,7 @@ import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.NDArrayWritable; import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -39,7 +39,7 @@ import org.nd4j.linalg.ops.transforms.Transforms;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestNDArrayWritableTransforms extends BaseND4JTest { public class TestNDArrayWritableTransforms extends BaseND4JTest {

View File

@ -30,13 +30,13 @@ import org.datavec.api.transform.ndarray.NDArrayScalarOpTransform;
import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.serde.JsonSerializer; import org.datavec.api.transform.serde.JsonSerializer;
import org.datavec.api.transform.serde.YamlSerializer; import org.datavec.api.transform.serde.YamlSerializer;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestYamlJsonSerde extends BaseND4JTest { public class TestYamlJsonSerde extends BaseND4JTest {

View File

@ -35,26 +35,26 @@ import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.joda.time.DateTimeZone; import org.joda.time.DateTimeZone;
import org.junit.Ignore; import org.junit.jupiter.api.Disabled;
import org.junit.Rule;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.io.File; import java.io.File;
import java.nio.file.Path;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestUI extends BaseND4JTest { public class TestUI extends BaseND4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test @Test
public void testUI() throws Exception { public void testUI(@TempDir Path testDir) throws Exception {
Schema schema = new Schema.Builder().addColumnString("StringColumn").addColumnInteger("IntColumn") Schema schema = new Schema.Builder().addColumnString("StringColumn").addColumnInteger("IntColumn")
.addColumnInteger("IntColumn2").addColumnInteger("IntColumn3") .addColumnInteger("IntColumn2").addColumnInteger("IntColumn3")
.addColumnTime("TimeColumn", DateTimeZone.UTC).build(); .addColumnTime("TimeColumn", DateTimeZone.UTC).build();
@ -92,7 +92,7 @@ public class TestUI extends BaseND4JTest {
DataAnalysis da = new DataAnalysis(schema, list); DataAnalysis da = new DataAnalysis(schema, list);
File fDir = testDir.newFolder(); File fDir = testDir.toFile();
String tempDir = fDir.getAbsolutePath(); String tempDir = fDir.getAbsolutePath();
String outPath = FilenameUtils.concat(tempDir, "datavec_transform_UITest.html"); String outPath = FilenameUtils.concat(tempDir, "datavec_transform_UITest.html");
System.out.println(outPath); System.out.println(outPath);
@ -143,7 +143,7 @@ public class TestUI extends BaseND4JTest {
@Test @Test
@Ignore @Disabled
public void testSequencePlot() throws Exception { public void testSequencePlot() throws Exception {
Schema schema = new SequenceSchema.Builder().addColumnDouble("sinx") Schema schema = new SequenceSchema.Builder().addColumnDouble("sinx")

View File

@ -21,14 +21,14 @@
package org.datavec.api.writable; package org.datavec.api.writable;
import org.datavec.api.transform.metadata.NDArrayMetaData; import org.datavec.api.transform.metadata.NDArrayMetaData;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import java.io.*; import java.io.*;
import static org.junit.Assert.*; import static org.junit.jupiter.api.Assertions.*;
public class TestNDArrayWritableAndSerialization extends BaseND4JTest { public class TestNDArrayWritableAndSerialization extends BaseND4JTest {

View File

@ -41,7 +41,7 @@ import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.*; import org.datavec.api.writable.*;
import org.datavec.arrow.recordreader.ArrowRecordReader; import org.datavec.arrow.recordreader.ArrowRecordReader;
import org.datavec.arrow.recordreader.ArrowWritableRecordBatch; import org.datavec.arrow.recordreader.ArrowWritableRecordBatch;
import org.junit.Rule;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;

View File

@ -29,16 +29,16 @@ import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.datavec.arrow.ArrowConverter; import org.datavec.arrow.ArrowConverter;
import org.junit.Ignore; import org.junit.jupiter.api.Disabled;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
public class ArrowWritableRecordTimeSeriesBatchTests extends BaseND4JTest { public class ArrowWritableRecordTimeSeriesBatchTests extends BaseND4JTest {
@ -69,7 +69,7 @@ public class ArrowWritableRecordTimeSeriesBatchTests extends BaseND4JTest {
assertEquals(3,fieldVectors.size()); assertEquals(3,fieldVectors.size());
for(FieldVector fieldVector : fieldVectors) { for(FieldVector fieldVector : fieldVectors) {
for(int i = 0; i < fieldVector.getValueCount(); i++) { for(int i = 0; i < fieldVector.getValueCount(); i++) {
assertFalse("Index " + i + " was null for field vector " + fieldVector, fieldVector.isNull(i)); assertFalse( fieldVector.isNull(i),"Index " + i + " was null for field vector " + fieldVector);
} }
} }
@ -79,7 +79,7 @@ public class ArrowWritableRecordTimeSeriesBatchTests extends BaseND4JTest {
@Test @Test
//not worried about this till after next release //not worried about this till after next release
@Ignore @Disabled
public void testVariableLengthTS() { public void testVariableLengthTS() {
Schema.Builder schema = new Schema.Builder() Schema.Builder schema = new Schema.Builder()
.addColumnString("str") .addColumnString("str")

View File

@ -23,7 +23,7 @@ import org.apache.commons.io.FileUtils;
import org.datavec.api.io.labels.ParentPathLabelGenerator; import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit; import org.datavec.api.split.FileSplit;
import org.datavec.image.recordreader.ImageRecordReader; import org.datavec.image.recordreader.ImageRecordReader;
import org.junit.Rule;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;

View File

@ -22,8 +22,8 @@ package org.datavec.image.loader;
import org.apache.commons.io.FilenameUtils; import org.apache.commons.io.FilenameUtils;
import org.datavec.api.records.reader.RecordReader; import org.datavec.api.records.reader.RecordReader;
import org.junit.Ignore; import org.junit.jupiter.api.Disabled;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
import java.io.File; import java.io.File;
@ -32,9 +32,9 @@ import java.io.InputStream;
import java.util.List; import java.util.List;
import java.util.Random; import java.util.Random;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.Assert.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
/** /**
* *
@ -182,7 +182,7 @@ public class LoaderTests {
} }
@Ignore // Use when confirming data is getting stored @Disabled // Use when confirming data is getting stored
@Test @Test
public void testProcessCifar() { public void testProcessCifar() {
int row = 32; int row = 32;
@ -208,15 +208,15 @@ public class LoaderTests {
int minibatch = 100; int minibatch = 100;
int nMinibatches = 50000 / minibatch; int nMinibatches = 50000 / minibatch;
for( int i=0; i<nMinibatches; i++ ){ for( int i=0; i < nMinibatches; i++) {
DataSet ds = loader.next(minibatch); DataSet ds = loader.next(minibatch);
String s = String.valueOf(i); String s = String.valueOf(i);
assertNotNull(s, ds.getFeatures()); assertNotNull(ds.getFeatures(),s);
assertNotNull(s, ds.getLabels()); assertNotNull(ds.getLabels(),s);
assertEquals(s, minibatch, ds.getFeatures().size(0)); assertEquals(minibatch, ds.getFeatures().size(0),s);
assertEquals(s, minibatch, ds.getLabels().size(0)); assertEquals(minibatch, ds.getLabels().size(0),s);
assertEquals(s, 10, ds.getLabels().size(1)); assertEquals(10, ds.getLabels().size(1),s);
} }
} }

View File

@ -21,7 +21,7 @@
package org.datavec.image.loader; package org.datavec.image.loader;
import org.datavec.image.data.Image; import org.datavec.image.data.Image;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.resources.Resources; import org.nd4j.common.resources.Resources;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -32,7 +32,7 @@ import java.io.FileInputStream;
import java.io.InputStream; import java.io.InputStream;
import java.util.Random; import java.util.Random;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestImageLoader { public class TestImageLoader {

View File

@ -30,9 +30,10 @@ import org.bytedeco.javacv.Java2DFrameConverter;
import org.bytedeco.javacv.OpenCVFrameConverter; import org.bytedeco.javacv.OpenCVFrameConverter;
import org.datavec.image.data.Image; import org.datavec.image.data.Image;
import org.datavec.image.data.ImageWritable; import org.datavec.image.data.ImageWritable;
import org.junit.Rule;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.resources.Resources; import org.nd4j.common.resources.Resources;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -42,16 +43,17 @@ import org.nd4j.common.io.ClassPathResource;
import java.awt.image.BufferedImage; import java.awt.image.BufferedImage;
import java.io.*; import java.io.*;
import java.lang.reflect.Field; import java.lang.reflect.Field;
import java.nio.file.Path;
import java.util.Random; import java.util.Random;
import org.bytedeco.leptonica.*; import org.bytedeco.leptonica.*;
import org.bytedeco.opencv.opencv_core.*; import org.bytedeco.opencv.opencv_core.*;
import static org.bytedeco.leptonica.global.lept.*; import static org.bytedeco.leptonica.global.lept.*;
import static org.bytedeco.opencv.global.opencv_core.*; import static org.bytedeco.opencv.global.opencv_core.*;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.Assert.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.Assert.fail; import static org.junit.jupiter.api.Assertions.fail;
/** /**
* *
@ -62,8 +64,6 @@ public class TestNativeImageLoader {
static final long seed = 10; static final long seed = 10;
static final Random rng = new Random(seed); static final Random rng = new Random(seed);
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test @Test
public void testConvertPix() throws Exception { public void testConvertPix() throws Exception {
@ -566,8 +566,8 @@ public class TestNativeImageLoader {
@Test @Test
public void testNativeImageLoaderEmptyStreams() throws Exception { public void testNativeImageLoaderEmptyStreams(@TempDir Path testDir) throws Exception {
File dir = testDir.newFolder(); File dir = testDir.toFile();
File f = new File(dir, "myFile.jpg"); File f = new File(dir, "myFile.jpg");
f.createNewFile(); f.createNewFile();
@ -578,7 +578,7 @@ public class TestNativeImageLoader {
fail("Expected exception"); fail("Expected exception");
} catch (IOException e){ } catch (IOException e){
String msg = e.getMessage(); String msg = e.getMessage();
assertTrue(msg, msg.contains("decode image")); assertTrue(msg.contains("decode image"),msg);
} }
try(InputStream is = new FileInputStream(f)){ try(InputStream is = new FileInputStream(f)){
@ -586,7 +586,7 @@ public class TestNativeImageLoader {
fail("Expected exception"); fail("Expected exception");
} catch (IOException e){ } catch (IOException e){
String msg = e.getMessage(); String msg = e.getMessage();
assertTrue(msg, msg.contains("decode image")); assertTrue(msg.contains("decode image"),msg);
} }
try(InputStream is = new FileInputStream(f)){ try(InputStream is = new FileInputStream(f)){
@ -594,7 +594,7 @@ public class TestNativeImageLoader {
fail("Expected exception"); fail("Expected exception");
} catch (IOException e){ } catch (IOException e){
String msg = e.getMessage(); String msg = e.getMessage();
assertTrue(msg, msg.contains("decode image")); assertTrue(msg.contains("decode image"),msg);
} }
try(InputStream is = new FileInputStream(f)){ try(InputStream is = new FileInputStream(f)){
@ -603,7 +603,7 @@ public class TestNativeImageLoader {
fail("Expected exception"); fail("Expected exception");
} catch (IOException e){ } catch (IOException e){
String msg = e.getMessage(); String msg = e.getMessage();
assertTrue(msg, msg.contains("decode image")); assertTrue( msg.contains("decode image"),msg);
} }
} }

View File

@ -27,7 +27,7 @@ import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.NDArrayWritable; import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.datavec.image.loader.NativeImageLoader; import org.datavec.image.loader.NativeImageLoader;
import org.junit.Rule;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.loader.FileBatch; import org.nd4j.common.loader.FileBatch;

View File

@ -36,9 +36,10 @@ import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.NDArrayWritable; import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.datavec.api.writable.batch.NDArrayRecordBatch; import org.datavec.api.writable.batch.NDArrayRecordBatch;
import org.junit.Rule;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -46,28 +47,30 @@ import org.nd4j.common.io.ClassPathResource;
import java.io.*; import java.io.*;
import java.net.URI; import java.net.URI;
import java.nio.file.Path;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.Random; import java.util.Random;
import static org.junit.Assert.*; import static org.junit.jupiter.api.Assertions.*;
public class TestImageRecordReader { public class TestImageRecordReader {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test(expected = IllegalArgumentException.class) @Test()
public void testEmptySplit() throws IOException { public void testEmptySplit() throws IOException {
InputSplit data = new CollectionInputSplit(new ArrayList<URI>()); assertThrows(IllegalArgumentException.class,() -> {
InputSplit data = new CollectionInputSplit(new ArrayList<>());
new ImageRecordReader().initialize(data, null); new ImageRecordReader().initialize(data, null);
});
} }
@Test @Test
public void testMetaData() throws IOException { public void testMetaData(@TempDir Path testDir) throws IOException {
File parentDir = testDir.newFolder(); File parentDir = testDir.toFile();
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(parentDir); new ClassPathResource("datavec-data-image/testimages/").copyDirectory(parentDir);
// System.out.println(f.getAbsolutePath()); // System.out.println(f.getAbsolutePath());
// System.out.println(f.getParentFile().getParentFile().getAbsolutePath()); // System.out.println(f.getParentFile().getParentFile().getAbsolutePath());
@ -104,11 +107,11 @@ public class TestImageRecordReader {
} }
@Test @Test
public void testImageRecordReaderLabelsOrder() throws Exception { public void testImageRecordReaderLabelsOrder(@TempDir Path testDir) throws Exception {
//Labels order should be consistent, regardless of file iteration order //Labels order should be consistent, regardless of file iteration order
//Idea: labels order should be consistent regardless of input file order //Idea: labels order should be consistent regardless of input file order
File f = testDir.newFolder(); File f = testDir.toFile();
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f); new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f);
File f0 = new File(f, "/class0/0.jpg"); File f0 = new File(f, "/class0/0.jpg");
File f1 = new File(f, "/class1/A.jpg"); File f1 = new File(f, "/class1/A.jpg");
@ -135,11 +138,11 @@ public class TestImageRecordReader {
@Test @Test
public void testImageRecordReaderRandomization() throws Exception { public void testImageRecordReaderRandomization(@TempDir Path testDir) throws Exception {
//Order of FileSplit+ImageRecordReader should be different after reset //Order of FileSplit+ImageRecordReader should be different after reset
//Idea: labels order should be consistent regardless of input file order //Idea: labels order should be consistent regardless of input file order
File f0 = testDir.newFolder(); File f0 = testDir.toFile();
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f0); new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f0);
FileSplit fs = new FileSplit(f0, new Random(12345)); FileSplit fs = new FileSplit(f0, new Random(12345));
@ -189,13 +192,13 @@ public class TestImageRecordReader {
@Test @Test
public void testImageRecordReaderRegression() throws Exception { public void testImageRecordReaderRegression(@TempDir Path testDir) throws Exception {
PathLabelGenerator regressionLabelGen = new TestRegressionLabelGen(); PathLabelGenerator regressionLabelGen = new TestRegressionLabelGen();
ImageRecordReader rr = new ImageRecordReader(28, 28, 3, regressionLabelGen); ImageRecordReader rr = new ImageRecordReader(28, 28, 3, regressionLabelGen);
File rootDir = testDir.newFolder(); File rootDir = testDir.toFile();
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(rootDir); new ClassPathResource("datavec-data-image/testimages/").copyDirectory(rootDir);
FileSplit fs = new FileSplit(rootDir); FileSplit fs = new FileSplit(rootDir);
rr.initialize(fs); rr.initialize(fs);
@ -244,10 +247,10 @@ public class TestImageRecordReader {
} }
@Test @Test
public void testListenerInvocationBatch() throws IOException { public void testListenerInvocationBatch(@TempDir Path testDir) throws IOException {
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator(); ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
ImageRecordReader rr = new ImageRecordReader(32, 32, 3, labelMaker); ImageRecordReader rr = new ImageRecordReader(32, 32, 3, labelMaker);
File f = testDir.newFolder(); File f = testDir.toFile();
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f); new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f);
File parent = f; File parent = f;
@ -260,10 +263,10 @@ public class TestImageRecordReader {
} }
@Test @Test
public void testListenerInvocationSingle() throws IOException { public void testListenerInvocationSingle(@TempDir Path testDir) throws IOException {
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator(); ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
ImageRecordReader rr = new ImageRecordReader(32, 32, 3, labelMaker); ImageRecordReader rr = new ImageRecordReader(32, 32, 3, labelMaker);
File parent = testDir.newFolder(); File parent = testDir.toFile();
new ClassPathResource("datavec-data-image/testimages/class0/").copyDirectory(parent); new ClassPathResource("datavec-data-image/testimages/class0/").copyDirectory(parent);
int numFiles = parent.list().length; int numFiles = parent.list().length;
rr.initialize(new FileSplit(parent)); rr.initialize(new FileSplit(parent));
@ -315,7 +318,7 @@ public class TestImageRecordReader {
@Test @Test
public void testImageRecordReaderPathMultiLabelGenerator() throws Exception { public void testImageRecordReaderPathMultiLabelGenerator(@TempDir Path testDir) throws Exception {
Nd4j.setDataType(DataType.FLOAT); Nd4j.setDataType(DataType.FLOAT);
//Assumption: 2 multi-class (one hot) classification labels: 2 and 3 classes respectively //Assumption: 2 multi-class (one hot) classification labels: 2 and 3 classes respectively
// PLUS single value (Writable) regression label // PLUS single value (Writable) regression label
@ -324,7 +327,7 @@ public class TestImageRecordReader {
ImageRecordReader rr = new ImageRecordReader(28, 28, 3, multiLabelGen); ImageRecordReader rr = new ImageRecordReader(28, 28, 3, multiLabelGen);
File rootDir = testDir.newFolder(); File rootDir = testDir.toFile();
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(rootDir); new ClassPathResource("datavec-data-image/testimages/").copyDirectory(rootDir);
FileSplit fs = new FileSplit(rootDir); FileSplit fs = new FileSplit(rootDir);
rr.initialize(fs); rr.initialize(fs);
@ -471,9 +474,9 @@ public class TestImageRecordReader {
@Test @Test
public void testNCHW_NCHW() throws Exception { public void testNCHW_NCHW(@TempDir Path testDir) throws Exception {
//Idea: labels order should be consistent regardless of input file order //Idea: labels order should be consistent regardless of input file order
File f0 = testDir.newFolder(); File f0 = testDir.toFile();
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f0); new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f0);
FileSplit fs0 = new FileSplit(f0, new Random(12345)); FileSplit fs0 = new FileSplit(f0, new Random(12345));

View File

@ -35,9 +35,10 @@ import org.datavec.image.transform.FlipImageTransform;
import org.datavec.image.transform.ImageTransform; import org.datavec.image.transform.ImageTransform;
import org.datavec.image.transform.PipelineImageTransform; import org.datavec.image.transform.PipelineImageTransform;
import org.datavec.image.transform.ResizeImageTransform; import org.datavec.image.transform.ResizeImageTransform;
import org.junit.Rule;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing; import org.nd4j.linalg.indexing.BooleanIndexing;
@ -46,24 +47,24 @@ import org.nd4j.common.io.ClassPathResource;
import java.io.File; import java.io.File;
import java.net.URI; import java.net.URI;
import java.nio.file.Path;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import static org.junit.Assert.*; import static org.junit.jupiter.api.Assertions.*;
public class TestObjectDetectionRecordReader { public class TestObjectDetectionRecordReader {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test @Test
public void test() throws Exception { public void test(@TempDir Path testDir) throws Exception {
for(boolean nchw : new boolean[]{true, false}) { for(boolean nchw : new boolean[]{true, false}) {
ImageObjectLabelProvider lp = new TestImageObjectDetectionLabelProvider(); ImageObjectLabelProvider lp = new TestImageObjectDetectionLabelProvider();
File f = testDir.newFolder(); File f = testDir.toFile();
new ClassPathResource("datavec-data-image/objdetect/").copyDirectory(f); new ClassPathResource("datavec-data-image/objdetect/").copyDirectory(f);
String path = new File(f, "000012.jpg").getParent(); String path = new File(f, "000012.jpg").getParent();

View File

@ -21,27 +21,27 @@
package org.datavec.image.recordreader.objdetect; package org.datavec.image.recordreader.objdetect;
import org.datavec.image.recordreader.objdetect.impl.VocLabelProvider; import org.datavec.image.recordreader.objdetect.impl.VocLabelProvider;
import org.junit.Rule;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import java.io.File; import java.io.File;
import java.nio.file.Path;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestVocLabelProvider { public class TestVocLabelProvider {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test @Test
public void testVocLabelProvider() throws Exception { public void testVocLabelProvider(@TempDir Path testDir) throws Exception {
File f = testDir.newFolder(); File f = testDir.toFile();
new ClassPathResource("datavec-data-image/voc/2007/").copyDirectory(f); new ClassPathResource("datavec-data-image/voc/2007/").copyDirectory(f);
String path = f.getAbsolutePath(); //new ClassPathResource("voc/2007/JPEGImages/000005.jpg").getFile().getParentFile().getParent(); String path = f.getAbsolutePath(); //new ClassPathResource("voc/2007/JPEGImages/000005.jpg").getFile().getParentFile().getParent();

View File

@ -28,8 +28,8 @@ import org.nd4j.common.io.ClassPathResource;
import org.nd4j.common.primitives.Pair; import org.nd4j.common.primitives.Pair;
import org.datavec.image.data.ImageWritable; import org.datavec.image.data.ImageWritable;
import org.datavec.image.loader.NativeImageLoader; import org.datavec.image.loader.NativeImageLoader;
import org.junit.Ignore; import org.junit.jupiter.api.Disabled;
import org.junit.Test; import org.junit.jupiter.api.Test;
import java.awt.*; import java.awt.*;
import java.util.LinkedList; import java.util.LinkedList;
@ -40,7 +40,7 @@ import org.bytedeco.opencv.opencv_core.*;
import static org.bytedeco.opencv.global.opencv_core.*; import static org.bytedeco.opencv.global.opencv_core.*;
import static org.bytedeco.opencv.global.opencv_imgproc.*; import static org.bytedeco.opencv.global.opencv_imgproc.*;
import static org.junit.Assert.*; import static org.junit.jupiter.api.Assertions.*;
/** /**
* *
@ -255,7 +255,7 @@ public class TestImageTransform {
assertEquals(22, transformed[1], 0); assertEquals(22, transformed[1], 0);
} }
@Ignore @Disabled
@Test @Test
public void testFilterImageTransform() throws Exception { public void testFilterImageTransform() throws Exception {
ImageWritable writable = makeRandomImage(0, 0, 4); ImageWritable writable = makeRandomImage(0, 0, 4);

View File

@ -25,7 +25,7 @@ import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.primitives.Triple; import org.nd4j.common.primitives.Triple;

View File

@ -49,7 +49,7 @@ import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.Rule;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.DisplayName;

View File

@ -36,14 +36,14 @@ import org.datavec.api.writable.LongWritable;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.joda.time.DateTimeZone; import org.joda.time.DateTimeZone;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class LocalTransformProcessRecordReaderTests { public class LocalTransformProcessRecordReaderTests {

View File

@ -29,9 +29,9 @@ import org.datavec.api.transform.schema.Schema;
import org.datavec.api.util.ndarray.RecordConverter; import org.datavec.api.util.ndarray.RecordConverter;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.datavec.local.transforms.AnalyzeLocal; import org.datavec.local.transforms.AnalyzeLocal;
import org.junit.Rule;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
@ -39,12 +39,11 @@ import org.nd4j.common.io.ClassPathResource;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestAnalyzeLocal { public class TestAnalyzeLocal {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test @Test
public void testAnalysisBasic() throws Exception { public void testAnalysisBasic() throws Exception {
@ -72,7 +71,7 @@ public class TestAnalyzeLocal {
INDArray mean = arr.mean(0); INDArray mean = arr.mean(0);
INDArray std = arr.std(0); INDArray std = arr.std(0);
for( int i=0; i<5; i++ ){ for( int i = 0; i < 5; i++) {
double m = ((NumericalColumnAnalysis)da.getColumnAnalysis().get(i)).getMean(); double m = ((NumericalColumnAnalysis)da.getColumnAnalysis().get(i)).getMean();
double stddev = ((NumericalColumnAnalysis)da.getColumnAnalysis().get(i)).getSampleStdev(); double stddev = ((NumericalColumnAnalysis)da.getColumnAnalysis().get(i)).getSampleStdev();
assertEquals(mean.getDouble(i), m, 1e-3); assertEquals(mean.getDouble(i), m, 1e-3);

View File

@ -27,7 +27,7 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit; import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import java.io.File; import java.io.File;
@ -36,8 +36,8 @@ import java.util.List;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
public class TestLineRecordReaderFunction { public class TestLineRecordReaderFunction {

View File

@ -25,7 +25,7 @@ import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.datavec.local.transforms.misc.NDArrayToWritablesFunction; import org.datavec.local.transforms.misc.NDArrayToWritablesFunction;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -33,7 +33,7 @@ import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestNDArrayToWritablesFunction { public class TestNDArrayToWritablesFunction {

View File

@ -25,7 +25,7 @@ import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.datavec.local.transforms.misc.WritablesToNDArrayFunction; import org.datavec.local.transforms.misc.WritablesToNDArrayFunction;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -33,7 +33,7 @@ import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestWritablesToNDArrayFunction { public class TestWritablesToNDArrayFunction {

View File

@ -30,12 +30,12 @@ import org.datavec.api.writable.Writable;
import org.datavec.local.transforms.misc.SequenceWritablesToStringFunction; import org.datavec.local.transforms.misc.SequenceWritablesToStringFunction;
import org.datavec.local.transforms.misc.WritablesToStringFunction; import org.datavec.local.transforms.misc.WritablesToStringFunction;
import org.junit.Test; import org.junit.jupiter.api.Test;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestWritablesToStringFunctions { public class TestWritablesToStringFunctions {

View File

@ -32,7 +32,8 @@ import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.AfterClass; import org.junit.AfterClass;
import org.junit.BeforeClass; import org.junit.BeforeClass;
import org.junit.Test; import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import java.io.*; import java.io.*;
@ -40,14 +41,14 @@ import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
/** /**
* @author saudet * @author saudet
*/ */
public class TestGeoTransforms { public class TestGeoTransforms {
@BeforeClass @BeforeAll
public static void beforeClass() throws Exception { public static void beforeClass() throws Exception {
//Use test resources version to avoid tests suddenly failing due to IP/Location DB content changing //Use test resources version to avoid tests suddenly failing due to IP/Location DB content changing
File f = new ClassPathResource("datavec-geo/GeoIP2-City-Test.mmdb").getFile(); File f = new ClassPathResource("datavec-geo/GeoIP2-City-Test.mmdb").getFile();

View File

@ -30,7 +30,8 @@ import org.datavec.local.transforms.LocalTransformExecutor;
import org.datavec.api.writable.*; import org.datavec.api.writable.*;
import org.datavec.python.PythonCondition; import org.datavec.python.PythonCondition;
import org.datavec.python.PythonTransform; import org.datavec.python.PythonTransform;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -43,7 +44,7 @@ import java.util.List;
import static junit.framework.TestCase.assertTrue; import static junit.framework.TestCase.assertTrue;
import static org.datavec.api.transform.schema.Schema.Builder; import static org.datavec.api.transform.schema.Schema.Builder;
import static org.junit.Assert.*; import static org.junit.jupiter.api.Assertions.*;
@NotThreadSafe @NotThreadSafe
public class TestPythonTransformProcess { public class TestPythonTransformProcess {
@ -77,8 +78,9 @@ public class TestPythonTransformProcess {
} }
@Test(timeout = 60000L) @Test()
public void testMixedTypes() throws Exception{ @Timeout(60000L)
public void testMixedTypes() throws Exception {
Builder schemaBuilder = new Builder(); Builder schemaBuilder = new Builder();
schemaBuilder schemaBuilder
.addColumnInteger("col1") .addColumnInteger("col1")
@ -99,7 +101,7 @@ public class TestPythonTransformProcess {
.inputSchema(initialSchema) .inputSchema(initialSchema)
.build() ).build(); .build() ).build();
List<Writable> inputs = Arrays.asList((Writable)new IntWritable(10), List<Writable> inputs = Arrays.asList(new IntWritable(10),
new FloatWritable(3.5f), new FloatWritable(3.5f),
new Text("5"), new Text("5"),
new DoubleWritable(2.0) new DoubleWritable(2.0)
@ -109,8 +111,9 @@ public class TestPythonTransformProcess {
assertEquals(((LongWritable)outputs.get(4)).get(), 36); assertEquals(((LongWritable)outputs.get(4)).get(), 36);
} }
@Test(timeout = 60000L) @Test()
public void testNDArray() throws Exception{ @Timeout(60000L)
public void testNDArray() throws Exception {
long[] shape = new long[]{3, 2}; long[] shape = new long[]{3, 2};
INDArray arr1 = Nd4j.rand(shape); INDArray arr1 = Nd4j.rand(shape);
INDArray arr2 = Nd4j.rand(shape); INDArray arr2 = Nd4j.rand(shape);
@ -145,8 +148,9 @@ public class TestPythonTransformProcess {
} }
@Test(timeout = 60000L) @Test()
public void testNDArray2() throws Exception{ @Timeout(60000L)
public void testNDArray2() throws Exception {
long[] shape = new long[]{3, 2}; long[] shape = new long[]{3, 2};
INDArray arr1 = Nd4j.rand(shape); INDArray arr1 = Nd4j.rand(shape);
INDArray arr2 = Nd4j.rand(shape); INDArray arr2 = Nd4j.rand(shape);
@ -181,7 +185,8 @@ public class TestPythonTransformProcess {
} }
@Test(timeout = 60000L) @Test()
@Timeout(60000L)
public void testNDArrayMixed() throws Exception{ public void testNDArrayMixed() throws Exception{
long[] shape = new long[]{3, 2}; long[] shape = new long[]{3, 2};
INDArray arr1 = Nd4j.rand(DataType.DOUBLE, shape); INDArray arr1 = Nd4j.rand(DataType.DOUBLE, shape);
@ -217,7 +222,8 @@ public class TestPythonTransformProcess {
} }
@Test(timeout = 60000L) @Test()
@Timeout(60000L)
public void testPythonFilter() { public void testPythonFilter() {
Schema schema = new Builder().addColumnInteger("column").build(); Schema schema = new Builder().addColumnInteger("column").build();
@ -237,8 +243,9 @@ public class TestPythonTransformProcess {
} }
@Test(timeout = 60000L) @Test()
public void testPythonFilterAndTransform() throws Exception{ @Timeout(60000L)
public void testPythonFilterAndTransform() throws Exception {
Builder schemaBuilder = new Builder(); Builder schemaBuilder = new Builder();
schemaBuilder schemaBuilder
.addColumnInteger("col1") .addColumnInteger("col1")

View File

@ -28,11 +28,11 @@ import org.datavec.api.writable.*;
import org.datavec.local.transforms.LocalTransformExecutor; import org.datavec.local.transforms.LocalTransformExecutor;
import org.junit.Test; import org.junit.jupiter.api.Test;
import java.util.*; import java.util.*;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestJoin { public class TestJoin {

View File

@ -31,13 +31,13 @@ import org.datavec.api.writable.comparator.DoubleWritableComparator;
import org.datavec.local.transforms.LocalTransformExecutor; import org.datavec.local.transforms.LocalTransformExecutor;
import org.junit.Test; import org.junit.jupiter.api.Test;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestCalculateSortedRank { public class TestCalculateSortedRank {

View File

@ -31,14 +31,14 @@ import org.datavec.api.writable.Writable;
import org.datavec.arrow.recordreader.ArrowWritableRecordTimeSeriesBatch; import org.datavec.arrow.recordreader.ArrowWritableRecordTimeSeriesBatch;
import org.datavec.local.transforms.LocalTransformExecutor; import org.datavec.local.transforms.LocalTransformExecutor;
import org.junit.Test; import org.junit.jupiter.api.Test;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
public class TestConvertToSequence { public class TestConvertToSequence {

View File

@ -41,6 +41,12 @@
</properties> </properties>
<dependencies> <dependencies>
<dependency>
<groupId>com.tdunning</groupId>
<artifactId>t-digest</artifactId>
<version>3.2</version>
<scope>test</scope>
</dependency>
<dependency> <dependency>
<groupId>org.scala-lang</groupId> <groupId>org.scala-lang</groupId>
<artifactId>scala-library</artifactId> <artifactId>scala-library</artifactId>

View File

@ -25,15 +25,15 @@ import org.apache.spark.serializer.SerializerInstance;
import org.datavec.api.records.reader.RecordReader; import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit; import org.datavec.api.split.FileSplit;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import java.io.File; import java.io.File;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.Assert.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
public class TestKryoSerialization extends BaseSparkTest { public class TestKryoSerialization extends BaseSparkTest {

View File

@ -27,7 +27,7 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit; import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.datavec.spark.BaseSparkTest; import org.datavec.spark.BaseSparkTest;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import java.io.File; import java.io.File;
@ -35,8 +35,8 @@ import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
public class TestLineRecordReaderFunction extends BaseSparkTest { public class TestLineRecordReaderFunction extends BaseSparkTest {

View File

@ -24,7 +24,7 @@ import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.NDArrayWritable; import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.datavec.spark.transform.misc.NDArrayToWritablesFunction; import org.datavec.spark.transform.misc.NDArrayToWritablesFunction;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -32,7 +32,7 @@ import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestNDArrayToWritablesFunction { public class TestNDArrayToWritablesFunction {

View File

@ -38,9 +38,10 @@ import org.datavec.spark.functions.pairdata.PairSequenceRecordReaderBytesFunctio
import org.datavec.spark.functions.pairdata.PathToKeyConverter; import org.datavec.spark.functions.pairdata.PathToKeyConverter;
import org.datavec.spark.functions.pairdata.PathToKeyConverterFilename; import org.datavec.spark.functions.pairdata.PathToKeyConverterFilename;
import org.datavec.spark.util.DataVecSparkUtil; import org.datavec.spark.util.DataVecSparkUtil;
import org.junit.Rule;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import scala.Tuple2; import scala.Tuple2;
@ -50,16 +51,13 @@ import java.nio.file.Path;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.fail; import static org.junit.jupiter.api.Assertions.fail;
public class TestPairSequenceRecordReaderBytesFunction extends BaseSparkTest { public class TestPairSequenceRecordReaderBytesFunction extends BaseSparkTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test @Test
public void test() throws Exception { public void test(@TempDir Path testDir) throws Exception {
//Goal: combine separate files together into a hadoop sequence file, for later parsing by a SequenceRecordReader //Goal: combine separate files together into a hadoop sequence file, for later parsing by a SequenceRecordReader
//For example: use to combine input and labels data from separate files for training a RNN //For example: use to combine input and labels data from separate files for training a RNN
if(Platform.isWindows()) { if(Platform.isWindows()) {
@ -67,7 +65,7 @@ public class TestPairSequenceRecordReaderBytesFunction extends BaseSparkTest {
} }
JavaSparkContext sc = getContext(); JavaSparkContext sc = getContext();
File f = testDir.newFolder(); File f = testDir.toFile();
new ClassPathResource("datavec-spark/video/").copyDirectory(f); new ClassPathResource("datavec-spark/video/").copyDirectory(f);
String path = f.getAbsolutePath() + "/*"; String path = f.getAbsolutePath() + "/*";

View File

@ -36,9 +36,10 @@ import org.datavec.image.recordreader.ImageRecordReader;
import org.datavec.spark.BaseSparkTest; import org.datavec.spark.BaseSparkTest;
import org.datavec.spark.functions.data.FilesAsBytesFunction; import org.datavec.spark.functions.data.FilesAsBytesFunction;
import org.datavec.spark.functions.data.RecordReaderBytesFunction; import org.datavec.spark.functions.data.RecordReaderBytesFunction;
import org.junit.Rule;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import java.io.File; import java.io.File;
@ -48,23 +49,22 @@ import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.fail; import static org.junit.jupiter.api.Assertions.fail;
public class TestRecordReaderBytesFunction extends BaseSparkTest { public class TestRecordReaderBytesFunction extends BaseSparkTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test @Test
public void testRecordReaderBytesFunction() throws Exception { public void testRecordReaderBytesFunction(@TempDir Path testDir) throws Exception {
if(Platform.isWindows()) { if(Platform.isWindows()) {
return; return;
} }
JavaSparkContext sc = getContext(); JavaSparkContext sc = getContext();
//Local file path //Local file path
File f = testDir.newFolder(); File f = testDir.toFile();
new ClassPathResource("datavec-spark/imagetest/").copyDirectory(f); new ClassPathResource("datavec-spark/imagetest/").copyDirectory(f);
List<String> labelsList = Arrays.asList("0", "1"); //Need this for Spark: can't infer without init call List<String> labelsList = Arrays.asList("0", "1"); //Need this for Spark: can't infer without init call
String path = f.getAbsolutePath() + "/*"; String path = f.getAbsolutePath() + "/*";

View File

@ -31,30 +31,29 @@ import org.datavec.api.writable.ArrayWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.datavec.image.recordreader.ImageRecordReader; import org.datavec.image.recordreader.ImageRecordReader;
import org.datavec.spark.BaseSparkTest; import org.datavec.spark.BaseSparkTest;
import org.junit.Rule;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import java.io.File; import java.io.File;
import java.nio.file.Path;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.fail; import static org.junit.jupiter.api.Assertions.fail;
public class TestRecordReaderFunction extends BaseSparkTest { public class TestRecordReaderFunction extends BaseSparkTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test @Test
public void testRecordReaderFunction() throws Exception { public void testRecordReaderFunction(@TempDir Path testDir) throws Exception {
if(Platform.isWindows()) { if(Platform.isWindows()) {
return; return;
} }
File f = testDir.newFolder(); File f = testDir.toFile();
new ClassPathResource("datavec-spark/imagetest/").copyDirectory(f); new ClassPathResource("datavec-spark/imagetest/").copyDirectory(f);
List<String> labelsList = Arrays.asList("0", "1"); //Need this for Spark: can't infer without init call List<String> labelsList = Arrays.asList("0", "1"); //Need this for Spark: can't infer without init call

View File

@ -36,9 +36,10 @@ import org.datavec.codec.reader.CodecRecordReader;
import org.datavec.spark.BaseSparkTest; import org.datavec.spark.BaseSparkTest;
import org.datavec.spark.functions.data.FilesAsBytesFunction; import org.datavec.spark.functions.data.FilesAsBytesFunction;
import org.datavec.spark.functions.data.SequenceRecordReaderBytesFunction; import org.datavec.spark.functions.data.SequenceRecordReaderBytesFunction;
import org.junit.Rule;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import java.io.File; import java.io.File;
@ -47,21 +48,20 @@ import java.nio.file.Path;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.fail; import static org.junit.jupiter.api.Assertions.fail;
public class TestSequenceRecordReaderBytesFunction extends BaseSparkTest { public class TestSequenceRecordReaderBytesFunction extends BaseSparkTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test @Test
public void testRecordReaderBytesFunction() throws Exception { public void testRecordReaderBytesFunction(@TempDir Path testDir) throws Exception {
if(Platform.isWindows()) { if(Platform.isWindows()) {
return; return;
} }
//Local file path //Local file path
File f = testDir.newFolder(); File f = testDir.toFile();
new ClassPathResource("datavec-spark/video/").copyDirectory(f); new ClassPathResource("datavec-spark/video/").copyDirectory(f);
String path = f.getAbsolutePath() + "/*"; String path = f.getAbsolutePath() + "/*";

View File

@ -33,28 +33,29 @@ import org.datavec.api.writable.ArrayWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.datavec.codec.reader.CodecRecordReader; import org.datavec.codec.reader.CodecRecordReader;
import org.datavec.spark.BaseSparkTest; import org.datavec.spark.BaseSparkTest;
import org.junit.Rule;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import java.io.File; import java.io.File;
import java.nio.file.Path;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.fail; import static org.junit.jupiter.api.Assertions.fail;
public class TestSequenceRecordReaderFunction extends BaseSparkTest { public class TestSequenceRecordReaderFunction extends BaseSparkTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test @Test
public void testSequenceRecordReaderFunctionCSV() throws Exception { public void testSequenceRecordReaderFunctionCSV(@TempDir Path testDir) throws Exception {
JavaSparkContext sc = getContext(); JavaSparkContext sc = getContext();
File f = testDir.newFolder(); File f = testDir.toFile();
new ClassPathResource("datavec-spark/csvsequence/").copyDirectory(f); new ClassPathResource("datavec-spark/csvsequence/").copyDirectory(f);
String path = f.getAbsolutePath() + "/*"; String path = f.getAbsolutePath() + "/*";
@ -120,10 +121,10 @@ public class TestSequenceRecordReaderFunction extends BaseSparkTest {
@Test @Test
public void testSequenceRecordReaderFunctionVideo() throws Exception { public void testSequenceRecordReaderFunctionVideo(@TempDir Path testDir) throws Exception {
JavaSparkContext sc = getContext(); JavaSparkContext sc = getContext();
File f = testDir.newFolder(); File f = testDir.toFile();
new ClassPathResource("datavec-spark/video/").copyDirectory(f); new ClassPathResource("datavec-spark/video/").copyDirectory(f);
String path = f.getAbsolutePath() + "/*"; String path = f.getAbsolutePath() + "/*";

View File

@ -22,7 +22,7 @@ package org.datavec.spark.functions;
import org.datavec.api.writable.*; import org.datavec.api.writable.*;
import org.datavec.spark.transform.misc.WritablesToNDArrayFunction; import org.datavec.spark.transform.misc.WritablesToNDArrayFunction;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -30,7 +30,7 @@ import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestWritablesToNDArrayFunction { public class TestWritablesToNDArrayFunction {

View File

@ -29,14 +29,14 @@ import org.datavec.api.writable.Writable;
import org.datavec.spark.BaseSparkTest; import org.datavec.spark.BaseSparkTest;
import org.datavec.spark.transform.misc.SequenceWritablesToStringFunction; import org.datavec.spark.transform.misc.SequenceWritablesToStringFunction;
import org.datavec.spark.transform.misc.WritablesToStringFunction; import org.datavec.spark.transform.misc.WritablesToStringFunction;
import org.junit.Test; import org.junit.jupiter.api.Test;
import scala.Tuple2; import scala.Tuple2;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestWritablesToStringFunctions extends BaseSparkTest { public class TestWritablesToStringFunctions extends BaseSparkTest {

View File

@ -26,7 +26,7 @@ import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaRDD;
import org.datavec.api.writable.*; import org.datavec.api.writable.*;
import org.datavec.spark.BaseSparkTest; import org.datavec.spark.BaseSparkTest;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import java.io.File; import java.io.File;
@ -35,8 +35,8 @@ import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
public class TestSparkStorageUtils extends BaseSparkTest { public class TestSparkStorageUtils extends BaseSparkTest {

View File

@ -30,13 +30,13 @@ import org.datavec.api.util.ndarray.RecordConverter;
import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.datavec.spark.BaseSparkTest; import org.datavec.spark.BaseSparkTest;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import java.util.*; import java.util.*;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class DataFramesTests extends BaseSparkTest { public class DataFramesTests extends BaseSparkTest {

View File

@ -28,7 +28,7 @@ import org.datavec.api.util.ndarray.RecordConverter;
import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.datavec.spark.BaseSparkTest; import org.datavec.spark.BaseSparkTest;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
@ -41,7 +41,7 @@ import java.util.ArrayList;
import java.util.List; import java.util.List;
import static junit.framework.TestCase.assertTrue; import static junit.framework.TestCase.assertTrue;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class NormalizationTests extends BaseSparkTest { public class NormalizationTests extends BaseSparkTest {

View File

@ -38,7 +38,7 @@ import org.datavec.local.transforms.AnalyzeLocal;
import org.datavec.spark.BaseSparkTest; import org.datavec.spark.BaseSparkTest;
import org.datavec.spark.transform.AnalyzeSpark; import org.datavec.spark.transform.AnalyzeSpark;
import org.joda.time.DateTimeZone; import org.joda.time.DateTimeZone;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
@ -47,7 +47,7 @@ import java.io.File;
import java.nio.file.Files; import java.nio.file.Files;
import java.util.*; import java.util.*;
import static org.junit.Assert.*; import static org.junit.jupiter.api.Assertions.*;
public class TestAnalysis extends BaseSparkTest { public class TestAnalysis extends BaseSparkTest {

View File

@ -27,11 +27,11 @@ import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.*; import org.datavec.api.writable.*;
import org.datavec.spark.BaseSparkTest; import org.datavec.spark.BaseSparkTest;
import org.datavec.spark.transform.SparkTransformExecutor; import org.datavec.spark.transform.SparkTransformExecutor;
import org.junit.Test; import org.junit.jupiter.api.Test;
import java.util.*; import java.util.*;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestJoin extends BaseSparkTest { public class TestJoin extends BaseSparkTest {

View File

@ -30,13 +30,13 @@ import org.datavec.api.writable.Writable;
import org.datavec.api.writable.comparator.DoubleWritableComparator; import org.datavec.api.writable.comparator.DoubleWritableComparator;
import org.datavec.spark.BaseSparkTest; import org.datavec.spark.BaseSparkTest;
import org.datavec.spark.transform.SparkTransformExecutor; import org.datavec.spark.transform.SparkTransformExecutor;
import org.junit.Test; import org.junit.jupiter.api.Test;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestCalculateSortedRank extends BaseSparkTest { public class TestCalculateSortedRank extends BaseSparkTest {

View File

@ -29,14 +29,14 @@ import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.datavec.spark.BaseSparkTest; import org.datavec.spark.BaseSparkTest;
import org.datavec.spark.transform.SparkTransformExecutor; import org.datavec.spark.transform.SparkTransformExecutor;
import org.junit.Test; import org.junit.jupiter.api.Test;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
public class TestConvertToSequence extends BaseSparkTest { public class TestConvertToSequence extends BaseSparkTest {

View File

@ -28,7 +28,7 @@ import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.datavec.spark.BaseSparkTest; import org.datavec.spark.BaseSparkTest;
import org.datavec.spark.transform.utils.SparkUtils; import org.datavec.spark.transform.utils.SparkUtils;
import org.junit.Test; import org.junit.jupiter.api.Test;
import java.io.File; import java.io.File;
import java.io.FileInputStream; import java.io.FileInputStream;
@ -36,7 +36,7 @@ import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestSparkUtil extends BaseSparkTest { public class TestSparkUtil extends BaseSparkTest {

View File

@ -20,7 +20,6 @@
package org.deeplearning4j; package org.deeplearning4j;
import ch.qos.logback.classic.LoggerContext; import ch.qos.logback.classic.LoggerContext;
import lombok.extern.slf4j.Slf4j;
import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.Pointer;
import org.junit.jupiter.api.*; import org.junit.jupiter.api.*;
@ -32,6 +31,7 @@ import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.profiler.ProfilerConfig; import org.nd4j.linalg.profiler.ProfilerConfig;
import org.slf4j.ILoggerFactory; import org.slf4j.ILoggerFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import java.lang.management.ManagementFactory; import java.lang.management.ManagementFactory;
import java.util.List; import java.util.List;
@ -39,13 +39,11 @@ import java.util.Map;
import java.util.Properties; import java.util.Properties;
import static org.junit.jupiter.api.Assumptions.assumeTrue; import static org.junit.jupiter.api.Assumptions.assumeTrue;
import org.junit.jupiter.api.extension.ExtendWith;
@Slf4j
@DisplayName("Base DL 4 J Test") @DisplayName("Base DL 4 J Test")
public abstract class BaseDL4JTest { public abstract class BaseDL4JTest {
private static Logger log = LoggerFactory.getLogger(BaseDL4JTest.class.getName());
protected long startTime; protected long startTime;

View File

@ -43,7 +43,7 @@ import java.lang.reflect.Field;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.util.*; import java.util.*;
import static org.junit.Assert.*; import static org.junit.jupiter.api.Assertions.*;
@Slf4j @Slf4j
public class LayerHelperValidationUtil { public class LayerHelperValidationUtil {
@ -145,7 +145,7 @@ public class LayerHelperValidationUtil {
System.out.println(p1); System.out.println(p1);
System.out.println(p2); System.out.println(p2);
} }
assertTrue(s + " - param changed during forward pass: " + p, maxRE < t.getMaxRelError()); assertTrue(maxRE < t.getMaxRelError(),s + " - param changed during forward pass: " + p);
} }
for( int i=0; i<ff1.size(); i++ ){ for( int i=0; i<ff1.size(); i++ ){
@ -163,7 +163,7 @@ public class LayerHelperValidationUtil {
double d2 = arr2.dup('c').getDouble(idx); double d2 = arr2.dup('c').getDouble(idx);
System.out.println("Different values at index " + idx + ": " + d1 + ", " + d2 + " - RE = " + maxRE); System.out.println("Different values at index " + idx + ": " + d1 + ", " + d2 + " - RE = " + maxRE);
} }
assertTrue(s + layerName + " activations - max RE: " + maxRE, maxRE < t.getMaxRelError()); assertTrue(maxRE < t.getMaxRelError(), s + layerName + " activations - max RE: " + maxRE);
log.info("Forward pass, max relative error: " + layerName + " - " + maxRE); log.info("Forward pass, max relative error: " + layerName + " - " + maxRE);
} }
@ -180,7 +180,7 @@ public class LayerHelperValidationUtil {
log.info(s + "Output, max relative error: " + maxRE); log.info(s + "Output, max relative error: " + maxRE);
assertEquals(net1NoHelper.params(), net2With.params()); //Check that forward pass does not modify params assertEquals(net1NoHelper.params(), net2With.params()); //Check that forward pass does not modify params
assertTrue(s + "Max RE: " + maxRE, maxRE < t.getMaxRelError()); assertTrue(maxRE < t.getMaxRelError(), s + "Max RE: " + maxRE);
} }
} }
@ -201,7 +201,7 @@ public class LayerHelperValidationUtil {
double re = relError(s1, s2); double re = relError(s1, s2);
String s = "Relative error: " + re; String s = "Relative error: " + re;
assertTrue(s, re < t.getMaxRelError()); assertTrue(re < t.getMaxRelError(), s);
} }
if(t.isTestBackward()) { if(t.isTestBackward()) {
@ -243,8 +243,8 @@ public class LayerHelperValidationUtil {
} else { } else {
System.out.println("OK: " + p); System.out.println("OK: " + p);
} }
assertTrue(t.getTestName() + " - Gradients are not equal: " + p + " - highest relative error = " + maxRE + " > max relative error = " + t.getMaxRelError(), assertTrue(maxRE < t.getMaxRelError(),
maxRE < t.getMaxRelError()); t.getTestName() + " - Gradients are not equal: " + p + " - highest relative error = " + maxRE + " > max relative error = " + t.getMaxRelError());
} }
} }
@ -283,7 +283,7 @@ public class LayerHelperValidationUtil {
double d2 = listNew.get(j); double d2 = listNew.get(j);
double re = relError(d1, d2); double re = relError(d1, d2);
String msg = "Scores at iteration " + j + " - relError = " + re + ", score1 = " + d1 + ", score2 = " + d2; String msg = "Scores at iteration " + j + " - relError = " + re + ", score1 = " + d1 + ", score2 = " + d2;
assertTrue(msg, re < t.getMaxRelError()); assertTrue(re < t.getMaxRelError(), msg);
System.out.println("j=" + j + ", d1 = " + d1 + ", d2 = " + d2); System.out.println("j=" + j + ", d1 = " + d1 + ", d2 = " + d2);
} }
} }
@ -315,7 +315,7 @@ public class LayerHelperValidationUtil {
try { try {
if (keepAndAssertPresent) { if (keepAndAssertPresent) {
Object o = f.get(l); Object o = f.get(l);
assertNotNull("Expect helper to be present for layer: " + l.getClass(), o); assertNotNull(o,"Expect helper to be present for layer: " + l.getClass());
} else { } else {
f.set(l, null); f.set(l, null);
Integer i = map.get(l.getClass()); Integer i = map.get(l.getClass());

View File

@ -26,8 +26,8 @@ import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Ignore; import org.junit.jupiter.api.Disabled;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.resources.Resources; import org.nd4j.common.resources.Resources;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
@ -38,7 +38,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.nio.file.Files; import java.nio.file.Files;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
@Ignore @Disabled
public class RandomTests extends BaseDL4JTest { public class RandomTests extends BaseDL4JTest {
@Test @Test

View File

@ -50,8 +50,8 @@ import java.lang.reflect.Field;
import java.util.List; import java.util.List;
import java.util.Random; import java.util.Random;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNotNull;
public class TestUtils { public class TestUtils {

View File

@ -27,7 +27,7 @@ import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.junit.rules.Timeout;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition; import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;

View File

@ -23,7 +23,7 @@ package org.deeplearning4j.datasets;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.fetchers.Cifar10Fetcher; import org.deeplearning4j.datasets.fetchers.Cifar10Fetcher;
import org.deeplearning4j.datasets.fetchers.TinyImageNetFetcher; import org.deeplearning4j.datasets.fetchers.TinyImageNetFetcher;
import org.junit.Test; import org.junit.jupiter.api.Test;
public class TestDataSets extends BaseDL4JTest { public class TestDataSets extends BaseDL4JTest {

View File

@ -19,7 +19,7 @@
*/ */
package org.deeplearning4j.datasets.datavec; package org.deeplearning4j.datasets.datavec;
import org.junit.rules.Timeout;
import org.nd4j.shade.guava.io.Files; import org.nd4j.shade.guava.io.Files;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
@ -47,7 +47,7 @@ import org.deeplearning4j.datasets.datavec.exception.ZeroLengthSequenceException
import org.deeplearning4j.datasets.datavec.tools.SpecialImageRecordReader; import org.deeplearning4j.datasets.datavec.tools.SpecialImageRecordReader;
import org.nd4j.linalg.dataset.AsyncDataSetIterator; import org.nd4j.linalg.dataset.AsyncDataSetIterator;
import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Disabled;
import org.junit.Rule;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
@ -74,9 +74,6 @@ import static org.junit.jupiter.api.Assertions.assertThrows;
@DisplayName("Record Reader Data Setiterator Test") @DisplayName("Record Reader Data Setiterator Test")
class RecordReaderDataSetiteratorTest extends BaseDL4JTest { class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
@Rule
public Timeout timeout = Timeout.seconds(300);
@Override @Override
public DataType getDataType() { public DataType getDataType() {
return DataType.FLOAT; return DataType.FLOAT;

View File

@ -19,7 +19,7 @@
*/ */
package org.deeplearning4j.datasets.datavec; package org.deeplearning4j.datasets.datavec;
import org.junit.rules.Timeout;
import org.nd4j.shade.guava.io.Files; import org.nd4j.shade.guava.io.Files;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils; import org.apache.commons.io.FilenameUtils;
@ -44,7 +44,7 @@ import org.datavec.api.writable.Writable;
import org.datavec.image.recordreader.ImageRecordReader; import org.datavec.image.recordreader.ImageRecordReader;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils; import org.deeplearning4j.TestUtils;
import org.junit.Rule;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -73,8 +73,7 @@ class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest {
@TempDir @TempDir
public Path temporaryFolder; public Path temporaryFolder;
@Rule
public Timeout timeout = Timeout.seconds(300);
@Test @Test
@DisplayName("Tests Basic") @DisplayName("Tests Basic")

View File

@ -20,9 +20,9 @@
package org.deeplearning4j.datasets.fetchers; package org.deeplearning4j.datasets.fetchers;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.junit.Rule;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.Timeout;
import java.io.File; import java.io.File;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assumptions.assumeTrue; import static org.junit.jupiter.api.Assumptions.assumeTrue;

View File

@ -21,14 +21,14 @@
package org.deeplearning4j.datasets.iterator; package org.deeplearning4j.datasets.iterator;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor; import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
import org.nd4j.linalg.dataset.api.preprocessor.MultiNormalizerMinMaxScaler; import org.nd4j.linalg.dataset.api.preprocessor.MultiNormalizerMinMaxScaler;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class CombinedPreProcessorTests extends BaseDL4JTest { public class CombinedPreProcessorTests extends BaseDL4JTest {

View File

@ -23,7 +23,7 @@ package org.deeplearning4j.datasets.iterator;
import lombok.val; import lombok.val;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.tools.DataSetGenerator; import org.deeplearning4j.datasets.iterator.tools.DataSetGenerator;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -32,7 +32,7 @@ import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Random; import java.util.Random;
import static org.junit.Assert.*; import static org.junit.jupiter.api.Assertions.*;
public class DataSetSplitterTests extends BaseDL4JTest { public class DataSetSplitterTests extends BaseDL4JTest {
@Test @Test
@ -54,7 +54,7 @@ public class DataSetSplitterTests extends BaseDL4JTest {
while (train.hasNext()) { while (train.hasNext()) {
val data = train.next().getFeatures(); val data = train.next().getFeatures();
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5); assertEquals( (float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
gcntTrain++; gcntTrain++;
global++; global++;
} }
@ -64,7 +64,7 @@ public class DataSetSplitterTests extends BaseDL4JTest {
while (test.hasNext()) { while (test.hasNext()) {
val data = test.next().getFeatures(); val data = test.next().getFeatures();
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5); assertEquals( (float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
gcntTest++; gcntTest++;
global++; global++;
} }
@ -94,7 +94,7 @@ public class DataSetSplitterTests extends BaseDL4JTest {
while (train.hasNext()) { while (train.hasNext()) {
val data = train.next().getFeatures(); val data = train.next().getFeatures();
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5); assertEquals((float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
gcntTrain++; gcntTrain++;
global++; global++;
} }
@ -104,7 +104,7 @@ public class DataSetSplitterTests extends BaseDL4JTest {
if (e % 2 == 0) if (e % 2 == 0)
while (test.hasNext()) { while (test.hasNext()) {
val data = test.next().getFeatures(); val data = test.next().getFeatures();
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5); assertEquals( (float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
gcntTest++; gcntTest++;
global++; global++;
} }
@ -113,8 +113,9 @@ public class DataSetSplitterTests extends BaseDL4JTest {
assertEquals(700 * numEpochs + (300 * numEpochs / 2), global); assertEquals(700 * numEpochs + (300 * numEpochs / 2), global);
} }
@Test(expected = ND4JIllegalStateException.class) @Test()
public void testSplitter_3() throws Exception { public void testSplitter_3() throws Exception {
assertThrows(ND4JIllegalStateException.class, () -> {
val back = new DataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5}); val back = new DataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
val splitter = new DataSetIteratorSplitter(back, 1000, 0.7); val splitter = new DataSetIteratorSplitter(back, 1000, 0.7);
@ -132,7 +133,7 @@ public class DataSetSplitterTests extends BaseDL4JTest {
while (train.hasNext()) { while (train.hasNext()) {
val data = train.next().getFeatures(); val data = train.next().getFeatures();
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5); assertEquals((float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
gcntTrain++; gcntTrain++;
global++; global++;
} }
@ -142,7 +143,7 @@ public class DataSetSplitterTests extends BaseDL4JTest {
while (test.hasNext()) { while (test.hasNext()) {
val data = test.next().getFeatures(); val data = test.next().getFeatures();
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5); assertEquals((float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
gcntTest++; gcntTest++;
global++; global++;
} }
@ -153,6 +154,9 @@ public class DataSetSplitterTests extends BaseDL4JTest {
} }
assertEquals(1000 * numEpochs, global); assertEquals(1000 * numEpochs, global);
});
} }
@Test @Test
@ -172,8 +176,7 @@ public class DataSetSplitterTests extends BaseDL4JTest {
partIterator.reset(); partIterator.reset();
while (partIterator.hasNext()) { while (partIterator.hasNext()) {
val data = partIterator.next().getFeatures(); val data = partIterator.next().getFeatures();
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, assertEquals((float) perEpoch, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
(float) perEpoch, data.getFloat(0), 1e-5);
//gcntTrain++; //gcntTrain++;
global++; global++;
cnt++; cnt++;
@ -206,8 +209,7 @@ public class DataSetSplitterTests extends BaseDL4JTest {
int cnt = 0; int cnt = 0;
val data = partIterator.next().getFeatures(); val data = partIterator.next().getFeatures();
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, assertEquals((float) perEpoch, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
(float) perEpoch, data.getFloat(0), 1e-5);
//gcntTrain++; //gcntTrain++;
global++; global++;
cnt++; cnt++;
@ -247,10 +249,10 @@ public class DataSetSplitterTests extends BaseDL4JTest {
val ds = trainIter.next(); val ds = trainIter.next();
assertNotNull(ds); assertNotNull(ds);
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures().getDouble(0), 1e-5f); assertEquals(globalIter, ds.getFeatures().getDouble(0), 1e-5f,"Failed at iteration [" + globalIter + "]");
globalIter++; globalIter++;
} }
assertTrue("Failed at epoch [" + e + "]", trained); assertTrue(trained,"Failed at epoch [" + e + "]");
assertEquals(800, globalIter); assertEquals(800, globalIter);
@ -262,10 +264,10 @@ public class DataSetSplitterTests extends BaseDL4JTest {
val ds = testIter.next(); val ds = testIter.next();
assertNotNull(ds); assertNotNull(ds);
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures().getDouble(0), 1e-5f); assertEquals(globalIter, ds.getFeatures().getDouble(0), 1e-5f,"Failed at iteration [" + globalIter + "]");
globalIter++; globalIter++;
} }
assertTrue("Failed at epoch [" + e + "]", tested); assertTrue(tested,"Failed at epoch [" + e + "]");
assertEquals(900, globalIter); assertEquals(900, globalIter);
// validation set is used every 5 epochs // validation set is used every 5 epochs
@ -277,10 +279,10 @@ public class DataSetSplitterTests extends BaseDL4JTest {
val ds = validationIter.next(); val ds = validationIter.next();
assertNotNull(ds); assertNotNull(ds);
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures().getDouble(0), 1e-5f); assertEquals(globalIter, ds.getFeatures().getDouble(0), 1e-5f,"Failed at iteration [" + globalIter + "]");
globalIter++; globalIter++;
} }
assertTrue("Failed at epoch [" + e + "]", validated); assertTrue(validated,"Failed at epoch [" + e + "]");
} }
// all 3 iterators have exactly 1000 elements combined // all 3 iterators have exactly 1000 elements combined
@ -312,7 +314,7 @@ public class DataSetSplitterTests extends BaseDL4JTest {
int farCnt = (1000 / 2) * (partNumber) + cnt; int farCnt = (1000 / 2) * (partNumber) + cnt;
val data = iteratorList.get(partNumber).next().getFeatures(); val data = iteratorList.get(partNumber).next().getFeatures();
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) farCnt, data.getFloat(0), 1e-5); assertEquals((float) farCnt, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
cnt++; cnt++;
global++; global++;
} }
@ -322,7 +324,7 @@ public class DataSetSplitterTests extends BaseDL4JTest {
while (iteratorList.get(0).hasNext()) { while (iteratorList.get(0).hasNext()) {
val data = iteratorList.get(0).next().getFeatures(); val data = iteratorList.get(0).next().getFeatures();
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5); assertEquals((float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
global++; global++;
} }
} }
@ -341,7 +343,7 @@ public class DataSetSplitterTests extends BaseDL4JTest {
while (iteratorList.get(partNumber).hasNext()) { while (iteratorList.get(partNumber).hasNext()) {
val data = iteratorList.get(partNumber).next().getFeatures(); val data = iteratorList.get(partNumber).next().getFeatures();
assertEquals("Train failed on iteration " + cnt, (float) (500*partNumber + cnt), data.getFloat(0), 1e-5); assertEquals( (float) (500*partNumber + cnt), data.getFloat(0), 1e-5,"Train failed on iteration " + cnt);
cnt++; cnt++;
} }
} }
@ -365,7 +367,7 @@ public class DataSetSplitterTests extends BaseDL4JTest {
while (iteratorList.get(partNumber).hasNext()) { while (iteratorList.get(partNumber).hasNext()) {
val data = iteratorList.get(partNumber).next().getFeatures(); val data = iteratorList.get(partNumber).next().getFeatures();
assertEquals("Train failed on iteration " + cnt, (float) (500*partNumber + cnt), data.getFloat(0), 1e-5); assertEquals( (float) (500*partNumber + cnt), data.getFloat(0), 1e-5,"Train failed on iteration " + cnt);
cnt++; cnt++;
} }
} }
@ -390,7 +392,7 @@ public class DataSetSplitterTests extends BaseDL4JTest {
val ds = validationIter.next(); val ds = validationIter.next();
assertNotNull(ds); assertNotNull(ds);
assertEquals("Validation failed on iteration " + valCnt, (float) valCnt + 90, ds.getFeatures().getFloat(0), 1e-5); assertEquals((float) valCnt + 90, ds.getFeatures().getFloat(0), 1e-5,"Validation failed on iteration " + valCnt);
valCnt++; valCnt++;
} }
assertEquals(5, valCnt); assertEquals(5, valCnt);

View File

@ -25,15 +25,15 @@ import lombok.val;
import lombok.var; import lombok.var;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.tools.SimpleVariableGenerator; import org.deeplearning4j.datasets.iterator.tools.SimpleVariableGenerator;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.dataset.api.DataSet;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.Assert.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
@Slf4j @Slf4j
public class DummyBlockDataSetIteratorTests extends BaseDL4JTest { public class DummyBlockDataSetIteratorTests extends BaseDL4JTest {

View File

@ -21,7 +21,7 @@ package org.deeplearning4j.datasets.iterator;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.junit.Rule;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.ExpectedException; import org.junit.rules.ExpectedException;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
@ -43,8 +43,7 @@ class EarlyTerminationDataSetIteratorTest extends BaseDL4JTest {
int numExamples = 105; int numExamples = 105;
@Rule
public final ExpectedException exception = ExpectedException.none();
@Test @Test
@DisplayName("Test Next And Reset") @DisplayName("Test Next And Reset")
@ -86,14 +85,16 @@ class EarlyTerminationDataSetIteratorTest extends BaseDL4JTest {
} }
@Test @Test
@DisplayName("Test Callsto Next Not Allowed") @DisplayName("Test calls to Next Not Allowed")
void testCallstoNextNotAllowed() throws IOException { void testCallstoNextNotAllowed() throws IOException {
assertThrows(RuntimeException.class,() -> {
int terminateAfter = 1; int terminateAfter = 1;
DataSetIterator iter = new MnistDataSetIterator(minibatchSize, numExamples); DataSetIterator iter = new MnistDataSetIterator(minibatchSize, numExamples);
EarlyTerminationDataSetIterator earlyEndIter = new EarlyTerminationDataSetIterator(iter, terminateAfter); EarlyTerminationDataSetIterator earlyEndIter = new EarlyTerminationDataSetIterator(iter, terminateAfter);
earlyEndIter.next(10); earlyEndIter.next(10);
iter.reset(); iter.reset();
exception.expect(RuntimeException.class);
earlyEndIter.next(10); earlyEndIter.next(10);
});
} }
} }

View File

@ -21,7 +21,7 @@ package org.deeplearning4j.datasets.iterator;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.junit.Rule;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.ExpectedException; import org.junit.rules.ExpectedException;
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter; import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
@ -30,11 +30,12 @@ import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.jupiter.api.Assertions.*;
@DisplayName("Early Termination Multi Data Set Iterator Test") @DisplayName("Early Termination Multi Data Set Iterator Test")
class EarlyTerminationMultiDataSetIteratorTest extends BaseDL4JTest { class EarlyTerminationMultiDataSetIteratorTest extends BaseDL4JTest {
@ -42,8 +43,7 @@ class EarlyTerminationMultiDataSetIteratorTest extends BaseDL4JTest {
int numExamples = 105; int numExamples = 105;
@Rule
public final ExpectedException exception = ExpectedException.none();
@Test @Test
@DisplayName("Test Next And Reset") @DisplayName("Test Next And Reset")
@ -91,14 +91,16 @@ class EarlyTerminationMultiDataSetIteratorTest extends BaseDL4JTest {
} }
@Test @Test
@DisplayName("Test Callsto Next Not Allowed") @DisplayName("Test calls to Next Not Allowed")
void testCallstoNextNotAllowed() throws IOException { void testCallstoNextNotAllowed() throws IOException {
assertThrows(RuntimeException.class,() -> {
int terminateAfter = 1; int terminateAfter = 1;
MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(new MnistDataSetIterator(minibatchSize, numExamples)); MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(new MnistDataSetIterator(minibatchSize, numExamples));
EarlyTerminationMultiDataSetIterator earlyEndIter = new EarlyTerminationMultiDataSetIterator(iter, terminateAfter); EarlyTerminationMultiDataSetIterator earlyEndIter = new EarlyTerminationMultiDataSetIterator(iter, terminateAfter);
earlyEndIter.next(10); earlyEndIter.next(10);
iter.reset(); iter.reset();
exception.expect(RuntimeException.class);
earlyEndIter.next(10); earlyEndIter.next(10);
});
} }
} }

View File

@ -24,14 +24,16 @@ import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.tools.DataSetGenerator; import org.deeplearning4j.datasets.iterator.tools.DataSetGenerator;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import static org.junit.Assert.*; import static org.junit.jupiter.api.Assertions.*;
@Slf4j @Slf4j
public class JointMultiDataSetIteratorTests extends BaseDL4JTest { public class JointMultiDataSetIteratorTests extends BaseDL4JTest {
@Test (timeout = 20000L) @Test ()
@Timeout(20000L)
public void testJMDSI_1() { public void testJMDSI_1() {
val iter0 = new DataSetGenerator(32, new int[]{3, 3}, new int[]{2, 2}); val iter0 = new DataSetGenerator(32, new int[]{3, 3}, new int[]{2, 2});
val iter1 = new DataSetGenerator(32, new int[]{3, 3, 3}, new int[]{2, 2, 2}); val iter1 = new DataSetGenerator(32, new int[]{3, 3, 3}, new int[]{2, 2, 2});
@ -75,7 +77,8 @@ public class JointMultiDataSetIteratorTests extends BaseDL4JTest {
} }
@Test (timeout = 20000L) @Test ()
@Timeout(20000L)
public void testJMDSI_2() { public void testJMDSI_2() {
val iter0 = new DataSetGenerator(32, new int[]{3, 3}, new int[]{2, 2}); val iter0 = new DataSetGenerator(32, new int[]{3, 3}, new int[]{2, 2});
val iter1 = new DataSetGenerator(32, new int[]{3, 3, 3}, new int[]{2, 2, 2}); val iter1 = new DataSetGenerator(32, new int[]{3, 3, 3}, new int[]{2, 2, 2});

View File

@ -23,7 +23,7 @@ package org.deeplearning4j.datasets.iterator;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.loader.DataSetLoaderIterator; import org.deeplearning4j.datasets.iterator.loader.DataSetLoaderIterator;
import org.deeplearning4j.datasets.iterator.loader.MultiDataSetLoaderIterator; import org.deeplearning4j.datasets.iterator.loader.MultiDataSetLoaderIterator;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.loader.Loader; import org.nd4j.common.loader.Loader;
import org.nd4j.common.loader.LocalFileSourceFactory; import org.nd4j.common.loader.LocalFileSourceFactory;
import org.nd4j.common.loader.Source; import org.nd4j.common.loader.Source;
@ -39,8 +39,8 @@ import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.Random; import java.util.Random;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
public class LoaderIteratorTests extends BaseDL4JTest { public class LoaderIteratorTests extends BaseDL4JTest {

View File

@ -24,7 +24,7 @@ import lombok.val;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.tools.DataSetGenerator; import org.deeplearning4j.datasets.iterator.tools.DataSetGenerator;
import org.deeplearning4j.datasets.iterator.tools.MultiDataSetGenerator; import org.deeplearning4j.datasets.iterator.tools.MultiDataSetGenerator;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.exception.ND4JIllegalStateException;
@ -32,7 +32,7 @@ import org.nd4j.linalg.exception.ND4JIllegalStateException;
import java.util.List; import java.util.List;
import java.util.Random; import java.util.Random;
import static org.junit.Assert.*; import static org.junit.jupiter.api.Assertions.*;
public class MultiDataSetSplitterTests extends BaseDL4JTest { public class MultiDataSetSplitterTests extends BaseDL4JTest {
@ -55,7 +55,7 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
while (train.hasNext()) { while (train.hasNext()) {
val data = train.next().getFeatures(0); val data = train.next().getFeatures(0);
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5); assertEquals( (float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
gcntTrain++; gcntTrain++;
global++; global++;
} }
@ -65,7 +65,7 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
while (test.hasNext()) { while (test.hasNext()) {
val data = test.next().getFeatures(0); val data = test.next().getFeatures(0);
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5); assertEquals( (float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
gcntTest++; gcntTest++;
global++; global++;
} }
@ -96,7 +96,7 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
while (train.hasNext()) { while (train.hasNext()) {
val data = train.next().getFeatures(0); val data = train.next().getFeatures(0);
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5); assertEquals( (float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
gcntTrain++; gcntTrain++;
global++; global++;
} }
@ -106,7 +106,7 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
if (e % 2 == 0) if (e % 2 == 0)
while (test.hasNext()) { while (test.hasNext()) {
val data = test.next().getFeatures(0); val data = test.next().getFeatures(0);
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5); assertEquals((float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
gcntTest++; gcntTest++;
global++; global++;
} }
@ -115,8 +115,9 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
assertEquals(700 * numEpochs + (300 * numEpochs / 2), global); assertEquals(700 * numEpochs + (300 * numEpochs / 2), global);
} }
@Test(expected = ND4JIllegalStateException.class) @Test()
public void testSplitter_3() throws Exception { public void testSplitter_3() throws Exception {
assertThrows(ND4JIllegalStateException.class,() -> {
val back = new MultiDataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5}); val back = new MultiDataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
val splitter = new MultiDataSetIteratorSplitter(back, 1000, 0.7); val splitter = new MultiDataSetIteratorSplitter(back, 1000, 0.7);
@ -134,7 +135,7 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
while (train.hasNext()) { while (train.hasNext()) {
val data = train.next().getFeatures(0); val data = train.next().getFeatures(0);
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5); assertEquals((float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
gcntTrain++; gcntTrain++;
global++; global++;
} }
@ -144,7 +145,7 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
while (test.hasNext()) { while (test.hasNext()) {
val data = test.next().getFeatures(0); val data = test.next().getFeatures(0);
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5); assertEquals( (float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
gcntTest++; gcntTest++;
global++; global++;
} }
@ -155,6 +156,8 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
} }
assertEquals(1000 * numEpochs, global); assertEquals(1000 * numEpochs, global);
});
} }
@Test @Test
@ -185,11 +188,11 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
assertNotNull(ds); assertNotNull(ds);
for (int i = 0; i < ds.getFeatures().length; ++i) { for (int i = 0; i < ds.getFeatures().length; ++i) {
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f); assertEquals( (double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f,"Failed at iteration [" + globalIter + "]");
} }
globalIter++; globalIter++;
} }
assertTrue("Failed at epoch [" + e + "]", trained); assertTrue(trained,"Failed at epoch [" + e + "]");
assertEquals(800, globalIter); assertEquals(800, globalIter);
@ -202,11 +205,11 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
assertNotNull(ds); assertNotNull(ds);
for (int i = 0; i < ds.getFeatures().length; ++i) { for (int i = 0; i < ds.getFeatures().length; ++i) {
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f); assertEquals((double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f,"Failed at iteration [" + globalIter + "]");
} }
globalIter++; globalIter++;
} }
assertTrue("Failed at epoch [" + e + "]", tested); assertTrue(tested,"Failed at epoch [" + e + "]");
assertEquals(900, globalIter); assertEquals(900, globalIter);
// validation set is used every 5 epochs // validation set is used every 5 epochs
@ -219,11 +222,11 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
assertNotNull(ds); assertNotNull(ds);
for (int i = 0; i < ds.getFeatures().length; ++i) { for (int i = 0; i < ds.getFeatures().length; ++i) {
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f); assertEquals( (double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f,"Failed at iteration [" + globalIter + "]");
} }
globalIter++; globalIter++;
} }
assertTrue("Failed at epoch [" + e + "]", validated); assertTrue(validated,"Failed at epoch [" + e + "]");
} }
// all 3 iterators have exactly 1000 elements combined // all 3 iterators have exactly 1000 elements combined
@ -256,8 +259,7 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
val data = partIterator.next().getFeatures(); val data = partIterator.next().getFeatures();
for (int i = 0; i < data.length; ++i) { for (int i = 0; i < data.length; ++i) {
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, assertEquals((float) perEpoch, data[i].getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
(float) perEpoch, data[i].getFloat(0), 1e-5);
} }
//gcntTrain++; //gcntTrain++;
global++; global++;
@ -299,12 +301,12 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
assertNotNull(ds); assertNotNull(ds);
for (int i = 0; i < ds.getFeatures().length; ++i) { for (int i = 0; i < ds.getFeatures().length; ++i) {
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, assertEquals((double) globalIter,
ds.getFeatures()[i].getDouble(0), 1e-5f); ds.getFeatures()[i].getDouble(0), 1e-5f,"Failed at iteration [" + globalIter + "]");
} }
globalIter++; globalIter++;
} }
assertTrue("Failed at epoch [" + e + "]", trained); assertTrue(trained,"Failed at epoch [" + e + "]");
assertEquals(800, globalIter); assertEquals(800, globalIter);
@ -316,11 +318,11 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
val ds = testIter.next(); val ds = testIter.next();
assertNotNull(ds); assertNotNull(ds);
for (int i = 0; i < ds.getFeatures().length; ++i) { for (int i = 0; i < ds.getFeatures().length; ++i) {
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f); assertEquals((double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f,"Failed at iteration [" + globalIter + "]");
} }
globalIter++; globalIter++;
} }
assertTrue("Failed at epoch [" + e + "]", tested); assertTrue(tested,"Failed at epoch [" + e + "]");
assertEquals(900, globalIter); assertEquals(900, globalIter);
// validation set is used every 5 epochs // validation set is used every 5 epochs
@ -333,12 +335,12 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
assertNotNull(ds); assertNotNull(ds);
for (int i = 0; i < ds.getFeatures().length; ++i) { for (int i = 0; i < ds.getFeatures().length; ++i) {
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, assertEquals((double) globalIter,
ds.getFeatures()[i].getDouble(0), 1e-5f); ds.getFeatures()[i].getDouble(0), 1e-5f,"Failed at iteration [" + globalIter + "]");
} }
globalIter++; globalIter++;
} }
assertTrue("Failed at epoch [" + e + "]", validated); assertTrue(validated,"Failed at epoch [" + e + "]");
} }
// all 3 iterators have exactly 1000 elements combined // all 3 iterators have exactly 1000 elements combined
@ -370,7 +372,7 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
int farCnt = (1000 / 2) * (partNumber) + cnt; int farCnt = (1000 / 2) * (partNumber) + cnt;
val data = iteratorList.get(partNumber).next().getFeatures(); val data = iteratorList.get(partNumber).next().getFeatures();
for (int i = 0; i < data.length; ++i) { for (int i = 0; i < data.length; ++i) {
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) farCnt, data[i].getFloat(0), 1e-5); assertEquals( (float) farCnt, data[i].getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
} }
cnt++; cnt++;
global++; global++;
@ -381,8 +383,8 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
while (iteratorList.get(0).hasNext()) { while (iteratorList.get(0).hasNext()) {
val data = iteratorList.get(0).next().getFeatures(); val data = iteratorList.get(0).next().getFeatures();
for (int i = 0; i < data.length; ++i) { for (int i = 0; i < data.length; ++i) {
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, assertEquals((float) cnt++,
data[i].getFloat(0), 1e-5); data[i].getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
} }
global++; global++;
} }
@ -402,7 +404,7 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
while (iteratorList.get(partNumber).hasNext()) { while (iteratorList.get(partNumber).hasNext()) {
val data = iteratorList.get(partNumber).next().getFeatures(); val data = iteratorList.get(partNumber).next().getFeatures();
for (int i = 0; i < data.length; ++i) { for (int i = 0; i < data.length; ++i) {
assertEquals("Train failed on iteration " + cnt, (float) (500 * partNumber + cnt), data[i].getFloat(0), 1e-5); assertEquals( (float) (500 * partNumber + cnt), data[i].getFloat(0), 1e-5,"Train failed on iteration " + cnt);
} }
cnt++; cnt++;
} }
@ -427,8 +429,8 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
while (iteratorList.get(partNumber).hasNext()) { while (iteratorList.get(partNumber).hasNext()) {
val data = iteratorList.get(partNumber).next().getFeatures(); val data = iteratorList.get(partNumber).next().getFeatures();
for (int i = 0; i < data.length; ++i) { for (int i = 0; i < data.length; ++i) {
assertEquals("Train failed on iteration " + cnt, (float) (500 * partNumber + cnt), assertEquals( (float) (500 * partNumber + cnt),
data[i].getFloat(0), 1e-5); data[i].getFloat(0), 1e-5,"Train failed on iteration " + cnt);
} }
cnt++; cnt++;
} }
@ -454,8 +456,8 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
val ds = validationIter.next(); val ds = validationIter.next();
assertNotNull(ds); assertNotNull(ds);
for (int i = 0; i < ds.getFeatures().length; ++i) { for (int i = 0; i < ds.getFeatures().length; ++i) {
assertEquals("Validation failed on iteration " + valCnt, (float) valCnt + 90, assertEquals((float) valCnt + 90,
ds.getFeatures()[i].getFloat(0), 1e-5); ds.getFeatures()[i].getFloat(0), 1e-5,"Validation failed on iteration " + valCnt);
} }
valCnt++; valCnt++;
} }

View File

@ -25,9 +25,9 @@ import org.datavec.api.split.FileSplit;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.util.TestDataSetConsumer; import org.deeplearning4j.nn.util.TestDataSetConsumer;
import org.junit.Rule;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.Timeout;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -42,9 +42,6 @@ import org.junit.jupiter.api.extension.ExtendWith;
@DisplayName("Multiple Epochs Iterator Test") @DisplayName("Multiple Epochs Iterator Test")
class MultipleEpochsIteratorTest extends BaseDL4JTest { class MultipleEpochsIteratorTest extends BaseDL4JTest {
@Rule
public Timeout timeout = Timeout.seconds(300);
@Test @Test
@DisplayName("Test Next And Reset") @DisplayName("Test Next And Reset")
void testNextAndReset() throws Exception { void testNextAndReset() throws Exception {

View File

@ -22,8 +22,8 @@ package org.deeplearning4j.datasets.iterator;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.junit.Ignore; import org.junit.jupiter.api.Disabled;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor; import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
@ -33,9 +33,9 @@ import org.nd4j.linalg.factory.Nd4j;
import java.util.List; import java.util.List;
import static junit.framework.TestCase.assertTrue; import static junit.framework.TestCase.assertTrue;
import static org.junit.Assert.*; import static org.junit.jupiter.api.Assertions.*;
@Ignore @Disabled
public class TestAsyncIterator extends BaseDL4JTest { public class TestAsyncIterator extends BaseDL4JTest {
@Test @Test

Some files were not shown because too many files have changed in this diff Show More