Fix rnn parameterized tests

master
agibsonccc 2021-03-18 15:49:27 +09:00
parent 7bd1c5cbaa
commit d1989b8529
35 changed files with 421 additions and 308 deletions

View File

@ -20,6 +20,9 @@
package org.deeplearning4j.datasets.datavec; package org.deeplearning4j.datasets.datavec;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.shade.guava.io.Files; import org.nd4j.shade.guava.io.Files;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
@ -72,6 +75,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows;
@Slf4j @Slf4j
@DisplayName("Record Reader Data Setiterator Test") @DisplayName("Record Reader Data Setiterator Test")
@Disabled
class RecordReaderDataSetiteratorTest extends BaseDL4JTest { class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
@Override @Override
@ -82,9 +86,10 @@ class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
@TempDir @TempDir
public Path temporaryFolder; public Path temporaryFolder;
@Test @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@DisplayName("Test Record Reader") @DisplayName("Test Record Reader")
void testRecordReader() throws Exception { void testRecordReader(Nd4jBackend nd4jBackend) throws Exception {
RecordReader recordReader = new CSVRecordReader(); RecordReader recordReader = new CSVRecordReader();
FileSplit csv = new FileSplit(Resources.asFile("csv-example.csv")); FileSplit csv = new FileSplit(Resources.asFile("csv-example.csv"));
recordReader.initialize(csv); recordReader.initialize(csv);
@ -93,9 +98,10 @@ class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
assertEquals(34, next.numExamples()); assertEquals(34, next.numExamples());
} }
@Test @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@DisplayName("Test Record Reader Max Batch Limit") @DisplayName("Test Record Reader Max Batch Limit")
void testRecordReaderMaxBatchLimit() throws Exception { void testRecordReaderMaxBatchLimit(Nd4jBackend backend) throws Exception {
RecordReader recordReader = new CSVRecordReader(); RecordReader recordReader = new CSVRecordReader();
FileSplit csv = new FileSplit(Resources.asFile("csv-example.csv")); FileSplit csv = new FileSplit(Resources.asFile("csv-example.csv"));
recordReader.initialize(csv); recordReader.initialize(csv);
@ -108,9 +114,10 @@ class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
assertEquals(false, iter.hasNext()); assertEquals(false, iter.hasNext());
} }
@Test @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@DisplayName("Test Record Reader Multi Regression") @DisplayName("Test Record Reader Multi Regression")
void testRecordReaderMultiRegression() throws Exception { void testRecordReaderMultiRegression(Nd4jBackend backend) throws Exception {
for (boolean builder : new boolean[] { false, true }) { for (boolean builder : new boolean[] { false, true }) {
RecordReader csv = new CSVRecordReader(); RecordReader csv = new CSVRecordReader();
csv.initialize(new FileSplit(Resources.asFile("iris.txt"))); csv.initialize(new FileSplit(Resources.asFile("iris.txt")));
@ -138,9 +145,10 @@ class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
} }
} }
@Test @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@DisplayName("Test Sequence Record Reader") @DisplayName("Test Sequence Record Reader")
void testSequenceRecordReader() throws Exception { void testSequenceRecordReader(Nd4jBackend backend) throws Exception {
File rootDir = temporaryFolder.toFile(); File rootDir = temporaryFolder.toFile();
// need to manually extract // need to manually extract
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
@ -217,9 +225,10 @@ class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
assertEquals(dsList.get(2).getLabels(), expL2); assertEquals(dsList.get(2).getLabels(), expL2);
} }
@Test @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@DisplayName("Test Sequence Record Reader Meta") @DisplayName("Test Sequence Record Reader Meta")
void testSequenceRecordReaderMeta() throws Exception { void testSequenceRecordReaderMeta(Nd4jBackend backend) throws Exception {
File rootDir = temporaryFolder.toFile(); File rootDir = temporaryFolder.toFile();
// need to manually extract // need to manually extract
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
@ -244,9 +253,10 @@ class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
} }
} }
@Test @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@DisplayName("Test Sequence Record Reader Regression") @DisplayName("Test Sequence Record Reader Regression")
void testSequenceRecordReaderRegression() throws Exception { void testSequenceRecordReaderRegression(Nd4jBackend backend) throws Exception {
// need to manually extract // need to manually extract
File rootDir = temporaryFolder.toFile(); File rootDir = temporaryFolder.toFile();
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
@ -296,9 +306,10 @@ class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
assertEquals(3, count); assertEquals(3, count);
} }
@Test @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@DisplayName("Test Sequence Record Reader Multi Regression") @DisplayName("Test Sequence Record Reader Multi Regression")
void testSequenceRecordReaderMultiRegression() throws Exception { void testSequenceRecordReaderMultiRegression(Nd4jBackend backend) throws Exception {
File rootDir = temporaryFolder.toFile(); File rootDir = temporaryFolder.toFile();
// need to manually extract // need to manually extract
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
@ -351,9 +362,10 @@ class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
assertEquals(3, count); assertEquals(3, count);
} }
@Test @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@DisplayName("Test Sequence Record Reader Reset") @DisplayName("Test Sequence Record Reader Reset")
void testSequenceRecordReaderReset() throws Exception { void testSequenceRecordReaderReset(Nd4jBackend backend) throws Exception {
File rootDir = temporaryFolder.toFile(); File rootDir = temporaryFolder.toFile();
// need to manually extract // need to manually extract
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
@ -385,9 +397,10 @@ class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
} }
} }
@Test @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@DisplayName("Test CSV Loading Regression") @DisplayName("Test CSV Loading Regression")
void testCSVLoadingRegression() throws Exception { void testCSVLoadingRegression(Nd4jBackend backend) throws Exception {
int nLines = 30; int nLines = 30;
int nFeatures = 5; int nFeatures = 5;
int miniBatchSize = 10; int miniBatchSize = 10;
@ -447,9 +460,10 @@ class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
return new Pair<>(dArr, temp); return new Pair<>(dArr, temp);
} }
@Test @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@DisplayName("Test Variable Length Sequence") @DisplayName("Test Variable Length Sequence")
void testVariableLengthSequence() throws Exception { void testVariableLengthSequence(Nd4jBackend backend) throws Exception {
File rootDir = temporaryFolder.toFile(); File rootDir = temporaryFolder.toFile();
// need to manually extract // need to manually extract
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
@ -582,9 +596,10 @@ class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
} }
} }
@Test @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@DisplayName("Test Sequence Record Reader Single Reader") @DisplayName("Test Sequence Record Reader Single Reader")
void testSequenceRecordReaderSingleReader() throws Exception { void testSequenceRecordReaderSingleReader(Nd4jBackend backend) throws Exception {
File rootDir = temporaryFolder.toFile(); File rootDir = temporaryFolder.toFile();
// need to manually extract // need to manually extract
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
@ -680,9 +695,10 @@ class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
assertEquals(1, iteratorRegression.totalOutcomes()); assertEquals(1, iteratorRegression.totalOutcomes());
} }
@Test @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@DisplayName("Test Sequence Record Reader Single Reader With Empty Sequence Throws") @DisplayName("Test Sequence Record Reader Single Reader With Empty Sequence Throws")
void testSequenceRecordReaderSingleReaderWithEmptySequenceThrows() { void testSequenceRecordReaderSingleReaderWithEmptySequenceThrows(Nd4jBackend backend) {
assertThrows(ZeroLengthSequenceException.class, () -> { assertThrows(ZeroLengthSequenceException.class, () -> {
SequenceRecordReader reader = new CSVSequenceRecordReader(1, ","); SequenceRecordReader reader = new CSVSequenceRecordReader(1, ",");
reader.initialize(new FileSplit(Resources.asFile("empty.txt"))); reader.initialize(new FileSplit(Resources.asFile("empty.txt")));
@ -690,9 +706,10 @@ class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
}); });
} }
@Test @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@DisplayName("Test Sequence Record Reader Two Readers With Empty Feature Sequence Throws") @DisplayName("Test Sequence Record Reader Two Readers With Empty Feature Sequence Throws")
void testSequenceRecordReaderTwoReadersWithEmptyFeatureSequenceThrows() { void testSequenceRecordReaderTwoReadersWithEmptyFeatureSequenceThrows(Nd4jBackend backend) {
assertThrows(ZeroLengthSequenceException.class, () -> { assertThrows(ZeroLengthSequenceException.class, () -> {
SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ","); SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
@ -702,9 +719,10 @@ class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
}); });
} }
@Test @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@DisplayName("Test Sequence Record Reader Two Readers With Empty Label Sequence Throws") @DisplayName("Test Sequence Record Reader Two Readers With Empty Label Sequence Throws")
void testSequenceRecordReaderTwoReadersWithEmptyLabelSequenceThrows() { void testSequenceRecordReaderTwoReadersWithEmptyLabelSequenceThrows(Nd4jBackend backend) {
assertThrows(ZeroLengthSequenceException.class, () -> { assertThrows(ZeroLengthSequenceException.class, () -> {
SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ","); SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ","); SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
@ -715,9 +733,10 @@ class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
}); });
} }
@Test @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@DisplayName("Test Sequence Record Reader Single Reader Meta Data") @DisplayName("Test Sequence Record Reader Single Reader Meta Data")
void testSequenceRecordReaderSingleReaderMetaData() throws Exception { void testSequenceRecordReaderSingleReaderMetaData(Nd4jBackend backend) throws Exception {
File rootDir = temporaryFolder.toFile(); File rootDir = temporaryFolder.toFile();
// need to manually extract // need to manually extract
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
@ -744,9 +763,10 @@ class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
} }
} }
@Test @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@DisplayName("Test Seq RRDSI Array Writable One Reader") @DisplayName("Test Seq RRDSI Array Writable One Reader")
void testSeqRRDSIArrayWritableOneReader() { void testSeqRRDSIArrayWritableOneReader(Nd4jBackend backend) {
List<List<Writable>> sequence1 = new ArrayList<>(); List<List<Writable>> sequence1 = new ArrayList<>();
sequence1.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] { 1, 2, 3 }, new long[] { 1, 3 })), new IntWritable(0))); sequence1.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] { 1, 2, 3 }, new long[] { 1, 3 })), new IntWritable(0)));
sequence1.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] { 4, 5, 6 }, new long[] { 1, 3 })), new IntWritable(1))); sequence1.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] { 4, 5, 6 }, new long[] { 1, 3 })), new IntWritable(1)));
@ -767,16 +787,17 @@ class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
assertEquals(expLabels, ds.getLabels()); assertEquals(expLabels, ds.getLabels());
} }
@Test @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@DisplayName("Test Seq RRDSI Array Writable One Reader Regression") @DisplayName("Test Seq RRDSI Array Writable One Reader Regression")
void testSeqRRDSIArrayWritableOneReaderRegression() { void testSeqRRDSIArrayWritableOneReaderRegression(Nd4jBackend backend) {
// Regression, where the output is an array writable // Regression, where the output is an array writable
List<List<Writable>> sequence1 = new ArrayList<>(); List<List<Writable>> sequence1 = new ArrayList<>();
sequence1.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] { 1, 2, 3 }, new long[] { 1, 3 })), new NDArrayWritable(Nd4j.create(new double[] { 100, 200, 300 }, new long[] { 1, 3 })))); sequence1.add(Arrays.asList(new NDArrayWritable(Nd4j.create(new double[] { 1, 2, 3 }, new long[] { 1, 3 })), new NDArrayWritable(Nd4j.create(new double[] { 100, 200, 300 }, new long[] { 1, 3 }))));
sequence1.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] { 4, 5, 6 }, new long[] { 1, 3 })), new NDArrayWritable(Nd4j.create(new double[] { 400, 500, 600 }, new long[] { 1, 3 })))); sequence1.add(Arrays.asList(new NDArrayWritable(Nd4j.create(new double[] { 4, 5, 6 }, new long[] { 1, 3 })), new NDArrayWritable(Nd4j.create(new double[] { 400, 500, 600 }, new long[] { 1, 3 }))));
List<List<Writable>> sequence2 = new ArrayList<>(); List<List<Writable>> sequence2 = new ArrayList<>();
sequence2.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] { 7, 8, 9 }, new long[] { 1, 3 })), new NDArrayWritable(Nd4j.create(new double[] { 700, 800, 900 }, new long[] { 1, 3 })))); sequence2.add(Arrays.asList(new NDArrayWritable(Nd4j.create(new double[] { 7, 8, 9 }, new long[] { 1, 3 })), new NDArrayWritable(Nd4j.create(new double[] { 700, 800, 900 }, new long[] { 1, 3 }))));
sequence2.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] { 10, 11, 12 }, new long[] { 1, 3 })), new NDArrayWritable(Nd4j.create(new double[] { 1000, 1100, 1200 }, new long[] { 1, 3 })))); sequence2.add(Arrays.asList(new NDArrayWritable(Nd4j.create(new double[] { 10, 11, 12 }, new long[] { 1, 3 })), new NDArrayWritable(Nd4j.create(new double[] { 1000, 1100, 1200 }, new long[] { 1, 3 }))));
SequenceRecordReader rr = new CollectionSequenceRecordReader(Arrays.asList(sequence1, sequence2)); SequenceRecordReader rr = new CollectionSequenceRecordReader(Arrays.asList(sequence1, sequence2));
SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(rr, 2, -1, 1, true); SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(rr, 2, -1, 1, true);
DataSet ds = iter.next(); DataSet ds = iter.next();
@ -791,16 +812,17 @@ class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
assertEquals(expLabels, ds.getLabels()); assertEquals(expLabels, ds.getLabels());
} }
@Test @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@DisplayName("Test Seq RRDSI Multiple Array Writables One Reader") @DisplayName("Test Seq RRDSI Multiple Array Writables One Reader")
void testSeqRRDSIMultipleArrayWritablesOneReader() { void testSeqRRDSIMultipleArrayWritablesOneReader(Nd4jBackend backend) {
// Input with multiple array writables: // Input with multiple array writables:
List<List<Writable>> sequence1 = new ArrayList<>(); List<List<Writable>> sequence1 = new ArrayList<>();
sequence1.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] { 1, 2, 3 }, new long[] { 1, 3 })), new NDArrayWritable(Nd4j.create(new double[] { 100, 200, 300 }, new long[] { 1, 3 })), new IntWritable(0))); sequence1.add(Arrays.asList(new NDArrayWritable(Nd4j.create(new double[] { 1, 2, 3 }, new long[] { 1, 3 })), new NDArrayWritable(Nd4j.create(new double[] { 100, 200, 300 }, new long[] { 1, 3 })), new IntWritable(0)));
sequence1.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] { 4, 5, 6 }, new long[] { 1, 3 })), new NDArrayWritable(Nd4j.create(new double[] { 400, 500, 600 }, new long[] { 1, 3 })), new IntWritable(1))); sequence1.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] { 4, 5, 6 }, new long[] { 1, 3 })), new NDArrayWritable(Nd4j.create(new double[] { 400, 500, 600 }, new long[] { 1, 3 })), new IntWritable(1)));
List<List<Writable>> sequence2 = new ArrayList<>(); List<List<Writable>> sequence2 = new ArrayList<>();
sequence2.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] { 7, 8, 9 }, new long[] { 1, 3 })), new NDArrayWritable(Nd4j.create(new double[] { 700, 800, 900 }, new long[] { 1, 3 })), new IntWritable(2))); sequence2.add(Arrays.asList(new NDArrayWritable(Nd4j.create(new double[] { 7, 8, 9 }, new long[] { 1, 3 })), new NDArrayWritable(Nd4j.create(new double[] { 700, 800, 900 }, new long[] { 1, 3 })), new IntWritable(2)));
sequence2.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] { 10, 11, 12 }, new long[] { 1, 3 })), new NDArrayWritable(Nd4j.create(new double[] { 1000, 1100, 1200 }, new long[] { 1, 3 })), new IntWritable(3))); sequence2.add(Arrays.asList(new NDArrayWritable(Nd4j.create(new double[] { 10, 11, 12 }, new long[] { 1, 3 })), new NDArrayWritable(Nd4j.create(new double[] { 1000, 1100, 1200 }, new long[] { 1, 3 })), new IntWritable(3)));
SequenceRecordReader rr = new CollectionSequenceRecordReader(Arrays.asList(sequence1, sequence2)); SequenceRecordReader rr = new CollectionSequenceRecordReader(Arrays.asList(sequence1, sequence2));
SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(rr, 2, 4, 2, false); SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(rr, 2, 4, 2, false);
DataSet ds = iter.next(); DataSet ds = iter.next();
@ -815,22 +837,23 @@ class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
assertEquals(expLabels, ds.getLabels()); assertEquals(expLabels, ds.getLabels());
} }
@Test @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@DisplayName("Test Seq RRDSI Array Writable Two Readers") @DisplayName("Test Seq RRDSI Array Writable Two Readers")
void testSeqRRDSIArrayWritableTwoReaders() { void testSeqRRDSIArrayWritableTwoReaders(Nd4jBackend backend) {
List<List<Writable>> sequence1 = new ArrayList<>(); List<List<Writable>> sequence1 = new ArrayList<>();
sequence1.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] { 1, 2, 3 }, new long[] { 1, 3 })), new IntWritable(100))); sequence1.add(Arrays.asList(new NDArrayWritable(Nd4j.create(new double[] { 1, 2, 3 }, new long[] { 1, 3 })), new IntWritable(100)));
sequence1.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] { 4, 5, 6 }, new long[] { 1, 3 })), new IntWritable(200))); sequence1.add(Arrays.asList(new NDArrayWritable(Nd4j.create(new double[] { 4, 5, 6 }, new long[] { 1, 3 })), new IntWritable(200)));
List<List<Writable>> sequence2 = new ArrayList<>(); List<List<Writable>> sequence2 = new ArrayList<>();
sequence2.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] { 7, 8, 9 }, new long[] { 1, 3 })), new IntWritable(300))); sequence2.add(Arrays.asList(new NDArrayWritable(Nd4j.create(new double[] { 7, 8, 9 }, new long[] { 1, 3 })), new IntWritable(300)));
sequence2.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] { 10, 11, 12 }, new long[] { 1, 3 })), new IntWritable(400))); sequence2.add(Arrays.asList(new NDArrayWritable(Nd4j.create(new double[] { 10, 11, 12 }, new long[] { 1, 3 })), new IntWritable(400)));
SequenceRecordReader rrFeatures = new CollectionSequenceRecordReader(Arrays.asList(sequence1, sequence2)); SequenceRecordReader rrFeatures = new CollectionSequenceRecordReader(Arrays.asList(sequence1, sequence2));
List<List<Writable>> sequence1L = new ArrayList<>(); List<List<Writable>> sequence1L = new ArrayList<>();
sequence1L.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] { 100, 200, 300 }, new long[] { 1, 3 })), new IntWritable(101))); sequence1L.add(Arrays.asList(new NDArrayWritable(Nd4j.create(new double[] { 100, 200, 300 }, new long[] { 1, 3 })), new IntWritable(101)));
sequence1L.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] { 400, 500, 600 }, new long[] { 1, 3 })), new IntWritable(201))); sequence1L.add(Arrays.asList(new NDArrayWritable(Nd4j.create(new double[] { 400, 500, 600 }, new long[] { 1, 3 })), new IntWritable(201)));
List<List<Writable>> sequence2L = new ArrayList<>(); List<List<Writable>> sequence2L = new ArrayList<>();
sequence2L.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] { 700, 800, 900 }, new long[] { 1, 3 })), new IntWritable(301))); sequence2L.add(Arrays.asList(new NDArrayWritable(Nd4j.create(new double[] { 700, 800, 900 }, new long[] { 1, 3 })), new IntWritable(301)));
sequence2L.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] { 1000, 1100, 1200 }, new long[] { 1, 3 })), new IntWritable(401))); sequence2L.add(Arrays.asList(new NDArrayWritable(Nd4j.create(new double[] { 1000, 1100, 1200 }, new long[] { 1, 3 })), new IntWritable(401)));
SequenceRecordReader rrLabels = new CollectionSequenceRecordReader(Arrays.asList(sequence1L, sequence2L)); SequenceRecordReader rrLabels = new CollectionSequenceRecordReader(Arrays.asList(sequence1L, sequence2L));
SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(rrFeatures, rrLabels, 2, -1, true); SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(rrFeatures, rrLabels, 2, -1, true);
// 2 examples, 4 values per time step, 2 time steps // 2 examples, 4 values per time step, 2 time steps
@ -845,7 +868,8 @@ class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
assertEquals(expLabels, ds.getLabels()); assertEquals(expLabels, ds.getLabels());
} }
@Test @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@DisplayName("Test Record Reader Meta Data") @DisplayName("Test Record Reader Meta Data")
void testRecordReaderMetaData() throws Exception { void testRecordReaderMetaData() throws Exception {
RecordReader csv = new CSVRecordReader(); RecordReader csv = new CSVRecordReader();
@ -878,9 +902,10 @@ class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
} }
} }
@Test @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@DisplayName("Test RRDS Iwith Async") @DisplayName("Test RRDS Iwith Async")
void testRRDSIwithAsync() throws Exception { void testRRDSIwithAsync(Nd4jBackend backend) throws Exception {
RecordReader csv = new CSVRecordReader(); RecordReader csv = new CSVRecordReader();
csv.initialize(new FileSplit(Resources.asFile("iris.txt"))); csv.initialize(new FileSplit(Resources.asFile("iris.txt")));
int batchSize = 10; int batchSize = 10;
@ -893,9 +918,10 @@ class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
} }
} }
@Test @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@DisplayName("Test Record Reader Data Set Iterator ND Array Writable Labels") @DisplayName("Test Record Reader Data Set Iterator ND Array Writable Labels")
void testRecordReaderDataSetIteratorNDArrayWritableLabels() { void testRecordReaderDataSetIteratorNDArrayWritableLabels(Nd4jBackend backend) {
Collection<Collection<Writable>> data = new ArrayList<>(); Collection<Collection<Writable>> data = new ArrayList<>();
data.add(Arrays.<Writable>asList(new DoubleWritable(0), new DoubleWritable(1), new NDArrayWritable(Nd4j.create(new double[] { 1.1, 2.1, 3.1 }, new long[] { 1, 3 })))); data.add(Arrays.<Writable>asList(new DoubleWritable(0), new DoubleWritable(1), new NDArrayWritable(Nd4j.create(new double[] { 1.1, 2.1, 3.1 }, new long[] { 1, 3 }))));
data.add(Arrays.<Writable>asList(new DoubleWritable(2), new DoubleWritable(3), new NDArrayWritable(Nd4j.create(new double[] { 4.1, 5.1, 6.1 }, new long[] { 1, 3 })))); data.add(Arrays.<Writable>asList(new DoubleWritable(2), new DoubleWritable(3), new NDArrayWritable(Nd4j.create(new double[] { 4.1, 5.1, 6.1 }, new long[] { 1, 3 }))));
@ -925,10 +951,11 @@ class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
assertEquals(expLabels, ds2.getLabels()); assertEquals(expLabels, ds2.getLabels());
} }
@Test @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@Disabled @Disabled
@DisplayName("Special RR Test 4") @DisplayName("Special RR Test 4")
void specialRRTest4() throws Exception { void specialRRTest4(Nd4jBackend backend) throws Exception {
RecordReader rr = new SpecialImageRecordReader(25000, 10, 3, 224, 224); RecordReader rr = new SpecialImageRecordReader(25000, 10, 3, 224, 224);
RecordReaderDataSetIterator rrdsi = new RecordReaderDataSetIterator(rr, 128); RecordReaderDataSetIterator rrdsi = new RecordReaderDataSetIterator(rr, 128);
int cnt = 0; int cnt = 0;
@ -1026,9 +1053,10 @@ class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
} }
*/ */
@Test @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@DisplayName("Test Record Reader Data Set Iterator Concat") @DisplayName("Test Record Reader Data Set Iterator Concat")
void testRecordReaderDataSetIteratorConcat() { void testRecordReaderDataSetIteratorConcat(Nd4jBackend backend) {
// [DoubleWritable, DoubleWritable, NDArrayWritable([1,10]), IntWritable] -> concatenate to a [1,13] feature vector automatically. // [DoubleWritable, DoubleWritable, NDArrayWritable([1,10]), IntWritable] -> concatenate to a [1,13] feature vector automatically.
List<Writable> l = Arrays.<Writable>asList(new DoubleWritable(1), new NDArrayWritable(Nd4j.create(new double[] { 2, 3, 4 })), new DoubleWritable(5), new NDArrayWritable(Nd4j.create(new double[] { 6, 7, 8 })), new IntWritable(9), new IntWritable(1)); List<Writable> l = Arrays.<Writable>asList(new DoubleWritable(1), new NDArrayWritable(Nd4j.create(new double[] { 2, 3, 4 })), new DoubleWritable(5), new NDArrayWritable(Nd4j.create(new double[] { 6, 7, 8 })), new IntWritable(9), new IntWritable(1));
RecordReader rr = new CollectionRecordReader(Collections.singletonList(l)); RecordReader rr = new CollectionRecordReader(Collections.singletonList(l));
@ -1040,9 +1068,10 @@ class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
assertEquals(expL, ds.getLabels()); assertEquals(expL, ds.getLabels());
} }
@Test @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@DisplayName("Test Record Reader Data Set Iterator Concat 2") @DisplayName("Test Record Reader Data Set Iterator Concat 2")
void testRecordReaderDataSetIteratorConcat2() { void testRecordReaderDataSetIteratorConcat2(Nd4jBackend backend) {
List<Writable> l = new ArrayList<>(); List<Writable> l = new ArrayList<>();
l.add(new IntWritable(0)); l.add(new IntWritable(0));
l.add(new NDArrayWritable(Nd4j.arange(1, 9))); l.add(new NDArrayWritable(Nd4j.arange(1, 9)));
@ -1054,11 +1083,12 @@ class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
assertEquals(expF, ds.getFeatures()); assertEquals(expF, ds.getFeatures());
} }
@Test @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@DisplayName("Test Record Reader Data Set Iterator Disjoint Features") @DisplayName("Test Record Reader Data Set Iterator Disjoint Features")
void testRecordReaderDataSetIteratorDisjointFeatures() { void testRecordReaderDataSetIteratorDisjointFeatures(Nd4jBackend backend) {
// Idea: input vector is like [f,f,f,f,l,l,f,f] or similar - i.e., label writables aren't start/end // Idea: input vector is like [f,f,f,f,l,l,f,f] or similar - i.e., label writables aren't start/end
List<Writable> l = Arrays.<Writable>asList(new DoubleWritable(1), new NDArrayWritable(Nd4j.create(new float[] { 2, 3, 4 }, new long[] { 1, 3 })), new DoubleWritable(5), new NDArrayWritable(Nd4j.create(new float[] { 6, 7, 8 }, new long[] { 1, 3 }))); List<Writable> l = Arrays.asList(new DoubleWritable(1), new NDArrayWritable(Nd4j.create(new float[] { 2, 3, 4 }, new long[] { 1, 3 })), new DoubleWritable(5), new NDArrayWritable(Nd4j.create(new float[] { 6, 7, 8 }, new long[] { 1, 3 })));
INDArray expF = Nd4j.create(new float[] { 1, 6, 7, 8 }, new long[] { 1, 4 }); INDArray expF = Nd4j.create(new float[] { 1, 6, 7, 8 }, new long[] { 1, 4 });
INDArray expL = Nd4j.create(new float[] { 2, 3, 4, 5 }, new long[] { 1, 4 }); INDArray expL = Nd4j.create(new float[] { 2, 3, 4, 5 }, new long[] { 1, 4 });
RecordReader rr = new CollectionRecordReader(Collections.singletonList(l)); RecordReader rr = new CollectionRecordReader(Collections.singletonList(l));
@ -1068,9 +1098,10 @@ class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
assertEquals(expL, ds.getLabels()); assertEquals(expL, ds.getLabels());
} }
@Test @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@DisplayName("Test Normalizer Prefetch Reset") @DisplayName("Test Normalizer Prefetch Reset")
void testNormalizerPrefetchReset() throws Exception { void testNormalizerPrefetchReset(Nd4jBackend backend) throws Exception {
// Check NPE fix for: https://github.com/eclipse/deeplearning4j/issues/4214 // Check NPE fix for: https://github.com/eclipse/deeplearning4j/issues/4214
RecordReader csv = new CSVRecordReader(); RecordReader csv = new CSVRecordReader();
csv.initialize(new FileSplit(Resources.asFile("iris.txt"))); csv.initialize(new FileSplit(Resources.asFile("iris.txt")));
@ -1087,9 +1118,10 @@ class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
iter.next(); iter.next();
} }
@Test @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@DisplayName("Test Reading From Stream") @DisplayName("Test Reading From Stream")
void testReadingFromStream() throws Exception { void testReadingFromStream(Nd4jBackend backend) throws Exception {
for (boolean b : new boolean[] { false, true }) { for (boolean b : new boolean[] { false, true }) {
int batchSize = 1; int batchSize = 1;
int labelIndex = 4; int labelIndex = 4;
@ -1121,9 +1153,10 @@ class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
} }
} }
@Test @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@DisplayName("Test Images RRDSI") @DisplayName("Test Images RRDSI")
void testImagesRRDSI() throws Exception { void testImagesRRDSI(Nd4jBackend backend) throws Exception {
File parentDir = temporaryFolder.toFile(); File parentDir = temporaryFolder.toFile();
parentDir.deleteOnExit(); parentDir.deleteOnExit();
String str1 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Zico/"); String str1 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Zico/");
@ -1150,16 +1183,17 @@ class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
assertArrayEquals(new long[] { 2, 2 }, ds.getLabels().shape()); assertArrayEquals(new long[] { 2, 2 }, ds.getLabels().shape());
} }
@Test @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@DisplayName("Test Seq RRDSI No Labels") @DisplayName("Test Seq RRDSI No Labels")
void testSeqRRDSINoLabels() { void testSeqRRDSINoLabels(Nd4jBackend backend) {
List<List<Writable>> sequence1 = new ArrayList<>(); List<List<Writable>> sequence1 = new ArrayList<>();
sequence1.add(Arrays.asList((Writable) new DoubleWritable(1), new DoubleWritable(2))); sequence1.add(Arrays.asList(new DoubleWritable(1), new DoubleWritable(2)));
sequence1.add(Arrays.asList((Writable) new DoubleWritable(3), new DoubleWritable(4))); sequence1.add(Arrays.asList(new DoubleWritable(3), new DoubleWritable(4)));
sequence1.add(Arrays.asList((Writable) new DoubleWritable(5), new DoubleWritable(6))); sequence1.add(Arrays.asList(new DoubleWritable(5), new DoubleWritable(6)));
List<List<Writable>> sequence2 = new ArrayList<>(); List<List<Writable>> sequence2 = new ArrayList<>();
sequence2.add(Arrays.asList((Writable) new DoubleWritable(10), new DoubleWritable(20))); sequence2.add(Arrays.asList(new DoubleWritable(10), new DoubleWritable(20)));
sequence2.add(Arrays.asList((Writable) new DoubleWritable(30), new DoubleWritable(40))); sequence2.add(Arrays.asList(new DoubleWritable(30), new DoubleWritable(40)));
SequenceRecordReader rrFeatures = new CollectionSequenceRecordReader(Arrays.asList(sequence1, sequence2)); SequenceRecordReader rrFeatures = new CollectionSequenceRecordReader(Arrays.asList(sequence1, sequence2));
SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(rrFeatures, 2, -1, -1); SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(rrFeatures, 2, -1, -1);
DataSet ds = iter.next(); DataSet ds = iter.next();
@ -1167,9 +1201,10 @@ class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
assertNull(ds.getLabels()); assertNull(ds.getLabels());
} }
@Test @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@DisplayName("Test Collect Meta Data") @DisplayName("Test Collect Meta Data")
void testCollectMetaData() { void testCollectMetaData(Nd4jBackend backend) {
RecordReaderDataSetIterator trainIter = new RecordReaderDataSetIterator.Builder(new CollectionRecordReader(Collections.<List<Writable>>emptyList()), 1).collectMetaData(true).build(); RecordReaderDataSetIterator trainIter = new RecordReaderDataSetIterator.Builder(new CollectionRecordReader(Collections.<List<Writable>>emptyList()), 1).collectMetaData(true).build();
assertTrue(trainIter.isCollectMetaData()); assertTrue(trainIter.isCollectMetaData());
trainIter.setCollectMetaData(false); trainIter.setCollectMetaData(false);

View File

@ -24,6 +24,7 @@ import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.file.FileDataSetIterator; import org.deeplearning4j.datasets.iterator.file.FileDataSetIterator;
import org.deeplearning4j.datasets.iterator.file.FileMultiDataSetIterator; import org.deeplearning4j.datasets.iterator.file.FileMultiDataSetIterator;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
@ -40,6 +41,7 @@ import java.util.*;
import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@Disabled
public class TestFileIterators extends BaseDL4JTest { public class TestFileIterators extends BaseDL4JTest {

View File

@ -41,14 +41,19 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List;
import java.util.stream.Stream; import java.util.stream.Stream;
import static org.deeplearning4j.nn.conf.ConvolutionMode.Same; import static org.deeplearning4j.nn.conf.ConvolutionMode.Same;
@ -56,6 +61,7 @@ import static org.deeplearning4j.nn.conf.ConvolutionMode.Truncate;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.nd4j.shade.guava.collect.Lists;
@DisplayName("Cnn Gradient Check Test") @DisplayName("Cnn Gradient Check Test")
class CNNGradientCheckTest extends BaseDL4JTest { class CNNGradientCheckTest extends BaseDL4JTest {
@ -77,7 +83,13 @@ class CNNGradientCheckTest extends BaseDL4JTest {
public static Stream<Arguments> params() { public static Stream<Arguments> params() {
return Arrays.asList(CNN2DFormat.values()).stream().map(Arguments::of); List<Arguments> args = new ArrayList<>();
for(Nd4jBackend nd4jBackend : BaseNd4jTestWithBackends.BACKENDS) {
for(CNN2DFormat format : CNN2DFormat.values()) {
args.add(Arguments.of(format,nd4jBackend));
}
}
return args.stream();
} }
@Override @Override
@ -85,11 +97,10 @@ class CNNGradientCheckTest extends BaseDL4JTest {
return 999990000L; return 999990000L;
} }
@Test
@DisplayName("Test Gradient CNNMLN") @DisplayName("Test Gradient CNNMLN")
@ParameterizedTest @ParameterizedTest
@MethodSource("#params") @MethodSource("org.deeplearning4j.gradientcheck.CNNGradientCheckTest#params")
public void testGradientCNNMLN(CNN2DFormat format) { public void testGradientCNNMLN(CNN2DFormat format,Nd4jBackend backend) {
if (// Only test NCHW due to flat input format... if (// Only test NCHW due to flat input format...
format != CNN2DFormat.NCHW) format != CNN2DFormat.NCHW)
return; return;
@ -144,9 +155,10 @@ class CNNGradientCheckTest extends BaseDL4JTest {
} }
} }
@Test @ParameterizedTest
@MethodSource("org.deeplearning4j.gradientcheck.CNNGradientCheckTest#params")
@DisplayName("Test Gradient CNNL 1 L 2 MLN") @DisplayName("Test Gradient CNNL 1 L 2 MLN")
void testGradientCNNL1L2MLN(CNN2DFormat format) { void testGradientCNNL1L2MLN(CNN2DFormat format,Nd4jBackend backend) {
if (// Only test NCHW due to flat input format... if (// Only test NCHW due to flat input format...
format != CNN2DFormat.NCHW) format != CNN2DFormat.NCHW)
return; return;
@ -207,9 +219,10 @@ class CNNGradientCheckTest extends BaseDL4JTest {
} }
@Disabled @Disabled
@Test @ParameterizedTest
@MethodSource("org.deeplearning4j.gradientcheck.CNNGradientCheckTest#params")
@DisplayName("Test Cnn With Space To Depth") @DisplayName("Test Cnn With Space To Depth")
void testCnnWithSpaceToDepth() { void testCnnWithSpaceToDepth(CNN2DFormat format,Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int nOut = 4; int nOut = 4;
int minibatchSize = 2; int minibatchSize = 2;
@ -246,8 +259,8 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Cnn With Space To Batch") @DisplayName("Test Cnn With Space To Batch")
@ParameterizedTest @ParameterizedTest
@MethodSource("#params") @MethodSource("org.deeplearning4j.gradientcheck.CNNGradientCheckTest#params")
public void testCnnWithSpaceToBatch(CNN2DFormat format) { public void testCnnWithSpaceToBatch(CNN2DFormat format,Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int nOut = 4; int nOut = 4;
int[] minibatchSizes = { 2, 4 }; int[] minibatchSizes = { 2, 4 };
@ -292,8 +305,8 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Cnn With Upsampling") @DisplayName("Test Cnn With Upsampling")
@ParameterizedTest @ParameterizedTest
@MethodSource("#params") @MethodSource("org.deeplearning4j.gradientcheck.CNNGradientCheckTest#params")
void testCnnWithUpsampling(CNN2DFormat format) { void testCnnWithUpsampling(CNN2DFormat format,Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int nOut = 4; int nOut = 4;
int[] minibatchSizes = { 1, 3 }; int[] minibatchSizes = { 1, 3 };
@ -328,8 +341,8 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Cnn With Subsampling") @DisplayName("Test Cnn With Subsampling")
@ParameterizedTest @ParameterizedTest
@MethodSource("#params") @MethodSource("org.deeplearning4j.gradientcheck.CNNGradientCheckTest#params")
void testCnnWithSubsampling(CNN2DFormat format) { void testCnnWithSubsampling(CNN2DFormat format,Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int nOut = 4; int nOut = 4;
int[] minibatchSizes = { 1, 3 }; int[] minibatchSizes = { 1, 3 };
@ -372,8 +385,8 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Cnn With Subsampling V 2") @DisplayName("Test Cnn With Subsampling V 2")
@ParameterizedTest @ParameterizedTest
@MethodSource("#params") @MethodSource("org.deeplearning4j.gradientcheck.CNNGradientCheckTest#params")
void testCnnWithSubsamplingV2(CNN2DFormat format) { void testCnnWithSubsamplingV2(CNN2DFormat format,Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int nOut = 4; int nOut = 4;
int[] minibatchSizes = { 1, 3 }; int[] minibatchSizes = { 1, 3 };
@ -412,8 +425,8 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Cnn Locally Connected 2 D") @DisplayName("Test Cnn Locally Connected 2 D")
@ParameterizedTest @ParameterizedTest
@MethodSource("#params") @MethodSource("org.deeplearning4j.gradientcheck.CNNGradientCheckTest#params")
void testCnnLocallyConnected2D(CNN2DFormat format) { void testCnnLocallyConnected2D(CNN2DFormat format,Nd4jBackend backend) {
int nOut = 3; int nOut = 3;
int width = 5; int width = 5;
int height = 5; int height = 5;
@ -444,8 +457,8 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Cnn Multi Layer") @DisplayName("Test Cnn Multi Layer")
@ParameterizedTest @ParameterizedTest
@MethodSource("#params") @MethodSource("org.deeplearning4j.gradientcheck.CNNGradientCheckTest#params")
void testCnnMultiLayer(CNN2DFormat format) { void testCnnMultiLayer(CNN2DFormat format,Nd4jBackend backend) {
int nOut = 2; int nOut = 2;
int[] minibatchSizes = { 1, 2, 5 }; int[] minibatchSizes = { 1, 2, 5 };
int width = 5; int width = 5;
@ -486,8 +499,8 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Cnn Same Padding Mode") @DisplayName("Test Cnn Same Padding Mode")
@ParameterizedTest @ParameterizedTest
@MethodSource("#params") @MethodSource("org.deeplearning4j.gradientcheck.CNNGradientCheckTest#params")
void testCnnSamePaddingMode(CNN2DFormat format) { void testCnnSamePaddingMode(CNN2DFormat format,Nd4jBackend backend) {
int nOut = 2; int nOut = 2;
int[] minibatchSizes = { 1, 3, 3, 2, 1, 2 }; int[] minibatchSizes = { 1, 3, 3, 2, 1, 2 };
// Same padding mode: insensitive to exact input size... // Same padding mode: insensitive to exact input size...
@ -522,8 +535,8 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Cnn Same Padding Mode Strided") @DisplayName("Test Cnn Same Padding Mode Strided")
@ParameterizedTest @ParameterizedTest
@MethodSource("#params") @MethodSource("org.deeplearning4j.gradientcheck.CNNGradientCheckTest#params")
void testCnnSamePaddingModeStrided(CNN2DFormat format) { void testCnnSamePaddingModeStrided(CNN2DFormat format,Nd4jBackend backend) {
int nOut = 2; int nOut = 2;
int[] minibatchSizes = { 1, 3 }; int[] minibatchSizes = { 1, 3 };
int width = 16; int width = 16;
@ -567,8 +580,8 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Cnn Zero Padding Layer") @DisplayName("Test Cnn Zero Padding Layer")
@ParameterizedTest @ParameterizedTest
@MethodSource("#params") @MethodSource("org.deeplearning4j.gradientcheck.CNNGradientCheckTest#params")
void testCnnZeroPaddingLayer(CNN2DFormat format) { void testCnnZeroPaddingLayer(CNN2DFormat format,Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int nOut = 4; int nOut = 4;
int width = 6; int width = 6;
@ -615,8 +628,8 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Deconvolution 2 D") @DisplayName("Test Deconvolution 2 D")
@ParameterizedTest @ParameterizedTest
@MethodSource("#params") @MethodSource("org.deeplearning4j.gradientcheck.CNNGradientCheckTest#params")
void testDeconvolution2D(CNN2DFormat format) { void testDeconvolution2D(CNN2DFormat format,Nd4jBackend backend) {
int nOut = 2; int nOut = 2;
int[] minibatchSizes = new int[] { 1, 3, 3, 1, 3 }; int[] minibatchSizes = new int[] { 1, 3, 3, 1, 3 };
int[] kernelSizes = new int[] { 1, 1, 1, 3, 3 }; int[] kernelSizes = new int[] { 1, 1, 1, 3, 3 };
@ -662,8 +675,8 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Separable Conv 2 D") @DisplayName("Test Separable Conv 2 D")
@ParameterizedTest @ParameterizedTest
@MethodSource("#params") @MethodSource("org.deeplearning4j.gradientcheck.CNNGradientCheckTest#params")
void testSeparableConv2D(CNN2DFormat format) { void testSeparableConv2D(CNN2DFormat format,Nd4jBackend backend) {
int nOut = 2; int nOut = 2;
int width = 6; int width = 6;
int height = 6; int height = 6;
@ -709,8 +722,8 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Cnn Dilated") @DisplayName("Test Cnn Dilated")
@ParameterizedTest @ParameterizedTest
@MethodSource("#params") @MethodSource("org.deeplearning4j.gradientcheck.CNNGradientCheckTest#params")
void testCnnDilated(CNN2DFormat format) { void testCnnDilated(CNN2DFormat format,Nd4jBackend backend) {
int nOut = 2; int nOut = 2;
int minibatchSize = 2; int minibatchSize = 2;
int width = 8; int width = 8;
@ -761,8 +774,8 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Cropping 2 D Layer") @DisplayName("Test Cropping 2 D Layer")
@ParameterizedTest @ParameterizedTest
@MethodSource("#params") @MethodSource("org.deeplearning4j.gradientcheck.CNNGradientCheckTest#params")
void testCropping2DLayer(CNN2DFormat format) { void testCropping2DLayer(CNN2DFormat format,Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int nOut = 2; int nOut = 2;
int width = 12; int width = 12;
@ -807,8 +820,8 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Depthwise Conv 2 D") @DisplayName("Test Depthwise Conv 2 D")
@ParameterizedTest @ParameterizedTest
@MethodSource("#params") @MethodSource("org.deeplearning4j.gradientcheck.CNNGradientCheckTest#params")
void testDepthwiseConv2D(CNN2DFormat format) { void testDepthwiseConv2D(CNN2DFormat format,Nd4jBackend backendt) {
int nIn = 3; int nIn = 3;
int depthMultiplier = 2; int depthMultiplier = 2;
int nOut = nIn * depthMultiplier; int nOut = nIn * depthMultiplier;

View File

@ -43,6 +43,7 @@ import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -51,13 +52,16 @@ import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler; import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.learning.config.NoOp;
import java.io.File; import java.io.File;
import java.io.FileOutputStream; import java.io.FileOutputStream;
import java.io.InputStream; import java.io.InputStream;
import java.nio.file.Path; import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List;
import java.util.stream.Stream; import java.util.stream.Stream;
import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertArrayEquals;
@ -70,8 +74,16 @@ public class YoloGradientCheckTests extends BaseDL4JTest {
} }
@TempDir Path testDir;
public static Stream<Arguments> params() { public static Stream<Arguments> params() {
return Arrays.asList(CNN2DFormat.values()).stream().map(Arguments::of); List<Arguments> args = new ArrayList<>();
for(Nd4jBackend nd4jBackend : BaseNd4jTestWithBackends.BACKENDS) {
for(CNN2DFormat format : CNN2DFormat.values()) {
args.add(Arguments.of(format,nd4jBackend));
}
}
return args.stream();
} }
@Override @Override
@ -80,8 +92,8 @@ public class YoloGradientCheckTests extends BaseDL4JTest {
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("#params") @MethodSource("org.deeplearning4j.gradientcheckYoloGradientCheckTests.#params")
public void testYoloOutputLayer(CNN2DFormat format) { public void testYoloOutputLayer(CNN2DFormat format,Nd4jBackend backend) {
int depthIn = 2; int depthIn = 2;
int c = 3; int c = 3;
int b = 3; int b = 3;
@ -180,8 +192,8 @@ public class YoloGradientCheckTests extends BaseDL4JTest {
@ParameterizedTest @ParameterizedTest
@MethodSource("#params") @MethodSource("org.deeplearning4j.gradientcheckYoloGradientCheckTests#params")
public void yoloGradientCheckRealData(@TempDir Path testDir,CNN2DFormat format) throws Exception { public void yoloGradientCheckRealData(CNN2DFormat format,Nd4jBackend backend) throws Exception {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
InputStream is1 = new ClassPathResource("yolo/VOC_TwoImage/JPEGImages/2007_009346.jpg").getInputStream(); InputStream is1 = new ClassPathResource("yolo/VOC_TwoImage/JPEGImages/2007_009346.jpg").getInputStream();
InputStream is2 = new ClassPathResource("yolo/VOC_TwoImage/Annotations/2007_009346.xml").getInputStream(); InputStream is2 = new ClassPathResource("yolo/VOC_TwoImage/Annotations/2007_009346.xml").getInputStream();

View File

@ -1779,7 +1779,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
@Test @Test
public void testCompGraphDropoutOutputLayers(){ public void testCompGraphDropoutOutputLayers(){
//https://github.com/deeplearning4j/deeplearning4j/issues/6326 //https://github.com/eclipse/deeplearning4j/issues/6326
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
.dropOut(0.8) .dropOut(0.8)
.graphBuilder() .graphBuilder()
@ -1817,7 +1817,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
@Test @Test
public void testCompGraphDropoutOutputLayers2() { public void testCompGraphDropoutOutputLayers2() {
//https://github.com/deeplearning4j/deeplearning4j/issues/6326 //https://github.com/eclipse/deeplearning4j/issues/6326
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
.dropOut(0.8) .dropOut(0.8)
.graphBuilder() .graphBuilder()
@ -1976,7 +1976,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
@Test @Test
public void testVerticesAndMasking7027(){ public void testVerticesAndMasking7027(){
//https://github.com/deeplearning4j/deeplearning4j/issues/7027 //https://github.com/eclipse/deeplearning4j/issues/7027
int inputSize = 300; int inputSize = 300;
int hiddenSize = 100; int hiddenSize = 100;
int dataSize = 10; int dataSize = 10;
@ -2017,7 +2017,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
@Test @Test
public void testCompGraphUpdaterBlocks(){ public void testCompGraphUpdaterBlocks(){
//Check that setting learning rate results in correct rearrangement of updater state within updater blocks //Check that setting learning rate results in correct rearrangement of updater state within updater blocks
//https://github.com/deeplearning4j/deeplearning4j/issues/6809#issuecomment-463892644 //https://github.com/eclipse/deeplearning4j/issues/6809#issuecomment-463892644
double lr = 1e-3; double lr = 1e-3;
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()

View File

@ -43,11 +43,13 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.primitives.Pair; import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.factory.Nd4jBackend;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
@ -62,18 +64,24 @@ public class ConvDataFormatTests extends BaseDL4JTest {
public static Stream<Arguments> params() { public static Stream<Arguments> params() {
return Arrays.asList(new DataType[]{DataType.FLOAT, DataType.DOUBLE}).stream().map(Arguments::of); List<Arguments> args = new ArrayList<>();
for(Nd4jBackend nd4jBackend : BaseNd4jTestWithBackends.BACKENDS) {
for(DataType dataType : Arrays.asList(new DataType[]{DataType.FLOAT, DataType.DOUBLE})) {
args.add(Arguments.of(dataType,nd4jBackend));
} }
}
return args.stream();
}
@Override @Override
public long getTimeoutMilliseconds() { public long getTimeoutMilliseconds() {
return 999999999L; return 999999999L;
} }
@Test @MethodSource("org.deeplearning4j.nn.layers.convolution.ConvDataFormatTests#params")
@MethodSource("#params")
@ParameterizedTest @ParameterizedTest
public void testConv2d(DataType dataType) { public void testConv2d(DataType dataType,Nd4jBackend backend) {
try { try {
for (boolean helpers : new boolean[]{false, true}) { for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) { for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
@ -105,10 +113,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
} }
@Test @MethodSource("org.deeplearning4j.nn.layers.convolution.ConvDataFormatTests#params")
@MethodSource("#params")
@ParameterizedTest @ParameterizedTest
public void testSubsampling2d(DataType dataType) { public void testSubsampling2d(DataType dataType,Nd4jBackend backend) {
try { try {
for (boolean helpers : new boolean[]{false, true}) { for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) { for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
@ -140,10 +147,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
} }
@Test @MethodSource("org.deeplearning4j.nn.layers.convolution.ConvDataFormatTests#params")
@MethodSource("#params")
@ParameterizedTest @ParameterizedTest
public void testDepthwiseConv2d(DataType dataType) { public void testDepthwiseConv2d(DataType dataType,Nd4jBackend backend) {
try { try {
for (boolean helpers : new boolean[]{false, true}) { for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) { for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
@ -175,10 +181,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
} }
@Test @MethodSource("org.deeplearning4j.nn.layers.convolution.ConvDataFormatTests#params")
@MethodSource("#params")
@ParameterizedTest @ParameterizedTest
public void testSeparableConv2d(DataType dataType) { public void testSeparableConv2d(DataType dataType,Nd4jBackend backend) {
try { try {
for (boolean helpers : new boolean[]{false, true}) { for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) { for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
@ -210,10 +215,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
} }
@Test @MethodSource("org.deeplearning4j.nn.layers.convolution.ConvDataFormatTests#params")
@MethodSource("#params")
@ParameterizedTest @ParameterizedTest
public void testDeconv2d(DataType dataType) { public void testDeconv2d(DataType dataType,Nd4jBackend backend) {
try { try {
for (boolean helpers : new boolean[]{false, true}) { for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) { for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
@ -245,10 +249,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
} }
@Test @MethodSource("org.deeplearning4j.nn.layers.convolution.ConvDataFormatTests#params")
@MethodSource("#params")
@ParameterizedTest @ParameterizedTest
public void testLRN(DataType dataType) { public void testLRN(DataType dataType,Nd4jBackend backend) {
try { try {
for (boolean helpers : new boolean[]{false, true}) { for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) { for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
@ -280,10 +283,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
} }
@Test @MethodSource("org.deeplearning4j.nn.layers.convolution.ConvDataFormatTests#params")
@MethodSource("#params")
@ParameterizedTest @ParameterizedTest
public void testZeroPaddingLayer(DataType dataType) { public void testZeroPaddingLayer(DataType dataType,Nd4jBackend backend) {
try { try {
for (boolean helpers : new boolean[]{false, true}) { for (boolean helpers : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -313,10 +315,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
} }
@Test @MethodSource("org.deeplearning4j.nn.layers.convolution.ConvDataFormatTests#params")
@MethodSource("#params")
@ParameterizedTest @ParameterizedTest
public void testCropping2DLayer(DataType dataType) { public void testCropping2DLayer(DataType dataType,Nd4jBackend backend) {
try { try {
for (boolean helpers : new boolean[]{false, true}) { for (boolean helpers : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -346,10 +347,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
} }
@Test @MethodSource("org.deeplearning4j.nn.layers.convolution.ConvDataFormatTests#params")
@MethodSource("#params")
@ParameterizedTest @ParameterizedTest
public void testUpsampling2d(DataType dataType) { public void testUpsampling2d(DataType dataType,Nd4jBackend backend) {
try { try {
for (boolean helpers : new boolean[]{false, true}) { for (boolean helpers : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -379,10 +379,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
} }
@Test @MethodSource("org.deeplearning4j.nn.layers.convolution.ConvDataFormatTests#params")
@MethodSource("#params")
@ParameterizedTest @ParameterizedTest
public void testBatchNormNet(DataType dataType) { public void testBatchNormNet(DataType dataType,Nd4jBackend backend) {
try { try {
for(boolean useLogStd : new boolean[]{true, false}) { for(boolean useLogStd : new boolean[]{true, false}) {
for (boolean helpers : new boolean[]{false, true}) { for (boolean helpers : new boolean[]{false, true}) {
@ -414,10 +413,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
} }
@Test @MethodSource("org.deeplearning4j.nn.layers.convolution.ConvDataFormatTests#params")
@MethodSource("#params")
@ParameterizedTest @ParameterizedTest
public void testCnnLossLayer(DataType dataType) { public void testCnnLossLayer(DataType dataType,Nd4jBackend backend) {
try { try {
for (boolean helpers : new boolean[]{false, true}) { for (boolean helpers : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -452,10 +450,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
} }
@Test @MethodSource("org.deeplearning4j.nn.layers.convolution.ConvDataFormatTests#params")
@MethodSource("#params")
@ParameterizedTest @ParameterizedTest
public void testSpaceToDepthNet(DataType dataType) { public void testSpaceToDepthNet(DataType dataType,Nd4jBackend backend) {
try { try {
for (boolean helpers : new boolean[]{false, true}) { for (boolean helpers : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -485,10 +482,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
} }
@Test @MethodSource("org.deeplearning4j.nn.layers.convolution.ConvDataFormatTests#params")
@MethodSource("#params")
@ParameterizedTest @ParameterizedTest
public void testSpaceToBatchNet(DataType dataType) { public void testSpaceToBatchNet(DataType dataType,Nd4jBackend backend) {
try { try {
for (boolean helpers : new boolean[]{false, true}) { for (boolean helpers : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -518,10 +514,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
} }
@Test @MethodSource("org.deeplearning4j.nn.layers.convolution.ConvDataFormatTests#params")
@MethodSource("#params")
@ParameterizedTest @ParameterizedTest
public void testLocallyConnected(DataType dataType) { public void testLocallyConnected(DataType dataType,Nd4jBackend backend) {
try { try {
for (boolean helpers : new boolean[]{false, true}) { for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) { for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
@ -554,10 +549,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
@Test @MethodSource("org.deeplearning4j.nn.layers.convolution.ConvDataFormatTests#params")
@MethodSource("#params")
@ParameterizedTest @ParameterizedTest
public void testGlobalPooling(DataType dataType) { public void testGlobalPooling(DataType dataType,Nd4jBackend backend) {
try { try {
for (boolean helpers : new boolean[]{false, true}) { for (boolean helpers : new boolean[]{false, true}) {
for (PoolingType pt : PoolingType.values()) { for (PoolingType pt : PoolingType.values()) {

View File

@ -50,6 +50,7 @@ import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.enums.RnnDataFormat; import org.nd4j.enums.RnnDataFormat;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -57,6 +58,7 @@ import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.MultiDataSet; import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.common.primitives.Pair; import org.nd4j.common.primitives.Pair;
@ -64,7 +66,9 @@ import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream; import java.io.ByteArrayOutputStream;
import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List;
import java.util.stream.Stream; import java.util.stream.Stream;
import static org.deeplearning4j.nn.conf.RNNFormat.NCW; import static org.deeplearning4j.nn.conf.RNNFormat.NCW;
@ -79,14 +83,20 @@ class BidirectionalTest extends BaseDL4JTest {
public static Stream<Arguments> params() { public static Stream<Arguments> params() {
return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of); List<Arguments> args = new ArrayList<>();
for(Nd4jBackend nd4jBackend : BaseNd4jTestWithBackends.BACKENDS) {
for(RNNFormat rnnFormat : RNNFormat.values()) {
args.add(Arguments.of(rnnFormat,nd4jBackend));
}
}
return args.stream();
} }
@Test
@DisplayName("Compare Implementations") @DisplayName("Compare Implementations")
@ParameterizedTest @ParameterizedTest
@MethodSource("#params") @MethodSource("org.deeplearning4j.nn.layers.recurrent.BidirectionalTest#params")
void compareImplementations(RNNFormat rnnDataFormat) { void compareImplementations(RNNFormat rnnDataFormat,Nd4jBackend backend) {
for (WorkspaceMode wsm : WorkspaceMode.values()) { for (WorkspaceMode wsm : WorkspaceMode.values()) {
log.info("*** Starting workspace mode: " + wsm); log.info("*** Starting workspace mode: " + wsm);
// Bidirectional(GravesLSTM) and GravesBidirectionalLSTM should be equivalent, given equivalent params // Bidirectional(GravesLSTM) and GravesBidirectionalLSTM should be equivalent, given equivalent params
@ -151,8 +161,8 @@ class BidirectionalTest extends BaseDL4JTest {
@DisplayName("Compare Implementations Comp Graph") @DisplayName("Compare Implementations Comp Graph")
@ParameterizedTest @ParameterizedTest
@MethodSource("#params") @MethodSource("org.deeplearning4j.nn.layers.recurrent.BidirectionalTest#params")
void compareImplementationsCompGraph(RNNFormat rnnFormat) { void compareImplementationsCompGraph(RNNFormat rnnFormat,Nd4jBackend backend) {
// for(WorkspaceMode wsm : WorkspaceMode.values()) { // for(WorkspaceMode wsm : WorkspaceMode.values()) {
for (WorkspaceMode wsm : new WorkspaceMode[] { WorkspaceMode.NONE, WorkspaceMode.ENABLED }) { for (WorkspaceMode wsm : new WorkspaceMode[] { WorkspaceMode.NONE, WorkspaceMode.ENABLED }) {
log.info("*** Starting workspace mode: " + wsm); log.info("*** Starting workspace mode: " + wsm);
@ -206,11 +216,10 @@ class BidirectionalTest extends BaseDL4JTest {
} }
} }
@Test
@DisplayName("Test Serialization") @DisplayName("Test Serialization")
@ParameterizedTest @ParameterizedTest
@MethodSource("#params") @MethodSource("org.deeplearning4j.nn.layers.recurrent.BidirectionalTest#params")
void testSerialization(RNNFormat rnnDataFormat) throws Exception { void testSerialization(RNNFormat rnnDataFormat,Nd4jBackend backend) throws Exception {
for (WorkspaceMode wsm : WorkspaceMode.values()) { for (WorkspaceMode wsm : WorkspaceMode.values()) {
log.info("*** Starting workspace mode: " + wsm); log.info("*** Starting workspace mode: " + wsm);
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -245,11 +254,10 @@ class BidirectionalTest extends BaseDL4JTest {
} }
} }
@Test
@DisplayName("Test Serialization Comp Graph") @DisplayName("Test Serialization Comp Graph")
@ParameterizedTest @ParameterizedTest
@MethodSource("#params") @MethodSource("org.deeplearning4j.nn.layers.recurrent.BidirectionalTest#params")
void testSerializationCompGraph(RNNFormat rnnDataFormat) throws Exception { void testSerializationCompGraph(RNNFormat rnnDataFormat,Nd4jBackend backend) throws Exception {
for (WorkspaceMode wsm : WorkspaceMode.values()) { for (WorkspaceMode wsm : WorkspaceMode.values()) {
log.info("*** Starting workspace mode: " + wsm); log.info("*** Starting workspace mode: " + wsm);
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -282,11 +290,10 @@ class BidirectionalTest extends BaseDL4JTest {
} }
} }
@Test
@DisplayName("Test Simple Bidirectional") @DisplayName("Test Simple Bidirectional")
@ParameterizedTest @ParameterizedTest
@MethodSource("#params") @MethodSource("org.deeplearning4j.nn.layers.recurrent.BidirectionalTest#params")
public void testSimpleBidirectional(RNNFormat rnnDataFormat) { public void testSimpleBidirectional(RNNFormat rnnDataFormat,Nd4jBackend backend) {
for (WorkspaceMode wsm : WorkspaceMode.values()) { for (WorkspaceMode wsm : WorkspaceMode.values()) {
log.info("*** Starting workspace mode: " + wsm); log.info("*** Starting workspace mode: " + wsm);
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -369,11 +376,10 @@ class BidirectionalTest extends BaseDL4JTest {
} }
} }
@Test
@DisplayName("Test Simple Bidirectional Comp Graph") @DisplayName("Test Simple Bidirectional Comp Graph")
@ParameterizedTest @ParameterizedTest
@MethodSource("#params") @MethodSource("org.deeplearning4j.nn.layers.recurrent.BidirectionalTest#params")
void testSimpleBidirectionalCompGraph(RNNFormat rnnDataFormat) { void testSimpleBidirectionalCompGraph(RNNFormat rnnDataFormat,Nd4jBackend backend) {
for (WorkspaceMode wsm : WorkspaceMode.values()) { for (WorkspaceMode wsm : WorkspaceMode.values()) {
log.info("*** Starting workspace mode: " + wsm); log.info("*** Starting workspace mode: " + wsm);
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -462,10 +468,11 @@ class BidirectionalTest extends BaseDL4JTest {
} }
} }
@Test
@DisplayName("Test Issue 5472") @DisplayName("Test Issue 5472")
void testIssue5472() { @MethodSource("org.deeplearning4j.nn.layers.recurrent.BidirectionalTest#params")
// https://github.com/deeplearning4j/deeplearning4j/issues/5472 @ParameterizedTest
void testIssue5472(RNNFormat rnnDataFormat,Nd4jBackend backend) {
// https://github.com/eclipse/deeplearning4j/issues/5472
int in = 2; int in = 2;
int out = 2; int out = 2;
ComputationGraphConfiguration.GraphBuilder builder = new NeuralNetConfiguration.Builder().updater(new Adam(0.01)).activation(Activation.RELU).graphBuilder().addInputs("IN").setInputTypes(InputType.recurrent(in)).addLayer("AUTOENCODER", new VariationalAutoencoder.Builder().encoderLayerSizes(64).decoderLayerSizes(64).nOut(7).pzxActivationFunction(Activation.IDENTITY).reconstructionDistribution(new BernoulliReconstructionDistribution(Activation.SIGMOID.getActivationFunction())).build(), "IN").addLayer("RNN", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nOut(128).build()), "AUTOENCODER").addLayer("OUT", new RnnOutputLayer.Builder().nOut(out).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build(), "RNN").setOutputs("OUT"); ComputationGraphConfiguration.GraphBuilder builder = new NeuralNetConfiguration.Builder().updater(new Adam(0.01)).activation(Activation.RELU).graphBuilder().addInputs("IN").setInputTypes(InputType.recurrent(in)).addLayer("AUTOENCODER", new VariationalAutoencoder.Builder().encoderLayerSizes(64).decoderLayerSizes(64).nOut(7).pzxActivationFunction(Activation.IDENTITY).reconstructionDistribution(new BernoulliReconstructionDistribution(Activation.SIGMOID.getActivationFunction())).build(), "IN").addLayer("RNN", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nOut(128).build()), "AUTOENCODER").addLayer("OUT", new RnnOutputLayer.Builder().nOut(out).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build(), "RNN").setOutputs("OUT");

View File

@ -39,15 +39,19 @@ import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.common.primitives.Pair; import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.impl.ActivationSigmoid; import org.nd4j.linalg.activations.impl.ActivationSigmoid;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.learning.config.AdaGrad; import org.nd4j.linalg.learning.config.AdaGrad;
import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List;
import java.util.stream.Stream; import java.util.stream.Stream;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
@ -60,15 +64,19 @@ class GravesBidirectionalLSTMTest extends BaseDL4JTest {
public static Stream<Arguments> params() { public static Stream<Arguments> params() {
return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of); List<Arguments> args = new ArrayList<>();
for(Nd4jBackend nd4jBackend : BaseNd4jTestWithBackends.BACKENDS) {
for(RNNFormat rnnFormat : RNNFormat.values()) {
args.add(Arguments.of(rnnFormat,nd4jBackend));
}
}
return args.stream();
} }
@Test
@DisplayName("Test Bidirectional LSTM Graves Forward Basic") @DisplayName("Test Bidirectional LSTM Graves Forward Basic")
@MethodSource("#params") @MethodSource("org.deeplearning4j.nn.layers.recurrent.GravesBidirectionalLSTMTest#params")
@ParameterizedTest @ParameterizedTest
void testBidirectionalLSTMGravesForwardBasic(RNNFormat rnnDataFormat) { void testBidirectionalLSTMGravesForwardBasic(RNNFormat rnnDataFormat,Nd4jBackend backend) {
// Very basic test of forward prop. of LSTM layer with a time series. // Very basic test of forward prop. of LSTM layer with a time series.
// Essentially make sure it doesn't throw any exceptions, and provides output in the correct shape. // Essentially make sure it doesn't throw any exceptions, and provides output in the correct shape.
int nIn = 13; int nIn = 13;
@ -108,11 +116,10 @@ class GravesBidirectionalLSTMTest extends BaseDL4JTest {
} }
} }
@Test
@DisplayName("Test Bidirectional LSTM Graves Backward Basic") @DisplayName("Test Bidirectional LSTM Graves Backward Basic")
@MethodSource("#params") @MethodSource("org.deeplearning4j.nn.layers.recurrent.GravesBidirectionalLSTMTest#params")
@ParameterizedTest @ParameterizedTest
void testBidirectionalLSTMGravesBackwardBasic(RNNFormat rnnDataFormat) { void testBidirectionalLSTMGravesBackwardBasic(RNNFormat rnnDataFormat,Nd4jBackend backend) {
// Very basic test of backprop for mini-batch + time series // Very basic test of backprop for mini-batch + time series
// Essentially make sure it doesn't throw any exceptions, and provides output in the correct shape. // Essentially make sure it doesn't throw any exceptions, and provides output in the correct shape.
testGravesBackwardBasicHelper(rnnDataFormat,13, 3, 17, 10, 7); testGravesBackwardBasicHelper(rnnDataFormat,13, 3, 17, 10, 7);
@ -168,9 +175,10 @@ class GravesBidirectionalLSTMTest extends BaseDL4JTest {
} }
} }
@Test
@DisplayName("Test Graves Bidirectional LSTM Forward Pass Helper") @DisplayName("Test Graves Bidirectional LSTM Forward Pass Helper")
void testGravesBidirectionalLSTMForwardPassHelper() throws Exception { @ParameterizedTest
@MethodSource("org.deeplearning4j.nn.layers.recurrent.GravesBidirectionalLSTMTest#params")
void testGravesBidirectionalLSTMForwardPassHelper(RNNFormat rnnDataFormat,Nd4jBackend backend) throws Exception {
// GravesBidirectionalLSTM.activateHelper() has different behaviour (due to optimizations) when forBackprop==true vs false // GravesBidirectionalLSTM.activateHelper() has different behaviour (due to optimizations) when forBackprop==true vs false
// But should otherwise provide identical activations // But should otherwise provide identical activations
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -204,11 +212,10 @@ class GravesBidirectionalLSTMTest extends BaseDL4JTest {
} }
} }
@Test @DisplayName("Test Get Set Params")
@DisplayName("Test Get Set Parmas") @MethodSource("org.deeplearning4j.nn.layers.recurrent.GravesBidirectionalLSTMTest#params")
@MethodSource("#params")
@ParameterizedTest @ParameterizedTest
void testGetSetParmas(RNNFormat rnnDataFormat) { void testGetSetParmas(RNNFormat rnnDataFormat,Nd4jBackend backend) {
final int nIn = 2; final int nIn = 2;
final int layerSize = 3; final int layerSize = 3;
final int miniBatchSize = 2; final int miniBatchSize = 2;
@ -226,11 +233,10 @@ class GravesBidirectionalLSTMTest extends BaseDL4JTest {
assertArrayEquals(act2.data().asDouble(), act1.data().asDouble(), 1e-8); assertArrayEquals(act2.data().asDouble(), act1.data().asDouble(), 1e-8);
} }
@Test
@DisplayName("Test Simple Forwards And Backwards Activation") @DisplayName("Test Simple Forwards And Backwards Activation")
@MethodSource("#params") @MethodSource("org.deeplearning4j.nn.layers.recurrent.GravesBidirectionalLSTMTest#params")
@ParameterizedTest @ParameterizedTest
void testSimpleForwardsAndBackwardsActivation(RNNFormat rnnDataFormat) { void testSimpleForwardsAndBackwardsActivation(RNNFormat rnnDataFormat,Nd4jBackend backend) {
final int nIn = 2; final int nIn = 2;
final int layerSize = 3; final int layerSize = 3;
final int miniBatchSize = 1; final int miniBatchSize = 1;
@ -336,9 +342,10 @@ class GravesBidirectionalLSTMTest extends BaseDL4JTest {
assertArrayEquals(backEpsilon.dup().data().asDouble(), refEpsilon.dup().data().asDouble(), 1e-6); assertArrayEquals(backEpsilon.dup().data().asDouble(), refEpsilon.dup().data().asDouble(), 1e-6);
} }
@Test @MethodSource("org.deeplearning4j.nn.layers.recurrent.GravesBidirectionalLSTMTest#params")
@DisplayName("Test Serialization") @DisplayName("Test Serialization")
void testSerialization() { @ParameterizedTest
void testSerialization(RNNFormat rnnDataFormat,Nd4jBackend backend) {
final MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new AdaGrad(0.1)).l2(0.001).seed(12345).list().layer(0, new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().activation(Activation.TANH).nIn(2).nOut(2).dist(new UniformDistribution(-0.05, 0.05)).build()).layer(1, new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().activation(Activation.TANH).nIn(2).nOut(2).dist(new UniformDistribution(-0.05, 0.05)).build()).layer(2, new org.deeplearning4j.nn.conf.layers.RnnOutputLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).nIn(2).nOut(2).build()).build(); final MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new AdaGrad(0.1)).l2(0.001).seed(12345).list().layer(0, new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().activation(Activation.TANH).nIn(2).nOut(2).dist(new UniformDistribution(-0.05, 0.05)).build()).layer(1, new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().activation(Activation.TANH).nIn(2).nOut(2).dist(new UniformDistribution(-0.05, 0.05)).build()).layer(2, new org.deeplearning4j.nn.conf.layers.RnnOutputLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).nIn(2).nOut(2).build()).build();
final String json1 = conf1.toJson(); final String json1 = conf1.toJson();
final MultiLayerConfiguration conf2 = MultiLayerConfiguration.fromJson(json1); final MultiLayerConfiguration conf2 = MultiLayerConfiguration.fromJson(json1);
@ -346,11 +353,10 @@ class GravesBidirectionalLSTMTest extends BaseDL4JTest {
assertEquals(json1, json2); assertEquals(json1, json2);
} }
@Test
@DisplayName("Test Gate Activation Fns Sanity Check") @DisplayName("Test Gate Activation Fns Sanity Check")
@MethodSource("#params") @MethodSource("org.deeplearning4j.nn.layers.recurrent.GravesBidirectionalLSTMTest#params")
@ParameterizedTest @ParameterizedTest
void testGateActivationFnsSanityCheck(RNNFormat rnnDataFormat) { void testGateActivationFnsSanityCheck(RNNFormat rnnDataFormat,Nd4jBackend backend) {
for (String gateAfn : new String[] { "sigmoid", "hardsigmoid" }) { for (String gateAfn : new String[] { "sigmoid", "hardsigmoid" }) {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(12345).list().layer(0, new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().gateActivationFunction(gateAfn).activation(Activation.TANH).nIn(2).nOut(2).dataFormat(rnnDataFormat).build()).layer(1, new org.deeplearning4j.nn.conf.layers.RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(2).nOut(2).dataFormat(rnnDataFormat).activation(Activation.TANH).build()).build(); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(12345).list().layer(0, new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().gateActivationFunction(gateAfn).activation(Activation.TANH).nIn(2).nOut(2).dataFormat(rnnDataFormat).build()).layer(1, new org.deeplearning4j.nn.conf.layers.RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(2).nOut(2).dataFormat(rnnDataFormat).activation(Activation.TANH).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);

View File

@ -34,12 +34,17 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex;
import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List;
import java.util.stream.Stream; import java.util.stream.Stream;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@ -51,13 +56,20 @@ class MaskZeroLayerTest extends BaseDL4JTest {
public static Stream<Arguments> params() { public static Stream<Arguments> params() {
return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of); List<Arguments> args = new ArrayList<>();
for(Nd4jBackend nd4jBackend : BaseNd4jTestWithBackends.BACKENDS) {
for(RNNFormat rnnFormat : RNNFormat.values()) {
args.add(Arguments.of(rnnFormat,nd4jBackend));
} }
}
return args.stream();
}
@DisplayName("Activate") @DisplayName("Activate")
@ParameterizedTest @ParameterizedTest
@MethodSource("#params") @MethodSource("org.deeplearning4j.nn.layers.recurrent.MaskZeroLayerTest#params")
void activate(RNNFormat rnnDataFormat) { void activate(RNNFormat rnnDataFormat,Nd4jBackend backend) {
// GIVEN two examples where some of the timesteps are zero. // GIVEN two examples where some of the timesteps are zero.
INDArray ex1 = Nd4j.create(new double[][] { new double[] { 0, 3, 5 }, new double[] { 0, 0, 2 } }); INDArray ex1 = Nd4j.create(new double[][] { new double[] { 0, 3, 5 }, new double[] { 0, 0, 2 } });
INDArray ex2 = Nd4j.create(new double[][] { new double[] { 0, 0, 2 }, new double[] { 0, 0, 2 } }); INDArray ex2 = Nd4j.create(new double[][] { new double[] { 0, 0, 2 }, new double[] { 0, 0, 2 } });
@ -96,8 +108,8 @@ class MaskZeroLayerTest extends BaseDL4JTest {
@DisplayName("Test Serialization") @DisplayName("Test Serialization")
@ParameterizedTest @ParameterizedTest
@MethodSource("#params") @MethodSource("org.deeplearning4j.nn.layers.recurrent.MaskZeroLayerTest#params")
void testSerialization(RNNFormat rnnDataFormat) { void testSerialization(RNNFormat rnnDataFormat,Nd4jBackend backend) {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(new org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer.Builder().setMaskValue(0.0).setUnderlying(new LSTM.Builder().nIn(4).nOut(5).dataFormat(rnnDataFormat).build()).build()).build(); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(new org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer.Builder().setMaskValue(0.0).setUnderlying(new LSTM.Builder().nIn(4).nOut(5).dataFormat(rnnDataFormat).build()).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();

View File

@ -44,11 +44,13 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.primitives.Pair; import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.factory.Nd4jBackend;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
@ -66,18 +68,18 @@ public class RnnDataFormatTests extends BaseDL4JTest {
for (boolean helpers: new boolean[]{true, false}) for (boolean helpers: new boolean[]{true, false})
for (boolean lastTimeStep: new boolean[]{true, false}) for (boolean lastTimeStep: new boolean[]{true, false})
for (boolean maskZero: new boolean[]{true, false}) for (boolean maskZero: new boolean[]{true, false})
ret.add(new Object[]{helpers, lastTimeStep, maskZero}); for(Nd4jBackend backend : BaseNd4jTestWithBackends.BACKENDS)
ret.add(new Object[]{helpers, lastTimeStep, maskZero,backend});
return ret.stream().map(Arguments::of); return ret.stream().map(Arguments::of);
} }
@Test @MethodSource("org.deeplearning4j.nn.layers.recurrent.RnnDataFormatTests#params")
@MethodSource("#params")
@ParameterizedTest @ParameterizedTest
public void testSimpleRnn(boolean helpers, public void testSimpleRnn(boolean helpers,
boolean lastTimeStep, boolean lastTimeStep,
boolean maskZeros boolean maskZeros,
) { Nd4jBackend backend) {
try { try {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -110,10 +112,10 @@ public class RnnDataFormatTests extends BaseDL4JTest {
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("#params") @MethodSource("org.deeplearning4j.nn.layers.recurrent.RnnDataFormatTests#params")
public void testLSTM(boolean helpers, public void testLSTM(boolean helpers,
boolean lastTimeStep, boolean lastTimeStep,
boolean maskZeros) { boolean maskZeros,Nd4jBackend backend) {
try { try {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -146,12 +148,11 @@ public class RnnDataFormatTests extends BaseDL4JTest {
} }
@Test @MethodSource("org.deeplearning4j.nn.layers.recurrent.RnnDataFormatTests#params")
@MethodSource("#params")
@ParameterizedTest @ParameterizedTest
public void testGraveLSTM(boolean helpers, public void testGraveLSTM(boolean helpers,
boolean lastTimeStep, boolean lastTimeStep,
boolean maskZeros) { boolean maskZeros,Nd4jBackend backend) {
try { try {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -184,12 +185,11 @@ public class RnnDataFormatTests extends BaseDL4JTest {
} }
@Test @MethodSource("org.deeplearning4j.nn.layers.recurrent.RnnDataFormatTests#params")
@MethodSource("#params")
@ParameterizedTest @ParameterizedTest
public void testGraveBiLSTM(boolean helpers, public void testGraveBiLSTM(boolean helpers,
boolean lastTimeStep, boolean lastTimeStep,
boolean maskZeros) { boolean maskZeros,Nd4jBackend backend) {
try { try {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);

View File

@ -41,14 +41,17 @@ import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.enums.RnnDataFormat; import org.nd4j.enums.RnnDataFormat;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.common.primitives.Pair; import org.nd4j.common.primitives.Pair;
import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.Random; import java.util.Random;
@ -62,12 +65,18 @@ public class TestRnnLayers extends BaseDL4JTest {
public static Stream<Arguments> params() { public static Stream<Arguments> params() {
return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of); List<Arguments> args = new ArrayList<>();
for(Nd4jBackend nd4jBackend : BaseNd4jTestWithBackends.BACKENDS) {
for(RNNFormat rnnFormat : RNNFormat.values()) {
args.add(Arguments.of(rnnFormat,nd4jBackend));
}
}
return args.stream();
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("#params") @MethodSource("org.deeplearning4j.nn.layers.recurrent.TestRnnLayers#params")
public void testTimeStepIs3Dimensional(RNNFormat rnnDataFormat) { public void testTimeStepIs3Dimensional(RNNFormat rnnDataFormat,Nd4jBackend backend) {
int nIn = 12; int nIn = 12;
int nOut = 3; int nOut = 3;
@ -117,8 +126,8 @@ public class TestRnnLayers extends BaseDL4JTest {
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("#params") @MethodSource("org.deeplearning4j.nn.layers.recurrent.TestRnnLayers#params")
public void testDropoutRecurrentLayers(RNNFormat rnnDataFormat){ public void testDropoutRecurrentLayers(RNNFormat rnnDataFormat,Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
String[] layerTypes = new String[]{"graves", "lstm", "simple"}; String[] layerTypes = new String[]{"graves", "lstm", "simple"};
@ -216,8 +225,8 @@ public class TestRnnLayers extends BaseDL4JTest {
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("#params") @MethodSource("org.deeplearning4j.nn.layers.recurrent.TestRnnLayers#params")
public void testMismatchedInputLabelLength(RNNFormat rnnDataFormat){ public void testMismatchedInputLabelLength(RNNFormat rnnDataFormat,Nd4jBackend backend){
for( int i = 0; i < 2; i++) { for( int i = 0; i < 2; i++) {

View File

@ -33,14 +33,18 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.ops.transforms.Transforms; import org.nd4j.linalg.ops.transforms.Transforms;
import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List;
import java.util.stream.Stream; import java.util.stream.Stream;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@ -52,12 +56,18 @@ public class TestSimpleRnn extends BaseDL4JTest {
public static Stream<Arguments> params() { public static Stream<Arguments> params() {
return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of); List<Arguments> args = new ArrayList<>();
for(Nd4jBackend nd4jBackend : BaseNd4jTestWithBackends.BACKENDS) {
for(RNNFormat rnnFormat : RNNFormat.values()) {
args.add(Arguments.of(rnnFormat,nd4jBackend));
}
}
return args.stream();
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("#params") @MethodSource("org.deeplearning4j.nn.layers.recurrent.TestRnnLayers#params")
public void testSimpleRnn(RNNFormat rnnDataFormat) { public void testSimpleRnn(RNNFormat rnnDataFormat, Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int m = 3; int m = 3;
@ -126,8 +136,8 @@ public class TestSimpleRnn extends BaseDL4JTest {
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("#params") @MethodSource("org.deeplearning4j.nn.layers.recurrent.TestRnnLayers#params")
public void testBiasInit(RNNFormat rnnDataFormat) { public void testBiasInit(RNNFormat rnnDataFormat,Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int nIn = 5; int nIn = 5;
int layerSize = 6; int layerSize = 6;

View File

@ -41,15 +41,19 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List;
import java.util.stream.Stream; import java.util.stream.Stream;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@ -58,12 +62,18 @@ public class TestTimeDistributed extends BaseDL4JTest {
public static Stream<Arguments> params() { public static Stream<Arguments> params() {
return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of); List<Arguments> args = new ArrayList<>();
for(Nd4jBackend nd4jBackend : BaseNd4jTestWithBackends.BACKENDS) {
for(RNNFormat rnnFormat : RNNFormat.values()) {
args.add(Arguments.of(rnnFormat,nd4jBackend));
}
}
return args.stream();
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("#params") @MethodSource("org.deeplearning4j.nn.layers.recurrent.TestTimeDistributed#params")
public void testTimeDistributed(RNNFormat rnnDataFormat){ public void testTimeDistributed(RNNFormat rnnDataFormat,Nd4jBackend backend){
for(WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.ENABLED, WorkspaceMode.NONE}) { for(WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.ENABLED, WorkspaceMode.NONE}) {
MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder() MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder()
@ -133,10 +143,9 @@ public class TestTimeDistributed extends BaseDL4JTest {
} }
@Test @MethodSource("org.deeplearning4j.nn.layers.recurrent.TestTimeDistributed#params")
@MethodSource("#params")
@ParameterizedTest @ParameterizedTest
public void testTimeDistributedDense(RNNFormat rnnDataFormat){ public void testTimeDistributedDense(RNNFormat rnnDataFormat,Nd4jBackend backend) {
for( int rnnType = 0; rnnType < 3; rnnType++ ) { for( int rnnType = 0; rnnType < 3; rnnType++ ) {
for( int ffType = 0; ffType < 3; ffType++ ) { for( int ffType = 0; ffType < 3; ffType++ ) {

View File

@ -261,7 +261,7 @@ public class TestMemoryReports extends BaseDL4JTest {
@Test @Test
public void testPreprocessors() throws Exception { public void testPreprocessors() throws Exception {
//https://github.com/deeplearning4j/deeplearning4j/issues/4223 //https://github.com/eclipse/deeplearning4j/issues/4223
File f = new ClassPathResource("4223/CompGraphConfig.json").getTempFileFromArchive(); File f = new ClassPathResource("4223/CompGraphConfig.json").getTempFileFromArchive();
String s = FileUtils.readFileToString(f, Charset.defaultCharset()); String s = FileUtils.readFileToString(f, Charset.defaultCharset());

View File

@ -88,7 +88,7 @@ public class WorkspaceTests extends BaseDL4JTest {
@Test @Test
public void testWorkspaceIndependence() { public void testWorkspaceIndependence() {
//https://github.com/deeplearning4j/deeplearning4j/issues/4337 //https://github.com/eclipse/deeplearning4j/issues/4337
int depthIn = 2; int depthIn = 2;
int depthOut = 2; int depthOut = 2;
int nOut = 2; int nOut = 2;
@ -143,7 +143,7 @@ public class WorkspaceTests extends BaseDL4JTest {
@Test @Test
public void testWithPreprocessorsCG() { public void testWithPreprocessorsCG() {
//https://github.com/deeplearning4j/deeplearning4j/issues/4347 //https://github.com/eclipse/deeplearning4j/issues/4347
//Cause for the above issue was layerVertex.setInput() applying the preprocessor, with the result //Cause for the above issue was layerVertex.setInput() applying the preprocessor, with the result
// not being detached properly from the workspace... // not being detached properly from the workspace...

View File

@ -195,7 +195,7 @@ public class ValidateMKLDNN extends BaseDL4JTest {
} }
} }
@Test @Disabled //https://github.com/deeplearning4j/deeplearning4j/issues/7272 @Test @Disabled //https://github.com/eclipse/deeplearning4j/issues/7272
public void validateLRN() { public void validateLRN() {
//Only run test if using nd4j-native backend //Only run test if using nd4j-native backend

View File

@ -938,7 +938,7 @@ public class MultiLayerTest extends BaseDL4JTest {
@DisplayName("Test MLN Updater Blocks") @DisplayName("Test MLN Updater Blocks")
void testMLNUpdaterBlocks() { void testMLNUpdaterBlocks() {
// Check that setting learning rate results in correct rearrangement of updater state within updater blocks // Check that setting learning rate results in correct rearrangement of updater state within updater blocks
// https://github.com/deeplearning4j/deeplearning4j/issues/6809#issuecomment-463892644 // https://github.com/eclipse/deeplearning4j/issues/6809#issuecomment-463892644
double lr = 1e-3; double lr = 1e-3;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).weightInit(WeightInit.XAVIER).updater(new Adam(lr)).list().layer(new DenseLayer.Builder().nIn(5).nOut(3).build()).layer(new DenseLayer.Builder().nIn(3).nOut(2).build()).layer(new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(2).nOut(1).activation(Activation.SIGMOID).build()).build(); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).weightInit(WeightInit.XAVIER).updater(new Adam(lr)).list().layer(new DenseLayer.Builder().nIn(5).nOut(3).build()).layer(new DenseLayer.Builder().nIn(3).nOut(2).build()).layer(new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(2).nOut(1).activation(Activation.SIGMOID).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);

View File

@ -181,7 +181,7 @@ class TransferLearningCompGraphTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Object Overrides") @DisplayName("Test Object Overrides")
void testObjectOverrides() { void testObjectOverrides() {
// https://github.com/deeplearning4j/deeplearning4j/issues/4368 // https://github.com/eclipse/deeplearning4j/issues/4368
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().dropOut(0.5).weightNoise(new DropConnect(0.5)).l2(0.5).constrainWeights(new UnitNormConstraint()).graphBuilder().addInputs("in").addLayer("layer", new DenseLayer.Builder().nIn(10).nOut(10).build(), "in").setOutputs("layer").build(); ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().dropOut(0.5).weightNoise(new DropConnect(0.5)).l2(0.5).constrainWeights(new UnitNormConstraint()).graphBuilder().addInputs("in").addLayer("layer", new DenseLayer.Builder().nIn(10).nOut(10).build(), "in").setOutputs("layer").build();
ComputationGraph orig = new ComputationGraph(conf); ComputationGraph orig = new ComputationGraph(conf);
orig.init(); orig.init();

View File

@ -317,7 +317,7 @@ class TransferLearningMLNTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Object Overrides") @DisplayName("Test Object Overrides")
void testObjectOverrides() { void testObjectOverrides() {
// https://github.com/deeplearning4j/deeplearning4j/issues/4368 // https://github.com/eclipse/deeplearning4j/issues/4368
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dropOut(0.5).weightNoise(new DropConnect(0.5)).l2(0.5).constrainWeights(new UnitNormConstraint()).list().layer(new DenseLayer.Builder().nIn(10).nOut(10).build()).build(); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dropOut(0.5).weightNoise(new DropConnect(0.5)).l2(0.5).constrainWeights(new UnitNormConstraint()).list().layer(new DenseLayer.Builder().nIn(10).nOut(10).build()).build();
MultiLayerNetwork orig = new MultiLayerNetwork(conf); MultiLayerNetwork orig = new MultiLayerNetwork(conf);
orig.init(); orig.init();

View File

@ -200,7 +200,7 @@ public class RegressionTest100a extends BaseDL4JTest {
//Minor bug in 1.0.0-beta and earlier: not adding epsilon value to forward pass for batch norm //Minor bug in 1.0.0-beta and earlier: not adding epsilon value to forward pass for batch norm
//Which means: the record output doesn't have this. To account for this, we'll manually set eps to 0.0 here //Which means: the record output doesn't have this. To account for this, we'll manually set eps to 0.0 here
//https://github.com/deeplearning4j/deeplearning4j/issues/5836#issuecomment-405526228 //https://github.com/eclipse/deeplearning4j/issues/5836#issuecomment-405526228
for(Layer l : net.getLayers()){ for(Layer l : net.getLayers()){
if(l.conf().getLayer() instanceof BatchNormalization){ if(l.conf().getLayer() instanceof BatchNormalization){
BatchNormalization bn = (BatchNormalization) l.conf().getLayer(); BatchNormalization bn = (BatchNormalization) l.conf().getLayer();

View File

@ -97,7 +97,7 @@ public class CustomLayer extends FeedForwardLayer {
//In this case, we can use the DefaultParamInitializer, which is the same one used for DenseLayer //In this case, we can use the DefaultParamInitializer, which is the same one used for DenseLayer
//For more complex layers, you may need to implement a custom parameter initializer //For more complex layers, you may need to implement a custom parameter initializer
//See the various parameter initializers here: //See the various parameter initializers here:
//https://github.com/deeplearning4j/deeplearning4j/tree/master/deeplearning4j-core/src/main/java/org/deeplearning4j/nn/params //https://github.com/eclipse/deeplearning4j/tree/master/deeplearning4j-core/src/main/java/org/deeplearning4j/nn/params
return DefaultParamInitializer.getInstance(); return DefaultParamInitializer.getInstance();
} }

View File

@ -36,6 +36,7 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
@ -73,6 +74,7 @@ class CrashReportingUtilTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test") @DisplayName("Test")
@Disabled
void test() throws Exception { void test() throws Exception {
File dir = testDir.toFile(); File dir = testDir.toFile();
CrashReportingUtil.crashDumpOutputDirectory(dir); CrashReportingUtil.crashDumpOutputDirectory(dir);

View File

@ -33,6 +33,7 @@ import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
@ -57,6 +58,7 @@ import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
@DisplayName("Model Serializer Test") @DisplayName("Model Serializer Test")
@Disabled
class ModelSerializerTest extends BaseDL4JTest { class ModelSerializerTest extends BaseDL4JTest {
@TempDir @TempDir

View File

@ -281,7 +281,7 @@ public class TestInstantiation extends BaseDL4JTest {
@Test @Test
public void testYolo4635() throws Exception { public void testYolo4635() throws Exception {
ignoreIfCuda(); ignoreIfCuda();
//https://github.com/deeplearning4j/deeplearning4j/issues/4635 //https://github.com/eclipse/deeplearning4j/issues/4635
int nClasses = 10; int nClasses = 10;
TinyYOLO model = TinyYOLO.builder().numClasses(nClasses).build(); TinyYOLO model = TinyYOLO.builder().numClasses(nClasses).build();
@ -292,7 +292,7 @@ public class TestInstantiation extends BaseDL4JTest {
@Test @Test
public void testTransferLearning() throws Exception { public void testTransferLearning() throws Exception {
ignoreIfCuda(); ignoreIfCuda();
//https://github.com/deeplearning4j/deeplearning4j/issues/7193 //https://github.com/eclipse/deeplearning4j/issues/7193
ComputationGraph cg = (ComputationGraph) ResNet50.builder().build().initPretrained(); ComputationGraph cg = (ComputationGraph) ResNet50.builder().build().initPretrained();

View File

@ -36,7 +36,7 @@ public class DropOut extends BaseRandomOp {
public DropOut(SameDiff sameDiff, SDVariable input, double p) { public DropOut(SameDiff sameDiff, SDVariable input, double p) {
super(sameDiff, input); super(sameDiff, input);
this.p = p; this.p = p;
//https://github.com/deeplearning4j/deeplearning4j/issues/5650 //https://github.com/eclipse/deeplearning4j/issues/5650
throw new UnsupportedOperationException("Dropout SameDiff support disabled pending backprop support"); throw new UnsupportedOperationException("Dropout SameDiff support disabled pending backprop support");
} }

View File

@ -250,7 +250,7 @@ public class VersionCheck {
} }
} catch (NoClassDefFoundError e){ } catch (NoClassDefFoundError e){
//Should only happen on Android 7.0 or earlier - silently ignore //Should only happen on Android 7.0 or earlier - silently ignore
//https://github.com/deeplearning4j/deeplearning4j/issues/6609 //https://github.com/eclipse/deeplearning4j/issues/6609
} catch (Throwable e){ } catch (Throwable e){
//log and skip //log and skip
log.debug("Error finding/loading version check resources", e); log.debug("Error finding/loading version check resources", e);

View File

@ -383,7 +383,7 @@ public class LossOpValidation extends BaseOpValidation {
.build(); .build();
Nd4j.getExecutioner().exec(op); Nd4j.getExecutioner().exec(op);
INDArray exp = Nd4j.scalar(0.6); //https://github.com/deeplearning4j/deeplearning4j/issues/6532 INDArray exp = Nd4j.scalar(0.6); //https://github.com/eclipse/deeplearning4j/issues/6532
assertEquals(exp, out); assertEquals(exp, out);
} }

View File

@ -141,7 +141,7 @@ public class MiscOpValidation extends BaseOpValidation {
bcOp = new FloorModOp(sd, in3, in2).outputVariable(); bcOp = new FloorModOp(sd, in3, in2).outputVariable();
name = "floormod"; name = "floormod";
if(OpValidationSuite.IGNORE_FAILING){ if(OpValidationSuite.IGNORE_FAILING){
//https://github.com/deeplearning4j/deeplearning4j/issues/5976 //https://github.com/eclipse/deeplearning4j/issues/5976
continue; continue;
} }
break; break;
@ -232,7 +232,7 @@ public class MiscOpValidation extends BaseOpValidation {
bcOp = new FloorModOp(sd, in3, in2).outputVariable(); bcOp = new FloorModOp(sd, in3, in2).outputVariable();
name = "floormod"; name = "floormod";
if(OpValidationSuite.IGNORE_FAILING){ if(OpValidationSuite.IGNORE_FAILING){
//https://github.com/deeplearning4j/deeplearning4j/issues/5976 //https://github.com/eclipse/deeplearning4j/issues/5976
continue; continue;
} }
break; break;
@ -334,7 +334,7 @@ public class MiscOpValidation extends BaseOpValidation {
bcOp = new FloorModOp(sd, in3, in2).outputVariable(); bcOp = new FloorModOp(sd, in3, in2).outputVariable();
name = "floormod"; name = "floormod";
if(OpValidationSuite.IGNORE_FAILING){ if(OpValidationSuite.IGNORE_FAILING){
//https://github.com/deeplearning4j/deeplearning4j/issues/5976 //https://github.com/eclipse/deeplearning4j/issues/5976
continue; continue;
} }
break; break;
@ -717,7 +717,7 @@ public class MiscOpValidation extends BaseOpValidation {
for (char bOrder : new char[]{'c', 'f'}) { for (char bOrder : new char[]{'c', 'f'}) {
for (boolean transposeA : new boolean[]{false, true}) { for (boolean transposeA : new boolean[]{false, true}) {
for (boolean transposeB : new boolean[]{false, true}) { for (boolean transposeB : new boolean[]{false, true}) {
for (boolean transposeResult : new boolean[]{false, true}) { //https://github.com/deeplearning4j/deeplearning4j/issues/5648 for (boolean transposeResult : new boolean[]{false, true}) { //https://github.com/eclipse/deeplearning4j/issues/5648
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
INDArray aArr = Nd4j.rand(DataType.DOUBLE, t(transposeA, aShape)).dup(aOrder); INDArray aArr = Nd4j.rand(DataType.DOUBLE, t(transposeA, aShape)).dup(aOrder);
@ -761,7 +761,7 @@ public class MiscOpValidation extends BaseOpValidation {
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testBatchMmulBasic(Nd4jBackend backend) { public void testBatchMmulBasic(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing(); //https://github.com/deeplearning4j/deeplearning4j/issues/6873 OpValidationSuite.ignoreFailing(); //https://github.com/eclipse/deeplearning4j/issues/6873
int M = 5; int M = 5;
int N = 3; int N = 3;
int K = 4; int K = 4;
@ -1188,7 +1188,7 @@ public class MiscOpValidation extends BaseOpValidation {
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testOneHotOp(){ public void testOneHotOp(){
//https://www.tensorflow.org/api_docs/python/tf/one_hot //https://www.tensorflow.org/api_docs/python/tf/one_hot
//https://github.com/deeplearning4j/deeplearning4j/blob/master/libnd4j/include/ops/declarable/generic/parity_ops/onehot.cpp //https://github.com/eclipse/deeplearning4j/blob/master/libnd4j/include/ops/declarable/generic/parity_ops/onehot.cpp
for( int axis=-1; axis<=0; axis++ ) { for( int axis=-1; axis<=0; axis++ ) {
String err = OpValidation.validate(new OpTestCase(new OneHot(Nd4j.create(new double[]{0, 1, 2}), String err = OpValidation.validate(new OpTestCase(new OneHot(Nd4j.create(new double[]{0, 1, 2}),
@ -1244,7 +1244,7 @@ public class MiscOpValidation extends BaseOpValidation {
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testOneHot3(Nd4jBackend backend) { public void testOneHot3(Nd4jBackend backend) {
//https://github.com/deeplearning4j/deeplearning4j/issues/6872 //https://github.com/eclipse/deeplearning4j/issues/6872
//https://www.tensorflow.org/api_docs/python/tf/one_hot //https://www.tensorflow.org/api_docs/python/tf/one_hot
//indices = [[0, 2], [1, -1]] //indices = [[0, 2], [1, -1]]

View File

@ -227,7 +227,7 @@ public class RandomOpValidation extends BaseOpValidation {
break; break;
case 4: case 4:
if(OpValidationSuite.IGNORE_FAILING){ if(OpValidationSuite.IGNORE_FAILING){
//https://github.com/deeplearning4j/deeplearning4j/issues/6036 //https://github.com/eclipse/deeplearning4j/issues/6036
continue; continue;
} }
name = "truncatednormal"; name = "truncatednormal";

View File

@ -721,7 +721,7 @@ public class ReductionOpValidation extends BaseOpValidation {
break; break;
case 6: case 6:
if (OpValidationSuite.IGNORE_FAILING) { if (OpValidationSuite.IGNORE_FAILING) {
//https://github.com/deeplearning4j/deeplearning4j/issues/6069 //https://github.com/eclipse/deeplearning4j/issues/6069
continue; continue;
} }
name = "dot"; name = "dot";

View File

@ -126,7 +126,7 @@ public class ShapeOpValidation extends BaseOpValidation {
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testReshapeGradient(Nd4jBackend backend) { public void testReshapeGradient(Nd4jBackend backend) {
//https://github.com/deeplearning4j/deeplearning4j/issues/6873 //https://github.com/eclipse/deeplearning4j/issues/6873
int[] origShape = new int[]{3, 4, 5}; int[] origShape = new int[]{3, 4, 5};
@ -1305,7 +1305,7 @@ public class ShapeOpValidation extends BaseOpValidation {
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSegmentOps(){ public void testSegmentOps(){
OpValidationSuite.ignoreFailing(); OpValidationSuite.ignoreFailing();
//https://github.com/deeplearning4j/deeplearning4j/issues/6952 //https://github.com/eclipse/deeplearning4j/issues/6952
INDArray s = Nd4j.create(new double[]{0,0,0,1,2,2,3,3}, new long[]{8}).castTo(DataType.INT); INDArray s = Nd4j.create(new double[]{0,0,0,1,2,2,3,3}, new long[]{8}).castTo(DataType.INT);
INDArray d = Nd4j.create(new double[]{5,1,7,2,3,4,1,3}, new long[]{8}); INDArray d = Nd4j.create(new double[]{5,1,7,2,3,4,1,3}, new long[]{8});
int numSegments = 4; int numSegments = 4;
@ -1910,7 +1910,7 @@ public class ShapeOpValidation extends BaseOpValidation {
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testDistancesExec(){ public void testDistancesExec(){
//https://github.com/deeplearning4j/deeplearning4j/issues/7001 //https://github.com/eclipse/deeplearning4j/issues/7001
for(String s : new String[]{"euclidean", "manhattan", "cosinesim", "cosinedist", "jaccard"}) { for(String s : new String[]{"euclidean", "manhattan", "cosinesim", "cosinedist", "jaccard"}) {
log.info("Starting: {}", s); log.info("Starting: {}", s);
INDArray defaultTestCase = Nd4j.create(4, 4); INDArray defaultTestCase = Nd4j.create(4, 4);

View File

@ -1745,7 +1745,7 @@ public class TransformOpValidation extends BaseOpValidation {
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testZeta(Nd4jBackend backend) { public void testZeta(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing(); //https://github.com/deeplearning4j/deeplearning4j/issues/6182 OpValidationSuite.ignoreFailing(); //https://github.com/eclipse/deeplearning4j/issues/6182
INDArray x = Nd4j.rand(3, 4).addi(1.0); INDArray x = Nd4j.rand(3, 4).addi(1.0);
INDArray q = Nd4j.rand(3, 4); INDArray q = Nd4j.rand(3, 4);

View File

@ -7429,7 +7429,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGet(){ public void testGet(){
//https://github.com/deeplearning4j/deeplearning4j/issues/6133 //https://github.com/eclipse/deeplearning4j/issues/6133
INDArray m = Nd4j.linspace(0,99,100, DataType.DOUBLE).reshape('c', 10,10); INDArray m = Nd4j.linspace(0,99,100, DataType.DOUBLE).reshape('c', 10,10);
INDArray exp = Nd4j.create(new double[]{5, 15, 25, 35, 45, 55, 65, 75, 85, 95}, new int[]{10}); INDArray exp = Nd4j.create(new double[]{5, 15, 25, 35, 45, 55, 65, 75, 85, 95}, new int[]{10});
INDArray col = m.getColumn(5); INDArray col = m.getColumn(5);

View File

@ -40,7 +40,7 @@ import java.util.stream.Stream;
@Slf4j @Slf4j
public abstract class BaseNd4jTestWithBackends extends BaseND4JTest { public abstract class BaseNd4jTestWithBackends extends BaseND4JTest {
private static List<Nd4jBackend> BACKENDS = new ArrayList<>(); public static List<Nd4jBackend> BACKENDS = new ArrayList<>();
static { static {
List<String> backendsToRun = Nd4jTestSuite.backendsToRun(); List<String> backendsToRun = Nd4jTestSuite.backendsToRun();

View File

@ -36,7 +36,7 @@ public class ClassPathResourceTest {
@Test @Test
public void testDirExtractingIntelliJ(@TempDir Path testDir) throws Exception { public void testDirExtractingIntelliJ(@TempDir Path testDir) throws Exception {
//https://github.com/deeplearning4j/deeplearning4j/issues/6483 //https://github.com/eclipse/deeplearning4j/issues/6483
ClassPathResource cpr = new ClassPathResource("somedir"); ClassPathResource cpr = new ClassPathResource("somedir");