All tests compile

master
agibsonccc 2021-03-16 11:57:24 +09:00
parent b1229432d6
commit 82bdcc21d2
729 changed files with 6080 additions and 5619 deletions

View File

@ -25,7 +25,7 @@ import org.datavec.api.records.reader.impl.csv.CSVLineSequenceRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;

View File

@ -25,7 +25,7 @@ import org.datavec.api.records.reader.impl.csv.CSVMultiSequenceRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;

View File

@ -26,7 +26,7 @@ import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
import org.datavec.api.split.InputSplit;
import org.datavec.api.split.NumberedFileInputSplit;
import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;

View File

@ -27,7 +27,7 @@ import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
import org.datavec.api.records.reader.impl.filebatch.FileBatchRecordReader;
import org.datavec.api.records.reader.impl.filebatch.FileBatchSequenceRecordReader;
import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;

View File

@ -27,7 +27,7 @@ import org.datavec.api.split.CollectionInputSplit;
import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;

View File

@ -30,7 +30,7 @@ import org.datavec.api.split.NumberedFileInputSplit;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;

View File

@ -29,7 +29,7 @@ import org.datavec.api.split.FileSplit;
import org.datavec.api.split.InputSplit;
import org.datavec.api.split.InputStreamInputSplit;
import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;

View File

@ -32,7 +32,7 @@ import org.datavec.api.split.InputSplit;
import org.datavec.api.split.NumberedFileInputSplit;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;

View File

@ -26,14 +26,14 @@ import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.*;
import static org.junit.jupiter.api.Assertions.*;
public class TestCollectionRecordReaders extends BaseND4JTest {

View File

@ -23,11 +23,11 @@ package org.datavec.api.records.reader.impl;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestConcatenatingRecordReader extends BaseND4JTest {

View File

@ -37,7 +37,7 @@ import org.datavec.api.transform.TransformProcess;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource;
import org.nd4j.shade.jackson.core.JsonFactory;
@ -47,7 +47,7 @@ import java.io.*;
import java.util.ArrayList;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestSerialization extends BaseND4JTest {

View File

@ -30,7 +30,7 @@ import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.LongWritable;
import org.datavec.api.writable.Writable;
import org.joda.time.DateTimeZone;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource;
@ -38,8 +38,8 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class TransformProcessRecordReaderTests extends BaseND4JTest {

View File

@ -26,7 +26,7 @@ import org.datavec.api.io.filters.BalancedPathFilter;
import org.datavec.api.io.filters.RandomPathFilter;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.io.labels.PatternPathLabelGenerator;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import java.io.*;
import java.net.URI;
@ -35,7 +35,7 @@ import java.util.ArrayList;
import java.util.Random;
import static junit.framework.TestCase.assertTrue;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
/**
*

View File

@ -20,13 +20,12 @@
package org.datavec.api.split;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import java.net.URI;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.*;
public class NumberedFileInputSplitTests extends BaseND4JTest {
@Test
@ -69,60 +68,81 @@ public class NumberedFileInputSplitTests extends BaseND4JTest {
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
}
@Test(expected = IllegalArgumentException.class)
@Test()
public void testNumberedFileInputSplitWithLeadingSpaces() {
assertThrows(IllegalArgumentException.class,() -> {
String baseString = "/path/to/files/prefix-%5d.suffix";
int minIdx = 0;
int maxIdx = 10;
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
});
}
@Test(expected = IllegalArgumentException.class)
@Test()
public void testNumberedFileInputSplitWithNoLeadingZeroInPadding() {
assertThrows(IllegalArgumentException.class, () -> {
String baseString = "/path/to/files/prefix%5d.suffix";
int minIdx = 0;
int maxIdx = 10;
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
});
}
@Test(expected = IllegalArgumentException.class)
@Test()
public void testNumberedFileInputSplitWithLeadingPlusInPadding() {
assertThrows(IllegalArgumentException.class,() -> {
String baseString = "/path/to/files/prefix%+5d.suffix";
int minIdx = 0;
int maxIdx = 10;
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
});
}
@Test(expected = IllegalArgumentException.class)
@Test()
public void testNumberedFileInputSplitWithLeadingMinusInPadding() {
assertThrows(IllegalArgumentException.class,() -> {
String baseString = "/path/to/files/prefix%-5d.suffix";
int minIdx = 0;
int maxIdx = 10;
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
});
}
@Test(expected = IllegalArgumentException.class)
@Test()
public void testNumberedFileInputSplitWithTwoDigitsInPadding() {
assertThrows(IllegalArgumentException.class,() -> {
String baseString = "/path/to/files/prefix%011d.suffix";
int minIdx = 0;
int maxIdx = 10;
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
});
}
@Test(expected = IllegalArgumentException.class)
@Test()
public void testNumberedFileInputSplitWithInnerZerosInPadding() {
assertThrows(IllegalArgumentException.class,() -> {
String baseString = "/path/to/files/prefix%101d.suffix";
int minIdx = 0;
int maxIdx = 10;
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
});
}
@Test(expected = IllegalArgumentException.class)
@Test()
public void testNumberedFileInputSplitWithRepeatInnerZerosInPadding() {
assertThrows(IllegalArgumentException.class,() -> {
String baseString = "/path/to/files/prefix%0505d.suffix";
int minIdx = 0;
int maxIdx = 10;
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
});
}
@ -135,7 +155,7 @@ public class NumberedFileInputSplitTests extends BaseND4JTest {
String path = locs[j++].getPath();
String exp = String.format(baseString, i);
String msg = exp + " vs " + path;
assertTrue(msg, path.endsWith(exp)); //Note: on Windows, Java can prepend drive to path - "/C:/"
assertTrue(path.endsWith(exp),msg); //Note: on Windows, Java can prepend drive to path - "/C:/"
}
}
}

View File

@ -25,9 +25,10 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.function.Function;
@ -37,22 +38,22 @@ import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
public class TestStreamInputSplit extends BaseND4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test
public void testCsvSimple() throws Exception {
File dir = testDir.newFolder();
public void testCsvSimple(@TempDir Path testDir) throws Exception {
File dir = testDir.toFile();
File f1 = new File(dir, "file1.txt");
File f2 = new File(dir, "file2.txt");
@ -93,9 +94,9 @@ public class TestStreamInputSplit extends BaseND4JTest {
@Test
public void testCsvSequenceSimple() throws Exception {
public void testCsvSequenceSimple(@TempDir Path testDir) throws Exception {
File dir = testDir.newFolder();
File dir = testDir.toFile();
File f1 = new File(dir, "file1.txt");
File f2 = new File(dir, "file2.txt");
@ -137,8 +138,8 @@ public class TestStreamInputSplit extends BaseND4JTest {
}
@Test
public void testShuffle() throws Exception {
File dir = testDir.newFolder();
public void testShuffle(@TempDir Path testDir) throws Exception {
File dir = testDir.toFile();
File f1 = new File(dir, "file1.txt");
File f2 = new File(dir, "file2.txt");
File f3 = new File(dir, "file3.txt");

View File

@ -27,14 +27,14 @@ import org.datavec.api.split.FileSplit;
import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
import org.datavec.api.split.partition.PartitionMetaData;
import org.datavec.api.split.partition.Partitioner;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import java.io.File;
import java.io.OutputStream;
import static junit.framework.TestCase.assertTrue;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
public class PartitionerTests extends BaseND4JTest {
@Test

View File

@ -29,12 +29,12 @@ import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import java.util.*;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestTransformProcess extends BaseND4JTest {

View File

@ -27,13 +27,13 @@ import org.datavec.api.transform.condition.string.StringRegexColumnCondition;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.transform.TestTransforms;
import org.datavec.api.writable.*;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import java.util.*;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class TestConditions extends BaseND4JTest {

View File

@ -27,7 +27,7 @@ import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import java.util.ArrayList;
@ -36,8 +36,8 @@ import java.util.Collections;
import java.util.List;
import static java.util.Arrays.asList;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class TestFilters extends BaseND4JTest {

View File

@ -26,19 +26,22 @@ import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.NullWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
public class TestJoin extends BaseND4JTest {
@Test
public void testJoin() {
public void testJoin(@TempDir Path testDir) {
Schema firstSchema =
new Schema.Builder().addColumnString("keyColumn").addColumnsInteger("first0", "first1").build();
@ -46,20 +49,20 @@ public class TestJoin extends BaseND4JTest {
Schema secondSchema = new Schema.Builder().addColumnString("keyColumn").addColumnsInteger("second0").build();
List<List<Writable>> first = new ArrayList<>();
first.add(Arrays.asList((Writable) new Text("key0"), new IntWritable(0), new IntWritable(1)));
first.add(Arrays.asList((Writable) new Text("key1"), new IntWritable(10), new IntWritable(11)));
first.add(Arrays.asList(new Text("key0"), new IntWritable(0), new IntWritable(1)));
first.add(Arrays.asList(new Text("key1"), new IntWritable(10), new IntWritable(11)));
List<List<Writable>> second = new ArrayList<>();
second.add(Arrays.asList((Writable) new Text("key0"), new IntWritable(100)));
second.add(Arrays.asList((Writable) new Text("key1"), new IntWritable(110)));
second.add(Arrays.asList(new Text("key0"), new IntWritable(100)));
second.add(Arrays.asList(new Text("key1"), new IntWritable(110)));
Join join = new Join.Builder(Join.JoinType.Inner).setJoinColumns("keyColumn")
.setSchemas(firstSchema, secondSchema).build();
List<List<Writable>> expected = new ArrayList<>();
expected.add(Arrays.asList((Writable) new Text("key0"), new IntWritable(0), new IntWritable(1),
expected.add(Arrays.asList(new Text("key0"), new IntWritable(0), new IntWritable(1),
new IntWritable(100)));
expected.add(Arrays.asList((Writable) new Text("key1"), new IntWritable(10), new IntWritable(11),
expected.add(Arrays.asList(new Text("key1"), new IntWritable(10), new IntWritable(11),
new IntWritable(110)));
@ -94,9 +97,9 @@ public class TestJoin extends BaseND4JTest {
}
@Test(expected = IllegalArgumentException.class)
@Test()
public void testJoinValidation() {
assertThrows(IllegalArgumentException.class,() -> {
Schema firstSchema = new Schema.Builder().addColumnString("keyColumn1").addColumnsInteger("first0", "first1")
.build();
@ -104,11 +107,13 @@ public class TestJoin extends BaseND4JTest {
new Join.Builder(Join.JoinType.Inner).setJoinColumns("keyColumn1", "thisDoesntExist")
.setSchemas(firstSchema, secondSchema).build();
});
}
@Test(expected = IllegalArgumentException.class)
@Test()
public void testJoinValidation2() {
assertThrows(IllegalArgumentException.class,() -> {
Schema firstSchema = new Schema.Builder().addColumnString("keyColumn1").addColumnsInteger("first0", "first1")
.build();
@ -116,5 +121,7 @@ public class TestJoin extends BaseND4JTest {
new Join.Builder(Join.JoinType.Inner).setJoinColumns("keyColumn1").setSchemas(firstSchema, secondSchema)
.build();
});
}
}

View File

@ -19,17 +19,18 @@
*/
package org.datavec.api.transform.ops;
import org.junit.Rule;
import org.junit.jupiter.api.Test;
import org.junit.rules.ExpectedException;
import org.nd4j.common.tests.BaseND4JTest;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import static org.junit.jupiter.api.Assertions.*;
@DisplayName("Aggregator Impls Test")
class AggregatorImplsTest extends BaseND4JTest {
@ -265,12 +266,12 @@ class AggregatorImplsTest extends BaseND4JTest {
assertEquals(9, cu.get().toInt());
}
@Rule
public final ExpectedException exception = ExpectedException.none();
@Test
@DisplayName("Incompatible Aggregator Test")
void incompatibleAggregatorTest() {
assertThrows(UnsupportedOperationException.class,() -> {
AggregatorImpls.AggregableSum<Integer> sm = new AggregatorImpls.AggregableSum<>();
for (int i = 0; i < intList.size(); i++) {
sm.accept(intList.get(i));
@ -280,8 +281,10 @@ class AggregatorImplsTest extends BaseND4JTest {
for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1));
}
exception.expect(UnsupportedOperationException.class);
sm.combine(reverse);
assertEquals(45, sm.get().toInt());
});
}
}

View File

@ -32,13 +32,13 @@ import org.datavec.api.transform.ops.AggregableMultiOp;
import org.datavec.api.transform.ops.IAggregableReduceOp;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.*;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import java.util.*;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;
public class TestMultiOpReduce extends BaseND4JTest {
@ -46,10 +46,10 @@ public class TestMultiOpReduce extends BaseND4JTest {
public void testMultiOpReducerDouble() {
List<List<Writable>> inputs = new ArrayList<>();
inputs.add(Arrays.asList((Writable) new Text("someKey"), new DoubleWritable(0)));
inputs.add(Arrays.asList((Writable) new Text("someKey"), new DoubleWritable(1)));
inputs.add(Arrays.asList((Writable) new Text("someKey"), new DoubleWritable(2)));
inputs.add(Arrays.asList((Writable) new Text("someKey"), new DoubleWritable(2)));
inputs.add(Arrays.asList(new Text("someKey"), new DoubleWritable(0)));
inputs.add(Arrays.asList(new Text("someKey"), new DoubleWritable(1)));
inputs.add(Arrays.asList(new Text("someKey"), new DoubleWritable(2)));
inputs.add(Arrays.asList(new Text("someKey"), new DoubleWritable(2)));
Map<ReduceOp, Double> exp = new LinkedHashMap<>();
exp.put(ReduceOp.Min, 0.0);
@ -82,7 +82,7 @@ public class TestMultiOpReduce extends BaseND4JTest {
assertEquals(out.get(0), new Text("someKey"));
String msg = op.toString();
assertEquals(msg, exp.get(op), out.get(1).toDouble(), 1e-5);
assertEquals(exp.get(op), out.get(1).toDouble(), 1e-5,msg);
}
}
@ -126,7 +126,7 @@ public class TestMultiOpReduce extends BaseND4JTest {
assertEquals(out.get(0), new Text("someKey"));
String msg = op.toString();
assertEquals(msg, exp.get(op), out.get(1).toDouble(), 1e-5);
assertEquals(exp.get(op), out.get(1).toDouble(), 1e-5,msg);
}
}
@ -210,7 +210,7 @@ public class TestMultiOpReduce extends BaseND4JTest {
assertEquals(out.get(0), new Text("someKey"));
String msg = op.toString();
assertEquals(msg, exp.get(op), out.get(1).toDouble(), 1e-5);
assertEquals(exp.get(op), out.get(1).toDouble(), 1e-5,msg);
}
for (ReduceOp op : Arrays.asList(ReduceOp.Min, ReduceOp.Max, ReduceOp.Range, ReduceOp.Sum, ReduceOp.Mean,

View File

@ -24,13 +24,13 @@ import org.datavec.api.transform.ops.IAggregableReduceOp;
import org.datavec.api.transform.reduce.impl.GeographicMidpointReduction;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestReductions extends BaseND4JTest {

View File

@ -22,10 +22,10 @@ package org.datavec.api.transform.schema;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.joda.time.DateTimeZone;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestJsonYaml extends BaseND4JTest {

View File

@ -21,10 +21,10 @@
package org.datavec.api.transform.schema;
import org.datavec.api.transform.ColumnType;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestSchemaMethods extends BaseND4JTest {

View File

@ -33,7 +33,7 @@ import org.datavec.api.writable.LongWritable;
import org.datavec.api.writable.NullWritable;
import org.datavec.api.writable.Writable;
import org.joda.time.DateTimeZone;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import java.util.ArrayList;
@ -41,7 +41,7 @@ import java.util.Arrays;
import java.util.List;
import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestReduceSequenceByWindowFunction extends BaseND4JTest {

View File

@ -27,7 +27,7 @@ import org.datavec.api.writable.LongWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.joda.time.DateTimeZone;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import java.util.ArrayList;
@ -35,7 +35,7 @@ import java.util.Arrays;
import java.util.List;
import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestSequenceSplit extends BaseND4JTest {

View File

@ -29,7 +29,7 @@ import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.LongWritable;
import org.datavec.api.writable.Writable;
import org.joda.time.DateTimeZone;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import java.util.ArrayList;
@ -37,7 +37,7 @@ import java.util.Arrays;
import java.util.List;
import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestWindowFunctions extends BaseND4JTest {

View File

@ -26,10 +26,10 @@ import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.serde.testClasses.CustomCondition;
import org.datavec.api.transform.serde.testClasses.CustomFilter;
import org.datavec.api.transform.serde.testClasses.CustomTransform;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestCustomTransformJsonYaml extends BaseND4JTest {

View File

@ -64,13 +64,13 @@ import org.datavec.api.transform.transform.time.TimeMathOpTransform;
import org.datavec.api.writable.comparator.DoubleWritableComparator;
import org.joda.time.DateTimeFieldType;
import org.joda.time.DateTimeZone;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import java.util.*;
import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestYamlJsonSerde extends BaseND4JTest {

View File

@ -24,12 +24,12 @@ import org.datavec.api.transform.StringReduceOp;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import java.util.*;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestReduce extends BaseND4JTest {

View File

@ -50,7 +50,7 @@ import org.datavec.api.writable.Text;
import org.datavec.api.writable.comparator.LongWritableComparator;
import org.joda.time.DateTimeFieldType;
import org.joda.time.DateTimeZone;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource;
@ -61,7 +61,7 @@ import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class RegressionTestJson extends BaseND4JTest {

View File

@ -50,13 +50,13 @@ import org.datavec.api.writable.Text;
import org.datavec.api.writable.comparator.LongWritableComparator;
import org.joda.time.DateTimeFieldType;
import org.joda.time.DateTimeZone;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import java.util.*;
import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestJsonYaml extends BaseND4JTest {

View File

@ -59,7 +59,7 @@ import org.datavec.api.writable.*;
import org.joda.time.DateTimeFieldType;
import org.joda.time.DateTimeZone;
import org.junit.Assert;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -72,7 +72,7 @@ import java.util.*;
import java.util.concurrent.TimeUnit;
import static junit.framework.TestCase.assertEquals;
import static org.junit.Assert.*;
import static org.junit.jupiter.api.Assertions.*;
public class TestTransforms extends BaseND4JTest {

View File

@ -29,7 +29,7 @@ import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -39,7 +39,7 @@ import org.nd4j.linalg.ops.transforms.Transforms;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestNDArrayWritableTransforms extends BaseND4JTest {

View File

@ -30,13 +30,13 @@ import org.datavec.api.transform.ndarray.NDArrayScalarOpTransform;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.serde.JsonSerializer;
import org.datavec.api.transform.serde.YamlSerializer;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestYamlJsonSerde extends BaseND4JTest {

View File

@ -35,26 +35,26 @@ import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.joda.time.DateTimeZone;
import org.junit.Ignore;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;
import java.io.File;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestUI extends BaseND4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test
public void testUI() throws Exception {
public void testUI(@TempDir Path testDir) throws Exception {
Schema schema = new Schema.Builder().addColumnString("StringColumn").addColumnInteger("IntColumn")
.addColumnInteger("IntColumn2").addColumnInteger("IntColumn3")
.addColumnTime("TimeColumn", DateTimeZone.UTC).build();
@ -92,7 +92,7 @@ public class TestUI extends BaseND4JTest {
DataAnalysis da = new DataAnalysis(schema, list);
File fDir = testDir.newFolder();
File fDir = testDir.toFile();
String tempDir = fDir.getAbsolutePath();
String outPath = FilenameUtils.concat(tempDir, "datavec_transform_UITest.html");
System.out.println(outPath);
@ -143,7 +143,7 @@ public class TestUI extends BaseND4JTest {
@Test
@Ignore
@Disabled
public void testSequencePlot() throws Exception {
Schema schema = new SequenceSchema.Builder().addColumnDouble("sinx")

View File

@ -21,14 +21,14 @@
package org.datavec.api.writable;
import org.datavec.api.transform.metadata.NDArrayMetaData;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.io.*;
import static org.junit.Assert.*;
import static org.junit.jupiter.api.Assertions.*;
public class TestNDArrayWritableAndSerialization extends BaseND4JTest {

View File

@ -41,7 +41,7 @@ import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.*;
import org.datavec.arrow.recordreader.ArrowRecordReader;
import org.datavec.arrow.recordreader.ArrowWritableRecordBatch;
import org.junit.Rule;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;

View File

@ -29,16 +29,16 @@ import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.datavec.arrow.ArrowConverter;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
public class ArrowWritableRecordTimeSeriesBatchTests extends BaseND4JTest {
@ -69,7 +69,7 @@ public class ArrowWritableRecordTimeSeriesBatchTests extends BaseND4JTest {
assertEquals(3,fieldVectors.size());
for(FieldVector fieldVector : fieldVectors) {
for(int i = 0; i < fieldVector.getValueCount(); i++) {
assertFalse("Index " + i + " was null for field vector " + fieldVector, fieldVector.isNull(i));
assertFalse( fieldVector.isNull(i),"Index " + i + " was null for field vector " + fieldVector);
}
}
@ -79,7 +79,7 @@ public class ArrowWritableRecordTimeSeriesBatchTests extends BaseND4JTest {
@Test
//not worried about this till after next release
@Ignore
@Disabled
public void testVariableLengthTS() {
Schema.Builder schema = new Schema.Builder()
.addColumnString("str")

View File

@ -23,7 +23,7 @@ import org.apache.commons.io.FileUtils;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit;
import org.datavec.image.recordreader.ImageRecordReader;
import org.junit.Rule;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.io.ClassPathResource;

View File

@ -22,8 +22,8 @@ package org.datavec.image.loader;
import org.apache.commons.io.FilenameUtils;
import org.datavec.api.records.reader.RecordReader;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.dataset.DataSet;
import java.io.File;
@ -32,9 +32,9 @@ import java.io.InputStream;
import java.util.List;
import java.util.Random;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
/**
*
@ -182,7 +182,7 @@ public class LoaderTests {
}
@Ignore // Use when confirming data is getting stored
@Disabled // Use when confirming data is getting stored
@Test
public void testProcessCifar() {
int row = 32;
@ -208,15 +208,15 @@ public class LoaderTests {
int minibatch = 100;
int nMinibatches = 50000 / minibatch;
for( int i=0; i<nMinibatches; i++ ){
for( int i=0; i < nMinibatches; i++) {
DataSet ds = loader.next(minibatch);
String s = String.valueOf(i);
assertNotNull(s, ds.getFeatures());
assertNotNull(s, ds.getLabels());
assertNotNull(ds.getFeatures(),s);
assertNotNull(ds.getLabels(),s);
assertEquals(s, minibatch, ds.getFeatures().size(0));
assertEquals(s, minibatch, ds.getLabels().size(0));
assertEquals(s, 10, ds.getLabels().size(1));
assertEquals(minibatch, ds.getFeatures().size(0),s);
assertEquals(minibatch, ds.getLabels().size(0),s);
assertEquals(10, ds.getLabels().size(1),s);
}
}

View File

@ -21,7 +21,7 @@
package org.datavec.image.loader;
import org.datavec.image.data.Image;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.resources.Resources;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -32,7 +32,7 @@ import java.io.FileInputStream;
import java.io.InputStream;
import java.util.Random;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestImageLoader {

View File

@ -30,9 +30,10 @@ import org.bytedeco.javacv.Java2DFrameConverter;
import org.bytedeco.javacv.OpenCVFrameConverter;
import org.datavec.image.data.Image;
import org.datavec.image.data.ImageWritable;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.resources.Resources;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -42,16 +43,17 @@ import org.nd4j.common.io.ClassPathResource;
import java.awt.image.BufferedImage;
import java.io.*;
import java.lang.reflect.Field;
import java.nio.file.Path;
import java.util.Random;
import org.bytedeco.leptonica.*;
import org.bytedeco.opencv.opencv_core.*;
import static org.bytedeco.leptonica.global.lept.*;
import static org.bytedeco.opencv.global.opencv_core.*;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
/**
*
@ -62,8 +64,6 @@ public class TestNativeImageLoader {
static final long seed = 10;
static final Random rng = new Random(seed);
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test
public void testConvertPix() throws Exception {
@ -566,8 +566,8 @@ public class TestNativeImageLoader {
@Test
public void testNativeImageLoaderEmptyStreams() throws Exception {
File dir = testDir.newFolder();
public void testNativeImageLoaderEmptyStreams(@TempDir Path testDir) throws Exception {
File dir = testDir.toFile();
File f = new File(dir, "myFile.jpg");
f.createNewFile();
@ -578,7 +578,7 @@ public class TestNativeImageLoader {
fail("Expected exception");
} catch (IOException e){
String msg = e.getMessage();
assertTrue(msg, msg.contains("decode image"));
assertTrue(msg.contains("decode image"),msg);
}
try(InputStream is = new FileInputStream(f)){
@ -586,7 +586,7 @@ public class TestNativeImageLoader {
fail("Expected exception");
} catch (IOException e){
String msg = e.getMessage();
assertTrue(msg, msg.contains("decode image"));
assertTrue(msg.contains("decode image"),msg);
}
try(InputStream is = new FileInputStream(f)){
@ -594,7 +594,7 @@ public class TestNativeImageLoader {
fail("Expected exception");
} catch (IOException e){
String msg = e.getMessage();
assertTrue(msg, msg.contains("decode image"));
assertTrue(msg.contains("decode image"),msg);
}
try(InputStream is = new FileInputStream(f)){
@ -603,7 +603,7 @@ public class TestNativeImageLoader {
fail("Expected exception");
} catch (IOException e){
String msg = e.getMessage();
assertTrue(msg, msg.contains("decode image"));
assertTrue( msg.contains("decode image"),msg);
}
}

View File

@ -27,7 +27,7 @@ import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.datavec.image.loader.NativeImageLoader;
import org.junit.Rule;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.loader.FileBatch;

View File

@ -36,9 +36,10 @@ import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.datavec.api.writable.batch.NDArrayRecordBatch;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -46,28 +47,30 @@ import org.nd4j.common.io.ClassPathResource;
import java.io.*;
import java.net.URI;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import static org.junit.Assert.*;
import static org.junit.jupiter.api.Assertions.*;
public class TestImageRecordReader {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test(expected = IllegalArgumentException.class)
@Test()
public void testEmptySplit() throws IOException {
InputSplit data = new CollectionInputSplit(new ArrayList<URI>());
assertThrows(IllegalArgumentException.class,() -> {
InputSplit data = new CollectionInputSplit(new ArrayList<>());
new ImageRecordReader().initialize(data, null);
});
}
@Test
public void testMetaData() throws IOException {
public void testMetaData(@TempDir Path testDir) throws IOException {
File parentDir = testDir.newFolder();
File parentDir = testDir.toFile();
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(parentDir);
// System.out.println(f.getAbsolutePath());
// System.out.println(f.getParentFile().getParentFile().getAbsolutePath());
@ -104,11 +107,11 @@ public class TestImageRecordReader {
}
@Test
public void testImageRecordReaderLabelsOrder() throws Exception {
public void testImageRecordReaderLabelsOrder(@TempDir Path testDir) throws Exception {
//Labels order should be consistent, regardless of file iteration order
//Idea: labels order should be consistent regardless of input file order
File f = testDir.newFolder();
File f = testDir.toFile();
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f);
File f0 = new File(f, "/class0/0.jpg");
File f1 = new File(f, "/class1/A.jpg");
@ -135,11 +138,11 @@ public class TestImageRecordReader {
@Test
public void testImageRecordReaderRandomization() throws Exception {
public void testImageRecordReaderRandomization(@TempDir Path testDir) throws Exception {
//Order of FileSplit+ImageRecordReader should be different after reset
//Idea: labels order should be consistent regardless of input file order
File f0 = testDir.newFolder();
File f0 = testDir.toFile();
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f0);
FileSplit fs = new FileSplit(f0, new Random(12345));
@ -189,13 +192,13 @@ public class TestImageRecordReader {
@Test
public void testImageRecordReaderRegression() throws Exception {
public void testImageRecordReaderRegression(@TempDir Path testDir) throws Exception {
PathLabelGenerator regressionLabelGen = new TestRegressionLabelGen();
ImageRecordReader rr = new ImageRecordReader(28, 28, 3, regressionLabelGen);
File rootDir = testDir.newFolder();
File rootDir = testDir.toFile();
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(rootDir);
FileSplit fs = new FileSplit(rootDir);
rr.initialize(fs);
@ -244,10 +247,10 @@ public class TestImageRecordReader {
}
@Test
public void testListenerInvocationBatch() throws IOException {
public void testListenerInvocationBatch(@TempDir Path testDir) throws IOException {
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
ImageRecordReader rr = new ImageRecordReader(32, 32, 3, labelMaker);
File f = testDir.newFolder();
File f = testDir.toFile();
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f);
File parent = f;
@ -260,10 +263,10 @@ public class TestImageRecordReader {
}
@Test
public void testListenerInvocationSingle() throws IOException {
public void testListenerInvocationSingle(@TempDir Path testDir) throws IOException {
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
ImageRecordReader rr = new ImageRecordReader(32, 32, 3, labelMaker);
File parent = testDir.newFolder();
File parent = testDir.toFile();
new ClassPathResource("datavec-data-image/testimages/class0/").copyDirectory(parent);
int numFiles = parent.list().length;
rr.initialize(new FileSplit(parent));
@ -315,7 +318,7 @@ public class TestImageRecordReader {
@Test
public void testImageRecordReaderPathMultiLabelGenerator() throws Exception {
public void testImageRecordReaderPathMultiLabelGenerator(@TempDir Path testDir) throws Exception {
Nd4j.setDataType(DataType.FLOAT);
//Assumption: 2 multi-class (one hot) classification labels: 2 and 3 classes respectively
// PLUS single value (Writable) regression label
@ -324,7 +327,7 @@ public class TestImageRecordReader {
ImageRecordReader rr = new ImageRecordReader(28, 28, 3, multiLabelGen);
File rootDir = testDir.newFolder();
File rootDir = testDir.toFile();
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(rootDir);
FileSplit fs = new FileSplit(rootDir);
rr.initialize(fs);
@ -471,9 +474,9 @@ public class TestImageRecordReader {
@Test
public void testNCHW_NCHW() throws Exception {
public void testNCHW_NCHW(@TempDir Path testDir) throws Exception {
//Idea: labels order should be consistent regardless of input file order
File f0 = testDir.newFolder();
File f0 = testDir.toFile();
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f0);
FileSplit fs0 = new FileSplit(f0, new Random(12345));

View File

@ -35,9 +35,10 @@ import org.datavec.image.transform.FlipImageTransform;
import org.datavec.image.transform.ImageTransform;
import org.datavec.image.transform.PipelineImageTransform;
import org.datavec.image.transform.ResizeImageTransform;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
@ -46,24 +47,24 @@ import org.nd4j.common.io.ClassPathResource;
import java.io.File;
import java.net.URI;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import static org.junit.Assert.*;
import static org.junit.jupiter.api.Assertions.*;
public class TestObjectDetectionRecordReader {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test
public void test() throws Exception {
public void test(@TempDir Path testDir) throws Exception {
for(boolean nchw : new boolean[]{true, false}) {
ImageObjectLabelProvider lp = new TestImageObjectDetectionLabelProvider();
File f = testDir.newFolder();
File f = testDir.toFile();
new ClassPathResource("datavec-data-image/objdetect/").copyDirectory(f);
String path = new File(f, "000012.jpg").getParent();

View File

@ -21,27 +21,27 @@
package org.datavec.image.recordreader.objdetect;
import org.datavec.image.recordreader.objdetect.impl.VocLabelProvider;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.io.ClassPathResource;
import java.io.File;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestVocLabelProvider {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test
public void testVocLabelProvider() throws Exception {
public void testVocLabelProvider(@TempDir Path testDir) throws Exception {
File f = testDir.newFolder();
File f = testDir.toFile();
new ClassPathResource("datavec-data-image/voc/2007/").copyDirectory(f);
String path = f.getAbsolutePath(); //new ClassPathResource("voc/2007/JPEGImages/000005.jpg").getFile().getParentFile().getParent();

View File

@ -28,8 +28,8 @@ import org.nd4j.common.io.ClassPathResource;
import org.nd4j.common.primitives.Pair;
import org.datavec.image.data.ImageWritable;
import org.datavec.image.loader.NativeImageLoader;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import java.awt.*;
import java.util.LinkedList;
@ -40,7 +40,7 @@ import org.bytedeco.opencv.opencv_core.*;
import static org.bytedeco.opencv.global.opencv_core.*;
import static org.bytedeco.opencv.global.opencv_imgproc.*;
import static org.junit.Assert.*;
import static org.junit.jupiter.api.Assertions.*;
/**
*
@ -255,7 +255,7 @@ public class TestImageTransform {
assertEquals(22, transformed[1], 0);
}
@Ignore
@Disabled
@Test
public void testFilterImageTransform() throws Exception {
ImageWritable writable = makeRandomImage(0, 0, 4);

View File

@ -25,7 +25,7 @@ import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.primitives.Triple;

View File

@ -49,7 +49,7 @@ import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.Rule;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.api.DisplayName;

View File

@ -36,14 +36,14 @@ import org.datavec.api.writable.LongWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.joda.time.DateTimeZone;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.io.ClassPathResource;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class LocalTransformProcessRecordReaderTests {

View File

@ -29,9 +29,9 @@ import org.datavec.api.transform.schema.Schema;
import org.datavec.api.util.ndarray.RecordConverter;
import org.datavec.api.writable.Writable;
import org.datavec.local.transforms.AnalyzeLocal;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.common.io.ClassPathResource;
@ -39,12 +39,11 @@ import org.nd4j.common.io.ClassPathResource;
import java.util.ArrayList;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestAnalyzeLocal {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test
public void testAnalysisBasic() throws Exception {
@ -72,7 +71,7 @@ public class TestAnalyzeLocal {
INDArray mean = arr.mean(0);
INDArray std = arr.std(0);
for( int i=0; i<5; i++ ){
for( int i = 0; i < 5; i++) {
double m = ((NumericalColumnAnalysis)da.getColumnAnalysis().get(i)).getMean();
double stddev = ((NumericalColumnAnalysis)da.getColumnAnalysis().get(i)).getSampleStdev();
assertEquals(mean.getDouble(i), m, 1e-3);

View File

@ -27,7 +27,7 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.io.ClassPathResource;
import java.io.File;
@ -36,8 +36,8 @@ import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class TestLineRecordReaderFunction {

View File

@ -25,7 +25,7 @@ import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.datavec.local.transforms.misc.NDArrayToWritablesFunction;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -33,7 +33,7 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestNDArrayToWritablesFunction {

View File

@ -25,7 +25,7 @@ import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.datavec.local.transforms.misc.WritablesToNDArrayFunction;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -33,7 +33,7 @@ import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestWritablesToNDArrayFunction {

View File

@ -30,12 +30,12 @@ import org.datavec.api.writable.Writable;
import org.datavec.local.transforms.misc.SequenceWritablesToStringFunction;
import org.datavec.local.transforms.misc.WritablesToStringFunction;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestWritablesToStringFunctions {

View File

@ -32,7 +32,8 @@ import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.nd4j.common.io.ClassPathResource;
import java.io.*;
@ -40,14 +41,14 @@ import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
/**
* @author saudet
*/
public class TestGeoTransforms {
@BeforeClass
@BeforeAll
public static void beforeClass() throws Exception {
//Use test resources version to avoid tests suddenly failing due to IP/Location DB content changing
File f = new ClassPathResource("datavec-geo/GeoIP2-City-Test.mmdb").getFile();

View File

@ -30,7 +30,8 @@ import org.datavec.local.transforms.LocalTransformExecutor;
import org.datavec.api.writable.*;
import org.datavec.python.PythonCondition;
import org.datavec.python.PythonTransform;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -43,7 +44,7 @@ import java.util.List;
import static junit.framework.TestCase.assertTrue;
import static org.datavec.api.transform.schema.Schema.Builder;
import static org.junit.Assert.*;
import static org.junit.jupiter.api.Assertions.*;
@NotThreadSafe
public class TestPythonTransformProcess {
@ -77,8 +78,9 @@ public class TestPythonTransformProcess {
}
@Test(timeout = 60000L)
public void testMixedTypes() throws Exception{
@Test()
@Timeout(60000L)
public void testMixedTypes() throws Exception {
Builder schemaBuilder = new Builder();
schemaBuilder
.addColumnInteger("col1")
@ -99,7 +101,7 @@ public class TestPythonTransformProcess {
.inputSchema(initialSchema)
.build() ).build();
List<Writable> inputs = Arrays.asList((Writable)new IntWritable(10),
List<Writable> inputs = Arrays.asList(new IntWritable(10),
new FloatWritable(3.5f),
new Text("5"),
new DoubleWritable(2.0)
@ -109,8 +111,9 @@ public class TestPythonTransformProcess {
assertEquals(((LongWritable)outputs.get(4)).get(), 36);
}
@Test(timeout = 60000L)
public void testNDArray() throws Exception{
@Test()
@Timeout(60000L)
public void testNDArray() throws Exception {
long[] shape = new long[]{3, 2};
INDArray arr1 = Nd4j.rand(shape);
INDArray arr2 = Nd4j.rand(shape);
@ -145,8 +148,9 @@ public class TestPythonTransformProcess {
}
@Test(timeout = 60000L)
public void testNDArray2() throws Exception{
@Test()
@Timeout(60000L)
public void testNDArray2() throws Exception {
long[] shape = new long[]{3, 2};
INDArray arr1 = Nd4j.rand(shape);
INDArray arr2 = Nd4j.rand(shape);
@ -181,7 +185,8 @@ public class TestPythonTransformProcess {
}
@Test(timeout = 60000L)
@Test()
@Timeout(60000L)
public void testNDArrayMixed() throws Exception{
long[] shape = new long[]{3, 2};
INDArray arr1 = Nd4j.rand(DataType.DOUBLE, shape);
@ -217,7 +222,8 @@ public class TestPythonTransformProcess {
}
@Test(timeout = 60000L)
@Test()
@Timeout(60000L)
public void testPythonFilter() {
Schema schema = new Builder().addColumnInteger("column").build();
@ -237,8 +243,9 @@ public class TestPythonTransformProcess {
}
@Test(timeout = 60000L)
public void testPythonFilterAndTransform() throws Exception{
@Test()
@Timeout(60000L)
public void testPythonFilterAndTransform() throws Exception {
Builder schemaBuilder = new Builder();
schemaBuilder
.addColumnInteger("col1")

View File

@ -28,11 +28,11 @@ import org.datavec.api.writable.*;
import org.datavec.local.transforms.LocalTransformExecutor;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import java.util.*;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestJoin {

View File

@ -31,13 +31,13 @@ import org.datavec.api.writable.comparator.DoubleWritableComparator;
import org.datavec.local.transforms.LocalTransformExecutor;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestCalculateSortedRank {

View File

@ -31,14 +31,14 @@ import org.datavec.api.writable.Writable;
import org.datavec.arrow.recordreader.ArrowWritableRecordTimeSeriesBatch;
import org.datavec.local.transforms.LocalTransformExecutor;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class TestConvertToSequence {

View File

@ -41,6 +41,12 @@
</properties>
<dependencies>
<dependency>
<groupId>com.tdunning</groupId>
<artifactId>t-digest</artifactId>
<version>3.2</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-library</artifactId>

View File

@ -25,15 +25,15 @@ import org.apache.spark.serializer.SerializerInstance;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.io.ClassPathResource;
import java.io.File;
import java.nio.ByteBuffer;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class TestKryoSerialization extends BaseSparkTest {

View File

@ -27,7 +27,7 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Writable;
import org.datavec.spark.BaseSparkTest;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.io.ClassPathResource;
import java.io.File;
@ -35,8 +35,8 @@ import java.util.HashSet;
import java.util.List;
import java.util.Set;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class TestLineRecordReaderFunction extends BaseSparkTest {

View File

@ -24,7 +24,7 @@ import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.datavec.spark.transform.misc.NDArrayToWritablesFunction;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -32,7 +32,7 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestNDArrayToWritablesFunction {

View File

@ -38,9 +38,10 @@ import org.datavec.spark.functions.pairdata.PairSequenceRecordReaderBytesFunctio
import org.datavec.spark.functions.pairdata.PathToKeyConverter;
import org.datavec.spark.functions.pairdata.PathToKeyConverterFilename;
import org.datavec.spark.util.DataVecSparkUtil;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.io.ClassPathResource;
import scala.Tuple2;
@ -50,16 +51,13 @@ import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;
public class TestPairSequenceRecordReaderBytesFunction extends BaseSparkTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test
public void test() throws Exception {
public void test(@TempDir Path testDir) throws Exception {
//Goal: combine separate files together into a hadoop sequence file, for later parsing by a SequenceRecordReader
//For example: use to combine input and labels data from separate files for training a RNN
if(Platform.isWindows()) {
@ -67,7 +65,7 @@ public class TestPairSequenceRecordReaderBytesFunction extends BaseSparkTest {
}
JavaSparkContext sc = getContext();
File f = testDir.newFolder();
File f = testDir.toFile();
new ClassPathResource("datavec-spark/video/").copyDirectory(f);
String path = f.getAbsolutePath() + "/*";

View File

@ -36,9 +36,10 @@ import org.datavec.image.recordreader.ImageRecordReader;
import org.datavec.spark.BaseSparkTest;
import org.datavec.spark.functions.data.FilesAsBytesFunction;
import org.datavec.spark.functions.data.RecordReaderBytesFunction;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.io.ClassPathResource;
import java.io.File;
@ -48,23 +49,22 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;
public class TestRecordReaderBytesFunction extends BaseSparkTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test
public void testRecordReaderBytesFunction() throws Exception {
public void testRecordReaderBytesFunction(@TempDir Path testDir) throws Exception {
if(Platform.isWindows()) {
return;
}
JavaSparkContext sc = getContext();
//Local file path
File f = testDir.newFolder();
File f = testDir.toFile();
new ClassPathResource("datavec-spark/imagetest/").copyDirectory(f);
List<String> labelsList = Arrays.asList("0", "1"); //Need this for Spark: can't infer without init call
String path = f.getAbsolutePath() + "/*";

View File

@ -31,30 +31,29 @@ import org.datavec.api.writable.ArrayWritable;
import org.datavec.api.writable.Writable;
import org.datavec.image.recordreader.ImageRecordReader;
import org.datavec.spark.BaseSparkTest;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.io.ClassPathResource;
import java.io.File;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;
public class TestRecordReaderFunction extends BaseSparkTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test
public void testRecordReaderFunction() throws Exception {
public void testRecordReaderFunction(@TempDir Path testDir) throws Exception {
if(Platform.isWindows()) {
return;
}
File f = testDir.newFolder();
File f = testDir.toFile();
new ClassPathResource("datavec-spark/imagetest/").copyDirectory(f);
List<String> labelsList = Arrays.asList("0", "1"); //Need this for Spark: can't infer without init call

View File

@ -36,9 +36,10 @@ import org.datavec.codec.reader.CodecRecordReader;
import org.datavec.spark.BaseSparkTest;
import org.datavec.spark.functions.data.FilesAsBytesFunction;
import org.datavec.spark.functions.data.SequenceRecordReaderBytesFunction;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.io.ClassPathResource;
import java.io.File;
@ -47,21 +48,20 @@ import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;
public class TestSequenceRecordReaderBytesFunction extends BaseSparkTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test
public void testRecordReaderBytesFunction() throws Exception {
public void testRecordReaderBytesFunction(@TempDir Path testDir) throws Exception {
if(Platform.isWindows()) {
return;
}
//Local file path
File f = testDir.newFolder();
File f = testDir.toFile();
new ClassPathResource("datavec-spark/video/").copyDirectory(f);
String path = f.getAbsolutePath() + "/*";

View File

@ -33,28 +33,29 @@ import org.datavec.api.writable.ArrayWritable;
import org.datavec.api.writable.Writable;
import org.datavec.codec.reader.CodecRecordReader;
import org.datavec.spark.BaseSparkTest;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.io.ClassPathResource;
import java.io.File;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;
public class TestSequenceRecordReaderFunction extends BaseSparkTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test
public void testSequenceRecordReaderFunctionCSV() throws Exception {
public void testSequenceRecordReaderFunctionCSV(@TempDir Path testDir) throws Exception {
JavaSparkContext sc = getContext();
File f = testDir.newFolder();
File f = testDir.toFile();
new ClassPathResource("datavec-spark/csvsequence/").copyDirectory(f);
String path = f.getAbsolutePath() + "/*";
@ -120,10 +121,10 @@ public class TestSequenceRecordReaderFunction extends BaseSparkTest {
@Test
public void testSequenceRecordReaderFunctionVideo() throws Exception {
public void testSequenceRecordReaderFunctionVideo(@TempDir Path testDir) throws Exception {
JavaSparkContext sc = getContext();
File f = testDir.newFolder();
File f = testDir.toFile();
new ClassPathResource("datavec-spark/video/").copyDirectory(f);
String path = f.getAbsolutePath() + "/*";

View File

@ -22,7 +22,7 @@ package org.datavec.spark.functions;
import org.datavec.api.writable.*;
import org.datavec.spark.transform.misc.WritablesToNDArrayFunction;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -30,7 +30,7 @@ import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestWritablesToNDArrayFunction {

View File

@ -29,14 +29,14 @@ import org.datavec.api.writable.Writable;
import org.datavec.spark.BaseSparkTest;
import org.datavec.spark.transform.misc.SequenceWritablesToStringFunction;
import org.datavec.spark.transform.misc.WritablesToStringFunction;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import scala.Tuple2;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestWritablesToStringFunctions extends BaseSparkTest {

View File

@ -26,7 +26,7 @@ import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.datavec.api.writable.*;
import org.datavec.spark.BaseSparkTest;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.factory.Nd4j;
import java.io.File;
@ -35,8 +35,8 @@ import java.util.Arrays;
import java.util.List;
import java.util.Map;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class TestSparkStorageUtils extends BaseSparkTest {

View File

@ -30,13 +30,13 @@ import org.datavec.api.util.ndarray.RecordConverter;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.Writable;
import org.datavec.spark.BaseSparkTest;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.util.*;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class DataFramesTests extends BaseSparkTest {

View File

@ -28,7 +28,7 @@ import org.datavec.api.util.ndarray.RecordConverter;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.Writable;
import org.datavec.spark.BaseSparkTest;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
@ -41,7 +41,7 @@ import java.util.ArrayList;
import java.util.List;
import static junit.framework.TestCase.assertTrue;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class NormalizationTests extends BaseSparkTest {

View File

@ -38,7 +38,7 @@ import org.datavec.local.transforms.AnalyzeLocal;
import org.datavec.spark.BaseSparkTest;
import org.datavec.spark.transform.AnalyzeSpark;
import org.joda.time.DateTimeZone;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.io.ClassPathResource;
@ -47,7 +47,7 @@ import java.io.File;
import java.nio.file.Files;
import java.util.*;
import static org.junit.Assert.*;
import static org.junit.jupiter.api.Assertions.*;
public class TestAnalysis extends BaseSparkTest {

View File

@ -27,11 +27,11 @@ import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.*;
import org.datavec.spark.BaseSparkTest;
import org.datavec.spark.transform.SparkTransformExecutor;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import java.util.*;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestJoin extends BaseSparkTest {

View File

@ -30,13 +30,13 @@ import org.datavec.api.writable.Writable;
import org.datavec.api.writable.comparator.DoubleWritableComparator;
import org.datavec.spark.BaseSparkTest;
import org.datavec.spark.transform.SparkTransformExecutor;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestCalculateSortedRank extends BaseSparkTest {

View File

@ -29,14 +29,14 @@ import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.datavec.spark.BaseSparkTest;
import org.datavec.spark.transform.SparkTransformExecutor;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class TestConvertToSequence extends BaseSparkTest {

View File

@ -28,7 +28,7 @@ import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.datavec.spark.BaseSparkTest;
import org.datavec.spark.transform.utils.SparkUtils;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import java.io.File;
import java.io.FileInputStream;
@ -36,7 +36,7 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestSparkUtil extends BaseSparkTest {

View File

@ -20,7 +20,6 @@
package org.deeplearning4j;
import ch.qos.logback.classic.LoggerContext;
import lombok.extern.slf4j.Slf4j;
import org.bytedeco.javacpp.Pointer;
import org.junit.jupiter.api.*;
@ -32,6 +31,7 @@ import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.profiler.ProfilerConfig;
import org.slf4j.ILoggerFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.lang.management.ManagementFactory;
import java.util.List;
@ -39,13 +39,11 @@ import java.util.Map;
import java.util.Properties;
import static org.junit.jupiter.api.Assumptions.assumeTrue;
import org.junit.jupiter.api.extension.ExtendWith;
@Slf4j
@DisplayName("Base DL 4 J Test")
public abstract class BaseDL4JTest {
private static Logger log = LoggerFactory.getLogger(BaseDL4JTest.class.getName());
protected long startTime;

View File

@ -43,7 +43,7 @@ import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.*;
import static org.junit.Assert.*;
import static org.junit.jupiter.api.Assertions.*;
@Slf4j
public class LayerHelperValidationUtil {
@ -145,7 +145,7 @@ public class LayerHelperValidationUtil {
System.out.println(p1);
System.out.println(p2);
}
assertTrue(s + " - param changed during forward pass: " + p, maxRE < t.getMaxRelError());
assertTrue(maxRE < t.getMaxRelError(),s + " - param changed during forward pass: " + p);
}
for( int i=0; i<ff1.size(); i++ ){
@ -163,7 +163,7 @@ public class LayerHelperValidationUtil {
double d2 = arr2.dup('c').getDouble(idx);
System.out.println("Different values at index " + idx + ": " + d1 + ", " + d2 + " - RE = " + maxRE);
}
assertTrue(s + layerName + " activations - max RE: " + maxRE, maxRE < t.getMaxRelError());
assertTrue(maxRE < t.getMaxRelError(), s + layerName + " activations - max RE: " + maxRE);
log.info("Forward pass, max relative error: " + layerName + " - " + maxRE);
}
@ -180,7 +180,7 @@ public class LayerHelperValidationUtil {
log.info(s + "Output, max relative error: " + maxRE);
assertEquals(net1NoHelper.params(), net2With.params()); //Check that forward pass does not modify params
assertTrue(s + "Max RE: " + maxRE, maxRE < t.getMaxRelError());
assertTrue(maxRE < t.getMaxRelError(), s + "Max RE: " + maxRE);
}
}
@ -201,7 +201,7 @@ public class LayerHelperValidationUtil {
double re = relError(s1, s2);
String s = "Relative error: " + re;
assertTrue(s, re < t.getMaxRelError());
assertTrue(re < t.getMaxRelError(), s);
}
if(t.isTestBackward()) {
@ -243,8 +243,8 @@ public class LayerHelperValidationUtil {
} else {
System.out.println("OK: " + p);
}
assertTrue(t.getTestName() + " - Gradients are not equal: " + p + " - highest relative error = " + maxRE + " > max relative error = " + t.getMaxRelError(),
maxRE < t.getMaxRelError());
assertTrue(maxRE < t.getMaxRelError(),
t.getTestName() + " - Gradients are not equal: " + p + " - highest relative error = " + maxRE + " > max relative error = " + t.getMaxRelError());
}
}
@ -283,7 +283,7 @@ public class LayerHelperValidationUtil {
double d2 = listNew.get(j);
double re = relError(d1, d2);
String msg = "Scores at iteration " + j + " - relError = " + re + ", score1 = " + d1 + ", score2 = " + d2;
assertTrue(msg, re < t.getMaxRelError());
assertTrue(re < t.getMaxRelError(), msg);
System.out.println("j=" + j + ", d1 = " + d1 + ", d2 = " + d2);
}
}
@ -315,7 +315,7 @@ public class LayerHelperValidationUtil {
try {
if (keepAndAssertPresent) {
Object o = f.get(l);
assertNotNull("Expect helper to be present for layer: " + l.getClass(), o);
assertNotNull(o,"Expect helper to be present for layer: " + l.getClass());
} else {
f.set(l, null);
Integer i = map.get(l.getClass());

View File

@ -26,8 +26,8 @@ import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.nd4j.common.resources.Resources;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
@ -38,7 +38,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.nio.file.Files;
import java.util.concurrent.CountDownLatch;
@Ignore
@Disabled
public class RandomTests extends BaseDL4JTest {
@Test

View File

@ -50,8 +50,8 @@ import java.lang.reflect.Field;
import java.util.List;
import java.util.Random;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
public class TestUtils {

View File

@ -27,7 +27,7 @@ import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.junit.rules.Timeout;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
import org.nd4j.linalg.dataset.DataSet;

View File

@ -23,7 +23,7 @@ package org.deeplearning4j.datasets;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.fetchers.Cifar10Fetcher;
import org.deeplearning4j.datasets.fetchers.TinyImageNetFetcher;
import org.junit.Test;
import org.junit.jupiter.api.Test;
public class TestDataSets extends BaseDL4JTest {

View File

@ -19,7 +19,7 @@
*/
package org.deeplearning4j.datasets.datavec;
import org.junit.rules.Timeout;
import org.nd4j.shade.guava.io.Files;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
@ -47,7 +47,7 @@ import org.deeplearning4j.datasets.datavec.exception.ZeroLengthSequenceException
import org.deeplearning4j.datasets.datavec.tools.SpecialImageRecordReader;
import org.nd4j.linalg.dataset.AsyncDataSetIterator;
import org.junit.jupiter.api.Disabled;
import org.junit.Rule;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.linalg.api.buffer.DataType;
@ -74,9 +74,6 @@ import static org.junit.jupiter.api.Assertions.assertThrows;
@DisplayName("Record Reader Data Setiterator Test")
class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
@Rule
public Timeout timeout = Timeout.seconds(300);
@Override
public DataType getDataType() {
return DataType.FLOAT;

View File

@ -19,7 +19,7 @@
*/
package org.deeplearning4j.datasets.datavec;
import org.junit.rules.Timeout;
import org.nd4j.shade.guava.io.Files;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
@ -44,7 +44,7 @@ import org.datavec.api.writable.Writable;
import org.datavec.image.recordreader.ImageRecordReader;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils;
import org.junit.Rule;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -73,8 +73,7 @@ class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest {
@TempDir
public Path temporaryFolder;
@Rule
public Timeout timeout = Timeout.seconds(300);
@Test
@DisplayName("Tests Basic")

View File

@ -20,9 +20,9 @@
package org.deeplearning4j.datasets.fetchers;
import org.deeplearning4j.BaseDL4JTest;
import org.junit.Rule;
import org.junit.jupiter.api.Test;
import org.junit.rules.Timeout;
import java.io.File;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assumptions.assumeTrue;

View File

@ -21,14 +21,14 @@
package org.deeplearning4j.datasets.iterator;
import org.deeplearning4j.BaseDL4JTest;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
import org.nd4j.linalg.dataset.api.preprocessor.MultiNormalizerMinMaxScaler;
import org.nd4j.linalg.factory.Nd4j;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class CombinedPreProcessorTests extends BaseDL4JTest {

View File

@ -23,7 +23,7 @@ package org.deeplearning4j.datasets.iterator;
import lombok.val;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.tools.DataSetGenerator;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
@ -32,7 +32,7 @@ import java.util.Collections;
import java.util.List;
import java.util.Random;
import static org.junit.Assert.*;
import static org.junit.jupiter.api.Assertions.*;
public class DataSetSplitterTests extends BaseDL4JTest {
@Test
@ -54,7 +54,7 @@ public class DataSetSplitterTests extends BaseDL4JTest {
while (train.hasNext()) {
val data = train.next().getFeatures();
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5);
assertEquals( (float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
gcntTrain++;
global++;
}
@ -64,7 +64,7 @@ public class DataSetSplitterTests extends BaseDL4JTest {
while (test.hasNext()) {
val data = test.next().getFeatures();
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5);
assertEquals( (float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
gcntTest++;
global++;
}
@ -94,7 +94,7 @@ public class DataSetSplitterTests extends BaseDL4JTest {
while (train.hasNext()) {
val data = train.next().getFeatures();
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5);
assertEquals((float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
gcntTrain++;
global++;
}
@ -104,7 +104,7 @@ public class DataSetSplitterTests extends BaseDL4JTest {
if (e % 2 == 0)
while (test.hasNext()) {
val data = test.next().getFeatures();
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5);
assertEquals( (float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
gcntTest++;
global++;
}
@ -113,8 +113,9 @@ public class DataSetSplitterTests extends BaseDL4JTest {
assertEquals(700 * numEpochs + (300 * numEpochs / 2), global);
}
@Test(expected = ND4JIllegalStateException.class)
@Test()
public void testSplitter_3() throws Exception {
assertThrows(ND4JIllegalStateException.class, () -> {
val back = new DataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
val splitter = new DataSetIteratorSplitter(back, 1000, 0.7);
@ -132,7 +133,7 @@ public class DataSetSplitterTests extends BaseDL4JTest {
while (train.hasNext()) {
val data = train.next().getFeatures();
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5);
assertEquals((float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
gcntTrain++;
global++;
}
@ -142,7 +143,7 @@ public class DataSetSplitterTests extends BaseDL4JTest {
while (test.hasNext()) {
val data = test.next().getFeatures();
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5);
assertEquals((float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
gcntTest++;
global++;
}
@ -153,6 +154,9 @@ public class DataSetSplitterTests extends BaseDL4JTest {
}
assertEquals(1000 * numEpochs, global);
});
}
@Test
@ -172,8 +176,7 @@ public class DataSetSplitterTests extends BaseDL4JTest {
partIterator.reset();
while (partIterator.hasNext()) {
val data = partIterator.next().getFeatures();
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e,
(float) perEpoch, data.getFloat(0), 1e-5);
assertEquals((float) perEpoch, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
//gcntTrain++;
global++;
cnt++;
@ -206,8 +209,7 @@ public class DataSetSplitterTests extends BaseDL4JTest {
int cnt = 0;
val data = partIterator.next().getFeatures();
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e,
(float) perEpoch, data.getFloat(0), 1e-5);
assertEquals((float) perEpoch, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
//gcntTrain++;
global++;
cnt++;
@ -247,10 +249,10 @@ public class DataSetSplitterTests extends BaseDL4JTest {
val ds = trainIter.next();
assertNotNull(ds);
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures().getDouble(0), 1e-5f);
assertEquals(globalIter, ds.getFeatures().getDouble(0), 1e-5f,"Failed at iteration [" + globalIter + "]");
globalIter++;
}
assertTrue("Failed at epoch [" + e + "]", trained);
assertTrue(trained,"Failed at epoch [" + e + "]");
assertEquals(800, globalIter);
@ -262,10 +264,10 @@ public class DataSetSplitterTests extends BaseDL4JTest {
val ds = testIter.next();
assertNotNull(ds);
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures().getDouble(0), 1e-5f);
assertEquals(globalIter, ds.getFeatures().getDouble(0), 1e-5f,"Failed at iteration [" + globalIter + "]");
globalIter++;
}
assertTrue("Failed at epoch [" + e + "]", tested);
assertTrue(tested,"Failed at epoch [" + e + "]");
assertEquals(900, globalIter);
// validation set is used every 5 epochs
@ -277,10 +279,10 @@ public class DataSetSplitterTests extends BaseDL4JTest {
val ds = validationIter.next();
assertNotNull(ds);
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures().getDouble(0), 1e-5f);
assertEquals(globalIter, ds.getFeatures().getDouble(0), 1e-5f,"Failed at iteration [" + globalIter + "]");
globalIter++;
}
assertTrue("Failed at epoch [" + e + "]", validated);
assertTrue(validated,"Failed at epoch [" + e + "]");
}
// all 3 iterators have exactly 1000 elements combined
@ -312,7 +314,7 @@ public class DataSetSplitterTests extends BaseDL4JTest {
int farCnt = (1000 / 2) * (partNumber) + cnt;
val data = iteratorList.get(partNumber).next().getFeatures();
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) farCnt, data.getFloat(0), 1e-5);
assertEquals((float) farCnt, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
cnt++;
global++;
}
@ -322,7 +324,7 @@ public class DataSetSplitterTests extends BaseDL4JTest {
while (iteratorList.get(0).hasNext()) {
val data = iteratorList.get(0).next().getFeatures();
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5);
assertEquals((float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
global++;
}
}
@ -341,7 +343,7 @@ public class DataSetSplitterTests extends BaseDL4JTest {
while (iteratorList.get(partNumber).hasNext()) {
val data = iteratorList.get(partNumber).next().getFeatures();
assertEquals("Train failed on iteration " + cnt, (float) (500*partNumber + cnt), data.getFloat(0), 1e-5);
assertEquals( (float) (500*partNumber + cnt), data.getFloat(0), 1e-5,"Train failed on iteration " + cnt);
cnt++;
}
}
@ -365,7 +367,7 @@ public class DataSetSplitterTests extends BaseDL4JTest {
while (iteratorList.get(partNumber).hasNext()) {
val data = iteratorList.get(partNumber).next().getFeatures();
assertEquals("Train failed on iteration " + cnt, (float) (500*partNumber + cnt), data.getFloat(0), 1e-5);
assertEquals( (float) (500*partNumber + cnt), data.getFloat(0), 1e-5,"Train failed on iteration " + cnt);
cnt++;
}
}
@ -390,7 +392,7 @@ public class DataSetSplitterTests extends BaseDL4JTest {
val ds = validationIter.next();
assertNotNull(ds);
assertEquals("Validation failed on iteration " + valCnt, (float) valCnt + 90, ds.getFeatures().getFloat(0), 1e-5);
assertEquals((float) valCnt + 90, ds.getFeatures().getFloat(0), 1e-5,"Validation failed on iteration " + valCnt);
valCnt++;
}
assertEquals(5, valCnt);

View File

@ -25,15 +25,15 @@ import lombok.val;
import lombok.var;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.tools.SimpleVariableGenerator;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.dataset.api.DataSet;
import java.util.ArrayList;
import java.util.Arrays;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
@Slf4j
public class DummyBlockDataSetIteratorTests extends BaseDL4JTest {

View File

@ -21,7 +21,7 @@ package org.deeplearning4j.datasets.iterator;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.junit.Rule;
import org.junit.jupiter.api.Test;
import org.junit.rules.ExpectedException;
import org.nd4j.linalg.dataset.DataSet;
@ -43,8 +43,7 @@ class EarlyTerminationDataSetIteratorTest extends BaseDL4JTest {
int numExamples = 105;
@Rule
public final ExpectedException exception = ExpectedException.none();
@Test
@DisplayName("Test Next And Reset")
@ -86,14 +85,16 @@ class EarlyTerminationDataSetIteratorTest extends BaseDL4JTest {
}
@Test
@DisplayName("Test Callsto Next Not Allowed")
@DisplayName("Test calls to Next Not Allowed")
void testCallstoNextNotAllowed() throws IOException {
assertThrows(RuntimeException.class,() -> {
int terminateAfter = 1;
DataSetIterator iter = new MnistDataSetIterator(minibatchSize, numExamples);
EarlyTerminationDataSetIterator earlyEndIter = new EarlyTerminationDataSetIterator(iter, terminateAfter);
earlyEndIter.next(10);
iter.reset();
exception.expect(RuntimeException.class);
earlyEndIter.next(10);
});
}
}

View File

@ -21,7 +21,7 @@ package org.deeplearning4j.datasets.iterator;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.junit.Rule;
import org.junit.jupiter.api.Test;
import org.junit.rules.ExpectedException;
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
@ -30,11 +30,12 @@ import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.jupiter.api.Assertions.*;
@DisplayName("Early Termination Multi Data Set Iterator Test")
class EarlyTerminationMultiDataSetIteratorTest extends BaseDL4JTest {
@ -42,8 +43,7 @@ class EarlyTerminationMultiDataSetIteratorTest extends BaseDL4JTest {
int numExamples = 105;
@Rule
public final ExpectedException exception = ExpectedException.none();
@Test
@DisplayName("Test Next And Reset")
@ -91,14 +91,16 @@ class EarlyTerminationMultiDataSetIteratorTest extends BaseDL4JTest {
}
@Test
@DisplayName("Test Callsto Next Not Allowed")
@DisplayName("Test calls to Next Not Allowed")
void testCallstoNextNotAllowed() throws IOException {
assertThrows(RuntimeException.class,() -> {
int terminateAfter = 1;
MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(new MnistDataSetIterator(minibatchSize, numExamples));
EarlyTerminationMultiDataSetIterator earlyEndIter = new EarlyTerminationMultiDataSetIterator(iter, terminateAfter);
earlyEndIter.next(10);
iter.reset();
exception.expect(RuntimeException.class);
earlyEndIter.next(10);
});
}
}

View File

@ -24,14 +24,16 @@ import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.tools.DataSetGenerator;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import static org.junit.Assert.*;
import static org.junit.jupiter.api.Assertions.*;
@Slf4j
public class JointMultiDataSetIteratorTests extends BaseDL4JTest {
@Test (timeout = 20000L)
@Test ()
@Timeout(20000L)
public void testJMDSI_1() {
val iter0 = new DataSetGenerator(32, new int[]{3, 3}, new int[]{2, 2});
val iter1 = new DataSetGenerator(32, new int[]{3, 3, 3}, new int[]{2, 2, 2});
@ -75,7 +77,8 @@ public class JointMultiDataSetIteratorTests extends BaseDL4JTest {
}
@Test (timeout = 20000L)
@Test ()
@Timeout(20000L)
public void testJMDSI_2() {
val iter0 = new DataSetGenerator(32, new int[]{3, 3}, new int[]{2, 2});
val iter1 = new DataSetGenerator(32, new int[]{3, 3, 3}, new int[]{2, 2, 2});

View File

@ -23,7 +23,7 @@ package org.deeplearning4j.datasets.iterator;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.loader.DataSetLoaderIterator;
import org.deeplearning4j.datasets.iterator.loader.MultiDataSetLoaderIterator;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.loader.Loader;
import org.nd4j.common.loader.LocalFileSourceFactory;
import org.nd4j.common.loader.Source;
@ -39,8 +39,8 @@ import java.util.Arrays;
import java.util.List;
import java.util.Random;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class LoaderIteratorTests extends BaseDL4JTest {

View File

@ -24,7 +24,7 @@ import lombok.val;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.tools.DataSetGenerator;
import org.deeplearning4j.datasets.iterator.tools.MultiDataSetGenerator;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
@ -32,7 +32,7 @@ import org.nd4j.linalg.exception.ND4JIllegalStateException;
import java.util.List;
import java.util.Random;
import static org.junit.Assert.*;
import static org.junit.jupiter.api.Assertions.*;
public class MultiDataSetSplitterTests extends BaseDL4JTest {
@ -55,7 +55,7 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
while (train.hasNext()) {
val data = train.next().getFeatures(0);
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5);
assertEquals( (float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
gcntTrain++;
global++;
}
@ -65,7 +65,7 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
while (test.hasNext()) {
val data = test.next().getFeatures(0);
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5);
assertEquals( (float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
gcntTest++;
global++;
}
@ -96,7 +96,7 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
while (train.hasNext()) {
val data = train.next().getFeatures(0);
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5);
assertEquals( (float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
gcntTrain++;
global++;
}
@ -106,7 +106,7 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
if (e % 2 == 0)
while (test.hasNext()) {
val data = test.next().getFeatures(0);
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5);
assertEquals((float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
gcntTest++;
global++;
}
@ -115,8 +115,9 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
assertEquals(700 * numEpochs + (300 * numEpochs / 2), global);
}
@Test(expected = ND4JIllegalStateException.class)
@Test()
public void testSplitter_3() throws Exception {
assertThrows(ND4JIllegalStateException.class,() -> {
val back = new MultiDataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
val splitter = new MultiDataSetIteratorSplitter(back, 1000, 0.7);
@ -134,7 +135,7 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
while (train.hasNext()) {
val data = train.next().getFeatures(0);
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5);
assertEquals((float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
gcntTrain++;
global++;
}
@ -144,7 +145,7 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
while (test.hasNext()) {
val data = test.next().getFeatures(0);
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5);
assertEquals( (float) cnt++, data.getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
gcntTest++;
global++;
}
@ -155,6 +156,8 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
}
assertEquals(1000 * numEpochs, global);
});
}
@Test
@ -185,11 +188,11 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
assertNotNull(ds);
for (int i = 0; i < ds.getFeatures().length; ++i) {
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f);
assertEquals( (double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f,"Failed at iteration [" + globalIter + "]");
}
globalIter++;
}
assertTrue("Failed at epoch [" + e + "]", trained);
assertTrue(trained,"Failed at epoch [" + e + "]");
assertEquals(800, globalIter);
@ -202,11 +205,11 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
assertNotNull(ds);
for (int i = 0; i < ds.getFeatures().length; ++i) {
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f);
assertEquals((double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f,"Failed at iteration [" + globalIter + "]");
}
globalIter++;
}
assertTrue("Failed at epoch [" + e + "]", tested);
assertTrue(tested,"Failed at epoch [" + e + "]");
assertEquals(900, globalIter);
// validation set is used every 5 epochs
@ -219,11 +222,11 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
assertNotNull(ds);
for (int i = 0; i < ds.getFeatures().length; ++i) {
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f);
assertEquals( (double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f,"Failed at iteration [" + globalIter + "]");
}
globalIter++;
}
assertTrue("Failed at epoch [" + e + "]", validated);
assertTrue(validated,"Failed at epoch [" + e + "]");
}
// all 3 iterators have exactly 1000 elements combined
@ -256,8 +259,7 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
val data = partIterator.next().getFeatures();
for (int i = 0; i < data.length; ++i) {
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e,
(float) perEpoch, data[i].getFloat(0), 1e-5);
assertEquals((float) perEpoch, data[i].getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
}
//gcntTrain++;
global++;
@ -299,12 +301,12 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
assertNotNull(ds);
for (int i = 0; i < ds.getFeatures().length; ++i) {
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter,
ds.getFeatures()[i].getDouble(0), 1e-5f);
assertEquals((double) globalIter,
ds.getFeatures()[i].getDouble(0), 1e-5f,"Failed at iteration [" + globalIter + "]");
}
globalIter++;
}
assertTrue("Failed at epoch [" + e + "]", trained);
assertTrue(trained,"Failed at epoch [" + e + "]");
assertEquals(800, globalIter);
@ -316,11 +318,11 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
val ds = testIter.next();
assertNotNull(ds);
for (int i = 0; i < ds.getFeatures().length; ++i) {
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f);
assertEquals((double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f,"Failed at iteration [" + globalIter + "]");
}
globalIter++;
}
assertTrue("Failed at epoch [" + e + "]", tested);
assertTrue(tested,"Failed at epoch [" + e + "]");
assertEquals(900, globalIter);
// validation set is used every 5 epochs
@ -333,12 +335,12 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
assertNotNull(ds);
for (int i = 0; i < ds.getFeatures().length; ++i) {
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter,
ds.getFeatures()[i].getDouble(0), 1e-5f);
assertEquals((double) globalIter,
ds.getFeatures()[i].getDouble(0), 1e-5f,"Failed at iteration [" + globalIter + "]");
}
globalIter++;
}
assertTrue("Failed at epoch [" + e + "]", validated);
assertTrue(validated,"Failed at epoch [" + e + "]");
}
// all 3 iterators have exactly 1000 elements combined
@ -370,7 +372,7 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
int farCnt = (1000 / 2) * (partNumber) + cnt;
val data = iteratorList.get(partNumber).next().getFeatures();
for (int i = 0; i < data.length; ++i) {
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) farCnt, data[i].getFloat(0), 1e-5);
assertEquals( (float) farCnt, data[i].getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
}
cnt++;
global++;
@ -381,8 +383,8 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
while (iteratorList.get(0).hasNext()) {
val data = iteratorList.get(0).next().getFeatures();
for (int i = 0; i < data.length; ++i) {
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++,
data[i].getFloat(0), 1e-5);
assertEquals((float) cnt++,
data[i].getFloat(0), 1e-5,"Train failed on iteration " + cnt + "; epoch: " + e);
}
global++;
}
@ -402,7 +404,7 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
while (iteratorList.get(partNumber).hasNext()) {
val data = iteratorList.get(partNumber).next().getFeatures();
for (int i = 0; i < data.length; ++i) {
assertEquals("Train failed on iteration " + cnt, (float) (500 * partNumber + cnt), data[i].getFloat(0), 1e-5);
assertEquals( (float) (500 * partNumber + cnt), data[i].getFloat(0), 1e-5,"Train failed on iteration " + cnt);
}
cnt++;
}
@ -427,8 +429,8 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
while (iteratorList.get(partNumber).hasNext()) {
val data = iteratorList.get(partNumber).next().getFeatures();
for (int i = 0; i < data.length; ++i) {
assertEquals("Train failed on iteration " + cnt, (float) (500 * partNumber + cnt),
data[i].getFloat(0), 1e-5);
assertEquals( (float) (500 * partNumber + cnt),
data[i].getFloat(0), 1e-5,"Train failed on iteration " + cnt);
}
cnt++;
}
@ -454,8 +456,8 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
val ds = validationIter.next();
assertNotNull(ds);
for (int i = 0; i < ds.getFeatures().length; ++i) {
assertEquals("Validation failed on iteration " + valCnt, (float) valCnt + 90,
ds.getFeatures()[i].getFloat(0), 1e-5);
assertEquals((float) valCnt + 90,
ds.getFeatures()[i].getFloat(0), 1e-5,"Validation failed on iteration " + valCnt);
}
valCnt++;
}

View File

@ -25,9 +25,9 @@ import org.datavec.api.split.FileSplit;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.util.TestDataSetConsumer;
import org.junit.Rule;
import org.junit.jupiter.api.Test;
import org.junit.rules.Timeout;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
@ -42,9 +42,6 @@ import org.junit.jupiter.api.extension.ExtendWith;
@DisplayName("Multiple Epochs Iterator Test")
class MultipleEpochsIteratorTest extends BaseDL4JTest {
@Rule
public Timeout timeout = Timeout.seconds(300);
@Test
@DisplayName("Test Next And Reset")
void testNextAndReset() throws Exception {

View File

@ -22,8 +22,8 @@ package org.deeplearning4j.datasets.iterator;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
@ -33,9 +33,9 @@ import org.nd4j.linalg.factory.Nd4j;
import java.util.List;
import static junit.framework.TestCase.assertTrue;
import static org.junit.Assert.*;
import static org.junit.jupiter.api.Assertions.*;
@Ignore
@Disabled
public class TestAsyncIterator extends BaseDL4JTest {
@Test

Some files were not shown because too many files have changed in this diff Show More